diff --git a/.gitattributes b/.gitattributes index ac80daab6b..39e1717ed6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -8,6 +8,21 @@ docs/reference/api/*.md linguist-generated=true docs/reference/cli/*.md linguist-generated=true coderd/apidoc/swagger.json linguist-generated=true coderd/database/dump.sql linguist-generated=true + +# Database codegen (sqlc) +coderd/database/queries.sql.go linguist-generated=true +coderd/database/models.go linguist-generated=true +coderd/database/querier.go linguist-generated=true + +# Database codegen (gomock) +coderd/database/dbmock/dbmock.go linguist-generated=true + +# Database codegen (dbgen) +coderd/database/dbmetrics/querymetrics.go linguist-generated=true +coderd/database/unique_constraint.go linguist-generated=true +coderd/database/foreign_key_constraint.go linguist-generated=true +coderd/database/check_constraint.go linguist-generated=true + peerbroker/proto/*.go linguist-generated=true provisionerd/proto/*.go linguist-generated=true provisionerd/proto/version.go linguist-generated=false diff --git a/.github/actions/go-test-failure-report/action.yaml b/.github/actions/go-test-failure-report/action.yaml new file mode 100644 index 0000000000..b793ce114f --- /dev/null +++ b/.github/actions/go-test-failure-report/action.yaml @@ -0,0 +1,76 @@ +name: "Go Test Failure Report" +description: "Publish Go test failure summaries and upload failure artifacts" + +inputs: + json-file: + description: "Path to the gotestsum JSON file. Use default for RUNNER_TEMP/go-test.json." + required: false + default: "default" + failures-file: + description: "Path to write newline-delimited failure details. Use default for RUNNER_TEMP/go-test-failures.ndjson." + required: false + default: "default" + artifact-name: + description: "Artifact name for uploaded failure details" + required: true + retention-days: + description: "Artifact retention in days" + required: false + default: "7" + max-output-bytes: + description: "Maximum bytes to include in the markdown summary" + required: false + default: "16384" + max-failures: + description: "Maximum failures to include in the summary output" + required: false + default: "50" + +runs: + using: "composite" + steps: + - name: Resolve Go test report paths + id: paths + shell: bash + env: + JSON_FILE: ${{ inputs.json-file }} + FAILURES_FILE: ${{ inputs.failures-file }} + run: | + set -euo pipefail + json_file="$JSON_FILE" + if [[ "$json_file" == "default" ]]; then + json_file="${RUNNER_TEMP}/go-test.json" + fi + failures_file="$FAILURES_FILE" + if [[ "$failures_file" == "default" ]]; then + failures_file="${RUNNER_TEMP}/go-test-failures.ndjson" + fi + { + echo "json-file=${json_file}" + echo "failures-file=${failures_file}" + } >> "$GITHUB_OUTPUT" + + - name: Publish Go test failure summary + shell: bash + env: + JSON_FILE: ${{ steps.paths.outputs.json-file }} + FAILURES_FILE: ${{ steps.paths.outputs.failures-file }} + MAX_OUTPUT_BYTES: ${{ inputs.max-output-bytes }} + MAX_FAILURES: ${{ inputs.max-failures }} + run: | + set -euo pipefail + go run ./scripts/gotestsummary \ + --jsonfile "${JSON_FILE}" \ + --markdown-out - \ + --failures-out "${FAILURES_FILE}" \ + --max-output-bytes "${MAX_OUTPUT_BYTES}" \ + --max-failures "${MAX_FAILURES}" \ + >> "$GITHUB_STEP_SUMMARY" + + - name: Upload Go test failures + if: ${{ always() }} + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: ${{ inputs.artifact-name }} + path: ${{ steps.paths.outputs.failures-file }} + retention-days: ${{ inputs.retention-days }} diff --git a/.github/actions/setup-tf/action.yaml b/.github/actions/setup-tf/action.yaml index abcf9d7a22..22c7253050 100644 --- a/.github/actions/setup-tf/action.yaml +++ b/.github/actions/setup-tf/action.yaml @@ -7,5 +7,5 @@ runs: - name: Install Terraform uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2 with: - terraform_version: 1.15.2 + terraform_version: 1.15.5 terraform_wrapper: false diff --git a/.github/actions/test-go-pg/action.yaml b/.github/actions/test-go-pg/action.yaml index ad409cd700..fb33ba649f 100644 --- a/.github/actions/test-go-pg/action.yaml +++ b/.github/actions/test-go-pg/action.yaml @@ -26,6 +26,18 @@ inputs: description: "Packages to test (default: ./...)" required: false default: "./..." + run-regex: + description: "Go test name regex passed via RUN" + required: false + default: "" + test-shuffle: + description: "Go test shuffle mode passed via TEST_SHUFFLE" + required: false + default: "" + gotestsum-json-file: + description: "Optional Linux path for gotestsum --jsonfile output. Use default for RUNNER_TEMP/go-test.json." + required: false + default: "" embedded-pg-path: description: "Path for embedded postgres data (Windows/macOS only)" required: false @@ -61,8 +73,11 @@ runs: TEST_NUM_PARALLEL_PACKAGES: ${{ inputs.test-parallelism-packages }} TEST_NUM_PARALLEL_TESTS: ${{ inputs.test-parallelism-tests }} TEST_COUNT: ${{ inputs.test-count }} + RUN: ${{ inputs.run-regex }} + TEST_SHUFFLE: ${{ inputs.test-shuffle }} TEST_PACKAGES: ${{ inputs.test-packages }} RACE_DETECTION: ${{ inputs.race-detection }} + GOTESTSUM_JSONFILE_INPUT: ${{ inputs.gotestsum-json-file }} TS_DEBUG_DISCO: "true" TS_DEBUG_DERP: "true" LC_CTYPE: "en_US.UTF-8" @@ -70,6 +85,18 @@ runs: run: | set -euo pipefail + # gotestsum natively reads GOTESTSUM_JSONFILE; set it directly instead + # of writing a PATH shim. "default" is the historical + # ${RUNNER_TEMP}/go-test.json location consumed by + # ./.github/actions/go-test-failure-report. + if [[ -n "${GOTESTSUM_JSONFILE_INPUT}" ]]; then + if [[ "${GOTESTSUM_JSONFILE_INPUT}" == "default" ]]; then + export GOTESTSUM_JSONFILE="${RUNNER_TEMP}/go-test.json" + else + export GOTESTSUM_JSONFILE="${GOTESTSUM_JSONFILE_INPUT}" + fi + fi + if [[ ${RACE_DETECTION} == true ]]; then make test-race else diff --git a/.github/workflows/cherry-pick.yaml b/.github/workflows/cherry-pick.yaml index 98abd79382..8528f7b703 100644 --- a/.github/workflows/cherry-pick.yaml +++ b/.github/workflows/cherry-pick.yaml @@ -154,4 +154,5 @@ jobs: if [ "$CONFLICT" = true ]; then COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)" fi - gh pr comment "$PR_NUMBER" --body "$COMMENT" + # Don't fail the job if commenting fails (e.g. the original PR is locked). + gh pr comment "$PR_NUMBER" --body "$COMMENT" || echo "::warning::Failed to comment on #${PR_NUMBER} (PR may be locked)." diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cdea8e28db..78d6aba61a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,6 +6,13 @@ on: - main - release/* + # GitHub Actions does not reliably trigger push-based CI when a new + # branch is created at a commit that already has a workflow run (e.g. + # from main). The create event fires separately and ensures CI runs + # on newly cut release branches. Non-release branch creations are + # filtered out by the changes job condition. + create: + pull_request: workflow_dispatch: @@ -21,6 +28,13 @@ concurrency: jobs: changes: runs-on: ubuntu-latest + # For create events, only run on release branches to avoid + # triggering CI for every feature branch creation. + if: | + github.event_name != 'create' || ( + github.event.ref_type == 'branch' && + startsWith(github.event.ref, 'release/') + ) outputs: docs-only: ${{ steps.filter.outputs.docs_count == steps.filter.outputs.all_count }} docs: ${{ steps.filter.outputs.docs }} @@ -505,24 +519,6 @@ jobs: source scripts/normalize_path.sh normalize_path_with_symlinks "$RUNNER_TEMP/sym" "$(dirname "$(which terraform)")" - - name: Configure Go test JSON capture - if: runner.os == 'Linux' - shell: bash - run: | - set -euo pipefail - bin_dir="${RUNNER_TEMP}/go-test-json-bin" - mkdir -p "$bin_dir" - - real_gotestsum="$(command -v gotestsum)" - real_gotestsum_quoted="$(printf '%q' "$real_gotestsum")" - printf '%s\n' \ - '#!/usr/bin/env bash' \ - 'set -euo pipefail' \ - "exec ${real_gotestsum_quoted} --jsonfile \"\${RUNNER_TEMP}/go-test.json\" \"\$@\"" \ - > "${bin_dir}/gotestsum" - chmod +x "${bin_dir}/gotestsum" - echo "$bin_dir" >> "$GITHUB_PATH" - - name: Setup RAM disk for Embedded Postgres (Windows) if: runner.os == 'Windows' shell: bash @@ -560,6 +556,7 @@ jobs: # By default, run tests with cache for improved speed (possibly at the expense of correctness). # On main, run tests without cache for the inverse. test-count: ${{ github.ref == 'refs/heads/main' && '1' || '' }} + gotestsum-json-file: default - name: Test with PostgreSQL Database (macOS) if: runner.os == 'macOS' @@ -599,24 +596,11 @@ jobs: embedded-pg-path: "R:/temp/embedded-pg" embedded-pg-cache: ${{ steps.embedded-pg-cache.outputs.embedded-pg-cache }} - - name: Publish Go test failure summary - if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork - run: | - go run ./scripts/gotestsummary \ - --jsonfile "${RUNNER_TEMP}/go-test.json" \ - --markdown-out - \ - --failures-out "${RUNNER_TEMP}/go-test-failures.ndjson" \ - --max-output-bytes 16384 \ - --max-failures 50 \ - >> "$GITHUB_STEP_SUMMARY" - - - name: Upload Go test failures - if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + - name: Publish Go test failure report + if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && (github.event_name != 'pull_request' || !github.event.pull_request.head.repo.fork) + uses: ./.github/actions/go-test-failure-report with: - name: go-test-failures-${{ github.job }}-${{ github.sha }} - path: ${{ runner.temp }}/go-test-failures.ndjson - retention-days: 7 + artifact-name: go-test-failures-${{ github.job }}-${{ github.sha }} - name: Upload failed test db dumps uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 @@ -688,24 +672,6 @@ jobs: source scripts/normalize_path.sh normalize_path_with_symlinks "$RUNNER_TEMP/sym" "$(dirname "$(which terraform)")" - - name: Configure Go test JSON capture - if: runner.os == 'Linux' - shell: bash - run: | - set -euo pipefail - bin_dir="${RUNNER_TEMP}/go-test-json-bin" - mkdir -p "$bin_dir" - - real_gotestsum="$(command -v gotestsum)" - real_gotestsum_quoted="$(printf '%q' "$real_gotestsum")" - printf '%s\n' \ - '#!/usr/bin/env bash' \ - 'set -euo pipefail' \ - "exec ${real_gotestsum_quoted} --jsonfile \"\${RUNNER_TEMP}/go-test.json\" \"\$@\"" \ - > "${bin_dir}/gotestsum" - chmod +x "${bin_dir}/gotestsum" - echo "$bin_dir" >> "$GITHUB_PATH" - - name: Test with PostgreSQL Database uses: ./.github/actions/test-go-pg with: @@ -716,25 +682,13 @@ jobs: # By default, run tests with cache for improved speed (possibly at the expense of correctness). # On main, run tests without cache for the inverse. test-count: ${{ github.ref == 'refs/heads/main' && '1' || '' }} + gotestsum-json-file: default - - name: Publish Go test failure summary - if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork - run: | - go run ./scripts/gotestsummary \ - --jsonfile "${RUNNER_TEMP}/go-test.json" \ - --markdown-out - \ - --failures-out "${RUNNER_TEMP}/go-test-failures.ndjson" \ - --max-output-bytes 16384 \ - --max-failures 50 \ - >> "$GITHUB_STEP_SUMMARY" - - - name: Upload Go test failures - if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + - name: Publish Go test failure report + if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && (github.event_name != 'pull_request' || !github.event.pull_request.head.repo.fork) + uses: ./.github/actions/go-test-failure-report with: - name: go-test-failures-${{ github.job }}-${{ github.sha }} - path: ${{ runner.temp }}/go-test-failures.ndjson - retention-days: 7 + artifact-name: go-test-failures-${{ github.job }}-${{ github.sha }} - name: Upload Test Cache uses: ./.github/actions/test-cache/upload @@ -793,24 +747,6 @@ jobs: # c.f. discussion on https://github.com/coder/coder/pull/15106 # Our Linux runners have 16 cores, but we reduce parallelism since race detection adds a lot of overhead. # We aim to have parallelism match CPU count (4*4=16) to avoid making flakes worse. - - name: Configure Go test JSON capture - if: runner.os == 'Linux' - shell: bash - run: | - set -euo pipefail - bin_dir="${RUNNER_TEMP}/go-test-json-bin" - mkdir -p "$bin_dir" - - real_gotestsum="$(command -v gotestsum)" - real_gotestsum_quoted="$(printf '%q' "$real_gotestsum")" - printf '%s\n' \ - '#!/usr/bin/env bash' \ - 'set -euo pipefail' \ - "exec ${real_gotestsum_quoted} --jsonfile \"\${RUNNER_TEMP}/go-test.json\" \"\$@\"" \ - > "${bin_dir}/gotestsum" - chmod +x "${bin_dir}/gotestsum" - echo "$bin_dir" >> "$GITHUB_PATH" - - name: Run Tests uses: ./.github/actions/test-go-pg with: @@ -818,25 +754,13 @@ jobs: test-parallelism-packages: "4" test-parallelism-tests: "4" race-detection: "true" + gotestsum-json-file: default - - name: Publish Go test failure summary - if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork - run: | - go run ./scripts/gotestsummary \ - --jsonfile "${RUNNER_TEMP}/go-test.json" \ - --markdown-out - \ - --failures-out "${RUNNER_TEMP}/go-test-failures.ndjson" \ - --max-output-bytes 16384 \ - --max-failures 50 \ - >> "$GITHUB_STEP_SUMMARY" - - - name: Upload Go test failures - if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + - name: Publish Go test failure report + if: failure() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && (github.event_name != 'pull_request' || !github.event.pull_request.head.repo.fork) + uses: ./.github/actions/go-test-failure-report with: - name: go-test-failures-${{ github.job }}-${{ github.sha }} - path: ${{ runner.temp }}/go-test-failures.ndjson - retention-days: 7 + artifact-name: go-test-failures-${{ github.job }}-${{ github.sha }} - name: Upload Test Cache uses: ./.github/actions/test-cache/upload diff --git a/.github/workflows/doc-check.yaml b/.github/workflows/doc-check.yaml index f29fd555cc..c692b7e2a8 100644 --- a/.github/workflows/doc-check.yaml +++ b/.github/workflows/doc-check.yaml @@ -213,7 +213,7 @@ jobs: - name: Run doc-check via Coder Agent Chat if: steps.check-secrets.outputs.skip != 'true' - uses: coder/agents-chat-action@f0b975f503d3ff3e4478517baae290d4d01a2c7e # v0 + uses: coder/agents-chat-action@b3fc81d7dae5006dd124e98ef6fada1a36cdd86e # v0.3.0 with: coder-url: ${{ secrets.DOC_CHECK_CODER_URL }} coder-token: ${{ secrets.DOC_CHECK_CODER_SESSION_TOKEN }} diff --git a/.github/workflows/docs-preview.yaml b/.github/workflows/docs-preview.yaml index c585a61acd..8f00114e65 100644 --- a/.github/workflows/docs-preview.yaml +++ b/.github/workflows/docs-preview.yaml @@ -1,5 +1,5 @@ # This workflow posts a docs preview link as a PR comment whenever a -# pull request that touches files under docs/ is opened. The preview +# pull request that touches docs/ is opened or updated. The preview # is served by coder.com's branch-preview feature at /docs/@. # # The link deep-links to the first added/modified/renamed Markdown file @@ -7,8 +7,12 @@ # Branch names are URL-encoded so that names containing slashes or # other special characters produce working links. # -# If the PR only deletes Markdown files (or only changes non-Markdown -# files such as images or manifest.json), no comment is posted. +# On subsequent pushes (synchronize) the existing comment is updated +# rather than creating a duplicate. If a previous push had a Markdown +# file but the current push has none, the stale comment is deleted so +# readers don't follow a dead deep-link. If the PR only deletes +# Markdown files (or only changes non-Markdown files such as images or +# manifest.json), no comment is posted. name: docs-preview @@ -16,9 +20,15 @@ on: pull_request: types: - opened + - synchronize + - reopened paths: - "docs/**" +concurrency: + group: docs-preview-${{ github.event.pull_request.number }} + cancel-in-progress: true + permissions: contents: read @@ -35,6 +45,22 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} REPO: ${{ github.repository }} run: | + # Marker embedded in the comment body so we can find this + # workflow's own comments later. Keep this in one place so + # later refactors don't drift between the body construction + # and the jq selectors used to find existing comments. + DOCS_PREVIEW_MARKER='' + + # Returns IDs of github-actions[bot] comments on the PR whose + # body contains DOCS_PREVIEW_MARKER. Used by both the stale- + # comment-cleanup branch (when this push has no Markdown + # changes) and the upsert branch below. + list_docs_preview_comments() { + gh api --paginate \ + "repos/${REPO}/issues/${PR_NUMBER}/comments" \ + --jq ".[] | select(.user.login == \"github-actions[bot]\") | select(.body | contains(\"${DOCS_PREVIEW_MARKER}\")) | .id" + } + # Fetch the list of non-deleted files from the PR. This is # intentionally not piped into grep so that a gh-api failure # (network, auth, rate-limit) propagates immediately instead @@ -51,7 +77,38 @@ jobs: | head -n 1) || true if [ -z "$first_doc" ]; then - echo "No added/modified Markdown files under docs/, skipping preview comment." + echo "No added/modified Markdown files under docs/ on this push." + + # Now that the workflow fires on synchronize, this branch + # is reachable on pushes that drop all Markdown while still + # touching docs/ (e.g. a push that removes the file an + # earlier push had previewed but adds a new image). The + # previous preview comment now points at a deleted page; + # delete it so readers don't follow a dead deep-link. + # + # Intentionally decoupled from head so that a gh-api failure + # propagates here instead of being swallowed by `|| true`. In + # this branch the workflow has no preview link to post anyway + # (no Markdown in the push), so a transient list failure is a + # cosmetic miss; log and exit cleanly rather than red-checking + # every docs-touching PR during a comments-endpoint hiccup. + # The next push will retry the cleanup. The upsert path below + # uses strict propagation by contrast, because silent failure + # there would create duplicate comments. + stale_comment_ids=$(list_docs_preview_comments) || { + echo "Could not list preview comments; skipping cleanup." + exit 0 + } + stale_id=$(printf '%s\n' "$stale_comment_ids" | head -n 1) || true + + if [ -n "$stale_id" ]; then + if gh api --method DELETE \ + "repos/${REPO}/issues/comments/${stale_id}"; then + echo "Deleted stale docs preview comment (id=${stale_id})." + else + echo "Failed to delete stale docs preview comment (id=${stale_id}); leaving in place." + fi + fi exit 0 fi @@ -97,9 +154,37 @@ jobs: url="${url}/${page_path}" fi - gh pr comment "${PR_NUMBER}" \ - --repo "${REPO}" \ - --body "## Docs preview + # The literal backticks around ${first_doc} are escaped so + # they survive the double-quoted string as Markdown inline + # code; ${url} and ${first_doc} expand normally. + comment_body="## Docs preview [:book: View docs preview](${url}) for \`${first_doc}\` - " + ${DOCS_PREVIEW_MARKER}" + + # Upsert: update the existing docs-preview comment if one + # exists, otherwise create a new one. This prevents duplicate + # preview comments on every push to the PR. + # + # Intentionally not piped into head so that a gh-api failure + # (network, auth, rate-limit) propagates immediately instead + # of being swallowed by `|| true`. + all_comment_ids=$(list_docs_preview_comments) + existing_id=$(printf '%s\n' "$all_comment_ids" | head -n 1) || true + + if [ -n "$existing_id" ]; then + if ! gh api --method PATCH \ + "repos/${REPO}/issues/comments/${existing_id}" \ + --field body="$comment_body"; then + echo "PATCH failed (comment may have been deleted); creating a new comment." + existing_id="" + else + echo "Updated existing docs preview comment (id=${existing_id})." + fi + fi + if [ -z "$existing_id" ]; then + gh pr comment "${PR_NUMBER}" \ + --repo "${REPO}" \ + --body "$comment_body" + echo "Created new docs preview comment." + fi diff --git a/.github/workflows/flake-go.yaml b/.github/workflows/flake-go.yaml new file mode 100644 index 0000000000..1c7eb96dd0 --- /dev/null +++ b/.github/workflows/flake-go.yaml @@ -0,0 +1,82 @@ +name: flake-go + +on: + pull_request: + workflow_dispatch: + inputs: + base_sha: + description: "Base commit to diff against. Defaults to merge-base against origin/main." + required: false + type: string + head_sha: + description: "Head commit to analyze. Defaults to the checked out HEAD." + required: false + type: string + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + flake_go: + name: Flake Check + runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }} + timeout-minutes: 20 + steps: + - name: Harden Runner + uses: step-security/harden-runner@f808768d1510423e83855289c910610ca9b43176 # v2.17.0 + with: + egress-policy: audit + + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + repository: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name || github.repository }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.event.inputs.head_sha || github.sha }} + fetch-depth: 0 + persist-credentials: false + + - name: Setup Go + uses: ./.github/actions/setup-go + + - name: Install whichtests + shell: bash + run: ./.github/scripts/retry.sh -- go install github.com/coder/whichtests@ec33bab1ec04cd86beb7a61a069db4463dba63f5 + + - name: Select changed tests + id: selector + shell: bash + run: | + set -euo pipefail + whichtests \ + --repo-root . \ + --github-actions \ + --coalesce \ + --out-matrix "$RUNNER_TEMP/flake-matrix.json" + + - name: Setup Terraform + if: ${{ fromJSON(steps.selector.outputs.matrix).include[0] != null }} + uses: ./.github/actions/setup-tf + + - name: Run targeted Go flake checks + id: flake_check + if: ${{ fromJSON(steps.selector.outputs.matrix).include[0] != null }} + uses: ./.github/actions/test-go-pg + with: + postgres-version: "13" + test-parallelism-packages: "4" + test-parallelism-tests: "16" + test-count: "25" + test-packages: ${{ fromJSON(steps.selector.outputs.matrix).include[0].package }} + run-regex: ${{ fromJSON(steps.selector.outputs.matrix).include[0].run_regex }} + test-shuffle: "on" + gotestsum-json-file: default + + - name: Publish Go test failure report + if: failure() && steps.flake_check.outcome == 'failure' && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && (github.event_name != 'pull_request' || !github.event.pull_request.head.repo.fork) + uses: ./.github/actions/go-test-failure-report + with: + artifact-name: go-test-failures-${{ github.job }}-${{ github.sha }} diff --git a/.github/workflows/weekly-docs.yaml b/.github/workflows/weekly-docs.yaml index 0d34ef1f43..505c8522de 100644 --- a/.github/workflows/weekly-docs.yaml +++ b/.github/workflows/weekly-docs.yaml @@ -55,9 +55,14 @@ jobs: mkdir -p "$(pnpm store path --silent)" - name: Check Markdown links - uses: umbrelladocs/action-linkspector@37c85bcde51b30bf929936502bac6bfb7e8f0a4d # v1.4.1 + uses: umbrelladocs/action-linkspector@036f295d12b67b0c4b445bc83db0538afb78db69 # v1.5.2 id: markdown-link-check # checks all markdown files from /docs including all subfolders + env: + # Use the runner-provided Chrome instead of letting linkspector's + # puppeteer download a specific version that may not match the + # runner's puppeteer cache. See: https://github.com/UmbrellaDocs/action-linkspector/issues/62 + PUPPETEER_EXECUTABLE_PATH: /usr/bin/google-chrome with: reporter: github-pr-review config_file: ".github/.linkspector.yml" diff --git a/Makefile b/Makefile index 3d92ff40df..b58f80eb68 100644 --- a/Makefile +++ b/Makefile @@ -1446,8 +1446,16 @@ ifdef TEST_SHORT GOTEST_FLAGS += -short endif +# RUN is single-quoted for the shell so regex metacharacters survive make. +# Embedded single quotes are not supported; whichtests only emits RUN values +# built from ASCII test names so generated regexes stay within this contract. ifdef RUN -GOTEST_FLAGS += -run $(RUN) +GOTEST_FLAGS += -run '$(RUN)' +endif + +# TEST_SHUFFLE values must be off, on, or an integer seed. +ifdef TEST_SHUFFLE +GOTEST_FLAGS += -shuffle=$(TEST_SHUFFLE) endif ifdef TEST_CPUPROFILE diff --git a/agent/agentfiles/files.go b/agent/agentfiles/files.go index 4f92a8f7c9..1ee83e7371 100644 --- a/agent/agentfiles/files.go +++ b/agent/agentfiles/files.go @@ -387,17 +387,17 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) { return } - // Duplicate entries both read the same file and race to write; - // the first entry's edits are silently lost. Resolve symlinks - // before comparing so two paths that alias the same real file - // (e.g. one via a symlink, one direct) don't slip past as - // distinct keys. prepareFileEdit resolves the path again for - // its own use; the double lstat cost is cheap compared to the - // data-loss risk of silent aliasing. + // Merge duplicate entries that refer to the same literal path + // so callers don't have to pre-coalesce. Two different paths + // that resolve to the same real file via symlinks are still + // rejected: silently merging edits the caller addressed to + // different paths would hide accidental aliasing. type seenEntry struct { caller string + index int // position in merged slice } seenPaths := make(map[string]seenEntry, len(req.Files)) + var merged []workspacesdk.FileEdits for _, f := range req.Files { // On resolve error, use the raw path; phase 1 surfaces // the error with its proper status code. @@ -406,17 +406,22 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) { key = resolved } if prev, dup := seenPaths[key]; dup { - msg := fmt.Sprintf("duplicate file path %q: combine edits into a single entry's \"edits\" list", f.Path) - if prev.caller != f.Path { - msg = fmt.Sprintf("duplicate file path %q aliases %q (same real file): combine edits into a single entry's \"edits\" list", f.Path, prev.caller) + // Same literal path: merge edits. + if filepath.Clean(prev.caller) == filepath.Clean(f.Path) { + merged[prev.index].Edits = append(merged[prev.index].Edits, f.Edits...) + continue } + // Different paths, same real file (symlink alias). + msg := fmt.Sprintf("duplicate file path %q aliases %q (same real file): combine edits into a single entry's \"edits\" list", f.Path, prev.caller) httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: msg, }) return } - seenPaths[key] = seenEntry{caller: f.Path} + seenPaths[key] = seenEntry{caller: f.Path, index: len(merged)} + merged = append(merged, f) } + req.Files = merged // Phase 1: compute all edits in memory. If any file fails // (bad path, search miss, permission error), bail before diff --git a/agent/agentfiles/files_test.go b/agent/agentfiles/files_test.go index 8f8572b74c..8fcdaba810 100644 --- a/agent/agentfiles/files_test.go +++ b/agent/agentfiles/files_test.go @@ -2622,11 +2622,10 @@ func TestFuzzyReplace_Rejects(t *testing.T) { } } -// TestEditFiles_DuplicatePath_Rejects pins that duplicate paths in -// one request are rejected with 400 and the file on disk is -// unchanged. The pre-fix behavior silently dropped the first -// entry's edits while reporting success (last write wins). -func TestEditFiles_DuplicatePath_Rejects(t *testing.T) { +// TestEditFiles_DuplicatePath_Merges verifies that duplicate paths in +// one request are merged: edits from all entries for the same path are +// concatenated and applied in order. +func TestEditFiles_DuplicatePath_Merges(t *testing.T) { t.Parallel() tmpdir := os.TempDir() @@ -2637,10 +2636,12 @@ func TestEditFiles_DuplicatePath_Rejects(t *testing.T) { original := "one\ntwo\nthree\n" require.NoError(t, afero.WriteFile(fs, path, []byte(original), 0o644)) + // Entry 2 searches for the output of entry 1, proving edits + // are applied in the order they appear across entries. req := workspacesdk.FileEditRequest{ Files: []workspacesdk.FileEdits{ - {Path: path, Edits: []workspacesdk.FileEdit{{Search: "one", Replace: "ONE"}}}, - {Path: path, Edits: []workspacesdk.FileEdit{{Search: "three", Replace: "THREE"}}}, + {Path: path, Edits: []workspacesdk.FileEdit{{Search: "one", Replace: "CHANGED"}}}, + {Path: path, Edits: []workspacesdk.FileEdit{{Search: "CHANGED", Replace: "FINAL"}}}, }, } @@ -2653,15 +2654,49 @@ func TestEditFiles_DuplicatePath_Rejects(t *testing.T) { r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf) api.Routes().ServeHTTP(w, r) - require.Equal(t, http.StatusBadRequest, w.Code, "body: %s", w.Body.String()) - got := &codersdk.Error{} - require.NoError(t, json.NewDecoder(w.Body).Decode(got)) - require.ErrorContains(t, got, "duplicate file path") + require.Equal(t, http.StatusOK, w.Code, "body: %s", w.Body.String()) - // File on disk must be untouched: no partial edits. data, err := afero.ReadFile(fs, path) require.NoError(t, err) - require.Equal(t, original, string(data)) + require.Equal(t, "FINAL\ntwo\nthree\n", string(data)) +} + +// TestEditFiles_DuplicatePath_NonCanonicalMerges verifies that +// non-canonical paths normalizing to the same file are merged, +// not rejected as symlink aliases. +func TestEditFiles_DuplicatePath_NonCanonicalMerges(t *testing.T) { + t.Parallel() + + tmpdir := os.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + fs := afero.NewMemMapFs() + api := agentfiles.NewAPI(logger, fs, nil) + canonical := filepath.Join(tmpdir, "noncanon") + nonCanonical := canonical[:len(tmpdir)] + "/./noncanon" + original := "one\ntwo\nthree\n" + require.NoError(t, afero.WriteFile(fs, canonical, []byte(original), 0o644)) + + req := workspacesdk.FileEditRequest{ + Files: []workspacesdk.FileEdits{ + {Path: canonical, Edits: []workspacesdk.FileEdit{{Search: "one", Replace: "ONE"}}}, + {Path: nonCanonical, Edits: []workspacesdk.FileEdit{{Search: "three", Replace: "THREE"}}}, + }, + } + + ctx := testutil.Context(t, testutil.WaitShort) + buf := bytes.NewBuffer(nil) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + require.NoError(t, enc.Encode(req)) + w := httptest.NewRecorder() + r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf) + api.Routes().ServeHTTP(w, r) + + require.Equal(t, http.StatusOK, w.Code, "body: %s", w.Body.String()) + + data, err := afero.ReadFile(fs, canonical) + require.NoError(t, err) + require.Equal(t, "ONE\ntwo\nTHREE\n", string(data)) } // TestEditFiles_DuplicatePath_SymlinkAliasRejects pins that two diff --git a/aibridge/api.go b/aibridge/api.go index 809d452fe9..34dce84ef8 100644 --- a/aibridge/api.go +++ b/aibridge/api.go @@ -57,6 +57,14 @@ func NewCopilotProvider(cfg config.Copilot) provider.Provider { return provider.NewCopilot(cfg) } +// NewDisabledProviderStub returns a Provider that reports Enabled() == +// false and has no-op implementations for all other methods. Use this +// instead of constructing a concrete provider for disabled rows so that +// adding a new provider type does not require updating a switch here. +func NewDisabledProviderStub(name, providerType string) provider.Provider { + return provider.NewDisabledStub(name, providerType) +} + func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { return metrics.NewMetrics(reg) } diff --git a/aibridge/bridge.go b/aibridge/bridge.go index f604d0a38a..daf103fb10 100644 --- a/aibridge/bridge.go +++ b/aibridge/bridge.go @@ -20,6 +20,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/circuitbreaker" aibcontext "github.com/coder/coder/v2/aibridge/context" + "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/metrics" "github.com/coder/coder/v2/aibridge/provider" @@ -30,6 +31,11 @@ import ( const ( // The duration after which an async recording will be aborted. recordingTimeout = time.Second * 5 + + // ErrorCodeProviderDisabled is the code written in the response + // body when a request targets a configured-but-disabled provider. + // Paired with HTTP 503. + ErrorCodeProviderDisabled = "provider_disabled" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -96,6 +102,14 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re mux := http.NewServeMux() for _, prov := range providers { + // Disabled providers serve a 503 sentinel on every path under + // "//". Bound to the bare name (not RoutePrefix) so paths + // outside the provider's normal "/v1" subtree are also caught. + if !prov.Enabled() { + prefix := fmt.Sprintf("/%s/", prov.Name()) + mux.HandleFunc(prefix, disabledProviderHandler(prov.Name(), logger)) + continue + } // Create per-provider circuit breaker if configured cfg := prov.CircuitBreakerConfig() providerName := prov.Name() @@ -170,6 +184,20 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re }, nil } +// disabledProviderHandler returns 503 with a body containing +// [ErrorCodeProviderDisabled] and the provider name for every request +// targeting name. +func disabledProviderHandler(name string, logger slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logger.Debug(r.Context(), "refusing request for disabled ai provider", + slog.F("provider", name), + slog.F("path", r.URL.Path), + slog.F("method", r.Method), + ) + http.Error(w, fmt.Sprintf("%s: AI provider %q is disabled", ErrorCodeProviderDisabled, name), http.StatusServiceUnavailable) + } +} + // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] rec. // If cbs is non-nil, circuit breaker protection is applied per endpoint/model tuple. @@ -248,11 +276,18 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC slog.F("user_agent", r.UserAgent()), slog.F("streaming", interceptor.Streaming()), slog.F("credential_kind", string(cred.Kind)), - slog.F("credential_hint", cred.Hint), - slog.F("credential_length", cred.Length), ) - log.Debug(ctx, "interception started") + // Log BYOK credentials. Centralized credentials are set by + // the key failover loop. + credLogFields := []slog.Field{} + if cred.Kind == intercept.CredentialKindBYOK { + credLogFields = append(credLogFields, + slog.F("credential_hint", cred.Hint), + slog.F("credential_length", cred.Length), + ) + } + log.Debug(ctx, "interception started", credLogFields...) if m != nil { m.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1) defer func() { @@ -261,22 +296,30 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC } // Process request with circuit breaker protection if configured - if err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { + execErr := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { return interceptor.ProcessRequest(rw, r) - }); err != nil { + }) + // For centralized, the hint now reflects the last attempted + // key from the failover loop. + credHint := interceptor.Credential().Hint + credLen := interceptor.Credential().Length + if execErr != nil { if m != nil { m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1) } - span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) - log.Warn(ctx, "interception failed", slog.Error(err)) + span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", execErr)) + log.Warn(ctx, "interception failed", slog.Error(execErr), slog.F("credential_hint", credHint), slog.F("credential_length", credLen)) } else { if m != nil { m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1) } - log.Debug(ctx, "interception ended") + log.Debug(ctx, "interception ended", slog.F("credential_hint", credHint), slog.F("credential_length", credLen)) } - _ = asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) + _ = asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ + ID: interceptor.ID().String(), + CredentialHint: credHint, + }) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/aibridge/bridge_test.go b/aibridge/bridge_test.go index f2657ab80f..93beb82de9 100644 --- a/aibridge/bridge_test.go +++ b/aibridge/bridge_test.go @@ -205,3 +205,58 @@ func TestPassthroughRoutesForProviders(t *testing.T) { }) } } + +// TestDisabledProviderHandler asserts that requests to a disabled +// provider return a 503 with an ErrorCodeProviderDisabled body and +// that a sibling enabled provider keeps routing normally. +func TestDisabledProviderHandler(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("upstream-reached")) + })) + t.Cleanup(upstream.Close) + + enabled := aibridge.NewOpenAIProvider(config.OpenAI{Name: "enabled-openai", BaseURL: upstream.URL}) + disabled := aibridge.NewDisabledProviderStub("disabled-openai", "openai") + bridge, err := aibridge.NewRequestBridge( + t.Context(), + []provider.Provider{enabled, disabled}, + nil, nil, logger, nil, bridgeTestTracer, + ) + require.NoError(t, err) + + for _, tc := range []struct { + name string + path string + }{ + {name: "Bridged", path: "/disabled-openai/v1/chat/completions"}, + {name: "Passthrough", path: "/disabled-openai/v1/models"}, + {name: "Unknown", path: "/disabled-openai/anything/else"}, + } { + t.Run("DisabledProviderReturnsSentinel/"+tc.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, tc.path, nil) + resp := httptest.NewRecorder() + bridge.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Contains(t, resp.Body.String(), aibridge.ErrorCodeProviderDisabled) + assert.Contains(t, resp.Body.String(), "disabled-openai") + }) + } + + t.Run("EnabledProviderUnaffected", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/enabled-openai/v1/models", nil) + resp := httptest.NewRecorder() + bridge.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "upstream-reached", resp.Body.String()) + }) +} diff --git a/aibridge/intercept/chatcompletions/blocking.go b/aibridge/intercept/chatcompletions/blocking.go index 95d065ce5b..fa1511f660 100644 --- a/aibridge/intercept/chatcompletions/blocking.go +++ b/aibridge/intercept/chatcompletions/blocking.go @@ -291,15 +291,16 @@ func (i *BlockingInterception) newChatCompletionWithKey(ctx context.Context, svc // 401/403. Errors that aren't key-specific don't trigger // failover and are returned to the caller. func (i *BlockingInterception) newChatCompletionWithKeyFailover(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (*openai.ChatCompletion, error) { - // TODO(ssncferreira): update the interception's credential - // hint with the actually-used key (the successful key on - // success, the last tried key on failure) in the upstack PR. walker := i.cfg.KeyPool.Walker() for { key, keyPoolErr := walker.Next() if keyPoolErr != nil { return nil, keyPoolErr } + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + i.logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) requestOpts := append([]option.RequestOption{}, opts...) requestOpts = append(requestOpts, diff --git a/aibridge/intercept/chatcompletions/blocking_internal_test.go b/aibridge/intercept/chatcompletions/blocking_internal_test.go index 3b3a917a54..2b9afaadea 100644 --- a/aibridge/intercept/chatcompletions/blocking_internal_test.go +++ b/aibridge/intercept/chatcompletions/blocking_internal_test.go @@ -72,31 +72,35 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { expectedRetryAfter string // Expected key states after the request, by index in keys. expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: last + // attempted key for centralized, user key from initial request for BYOK. + expectedCredentialHint string }{ { // Given: 1 valid key returning 200. // Then: 1 request, 200 response, key remains valid. name: "single_valid_key", - keys: []string{"k0"}, + keys: []string{"k0-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusOK, body: successBody}, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 returns 429, key-1 returns 200. // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. name: "failover_after_429", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -104,15 +108,16 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 401, key-1 returns 200. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_401", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -120,15 +125,16 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 403, key-1 returns 200. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_403", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -136,25 +142,26 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. // Then: 3 requests, 429 response with smallest Retry-After, // all keys temporary. name: "all_keys_rate_limited", - keys: []string{"k0", "k1", "k2"}, + keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "3"}, body: rateLimitBody, }, - "k2": { + "k2-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "10"}, body: rateLimitBody, @@ -168,15 +175,16 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k2-long-key"), }, { // Given: 2 keys; both return 401. // Then: 2 requests, 502 api_error response, both keys permanent. name: "all_keys_unauthorized", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusBadGateway, @@ -184,14 +192,15 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStatePermanent, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 500. // Then: 1 request, 500 response, both keys remain valid. name: "server_error_no_failover", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, }, expectedRequestCount: 1, expectedStatusCode: http.StatusInternalServerError, @@ -199,6 +208,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: BYOK with a single key returning 429. @@ -219,9 +229,10 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { body: rateLimitBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "5", + expectedCredentialHint: utils.MaskSecret("user-byok"), }, } @@ -252,6 +263,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { cfg := config.OpenAI{BaseURL: upstream.URL + "/"} var pool *keypool.Pool + credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") if len(tc.keys) > 0 { var err error pool, err = keypool.New(tc.keys, quartz.NewMock(t)) @@ -259,6 +271,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { cfg.KeyPool = pool } else if tc.byokKey != "" { cfg.Key = tc.byokKey + credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) } interceptor := NewBlockingInterceptor( @@ -269,7 +282,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { http.Header{}, "Authorization", otel.Tracer("blocking_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + credInfo, ) interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) @@ -288,6 +301,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { if pool != nil { assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") } + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") }) } } @@ -309,6 +323,9 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { expectedSeenKeys []string expectedStatusCode int expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: hint of the + // last attempted key across all agentic-loop iterations. + expectedCredentialHint string }{ { // Given: 2 keys; both upstream calls succeed on key-0. @@ -319,12 +336,13 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, body: textCompleteBody}, }, expectedRequestCount: 2, - expectedSeenKeys: []string{"k0", "k0"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, expectedStatusCode: http.StatusOK, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then 429s @@ -342,12 +360,13 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, body: textCompleteBody}, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedStatusCode: http.StatusOK, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then both @@ -369,12 +388,13 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { }, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedStatusCode: http.StatusTooManyRequests, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, } @@ -409,7 +429,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { })) t.Cleanup(upstream.Close) - pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) require.NoError(t, err) cfg := config.OpenAI{ @@ -459,6 +479,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { defer seenKeysMu.Unlock() assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") }) } } diff --git a/aibridge/intercept/chatcompletions/streaming.go b/aibridge/intercept/chatcompletions/streaming.go index 581ab49d03..e20a2a801d 100644 --- a/aibridge/intercept/chatcompletions/streaming.go +++ b/aibridge/intercept/chatcompletions/streaming.go @@ -164,6 +164,11 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re break } currentKey = key + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) + opts = append(opts, option.WithAPIKey(key.Value()), // Disable SDK retries because the failover diff --git a/aibridge/intercept/chatcompletions/streaming_internal_test.go b/aibridge/intercept/chatcompletions/streaming_internal_test.go index 82c58f9bc1..9561c0948a 100644 --- a/aibridge/intercept/chatcompletions/streaming_internal_test.go +++ b/aibridge/intercept/chatcompletions/streaming_internal_test.go @@ -144,36 +144,40 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { expectedRetryAfter string // Expected key states after the request, by index in keys. expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: last + // attempted key for centralized, user key from initial request for BYOK. + expectedCredentialHint string }{ { // Given: 1 valid key returning a successful stream. // Then: 1 request, 200 response, key remains valid. name: "single_valid_key", - keys: []string{"k0"}, + keys: []string{"k0-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 returns 429 pre-stream, key-1 // streams successfully. // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. name: "failover_after_429", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -185,16 +189,17 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 401 pre-stream, key-1 // streams successfully. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_401", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": { + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -206,15 +211,16 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_403", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1": { + "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -226,6 +232,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 3 keys; all return 429 pre-stream with @@ -233,19 +240,19 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { // Then: 3 requests, 429 response with smallest // Retry-After, all keys temporary. name: "all_keys_rate_limited", - keys: []string{"k0", "k1", "k2"}, + keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "3"}, body: rateLimitBody, }, - "k2": { + "k2-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "10"}, body: rateLimitBody, @@ -259,15 +266,16 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k2-long-key"), }, { // Given: 2 keys; both return 401 pre-stream. // Then: 2 requests, 502 api_error response, both keys permanent. name: "all_keys_unauthorized", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusBadGateway, @@ -275,14 +283,15 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStatePermanent, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 500 pre-stream. // Then: 1 request, 500 response, both keys remain valid. name: "server_error_no_failover", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, }, expectedRequestCount: 1, expectedStatusCode: http.StatusInternalServerError, @@ -290,6 +299,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: BYOK with a single key returning 429. @@ -310,9 +320,10 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { body: rateLimitBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "5", + expectedCredentialHint: utils.MaskSecret("user-byok"), }, } @@ -342,6 +353,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { cfg := config.OpenAI{BaseURL: upstream.URL + "/"} var pool *keypool.Pool + credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") if len(tc.keys) > 0 { var err error pool, err = keypool.New(tc.keys, quartz.NewMock(t)) @@ -349,6 +361,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { cfg.KeyPool = pool } else if tc.byokKey != "" { cfg.Key = tc.byokKey + credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) } interceptor := NewStreamingInterceptor( @@ -359,7 +372,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { http.Header{}, "Authorization", otel.Tracer("streaming_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + credInfo, ) interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) @@ -378,6 +391,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { if pool != nil { assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") } + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") }) } } @@ -435,6 +449,9 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { // error (e.g. all keys exhausted). expectedErr bool expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: hint of the + // last attempted key across all agentic-loop iterations. + expectedCredentialHint string }{ { // Given: 2 keys; both upstream calls succeed on key-0. @@ -445,13 +462,14 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, }, expectedRequestCount: 2, - expectedSeenKeys: []string{"k0", "k0"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, expectedBodyContains: "done", expectErrorAsSSEEvent: false, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then 429s @@ -469,13 +487,14 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedBodyContains: "done", expectErrorAsSSEEvent: false, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then both @@ -497,7 +516,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { }, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedBodyContains: "all configured keys are rate-limited", expectErrorAsSSEEvent: true, expectedErr: true, @@ -505,6 +524,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, } @@ -538,7 +558,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { })) t.Cleanup(upstream.Close) - pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) require.NoError(t, err) cfg := config.OpenAI{ @@ -596,6 +616,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { defer seenKeysMu.Unlock() assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") }) } } diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index e35e2a9726..1f1f49e744 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -329,6 +329,12 @@ func (*interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconf } var out []option.RequestOption + out = append(out, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + if ua := req.Header.Get("User-Agent"); ua != "" { + req.Header.Set("User-Agent", ua+" sdk-ua-app-id/APN_1.1%2Fpc_cdfmjwn8i6u8l9fwz8h82e4w3%24") + } + return next(req) + })) out = append(out, bedrock.WithConfig(awsCfg)) // If a custom base URL is set, override the default endpoint constructed by the bedrock middleware. diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index e91f80feb9..bf74885b2b 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -367,15 +367,16 @@ func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthro // Errors that aren't key-specific don't trigger failover and // are returned to the caller. func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) { - // TODO(ssncferreira): update the interception's credential - // hint with the actually-used key (the successful key on - // success, the last tried key on failure) in the upstack PR. walker := i.cfg.KeyPool.Walker() for { key, keyPoolErr := walker.Next() if keyPoolErr != nil { return nil, keyPoolErr } + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + i.logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) msg, err := i.newMessageWithKey(ctx, svc, option.WithAPIKey(key.Value()), diff --git a/aibridge/intercept/messages/blocking_internal_test.go b/aibridge/intercept/messages/blocking_internal_test.go index 857d425fe3..9b3f0d447b 100644 --- a/aibridge/intercept/messages/blocking_internal_test.go +++ b/aibridge/intercept/messages/blocking_internal_test.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" + "github.com/coder/coder/v2/aibridge/utils" "github.com/coder/quartz" ) @@ -54,31 +55,35 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { expectedRetryAfter string // Expected key states after the request, by index in keys. expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: last + // attempted key for centralized, user key from initial request for BYOK. + expectedCredentialHint string }{ { // Given: 1 valid key returning 200. // Then: 1 request, 200 response, key remains valid. name: "single_valid_key", - keys: []string{"k0"}, + keys: []string{"k0-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusOK, body: successBody}, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 returns 429, key-1 returns 200. // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. name: "failover_after_429", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -86,15 +91,16 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 401, key-1 returns 200. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_401", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -102,15 +108,16 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 403, key-1 returns 200. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_403", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -118,25 +125,26 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. // Then: 3 requests, 429 response with smallest Retry-After, // all keys temporary. name: "all_keys_rate_limited", - keys: []string{"k0", "k1", "k2"}, + keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "3"}, body: rateLimitBody, }, - "k2": { + "k2-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "10"}, body: rateLimitBody, @@ -150,15 +158,16 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k2-long-key"), }, { // Given: 2 keys; both return 401. // Then: 2 requests, 502 api_error response, both keys permanent. name: "all_keys_unauthorized", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusBadGateway, @@ -166,14 +175,15 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStatePermanent, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 500. // Then: 1 request, 500 response, both keys remain valid. name: "server_error_no_failover", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, }, expectedRequestCount: 1, expectedStatusCode: http.StatusInternalServerError, @@ -181,6 +191,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: BYOK with a single key returning 429. @@ -201,9 +212,10 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { body: rateLimitBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "5", + expectedCredentialHint: utils.MaskSecret("user-byok"), }, } @@ -234,6 +246,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { cfg := config.Anthropic{BaseURL: upstream.URL + "/"} var pool *keypool.Pool + credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") if len(tc.keys) > 0 { var err error pool, err = keypool.New(tc.keys, quartz.NewMock(t)) @@ -241,6 +254,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { cfg.KeyPool = pool } else if tc.byokKey != "" { cfg.Key = tc.byokKey + credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) } payload, err := NewRequestPayload([]byte(requestBody)) @@ -255,7 +269,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { http.Header{}, "X-Api-Key", otel.Tracer("blocking_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + credInfo, ) interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) @@ -271,6 +285,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") if pool != nil { assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") } @@ -296,6 +311,9 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { expectedStatusCode int expectedRetryAfter string expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: hint of the + // last attempted key across all agentic-loop iterations. + expectedCredentialHint string }{ { // Given: 2 keys; both upstream calls succeed on key-0. @@ -306,12 +324,13 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, - expectedSeenKeys: []string{"k0", "k0"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, expectedStatusCode: http.StatusOK, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then 429s @@ -329,12 +348,13 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedStatusCode: http.StatusOK, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then both @@ -356,13 +376,14 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { }, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedStatusCode: http.StatusTooManyRequests, expectedRetryAfter: "3", expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, } @@ -397,7 +418,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { })) t.Cleanup(upstream.Close) - pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) require.NoError(t, err) cfg := config.Anthropic{ @@ -447,6 +468,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") seenKeysMu.Lock() defer seenKeysMu.Unlock() diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index 475f32c99c..47c49528a9 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -195,6 +195,11 @@ newStream: break } currentKey = key + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) + streamOpts = append(streamOpts, option.WithAPIKey(key.Value()), // Disable SDK retries because the failover diff --git a/aibridge/intercept/messages/streaming_internal_test.go b/aibridge/intercept/messages/streaming_internal_test.go index 97f48d4cc3..5fc7da00df 100644 --- a/aibridge/intercept/messages/streaming_internal_test.go +++ b/aibridge/intercept/messages/streaming_internal_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" + "github.com/coder/coder/v2/aibridge/utils" "github.com/coder/quartz" ) @@ -60,36 +61,40 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { expectedRetryAfter string // Expected key states after the request, by index in keys. expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: last + // attempted key for centralized, user key from initial request for BYOK. + expectedCredentialHint string }{ { // Given: 1 valid key returning a successful stream. // Then: 1 request, 200 response, key remains valid. name: "single_valid_key", - keys: []string{"k0"}, + keys: []string{"k0-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 returns 429 pre-stream, key-1 // streams successfully. // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. name: "failover_after_429", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -101,16 +106,17 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 401 pre-stream, key-1 // streams successfully. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_401", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": { + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -122,15 +128,16 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_403", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1": { + "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -142,6 +149,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 3 keys; all return 429 pre-stream with @@ -149,19 +157,19 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { // Then: 3 requests, 429 response with smallest // Retry-After, all keys temporary. name: "all_keys_rate_limited", - keys: []string{"k0", "k1", "k2"}, + keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "3"}, body: rateLimitBody, }, - "k2": { + "k2-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "10"}, body: rateLimitBody, @@ -175,15 +183,16 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k2-long-key"), }, { // Given: 2 keys; both return 401 pre-stream. // Then: 2 requests, 502 api_error response, both keys permanent. name: "all_keys_unauthorized", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusBadGateway, @@ -191,14 +200,15 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStatePermanent, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 500 pre-stream. // Then: 1 request, 500 response, both keys remain valid. name: "server_error_no_failover", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, }, expectedRequestCount: 1, expectedStatusCode: http.StatusInternalServerError, @@ -206,6 +216,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: BYOK with a single key returning 429. @@ -226,9 +237,10 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { body: rateLimitBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "5", + expectedCredentialHint: utils.MaskSecret("user-byok"), }, } @@ -258,6 +270,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { cfg := config.Anthropic{BaseURL: upstream.URL + "/"} var pool *keypool.Pool + credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") if len(tc.keys) > 0 { var err error pool, err = keypool.New(tc.keys, quartz.NewMock(t)) @@ -265,6 +278,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { cfg.KeyPool = pool } else if tc.byokKey != "" { cfg.Key = tc.byokKey + credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) } payload, err := NewRequestPayload([]byte(requestBody)) @@ -279,7 +293,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { http.Header{}, "X-Api-Key", otel.Tracer("streaming_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + credInfo, ) interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) @@ -301,6 +315,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { if pool != nil { assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") } + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") }) } } @@ -387,6 +402,9 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { // error (e.g. all keys exhausted). expectedErr bool expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: hint of the + // last attempted key across all agentic-loop iterations. + expectedCredentialHint string }{ { // Given: 2 keys; both upstream calls succeed on key-0. @@ -397,13 +415,14 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, }, expectedRequestCount: 2, - expectedSeenKeys: []string{"k0", "k0"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, expectedBodyContains: "done", expectErrorAsSSEEvent: false, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then 429s @@ -421,13 +440,14 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedBodyContains: "done", expectErrorAsSSEEvent: false, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then both @@ -453,7 +473,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { }, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedBodyContains: "all configured keys are rate-limited", expectErrorAsSSEEvent: true, expectedErr: true, @@ -461,6 +481,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, } @@ -494,7 +515,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { })) t.Cleanup(upstream.Close) - pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) require.NoError(t, err) cfg := config.Anthropic{ @@ -553,6 +574,7 @@ func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { defer seenKeysMu.Unlock() assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") }) } } diff --git a/aibridge/intercept/responses/blocking.go b/aibridge/intercept/responses/blocking.go index 9726b6f750..892dc1e71d 100644 --- a/aibridge/intercept/responses/blocking.go +++ b/aibridge/intercept/responses/blocking.go @@ -171,15 +171,16 @@ func (i *BlockingResponsesInterceptor) newResponseWithKey(ctx context.Context, s // Errors that aren't key-specific don't trigger failover and // are returned to the caller. func (i *BlockingResponsesInterceptor) newResponseWithKeyFailover(ctx context.Context, srv responses.ResponseService, opts []option.RequestOption) (*responses.Response, error) { - // TODO(ssncferreira): update the interception's credential - // hint with the actually-used key (the successful key on - // success, the last tried key on failure) in the upstack PR. walker := i.cfg.KeyPool.Walker() for { key, keyPoolErr := walker.Next() if keyPoolErr != nil { return nil, keyPoolErr } + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + i.logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) requestOpts := append([]option.RequestOption{}, opts...) requestOpts = append(requestOpts, diff --git a/aibridge/intercept/responses/blocking_internal_test.go b/aibridge/intercept/responses/blocking_internal_test.go index 678c2ce0f3..94acf0deef 100644 --- a/aibridge/intercept/responses/blocking_internal_test.go +++ b/aibridge/intercept/responses/blocking_internal_test.go @@ -58,31 +58,35 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { expectedRetryAfter string // Expected key states after the request, by index in keys. expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: last + // attempted key for centralized, user key from initial request for BYOK. + expectedCredentialHint string }{ { // Given: 1 valid key returning 200. // Then: 1 request, 200 response, key remains valid. name: "single_valid_key", - keys: []string{"k0"}, + keys: []string{"k0-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusOK, body: successBody}, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 returns 429, key-1 returns 200. // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. name: "failover_after_429", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -90,15 +94,16 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 401, key-1 returns 200. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_401", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -106,15 +111,16 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 403, key-1 returns 200. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_403", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1": {statusCode: http.StatusOK, body: successBody}, + "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusOK, body: successBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusOK, @@ -122,25 +128,26 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. // Then: 3 requests, 429 response with smallest Retry-After, // all keys temporary. name: "all_keys_rate_limited", - keys: []string{"k0", "k1", "k2"}, + keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "3"}, body: rateLimitBody, }, - "k2": { + "k2-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "10"}, body: rateLimitBody, @@ -154,15 +161,16 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k2-long-key"), }, { // Given: 2 keys; both return 401. // Then: 2 requests, 502 api_error response, both keys permanent. name: "all_keys_unauthorized", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusBadGateway, @@ -170,14 +178,15 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStatePermanent, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 500. // Then: 1 request, 500 response, both keys remain valid. name: "server_error_no_failover", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, }, expectedRequestCount: 1, expectedStatusCode: http.StatusInternalServerError, @@ -185,6 +194,7 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: BYOK with a single key returning 429. @@ -204,8 +214,9 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { body: rateLimitBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + expectedCredentialHint: utils.MaskSecret("user-byok"), }, } @@ -235,6 +246,7 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { t.Cleanup(upstream.Close) cfg := config.OpenAI{BaseURL: upstream.URL + "/"} + credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") var pool *keypool.Pool if len(tc.keys) > 0 { var err error @@ -243,6 +255,7 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { cfg.KeyPool = pool } else if tc.byokKey != "" { cfg.Key = tc.byokKey + credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) } payload, err := NewRequestPayload([]byte(requestBody)) @@ -256,7 +269,7 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { http.Header{}, "Authorization", otel.Tracer("blocking_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + credInfo, ) interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) @@ -272,6 +285,7 @@ func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") if pool != nil { assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") } @@ -296,6 +310,9 @@ func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { expectedSeenKeys []string expectedStatusCode int expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: hint of the + // last attempted key across all agentic-loop iterations. + expectedCredentialHint string }{ { // Given: 2 keys; both upstream calls succeed on key-0. @@ -306,12 +323,13 @@ func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, body: textCompleteBody}, }, expectedRequestCount: 2, - expectedSeenKeys: []string{"k0", "k0"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, expectedStatusCode: http.StatusOK, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then 429s @@ -329,12 +347,13 @@ func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, body: textCompleteBody}, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedStatusCode: http.StatusOK, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then both @@ -356,12 +375,13 @@ func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { }, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedStatusCode: http.StatusTooManyRequests, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, } @@ -396,7 +416,7 @@ func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { })) t.Cleanup(upstream.Close) - pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) require.NoError(t, err) cfg := config.OpenAI{ @@ -444,6 +464,7 @@ func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") seenKeysMu.Lock() defer seenKeysMu.Unlock() diff --git a/aibridge/intercept/responses/streaming.go b/aibridge/intercept/responses/streaming.go index 2140c5e6c8..3b38b7a7e6 100644 --- a/aibridge/intercept/responses/streaming.go +++ b/aibridge/intercept/responses/streaming.go @@ -144,6 +144,11 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r return xerrors.Errorf("key pool exhausted: %w", keyPoolErr) } currentKey = key + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + i.logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) + opts = append(opts, option.WithAPIKey(key.Value()), // Disable SDK retries because the failover diff --git a/aibridge/intercept/responses/streaming_internal_test.go b/aibridge/intercept/responses/streaming_internal_test.go index 3226147cbd..4f20d76c17 100644 --- a/aibridge/intercept/responses/streaming_internal_test.go +++ b/aibridge/intercept/responses/streaming_internal_test.go @@ -51,36 +51,40 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { expectedRetryAfter string // Expected key states after the request, by index in keys. expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: last + // attempted key for centralized, user key from initial request for BYOK. + expectedCredentialHint string }{ { // Given: 1 valid key returning a successful stream. // Then: 1 request, 200 response, key remains valid. name: "single_valid_key", - keys: []string{"k0"}, + keys: []string{"k0-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 returns 429 pre-stream, key-1 // streams successfully. // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. name: "failover_after_429", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -92,16 +96,17 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 401 pre-stream, key-1 // streams successfully. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_401", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": { + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -113,15 +118,16 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. name: "failover_after_403", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1": { + "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1-long-key": { statusCode: http.StatusOK, headers: map[string]string{"Content-Type": "text/event-stream"}, body: streamingSuccessBody, @@ -133,6 +139,7 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 3 keys; all return 429 pre-stream with @@ -140,19 +147,19 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { // Then: 3 requests, 429 response with smallest // Retry-After, all keys temporary. name: "all_keys_rate_limited", - keys: []string{"k0", "k1", "k2"}, + keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, responses: map[string]upstreamResponse{ - "k0": { + "k0-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "5"}, body: rateLimitBody, }, - "k1": { + "k1-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "3"}, body: rateLimitBody, }, - "k2": { + "k2-long-key": { statusCode: http.StatusTooManyRequests, headers: map[string]string{"Retry-After": "10"}, body: rateLimitBody, @@ -166,15 +173,16 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k2-long-key"), }, { // Given: 2 keys; both return 401 pre-stream. // Then: 2 requests, 502 api_error response, both keys permanent. name: "all_keys_unauthorized", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, }, expectedRequestCount: 2, expectedStatusCode: http.StatusBadGateway, @@ -182,14 +190,15 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStatePermanent, keypool.KeyStatePermanent, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 returns 500 pre-stream. // Then: 1 request, 500 response, both keys remain valid. name: "server_error_no_failover", - keys: []string{"k0", "k1"}, + keys: []string{"k0-long-key", "k1-long-key"}, responses: map[string]upstreamResponse{ - "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, }, expectedRequestCount: 1, expectedStatusCode: http.StatusInternalServerError, @@ -197,6 +206,7 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: BYOK with a single key returning 429. @@ -216,8 +226,9 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { body: rateLimitBody, }, }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + expectedCredentialHint: utils.MaskSecret("user-byok"), }, } @@ -246,6 +257,7 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { t.Cleanup(upstream.Close) cfg := config.OpenAI{BaseURL: upstream.URL + "/"} + credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") var pool *keypool.Pool if len(tc.keys) > 0 { var err error @@ -254,6 +266,7 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { cfg.KeyPool = pool } else if tc.byokKey != "" { cfg.Key = tc.byokKey + credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) } payload, err := NewRequestPayload([]byte(streamingRequestBody)) @@ -267,7 +280,7 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { http.Header{}, "Authorization", otel.Tracer("streaming_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + credInfo, ) interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) @@ -283,6 +296,7 @@ func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") if pool != nil { assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") } @@ -339,6 +353,9 @@ func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { // error (e.g. all keys exhausted). expectedErr bool expectedKeyStates []keypool.KeyState + // Expected credential hint after ProcessRequest: hint of the + // last attempted key across all agentic-loop iterations. + expectedCredentialHint string }{ { // Given: 2 keys; both upstream calls succeed on key-0. @@ -349,12 +366,13 @@ func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, }, expectedRequestCount: 2, - expectedSeenKeys: []string{"k0", "k0"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, expectedBodyContains: "done", expectedKeyStates: []keypool.KeyState{ keypool.KeyStateValid, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k0-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then 429s @@ -372,12 +390,13 @@ func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedBodyContains: "done", expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateValid, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, { // Given: 2 keys; key-0 succeeds initially, then both @@ -399,13 +418,14 @@ func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { }, }, expectedRequestCount: 3, - expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, expectedBodyContains: "all configured keys are rate-limited", expectedErr: true, expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateTemporary, }, + expectedCredentialHint: utils.MaskSecret("k1-long-key"), }, } @@ -439,7 +459,7 @@ func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { })) t.Cleanup(upstream.Close) - pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) require.NoError(t, err) cfg := config.OpenAI{ @@ -489,6 +509,7 @@ func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") body := w.Body.String() assert.Contains(t, body, tc.expectedBodyContains, "response body") + assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") seenKeysMu.Lock() defer seenKeysMu.Unlock() diff --git a/aibridge/internal/integrationtest/bridge_internal_test.go b/aibridge/internal/integrationtest/bridge_internal_test.go index cfb40599d3..595d7159c6 100644 --- a/aibridge/internal/integrationtest/bridge_internal_test.go +++ b/aibridge/internal/integrationtest/bridge_internal_test.go @@ -131,6 +131,13 @@ func TestAnthropicMessages(t *testing.T) { require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) + // Verify PRM attribution is NOT present on non-Bedrock Anthropic requests. + received := upstream.receivedRequests() + require.Len(t, received, 1) + ua := received[0].Header.Get("User-Agent") + assert.NotContains(t, ua, "sdk-ua-app-id", + "PRM attribution should not be present on non-Bedrock requests") + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } @@ -327,6 +334,11 @@ func TestAWSBedrockIntegration(t *testing.T) { require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") + // Verify PRM attribution is appended to the User-Agent header. + ua := received[0].Header.Get("User-Agent") + require.Contains(t, ua, "sdk-ua-app-id/APN_1.1%2Fpc_cdfmjwn8i6u8l9fwz8h82e4w3%24", + "expected AWS PRM attribution in User-Agent header") + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) diff --git a/aibridge/internal/testutil/mockprovider.go b/aibridge/internal/testutil/mockprovider.go index 0fd85d2863..e5015cd870 100644 --- a/aibridge/internal/testutil/mockprovider.go +++ b/aibridge/internal/testutil/mockprovider.go @@ -15,6 +15,7 @@ import ( type MockProvider struct { NameStr string URL string + Disabled bool Bridged []string Passthrough []string InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) @@ -22,6 +23,7 @@ type MockProvider struct { func (m *MockProvider) Type() string { return m.NameStr } func (m *MockProvider) Name() string { return m.NameStr } +func (m *MockProvider) Enabled() bool { return !m.Disabled } func (m *MockProvider) BaseURL() string { return m.URL } func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) } func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } diff --git a/aibridge/keypool/keymark.go b/aibridge/keypool/keymark.go index 9b00bb400a..9dfedb3e44 100644 --- a/aibridge/keypool/keymark.go +++ b/aibridge/keypool/keymark.go @@ -5,7 +5,6 @@ import ( "net/http" "cdr.dev/slog/v3" - "github.com/coder/coder/v2/aibridge/utils" ) // MarkKeyOnStatus marks key based on a key-specific HTTP @@ -32,7 +31,7 @@ func MarkKeyOnStatus( if key.MarkTemporary(cooldown) { logger.Info(ctx, "key marked temporary", slog.F("provider", providerName), - slog.F("api_key_hint", utils.MaskSecret(key.Value())), + slog.F("api_key_hint", key.Hint()), slog.F("status", statusCode), slog.F("cooldown", cooldown)) } @@ -41,7 +40,7 @@ func MarkKeyOnStatus( if key.MarkPermanent() { logger.Warn(ctx, "key marked permanent", slog.F("provider", providerName), - slog.F("api_key_hint", utils.MaskSecret(key.Value())), + slog.F("api_key_hint", key.Hint()), slog.F("status", statusCode)) } return true diff --git a/aibridge/keypool/keypool.go b/aibridge/keypool/keypool.go index 55d1712a93..e28ae78325 100644 --- a/aibridge/keypool/keypool.go +++ b/aibridge/keypool/keypool.go @@ -7,6 +7,7 @@ import ( "golang.org/x/xerrors" + "github.com/coder/coder/v2/aibridge/utils" "github.com/coder/quartz" ) @@ -116,6 +117,12 @@ func (k *Key) Value() string { return k.value } +// Hint returns a masked, identifiable fragment of the key, suitable +// for logs and persisted records. +func (k *Key) Hint() string { + return utils.MaskSecret(k.value) +} + // State returns the current state of the key, derived from its // permanent flag and cooldown deadline. func (k *Key) State() KeyState { diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index eb50a3b296..d053cce903 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -95,6 +95,8 @@ func (p *Anthropic) Name() string { return p.cfg.Name } +func (*Anthropic) Enabled() bool { return true } + func (p *Anthropic) RoutePrefix() string { return fmt.Sprintf("/%s", p.Name()) } @@ -168,15 +170,10 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr authHeaderName = "Authorization" credKind = intercept.CredentialKindBYOK credSecret = token - } else if cfg.KeyPool != nil { - // Centralized: use the first key as a placeholder hint. - // TODO(ssncferreira): record the actually-used key in - // the interception record to reflect failover. - if key, keyPoolErr := cfg.KeyPool.Walker().Next(); keyPoolErr == nil { - credSecret = key.Value() - } } - + // Centralized leaves credSecret empty: the hint is set by the + // failover loop on each key attempt and persisted at + // end-of-interception. cred := intercept.NewCredentialInfo(credKind, credSecret) var interceptor intercept.Interceptor diff --git a/aibridge/provider/anthropic_internal_test.go b/aibridge/provider/anthropic_internal_test.go index b3d89556a8..815a83ba03 100644 --- a/aibridge/provider/anthropic_internal_test.go +++ b/aibridge/provider/anthropic_internal_test.go @@ -257,7 +257,9 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) { setHeaders: map[string]string{}, wantXApiKey: "test-key", wantCredentialKind: intercept.CredentialKindCentralized, - wantCredentialHint: "t...y", + // Centralized hint is empty at CreateInterceptor; set + // by the key failover loop during ProcessRequest. + wantCredentialHint: "", }, { name: "Messages_BYOK_BearerToken_And_APIKey", diff --git a/aibridge/provider/copilot.go b/aibridge/provider/copilot.go index 1186e8b253..fd317aadab 100644 --- a/aibridge/provider/copilot.go +++ b/aibridge/provider/copilot.go @@ -78,6 +78,8 @@ func (p *Copilot) Name() string { return p.cfg.Name } +func (*Copilot) Enabled() bool { return true } + func (p *Copilot) BaseURL() string { return p.cfg.BaseURL } diff --git a/aibridge/provider/disabled.go b/aibridge/provider/disabled.go new file mode 100644 index 0000000000..95384b4952 --- /dev/null +++ b/aibridge/provider/disabled.go @@ -0,0 +1,47 @@ +package provider + +import ( + "fmt" + "net/http" + + "go.opentelemetry.io/otel/trace" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/keypool" +) + +// DisabledStub is a Provider placeholder for a configured-but-disabled +// provider. Only Name and Enabled return meaningful values; all other +// methods return empty/nil so the stub never influences routing. +type DisabledStub struct { + name string + providerType string +} + +// NewDisabledStub returns a Provider stub that reports Enabled() == false. +// The type string is preserved so callers can distinguish provider families. +func NewDisabledStub(name, providerType string) *DisabledStub { + return &DisabledStub{name: name, providerType: providerType} +} + +func (d *DisabledStub) Type() string { return d.providerType } +func (d *DisabledStub) Name() string { return d.name } +func (*DisabledStub) Enabled() bool { return false } +func (*DisabledStub) BaseURL() string { return "" } +func (d *DisabledStub) RoutePrefix() string { + return fmt.Sprintf("/%s", d.name) +} +func (*DisabledStub) BridgedRoutes() []string { return nil } +func (*DisabledStub) PassthroughRoutes() []string { return nil } +func (*DisabledStub) AuthHeader() string { return "" } +func (*DisabledStub) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig { + return keypool.KeyFailoverConfig{} +} +func (*DisabledStub) CircuitBreakerConfig() *config.CircuitBreaker { return nil } +func (*DisabledStub) APIDumpDir() string { return "" } +func (*DisabledStub) CreateInterceptor(_ http.ResponseWriter, _ *http.Request, _ trace.Tracer) (intercept.Interceptor, error) { + //nolint:nilnil // disabled providers never reach the interceptor. + return nil, nil +} diff --git a/aibridge/provider/openai.go b/aibridge/provider/openai.go index 177ae03409..88020b7eb2 100644 --- a/aibridge/provider/openai.go +++ b/aibridge/provider/openai.go @@ -84,6 +84,8 @@ func (p *OpenAI) Name() string { return p.cfg.Name } +func (*OpenAI) Enabled() bool { return true } + func (p *OpenAI) RoutePrefix() string { // Route prefix includes version to match default OpenAI base URL. // More detailed explanation: https://github.com/coder/aibridge/pull/174#discussion_r2782320152 @@ -141,14 +143,10 @@ func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trace cfg.KeyPool = nil credKind = intercept.CredentialKindBYOK credSecret = token - } else if cfg.KeyPool != nil { - // Centralized: use the first key as a placeholder hint. - // TODO(ssncferreira): record the actually-used key in - // the interception record to reflect failover. - if key, keyPoolErr := cfg.KeyPool.Walker().Next(); keyPoolErr == nil { - credSecret = key.Value() - } } + // Centralized leaves credSecret empty: the hint is set by the + // failover loop on each key attempt and persisted at + // end-of-interception. cred := intercept.NewCredentialInfo(credKind, credSecret) path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) diff --git a/aibridge/provider/openai_internal_test.go b/aibridge/provider/openai_internal_test.go index e1afcc872c..1922d22c30 100644 --- a/aibridge/provider/openai_internal_test.go +++ b/aibridge/provider/openai_internal_test.go @@ -229,7 +229,9 @@ func TestOpenAI_CreateInterceptor(t *testing.T) { setHeaders: map[string]string{}, wantAuthorization: "Bearer centralized-key", wantCredentialKind: intercept.CredentialKindCentralized, - wantCredentialHint: "ce...ey", + // Centralized hint is empty at CreateInterceptor; set + // by the key failover loop during ProcessRequest. + wantCredentialHint: "", }, { name: "Responses_BYOK", @@ -249,7 +251,9 @@ func TestOpenAI_CreateInterceptor(t *testing.T) { setHeaders: map[string]string{}, wantAuthorization: "Bearer centralized-key", wantCredentialKind: intercept.CredentialKindCentralized, - wantCredentialHint: "ce...ey", + // Centralized hint is empty at CreateInterceptor; set + // by the key failover loop during ProcessRequest. + wantCredentialHint: "", }, // X-Api-Key should not appear in production since clients use Authorization, // but ensure it is stripped if it does arrive. diff --git a/aibridge/provider/provider.go b/aibridge/provider/provider.go index 7520333b53..6f21d7290d 100644 --- a/aibridge/provider/provider.go +++ b/aibridge/provider/provider.go @@ -53,6 +53,8 @@ type Provider interface { // Name returns the provider instance name. // Defaults to Type() when not explicitly configured. Name() string + // Enabled reports whether the provider should serve requests. + Enabled() bool // BaseURL defines the base URL endpoint for this provider's API. BaseURL() string diff --git a/aibridge/recorder/types.go b/aibridge/recorder/types.go index cd541eebd4..faa5713900 100644 --- a/aibridge/recorder/types.go +++ b/aibridge/recorder/types.go @@ -39,13 +39,20 @@ type InterceptionRecord struct { Client string UserAgent string CorrelatingToolCallID *string - CredentialKind string - CredentialHint string + // CredentialKind is always set: either BYOK or centralized. + CredentialKind string + // CredentialHint is only set for BYOK, where the key is known + // from the request. Centralized uses key failover, so the hint + // can only be determined at end-of-interception. + CredentialHint string } type InterceptionRecordEnded struct { ID string EndedAt time.Time + // CredentialHint is the hint observed at end-of-interception. + // Only applied to the DB row for centralized; ignored for BYOK. + CredentialHint string } type TokenUsageRecord struct { diff --git a/cli/aibridged.go b/cli/aibridged.go index 50e2c35d5c..a890488a10 100644 --- a/cli/aibridged.go +++ b/cli/aibridged.go @@ -4,6 +4,7 @@ package cli import ( "context" + "slices" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" @@ -37,6 +38,7 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg reg := prometheus.WrapRegistererWithPrefix("coder_aibridged_", coderAPI.PrometheusRegistry) metrics := aibridge.NewMetrics(reg) + providerMetrics := aibridged.NewMetrics(reg) tracer := coderAPI.TracerProvider.Tracer(tracing.TracerName) // Create pool for reusable stateful [aibridge.RequestBridge] instances (one per user). @@ -50,10 +52,11 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg // derives from env config and serves as a fallback if the database // load fails inside the reloader. reloader := &poolDBReloader{ - pool: pool, - db: coderAPI.Database, - cfg: cfg, - logger: logger.Named("provider-loader"), + pool: pool, + db: coderAPI.Database, + cfg: cfg, + logger: logger.Named("provider-loader"), + metrics: providerMetrics, } unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, reloader, logger.Named("provider-reload")) if err != nil { @@ -78,14 +81,16 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg // the live provider set from the database and forwarding it to the // pool. type poolDBReloader struct { - pool *aibridged.CachedBridgePool - db database.Store - cfg codersdk.AIBridgeConfig - logger slog.Logger + pool *aibridged.CachedBridgePool + db database.Store + cfg codersdk.AIBridgeConfig + logger slog.Logger + metrics *aibridged.Metrics } func (r *poolDBReloader) Reload(ctx context.Context) error { - providers, err := BuildProviders(ctx, r.db, r.cfg, r.logger) + r.metrics.RecordReloadAttempt() + providers, outcomes, err := BuildProviders(ctx, r.db, r.cfg, r.logger) if err != nil { // Keep the previous snapshot in place: dropping all providers // because the DB read failed would compound the visible failure @@ -93,19 +98,23 @@ func (r *poolDBReloader) Reload(ctx context.Context) error { return xerrors.Errorf("load ai providers from database: %w", err) } r.pool.ReplaceProviders(providers) + r.metrics.RecordReloadSuccess(outcomes) return nil } -// BuildProviders loads every enabled ai_providers row, attaches its -// keys, and constructs the equivalent [aibridge.Provider] instances. -// The database is the single source of truth for runtime provider -// configuration. +// BuildProviders loads all ai_providers rows (enabled and disabled), +// attaches keys to enabled rows, and constructs the equivalent +// [aibridge.Provider] instances. The database is the single source of +// truth for runtime provider configuration. +// +// Disabled rows produce a Provider stub with Enabled() == false so the +// bridge can answer requests targeting them with a 503 sentinel. // // Per-provider construction errors are logged and the offending row is // excluded from the returned snapshot; only a failure of the DB query // itself is propagated. This keeps a single misconfigured row from // taking the whole daemon down. -func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridgeConfig, logger slog.Logger) ([]aibridge.Provider, error) { +func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridgeConfig, logger slog.Logger) ([]aibridge.Provider, []aibridged.ProviderOutcome, error) { //nolint:gocritic // AsAIBridged has a minimal permission set for this purpose. authCtx := dbauthz.AsAIBridged(ctx) @@ -117,7 +126,7 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg err := db.InTx(func(tx database.Store) error { var err error rows, err = tx.GetAIProviders(authCtx, database.GetAIProvidersParams{ - IncludeDisabled: false, + IncludeDisabled: true, }) if err != nil { return xerrors.Errorf("load ai providers: %w", err) @@ -129,9 +138,15 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg // Load keys only for the enabled providers to avoid materializing // secrets for disabled rows. - ids := make([]uuid.UUID, len(rows)) - for i, r := range rows { - ids[i] = r.ID + ids := make([]uuid.UUID, 0, len(rows)) + for _, r := range rows { + if !r.Enabled { + continue + } + ids = append(ids, r.ID) + } + if len(ids) == 0 { + return nil } keyRows, err := tx.GetAIProviderKeysByProviderIDs(authCtx, ids) if err != nil { @@ -143,13 +158,25 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg return nil }, &database.TxOptions{ReadOnly: true, TxIdentifier: "build_ai_providers"}) if err != nil { - return nil, err + return nil, nil, err } - out := make([]aibridge.Provider, 0, len(rows)) + providers := make([]aibridge.Provider, 0, len(rows)) + outcomes := make([]aibridged.ProviderOutcome, 0, len(rows)) + enabledCount := 0 for _, row := range rows { + outcome := aibridged.ProviderOutcome{ + Name: row.Name, + Type: string(row.Type), + } + if row.Enabled { + enabledCount++ + } prov, err := buildAIProviderFromRow(row, keysByProvider[row.ID], cfg) if err != nil { + outcome.Status = aibridged.ProviderStatusError + outcome.Err = err + outcomes = append(outcomes, outcome) logger.Error(ctx, "skipping misconfigured ai provider", slog.F("provider_id", row.ID), slog.F("provider_name", row.Name), @@ -158,23 +185,36 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg ) continue } - out = append(out, prov) + if row.Enabled { + outcome.Status = aibridged.ProviderStatusEnabled + } else { + outcome.Status = aibridged.ProviderStatusDisabled + } + outcomes = append(outcomes, outcome) + providers = append(providers, prov) } - if len(rows) > 0 && len(out) == 0 { - logger.Warn(ctx, "all enabled ai providers failed to build; daemon will start with zero providers") + if enabledCount > 0 && !slices.ContainsFunc(providers, func(p aibridge.Provider) bool { return p.Enabled() }) { + logger.Warn(ctx, "all enabled ai providers failed to build; only disabled providers remain") } - return out, nil + return providers, outcomes, nil } // buildAIProviderFromRow decodes the settings blob and constructs the // appropriate [aibridge.Provider] for a single ai_providers row. +// Disabled rows return a Provider stub carrying only Name and +// Disabled: true; settings decode, key loading, and credential checks +// are skipped because the provider will never call upstream. func buildAIProviderFromRow( row database.AIProvider, keys []database.AIProviderKey, cfg codersdk.AIBridgeConfig, ) (aibridge.Provider, error) { + if !row.Enabled { + return disabledProviderFromRow(row) + } + settings, err := db2sdk.AIProviderSettings(row.Settings) if err != nil { return nil, xerrors.Errorf("decode settings: %w", err) @@ -184,17 +224,28 @@ func buildAIProviderFromRow( sendActorHeaders := cfg.SendActorHeaders.Value() dumpDir := cfg.APIDumpDir.Value() + // aibridge currently has native support for OpenAI and Anthropic + // only. The other ai_provider_type values (azure, google, + // openai-compat, openrouter, vercel) route through the OpenAI + // provider because chatd configures them against their + // OpenAI-compatible endpoints. Bedrock routes through the Anthropic + // provider with a Bedrock discriminator in Settings. switch row.Type { - case database.AiProviderTypeOpenai: + case database.AiProviderTypeOpenai, + database.AiProviderTypeAzure, + database.AiProviderTypeGoogle, + database.AiProviderTypeOpenaiCompat, + database.AiProviderTypeOpenrouter, + database.AiProviderTypeVercel: if len(keys) == 0 && !cfg.AllowBYOK.Value() { - return nil, xerrors.New("openai provider has no api keys configured and BYOK is not enabled") + return nil, xerrors.Errorf("%s provider has no api keys configured and BYOK is not enabled", row.Type) } var pool *keypool.Pool if len(keys) > 0 { var err error pool, err = buildAIProviderKeyPool(keys) if err != nil { - return nil, xerrors.Errorf("openai key pool: %w", err) + return nil, xerrors.Errorf("%s key pool: %w", row.Type, err) } } return aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{ @@ -206,8 +257,15 @@ func buildAIProviderFromRow( SendActorHeaders: sendActorHeaders, }), nil - case database.AiProviderTypeAnthropic: + case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: bedrock := bedrockConfigFromRow(row, settings) + // A row typed 'bedrock' authenticates exclusively via settings; + // without populated Bedrock credentials it cannot make upstream + // calls, so refuse rather than falling back to an unsigned + // Anthropic client. + if row.Type == database.AiProviderTypeBedrock && bedrock == nil { + return nil, xerrors.New("bedrock provider has no bedrock credentials configured") + } // Bedrock-backed Anthropic authenticates via AWS credentials in // the settings blob, not the api_keys table. A bearer-token // Anthropic without any key cannot make upstream calls. @@ -246,6 +304,14 @@ func buildAIProviderFromRow( } } +// disabledProviderFromRow builds a Provider stub for a disabled row. +// Using provider.DisabledStub rather than a concrete provider avoids +// duplicating the row.Type switch and ensures that a new AiProviderType +// value is automatically handled without requiring a matching case here. +func disabledProviderFromRow(row database.AIProvider) (aibridge.Provider, error) { + return aibridge.NewDisabledProviderStub(row.Name, string(row.Type)), nil +} + // buildAIProviderKeyPool builds a [keypool.Pool]. Callers must check // len(keys) > 0 first; keypool.New rejects empty input. func buildAIProviderKeyPool(keys []database.AIProviderKey) (*keypool.Pool, error) { diff --git a/cli/aibridged_internal_test.go b/cli/aibridged_internal_test.go index 6aa0608de8..e82a228c67 100644 --- a/cli/aibridged_internal_test.go +++ b/cli/aibridged_internal_test.go @@ -13,6 +13,7 @@ import ( "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/coderd" agplaibridge "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/aibridged" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" @@ -35,7 +36,8 @@ func buildFromEnv(t *testing.T, cfg codersdk.AIBridgeConfig) ([]aibridge.Provide if err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, logger); err != nil { return nil, err } - return BuildProviders(ctx, db, cfg, logger) + providers, _, err := BuildProviders(ctx, db, cfg, logger) + return providers, err } func TestBuildProviders(t *testing.T) { @@ -323,28 +325,35 @@ func TestBuildProvidersSkipsBadRows(t *testing.T) { Settings: sql.NullString{String: "not-json", Valid: true}, }) - providers, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) + providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) require.NoError(t, err) assert.Empty(t, providers) + require.Len(t, outcomes, 1) + assert.Equal(t, "anthropic-broken", outcomes[0].Name) + assert.Equal(t, aibridged.ProviderStatusError, outcomes[0].Status) + assert.Error(t, outcomes[0].Err) }) - t.Run("UnsupportedType", func(t *testing.T) { + t.Run("EnabledButNoKeys", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - // Azure is a valid DB-level provider type but has no runtime - // builder yet; it must hit the default branch and be skipped. + // Azure routes through the OpenAI-family builder, which rejects + // rows without keys when BYOK is disabled. The row must be + // classified as error and excluded from the snapshot. dbgen.AIProvider(t, db, database.AIProvider{ Type: database.AiProviderTypeAzure, Name: "azure-openai", BaseUrl: "https://example.openai.azure.com/", }) - providers, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) + providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) require.NoError(t, err) assert.Empty(t, providers) + require.Len(t, outcomes, 1) + assert.Equal(t, aibridged.ProviderStatusError, outcomes[0].Status) }) t.Run("BadRowDoesNotBlockGoodRow", func(t *testing.T) { @@ -369,10 +378,75 @@ func TestBuildProvidersSkipsBadRows(t *testing.T) { APIKey: "sk-good", }) - providers, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) + providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) require.NoError(t, err) require.Len(t, providers, 1) assert.Equal(t, "openai-good", providers[0].Name()) + require.Len(t, outcomes, 2) + byName := map[string]aibridged.ProviderOutcome{} + for _, o := range outcomes { + byName[o.Name] = o + } + assert.Equal(t, aibridged.ProviderStatusError, byName["anthropic-broken"].Status) + assert.Equal(t, aibridged.ProviderStatusEnabled, byName["openai-good"].Status) + }) + + t.Run("DisabledRowClassifiedAsDisabled", func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + row database.AIProvider + }{ + { + name: "OpenAI", + row: database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "openai-off", + BaseUrl: "https://api.openai.com/", + }, + }, + { + // Anthropic and Bedrock have stricter credential checks + // than the OpenAI family; the disabled short-circuit + // must reach them too. No keys, no bedrock settings. + name: "Anthropic", + row: database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "anthropic-off", + BaseUrl: "https://api.anthropic.com/", + }, + }, + { + name: "Bedrock", + row: database.AIProvider{ + Type: database.AiProviderTypeBedrock, + Name: "bedrock-off", + BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil) + + dbgen.AIProvider(t, db, tc.row, func(p *database.InsertAIProviderParams) { + p.Enabled = false + }) + + providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) + require.NoError(t, err) + require.Len(t, providers, 1, "disabled providers stay in the snapshot so the bridge can serve a 503 sentinel") + assert.Equal(t, tc.row.Name, providers[0].Name()) + assert.False(t, providers[0].Enabled()) + require.Len(t, outcomes, 1) + assert.Equal(t, tc.row.Name, outcomes[0].Name) + assert.Equal(t, aibridged.ProviderStatusDisabled, outcomes[0].Status) + assert.NoError(t, outcomes[0].Err) + }) + } }) } diff --git a/cli/clitest/clitest_test.go b/cli/clitest/clitest_test.go index c214981387..d683af8d34 100644 --- a/cli/clitest/clitest_test.go +++ b/cli/clitest/clitest_test.go @@ -7,8 +7,8 @@ import ( "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestMain(m *testing.M) { @@ -17,11 +17,12 @@ func TestMain(m *testing.M) { func TestCli(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) clitest.CreateTemplateVersionSource(t, nil) client := coderdtest.New(t, nil) i, config := clitest.New(t) clitest.SetupConfig(t, client, config) - pty := ptytest.New(t).Attach(i) + stdout := expecter.NewAttachedToInvocation(t, i) clitest.Start(t, i) - pty.ExpectMatch("coder") + stdout.ExpectMatchContext(ctx, "coder") } diff --git a/cli/cliui/externalauth_test.go b/cli/cliui/externalauth_test.go index 1482aacc2d..3a7359a485 100644 --- a/cli/cliui/externalauth_test.go +++ b/cli/cliui/externalauth_test.go @@ -10,8 +10,8 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/serpent" ) @@ -21,7 +21,6 @@ func TestExternalAuth(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - ptty := ptytest.New(t) cmd := &serpent.Command{ Handler: func(inv *serpent.Invocation) error { var fetched atomic.Bool @@ -42,16 +41,16 @@ func TestExternalAuth(t *testing.T) { } inv := cmd.Invoke().WithContext(ctx) + stdout := expecter.NewAttachedToInvocation(t, inv) - ptty.Attach(inv) done := make(chan struct{}) go func() { defer close(done) err := inv.Run() assert.NoError(t, err) }() - ptty.ExpectMatchContext(ctx, "You must authenticate with") - ptty.ExpectMatchContext(ctx, "https://example.com/gitauth/github") - ptty.ExpectMatchContext(ctx, "Successfully authenticated with GitHub") + stdout.ExpectMatchContext(ctx, "You must authenticate with") + stdout.ExpectMatchContext(ctx, "https://example.com/gitauth/github") + stdout.ExpectMatchContext(ctx, "Successfully authenticated with GitHub") <-done } diff --git a/cli/cliui/provisionerjob_test.go b/cli/cliui/provisionerjob_test.go index 304e0608b8..b2ad8eb293 100644 --- a/cli/cliui/provisionerjob_test.go +++ b/cli/cliui/provisionerjob_test.go @@ -16,8 +16,8 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/serpent" ) @@ -48,12 +48,12 @@ func TestProvisionerJob(t *testing.T) { test.JobMutex.Unlock() }) testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) { - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) test.Next <- struct{}{} - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) - test.PTY.ExpectMatch(cliui.ProvisioningStateRunning) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateRunning) test.Next <- struct{}{} - test.PTY.ExpectMatch(cliui.ProvisioningStateRunning) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateRunning) return true }, testutil.IntervalFast) }) @@ -85,12 +85,12 @@ func TestProvisionerJob(t *testing.T) { test.JobMutex.Unlock() }) testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) { - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) test.Next <- struct{}{} - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) - test.PTY.ExpectMatch("Something") + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) + test.Stdout.ExpectMatchContext(ctx, "Something") test.Next <- struct{}{} - test.PTY.ExpectMatch("Something") + test.Stdout.ExpectMatchContext(ctx, "Something") return true }, testutil.IntervalFast) }) @@ -151,12 +151,12 @@ func TestProvisionerJob(t *testing.T) { test.JobMutex.Unlock() }) testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) { - test.PTY.ExpectRegexMatch(tc.expected) + test.Stdout.ExpectRegexMatchContext(ctx, tc.expected) test.Next <- struct{}{} - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) // step completed - test.PTY.ExpectMatch(cliui.ProvisioningStateRunning) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) // step completed + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateRunning) test.Next <- struct{}{} - test.PTY.ExpectMatch(cliui.ProvisioningStateRunning) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateRunning) return true }, testutil.IntervalFast) }) @@ -193,11 +193,11 @@ func TestProvisionerJob(t *testing.T) { test.JobMutex.Unlock() }) testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) { - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) test.Next <- struct{}{} - test.PTY.ExpectMatch("Gracefully canceling") + test.Stdout.ExpectMatchContext(ctx, "Gracefully canceling") test.Next <- struct{}{} - test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) + test.Stdout.ExpectMatchContext(ctx, cliui.ProvisioningStateQueued) return true }, testutil.IntervalFast) }) @@ -208,7 +208,7 @@ type provisionerJobTest struct { Job *codersdk.ProvisionerJob JobMutex *sync.Mutex Logs chan codersdk.ProvisionerJobLog - PTY *ptytest.PTY + Stdout *expecter.Expecter } func newProvisionerJob(t *testing.T) provisionerJobTest { @@ -240,8 +240,7 @@ func newProvisionerJob(t *testing.T) provisionerJobTest { } inv := cmd.Invoke() - ptty := ptytest.New(t) - ptty.Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) done := make(chan struct{}) go func() { defer close(done) @@ -258,7 +257,7 @@ func newProvisionerJob(t *testing.T) provisionerJobTest { Job: job, JobMutex: &jobLock, Logs: logs, - PTY: ptty, + Stdout: stdout, } } diff --git a/cli/cliui/select_test.go b/cli/cliui/select_test.go index 55ab81f50f..d532ff19eb 100644 --- a/cli/cliui/select_test.go +++ b/cli/cliui/select_test.go @@ -8,7 +8,6 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/serpent" ) @@ -16,10 +15,9 @@ func TestSelect(t *testing.T) { t.Parallel() t.Run("Select", func(t *testing.T) { t.Parallel() - ptty := ptytest.New(t) msgChan := make(chan string) go func() { - resp, err := newSelect(ptty, cliui.SelectOptions{ + resp, err := newSelect(cliui.SelectOptions{ Options: []string{"First", "Second"}, }) assert.NoError(t, err) @@ -29,7 +27,7 @@ func TestSelect(t *testing.T) { }) } -func newSelect(ptty *ptytest.PTY, opts cliui.SelectOptions) (string, error) { +func newSelect(opts cliui.SelectOptions) (string, error) { value := "" cmd := &serpent.Command{ Handler: func(inv *serpent.Invocation) error { @@ -39,7 +37,6 @@ func newSelect(ptty *ptytest.PTY, opts cliui.SelectOptions) (string, error) { }, } inv := cmd.Invoke() - ptty.Attach(inv) return value, inv.Run() } @@ -47,10 +44,10 @@ func TestRichSelect(t *testing.T) { t.Parallel() t.Run("RichSelect", func(t *testing.T) { t.Parallel() - ptty := ptytest.New(t) + msgChan := make(chan string) go func() { - resp, err := newRichSelect(ptty, cliui.RichSelectOptions{ + resp, err := newRichSelect(cliui.RichSelectOptions{ Options: []codersdk.TemplateVersionParameterOption{ {Name: "A-Name", Value: "A-Value", Description: "A-Description."}, {Name: "B-Name", Value: "B-Value", Description: "B-Description."}, @@ -63,7 +60,7 @@ func TestRichSelect(t *testing.T) { }) } -func newRichSelect(ptty *ptytest.PTY, opts cliui.RichSelectOptions) (string, error) { +func newRichSelect(opts cliui.RichSelectOptions) (string, error) { value := "" cmd := &serpent.Command{ Handler: func(inv *serpent.Invocation) error { @@ -75,7 +72,6 @@ func newRichSelect(ptty *ptytest.PTY, opts cliui.RichSelectOptions) (string, err }, } inv := cmd.Invoke() - ptty.Attach(inv) return value, inv.Run() } @@ -181,11 +177,10 @@ func TestMultiSelect(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - ptty := ptytest.New(t) msgChan := make(chan []string) go func() { - resp, err := newMultiSelect(ptty, tt.items, tt.allowCustom) + resp, err := newMultiSelect(tt.items, tt.allowCustom) assert.NoError(t, err) msgChan <- resp }() @@ -195,7 +190,7 @@ func TestMultiSelect(t *testing.T) { } } -func newMultiSelect(pty *ptytest.PTY, items []string, custom bool) ([]string, error) { +func newMultiSelect(items []string, custom bool) ([]string, error) { var values []string cmd := &serpent.Command{ Handler: func(inv *serpent.Invocation) error { @@ -211,6 +206,5 @@ func newMultiSelect(pty *ptytest.PTY, items []string, custom bool) ([]string, er }, } inv := cmd.Invoke() - pty.Attach(inv) return values, inv.Run() } diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 7e42bfe81a..61588e4fb9 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -24,8 +24,8 @@ import ( "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func sshConfigFileName(t *testing.T) (sshConfig string) { @@ -64,6 +64,8 @@ func TestConfigSSH(t *testing.T) { t.Skip("See coder/internal#117") } + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) const hostname = "test-coder." const expectedKey = "ConnectionAttempts" const removeKey = "ConnectTimeout" @@ -131,9 +133,8 @@ func TestConfigSSH(t *testing.T) { "--ssh-config-file", sshConfigFile, "--skip-proxy-command") clitest.SetupConfig(t, member, root) - pty := ptytest.New(t) - inv.Stdin = pty.Input() - inv.Stdout = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) waiter := clitest.StartWithWaiter(t, inv) @@ -143,8 +144,8 @@ func TestConfigSSH(t *testing.T) { {match: "Continue?", write: "yes"}, } for _, m := range matches { - pty.ExpectMatch(m.match) - pty.WriteLine(m.write) + stdout.ExpectMatchContext(ctx, m.match) + stdin.WriteLine(m.write) } waiter.RequireSuccess() @@ -157,10 +158,8 @@ func TestConfigSSH(t *testing.T) { home := filepath.Dir(filepath.Dir(sshConfigFile)) // #nosec sshCmd := exec.Command("ssh", "-F", sshConfigFile, hostname+r.Workspace.Name, "echo", "test") - pty = ptytest.New(t) // Set HOME because coder config is included from ~/.ssh/coder. sshCmd.Env = append(sshCmd.Env, fmt.Sprintf("HOME=%s", home)) - inv.Stderr = pty.Output() data, err := sshCmd.Output() require.NoError(t, err) require.Equal(t, "test", strings.TrimSpace(string(data))) @@ -693,6 +692,8 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client, db := coderdtest.NewWithDatabase(t, nil) user := coderdtest.CreateFirstUser(t, client) @@ -718,8 +719,8 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { //nolint:gocritic // This has always ran with the admin user. clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - pty.Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) done := tGo(t, func() { err := inv.Run() if !tt.wantErr { @@ -730,8 +731,8 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { }) for _, m := range tt.matches { - pty.ExpectMatch(m.match) - pty.WriteLine(m.write) + stdout.ExpectMatchContext(ctx, m.match) + stdin.WriteLine(m.write) } <-done diff --git a/cli/create_test.go b/cli/create_test.go index 670f785791..043148d178 100644 --- a/cli/create_test.go +++ b/cli/create_test.go @@ -20,8 +20,8 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/provisionersdk/proto" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestCreateDynamic(t *testing.T) { @@ -74,14 +74,14 @@ func TestCreateDynamic(t *testing.T) { } inv, root := clitest.New(t, args...) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) doneChan := make(chan error) go func() { doneChan <- inv.Run() }() - pty.ExpectMatchContext(ctx, "has been created") + stdout.ExpectMatchContext(ctx, "has been created") err := testutil.RequireReceive(ctx, t, doneChan) require.NoError(t, err) @@ -103,14 +103,14 @@ func TestCreateDynamic(t *testing.T) { } inv, root = clitest.New(t, args...) clitest.SetupConfig(t, member, root) - pty = ptytest.New(t).Attach(inv) + stdout = expecter.NewAttachedToInvocation(t, inv) doneChan = make(chan error) go func() { doneChan <- inv.Run() }() - pty.ExpectMatchContext(ctx, "has been created") + stdout.ExpectMatchContext(ctx, "has been created") err = testutil.RequireReceive(ctx, t, doneChan) require.NoError(t, err) @@ -129,7 +129,8 @@ func TestCreateDynamic(t *testing.T) { // When enable_region=true, the region parameter becomes required and CLI should prompt. t.Run("PromptForConditionalParam", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) + ctx := testutil.Context(t, time.Hour) + logger := testutil.Logger(t) template, _ := coderdtest.DynamicParameterTemplate(t, owner, first.OrganizationID, coderdtest.DynamicParameterTemplateParams{ MainTF: conditionalParamTF, @@ -143,7 +144,8 @@ func TestCreateDynamic(t *testing.T) { } inv, root := clitest.New(t, args...) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) doneChan := make(chan error) go func() { @@ -151,14 +153,14 @@ func TestCreateDynamic(t *testing.T) { }() // CLI should prompt for the region parameter since enable_region=true - pty.ExpectMatchContext(ctx, "region") - pty.WriteLine("eu-west") + stdout.ExpectMatchContext(ctx, "region") + stdin.WriteLine("eu-west") // Confirm creation - pty.ExpectMatchContext(ctx, "Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") - pty.ExpectMatchContext(ctx, "has been created") + stdout.ExpectMatchContext(ctx, "has been created") err := <-doneChan require.NoError(t, err) @@ -305,14 +307,14 @@ func TestCreateDynamic(t *testing.T) { "-y", ) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) doneChan := make(chan error) go func() { doneChan <- inv.Run() }() - pty.ExpectMatchContext(ctx, "has been created") + stdout.ExpectMatchContext(ctx, "has been created") err = <-doneChan require.NoError(t, err, "slider=8 should succeed when max_slider=10") @@ -331,6 +333,8 @@ func TestCreate(t *testing.T) { t.Parallel() t.Run("Create", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -348,7 +352,8 @@ func TestCreate(t *testing.T) { inv, root := clitest.New(t, args...) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -363,9 +368,9 @@ func TestCreate(t *testing.T) { {match: "Confirm create", write: "yes"}, } for _, m := range matches { - pty.ExpectMatch(m.match) + stdout.ExpectMatchContext(ctx, m.match) if len(m.write) > 0 { - pty.WriteLine(m.write) + stdin.WriteLine(m.write) } } <-doneChan @@ -385,6 +390,8 @@ func TestCreate(t *testing.T) { t.Run("CreateForOtherUser", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, completeWithAgent()) @@ -403,7 +410,8 @@ func TestCreate(t *testing.T) { //nolint:gocritic // Creating a workspace for another user requires owner permissions. clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -418,9 +426,9 @@ func TestCreate(t *testing.T) { {match: "Confirm create", write: "yes"}, } for _, m := range matches { - pty.ExpectMatch(m.match) + stdout.ExpectMatchContext(ctx, m.match) if len(m.write) > 0 { - pty.WriteLine(m.write) + stdin.WriteLine(m.write) } } <-doneChan @@ -439,6 +447,8 @@ func TestCreate(t *testing.T) { t.Run("CreateWithSpecificTemplateVersion", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -467,7 +477,8 @@ func TestCreate(t *testing.T) { inv, root := clitest.New(t, args...) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -482,9 +493,9 @@ func TestCreate(t *testing.T) { {match: "Confirm create", write: "yes"}, } for _, m := range matches { - pty.ExpectMatch(m.match) + stdout.ExpectMatchContext(ctx, m.match) if len(m.write) > 0 { - pty.WriteLine(m.write) + stdin.WriteLine(m.write) } } <-doneChan @@ -506,6 +517,8 @@ func TestCreate(t *testing.T) { t.Run("InheritStopAfterFromTemplate", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -522,7 +535,8 @@ func TestCreate(t *testing.T) { } inv, root := clitest.New(t, args...) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) waiter := clitest.StartWithWaiter(t, inv) matches := []struct { match string @@ -533,9 +547,9 @@ func TestCreate(t *testing.T) { {match: "Confirm create", write: "yes"}, } for _, m := range matches { - pty.ExpectMatch(m.match) + stdout.ExpectMatchContext(ctx, m.match) if len(m.write) > 0 { - pty.WriteLine(m.write) + stdin.WriteLine(m.write) } } waiter.RequireSuccess() @@ -570,6 +584,8 @@ func TestCreate(t *testing.T) { t.Run("FromNothing", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -579,7 +595,8 @@ func TestCreate(t *testing.T) { inv, root := clitest.New(t, "create", "") clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -592,8 +609,8 @@ func TestCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) - pty.WriteLine(value) + stdout.ExpectMatchContext(ctx, match) + stdin.WriteLine(value) } <-doneChan @@ -621,14 +638,14 @@ func TestCreate(t *testing.T) { ) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() - pty.ExpectMatchContext(ctx, "building in the background") + stdout.ExpectMatchContext(ctx, "building in the background") _ = testutil.TryReceive(ctx, t, doneChan) // Verify workspace was actually created. @@ -658,14 +675,14 @@ func TestCreate(t *testing.T) { ) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() - pty.ExpectMatchContext(ctx, "building in the background") + stdout.ExpectMatchContext(ctx, "building in the background") _ = testutil.TryReceive(ctx, t, doneChan) // Verify workspace was created and parameters were applied. @@ -706,14 +723,14 @@ func TestCreate(t *testing.T) { ) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() - pty.ExpectMatchContext(ctx, "building in the background") + stdout.ExpectMatchContext(ctx, "building in the background") _ = testutil.TryReceive(ctx, t, doneChan) ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{}) @@ -801,7 +818,7 @@ func TestCreateWithRichParameters(t *testing.T) { setup func() []string // handlePty optionally runs after the command is started. It should handle // all expected prompts from the pty. - handlePty func(pty *ptytest.PTY) + handlePty func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) // postRun runs after the command has finished but before the workspace is // verified. It must return the workspace name to check (used for the copy // workspace tests). @@ -818,15 +835,15 @@ func TestCreateWithRichParameters(t *testing.T) { }{ { name: "ValuesFromPrompt", - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // Enter the value for each parameter as prompted. for _, param := range params { - pty.ExpectMatch(param.name) - pty.WriteLine(param.value) + stdout.ExpectMatchContext(ctx, param.name) + stdin.WriteLine(param.value) } // Confirm the creation. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, }, { @@ -839,16 +856,16 @@ func TestCreateWithRichParameters(t *testing.T) { } return args }, - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // Simply accept the defaults. for _, param := range params { - pty.ExpectMatch(param.name) - pty.ExpectMatch(`Enter a value (default: "` + param.value + `")`) - pty.WriteLine("") + stdout.ExpectMatchContext(ctx, param.name) + stdout.ExpectMatchContext(ctx, `Enter a value (default: "`+param.value+`")`) + stdin.WriteLine("") } // Confirm the creation. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, }, { @@ -865,10 +882,10 @@ func TestCreateWithRichParameters(t *testing.T) { return []string{"--rich-parameter-file", parameterFile.Name()} }, - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // No prompts, we only need to confirm. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, }, { @@ -881,10 +898,10 @@ func TestCreateWithRichParameters(t *testing.T) { } return args }, - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // No prompts, we only need to confirm. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, }, { @@ -920,9 +937,6 @@ func TestCreateWithRichParameters(t *testing.T) { postRun: func(t *testing.T, tctx testContext) string { inv, root := clitest.New(t, "create", "--copy-parameters-from", tctx.workspaceName, "other-workspace", "-y") clitest.SetupConfig(t, tctx.member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() err := inv.Run() require.NoError(t, err, "failed to create a workspace based on the source workspace") return "other-workspace" @@ -952,9 +966,6 @@ func TestCreateWithRichParameters(t *testing.T) { // Then create the copy. It should use the old template version. inv, root := clitest.New(t, "create", "--copy-parameters-from", tctx.workspaceName, "other-workspace", "-y") clitest.SetupConfig(t, tctx.member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() err := inv.Run() require.NoError(t, err, "failed to create a workspace based on the source workspace") return "other-workspace" @@ -962,16 +973,16 @@ func TestCreateWithRichParameters(t *testing.T) { }, { name: "ValuesFromTemplateDefaults", - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // Simply accept the defaults. for _, param := range params { - pty.ExpectMatch(param.name) - pty.ExpectMatch(`Enter a value (default: "` + param.value + `")`) - pty.WriteLine("") + stdout.ExpectMatchContext(ctx, param.name) + stdout.ExpectMatchContext(ctx, `Enter a value (default: "`+param.value+`")`) + stdin.WriteLine("") } // Confirm the creation. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, withDefaults: true, }, @@ -980,14 +991,14 @@ func TestCreateWithRichParameters(t *testing.T) { setup: func() []string { return []string{"--use-parameter-defaults"} }, - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // Default values should get printed. for _, param := range params { - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", param.name, param.value)) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", param.name, param.value)) } // No prompts, we only need to confirm. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, withDefaults: true, }, @@ -1001,14 +1012,14 @@ func TestCreateWithRichParameters(t *testing.T) { } return args }, - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // Default values should get printed. for _, param := range params { - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", param.name, param.value)) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", param.name, param.value)) } // No prompts, we only need to confirm. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, }, { @@ -1031,14 +1042,14 @@ cli_param: from file`) "--parameter", "cli_param=from cli", } }, - handlePty: func(pty *ptytest.PTY) { + handlePty: func(ctx context.Context, stdout *expecter.Expecter, stdin *testutil.Writer) { // Should get prompted for the input param since it has no default. - pty.ExpectMatch("input_param") - pty.WriteLine("from input") + stdout.ExpectMatchContext(ctx, "input_param") + stdin.WriteLine("from input") // Confirm the creation. - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }, withDefaults: true, inputParameters: []param{ @@ -1082,6 +1093,8 @@ cli_param: from file`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) parameters := params if len(tt.inputParameters) > 0 { @@ -1122,14 +1135,15 @@ cli_param: from file`) inv, root := clitest.New(t, args...) clitest.SetupConfig(t, member, root) doneChan := make(chan error) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { doneChan <- inv.Run() }() // The test may do something with the pty. if tt.handlePty != nil { - tt.handlePty(pty) + tt.handlePty(ctx, stdout, stdin) } // Wait for the command to exit. @@ -1235,6 +1249,7 @@ func TestCreateWithPreset(t *testing.T) { // the CLI uses the specified preset instead of the default t.Run("PresetFlag", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1263,17 +1278,15 @@ func TestCreateWithPreset(t *testing.T) { workspaceName := "my-workspace" inv, root := clitest.New(t, "create", workspaceName, "--template", template.Name, "-y", "--preset", preset.Name) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) err := inv.Run() require.NoError(t, err) // Should: display the selected preset as well as its parameters presetName := fmt.Sprintf("Preset '%s' applied:", preset.Name) - pty.ExpectMatch(presetName) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", thirdParameterName, thirdParameterValue)) + stdout.ExpectMatchContext(ctx, presetName) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", thirdParameterName, thirdParameterValue)) // Verify if the new workspace uses expected parameters. ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) @@ -1312,6 +1325,7 @@ func TestCreateWithPreset(t *testing.T) { // the CLI automatically uses the default preset to create the workspace t.Run("DefaultPreset", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1340,22 +1354,17 @@ func TestCreateWithPreset(t *testing.T) { workspaceName := "my-workspace" inv, root := clitest.New(t, "create", workspaceName, "--template", template.Name, "-y") clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) err := inv.Run() require.NoError(t, err) // Should: display the default preset as well as its parameters presetName := fmt.Sprintf("Preset '%s' (default) applied:", defaultPreset.Name) - pty.ExpectMatch(presetName) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", thirdParameterName, thirdParameterValue)) + stdout.ExpectMatchContext(ctx, presetName) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", thirdParameterName, thirdParameterValue)) // Verify if the new workspace uses expected parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - tvPresets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) require.Len(t, tvPresets, 2) @@ -1389,12 +1398,14 @@ func TestCreateWithPreset(t *testing.T) { // the CLI prompts the user to select a preset. t.Run("NoDefaultPresetPromptUser", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) - // Given: a template and a template version with two presets + // Given: a template and a template version with a single, non-default preset. preset := proto.Preset{ Name: "preset-test", Description: "Preset Test.", @@ -1414,7 +1425,8 @@ func TestCreateWithPreset(t *testing.T) { "--parameter", fmt.Sprintf("%s=%s", thirdParameterName, thirdParameterValue)) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -1422,18 +1434,16 @@ func TestCreateWithPreset(t *testing.T) { }() // Should: prompt the user for the preset - pty.ExpectMatch("Select a preset below:") - pty.WriteLine("\n") - pty.ExpectMatch("Preset 'preset-test' applied") - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Select a preset below:") + // We don't actually have to respond to the selector, since we hardcode the cliui.Select to return the + // first option in test scenarios (c.f. cliui/select.go) + stdout.ExpectMatchContext(ctx, "Preset 'preset-test' applied") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") <-doneChan // Verify if the new workspace uses expected parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - tvPresets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) require.Len(t, tvPresets, 1) @@ -1460,6 +1470,7 @@ func TestCreateWithPreset(t *testing.T) { // with workspace creation without applying any preset. t.Run("TemplateVersionWithoutPresets", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1476,17 +1487,12 @@ func TestCreateWithPreset(t *testing.T) { "--parameter", fmt.Sprintf("%s=%s", firstParameterName, firstOptionalParameterValue), "--parameter", fmt.Sprintf("%s=%s", thirdParameterName, thirdParameterValue)) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) err := inv.Run() require.NoError(t, err) - pty.ExpectMatch("No preset applied.") + stdout.ExpectMatchContext(ctx, "No preset applied.") // Verify if the new workspace uses expected parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ Name: workspaceName, }) @@ -1509,6 +1515,7 @@ func TestCreateWithPreset(t *testing.T) { // The workspace should be created without using any preset-defined parameters. t.Run("PresetFlagNone", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1533,17 +1540,12 @@ func TestCreateWithPreset(t *testing.T) { "--parameter", fmt.Sprintf("%s=%s", firstParameterName, firstOptionalParameterValue), "--parameter", fmt.Sprintf("%s=%s", thirdParameterName, thirdParameterValue)) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) err := inv.Run() require.NoError(t, err) - pty.ExpectMatch("No preset applied.") + stdout.ExpectMatchContext(ctx, "No preset applied.") // Verify that the new workspace doesn't use the preset parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - tvPresets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) require.Len(t, tvPresets, 1) @@ -1591,9 +1593,6 @@ func TestCreateWithPreset(t *testing.T) { workspaceName := "my-workspace" inv, root := clitest.New(t, "create", workspaceName, "--template", template.Name, "-y", "--preset", "invalid-preset") clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() err := inv.Run() // Should: fail with an error indicating the preset was not found @@ -1610,6 +1609,7 @@ func TestCreateWithPreset(t *testing.T) { // - and the value of parameter B from the parameter flag. t.Run("PresetOverridesParameterFlagValues", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1633,21 +1633,16 @@ func TestCreateWithPreset(t *testing.T) { "--parameter", fmt.Sprintf("%s=%s", firstParameterName, firstOptionalParameterValue), "--parameter", fmt.Sprintf("%s=%s", thirdParameterName, thirdParameterValue)) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) err := inv.Run() require.NoError(t, err) // Should: display the selected preset as well as its parameter presetName := fmt.Sprintf("Preset '%s' applied:", preset.Name) - pty.ExpectMatch(presetName) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) + stdout.ExpectMatchContext(ctx, presetName) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) // Verify if the new workspace uses expected parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - tvPresets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) require.Len(t, tvPresets, 1) @@ -1679,6 +1674,7 @@ func TestCreateWithPreset(t *testing.T) { // - and the value of parameter B from the file. t.Run("PresetOverridesParameterFileValues", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1707,21 +1703,16 @@ func TestCreateWithPreset(t *testing.T) { "--preset", preset.Name, "--rich-parameter-file", parameterFile.Name()) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) err := inv.Run() require.NoError(t, err) // Should: display the selected preset as well as its parameter presetName := fmt.Sprintf("Preset '%s' applied:", preset.Name) - pty.ExpectMatch(presetName) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) + stdout.ExpectMatchContext(ctx, presetName) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) // Verify if the new workspace uses expected parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - tvPresets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) require.Len(t, tvPresets, 1) @@ -1748,7 +1739,8 @@ func TestCreateWithPreset(t *testing.T) { // the CLI prompts the user for input to fill in the missing parameters. t.Run("PromptsForMissingParametersWhenPresetIsIncomplete", func(t *testing.T) { t.Parallel() - + ctx := testutil.Context(t, testutil.WaitMedium) + logger := testutil.Logger(t) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -1769,7 +1761,8 @@ func TestCreateWithPreset(t *testing.T) { inv, root := clitest.New(t, "create", workspaceName, "--template", template.Name, "--preset", preset.Name) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -1778,21 +1771,18 @@ func TestCreateWithPreset(t *testing.T) { // Should: display the selected preset as well as its parameters presetName := fmt.Sprintf("Preset '%s' applied:", preset.Name) - pty.ExpectMatch(presetName) - pty.ExpectMatch(fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) + stdout.ExpectMatchContext(ctx, presetName) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("%s: '%s'", firstParameterName, secondOptionalParameterValue)) // Should: prompt for the missing parameter - pty.ExpectMatch(thirdParameterDescription) - pty.WriteLine(thirdParameterValue) - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, thirdParameterDescription) + stdin.WriteLine(thirdParameterValue) + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") <-doneChan // Verify if the new workspace uses expected parameters. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - tvPresets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) require.Len(t, tvPresets, 1) @@ -1857,7 +1847,8 @@ func TestCreateValidateRichParameters(t *testing.T) { t.Run("ValidateString", func(t *testing.T) { t.Parallel() - + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -1869,7 +1860,8 @@ func TestCreateValidateRichParameters(t *testing.T) { inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -1885,9 +1877,9 @@ func TestCreateValidateRichParameters(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan @@ -1895,6 +1887,8 @@ func TestCreateValidateRichParameters(t *testing.T) { t.Run("ValidateNumber", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1907,7 +1901,8 @@ func TestCreateValidateRichParameters(t *testing.T) { inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -1923,9 +1918,9 @@ func TestCreateValidateRichParameters(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan @@ -1933,6 +1928,8 @@ func TestCreateValidateRichParameters(t *testing.T) { t.Run("ValidateNumber_CustomError", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1945,7 +1942,8 @@ func TestCreateValidateRichParameters(t *testing.T) { inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -1961,9 +1959,9 @@ func TestCreateValidateRichParameters(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan @@ -1971,6 +1969,8 @@ func TestCreateValidateRichParameters(t *testing.T) { t.Run("ValidateBool", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -1983,7 +1983,8 @@ func TestCreateValidateRichParameters(t *testing.T) { inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() @@ -1999,9 +2000,9 @@ func TestCreateValidateRichParameters(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan @@ -2018,15 +2019,18 @@ func TestCreateValidateRichParameters(t *testing.T) { template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID) t.Run("Prompt", func(t *testing.T) { + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) inv, root := clitest.New(t, "create", "my-workspace-1", "--template", template.Name) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) clitest.Start(t, inv) - pty.ExpectMatch(listOfStringsParameterName) - pty.ExpectMatch("aaa, bbb, ccc") - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, listOfStringsParameterName) + stdout.ExpectMatchContext(ctx, "aaa, bbb, ccc") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") }) t.Run("Default", func(t *testing.T) { @@ -2049,6 +2053,8 @@ func TestCreateValidateRichParameters(t *testing.T) { t.Run("ValidateListOfStrings_YAMLFile", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -2066,8 +2072,8 @@ func TestCreateValidateRichParameters(t *testing.T) { - fff`) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) clitest.Start(t, inv) matches := []string{ @@ -2076,9 +2082,9 @@ func TestCreateValidateRichParameters(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } }) @@ -2086,6 +2092,8 @@ func TestCreateValidateRichParameters(t *testing.T) { func TestCreateWithGitAuth(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) echoResponses := &echo.Responses{ Parse: echo.ParseComplete, ProvisionInit: echo.InitComplete, @@ -2120,13 +2128,14 @@ func TestCreateWithGitAuth(t *testing.T) { inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) clitest.Start(t, inv) - pty.ExpectMatch("You must authenticate with GitHub to create a workspace") + stdout.ExpectMatchContext(ctx, "You must authenticate with GitHub to create a workspace") resp := coderdtest.RequestExternalAuthCallback(t, "github", member) _ = resp.Body.Close() require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - pty.ExpectMatch("Confirm create?") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Confirm create?") + stdin.WriteLine("yes") } diff --git a/cli/delete_test.go b/cli/delete_test.go index 909166876d..c8dff9646a 100644 --- a/cli/delete_test.go +++ b/cli/delete_test.go @@ -22,8 +22,8 @@ import ( "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/quartz" ) @@ -31,6 +31,7 @@ func TestDelete(t *testing.T) { t.Parallel() t.Run("WithParameter", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -42,7 +43,7 @@ func TestDelete(t *testing.T) { inv, root := clitest.New(t, "delete", workspace.Name, "-y") clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() @@ -51,7 +52,7 @@ func TestDelete(t *testing.T) { assert.ErrorIs(t, err, io.EOF) } }() - pty.ExpectMatch("has been deleted") + stdout.ExpectMatchContext(ctx, "has been deleted") <-doneChan }) @@ -71,8 +72,7 @@ func TestDelete(t *testing.T) { clitest.SetupConfig(t, templateAdmin, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.WithContext(ctx).Run() @@ -81,7 +81,7 @@ func TestDelete(t *testing.T) { assert.ErrorIs(t, err, io.EOF) } }() - pty.ExpectMatch("has been deleted") + stdout.ExpectMatchContext(ctx, "has been deleted") testutil.TryReceive(ctx, t, doneChan) _, err := client.Workspace(ctx, workspace.ID) @@ -117,8 +117,7 @@ func TestDelete(t *testing.T) { //nolint:gocritic // Deleting orphaned workspaces requires an admin. clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() @@ -127,7 +126,7 @@ func TestDelete(t *testing.T) { assert.ErrorIs(t, err, io.EOF) } }() - pty.ExpectMatch("has been deleted") + stdout.ExpectMatchContext(ctx, "has been deleted") <-doneChan }) @@ -146,11 +145,12 @@ func TestDelete(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + ctx := testutil.Context(t, testutil.WaitMedium) inv, root := clitest.New(t, "delete", user.Username+"/"+workspace.Name, "-y") //nolint:gocritic // This requires an admin. clitest.SetupConfig(t, adminClient, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() @@ -160,7 +160,7 @@ func TestDelete(t *testing.T) { } }() - pty.ExpectMatch("has been deleted") + stdout.ExpectMatchContext(ctx, "has been deleted") <-doneChan workspace, err = client.Workspace(context.Background(), workspace.ID) @@ -207,7 +207,7 @@ func TestDelete(t *testing.T) { // Then: the workspace deletion should warn about no provisioners inv, root := clitest.New(t, "delete", workspace.Name, "-y") - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.SetupConfig(t, templateAdmin, root) doneChan := make(chan struct{}) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -216,7 +216,7 @@ func TestDelete(t *testing.T) { defer close(doneChan) _ = inv.WithContext(ctx).Run() }() - pty.ExpectMatch("there are no provisioners that accept the required tags") + stdout.ExpectMatchContext(ctx, "there are no provisioners that accept the required tags") cancel() <-doneChan }) @@ -311,7 +311,7 @@ func TestDelete(t *testing.T) { inv, root := clitest.New(t, "delete", workspaceOwner+"/"+workspace.Name, "-y") clitest.SetupConfig(t, runClient, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) var runErr error go func() { defer close(doneChan) @@ -324,7 +324,7 @@ func TestDelete(t *testing.T) { require.Error(t, runErr) require.Contains(t, runErr.Error(), expectedErr) } else { - pty.ExpectMatch("has been deleted") + stdout.ExpectMatchContext(ctx, "has been deleted") <-doneChan // When running with the race detector on, we sometimes get an EOF. diff --git a/cli/exp_rpty_test.go b/cli/exp_rpty_test.go index eb29190c6f..72548188ea 100644 --- a/cli/exp_rpty_test.go +++ b/cli/exp_rpty_test.go @@ -15,8 +15,8 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestExpRpty(t *testing.T) { @@ -28,7 +28,7 @@ func TestExpRpty(t *testing.T) { client, workspace, agentToken := setupWorkspaceForAgent(t) inv, root := clitest.New(t, "exp", "rpty", workspace.Name) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdin := testutil.NewWriterAttachedToInvocation(t, testutil.Logger(t), inv) ctx := testutil.Context(t, testutil.WaitLong) @@ -40,7 +40,7 @@ func TestExpRpty(t *testing.T) { assert.NoError(t, err) }) - pty.WriteLine("exit") + stdin.WriteLine("exit") <-cmdDone }) @@ -51,7 +51,7 @@ func TestExpRpty(t *testing.T) { randStr := uuid.NewString() inv, root := clitest.New(t, "exp", "rpty", workspace.Name, "echo", randStr) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) ctx := testutil.Context(t, testutil.WaitLong) @@ -63,7 +63,7 @@ func TestExpRpty(t *testing.T) { assert.NoError(t, err) }) - pty.ExpectMatch(randStr) + stdout.ExpectMatchContext(ctx, randStr) <-cmdDone }) @@ -86,6 +86,7 @@ func TestExpRpty(t *testing.T) { t.Skip("Skipping test on non-Linux platform") } + logger := testutil.Logger(t) wantLabel := "coder.devcontainers.TestExpRpty.Container" client, workspace, agentToken := setupWorkspaceForAgent(t) @@ -124,7 +125,8 @@ func TestExpRpty(t *testing.T) { inv, root := clitest.New(t, "exp", "rpty", workspace.Name, "-c", ct.Container.ID) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) ctx := testutil.Context(t, testutil.WaitLong) cmdDone := tGo(t, func() { @@ -132,10 +134,10 @@ func TestExpRpty(t *testing.T) { assert.NoError(t, err) }) - pty.ExpectMatchContext(ctx, " #") - pty.WriteLine("hostname") - pty.ExpectMatchContext(ctx, ct.Container.Config.Hostname) - pty.WriteLine("exit") + stdout.ExpectMatchContext(ctx, " #") + stdin.WriteLine("hostname") + stdout.ExpectMatchContext(ctx, ct.Container.Config.Hostname) + stdin.WriteLine("exit") <-cmdDone }) } diff --git a/cli/list_test.go b/cli/list_test.go index 8cdde03072..201188ad1e 100644 --- a/cli/list_test.go +++ b/cli/list_test.go @@ -15,8 +15,8 @@ import ( "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestList(t *testing.T) { @@ -34,7 +34,7 @@ func TestList(t *testing.T) { inv, root := clitest.New(t, "ls") clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancelFunc() @@ -44,8 +44,8 @@ func TestList(t *testing.T) { assert.NoError(t, errC) close(done) }() - pty.ExpectMatch(r.Workspace.Name) - pty.ExpectMatch("Started") + stdout.ExpectMatchContext(ctx, r.Workspace.Name) + stdout.ExpectMatchContext(ctx, "Started") cancelFunc() <-done }) diff --git a/cli/login_test.go b/cli/login_test.go index 6d6e54eb6e..5768a68127 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "runtime" "testing" "github.com/stretchr/testify/assert" @@ -15,8 +14,8 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/pretty" ) @@ -74,13 +73,16 @@ func TestLogin(t *testing.T) { t.Run("InitialUserTTY", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) // The --force-tty flag is required on Windows, because the `isatty` library does not // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), root) + ctx := testutil.Context(t, testutil.WaitMedium) go func() { defer close(doneChan) err := root.Run() @@ -105,12 +107,11 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) - pty.WriteLine(value) + stdout.ExpectMatchContext(ctx, match) + stdin.WriteLine(value) } - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, "Welcome to Coder") <-doneChan - ctx := testutil.Context(t, testutil.WaitShort) resp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ Email: coderdtest.FirstUserParams.Email, Password: coderdtest.FirstUserParams.Password, @@ -126,13 +127,16 @@ func TestLogin(t *testing.T) { t.Run("InitialUserTTYWithNoTrial", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) // The --force-tty flag is required on Windows, because the `isatty` library does not // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), root) + ctx := testutil.Context(t, testutil.WaitMedium) go func() { defer close(doneChan) err := root.Run() @@ -151,12 +155,11 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) - pty.WriteLine(value) + stdout.ExpectMatchContext(ctx, match) + stdin.WriteLine(value) } - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, "Welcome to Coder") <-doneChan - ctx := testutil.Context(t, testutil.WaitShort) resp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ Email: coderdtest.FirstUserParams.Email, Password: coderdtest.FirstUserParams.Password, @@ -172,13 +175,16 @@ func TestLogin(t *testing.T) { t.Run("InitialUserTTYNameOptional", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) // The --force-tty flag is required on Windows, because the `isatty` library does not // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), root) + ctx := testutil.Context(t, testutil.WaitMedium) go func() { defer close(doneChan) err := root.Run() @@ -203,12 +209,11 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) - pty.WriteLine(value) + stdout.ExpectMatchContext(ctx, match) + stdin.WriteLine(value) } - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, "Welcome to Coder") <-doneChan - ctx := testutil.Context(t, testutil.WaitShort) resp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ Email: coderdtest.FirstUserParams.Email, Password: coderdtest.FirstUserParams.Password, @@ -224,16 +229,19 @@ func TestLogin(t *testing.T) { t.Run("InitialUserTTYFlag", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) // The --force-tty flag is required on Windows, because the `isatty` library does not // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 inv, _ := clitest.New(t, "--url", client.URL.String(), "login", "--force-tty") - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) + ctx := testutil.Context(t, testutil.WaitMedium) clitest.Start(t, inv) - pty.ExpectMatch(fmt.Sprintf("Attempting to authenticate with flag URL: '%s'", client.URL.String())) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Attempting to authenticate with flag URL: '%s'", client.URL.String())) matches := []string{ "first user?", "yes", "username", coderdtest.FirstUserParams.Username, @@ -252,11 +260,10 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) - pty.WriteLine(value) + stdout.ExpectMatchContext(ctx, match) + stdin.WriteLine(value) } - pty.ExpectMatch("Welcome to Coder") - ctx := testutil.Context(t, testutil.WaitShort) + stdout.ExpectMatchContext(ctx, "Welcome to Coder") resp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ Email: coderdtest.FirstUserParams.Email, Password: coderdtest.FirstUserParams.Password, @@ -272,6 +279,7 @@ func TestLogin(t *testing.T) { t.Run("InitialUserFlags", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) inv, _ := clitest.New( t, "login", client.URL.String(), @@ -281,22 +289,23 @@ func TestLogin(t *testing.T) { "--first-user-password", coderdtest.FirstUserParams.Password, "--first-user-trial", ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) + ctx := testutil.Context(t, testutil.WaitMedium) w := clitest.StartWithWaiter(t, inv) - pty.ExpectMatch("firstName") - pty.WriteLine(coderdtest.TrialUserParams.FirstName) - pty.ExpectMatch("lastName") - pty.WriteLine(coderdtest.TrialUserParams.LastName) - pty.ExpectMatch("phoneNumber") - pty.WriteLine(coderdtest.TrialUserParams.PhoneNumber) - pty.ExpectMatch("jobTitle") - pty.WriteLine(coderdtest.TrialUserParams.JobTitle) - pty.ExpectMatch("companyName") - pty.WriteLine(coderdtest.TrialUserParams.CompanyName) + stdout.ExpectMatchContext(ctx, "firstName") + stdin.WriteLine(coderdtest.TrialUserParams.FirstName) + stdout.ExpectMatchContext(ctx, "lastName") + stdin.WriteLine(coderdtest.TrialUserParams.LastName) + stdout.ExpectMatchContext(ctx, "phoneNumber") + stdin.WriteLine(coderdtest.TrialUserParams.PhoneNumber) + stdout.ExpectMatchContext(ctx, "jobTitle") + stdin.WriteLine(coderdtest.TrialUserParams.JobTitle) + stdout.ExpectMatchContext(ctx, "companyName") + stdin.WriteLine(coderdtest.TrialUserParams.CompanyName) // `developers` and `country` `cliui.Select` automatically selects the first option during tests. - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, "Welcome to Coder") w.RequireSuccess() - ctx := testutil.Context(t, testutil.WaitShort) resp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ Email: coderdtest.FirstUserParams.Email, Password: coderdtest.FirstUserParams.Password, @@ -312,6 +321,7 @@ func TestLogin(t *testing.T) { t.Run("InitialUserFlagsNameOptional", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) inv, _ := clitest.New( t, "login", client.URL.String(), @@ -320,22 +330,23 @@ func TestLogin(t *testing.T) { "--first-user-password", coderdtest.FirstUserParams.Password, "--first-user-trial", ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) + ctx := testutil.Context(t, testutil.WaitMedium) w := clitest.StartWithWaiter(t, inv) - pty.ExpectMatch("firstName") - pty.WriteLine(coderdtest.TrialUserParams.FirstName) - pty.ExpectMatch("lastName") - pty.WriteLine(coderdtest.TrialUserParams.LastName) - pty.ExpectMatch("phoneNumber") - pty.WriteLine(coderdtest.TrialUserParams.PhoneNumber) - pty.ExpectMatch("jobTitle") - pty.WriteLine(coderdtest.TrialUserParams.JobTitle) - pty.ExpectMatch("companyName") - pty.WriteLine(coderdtest.TrialUserParams.CompanyName) + stdout.ExpectMatchContext(ctx, "firstName") + stdin.WriteLine(coderdtest.TrialUserParams.FirstName) + stdout.ExpectMatchContext(ctx, "lastName") + stdin.WriteLine(coderdtest.TrialUserParams.LastName) + stdout.ExpectMatchContext(ctx, "phoneNumber") + stdin.WriteLine(coderdtest.TrialUserParams.PhoneNumber) + stdout.ExpectMatchContext(ctx, "jobTitle") + stdin.WriteLine(coderdtest.TrialUserParams.JobTitle) + stdout.ExpectMatchContext(ctx, "companyName") + stdin.WriteLine(coderdtest.TrialUserParams.CompanyName) // `developers` and `country` `cliui.Select` automatically selects the first option during tests. - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, "Welcome to Coder") w.RequireSuccess() - ctx := testutil.Context(t, testutil.WaitShort) resp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ Email: coderdtest.FirstUserParams.Email, Password: coderdtest.FirstUserParams.Password, @@ -351,6 +362,7 @@ func TestLogin(t *testing.T) { t.Run("InitialUserTTYConfirmPasswordFailAndReprompt", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() client := coderdtest.New(t, nil) @@ -359,7 +371,8 @@ func TestLogin(t *testing.T) { // https://github.com/mattn/go-isatty/issues/59 doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), root) go func() { defer close(doneChan) err := root.WithContext(ctx).Run() @@ -377,59 +390,60 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) - pty.WriteLine(value) + stdout.ExpectMatchContext(ctx, match) + stdin.WriteLine(value) } // Validate that we reprompt for matching passwords. - pty.ExpectMatch("Passwords do not match") - pty.ExpectMatch("Enter a " + pretty.Sprint(cliui.DefaultStyles.Field, "password")) - pty.WriteLine(coderdtest.FirstUserParams.Password) - pty.ExpectMatch("Confirm") - pty.WriteLine(coderdtest.FirstUserParams.Password) - pty.ExpectMatch("trial") - pty.WriteLine("yes") - pty.ExpectMatch("firstName") - pty.WriteLine(coderdtest.TrialUserParams.FirstName) - pty.ExpectMatch("lastName") - pty.WriteLine(coderdtest.TrialUserParams.LastName) - pty.ExpectMatch("phoneNumber") - pty.WriteLine(coderdtest.TrialUserParams.PhoneNumber) - pty.ExpectMatch("jobTitle") - pty.WriteLine(coderdtest.TrialUserParams.JobTitle) - pty.ExpectMatch("companyName") - pty.WriteLine(coderdtest.TrialUserParams.CompanyName) - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, "Passwords do not match") + stdout.ExpectMatchContext(ctx, "Enter a "+pretty.Sprint(cliui.DefaultStyles.Field, "password")) + stdin.WriteLine(coderdtest.FirstUserParams.Password) + stdout.ExpectMatchContext(ctx, "Confirm") + stdin.WriteLine(coderdtest.FirstUserParams.Password) + stdout.ExpectMatchContext(ctx, "trial") + stdin.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "firstName") + stdin.WriteLine(coderdtest.TrialUserParams.FirstName) + stdout.ExpectMatchContext(ctx, "lastName") + stdin.WriteLine(coderdtest.TrialUserParams.LastName) + stdout.ExpectMatchContext(ctx, "phoneNumber") + stdin.WriteLine(coderdtest.TrialUserParams.PhoneNumber) + stdout.ExpectMatchContext(ctx, "jobTitle") + stdin.WriteLine(coderdtest.TrialUserParams.JobTitle) + stdout.ExpectMatchContext(ctx, "companyName") + stdin.WriteLine(coderdtest.TrialUserParams.CompanyName) + stdout.ExpectMatchContext(ctx, "Welcome to Coder") <-doneChan }) t.Run("ExistingUserValidTokenTTY", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) coderdtest.CreateFirstUser(t, client) + ctx := testutil.Context(t, testutil.WaitMedium) doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open") - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), root) go func() { defer close(doneChan) err := root.Run() assert.NoError(t, err) }() - pty.ExpectMatch(fmt.Sprintf("Attempting to authenticate with argument URL: '%s'", client.URL.String())) - pty.ExpectMatch("Paste your token here:") - pty.WriteLine(client.SessionToken()) - if runtime.GOOS != "windows" { - // For some reason, the match does not show up on Windows. - pty.ExpectMatch(client.SessionToken()) - } - pty.ExpectMatch("Welcome to Coder") + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Attempting to authenticate with argument URL: '%s'", client.URL.String())) + stdout.ExpectMatchContext(ctx, "Paste your token here:") + stdin.WriteLine(client.SessionToken()) + stdout.ExpectMatchContext(ctx, "Welcome to Coder") <-doneChan }) t.Run("ExistingUserURLSavedInConfig", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, nil) url := client.URL.String() coderdtest.CreateFirstUser(t, client) @@ -438,21 +452,24 @@ func TestLogin(t *testing.T) { clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() - pty.ExpectMatch(fmt.Sprintf("Attempting to authenticate with config URL: '%s'", url)) - pty.ExpectMatch("Paste your token here:") - pty.WriteLine(client.SessionToken()) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Attempting to authenticate with config URL: '%s'", url)) + stdout.ExpectMatchContext(ctx, "Paste your token here:") + stdin.WriteLine(client.SessionToken()) <-doneChan }) t.Run("ExistingUserURLSavedInEnv", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, nil) url := client.URL.String() coderdtest.CreateFirstUser(t, client) @@ -461,21 +478,23 @@ func TestLogin(t *testing.T) { inv.Environ.Set("CODER_URL", url) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() - pty.ExpectMatch(fmt.Sprintf("Attempting to authenticate with environment URL: '%s'", url)) - pty.ExpectMatch("Paste your token here:") - pty.WriteLine(client.SessionToken()) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Attempting to authenticate with environment URL: '%s'", url)) + stdout.ExpectMatchContext(ctx, "Paste your token here:") + stdin.WriteLine(client.SessionToken()) <-doneChan }) t.Run("ExistingUserInvalidTokenTTY", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) coderdtest.CreateFirstUser(t, client) @@ -483,7 +502,8 @@ func TestLogin(t *testing.T) { defer cancelFunc() doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", client.URL.String(), "--no-open") - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), root) go func() { defer close(doneChan) err := root.WithContext(ctx).Run() @@ -491,13 +511,9 @@ func TestLogin(t *testing.T) { assert.Error(t, err) }() - pty.ExpectMatch("Paste your token here:") - pty.WriteLine("an-invalid-token") - if runtime.GOOS != "windows" { - // For some reason, the match does not show up on Windows. - pty.ExpectMatch("an-invalid-token") - } - pty.ExpectMatch("That's not a valid token!") + stdout.ExpectMatchContext(ctx, "Paste your token here:") + stdin.WriteLine("an-invalid-token") + stdout.ExpectMatchContext(ctx, "That's not a valid token!") cancelFunc() <-doneChan }) @@ -582,12 +598,12 @@ func TestLoginToken(t *testing.T) { inv, root := clitest.New(t, "login", "token", "--url", client.URL.String()) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) ctx := testutil.Context(t, testutil.WaitShort) err := inv.WithContext(ctx).Run() require.NoError(t, err) - pty.ExpectMatch(client.SessionToken()) + stdout.ExpectMatchContext(ctx, client.SessionToken()) }) t.Run("NoTokenStored", func(t *testing.T) { diff --git a/cli/organization_test.go b/cli/organization_test.go index 8c4997f4ae..ab5751b513 100644 --- a/cli/organization_test.go +++ b/cli/organization_test.go @@ -17,7 +17,8 @@ import ( "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" + "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/pretty" ) @@ -29,6 +30,7 @@ func TestCurrentOrganization(t *testing.T) { // 2. The user is connecting to an older Coder instance. t.Run("no-default", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) orgID := uuid.New() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -49,13 +51,13 @@ func TestCurrentOrganization(t *testing.T) { client := codersdk.New(must(url.Parse(srv.URL))) inv, root := clitest.New(t, "organizations", "show", "selected") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) errC := make(chan error) go func() { errC <- inv.Run() }() require.NoError(t, <-errC) - pty.ExpectMatch(orgID.String()) + stdout.ExpectMatchContext(ctx, orgID.String()) }) } @@ -140,6 +142,8 @@ func TestOrganizationDelete(t *testing.T) { t.Run("Prompted", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) + ctx := testutil.Context(t, testutil.WaitMedium) orgID := uuid.New() var deleteCalled atomic.Bool @@ -167,15 +171,16 @@ func TestOrganizationDelete(t *testing.T) { client := codersdk.New(must(url.Parse(server.URL))) inv, root := clitest.New(t, "organizations", "delete", "my-org") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) execDone := make(chan error) go func() { execDone <- inv.Run() }() - pty.ExpectMatch(fmt.Sprintf("Delete organization %s?", pretty.Sprint(cliui.DefaultStyles.Code, "my-org"))) - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Delete organization %s?", pretty.Sprint(cliui.DefaultStyles.Code, "my-org"))) + stdin.WriteLine("yes") require.NoError(t, <-execDone) require.True(t, deleteCalled.Load(), "expected delete request") diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 91c13efabe..d0cfeeb8fb 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -25,8 +25,8 @@ import ( "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestPortForward_None(t *testing.T) { @@ -160,10 +160,7 @@ func TestPortForward(t *testing.T) { // the "local" listener. inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t) - inv.Stdin = pty.Input() - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) iNet := testutil.NewInProcNet() inv.Net = iNet @@ -175,7 +172,7 @@ func TestPortForward(t *testing.T) { t.Logf("command complete; err=%s", err.Error()) errC <- err }() - pty.ExpectMatchContext(ctx, "Ready!") + stdout.ExpectMatchContext(ctx, "Ready!") // Open two connections simultaneously and test them out of // sync. @@ -216,10 +213,7 @@ func TestPortForward(t *testing.T) { // the "local" listeners. inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag1, flag2) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t) - inv.Stdin = pty.Input() - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) iNet := testutil.NewInProcNet() inv.Net = iNet @@ -229,7 +223,7 @@ func TestPortForward(t *testing.T) { go func() { errC <- inv.WithContext(ctx).Run() }() - pty.ExpectMatchContext(ctx, "Ready!") + stdout.ExpectMatchContext(ctx, "Ready!") // Open a connection to both listener 1 and 2 simultaneously and // then test them out of order. @@ -277,8 +271,7 @@ func TestPortForward(t *testing.T) { // the "local" listeners. inv, root := clitest.New(t, append([]string{"-v", "port-forward", workspace.Name}, flags...)...) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) iNet := testutil.NewInProcNet() inv.Net = iNet @@ -288,7 +281,7 @@ func TestPortForward(t *testing.T) { go func() { errC <- inv.WithContext(ctx).Run() }() - pty.ExpectMatchContext(ctx, "Ready!") + stdout.ExpectMatchContext(ctx, "Ready!") // Open connections to all items in the "dial" array. var ( @@ -338,10 +331,7 @@ func TestPortForward(t *testing.T) { // the "local" listener. inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag) clitest.SetupConfig(t, member, root) - pty := ptytest.New(t) - inv.Stdin = pty.Input() - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) iNet := testutil.NewInProcNet() inv.Net = iNet @@ -359,7 +349,7 @@ func TestPortForward(t *testing.T) { t.Logf("command complete; err=%s", err.Error()) errC <- err }() - pty.ExpectMatchContext(ctx, "Ready!") + stdout.ExpectMatchContext(ctx, "Ready!") // Test IPv4 still works dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort) diff --git a/cli/rename_test.go b/cli/rename_test.go index 31d14e5e08..e9aa8d480d 100644 --- a/cli/rename_test.go +++ b/cli/rename_test.go @@ -8,12 +8,13 @@ import ( "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestRename(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true, AllowWorkspaceRenames: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -30,13 +31,13 @@ func TestRename(t *testing.T) { want := coderdtest.RandomUsername(t) inv, root := clitest.New(t, "rename", workspace.Name, want, "--yes") clitest.SetupConfig(t, member, root) - pty := ptytest.New(t) - pty.Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) clitest.Start(t, inv) - pty.ExpectMatch("confirm rename:") - pty.WriteLine(workspace.Name) - pty.ExpectMatch("renamed to") + stdout.ExpectMatchContext(ctx, "confirm rename:") + stdin.WriteLine(workspace.Name) + stdout.ExpectMatchContext(ctx, "renamed to") ws, err := client.Workspace(ctx, workspace.ID) assert.NoError(t, err) diff --git a/cli/restart_test.go b/cli/restart_test.go index a8cd7ee5f3..3506d313a2 100644 --- a/cli/restart_test.go +++ b/cli/restart_test.go @@ -1,7 +1,6 @@ package cli_test import ( - "context" "fmt" "testing" @@ -14,8 +13,8 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/provisionersdk/proto" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestRestart(t *testing.T) { @@ -49,15 +48,15 @@ func TestRestart(t *testing.T) { inv, root := clitest.New(t, "restart", workspace.Name, "--yes") clitest.SetupConfig(t, member, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) done := make(chan error, 1) go func() { done <- inv.WithContext(ctx).Run() }() - pty.ExpectMatch("Stopping workspace") - pty.ExpectMatch("Starting workspace") - pty.ExpectMatch("workspace has been restarted") + stdout.ExpectMatchContext(ctx, "Stopping workspace") + stdout.ExpectMatchContext(ctx, "Starting workspace") + stdout.ExpectMatchContext(ctx, "workspace has been restarted") err := <-done require.NoError(t, err, "execute failed") @@ -66,6 +65,7 @@ func TestRestart(t *testing.T) { t.Run("PromptEphemeralParameters", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -84,13 +84,15 @@ func TestRestart(t *testing.T) { inv, root := clitest.New(t, "restart", workspace.Name, "--prompt-ephemeral-parameters") clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() + ctx := testutil.Context(t, testutil.WaitShort) matches := []string{ ephemeralParameterDescription, ephemeralParameterValue, "Restart workspace?", "yes", @@ -101,18 +103,15 @@ func TestRestart(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan // Verify if build option is set - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) @@ -126,6 +125,7 @@ func TestRestart(t *testing.T) { t.Run("EphemeralParameterFlags", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -143,13 +143,15 @@ func TestRestart(t *testing.T) { "--ephemeral-parameter", fmt.Sprintf("%s=%s", ephemeralParameterName, ephemeralParameterValue)) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() + ctx := testutil.Context(t, testutil.WaitShort) matches := []string{ "Restart workspace?", "yes", "Stopping workspace", "", @@ -159,18 +161,15 @@ func TestRestart(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan // Verify if build option is set - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) @@ -184,6 +183,7 @@ func TestRestart(t *testing.T) { t.Run("with deprecated build-options flag", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -202,13 +202,15 @@ func TestRestart(t *testing.T) { inv, root := clitest.New(t, "restart", workspace.Name, "--build-options") clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() + ctx := testutil.Context(t, testutil.WaitShort) matches := []string{ ephemeralParameterDescription, ephemeralParameterValue, "Restart workspace?", "yes", @@ -219,18 +221,15 @@ func TestRestart(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan // Verify if build option is set - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) @@ -244,6 +243,7 @@ func TestRestart(t *testing.T) { t.Run("with deprecated build-option flag", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) @@ -261,13 +261,15 @@ func TestRestart(t *testing.T) { "--build-option", fmt.Sprintf("%s=%s", ephemeralParameterName, ephemeralParameterValue)) clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() + ctx := testutil.Context(t, testutil.WaitShort) matches := []string{ "Restart workspace?", "yes", "Stopping workspace", "", @@ -277,18 +279,15 @@ func TestRestart(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - pty.ExpectMatch(match) + stdout.ExpectMatchContext(ctx, match) if value != "" { - pty.WriteLine(value) + stdin.WriteLine(value) } } <-doneChan // Verify if build option is set - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) @@ -349,20 +348,18 @@ func TestRestartWithParameters(t *testing.T) { inv, root := clitest.New(t, "restart", workspace.Name, "-y") clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() + ctx := testutil.Context(t, testutil.WaitShort) - pty.ExpectMatch("workspace has been restarted") + stdout.ExpectMatchContext(ctx, "workspace has been restarted") <-doneChan // Verify if immutable parameter is set - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspace, err := client.WorkspaceByOwnerAndName(ctx, workspace.OwnerName, workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) @@ -376,6 +373,7 @@ func TestRestartWithParameters(t *testing.T) { t.Run("AlwaysPrompt", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) // Create the workspace client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) owner := coderdtest.CreateFirstUser(t, client) @@ -396,24 +394,23 @@ func TestRestartWithParameters(t *testing.T) { inv, root := clitest.New(t, "restart", workspace.Name, "-y", "--always-prompt") clitest.SetupConfig(t, member, root) doneChan := make(chan struct{}) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) go func() { defer close(doneChan) err := inv.Run() assert.NoError(t, err) }() + ctx := testutil.Context(t, testutil.WaitShort) // We should be prompted for the parameters again. newValue := "xyz" - pty.ExpectMatch(mutableParameterName) - pty.WriteLine(newValue) - pty.ExpectMatch("workspace has been restarted") + stdout.ExpectMatchContext(ctx, mutableParameterName) + stdin.WriteLine(newValue) + stdout.ExpectMatchContext(ctx, "workspace has been restarted") <-doneChan // Verify that the updated values are persisted. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - workspace, err := client.WorkspaceByOwnerAndName(ctx, workspace.OwnerName, workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID) diff --git a/cli/schedule_test.go b/cli/schedule_test.go index ed9c5b1743..c9f61345a1 100644 --- a/cli/schedule_test.go +++ b/cli/schedule_test.go @@ -19,8 +19,8 @@ import ( "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/util/tz" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) // setupTestSchedule creates 4 workspaces: @@ -97,20 +97,21 @@ func TestScheduleShow(t *testing.T) { inv, root := clitest.New(t, "schedule", "show") //nolint:gocritic // Testing that owner user sees all clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: they should see their own workspaces. // 1st workspace: a-owner-ws1 has both autostart and autostop enabled. - pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) - pty.ExpectMatch("8h") - pty.ExpectMatch(ws[0].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[0].OwnerName+"/"+ws[0].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, ws[0].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) // 2nd workspace: b-owner-ws2 has only autostart enabled. - pty.ExpectMatch(ws[1].OwnerName + "/" + ws[1].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[1].OwnerName+"/"+ws[1].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) }) t.Run("OwnerAll", func(t *testing.T) { @@ -118,26 +119,27 @@ func TestScheduleShow(t *testing.T) { inv, root := clitest.New(t, "schedule", "show", "--all") //nolint:gocritic // Testing that owner user sees all clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: they should see all workspaces // 1st workspace: a-owner-ws1 has both autostart and autostop enabled. - pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) - pty.ExpectMatch("8h") - pty.ExpectMatch(ws[0].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[0].OwnerName+"/"+ws[0].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, ws[0].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) // 2nd workspace: b-owner-ws2 has only autostart enabled. - pty.ExpectMatch(ws[1].OwnerName + "/" + ws[1].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[1].OwnerName+"/"+ws[1].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) // 3rd workspace: c-member-ws3 has only autostop enabled. - pty.ExpectMatch(ws[2].OwnerName + "/" + ws[2].Name) - pty.ExpectMatch("8h") - pty.ExpectMatch(ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[2].OwnerName+"/"+ws[2].Name) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) // 4th workspace: d-member-ws4 has neither autostart nor autostop enabled. - pty.ExpectMatch(ws[3].OwnerName + "/" + ws[3].Name) + stdout.ExpectMatchContext(ctx, ws[3].OwnerName+"/"+ws[3].Name) }) t.Run("OwnerSearchByName", func(t *testing.T) { @@ -145,14 +147,15 @@ func TestScheduleShow(t *testing.T) { inv, root := clitest.New(t, "schedule", "show", "--search", "name:"+ws[1].Name) //nolint:gocritic // Testing that owner user sees all clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: they should see workspaces matching that query // 2nd workspace: b-owner-ws2 has only autostart enabled. - pty.ExpectMatch(ws[1].OwnerName + "/" + ws[1].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[1].OwnerName+"/"+ws[1].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) }) t.Run("OwnerOneArg", func(t *testing.T) { @@ -160,37 +163,39 @@ func TestScheduleShow(t *testing.T) { inv, root := clitest.New(t, "schedule", "show", ws[2].OwnerName+"/"+ws[2].Name) //nolint:gocritic // Testing that owner user sees all clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: they should see that workspace // 3rd workspace: c-member-ws3 has only autostop enabled. - pty.ExpectMatch(ws[2].OwnerName + "/" + ws[2].Name) - pty.ExpectMatch("8h") - pty.ExpectMatch(ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[2].OwnerName+"/"+ws[2].Name) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) }) t.Run("MemberNoArgs", func(t *testing.T) { // When: a member specifies no args inv, root := clitest.New(t, "schedule", "show") clitest.SetupConfig(t, memberClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: they should see their own workspaces // 1st workspace: c-member-ws3 has only autostop enabled. - pty.ExpectMatch(ws[2].OwnerName + "/" + ws[2].Name) - pty.ExpectMatch("8h") - pty.ExpectMatch(ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[2].OwnerName+"/"+ws[2].Name) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) // 2nd workspace: d-member-ws4 has neither autostart nor autostop enabled. - pty.ExpectMatch(ws[3].OwnerName + "/" + ws[3].Name) + stdout.ExpectMatchContext(ctx, ws[3].OwnerName+"/"+ws[3].Name) }) t.Run("MemberAll", func(t *testing.T) { // When: a member lists all workspaces inv, root := clitest.New(t, "schedule", "show", "--all") clitest.SetupConfig(t, memberClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) ctx := testutil.Context(t, testutil.WaitShort) errC := make(chan error) go func() { @@ -200,11 +205,11 @@ func TestScheduleShow(t *testing.T) { // Then: they should only see their own // 1st workspace: c-member-ws3 has only autostop enabled. - pty.ExpectMatch(ws[2].OwnerName + "/" + ws[2].Name) - pty.ExpectMatch("8h") - pty.ExpectMatch(ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[2].OwnerName+"/"+ws[2].Name) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) // 2nd workspace: d-member-ws4 has neither autostart nor autostop enabled. - pty.ExpectMatch(ws[3].OwnerName + "/" + ws[3].Name) + stdout.ExpectMatchContext(ctx, ws[3].OwnerName+"/"+ws[3].Name) }) t.Run("JSON", func(t *testing.T) { @@ -276,13 +281,14 @@ func TestScheduleModify(t *testing.T) { ) //nolint:gocritic // this workspace is not owned by the same user clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: the updated schedule should be shown - pty.ExpectMatch(ws[3].OwnerName + "/" + ws[3].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[3].OwnerName+"/"+ws[3].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) }) t.Run("SetStop", func(t *testing.T) { @@ -292,13 +298,14 @@ func TestScheduleModify(t *testing.T) { ) //nolint:gocritic // this workspace is not owned by the same user clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: the updated schedule should be shown - pty.ExpectMatch(ws[2].OwnerName + "/" + ws[2].Name) - pty.ExpectMatch("8h30m") - pty.ExpectMatch(ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, ws[2].OwnerName+"/"+ws[2].Name) + stdout.ExpectMatchContext(ctx, "8h30m") + stdout.ExpectMatchContext(ctx, ws[2].LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)) }) t.Run("UnsetStart", func(t *testing.T) { @@ -308,11 +315,12 @@ func TestScheduleModify(t *testing.T) { ) //nolint:gocritic // this workspace is owned by owner clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: the updated schedule should be shown - pty.ExpectMatch(ws[1].OwnerName + "/" + ws[1].Name) + stdout.ExpectMatchContext(ctx, ws[1].OwnerName+"/"+ws[1].Name) }) t.Run("UnsetStop", func(t *testing.T) { @@ -322,11 +330,12 @@ func TestScheduleModify(t *testing.T) { ) //nolint:gocritic // this workspace is owned by owner clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: the updated schedule should be shown - pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name) + stdout.ExpectMatchContext(ctx, ws[0].OwnerName+"/"+ws[0].Name) }) } @@ -359,7 +368,8 @@ func TestScheduleOverride(t *testing.T) { ) clitest.SetupConfig(t, ownerClient, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Fetch the workspace to get the actual deadline set by the @@ -376,11 +386,11 @@ func TestScheduleOverride(t *testing.T) { expectedDeadline := updated.LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339) // Then: the updated schedule should be shown - pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name) - pty.ExpectMatch(sched.Humanize()) - pty.ExpectMatch(sched.Next(now).In(loc).Format(time.RFC3339)) - pty.ExpectMatch("8h") - pty.ExpectMatch(expectedDeadline) + stdout.ExpectMatchContext(ctx, ws[0].OwnerName+"/"+ws[0].Name) + stdout.ExpectMatchContext(ctx, sched.Humanize()) + stdout.ExpectMatchContext(ctx, sched.Next(now).In(loc).Format(time.RFC3339)) + stdout.ExpectMatchContext(ctx, "8h") + stdout.ExpectMatchContext(ctx, expectedDeadline) }) } } @@ -422,13 +432,14 @@ func TestScheduleStart_TemplateAutostartRequirement(t *testing.T) { "schedule", "start", workspace.Name, "9:30AM", "Mon-Fri", ) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + ctx := testutil.Context(t, testutil.WaitShort) require.NoError(t, inv.Run()) // Then: warning should be shown // In AGPL, this will show all days (enterprise feature defaults to all days allowed) - pty.ExpectMatch("Warning") - pty.ExpectMatch("may only autostart") + stdout.ExpectMatchContext(ctx, "Warning") + stdout.ExpectMatchContext(ctx, "may only autostart") }) t.Run("NoWarningWhenManual", func(t *testing.T) { diff --git a/cli/secret_test.go b/cli/secret_test.go index 3cbb6b89b8..06224d45c6 100644 --- a/cli/secret_test.go +++ b/cli/secret_test.go @@ -14,8 +14,8 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) func TestSecretCreate(t *testing.T) { @@ -501,6 +501,7 @@ func TestSecretDelete(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -516,12 +517,13 @@ func TestSecretDelete(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) inv = inv.WithContext(ctx) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) waiter := clitest.StartWithWaiter(t, inv) - pty.ExpectMatchContext(ctx, "Delete secret") - pty.ExpectMatchContext(ctx, "service-token") - pty.WriteLine("yes") - pty.ExpectMatchContext(ctx, "Deleted secret") + stdout.ExpectMatchContext(ctx, "Delete secret") + stdout.ExpectMatchContext(ctx, "service-token") + stdin.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Deleted secret") require.NoError(t, waiter.Wait()) @@ -566,6 +568,7 @@ func TestSecretDelete(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() + logger := testutil.Logger(t) client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -574,11 +577,12 @@ func TestSecretDelete(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) inv = inv.WithContext(ctx) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) waiter := clitest.StartWithWaiter(t, inv) - pty.ExpectMatchContext(ctx, "Delete secret") - pty.ExpectMatchContext(ctx, "missing-secret") - pty.WriteLine("yes") + stdout.ExpectMatchContext(ctx, "Delete secret") + stdout.ExpectMatchContext(ctx, "missing-secret") + stdin.WriteLine("yes") err := waiter.Wait() require.ErrorContains(t, err, `delete secret "missing-secret"`) diff --git a/cli/server.go b/cli/server.go index 3e3dd0c643..b2fa89fd3b 100644 --- a/cli/server.go +++ b/cli/server.go @@ -56,7 +56,6 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/sloghuman" - "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/clilog" "github.com/coder/coder/v2/cli/cliui" @@ -1042,7 +1041,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // unconditionally when the bridge feature is enabled by config so // chatd can use it regardless of license entitlement. if vals.AI.BridgeConfig.Enabled.Value() { - aibridgeProviders, err := BuildProviders(aibridgeInitCtx, options.Database, vals.AI.BridgeConfig, logger.Named("aibridge.providers")) + aibridgeProviders, _, err := BuildProviders(aibridgeInitCtx, options.Database, vals.AI.BridgeConfig, logger.Named("aibridge.providers")) if err != nil { return xerrors.Errorf("build AI providers: %w", err) } @@ -3008,11 +3007,10 @@ func ReadAIProvidersFromEnv(logger slog.Logger, environ []string) ([]codersdk.AI return nil, xerrors.Errorf("provider %d: TYPE is required", i) } - switch p.Type { - case aibridge.ProviderOpenAI, aibridge.ProviderAnthropic, aibridge.ProviderCopilot: - default: - return nil, xerrors.Errorf("provider %d: unknown TYPE %q (must be %s, %s, or %s)", - i, p.Type, aibridge.ProviderOpenAI, aibridge.ProviderAnthropic, aibridge.ProviderCopilot) + providerType := database.AIProviderType(p.Type) + if !providerType.Valid() { + return nil, xerrors.Errorf("provider %d: unknown TYPE %q (must be one of: %v)", + i, p.Type, database.AllAIProviderTypeValues()) } var bedrockKey, bedrockSecret string @@ -3028,21 +3026,36 @@ func ReadAIProvidersFromEnv(logger slog.Logger, environ []string) ([]codersdk.AI ) isBedrock := codersdk.IsBedrockConfigured(p.BedrockBaseURL, settings) - if p.Type != aibridge.ProviderAnthropic && isBedrock { - return nil, xerrors.Errorf("provider %d (%s): BEDROCK_* fields are only supported with TYPE %q", - i, p.Type, aibridge.ProviderAnthropic) + // BEDROCK_* fields are accepted on anthropic (mutually exclusive + // with KEYS) and required on bedrock. Any other TYPE rejecting + // them prevents silently-ignored credentials. + isBedrockType := providerType == database.AiProviderTypeBedrock + isAnthropicType := providerType == database.AiProviderTypeAnthropic + if !isAnthropicType && !isBedrockType && isBedrock { + return nil, xerrors.Errorf("provider %d (%s): BEDROCK_* fields are only supported with TYPE %q or %q", + i, p.Type, database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock) } - if p.Type == aibridge.ProviderCopilot && len(p.Keys) > 0 { + if isBedrockType && !isBedrock { + return nil, xerrors.Errorf("provider %d (%s): TYPE %q requires BEDROCK_* fields to be configured", + i, p.Type, database.AiProviderTypeBedrock) + } + + if isBedrockType && len(p.Keys) > 0 { + return nil, xerrors.Errorf("provider %d (%s): KEY/KEYS are not supported for TYPE %q (use BEDROCK_* fields)", + i, p.Type, database.AiProviderTypeBedrock) + } + + if providerType == database.AiProviderTypeCopilot && len(p.Keys) > 0 { return nil, xerrors.Errorf("provider %d (%s): KEY/KEYS are not supported for TYPE %q", - i, p.Type, aibridge.ProviderCopilot) + i, p.Type, database.AiProviderTypeCopilot) } // An Anthropic provider authenticates either via a bearer // token (KEYS) or via Bedrock (BEDROCK_*), not both. Surface // the conflict here so misconfigured deployments fail before // any DB work happens at server startup. - if p.Type == aibridge.ProviderAnthropic && len(p.Keys) > 0 && isBedrock { + if isAnthropicType && len(p.Keys) > 0 && isBedrock { return nil, xerrors.Errorf("provider %d (%s): KEY/KEYS and BEDROCK_* fields are mutually exclusive", i, p.Type) } diff --git a/cli/server_aibridge_internal_test.go b/cli/server_aibridge_internal_test.go index 1797f1c7ed..a91e5b51d2 100644 --- a/cli/server_aibridge_internal_test.go +++ b/cli/server_aibridge_internal_test.go @@ -1,6 +1,8 @@ package cli import ( + "database/sql" + "encoding/json" "fmt" "testing" @@ -11,6 +13,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" "github.com/coder/serpent" @@ -362,6 +365,40 @@ func TestReadAIProvidersFromEnv(t *testing.T) { }, errContains: "cannot mix CODER_AIBRIDGE_PROVIDER_* and CODER_AI_GATEWAY_PROVIDER_* environment variables", }, + { + name: "BedrockTypeHappyPath", + env: []string{ + "CODER_AIBRIDGE_PROVIDER_0_TYPE=bedrock", + "CODER_AIBRIDGE_PROVIDER_0_NAME=bedrock-prod", + "CODER_AIBRIDGE_PROVIDER_0_BEDROCK_REGION=us-east-1", + "CODER_AIBRIDGE_PROVIDER_0_BEDROCK_ACCESS_KEY=AKID", + "CODER_AIBRIDGE_PROVIDER_0_BEDROCK_ACCESS_KEY_SECRET=secret", + }, + expected: []codersdk.AIProviderConfig{ + { + Type: string(database.AiProviderTypeBedrock), + Name: "bedrock-prod", + BedrockRegion: "us-east-1", + BedrockAccessKeys: []string{"AKID"}, + BedrockAccessKeySecrets: []string{"secret"}, + }, + }, + }, + { + name: "BedrockTypeWithoutBedrockFields", + env: []string{"CODER_AIBRIDGE_PROVIDER_0_TYPE=bedrock", "CODER_AIBRIDGE_PROVIDER_0_NAME=bedrock-prod"}, + errContains: "requires BEDROCK_* fields to be configured", + }, + { + name: "BedrockTypeRejectsAPIKeys", + env: []string{ + "CODER_AIBRIDGE_PROVIDER_0_TYPE=bedrock", + "CODER_AIBRIDGE_PROVIDER_0_NAME=bedrock-prod", + "CODER_AIBRIDGE_PROVIDER_0_BEDROCK_REGION=us-east-1", + "CODER_AIBRIDGE_PROVIDER_0_KEY=sk-should-fail", + }, + errContains: "KEY/KEYS are not supported for TYPE", + }, { name: "BedrockKeysTooMany", env: []string{ @@ -544,32 +581,106 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { const dumpDir = "/tmp/coder-aibridge-dumps" tests := []struct { - name string - row database.AIProvider + name string + row database.AIProvider + expectedType string }{ { name: "OpenAI", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeOpenai, Name: "openai", BaseUrl: "https://api.openai.com/", }, + expectedType: aibridge.ProviderOpenAI, }, { name: "Anthropic", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeAnthropic, Name: "anthropic", BaseUrl: "https://api.anthropic.com/", }, + expectedType: aibridge.ProviderAnthropic, }, { name: "Copilot", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeCopilot, Name: "copilot", BaseUrl: "https://api.githubcopilot.com/", }, + expectedType: aibridge.ProviderCopilot, + }, + { + name: "Azure", + row: database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeAzure, + Name: "azure", + BaseUrl: "https://example.openai.azure.com/", + }, + expectedType: aibridge.ProviderOpenAI, + }, + { + name: "Google", + row: database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeGoogle, + Name: "google", + BaseUrl: "https://generativelanguage.googleapis.com/v1beta/openai/", + }, + expectedType: aibridge.ProviderOpenAI, + }, + { + name: "OpenAICompat", + row: database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeOpenaiCompat, + Name: "openai-compat", + BaseUrl: "https://compat.example.com/v1/", + }, + expectedType: aibridge.ProviderOpenAI, + }, + { + name: "OpenRouter", + row: database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeOpenrouter, + Name: "openrouter", + BaseUrl: "https://openrouter.ai/api/v1/", + }, + expectedType: aibridge.ProviderOpenAI, + }, + { + name: "Vercel", + row: database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeVercel, + Name: "vercel", + BaseUrl: "https://api.v0.dev/v1/", + }, + expectedType: aibridge.ProviderOpenAI, + }, + { + name: "Bedrock", + row: database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeBedrock, + Name: "bedrock", + BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/", + Settings: mustMarshalSettings(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{ + Region: "us-east-1", + AccessKey: ptr.Ref("AKID"), + AccessKeySecret: ptr.Ref("secret"), + }, + }), + }, + expectedType: aibridge.ProviderAnthropic, }, } @@ -583,6 +694,30 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { }) require.NoError(t, err) assert.Equal(t, dumpDir, provider.APIDumpDir()) + assert.Equal(t, tt.expectedType, provider.Type()) }) } } + +func TestBuildAIProviderFromRowBedrockWithoutSettings(t *testing.T) { + t.Parallel() + + _, err := buildAIProviderFromRow(database.AIProvider{ + Enabled: true, + Type: database.AiProviderTypeBedrock, + Name: "bedrock-no-settings", + BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/", + }, nil, codersdk.AIBridgeConfig{ + AllowBYOK: serpent.Bool(true), + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "bedrock provider has no bedrock credentials configured") +} + +func mustMarshalSettings(s codersdk.AIProviderSettings) sql.NullString { + data, err := json.Marshal(s) + if err != nil { + panic(err) + } + return sql.NullString{String: string(data), Valid: true} +} diff --git a/cli/server_createadminuser_test.go b/cli/server_createadminuser_test.go index c0883a2d27..d0eef5f72d 100644 --- a/cli/server_createadminuser_test.go +++ b/cli/server_createadminuser_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" ) //nolint:paralleltest, tparallel @@ -128,19 +129,17 @@ func TestServerCreateAdminUser(t *testing.T) { "--email", email, "--password", password, ) - pty := ptytest.New(t) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.Start(t, inv) - pty.ExpectMatchContext(ctx, "Creating user...") - pty.ExpectMatchContext(ctx, "Generating user SSH key...") - pty.ExpectMatchContext(ctx, fmt.Sprintf("Adding user to organization %q (%s) as admin...", org1Name, org1ID.String())) - pty.ExpectMatchContext(ctx, fmt.Sprintf("Adding user to organization %q (%s) as admin...", org2Name, org2ID.String())) - pty.ExpectMatchContext(ctx, "User created successfully.") - pty.ExpectMatchContext(ctx, username) - pty.ExpectMatchContext(ctx, email) - pty.ExpectMatchContext(ctx, "****") + stdout.ExpectMatchContext(ctx, "Creating user...") + stdout.ExpectMatchContext(ctx, "Generating user SSH key...") + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Adding user to organization %q (%s) as admin...", org1Name, org1ID.String())) + stdout.ExpectMatchContext(ctx, fmt.Sprintf("Adding user to organization %q (%s) as admin...", org2Name, org2ID.String())) + stdout.ExpectMatchContext(ctx, "User created successfully.") + stdout.ExpectMatchContext(ctx, username) + stdout.ExpectMatchContext(ctx, email) + stdout.ExpectMatchContext(ctx, "****") verifyUser(t, connectionURL, username, email, password) }) @@ -184,6 +183,7 @@ func TestServerCreateAdminUser(t *testing.T) { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() } + logger := testutil.Logger(t) connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) @@ -195,23 +195,24 @@ func TestServerCreateAdminUser(t *testing.T) { "--postgres-url", connectionURL, "--ssh-keygen-algorithm", "ed25519", ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) + stdin := testutil.NewWriterAttachedToInvocation(t, logger.Named("stdin"), inv) clitest.Start(t, inv) - pty.ExpectMatchContext(ctx, "Username") - pty.WriteLine(username) - pty.ExpectMatchContext(ctx, "Email") - pty.WriteLine(email) - pty.ExpectMatchContext(ctx, "Password") - pty.WriteLine(password) - pty.ExpectMatchContext(ctx, "Confirm password") - pty.WriteLine(password) + stdout.ExpectMatchContext(ctx, "Username") + stdin.WriteLine(username) + stdout.ExpectMatchContext(ctx, "Email") + stdin.WriteLine(email) + stdout.ExpectMatchContext(ctx, "Password") + stdin.WriteLine(password) + stdout.ExpectMatchContext(ctx, "Confirm password") + stdin.WriteLine(password) - pty.ExpectMatchContext(ctx, "User created successfully.") - pty.ExpectMatchContext(ctx, username) - pty.ExpectMatchContext(ctx, email) - pty.ExpectMatchContext(ctx, "****") + stdout.ExpectMatchContext(ctx, "User created successfully.") + stdout.ExpectMatchContext(ctx, username) + stdout.ExpectMatchContext(ctx, email) + stdout.ExpectMatchContext(ctx, "****") verifyUser(t, connectionURL, username, email, password) }) diff --git a/cli/server_test.go b/cli/server_test.go index 89e0ba7048..6776e84424 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -59,6 +59,7 @@ import ( "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/serpent" ) @@ -229,7 +230,7 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", "--ephemeral", ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) // Embedded postgres takes a while to fire up. const superDuperLong = testutil.WaitSuperLong * 3 @@ -240,7 +241,7 @@ func TestServer(t *testing.T) { }() matchCh1 := make(chan string, 1) go func() { - matchCh1 <- pty.ExpectMatchContext(ctx, "Using an ephemeral deployment directory") + matchCh1 <- stdout.ExpectMatchContext(ctx, "Using an ephemeral deployment directory") }() select { case err := <-errCh: @@ -248,7 +249,7 @@ func TestServer(t *testing.T) { case <-matchCh1: // OK! } - rootDirLine := pty.ReadLine(ctx) + rootDirLine := stdout.ReadLine(ctx) rootDir := strings.TrimPrefix(rootDirLine, "Using an ephemeral deployment directory") rootDir = strings.TrimSpace(rootDir) rootDir = strings.TrimPrefix(rootDir, "(") @@ -259,7 +260,7 @@ func TestServer(t *testing.T) { matchCh2 := make(chan string, 1) go func() { // The "View the Web UI" log is a decent indicator that the server was successfully started. - matchCh2 <- pty.ExpectMatchContext(ctx, "View the Web UI") + matchCh2 <- stdout.ExpectMatchContext(ctx, "View the Web UI") }() select { case err := <-errCh: @@ -276,24 +277,23 @@ func TestServer(t *testing.T) { t.Run("BuiltinPostgresURL", func(t *testing.T) { t.Parallel() root, _ := clitest.New(t, "server", "postgres-builtin-url") - pty := ptytest.New(t) - root.Stdout = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, root) + ctx := testutil.Context(t, testutil.WaitShort) err := root.Run() require.NoError(t, err) - pty.ExpectMatch("psql") + stdout.ExpectMatchContext(ctx, "psql") }) t.Run("BuiltinPostgresURLRaw", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url") - pty := ptytest.New(t) - root.Stdout = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, root) err := root.WithContext(ctx).Run() require.NoError(t, err) - got := pty.ReadLine(ctx) + got := stdout.ReadLine(ctx) if !strings.HasPrefix(got, "postgres://") { t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got) } @@ -506,6 +506,7 @@ func TestServer(t *testing.T) { // reachable. t.Run("LocalAccessURL", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) inv, cfg := clitest.New(t, "server", dbArg(t), @@ -513,7 +514,7 @@ func TestServer(t *testing.T) { "--access-url", "http://localhost:3000/", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) // Since we end the test after seeing the log lines about the access url, we could cancel the test before // our initial interactions with PostgreSQL are complete. So, ignore errors of that type for this test. startIgnoringPostgresQueryCancel(t, inv) @@ -521,9 +522,9 @@ func TestServer(t *testing.T) { // Just wait for startup _ = waitAccessURL(t, cfg) - pty.ExpectMatch("this may cause unexpected problems when creating workspaces") - pty.ExpectMatch("View the Web UI:") - pty.ExpectMatch("http://localhost:3000/") + stdout.ExpectMatchContext(ctx, "this may cause unexpected problems when creating workspaces") + stdout.ExpectMatchContext(ctx, "View the Web UI:") + stdout.ExpectMatchContext(ctx, "http://localhost:3000/") }) // Validate that an https scheme is prepended to a remote access URL @@ -531,6 +532,7 @@ func TestServer(t *testing.T) { t.Run("RemoteAccessURL", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) inv, cfg := clitest.New(t, "server", dbArg(t), @@ -538,7 +540,7 @@ func TestServer(t *testing.T) { "--access-url", "https://foobarbaz.mydomain", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) // Since we end the test after seeing the log lines about the access url, we could cancel the test before // our initial interactions with PostgreSQL are complete. So, ignore errors of that type for this test. @@ -547,13 +549,14 @@ func TestServer(t *testing.T) { // Just wait for startup _ = waitAccessURL(t, cfg) - pty.ExpectMatch("this may cause unexpected problems when creating workspaces") - pty.ExpectMatch("View the Web UI:") - pty.ExpectMatch("https://foobarbaz.mydomain") + stdout.ExpectMatchContext(ctx, "this may cause unexpected problems when creating workspaces") + stdout.ExpectMatchContext(ctx, "View the Web UI:") + stdout.ExpectMatchContext(ctx, "https://foobarbaz.mydomain") }) t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) inv, cfg := clitest.New(t, "server", dbArg(t), @@ -561,7 +564,7 @@ func TestServer(t *testing.T) { "--access-url", "https://google.com", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) // Since we end the test after seeing the log lines about the access url, we could cancel the test before // our initial interactions with PostgreSQL are complete. So, ignore errors of that type for this test. startIgnoringPostgresQueryCancel(t, inv) @@ -569,8 +572,8 @@ func TestServer(t *testing.T) { // Just wait for startup _ = waitAccessURL(t, cfg) - pty.ExpectMatch("View the Web UI:") - pty.ExpectMatch("https://google.com") + stdout.ExpectMatchContext(ctx, "View the Web UI:") + stdout.ExpectMatchContext(ctx, "https://google.com") }) t.Run("NoSchemeAccessURL", func(t *testing.T) { @@ -735,8 +738,6 @@ func TestServer(t *testing.T) { "--tls-key-file", key2Path, "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.Stdout = pty.Output() clitest.Start(t, root.WithContext(ctx)) accessURL := waitAccessURL(t, cfg) @@ -814,18 +815,18 @@ func TestServer(t *testing.T) { "--tls-key-file", keyPath, "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.Start(t, inv) // We can't use waitAccessURL as it will only return the HTTP URL. const httpLinePrefix = "Started HTTP listener at" - pty.ExpectMatch(httpLinePrefix) - httpLine := pty.ReadLine(ctx) + stdout.ExpectMatchContext(ctx, httpLinePrefix) + httpLine := stdout.ReadLine(ctx) httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) require.NotEmpty(t, httpAddr) const tlsLinePrefix = "Started TLS/HTTPS listener at " - pty.ExpectMatch(tlsLinePrefix) - tlsLine := pty.ReadLine(ctx) + stdout.ExpectMatchContext(ctx, tlsLinePrefix) + tlsLine := stdout.ReadLine(ctx) tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) require.NotEmpty(t, tlsAddr) @@ -951,8 +952,7 @@ func TestServer(t *testing.T) { } inv, _ := clitest.New(t, flags...) - pty := ptytest.New(t) - pty.Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.Start(t, inv) @@ -963,15 +963,15 @@ func TestServer(t *testing.T) { // We can't use waitAccessURL as it will only return the HTTP URL. if c.httpListener { const httpLinePrefix = "Started HTTP listener at" - pty.ExpectMatch(httpLinePrefix) - httpLine := pty.ReadLine(ctx) + stdout.ExpectMatchContext(ctx, httpLinePrefix) + httpLine := stdout.ReadLine(ctx) httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) require.NotEmpty(t, httpAddr) } if c.tlsListener { const tlsLinePrefix = "Started TLS/HTTPS listener at" - pty.ExpectMatch(tlsLinePrefix) - tlsLine := pty.ReadLine(ctx) + stdout.ExpectMatchContext(ctx, tlsLinePrefix) + tlsLine := stdout.ReadLine(ctx) tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) require.NotEmpty(t, tlsAddr) } @@ -1041,6 +1041,7 @@ func TestServer(t *testing.T) { t.Run("CanListenUnspecifiedv4", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) inv, _ := clitest.New(t, "server", dbArg(t), @@ -1048,18 +1049,19 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) // Since we end the test after seeing the log lines about the HTTP listener, we could cancel the test before // our initial interactions with PostgreSQL are complete. So, ignore errors of that type for this test. startIgnoringPostgresQueryCancel(t, inv) - pty.ExpectMatch("Started HTTP listener") - pty.ExpectMatch("http://0.0.0.0:") + stdout.ExpectMatchContext(ctx, "Started HTTP listener") + stdout.ExpectMatchContext(ctx, "http://0.0.0.0:") }) t.Run("CanListenUnspecifiedv6", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) inv, _ := clitest.New(t, "server", dbArg(t), @@ -1067,13 +1069,13 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", ) - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) // Since we end the test after seeing the log lines about the HTTP listener, we could cancel the test before // our initial interactions with PostgreSQL are complete. So, ignore errors of that type for this test. startIgnoringPostgresQueryCancel(t, inv) - pty.ExpectMatch("Started HTTP listener at") - pty.ExpectMatch("http://[::]:") + stdout.ExpectMatchContext(ctx, "Started HTTP listener at") + stdout.ExpectMatchContext(ctx, "http://[::]:") }) t.Run("NoAddress", func(t *testing.T) { @@ -1128,12 +1130,10 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.Start(t, inv.WithContext(ctx)) - pty.ExpectMatch("is deprecated") + stdout.ExpectMatchContext(ctx, "is deprecated") accessURL := waitAccessURL(t, cfg) require.Equal(t, "http", accessURL.Scheme) @@ -1158,12 +1158,10 @@ func TestServer(t *testing.T) { "--tls-key-file", keyPath, "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.Stdout = pty.Output() - root.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, root) clitest.Start(t, root.WithContext(ctx)) - pty.ExpectMatch("is deprecated") + stdout.ExpectMatchContext(ctx, "is deprecated") accessURL := waitAccessURL(t, cfg) require.Equal(t, "https", accessURL.Scheme) @@ -1259,15 +1257,13 @@ func TestServer(t *testing.T) { "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.Start(t, inv) // Wait until we see the prometheus address in the logs. addrMatchExpr := `http server listening\s+addr=(\S+)\s+name=prometheus` - lineMatch := pty.ExpectRegexMatchContext(ctx, addrMatchExpr) + lineMatch := stdout.ExpectRegexMatchContext(ctx, addrMatchExpr) promAddr := regexp.MustCompile(addrMatchExpr).FindStringSubmatch(lineMatch)[1] testutil.Eventually(ctx, t, func(ctx context.Context) bool { @@ -1322,15 +1318,13 @@ func TestServer(t *testing.T) { "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - inv.Stdout = pty.Output() - inv.Stderr = pty.Output() + stdout := expecter.NewAttachedToInvocation(t, inv) clitest.Start(t, inv) // Wait until we see the prometheus address in the logs. addrMatchExpr := `http server listening\s+addr=(\S+)\s+name=prometheus` - lineMatch := pty.ExpectRegexMatchContext(ctx, addrMatchExpr) + lineMatch := stdout.ExpectRegexMatchContext(ctx, addrMatchExpr) promAddr := regexp.MustCompile(addrMatchExpr).FindStringSubmatch(lineMatch)[1] testutil.Eventually(ctx, t, func(ctx context.Context) bool { @@ -1751,7 +1745,6 @@ func TestServer(t *testing.T) { inv, cfg := clitest.New(t, args..., ) - ptytest.New(t).Attach(inv) inv = inv.WithContext(ctx) w := clitest.StartWithWaiter(t, inv) gotURL := waitAccessURL(t, cfg) @@ -2019,15 +2012,15 @@ func TestServer_Logging_NoParallel(t *testing.T) { "--provisioner-types=echo", "--log-stackdriver", fi, ) - // Attach pty so we get debug output from the command if this test + // Attach expecter so we get debug output from the command if this test // fails. - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) startIgnoringPostgresQueryCancel(t, inv.WithContext(ctx)) // Wait for server to listen on HTTP, this is a good // starting point for expecting logs. - _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at") + _ = stdout.ExpectMatchContext(ctx, "Started HTTP listener at") loggingWaitFile(t, fi, testutil.WaitSuperLong) }) @@ -2056,15 +2049,15 @@ func TestServer_Logging_NoParallel(t *testing.T) { "--log-json", fi2, "--log-stackdriver", fi3, ) - // Attach pty so we get debug output from the command if this test + // Attach expecter so we get debug output from the command if this test // fails. - pty := ptytest.New(t).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) startIgnoringPostgresQueryCancel(t, inv) // Wait for server to listen on HTTP, this is a good // starting point for expecting logs. - _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at") + _ = stdout.ExpectMatchContext(ctx, "Started HTTP listener at") loggingWaitFile(t, fi1, testutil.WaitSuperLong) loggingWaitFile(t, fi2, testutil.WaitSuperLong) @@ -2258,7 +2251,7 @@ func TestServer_GracefulShutdown(t *testing.T) { return ctx, stopFunc }) serverErr := make(chan error, 1) - pty := ptytest.New(t).Attach(root) + stdout := expecter.NewAttachedToInvocation(t, root) go func() { serverErr <- root.WithContext(ctx).Run() }() @@ -2266,7 +2259,7 @@ func TestServer_GracefulShutdown(t *testing.T) { // It's fair to assume `stopFunc` isn't nil here, because the server // has started and access URL is propagated. stopFunc() - pty.ExpectMatch("waiting for provisioner jobs to complete") + stdout.ExpectMatchContext(ctx, "waiting for provisioner jobs to complete") err := <-serverErr require.NoError(t, err) } @@ -2501,19 +2494,19 @@ func TestServer_TelemetryDisabled_FinalReport(t *testing.T) { inv.Logger = inv.Logger.Named(opts.name) errChan := make(chan error, 1) - pty := ptytest.New(t).Named(opts.name).Attach(inv) + stdout := expecter.NewAttachedToInvocation(t, inv) go func() { errChan <- inv.WithContext(ctx).Run() // close the pty here so that we can start tearing down resources. This test creates multiple servers with // associated ptys. There is a `t.Cleanup()` that does this, but it waits until the whole test is complete. - _ = pty.Close() + stdout.Close("invocation complete") }() if opts.waitForSnapshot { - pty.ExpectMatchContext(testutil.Context(t, testutil.WaitLong), "submitted snapshot") + stdout.ExpectMatchContext(testutil.Context(t, testutil.WaitLong), "submitted snapshot") } if opts.waitForTelemetryDisabledCheck { - pty.ExpectMatchContext(testutil.Context(t, testutil.WaitLong), "finished telemetry status check") + stdout.ExpectMatchContext(testutil.Context(t, testutil.WaitLong), "finished telemetry status check") } return errChan, cancelFunc } diff --git a/cli/task_send_test.go b/cli/task_send_test.go index e545da80d1..1590bcab29 100644 --- a/cli/task_send_test.go +++ b/cli/task_send_test.go @@ -237,7 +237,10 @@ func Test_TaskSend(t *testing.T) { t.Parallel() // Given: An initializing task (workspace running, no agent - // connected). + // connected). Close the agent, pause, then resume so the + // workspace is started but no agent is connected. The + // command enters waitForTaskIdle directly (initializing + // path), where we verify it handles an external pause. setupCtx := testutil.Context(t, testutil.WaitLong) setup := setupCLITaskTest(setupCtx, t, nil) @@ -245,8 +248,13 @@ func Test_TaskSend(t *testing.T) { pauseTask(setupCtx, t, setup.userClient, setup.task) resumeTask(setupCtx, t, setup.userClient, setup.task) + // Set up mock clock and traps before starting the command. + mClock := quartz.NewMock(t) + tickTrap := mClock.Trap().NewTicker("task_send", "poll") + resetTrap := mClock.Trap().TickerReset("task_send", "poll") + // When: We attempt to send input to the initializing task. - inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input") + inv, root := clitest.NewWithClock(t, mClock, "task", "send", setup.task.Name, "some task input") clitest.SetupConfig(t, setup.userClient, root) ctx := testutil.Context(t, testutil.WaitLong) @@ -259,11 +267,34 @@ func Test_TaskSend(t *testing.T) { // of waitForTaskIdle. pty.ExpectMatchContext(ctx, "Waiting for task to become idle") - // Pause the task while waitForTaskIdle is polling. Since - // no agent is connected, the task stays initializing until - // we pause it, at which point the status becomes paused. + // Wait for ticker creation and release it. + tickCall := tickTrap.MustWait(ctx) + tickCall.MustRelease(ctx) + tickTrap.Close() + + // Fire the first poll. The goroutine calls ticker.Reset + // which the trap catches, freezing the goroutine BEFORE + // client.TaskByID runs. Release it so the first poll + // sees 'initializing' and continues. + mClock.Advance(time.Nanosecond).MustWait(ctx) + resetCall := resetTrap.MustWait(ctx) + resetCall.MustRelease(ctx) + + // Fire the second poll. The goroutine is again frozen at + // ticker.Reset by the trap. + mClock.Advance(5 * time.Second).MustWait(ctx) + resetCall = resetTrap.MustWait(ctx) + + // While the goroutine is frozen (before client.TaskByID), + // pause the task. The stop build completes, so the DB has + // (stop, succeeded) = 'paused'. pauseTask(ctx, t, setup.userClient, setup.task) + // Release the trap. The goroutine unfreezes and + // client.TaskByID deterministically sees 'paused'. + resetCall.MustRelease(ctx) + resetTrap.Close() + // Then: The command should fail because the task was paused. err := w.Wait() require.Error(t, err) @@ -303,23 +334,31 @@ func Test_TaskSend(t *testing.T) { tickCall.MustRelease(ctx) tickTrap.Close() - // Fire the immediate first poll (time.Nanosecond initial interval). + // Fire the first poll. The goroutine calls ticker.Reset + // which the trap catches, freezing the goroutine BEFORE + // client.TaskByID runs. Release it so the first poll + // sees "working" and continues. mClock.Advance(time.Nanosecond).MustWait(ctx) - - // Wait for Reset (confirms first poll completed and saw "working"). resetCall := resetTrap.MustWait(ctx) resetCall.MustRelease(ctx) - resetTrap.Close() - // Transition the app back to idle so waitForTaskIdle proceeds. + // Fire the second poll. The goroutine is again frozen + // at ticker.Reset by the trap. + mClock.Advance(5 * time.Second).MustWait(ctx) + resetCall = resetTrap.MustWait(ctx) + + // While the goroutine is frozen (before client.TaskByID), + // transition the app to idle. require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "task-sidebar", State: codersdk.WorkspaceAppStatusStateIdle, Message: "ready", })) - // Fire second poll at the regular 5s interval. - mClock.Advance(5 * time.Second).MustWait(ctx) + // Release the trap. The goroutine unfreezes and + // client.TaskByID deterministically sees "idle". + resetCall.MustRelease(ctx) + resetTrap.Close() // Then: The command should complete successfully. require.NoError(t, w.Wait()) diff --git a/coderd/ai_providers.go b/coderd/ai_providers.go index dd8e4c00d3..78ca50ecc9 100644 --- a/coderd/ai_providers.go +++ b/coderd/ai_providers.go @@ -320,10 +320,13 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { if req.Settings != nil { existing = mergeAIProviderSettings(existing, *req.Settings) } - // Bedrock settings are only meaningful for anthropic-typed - // providers; rejecting the mismatch keeps a misconfiguration - // from sitting silently in the encrypted blob. - if existing.Bedrock != nil && old.Type != database.AiProviderTypeAnthropic { + // Bedrock settings are only meaningful for anthropic- or + // bedrock-typed providers; rejecting the mismatch keeps a + // misconfiguration from sitting silently in the encrypted + // blob. + if existing.Bedrock != nil && + old.Type != database.AiProviderTypeAnthropic && + old.Type != database.AiProviderTypeBedrock { return errAIProviderBedrockTypeMismatch } settings, err := encodeAIProviderSettings(existing) @@ -382,7 +385,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { } if errors.Is(err, errAIProviderBedrockTypeMismatch) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Bedrock settings are only valid for type=anthropic.", + Message: "Bedrock settings are only valid for type=anthropic or type=bedrock.", }) return } @@ -482,9 +485,9 @@ var errBedrockRejectsAPIKeys = xerrors.New("bedrock providers do not accept api_ // errAIProviderBedrockTypeMismatch is the sentinel returned from // inside the update transaction when the post-merge settings carry a -// Bedrock block but the provider is not anthropic-typed; the outer -// handler translates it into a 400. -var errAIProviderBedrockTypeMismatch = xerrors.New("bedrock settings are only valid for type=anthropic") +// Bedrock block but the provider is not anthropic- or bedrock-typed; +// the outer handler translates it into a 400. +var errAIProviderBedrockTypeMismatch = xerrors.New("bedrock settings are only valid for type=anthropic or type=bedrock") // errAIProviderInvalidName is returned from lookupAIProvider when the // idOrName parameter is neither a UUID nor a syntactically-valid name. diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index 98cfba2226..055877ecce 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -116,10 +116,21 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", dp.Name, err) } + // Load existing bearer keys so the canonical hash + // includes credentials for comparison. + existingKeyRows, err := tx.GetAIProviderKeysByProviderID(sysCtx, existing.ID) + if err != nil { + return xerrors.Errorf("load existing keys for %q: %w", dp.Name, err) + } + existingKeys := make([]string, 0, len(existingKeyRows)) + for _, k := range existingKeyRows { + existingKeys = append(existingKeys, k.APIKey) + } existingDP := desiredAIProvider{ Type: existing.Type, BaseURL: existing.BaseUrl, Bedrock: existingSettings.Bedrock, + Keys: existingKeys, } existingHash := computeProviderHash(existingDP.canonical()) if existingHash == dp.Hash { @@ -196,18 +207,15 @@ func SeedAIProvidersFromEnv( // canonicalAIProvider is the shape we hash to detect drift between the // configured environment and the row stored in the database. The fields // we hash are exactly the operator-controllable inputs that affect -// runtime behavior. Credentials are intentionally NOT part of the hash -// so operators can rotate them via the API without forcing a server -// restart. This applies to both bearer API keys (stored in -// ai_provider_keys) and to Bedrock access key/secret pairs (stored in -// the settings blob because Bedrock authenticates via settings rather -// than a bearer token). +// runtime behavior, including credentials. +// // Model and SmallFastModel are excluded: they're tunables, and their // serpent defaults shift across releases. type canonicalAIProvider struct { Type string `json:"type"` BaseURL string `json:"base_url"` BedrockRegion string `json:"bedrock_region"` + KeysHash string `json:"keys_hash"` } // desiredAIProvider is a normalized provider description sourced from @@ -235,9 +243,39 @@ func (d desiredAIProvider) canonical() canonicalAIProvider { if d.Bedrock != nil { c.BedrockRegion = d.Bedrock.Region } + c.KeysHash = computeKeysHash(d.Keys, d.Bedrock) return c } +// computeKeysHash produces a deterministic hash over the bearer API +// keys and, for Bedrock providers, the access key and secret. +func computeKeysHash(bearerKeys []string, bedrock *codersdk.AIProviderBedrockSettings) string { + // Collect all credential material in a deterministic order. + // Bearer keys are sorted so reordering in env vars does not + // trigger a false-positive drift. + sorted := make([]string, len(bearerKeys)) + copy(sorted, bearerKeys) + slices.Sort(sorted) + + h := sha256.New() + for _, k := range sorted { + _, _ = h.Write([]byte(k)) + // Separator so "ab"+"c" != "a"+"bc". + _, _ = h.Write([]byte{0}) + } + if bedrock != nil { + if bedrock.AccessKey != nil { + _, _ = h.Write([]byte(*bedrock.AccessKey)) + } + _, _ = h.Write([]byte{0}) + if bedrock.AccessKeySecret != nil { + _, _ = h.Write([]byte(*bedrock.AccessKeySecret)) + } + _, _ = h.Write([]byte{0}) + } + return hex.EncodeToString(h.Sum(nil)) +} + func computeProviderHash(c canonicalAIProvider) string { // json.Marshal is deterministic for structs because field order is // fixed by the struct definition. @@ -327,28 +365,23 @@ func providersFromEnv(ctx context.Context, cfg codersdk.AIBridgeConfig, logger s dp := desiredAIProvider{ Name: name, } - switch p.Type { - case aibridge.ProviderOpenAI: - dp.Type = database.AiProviderTypeOpenai - case aibridge.ProviderAnthropic: - dp.Type = database.AiProviderTypeAnthropic - case aibridge.ProviderCopilot: - dp.Type = database.AiProviderTypeCopilot - default: + providerType := database.AIProviderType(p.Type) + if !providerType.Valid() { logger.Warn(ctx, "skipping indexed AI provider with unsupported type", slog.F("name", name), slog.F("type", p.Type), ) continue } + dp.Type = providerType dp.BaseURL = p.BaseURL - // Bedrock fields only apply to Anthropic. Detection goes - // through AIProviderBedrockSettings.IsConfigured() so the - // legacy and indexed paths agree on what counts as a Bedrock - // provider. + // Bedrock fields apply to Anthropic and the dedicated Bedrock + // type. Detection goes through + // AIProviderBedrockSettings.IsConfigured() so the legacy and + // indexed paths agree on what counts as a Bedrock provider. isBedrock := false - if dp.Type == database.AiProviderTypeAnthropic { + if dp.Type == database.AiProviderTypeAnthropic || dp.Type == database.AiProviderTypeBedrock { var accessKey, accessKeySecret string if len(p.BedrockAccessKeys) > 0 { accessKey = p.BedrockAccessKeys[0] diff --git a/coderd/ai_providers_migrate_test.go b/coderd/ai_providers_migrate_test.go index 87f5dd0764..89165002b0 100644 --- a/coderd/ai_providers_migrate_test.go +++ b/coderd/ai_providers_migrate_test.go @@ -91,21 +91,23 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { } require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) - // Changing the API key alone does NOT count as drift: keys - // live in a separate table and operators rotate them via the - // API. Only changes to non-credential provider-level fields - // (base_url, type, Bedrock region/model) trip the drift check. + // Changing the API key counts as drift: keys are included + // in the canonical hash so operators notice when env-var + // credential changes are ignored by an existing provider. cfg.LegacyOpenAI.Key = serpent.String("sk-rotated") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) - - // Changing the base URL is real drift. - cfg.LegacyOpenAI.BaseURL = serpent.String("https://api.openai.com/v2") err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "differs from the current environment configuration") + + // Changing the base URL is also real drift. + cfg.LegacyOpenAI.Key = serpent.String("sk-original") + cfg.LegacyOpenAI.BaseURL = serpent.String("https://api.openai.com/v2") + err = coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) + require.Error(t, err) + require.Contains(t, err.Error(), "differs from the current environment configuration") }) - t.Run("BedrockCredentialRotationIsNotDrift", func(t *testing.T) { + t.Run("BedrockCredentialChangeIsDrift", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) @@ -120,17 +122,20 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { } require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) - // Rotating the Bedrock access key and secret in env must NOT - // trip the drift check: they're credentials, equivalent to - // bearer API keys, and operators rotate them via the API. + // Rotating the Bedrock access key in env trips the drift + // check so operators know the change did not take effect. cfg.LegacyBedrock.AccessKey = serpent.String("AKIA-rotated") cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret-rotated") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) + require.Error(t, err) + require.Contains(t, err.Error(), "differs from the current environment configuration") // Changing the Bedrock region (a non-credential field) is - // real drift. + // also real drift. + cfg.LegacyBedrock.AccessKey = serpent.String("AKIA-original") + cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret-original") cfg.LegacyBedrock.Region = serpent.String("us-west-2") - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) + err = coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "differs from the current environment configuration") }) @@ -293,6 +298,57 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { require.Equal(t, "sk-ant-1", anKeys[0].APIKey) }) + t.Run("IndexedProvidersKeyDriftWithMultipleKeysAndProviders", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + cfg := codersdk.AIBridgeConfig{ + Providers: []codersdk.AIProviderConfig{ + { + Type: "openai", + Name: "primary-openai", + BaseURL: "https://api.openai.com/v1", + Keys: []string{"sk-openai-1", "sk-openai-2"}, + }, + { + Type: "anthropic", + Name: "primary-anthropic", + BaseURL: "https://api.anthropic.com/", + Keys: []string{"sk-ant-1", "sk-ant-2"}, + }, + }, + } + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + + // Reordering keys must not count as drift. The canonical hash + // sorts keys before hashing, so equivalent key sets remain + // stable across restarts. + cfg.Providers[0].Keys = []string{"sk-openai-2", "sk-openai-1"} + cfg.Providers[1].Keys = []string{"sk-ant-2", "sk-ant-1"} + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + + // Changing one key on one provider must block startup even + // when multiple providers are configured. + cfg.Providers[1].Keys = []string{"sk-ant-2", "sk-ant-rotated"} + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) + require.Error(t, err) + require.Contains(t, err.Error(), "differs from the current environment configuration") + require.Contains(t, err.Error(), `"primary-anthropic"`) + + oa, err := db.GetAIProviderByName(ctx, "primary-openai") + require.NoError(t, err) + oaKeys, err := db.GetAIProviderKeysByProviderID(ctx, oa.ID) + require.NoError(t, err) + require.ElementsMatch(t, []string{"sk-openai-1", "sk-openai-2"}, []string{oaKeys[0].APIKey, oaKeys[1].APIKey}) + + an, err := db.GetAIProviderByName(ctx, "primary-anthropic") + require.NoError(t, err) + anKeys, err := db.GetAIProviderKeysByProviderID(ctx, an.ID) + require.NoError(t, err) + require.ElementsMatch(t, []string{"sk-ant-1", "sk-ant-2"}, []string{anKeys[0].APIKey, anKeys[1].APIKey}) + }) + t.Run("BedrockIndexedProviderHasNoKeys", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) @@ -371,14 +427,15 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - // vercel is a valid ai_provider_type DB value but the aibridge - // runtime has no constructor for it, so the seed switch falls - // into the default branch and skips the row. + // A TYPE that isn't part of the ai_provider_type enum falls + // into the default branch and the row is skipped rather than + // rejected, so deployments don't fail to start over a single + // typo'd provider. cfg := codersdk.AIBridgeConfig{ Providers: []codersdk.AIProviderConfig{ { - Type: "vercel", - Name: "vercel-instance", + Type: "not-a-real-provider", + Name: "ghost", BaseURL: "https://example.com", }, { @@ -423,7 +480,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { require.Empty(t, all, "expected no active rows after soft-delete + re-seed") }) - t.Run("ExistingKeysArePreserved", func(t *testing.T) { + t.Run("ExistingKeysBlockOnDrift", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) @@ -439,15 +496,17 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { row, err := db.GetAIProviderByName(ctx, "openai") require.NoError(t, err) - // Operator rotates the env key. The seed must not duplicate - // keys on a row that already exists; the new key is only - // installed via the API/CRUD layer in this flow. + // Operator rotates the env key. The seed now blocks startup + // because the keys differ, alerting the operator. cfg.LegacyOpenAI.Key = serpent.String("sk-rotated") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + err = coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) + require.Error(t, err) + require.Contains(t, err.Error(), "differs from the current environment configuration") + // The original key is still in the database. keys, err := db.GetAIProviderKeysByProviderID(ctx, row.ID) require.NoError(t, err) - require.Len(t, keys, 1, "env reseed must not duplicate keys on existing rows") + require.Len(t, keys, 1) require.Equal(t, "sk-original", keys[0].APIKey) }) @@ -481,6 +540,40 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { require.Len(t, all, 1, "duplicate indexed entries with matching hash must produce a single row") }) + t.Run("IndexedDuplicateNameMatchingHashDedupesReorderedKeys", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + // Key order should not affect the canonical hash. Reordered + // duplicates under the same name should still dedupe. + cfg := codersdk.AIBridgeConfig{ + Providers: []codersdk.AIProviderConfig{ + { + Type: "openai", + Name: "shared", + BaseURL: "https://api.openai.com/v1", + Keys: []string{"sk-1", "sk-2"}, + }, + { + Type: "openai", + Name: "shared", + BaseURL: "https://api.openai.com/v1", + Keys: []string{"sk-2", "sk-1"}, + }, + }, + } + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + + all, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{}) + require.NoError(t, err) + require.Len(t, all, 1) + keys, err := db.GetAIProviderKeysByProviderID(ctx, all[0].ID) + require.NoError(t, err) + require.Len(t, keys, 2) + require.ElementsMatch(t, []string{"sk-1", "sk-2"}, []string{keys[0].APIKey, keys[1].APIKey}) + }) + t.Run("IndexedDuplicateNameMismatchingHashFails", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) diff --git a/coderd/aibridged/metrics.go b/coderd/aibridged/metrics.go new file mode 100644 index 0000000000..b06a9c067c --- /dev/null +++ b/coderd/aibridged/metrics.go @@ -0,0 +1,94 @@ +package aibridged + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Metrics is the prometheus surface for aibridged provider reloads. +type Metrics struct { + registerer prometheus.Registerer + + // ProviderInfo is one series per configured provider; value is + // always 1 and the status label carries the alertable signal. + // Labels: provider_name, provider_type, status. + ProviderInfo *prometheus.GaugeVec + + // ProvidersLastReloadTimestampSeconds is the unix timestamp of the + // last reload attempt, success or failure. + ProvidersLastReloadTimestampSeconds prometheus.Gauge + + // ProvidersLastReloadSuccessTimestampSeconds is the unix timestamp + // of the last reload that successfully refreshed the pool. A gap + // against ProvidersLastReloadTimestampSeconds means the loop is + // firing but the refresh function is failing. + ProvidersLastReloadSuccessTimestampSeconds prometheus.Gauge +} + +// NewMetrics registers the provider metrics against reg. +func NewMetrics(reg prometheus.Registerer) *Metrics { + factory := promauto.With(reg) + + return &Metrics{ + registerer: reg, + + ProviderInfo: factory.NewGaugeVec(prometheus.GaugeOpts{ + Name: "provider_info", + Help: "One series per configured AI provider. Value is always 1; the status label (enabled, disabled, error) carries the alertable signal.", + }, []string{"provider_name", "provider_type", "status"}), + + ProvidersLastReloadTimestampSeconds: factory.NewGauge(prometheus.GaugeOpts{ + Name: "providers_last_reload_timestamp_seconds", + Help: "Unix timestamp of the last provider reload attempt, success or failure.", + }), + + ProvidersLastReloadSuccessTimestampSeconds: factory.NewGauge(prometheus.GaugeOpts{ + Name: "providers_last_reload_success_timestamp_seconds", + Help: "Unix timestamp of the last provider reload that successfully refreshed the pool. A gap against coder_aibridged_providers_last_reload_timestamp_seconds means the loop is firing but the refresh function is failing.", + }), + } +} + +// Unregister removes the provider metrics from the registerer. +func (m *Metrics) Unregister() { + if m == nil { + return + } + m.registerer.Unregister(m.ProviderInfo) + m.registerer.Unregister(m.ProvidersLastReloadTimestampSeconds) + m.registerer.Unregister(m.ProvidersLastReloadSuccessTimestampSeconds) +} + +// RecordReloadAttempt stamps the attempt-time gauge at the start of a +// reload. A reload that hangs mid-flight is detected by watching the +// gap between this gauge and ProvidersLastReloadSuccessTimestampSeconds. +func (m *Metrics) RecordReloadAttempt() { + if m == nil { + return + } + m.ProvidersLastReloadTimestampSeconds.Set(float64(time.Now().Unix())) +} + +// RecordReloadSuccess rewrites the ProviderInfo GaugeVec from the +// outcomes and stamps the success-time gauge. Reset clears series for +// providers that have left the configuration so they don't linger as +// stale. +func (m *Metrics) RecordReloadSuccess(outcomes []ProviderOutcome) { + if m == nil { + return + } + WriteProviderInfoSnapshot(m.ProviderInfo, outcomes) + m.ProvidersLastReloadSuccessTimestampSeconds.Set(float64(time.Now().Unix())) +} + +// WriteProviderInfoSnapshot Resets info and writes one series per +// outcome. Both aibridged and aibridgeproxyd use this so the +// provider_info recording contract stays in one place. +func WriteProviderInfoSnapshot(info *prometheus.GaugeVec, outcomes []ProviderOutcome) { + info.Reset() + for _, o := range outcomes { + info.WithLabelValues(o.Name, o.Type, string(o.Status)).Set(1) + } +} diff --git a/coderd/aibridged/metrics_test.go b/coderd/aibridged/metrics_test.go new file mode 100644 index 0000000000..008c79dd34 --- /dev/null +++ b/coderd/aibridged/metrics_test.go @@ -0,0 +1,84 @@ +package aibridged_test + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/aibridged" +) + +// TestMetricsRecordReloadSuccess covers the provider_info GaugeVec +// surface: every reload pass rewrites the series for the current +// outcomes and the Reset on each pass drops stale series. +func TestMetricsRecordReloadSuccess(t *testing.T) { + t.Parallel() + + reg := prometheus.NewRegistry() + m := aibridged.NewMetrics(reg) + + outcomes := []aibridged.ProviderOutcome{ + {Name: "alpha", Type: "openai", Status: aibridged.ProviderStatusEnabled}, + {Name: "beta", Type: "anthropic", Status: aibridged.ProviderStatusDisabled}, + {Name: "gamma", Type: "openai", Status: aibridged.ProviderStatusError, Err: xerrors.New("bad config")}, + } + + before := time.Now().Unix() + m.RecordReloadAttempt() + m.RecordReloadSuccess(outcomes) + after := time.Now().Unix() + + assert.Equal(t, 1.0, promtest.ToFloat64(m.ProviderInfo.WithLabelValues("alpha", "openai", "enabled"))) + assert.Equal(t, 1.0, promtest.ToFloat64(m.ProviderInfo.WithLabelValues("beta", "anthropic", "disabled"))) + assert.Equal(t, 1.0, promtest.ToFloat64(m.ProviderInfo.WithLabelValues("gamma", "openai", "error"))) + + attemptTS := int64(promtest.ToFloat64(m.ProvidersLastReloadTimestampSeconds)) + successTS := int64(promtest.ToFloat64(m.ProvidersLastReloadSuccessTimestampSeconds)) + assert.GreaterOrEqual(t, attemptTS, before) + assert.LessOrEqual(t, attemptTS, after) + assert.GreaterOrEqual(t, successTS, before) + assert.LessOrEqual(t, successTS, after) +} + +// TestMetricsResetsStaleProviderSeries verifies that providers removed +// from the outcome set between reloads do not leave behind stale +// series. +func TestMetricsResetsStaleProviderSeries(t *testing.T) { + t.Parallel() + + reg := prometheus.NewRegistry() + m := aibridged.NewMetrics(reg) + + m.RecordReloadSuccess([]aibridged.ProviderOutcome{ + {Name: "alpha", Type: "openai", Status: aibridged.ProviderStatusEnabled}, + {Name: "beta", Type: "anthropic", Status: aibridged.ProviderStatusEnabled}, + }) + require.Equal(t, 2, promtest.CollectAndCount(m.ProviderInfo)) + + m.RecordReloadSuccess([]aibridged.ProviderOutcome{ + {Name: "alpha", Type: "openai", Status: aibridged.ProviderStatusEnabled}, + }) + + assert.Equal(t, 1, promtest.CollectAndCount(m.ProviderInfo), + "beta should have been Reset out of the GaugeVec") + assert.Equal(t, 1.0, promtest.ToFloat64(m.ProviderInfo.WithLabelValues("alpha", "openai", "enabled"))) +} + +// TestMetricsNilSafe asserts the helpers tolerate a nil receiver so +// callers can pass `nil` to disable metric updates without guarding +// every call site. +func TestMetricsNilSafe(t *testing.T) { + t.Parallel() + + var m *aibridged.Metrics + require.NotPanics(t, func() { + m.RecordReloadAttempt() + m.RecordReloadSuccess(nil) + m.Unregister() + }) +} diff --git a/coderd/aibridged/pool.go b/coderd/aibridged/pool.go index 3b7e60955c..b86cefe00a 100644 --- a/coderd/aibridged/pool.go +++ b/coderd/aibridged/pool.go @@ -30,7 +30,9 @@ const ( type Pooler interface { Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error) // ReplaceProviders swaps the providers used to construct future - // RequestBridge instances and clears the cache. + // RequestBridge instances and clears the cache. Disabled providers + // must be included; the bridge serves a 503 sentinel on their + // routes. ReplaceProviders(providers []aibridge.Provider) Shutdown(ctx context.Context) error } @@ -53,7 +55,8 @@ var _ Pooler = &CachedBridgePool{} type CachedBridgePool struct { cache *ristretto.Cache[string, *aibridge.RequestBridge] - // providers is the live provider set used by new RequestBridge instances. + // providers is the live provider set used by new RequestBridge + // instances. Includes disabled providers. providers atomic.Pointer[[]aibridge.Provider] providerVersion atomic.Int64 logger slog.Logger diff --git a/coderd/aibridged/proto/aibridged.pb.go b/coderd/aibridged/proto/aibridged.pb.go index c364aeda40..17fef851ea 100644 --- a/coderd/aibridged/proto/aibridged.pb.go +++ b/coderd/aibridged/proto/aibridged.pb.go @@ -216,8 +216,9 @@ type RecordInterceptionEndedRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // UUID. - EndedAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=ended_at,json=endedAt,proto3" json:"ended_at,omitempty"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // UUID. + EndedAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=ended_at,json=endedAt,proto3" json:"ended_at,omitempty"` + CredentialHint string `protobuf:"bytes,3,opt,name=credential_hint,json=credentialHint,proto3" json:"credential_hint,omitempty"` } func (x *RecordInterceptionEndedRequest) Reset() { @@ -266,6 +267,13 @@ func (x *RecordInterceptionEndedRequest) GetEndedAt() *timestamppb.Timestamp { return nil } +func (x *RecordInterceptionEndedRequest) GetCredentialHint() string { + if x != nil { + return x.CredentialHint + } + return "" +} + type RecordInterceptionEndedResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1295,249 +1303,252 @@ var file_coderd_aibridged_proto_aibridged_proto_rawDesc = []byte{ 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x22, 0x1c, 0x0a, 0x1a, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x67, 0x0a, 0x1e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, - 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x35, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x5f, - 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x52, 0x07, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x41, 0x74, 0x22, 0x21, 0x0a, - 0x1f, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0xe9, 0x03, 0x0a, 0x17, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, - 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x0f, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x6d, 0x73, 0x67, 0x5f, 0x69, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x73, 0x67, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, - 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x0b, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, - 0x23, 0x0a, 0x0d, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0c, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x54, 0x6f, - 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x48, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x90, 0x01, 0x0a, 0x1e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x35, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x65, 0x64, + 0x5f, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x07, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x41, 0x74, 0x12, 0x27, + 0x0a, 0x0f, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x68, 0x69, 0x6e, + 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, + 0x69, 0x61, 0x6c, 0x48, 0x69, 0x6e, 0x74, 0x22, 0x21, 0x0a, 0x1f, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, + 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xe9, 0x03, 0x0a, 0x17, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, - 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x39, - 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, - 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x35, 0x0a, 0x17, 0x63, 0x61, 0x63, - 0x68, 0x65, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x6f, - 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x63, 0x61, 0x63, 0x68, - 0x65, 0x52, 0x65, 0x61, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, - 0x12, 0x37, 0x0a, 0x18, 0x63, 0x61, 0x63, 0x68, 0x65, 0x5f, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, - 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x15, 0x63, 0x61, 0x63, 0x68, 0x65, 0x57, 0x72, 0x69, 0x74, 0x65, 0x49, 0x6e, - 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x1a, 0x51, 0x0a, 0x0d, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2a, 0x0a, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, - 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x1a, 0x0a, 0x18, - 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xcb, 0x02, 0x0a, 0x18, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x15, - 0x0a, 0x06, 0x6d, 0x73, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x6d, 0x73, 0x67, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x49, 0x0a, - 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x2d, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, - 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, + 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, + 0x15, 0x0a, 0x06, 0x6d, 0x73, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x6d, 0x73, 0x67, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x69, 0x6e, + 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x6f, 0x75, 0x74, + 0x70, 0x75, 0x74, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x0c, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x48, + 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x2c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, - 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, + 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x64, 0x41, 0x74, 0x1a, 0x51, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, + 0x64, 0x41, 0x74, 0x12, 0x35, 0x0a, 0x17, 0x63, 0x61, 0x63, 0x68, 0x65, 0x5f, 0x72, 0x65, 0x61, + 0x64, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x63, 0x61, 0x63, 0x68, 0x65, 0x52, 0x65, 0x61, 0x64, 0x49, + 0x6e, 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x37, 0x0a, 0x18, 0x63, 0x61, + 0x63, 0x68, 0x65, 0x5f, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x03, 0x52, 0x15, 0x63, 0x61, + 0x63, 0x68, 0x65, 0x57, 0x72, 0x69, 0x74, 0x65, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x73, 0x1a, 0x51, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2a, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x1b, 0x0a, 0x19, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, - 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x8f, 0x04, 0x0a, 0x16, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, - 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, - 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x6d, 0x73, 0x67, 0x5f, 0x69, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x73, 0x67, 0x49, 0x64, 0x12, 0x22, - 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x55, 0x72, 0x6c, 0x88, - 0x01, 0x01, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x6f, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x74, 0x6f, 0x6f, 0x6c, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x1a, 0x0a, 0x08, - 0x69, 0x6e, 0x6a, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, - 0x69, 0x6e, 0x6a, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x2e, 0x0a, 0x10, 0x69, 0x6e, 0x76, 0x6f, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x48, 0x01, 0x52, 0x0f, 0x69, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x45, 0x72, 0x72, 0x6f, 0x72, 0x88, 0x01, 0x01, 0x12, 0x47, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, - 0x64, 0x61, 0x74, 0x61, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, - 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, - 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x20, 0x0a, 0x0c, - 0x74, 0x6f, 0x6f, 0x6c, 0x5f, 0x63, 0x61, 0x6c, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x74, 0x6f, 0x6f, 0x6c, 0x43, 0x61, 0x6c, 0x6c, 0x49, 0x64, 0x1a, 0x51, - 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, - 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, - 0x79, 0x12, 0x2a, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, - 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x75, 0x72, 0x6c, - 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x69, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x19, 0x0a, 0x17, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, - 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0xb8, 0x02, 0x0a, 0x19, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, - 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, - 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, - 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, - 0x74, 0x12, 0x4a, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, - 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x39, 0x0a, - 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x1a, 0x51, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x1a, 0x0a, 0x18, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0xcb, 0x02, 0x0a, 0x18, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x27, 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, + 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x6d, 0x73, 0x67, 0x5f, + 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x73, 0x67, 0x49, 0x64, 0x12, + 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x49, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, + 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x1a, 0x51, 0x0a, + 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x2a, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, + 0x22, 0x1b, 0x0a, 0x19, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, + 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x8f, 0x04, + 0x0a, 0x16, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, + 0x64, 0x12, 0x15, 0x0a, 0x06, 0x6d, 0x73, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x6d, 0x73, 0x67, 0x49, 0x64, 0x12, 0x22, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x09, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x55, 0x72, 0x6c, 0x88, 0x01, 0x01, 0x12, 0x12, 0x0a, 0x04, + 0x74, 0x6f, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x6f, 0x6f, 0x6c, + 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x69, 0x6e, 0x6a, 0x65, 0x63, 0x74, + 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x6e, 0x6a, 0x65, 0x63, 0x74, + 0x65, 0x64, 0x12, 0x2e, 0x0a, 0x10, 0x69, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x0f, + 0x69, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x88, + 0x01, 0x01, 0x12, 0x47, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x08, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x39, 0x0a, 0x0a, 0x63, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x20, 0x0a, 0x0c, 0x74, 0x6f, 0x6f, 0x6c, 0x5f, 0x63, + 0x61, 0x6c, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x6f, + 0x6f, 0x6c, 0x43, 0x61, 0x6c, 0x6c, 0x49, 0x64, 0x1a, 0x51, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2a, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, - 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x1c, 0x0a, 0x1a, 0x52, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x75, 0x72, 0x6c, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x69, + 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0x19, 0x0a, 0x17, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, + 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x02, 0x0a, 0x19, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, - 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x35, 0x0a, 0x1a, 0x47, 0x65, 0x74, - 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, - 0x22, 0xb2, 0x01, 0x0a, 0x1b, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x40, 0x0a, 0x10, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x5f, 0x6d, 0x63, 0x70, 0x5f, 0x63, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2e, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x4d, 0x63, 0x70, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x51, 0x0a, 0x19, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, - 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x63, 0x70, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4d, 0x43, - 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x16, 0x65, - 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x63, 0x70, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x73, 0x22, 0x85, 0x01, 0x0a, 0x0f, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x28, 0x0a, 0x10, 0x74, - 0x6f, 0x6f, 0x6c, 0x5f, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x5f, 0x72, 0x65, 0x67, 0x65, 0x78, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6f, 0x6c, 0x41, 0x6c, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x67, 0x65, 0x78, 0x12, 0x26, 0x0a, 0x0f, 0x74, 0x6f, 0x6f, 0x6c, 0x5f, 0x64, 0x65, - 0x6e, 0x79, 0x5f, 0x72, 0x65, 0x67, 0x65, 0x78, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, - 0x74, 0x6f, 0x6f, 0x6c, 0x44, 0x65, 0x6e, 0x79, 0x52, 0x65, 0x67, 0x65, 0x78, 0x22, 0x72, 0x0a, - 0x24, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x31, - 0x0a, 0x15, 0x6d, 0x63, 0x70, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x63, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x12, 0x6d, - 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x49, 0x64, - 0x73, 0x22, 0xda, 0x02, 0x0a, 0x25, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, - 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x63, 0x0a, 0x0d, 0x61, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x3e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, - 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, - 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, - 0x12, 0x50, 0x0a, 0x06, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x38, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, - 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x45, - 0x72, 0x72, 0x6f, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x06, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x73, 0x1a, 0x3f, 0x0a, 0x11, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x1a, 0x39, 0x0a, 0x0b, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x3e, - 0x0a, 0x13, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x15, 0x0a, 0x06, 0x6b, 0x65, 0x79, 0x5f, 0x69, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6b, 0x65, 0x79, 0x49, 0x64, 0x22, 0x6b, - 0x0a, 0x14, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x49, - 0x64, 0x12, 0x1c, 0x0a, 0x0a, 0x61, 0x70, 0x69, 0x5f, 0x6b, 0x65, 0x79, 0x5f, 0x69, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x49, 0x64, 0x12, - 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x32, 0xa9, 0x04, 0x0a, 0x08, - 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x12, 0x59, 0x0a, 0x12, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, - 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x21, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, - 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x68, 0x0a, 0x17, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, - 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x12, 0x25, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, - 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x45, 0x6e, 0x64, 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x53, 0x0a, - 0x10, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, - 0x65, 0x12, 0x1e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x1f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x56, 0x0a, 0x11, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, - 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, - 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, - 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x50, 0x0a, 0x0f, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, - 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, - 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x59, 0x0a, 0x12, - 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, - 0x68, 0x74, 0x12, 0x20, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, - 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0xeb, 0x01, 0x0a, 0x0f, 0x4d, 0x43, 0x50, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x75, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x5c, 0x0a, 0x13, 0x47, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x0f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x49, + 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x12, 0x4a, 0x0a, 0x08, 0x6d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2e, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, + 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, + 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, + 0x41, 0x74, 0x1a, 0x51, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, + 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2a, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x1c, 0x0a, 0x1a, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, + 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x35, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x22, 0xb2, 0x01, 0x0a, 0x1b, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x73, 0x12, 0x21, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, - 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, - 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x7a, 0x0a, 0x1d, 0x47, 0x65, 0x74, - 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x2b, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x40, 0x0a, 0x10, 0x63, 0x6f, + 0x64, 0x65, 0x72, 0x5f, 0x6d, 0x63, 0x70, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4d, 0x43, 0x50, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x63, 0x6f, + 0x64, 0x65, 0x72, 0x4d, 0x63, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x51, 0x0a, 0x19, + 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x63, + 0x70, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x16, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, + 0x6c, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x63, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x22, + 0x85, 0x01, 0x0a, 0x0f, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x69, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x28, 0x0a, 0x10, 0x74, 0x6f, 0x6f, 0x6c, 0x5f, 0x61, 0x6c, + 0x6c, 0x6f, 0x77, 0x5f, 0x72, 0x65, 0x67, 0x65, 0x78, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0e, 0x74, 0x6f, 0x6f, 0x6c, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x67, 0x65, 0x78, 0x12, + 0x26, 0x0a, 0x0f, 0x74, 0x6f, 0x6f, 0x6c, 0x5f, 0x64, 0x65, 0x6e, 0x79, 0x5f, 0x72, 0x65, 0x67, + 0x65, 0x78, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x74, 0x6f, 0x6f, 0x6c, 0x44, 0x65, + 0x6e, 0x79, 0x52, 0x65, 0x67, 0x65, 0x78, 0x22, 0x72, 0x0a, 0x24, 0x47, 0x65, 0x74, 0x4d, 0x43, + 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x31, 0x0a, 0x15, 0x6d, 0x63, 0x70, 0x5f, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x69, 0x64, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x12, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x49, 0x64, 0x73, 0x22, 0xda, 0x02, 0x0a, 0x25, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x55, 0x0a, 0x0a, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x65, 0x72, 0x12, 0x47, 0x0a, 0x0c, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x65, 0x64, 0x12, 0x1a, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x73, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, - 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x76, 0x32, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x64, - 0x2f, 0x61, 0x69, 0x62, 0x72, 0x69, 0x64, 0x67, 0x65, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x63, 0x0a, 0x0d, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x3e, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, + 0x63, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x61, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x50, 0x0a, 0x06, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x45, + 0x6e, 0x74, 0x72, 0x79, 0x52, 0x06, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x1a, 0x3f, 0x0a, 0x11, + 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x1a, 0x39, 0x0a, + 0x0b, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x3e, 0x0a, 0x13, 0x49, 0x73, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, + 0x79, 0x12, 0x15, 0x0a, 0x06, 0x6b, 0x65, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x6b, 0x65, 0x79, 0x49, 0x64, 0x22, 0x6b, 0x0a, 0x14, 0x49, 0x73, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1c, 0x0a, 0x0a, 0x61, + 0x70, 0x69, 0x5f, 0x6b, 0x65, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x61, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, + 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, + 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x32, 0xa9, 0x04, 0x0a, 0x08, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, + 0x65, 0x72, 0x12, 0x59, 0x0a, 0x12, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, + 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x68, 0x0a, + 0x17, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x12, 0x25, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x26, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x49, 0x6e, + 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x65, 0x64, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x53, 0x0a, 0x10, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1e, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, + 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x55, + 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x56, 0x0a, 0x11, + 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, + 0x65, 0x12, 0x1f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, + 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x50, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, + 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x54, 0x6f, 0x6f, 0x6c, 0x55, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x59, 0x0a, 0x12, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, + 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x12, 0x20, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, + 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x4d, 0x6f, 0x64, + 0x65, 0x6c, 0x54, 0x68, 0x6f, 0x75, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x32, 0xeb, 0x01, 0x0a, 0x0f, 0x4d, 0x43, 0x50, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x75, + 0x72, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x5c, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x12, 0x21, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x22, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x7a, 0x0a, 0x1d, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, + 0x61, 0x74, 0x63, 0x68, 0x12, 0x2b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, + 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x2c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x43, 0x50, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, + 0x55, 0x0a, 0x0a, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x72, 0x12, 0x47, 0x0a, + 0x0c, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x1a, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x65, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x49, 0x73, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, + 0x2f, 0x76, 0x32, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x64, 0x2f, 0x61, 0x69, 0x62, 0x72, 0x69, + 0x64, 0x67, 0x65, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/coderd/aibridged/proto/aibridged.proto b/coderd/aibridged/proto/aibridged.proto index cd61411516..08fceee676 100644 --- a/coderd/aibridged/proto/aibridged.proto +++ b/coderd/aibridged/proto/aibridged.proto @@ -58,6 +58,7 @@ message RecordInterceptionResponse {} message RecordInterceptionEndedRequest { string id = 1; // UUID. google.protobuf.Timestamp ended_at = 2; + string credential_hint = 3; } message RecordInterceptionEndedResponse {} diff --git a/coderd/aibridged/provider.go b/coderd/aibridged/provider.go new file mode 100644 index 0000000000..9d2faa030b --- /dev/null +++ b/coderd/aibridged/provider.go @@ -0,0 +1,28 @@ +package aibridged + +// ProviderStatus is the lifecycle state of a configured AI provider. +type ProviderStatus string + +const ( + // ProviderStatusEnabled indicates the provider is configured and + // valid, and is included in the active pool snapshot. + ProviderStatusEnabled ProviderStatus = "enabled" + // ProviderStatusDisabled indicates the provider is configured but + // intentionally turned off by an operator. + ProviderStatusDisabled ProviderStatus = "disabled" + // ProviderStatusError indicates the provider is configured but + // cannot be constructed (missing keys, unsupported type, malformed + // settings). + ProviderStatusError ProviderStatus = "error" +) + +// ProviderOutcome classifies one ai_providers row, including disabled +// rows (which the pool keeps as 503 stubs) and errored rows (which the +// pool excludes). Err is populated only when Status == ProviderStatusError; +// the build error is already logged at the call site. +type ProviderOutcome struct { + Name string + Type string + Status ProviderStatus + Err error +} diff --git a/coderd/aibridged/translator.go b/coderd/aibridged/translator.go index 2769ef0d89..6d251df0fe 100644 --- a/coderd/aibridged/translator.go +++ b/coderd/aibridged/translator.go @@ -45,8 +45,9 @@ func (t *recorderTranslation) RecordInterception(ctx context.Context, req *aibri func (t *recorderTranslation) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error { _, err := t.client.RecordInterceptionEnded(ctx, &proto.RecordInterceptionEndedRequest{ - Id: req.ID, - EndedAt: timestamppb.New(req.EndedAt), + Id: req.ID, + EndedAt: timestamppb.New(req.EndedAt), + CredentialHint: req.CredentialHint, }) return err } diff --git a/coderd/aibridgedserver/aibridgedserver.go b/coderd/aibridgedserver/aibridgedserver.go index c593b18f79..8dbaa10bfa 100644 --- a/coderd/aibridgedserver/aibridgedserver.go +++ b/coderd/aibridgedserver/aibridgedserver.go @@ -222,8 +222,9 @@ func (s *Server) RecordInterceptionEnded(ctx context.Context, in *proto.RecordIn } _, err = s.store.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{ - ID: intcID, - EndedAt: in.EndedAt.AsTime(), + ID: intcID, + EndedAt: in.EndedAt.AsTime(), + CredentialHint: in.CredentialHint, }) if err != nil { return nil, xerrors.Errorf("end interception: %w", err) diff --git a/coderd/aibridgedserver/aibridgedserver_test.go b/coderd/aibridgedserver/aibridgedserver_test.go index eb2f413e1e..9aeb082069 100644 --- a/coderd/aibridgedserver/aibridgedserver_test.go +++ b/coderd/aibridgedserver/aibridgedserver_test.go @@ -944,23 +944,26 @@ func TestRecordInterceptionEnded(t *testing.T) { { name: "ok", request: &proto.RecordInterceptionEndedRequest{ - Id: uuid.UUID{1}.String(), - EndedAt: timestamppb.Now(), + Id: uuid.UUID{1}.String(), + EndedAt: timestamppb.Now(), + CredentialHint: "sk-a...efgh", }, setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionEndedRequest) { interceptionID, err := uuid.Parse(req.GetId()) assert.NoError(t, err, "parse interception UUID") db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), database.UpdateAIBridgeInterceptionEndedParams{ - ID: interceptionID, - EndedAt: req.EndedAt.AsTime(), + ID: interceptionID, + EndedAt: req.EndedAt.AsTime(), + CredentialHint: req.CredentialHint, }).Return(database.AIBridgeInterception{ - ID: interceptionID, - InitiatorID: uuid.UUID{2}, - Provider: "prov", - Model: "mod", - StartedAt: time.Now(), - EndedAt: sql.NullTime{Time: req.EndedAt.AsTime(), Valid: true}, + ID: interceptionID, + InitiatorID: uuid.UUID{2}, + Provider: "prov", + Model: "mod", + StartedAt: time.Now(), + EndedAt: sql.NullTime{Time: req.EndedAt.AsTime(), Valid: true}, + CredentialHint: req.CredentialHint, }, nil) }, }, diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index ae5715bd78..1bbb216b1a 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -9171,6 +9171,110 @@ const docTemplate = `{ ] } }, + "/api/v2/users/{user}/ai/budget": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "Enterprise" + ], + "summary": "Get user AI budget override", + "operationId": "get-user-ai-budget-override", + "parameters": [ + { + "type": "string", + "description": "User ID, username, or me", + "name": "user", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.UserAIBudgetOverride" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + }, + "put": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Enterprise" + ], + "summary": "Upsert user AI budget override", + "operationId": "upsert-user-ai-budget-override", + "parameters": [ + { + "type": "string", + "description": "User ID, username, or me", + "name": "user", + "in": "path", + "required": true + }, + { + "description": "Upsert user AI budget override request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/codersdk.UpsertUserAIBudgetOverrideRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.UserAIBudgetOverride" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + }, + "delete": { + "tags": [ + "Enterprise" + ], + "summary": "Delete user AI budget override", + "operationId": "delete-user-ai-budget-override", + "parameters": [ + { + "type": "string", + "description": "User ID, username, or me", + "name": "user", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/api/v2/users/{user}/appearance": { "get": { "produces": [ @@ -13961,7 +14065,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } ], @@ -13969,7 +14073,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } }, @@ -14035,7 +14139,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } ], @@ -14077,7 +14181,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } ], @@ -14288,71 +14392,6 @@ const docTemplate = `{ "ReinitializeReasonPrebuildClaimed" ] }, - "coderd.SCIMUser": { - "type": "object", - "properties": { - "active": { - "description": "Active is a ptr to prevent the empty value from being interpreted as false.", - "type": "boolean" - }, - "emails": { - "type": "array", - "items": { - "type": "object", - "properties": { - "display": { - "type": "string" - }, - "primary": { - "type": "boolean" - }, - "type": { - "type": "string" - }, - "value": { - "type": "string", - "format": "email" - } - } - } - }, - "groups": { - "type": "array", - "items": {} - }, - "id": { - "type": "string" - }, - "meta": { - "type": "object", - "properties": { - "resourceType": { - "type": "string" - } - } - }, - "name": { - "type": "object", - "properties": { - "familyName": { - "type": "string" - }, - "givenName": { - "type": "string" - } - } - }, - "schemas": { - "type": "array", - "items": { - "type": "string" - } - }, - "userName": { - "type": "string" - } - } - }, "coderd.cspViolation": { "type": "object", "properties": { @@ -15071,7 +15110,7 @@ const docTemplate = `{ "type": "string" }, "type": { - "description": "Type is the provider type: \"openai\", \"anthropic\", or \"copilot\".", + "description": "Type is the provider type. Valid values are: \"openai\",\n\"anthropic\", \"azure\", \"bedrock\", \"google\", \"openai-compat\",\n\"openrouter\", \"vercel\", \"copilot\".", "type": "string" } } @@ -15264,6 +15303,10 @@ const docTemplate = `{ "audit_log:*", "audit_log:create", "audit_log:read", + "boundary_log:*", + "boundary_log:create", + "boundary_log:delete", + "boundary_log:read", "boundary_usage:*", "boundary_usage:delete", "boundary_usage:read", @@ -15490,6 +15533,10 @@ const docTemplate = `{ "APIKeyScopeAuditLogAll", "APIKeyScopeAuditLogCreate", "APIKeyScopeAuditLogRead", + "APIKeyScopeBoundaryLogAll", + "APIKeyScopeBoundaryLogCreate", + "APIKeyScopeBoundaryLogDelete", + "APIKeyScopeBoundaryLogRead", "APIKeyScopeBoundaryUsageAll", "APIKeyScopeBoundaryUsageDelete", "APIKeyScopeBoundaryUsageRead", @@ -16554,7 +16601,9 @@ const docTemplate = `{ "startup_timeout", "auth", "config", - "usage_limit" + "usage_limit", + "missing_key", + "provider_disabled" ], "x-enum-varnames": [ "ChatErrorKindGeneric", @@ -16564,7 +16613,9 @@ const docTemplate = `{ "ChatErrorKindStartupTimeout", "ChatErrorKindAuth", "ChatErrorKindConfig", - "ChatErrorKindUsageLimit" + "ChatErrorKindUsageLimit", + "ChatErrorKindMissingKey", + "ChatErrorKindProviderDisabled" ] }, "codersdk.ChatFileMetadata": { @@ -18789,6 +18840,9 @@ const docTemplate = `{ "scim_api_key": { "type": "string" }, + "scim_use_legacy": { + "type": "boolean" + }, "session_lifetime": { "$ref": "#/definitions/codersdk.SessionLifetime" }, @@ -22283,6 +22337,7 @@ const docTemplate = `{ "assign_org_role", "assign_role", "audit_log", + "boundary_log", "boundary_usage", "chat", "connection_log", @@ -22333,6 +22388,7 @@ const docTemplate = `{ "ResourceAssignOrgRole", "ResourceAssignRole", "ResourceAuditLog", + "ResourceBoundaryLog", "ResourceBoundaryUsage", "ResourceChat", "ResourceConnectionLog", @@ -24711,6 +24767,23 @@ const docTemplate = `{ } } }, + "codersdk.UpsertUserAIBudgetOverrideRequest": { + "type": "object", + "required": [ + "group_id" + ], + "properties": { + "group_id": { + "description": "GroupID is the group the user's spend is attributed to. The user must\nbe a member of this group.", + "type": "string", + "format": "uuid" + }, + "spend_limit_micros": { + "type": "integer", + "minimum": 0 + } + } + }, "codersdk.UpsertWorkspaceAgentPortShareRequest": { "type": "object", "properties": { @@ -24865,6 +24938,30 @@ const docTemplate = `{ } } }, + "codersdk.UserAIBudgetOverride": { + "type": "object", + "properties": { + "created_at": { + "type": "string", + "format": "date-time" + }, + "group_id": { + "type": "string", + "format": "uuid" + }, + "spend_limit_micros": { + "type": "integer" + }, + "updated_at": { + "type": "string", + "format": "date-time" + }, + "user_id": { + "type": "string", + "format": "uuid" + } + } + }, "codersdk.UserActivity": { "type": "object", "properties": { @@ -27406,6 +27503,71 @@ const docTemplate = `{ "key.NodePublic": { "type": "object" }, + "legacyscim.SCIMUser": { + "type": "object", + "properties": { + "active": { + "description": "Active is a ptr to prevent the empty value from being interpreted as false.", + "type": "boolean" + }, + "emails": { + "type": "array", + "items": { + "type": "object", + "properties": { + "display": { + "type": "string" + }, + "primary": { + "type": "boolean" + }, + "type": { + "type": "string" + }, + "value": { + "type": "string", + "format": "email" + } + } + } + }, + "groups": { + "type": "array", + "items": {} + }, + "id": { + "type": "string" + }, + "meta": { + "type": "object", + "properties": { + "resourceType": { + "type": "string" + } + } + }, + "name": { + "type": "object", + "properties": { + "familyName": { + "type": "string" + }, + "givenName": { + "type": "string" + } + } + }, + "schemas": { + "type": "array", + "items": { + "type": "string" + } + }, + "userName": { + "type": "string" + } + } + }, "netcheck.Report": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 21ee879158..6f7224e972 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -8132,6 +8132,98 @@ ] } }, + "/api/v2/users/{user}/ai/budget": { + "get": { + "produces": ["application/json"], + "tags": ["Enterprise"], + "summary": "Get user AI budget override", + "operationId": "get-user-ai-budget-override", + "parameters": [ + { + "type": "string", + "description": "User ID, username, or me", + "name": "user", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.UserAIBudgetOverride" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + }, + "put": { + "consumes": ["application/json"], + "produces": ["application/json"], + "tags": ["Enterprise"], + "summary": "Upsert user AI budget override", + "operationId": "upsert-user-ai-budget-override", + "parameters": [ + { + "type": "string", + "description": "User ID, username, or me", + "name": "user", + "in": "path", + "required": true + }, + { + "description": "Upsert user AI budget override request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/codersdk.UpsertUserAIBudgetOverrideRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.UserAIBudgetOverride" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + }, + "delete": { + "tags": ["Enterprise"], + "summary": "Delete user AI budget override", + "operationId": "delete-user-ai-budget-override", + "parameters": [ + { + "type": "string", + "description": "User ID, username, or me", + "name": "user", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/api/v2/users/{user}/appearance": { "get": { "produces": ["application/json"], @@ -12389,7 +12481,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } ], @@ -12397,7 +12489,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } }, @@ -12455,7 +12547,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } ], @@ -12493,7 +12585,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/coderd.SCIMUser" + "$ref": "#/definitions/legacyscim.SCIMUser" } } ], @@ -12692,71 +12784,6 @@ "enum": ["prebuild_claimed"], "x-enum-varnames": ["ReinitializeReasonPrebuildClaimed"] }, - "coderd.SCIMUser": { - "type": "object", - "properties": { - "active": { - "description": "Active is a ptr to prevent the empty value from being interpreted as false.", - "type": "boolean" - }, - "emails": { - "type": "array", - "items": { - "type": "object", - "properties": { - "display": { - "type": "string" - }, - "primary": { - "type": "boolean" - }, - "type": { - "type": "string" - }, - "value": { - "type": "string", - "format": "email" - } - } - } - }, - "groups": { - "type": "array", - "items": {} - }, - "id": { - "type": "string" - }, - "meta": { - "type": "object", - "properties": { - "resourceType": { - "type": "string" - } - } - }, - "name": { - "type": "object", - "properties": { - "familyName": { - "type": "string" - }, - "givenName": { - "type": "string" - } - } - }, - "schemas": { - "type": "array", - "items": { - "type": "string" - } - }, - "userName": { - "type": "string" - } - } - }, "coderd.cspViolation": { "type": "object", "properties": { @@ -13475,7 +13502,7 @@ "type": "string" }, "type": { - "description": "Type is the provider type: \"openai\", \"anthropic\", or \"copilot\".", + "description": "Type is the provider type. Valid values are: \"openai\",\n\"anthropic\", \"azure\", \"bedrock\", \"google\", \"openai-compat\",\n\"openrouter\", \"vercel\", \"copilot\".", "type": "string" } } @@ -13660,6 +13687,10 @@ "audit_log:*", "audit_log:create", "audit_log:read", + "boundary_log:*", + "boundary_log:create", + "boundary_log:delete", + "boundary_log:read", "boundary_usage:*", "boundary_usage:delete", "boundary_usage:read", @@ -13886,6 +13917,10 @@ "APIKeyScopeAuditLogAll", "APIKeyScopeAuditLogCreate", "APIKeyScopeAuditLogRead", + "APIKeyScopeBoundaryLogAll", + "APIKeyScopeBoundaryLogCreate", + "APIKeyScopeBoundaryLogDelete", + "APIKeyScopeBoundaryLogRead", "APIKeyScopeBoundaryUsageAll", "APIKeyScopeBoundaryUsageDelete", "APIKeyScopeBoundaryUsageRead", @@ -14904,7 +14939,9 @@ "startup_timeout", "auth", "config", - "usage_limit" + "usage_limit", + "missing_key", + "provider_disabled" ], "x-enum-varnames": [ "ChatErrorKindGeneric", @@ -14914,7 +14951,9 @@ "ChatErrorKindStartupTimeout", "ChatErrorKindAuth", "ChatErrorKindConfig", - "ChatErrorKindUsageLimit" + "ChatErrorKindUsageLimit", + "ChatErrorKindMissingKey", + "ChatErrorKindProviderDisabled" ] }, "codersdk.ChatFileMetadata": { @@ -17060,6 +17099,9 @@ "scim_api_key": { "type": "string" }, + "scim_use_legacy": { + "type": "boolean" + }, "session_lifetime": { "$ref": "#/definitions/codersdk.SessionLifetime" }, @@ -20426,6 +20468,7 @@ "assign_org_role", "assign_role", "audit_log", + "boundary_log", "boundary_usage", "chat", "connection_log", @@ -20476,6 +20519,7 @@ "ResourceAssignOrgRole", "ResourceAssignRole", "ResourceAuditLog", + "ResourceBoundaryLog", "ResourceBoundaryUsage", "ResourceChat", "ResourceConnectionLog", @@ -22744,6 +22788,21 @@ } } }, + "codersdk.UpsertUserAIBudgetOverrideRequest": { + "type": "object", + "required": ["group_id"], + "properties": { + "group_id": { + "description": "GroupID is the group the user's spend is attributed to. The user must\nbe a member of this group.", + "type": "string", + "format": "uuid" + }, + "spend_limit_micros": { + "type": "integer", + "minimum": 0 + } + } + }, "codersdk.UpsertWorkspaceAgentPortShareRequest": { "type": "object", "properties": { @@ -22877,6 +22936,30 @@ } } }, + "codersdk.UserAIBudgetOverride": { + "type": "object", + "properties": { + "created_at": { + "type": "string", + "format": "date-time" + }, + "group_id": { + "type": "string", + "format": "uuid" + }, + "spend_limit_micros": { + "type": "integer" + }, + "updated_at": { + "type": "string", + "format": "date-time" + }, + "user_id": { + "type": "string", + "format": "uuid" + } + } + }, "codersdk.UserActivity": { "type": "object", "properties": { @@ -25271,6 +25354,71 @@ "key.NodePublic": { "type": "object" }, + "legacyscim.SCIMUser": { + "type": "object", + "properties": { + "active": { + "description": "Active is a ptr to prevent the empty value from being interpreted as false.", + "type": "boolean" + }, + "emails": { + "type": "array", + "items": { + "type": "object", + "properties": { + "display": { + "type": "string" + }, + "primary": { + "type": "boolean" + }, + "type": { + "type": "string" + }, + "value": { + "type": "string", + "format": "email" + } + } + } + }, + "groups": { + "type": "array", + "items": {} + }, + "id": { + "type": "string" + }, + "meta": { + "type": "object", + "properties": { + "resourceType": { + "type": "string" + } + } + }, + "name": { + "type": "object", + "properties": { + "familyName": { + "type": "string" + }, + "givenName": { + "type": "string" + } + } + }, + "schemas": { + "type": "array", + "items": { + "type": "string" + } + }, + "userName": { + "type": "string" + } + } + }, "netcheck.Report": { "type": "object", "properties": { diff --git a/coderd/audit.go b/coderd/audit.go index 661019d063..a58168567f 100644 --- a/coderd/audit.go +++ b/coderd/audit.go @@ -303,6 +303,12 @@ func auditLogDescription(alog database.GetAuditLogsOffsetRow) string { _, _ = b.WriteString("{user} ") } + // Chat write operations get semantic descriptions derived from the diff. + if desc, ok := chatAuditLogDescription(alog); ok { + _, _ = b.WriteString(desc) + return b.String() + } + switch { case alog.AuditLog.StatusCode == int32(http.StatusSeeOther): _, _ = b.WriteString("was redirected attempting to ") @@ -345,6 +351,56 @@ func auditLogDescription(alog database.GetAuditLogsOffsetRow) string { return b.String() } +// chatAuditLogDescription returns a description for successful chat write +// operations based on the diff contents. It returns false for non-chat +// resources, non-write actions, or error/redirect status codes, letting +// the caller fall through to the generic description. +func chatAuditLogDescription(alog database.GetAuditLogsOffsetRow) (string, bool) { + if alog.AuditLog.ResourceType != database.ResourceTypeChat || + alog.AuditLog.Action != database.AuditActionWrite || + alog.AuditLog.StatusCode >= 400 || + alog.AuditLog.StatusCode == int32(http.StatusSeeOther) { + return "", false + } + + var diff codersdk.AuditDiff + if err := json.Unmarshal(alog.AuditLog.Diff, &diff); err != nil { + return "", false + } + + // Single "archived" field: archive or unarchive. + if len(diff) == 1 { + if field, ok := diff["archived"]; ok { + oldVal, oldOK := field.Old.(bool) + newVal, newOK := field.New.(bool) + if oldOK && newOK { + if !oldVal && newVal { + return "archived chat {target}", true + } + if oldVal && !newVal { + return "unarchived chat {target}", true + } + } + } + } + + // All fields are ACL changes: sharing update. + if len(diff) > 0 { + aclOnly := true + for field := range diff { + if field != "user_acl" && field != "group_acl" { + aclOnly = false + break + } + } + if aclOnly { + return "updated sharing for chat {target}", true + } + } + + return "", false +} + func (api *API) auditLogIsResourceDeleted(ctx context.Context, alog database.GetAuditLogsOffsetRow) bool { switch alog.AuditLog.ResourceType { case database.ResourceTypeTemplate: diff --git a/coderd/audit_internal_test.go b/coderd/audit_internal_test.go index cc7fddf3e0..640690cff9 100644 --- a/coderd/audit_internal_test.go +++ b/coderd/audit_internal_test.go @@ -3,6 +3,7 @@ package coderd import ( "context" "database/sql" + "encoding/json" "testing" "github.com/google/uuid" @@ -14,6 +15,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/codersdk" ) func TestAuditLogIsResourceDeleted(t *testing.T) { @@ -111,6 +113,91 @@ func TestAuditLogDescription(t *testing.T) { }, want: "{user} deleted the git ssh key", }, + { + name: "chat_archived", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "archived": {Old: false, New: true}, + }), + want: "{user} archived chat {target}", + }, + { + name: "chat_unarchived", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "archived": {Old: true, New: false}, + }), + want: "{user} unarchived chat {target}", + }, + { + name: "chat_sharing_user_acl", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "user_acl": {Old: map[string]any{}, New: map[string]any{"user-1": map[string]any{"permissions": []string{"read"}}}}, + }), + want: "{user} updated sharing for chat {target}", + }, + { + name: "chat_sharing_group_acl", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "group_acl": {Old: map[string]any{}, New: map[string]any{"group-1": map[string]any{"permissions": []string{"read"}}}}, + }), + want: "{user} updated sharing for chat {target}", + }, + { + name: "chat_sharing_both_acls", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "user_acl": {Old: map[string]any{}, New: map[string]any{"user-1": map[string]any{"permissions": []string{"read"}}}}, + "group_acl": {Old: map[string]any{}, New: map[string]any{"group-1": map[string]any{"permissions": []string{"read"}}}}, + }), + want: "{user} updated sharing for chat {target}", + }, + { + name: "chat_mixed_diff_falls_through", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "archived": {Old: false, New: true}, + "pin_order": {Old: 1, New: 0}, + }), + want: "{user} updated chat {target}", + }, + { + name: "chat_acl_with_extra_field_falls_through", + alog: chatAuditLogRow(t, codersdk.AuditDiff{ + "user_acl": {Old: map[string]any{}, New: map[string]any{}}, + "pin_order": {Old: 1, New: 0}, + }), + want: "{user} updated chat {target}", + }, + { + name: "chat_failed_write_no_override", + alog: func() database.GetAuditLogsOffsetRow { + row := chatAuditLogRow(t, codersdk.AuditDiff{ + "archived": {Old: false, New: true}, + }) + row.AuditLog.StatusCode = 400 + return row + }(), + want: "{user} unsuccessfully attempted to write chat {target}", + }, + { + name: "chat_redirect_no_override", + alog: func() database.GetAuditLogsOffsetRow { + row := chatAuditLogRow(t, codersdk.AuditDiff{ + "archived": {Old: false, New: true}, + }) + row.AuditLog.StatusCode = 303 + return row + }(), + want: "{user} was redirected attempting to write chat {target}", + }, + { + name: "chat_non_write_action_no_override", + alog: func() database.GetAuditLogsOffsetRow { + row := chatAuditLogRow(t, codersdk.AuditDiff{ + "user_acl": {Old: map[string]any{}, New: map[string]any{"user-1": map[string]any{"permissions": []string{"read"}}}}, + }) + row.AuditLog.Action = database.AuditActionCreate + return row + }(), + want: "{user} created chat {target}", + }, } // nolint: paralleltest // no longer need to reinitialize loop vars in go 1.22 for _, tc := range testCases { @@ -121,3 +208,19 @@ func TestAuditLogDescription(t *testing.T) { }) } } + +// chatAuditLogRow builds a GetAuditLogsOffsetRow for a successful chat write +// with the given diff, suitable for testing auditLogDescription. +func chatAuditLogRow(t *testing.T, diff codersdk.AuditDiff) database.GetAuditLogsOffsetRow { + t.Helper() + rawDiff, err := json.Marshal(diff) + require.NoError(t, err) + return database.GetAuditLogsOffsetRow{ + AuditLog: database.AuditLog{ + Action: database.AuditActionWrite, + StatusCode: 200, + ResourceType: database.ResourceTypeChat, + Diff: rawDiff, + }, + } +} diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index 84fff375e0..5a141ce8cf 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -422,6 +422,23 @@ func (e *Executor) runOnce(t time.Time) Stats { Isolation: sql.LevelRepeatableRead, TxIdentifier: "lifecycle", }) + // A concurrent build (e.g. from the API or another lifecycle + // executor) may have already inserted a build with the same + // number. This is a benign race; the other actor's build + // will take effect. Clear the error so downstream checks + // (audit, notification, stats) treat this as a no-op. + if database.IsUniqueViolation(err, database.UniqueWorkspaceBuildsWorkspaceIDBuildNumberKey) { + log.Info(e.ctx, "skipping workspace: concurrent build already inserted", slog.Error(err)) + err = nil + // Reset notification flags set before builder.Build. + // The build was rolled back, so this executor did not + // perform the transition. The concurrent actor handles + // both the build and any notifications. Without these + // resets, downstream code would send duplicate or + // incorrect notifications. + didAutoUpdate = false + shouldNotifyTaskPause = false + } if auditLog != nil { // If the transition didn't succeed then updating the workspace // to indicate dormant didn't either. diff --git a/coderd/autobuild/lifecycle_executor_test.go b/coderd/autobuild/lifecycle_executor_test.go index 345647977d..89805429b9 100644 --- a/coderd/autobuild/lifecycle_executor_test.go +++ b/coderd/autobuild/lifecycle_executor_test.go @@ -4,10 +4,12 @@ import ( "context" "database/sql" "errors" + "sync/atomic" "testing" "time" "github.com/google/uuid" + "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -160,6 +162,92 @@ func TestMultipleLifecycleExecutors(t *testing.T) { assert.Equal(t, database.WorkspaceTransitionStart, stats.Transitions[workspace.ID]) } +// uniqueViolationStore wraps a database.Store and injects a unique violation +// error from InsertWorkspaceBuild after a configurable number of successful +// calls. This simulates a concurrent build race (e.g. an API-driven start +// racing with the lifecycle executor autostart). +type uniqueViolationStore struct { + database.Store + insertCount *atomic.Int32 // pointer: shared across InTx copies + failAfterN int32 +} + +func newUniqueViolationStore(db database.Store, failAfterN int32) *uniqueViolationStore { + return &uniqueViolationStore{ + Store: db, + insertCount: &atomic.Int32{}, + failAfterN: failAfterN, + } +} + +func (s *uniqueViolationStore) InTx(fn func(database.Store) error, opts *database.TxOptions) error { + return s.Store.InTx(func(tx database.Store) error { + return fn(&uniqueViolationStore{ + Store: tx, + insertCount: s.insertCount, // shared pointer + failAfterN: s.failAfterN, + }) + }, opts) +} + +func (s *uniqueViolationStore) InsertWorkspaceBuild(ctx context.Context, arg database.InsertWorkspaceBuildParams) error { + n := s.insertCount.Add(1) + if n > s.failAfterN { + return &pq.Error{ + Code: pq.ErrorCode("23505"), + Constraint: string(database.UniqueWorkspaceBuildsWorkspaceIDBuildNumberKey), + Message: `duplicate key value violates unique constraint "workspace_builds_workspace_id_build_number_key"`, + } + } + return s.Store.InsertWorkspaceBuild(ctx, arg) +} + +func TestExecutorBuildNumberRaceIsHandled(t *testing.T) { + t.Parallel() + + // The lifecycle executor must handle a unique-violation from + // InsertWorkspaceBuild gracefully. This error occurs when a concurrent + // actor (API handler, another executor, prebuilds reconciler) inserts a + // build with the same number before the executor's INSERT lands. + // + // We inject the error via a store wrapper. The first two + // InsertWorkspaceBuild calls succeed (setup builds), then the third + // (the lifecycle executor's autostart build) gets a unique violation. + + realDB, ps := dbtestutil.NewDB(t) + wrappedDB := newUniqueViolationStore(realDB, 2) // Allow builds 1 (start) and 2 (stop); fail build 3 (autostart) + + var ( + sched, _ = cron.Weekly("CRON_TZ=UTC 0 * * * *") + tickCh = make(chan time.Time) + statsCh = make(chan autobuild.Stats) + client = coderdtest.New(t, &coderdtest.Options{ + IncludeProvisionerDaemon: true, + AutobuildTicker: tickCh, + AutobuildStats: statsCh, + Database: wrappedDB, + Pubsub: ps, + }) + workspace = mustProvisionWorkspace(t, client, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.AutostartSchedule = ptr.Ref(sched.String()) + }) + ) + + workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop) + + p, err := coderdtest.GetProvisionerForTags(realDB, time.Now(), workspace.OrganizationID, nil) + require.NoError(t, err) + next := sched.Next(workspace.LatestBuild.CreatedAt) + coderdtest.UpdateProvisionerLastSeenAt(t, realDB, p.ID, next) + + tickCh <- next + stats := <-statsCh + + // The lifecycle executor should treat the unique violation as a benign + // race, not as a hard error. + assert.Empty(t, stats.Errors, "lifecycle executor should not report unique-violation as error") +} + func TestExecutorAutostartTemplateUpdated(t *testing.T) { t.Parallel() diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index 5682341ef9..1c20622e58 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -44,6 +44,7 @@ const ( CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events + CheckUserAiBudgetOverridesSpendLimitMicrosCheck CheckConstraint = "user_ai_budget_overrides_spend_limit_micros_check" // user_ai_budget_overrides CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys CheckUserSkillsContentSize CheckConstraint = "user_skills_content_size" // user_skills CheckUserSkillsDescriptionSize CheckConstraint = "user_skills_description_size" // user_skills diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 36b081b86b..bc93df7cd3 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1509,6 +1509,16 @@ func GroupAIBudget(b database.GroupAiBudget) codersdk.GroupAIBudget { } } +func UserAIBudgetOverride(o database.UserAiBudgetOverride) codersdk.UserAIBudgetOverride { + return codersdk.UserAIBudgetOverride{ + UserID: o.UserID, + GroupID: o.GroupID, + SpendLimitMicros: o.SpendLimitMicros, + CreatedAt: o.CreatedAt, + UpdatedAt: o.UpdatedAt, + } +} + func InvalidatedPresets(invalidatedPresets []database.UpdatePresetsLastInvalidatedAtRow) []codersdk.InvalidatedPreset { var presets []codersdk.InvalidatedPreset for _, p := range invalidatedPresets { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 226f7f8ca7..ca6e0203c6 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -651,6 +651,8 @@ var ( rbac.ResourceAibridgeInterception.Type: {policy.ActionDelete}, // Chat auto-archive sets archived=true on inactive chats. rbac.ResourceChat.Type: {policy.ActionRead, policy.ActionUpdate}, + // Purge old boundary logs past the retention period. + rbac.ResourceBoundaryLog.Type: {policy.ActionDelete}, }), User: []rbac.Permission{}, ByOrgID: map[string]rbac.OrgPermissions{}, @@ -742,6 +744,29 @@ var ( }), Scope: rbac.ScopeAll, }.WithCachedASTValue() + + subjectSCIM = rbac.Subject{ + Type: rbac.SubjectTypeSCIMProvisioner, + FriendlyName: "SCIM Provisioner", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "scim"}, + DisplayName: "SCIM", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceSystem.Type: {policy.ActionRead}, // Required for idp config reads, this should be fixed + rbac.ResourceAssignRole.Type: rbac.ResourceAssignRole.AvailableActions(), + rbac.ResourceAssignOrgRole.Type: rbac.ResourceAssignOrgRole.AvailableActions(), + rbac.ResourceUser.Type: {policy.ActionCreate, policy.ActionUpdate, policy.ActionRead, policy.ActionUpdatePersonal}, + rbac.ResourceOrganization.Type: {policy.ActionRead}, + rbac.ResourceOrganizationMember.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionUpdate}, + }), + User: []rbac.Permission{}, + ByOrgID: map[string]rbac.OrgPermissions{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() ) // AsProvisionerd returns a context with an actor that has permissions required @@ -872,6 +897,12 @@ func AsAIProviderMetadataReader(ctx context.Context) context.Context { return As(ctx, subjectAIProviderMetadataReader) } +// AsSCIMProvisioner returns a context with an actor that has permissions required for +// handling the /scim/v2 routes and provisioning users via SCIM. +func AsSCIMProvisioner(ctx context.Context) context.Context { + return As(ctx, subjectSCIM) +} + var AsRemoveActor = rbac.Subject{ ID: "remove-actor", } @@ -2162,9 +2193,8 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld return q.db.DeleteOldAuditLogs(ctx, arg) } -// TODO (PR #24810): Replace rbac.ResourceSystem with dedicated boundary_log resource type. func (q *querier) DeleteOldBoundaryLogs(ctx context.Context, arg database.DeleteOldBoundaryLogsParams) (int64, error) { - if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil { + if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceBoundaryLog); err != nil { return 0, err } return q.db.DeleteOldBoundaryLogs(ctx, arg) @@ -2293,6 +2323,32 @@ func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) return q.db.DeleteTask(ctx, arg) } +func (q *querier) DeleteUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (database.UserAiBudgetOverride, error) { + // Removing a user's AI budget override affects both the user (clearing + // their per-user spend cap) and the group it was attributed to. + u, err := q.db.GetUserByID(ctx, userID) + if err != nil { + return database.UserAiBudgetOverride{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, u); err != nil { + return database.UserAiBudgetOverride{}, err + } + // Fetch the existing override to learn which group it attributes spend to, + // so we can authorize the caller against that group as well. + userOverride, err := q.db.GetUserAIBudgetOverride(ctx, userID) + if err != nil { + return database.UserAiBudgetOverride{}, err + } + g, err := q.db.GetGroupByID(ctx, userOverride.GroupID) + if err != nil { + return database.UserAiBudgetOverride{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, g); err != nil { + return database.UserAiBudgetOverride{}, err + } + return q.db.DeleteUserAIBudgetOverride(ctx, userID) +} + func (q *querier) DeleteUserAIProviderKey(ctx context.Context, arg database.DeleteUserAIProviderKeyParams) error { u, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { @@ -2751,17 +2807,15 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI return q.db.GetAuthorizationUserRoles(ctx, userID) } -// TODO (PR #24810): Replace rbac.ResourceAuditLog with dedicated boundary_log resource type. func (q *querier) GetBoundaryLogByID(ctx context.Context, id uuid.UUID) (database.BoundaryLog, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAuditLog); err != nil { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceBoundaryLog); err != nil { return database.BoundaryLog{}, err } return q.db.GetBoundaryLogByID(ctx, id) } -// TODO (PR #24810): Replace rbac.ResourceAuditLog with dedicated boundary_log resource type. func (q *querier) GetBoundarySessionByID(ctx context.Context, id uuid.UUID) (database.BoundarySession, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAuditLog); err != nil { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceBoundaryLog); err != nil { return database.BoundarySession{}, err } return q.db.GetBoundarySessionByID(ctx, id) @@ -4508,6 +4562,13 @@ func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, return q.db.GetUnexpiredLicenses(ctx) } +func (q *querier) GetUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (database.UserAiBudgetOverride, error) { + if _, err := q.GetUserByID(ctx, userID); err != nil { // AuthZ check + return database.UserAiBudgetOverride{}, err + } + return q.db.GetUserAIBudgetOverride(ctx, userID) +} + func (q *querier) GetUserAIProviderKeyByProviderID(ctx context.Context, arg database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { u, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { @@ -4659,7 +4720,8 @@ func (q *querier) GetUserCodeDiffDisplayMode(ctx context.Context, userID uuid.UU } func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { + // If you can read every user, then you can read the count of users. + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUser); err != nil { return 0, err } return q.db.GetUserCount(ctx, includeSystem) @@ -5451,14 +5513,29 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) } -// TODO (PR #24810): Replace rbac.ResourceAuditLog with dedicated boundary_log resource type. -func (q *querier) InsertBoundaryLog(ctx context.Context, arg database.InsertBoundaryLogParams) (database.BoundaryLog, error) { - return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertBoundaryLog)(ctx, arg) +func (q *querier) InsertBoundaryLogs(ctx context.Context, arg database.InsertBoundaryLogsParams) ([]database.BoundaryLog, error) { + session, err := q.db.GetBoundarySessionByID(ctx, arg.SessionID) + if err != nil { + return nil, xerrors.Errorf("get boundary session for owner: %w", err) + } + if err := q.authorizeContext(ctx, policy.ActionCreate, + rbac.ResourceBoundaryLog.WithOwner(session.OwnerID.UUID.String())); err != nil { + return nil, err + } + return q.db.InsertBoundaryLogs(ctx, arg) } -// TODO (PR #24810): Replace rbac.ResourceAuditLog with dedicated boundary_log resource type. func (q *querier) InsertBoundarySession(ctx context.Context, arg database.InsertBoundarySessionParams) (database.BoundarySession, error) { - return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertBoundarySession)(ctx, arg) + row, err := q.db.GetWorkspaceAgentAndWorkspaceByID(ctx, arg.WorkspaceAgentID) + if err != nil { + return database.BoundarySession{}, xerrors.Errorf("get workspace for boundary session owner: %w", err) + } + arg.OwnerID = uuid.NullUUID{UUID: row.WorkspaceTable.OwnerID, Valid: true} + if err := q.authorizeContext(ctx, policy.ActionCreate, + rbac.ResourceBoundaryLog.WithOwner(arg.OwnerID.UUID.String())); err != nil { + return database.BoundarySession{}, err + } + return q.db.InsertBoundarySession(ctx, arg) } func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) { @@ -6174,9 +6251,8 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs) } -// TODO (PR #24810): Replace rbac.ResourceAuditLog with dedicated boundary_log resource type. func (q *querier) ListBoundaryLogsBySessionID(ctx context.Context, arg database.ListBoundaryLogsBySessionIDParams) ([]database.BoundaryLog, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAuditLog); err != nil { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceBoundaryLog); err != nil { return nil, err } return q.db.ListBoundaryLogsBySessionID(ctx, arg) @@ -8312,6 +8388,26 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error { return q.db.UpsertTemplateUsageStats(ctx) } +func (q *querier) UpsertUserAIBudgetOverride(ctx context.Context, arg database.UpsertUserAIBudgetOverrideParams) (database.UserAiBudgetOverride, error) { + // Setting a user's AI budget override affects both the user (their + // per-user spend cap) and the group (spend attribution). + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return database.UserAiBudgetOverride{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, u); err != nil { + return database.UserAiBudgetOverride{}, err + } + g, err := q.db.GetGroupByID(ctx, arg.GroupID) + if err != nil { + return database.UserAiBudgetOverride{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, g); err != nil { + return database.UserAiBudgetOverride{}, err + } + return q.db.UpsertUserAIBudgetOverride(ctx, arg) +} + func (q *querier) UpsertUserAIProviderKey(ctx context.Context, arg database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { u, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index d2d3facb28..63b650adda 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -440,35 +440,55 @@ func (s *MethodTestSuite) TestAuditLogs() { })) } -// TODO (PR #24810): These RBAC assertions use placeholder resource types. -// They will be updated when the dedicated boundary_log resource type is added. func (s *MethodTestSuite) TestBoundaryLogs() { - s.Run("InsertBoundarySession", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { - arg := database.InsertBoundarySessionParams{} - dbm.EXPECT().InsertBoundarySession(gomock.Any(), arg).Return(database.BoundarySession{}, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceAuditLog, policy.ActionCreate) + s.Run("InsertBoundarySession", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + aww := testutil.Fake(s.T(), faker, database.GetWorkspaceAgentAndWorkspaceByIDRow{}) + arg := database.InsertBoundarySessionParams{ + WorkspaceAgentID: aww.WorkspaceAgent.ID, + } + dbm.EXPECT().GetWorkspaceAgentAndWorkspaceByID(gomock.Any(), aww.WorkspaceAgent.ID).Return(aww, nil).AnyTimes() + expectedArg := database.InsertBoundarySessionParams{ + WorkspaceAgentID: aww.WorkspaceAgent.ID, + OwnerID: uuid.NullUUID{UUID: aww.WorkspaceTable.OwnerID, Valid: true}, + } + dbm.EXPECT().InsertBoundarySession(gomock.Any(), expectedArg).Return(database.BoundarySession{}, nil).AnyTimes() + check.Args(arg).Asserts( + rbac.ResourceBoundaryLog.WithOwner(aww.WorkspaceTable.OwnerID.String()), policy.ActionCreate, + ) })) s.Run("GetBoundarySessionByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetBoundarySessionByID(gomock.Any(), uuid.Nil).Return(database.BoundarySession{}, nil).AnyTimes() - check.Args(uuid.Nil).Asserts(rbac.ResourceAuditLog, policy.ActionRead) + check.Args(uuid.Nil).Asserts(rbac.ResourceBoundaryLog, policy.ActionRead) })) - s.Run("InsertBoundaryLog", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { - arg := database.InsertBoundaryLogParams{} - dbm.EXPECT().InsertBoundaryLog(gomock.Any(), arg).Return(database.BoundaryLog{}, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceAuditLog, policy.ActionCreate) + s.Run("InsertBoundaryLogs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ownerID := uuid.New() + sessionID := uuid.New() + session := database.BoundarySession{ + ID: sessionID, + OwnerID: uuid.NullUUID{UUID: ownerID, Valid: true}, + } + arg := database.InsertBoundaryLogsParams{ + SessionID: sessionID, + ID: []uuid.UUID{uuid.New(), uuid.New()}, + } + dbm.EXPECT().GetBoundarySessionByID(gomock.Any(), sessionID).Return(session, nil).AnyTimes() + dbm.EXPECT().InsertBoundaryLogs(gomock.Any(), arg).Return([]database.BoundaryLog{}, nil).AnyTimes() + check.Args(arg).Asserts( + rbac.ResourceBoundaryLog.WithOwner(ownerID.String()), policy.ActionCreate, + ) })) s.Run("GetBoundaryLogByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetBoundaryLogByID(gomock.Any(), uuid.Nil).Return(database.BoundaryLog{}, nil).AnyTimes() - check.Args(uuid.Nil).Asserts(rbac.ResourceAuditLog, policy.ActionRead) + check.Args(uuid.Nil).Asserts(rbac.ResourceBoundaryLog, policy.ActionRead) })) s.Run("ListBoundaryLogsBySessionID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { arg := database.ListBoundaryLogsBySessionIDParams{} dbm.EXPECT().ListBoundaryLogsBySessionID(gomock.Any(), arg).Return([]database.BoundaryLog{}, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceAuditLog, policy.ActionRead) + check.Args(arg).Asserts(rbac.ResourceBoundaryLog, policy.ActionRead) })) s.Run("DeleteOldBoundaryLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().DeleteOldBoundaryLogs(gomock.Any(), database.DeleteOldBoundaryLogsParams{}).Return(int64(0), nil).AnyTimes() - check.Args(database.DeleteOldBoundaryLogsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete) + check.Args(database.DeleteOldBoundaryLogsParams{}).Asserts(rbac.ResourceBoundaryLog, policy.ActionDelete) })) } @@ -4587,7 +4607,7 @@ func (s *MethodTestSuite) TestSystemFunctions() { })) s.Run("GetUserCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetUserCount(gomock.Any(), false).Return(int64(0), nil).AnyTimes() - check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0)) + check.Args(false).Asserts(rbac.ResourceUser, policy.ActionRead).Returns(int64(0)) })) s.Run("GetTemplates", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetTemplates(gomock.Any()).Return([]database.Template{}, nil).AnyTimes() @@ -6464,6 +6484,36 @@ func (s *MethodTestSuite) TestAIBridge() { check.Args(g.ID).Asserts(g, policy.ActionUpdate).Returns(b) })) + s.Run("GetUserAIBudgetOverride", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + user := testutil.Fake(s.T(), faker, database.User{}) + override := testutil.Fake(s.T(), faker, database.UserAiBudgetOverride{UserID: user.ID}) + dbm.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil).AnyTimes() + dbm.EXPECT().GetUserAIBudgetOverride(gomock.Any(), user.ID).Return(override, nil).AnyTimes() + check.Args(user.ID).Asserts(user, policy.ActionRead).Returns(override) + })) + + s.Run("UpsertUserAIBudgetOverride", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + user := testutil.Fake(s.T(), faker, database.User{}) + group := testutil.Fake(s.T(), faker, database.Group{}) + override := testutil.Fake(s.T(), faker, database.UserAiBudgetOverride{UserID: user.ID, GroupID: group.ID}) + arg := database.UpsertUserAIBudgetOverrideParams{UserID: user.ID, GroupID: group.ID, SpendLimitMicros: override.SpendLimitMicros} + dbm.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil).AnyTimes() + dbm.EXPECT().GetGroupByID(gomock.Any(), group.ID).Return(group, nil).AnyTimes() + dbm.EXPECT().UpsertUserAIBudgetOverride(gomock.Any(), arg).Return(override, nil).AnyTimes() + check.Args(arg).Asserts(user, policy.ActionUpdate, group, policy.ActionUpdate).Returns(override) + })) + + s.Run("DeleteUserAIBudgetOverride", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + user := testutil.Fake(s.T(), faker, database.User{}) + group := testutil.Fake(s.T(), faker, database.Group{}) + override := testutil.Fake(s.T(), faker, database.UserAiBudgetOverride{UserID: user.ID, GroupID: group.ID}) + dbm.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil).AnyTimes() + dbm.EXPECT().GetUserAIBudgetOverride(gomock.Any(), user.ID).Return(override, nil).AnyTimes() + dbm.EXPECT().GetGroupByID(gomock.Any(), group.ID).Return(group, nil).AnyTimes() + dbm.EXPECT().DeleteUserAIBudgetOverride(gomock.Any(), user.ID).Return(override, nil).AnyTimes() + check.Args(user.ID).Asserts(user, policy.ActionUpdate, group, policy.ActionUpdate).Returns(override) + })) + s.Run("GetAIProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { provider := testutil.Fake(s.T(), faker, database.AIProvider{}) dbm.EXPECT().GetAIProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 834bad6274..416a2b7257 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -458,6 +458,7 @@ func BoundarySession(t testing.TB, db database.Store, seed database.BoundarySess session, err := db.InsertBoundarySession(genCtx, database.InsertBoundarySessionParams{ ID: takeFirst(seed.ID, uuid.New()), WorkspaceAgentID: takeFirst(seed.WorkspaceAgentID, uuid.New()), + OwnerID: takeFirst(seed.OwnerID, uuid.NullUUID{UUID: uuid.New(), Valid: true}), ConfinedProcessName: takeFirst(seed.ConfinedProcessName, "claude-code"), StartedAt: takeFirst(seed.StartedAt, dbtime.Now()), UpdatedAt: takeFirst(seed.UpdatedAt, dbtime.Now()), @@ -466,20 +467,52 @@ func BoundarySession(t testing.TB, db database.Store, seed database.BoundarySess return session } -func BoundaryLog(t testing.TB, db database.Store, seed database.BoundaryLog) database.BoundaryLog { - log, err := db.InsertBoundaryLog(genCtx, database.InsertBoundaryLogParams{ - ID: takeFirst(seed.ID, uuid.New()), - SessionID: seed.SessionID, - SequenceNumber: takeFirst(seed.SequenceNumber, 0), - CapturedAt: takeFirst(seed.CapturedAt, dbtime.Now()), - CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()), - Proto: takeFirst(seed.Proto, "http"), - Method: takeFirst(seed.Method, "GET"), - Detail: takeFirst(seed.Detail, "https://example.com"), - MatchedRule: seed.MatchedRule, +func BoundaryLogs(t testing.TB, db database.Store, seed []database.BoundaryLog) []database.BoundaryLog { + ids := make([]uuid.UUID, 0, len(seed)) + sessionID := seed[0].SessionID + sequenceNumbers := make([]int32, 0, len(seed)) + capturedAt := make([]time.Time, 0, len(seed)) + createdAt := make([]time.Time, 0, len(seed)) + protos := make([]string, 0, len(seed)) + method := make([]string, 0, len(seed)) + detail := make([]string, 0, len(seed)) + matchedRule := make([]string, 0, len(seed)) + for _, log := range seed { + log = takeFirstBoundaryLog(log) + ids = append(ids, log.ID) + sequenceNumbers = append(sequenceNumbers, log.SequenceNumber) + capturedAt = append(capturedAt, log.CapturedAt) + createdAt = append(createdAt, log.CreatedAt) + protos = append(protos, log.Proto) + method = append(method, log.Method) + detail = append(detail, log.Detail) + matchedRule = append(matchedRule, log.MatchedRule.String) + } + logs, err := db.InsertBoundaryLogs(genCtx, database.InsertBoundaryLogsParams{ + ID: ids, + SessionID: sessionID, + SequenceNumber: sequenceNumbers, + CapturedAt: capturedAt, + CreatedAt: createdAt, + Proto: protos, + Method: method, + Detail: detail, + MatchedRule: matchedRule, }) - require.NoError(t, err, "insert boundary log") - return log + require.NoError(t, err, "insert boundary logs") + return logs +} + +func takeFirstBoundaryLog(seed database.BoundaryLog) database.BoundaryLog { + seed.ID = takeFirst(seed.ID, uuid.New()) + seed.SessionID = takeFirst(seed.SessionID, uuid.New()) + seed.SequenceNumber = takeFirst(seed.SequenceNumber, 0) + seed.CapturedAt = takeFirst(seed.CapturedAt, dbtime.Now()) + seed.CreatedAt = takeFirst(seed.CreatedAt, dbtime.Now()) + seed.Proto = takeFirst(seed.Proto, "http") + seed.Method = takeFirst(seed.Method, "GET") + seed.Detail = takeFirst(seed.Detail, "https://example.com") + return seed } func Template(t testing.TB, db database.Store, seed database.Template) database.Template { @@ -1969,8 +2002,9 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA }) if endedAt != nil { interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{ - ID: interception.ID, - EndedAt: *endedAt, + ID: interception.ID, + EndedAt: *endedAt, + CredentialHint: takeFirst(seed.CredentialHint, ""), }) require.NoError(t, err, "insert aibridge interception") } diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 304d9c73ec..765bf0abc9 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -793,6 +793,14 @@ func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTa return r0, r1 } +func (m queryMetricsStore) DeleteUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (database.UserAiBudgetOverride, error) { + start := time.Now() + r0, r1 := m.s.DeleteUserAIBudgetOverride(ctx, userID) + m.queryLatencies.WithLabelValues("DeleteUserAIBudgetOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserAIBudgetOverride").Inc() + return r0, r1 +} + func (m queryMetricsStore) DeleteUserAIProviderKey(ctx context.Context, arg database.DeleteUserAIProviderKeyParams) error { start := time.Now() r0 := m.s.DeleteUserAIProviderKey(ctx, arg) @@ -2905,6 +2913,14 @@ func (m queryMetricsStore) GetUnexpiredLicenses(ctx context.Context) ([]database return r0, r1 } +func (m queryMetricsStore) GetUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (database.UserAiBudgetOverride, error) { + start := time.Now() + r0, r1 := m.s.GetUserAIBudgetOverride(ctx, userID) + m.queryLatencies.WithLabelValues("GetUserAIBudgetOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserAIBudgetOverride").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetUserAIProviderKeyByProviderID(ctx context.Context, arg database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { start := time.Now() r0, r1 := m.s.GetUserAIProviderKeyByProviderID(ctx, arg) @@ -3753,11 +3769,11 @@ func (m queryMetricsStore) InsertAuditLog(ctx context.Context, arg database.Inse return r0, r1 } -func (m queryMetricsStore) InsertBoundaryLog(ctx context.Context, arg database.InsertBoundaryLogParams) (database.BoundaryLog, error) { +func (m queryMetricsStore) InsertBoundaryLogs(ctx context.Context, arg database.InsertBoundaryLogsParams) ([]database.BoundaryLog, error) { start := time.Now() - r0, r1 := m.s.InsertBoundaryLog(ctx, arg) - m.queryLatencies.WithLabelValues("InsertBoundaryLog").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertBoundaryLog").Inc() + r0, r1 := m.s.InsertBoundaryLogs(ctx, arg) + m.queryLatencies.WithLabelValues("InsertBoundaryLogs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertBoundaryLogs").Inc() return r0, r1 } @@ -6057,6 +6073,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error { return r0 } +func (m queryMetricsStore) UpsertUserAIBudgetOverride(ctx context.Context, arg database.UpsertUserAIBudgetOverrideParams) (database.UserAiBudgetOverride, error) { + start := time.Now() + r0, r1 := m.s.UpsertUserAIBudgetOverride(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertUserAIBudgetOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserAIBudgetOverride").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpsertUserAIProviderKey(ctx context.Context, arg database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { start := time.Now() r0, r1 := m.s.UpsertUserAIProviderKey(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 637594ae2c..081ca4462d 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1349,6 +1349,21 @@ func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg) } +// DeleteUserAIBudgetOverride mocks base method. +func (m *MockStore) DeleteUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (database.UserAiBudgetOverride, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserAIBudgetOverride", ctx, userID) + ret0, _ := ret[0].(database.UserAiBudgetOverride) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteUserAIBudgetOverride indicates an expected call of DeleteUserAIBudgetOverride. +func (mr *MockStoreMockRecorder) DeleteUserAIBudgetOverride(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserAIBudgetOverride", reflect.TypeOf((*MockStore)(nil).DeleteUserAIBudgetOverride), ctx, userID) +} + // DeleteUserAIProviderKey mocks base method. func (m *MockStore) DeleteUserAIProviderKey(ctx context.Context, arg database.DeleteUserAIProviderKeyParams) error { m.ctrl.T.Helper() @@ -5445,6 +5460,21 @@ func (mr *MockStoreMockRecorder) GetUnexpiredLicenses(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUnexpiredLicenses", reflect.TypeOf((*MockStore)(nil).GetUnexpiredLicenses), ctx) } +// GetUserAIBudgetOverride mocks base method. +func (m *MockStore) GetUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (database.UserAiBudgetOverride, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserAIBudgetOverride", ctx, userID) + ret0, _ := ret[0].(database.UserAiBudgetOverride) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserAIBudgetOverride indicates an expected call of GetUserAIBudgetOverride. +func (mr *MockStoreMockRecorder) GetUserAIBudgetOverride(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserAIBudgetOverride", reflect.TypeOf((*MockStore)(nil).GetUserAIBudgetOverride), ctx, userID) +} + // GetUserAIProviderKeyByProviderID mocks base method. func (m *MockStore) GetUserAIProviderKeyByProviderID(ctx context.Context, arg database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { m.ctrl.T.Helper() @@ -7049,19 +7079,19 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), ctx, arg) } -// InsertBoundaryLog mocks base method. -func (m *MockStore) InsertBoundaryLog(ctx context.Context, arg database.InsertBoundaryLogParams) (database.BoundaryLog, error) { +// InsertBoundaryLogs mocks base method. +func (m *MockStore) InsertBoundaryLogs(ctx context.Context, arg database.InsertBoundaryLogsParams) ([]database.BoundaryLog, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InsertBoundaryLog", ctx, arg) - ret0, _ := ret[0].(database.BoundaryLog) + ret := m.ctrl.Call(m, "InsertBoundaryLogs", ctx, arg) + ret0, _ := ret[0].([]database.BoundaryLog) ret1, _ := ret[1].(error) return ret0, ret1 } -// InsertBoundaryLog indicates an expected call of InsertBoundaryLog. -func (mr *MockStoreMockRecorder) InsertBoundaryLog(ctx, arg any) *gomock.Call { +// InsertBoundaryLogs indicates an expected call of InsertBoundaryLogs. +func (mr *MockStoreMockRecorder) InsertBoundaryLogs(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertBoundaryLog", reflect.TypeOf((*MockStore)(nil).InsertBoundaryLog), ctx, arg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertBoundaryLogs", reflect.TypeOf((*MockStore)(nil).InsertBoundaryLogs), ctx, arg) } // InsertBoundarySession mocks base method. @@ -11359,6 +11389,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx) } +// UpsertUserAIBudgetOverride mocks base method. +func (m *MockStore) UpsertUserAIBudgetOverride(ctx context.Context, arg database.UpsertUserAIBudgetOverrideParams) (database.UserAiBudgetOverride, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertUserAIBudgetOverride", ctx, arg) + ret0, _ := ret[0].(database.UserAiBudgetOverride) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertUserAIBudgetOverride indicates an expected call of UpsertUserAIBudgetOverride. +func (mr *MockStoreMockRecorder) UpsertUserAIBudgetOverride(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserAIBudgetOverride", reflect.TypeOf((*MockStore)(nil).UpsertUserAIBudgetOverride), ctx, arg) +} + // UpsertUserAIProviderKey mocks base method. func (m *MockStore) UpsertUserAIProviderKey(ctx context.Context, arg database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 9c2df36cab..82aa376d34 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -249,7 +249,11 @@ CREATE TYPE api_key_scope AS ENUM ( 'user_skill:read', 'user_skill:update', 'user_skill:delete', - 'user_skill:*' + 'user_skill:*', + 'boundary_log:*', + 'boundary_log:create', + 'boundary_log:delete', + 'boundary_log:read' ); CREATE TYPE app_sharing_level AS ENUM ( @@ -837,6 +841,42 @@ BEGIN END; $$; +CREATE FUNCTION delete_user_ai_budget_overrides_on_group_member_delete() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + DELETE FROM user_ai_budget_overrides + WHERE user_id = OLD.user_id AND group_id = OLD.group_id; + RETURN OLD; +END; +$$; + +CREATE FUNCTION delete_user_ai_budget_overrides_on_org_member_delete() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + DELETE FROM user_ai_budget_overrides + WHERE user_id = OLD.user_id AND group_id = OLD.organization_id; + RETURN OLD; +END; +$$; + +CREATE FUNCTION enforce_user_ai_budget_override_membership() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM group_members_expanded + WHERE user_id = NEW.user_id AND group_id = NEW.group_id + ) THEN + RAISE EXCEPTION 'user % is not a member of group %', NEW.user_id, NEW.group_id + USING ERRCODE = 'check_violation', + CONSTRAINT = 'user_ai_budget_overrides_must_be_group_member'; + END IF; + RETURN NEW; +END; +$$; + CREATE FUNCTION enforce_user_secrets_per_user_limits() RETURNS trigger LANGUAGE plpgsql AS $$ @@ -1216,6 +1256,17 @@ BEGIN END; $$; +CREATE FUNCTION remove_mcp_server_config_id_from_chats() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + UPDATE chats + SET mcp_server_ids = array_remove(mcp_server_ids, OLD.id) + WHERE OLD.id = ANY(mcp_server_ids); + RETURN OLD; +END; +$$; + CREATE FUNCTION remove_organization_member_role() RETURNS trigger LANGUAGE plpgsql AS $$ @@ -1474,7 +1525,8 @@ CREATE TABLE boundary_sessions ( workspace_agent_id uuid NOT NULL, confined_process_name text NOT NULL, started_at timestamp with time zone NOT NULL, - updated_at timestamp with time zone NOT NULL + updated_at timestamp with time zone NOT NULL, + owner_id uuid ); COMMENT ON TABLE boundary_sessions IS 'Boundary session metadata. Each row represents a single invocation of a Boundary process wrapping a confined agent.'; @@ -1489,6 +1541,8 @@ COMMENT ON COLUMN boundary_sessions.started_at IS 'Time when the first log for t COMMENT ON COLUMN boundary_sessions.updated_at IS 'Time when the session was last updated.'; +COMMENT ON COLUMN boundary_sessions.owner_id IS 'The ID of the user who owns the workspace. NULL if the user has been deleted.'; + CREATE TABLE boundary_usage_stats ( replica_id uuid NOT NULL, unique_workspaces_count bigint DEFAULT 0 NOT NULL, @@ -3119,6 +3173,17 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.'; +CREATE TABLE user_ai_budget_overrides ( + user_id uuid NOT NULL, + group_id uuid NOT NULL, + spend_limit_micros bigint NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + CONSTRAINT user_ai_budget_overrides_spend_limit_micros_check CHECK ((spend_limit_micros >= 0)) +); + +COMMENT ON TABLE user_ai_budget_overrides IS 'Per-user AI spend override that supersedes group budget resolution.'; + CREATE TABLE user_ai_provider_keys ( id uuid DEFAULT gen_random_uuid() NOT NULL, user_id uuid NOT NULL, @@ -3968,6 +4033,9 @@ ALTER TABLE ONLY usage_events_daily ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); +ALTER TABLE ONLY user_ai_budget_overrides + ADD CONSTRAINT user_ai_budget_overrides_pkey PRIMARY KEY (user_id); + ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id); @@ -4435,6 +4503,10 @@ CREATE TRIGGER inhibit_enqueue_if_disabled BEFORE INSERT ON notification_message CREATE TRIGGER protect_deleting_organizations BEFORE UPDATE ON organizations FOR EACH ROW WHEN (((new.deleted = true) AND (old.deleted = false))) EXECUTE FUNCTION protect_deleting_organizations(); +CREATE TRIGGER remove_chat_mcp_server_config_id BEFORE DELETE ON mcp_server_configs FOR EACH ROW EXECUTE FUNCTION remove_mcp_server_config_id_from_chats(); + +COMMENT ON TRIGGER remove_chat_mcp_server_config_id ON mcp_server_configs IS 'When an MCP server config is deleted, this trigger removes its ID from all chats.'; + CREATE TRIGGER remove_organization_member_custom_role BEFORE DELETE ON custom_roles FOR EACH ROW EXECUTE FUNCTION remove_organization_member_role(); COMMENT ON TRIGGER remove_organization_member_custom_role ON custom_roles IS 'When a custom_role is deleted, this trigger removes the role from all organization members.'; @@ -4445,6 +4517,12 @@ CREATE TRIGGER trigger_delete_group_members_on_org_member_delete BEFORE DELETE O CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_provider_app_tokens FOR EACH ROW EXECUTE FUNCTION delete_deleted_oauth2_provider_app_token_api_key(); +CREATE TRIGGER trigger_delete_user_ai_budget_overrides_on_group_member_delete BEFORE DELETE ON group_members FOR EACH ROW EXECUTE FUNCTION delete_user_ai_budget_overrides_on_group_member_delete(); + +CREATE TRIGGER trigger_delete_user_ai_budget_overrides_on_org_member_delete BEFORE DELETE ON organization_members FOR EACH ROW EXECUTE FUNCTION delete_user_ai_budget_overrides_on_org_member_delete(); + +CREATE TRIGGER trigger_enforce_user_ai_budget_override_membership BEFORE INSERT OR UPDATE ON user_ai_budget_overrides FOR EACH ROW EXECUTE FUNCTION enforce_user_ai_budget_override_membership(); + CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles(); @@ -4494,6 +4572,9 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY boundary_logs ADD CONSTRAINT boundary_logs_session_id_fkey FOREIGN KEY (session_id) REFERENCES boundary_sessions(id) ON DELETE CASCADE; +ALTER TABLE ONLY boundary_sessions + ADD CONSTRAINT boundary_sessions_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE SET NULL; + ALTER TABLE ONLY boundary_sessions ADD CONSTRAINT boundary_sessions_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id); @@ -4767,6 +4848,12 @@ ALTER TABLE ONLY templates ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY user_ai_budget_overrides + ADD CONSTRAINT user_ai_budget_overrides_group_id_fkey FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE; + +ALTER TABLE ONLY user_ai_budget_overrides + ADD CONSTRAINT user_ai_budget_overrides_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 624f3229b6..5eeb24587a 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -13,6 +13,7 @@ const ( ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id); ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyBoundaryLogsSessionID ForeignKeyConstraint = "boundary_logs_session_id_fkey" // ALTER TABLE ONLY boundary_logs ADD CONSTRAINT boundary_logs_session_id_fkey FOREIGN KEY (session_id) REFERENCES boundary_sessions(id) ON DELETE CASCADE; + ForeignKeyBoundarySessionsOwnerID ForeignKeyConstraint = "boundary_sessions_owner_id_fkey" // ALTER TABLE ONLY boundary_sessions ADD CONSTRAINT boundary_sessions_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE SET NULL; ForeignKeyBoundarySessionsWorkspaceAgentID ForeignKeyConstraint = "boundary_sessions_workspace_agent_id_fkey" // ALTER TABLE ONLY boundary_sessions ADD CONSTRAINT boundary_sessions_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id); ForeignKeyChatDebugRunsChatID ForeignKeyConstraint = "chat_debug_runs_chat_id_fkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatDebugStepsChatID ForeignKeyConstraint = "chat_debug_steps_chat_id_fkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; @@ -104,6 +105,8 @@ const ( ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE; ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT; ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + ForeignKeyUserAiBudgetOverridesGroupID ForeignKeyConstraint = "user_ai_budget_overrides_group_id_fkey" // ALTER TABLE ONLY user_ai_budget_overrides ADD CONSTRAINT user_ai_budget_overrides_group_id_fkey FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE; + ForeignKeyUserAiBudgetOverridesUserID ForeignKeyConstraint = "user_ai_budget_overrides_user_id_fkey" // ALTER TABLE ONLY user_ai_budget_overrides ADD CONSTRAINT user_ai_budget_overrides_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserAiProviderKeysAiProviderID ForeignKeyConstraint = "user_ai_provider_keys_ai_provider_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE; ForeignKeyUserAiProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_ai_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyUserAiProviderKeysUserID ForeignKeyConstraint = "user_ai_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000510_cleanup_chats_mcp_server_ids_on_delete.down.sql b/coderd/database/migrations/000510_cleanup_chats_mcp_server_ids_on_delete.down.sql new file mode 100644 index 0000000000..15c10e19e6 --- /dev/null +++ b/coderd/database/migrations/000510_cleanup_chats_mcp_server_ids_on_delete.down.sql @@ -0,0 +1,2 @@ +DROP TRIGGER IF EXISTS remove_chat_mcp_server_config_id ON mcp_server_configs; +DROP FUNCTION IF EXISTS remove_mcp_server_config_id_from_chats; diff --git a/coderd/database/migrations/000510_cleanup_chats_mcp_server_ids_on_delete.up.sql b/coderd/database/migrations/000510_cleanup_chats_mcp_server_ids_on_delete.up.sql new file mode 100644 index 0000000000..5366328b3c --- /dev/null +++ b/coderd/database/migrations/000510_cleanup_chats_mcp_server_ids_on_delete.up.sql @@ -0,0 +1,41 @@ +-- Remove already-stale MCP server references before future deletes are +-- handled by the trigger below. +UPDATE chats +SET mcp_server_ids = ( + SELECT COALESCE(array_agg(ids.mcp_server_id ORDER BY ids.position), '{}'::uuid[]) + FROM unnest(chats.mcp_server_ids) WITH ORDINALITY AS ids(mcp_server_id, position) + WHERE EXISTS ( + SELECT 1 + FROM mcp_server_configs + WHERE mcp_server_configs.id = ids.mcp_server_id + ) +) +WHERE EXISTS ( + SELECT 1 + FROM unnest(chats.mcp_server_ids) AS ids(mcp_server_id) + WHERE NOT EXISTS ( + SELECT 1 + FROM mcp_server_configs + WHERE mcp_server_configs.id = ids.mcp_server_id + ) +); + +CREATE OR REPLACE FUNCTION remove_mcp_server_config_id_from_chats() + RETURNS TRIGGER AS +$$ +BEGIN + UPDATE chats + SET mcp_server_ids = array_remove(mcp_server_ids, OLD.id) + WHERE OLD.id = ANY(mcp_server_ids); + RETURN OLD; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER remove_chat_mcp_server_config_id + BEFORE DELETE ON mcp_server_configs FOR EACH ROW + EXECUTE PROCEDURE remove_mcp_server_config_id_from_chats(); + +COMMENT ON TRIGGER + remove_chat_mcp_server_config_id + ON mcp_server_configs IS + 'When an MCP server config is deleted, this trigger removes its ID from all chats.'; diff --git a/coderd/database/migrations/000511_boundary_log_scopes.down.sql b/coderd/database/migrations/000511_boundary_log_scopes.down.sql new file mode 100644 index 0000000000..5a1baaa20c --- /dev/null +++ b/coderd/database/migrations/000511_boundary_log_scopes.down.sql @@ -0,0 +1 @@ +-- No-op for boundary_log scopes: keep enum values to avoid dependency churn. diff --git a/coderd/database/migrations/000511_boundary_log_scopes.up.sql b/coderd/database/migrations/000511_boundary_log_scopes.up.sql new file mode 100644 index 0000000000..12ec141591 --- /dev/null +++ b/coderd/database/migrations/000511_boundary_log_scopes.up.sql @@ -0,0 +1,5 @@ +-- Add boundary_log scopes for RBAC. +ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'boundary_log:*'; +ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'boundary_log:create'; +ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'boundary_log:delete'; +ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'boundary_log:read'; diff --git a/coderd/database/migrations/000512_boundary_session_owner.down.sql b/coderd/database/migrations/000512_boundary_session_owner.down.sql new file mode 100644 index 0000000000..3429fee351 --- /dev/null +++ b/coderd/database/migrations/000512_boundary_session_owner.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE boundary_sessions DROP CONSTRAINT IF EXISTS boundary_sessions_owner_id_fkey; +ALTER TABLE boundary_sessions DROP COLUMN IF EXISTS owner_id; diff --git a/coderd/database/migrations/000512_boundary_session_owner.up.sql b/coderd/database/migrations/000512_boundary_session_owner.up.sql new file mode 100644 index 0000000000..d97140df57 --- /dev/null +++ b/coderd/database/migrations/000512_boundary_session_owner.up.sql @@ -0,0 +1,28 @@ +-- Add owner_id to boundary_sessions to avoid expensive JOINs when +-- deriving the workspace owner for RBAC checks during log insertion. +ALTER TABLE boundary_sessions ADD COLUMN owner_id uuid; + +COMMENT ON COLUMN boundary_sessions.owner_id IS 'The ID of the user who owns the workspace. NULL if the user has been deleted.'; + +-- Backfill owner_id from the workspace agent -> workspace -> owner chain. +-- Soft-deleted agents and workspaces are included so that their audit +-- data is preserved. +UPDATE boundary_sessions bs +SET owner_id = w.owner_id +FROM workspace_agents wa +JOIN workspace_resources wr ON wa.resource_id = wr.id +JOIN provisioner_jobs pj ON wr.job_id = pj.id +JOIN workspace_builds wb ON pj.id = wb.job_id +JOIN workspaces w ON wb.workspace_id = w.id +WHERE wa.id = bs.workspace_agent_id + AND pj.type = 'workspace_build'; + +-- Delete any sessions that could not be backfilled (orphaned data +-- with no resolvable workspace agent or workspace build chain). +DELETE FROM boundary_sessions WHERE owner_id IS NULL; + +-- Add FK constraint. SET NULL preserves audit data when a user is +-- hard-deleted; the session and its logs survive with a NULL owner. +ALTER TABLE boundary_sessions + ADD CONSTRAINT boundary_sessions_owner_id_fkey + FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE SET NULL; diff --git a/coderd/database/migrations/000513_user_ai_budget_overrides.down.sql b/coderd/database/migrations/000513_user_ai_budget_overrides.down.sql new file mode 100644 index 0000000000..1a1a8e2160 --- /dev/null +++ b/coderd/database/migrations/000513_user_ai_budget_overrides.down.sql @@ -0,0 +1,7 @@ +DROP TRIGGER IF EXISTS trigger_delete_user_ai_budget_overrides_on_org_member_delete ON organization_members; +DROP FUNCTION IF EXISTS delete_user_ai_budget_overrides_on_org_member_delete; +DROP TRIGGER IF EXISTS trigger_delete_user_ai_budget_overrides_on_group_member_delete ON group_members; +DROP FUNCTION IF EXISTS delete_user_ai_budget_overrides_on_group_member_delete; +DROP TRIGGER IF EXISTS trigger_enforce_user_ai_budget_override_membership ON user_ai_budget_overrides; +DROP FUNCTION IF EXISTS enforce_user_ai_budget_override_membership; +DROP TABLE IF EXISTS user_ai_budget_overrides CASCADE; diff --git a/coderd/database/migrations/000513_user_ai_budget_overrides.up.sql b/coderd/database/migrations/000513_user_ai_budget_overrides.up.sql new file mode 100644 index 0000000000..b1ab1cd9d2 --- /dev/null +++ b/coderd/database/migrations/000513_user_ai_budget_overrides.up.sql @@ -0,0 +1,76 @@ +CREATE TABLE user_ai_budget_overrides ( + user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + group_id UUID NOT NULL REFERENCES groups(id) ON DELETE CASCADE, + -- Spend limit applied to the user, in micro-units (1 unit = 1,000,000). + spend_limit_micros BIGINT NOT NULL CHECK (spend_limit_micros >= 0), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + -- The membership invariant (user must be a member of the attributed + -- group, including when that group is "Everyone") would naturally be + -- a composite FK to group_members_expanded, but PostgreSQL does not + -- allow FKs to views. It's enforced instead by a write-time trigger + -- on this table and removal-time triggers on the underlying + -- membership tables. +); + +COMMENT ON TABLE user_ai_budget_overrides IS 'Per-user AI spend override that supersedes group budget resolution.'; + +-- Write-time membership check. Reads from group_members_expanded so +-- the "Everyone" group (whose membership lives in organization_members) +-- is correctly handled. Raises check_violation with a constraint name +-- so callers can match it via database.IsCheckViolation in Go. +CREATE FUNCTION enforce_user_ai_budget_override_membership() RETURNS TRIGGER + LANGUAGE plpgsql +AS $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM group_members_expanded + WHERE user_id = NEW.user_id AND group_id = NEW.group_id + ) THEN + RAISE EXCEPTION 'user % is not a member of group %', NEW.user_id, NEW.group_id + USING ERRCODE = 'check_violation', + CONSTRAINT = 'user_ai_budget_overrides_must_be_group_member'; + END IF; + RETURN NEW; +END; +$$; + +CREATE TRIGGER trigger_enforce_user_ai_budget_override_membership + BEFORE INSERT OR UPDATE ON user_ai_budget_overrides + FOR EACH ROW +EXECUTE PROCEDURE enforce_user_ai_budget_override_membership(); + +-- When a user is removed from a regular group (any group except +-- "Everyone"), delete any override attributed to that group. +CREATE FUNCTION delete_user_ai_budget_overrides_on_group_member_delete() RETURNS TRIGGER + LANGUAGE plpgsql +AS $$ +BEGIN + DELETE FROM user_ai_budget_overrides + WHERE user_id = OLD.user_id AND group_id = OLD.group_id; + RETURN OLD; +END; +$$; + +CREATE TRIGGER trigger_delete_user_ai_budget_overrides_on_group_member_delete + BEFORE DELETE ON group_members + FOR EACH ROW +EXECUTE PROCEDURE delete_user_ai_budget_overrides_on_group_member_delete(); + +-- When a user is removed from an organization, delete any override +-- attributed to that organization's "Everyone" group (which has +-- id == organization_id). +CREATE FUNCTION delete_user_ai_budget_overrides_on_org_member_delete() RETURNS TRIGGER + LANGUAGE plpgsql +AS $$ +BEGIN + DELETE FROM user_ai_budget_overrides + WHERE user_id = OLD.user_id AND group_id = OLD.organization_id; + RETURN OLD; +END; +$$; + +CREATE TRIGGER trigger_delete_user_ai_budget_overrides_on_org_member_delete + BEFORE DELETE ON organization_members + FOR EACH ROW +EXECUTE PROCEDURE delete_user_ai_budget_overrides_on_org_member_delete(); diff --git a/coderd/database/migrations/testdata/fixtures/000512_boundary_session_owner.up.sql b/coderd/database/migrations/testdata/fixtures/000512_boundary_session_owner.up.sql new file mode 100644 index 0000000000..d1942bd5a5 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000512_boundary_session_owner.up.sql @@ -0,0 +1,42 @@ +-- Re-insert boundary session and log fixture data after migration 000511 +-- deletes orphaned rows (the original fixture's workspace_agent links to a +-- template_version_import job, not a workspace_build, so the backfill +-- cannot resolve the owner). + +INSERT INTO boundary_sessions ( + id, + workspace_agent_id, + confined_process_name, + started_at, + updated_at, + owner_id +) VALUES ( + 'a1b2c3d4-e5f6-4890-abcd-ef1234567890', + '45e89705-e09d-4850-bcec-f9a937f5d78d', + 'claude-code', + '2026-04-01 10:00:00+00', + '2026-04-01 10:00:00+00', + '30095c71-380b-457a-8995-97b8ee6e5307' +); + +INSERT INTO boundary_logs ( + id, + session_id, + sequence_number, + captured_at, + created_at, + proto, + method, + detail, + matched_rule +) VALUES ( + 'b2c3d4e5-f6a7-4901-bcde-f12345678901', + 'a1b2c3d4-e5f6-4890-abcd-ef1234567890', + 0, + '2026-04-01 10:00:01+00', + '2026-04-01 10:00:00+00', + 'http', + 'GET', + 'https://api.anthropic.com/v1/messages', + 'domain=api.anthropic.com' +); diff --git a/coderd/database/migrations/testdata/fixtures/000513_user_ai_budget_overrides.up.sql b/coderd/database/migrations/testdata/fixtures/000513_user_ai_budget_overrides.up.sql new file mode 100644 index 0000000000..787b808b7d --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000513_user_ai_budget_overrides.up.sql @@ -0,0 +1,15 @@ +-- Seed a group_members row so the override below references a real +-- membership. +INSERT INTO group_members ( + user_id, + group_id +) VALUES + ('30095c71-380b-457a-8995-97b8ee6e5307', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1') +ON CONFLICT DO NOTHING; + +INSERT INTO user_ai_budget_overrides ( + user_id, + group_id, + spend_limit_micros +) VALUES + ('30095c71-380b-457a-8995-97b8ee6e5307', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', 500000000); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index bab9759762..62eb12a1d2 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -1003,3 +1003,10 @@ type UpsertConnectionLogParams struct { func (r GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow) RBACObject() rbac.Object { return r.WorkspaceTable.RBACObject() } + +func (s BoundarySession) RBACObject() rbac.Object { + if s.OwnerID.Valid { + return rbac.ResourceBoundaryLog.WithOwner(s.OwnerID.UUID.String()) + } + return rbac.ResourceBoundaryLog +} diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 73f973e15c..972a104201 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -413,6 +413,8 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, arg.AfterID, arg.Search, arg.Name, + arg.ExactUsername, + arg.ExactEmail, pq.Array(arg.Status), pq.Array(arg.RbacRole), arg.LastSeenBefore, diff --git a/coderd/database/models.go b/coderd/database/models.go index 940904385a..ebfaa7a051 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -320,6 +320,10 @@ const ( ApiKeyScopeUserSkillUpdate APIKeyScope = "user_skill:update" ApiKeyScopeUserSkillDelete APIKeyScope = "user_skill:delete" ApiKeyScopeUserSkill APIKeyScope = "user_skill:*" + ApiKeyScopeBoundaryLog APIKeyScope = "boundary_log:*" + ApiKeyScopeBoundaryLogCreate APIKeyScope = "boundary_log:create" + ApiKeyScopeBoundaryLogDelete APIKeyScope = "boundary_log:delete" + ApiKeyScopeBoundaryLogRead APIKeyScope = "boundary_log:read" ) func (e *APIKeyScope) Scan(src interface{}) error { @@ -580,7 +584,11 @@ func (e APIKeyScope) Valid() bool { ApiKeyScopeUserSkillRead, ApiKeyScopeUserSkillUpdate, ApiKeyScopeUserSkillDelete, - ApiKeyScopeUserSkill: + ApiKeyScopeUserSkill, + ApiKeyScopeBoundaryLog, + ApiKeyScopeBoundaryLogCreate, + ApiKeyScopeBoundaryLogDelete, + ApiKeyScopeBoundaryLogRead: return true } return false @@ -810,6 +818,10 @@ func AllAPIKeyScopeValues() []APIKeyScope { ApiKeyScopeUserSkillUpdate, ApiKeyScopeUserSkillDelete, ApiKeyScopeUserSkill, + ApiKeyScopeBoundaryLog, + ApiKeyScopeBoundaryLogCreate, + ApiKeyScopeBoundaryLogDelete, + ApiKeyScopeBoundaryLogRead, } } @@ -4543,6 +4555,8 @@ type BoundarySession struct { StartedAt time.Time `db:"started_at" json:"started_at"` // Time when the session was last updated. UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + // The ID of the user who owns the workspace. NULL if the user has been deleted. + OwnerID uuid.NullUUID `db:"owner_id" json:"owner_id"` } // Per-replica boundary usage statistics for telemetry aggregation. @@ -5716,6 +5730,15 @@ type User struct { ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"` } +// Per-user AI spend override that supersedes group budget resolution. +type UserAiBudgetOverride struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupID uuid.UUID `db:"group_id" json:"group_id"` + SpendLimitMicros int64 `db:"spend_limit_micros" json:"spend_limit_micros"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + // User-owned API keys associated with AI providers. These keys are used only when BYOK is enabled. type UserAiProviderKey struct { ID uuid.UUID `db:"id" json:"id"` diff --git a/coderd/database/pubsub/pubsub.go b/coderd/database/pubsub/pubsub.go index 86f7217b16..97d289e223 100644 --- a/coderd/database/pubsub/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -56,14 +56,14 @@ type msgOrErr struct { err error } -// msgQueue implements a fixed length queue with the ability to replace elements +// MsgQueue implements a fixed length queue with the ability to replace elements // after they are queued (but before they are dequeued). // // The purpose of this data structure is to build something that works a bit // like a golang channel, but if the queue is full, then we can replace the // last element with an error so that the subscriber can get notified that some // messages were dropped, all without blocking. -type msgQueue struct { +type MsgQueue struct { ctx context.Context cond *sync.Cond q [BufferSize]msgOrErr @@ -74,11 +74,11 @@ type msgQueue struct { le ListenerWithErr } -func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue { +func NewMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *MsgQueue { if l == nil && le == nil { panic("l or le must be non-nil") } - q := &msgQueue{ + q := &MsgQueue{ ctx: ctx, cond: sync.NewCond(&sync.Mutex{}), l: l, @@ -88,7 +88,7 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue return q } -func (q *msgQueue) run() { +func (q *MsgQueue) run() { for { // wait until there is something on the queue or we are closed q.cond.L.Lock() @@ -125,7 +125,7 @@ func (q *msgQueue) run() { } } -func (q *msgQueue) enqueue(msg []byte) { +func (q *MsgQueue) Enqueue(msg []byte) { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -149,15 +149,15 @@ func (q *msgQueue) enqueue(msg []byte) { q.cond.Broadcast() } -func (q *msgQueue) close() { +func (q *MsgQueue) Close() { q.cond.L.Lock() defer q.cond.L.Unlock() defer q.cond.Broadcast() q.closed = true } -// dropped records an error in the queue that messages might have been dropped -func (q *msgQueue) dropped() { +// Dropped records an error in the queue that messages might have been Dropped +func (q *MsgQueue) Dropped() { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -195,7 +195,7 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification { } type queueSet struct { - m map[*msgQueue]struct{} + m map[*MsgQueue]struct{} // unlistenInProgress will be non-nil if another goroutine is unlistening for the event this // queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done. unlistenInProgress chan struct{} @@ -203,7 +203,7 @@ type queueSet struct { func newQueueSet() *queueSet { return &queueSet{ - m: make(map[*msgQueue]struct{}), + m: make(map[*MsgQueue]struct{}), } } @@ -243,19 +243,19 @@ const BufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { - return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil)) + return p.subscribeQueue(event, NewMsgQueue(context.Background(), listener, nil)) } func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { - return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener)) + return p.subscribeQueue(event, NewMsgQueue(context.Background(), nil, listener)) } -func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { +func (p *PGPubsub) subscribeQueue(event string, newQ *MsgQueue) (cancel func(), err error) { defer func() { if err != nil { // if we hit an error, we need to close the queue so we don't // leak its goroutine. - newQ.close() + newQ.Close() p.subscribesTotal.WithLabelValues("false").Inc() } else { p.subscribesTotal.WithLabelValues("true").Inc() @@ -325,7 +325,7 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), func() { p.qMu.Lock() defer p.qMu.Unlock() - newQ.close() + newQ.Close() qSet, ok := p.queues[event] if !ok { p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event)) @@ -436,7 +436,7 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { } extra := []byte(notif.Extra) for q := range qSet.m { - q.enqueue(extra) + q.Enqueue(extra) } } @@ -445,7 +445,7 @@ func (p *PGPubsub) recordReconnect() { defer p.qMu.Unlock() for _, qSet := range p.queues { for q := range qSet.m { - q.dropped() + q.Dropped() } } } diff --git a/coderd/database/pubsub/pubsub_internal_test.go b/coderd/database/pubsub/pubsub_internal_test.go index 0f699b4e4d..0c51d7a8e8 100644 --- a/coderd/database/pubsub/pubsub_internal_test.go +++ b/coderd/database/pubsub/pubsub_internal_test.go @@ -13,135 +13,6 @@ import ( "github.com/coder/coder/v2/testutil" ) -func Test_msgQueue_ListenerWithError(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - m := make(chan string) - e := make(chan error) - uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { - m <- string(msg) - e <- err - }) - defer uut.close() - - // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. - // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned - // when we wrap around the end of the circular buffer. This tests that we correctly handle - // the wrapping and aren't dequeueing misaligned data. - cycles := (BufferSize / 5) * 2 // almost twice around the ring - for j := 0; j < cycles; j++ { - for i := 0; i < 4; i++ { - uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) - } - uut.dropped() - for i := 0; i < 4; i++ { - select { - case <-ctx.Done(): - t.Fatal("timed out") - case msg := <-m: - require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case err := <-e: - require.NoError(t, err) - } - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case msg := <-m: - require.Equal(t, "", msg) - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case err := <-e: - require.ErrorIs(t, err, ErrDroppedMessages) - } - } -} - -func Test_msgQueue_Listener(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - m := make(chan string) - uut := newMsgQueue(ctx, func(ctx context.Context, msg []byte) { - m <- string(msg) - }, nil) - defer uut.close() - - // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. - // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned - // when we wrap around the end of the circular buffer. This tests that we correctly handle - // the wrapping and aren't dequeueing misaligned data. - cycles := (BufferSize / 5) * 2 // almost twice around the ring - for j := 0; j < cycles; j++ { - for i := 0; i < 4; i++ { - uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) - } - uut.dropped() - for i := 0; i < 4; i++ { - select { - case <-ctx.Done(): - t.Fatal("timed out") - case msg := <-m: - require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) - } - } - // Listener skips over errors, so we only read out the 4 real messages. - } -} - -func Test_msgQueue_Full(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - firstDequeue := make(chan struct{}) - allowRead := make(chan struct{}) - n := 0 - errors := make(chan error) - uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { - if n == 0 { - close(firstDequeue) - } - <-allowRead - if err == nil { - require.Equal(t, fmt.Sprintf("%d", n), string(msg)) - n++ - return - } - errors <- err - }) - defer uut.close() - - // we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks - // but only after we've dequeued a message, and then another extra because we want to exceed - // the capacity, not just reach it. - for i := 0; i < BufferSize+2; i++ { - uut.enqueue([]byte(fmt.Sprintf("%d", i))) - // ensure the first dequeue has happened before proceeding, so that this function isn't racing - // against the goroutine that dequeues items. - <-firstDequeue - } - close(allowRead) - - select { - case <-ctx.Done(): - t.Fatal("timed out") - case err := <-errors: - require.ErrorIs(t, err, ErrDroppedMessages) - } - // Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last - // message we send doesn't get queued, AND, it bumps a message out of the queue to make room - // for the error, so we read 2 less than we sent. - require.Equal(t, BufferSize, n) -} - func TestPubSub_DoesntBlockNotify(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) diff --git a/coderd/database/pubsub/pubsub_test.go b/coderd/database/pubsub/pubsub_test.go index 066b9ce59a..3dbfa92f52 100644 --- a/coderd/database/pubsub/pubsub_test.go +++ b/coderd/database/pubsub/pubsub_test.go @@ -3,6 +3,7 @@ package pubsub_test import ( "context" "database/sql" + "fmt" "testing" "time" @@ -201,3 +202,132 @@ func TestPGPubsubDriver(t *testing.T) { } }, testutil.IntervalMedium, "subscriber did not receive message after reconnect") } + +func Test_MsgQueue_ListenerWithError(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + m := make(chan string) + e := make(chan error) + uut := pubsub.NewMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { + m <- string(msg) + e <- err + }) + defer uut.Close() + + // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. + // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned + // when we wrap around the end of the circular buffer. This tests that we correctly handle + // the wrapping and aren't dequeueing misaligned data. + cycles := (pubsub.BufferSize / 5) * 2 // almost twice around the ring + for j := 0; j < cycles; j++ { + for i := 0; i < 4; i++ { + uut.Enqueue([]byte(fmt.Sprintf("%d%d", j, i))) + } + uut.Dropped() + for i := 0; i < 4; i++ { + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-e: + require.NoError(t, err) + } + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, "", msg) + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-e: + require.ErrorIs(t, err, pubsub.ErrDroppedMessages) + } + } +} + +func Test_MsgQueue_Listener(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + m := make(chan string) + uut := pubsub.NewMsgQueue(ctx, func(ctx context.Context, msg []byte) { + m <- string(msg) + }, nil) + defer uut.Close() + + // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. + // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned + // when we wrap around the end of the circular buffer. This tests that we correctly handle + // the wrapping and aren't dequeueing misaligned data. + cycles := (pubsub.BufferSize / 5) * 2 // almost twice around the ring + for j := 0; j < cycles; j++ { + for i := 0; i < 4; i++ { + uut.Enqueue([]byte(fmt.Sprintf("%d%d", j, i))) + } + uut.Dropped() + for i := 0; i < 4; i++ { + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) + } + } + // Listener skips over errors, so we only read out the 4 real messages. + } +} + +func Test_MsgQueue_Full(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + firstDequeue := make(chan struct{}) + allowRead := make(chan struct{}) + n := 0 + errors := make(chan error) + uut := pubsub.NewMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { + if n == 0 { + close(firstDequeue) + } + <-allowRead + if err == nil { + require.Equal(t, fmt.Sprintf("%d", n), string(msg)) + n++ + return + } + errors <- err + }) + defer uut.Close() + + // we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks + // but only after we've dequeued a message, and then another extra because we want to exceed + // the capacity, not just reach it. + for i := 0; i < pubsub.BufferSize+2; i++ { + uut.Enqueue([]byte(fmt.Sprintf("%d", i))) + // ensure the first dequeue has happened before proceeding, so that this function isn't racing + // against the goroutine that dequeues items. + <-firstDequeue + } + close(allowRead) + + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-errors: + require.ErrorIs(t, err, pubsub.ErrDroppedMessages) + } + // Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last + // message we send doesn't get queued, AND, it bumps a message out of the queue to make room + // for the error, so we read 2 less than we sent. + require.Equal(t, pubsub.BufferSize, n) +} diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ef6078b6c3..12592fe61e 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -198,6 +198,7 @@ type sqlcQuerier interface { DeleteTailnetPeer(ctx context.Context, arg DeleteTailnetPeerParams) (DeleteTailnetPeerRow, error) DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error) DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error) + DeleteUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (UserAiBudgetOverride, error) DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error @@ -738,6 +739,7 @@ type sqlcQuerier interface { // inclusive. GetTotalUsageDCManagedAgentsV1(ctx context.Context, arg GetTotalUsageDCManagedAgentsV1Params) (int64, error) GetUnexpiredLicenses(ctx context.Context) ([]License, error) + GetUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (UserAiBudgetOverride, error) GetUserAIProviderKeyByProviderID(ctx context.Context, arg GetUserAIProviderKeyByProviderIDParams) (UserAiProviderKey, error) // GetUserAIProviderKeys is used by dbcrypt key rotation. Request paths should use // user-scoped lookups instead of this bulk accessor. @@ -921,7 +923,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) - InsertBoundaryLog(ctx context.Context, arg InsertBoundaryLogParams) (BoundaryLog, error) + InsertBoundaryLogs(ctx context.Context, arg InsertBoundaryLogsParams) ([]BoundaryLog, error) InsertBoundarySession(ctx context.Context, arg InsertBoundarySessionParams) (BoundarySession, error) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) // updated_at is the retention clock used by DeleteOldChatDebugRuns. @@ -1408,6 +1410,7 @@ type sqlcQuerier interface { // used to store the data, and the minutes are summed for each user and template // combination. The result is stored in the template_usage_stats table. UpsertTemplateUsageStats(ctx context.Context) error + UpsertUserAIBudgetOverride(ctx context.Context, arg UpsertUserAIBudgetOverrideParams) (UserAiBudgetOverride, error) // UpsertUserAIProviderKey preserves the original id and created_at when the // user/provider pair already exists. On conflict, callers provide id and // created_at for the insert path only. diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index f181e2e94b..cefe6a866e 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -9921,8 +9921,9 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{ - ID: uuid.New(), - EndedAt: time.Now(), + ID: uuid.New(), + EndedAt: time.Now(), + CredentialHint: "sk-a...efgh", }) require.ErrorContains(t, err, "no rows in result set") require.EqualValues(t, database.AIBridgeInterception{}, got) @@ -9957,18 +9958,21 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) { endedAt := time.Now() // Mark first interception as done updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{ - ID: intc0.ID, - EndedAt: endedAt, + ID: intc0.ID, + EndedAt: endedAt, + CredentialHint: "sk-a...efgh", }) require.NoError(t, err) require.EqualValues(t, updated.ID, intc0.ID) require.True(t, updated.EndedAt.Valid) require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second) + require.Equal(t, "sk-a...efgh", updated.CredentialHint) // Updating first interception again should fail updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{ - ID: intc0.ID, - EndedAt: endedAt.Add(time.Hour), + ID: intc0.ID, + EndedAt: endedAt.Add(time.Hour), + CredentialHint: "sk-a...efgh", }) require.ErrorIs(t, err, sql.ErrNoRows) @@ -9979,6 +9983,52 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) { require.False(t, got.EndedAt.Valid) } }) + + t.Run("CentralizedHintUpdated", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + user := dbgen.User(t, db, database.User{}) + intc, err := db.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{ + ID: uuid.New(), + InitiatorID: user.ID, + Metadata: json.RawMessage("{}"), + CredentialKind: database.CredentialKindCentralized, + CredentialHint: "", + }) + require.NoError(t, err) + + updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{ + ID: intc.ID, + EndedAt: time.Now(), + CredentialHint: "sk-a...efgh", + }) + require.NoError(t, err) + require.Equal(t, "sk-a...efgh", updated.CredentialHint) + }) + + t.Run("BYOKHintPreserved", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + user := dbgen.User(t, db, database.User{}) + intc, err := db.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{ + ID: uuid.New(), + InitiatorID: user.ID, + Metadata: json.RawMessage("{}"), + CredentialKind: database.CredentialKindByok, + CredentialHint: "sk-u...byok", + }) + require.NoError(t, err) + + updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{ + ID: intc.ID, + EndedAt: time.Now(), + CredentialHint: "sk-a...efgh", + }) + require.NoError(t, err) + require.Equal(t, "sk-u...byok", updated.CredentialHint) + }) } func TestDeleteExpiredAPIKeys(t *testing.T) { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 00c06ece56..65e066f25b 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2389,20 +2389,28 @@ func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Contex const updateAIBridgeInterceptionEnded = `-- name: UpdateAIBridgeInterceptionEnded :one UPDATE aibridge_interceptions - SET ended_at = $1::timestamptz + SET ended_at = $1::timestamptz, + -- BYOK records its hint at the start of the interception. + -- Centralized uses key failover, so its hint is only known + -- at end-of-interception. + credential_hint = CASE + WHEN credential_kind = 'centralized' THEN $2::text + ELSE credential_hint + END WHERE - id = $2::uuid + id = $3::uuid AND ended_at IS NULL RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name, credential_kind, credential_hint ` type UpdateAIBridgeInterceptionEndedParams struct { - EndedAt time.Time `db:"ended_at" json:"ended_at"` - ID uuid.UUID `db:"id" json:"id"` + EndedAt time.Time `db:"ended_at" json:"ended_at"` + CredentialHint string `db:"credential_hint" json:"credential_hint"` + ID uuid.UUID `db:"id" json:"id"` } func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) { - row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.ID) + row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.CredentialHint, arg.ID) var i AIBridgeInterception err := row.Scan( &i.ID, @@ -2441,6 +2449,23 @@ func (q *sqlQuerier) DeleteGroupAIBudget(ctx context.Context, groupID uuid.UUID) return i, err } +const deleteUserAIBudgetOverride = `-- name: DeleteUserAIBudgetOverride :one +DELETE FROM user_ai_budget_overrides WHERE user_id = $1 RETURNING user_id, group_id, spend_limit_micros, created_at, updated_at +` + +func (q *sqlQuerier) DeleteUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (UserAiBudgetOverride, error) { + row := q.db.QueryRowContext(ctx, deleteUserAIBudgetOverride, userID) + var i UserAiBudgetOverride + err := row.Scan( + &i.UserID, + &i.GroupID, + &i.SpendLimitMicros, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getAIModelPriceByProviderModel = `-- name: GetAIModelPriceByProviderModel :one SELECT provider, model, input_price, output_price, cache_read_price, cache_write_price, created_at, updated_at FROM ai_model_prices @@ -2486,6 +2511,25 @@ func (q *sqlQuerier) GetGroupAIBudget(ctx context.Context, groupID uuid.UUID) (G return i, err } +const getUserAIBudgetOverride = `-- name: GetUserAIBudgetOverride :one +SELECT user_id, group_id, spend_limit_micros, created_at, updated_at +FROM user_ai_budget_overrides +WHERE user_id = $1 +` + +func (q *sqlQuerier) GetUserAIBudgetOverride(ctx context.Context, userID uuid.UUID) (UserAiBudgetOverride, error) { + row := q.db.QueryRowContext(ctx, getUserAIBudgetOverride, userID) + var i UserAiBudgetOverride + err := row.Scan( + &i.UserID, + &i.GroupID, + &i.SpendLimitMicros, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const upsertAIModelPrices = `-- name: UpsertAIModelPrices :exec INSERT INTO ai_model_prices ( provider, model, input_price, output_price, cache_read_price, cache_write_price @@ -2540,6 +2584,35 @@ func (q *sqlQuerier) UpsertGroupAIBudget(ctx context.Context, arg UpsertGroupAIB return i, err } +const upsertUserAIBudgetOverride = `-- name: UpsertUserAIBudgetOverride :one +INSERT INTO user_ai_budget_overrides (user_id, group_id, spend_limit_micros) +VALUES ($1, $2, $3) +ON CONFLICT (user_id) DO UPDATE SET + group_id = EXCLUDED.group_id, + spend_limit_micros = EXCLUDED.spend_limit_micros, + updated_at = NOW() +RETURNING user_id, group_id, spend_limit_micros, created_at, updated_at +` + +type UpsertUserAIBudgetOverrideParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupID uuid.UUID `db:"group_id" json:"group_id"` + SpendLimitMicros int64 `db:"spend_limit_micros" json:"spend_limit_micros"` +} + +func (q *sqlQuerier) UpsertUserAIBudgetOverride(ctx context.Context, arg UpsertUserAIBudgetOverrideParams) (UserAiBudgetOverride, error) { + row := q.db.QueryRowContext(ctx, upsertUserAIBudgetOverride, arg.UserID, arg.GroupID, arg.SpendLimitMicros) + var i UserAiBudgetOverride + err := row.Scan( + &i.UserID, + &i.GroupID, + &i.SpendLimitMicros, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getActiveAISeatCount = `-- name: GetActiveAISeatCount :one SELECT COUNT(*) @@ -3627,7 +3700,7 @@ func (q *sqlQuerier) GetBoundaryLogByID(ctx context.Context, id uuid.UUID) (Boun } const getBoundarySessionByID = `-- name: GetBoundarySessionByID :one -SELECT id, workspace_agent_id, confined_process_name, started_at, updated_at FROM boundary_sessions WHERE id = $1 +SELECT id, workspace_agent_id, confined_process_name, started_at, updated_at, owner_id FROM boundary_sessions WHERE id = $1 ` func (q *sqlQuerier) GetBoundarySessionByID(ctx context.Context, id uuid.UUID) (BoundarySession, error) { @@ -3639,11 +3712,12 @@ func (q *sqlQuerier) GetBoundarySessionByID(ctx context.Context, id uuid.UUID) ( &i.ConfinedProcessName, &i.StartedAt, &i.UpdatedAt, + &i.OwnerID, ) return i, err } -const insertBoundaryLog = `-- name: InsertBoundaryLog :one +const insertBoundaryLogs = `-- name: InsertBoundaryLogs :many INSERT INTO boundary_logs ( id, session_id, @@ -3654,62 +3728,80 @@ INSERT INTO boundary_logs ( method, detail, matched_rule -) VALUES ( - $1, - $2, - $3, - $4, - $5, - $6, - $7, - $8, - $9 -) RETURNING id, session_id, sequence_number, captured_at, created_at, proto, method, detail, matched_rule +) +SELECT + unnest($1 :: uuid[]), + $2 :: uuid, + unnest($3 :: int[]), + unnest($4 :: timestamptz[]), + unnest($5 :: timestamptz[]), + unnest($6 :: text[]), + unnest($7 :: text[]), + unnest($8 :: text[]), + unnest($9 :: text[]) +RETURNING id, session_id, sequence_number, captured_at, created_at, proto, method, detail, matched_rule ` -type InsertBoundaryLogParams struct { - ID uuid.UUID `db:"id" json:"id"` - SessionID uuid.UUID `db:"session_id" json:"session_id"` - SequenceNumber int32 `db:"sequence_number" json:"sequence_number"` - CapturedAt time.Time `db:"captured_at" json:"captured_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - Proto string `db:"proto" json:"proto"` - Method string `db:"method" json:"method"` - Detail string `db:"detail" json:"detail"` - MatchedRule sql.NullString `db:"matched_rule" json:"matched_rule"` +type InsertBoundaryLogsParams struct { + ID []uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + SequenceNumber []int32 `db:"sequence_number" json:"sequence_number"` + CapturedAt []time.Time `db:"captured_at" json:"captured_at"` + CreatedAt []time.Time `db:"created_at" json:"created_at"` + Proto []string `db:"proto" json:"proto"` + Method []string `db:"method" json:"method"` + Detail []string `db:"detail" json:"detail"` + MatchedRule []string `db:"matched_rule" json:"matched_rule"` } -func (q *sqlQuerier) InsertBoundaryLog(ctx context.Context, arg InsertBoundaryLogParams) (BoundaryLog, error) { - row := q.db.QueryRowContext(ctx, insertBoundaryLog, - arg.ID, +func (q *sqlQuerier) InsertBoundaryLogs(ctx context.Context, arg InsertBoundaryLogsParams) ([]BoundaryLog, error) { + rows, err := q.db.QueryContext(ctx, insertBoundaryLogs, + pq.Array(arg.ID), arg.SessionID, - arg.SequenceNumber, - arg.CapturedAt, - arg.CreatedAt, - arg.Proto, - arg.Method, - arg.Detail, - arg.MatchedRule, + pq.Array(arg.SequenceNumber), + pq.Array(arg.CapturedAt), + pq.Array(arg.CreatedAt), + pq.Array(arg.Proto), + pq.Array(arg.Method), + pq.Array(arg.Detail), + pq.Array(arg.MatchedRule), ) - var i BoundaryLog - err := row.Scan( - &i.ID, - &i.SessionID, - &i.SequenceNumber, - &i.CapturedAt, - &i.CreatedAt, - &i.Proto, - &i.Method, - &i.Detail, - &i.MatchedRule, - ) - return i, err + if err != nil { + return nil, err + } + defer rows.Close() + var items []BoundaryLog + for rows.Next() { + var i BoundaryLog + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.SequenceNumber, + &i.CapturedAt, + &i.CreatedAt, + &i.Proto, + &i.Method, + &i.Detail, + &i.MatchedRule, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const insertBoundarySession = `-- name: InsertBoundarySession :one INSERT INTO boundary_sessions ( id, workspace_agent_id, + owner_id, confined_process_name, started_at, updated_at @@ -3718,22 +3810,25 @@ INSERT INTO boundary_sessions ( $2, $3, $4, - $5 -) RETURNING id, workspace_agent_id, confined_process_name, started_at, updated_at + $5, + $6 +) RETURNING id, workspace_agent_id, confined_process_name, started_at, updated_at, owner_id ` type InsertBoundarySessionParams struct { - ID uuid.UUID `db:"id" json:"id"` - WorkspaceAgentID uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` - ConfinedProcessName string `db:"confined_process_name" json:"confined_process_name"` - StartedAt time.Time `db:"started_at" json:"started_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ID uuid.UUID `db:"id" json:"id"` + WorkspaceAgentID uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` + OwnerID uuid.NullUUID `db:"owner_id" json:"owner_id"` + ConfinedProcessName string `db:"confined_process_name" json:"confined_process_name"` + StartedAt time.Time `db:"started_at" json:"started_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } func (q *sqlQuerier) InsertBoundarySession(ctx context.Context, arg InsertBoundarySessionParams) (BoundarySession, error) { row := q.db.QueryRowContext(ctx, insertBoundarySession, arg.ID, arg.WorkspaceAgentID, + arg.OwnerID, arg.ConfinedProcessName, arg.StartedAt, arg.UpdatedAt, @@ -3745,6 +3840,7 @@ func (q *sqlQuerier) InsertBoundarySession(ctx context.Context, arg InsertBounda &i.ConfinedProcessName, &i.StartedAt, &i.UpdatedAt, + &i.OwnerID, ) return i, err } @@ -28052,65 +28148,77 @@ WHERE name ILIKE concat('%', $3, '%') ELSE true END + -- Filter by exact username + AND CASE + WHEN $4 :: text != '' THEN + lower(username) = lower($4) + ELSE true + END + -- Filter by exact email + AND CASE + WHEN $5 :: text != '' THEN + lower(email) = lower($5) + ELSE true + END -- Filter by status AND CASE -- @status needs to be a text because it can be empty, If it was -- user_status enum, it would not. - WHEN cardinality($4 :: user_status[]) > 0 THEN - status = ANY($4 :: user_status[]) + WHEN cardinality($6 :: user_status[]) > 0 THEN + status = ANY($6 :: user_status[]) ELSE true END -- Filter by rbac_roles AND CASE -- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as -- everyone is a member. - WHEN cardinality($5 :: text[]) > 0 AND 'member' != ANY($5 :: text[]) THEN - rbac_roles && $5 :: text[] + WHEN cardinality($7 :: text[]) > 0 AND 'member' != ANY($7 :: text[]) THEN + rbac_roles && $7 :: text[] ELSE true END -- Filter by last_seen - AND CASE - WHEN $6 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - last_seen_at <= $6 - ELSE true - END - AND CASE - WHEN $7 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - last_seen_at >= $7 - ELSE true - END - -- Filter by created_at AND CASE WHEN $8 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - created_at <= $8 + last_seen_at <= $8 ELSE true END AND CASE WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - created_at >= $9 + last_seen_at >= $9 + ELSE true + END + -- Filter by created_at + AND CASE + WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + created_at <= $10 + ELSE true + END + AND CASE + WHEN $11 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + created_at >= $11 ELSE true END -- Filter by system type AND CASE - WHEN $10::bool THEN TRUE + WHEN $12::bool THEN TRUE ELSE is_system = false END -- Filter by github.com user ID AND CASE - WHEN $11 :: bigint != 0 THEN - github_com_user_id = $11 + WHEN $13 :: bigint != 0 THEN + github_com_user_id = $13 ELSE true END -- Filter by login_type AND CASE - WHEN cardinality($12 :: login_type[]) > 0 THEN - login_type = ANY($12 :: login_type[]) + WHEN cardinality($14 :: login_type[]) > 0 THEN + login_type = ANY($14 :: login_type[]) ELSE true END -- Filter by service account. AND CASE - WHEN $13 :: boolean IS NOT NULL THEN - is_service_account = $13 :: boolean + WHEN $15 :: boolean IS NOT NULL THEN + is_service_account = $15 :: boolean ELSE true END -- End of filters @@ -28119,16 +28227,18 @@ WHERE -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. - LOWER(username) ASC OFFSET $14 + LOWER(username) ASC OFFSET $16 LIMIT -- A null limit means "no limit", so 0 means return all - NULLIF($15 :: int, 0) + NULLIF($17 :: int, 0) ` type GetUsersParams struct { AfterID uuid.UUID `db:"after_id" json:"after_id"` Search string `db:"search" json:"search"` Name string `db:"name" json:"name"` + ExactUsername string `db:"exact_username" json:"exact_username"` + ExactEmail string `db:"exact_email" json:"exact_email"` Status []UserStatus `db:"status" json:"status"` RbacRole []string `db:"rbac_role" json:"rbac_role"` LastSeenBefore time.Time `db:"last_seen_before" json:"last_seen_before"` @@ -28173,6 +28283,8 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUse arg.AfterID, arg.Search, arg.Name, + arg.ExactUsername, + arg.ExactEmail, pq.Array(arg.Status), pq.Array(arg.RbacRole), arg.LastSeenBefore, diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 7756c7086b..a1b49d25cd 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -8,7 +8,14 @@ RETURNING *; -- name: UpdateAIBridgeInterceptionEnded :one UPDATE aibridge_interceptions - SET ended_at = @ended_at::timestamptz + SET ended_at = @ended_at::timestamptz, + -- BYOK records its hint at the start of the interception. + -- Centralized uses key failover, so its hint is only known + -- at end-of-interception. + credential_hint = CASE + WHEN credential_kind = 'centralized' THEN @credential_hint::text + ELSE credential_hint + END WHERE id = @id::uuid AND ended_at IS NULL diff --git a/coderd/database/queries/aicostcontrol.sql b/coderd/database/queries/aicostcontrol.sql index 6740b2568c..188ec7357e 100644 --- a/coderd/database/queries/aicostcontrol.sql +++ b/coderd/database/queries/aicostcontrol.sql @@ -40,3 +40,20 @@ RETURNING *; -- name: DeleteGroupAIBudget :one DELETE FROM group_ai_budgets WHERE group_id = @group_id RETURNING *; + +-- name: GetUserAIBudgetOverride :one +SELECT * +FROM user_ai_budget_overrides +WHERE user_id = @user_id; + +-- name: UpsertUserAIBudgetOverride :one +INSERT INTO user_ai_budget_overrides (user_id, group_id, spend_limit_micros) +VALUES (@user_id, @group_id, @spend_limit_micros) +ON CONFLICT (user_id) DO UPDATE SET + group_id = EXCLUDED.group_id, + spend_limit_micros = EXCLUDED.spend_limit_micros, + updated_at = NOW() +RETURNING *; + +-- name: DeleteUserAIBudgetOverride :one +DELETE FROM user_ai_budget_overrides WHERE user_id = @user_id RETURNING *; diff --git a/coderd/database/queries/boundarylogs.sql b/coderd/database/queries/boundarylogs.sql index d8c35fd7eb..3abeb618a5 100644 --- a/coderd/database/queries/boundarylogs.sql +++ b/coderd/database/queries/boundarylogs.sql @@ -2,12 +2,14 @@ INSERT INTO boundary_sessions ( id, workspace_agent_id, + owner_id, confined_process_name, started_at, updated_at ) VALUES ( @id, @workspace_agent_id, + @owner_id, @confined_process_name, @started_at, @updated_at @@ -16,7 +18,7 @@ INSERT INTO boundary_sessions ( -- name: GetBoundarySessionByID :one SELECT * FROM boundary_sessions WHERE id = @id; --- name: InsertBoundaryLog :one +-- name: InsertBoundaryLogs :many INSERT INTO boundary_logs ( id, session_id, @@ -27,17 +29,18 @@ INSERT INTO boundary_logs ( method, detail, matched_rule -) VALUES ( - @id, - @session_id, - @sequence_number, - @captured_at, - @created_at, - @proto, - @method, - @detail, - @matched_rule -) RETURNING *; +) +SELECT + unnest(@id :: uuid[]), + @session_id :: uuid, + unnest(@sequence_number :: int[]), + unnest(@captured_at :: timestamptz[]), + unnest(@created_at :: timestamptz[]), + unnest(@proto :: text[]), + unnest(@method :: text[]), + unnest(@detail :: text[]), + unnest(@matched_rule :: text[]) +RETURNING *; -- name: GetBoundaryLogByID :one SELECT * FROM boundary_logs WHERE id = @id; diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 03f403e145..7bbd2dd0c9 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -486,6 +486,18 @@ WHERE name ILIKE concat('%', @name, '%') ELSE true END + -- Filter by exact username + AND CASE + WHEN @exact_username :: text != '' THEN + lower(username) = lower(@exact_username) + ELSE true + END + -- Filter by exact email + AND CASE + WHEN @exact_email :: text != '' THEN + lower(email) = lower(@exact_email) + ELSE true + END -- Filter by status AND CASE -- @status needs to be a text because it can be empty, If it was diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 8ef517a9cb..3d5e5dabcf 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -97,6 +97,7 @@ const ( UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id); UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type); UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); + UniqueUserAiBudgetOverridesPkey UniqueConstraint = "user_ai_budget_overrides_pkey" // ALTER TABLE ONLY user_ai_budget_overrides ADD CONSTRAINT user_ai_budget_overrides_pkey PRIMARY KEY (user_id); UniqueUserAiProviderKeysPkey UniqueConstraint = "user_ai_provider_keys_pkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id); UniqueUserAiProviderKeysUserIDAiProviderIDKey UniqueConstraint = "user_ai_provider_keys_user_id_ai_provider_id_key" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); diff --git a/coderd/mcp_test.go b/coderd/mcp_test.go index add730960f..dde85f12e7 100644 --- a/coderd/mcp_test.go +++ b/coderd/mcp_test.go @@ -1396,10 +1396,11 @@ func TestChatWithMCPServerIDs(t *testing.T) { // Create the chat model config required for creating a chat. _ = createChatModelConfigForMCP(t, expClient) - // Create an enabled MCP server config. - mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true) + // Create enabled MCP server configs. + mcpConfigA := createMCPServerConfig(t, client, "chat-mcp-server-a", true) + mcpConfigB := createMCPServerConfig(t, client, "chat-mcp-server-b", true) - // Create a chat referencing the MCP server. + // Create a chat referencing the MCP servers. chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ @@ -1408,16 +1409,24 @@ func TestChatWithMCPServerIDs(t *testing.T) { Text: "hello with mcp server", }, }, - MCPServerIDs: []uuid.UUID{mcpConfig.ID}, + MCPServerIDs: []uuid.UUID{mcpConfigA.ID, mcpConfigB.ID}, }) require.NoError(t, err) require.NotEqual(t, uuid.Nil, chat.ID) - require.Contains(t, chat.MCPServerIDs, mcpConfig.ID) + require.ElementsMatch(t, []uuid.UUID{mcpConfigA.ID, mcpConfigB.ID}, chat.MCPServerIDs) // Fetch the chat and verify the MCP server IDs persist. fetched, err := expClient.GetChat(ctx, chat.ID) require.NoError(t, err) - require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID) + require.ElementsMatch(t, []uuid.UUID{mcpConfigA.ID, mcpConfigB.ID}, fetched.MCPServerIDs) + + err = client.DeleteMCPServerConfig(ctx, mcpConfigA.ID) + require.NoError(t, err) + + fetched, err = expClient.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.NotContains(t, fetched.MCPServerIDs, mcpConfigA.ID) + require.Contains(t, fetched.MCPServerIDs, mcpConfigB.ID) } func createChatModelConfigForMCP(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig { diff --git a/coderd/notifications/manager.go b/coderd/notifications/manager.go index f65fc3ff7f..4d44563fce 100644 --- a/coderd/notifications/manager.go +++ b/coderd/notifications/manager.go @@ -237,9 +237,7 @@ func (m *Manager) BufferedUpdatesCount() (success int, failure int) { // syncUpdates updates messages in the store based on the given successful and failed message dispatch results. func (m *Manager) syncUpdates(ctx context.Context) { // Ensure we update the metrics to reflect the current state after each invocation. - defer func() { - m.metrics.PendingUpdates.Set(float64(len(m.success) + len(m.failure))) - }() + defer m.metrics.pendingUpdatesGauge.set(func() int { return len(m.success) + len(m.failure) }) select { case <-ctx.Done(): @@ -250,7 +248,7 @@ func (m *Manager) syncUpdates(ctx context.Context) { nSuccess := len(m.success) nFailure := len(m.failure) - m.metrics.PendingUpdates.Set(float64(nSuccess + nFailure)) + m.metrics.pendingUpdatesGauge.set(func() int { return len(m.success) + len(m.failure) }) // Nothing to do. if nSuccess+nFailure == 0 { diff --git a/coderd/notifications/metrics.go b/coderd/notifications/metrics.go index 204bc260c7..69a262bb47 100644 --- a/coderd/notifications/metrics.go +++ b/coderd/notifications/metrics.go @@ -3,6 +3,7 @@ package notifications import ( "fmt" "strings" + "sync" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -17,8 +18,28 @@ type Metrics struct { InflightDispatches *prometheus.GaugeVec DispatcherSendSeconds *prometheus.HistogramVec - PendingUpdates prometheus.Gauge + PendingUpdates prometheus.Collector SyncedUpdates prometheus.Counter + + pendingUpdatesGauge *pendingUpdatesGauge +} + +// pendingUpdatesGauge serializes count evaluation with the gauge write, +// preventing stale snapshots when concurrent goroutines race to update +// the metric. +type pendingUpdatesGauge struct { + gauge prometheus.Gauge + mu sync.Mutex +} + +// set evaluates count under the lock and writes the result to the gauge. +// count is a function, not a value, so the channel length is read atomically +// with the write; passing a pre-evaluated int would reintroduce the race. +func (g *pendingUpdatesGauge) set(count func() int) { + g.mu.Lock() + defer g.mu.Unlock() + + g.gauge.Set(float64(count())) } const ( @@ -35,6 +56,11 @@ const ( ) func NewMetrics(reg prometheus.Registerer) *Metrics { + pendingUpdates := promauto.With(reg).NewGauge(prometheus.GaugeOpts{ + Name: "pending_updates", Namespace: ns, Subsystem: subsystem, + Help: "The number of dispatch attempt results waiting to be flushed to the store.", + }) + return &Metrics{ DispatchAttempts: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Name: "dispatch_attempts_total", Namespace: ns, Subsystem: subsystem, @@ -68,10 +94,10 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { }, []string{LabelMethod}), // Currently no requirement to discriminate between success and failure updates which are pending. - PendingUpdates: promauto.With(reg).NewGauge(prometheus.GaugeOpts{ - Name: "pending_updates", Namespace: ns, Subsystem: subsystem, - Help: "The number of dispatch attempt results waiting to be flushed to the store.", - }), + PendingUpdates: pendingUpdates, + pendingUpdatesGauge: &pendingUpdatesGauge{ + gauge: pendingUpdates, + }, SyncedUpdates: promauto.With(reg).NewCounter(prometheus.CounterOpts{ Name: "synced_updates_total", Namespace: ns, Subsystem: subsystem, Help: "The number of dispatch attempt results flushed to the store.", diff --git a/coderd/notifications/metrics_internal_test.go b/coderd/notifications/metrics_internal_test.go new file mode 100644 index 0000000000..04360dc221 --- /dev/null +++ b/coderd/notifications/metrics_internal_test.go @@ -0,0 +1,85 @@ +package notifications + +import ( + "sync" + "testing" + + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/testutil" +) + +func TestMetricsSetPendingUpdatesSerializesGaugeWrites(t *testing.T) { + t.Parallel() + + realGauge := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "test_pending_updates", + Help: "test pending updates gauge", + }) + blockingGauge := &pendingUpdatesBlockingGauge{ + Gauge: realGauge, + blockValue: 3, + entered: make(chan struct{}), + release: make(chan struct{}), + } + metrics := &Metrics{ + PendingUpdates: blockingGauge, + pendingUpdatesGauge: &pendingUpdatesGauge{gauge: blockingGauge}, + } + + success := make(chan dispatchResult, 4) + failure := make(chan dispatchResult, 4) + success <- dispatchResult{} + success <- dispatchResult{} + + firstDone := make(chan struct{}) + go func() { + defer close(firstDone) + failure <- dispatchResult{} + // The first writer observes total=3 and blocks inside Set(3) + // while still holding the pendingUpdatesGauge mutex. + metrics.pendingUpdatesGauge.set(func() int { return len(success) + len(failure) }) + }() + + testutil.TryReceive(testutil.Context(t, testutil.WaitShort), t, blockingGauge.entered) + + // The main goroutine raises the real total to 4 before a second + // writer queues behind the locked gauge. + success <- dispatchResult{} + + secondDone := make(chan struct{}) + go func() { + defer close(secondDone) + // This count must be evaluated after release, while holding the + // mutex, so the final gauge value cannot regress to 3. + metrics.pendingUpdatesGauge.set(func() int { return len(success) + len(failure) }) + }() + + close(blockingGauge.release) + testutil.TryReceive(testutil.Context(t, testutil.WaitShort), t, firstDone) + testutil.TryReceive(testutil.Context(t, testutil.WaitShort), t, secondDone) + + require.Equal(t, 4, len(success)+len(failure)) + require.EqualValues(t, 4, promtest.ToFloat64(metrics.PendingUpdates)) +} + +type pendingUpdatesBlockingGauge struct { + prometheus.Gauge + + blockValue float64 + entered chan struct{} + release chan struct{} + once sync.Once +} + +func (g *pendingUpdatesBlockingGauge) Set(value float64) { + if value == g.blockValue { + g.once.Do(func() { + close(g.entered) + <-g.release + }) + } + g.Gauge.Set(value) +} diff --git a/coderd/notifications/metrics_test.go b/coderd/notifications/metrics_test.go index 5562ded86e..3a2d7fbc34 100644 --- a/coderd/notifications/metrics_test.go +++ b/coderd/notifications/metrics_test.go @@ -276,17 +276,24 @@ func TestPendingUpdatesMetric(t *testing.T) { mClock.Advance(cfg.FetchInterval.Value()).MustWait(ctx) // THEN: - // handler has dispatched the given notifications. - func() { + // Both handlers have dispatched the given notifications, and their + // results are pending in the metrics. + require.EventuallyWithT(t, func(ct *assert.CollectT) { handler.mu.RLock() + inboxHandler.mu.RLock() defer handler.mu.RUnlock() + defer inboxHandler.mu.RUnlock() - require.Len(t, handler.succeeded, 1) - require.Len(t, handler.failed, 1) - }() + assert.Len(ct, handler.succeeded, 1) + assert.Len(ct, handler.failed, 1) + assert.Len(ct, inboxHandler.succeeded, 1) + assert.Len(ct, inboxHandler.failed, 1) - // Both handler calls should be pending in the metrics. - require.EqualValues(t, 4, promtest.ToFloat64(metrics.PendingUpdates)) + success, failure := mgr.BufferedUpdatesCount() + assert.Equal(ct, 2, success) + assert.Equal(ct, 2, failure) + assert.EqualValues(ct, 4, promtest.ToFloat64(metrics.PendingUpdates)) + }, testutil.WaitShort, testutil.IntervalFast) // THEN: // Trigger syncing updates diff --git a/coderd/notifications/notifier.go b/coderd/notifications/notifier.go index 391c7c9bdb..9c7284c019 100644 --- a/coderd/notifications/notifier.go +++ b/coderd/notifications/notifier.go @@ -172,6 +172,7 @@ func (n *notifier) process(ctx context.Context, success chan<- dispatchResult, f // If a notification template has been disabled by the user after a notification was enqueued, mark it as inhibited if msg.Disabled { failure <- n.newInhibitedDispatch(msg) + n.metrics.pendingUpdatesGauge.set(func() int { return len(success) + len(failure) }) continue } @@ -184,7 +185,7 @@ func (n *notifier) process(ctx context.Context, success chan<- dispatchResult, f n.log.Error(ctx, "dispatcher construction failed", slog.F("msg_id", msg.ID), slog.Error(err)) } failure <- n.newFailedDispatch(msg, err, xerrors.Is(err, decorateHelpersError{})) - n.metrics.PendingUpdates.Set(float64(len(success) + len(failure))) + n.metrics.pendingUpdatesGauge.set(func() int { return len(success) + len(failure) }) continue } @@ -316,7 +317,7 @@ func (n *notifier) deliver(ctx context.Context, msg database.AcquireNotification logger.Debug(ctx, "message dispatch succeeded") } } - n.metrics.PendingUpdates.Set(float64(len(success) + len(failure))) + n.metrics.pendingUpdatesGauge.set(func() int { return len(success) + len(failure) }) return nil } diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 5a52ecc0a1..1c66333aef 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -1588,7 +1588,10 @@ func (s *server) DownloadFile(request *proto.FileRequest, stream proto.DRPCProvi return fail(xerrors.Errorf("unsupported file upload type: %s", request.UploadType)) } - upload, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, file.Data) + upload, chunks, err := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, file.Data) + if err != nil { + return fail(xerrors.Errorf("prepare file upload: %w", err)) + } err = stream.Send(&sdkproto.FileUpload{ Type: &sdkproto.FileUpload_DataUpload{DataUpload: upload}, diff --git a/coderd/provisionerdserver/upload_file_test.go b/coderd/provisionerdserver/upload_file_test.go index d041bb9f98..f235095742 100644 --- a/coderd/provisionerdserver/upload_file_test.go +++ b/coderd/provisionerdserver/upload_file_test.go @@ -48,7 +48,8 @@ func TestUploadFileLargeModuleFiles(t *testing.T) { require.NoError(t, err) // Convert to upload format - upload, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleData) + upload, chunks, err := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleData) + require.NoError(t, err) stream := newMockUploadStream(upload, chunks...) @@ -93,7 +94,8 @@ func TestUploadFileErrorScenarios(t *testing.T) { _, err := crand.Read(moduleData) require.NoError(t, err) - upload, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleData) + upload, chunks, err := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleData) + require.NoError(t, err) t.Run("chunk_before_upload", func(t *testing.T) { t.Parallel() diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 5d40e59cc6..4b253bc10d 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -85,6 +85,7 @@ const ( SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder" SubjectTypeChatd SubjectType = "chatd" SubjectTypeAIProviderMetadataReader SubjectType = "ai_provider_metadata_reader" + SubjectTypeSCIMProvisioner SubjectType = "scim_provisioner" ) const ( diff --git a/coderd/rbac/object_gen.go b/coderd/rbac/object_gen.go index 340221f611..824cf92fdd 100644 --- a/coderd/rbac/object_gen.go +++ b/coderd/rbac/object_gen.go @@ -89,6 +89,15 @@ var ( Type: "audit_log", } + // ResourceBoundaryLog + // Valid Actions + // - "ActionCreate" :: create boundary log records + // - "ActionDelete" :: delete boundary logs + // - "ActionRead" :: read boundary logs and session metadata + ResourceBoundaryLog = Object{ + Type: "boundary_log", + } + // ResourceBoundaryUsage // Valid Actions // - "ActionDelete" :: delete boundary usage statistics @@ -478,6 +487,7 @@ func AllResources() []Objecter { ResourceAssignOrgRole, ResourceAssignRole, ResourceAuditLog, + ResourceBoundaryLog, ResourceBoundaryUsage, ResourceChat, ResourceConnectionLog, diff --git a/coderd/rbac/policy/policy.go b/coderd/rbac/policy/policy.go index 7d7a42110d..f2b17927bd 100644 --- a/coderd/rbac/policy/policy.go +++ b/coderd/rbac/policy/policy.go @@ -422,6 +422,13 @@ var RBACPermissions = map[string]PermissionDefinition{ ActionRead: "read AI seat state", }, }, + "boundary_log": { + Actions: map[Action]ActionDefinition{ + ActionCreate: "create boundary log records", + ActionRead: "read boundary logs and session metadata", + ActionDelete: "delete boundary logs", + }, + }, "boundary_usage": { Actions: map[Action]ActionDefinition{ ActionRead: "read boundary usage statistics", diff --git a/coderd/rbac/roles.go b/coderd/rbac/roles.go index cbaf49f9c0..1b19947ea6 100644 --- a/coderd/rbac/roles.go +++ b/coderd/rbac/roles.go @@ -303,7 +303,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) { // Workspace is specifically handled based on the opts.NoOwnerWorkspaceExec. // Owners can inspect and delete personal skills for operability and // abuse handling, but cannot create or edit user-authored instructions. - allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUserSecret, ResourceUserSkill, ResourceUsageEvent, ResourceBoundaryUsage, ResourceAiSeat), + allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUserSecret, ResourceUserSkill, ResourceUsageEvent, ResourceBoundaryUsage, ResourceBoundaryLog, ResourceAiSeat), // This adds back in the Workspace permissions. Permissions(map[string][]policy.Action{ ResourceWorkspace.Type: ownerWorkspaceActions, @@ -313,6 +313,9 @@ func ReloadBuiltinRoles(opts *RoleOptions) { // Explicitly setting PrebuiltWorkspace permissions for clarity. // Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions. ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete}, + // Owners can read all boundary logs. Delete is reserved for + // DBPurge only. Create is user-scoped (inherited from member). + ResourceBoundaryLog.Type: {policy.ActionRead}, })..., ), User: []Permission{}, @@ -332,7 +335,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) { denyPermissions..., ), User: append( - allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceAibridgeInterception, ResourceChat, ResourceAiSeat), + allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceBoundaryLog, ResourceAibridgeInterception, ResourceChat, ResourceAiSeat), Permissions(map[string][]policy.Action{ // Users cannot do create/update/delete on themselves, but they // can read their own details. @@ -342,6 +345,11 @@ func ReloadBuiltinRoles(opts *RoleOptions) { // Members can create and update AI Bridge interceptions but // cannot read them back. ResourceAibridgeInterception.Type: {policy.ActionCreate, policy.ActionUpdate}, + // Workspace agents create boundary logs under their owner's + // identity. Create is user-scoped so agents can only write + // logs owned by their workspace owner. + // Read: owners and auditors. Delete: DBPurge only. + ResourceBoundaryLog.Type: {policy.ActionCreate}, })..., ), ByOrgID: map[string]OrgPermissions{}, @@ -366,6 +374,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) { ResourceDeploymentConfig.Type: {policy.ActionRead}, // Allow auditors to query AI Bridge interceptions. ResourceAibridgeInterception.Type: {policy.ActionRead}, + // Allow auditors to read boundary logs. + ResourceBoundaryLog.Type: {policy.ActionRead}, }), User: []Permission{}, ByOrgID: map[string]OrgPermissions{}, @@ -465,7 +475,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) { // Org admins should not have workspace exec perms. organizationID.String(): { Org: append( - allPermsExcept(ResourceWorkspace, ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceAssignRole, ResourceUserSecret, ResourceBoundaryUsage, ResourceAiSeat), + allPermsExcept(ResourceWorkspace, ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceAssignRole, ResourceUserSecret, ResourceBoundaryUsage, ResourceBoundaryLog, ResourceAiSeat), Permissions(map[string][]policy.Action{ ResourceWorkspace.Type: slice.Omit(ResourceWorkspace.AvailableActions(), policy.ActionApplicationConnect, policy.ActionSSH), ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent, policy.ActionUpdateAgent}, @@ -1052,6 +1062,7 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions { ResourcePrebuiltWorkspace, ResourceUser, ResourceOrganizationMember, + ResourceBoundaryLog, ResourceAibridgeInterception, // Chat access requires the agents-access role. ResourceChat, @@ -1137,6 +1148,7 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions { ResourcePrebuiltWorkspace, ResourceUser, ResourceOrganizationMember, + ResourceBoundaryLog, ResourceAibridgeInterception, // Chat access requires the agents-access role. ResourceChat, diff --git a/coderd/rbac/roles_test.go b/coderd/rbac/roles_test.go index 0170d308e0..0ac992fc86 100644 --- a/coderd/rbac/roles_test.go +++ b/coderd/rbac/roles_test.go @@ -1229,6 +1229,75 @@ func TestRolePermissions(t *testing.T) { false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin}, }, }, + { + // Boundary logs: members can create logs they own (user-scoped). + // memberMe and agentsAccessUser have ID == currentUser, so they + // match the resource owner. Other subjects have different IDs. + Name: "BoundaryLogCreate", + Actions: []policy.Action{policy.ActionCreate}, + Resource: rbac.ResourceBoundaryLog.WithOwner(currentUser.String()), + AuthorizeMap: map[bool][]hasAuthSubjects{ + true: {memberMe, agentsAccessUser}, + false: { + owner, + orgAdmin, otherOrgAdmin, + orgAuditor, otherOrgAuditor, auditor, + templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin, + userAdmin, orgUserAdmin, otherOrgUserAdmin, + }, + }, + }, + { + // Cross-user isolation: no subject can create boundary logs + // owned by a different user. The resource owner is a random + // UUID that does not match any test subject's ID. + Name: "BoundaryLogCreateOther", + Actions: []policy.Action{policy.ActionCreate}, + Resource: rbac.ResourceBoundaryLog.WithOwner(uuid.New().String()), + AuthorizeMap: map[bool][]hasAuthSubjects{ + true: {}, + false: { + owner, memberMe, agentsAccessUser, + orgAdmin, otherOrgAdmin, + orgAuditor, otherOrgAuditor, auditor, + templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin, + userAdmin, orgUserAdmin, otherOrgUserAdmin, + }, + }, + }, + { + // Boundary logs: only DBPurge can delete. No human role + // has delete; DBPurge is a system subject outside this matrix. + Name: "BoundaryLogDelete", + Actions: []policy.Action{policy.ActionDelete}, + Resource: rbac.ResourceBoundaryLog, + AuthorizeMap: map[bool][]hasAuthSubjects{ + true: {}, + false: { + owner, memberMe, agentsAccessUser, + orgAdmin, otherOrgAdmin, + orgAuditor, otherOrgAuditor, auditor, + templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin, + userAdmin, orgUserAdmin, otherOrgUserAdmin, + }, + }, + }, + { + // Boundary logs: owner and auditor get read. + Name: "BoundaryLogRead", + Actions: []policy.Action{policy.ActionRead}, + Resource: rbac.ResourceBoundaryLog, + AuthorizeMap: map[bool][]hasAuthSubjects{ + true: {owner, auditor}, + false: { + memberMe, agentsAccessUser, + orgAdmin, otherOrgAdmin, + orgAuditor, otherOrgAuditor, + templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin, + userAdmin, orgUserAdmin, otherOrgUserAdmin, + }, + }, + }, { Name: "ChatUsageCRU", Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate}, @@ -1471,3 +1540,121 @@ func TestChangeSet(t *testing.T) { }) } } + +// TestWorkspaceAgentScopeBoundaryLog verifies that a real workspace agent +// scope (not ScopeAll) can create boundary logs for its own owner but +// cannot create them for other users, and cannot read or delete them. +func TestWorkspaceAgentScopeBoundaryLog(t *testing.T) { + t.Parallel() + + auth := rbac.NewStrictAuthorizer(prometheus.NewRegistry()) + + ownerID := uuid.New() + otherOwnerID := uuid.New() + workspaceID := uuid.New() + templateID := uuid.New() + versionID := uuid.New() + + agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{ + WorkspaceID: workspaceID, + OwnerID: ownerID, + TemplateID: templateID, + VersionID: versionID, + }) + + memberRole, err := rbac.RoleByName(rbac.RoleMember()) + require.NoError(t, err) + + agent := rbac.Subject{ + ID: ownerID.String(), + Roles: rbac.Roles{memberRole}, + Scope: agentScope, + }.WithCachedASTValue() + + // Agent can create boundary logs for its own owner. + err = auth.Authorize(context.Background(), agent, policy.ActionCreate, + rbac.ResourceBoundaryLog.WithOwner(ownerID.String())) + require.NoError(t, err, "agent should create boundary logs for own owner") + + // Agent cannot create boundary logs for a different owner. + err = auth.Authorize(context.Background(), agent, policy.ActionCreate, + rbac.ResourceBoundaryLog.WithOwner(otherOwnerID.String())) + require.Error(t, err, "agent must not create boundary logs for other owner") + + // Agent cannot read boundary logs (even its own owner's). + err = auth.Authorize(context.Background(), agent, policy.ActionRead, + rbac.ResourceBoundaryLog.WithOwner(ownerID.String())) + require.Error(t, err, "agent must not read boundary logs") + + // Agent cannot delete boundary logs (even its own owner's). + err = auth.Authorize(context.Background(), agent, policy.ActionDelete, + rbac.ResourceBoundaryLog.WithOwner(ownerID.String())) + require.Error(t, err, "agent must not delete boundary logs") + + // When the workspace owner is a site admin, the agent scope + // wildcard for boundary_log combined with the owner role's site-level + // read grant means the agent CAN read all boundary logs. This is an + // accepted consequence of the wildcard scope needed for creation. + ownerRole, err := rbac.RoleByName(rbac.RoleOwner()) + require.NoError(t, err) + + adminAgent := rbac.Subject{ + ID: ownerID.String(), + Roles: rbac.Roles{memberRole, ownerRole}, + Scope: agentScope, + }.WithCachedASTValue() + + // Admin-owned agent CAN read boundary logs due to site-level owner + // role + wildcard scope. + err = auth.Authorize(context.Background(), adminAgent, policy.ActionRead, + rbac.ResourceBoundaryLog.WithOwner(otherOwnerID.String())) + require.NoError(t, err, "admin agent inherits site-level read via owner role") + + // Admin-owned agent still cannot create boundary logs for another owner + // because member-level create is user-scoped (subject.id must match owner). + err = auth.Authorize(context.Background(), adminAgent, policy.ActionCreate, + rbac.ResourceBoundaryLog.WithOwner(otherOwnerID.String())) + require.Error(t, err, "admin agent must not create boundary logs for other owner") +} + +// TestDBPurgeBoundaryLogDelete verifies that the DBPurge system subject +// can delete boundary logs but cannot create or read them. +func TestDBPurgeBoundaryLogDelete(t *testing.T) { + t.Parallel() + + auth := rbac.NewStrictAuthorizer(prometheus.NewRegistry()) + + // Build the DBPurge subject the same way dbauthz does. + dbPurge := rbac.Subject{ + Type: rbac.SubjectTypeDBPurge, + FriendlyName: "DB Purge", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "dbpurge"}, + DisplayName: "DB Purge Daemon", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceBoundaryLog.Type: {policy.ActionDelete}, + }), + User: []rbac.Permission{}, + ByOrgID: map[string]rbac.OrgPermissions{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + + // DBPurge can delete boundary logs. + err := auth.Authorize(context.Background(), dbPurge, policy.ActionDelete, + rbac.ResourceBoundaryLog) + require.NoError(t, err, "DBPurge should delete boundary logs") + + // DBPurge cannot create boundary logs. + err = auth.Authorize(context.Background(), dbPurge, policy.ActionCreate, + rbac.ResourceBoundaryLog.WithOwner(uuid.New().String())) + require.Error(t, err, "DBPurge must not create boundary logs") + + // DBPurge cannot read boundary logs. + err = auth.Authorize(context.Background(), dbPurge, policy.ActionRead, + rbac.ResourceBoundaryLog) + require.Error(t, err, "DBPurge must not read boundary logs") +} diff --git a/coderd/rbac/scopes.go b/coderd/rbac/scopes.go index 17e3990c31..7cbec46d74 100644 --- a/coderd/rbac/scopes.go +++ b/coderd/rbac/scopes.go @@ -65,6 +65,11 @@ func WorkspaceAgentScope(params WorkspaceAgentScopeParams) Scope { {Type: ResourceTemplate.Type, ID: params.TemplateID.String()}, {Type: ResourceTemplate.Type, ID: params.VersionID.String()}, {Type: ResourceUser.Type, ID: params.OwnerID.String()}, + // No pre-existing ID for new records; wildcard is required. + // Owner-scoped create (user-level) limits agents to their own + // logs. Adding site-level actions to the member role would + // bypass this and grant deployment-wide access. + {Type: ResourceBoundaryLog.Type, ID: policy.WildcardSymbol}, }, extraAllowList...), } } diff --git a/coderd/rbac/scopes_constants_gen.go b/coderd/rbac/scopes_constants_gen.go index c12cba430a..b664a4371a 100644 --- a/coderd/rbac/scopes_constants_gen.go +++ b/coderd/rbac/scopes_constants_gen.go @@ -33,6 +33,9 @@ const ( ScopeAssignRoleUnassign ScopeName = "assign_role:unassign" ScopeAuditLogCreate ScopeName = "audit_log:create" ScopeAuditLogRead ScopeName = "audit_log:read" + ScopeBoundaryLogCreate ScopeName = "boundary_log:create" + ScopeBoundaryLogDelete ScopeName = "boundary_log:delete" + ScopeBoundaryLogRead ScopeName = "boundary_log:read" ScopeBoundaryUsageDelete ScopeName = "boundary_usage:delete" ScopeBoundaryUsageRead ScopeName = "boundary_usage:read" ScopeBoundaryUsageUpdate ScopeName = "boundary_usage:update" @@ -210,6 +213,9 @@ func (e ScopeName) Valid() bool { ScopeAssignRoleUnassign, ScopeAuditLogCreate, ScopeAuditLogRead, + ScopeBoundaryLogCreate, + ScopeBoundaryLogDelete, + ScopeBoundaryLogRead, ScopeBoundaryUsageDelete, ScopeBoundaryUsageRead, ScopeBoundaryUsageUpdate, @@ -388,6 +394,9 @@ func AllScopeNameValues() []ScopeName { ScopeAssignRoleUnassign, ScopeAuditLogCreate, ScopeAuditLogRead, + ScopeBoundaryLogCreate, + ScopeBoundaryLogDelete, + ScopeBoundaryLogRead, ScopeBoundaryUsageDelete, ScopeBoundaryUsageRead, ScopeBoundaryUsageUpdate, diff --git a/coderd/users.go b/coderd/users.go index 4245e6766c..585360630e 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1336,7 +1336,7 @@ func (api *API) userPreferenceSettings(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserPreferenceSettings{ TaskNotificationAlertDismissed: taskAlertDismissed, ThinkingDisplayMode: sanitizeThinkingDisplayMode(thinkingMode), - ShellToolDisplayMode: sanitizeAgentDisplayMode(shellToolMode), + ShellToolDisplayMode: sanitizeShellToolDisplayMode(shellToolMode), CodeDiffDisplayMode: sanitizeAgentDisplayMode(codeDiffMode), AgentChatSendShortcut: sanitizeAgentChatSendShortcut(agentChatSendShortcut), }) @@ -1446,13 +1446,13 @@ func (api *API) putUserPreferenceSettings(rw http.ResponseWriter, r *http.Reques if err != nil { return newUserPreferenceSettingsAPIError("Internal error updating shell tool display mode.", err) } - settings.ShellToolDisplayMode = sanitizeAgentDisplayMode(updated) + settings.ShellToolDisplayMode = sanitizeShellToolDisplayMode(updated) } else { stored, err := tx.GetUserShellToolDisplayMode(ctx, user.ID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return newUserPreferenceSettingsAPIError("Error reading shell tool display mode.", err) } - settings.ShellToolDisplayMode = sanitizeAgentDisplayMode(stored) + settings.ShellToolDisplayMode = sanitizeShellToolDisplayMode(stored) } if params.CodeDiffDisplayMode != "" { @@ -1545,12 +1545,20 @@ func sanitizeThinkingDisplayMode(raw string) codersdk.ThinkingDisplayMode { return codersdk.ThinkingDisplayModeAuto } +func sanitizeShellToolDisplayMode(raw string) codersdk.AgentDisplayMode { + mode := sanitizeAgentDisplayMode(raw) + if mode == "" { + return codersdk.AgentDisplayModeAlwaysCollapsed + } + return mode +} + func sanitizeAgentDisplayMode(raw string) codersdk.AgentDisplayMode { mode := codersdk.AgentDisplayMode(raw) if slices.Contains(codersdk.ValidAgentDisplayModes, mode) { return mode } - return codersdk.AgentDisplayModeAuto + return "" } func sanitizeAgentChatSendShortcut(raw string) codersdk.AgentChatSendShortcut { diff --git a/coderd/users_test.go b/coderd/users_test.go index 87ac9b6db5..29da1887a4 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -2423,7 +2423,7 @@ func TestAgentDisplayModePreferences(t *testing.T) { require.Equal(t, field, sdkErr.Validations[0].Field) } - t.Run("defaults to auto", func(t *testing.T) { + t.Run("defaults shell tools to always collapsed", func(t *testing.T) { t.Parallel() client, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) @@ -2433,8 +2433,8 @@ func TestAgentDisplayModePreferences(t *testing.T) { settings, err := client.GetUserPreferenceSettings(ctx, codersdk.Me) require.NoError(t, err) - require.Equal(t, codersdk.AgentDisplayModeAuto, settings.ShellToolDisplayMode) - require.Equal(t, codersdk.AgentDisplayModeAuto, settings.CodeDiffDisplayMode) + require.Equal(t, codersdk.AgentDisplayModeAlwaysCollapsed, settings.ShellToolDisplayMode) + require.Empty(t, settings.CodeDiffDisplayMode) }) t.Run("round-trips shell tool display mode", func(t *testing.T) { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index b3074740b1..074ca687b1 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1096,6 +1096,11 @@ func (api *API) workspaceAgentRecreateDevcontainer(rw http.ResponseWriter, r *ht ctx := r.Context() waws := httpmw.WorkspaceAgentAndWorkspaceParam(r) + if !api.Authorize(r, policy.ActionUpdate, waws.WorkspaceTable) { + httpapi.Forbidden(rw) + return + } + devcontainer := chi.URLParam(r, "devcontainer") if devcontainer == "" { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -2562,9 +2567,9 @@ func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Req if locked.OwnerID != workspace.OwnerID { return errChatDoesNotBelongToWorkspaceOwner } - if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams( + if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleUserChatMessageInsertParams( chat.ID, - database.ChatMessageRoleUser, + "", // Agent-initiated context injection has no caller API key. content, database.ChatMessageVisibilityBoth, locked.LastModelConfigID, diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 0fff131f5a..9b36d11c27 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -1876,6 +1876,51 @@ func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { }) } +func TestWorkspaceAgentRecreateDevcontainerAuthorization(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + role func(uuid.UUID) rbac.RoleIdentifier + }{ + { + name: "TemplateAdmin", + role: func(uuid.UUID) rbac.RoleIdentifier { + return rbac.RoleTemplateAdmin() + }, + }, + { + name: "OrgTemplateAdmin", + role: rbac.ScopedRoleOrgTemplateAdmin, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitMedium) + client, db = coderdtest.NewWithDatabase(t, nil) + admin = coderdtest.CreateFirstUser(t, client) + _, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) + templateAdminClient, _ = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID, tc.role(admin.OrganizationID)) + workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: admin.OrganizationID, + OwnerID: workspaceOwner.ID, + }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { + return agents + }).Do() + ) + + _, err := templateAdminClient.WorkspaceAgentRecreateDevcontainer(ctx, workspace.Agents[0].ID, uuid.NewString()) + require.Error(t, err) + + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusForbidden, sdkErr.StatusCode()) + }) + } +} + func TestWorkspaceAgentDeleteDevcontainer(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index b94d9dc7a4..483ff0f183 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -1629,7 +1629,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C return xerrors.Errorf("marshal initial user content: %w", err) } - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage. ChatID: insertedChat.ID, } @@ -1673,13 +1673,15 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C chatprompt.CurrentContentVersion, )) - appendChatMessage(&msgParams, newChatMessage( - database.ChatMessageRoleUser, + userMsg := newUserChatMessage( + opts.APIKeyID, userContent, database.ChatMessageVisibilityBoth, opts.ModelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(opts.OwnerID).withAPIKeyID(opts.APIKeyID)) + ) + userMsg = userMsg.withCreatedBy(opts.OwnerID) + appendUserChatMessage(&msgParams, userMsg) _, err = tx.InsertChatMessages(ctx, msgParams) if err != nil { @@ -2111,16 +2113,18 @@ func (p *Server) EditMessage( // InsertChatMessages CTE updates chats.last_model_config_id // when the new message's model differs, so the assistant turn // that follows picks up the new selection. - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: opts.ChatID, } - appendChatMessage(&msgParams, newChatMessage( - database.ChatMessageRoleUser, + editUserMsg := newUserChatMessage( + opts.APIKeyID, content, editedMsg.Visibility, messageModelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(opts.CreatedBy).withAPIKeyID(opts.APIKeyID)) + ) + editUserMsg = editUserMsg.withCreatedBy(opts.CreatedBy) + appendUserChatMessage(&msgParams, editUserMsg) newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams) if err != nil { return xerrors.Errorf("insert replacement message: %w", err) @@ -3899,19 +3903,16 @@ func insertChatMessageWithStore( return messages, nil } -// chatMessage describes a single message to insert as part of a batch. -// Use newChatMessage to create one, then chain builder methods for -// optional fields. For nullable UUID fields (ModelConfigID, CreatedBy), -// use uuid.Nil to represent NULL — the SQL uses NULLIF to convert zero -// UUIDs to NULL. For nullable int64 fields, use 0 to represent NULL — -// the SQL uses NULLIF to convert zeros to NULL. +// chatMessage is the base message type for batch inserts. Use directly +// only for non-user messages; for user messages, use userChatMessage. +// For nullable UUID fields (ModelConfigID, CreatedBy), use uuid.Nil to +// represent NULL. For nullable int64 fields, use 0 to represent NULL. type chatMessage struct { role database.ChatMessageRole content pqtype.NullRawMessage visibility database.ChatMessageVisibility modelConfigID uuid.UUID createdBy uuid.UUID - apiKeyID string contentVersion int16 compressed bool inputTokens int64 @@ -3926,6 +3927,23 @@ type chatMessage struct { providerResponseID string } +// userChatMessage wraps chatMessage with a required apiKeyID so that +// omitting it for user messages is a compile error, not a silent data bug. +type userChatMessage struct { + chatMessage + apiKeyID string +} + +func (m userChatMessage) withCreatedBy(id uuid.UUID) userChatMessage { + m.chatMessage = m.chatMessage.withCreatedBy(id) + return m +} + +func (m userChatMessage) withCompressed() userChatMessage { + m.chatMessage = m.chatMessage.withCompressed() + return m +} + func newChatMessage( role database.ChatMessageRole, content pqtype.NullRawMessage, @@ -3942,13 +3960,29 @@ func newChatMessage( } } -func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage { - m.createdBy = id - return m +// newUserChatMessage creates a user message. apiKeyID is required so +// that forgetting it is a compile error rather than a silent data bug. +func newUserChatMessage( + apiKeyID string, + content pqtype.NullRawMessage, + visibility database.ChatMessageVisibility, + modelConfigID uuid.UUID, + contentVersion int16, +) userChatMessage { + return userChatMessage{ + chatMessage: newChatMessage( + database.ChatMessageRoleUser, + content, + visibility, + modelConfigID, + contentVersion, + ), + apiKeyID: apiKeyID, + } } -func (m chatMessage) withAPIKeyID(id string) chatMessage { - m.apiKeyID = id +func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage { + m.createdBy = id return m } @@ -3990,13 +4024,16 @@ func (m chatMessage) withProviderResponseID(id string) chatMessage { return m } -// appendChatMessage appends a single message to the batch insert params. -func appendChatMessage( +// appendMessageFields writes all chatMessage fields into the batch insert +// params. apiKeyID is explicit so non-user messages always get "" while +// user messages carry the caller's key for AI Gateway routing. +func appendMessageFields( params *database.InsertChatMessagesParams, msg chatMessage, + apiKeyID string, ) { params.CreatedBy = append(params.CreatedBy, msg.createdBy) - params.APIKeyID = append(params.APIKeyID, msg.apiKeyID) + params.APIKeyID = append(params.APIKeyID, apiKeyID) params.ModelConfigID = append(params.ModelConfigID, msg.modelConfigID) params.Role = append(params.Role, msg.role) params.Content = append(params.Content, string(msg.content.RawMessage)) @@ -4015,25 +4052,44 @@ func appendChatMessage( params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID) } -// BuildSingleChatMessageInsertParams creates batch insert params for one -// message using the shared chat message builder. -func BuildSingleChatMessageInsertParams( +// appendChatMessage appends a non-user message to the batch insert params. +func appendChatMessage( + params *database.InsertChatMessagesParams, + msg chatMessage, +) { + if msg.role == database.ChatMessageRoleUser { + panic("developer error: use appendUserChatMessage for user-role messages") + } + appendMessageFields(params, msg, "") +} + +// appendUserChatMessage inserts a user message with its apiKeyID preserved. +func appendUserChatMessage( + params *database.InsertChatMessagesParams, + msg userChatMessage, +) { + appendMessageFields(params, msg.chatMessage, msg.apiKeyID) +} + +// BuildSingleUserChatMessageInsertParams creates batch insert params for +// one user message, requiring an apiKeyID for AI Gateway attribution. +func BuildSingleUserChatMessageInsertParams( chatID uuid.UUID, - role database.ChatMessageRole, + apiKeyID string, content pqtype.NullRawMessage, visibility database.ChatMessageVisibility, modelConfigID uuid.UUID, contentVersion int16, createdBy uuid.UUID, ) database.InsertChatMessagesParams { - params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: chatID, } - msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion) + msg := newUserChatMessage(apiKeyID, content, visibility, modelConfigID, contentVersion) if createdBy != uuid.Nil { msg = msg.withCreatedBy(createdBy) } - appendChatMessage(¶ms, msg) + appendUserChatMessage(¶ms, msg) return params } @@ -4048,16 +4104,18 @@ func insertUserMessageAndSetPending( createdBy uuid.UUID, apiKeyID string, ) (database.ChatMessage, database.Chat, error) { - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: lockedChat.ID, } - appendChatMessage(&msgParams, newChatMessage( - database.ChatMessageRoleUser, + insertUserMsg := newUserChatMessage( + apiKeyID, content, database.ChatMessageVisibilityBoth, modelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(createdBy).withAPIKeyID(apiKeyID)) + ) + insertUserMsg = insertUserMsg.withCreatedBy(createdBy) + appendUserChatMessage(&msgParams, insertUserMsg) messages, err := insertChatMessageWithStore(ctx, store, msgParams) if err != nil { return database.ChatMessage{}, database.Chat{}, err @@ -5870,11 +5928,11 @@ func (p *Server) tryAutoPromoteQueuedMessage( return nil, nil, false, xerrors.New("popped queued message out of order") } - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: chat.ID, } - appendChatMessage(&msgParams, newChatMessage( - database.ChatMessageRoleUser, + queuedUserMsg := newUserChatMessage( + nextQueued.APIKeyID.String, pqtype.NullRawMessage{ RawMessage: nextQueued.Content, Valid: len(nextQueued.Content) > 0, @@ -5882,7 +5940,9 @@ func (p *Server) tryAutoPromoteQueuedMessage( database.ChatMessageVisibilityBoth, effectiveModelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(chat.OwnerID).withAPIKeyID(nextQueued.APIKeyID.String)) + ) + queuedUserMsg = queuedUserMsg.withCreatedBy(chat.OwnerID) + appendUserChatMessage(&msgParams, queuedUserMsg) msgs, err := insertChatMessageWithStore(ctx, tx, msgParams) if err != nil { return nil, nil, false, xerrors.Errorf("insert promoted message: %w", err) @@ -8459,18 +8519,21 @@ func (p *Server) persistChatContextSummary( var insertedMessages []database.ChatMessage txErr := p.db.InTx(func(tx database.Store) error { - summaryParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + summaryParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage. ChatID: chatID, } // Hidden summary user message (not published to subscribers). - appendChatMessage(&summaryParams, newChatMessage( - database.ChatMessageRoleUser, + summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) + summaryUserMsg := newUserChatMessage( + summaryAPIKeyID, systemContent, database.ChatMessageVisibilityModel, modelConfigID, chatprompt.CurrentContentVersion, - ).withCompressed()) + ) + summaryUserMsg = summaryUserMsg.withCompressed() + appendUserChatMessage(&summaryParams, summaryUserMsg) // Assistant tool-call message. appendChatMessage(&summaryParams, newChatMessage( @@ -9001,11 +9064,12 @@ func (p *Server) persistInstructionFiles( if err != nil { return "", nil, nil } - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: chat.ID, } - appendChatMessage(&msgParams, newChatMessage( - database.ChatMessageRoleUser, + appendUserChatMessage(&msgParams, newUserChatMessage( + contextAPIKeyID, content, database.ChatMessageVisibilityBoth, modelConfigID, @@ -9024,11 +9088,12 @@ func (p *Server) persistInstructionFiles( return "", nil, xerrors.Errorf("marshal context-file parts: %w", err) } - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: chat.ID, } - appendChatMessage(&msgParams, newChatMessage( - database.ChatMessageRoleUser, + appendUserChatMessage(&msgParams, newUserChatMessage( + contextAPIKeyID, content, database.ChatMessageVisibilityBoth, modelConfigID, diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index b28c0cdb0c..69aebbfe30 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -19,9 +19,12 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtestutil" dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" @@ -1339,6 +1342,8 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { t.Parallel() ctx := context.Background() + testAPIKeyID := uuid.NewString() + ctx = aibridge.WithDelegatedAPIKeyID(ctx, testAPIKeyID) ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) @@ -1366,7 +1371,18 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { gomock.Any(), agentID, ).Return(workspaceAgent, nil).Times(1) - db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Cond(func(x any) bool { + params, ok := x.(database.InsertChatMessagesParams) + if !ok { + return false + } + for i, role := range params.Role { + if role == database.ChatMessageRoleUser && params.APIKeyID[i] != testAPIKeyID { + return false + } + } + return true + })).Return(nil, nil).AnyTimes() db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Cond(func(x any) bool { arg, ok := x.(database.UpdateChatLastInjectedContextParams) @@ -6616,3 +6632,61 @@ func TestPrimeWorkspaceMCPCache_ExitsOnContextCancel(t *testing.T) { _, ok := server.workspaceMCPToolsCache.Load(chat.ID) require.False(t, ok, "primer must not cache anything when canceled") } + +func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + modelConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{}) + chat := dbgen.Chat(t, db, database.Chat{ + OwnerID: user.ID, + OrganizationID: org.ID, + LastModelConfigID: modelConfig.ID, + }) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + }) + + ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID) + + server := &Server{db: db} + + err := server.persistChatContextSummary( + ctx, + chat.ID, + modelConfig.ID, + "tool-call-id-1", + chatloop.CompactionResult{ + SystemSummary: "summarized context", + SummaryReport: "context was summarized", + ThresholdPercent: 70, + UsagePercent: 85.0, + ContextTokens: 8500, + ContextLimit: 10000, + }, + ) + require.NoError(t, err) + + msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + + // GetChatMessagesForPromptByChatID uses a compaction boundary CTE + // that selects compressed=true, visibility='model'. Only the user + // summary qualifies; the assistant (visibility=user) and tool + // result (visibility=both) are excluded by the CTE filter. + require.NotEmpty(t, msgs) + + var foundUserSummary bool + for _, msg := range msgs { + if msg.Role == database.ChatMessageRoleUser { + foundUserSummary = true + require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set") + require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match") + } + } + require.True(t, foundUserSummary, "expected to find compressed user summary message") +} diff --git a/coderd/x/chatd/chaterror/classify.go b/coderd/x/chatd/chaterror/classify.go index 73e50f083b..4bf28efd4f 100644 --- a/coderd/x/chatd/chaterror/classify.go +++ b/coderd/x/chatd/chaterror/classify.go @@ -195,6 +195,7 @@ func Classify(err error) ClassifiedError { } retryableHTTP2StreamReset, hasHTTP2StreamReset := classifyHTTP2StreamReset(err) + providerDisabledMatch := containsAny(lower, providerDisabledPatterns...) deadline := errors.Is(err, context.DeadlineExceeded) || strings.Contains(lower, "context deadline exceeded") overloadedMatch := statusCode == 529 || containsAny(lower, overloadedPatterns...) usageLimitMatch := containsAny(lower, usageLimitPatterns...) @@ -221,6 +222,8 @@ func Classify(err error) ClassifiedError { // over whatever HTTP status code the provider happened to use. // Strong auth still stays above config because bad credentials are // the root cause when both signals appear. + // Provider-disabled must precede timeout because disabled providers + // return 503, which matches the timeout rule. rules := []struct { match bool kind codersdk.ChatErrorKind @@ -251,6 +254,11 @@ func Classify(err error) ClassifiedError { kind: codersdk.ChatErrorKindRateLimit, retryable: true, }, + { + match: providerDisabledMatch, + kind: codersdk.ChatErrorKindProviderDisabled, + retryable: false, + }, { match: timeoutMatch && !configMatch, kind: codersdk.ChatErrorKindTimeout, diff --git a/coderd/x/chatd/chaterror/classify_test.go b/coderd/x/chatd/chaterror/classify_test.go index 457704bd5f..0e2e008bb8 100644 --- a/coderd/x/chatd/chaterror/classify_test.go +++ b/coderd/x/chatd/chaterror/classify_test.go @@ -2,6 +2,7 @@ package chaterror_test import ( "context" + "fmt" "io" "net/http" "strings" @@ -218,6 +219,85 @@ func TestClassify(t *testing.T) { StatusCode: 0, }, }, + // The next cases model the error that fantasy produces + // when aibridge's disabledProviderHandler returns a 503 + // plain-text sentinel. Fantasy sets Title from the HTTP + // status text and Message from the response body (including + // the trailing newline written by http.Error). + { + name: "ProviderDisabled503ClassifiesAsProviderDisabled", + err: &fantasy.ProviderError{ + Title: fantasy.ErrorTitleForStatusCode(http.StatusServiceUnavailable), + Message: fmt.Sprintf("%s: AI provider %q is disabled\n", codersdk.ChatErrorKindProviderDisabled, "openai"), + StatusCode: http.StatusServiceUnavailable, + }, + want: chaterror.ClassifiedError{ + Message: "The OpenAI provider has been disabled. Contact your Coder administrator.", + Detail: fmt.Sprintf("%s: AI provider %q is disabled", codersdk.ChatErrorKindProviderDisabled, "openai"), + Kind: codersdk.ChatErrorKindProviderDisabled, + Provider: "openai", + Retryable: false, + StatusCode: 503, + }, + }, + { + name: "ProviderDisabled503UnknownProvider", + err: &fantasy.ProviderError{ + Title: fantasy.ErrorTitleForStatusCode(http.StatusServiceUnavailable), + Message: fmt.Sprintf("%s: AI provider %q is disabled\n", codersdk.ChatErrorKindProviderDisabled, "mycustomprovider"), + StatusCode: http.StatusServiceUnavailable, + }, + want: chaterror.ClassifiedError{ + Message: "The AI provider has been disabled. Contact your Coder administrator.", + Detail: fmt.Sprintf("%s: AI provider %q is disabled", codersdk.ChatErrorKindProviderDisabled, "mycustomprovider"), + Kind: codersdk.ChatErrorKindProviderDisabled, + Provider: "", + Retryable: false, + StatusCode: 503, + }, + }, + { + name: "ProviderDisabledPlainErrorString", + err: xerrors.New(fmt.Sprintf("%s: AI provider %q is disabled", codersdk.ChatErrorKindProviderDisabled, "anthropic")), + want: chaterror.ClassifiedError{ + Message: "The Anthropic provider has been disabled. Contact your Coder administrator.", + Kind: codersdk.ChatErrorKindProviderDisabled, + Provider: "anthropic", + Retryable: false, + StatusCode: 0, + }, + }, + { + name: "ProviderDisabledBeatsTimeout503", + err: &fantasy.ProviderError{ + Title: fantasy.ErrorTitleForStatusCode(http.StatusServiceUnavailable), + Message: fmt.Sprintf("%s: AI provider %q is disabled\n", codersdk.ChatErrorKindProviderDisabled, "google"), + StatusCode: http.StatusServiceUnavailable, + }, + want: chaterror.ClassifiedError{ + Message: "The Google provider has been disabled. Contact your Coder administrator.", + Detail: fmt.Sprintf("%s: AI provider %q is disabled", codersdk.ChatErrorKindProviderDisabled, "google"), + Kind: codersdk.ChatErrorKindProviderDisabled, + Provider: "google", + Retryable: false, + StatusCode: 503, + }, + }, + { + name: "Generic503StillClassifiesAsTimeout", + err: &fantasy.ProviderError{ + Message: "service unavailable", + StatusCode: 503, + }, + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Detail: "service unavailable", + Kind: codersdk.ChatErrorKindTimeout, + Provider: "", + Retryable: true, + StatusCode: 503, + }, + }, } for _, tt := range tests { @@ -363,6 +443,7 @@ func TestClassify_PatternCoverage(t *testing.T) { {name: "OperationInterruptedLiteral", err: "operation interrupted", wantKind: codersdk.ChatErrorKindGeneric, wantRetry: false}, {name: "Status408", err: "status 408", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, {name: "Status500", err: "status 500", wantKind: codersdk.ChatErrorKindGeneric, wantRetry: true}, + {name: "ProviderDisabledLiteral", err: "provider_disabled", wantKind: codersdk.ChatErrorKindProviderDisabled, wantRetry: false}, } for _, tt := range tests { @@ -1158,6 +1239,28 @@ func TestClassify_ChainBrokenSurvivesWithClassification(t *testing.T) { " can detect it after re-classification") } +func TestClassify_MissingKeyPreClassified(t *testing.T) { + t.Parallel() + + raw := xerrors.New("AI Gateway routing requires the active turn API key ID") + wrapped := chaterror.WithClassification(raw, chaterror.ClassifiedError{ + Kind: codersdk.ChatErrorKindMissingKey, + Retryable: false, + Detail: "If this error persists after resending, please report it as a bug.", + }) + + classified := chaterror.Classify(wrapped) + require.Equal(t, codersdk.ChatErrorKindMissingKey, classified.Kind) + require.False(t, classified.Retryable) + require.Equal(t, "If this error persists after resending, please report it as a bug.", classified.Detail) + require.Equal(t, + "This conversation was started with an API key that is no longer available."+ + " Send your message again to continue.", + classified.Message, + "Message should be filled by terminalMessage when not set explicitly", + ) +} + func testProviderError( message string, statusCode int, diff --git a/coderd/x/chatd/chaterror/message.go b/coderd/x/chatd/chaterror/message.go index a551078349..fef3ba78fa 100644 --- a/coderd/x/chatd/chaterror/message.go +++ b/coderd/x/chatd/chaterror/message.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + stringutil "github.com/coder/coder/v2/coderd/util/strings" "github.com/coder/coder/v2/codersdk" ) @@ -16,56 +17,58 @@ func terminalMessage(classified ClassifiedError) string { subject := providerSubject(classified.Provider) switch classified.Kind { case codersdk.ChatErrorKindOverloaded: - return fmt.Sprintf("%s is temporarily overloaded.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s is temporarily overloaded.", subject)) case codersdk.ChatErrorKindRateLimit: - return fmt.Sprintf("%s is rate limiting requests.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s is rate limiting requests.", subject)) case codersdk.ChatErrorKindTimeout: if !classified.Retryable && classified.StatusCode == 0 { return "The request timed out before it completed." } - return fmt.Sprintf("%s is temporarily unavailable.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s is temporarily unavailable.", subject)) case codersdk.ChatErrorKindStartupTimeout: - return fmt.Sprintf( + return stringutil.Capitalize(fmt.Sprintf( "%s did not start responding in time.", subject, - ) + )) case codersdk.ChatErrorKindUsageLimit: - displayName := providerDisplayName(classified.Provider) - if displayName == "" { - displayName = "the AI provider" - } - return fmt.Sprintf( + return stringutil.Capitalize(fmt.Sprintf( "The usage quota for %s has been exceeded."+ " Check the billing and quota settings for the provider account.", - displayName, - ) + subject, + )) case codersdk.ChatErrorKindAuth: - displayName := providerDisplayName(classified.Provider) - if displayName == "" { - displayName = "the AI provider" - } return fmt.Sprintf( "Authentication with %s failed."+ " Check the API key and permissions.", - displayName, - ) - - case codersdk.ChatErrorKindConfig: - return fmt.Sprintf( - "%s rejected the model configuration."+ - " Check the selected model and provider settings.", subject, ) + case codersdk.ChatErrorKindConfig: + return stringutil.Capitalize(fmt.Sprintf( + "%s rejected the model configuration."+ + " Check the selected model and provider settings.", + subject, + )) + + case codersdk.ChatErrorKindMissingKey: + return "This conversation was started with an API key that is no longer available." + + " Send your message again to continue." + case codersdk.ChatErrorKindProviderDisabled: + displayName := providerDisplayName(classified.Provider) + return fmt.Sprintf( + "The %s provider has been disabled."+ + " Contact your Coder administrator.", + displayName, + ) default: if !classified.Retryable && classified.StatusCode == 0 { return "The chat request failed unexpectedly." } - return fmt.Sprintf("%s returned an unexpected error.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s returned an unexpected error.", subject)) } } @@ -81,39 +84,43 @@ func retryMessage(classified ClassifiedError) string { subject := providerSubject(classified.Provider) switch classified.Kind { case codersdk.ChatErrorKindOverloaded: - return fmt.Sprintf("%s is temporarily overloaded.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s is temporarily overloaded.", subject)) case codersdk.ChatErrorKindRateLimit: - return fmt.Sprintf("%s is rate limiting requests.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s is rate limiting requests.", subject)) case codersdk.ChatErrorKindTimeout: - return fmt.Sprintf("%s is temporarily unavailable.", subject) + return stringutil.Capitalize(fmt.Sprintf("%s is temporarily unavailable.", subject)) case codersdk.ChatErrorKindStartupTimeout: - return fmt.Sprintf( + return stringutil.Capitalize(fmt.Sprintf( "%s did not start responding in time.", subject, - ) + )) case codersdk.ChatErrorKindAuth: - displayName := providerDisplayName(classified.Provider) - if displayName == "" { - displayName = "the AI provider" - } return fmt.Sprintf( - "Authentication with %s failed.", displayName, + "Authentication with %s failed.", subject, ) case codersdk.ChatErrorKindConfig: - return fmt.Sprintf( + return stringutil.Capitalize(fmt.Sprintf( "%s rejected the model configuration.", subject, + )) + case codersdk.ChatErrorKindMissingKey: + return "The API key for this conversation is no longer available." + case codersdk.ChatErrorKindProviderDisabled: + displayName := providerDisplayName(classified.Provider) + return fmt.Sprintf( + "The %s provider has been disabled by an administrator.", + displayName, ) default: - return fmt.Sprintf( + return stringutil.Capitalize(fmt.Sprintf( "%s returned an unexpected error.", subject, - ) + )) } } func providerSubject(provider string) string { - if displayName := providerDisplayName(provider); displayName != "" { + if displayName := providerDisplayName(provider); displayName != "AI" && displayName != "" { return displayName } - return "The AI provider" + return "the AI provider" } func providerDisplayName(provider string) string { @@ -135,7 +142,7 @@ func providerDisplayName(provider string) string { case "vercel": return "Vercel AI Gateway" default: - return "" + return "AI" } } diff --git a/coderd/x/chatd/chaterror/message_test.go b/coderd/x/chatd/chaterror/message_test.go index 87cb375cbc..94bf14bd13 100644 --- a/coderd/x/chatd/chaterror/message_test.go +++ b/coderd/x/chatd/chaterror/message_test.go @@ -90,6 +90,13 @@ func TestTerminalMessage(t *testing.T) { retryable: false, want: "The usage quota for the AI provider has been exceeded. Check the billing and quota settings for the provider account.", }, + { + name: "MissingKey", + kind: codersdk.ChatErrorKindMissingKey, + provider: "", + retryable: false, + want: "This conversation was started with an API key that is no longer available. Send your message again to continue.", + }, } for _, tt := range tests { diff --git a/coderd/x/chatd/chaterror/signals.go b/coderd/x/chatd/chaterror/signals.go index ebe6ff939b..8dad919127 100644 --- a/coderd/x/chatd/chaterror/signals.go +++ b/coderd/x/chatd/chaterror/signals.go @@ -4,6 +4,8 @@ import ( "regexp" "strconv" "strings" + + "github.com/coder/coder/v2/aibridge" ) type providerHint struct { @@ -83,6 +85,7 @@ var ( } genericRetryablePatterns = []string{"server error", "internal server error"} interruptedPatterns = []string{"chat interrupted", "request interrupted", "operation interrupted"} + providerDisabledPatterns = []string{aibridge.ErrorCodeProviderDisabled} ) func extractStatusCode(lower string) int { diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index c67e2ee6d0..7a81dc4d6e 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -39,10 +39,11 @@ const ( // prevents infinite compaction loops when the model keeps // hitting the context limit after summarization. maxCompactionRetries = 3 - // defaultStartupTimeout bounds how long an individual - // model attempt may spend starting to respond before + // defaultStreamSilenceTimeout bounds how long an individual + // model attempt may go without receiving a stream part before // the attempt is canceled and retried. - defaultStartupTimeout = 60 * time.Second + defaultStreamSilenceTimeout = 10 * time.Minute + streamSilenceGuardTimerTag = "streamSilenceGuard" ) var ( @@ -53,8 +54,8 @@ var ( // the run should terminate cleanly after persistence. ErrStopAfterTool = xerrors.New("stop after tool") - errStartupTimeout = xerrors.New( - "chat response did not start before the startup timeout", + errStreamSilenceTimeout = xerrors.New( + "chat stream was silent for longer than the configured timeout", ) ) @@ -114,14 +115,14 @@ type RunOptions struct { Messages []fantasy.Message Tools []fantasy.AgentTool MaxSteps int - // StartupTimeout bounds how long each model attempt may - // spend opening the provider stream and waiting for its - // first stream part before the attempt is canceled and - // retried. Zero uses the production default. - StartupTimeout time.Duration - // Clock creates startup guard timers. In production use a - // real clock; tests can inject quartz.NewMock(t) to make - // startup timeout behavior deterministic. + // StreamSilenceTimeout bounds how long each model attempt + // may go without receiving a stream part before the + // attempt is canceled and retried. Zero uses the + // production default. + StreamSilenceTimeout time.Duration + // Clock creates stream silence guard timers. In production + // use a real clock; tests can inject quartz.NewMock(t) to + // make timeout behavior deterministic. Clock quartz.Clock ActiveTools []string @@ -364,8 +365,8 @@ func Run(ctx context.Context, opts RunOptions) error { if opts.MaxSteps <= 0 { opts.MaxSteps = 1 } - if opts.StartupTimeout <= 0 { - opts.StartupTimeout = defaultStartupTimeout + if opts.StreamSilenceTimeout <= 0 { + opts.StreamSilenceTimeout = defaultStreamSilenceTimeout } if opts.Clock == nil { opts.Clock = quartz.NewReal() @@ -468,7 +469,7 @@ func Run(ctx context.Context, opts RunOptions) error { provider, modelName, opts.Clock, - opts.StartupTimeout, + opts.StreamSilenceTimeout, func(attemptCtx context.Context) (fantasy.StreamResponse, error) { return opts.Model.Stream(attemptCtx, call) }, @@ -782,9 +783,9 @@ func prepareMessagesForRequest( return canonical, prompt, nil } -// guardedAttempt owns an attempt-scoped context and startup guard +// guardedAttempt owns an attempt-scoped context and silence guard // around a provider stream. release is idempotent and frees the -// attempt-scoped timer/context. finish canonicalizes startup timeout +// attempt-scoped timer/context. finish canonicalizes silence timeout // errors before the retry loop classifies them. type guardedAttempt struct { ctx context.Context @@ -793,47 +794,77 @@ type guardedAttempt struct { finish func(error) error } -// startupGuard arbitrates whether an attempt times out during -// stream startup. Exactly one outcome wins: the timer cancels -// the attempt, or the first-part path disarms the timer. -type startupGuard struct { - timer *quartz.Timer - cancel context.CancelCauseFunc - once sync.Once +// streamSilenceGuard arbitrates whether an attempt times out while +// waiting for the next stream part. Exactly one outcome wins: the +// timer cancels the attempt, or release disarms the timer. +type streamSilenceGuard struct { + mu sync.Mutex + timer *quartz.Timer + cancel context.CancelCauseFunc + timeout time.Duration + settled bool } -func newStartupGuard( +func newStreamSilenceGuard( clock quartz.Clock, timeout time.Duration, cancel context.CancelCauseFunc, -) *startupGuard { - guard := &startupGuard{cancel: cancel} - guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard") +) *streamSilenceGuard { + guard := &streamSilenceGuard{ + cancel: cancel, + timeout: timeout, + } + guard.timer = clock.AfterFunc( + timeout, + guard.onTimeout, + streamSilenceGuardTimerTag, + ) return guard } -func (g *startupGuard) onTimeout() { - g.once.Do(func() { - g.cancel(errStartupTimeout) - }) +func (g *streamSilenceGuard) settle() bool { + g.mu.Lock() + defer g.mu.Unlock() + if g.settled { + return false + } + g.settled = true + return true } -func (g *startupGuard) Disarm() { - g.once.Do(func() { - g.timer.Stop() - }) +func (g *streamSilenceGuard) onTimeout() { + if !g.settle() { + return + } + g.cancel(errStreamSilenceTimeout) } -func classifyStartupTimeout( +func (g *streamSilenceGuard) Reset() { + g.mu.Lock() + defer g.mu.Unlock() + if g.settled { + return + } + g.timer.Reset(g.timeout, streamSilenceGuardTimerTag) +} + +func (g *streamSilenceGuard) Disarm() { + if !g.settle() { + return + } + g.timer.Stop() +} + +func classifyStreamSilenceTimeout( attemptCtx context.Context, provider string, err error, ) error { - if !errors.Is(context.Cause(attemptCtx), errStartupTimeout) { + if !errors.Is(context.Cause(attemptCtx), errStreamSilenceTimeout) { return err } if err == nil { - err = errStartupTimeout + err = errStreamSilenceTimeout } return chaterror.WithClassification(err, chaterror.ClassifiedError{ Kind: codersdk.ChatErrorKindStartupTimeout, @@ -851,7 +882,7 @@ func guardedStream( metrics *Metrics, ) (guardedAttempt, error) { attemptCtx, cancelAttempt := context.WithCancelCause(parent) - guard := newStartupGuard(clock, timeout, cancelAttempt) + guard := newStreamSilenceGuard(clock, timeout, cancelAttempt) var releaseOnce sync.Once release := func() { releaseOnce.Do(func() { @@ -863,7 +894,7 @@ func guardedStream( streamStart := clock.Now() stream, err := openStream(attemptCtx) if err != nil { - err = classifyStartupTimeout(attemptCtx, provider, err) + err = classifyStreamSilenceTimeout(attemptCtx, provider, err) release() return guardedAttempt{}, err } @@ -877,7 +908,7 @@ func guardedStream( ctx: attemptCtx, stream: fantasy.StreamResponse(func(yield func(fantasy.StreamPart) bool) { for part := range stream { - guard.Disarm() + guard.Reset() recordTTFT() if !yield(part) { return @@ -886,7 +917,7 @@ func guardedStream( }), release: release, finish: func(err error) error { - return classifyStartupTimeout(attemptCtx, provider, err) + return classifyStreamSilenceTimeout(attemptCtx, provider, err) }, }, nil } diff --git a/coderd/x/chatd/chatloop/chatloop_run_internal_test.go b/coderd/x/chatd/chatloop/chatloop_run_internal_test.go index a7a2079d79..64b1d8f97c 100644 --- a/coderd/x/chatd/chatloop/chatloop_run_internal_test.go +++ b/coderd/x/chatd/chatloop/chatloop_run_internal_test.go @@ -581,13 +581,13 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) { ) } -func TestStartupGuard_DisarmAndFireRace(t *testing.T) { +func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) { t.Parallel() for range 128 { var cancels atomic.Int32 - guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) { - if errors.Is(err, errStartupTimeout) { + guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, func(err error) { + if errors.Is(err, errStreamSilenceTimeout) { cancels.Add(1) } }) @@ -618,17 +618,17 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) { } } -func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) { +func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) { t.Parallel() attemptCtx, cancelAttempt := context.WithCancelCause(context.Background()) defer cancelAttempt(nil) - guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt) + guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, cancelAttempt) guard.Disarm() guard.onTimeout() - classified := chaterror.Classify(classifyStartupTimeout( + classified := chaterror.Classify(classifyStreamSilenceTimeout( attemptCtx, "openai", xerrors.New("invalid model"), @@ -638,10 +638,10 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) { require.Nil(t, context.Cause(attemptCtx)) } -func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { +func TestRun_RetriesSilenceTimeoutWhileOpeningStream(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -650,7 +650,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -675,10 +675,10 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -694,7 +694,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { }() trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(startupTimeout).MustWait(ctx) + mClock.Advance(silenceTimeout).MustWait(ctx) trap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) @@ -710,9 +710,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { ) select { case cause := <-attemptCause: - require.ErrorIs(t, cause, errStartupTimeout) + require.ErrorIs(t, cause, errStreamSilenceTimeout) case <-ctx.Done(): - t.Fatal("timed out waiting for startup timeout cause") + t.Fatal("timed out waiting for silence timeout cause") } } @@ -728,7 +728,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { t.Run(provider, func(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -737,7 +737,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -763,10 +763,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -795,10 +795,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { } } -func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { +func TestRun_RetriesSilenceTimeoutBeforeFirstPart(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -807,7 +807,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -837,10 +837,10 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -856,7 +856,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { }() trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(startupTimeout).MustWait(ctx) + mClock.Advance(silenceTimeout).MustWait(ctx) trap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) @@ -872,16 +872,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { ) select { case cause := <-attemptCause: - require.ErrorIs(t, cause, errStartupTimeout) + require.ErrorIs(t, cause, errStreamSilenceTimeout) case <-ctx.Done(): - t.Fatal("timed out waiting for startup timeout cause") + t.Fatal("timed out waiting for silence timeout cause") } } -func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { +func TestRun_StreamPartsResetSilenceTimeout(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -890,12 +890,17 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer armTrap.Close() + resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag) + defer resetTrap.Close() attempts := 0 retried := false firstPartYielded := make(chan struct{}, 1) - continueStream := make(chan struct{}) + secondPartYielded := make(chan struct{}, 1) + continueToSecond := make(chan struct{}) + continueToFinish := make(chan struct{}) model := &chattest.FakeModel{ ProviderName: "openai", StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { @@ -910,7 +915,29 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { } select { - case <-continueStream: + case <-continueToSecond: + case <-ctx.Done(): + _ = yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: ctx.Err(), + }) + return + } + + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, + ID: "text-1", + Delta: "done", + }) { + return + } + select { + case secondPartYielded <- struct{}{}: + default: + } + + select { + case <-continueToFinish: case <-ctx.Done(): _ = yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, @@ -920,7 +947,6 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { } parts := []fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, } @@ -936,10 +962,10 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -954,23 +980,130 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { }) }() - trap.MustWait(ctx).MustRelease(ctx) - trap.Close() - + armTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) select { case <-firstPartYielded: case <-ctx.Done(): t.Fatal("timed out waiting for first stream part") } - mClock.Advance(startupTimeout).MustWait(ctx) - close(continueStream) + mClock.Advance(silenceTimeout / 2).MustWait(ctx) + close(continueToSecond) + resetTrap.MustWait(ctx).MustRelease(ctx) + select { + case <-secondPartYielded: + case <-ctx.Done(): + t.Fatal("timed out waiting for second stream part") + } + + mClock.Advance(silenceTimeout / 2).MustWait(ctx) + close(continueToFinish) + resetTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) require.Equal(t, 1, attempts) require.False(t, retried) } +func TestRun_RetriesSilenceTimeoutBetweenParts(t *testing.T) { + t.Parallel() + + const silenceTimeout = 5 * time.Millisecond + + ctx, cancel := context.WithTimeout( + context.Background(), + testutil.WaitLong, + ) + defer cancel() + + mClock := quartz.NewMock(t) + armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer armTrap.Close() + resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag) + defer resetTrap.Close() + + attempts := 0 + firstPartYielded := make(chan struct{}, 1) + attemptCause := make(chan error, 1) + var retries []chatretry.ClassifiedError + model := &chattest.FakeModel{ + ProviderName: "openai", + StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + attempts++ + if attempts == 1 { + return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) { + return + } + select { + case firstPartYielded <- struct{}{}: + default: + } + + <-ctx.Done() + attemptCause <- context.Cause(ctx) + _ = yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: ctx.Err(), + }) + }), nil + } + return streamFromParts([]fantasy.StreamPart{{ + Type: fantasy.StreamPartTypeFinish, + FinishReason: fantasy.FinishReasonStop, + }}), nil + }, + } + + done := make(chan error, 1) + go func() { + done <- Run(context.Background(), RunOptions{ + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, + PersistStep: func(_ context.Context, _ PersistedStep) error { + return nil + }, + OnRetry: func( + _ int, + _ error, + classified chatretry.ClassifiedError, + _ time.Duration, + ) { + retries = append(retries, classified) + }, + }) + }() + + armTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) + select { + case <-firstPartYielded: + case <-ctx.Done(): + t.Fatal("timed out waiting for first stream part") + } + + mClock.Advance(silenceTimeout).MustWait(ctx) + armTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) + + require.NoError(t, awaitRunResult(ctx, t, done)) + require.Equal(t, 2, attempts) + require.Len(t, retries, 1) + require.Equal(t, codersdk.ChatErrorKindStartupTimeout, retries[0].Kind) + require.True(t, retries[0].Retryable) + require.Equal(t, "openai", retries[0].Provider) + select { + case cause := <-attemptCause: + require.ErrorIs(t, cause, errStreamSilenceTimeout) + case <-ctx.Done(): + t.Fatal("timed out waiting for silence timeout cause") + } +} + func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { t.Parallel() @@ -1014,10 +1147,10 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { t.Fatal("expected Run to panic") } -func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { +func TestRun_RetriesSilenceTimeoutWhenStreamStaysSilent(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -1026,7 +1159,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -1052,10 +1185,10 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -1071,7 +1204,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { }() trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(startupTimeout).MustWait(ctx) + mClock.Advance(silenceTimeout).MustWait(ctx) trap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) @@ -1087,9 +1220,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { ) select { case cause := <-attemptCause: - require.ErrorIs(t, cause, errStartupTimeout) + require.ErrorIs(t, cause, errStreamSilenceTimeout) case <-ctx.Done(): - t.Fatal("timed out waiting for startup timeout cause") + t.Fatal("timed out waiting for silence timeout cause") } } diff --git a/coderd/x/chatd/chatloop/compaction.go b/coderd/x/chatd/chatloop/compaction.go index 503eff51bc..b267f17e2a 100644 --- a/coderd/x/chatd/chatloop/compaction.go +++ b/coderd/x/chatd/chatloop/compaction.go @@ -35,14 +35,22 @@ const ( "- Key decisions made and their rationale\n" + "- Concrete technical details: file paths, function names, " + "commands, APIs, and configurations\n" + - "- Errors encountered and how they were resolved\n" + + "- Errors encountered and how they were resolved. Keep error " + + "notes specific: name the file, the error, and the fix. Do not " + + "generalize from a specific failure to a blanket tool-avoidance " + + "rule (e.g. \"tool X is unreliable\" or \"always use Y instead " + + "of Z\")\n" + "- Current state of the work: what is DONE, what is IN PROGRESS, " + "and what REMAINS to be done\n" + "- The specific action the assistant was performing or about to " + "perform when this summary was triggered\n\n" + "Be dense and factual. Every sentence should convey essential " + "context for continuation. Do not include pleasantries or " + - "conversational filler." + "conversational filler. For content that can be reproduced " + + "(repo files, command output, API responses), reference how to " + + "obtain it (file path, command, URL) rather than inlining the " + + "full content. Include brief inline summaries when the content " + + "itself would exceed a few lines." defaultCompactionSystemSummaryPrefix = "The following is a summary of " + "the earlier conversation. The assistant was actively working when " + "the context was compacted. Continue the work described below:" diff --git a/coderd/x/chatd/chatloop/metrics_test.go b/coderd/x/chatd/chatloop/metrics_test.go index 7aa3885750..c0c86deacc 100644 --- a/coderd/x/chatd/chatloop/metrics_test.go +++ b/coderd/x/chatd/chatloop/metrics_test.go @@ -296,6 +296,7 @@ func TestRecordStreamRetry(t *testing.T) { {name: "startup_timeout", kind: codersdk.ChatErrorKindStartupTimeout}, {name: "auth", kind: codersdk.ChatErrorKindAuth}, {name: "config", kind: codersdk.ChatErrorKindConfig}, + {name: "missing_key", kind: codersdk.ChatErrorKindMissingKey}, {name: "generic", kind: codersdk.ChatErrorKindGeneric}, {name: "chain_broken", kind: codersdk.ChatErrorKindGeneric, chainBroken: true}, } diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index 768aa5774e..fec0840b08 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -1243,6 +1243,7 @@ func ModelFromConfig( } providerClient, err = fantasyopenai.New(options...) case fantasyopenaicompat.Name: + httpClient = withOpenAICompatRequestPatches(httpClient, baseURL, modelID) options := []fantasyopenaicompat.Option{ fantasyopenaicompat.WithAPIKey(apiKey), fantasyopenaicompat.WithUserAgent(userAgent), diff --git a/coderd/x/chatd/chatprovider/openai_compat_patches.go b/coderd/x/chatd/chatprovider/openai_compat_patches.go new file mode 100644 index 0000000000..beac165bb5 --- /dev/null +++ b/coderd/x/chatd/chatprovider/openai_compat_patches.go @@ -0,0 +1,237 @@ +package chatprovider + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" +) + +// OpenAI-compatible providers share an API shape but differ in the exact JSON +// they accept. These patches adjust Fantasy's serialized request body at the +// transport boundary so higher-level generation code can stay provider agnostic. +// +// googleOpenAICompatDummyThoughtSignature is Google's documented last-resort +// bypass for callers that cannot preserve a real Gemini thought signature. +// See https://ai.google.dev/gemini-api/docs/thought-signatures. +const googleOpenAICompatDummyThoughtSignature = "skip_thought_signature_validator" + +func withOpenAICompatRequestPatches( + client *http.Client, + baseURL string, + modelID string, +) *http.Client { + if client == nil { + client = &http.Client{} + } else { + clone := *client + client = &clone + } + client.Transport = &openAICompatRequestPatchTransport{ + Base: client.Transport, + BaseURL: baseURL, + ModelID: modelID, + } + return client +} + +type openAICompatRequestPatchTransport struct { + Base http.RoundTripper + // BaseURL is the configured provider base URL, used to detect direct Gemini endpoints. + BaseURL string + // ModelID is the configured model ID, used to detect Gemini routes through Coder AI Bridge. + ModelID string +} + +func (t *openAICompatRequestPatchTransport) RoundTrip(req *http.Request) (*http.Response, error) { + base := t.base() + if !shouldPatchOpenAICompatRequest(req) { + return base.RoundTrip(req) + } + + body, err := io.ReadAll(req.Body) + closeErr := req.Body.Close() + if err != nil { + return nil, err + } + if closeErr != nil { + return nil, closeErr + } + + patched := patchOpenAICompatChatCompletionsBody(body, t.BaseURL, t.ModelID) + patchedReq := req.Clone(req.Context()) + patchedReq.Body = io.NopCloser(bytes.NewReader(patched)) + patchedReq.ContentLength = int64(len(patched)) + patchedReq.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(patched)), nil + } + + return base.RoundTrip(patchedReq) +} + +func (t *openAICompatRequestPatchTransport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func shouldPatchOpenAICompatRequest(req *http.Request) bool { + return req != nil && + req.Method == http.MethodPost && + req.Body != nil && + strings.HasSuffix(req.URL.Path, "/chat/completions") +} + +func patchOpenAICompatChatCompletionsBody(body []byte, baseURL string, modelID string) []byte { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + + changed := rewriteOpenAICompatSingleToolChoice(payload) + if shouldAddGoogleOpenAICompatThoughtSignatures(baseURL, modelID) { + changed = addGoogleOpenAICompatThoughtSignatures(payload) || changed + } + if !changed { + return body + } + + patched, err := json.Marshal(payload) + if err != nil { + return body + } + return patched +} + +// rewriteOpenAICompatSingleToolChoice replaces a single named tool choice with +// "required" because some compatible endpoints reject the named object form. +func rewriteOpenAICompatSingleToolChoice(payload map[string]any) bool { + tools, ok := payload["tools"].([]any) + if !ok || len(tools) != 1 { + return false + } + tool, ok := tools[0].(map[string]any) + if !ok { + return false + } + function, ok := tool["function"].(map[string]any) + if !ok { + return false + } + toolName, _ := function["name"].(string) + if toolName == "" { + return false + } + + toolChoice, ok := payload["tool_choice"].(map[string]any) + if !ok { + return false + } + if toolType, _ := toolChoice["type"].(string); toolType != "function" { + return false + } + choiceFunction, ok := toolChoice["function"].(map[string]any) + if !ok { + return false + } + choiceName, _ := choiceFunction["name"].(string) + if choiceName != toolName { + return false + } + + payload["tool_choice"] = "required" + return true +} + +// shouldAddGoogleOpenAICompatThoughtSignatures detects direct Gemini OpenAI +// endpoints and Coder AI Bridge Gemini routes. Other gateways, such as Vercel, +// keep their own provider-specific compatibility behavior. +func shouldAddGoogleOpenAICompatThoughtSignatures(baseURL string, modelID string) bool { + parsed, err := url.Parse(baseURL) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + path := strings.ToLower(parsed.EscapedPath()) + if host == "generativelanguage.googleapis.com" && strings.Contains(path, "/openai") { + return true + } + return host == "coder-aibridge" && isGeminiModelID(modelID) +} + +func isGeminiModelID(modelID string) bool { + modelID = strings.ToLower(strings.TrimSpace(modelID)) + return strings.HasPrefix(modelID, "gemini-") || strings.Contains(modelID, "/gemini-") +} + +// addGoogleOpenAICompatThoughtSignatures adds a dummy thought signature to the +// first tool call on each assistant tool-call message in the latest user turn. +// Gemini validates tool-call history with thought signatures, but +// OpenAI-compatible serialization can drop the original provider metadata. +func addGoogleOpenAICompatThoughtSignatures(payload map[string]any) bool { + messages, ok := payload["messages"].([]any) + if !ok { + return false + } + + currentTurnStart := -1 + for i, raw := range messages { + message, ok := raw.(map[string]any) + if !ok { + continue + } + if role, _ := message["role"].(string); role == "user" { + currentTurnStart = i + } + } + + if currentTurnStart == -1 { + return false + } + + changed := false + for _, raw := range messages[currentTurnStart+1:] { + message, ok := raw.(map[string]any) + if !ok || !isOpenAICompatAssistantRole(message["role"]) { + continue + } + toolCalls, ok := message["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + continue + } + firstToolCall, ok := toolCalls[0].(map[string]any) + if !ok { + continue + } + if ensureGoogleOpenAICompatThoughtSignature(firstToolCall) { + changed = true + } + } + return changed +} + +func isOpenAICompatAssistantRole(role any) bool { + roleValue, _ := role.(string) + return roleValue == "assistant" || roleValue == "model" +} + +func ensureGoogleOpenAICompatThoughtSignature(toolCall map[string]any) bool { + extraContent, _ := toolCall["extra_content"].(map[string]any) + google, _ := extraContent["google"].(map[string]any) + if signature, _ := google["thought_signature"].(string); signature != "" { + return false + } + if extraContent == nil { + extraContent = map[string]any{} + toolCall["extra_content"] = extraContent + } + if google == nil { + google = map[string]any{} + extraContent["google"] = google + } + google["thought_signature"] = googleOpenAICompatDummyThoughtSignature + return true +} diff --git a/coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go b/coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go new file mode 100644 index 0000000000..eace6c4173 --- /dev/null +++ b/coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go @@ -0,0 +1,156 @@ +//nolint:testpackage // These tests cover unexported request-patch guards. +package chatprovider + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPatchOpenAICompatChatCompletionsBody_Guards(t *testing.T) { + t.Parallel() + + t.Run("leaves multi tool specific choice unchanged", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "tools": []any{ + functionTool("first_tool"), + functionTool("second_tool"), + }, + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "first_tool", + }, + }, + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model") + body := decodeJSONMap(t, patched) + toolChoice, ok := body["tool_choice"].(map[string]any) + require.True(t, ok) + function, ok := toolChoice["function"].(map[string]any) + require.True(t, ok) + require.Equal(t, "first_tool", function["name"]) + }) + + t.Run("leaves string tool choice unchanged", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "tools": []any{functionTool("first_tool")}, + "tool_choice": "auto", + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model") + body := decodeJSONMap(t, patched) + require.Equal(t, "auto", body["tool_choice"]) + }) + + t.Run("leaves Gemini assistant history without a user turn unchanged", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "messages": []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + functionToolCall("call_without_user", "history_tool"), + }, + }, + }, + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash") + body := decodeJSONMap(t, patched) + messages := body["messages"].([]any) + require.Empty(t, googleThoughtSignature(t, messages[0], 0)) + }) + + t.Run("preserves existing Gemini thought signature", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "messages": []any{ + map[string]any{"role": "user", "content": "current turn"}, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_with_signature", + "type": "function", + "function": map[string]any{ + "name": "signed_tool", + "arguments": `{}`, + }, + "extra_content": map[string]any{ + "google": map[string]any{ + "thought_signature": "real-signature", + }, + }, + }, + }, + }, + }, + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash") + body := decodeJSONMap(t, patched) + messages := body["messages"].([]any) + require.Equal(t, "real-signature", googleThoughtSignature(t, messages[1], 0)) + }) +} + +func functionTool(name string) map[string]any { + return map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + } +} + +func functionToolCall(id string, name string) map[string]any { + return map[string]any{ + "id": id, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": `{}`, + }, + } +} + +func mustJSON(t *testing.T, payload map[string]any) []byte { + t.Helper() + + body, err := json.Marshal(payload) + require.NoError(t, err) + return body +} + +func decodeJSONMap(t *testing.T, body []byte) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(body, &payload)) + return payload +} + +func googleThoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string { + t.Helper() + + message, ok := rawMessage.(map[string]any) + require.True(t, ok) + toolCalls, ok := message["tool_calls"].([]any) + require.True(t, ok) + require.Greater(t, len(toolCalls), toolCallIndex) + toolCall, ok := toolCalls[toolCallIndex].(map[string]any) + require.True(t, ok) + extraContent, _ := toolCall["extra_content"].(map[string]any) + google, _ := extraContent["google"].(map[string]any) + signature, _ := google["thought_signature"].(string) + return signature +} diff --git a/coderd/x/chatd/chatprovider/openai_compat_patches_test.go b/coderd/x/chatd/chatprovider/openai_compat_patches_test.go new file mode 100644 index 0000000000..c6042c0c63 --- /dev/null +++ b/coderd/x/chatd/chatprovider/openai_compat_patches_test.go @@ -0,0 +1,186 @@ +package chatprovider_test + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "charm.land/fantasy" + fantasyopenaicompat "charm.land/fantasy/providers/openaicompat" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +const dummyThoughtSignature = "skip_thought_signature_validator" + +func TestModelFromConfig_GeminiOpenAICompatThoughtSignatures(t *testing.T) { + t.Parallel() + + t.Run("Gemini endpoint receives current turn thought signature", func(t *testing.T) { + t.Parallel() + + body := generateOpenAICompatRequest(t, "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash") + messages := body["messages"].([]any) + + require.Empty(t, thoughtSignature(t, messages[1], 0)) + require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0)) + require.Empty(t, thoughtSignature(t, messages[4], 1)) + require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[6], 0)) + }) + + t.Run("Coder AI Bridge Gemini route receives current turn thought signature", func(t *testing.T) { + t.Parallel() + + body := generateOpenAICompatRequest(t, "http://coder-aibridge/v1", "gemini-3.5-flash") + messages := body["messages"].([]any) + + require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0)) + }) + + t.Run("Vercel OpenAI-compatible Gemini route is unchanged", func(t *testing.T) { + t.Parallel() + + body := generateOpenAICompatRequest(t, "https://gateway.vercel.ai/v1", "google/gemini-3.5-flash") + messages := body["messages"].([]any) + + require.Empty(t, thoughtSignature(t, messages[4], 0)) + }) +} + +func generateOpenAICompatRequest(t *testing.T, baseURL string, modelID string) map[string]any { + t.Helper() + + transport := &captureChatCompletionTransport{} + model, err := chatprovider.ModelFromConfig( + fantasyopenaicompat.Name, + modelID, + chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyopenaicompat.Name: "test-key", + }, + BaseURLByProvider: map[string]string{ + fantasyopenaicompat.Name: baseURL, + }, + }, + chatprovider.UserAgent(), + nil, + &http.Client{Transport: transport}, + ) + require.NoError(t, err) + + _, err = model.Generate(t.Context(), fantasy.Call{ + Prompt: geminiOpenAICompatToolPrompt(), + }) + require.NoError(t, err) + require.NotNil(t, transport.body) + return transport.body +} + +type captureChatCompletionTransport struct { + body map[string]any +} + +func (ct *captureChatCompletionTransport) RoundTrip(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + _ = req.Body.Close() + if strings.HasSuffix(req.URL.Path, "/chat/completions") { + ct.body = map[string]any{} + if err := json.Unmarshal(body, &ct.body); err != nil { + return nil, err + } + } + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{ + "id":"chatcmpl-test", + "object":"chat.completion", + "created":0, + "model":"gemini-3.5-flash", + "choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`)), + }, nil +} + +func geminiOpenAICompatToolPrompt() []fantasy.Message { + return []fantasy.Message{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "previous turn"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ToolCallID: "previous-call", ToolName: "previous_tool", Input: `{}`}, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "previous-call", + Output: fantasy.ToolResultOutputContentText{Text: `{}`}, + }, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "current turn"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ToolCallID: "current-call-a", ToolName: "first_tool", Input: `{}`}, + fantasy.ToolCallPart{ToolCallID: "current-call-b", ToolName: "parallel_tool", Input: `{}`}, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "current-call-a", + Output: fantasy.ToolResultOutputContentText{Text: `{}`}, + }, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: "current-call-c", + ToolName: "second_step_tool", + Input: `{}`, + }, + }, + }, + } +} + +func thoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string { + t.Helper() + message, ok := rawMessage.(map[string]any) + require.True(t, ok) + toolCalls, ok := message["tool_calls"].([]any) + require.True(t, ok) + require.Greater(t, len(toolCalls), toolCallIndex) + toolCall, ok := toolCalls[toolCallIndex].(map[string]any) + require.True(t, ok) + extraContent, _ := toolCall["extra_content"].(map[string]any) + google, _ := extraContent["google"].(map[string]any) + signature, _ := google["thought_signature"].(string) + return signature +} diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 5db1a16e53..07e8fd66b0 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -16,7 +16,9 @@ import ( "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" + "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) const ( @@ -98,7 +100,14 @@ func (p *Server) newAIGatewayModel( return nil, xerrors.New("AI Gateway routing requires an AI provider name") } if opts.ActiveAPIKeyID == "" { - return nil, xerrors.New("AI Gateway routing requires the active turn API key ID") + return nil, chaterror.WithClassification( + xerrors.New("AI Gateway routing requires the active turn API key ID"), + chaterror.ClassifiedError{ + Kind: codersdk.ChatErrorKindMissingKey, + Retryable: false, + Detail: "If this error persists after resending, please report it as a bug.", + }, + ) } factoryPtr := p.aibridgeTransportFactory diff --git a/coderd/x/chatd/model_routing_internal_test.go b/coderd/x/chatd/model_routing_internal_test.go index 786365d9fb..0d2f317204 100644 --- a/coderd/x/chatd/model_routing_internal_test.go +++ b/coderd/x/chatd/model_routing_internal_test.go @@ -20,8 +20,10 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/codersdk" ) type aibridgeTestFactory struct { @@ -530,6 +532,11 @@ func TestAIBridgeRoutingFailClosed(t *testing.T) { } _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{}) require.ErrorContains(t, err, "active turn API key ID") + + classified := chaterror.Classify(err) + require.Equal(t, codersdk.ChatErrorKindMissingKey, classified.Kind, + "production path must return a pre-classified missing_key error") + require.False(t, classified.Retryable) }) t.Run("StaticModel", func(t *testing.T) { diff --git a/coderd/x/chatd/quickgen_internal_test.go b/coderd/x/chatd/quickgen_internal_test.go index 09fc8001ab..0e46ccc0f7 100644 --- a/coderd/x/chatd/quickgen_internal_test.go +++ b/coderd/x/chatd/quickgen_internal_test.go @@ -3,11 +3,14 @@ package chatd import ( "context" "encoding/json" + "net/http" + "net/http/httptest" "strings" "testing" "time" "charm.land/fantasy" + fantasyopenaicompat "charm.land/fantasy/providers/openaicompat" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" @@ -667,6 +670,100 @@ func TestFallbackTurnStatusLabel(t *testing.T) { } } +func TestGenerateStructuredTitleWithUsage_OpenAICompatibleRequiredToolChoice(t *testing.T) { + t.Parallel() + + server, requests := newOpenAICompatStructuredOutputServer(t, "propose_title", `{"title":"Failed workspace logs"}`) + model := openAICompatTestModel(t, server.URL) + + title, _, err := generateStructuredTitleWithUsage( + t.Context(), + model, + titleGenerationPrompt, + "summarize failed workspace build logs", + ) + require.NoError(t, err) + require.Equal(t, "Failed workspace logs", title) + + body := testutil.TryReceive(t.Context(), t, requests) + require.Equal(t, "required", body["tool_choice"]) +} + +func newOpenAICompatStructuredOutputServer( + t *testing.T, + toolName string, + arguments string, +) (*httptest.Server, <-chan map[string]any) { + t.Helper() + + requests := make(chan map[string]any, 10) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + requests <- body + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "chatcmpl-structured-output", + "object": "chat.completion", + "created": time.Now().Unix(), + "model": "anthropic/claude-4-5-sonnet", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_structured_output", + "type": "function", + "function": map[string]any{ + "name": toolName, + "arguments": arguments, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + }) + })) + t.Cleanup(server.Close) + return server, requests +} + +func openAICompatTestModel(t *testing.T, baseURL string) fantasy.LanguageModel { + t.Helper() + + model, err := chatprovider.ModelFromConfig( + fantasyopenaicompat.Name, + "anthropic/claude-4-5-sonnet", + chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyopenaicompat.Name: "test-key", + }, + BaseURLByProvider: map[string]string{ + fantasyopenaicompat.Name: baseURL, + }, + }, + chatprovider.UserAgent(), + nil, + nil, + ) + require.NoError(t, err) + return model +} + func TestGenerateStructuredTurnStatusLabel(t *testing.T) { t.Parallel() @@ -682,11 +779,26 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { }, } - label, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done") + label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done") require.NoError(t, err) require.Equal(t, "Submitted PR", label) }) + t.Run("sends required tool_choice to openai-compatible provider", func(t *testing.T) { + t.Parallel() + + server, requests := newOpenAICompatStructuredOutputServer(t, "propose_turn_status_label", `{"label":"Submitted PR"}`) + model := openAICompatTestModel(t, server.URL) + + label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done") + require.NoError(t, err) + require.Equal(t, "Submitted PR", label) + require.Len(t, requests, 1) + + body := testutil.TryReceive(t.Context(), t, requests) + require.Equal(t, "required", body["tool_choice"]) + }) + t.Run("rejects narrative label", func(t *testing.T) { t.Parallel() @@ -698,7 +810,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { }, } - _, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done") + _, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done") require.ErrorContains(t, err, "generated turn status label was invalid") }) @@ -706,7 +818,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { t.Parallel() model := &chattest.FakeModel{} - _, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, " ") + _, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, " ") require.ErrorContains(t, err, "turn status label input was empty") }) } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 8984ac86cd..450397416b 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -1090,16 +1090,18 @@ func (p *Server) createChildSubagentChatWithOptions( return xerrors.Errorf("update child injected context: %w", err) } - userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: insertedChat.ID, } - appendChatMessage(&userParams, newChatMessage( - database.ChatMessageRoleUser, + childUserMsg := newUserChatMessage( + childAPIKeyID, userContent, database.ChatMessageVisibilityBoth, modelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(parent.OwnerID).withAPIKeyID(childAPIKeyID)) + ) + childUserMsg = childUserMsg.withCreatedBy(parent.OwnerID) + appendUserChatMessage(&userParams, childUserMsg) if _, err := tx.InsertChatMessages(ctx, userParams); err != nil { return xerrors.Errorf("insert initial child user message: %w", err) } @@ -1176,16 +1178,27 @@ func copyParentContextMessages( return nil, xerrors.Errorf("marshal filtered context parts: %w", err) } - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage. ChatID: child.ID, } - appendChatMessage(&msgParams, newChatMessage( - copiedRole, - filteredContent, - copiedVisibility, - child.LastModelConfigID, - copiedVersion, - )) + if copiedRole == database.ChatMessageRoleUser { + copiedAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) + appendUserChatMessage(&msgParams, newUserChatMessage( + copiedAPIKeyID, + filteredContent, + copiedVisibility, + child.LastModelConfigID, + copiedVersion, + )) + } else { + appendChatMessage(&msgParams, newChatMessage( + copiedRole, + filteredContent, + copiedVisibility, + child.LastModelConfigID, + copiedVersion, + )) + } if _, err := store.InsertChatMessages(ctx, msgParams); err != nil { return nil, xerrors.Errorf("insert context message: %w", err) } diff --git a/coderd/x/chatd/subagent_context_internal_test.go b/coderd/x/chatd/subagent_context_internal_test.go index dc60e3330f..5ccab312d6 100644 --- a/coderd/x/chatd/subagent_context_internal_test.go +++ b/coderd/x/chatd/subagent_context_internal_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" @@ -446,10 +447,33 @@ func TestCreateChildSubagentChatUpdatesInheritedLastInjectedContext(t *testing.T ctx := chatdTestContext(t) parentChat := createParentChatWithInheritedContext(ctx, t, db, server) + // Set a delegated API key so that copied user-role context messages + // are stamped with api_key_id, preserving AI Gateway routing. + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: parentChat.OwnerID}) + ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID) + child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") require.NoError(t, err) assertChildInheritedContext(ctx, t, db, child.ID, "inspect bindings") + + // Verify that all user-role messages in the child chat carry + // api_key_id so activeTurnAPIKeyIDFromMessages resolves correctly. + childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: child.ID, + AfterID: 0, + }) + require.NoError(t, err) + var userMsgCount int + for _, msg := range childMessages { + if msg.Role != database.ChatMessageRoleUser { + continue + } + userMsgCount++ + require.True(t, msg.APIKeyID.Valid, "child user message (id=%d) should have api_key_id set", msg.ID) + require.Equal(t, apiKey.ID, msg.APIKeyID.String, "child user message (id=%d) api_key_id mismatch", msg.ID) + } + require.Greater(t, userMsgCount, 0, "expected at least one user-role message in child chat") } func TestSpawnComputerUseAgentInheritsContext(t *testing.T) { diff --git a/coderd/x/nats/cluster.go b/coderd/x/nats/cluster.go new file mode 100644 index 0000000000..7b0fd1ab80 --- /dev/null +++ b/coderd/x/nats/cluster.go @@ -0,0 +1,148 @@ +package nats + +import ( + "net" + "net/url" + "slices" + "strconv" + "strings" + + "golang.org/x/xerrors" +) + +// SetPeerAddresses replaces the configured NATS cluster peer routes. +func (p *Pubsub) SetPeerAddresses(addresses []string) error { + p.clusterMu.Lock() + defer p.clusterMu.Unlock() + + if p.ctx.Err() != nil { + return errClosed + } + if !p.clustered { + return xerrors.New("nats pubsub was not started with clustering enabled") + } + + routes, err := parsePeerAddresses(addresses) + if err != nil { + return err + } + + self := &url.URL{Scheme: "nats", Host: p.ns.ClusterAddr().String()} + routes = filterSelfRoutes(routes, self) + routes = sortRouteURLs(routes) + + if sortedURLsEqual(p.currentRoutes, routes) { + return nil + } + + newOpts := p.serverOpts.Clone() + newOpts.Routes = cloneRouteURLs(routes) + if err := p.ns.ReloadOptions(newOpts); err != nil { + return xerrors.Errorf("reload nats peer addresses: %w", err) + } + p.serverOpts = newOpts.Clone() + p.currentRoutes = cloneRouteURLs(routes) + return nil +} + +func parsePeerAddresses(addresses []string) ([]*url.URL, error) { + routesByAddress := make(map[string]*url.URL, len(addresses)) + for i, address := range addresses { + trimmed := strings.TrimSpace(address) + if trimmed == "" { + return nil, xerrors.Errorf("peer address %d is empty", i) + } + + normalizedHost, err := normalizeHostPort(trimmed) + if err != nil { + return nil, err + } + + routesByAddress[normalizedHost] = &url.URL{ + Scheme: "nats", + Host: normalizedHost, + } + } + + routes := make([]*url.URL, 0, len(routesByAddress)) + for _, route := range routesByAddress { + routes = append(routes, route) + } + return routes, nil +} + +func filterSelfRoutes(routes []*url.URL, self *url.URL) []*url.URL { + filtered := make([]*url.URL, 0, len(routes)) + for _, route := range routes { + if route.String() == self.String() { + continue + } + filtered = append(filtered, route) + } + return filtered +} + +func normalizeHostPort(address string) (string, error) { + route, err := url.Parse(address) + if err != nil { + return "", xerrors.Errorf("parse peer address %q: %w", address, err) + } + if route.User != nil { + return "", xerrors.Errorf("peer address %q must not include userinfo", address) + } + if route.Path != "" || route.RawQuery != "" || route.Fragment != "" { + return "", xerrors.Errorf("peer address %q must not include path, query, or fragment", address) + } + + host, port, err := net.SplitHostPort(route.Host) + if err != nil { + return "", xerrors.Errorf("split %q host port: %w", address, err) + } + if host == "" || port == "" { + return "", xerrors.Errorf("%q must include host and port", address) + } + + portNumber, err := strconv.Atoi(port) + if err != nil { + return "", xerrors.Errorf("parse %q port: %w", address, err) + } + if portNumber <= 0 || portNumber > 65535 { + return "", xerrors.Errorf("peer address %q must include a valid port", address) + } + return net.JoinHostPort(host, strconv.Itoa(portNumber)), nil +} + +func sortRouteURLs(routes []*url.URL) []*url.URL { + slices.SortFunc(routes, func(a, b *url.URL) int { + return strings.Compare(a.String(), b.String()) + }) + return routes +} + +// sortedURLsEqual assumes sorted slices. +func sortedURLsEqual(a, b []*url.URL) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].String() != b[i].String() { + return false + } + } + return true +} + +func cloneRouteURLs(routes []*url.URL) []*url.URL { + if routes == nil { + return nil + } + clones := make([]*url.URL, len(routes)) + for i, route := range routes { + if route == nil { + continue + } + clone := *route + clones[i] = &clone + } + return clones +} diff --git a/coderd/x/nats/cluster_internal_test.go b/coderd/x/nats/cluster_internal_test.go new file mode 100644 index 0000000000..eadf2e561f --- /dev/null +++ b/coderd/x/nats/cluster_internal_test.go @@ -0,0 +1,134 @@ +package nats + +import ( + "errors" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_parsePeerAddresses(t *testing.T) { + t.Parallel() + + t.Run("Valid", func(t *testing.T) { + t.Parallel() + routes, err := parsePeerAddresses([]string{ + "whatever://127.0.0.1:4222 ", + "http://[::1]:7222", + "nats://example.com:6222", + }) + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "nats://127.0.0.1:4222", + "nats://[::1]:7222", + "nats://example.com:6222", + }, routeStrings(routes)) + }) + + t.Run("Empty", func(t *testing.T) { + t.Parallel() + routes, err := parsePeerAddresses(nil) + require.NoError(t, err) + require.Empty(t, routes) + }) + + t.Run("Dedupes", func(t *testing.T) { + t.Parallel() + routes, err := parsePeerAddresses([]string{ + "nats://b.example:6222", + "nats://a.example:6222", + "nats://b.example:6222", + }) + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "nats://a.example:6222", + "nats://b.example:6222", + }, routeStrings(routes)) + }) + + t.Run("Invalid", func(t *testing.T) { + t.Parallel() + for _, address := range []string{ + "", + " ", + "127.0.0.1:4222", + "127.0.0.1", + ":4222", + "127.0.0.1:0", + "127.0.0.1:bad", + "nats://127.0.0.1", + "nats://:4222", + "nats://127.0.0.1:0", + "nats://127.0.0.1:bad", + "nats://user@127.0.0.1:4222", + "nats://127.0.0.1:4222/path", + "nats://127.0.0.1:4222?x=1", + "nats://127.0.0.1:4222#frag", + } { + t.Run(address, func(t *testing.T) { + t.Parallel() + _, err := parsePeerAddresses([]string{address}) + require.Error(t, err) + }) + } + }) +} + +func Test_filterSelfRoutes(t *testing.T) { + t.Parallel() + + routes, err := parsePeerAddresses([]string{ + "nats://b.example:6222", + "http://self.example:6222", + }) + require.NoError(t, err) + + routes = filterSelfRoutes(routes, &url.URL{Scheme: "nats", Host: "self.example:6222"}) + require.Equal(t, []string{"nats://b.example:6222"}, routeStrings(routes)) +} + +// Cluster tests bind free ports and reload shared route state. +func TestPubsub_SetPeerAddresses(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + a := newTestPubsub(t, clusterTestOptions(t)) + b := newTestPubsub(t, clusterTestOptions(t)) + c := newTestPubsub(t, clusterTestOptions(t)) + + addrB := clusterRouteAddress(t, b) + addrC := clusterRouteAddress(t, c) + require.NoError(t, a.SetPeerAddresses([]string{addrC, addrB})) + requireRoutesEqual(t, a.currentRoutes, addrB, addrC) + + require.NoError(t, a.SetPeerAddresses([]string{addrB, addrC})) + requireRoutesEqual(t, a.currentRoutes, addrB, addrC) + + require.NoError(t, a.SetPeerAddresses(nil)) + require.Empty(t, a.currentRoutes) + require.Empty(t, a.serverOpts.Routes) + }) + + t.Run("StandaloneConfigError", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, defaultTestOptions()) + err := ps.SetPeerAddresses(nil) + require.ErrorContains(t, err, "not started with clustering enabled") + }) + + t.Run("Closed", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, clusterTestOptions(t)) + require.NoError(t, ps.Close()) + err := ps.SetPeerAddresses(nil) + require.True(t, errors.Is(err, errClosed), "got %v", err) + }) + + t.Run("DropsSelfRoute", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, clusterTestOptions(t)) + require.NoError(t, ps.SetPeerAddresses([]string{clusterRouteAddress(t, ps)})) + require.Empty(t, ps.currentRoutes) + }) +} diff --git a/coderd/x/nats/pubsub.go b/coderd/x/nats/pubsub.go new file mode 100644 index 0000000000..a41247ed09 --- /dev/null +++ b/coderd/x/nats/pubsub.go @@ -0,0 +1,693 @@ +package nats + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "net/url" + "sync" + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + natsgo "github.com/nats-io/nats.go" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database/pubsub" +) + +// DefaultMaxPending is the per-client outbound pending byte budget. +const DefaultMaxPending int64 = 128 << 20 + +const ( + defaultClusterName = "coder" + defaultClusterPort = 6222 + defaultRoutePoolSize = 3 +) + +var errClosed = xerrors.New("nats pubsub closed") + +// PendingLimits configures per-subscription NATS pending limits set +// via SetPendingLimits on each *natsgo.Subscription. +type PendingLimits struct { + // Msgs is the per-subscription pending message limit. Positive + // values also set each local listener queue capacity. + // Zero uses the package default. Negative disables this limit. + Msgs int + + // Bytes is the per-subscription pending byte limit. + // Zero uses the package default. Negative disables this limit. + Bytes int +} + +// Options configures the embedded NATS Pubsub. +type Options struct { + // MaxPayload is the NATS max payload. Zero means server default. + MaxPayload int32 + + // MaxPending is the per-client outbound pending byte budget on the + // embedded server. Zero or negative means the package default, + // 128 MiB. + MaxPending int64 + + // PendingLimits configures per-subscription NATS pending limits. + // Positive Msgs also sets local listener queue capacity. + // Zero fields use package defaults: Msgs -1 and Bytes 512 MiB. + PendingLimits PendingLimits + + // ReconnectWait controls client reconnect delay. Zero keeps the + // NATS default. + ReconnectWait time.Duration + + // InProcess, when true, uses nats.InProcessServer instead of TCP + // loopback. Intended for benchmarks and tests. + InProcess bool + + // PublishConns is the number of publisher connections. Each Publish + // is routed by a stable hash of the subject. Zero or negative means 1. + PublishConns int + + // SubscribeConns is the number of subscriber connections. Each + // shared subscription is pinned to one connection by a stable hash + // of its subject. Zero or negative means 1. + SubscribeConns int + + // ClusterHost is the embedded NATS route listener host. Empty means + // all interfaces when cluster mode is enabled. + ClusterHost string + + // ClusterPort is the embedded NATS route listener port. Zero means + // 6222 when cluster mode is enabled. + ClusterPort int + + // RoutePoolSize is the NATS route pool size. Zero means the package + // default when cluster mode is enabled. + RoutePoolSize int + + // disableCluster is intended only for testing. Since we cannot reload a server + // with a cluster host/port after initialization, we start all production servers + // with clustering enabled. + disableCluster bool +} + +// Pubsub is an embedded NATS-backed implementation of pubsub.Pubsub. +// +// Each Pubsub owns one embedded server, a pool of publisher +// *natsgo.Conns (Options.PublishConns) and a pool of subscriber +// *natsgo.Conns (Options.SubscribeConns). Publishes and shared +// subscriptions are pinned to a connection by a stable hash of the +// subject, so same-subject traffic preserves per-subject ordering and +// every local subscriber for a subject coalesces onto one underlying +// *natsgo.Subscription. +type Pubsub struct { + mu sync.Mutex + + logger slog.Logger + opts Options + + ns *natsserver.Server + // publishPool and subscribePool are immutable after construction so + // the hot path can index without holding p.mu. + publishPool []*natsgo.Conn + subscribePool []*natsgo.Conn + + // subscriptions coalesces concurrent local subscribers on the + // same subject onto a single underlying *natsgo.Subscription. + subscriptions map[string]*natsSub + closeOnce sync.Once + + // ctx is canceled by Close while holding p.mu so subscriber state + // cleanup observes the canceled context. + ctx context.Context + cancel context.CancelFunc + + clusterMu sync.Mutex + clustered bool + serverOpts *natsserver.Options + currentRoutes []*url.URL +} + +// natsSub maps to one underlying *natsgo.Subscription. The first +// local subscriber creates it; later local subscribers attach to it. +// When the last local subscriber detaches, the NATS subscription is +// unsubscribed. +type natsSub struct { + // sub is set before this natsSub is published in Pubsub.subscriptions + // and is immutable after that. + sub *natsgo.Subscription + + // mu guards localSubs. + mu sync.Mutex + // localSubs are the local subscribers attached to this NATS subscription. + localSubs map[*localSub]struct{} + + // dropMu keeps async error accounting independent from listener fan-out. + dropMu sync.Mutex + // lastDropped is the cumulative NATS dropped count last reported locally. + lastDropped uint64 +} + +// localSub is the local handle returned by Subscribe / +// SubscribeWithErr. Each local subscriber gets its own bounded inbox +// and dispatcher goroutine so one slow listener cannot block peers on +// the same subject. +type localSub struct { + cancelOnce sync.Once + + ctx context.Context + + event string + listener pubsub.ListenerWithErr + + // queue is the per-listener data fan-out inbox. The shared NATS + // callback enqueues non-blockingly; on overflow the message is + // dropped and a drop signal is raised. + queue chan []byte + // dropSignal is a size-1 buffered channel that coalesces drop + // notifications from local overflow and NATS slow-consumer + // broadcasts onto a single pending wake. + dropSignal chan struct{} + cancel context.CancelFunc +} + +// Compile-time assertion that *Pubsub satisfies the pubsub.Pubsub interface. +var _ pubsub.Pubsub = (*Pubsub)(nil) + +// newPubsub allocates a *Pubsub with initialized maps and cancel ctx. +func newPubsub(ctx context.Context, logger slog.Logger, opts Options) *Pubsub { + ctx, cancel := context.WithCancel(ctx) + return &Pubsub{ + logger: logger, + opts: opts, + subscriptions: make(map[string]*natsSub), + ctx: ctx, + cancel: cancel, + } +} + +// defaultPendingLimits returns the effective per-subscription pending +// limits applied at Subscribe time. +func defaultPendingLimits(in PendingLimits) PendingLimits { + out := in + if out.Msgs == 0 { + out.Msgs = -1 + } + if out.Bytes == 0 { + out.Bytes = 512 * 1024 * 1024 + } + return out +} + +// buildConnHandlers returns the connHandlers stack installed on every +// owned connection. Handlers close over p so slow-consumer routing +// keeps working. +func (p *Pubsub) buildConnHandlers() connHandlers { + return connHandlers{ + disconnectErr: func(conn *natsgo.Conn, err error) { + if err != nil { + p.logger.Warn(p.ctx, "nats client disconnected", slog.Error(err)) + } + p.signalSubscribersDroppedForConn(conn) + }, + reconnect: func(_ *natsgo.Conn) { + p.logger.Info(p.ctx, "nats client reconnected") + }, + closed: func(_ *natsgo.Conn) { + p.logger.Debug(p.ctx, "nats client closed") + }, + errH: func(_ *natsgo.Conn, sub *natsgo.Subscription, err error) { + if err != nil && errors.Is(err, natsgo.ErrSlowConsumer) { + p.handleAsyncError(sub, err) + return + } + if err != nil { + p.logger.Warn(p.ctx, "nats async error", slog.Error(err)) + } + }, + } +} + +// New creates an embedded NATS Pubsub. The returned *Pubsub owns the +// embedded server and the publisher and subscriber connection pools. +// Close shuts down all owned resources. +func New(ctx context.Context, logger slog.Logger, opts Options) (*Pubsub, error) { + sopts, err := buildServerOptions(opts) + if err != nil { + return nil, err + } + + ns, err := startEmbeddedServer(sopts) + if err != nil { + return nil, err + } + + logger.Info(context.Background(), "embedded nats server started", + slog.F("client_url", ns.ClientURL()), + ) + + p := newPubsub(ctx, logger, opts) + p.ns = ns + p.clustered = !opts.disableCluster + p.serverOpts = sopts.Clone() + p.currentRoutes = cloneRouteURLs(sopts.Routes) + handlers := p.buildConnHandlers() + + publishPool, err := newConnPool(ns, opts, handlers, opts.PublishConns, "coder-pubsub-pub") + if err != nil { + p.cancel() + ns.Shutdown() + ns.WaitForShutdown() + return nil, err + } + subscribePool, err := newConnPool(ns, opts, handlers, opts.SubscribeConns, "coder-pubsub-sub") + if err != nil { + p.cancel() + for _, c := range publishPool { + c.Close() + } + ns.Shutdown() + ns.WaitForShutdown() + return nil, err + } + p.publishPool = publishPool + p.subscribePool = subscribePool + go func() { + <-p.ctx.Done() + _ = p.Close() + }() + return p, nil +} + +func newConnPool(ns *natsserver.Server, opts Options, handlers connHandlers, count int, clientName string) ([]*natsgo.Conn, error) { + if count <= 0 { + count = 1 + } + pool := make([]*natsgo.Conn, 0, count) + for i := 0; i < count; i++ { + // Suffix names when the pool has more than one entry so server + // logs can distinguish connections. + name := clientName + if count > 1 { + name = fmt.Sprintf("%s-%d", clientName, i) + } + nc, err := connectClient(ns, opts, handlers, name) + if err != nil { + for _, c := range pool { + c.Close() + } + return nil, xerrors.Errorf("dial conn: %w", err) + } + pool = append(pool, nc) + } + return pool, nil +} + +// Publish publishes a message under the given event name. The +// publisher connection is selected by a stable hash of the subject so +// same-subject publishes preserve per-subject ordering. +func (p *Pubsub) Publish(event string, message []byte) error { + if p.ctx.Err() != nil { + return errClosed + } + + if err := pickConn(p.publishPool, event).Publish(event, message); err != nil { + return xerrors.Errorf("publish: %w", err) + } + return nil +} + +// Flush blocks until every publisher connection has flushed buffered +// publishes to the embedded server. Returns the first error +// encountered; remaining connections are still flushed. +func (p *Pubsub) Flush() error { + if p.ctx.Err() != nil { + return errClosed + } + + var firstErr error + for i, nc := range p.publishPool { + if err := nc.Flush(); err != nil && firstErr == nil { + firstErr = xerrors.Errorf("flush pub conn %d: %w", i, err) + } + } + return firstErr +} + +// Subscribe subscribes a Listener to the given event name. Errors +// such as ErrDroppedMessages are silently ignored, mirroring the +// legacy pubsub Listener semantics. +func (p *Pubsub) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) { + return p.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) { + if err != nil { + return + } + listener(ctx, msg) + }) +} + +// SubscribeWithErr subscribes a ListenerWithErr to the given event +// name. The listener also receives error deliveries such as +// pubsub.ErrDroppedMessages. Multiple local subscribers on the same +// event share a single underlying *natsgo.Subscription with +// per-listener bounded inboxes so a slow listener cannot block its +// peers. +func (p *Pubsub) SubscribeWithErr(event string, listener pubsub.ListenerWithErr) (cancel func(), err error) { + s, err := p.addSubscriber(event, listener) + if err != nil { + return nil, err + } + + cancelFn := func() { + s.close() + p.unsubscribeLocal(s) + } + return cancelFn, nil +} + +// listenerQueueSize returns the per-listener inbox capacity. A +// positive PendingLimits.Msgs sets the cap (giving callers a knob to +// trigger local-overflow drops since coalescing makes NATS-level +// slow-consumer signals rare). Otherwise the default is used. +func listenerQueueSize(in PendingLimits) int { + if in.Msgs > 0 { + return in.Msgs + } + return defaultListenerQueueSize +} + +const defaultListenerQueueSize = 1024 + +// addSubscriber creates a local subscriber and attaches it to the natsSub +// for event. New natsSub entries are published only after NATS setup succeeds. +func (p *Pubsub) addSubscriber(event string, listener pubsub.ListenerWithErr) (*localSub, error) { + ctx, cancel := context.WithCancel(p.ctx) + s := &localSub{ + ctx: ctx, + cancel: cancel, + event: event, + listener: listener, + queue: make(chan []byte, listenerQueueSize(p.opts.PendingLimits)), + dropSignal: make(chan struct{}, 1), + } + s.init() + + cleanupSub, err := func() (*natsgo.Subscription, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.ctx.Err() != nil { + return nil, errClosed + } + + nsub, ok := p.subscriptions[event] + if ok { + nsub.mu.Lock() + nsub.localSubs[s] = struct{}{} + nsub.mu.Unlock() + return nsub.sub, nil + } + + nsub = &natsSub{ + localSubs: map[*localSub]struct{}{ + s: {}, + }, + } + + subConn := pickConn(p.subscribePool, event) + natsSubscription, err := subConn.Subscribe(event, nsub.handleMessage) + if err != nil { + return nil, xerrors.Errorf("subscribe: %w", err) + } + nsub.sub = natsSubscription + + // Flush the SUB to the server so a publish issued immediately + // after Subscribe returns cannot race ahead of registration. + if err := subConn.Flush(); err != nil { + return natsSubscription, xerrors.Errorf("flush subscribe: %w", err) + } + limits := defaultPendingLimits(p.opts.PendingLimits) + if err := natsSubscription.SetPendingLimits(limits.Msgs, limits.Bytes); err != nil { + return natsSubscription, xerrors.Errorf("set pending limits: %w", err) + } + + p.subscriptions[event] = nsub + return natsSubscription, nil + }() + if err != nil { + s.close() + if cleanupSub != nil { + if unsubscribeErr := cleanupSub.Unsubscribe(); unsubscribeErr != nil { + err = errors.Join(err, xerrors.Errorf("unsubscribe: %w", unsubscribeErr)) + } + } + return nil, err + } + return s, nil +} + +// unsubscribeLocal removes s from its natsSub. If s was the last +// listener, it also removes and unsubscribes the underlying NATS +// subscription. +func (p *Pubsub) unsubscribeLocal(s *localSub) { + natsSub := func() *natsgo.Subscription { + p.mu.Lock() + defer p.mu.Unlock() + + nsub := p.subscriptions[s.event] + if nsub == nil { + return nil + } + + nsub.mu.Lock() + defer nsub.mu.Unlock() + if _, tracked := nsub.localSubs[s]; !tracked { + return nil + } + delete(nsub.localSubs, s) + if len(nsub.localSubs) > 0 { + return nil + } + // Last listener: remove the nsub entry so a new Subscribe to this + // subject creates a fresh underlying subscription. + delete(p.subscriptions, s.event) + return nsub.sub + }() + if natsSub != nil { + _ = natsSub.Unsubscribe() + } +} + +// handleMessage handles messages for the shared subscription. Each +// enqueue is non-blocking and does not call user code, so one slow +// listener cannot stall the NATS delivery goroutine. +// +// Zero-copy fan-out: the same msg.Data slice is delivered to every +// local listener without cloning. Listeners on a coalesced subject MUST +// treat the delivered bytes as immutable. +func (nsub *natsSub) handleMessage(msg *natsgo.Msg) { + nsub.mu.Lock() + defer nsub.mu.Unlock() + + for s := range nsub.localSubs { + s.enqueue(msg.Data) + } +} + +// init starts the per-listener delivery goroutine. +func (s *localSub) init() { + go func() { + for { + select { + case <-s.ctx.Done(): + return + case data := <-s.queue: + s.listener(s.ctx, data, nil) + case <-s.dropSignal: + s.listener(s.ctx, nil, pubsub.ErrDroppedMessages) + } + } + }() +} + +// close cancels local delivery without waiting for callbacks. +func (s *localSub) close() { + s.cancelOnce.Do(func() { + if s.cancel != nil { + s.cancel() + } + }) +} + +// enqueue non-blockingly sends data onto s.queue. On overflow it drops the +// message and raises a drop signal so pubsub.ErrDroppedMessages is surfaced. +// If s is canceled the message is silently dropped. +func (s *localSub) enqueue(data []byte) { + select { + case s.queue <- data: + default: + s.signalDrop() + } +} + +// signalDrop pushes onto dropSignal without blocking. Multiple drops +// between dispatcher dequeues coalesce into a single pending signal, so +// the listener observes one ErrDroppedMessages per drop wave. +func (s *localSub) signalDrop() { + select { + case s.dropSignal <- struct{}{}: + default: + } +} + +// signalSubscribersDroppedForConn signals local subscribers assigned to conn. +func (p *Pubsub) signalSubscribersDroppedForConn(conn *natsgo.Conn) { + if conn == nil || len(p.subscribePool) == 0 { + return + } + + p.mu.Lock() + subs := make([]*localSub, 0) + for event, nsub := range p.subscriptions { + if pickConn(p.subscribePool, event) != conn { + continue + } + nsub.mu.Lock() + for s := range nsub.localSubs { + subs = append(subs, s) + } + nsub.mu.Unlock() + } + p.mu.Unlock() + + for _, s := range subs { + s.signalDrop() + } +} + +// handleAsyncError routes async error callbacks. Only slow-consumer +// errors trigger drop accounting. +func (p *Pubsub) handleAsyncError(sub *natsgo.Subscription, err error) { + if sub == nil || !errors.Is(err, natsgo.ErrSlowConsumer) { + return + } + p.mu.Lock() + var nsub *natsSub + for _, candidate := range p.subscriptions { + if candidate.sub == sub { + nsub = candidate + break + } + } + p.mu.Unlock() + if nsub == nil { + return + } + p.handleSlowSubscriber(nsub) +} + +// handleSlowSubscriber broadcasts pubsub.ErrDroppedMessages to every +// local listener on nsub when NATS reports a new drop delta. The +// slow-consumer signal is per-subscription and cannot be narrowed to a +// single local listener. +func (p *Pubsub) handleSlowSubscriber(nsub *natsSub) { + nsub.dropMu.Lock() + dropped, err := nsub.sub.Dropped() + if err != nil { + nsub.dropMu.Unlock() + p.logger.Warn(p.ctx, "nats: query dropped count", slog.Error(err)) + return + } + if dropped < 0 { + nsub.dropMu.Unlock() + p.logger.Warn(p.ctx, "nats: negative dropped count") + return + } + // Dropped is cumulative per subscription; signal only new drops. + droppedCount := uint64(dropped) + if droppedCount < nsub.lastDropped { + nsub.lastDropped = droppedCount + nsub.dropMu.Unlock() + return + } + if droppedCount == nsub.lastDropped { + nsub.dropMu.Unlock() + return + } + nsub.lastDropped = droppedCount + nsub.dropMu.Unlock() + + nsub.mu.Lock() + defer nsub.mu.Unlock() + + for s := range nsub.localSubs { + s.signalDrop() + } +} + +// Close stops local delivery and shuts down the Pubsub. It is idempotent. +// Close does not drain queued listener messages. +func (p *Pubsub) Close() error { + p.closeOnce.Do(func() { + p.mu.Lock() + // Cancel while holding p.mu so subscriber state cleanup below + // observes the canceled context. + p.cancel() + var subs []*localSub + shareds := make([]*natsSub, 0, len(p.subscriptions)) + for _, ss := range p.subscriptions { + shareds = append(shareds, ss) + ss.mu.Lock() + for s := range ss.localSubs { + subs = append(subs, s) + delete(ss.localSubs, s) + } + ss.mu.Unlock() + } + clear(p.subscriptions) + p.mu.Unlock() + + // Unsubscribe shared subscriptions before closing connections. + for _, ss := range shareds { + if ss.sub != nil { + _ = ss.sub.Unsubscribe() + } + } + + // Signal per-listener goroutines without waiting for callbacks. + for _, s := range subs { + s.close() + } + + for _, nc := range p.subscribePool { + if nc != nil { + nc.Close() + } + } + for _, nc := range p.publishPool { + if nc != nil { + nc.Close() + } + } + + if p.ns != nil { + p.ns.Shutdown() + p.ns.WaitForShutdown() + } + }) + return nil +} + +// pickConn returns the connection assigned to subject. Selection uses +// a stable FNV-1a hash so same-subject traffic always targets the same +// connection within a process; pools are immutable after construction +// so the lookup is lock-free. +func pickConn(pool []*natsgo.Conn, subject string) *natsgo.Conn { + if len(pool) == 1 { + return pool[0] + } + h := fnv.New32a() + _, _ = h.Write([]byte(subject)) + n := uint32(len(pool)) //nolint:gosec // pool size bounded by Options.{Publish,Subscribe}Conns + return pool[h.Sum32()%n] +} diff --git a/coderd/x/nats/pubsub_internal_test.go b/coderd/x/nats/pubsub_internal_test.go new file mode 100644 index 0000000000..3b5263654e --- /dev/null +++ b/coderd/x/nats/pubsub_internal_test.go @@ -0,0 +1,492 @@ +package nats + +import ( + "context" + "errors" + "fmt" + "net/url" + "sync" + "sync/atomic" + "testing" + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + natsgo "github.com/nats-io/nats.go" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/testutil" +) + +func Test_defaultPendingLimits(t *testing.T) { + t.Parallel() + + const defaultBytes = 512 * 1024 * 1024 + testCases := []struct { + name string + in PendingLimits + want PendingLimits + }{ + { + name: "AllZero", + in: PendingLimits{}, + want: PendingLimits{Msgs: -1, Bytes: defaultBytes}, + }, + { + name: "MsgsOnly", + in: PendingLimits{Msgs: 8}, + want: PendingLimits{Msgs: 8, Bytes: defaultBytes}, + }, + { + name: "BytesOnly", + in: PendingLimits{Bytes: 1024}, + want: PendingLimits{Msgs: -1, Bytes: 1024}, + }, + { + name: "NegativeMsgs", + in: PendingLimits{Msgs: -2}, + want: PendingLimits{Msgs: -2, Bytes: defaultBytes}, + }, + { + name: "NegativeBytes", + in: PendingLimits{Bytes: -2}, + want: PendingLimits{Msgs: -1, Bytes: -2}, + }, + { + name: "NegativeBoth", + in: PendingLimits{Msgs: -2, Bytes: -3}, + want: PendingLimits{Msgs: -2, Bytes: -3}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.want, defaultPendingLimits(tc.in)) + }) + } +} + +func Test_pickConn(t *testing.T) { + t.Parallel() + + t.Run("DifferentSubjects", func(t *testing.T) { + t.Parallel() + var a, b natsgo.Conn + pool := []*natsgo.Conn{&a, &b} + + require.NotSame(t, pickConn(pool, "a"), pickConn(pool, "b")) + }) +} + +func subjectForConn(t *testing.T, pool []*natsgo.Conn, conn *natsgo.Conn, prefix string) string { + t.Helper() + + for i := range 10_000 { + subject := fmt.Sprintf("%s_%d", prefix, i) + if pickConn(pool, subject) == conn { + return subject + } + } + require.FailNow(t, "no subject matched requested connection") + return "" +} + +func Test_New(t *testing.T) { + t.Parallel() + + t.Run("ConnectionCount", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, defaultTestOptions()) + t.Cleanup(func() { _ = ps.Close() }) + + const n = 50 + cancels := make([]func(), 0, n) + for i := range n { + c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(_ context.Context, _ []byte) {}) + require.NoError(t, err) + cancels = append(cancels, c) + } + t.Cleanup(func() { + for _, c := range cancels { + c() + } + }) + + require.Equal(t, 2, ps.ns.NumClients(), + "expected exactly 2 client connections (pubConn + subConn), got %d", ps.ns.NumClients()) + require.Len(t, ps.publishPool, 1, "default PublishConns must be 1") + require.Len(t, ps.subscribePool, 1, "default SubscribeConns must be 1") + require.NotSame(t, ps.publishPool[0], ps.subscribePool[0], "pubConn and subConn must be distinct") + }) +} + +func Test_SubscribeWithErr(t *testing.T) { + t.Parallel() + + t.Run("SameSubjectSharesSubscription", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitShort) + ps, err := New(ctx, logger, defaultTestOptions()) + require.NoError(t, err) + t.Cleanup(func() { _ = ps.Close() }) + + cancelA, err := ps.Subscribe("coalesce_evt", func(context.Context, []byte) {}) + require.NoError(t, err) + t.Cleanup(cancelA) + cancelB, err := ps.Subscribe("coalesce_evt", func(context.Context, []byte) {}) + require.NoError(t, err) + t.Cleanup(cancelB) + + ps.mu.Lock() + defer ps.mu.Unlock() + require.Len(t, ps.subscriptions, 1) + }) +} + +func Test_Pubsub_buildConnHandlers(t *testing.T) { + t.Parallel() + + t.Run("DisconnectSignalsDropsForMatchingSubscriberConn", func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitShort) + ps := newPubsub(ctx, logger, defaultTestOptions()) + + var subConnA, subConnB, pubConn natsgo.Conn + ps.subscribePool = []*natsgo.Conn{&subConnA, &subConnB} + matchingEvent := subjectForConn(t, ps.subscribePool, &subConnA, "disconnect_match") + otherEvent := subjectForConn(t, ps.subscribePool, &subConnB, "disconnect_other") + + newLocal := func(event string) *localSub { + return &localSub{ + event: event, + dropSignal: make(chan struct{}, 1), + } + } + + matchingSub := newLocal(matchingEvent) + otherSub := newLocal(otherEvent) + ps.subscriptions[matchingSub.event] = &natsSub{localSubs: map[*localSub]struct{}{matchingSub: {}}} + ps.subscriptions[otherSub.event] = &natsSub{localSubs: map[*localSub]struct{}{otherSub: {}}} + + handlers := ps.buildConnHandlers() + handlers.disconnectErr(&subConnA, xerrors.New("disconnect")) + + select { + case <-matchingSub.dropSignal: + default: + require.Fail(t, "matching subscriber did not receive drop signal") + } + select { + case <-otherSub.dropSignal: + require.Fail(t, "non-matching subscriber received drop signal") + default: + } + + handlers.disconnectErr(&pubConn, xerrors.New("publisher disconnect")) + select { + case <-otherSub.dropSignal: + require.Fail(t, "publisher connection disconnect signaled subscriber") + default: + } + }) +} + +func Test_localSub_init(t *testing.T) { + t.Parallel() + + t.Run("SerializesCallbacks", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + dataStarted := make(chan struct{}) + dropDelivered := make(chan struct{}) + release := make(chan struct{}) + var dataOnce sync.Once + var dropOnce sync.Once + var releaseOnce sync.Once + var active atomic.Int64 + var concurrent atomic.Bool + + s := &localSub{ + ctx: ctx, + cancel: func() {}, + listener: func(_ context.Context, _ []byte, ferr error) { + if active.Add(1) != 1 { + concurrent.Store(true) + } + defer active.Add(-1) + + if errors.Is(ferr, pubsub.ErrDroppedMessages) { + dropOnce.Do(func() { close(dropDelivered) }) + return + } + + dataOnce.Do(func() { close(dataStarted) }) + <-release + }, + queue: make(chan []byte, 1), + dropSignal: make(chan struct{}, 1), + } + s.init() + t.Cleanup(func() { + releaseOnce.Do(func() { close(release) }) + s.close() + }) + + s.enqueue([]byte("data")) + require.Eventually(t, func() bool { + select { + case <-dataStarted: + return true + default: + return false + } + }, testutil.WaitShort, testutil.IntervalFast) + + s.signalDrop() + require.Never(t, func() bool { + select { + case <-dropDelivered: + return true + default: + return false + } + }, testutil.IntervalMedium, testutil.IntervalFast, + "drop callback must wait for the blocked data callback") + require.False(t, concurrent.Load(), "listener callback ran concurrently") + + releaseOnce.Do(func() { close(release) }) + require.Eventually(t, func() bool { + select { + case <-dropDelivered: + return true + default: + return false + } + }, testutil.WaitShort, testutil.IntervalFast) + require.False(t, concurrent.Load(), "listener callback ran concurrently") + }) + + t.Run("CrossSubjectListenerIsolation", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitLong) + ps, err := New(ctx, logger, defaultTestOptions()) + require.NoError(t, err) + t.Cleanup(func() { _ = ps.Close() }) + + release := make(chan struct{}) + var releaseOnce sync.Once + var slowDrops atomic.Int64 + var slowBlocked atomic.Bool + slowCancel, err := ps.SubscribeWithErr("iso_slow", func(_ context.Context, _ []byte, ferr error) { + if ferr != nil && errors.Is(ferr, pubsub.ErrDroppedMessages) { + slowDrops.Add(1) + return + } + if slowBlocked.CompareAndSwap(false, true) { + <-release + } + }) + require.NoError(t, err) + defer slowCancel() + + var fastCount atomic.Int64 + fastCancel, err := ps.Subscribe("iso_fast", func(_ context.Context, _ []byte) { + fastCount.Add(1) + }) + require.NoError(t, err) + defer fastCancel() + defer releaseOnce.Do(func() { close(release) }) + + total := defaultListenerQueueSize + 256 + payload := make([]byte, 4*1024) + for range total { + require.NoError(t, ps.Publish("iso_slow", payload)) + require.NoError(t, ps.Publish("iso_fast", []byte("ping"))) + } + require.NoError(t, ps.Flush()) + + require.Eventually(t, func() bool { + return fastCount.Load() >= int64(total) + }, testutil.WaitLong, testutil.IntervalFast) + require.Zero(t, slowDrops.Load(), + "drop callback must wait for the blocked data callback") + releaseOnce.Do(func() { close(release) }) + require.Eventually(t, func() bool { + return slowDrops.Load() >= 1 + }, testutil.WaitLong, testutil.IntervalFast, + "slow subscriber must receive at least one ErrDroppedMessages signal") + + require.GreaterOrEqual(t, fastCount.Load(), int64(total), + "fast subscriber must keep receiving despite slow peer on shared subConn") + require.Len(t, ps.subscribePool, 1) + require.False(t, ps.subscribePool[0].IsClosed(), "subConn must not be closed by slow consumer") + require.True(t, ps.subscribePool[0].IsConnected(), "subConn must stay connected") + require.Equal(t, 2, ps.ns.NumClients(), "slow consumer must not disconnect subConn") + }) +} + +func TestPubsubCluster(t *testing.T) { + t.Parallel() + + // OK verifies that SetPeerAddresses changes the active cluster topology. + // A starts connected to B, then C is added and receives both global and + // C-only messages. B is then removed from A's peers, while C continues to + // receive global and C-only messages. + t.Run("OK", func(t *testing.T) { + t.Parallel() + + a := newTestPubsub(t, clusterTestOptions(t)) + b := newTestPubsub(t, clusterTestOptions(t)) + c := newTestPubsub(t, clusterTestOptions(t)) + + addrB := clusterRouteAddress(t, b) + addrC := clusterRouteAddress(t, c) + + require.NoError(t, a.SetPeerAddresses([]string{addrB})) + requireRoutesEqual(t, a.currentRoutes, addrB) + + globalEvent := "global" + bGlobal := make(chan []byte, 8) + cancelBGlobal, err := b.Subscribe(globalEvent, func(_ context.Context, msg []byte) { + bGlobal <- msg + }) + require.NoError(t, err) + defer cancelBGlobal() + + waitForRouteSubscription(t, a, globalEvent) + publishAndFlush(t, a, globalEvent, "from-a-to-b") + require.Equal(t, "from-a-to-b", string(receiveMessage(t, bGlobal))) + + // Add C's subscriptions before adding C as an extra peer to A. + cGlobal := make(chan []byte, 8) + cancelCGlobal, err := c.Subscribe(globalEvent, func(_ context.Context, msg []byte) { + cGlobal <- msg + }) + require.NoError(t, err) + defer cancelCGlobal() + + cSubject := "c-only-subscriber" + cUnique := make(chan []byte, 8) + cancelCUnique, err := c.Subscribe(cSubject, func(_ context.Context, msg []byte) { + cUnique <- msg + }) + require.NoError(t, err) + defer cancelCUnique() + + // Add C to A's peer list. B and C should both receive global messages, + // while the C-only subject should route only to C. + require.NoError(t, a.SetPeerAddresses([]string{addrC, addrB})) + requireRoutesEqual(t, a.currentRoutes, addrB, addrC) + + waitForRouteSubscription(t, a, globalEvent) + waitForRouteSubscription(t, a, cSubject) + + publishAndFlush(t, a, globalEvent, "new-global-msg") + require.Equal(t, "new-global-msg", string(receiveMessage(t, bGlobal))) + require.Equal(t, "new-global-msg", string(receiveMessage(t, cGlobal))) + + publishAndFlush(t, a, cSubject, "c-unique-msg") + require.Equal(t, "c-unique-msg", string(receiveMessage(t, cUnique))) + + // Remove B from A's peer list. Only C should receive the next messages. + require.NoError(t, a.SetPeerAddresses([]string{addrC})) + requireRoutesEqual(t, a.currentRoutes, addrC) + + publishAndFlush(t, a, globalEvent, "no-b-peer") + require.Equal(t, "no-b-peer", string(receiveMessage(t, cGlobal))) + + publishAndFlush(t, a, cSubject, "c-messages-still-work") + require.Equal(t, "c-messages-still-work", string(receiveMessage(t, cUnique))) + }) +} + +func defaultTestOptions() Options { + return Options{disableCluster: true} +} + +func clusterTestOptions(t *testing.T) Options { + t.Helper() + return Options{ + ClusterHost: "127.0.0.1", + ClusterPort: natsserver.RANDOM_PORT, + disableCluster: false, + } +} + +func newTestPubsub(t *testing.T, opts Options) *Pubsub { + t.Helper() + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitLong) + ps, err := New(ctx, logger, opts) + require.NoError(t, err) + t.Cleanup(func() { + _ = ps.Close() + }) + return ps +} + +func clusterRouteAddress(t *testing.T, ps *Pubsub) string { + t.Helper() + addr := ps.ns.ClusterAddr() + require.NotNil(t, addr) + return "nats://" + addr.String() +} + +func waitForRouteSubscription(t *testing.T, ps *Pubsub, subject string) { + t.Helper() + require.Eventually(t, func() bool { + routes, err := ps.ns.Routez(&natsserver.RoutezOptions{Subscriptions: true}) + if err != nil { + return false + } + for _, route := range routes.Routes { + for _, sub := range route.Subs { + if sub == subject { + return true + } + } + } + return false + }, testutil.WaitShort, testutil.IntervalFast) +} + +func publishAndFlush(t *testing.T, ps *Pubsub, event, message string) { + t.Helper() + require.NoError(t, ps.Publish(event, []byte(message))) + require.NoError(t, ps.Flush()) +} + +func receiveMessage(t *testing.T, got <-chan []byte) []byte { + t.Helper() + select { + case msg := <-got: + return msg + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for message") + return nil + } +} + +func requireRoutesEqual(t *testing.T, routes []*url.URL, addresses ...string) { + t.Helper() + want, err := parsePeerAddresses(addresses) + require.NoError(t, err) + want = sortRouteURLs(want) + require.True(t, sortedURLsEqual(want, routes), "want %v, got %v", routeStrings(want), routeStrings(routes)) +} + +func routeStrings(routes []*url.URL) []string { + strings := make([]string, 0, len(routes)) + for _, route := range routes { + strings = append(strings, route.String()) + } + return strings +} diff --git a/coderd/x/nats/pubsub_test.go b/coderd/x/nats/pubsub_test.go new file mode 100644 index 0000000000..7b65228b7a --- /dev/null +++ b/coderd/x/nats/pubsub_test.go @@ -0,0 +1,204 @@ +package nats_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/x/nats" + "github.com/coder/coder/v2/testutil" +) + +func newPubsub(t *testing.T, opts nats.Options) *nats.Pubsub { + t.Helper() + + if opts.ClusterPort == 0 { + opts.ClusterPort = natsserver.RANDOM_PORT + } + + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitLong) + ps, err := nats.New(ctx, logger, opts) + require.NoError(t, err) + t.Cleanup(func() { + _ = ps.Close() + }) + return ps +} + +func TestPubsub(t *testing.T) { + t.Parallel() + + t.Run("RoundTrip", func(t *testing.T) { + t.Parallel() + ps := newPubsub(t, nats.Options{}) + + got := make(chan []byte, 1) + cancel, err := ps.Subscribe("test_event", func(_ context.Context, msg []byte) { + got <- msg + }) + require.NoError(t, err) + defer cancel() + + require.NoError(t, ps.Publish("test_event", []byte("hello"))) + + select { + case msg := <-got: + assert.Equal(t, "hello", string(msg)) + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for message") + } + }) + + t.Run("SubscribeWithErrNormalMessage", func(t *testing.T) { + t.Parallel() + ps := newPubsub(t, nats.Options{}) + + got := make(chan []byte, 1) + cancel, err := ps.SubscribeWithErr("evt", func(_ context.Context, msg []byte, err error) { + assert.NoError(t, err) + got <- msg + }) + require.NoError(t, err) + defer cancel() + + require.NoError(t, ps.Publish("evt", []byte("payload"))) + + select { + case msg := <-got: + assert.Equal(t, "payload", string(msg)) + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for message") + } + }) + + t.Run("EchoDefault", func(t *testing.T) { + t.Parallel() + ps := newPubsub(t, nats.Options{}) + + got := make(chan []byte, 1) + cancel, err := ps.Subscribe("echo_evt", func(_ context.Context, msg []byte) { + got <- msg + }) + require.NoError(t, err) + defer cancel() + + require.NoError(t, ps.Publish("echo_evt", []byte("data"))) + + select { + case msg := <-got: + assert.Equal(t, "data", string(msg)) + case <-time.After(testutil.WaitShort): + t.Fatal("default should echo own messages") + } + }) + + t.Run("Ordering", func(t *testing.T) { + t.Parallel() + ps := newPubsub(t, nats.Options{}) + + const n = 100 + got := make(chan []byte, n) + cancel, err := ps.Subscribe("ord_evt", func(_ context.Context, msg []byte) { + got <- msg + }) + require.NoError(t, err) + defer cancel() + + for i := 0; i < n; i++ { + require.NoError(t, ps.Publish("ord_evt", []byte(fmt.Sprintf("%d", i)))) + } + + deadline := time.After(testutil.WaitLong) + for i := 0; i < n; i++ { + select { + case msg := <-got: + assert.Equal(t, fmt.Sprintf("%d", i), string(msg)) + case <-deadline: + t.Fatalf("timed out at message %d/%d", i, n) + } + } + }) + + t.Run("CloseIdempotent", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + ps, err := nats.New(ctx, logger, nats.Options{}) + require.NoError(t, err) + + var first, second error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + first = ps.Close() + }() + wg.Wait() + second = ps.Close() + assert.NoError(t, first) + assert.NoError(t, second) + }) + + t.Run("SubscribeWithErrReceivesDropError", func(t *testing.T) { + t.Parallel() + ps := newPubsub(t, nats.Options{ + PendingLimits: nats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024}, + }) + + const event = "slow_evt_sync" + started := make(chan struct{}) + release := make(chan struct{}) + dropped := make(chan error, 1) + var startedOnce sync.Once + var releaseOnce sync.Once + defer releaseOnce.Do(func() { close(release) }) + + cancel, err := ps.SubscribeWithErr(event, func(_ context.Context, _ []byte, err error) { + if err != nil { + select { + case dropped <- err: + default: + } + return + } + startedOnce.Do(func() { + close(started) + <-release + }) + }) + require.NoError(t, err) + defer cancel() + + require.NoError(t, ps.Publish(event, []byte("first"))) + require.NoError(t, ps.Flush()) + select { + case <-started: + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for first callback") + } + + for i := 0; i < 8; i++ { + require.NoError(t, ps.Publish(event, []byte("burst"))) + } + require.NoError(t, ps.Flush()) + releaseOnce.Do(func() { close(release) }) + + select { + case err := <-dropped: + assert.ErrorIs(t, err, pubsub.ErrDroppedMessages) + case <-time.After(testutil.WaitLong): + t.Fatal("timed out waiting for drop error") + } + }) +} diff --git a/coderd/x/nats/server.go b/coderd/x/nats/server.go new file mode 100644 index 0000000000..6013c44feb --- /dev/null +++ b/coderd/x/nats/server.go @@ -0,0 +1,119 @@ +package nats + +import ( + "time" + + natsserver "github.com/nats-io/nats-server/v2/server" + natsgo "github.com/nats-io/nats.go" + "golang.org/x/xerrors" +) + +const readyTimeout = 10 * time.Second + +// buildServerOptions constructs the embedded NATS server options. The +// server runs with a loopback random client listener and an optional +// cluster route listener. +func buildServerOptions(opts Options) (*natsserver.Options, error) { + maxPayload := opts.MaxPayload + if maxPayload == 0 { + maxPayload = natsserver.MAX_PAYLOAD_SIZE + } + maxPending := opts.MaxPending + if maxPending <= 0 { + maxPending = DefaultMaxPending + } + + sopts := &natsserver.Options{ + JetStream: false, + MaxPayload: maxPayload, + MaxPending: maxPending, + NoLog: true, + NoSigs: true, + } + + sopts.DontListen = false + sopts.Host = "127.0.0.1" + sopts.Port = natsserver.RANDOM_PORT + + if !opts.disableCluster { + clusterHost := opts.ClusterHost + if clusterHost == "" { + clusterHost = natsserver.DEFAULT_HOST + } + clusterPort := opts.ClusterPort + if clusterPort == 0 { + clusterPort = defaultClusterPort + } + routePoolSize := opts.RoutePoolSize + if routePoolSize == 0 { + routePoolSize = defaultRoutePoolSize + } + + sopts.Cluster = natsserver.ClusterOpts{ + Name: defaultClusterName, + Host: clusterHost, + Port: clusterPort, + PoolSize: routePoolSize, + } + } + + return sopts, nil +} + +// startEmbeddedServer starts an in-process NATS server. +func startEmbeddedServer(opts *natsserver.Options) (*natsserver.Server, error) { + ns, err := natsserver.NewServer(opts) + if err != nil { + return nil, xerrors.Errorf("new embedded nats server: %w", err) + } + go ns.Start() + if !ns.ReadyForConnections(readyTimeout) { + ns.Shutdown() + ns.WaitForShutdown() + return nil, xerrors.Errorf("embedded nats server not ready within %s", readyTimeout) + } + return ns, nil +} + +type connHandlers struct { + disconnectErr natsgo.ConnErrHandler + reconnect natsgo.ConnHandler + closed natsgo.ConnHandler + errH natsgo.ErrHandler +} + +// connectClient dials the embedded server's client listener over TCP +// loopback (or net.Pipe when opts.InProcess is true) and returns the +// resulting *natsgo.Conn. connName identifies the connection in server +// logs. +func connectClient(ns *natsserver.Server, opts Options, handlers connHandlers, connName string) (*natsgo.Conn, error) { + connOpts := []natsgo.Option{ + natsgo.Name(connName), + } + if opts.ReconnectWait > 0 { + connOpts = append(connOpts, natsgo.ReconnectWait(opts.ReconnectWait)) + } + if handlers.disconnectErr != nil { + connOpts = append(connOpts, natsgo.DisconnectErrHandler(handlers.disconnectErr)) + } + if handlers.reconnect != nil { + connOpts = append(connOpts, natsgo.ReconnectHandler(handlers.reconnect)) + } + if handlers.closed != nil { + connOpts = append(connOpts, natsgo.ClosedHandler(handlers.closed)) + } + if handlers.errH != nil { + connOpts = append(connOpts, natsgo.ErrorHandler(handlers.errH)) + } + clientURL := ns.ClientURL() + if opts.InProcess { + // InProcessServer overrides URL dialing with a net.Pipe; the + // URL argument is ignored but must still be syntactically valid. + connOpts = append(connOpts, natsgo.InProcessServer(ns)) + } + nc, err := natsgo.Connect(clientURL, connOpts...) + if err != nil { + return nil, xerrors.Errorf("connect client: %w", err) + } + return nc, nil +} diff --git a/codersdk/aibridge.go b/codersdk/aibridge.go index 4e49176171..d04359acb3 100644 --- a/codersdk/aibridge.go +++ b/codersdk/aibridge.go @@ -175,9 +175,12 @@ type AIBridgeListSessionsFilter struct { Initiator string `json:"initiator,omitempty"` StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"` StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"` - // Provider matches the provider type column (openai, anthropic, - // copilot). Retained for backward compatibility; new clients should - // prefer ProviderName, which scopes to a specific configured row. + // Provider matches the runtime provider type column (openai, + // anthropic, copilot). The runtime type collapses the configured + // ai_provider_type: azure, google, openai-compat, openrouter, and + // vercel route through openai; bedrock routes through anthropic. + // Retained for backward compatibility; new clients should prefer + // ProviderName, which scopes to a specific configured row. Provider string `json:"provider,omitempty"` ProviderName string `json:"provider_name,omitempty"` Model string `json:"model,omitempty"` @@ -202,9 +205,12 @@ type AIBridgeListInterceptionsFilter struct { Initiator string `json:"initiator,omitempty"` StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"` StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"` - // Provider matches the provider type column (openai, anthropic, - // copilot). Retained for backward compatibility; new clients should - // prefer ProviderName, which scopes to a specific configured row. + // Provider matches the runtime provider type column (openai, + // anthropic, copilot). The runtime type collapses the configured + // ai_provider_type: azure, google, openai-compat, openrouter, and + // vercel route through openai; bedrock routes through anthropic. + // Retained for backward compatibility; new clients should prefer + // ProviderName, which scopes to a specific configured row. Provider string `json:"provider,omitempty"` ProviderName string `json:"provider_name,omitempty"` Model string `json:"model,omitempty"` @@ -429,3 +435,71 @@ func (c *Client) DeleteGroupAIBudget(ctx context.Context, group uuid.UUID) error } return nil } + +type UserAIBudgetOverride struct { + UserID uuid.UUID `json:"user_id" format:"uuid"` + GroupID uuid.UUID `json:"group_id" format:"uuid"` + SpendLimitMicros int64 `json:"spend_limit_micros"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` +} + +type UpsertUserAIBudgetOverrideRequest struct { + // GroupID is the group the user's spend is attributed to. The user must + // be a member of this group. + GroupID uuid.UUID `json:"group_id" format:"uuid" validate:"required"` + SpendLimitMicros int64 `json:"spend_limit_micros" validate:"gte=0"` +} + +// UserAIBudgetOverride returns the AI spend budget override configured for the given user. +func (c *Client) UserAIBudgetOverride(ctx context.Context, user uuid.UUID) (UserAIBudgetOverride, error) { + res, err := c.Request(ctx, http.MethodGet, + fmt.Sprintf("/api/v2/users/%s/ai/budget", user.String()), + nil, + ) + if err != nil { + return UserAIBudgetOverride{}, xerrors.Errorf("make request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return UserAIBudgetOverride{}, ReadBodyAsError(res) + } + var resp UserAIBudgetOverride + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// UpsertUserAIBudgetOverride creates or updates the AI spend budget override for the given user. +func (c *Client) UpsertUserAIBudgetOverride(ctx context.Context, user uuid.UUID, req UpsertUserAIBudgetOverrideRequest) (UserAIBudgetOverride, error) { + res, err := c.Request(ctx, http.MethodPut, + fmt.Sprintf("/api/v2/users/%s/ai/budget", user.String()), + req, + ) + if err != nil { + return UserAIBudgetOverride{}, xerrors.Errorf("make request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return UserAIBudgetOverride{}, ReadBodyAsError(res) + } + var resp UserAIBudgetOverride + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// DeleteUserAIBudgetOverride removes the AI spend budget override for the given user. +func (c *Client) DeleteUserAIBudgetOverride(ctx context.Context, user uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, + fmt.Sprintf("/api/v2/users/%s/ai/budget", user.String()), + nil, + ) + if err != nil { + return xerrors.Errorf("make request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} diff --git a/codersdk/aiproviders.go b/codersdk/aiproviders.go index 3b47598118..fe34b9c03c 100644 --- a/codersdk/aiproviders.go +++ b/codersdk/aiproviders.go @@ -224,10 +224,24 @@ func (req CreateAIProviderRequest) Validate() []ValidationError { validations = append(validations, validateAIProviderName(req.Name)...) validations = append(validations, validateRequiredAIProviderBaseURL(req.BaseURL)...) validations = append(validations, validateAIProviderAPIKeys(req.APIKeys)...) - if req.Settings.Bedrock != nil && req.Type != AIProviderTypeAnthropic { + if req.Settings.Bedrock != nil && + req.Type != AIProviderTypeAnthropic && + req.Type != AIProviderTypeBedrock { validations = append(validations, ValidationError{ Field: "settings", - Detail: "bedrock settings are only valid for type=anthropic", + Detail: "bedrock settings are only valid for type=anthropic or type=bedrock", + }) + } + if req.Type == AIProviderTypeBedrock && (req.Settings.Bedrock == nil || !req.Settings.Bedrock.IsConfigured()) { + validations = append(validations, ValidationError{ + Field: "settings", + Detail: "type=bedrock requires bedrock settings", + }) + } + if req.Type == AIProviderTypeBedrock && len(req.APIKeys) > 0 { + validations = append(validations, ValidationError{ + Field: "api_keys", + Detail: "type=bedrock does not accept api_keys", }) } return validations diff --git a/codersdk/apikey_scopes_gen.go b/codersdk/apikey_scopes_gen.go index 7bad39ccc2..4e4fb8d803 100644 --- a/codersdk/apikey_scopes_gen.go +++ b/codersdk/apikey_scopes_gen.go @@ -40,6 +40,10 @@ const ( APIKeyScopeAuditLogAll APIKeyScope = "audit_log:*" APIKeyScopeAuditLogCreate APIKeyScope = "audit_log:create" APIKeyScopeAuditLogRead APIKeyScope = "audit_log:read" + APIKeyScopeBoundaryLogAll APIKeyScope = "boundary_log:*" + APIKeyScopeBoundaryLogCreate APIKeyScope = "boundary_log:create" + APIKeyScopeBoundaryLogDelete APIKeyScope = "boundary_log:delete" + APIKeyScopeBoundaryLogRead APIKeyScope = "boundary_log:read" APIKeyScopeBoundaryUsageAll APIKeyScope = "boundary_usage:*" APIKeyScopeBoundaryUsageDelete APIKeyScope = "boundary_usage:delete" APIKeyScopeBoundaryUsageRead APIKeyScope = "boundary_usage:read" diff --git a/codersdk/chats.go b/codersdk/chats.go index 665ace7aa8..bcf235f590 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1525,14 +1525,16 @@ type ChatStreamStatus struct { type ChatErrorKind string const ( - ChatErrorKindGeneric ChatErrorKind = "generic" - ChatErrorKindOverloaded ChatErrorKind = "overloaded" - ChatErrorKindRateLimit ChatErrorKind = "rate_limit" - ChatErrorKindTimeout ChatErrorKind = "timeout" - ChatErrorKindStartupTimeout ChatErrorKind = "startup_timeout" - ChatErrorKindAuth ChatErrorKind = "auth" - ChatErrorKindConfig ChatErrorKind = "config" - ChatErrorKindUsageLimit ChatErrorKind = "usage_limit" + ChatErrorKindGeneric ChatErrorKind = "generic" + ChatErrorKindOverloaded ChatErrorKind = "overloaded" + ChatErrorKindRateLimit ChatErrorKind = "rate_limit" + ChatErrorKindTimeout ChatErrorKind = "timeout" + ChatErrorKindStartupTimeout ChatErrorKind = "startup_timeout" + ChatErrorKindAuth ChatErrorKind = "auth" + ChatErrorKindConfig ChatErrorKind = "config" + ChatErrorKindUsageLimit ChatErrorKind = "usage_limit" + ChatErrorKindMissingKey ChatErrorKind = "missing_key" + ChatErrorKindProviderDisabled ChatErrorKind = "provider_disabled" ) // AllChatErrorKinds contains every ChatErrorKind value. @@ -1546,6 +1548,8 @@ var AllChatErrorKinds = []ChatErrorKind{ ChatErrorKindAuth, ChatErrorKindConfig, ChatErrorKindUsageLimit, + ChatErrorKindMissingKey, + ChatErrorKindProviderDisabled, } // ChatError represents a terminal chat error in persisted chat state or the diff --git a/codersdk/deployment.go b/codersdk/deployment.go index f9cdb8a8fc..3fb36c587f 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -638,6 +638,7 @@ type DeploymentValues struct { AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"` BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"` SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"` + UseLegacySCIM serpent.Bool `json:"scim_use_legacy,omitempty" typescript:",notnull"` ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"` Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"` RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"` @@ -3447,6 +3448,18 @@ func (c *DeploymentValues) Options() serpent.OptionSet { Annotations: serpent.Annotations{}.Mark(annotationEnterpriseKey, "true").Mark(annotationSecretKey, "true"), Value: &c.SCIMAPIKey, }, + { + Name: "SCIM Use Legacy", + // The legacy SCIM is a weird mix of SCIM 1.0 and SCIM 2.0 + Description: "Use the legacy SCIM implementation instead of the SCIM 2.0 handler. This is provided for backward compatibility for existing users.", + Flag: "scim-use-legacy", + Env: "CODER_SCIM_USE_LEGACY", + Hidden: true, + // TODO: When SCIM 2.0 has been tested more, flip this to false to default to the new scim + Default: "true", + Annotations: serpent.Annotations{}.Mark(annotationEnterpriseKey, "true"), + Value: &c.UseLegacySCIM, + }, { Name: "External Token Encryption Keys", Description: "Encrypt OIDC and Git authentication tokens with AES-256-GCM in the database. The value must be a comma-separated list of base64-encoded keys. Each key, when base64-decoded, must be exactly 32 bytes in length. The first key will be used to encrypt new values. Subsequent keys will be used as a fallback when decrypting. During normal operation it is recommended to only set one key unless you are in the process of rotating keys with the `coder server dbcrypt rotate` command.", @@ -4684,7 +4697,9 @@ type AIBridgeBedrockConfig struct { // CODER_AIBRIDGE_PROVIDER__ is also accepted as a deprecated alias. // This follows the same indexed pattern as ExternalAuthConfig. type AIProviderConfig struct { - // Type is the provider type: "openai", "anthropic", or "copilot". + // Type is the provider type. Valid values are: "openai", + // "anthropic", "azure", "bedrock", "google", "openai-compat", + // "openrouter", "vercel", "copilot". Type string `json:"type"` // Name is the unique instance identifier used for routing. // Defaults to Type if not provided. diff --git a/codersdk/rbacresources_gen.go b/codersdk/rbacresources_gen.go index 11b6488182..75b1e82421 100644 --- a/codersdk/rbacresources_gen.go +++ b/codersdk/rbacresources_gen.go @@ -13,6 +13,7 @@ const ( ResourceAssignOrgRole RBACResource = "assign_org_role" ResourceAssignRole RBACResource = "assign_role" ResourceAuditLog RBACResource = "audit_log" + ResourceBoundaryLog RBACResource = "boundary_log" ResourceBoundaryUsage RBACResource = "boundary_usage" ResourceChat RBACResource = "chat" ResourceConnectionLog RBACResource = "connection_log" @@ -89,6 +90,7 @@ var RBACResourceActions = map[RBACResource][]RBACAction{ ResourceAssignOrgRole: {ActionAssign, ActionCreate, ActionDelete, ActionRead, ActionUnassign, ActionUpdate}, ResourceAssignRole: {ActionAssign, ActionRead, ActionUnassign}, ResourceAuditLog: {ActionCreate, ActionRead}, + ResourceBoundaryLog: {ActionCreate, ActionDelete, ActionRead}, ResourceBoundaryUsage: {ActionDelete, ActionRead, ActionUpdate}, ResourceChat: {ActionCreate, ActionDelete, ActionRead, ActionShare, ActionUpdate}, ResourceConnectionLog: {ActionRead, ActionUpdate}, diff --git a/docs/admin/integrations/oauth2-provider.md b/docs/admin/integrations/oauth2-provider.md index 910a6c31b4..7476b58681 100644 --- a/docs/admin/integrations/oauth2-provider.md +++ b/docs/admin/integrations/oauth2-provider.md @@ -239,7 +239,7 @@ eval $(./setup-test-app.sh) ./cleanup-test-app.sh ``` -For more details on testing, see the [OAuth2 test scripts README](../../../scripts/oauth2/README.md). +For more details on testing, see the [OAuth2 test scripts README](https://github.com/coder/coder/blob/main/scripts/oauth2/README.md). ## Common Issues diff --git a/docs/admin/integrations/prometheus.md b/docs/admin/integrations/prometheus.md index 210f22d040..acaf3e0641 100644 --- a/docs/admin/integrations/prometheus.md +++ b/docs/admin/integrations/prometheus.md @@ -120,11 +120,17 @@ deployment. They will always be available from the agent. | `coder_aibridged_non_injected_tool_selections_total` | counter | The number of times an AI model selected a tool to be invoked by the client. | `model` `name` `provider` | | `coder_aibridged_passthrough_total` | counter | The count of requests which were not intercepted but passed through to the upstream. | `method` `provider` `route` | | `coder_aibridged_prompts_total` | counter | The number of prompts issued by users (initiators). | `initiator_id` `model` `provider` | +| `coder_aibridged_provider_info` | gauge | One series per configured AI provider. Value is always 1; the status label (enabled, disabled, error) carries the alertable signal. | `provider_name` `provider_type` `status` | +| `coder_aibridged_providers_last_reload_success_timestamp_seconds` | gauge | Unix timestamp of the last provider reload that successfully refreshed the pool. A gap against coder_aibridged_providers_last_reload_timestamp_seconds means the loop is firing but the refresh function is failing. | | +| `coder_aibridged_providers_last_reload_timestamp_seconds` | gauge | Unix timestamp of the last provider reload attempt, success or failure. | | | `coder_aibridged_tokens_total` | counter | The number of tokens used by intercepted requests. | `initiator_id` `model` `provider` `type` | | `coder_aibridgeproxyd_connect_sessions_total` | counter | Total number of CONNECT sessions established. | `type` | | `coder_aibridgeproxyd_inflight_mitm_requests` | gauge | Number of MITM requests currently being processed. | `provider` | | `coder_aibridgeproxyd_mitm_requests_total` | counter | Total number of MITM requests handled by the proxy. | `provider` | | `coder_aibridgeproxyd_mitm_responses_total` | counter | Total number of MITM responses by HTTP status code class. | `code` `provider` | +| `coder_aibridgeproxyd_provider_info` | gauge | One series per configured AI provider. Value is always 1; the status label (enabled, disabled, error) carries the alertable signal. | `provider_name` `provider_type` `status` | +| `coder_aibridgeproxyd_providers_last_reload_success_timestamp_seconds` | gauge | Unix timestamp of the last provider reload that successfully refreshed the router. A gap against coder_aibridgeproxyd_providers_last_reload_timestamp_seconds means the loop is firing but the refresh function is failing. | | +| `coder_aibridgeproxyd_providers_last_reload_timestamp_seconds` | gauge | Unix timestamp of the last provider reload attempt, success or failure. | | | `coder_derp_server_accepts_total` | counter | Total DERP connections accepted. | | | `coder_derp_server_average_queue_duration_ms` | gauge | Average queue duration in milliseconds. | | | `coder_derp_server_bytes_received_total` | counter | Total bytes received. | | diff --git a/docs/admin/monitoring/logs.md b/docs/admin/monitoring/logs.md index 8b9f5e747d..7e4c27154c 100644 --- a/docs/admin/monitoring/logs.md +++ b/docs/admin/monitoring/logs.md @@ -19,6 +19,11 @@ machine/VM. the[`CODER_LOG_FILTER`](../../reference/cli/server.md#-l---log-filter) server config. Using `.*` will result in the `DEBUG` log level being used. +> [!NOTE] +> To disable human-readable logging, set `--log-human` (or +> `CODER_LOGGING_HUMAN`) to `/dev/null`. An empty string does not disable +> logging. + Events such as server errors, audit logs, user activities, and SSO & OpenID Connect logs are all captured in the `coderd` logs. diff --git a/docs/admin/security/0001_user_apikeys_invalidation.md b/docs/admin/security/0001_user_apikeys_invalidation.md deleted file mode 100644 index 203a891766..0000000000 --- a/docs/admin/security/0001_user_apikeys_invalidation.md +++ /dev/null @@ -1,89 +0,0 @@ -# API Tokens of deleted users not invalidated - ---- - -## Summary - -Coder identified an issue in -[https://github.com/coder/coder](https://github.com/coder/coder) where API -tokens belonging to a deleted user were not invalidated. A deleted user in -possession of a valid and non-expired API token is still able to use the above -token with their full suite of capabilities. - -## Impact: HIGH - -If exploited, an attacker could perform any action that the deleted user was -authorized to perform. - -## Exploitability: HIGH - -The CLI writes the API key to `~/.coderv2/session` by default, so any deleted -user who previously logged in via the Coder CLI has the potential to exploit -this. Note that there is a time window for exploitation; API tokens have a -maximum lifetime after which they are no longer valid. - -The issue only affects users who were active (not suspended) at the time they -were deleted. Users who were first suspended and later deleted cannot exploit -this issue. - -## Affected Versions - -All versions of Coder between v0.8.15 and v0.22.2 (inclusive) are affected. - -All customers are advised to upgrade to -[v0.23.0](https://github.com/coder/coder/releases/tag/v0.23.0) as soon as -possible. - -## Details - -Coder incorrectly failed to invalidate API keys belonging to a user when they -were deleted. When authenticating a user via their API key, Coder incorrectly -failed to check whether the API key corresponds to a deleted user. - -## Indications of Compromise - -> [!TIP] -> Automated remediation steps in the upgrade purge all affected API keys. -> Either perform the following query before upgrade or run it on a backup of -> your database from before the upgrade. - -Execute the following SQL query: - -```sql -SELECT - users.email, - users.updated_at, - api_keys.id, - api_keys.last_used -FROM - users -LEFT JOIN - api_keys -ON - api_keys.user_id = users.id -WHERE - users.deleted -AND - api_keys.last_used > users.updated_at -; -``` - -If the output is similar to the below, then you are not affected: - -```sql ------ -(0 rows) -``` - -Otherwise, the following information will be reported: - -- User email -- Time the user was last modified (i.e. deleted) -- User API key ID -- Time the affected API key was last used - -> [!TIP] -> If your license includes the -> [Audit Logs](https://coder.com/docs/admin/audit-logs#filtering-logs) feature, -> you can then query all actions performed by the above users by using the -> filter `email:$USER_EMAIL`. diff --git a/docs/admin/security/index.md b/docs/admin/security/index.md index 37028093f8..f6684519e8 100644 --- a/docs/admin/security/index.md +++ b/docs/admin/security/index.md @@ -11,17 +11,6 @@ For other security tips, visit our guide to > If you discover a vulnerability in Coder, please do not hesitate to report it > to us by following the [security policy](https://github.com/coder/coder/blob/main/SECURITY.md). -From time to time, Coder employees or other community members may discover -vulnerabilities in the product. - -If a vulnerability requires an immediate upgrade to mitigate a potential -security risk, we will add it to the below table. - -Click on the description links to view more details about each specific -vulnerability. - ---- - -| Description | Severity | Fix | Vulnerable Versions | -|-----------------------------------------------------------------------------------------------------------------------------------------------|----------|----------------------------------------------------------------|---------------------| -| [API tokens of deleted users not invalidated](https://github.com/coder/coder/blob/main/docs/admin/security/0001_user_apikeys_invalidation.md) | HIGH | [v0.23.0](https://github.com/coder/coder/releases/tag/v0.23.0) | v0.8.25 - v0.22.2 | +Security advisories are published on the +[GitHub Security Advisories](https://github.com/coder/coder/security/advisories) +page. diff --git a/docs/admin/templates/extending-templates/web-ides.md b/docs/admin/templates/extending-templates/web-ides.md index 4240dfe552..dae3fc593b 100644 --- a/docs/admin/templates/extending-templates/web-ides.md +++ b/docs/admin/templates/extending-templates/web-ides.md @@ -55,7 +55,7 @@ resource "coder_agent" "main" { For advanced use, we recommend installing code-server in your VM snapshot or container image. Here's a Dockerfile which leverages some special -[code-server features](https://coder.com/docs/code-server/): +[code-server features](https://coder.com/docs/code-server): ```Dockerfile FROM codercom/enterprise-base:ubuntu diff --git a/docs/ai-coder/ai-gateway/clients/coder-agents.md b/docs/ai-coder/ai-gateway/clients/coder-agents.md index f5187cce58..de0fcad927 100644 --- a/docs/ai-coder/ai-gateway/clients/coder-agents.md +++ b/docs/ai-coder/ai-gateway/clients/coder-agents.md @@ -164,6 +164,13 @@ key is a valid Coder token. one [model](../../agents/models.md#add-a-model) to the provider after saving the Base URL. Providers without an enabled model are hidden from developers. +- **"Chat interrupted" error when resuming a conversation.** + This occurs when the API key that was used to start a chat turn is no + longer available. Common causes: upgrading from a version before + `api_key_id` tracking was introduced, or deleting an API key while a + chat is active. The error is self-healing: send your message again and + the new message will use your current API key. If the error persists + after resending, this indicates a bug. Please report it. ## Known limitations diff --git a/docs/ai-coder/ai-governance.md b/docs/ai-coder/ai-governance.md index 0c8f7b609a..1581a972c8 100644 --- a/docs/ai-coder/ai-governance.md +++ b/docs/ai-coder/ai-governance.md @@ -7,7 +7,9 @@ development environments. As adoption grows, many enterprises also need observability, management, and policy controls to support secure and auditable AI rollouts. -The AI Governance Add-On is a per-user license that can be added to Premium seats. Each user with the add-on gets access to a set of features +The AI Governance Add-On is a separate, per-user license for Premium customers. +It is not included with a Premium subscription and must be purchased separately. +Each user with the add-on gets access to a set of features that help organizations safely roll out AI tooling at scale: - [AI Gateway](./ai-gateway/index.md): LLM gateway to audit AI sessions, central @@ -15,9 +17,13 @@ that help organizations safely roll out AI tooling at scale: - [Agent Firewall](./agent-firewall/index.md): Process-level firewalls for agents, restricting which domains can be accessed by AI agents +> [!NOTE] +> As of Coder v2.32, the AI Governance Add-On is required to use AI Gateway and Agent Firewall. +> Deployments without the add-on cannot access these features. + ## Who should use the AI Governance Add-On -The AI Governance Add-On is for teams that want to extend that platform to +The AI Governance Add-On is for teams that want to extend the Coder platform to support AI-powered IDEs and coding agents in a controlled, observable way. It's a good fit if you're: @@ -77,10 +83,6 @@ rates, and usage patterns to inform decisions about AI strategy. Starting with Coder v2.30 (February 2026), AI Gateway and Agent Firewall are generally available as part of the AI Governance Add-On. -As of Coder v2.32, the AI Governance Add-On is required to use AI Gateway and -Agent Firewall. Deployments without the add-on will not be able to access -these features. - To learn more about enabling the AI Governance Add-On, pricing, or trial options, reach out to your [Coder account team](https://coder.com/contact/sales). diff --git a/docs/install/cloud/azure-vm.md b/docs/install/cloud/azure-vm.md index 2ab41bc53a..6cc2163105 100644 --- a/docs/install/cloud/azure-vm.md +++ b/docs/install/cloud/azure-vm.md @@ -56,7 +56,7 @@ as a system service. For this instance, we will run Coder as a system service, however you can run Coder a multitude of different ways. You can learn more about those -[here](https://coder.com/docs/coder-oss/latest/install). +[here](https://coder.com/docs/install). In the Azure VM instance, run the following command to install Coder diff --git a/docs/install/docker.md b/docs/install/docker.md index 63bc5cd7b9..31a7628c7a 100644 --- a/docs/install/docker.md +++ b/docs/install/docker.md @@ -8,11 +8,16 @@ You can install and run Coder using the official Docker images published on - Docker. See the [official installation documentation](https://docs.docker.com/install/). -- A Linux machine. For macOS devices, start Coder using the - [standalone binary](./cli.md). +- A Linux host. - 2 CPU cores and 4 GB memory free on your machine. +> [!IMPORTANT] +> This guide is for **Linux** hosts only. The `getent` and `--group-add` +> Docker socket patterns used below are Linux-specific and do not translate +> cleanly to macOS Docker runtimes. For macOS, install Coder using the +> [standalone binary](./cli.md) instead. +
## Install Coder via `docker compose` diff --git a/docs/install/releases/feature-stages.md b/docs/install/releases/feature-stages.md index c43e3a3fea..8cbe79b94a 100644 --- a/docs/install/releases/feature-stages.md +++ b/docs/install/releases/feature-stages.md @@ -129,7 +129,7 @@ For support, consult our knowledgeable and growing community on already. Customers with a valid Coder license, can submit a support request or contact your [account team](https://coder.com/contact). -We intend [Coder documentation](../../README.md) to be the +We intend [Coder documentation](../../about/contributing/documentation.md) to be the [single source of truth](https://en.wikipedia.org/wiki/Single_source_of_truth) and all features should have some form of complete documentation that outlines how to use or implement a feature. If you discover an error or if you have a diff --git a/docs/manifest.json b/docs/manifest.json index b130dd6912..cbbeceefbe 100644 --- a/docs/manifest.json +++ b/docs/manifest.json @@ -2388,6 +2388,11 @@ "description": "Prints the list of users.", "path": "reference/cli/users_list.md" }, + { + "title": "users oidc-claims", + "description": "Display the OIDC claims for the authenticated user.", + "path": "reference/cli/users_oidc-claims.md" + }, { "title": "users show", "description": "Show a single user. Use 'me' to indicate the currently authenticated user.", diff --git a/docs/reference/api/chats.md b/docs/reference/api/chats.md index 758f6641f5..f475d8482d 100644 --- a/docs/reference/api/chats.md +++ b/docs/reference/api/chats.md @@ -292,13 +292,13 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|---------------|--------------------------------------------------------------------------------------------------------------| -| `client_type` | `api`, `ui` | -| `kind` | `auth`, `config`, `generic`, `overloaded`, `rate_limit`, `startup_timeout`, `timeout`, `usage_limit` | -| `type` | `context-file`, `file`, `file-reference`, `reasoning`, `skill`, `source`, `text`, `tool-call`, `tool-result` | -| `plan_mode` | `plan` | -| `status` | `completed`, `error`, `paused`, `pending`, `requires_action`, `running`, `waiting` | +| Property | Value(s) | +|---------------|------------------------------------------------------------------------------------------------------------------------------------------| +| `client_type` | `api`, `ui` | +| `kind` | `auth`, `config`, `generic`, `missing_key`, `overloaded`, `provider_disabled`, `rate_limit`, `startup_timeout`, `timeout`, `usage_limit` | +| `type` | `context-file`, `file`, `file-reference`, `reasoning`, `skill`, `source`, `text`, `tool-call`, `tool-result` | +| `plan_mode` | `plan` | +| `status` | `completed`, `error`, `paused`, `pending`, `requires_action`, `running`, `waiting` | To perform this operation, you must be authenticated. [Learn more](authentication.md). diff --git a/docs/reference/api/enterprise.md b/docs/reference/api/enterprise.md index def5c219a2..ed1ce268e7 100644 --- a/docs/reference/api/enterprise.md +++ b/docs/reference/api/enterprise.md @@ -3418,6 +3418,125 @@ curl -X POST http://coder-server:8080/api/v2/templates/{template}/prebuilds/inva To perform this operation, you must be authenticated. [Learn more](authentication.md). +## Get user AI budget override + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/users/{user}/ai/budget \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /api/v2/users/{user}/ai/budget` + +### Parameters + +| Name | In | Type | Required | Description | +|--------|------|--------|----------|--------------------------| +| `user` | path | string | true | User ID, username, or me | + +### Example responses + +> 200 Response + +```json +{ + "created_at": "2019-08-24T14:15:22Z", + "group_id": "306db4e0-7449-4501-b76f-075576fe2d8f", + "spend_limit_micros": 0, + "updated_at": "2019-08-24T14:15:22Z", + "user_id": "a169451c-8525-4352-b8ca-070dd449a1a5" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|--------------------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.UserAIBudgetOverride](schemas.md#codersdkuseraibudgetoverride) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Upsert user AI budget override + +### Code samples + +```shell +# Example request using curl +curl -X PUT http://coder-server:8080/api/v2/users/{user}/ai/budget \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`PUT /api/v2/users/{user}/ai/budget` + +> Body parameter + +```json +{ + "group_id": "306db4e0-7449-4501-b76f-075576fe2d8f", + "spend_limit_micros": 0 +} +``` + +### Parameters + +| Name | In | Type | Required | Description | +|--------|------|----------------------------------------------------------------------------------------------------|----------|----------------------------------------| +| `user` | path | string | true | User ID, username, or me | +| `body` | body | [codersdk.UpsertUserAIBudgetOverrideRequest](schemas.md#codersdkupsertuseraibudgetoverriderequest) | true | Upsert user AI budget override request | + +### Example responses + +> 200 Response + +```json +{ + "created_at": "2019-08-24T14:15:22Z", + "group_id": "306db4e0-7449-4501-b76f-075576fe2d8f", + "spend_limit_micros": 0, + "updated_at": "2019-08-24T14:15:22Z", + "user_id": "a169451c-8525-4352-b8ca-070dd449a1a5" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|--------------------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.UserAIBudgetOverride](schemas.md#codersdkuseraibudgetoverride) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Delete user AI budget override + +### Code samples + +```shell +# Example request using curl +curl -X DELETE http://coder-server:8080/api/v2/users/{user}/ai/budget \ + -H 'Coder-Session-Token: API_KEY' +``` + +`DELETE /api/v2/users/{user}/ai/budget` + +### Parameters + +| Name | In | Type | Required | Description | +|--------|------|--------|----------|--------------------------| +| `user` | path | string | true | User ID, username, or me | + +### Responses + +| Status | Meaning | Description | Schema | +|--------|-----------------------------------------------------------------|-------------|--------| +| 204 | [No Content](https://tools.ietf.org/html/rfc7231#section-6.3.5) | No Content | | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## Get user quiet hours schedule ### Code samples @@ -4520,9 +4639,9 @@ curl -X POST http://coder-server:8080/scim/v2/Users \ ### Parameters -| Name | In | Type | Required | Description | -|--------|------|----------------------------------------------|----------|-------------| -| `body` | body | [coderd.SCIMUser](schemas.md#coderdscimuser) | true | New user | +| Name | In | Type | Required | Description | +|--------|------|------------------------------------------------------|----------|-------------| +| `body` | body | [legacyscim.SCIMUser](schemas.md#legacyscimscimuser) | true | New user | ### Example responses @@ -4559,9 +4678,9 @@ curl -X POST http://coder-server:8080/scim/v2/Users \ ### Responses -| Status | Meaning | Description | Schema | -|--------|---------------------------------------------------------|-------------|----------------------------------------------| -| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [coderd.SCIMUser](schemas.md#coderdscimuser) | +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [legacyscim.SCIMUser](schemas.md#legacyscimscimuser) | To perform this operation, you must be authenticated. [Learn more](authentication.md). @@ -4638,10 +4757,10 @@ curl -X PUT http://coder-server:8080/scim/v2/Users/{id} \ ### Parameters -| Name | In | Type | Required | Description | -|--------|------|----------------------------------------------|----------|----------------------| -| `id` | path | string(uuid) | true | User ID | -| `body` | body | [coderd.SCIMUser](schemas.md#coderdscimuser) | true | Replace user request | +| Name | In | Type | Required | Description | +|--------|------|------------------------------------------------------|----------|----------------------| +| `id` | path | string(uuid) | true | User ID | +| `body` | body | [legacyscim.SCIMUser](schemas.md#legacyscimscimuser) | true | Replace user request | ### Example responses @@ -4730,10 +4849,10 @@ curl -X PATCH http://coder-server:8080/scim/v2/Users/{id} \ ### Parameters -| Name | In | Type | Required | Description | -|--------|------|----------------------------------------------|----------|---------------------| -| `id` | path | string(uuid) | true | User ID | -| `body` | body | [coderd.SCIMUser](schemas.md#coderdscimuser) | true | Update user request | +| Name | In | Type | Required | Description | +|--------|------|------------------------------------------------------|----------|---------------------| +| `id` | path | string(uuid) | true | User ID | +| `body` | body | [legacyscim.SCIMUser](schemas.md#legacyscimscimuser) | true | Update user request | ### Example responses diff --git a/docs/reference/api/general.md b/docs/reference/api/general.md index 02dbfe4135..98812f55ae 100644 --- a/docs/reference/api/general.md +++ b/docs/reference/api/general.md @@ -538,6 +538,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "workspace_agent_logs": 0 }, "scim_api_key": "string", + "scim_use_legacy": true, "session_lifetime": { "default_duration": 0, "default_token_lifetime": 0, diff --git a/docs/reference/api/members.md b/docs/reference/api/members.md index 1556ced557..fae805d3a7 100644 --- a/docs/reference/api/members.md +++ b/docs/reference/api/members.md @@ -193,10 +193,10 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | -| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| Property | Value(s) | +|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | +| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | To perform this operation, you must be authenticated. [Learn more](authentication.md). @@ -326,10 +326,10 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | -| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| Property | Value(s) | +|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | +| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | To perform this operation, you must be authenticated. [Learn more](authentication.md). @@ -459,10 +459,10 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | -| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| Property | Value(s) | +|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | +| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | To perform this operation, you must be authenticated. [Learn more](authentication.md). @@ -554,10 +554,10 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | -| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| Property | Value(s) | +|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | +| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | To perform this operation, you must be authenticated. [Learn more](authentication.md). @@ -960,9 +960,9 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | -| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| Property | Value(s) | +|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` | +| `resource_type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | To perform this operation, you must be authenticated. [Learn more](authentication.md). diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index aaee58512c..ea8f19c4bf 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -220,57 +220,6 @@ |--------------------| | `prebuild_claimed` | -## coderd.SCIMUser - -```json -{ - "active": true, - "emails": [ - { - "display": "string", - "primary": true, - "type": "string", - "value": "user@example.com" - } - ], - "groups": [ - null - ], - "id": "string", - "meta": { - "resourceType": "string" - }, - "name": { - "familyName": "string", - "givenName": "string" - }, - "schemas": [ - "string" - ], - "userName": "string" -} -``` - -### Properties - -| Name | Type | Required | Restrictions | Description | -|------------------|--------------------|----------|--------------|-----------------------------------------------------------------------------| -| `active` | boolean | false | | Active is a ptr to prevent the empty value from being interpreted as false. | -| `emails` | array of object | false | | | -| `» display` | string | false | | | -| `» primary` | boolean | false | | | -| `» type` | string | false | | | -| `» value` | string | false | | | -| `groups` | array of undefined | false | | | -| `id` | string | false | | | -| `meta` | object | false | | | -| `» resourceType` | string | false | | | -| `name` | object | false | | | -| `» familyName` | string | false | | | -| `» givenName` | string | false | | | -| `schemas` | array of string | false | | | -| `userName` | string | false | | | - ## coderd.cspViolation ```json @@ -1352,14 +1301,14 @@ ### Properties -| Name | Type | Required | Restrictions | Description | -|----------------------------|--------|----------|--------------|--------------------------------------------------------------------------------------------| -| `base_url` | string | false | | Base URL is the base URL of the upstream provider API. | -| `bedrock_model` | string | false | | | -| `bedrock_region` | string | false | | | -| `bedrock_small_fast_model` | string | false | | | -| `name` | string | false | | Name is the unique instance identifier used for routing. Defaults to Type if not provided. | -| `type` | string | false | | Type is the provider type: "openai", "anthropic", or "copilot". | +| Name | Type | Required | Restrictions | Description | +|----------------------------|--------|----------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| +| `base_url` | string | false | | Base URL is the base URL of the upstream provider API. | +| `bedrock_model` | string | false | | | +| `bedrock_region` | string | false | | | +| `bedrock_small_fast_model` | string | false | | | +| `name` | string | false | | Name is the unique instance identifier used for routing. Defaults to Type if not provided. | +| `type` | string | false | | Type is the provider type. Valid values are: "openai", "anthropic", "azure", "bedrock", "google", "openai-compat", "openrouter", "vercel", "copilot". | ## codersdk.AIProviderKey @@ -1495,9 +1444,9 @@ None #### Enumerated Values -| Value(s) | -|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `ai_model_price:*`, `ai_model_price:read`, `ai_model_price:update`, `ai_provider:*`, `ai_provider:create`, `ai_provider:delete`, `ai_provider:read`, `ai_provider:update`, `ai_seat:*`, `ai_seat:create`, `ai_seat:read`, `aibridge_interception:*`, `aibridge_interception:create`, `aibridge_interception:read`, `aibridge_interception:update`, `all`, `api_key:*`, `api_key:create`, `api_key:delete`, `api_key:read`, `api_key:update`, `application_connect`, `assign_org_role:*`, `assign_org_role:assign`, `assign_org_role:create`, `assign_org_role:delete`, `assign_org_role:read`, `assign_org_role:unassign`, `assign_org_role:update`, `assign_role:*`, `assign_role:assign`, `assign_role:read`, `assign_role:unassign`, `audit_log:*`, `audit_log:create`, `audit_log:read`, `boundary_usage:*`, `boundary_usage:delete`, `boundary_usage:read`, `boundary_usage:update`, `chat:*`, `chat:create`, `chat:delete`, `chat:read`, `chat:share`, `chat:update`, `coder:all`, `coder:apikeys.manage_self`, `coder:application_connect`, `coder:templates.author`, `coder:templates.build`, `coder:workspaces.access`, `coder:workspaces.create`, `coder:workspaces.delete`, `coder:workspaces.operate`, `connection_log:*`, `connection_log:read`, `connection_log:update`, `crypto_key:*`, `crypto_key:create`, `crypto_key:delete`, `crypto_key:read`, `crypto_key:update`, `debug_info:*`, `debug_info:read`, `deployment_config:*`, `deployment_config:read`, `deployment_config:update`, `deployment_stats:*`, `deployment_stats:read`, `file:*`, `file:create`, `file:read`, `group:*`, `group:create`, `group:delete`, `group:read`, `group:update`, `group_member:*`, `group_member:read`, `idpsync_settings:*`, `idpsync_settings:read`, `idpsync_settings:update`, `inbox_notification:*`, `inbox_notification:create`, `inbox_notification:read`, `inbox_notification:update`, `license:*`, `license:create`, `license:delete`, `license:read`, `notification_message:*`, `notification_message:create`, `notification_message:delete`, `notification_message:read`, `notification_message:update`, `notification_preference:*`, `notification_preference:read`, `notification_preference:update`, `notification_template:*`, `notification_template:read`, `notification_template:update`, `oauth2_app:*`, `oauth2_app:create`, `oauth2_app:delete`, `oauth2_app:read`, `oauth2_app:update`, `oauth2_app_code_token:*`, `oauth2_app_code_token:create`, `oauth2_app_code_token:delete`, `oauth2_app_code_token:read`, `oauth2_app_secret:*`, `oauth2_app_secret:create`, `oauth2_app_secret:delete`, `oauth2_app_secret:read`, `oauth2_app_secret:update`, `organization:*`, `organization:create`, `organization:delete`, `organization:read`, `organization:update`, `organization_member:*`, `organization_member:create`, `organization_member:delete`, `organization_member:read`, `organization_member:update`, `prebuilt_workspace:*`, `prebuilt_workspace:delete`, `prebuilt_workspace:update`, `provisioner_daemon:*`, `provisioner_daemon:create`, `provisioner_daemon:delete`, `provisioner_daemon:read`, `provisioner_daemon:update`, `provisioner_jobs:*`, `provisioner_jobs:create`, `provisioner_jobs:read`, `provisioner_jobs:update`, `replicas:*`, `replicas:read`, `system:*`, `system:create`, `system:delete`, `system:read`, `system:update`, `tailnet_coordinator:*`, `tailnet_coordinator:create`, `tailnet_coordinator:delete`, `tailnet_coordinator:read`, `tailnet_coordinator:update`, `task:*`, `task:create`, `task:delete`, `task:read`, `task:update`, `template:*`, `template:create`, `template:delete`, `template:read`, `template:update`, `template:use`, `template:view_insights`, `usage_event:*`, `usage_event:create`, `usage_event:read`, `usage_event:update`, `user:*`, `user:create`, `user:delete`, `user:read`, `user:read_personal`, `user:update`, `user:update_personal`, `user_secret:*`, `user_secret:create`, `user_secret:delete`, `user_secret:read`, `user_secret:update`, `user_skill:*`, `user_skill:create`, `user_skill:delete`, `user_skill:read`, `user_skill:update`, `webpush_subscription:*`, `webpush_subscription:create`, `webpush_subscription:delete`, `webpush_subscription:read`, `workspace:*`, `workspace:application_connect`, `workspace:create`, `workspace:create_agent`, `workspace:delete`, `workspace:delete_agent`, `workspace:read`, `workspace:share`, `workspace:ssh`, `workspace:start`, `workspace:stop`, `workspace:update`, `workspace:update_agent`, `workspace_agent_devcontainers:*`, `workspace_agent_devcontainers:create`, `workspace_agent_resource_monitor:*`, `workspace_agent_resource_monitor:create`, `workspace_agent_resource_monitor:read`, `workspace_agent_resource_monitor:update`, `workspace_dormant:*`, `workspace_dormant:application_connect`, `workspace_dormant:create`, `workspace_dormant:create_agent`, `workspace_dormant:delete`, `workspace_dormant:delete_agent`, `workspace_dormant:read`, `workspace_dormant:share`, `workspace_dormant:ssh`, `workspace_dormant:start`, `workspace_dormant:stop`, `workspace_dormant:update`, `workspace_dormant:update_agent`, `workspace_proxy:*`, `workspace_proxy:create`, `workspace_proxy:delete`, `workspace_proxy:read`, `workspace_proxy:update` | +| Value(s) | +|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `ai_model_price:*`, `ai_model_price:read`, `ai_model_price:update`, `ai_provider:*`, `ai_provider:create`, `ai_provider:delete`, `ai_provider:read`, `ai_provider:update`, `ai_seat:*`, `ai_seat:create`, `ai_seat:read`, `aibridge_interception:*`, `aibridge_interception:create`, `aibridge_interception:read`, `aibridge_interception:update`, `all`, `api_key:*`, `api_key:create`, `api_key:delete`, `api_key:read`, `api_key:update`, `application_connect`, `assign_org_role:*`, `assign_org_role:assign`, `assign_org_role:create`, `assign_org_role:delete`, `assign_org_role:read`, `assign_org_role:unassign`, `assign_org_role:update`, `assign_role:*`, `assign_role:assign`, `assign_role:read`, `assign_role:unassign`, `audit_log:*`, `audit_log:create`, `audit_log:read`, `boundary_log:*`, `boundary_log:create`, `boundary_log:delete`, `boundary_log:read`, `boundary_usage:*`, `boundary_usage:delete`, `boundary_usage:read`, `boundary_usage:update`, `chat:*`, `chat:create`, `chat:delete`, `chat:read`, `chat:share`, `chat:update`, `coder:all`, `coder:apikeys.manage_self`, `coder:application_connect`, `coder:templates.author`, `coder:templates.build`, `coder:workspaces.access`, `coder:workspaces.create`, `coder:workspaces.delete`, `coder:workspaces.operate`, `connection_log:*`, `connection_log:read`, `connection_log:update`, `crypto_key:*`, `crypto_key:create`, `crypto_key:delete`, `crypto_key:read`, `crypto_key:update`, `debug_info:*`, `debug_info:read`, `deployment_config:*`, `deployment_config:read`, `deployment_config:update`, `deployment_stats:*`, `deployment_stats:read`, `file:*`, `file:create`, `file:read`, `group:*`, `group:create`, `group:delete`, `group:read`, `group:update`, `group_member:*`, `group_member:read`, `idpsync_settings:*`, `idpsync_settings:read`, `idpsync_settings:update`, `inbox_notification:*`, `inbox_notification:create`, `inbox_notification:read`, `inbox_notification:update`, `license:*`, `license:create`, `license:delete`, `license:read`, `notification_message:*`, `notification_message:create`, `notification_message:delete`, `notification_message:read`, `notification_message:update`, `notification_preference:*`, `notification_preference:read`, `notification_preference:update`, `notification_template:*`, `notification_template:read`, `notification_template:update`, `oauth2_app:*`, `oauth2_app:create`, `oauth2_app:delete`, `oauth2_app:read`, `oauth2_app:update`, `oauth2_app_code_token:*`, `oauth2_app_code_token:create`, `oauth2_app_code_token:delete`, `oauth2_app_code_token:read`, `oauth2_app_secret:*`, `oauth2_app_secret:create`, `oauth2_app_secret:delete`, `oauth2_app_secret:read`, `oauth2_app_secret:update`, `organization:*`, `organization:create`, `organization:delete`, `organization:read`, `organization:update`, `organization_member:*`, `organization_member:create`, `organization_member:delete`, `organization_member:read`, `organization_member:update`, `prebuilt_workspace:*`, `prebuilt_workspace:delete`, `prebuilt_workspace:update`, `provisioner_daemon:*`, `provisioner_daemon:create`, `provisioner_daemon:delete`, `provisioner_daemon:read`, `provisioner_daemon:update`, `provisioner_jobs:*`, `provisioner_jobs:create`, `provisioner_jobs:read`, `provisioner_jobs:update`, `replicas:*`, `replicas:read`, `system:*`, `system:create`, `system:delete`, `system:read`, `system:update`, `tailnet_coordinator:*`, `tailnet_coordinator:create`, `tailnet_coordinator:delete`, `tailnet_coordinator:read`, `tailnet_coordinator:update`, `task:*`, `task:create`, `task:delete`, `task:read`, `task:update`, `template:*`, `template:create`, `template:delete`, `template:read`, `template:update`, `template:use`, `template:view_insights`, `usage_event:*`, `usage_event:create`, `usage_event:read`, `usage_event:update`, `user:*`, `user:create`, `user:delete`, `user:read`, `user:read_personal`, `user:update`, `user:update_personal`, `user_secret:*`, `user_secret:create`, `user_secret:delete`, `user_secret:read`, `user_secret:update`, `user_skill:*`, `user_skill:create`, `user_skill:delete`, `user_skill:read`, `user_skill:update`, `webpush_subscription:*`, `webpush_subscription:create`, `webpush_subscription:delete`, `webpush_subscription:read`, `workspace:*`, `workspace:application_connect`, `workspace:create`, `workspace:create_agent`, `workspace:delete`, `workspace:delete_agent`, `workspace:read`, `workspace:share`, `workspace:ssh`, `workspace:start`, `workspace:stop`, `workspace:update`, `workspace:update_agent`, `workspace_agent_devcontainers:*`, `workspace_agent_devcontainers:create`, `workspace_agent_resource_monitor:*`, `workspace_agent_resource_monitor:create`, `workspace_agent_resource_monitor:read`, `workspace_agent_resource_monitor:update`, `workspace_dormant:*`, `workspace_dormant:application_connect`, `workspace_dormant:create`, `workspace_dormant:create_agent`, `workspace_dormant:delete`, `workspace_dormant:delete_agent`, `workspace_dormant:read`, `workspace_dormant:share`, `workspace_dormant:ssh`, `workspace_dormant:start`, `workspace_dormant:stop`, `workspace_dormant:update`, `workspace_dormant:update_agent`, `workspace_proxy:*`, `workspace_proxy:create`, `workspace_proxy:delete`, `workspace_proxy:read`, `workspace_proxy:update` | ## codersdk.AddLicenseRequest @@ -2732,9 +2681,9 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in #### Enumerated Values -| Value(s) | -|------------------------------------------------------------------------------------------------------| -| `auth`, `config`, `generic`, `overloaded`, `rate_limit`, `startup_timeout`, `timeout`, `usage_limit` | +| Value(s) | +|------------------------------------------------------------------------------------------------------------------------------------------| +| `auth`, `config`, `generic`, `missing_key`, `overloaded`, `provider_disabled`, `rate_limit`, `startup_timeout`, `timeout`, `usage_limit` | ## codersdk.ChatFileMetadata @@ -6058,6 +6007,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "workspace_agent_logs": 0 }, "scim_api_key": "string", + "scim_use_legacy": true, "session_lifetime": { "default_duration": 0, "default_token_lifetime": 0, @@ -6657,6 +6607,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "workspace_agent_logs": 0 }, "scim_api_key": "string", + "scim_use_legacy": true, "session_lifetime": { "default_duration": 0, "default_token_lifetime": 0, @@ -6817,6 +6768,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o | `redirect_to_access_url` | boolean | false | | | | `retention` | [codersdk.RetentionConfig](#codersdkretentionconfig) | false | | | | `scim_api_key` | string | false | | | +| `scim_use_legacy` | boolean | false | | | | `session_lifetime` | [codersdk.SessionLifetime](#codersdksessionlifetime) | false | | | | `ssh_keygen_algorithm` | string | false | | | | `stats_collection` | [codersdk.StatsCollectionConfig](#codersdkstatscollectionconfig) | false | | | @@ -10866,9 +10818,9 @@ Only certain features set these fields: - FeatureManagedAgentLimit| #### Enumerated Values -| Value(s) | -|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| Value(s) | +|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | ## codersdk.RateLimitConfig @@ -13649,6 +13601,22 @@ If the schedule is empty, the user will be updated to use the default schedule.| |----------------------|---------|----------|--------------|-------------| | `spend_limit_micros` | integer | false | | | +## codersdk.UpsertUserAIBudgetOverrideRequest + +```json +{ + "group_id": "306db4e0-7449-4501-b76f-075576fe2d8f", + "spend_limit_micros": 0 +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|----------------------|---------|----------|--------------|---------------------------------------------------------------------------------------------------| +| `group_id` | string | true | | Group ID is the group the user's spend is attributed to. The user must be a member of this group. | +| `spend_limit_micros` | integer | false | | | + ## codersdk.UpsertWorkspaceAgentPortShareRequest ```json @@ -13778,6 +13746,28 @@ If the schedule is empty, the user will be updated to use the default schedule.| |----------|-----------------------| | `status` | `active`, `suspended` | +## codersdk.UserAIBudgetOverride + +```json +{ + "created_at": "2019-08-24T14:15:22Z", + "group_id": "306db4e0-7449-4501-b76f-075576fe2d8f", + "spend_limit_micros": 0, + "updated_at": "2019-08-24T14:15:22Z", + "user_id": "a169451c-8525-4352-b8ca-070dd449a1a5" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|----------------------|---------|----------|--------------|-------------| +| `created_at` | string | false | | | +| `group_id` | string | false | | | +| `spend_limit_micros` | integer | false | | | +| `updated_at` | string | false | | | +| `user_id` | string | false | | | + ## codersdk.UserActivity ```json @@ -17915,6 +17905,57 @@ Zero means unspecified. There might be a limit, but the client need not try to r None +## legacyscim.SCIMUser + +```json +{ + "active": true, + "emails": [ + { + "display": "string", + "primary": true, + "type": "string", + "value": "user@example.com" + } + ], + "groups": [ + null + ], + "id": "string", + "meta": { + "resourceType": "string" + }, + "name": { + "familyName": "string", + "givenName": "string" + }, + "schemas": [ + "string" + ], + "userName": "string" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------------------|--------------------|----------|--------------|-----------------------------------------------------------------------------| +| `active` | boolean | false | | Active is a ptr to prevent the empty value from being interpreted as false. | +| `emails` | array of object | false | | | +| `» display` | string | false | | | +| `» primary` | boolean | false | | | +| `» type` | string | false | | | +| `» value` | string | false | | | +| `groups` | array of undefined | false | | | +| `id` | string | false | | | +| `meta` | object | false | | | +| `» resourceType` | string | false | | | +| `name` | object | false | | | +| `» familyName` | string | false | | | +| `» givenName` | string | false | | | +| `schemas` | array of string | false | | | +| `userName` | string | false | | | + ## netcheck.Report ```json diff --git a/docs/reference/api/users.md b/docs/reference/api/users.md index 376a415031..0bedde7b0c 100644 --- a/docs/reference/api/users.md +++ b/docs/reference/api/users.md @@ -865,11 +865,11 @@ Status Code **200** #### Enumerated Values -| Property | Value(s) | -|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | -| `login_type` | `github`, `oidc`, `password`, `token` | -| `scope` | `all`, `application_connect` | +| Property | Value(s) | +|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `type` | `*`, `ai_model_price`, `ai_provider`, `ai_seat`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `user_skill`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` | +| `login_type` | `github`, `oidc`, `password`, `token` | +| `scope` | `all`, `application_connect` | To perform this operation, you must be authenticated. [Learn more](authentication.md). diff --git a/docs/user-guides/workspace-access/index.md b/docs/user-guides/workspace-access/index.md index da72459cbb..ee1bd9aa5c 100644 --- a/docs/user-guides/workspace-access/index.md +++ b/docs/user-guides/workspace-access/index.md @@ -132,7 +132,7 @@ on connecting your JetBrains IDEs. [code-server](https://github.com/coder/code-server) is our supported method of running VS Code in the web browser. Learn more about [what makes code-server different from VS Code web](./code-server.md) or visit the -[documentation for code-server](https://coder.com/docs/code-server/latest). +[documentation for code-server](https://coder.com/docs/code-server). ![code-server in a workspace](../../images/code-server-ide.png) diff --git a/dogfood/coder/main.tf b/dogfood/coder/main.tf index aad65c886f..ad576e543f 100644 --- a/dogfood/coder/main.tf +++ b/dogfood/coder/main.tf @@ -277,7 +277,6 @@ data "coder_external_auth" "github" { data "coder_workspace" "me" {} data "coder_workspace_owner" "me" {} -data "coder_task" "me" {} data "coder_workspace_tags" "tags" { tags = { "cluster" : "dogfood-v2" @@ -991,10 +990,6 @@ resource "coder_metadata" "container_info" { key = "region" value = data.coder_parameter.region.option[index(data.coder_parameter.region.option.*.value, data.coder_parameter.region.value)].name } - item { - key = "ai_task" - value = data.coder_task.me.enabled ? "yes" : "no" - } } resource "coder_script" "boundary_config_setup" { diff --git a/dogfood/vscode-coder/main.tf b/dogfood/vscode-coder/main.tf index eece70b548..791136979f 100644 --- a/dogfood/vscode-coder/main.tf +++ b/dogfood/vscode-coder/main.tf @@ -204,7 +204,6 @@ data "coder_external_auth" "github" { data "coder_workspace" "me" {} data "coder_workspace_owner" "me" {} -data "coder_task" "me" {} data "coder_workspace_tags" "tags" { tags = { "cluster" : "dogfood-v2" @@ -541,99 +540,28 @@ resource "coder_metadata" "container_info" { key = "region" value = data.coder_parameter.region.option[index(data.coder_parameter.region.option.*.value, data.coder_parameter.region.value)].name } - item { - key = "ai_task" - value = data.coder_task.me.enabled ? "yes" : "no" - } -} - -# --- AI task support --- - -locals { - claude_system_prompt = <<-EOT - -- Framing -- - You are a helpful coding assistant working on the coder/vscode-coder - VS Code extension. Aim to autonomously investigate and solve issues - the user gives you and test your work, whenever possible. - - Avoid shortcuts like mocking tests. When you get stuck, you can ask - the user but opt for autonomy. - - -- Tool Selection -- - - Built-in tools for everything: - (file operations, git commands, builds & installs, one-off shell commands) - - -- Testing -- - Integration tests launch a real VS Code instance and require a - virtual framebuffer. Run them headlessly with: - xvfb-run -a pnpm test:integration - This matches how CI runs them. Unit tests do not need xvfb-run: - pnpm test - - -- Workflow -- - When starting new work: - 1. If given a GitHub issue URL, use the `gh` CLI to read the full - issue details with `gh issue view `. - 2. Create a feature branch for the work using a descriptive name - based on the issue or task. - Example: `git checkout -b fix/issue-123-ssh-retry` - 3. Proceed with implementation following the AGENTS.md guidelines. - - -- Context -- - This is the coder/vscode-coder VS Code extension. It is a real-world - production extension used by developers to connect to Coder workspaces. - Be sure to read AGENTS.md before making any changes. - EOT } module "claude-code" { - count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0 - source = "dev.registry.coder.com/coder/claude-code/coder" - version = "4.9.2" - enable_boundary = true - agent_id = coder_agent.dev.id - workdir = local.repo_dir - claude_code_version = "latest" - model = "opus" - order = 999 - claude_api_key = data.coder_parameter.use_ai_bridge.value ? data.coder_workspace_owner.me.session_token : var.anthropic_api_key - agentapi_version = "latest" - system_prompt = local.claude_system_prompt - ai_prompt = data.coder_task.me.prompt + count = data.coder_workspace.me.start_count + source = "dev.registry.coder.com/coder/claude-code/coder" + version = "5.2.0" + enable_ai_gateway = data.coder_parameter.use_ai_bridge.value + anthropic_api_key = data.coder_parameter.use_ai_bridge.value ? "" : var.anthropic_api_key + agent_id = coder_agent.dev.id + workdir = local.repo_dir } -resource "coder_ai_task" "task" { - count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0 - app_id = module.claude-code[count.index].task_app_id -} - -resource "coder_app" "watch" { - count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0 +resource "coder_app" "claude" { agent_id = coder_agent.dev.id - slug = "watch" - display_name = "pnpm watch" - icon = "${data.coder_workspace.me.access_url}/icon/code.svg" - command = "screen -x pnpm_watch" - share = "authenticated" - open_in = "tab" - order = 0 -} - -resource "coder_script" "watch" { - count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0 - display_name = "pnpm watch" - agent_id = coder_agent.dev.id - run_on_start = true - start_blocks_login = false - icon = "${data.coder_workspace.me.access_url}/icon/code.svg" - script = <<-EOT - #!/usr/bin/env bash - set -eux -o pipefail - - trap 'coder exp sync complete pnpm-watch' EXIT - coder exp sync want pnpm-watch install-deps - coder exp sync start pnpm-watch - - cd "${local.repo_dir}" && screen -dmS pnpm_watch /bin/sh -c 'while true; do pnpm watch; echo "pnpm watch exited with code $? restarting in 10s"; sleep 10; done' + slug = "claude" + display_name = "Claude Code" + icon = "/icon/claude.svg" + open_in = "slim-window" + command = <<-EOT + #!/bin/bash + set -e + cd "${local.repo_dir}" + exec tmux new-session -A -s claude claude EOT } diff --git a/enterprise/aibridgeproxyd/aibridgeproxyd.go b/enterprise/aibridgeproxyd/aibridgeproxyd.go index 19d05ca511..cfcb2071c4 100644 --- a/enterprise/aibridgeproxyd/aibridgeproxyd.go +++ b/enterprise/aibridgeproxyd/aibridgeproxyd.go @@ -31,18 +31,6 @@ import ( agplaibridge "github.com/coder/coder/v2/coderd/aibridge" ) -// ProviderRoute is the routing entry for a single AI provider: the -// instance name (the routing key) and the upstream base URL (the -// source of the MITM allowlist host). -type ProviderRoute struct { - Name string - BaseURL string -} - -// RefreshProvidersFunc returns the live provider set used by Reload to -// rebuild the proxy's routing snapshot. -type RefreshProvidersFunc func(ctx context.Context) ([]ProviderRoute, error) - // Known AI provider hosts. const ( HostAnthropic = "api.anthropic.com" @@ -161,7 +149,7 @@ type Server struct { // providerRouter keeps CONNECT matching and provider lookup in sync. type providerRouter struct { - mitmHosts []string // host:port allowlist for the goproxy condition. + mitmHosts []string // host:port set the goproxy condition matches against. nameByHost map[string]string // lowercase hostname -> provider name. } @@ -218,15 +206,8 @@ type Options struct { // CertStore is an optional certificate cache for MITM. If nil, a default // cache is created. Exposed for testing. CertStore goproxy.CertStorage - // DomainAllowlist seeds the boot-time MITM allowlist. Production - // callers should leave this empty and rely on RefreshProviders; - // tests use it to skip the refresh round-trip. - DomainAllowlist []string - // AIBridgeProviderFromHost seeds the boot-time host -> provider - // name mapping. Required iff DomainAllowlist is non-empty. - AIBridgeProviderFromHost func(host string) string // UpstreamProxy is the URL of an upstream HTTP proxy to chain tunneled - // (non-allowlisted) requests through. If empty, tunneled requests connect + // (non-provider-host) requests through. If empty, tunneled requests connect // directly to their destinations. // Format: http://[user:pass@]host:port or https://[user:pass@]host:port UpstreamProxy string @@ -249,7 +230,7 @@ type Options struct { // If nil, metrics will not be recorded. Metrics *Metrics // RefreshProviders, when set, is invoked by Server.Reload to fetch - // the live provider snapshot used to derive the MITM allowlist and + // the live provider snapshot used to derive the MITM host set and // host -> provider-name routing. Nil disables hot-reload. RefreshProviders RefreshProvidersFunc } @@ -296,14 +277,6 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) allowedPorts = []string{"80", "443"} } - // Build the boot-time router from DomainAllowlist + the lookup fn. - // Both empty is fine: the server fails closed (no MITM until - // Reload populates the router from the database). - bootRouter, err := buildBootRouter(opts.DomainAllowlist, opts.AIBridgeProviderFromHost, allowedPorts) - if err != nil { - return nil, err - } - // Parse configured exceptions to the blocked IP ranges. allowedPrivateRanges := make([]net.IPNet, 0, len(opts.AllowedPrivateCIDRs)) for _, cidr := range opts.AllowedPrivateCIDRs { @@ -352,13 +325,13 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) newDumper: opts.NewDumper, metrics: opts.Metrics, } - // Seed the boot-time router from the constructor inputs so the - // proxy can serve immediately. Reload may swap this snapshot at any - // point after construction. - srv.providerRouter.Store(bootRouter) + // Start with an empty router; the first Reload populates it from + // the configured provider source. The proxy fails closed (no MITM) + // until that happens. + srv.providerRouter.Store(emptyProviderRouter) - // Configure upstream proxy for tunneled (non-allowlisted) CONNECT requests. - // Allowlisted domains are MITM'd and forwarded to aibridge directly, + // Configure upstream proxy for tunneled (non-provider-host) CONNECT requests. + // Provider-host domains are MITM'd and forwarded to aibridge directly, // bypassing the upstream proxy. if opts.UpstreamProxy != "" { upstreamURL, err := url.Parse(opts.UpstreamProxy) @@ -443,7 +416,7 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) // Reject CONNECT requests to non-standard ports. proxy.OnRequest().HandleConnectFunc(srv.portMiddleware(allowedPorts)) - // Apply MITM with authentication only to allowlisted hosts. The host + // Apply MITM with authentication only to provider hosts. The host // list is loaded from the atomic router on every CONNECT so a // Reload while inflight requests are in progress takes effect on // the next CONNECT without touching the already-MITM'd ones. @@ -452,9 +425,9 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) srv.authMiddleware, ) - // Tunnel CONNECT requests for non-allowlisted domains directly to their destination. + // Tunnel CONNECT requests for non-provider-host domains directly to their destination. // goproxy calls handlers in registration order: this must come after the MITM handler - // so it only handles requests that weren't matched by the allowlist. + // so it only handles requests that weren't matched as provider hosts. proxy.OnRequest().HandleConnectFunc(srv.tunneledMiddleware) // Handle decrypted requests: route to aibridged for known AI providers, or tunnel to original destination. @@ -495,7 +468,6 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) slog.F("listen_addr", listener.Addr().String()), slog.F("tls_listener_enabled", srv.tlsEnabled), slog.F("coder_access_url", coderAccessURL.String()), - slog.F("domain_allowlist", bootRouter.mitmHosts), slog.F("upstream_proxy", opts.UpstreamProxy), slog.F("allowed_private_cidrs", opts.AllowedPrivateCIDRs), slog.F("api_dump_enabled", opts.NewDumper != nil), @@ -810,7 +782,7 @@ func newProxyAuthRequiredResponse(req *http.Request) *http.Response { } } -// tunneledMiddleware is a CONNECT middleware that handles tunneled (non-allowlisted) +// tunneledMiddleware is a CONNECT middleware that handles tunneled (non-provider-host) // connections. These connections are not MITM'd and are tunneled directly to their // destination. This middleware records metrics for tunneled CONNECT sessions. func (s *Server) tunneledMiddleware(host string, _ *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { @@ -946,16 +918,28 @@ func (s *Server) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http. return req, resp } - if reqCtx.Provider == "" { - // A concurrent Reload can remove the provider after CONNECT - // authentication. The request is MITM'd (decrypted), but without a - // mapping there is no known route to aibridge. Log and forward - // to the original destination as a fallback. - s.logger.Warn(s.ctx, "decrypted request has no provider mapping, passing through", + // Re-validate the CONNECT-time provider against the live router. + // A long-lived CONNECT tunnel can outlive a provider being disabled, + // removed, or renamed: the captured reqCtx.Provider is stale, but + // subsequent decrypted requests would still route to aibridged if we + // trusted it. Look up the provider for the current request's host + // and pass through if the mapping is gone or has changed. + host := req.URL.Hostname() + if host == "" { + host = req.Host + if h, _, splitErr := net.SplitHostPort(host); splitErr == nil { + host = h + } + } + liveProvider := s.loadProviderRouter().providerFromHost(host) + if liveProvider == "" || liveProvider != reqCtx.Provider { + s.logger.Warn(s.ctx, "provider mapping changed or removed since CONNECT, passing through", slog.F("connect_id", reqCtx.ConnectSessionID.String()), slog.F("host", req.Host), slog.F("method", req.Method), slog.F("path", originalPath), + slog.F("connect_provider", reqCtx.Provider), + slog.F("live_provider", liveProvider), ) return req, nil } @@ -1053,8 +1037,13 @@ func injectBYOKHeaderIfNeeded(header http.Header, coderToken string) { } // handleResponse handles responses received from aibridged. -// This is only called for MITM'd requests (allowlisted domains routed through aibridged). -// Tunneled requests (non-allowlisted domains) bypass this handler entirely. +// This is called for every MITM'd request, including the pass-through +// path where handleRequest re-validated the CONNECT-time provider and +// forwarded the request to the original upstream instead of aibridged. +// Pass-through responses are identified by reqCtx.RequestID == uuid.Nil +// (set only when handleRequest routes to aibridged) and are skipped here +// to avoid mislabeled logs and corrupting MITM metrics. +// Tunneled requests (non-provider-host domains) bypass this handler entirely. func (s *Server) handleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { if resp == nil { return nil @@ -1077,6 +1066,14 @@ func (s *Server) handleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *htt slog.F("status", resp.StatusCode), ) + // Pass-through responses (handleRequest returned without routing to + // aibridged) come from the real upstream. The aibridged-specific log + // and metrics do not apply; the pass-through itself is already logged + // in handleRequest. + if requestID == uuid.Nil { + return resp + } + switch { case resp.StatusCode >= http.StatusInternalServerError: logger.Error(s.ctx, "received error response from aibridged") diff --git a/enterprise/aibridgeproxyd/aibridgeproxyd_test.go b/enterprise/aibridgeproxyd/aibridgeproxyd_test.go index 6b843d8b14..50224aa98c 100644 --- a/enterprise/aibridgeproxyd/aibridgeproxyd_test.go +++ b/enterprise/aibridgeproxyd/aibridgeproxyd_test.go @@ -3,6 +3,7 @@ package aibridgeproxyd_test import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -34,6 +35,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/aibridge" agplaibridge "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/aibridged" "github.com/coder/coder/v2/enterprise/aibridgeproxyd" "github.com/coder/coder/v2/testutil" ) @@ -145,20 +147,19 @@ func generateListenerCert(t *testing.T) (certFile, keyFile string) { } type testProxyConfig struct { - listenAddr string - tlsCertFile string - tlsKeyFile string - coderAccessURL string - allowedPorts []string - certStore *aibridgeproxyd.CertCache - domainAllowlist []string - aibridgeProviderFromHost func(string) string - upstreamProxy string - upstreamProxyCA string - allowedPrivateCIDRs []string - newDumper func(string, string) aibridgeproxyd.RoundTripDumper - metrics *aibridgeproxyd.Metrics - refreshProviders aibridgeproxyd.RefreshProvidersFunc + listenAddr string + tlsCertFile string + tlsKeyFile string + coderAccessURL string + allowedPorts []string + certStore *aibridgeproxyd.CertCache + providers []aibridgeproxyd.ReloadedProvider + upstreamProxy string + upstreamProxyCA string + allowedPrivateCIDRs []string + newDumper func(string, string) aibridgeproxyd.RoundTripDumper + metrics *aibridgeproxyd.Metrics + refreshProviders aibridgeproxyd.RefreshProvidersFunc } type testProxyOption func(*testProxyConfig) @@ -181,15 +182,43 @@ func withCertStore(store *aibridgeproxyd.CertCache) testProxyOption { } } -func withDomainAllowlist(domains ...string) testProxyOption { +// withProviders configures the proxy with the given classified provider +// set. The reload helper synthesizes a RefreshProvidersFunc and the +// router is populated synchronously during newTestProxy before the +// server begins serving. +func withProviders(providers ...aibridgeproxyd.ReloadedProvider) testProxyOption { return func(cfg *testProxyConfig) { - cfg.domainAllowlist = domains + cfg.providers = providers } } -func withAIBridgeProviderFromHost(fn func(string) string) testProxyOption { +// withProviderHosts is a convenience that builds enabled +// ReloadedProvider entries from each host, looking up the well-known +// provider name via testProviderFromHost and falling back to +// "test-provider" for hosts without a well-known mapping. Equivalent +// to passing each entry individually to withProviders. +func withProviderHosts(hosts ...string) testProxyOption { return func(cfg *testProxyConfig) { - cfg.aibridgeProviderFromHost = fn + providers := make([]aibridgeproxyd.ReloadedProvider, 0, len(hosts)) + for _, h := range hosts { + name := testProviderFromHost(h) + if name == "" { + name = "test-provider" + } + host, _, splitErr := net.SplitHostPort(h) + if splitErr != nil { + host = h + } + providers = append(providers, aibridgeproxyd.ReloadedProvider{ + ProviderOutcome: aibridged.ProviderOutcome{ + Name: name, + Type: "openai", + Status: aibridged.ProviderStatusEnabled, + }, + Host: strings.ToLower(host), + }) + } + cfg.providers = providers } } @@ -264,39 +293,48 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server t.Helper() cfg := &testProxyConfig{ - listenAddr: "127.0.0.1:0", - coderAccessURL: "http://localhost:3000", - domainAllowlist: []string{"127.0.0.1", "localhost"}, + listenAddr: "127.0.0.1:0", + coderAccessURL: "http://localhost:3000", // Allow 127.0.0.1 by default so test servers, which always listen on // loopback, are reachable. Tests that verify IP blocking override this. allowedPrivateCIDRs: []string{"127.0.0.1/32"}, - aibridgeProviderFromHost: func(host string) string { - return "test-provider" + providers: []aibridgeproxyd.ReloadedProvider{ + {ProviderOutcome: aibridged.ProviderOutcome{Name: "test-provider", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "127.0.0.1"}, + {ProviderOutcome: aibridged.ProviderOutcome{Name: "test-provider", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "localhost"}, }, } for _, opt := range opts { opt(cfg) } + // If the test did not supply a RefreshProviders, synthesize one + // that returns the configured providers verbatim. This populates + // the router synchronously below, mirroring how production starts + // up after the first reload completes. + if cfg.refreshProviders == nil { + providers := cfg.providers + cfg.refreshProviders = func(context.Context) (aibridgeproxyd.ProviderReload, error) { + return aibridgeproxyd.ProviderReload{Providers: providers}, nil + } + } + mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) aibridgeOpts := aibridgeproxyd.Options{ - ListenAddr: cfg.listenAddr, - TLSCertFile: cfg.tlsCertFile, - TLSKeyFile: cfg.tlsKeyFile, - CoderAccessURL: cfg.coderAccessURL, - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - AllowedPorts: cfg.allowedPorts, - DomainAllowlist: cfg.domainAllowlist, - AIBridgeProviderFromHost: cfg.aibridgeProviderFromHost, - UpstreamProxy: cfg.upstreamProxy, - UpstreamProxyCA: cfg.upstreamProxyCA, - AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs, - NewDumper: cfg.newDumper, - Metrics: cfg.metrics, - RefreshProviders: cfg.refreshProviders, + ListenAddr: cfg.listenAddr, + TLSCertFile: cfg.tlsCertFile, + TLSKeyFile: cfg.tlsKeyFile, + CoderAccessURL: cfg.coderAccessURL, + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + AllowedPorts: cfg.allowedPorts, + UpstreamProxy: cfg.upstreamProxy, + UpstreamProxyCA: cfg.upstreamProxyCA, + AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs, + NewDumper: cfg.newDumper, + Metrics: cfg.metrics, + RefreshProviders: cfg.refreshProviders, } if cfg.certStore != nil { aibridgeOpts.CertStore = cfg.certStore @@ -306,6 +344,10 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server require.NoError(t, err) t.Cleanup(func() { _ = srv.Close() }) + // Populate the router before the server starts handling traffic. + // Production performs the first reload during boot via pubsub. + require.NoError(t, srv.Reload(t.Context())) + // Wait for the proxy server to be ready. proxyAddr := srv.Addr() require.NotEmpty(t, proxyAddr) @@ -444,10 +486,9 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "listen address is required") @@ -460,11 +501,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: "", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "listen address is required") @@ -477,12 +517,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - TLSCertFile: "cert.pem", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: "127.0.0.1:0", + TLSCertFile: "cert.pem", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "tls cert file and tls key file must both be set") @@ -495,12 +534,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - TLSKeyFile: "key.pem", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: "127.0.0.1:0", + TLSKeyFile: "key.pem", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "tls cert file and tls key file must both be set") @@ -513,14 +551,12 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - TLSCertFile: "/nonexistent/cert.pem", - TLSKeyFile: "/nonexistent/key.pem", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + TLSCertFile: "/nonexistent/cert.pem", + TLSKeyFile: "/nonexistent/key.pem", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "load listener TLS certificate") @@ -533,10 +569,9 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: "127.0.0.1:0", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "coder access URL is required") @@ -549,11 +584,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: " ", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: " ", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "coder access URL is required") @@ -566,11 +600,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "://invalid", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "://invalid", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.Error(t, err) require.Contains(t, err.Error(), "invalid coder access URL") @@ -583,12 +616,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.NoError(t, err) require.Equal(t, "localhost", srv.CoderAccessURL().Hostname()) @@ -602,12 +633,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "https://localhost", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "https://localhost", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.NoError(t, err) require.Equal(t, "localhost", srv.CoderAccessURL().Hostname()) @@ -621,12 +650,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.NoError(t, err) require.Equal(t, "localhost", srv.CoderAccessURL().Hostname()) @@ -639,10 +666,9 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMKeyFile: "key.pem", - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: ":0", + CoderAccessURL: "http://localhost:3000", + MITMKeyFile: "key.pem", }) require.Error(t, err) require.Contains(t, err.Error(), "cert file and key file are required") @@ -654,10 +680,9 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: "cert.pem", - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + ListenAddr: ":0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: "cert.pem", }) require.Error(t, err) require.Contains(t, err.Error(), "cert file and key file are required") @@ -669,104 +694,15 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: "/nonexistent/cert.pem", - MITMKeyFile: "/nonexistent/key.pem", - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: ":0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: "/nonexistent/cert.pem", + MITMKeyFile: "/nonexistent/key.pem", }) require.Error(t, err) require.Contains(t, err.Error(), "failed to load MITM certificate") }) - t.Run("MissingDomainAllowlist", func(t *testing.T) { - t.Parallel() - - mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) - logger := slogtest.Make(t, nil) - - srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - AIBridgeProviderFromHost: testProviderFromHost, - }) - require.NoError(t, err) - t.Cleanup(func() { _ = srv.Close() }) - }) - - t.Run("EmptyDomainAllowlist", func(t *testing.T) { - t.Parallel() - - mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) - logger := slogtest.Make(t, nil) - - srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{""}, - AIBridgeProviderFromHost: testProviderFromHost, - }) - require.NoError(t, err) - t.Cleanup(func() { _ = srv.Close() }) - }) - - t.Run("InvalidDomainAllowlist", func(t *testing.T) { - t.Parallel() - - mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) - logger := slogtest.Make(t, nil) - - _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{"[invalid:domain"}, - }) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid domain") - }) - - t.Run("DomainWithNonAllowedPort", func(t *testing.T) { - t.Parallel() - - mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) - logger := slogtest.Make(t, nil) - - _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{"api.anthropic.com:8443"}, - }) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid port in domain") - }) - - t.Run("AllowlistWithoutProviderMapping", func(t *testing.T) { - t.Parallel() - - mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) - logger := slogtest.Make(t, nil) - - _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{"unknown.example.com"}, - AIBridgeProviderFromHost: testProviderFromHost, - }) - require.Error(t, err) - require.Contains(t, err.Error(), `domain "unknown.example.com" is in allowlist but has no provider mapping`) - }) - t.Run("InvalidUpstreamProxy", func(t *testing.T) { t.Parallel() @@ -774,13 +710,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "://invalid-url", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "://invalid-url", }) require.Error(t, err) require.Contains(t, err.Error(), "invalid upstream proxy URL") @@ -793,14 +727,12 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "https://proxy.example.com:8080", - UpstreamProxyCA: "/nonexistent/ca.pem", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "https://proxy.example.com:8080", + UpstreamProxyCA: "/nonexistent/ca.pem", }) require.Error(t, err) require.Contains(t, err.Error(), "failed to read upstream proxy CA certificate") @@ -813,13 +745,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "http://:@proxy.example.com:8080", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "http://:@proxy.example.com:8080", }) require.Error(t, err) require.Contains(t, err.Error(), "invalid credentials: both username and password are empty") @@ -832,13 +762,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - AllowedPrivateCIDRs: []string{"not-a-cidr"}, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + AllowedPrivateCIDRs: []string{"not-a-cidr"}, }) require.Error(t, err) require.Contains(t, err.Error(), "invalid allowed private CIDR") @@ -851,12 +779,10 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.NoError(t, err) require.NotNil(t, srv) @@ -870,14 +796,12 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - TLSCertFile: listenerCertFile, - TLSKeyFile: listenerKeyFile, - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + TLSCertFile: listenerCertFile, + TLSKeyFile: listenerKeyFile, + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.NoError(t, err) require.NotNil(t, srv) @@ -890,13 +814,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "http://proxy.example.com:8080", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "http://proxy.example.com:8080", }) require.NoError(t, err) require.NotNil(t, srv) @@ -910,14 +832,12 @@ func TestNew(t *testing.T) { // Use the shared MITM certificate as the upstream proxy CA (it's a valid PEM cert) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "https://proxy.example.com:8080", - UpstreamProxyCA: mitmCertFile, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "https://proxy.example.com:8080", + UpstreamProxyCA: mitmCertFile, }) require.NoError(t, err) require.NotNil(t, srv) @@ -930,13 +850,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "http://proxyuser:proxypass@proxy.example.com:8080", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "http://proxyuser:proxypass@proxy.example.com:8080", }) require.NoError(t, err) require.NotNil(t, srv) @@ -949,13 +867,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "http://proxyuser:@proxy.example.com:8080", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "http://proxyuser:@proxy.example.com:8080", }) require.NoError(t, err) require.NotNil(t, srv) @@ -969,13 +885,11 @@ func TestNew(t *testing.T) { // Username only (no colon) should also succeed (password is optional) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "http://proxyuser@proxy.example.com:8080", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "http://proxyuser@proxy.example.com:8080", }) require.NoError(t, err) require.NotNil(t, srv) @@ -988,13 +902,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - UpstreamProxy: "http://:proxypass@proxy.example.com:8080", + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + UpstreamProxy: "http://:proxypass@proxy.example.com:8080", }) require.NoError(t, err) require.NotNil(t, srv) @@ -1011,13 +923,11 @@ func TestNew(t *testing.T) { metrics := aibridgeproxyd.NewMetrics(reg) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - Metrics: metrics, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + Metrics: metrics, }) require.NoError(t, err) require.NotNil(t, srv) @@ -1030,13 +940,11 @@ func TestNew(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - AllowedPrivateCIDRs: []string{"127.0.0.1/32"}, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + AllowedPrivateCIDRs: []string{"127.0.0.1/32"}, }) require.NoError(t, err) require.NotNil(t, srv) @@ -1053,12 +961,10 @@ func TestClose(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, }) require.NoError(t, err) @@ -1081,13 +987,11 @@ func TestClose(t *testing.T) { metrics := aibridgeproxyd.NewMetrics(reg) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - AIBridgeProviderFromHost: testProviderFromHost, - Metrics: metrics, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + Metrics: metrics, }) require.NoError(t, err) @@ -1110,19 +1014,19 @@ func TestProxy_CertCaching(t *testing.T) { t.Parallel() tests := []struct { - name string - domainAllowlist []string - tunneled bool + name string + providerHosts []string + tunneled bool }{ { - name: "AllowlistedDomainCached", - domainAllowlist: nil, // will use targetURL.Hostname() - tunneled: false, + name: "ProviderHostCached", + providerHosts: nil, // will use targetURL.Hostname() + tunneled: false, }, { - name: "NonAllowlistedDomainNotCached", - domainAllowlist: []string{"other.example.com"}, - tunneled: true, + name: "NonProviderHostNotCached", + providerHosts: []string{"other.example.com"}, + tunneled: true, }, } @@ -1135,7 +1039,7 @@ func TestProxy_CertCaching(t *testing.T) { w.WriteHeader(http.StatusOK) }) - // Create a mock aibridged server for allowlisted (MITM'd) requests. + // Create a mock aibridged server for provider-host (MITM'd) requests. aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) @@ -1144,10 +1048,10 @@ func TestProxy_CertCaching(t *testing.T) { // Create a cert cache so we can inspect it after the request. certCache := aibridgeproxyd.NewCertCache() - // Configure domain allowlist. - domainAllowlist := tt.domainAllowlist - if domainAllowlist == nil { - domainAllowlist = []string{targetURL.Hostname()} + // Configure provider hosts. + providerHosts := tt.providerHosts + if providerHosts == nil { + providerHosts = []string{targetURL.Hostname()} } // Start the proxy server with the certificate cache. @@ -1155,7 +1059,7 @@ func TestProxy_CertCaching(t *testing.T) { withCoderAccessURL(aibridgedServer.URL), withAllowedPorts(targetURL.Port()), withCertStore(certCache), - withDomainAllowlist(domainAllowlist...), + withProviderHosts(providerHosts...), ) // Build the cert pool for the client to trust: @@ -1189,7 +1093,7 @@ func TestProxy_CertCaching(t *testing.T) { if tt.tunneled { // Certificate should NOT have been cached since request was tunneled. - require.Equal(t, 1, genCalls, "certificate should NOT have been cached for non-allowlisted domain") + require.Equal(t, 1, genCalls, "certificate should NOT have been cached for non-provider-host") } else { // Certificate should have been cached during MITM. require.Equal(t, 0, genCalls, "certificate should have been cached during request") @@ -1233,7 +1137,7 @@ func TestProxy_PortValidation(t *testing.T) { _, _ = w.Write([]byte("hello from target")) }) - // Create a mock aibridged server for allowlisted (MITM'd) requests. + // Create a mock aibridged server for provider-host (MITM'd) requests. aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("hello from aibridged")) @@ -1244,7 +1148,7 @@ func TestProxy_PortValidation(t *testing.T) { srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), withAllowedPorts(tt.allowedPorts(targetURL)...), - withDomainAllowlist(targetURL.Hostname()), + withProviderHosts(targetURL.Hostname()), ) // Make a request through the proxy to the target server. @@ -1309,7 +1213,7 @@ func TestProxy_Authentication(t *testing.T) { _, _ = w.Write([]byte("hello from target")) }) - // Create a mock aibridged server for allowlisted (MITM'd) requests. + // Create a mock aibridged server for provider-host (MITM'd) requests. aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("hello from aibridged")) @@ -1320,7 +1224,7 @@ func TestProxy_Authentication(t *testing.T) { srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), withAllowedPorts(targetURL.Port()), - withDomainAllowlist(targetURL.Hostname()), + withProviderHosts(targetURL.Hostname()), ) if tt.expectSuccess { @@ -1365,18 +1269,18 @@ func TestProxy_MITM(t *testing.T) { t.Parallel() tests := []struct { - name string - domainAllowlist []string - allowedPorts []string - buildTargetURL func(tunneledURL *url.URL) (string, error) - tunneled bool - expectedPath string - provider string + name string + providerHosts []string + allowedPorts []string + buildTargetURL func(tunneledURL *url.URL) (string, error) + tunneled bool + expectedPath string + provider string }{ { - name: "MitmdAnthropic", - domainAllowlist: []string{aibridgeproxyd.HostAnthropic}, - allowedPorts: []string{"443"}, + name: "MitmdAnthropic", + providerHosts: []string{aibridgeproxyd.HostAnthropic}, + allowedPorts: []string{"443"}, buildTargetURL: func(_ *url.URL) (string, error) { return "https://api.anthropic.com/v1/messages", nil }, @@ -1384,9 +1288,9 @@ func TestProxy_MITM(t *testing.T) { provider: "anthropic", }, { - name: "MitmdAnthropicNonDefaultPort", - domainAllowlist: []string{aibridgeproxyd.HostAnthropic}, - allowedPorts: []string{"8443"}, + name: "MitmdAnthropicNonDefaultPort", + providerHosts: []string{aibridgeproxyd.HostAnthropic}, + allowedPorts: []string{"8443"}, buildTargetURL: func(_ *url.URL) (string, error) { return "https://api.anthropic.com:8443/v1/messages", nil }, @@ -1394,9 +1298,9 @@ func TestProxy_MITM(t *testing.T) { provider: "anthropic", }, { - name: "MitmdOpenAI", - domainAllowlist: []string{aibridgeproxyd.HostOpenAI}, - allowedPorts: []string{"443"}, + name: "MitmdOpenAI", + providerHosts: []string{aibridgeproxyd.HostOpenAI}, + allowedPorts: []string{"443"}, buildTargetURL: func(_ *url.URL) (string, error) { return "https://api.openai.com/v1/chat/completions", nil }, @@ -1404,9 +1308,9 @@ func TestProxy_MITM(t *testing.T) { provider: "openai", }, { - name: "MitmdOpenAINonDefaultPort", - domainAllowlist: []string{aibridgeproxyd.HostOpenAI}, - allowedPorts: []string{"8443"}, + name: "MitmdOpenAINonDefaultPort", + providerHosts: []string{aibridgeproxyd.HostOpenAI}, + allowedPorts: []string{"8443"}, buildTargetURL: func(_ *url.URL) (string, error) { return "https://api.openai.com:8443/v1/chat/completions", nil }, @@ -1414,9 +1318,9 @@ func TestProxy_MITM(t *testing.T) { provider: "openai", }, { - name: "TunneledUnknownHost", - domainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, - allowedPorts: nil, // will use tunneledURL.Port() + name: "TunneledUnknownHost", + providerHosts: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI}, + allowedPorts: nil, // will use tunneledURL.Port() buildTargetURL: func(tunneledURL *url.URL) (string, error) { return url.JoinPath(tunneledURL.String(), "/some/path") }, @@ -1458,18 +1362,17 @@ func TestProxy_MITM(t *testing.T) { allowedPorts = []string{tunneledURL.Port()} } - // Configure domain allowlist. - domainAllowlist := tt.domainAllowlist - if domainAllowlist == nil { - domainAllowlist = []string{tunneledURL.Hostname()} + // Configure provider hosts. + providerHosts := tt.providerHosts + if providerHosts == nil { + providerHosts = []string{tunneledURL.Hostname()} } // Start the proxy server pointing to our mock aibridged. srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), withAllowedPorts(allowedPorts...), - withDomainAllowlist(domainAllowlist...), - withAIBridgeProviderFromHost(testProviderFromHost), + withProviderHosts(providerHosts...), withMetrics(metrics), ) @@ -1607,8 +1510,7 @@ func TestProxy_MITM_BYOKInjection(t *testing.T) { srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), - withDomainAllowlist(aibridgeproxyd.HostCopilot), - withAIBridgeProviderFromHost(testProviderFromHost), + withProviderHosts(aibridgeproxyd.HostCopilot), ) certPool := getProxyCertPool(t) @@ -1687,8 +1589,8 @@ func TestListenerTLS(t *testing.T) { withAllowedPorts(targetURL.Port()), ) if tt.tunneled { - // Use a domain allowlist that excludes the target server so requests are tunneled. - proxyOpts = append(proxyOpts, withDomainAllowlist(aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI)) + // Configure provider hosts that exclude the target server so requests are tunneled. + proxyOpts = append(proxyOpts, withProviderHosts(aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI)) } srv := newTestProxy(t, proxyOpts...) @@ -1791,14 +1693,10 @@ func TestServeCACert_CompoundPEM(t *testing.T) { logger := slogtest.Make(t, nil) srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: "127.0.0.1:0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: compoundCertFile, - MITMKeyFile: keyFile, - DomainAllowlist: []string{"127.0.0.1", "localhost"}, - AIBridgeProviderFromHost: func(host string) string { - return "test-provider" - }, + ListenAddr: "127.0.0.1:0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: compoundCertFile, + MITMKeyFile: keyFile, }) require.NoError(t, err) t.Cleanup(func() { _ = srv.Close() }) @@ -1849,8 +1747,8 @@ func TestUpstreamProxy(t *testing.T) { name string // tunneled determines whether the request should be tunneled through // the upstream proxy (true) or MITM'd by aiproxy (false). - // When true, the target domain is NOT in the allowlist. - // When false, the target domain IS in the allowlist. + // When true, the target domain has no configured provider. + // When false, the target domain has a configured provider. tunneled bool // upstreamProxyTLS determines whether the upstream proxy uses TLS. // When true, aiproxy must be configured with the upstream proxy's CA. @@ -1865,7 +1763,7 @@ func TestUpstreamProxy(t *testing.T) { upstreamProxyAuth string }{ { - name: "NonAllowlistedDomain_TunneledToHTTPUpstreamProxy", + name: "NonProviderHost_TunneledToHTTPUpstreamProxy", tunneled: true, upstreamProxyTLS: false, buildTargetURL: func(finalDestinationURL *url.URL) string { @@ -1873,7 +1771,7 @@ func TestUpstreamProxy(t *testing.T) { }, }, { - name: "NonAllowlistedDomain_TunneledToHTTPSUpstreamProxy", + name: "NonProviderHost_TunneledToHTTPSUpstreamProxy", tunneled: true, upstreamProxyTLS: true, buildTargetURL: func(finalDestinationURL *url.URL) string { @@ -1881,7 +1779,7 @@ func TestUpstreamProxy(t *testing.T) { }, }, { - name: "NonAllowlistedDomain_TunneledToHTTPUpstreamProxyWithAuth", + name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithAuth", tunneled: true, upstreamProxyTLS: false, upstreamProxyAuth: "proxyuser:proxypass", @@ -1890,7 +1788,7 @@ func TestUpstreamProxy(t *testing.T) { }, }, { - name: "NonAllowlistedDomain_TunneledToHTTPUpstreamProxyWithUsernameOnly", + name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithUsernameOnly", tunneled: true, upstreamProxyTLS: false, upstreamProxyAuth: "proxyuser", @@ -1899,7 +1797,7 @@ func TestUpstreamProxy(t *testing.T) { }, }, { - name: "NonAllowlistedDomain_TunneledToHTTPUpstreamProxyWithUsernameAndColon", + name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithUsernameAndColon", tunneled: true, upstreamProxyTLS: false, upstreamProxyAuth: "proxyuser:", @@ -1908,7 +1806,7 @@ func TestUpstreamProxy(t *testing.T) { }, }, { - name: "NonAllowlistedDomain_TunneledToHTTPUpstreamProxyWithTokenAuth", + name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithTokenAuth", tunneled: true, upstreamProxyTLS: false, upstreamProxyAuth: ":proxypass", @@ -1917,7 +1815,7 @@ func TestUpstreamProxy(t *testing.T) { }, }, { - name: "AllowlistedDomain_MITMByAIProxy", + name: "ProviderHost_MITMByAIProxy", tunneled: false, upstreamProxyTLS: false, buildTargetURL: func(_ *url.URL) string { @@ -2057,10 +1955,10 @@ func TestUpstreamProxy(t *testing.T) { parsedTargetURL, err := url.Parse(targetURL) require.NoError(t, err) - // Configure allowlist based on test case: - // - For tunneled requests, api.anthropic.com is in allowlist, but we target a different host. - // - For MITM, api.anthropic.com must be in the allowlist. - domainAllowlist := []string{aibridgeproxyd.HostAnthropic} + // Configure provider hosts based on test case: + // - For tunneled requests, api.anthropic.com has a configured provider, but we target a different host. + // - For MITM, api.anthropic.com must have a configured provider. + providerHosts := []string{aibridgeproxyd.HostAnthropic} // Build upstream proxy URL with optional auth credentials. upstreamProxyURLStr := upstreamProxy.URL @@ -2073,10 +1971,9 @@ func TestUpstreamProxy(t *testing.T) { // Create aiproxy with upstream proxy configured. proxyOpts := []testProxyOption{ withCoderAccessURL(aibridgeServer.URL), - withDomainAllowlist(domainAllowlist...), + withProviderHosts(providerHosts...), withUpstreamProxy(upstreamProxyURLStr), withAllowedPorts("80", "443", parsedTargetURL.Port()), - withAIBridgeProviderFromHost(testProviderFromHost), } if upstreamProxyCAFile != "" { proxyOpts = append(proxyOpts, withUpstreamProxyCA(upstreamProxyCAFile)) @@ -2114,7 +2011,7 @@ func TestUpstreamProxy(t *testing.T) { // Verify the request flow based on test case. if tt.tunneled { require.True(t, upstreamProxyCONNECTReceived, - "upstream proxy should receive CONNECT for non-allowlisted domain") + "upstream proxy should receive CONNECT for non-provider-host") require.Equal(t, finalDestinationURL.Host, upstreamProxyCONNECTHost, "upstream proxy should receive CONNECT to correct host") require.True(t, finalDestinationReceived, @@ -2124,12 +2021,12 @@ func TestUpstreamProxy(t *testing.T) { require.Equal(t, requestBody, finalDestinationBody, "final destination should receive the exact request body") require.False(t, aibridgeReceived, - "aibridge should NOT receive request for non-allowlisted domain") + "aibridge should NOT receive request for non-provider-host") require.Empty(t, aibridgeAuthz, "tunneled requests should not reach aibridge") } else { require.False(t, upstreamProxyCONNECTReceived, - "upstream proxy should NOT receive CONNECT for allowlisted domain") + "upstream proxy should NOT receive CONNECT for provider host") require.True(t, aibridgeReceived, "aibridge should receive the MITM'd request") require.Equal(t, tt.expectedAIBridgePath, aibridgePath, @@ -2141,7 +2038,7 @@ func TestUpstreamProxy(t *testing.T) { require.Equal(t, requestBody, aibridgeBody, "aibridge should receive the exact request body") require.False(t, finalDestinationReceived, - "final destination should NOT receive request for allowlisted domain") + "final destination should NOT receive request for provider host") } // Verify upstream proxy authentication if configured. @@ -2155,7 +2052,7 @@ func TestUpstreamProxy(t *testing.T) { } // TestProxy_MITM_CustomProvider verifies that a non-builtin provider -// (e.g. OpenRouter) whose domain is added to the allowlist is correctly +// (e.g. OpenRouter) whose domain is registered as a provider host is correctly // MITM'd and routed through the proxy to the bridge endpoint. func TestProxy_MITM_CustomProvider(t *testing.T) { t.Parallel() @@ -2177,16 +2074,18 @@ func TestProxy_MITM_CustomProvider(t *testing.T) { })) t.Cleanup(aibridgedServer.Close) - // Wire the custom domain and provider mapping directly, as the - // real daemon would after calling domainsFromProviders. + // Wire the custom domain and provider mapping directly via + // withProviders, equivalent to the snapshot the daemon's Reload + // builds from classified providers in production. srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), - withDomainAllowlist(openrouterDomain), - withAIBridgeProviderFromHost(func(host string) string { - if host == openrouterDomain { - return openrouterProvider - } - return "" + withProviders(aibridgeproxyd.ReloadedProvider{ + ProviderOutcome: aibridged.ProviderOutcome{ + Name: openrouterProvider, + Type: "openai", + Status: aibridged.ProviderStatusEnabled, + }, + Host: openrouterDomain, }), ) @@ -2307,10 +2206,10 @@ func TestProxy_PrivateIPBlocking(t *testing.T) { // Build the CONNECT target using the configured hostname. connectTarget := fmt.Sprintf("%s:%s", tt.targetHostname, targetURL.Port()) - // Use a domain allowlist that excludes the target so CONNECT requests + // Configure provider hosts that exclude the target so CONNECT requests // go through the tunnel path rather than being MITM'd. opts := []testProxyOption{ - withDomainAllowlist(aibridgeproxyd.HostAnthropic), + withProviderHosts(aibridgeproxyd.HostAnthropic), withAllowedPorts(targetURL.Port()), } @@ -2395,8 +2294,7 @@ func TestProxy_APIDump(t *testing.T) { srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), withAllowedPorts("443"), - withDomainAllowlist(aibridgeproxyd.HostAnthropic), - withAIBridgeProviderFromHost(testProviderFromHost), + withProviderHosts(aibridgeproxyd.HostAnthropic), withNewDumper(func(provider, requestID string) aibridgeproxyd.RoundTripDumper { dumpedProvider = provider dumpedRequestID = requestID @@ -2443,8 +2341,7 @@ func TestProxy_APIDump_ErrorsDoNotAffectProxy(t *testing.T) { srv := newTestProxy(t, withCoderAccessURL(aibridgedServer.URL), withAllowedPorts("443"), - withDomainAllowlist(aibridgeproxyd.HostAnthropic), - withAIBridgeProviderFromHost(testProviderFromHost), + withProviderHosts(aibridgeproxyd.HostAnthropic), withNewDumper(func(_, _ string) aibridgeproxyd.RoundTripDumper { return &failingDumper{} }), diff --git a/enterprise/aibridgeproxyd/metrics.go b/enterprise/aibridgeproxyd/metrics.go index 55a1fa4177..ccfd334aa7 100644 --- a/enterprise/aibridgeproxyd/metrics.go +++ b/enterprise/aibridgeproxyd/metrics.go @@ -30,6 +30,21 @@ type Metrics struct { // Labels: code (HTTP status code), provider // Cardinality is bounded: ~100 used status codes x few providers. MITMResponsesTotal *prometheus.CounterVec + + // ProviderInfo is one series per configured provider; value is + // always 1 and the status label carries the alertable signal. + // Labels: provider_name, provider_type, status. + ProviderInfo *prometheus.GaugeVec + + // ProvidersLastReloadTimestampSeconds is the unix timestamp of the + // last reload attempt, success or failure. + ProvidersLastReloadTimestampSeconds prometheus.Gauge + + // ProvidersLastReloadSuccessTimestampSeconds is the unix timestamp + // of the last reload that successfully refreshed the router. A gap + // against ProvidersLastReloadTimestampSeconds means the loop is + // firing but the refresh function is failing. + ProvidersLastReloadSuccessTimestampSeconds prometheus.Gauge } // NewMetrics creates and registers all metrics for aibridgeproxyd. @@ -58,6 +73,21 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "mitm_responses_total", Help: "Total number of MITM responses by HTTP status code class.", }, []string{"code", "provider"}), + + ProviderInfo: factory.NewGaugeVec(prometheus.GaugeOpts{ + Name: "provider_info", + Help: "One series per configured AI provider. Value is always 1; the status label (enabled, disabled, error) carries the alertable signal.", + }, []string{"provider_name", "provider_type", "status"}), + + ProvidersLastReloadTimestampSeconds: factory.NewGauge(prometheus.GaugeOpts{ + Name: "providers_last_reload_timestamp_seconds", + Help: "Unix timestamp of the last provider reload attempt, success or failure.", + }), + + ProvidersLastReloadSuccessTimestampSeconds: factory.NewGauge(prometheus.GaugeOpts{ + Name: "providers_last_reload_success_timestamp_seconds", + Help: "Unix timestamp of the last provider reload that successfully refreshed the router. A gap against coder_aibridgeproxyd_providers_last_reload_timestamp_seconds means the loop is firing but the refresh function is failing.", + }), } } @@ -67,4 +97,7 @@ func (m *Metrics) Unregister() { m.registerer.Unregister(m.MITMRequestsTotal) m.registerer.Unregister(m.InflightMITMRequests) m.registerer.Unregister(m.MITMResponsesTotal) + m.registerer.Unregister(m.ProviderInfo) + m.registerer.Unregister(m.ProvidersLastReloadTimestampSeconds) + m.registerer.Unregister(m.ProvidersLastReloadSuccessTimestampSeconds) } diff --git a/enterprise/aibridgeproxyd/metrics_internal_test.go b/enterprise/aibridgeproxyd/metrics_internal_test.go new file mode 100644 index 0000000000..6ebefbd56b --- /dev/null +++ b/enterprise/aibridgeproxyd/metrics_internal_test.go @@ -0,0 +1,135 @@ +package aibridgeproxyd + +import ( + "context" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/aibridged" + "github.com/coder/coder/v2/testutil" +) + +// TestReloadUpdatesProviderMetrics covers the provider_info GaugeVec +// surface: every reload pass rewrites the series for the current +// snapshot, including disabled and errored rows; the Reset on each +// reload drops series for providers that have left the configuration. +func TestReloadUpdatesProviderMetrics(t *testing.T) { + t.Parallel() + + reg := prometheus.NewRegistry() + metrics := NewMetrics(reg) + + reload := ProviderReload{Providers: []ReloadedProvider{ + {ProviderOutcome: aibridged.ProviderOutcome{Name: "alpha", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "alpha.example.com"}, + {ProviderOutcome: aibridged.ProviderOutcome{Name: "beta", Type: "anthropic", Status: aibridged.ProviderStatusDisabled}}, + {ProviderOutcome: aibridged.ProviderOutcome{Name: "gamma", Type: "openai", Status: aibridged.ProviderStatusError, Err: xerrors.New("bad config")}}, + }} + + ctx := testutil.Context(t, testutil.WaitShort) + srv := &Server{ + ctx: ctx, + logger: slogtest.Make(t, nil), + allowedPorts: []string{"443"}, + metrics: metrics, + refreshProviders: func(context.Context) (ProviderReload, error) { + return reload, nil + }, + } + srv.providerRouter.Store(emptyProviderRouter) + + before := time.Now().Unix() + require.NoError(t, srv.Reload(ctx)) + after := time.Now().Unix() + + assert.Equal(t, 1.0, promtest.ToFloat64(metrics.ProviderInfo.WithLabelValues("alpha", "openai", "enabled"))) + assert.Equal(t, 1.0, promtest.ToFloat64(metrics.ProviderInfo.WithLabelValues("beta", "anthropic", "disabled"))) + assert.Equal(t, 1.0, promtest.ToFloat64(metrics.ProviderInfo.WithLabelValues("gamma", "openai", "error"))) + + attemptTS := int64(promtest.ToFloat64(metrics.ProvidersLastReloadTimestampSeconds)) + successTS := int64(promtest.ToFloat64(metrics.ProvidersLastReloadSuccessTimestampSeconds)) + assert.GreaterOrEqual(t, attemptTS, before) + assert.LessOrEqual(t, attemptTS, after) + assert.GreaterOrEqual(t, successTS, before) + assert.LessOrEqual(t, successTS, after) +} + +// TestReloadResetsStaleProviderSeries verifies that providers removed +// between reloads do not leave behind stale series. Without Reset, a +// removed provider's last-seen value would persist for 5+ minutes and +// could fire alerts despite the provider no longer being configured. +func TestReloadResetsStaleProviderSeries(t *testing.T) { + t.Parallel() + + reg := prometheus.NewRegistry() + metrics := NewMetrics(reg) + + current := ProviderReload{Providers: []ReloadedProvider{ + {ProviderOutcome: aibridged.ProviderOutcome{Name: "alpha", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "alpha.example.com"}, + {ProviderOutcome: aibridged.ProviderOutcome{Name: "beta", Type: "anthropic", Status: aibridged.ProviderStatusEnabled}, Host: "beta.example.com"}, + }} + + ctx := testutil.Context(t, testutil.WaitShort) + srv := &Server{ + ctx: ctx, + logger: slogtest.Make(t, nil), + allowedPorts: []string{"443"}, + metrics: metrics, + refreshProviders: func(context.Context) (ProviderReload, error) { + return current, nil + }, + } + srv.providerRouter.Store(emptyProviderRouter) + + require.NoError(t, srv.Reload(ctx)) + require.Equal(t, 2, promtest.CollectAndCount(metrics.ProviderInfo)) + + current = ProviderReload{Providers: []ReloadedProvider{ + {ProviderOutcome: aibridged.ProviderOutcome{Name: "alpha", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "alpha.example.com"}, + }} + require.NoError(t, srv.Reload(ctx)) + + assert.Equal(t, 1, promtest.CollectAndCount(metrics.ProviderInfo), + "beta should have been Reset out of the GaugeVec") + assert.Equal(t, 1.0, promtest.ToFloat64(metrics.ProviderInfo.WithLabelValues("alpha", "openai", "enabled"))) +} + +// TestReloadAttemptTimestampUpdatesOnFailure asserts the attempt-time +// gauge advances even when the refresh function fails, while the +// success-time gauge does not. +func TestReloadAttemptTimestampUpdatesOnFailure(t *testing.T) { + t.Parallel() + + reg := prometheus.NewRegistry() + metrics := NewMetrics(reg) + refreshErr := xerrors.New("simulated failure") + + ctx := testutil.Context(t, testutil.WaitShort) + srv := &Server{ + ctx: ctx, + logger: slogtest.Make(t, nil), + allowedPorts: []string{"443"}, + metrics: metrics, + refreshProviders: func(context.Context) (ProviderReload, error) { + return ProviderReload{}, refreshErr + }, + } + srv.providerRouter.Store(emptyProviderRouter) + + before := time.Now().Unix() + err := srv.Reload(ctx) + require.ErrorIs(t, err, refreshErr) + after := time.Now().Unix() + + attemptTS := int64(promtest.ToFloat64(metrics.ProvidersLastReloadTimestampSeconds)) + successTS := int64(promtest.ToFloat64(metrics.ProvidersLastReloadSuccessTimestampSeconds)) + assert.GreaterOrEqual(t, attemptTS, before) + assert.LessOrEqual(t, attemptTS, after) + assert.Equal(t, int64(0), successTS, "success timestamp must not advance on failure") +} diff --git a/enterprise/aibridgeproxyd/reload.go b/enterprise/aibridgeproxyd/reload.go index d235a9be3b..04b1f5438b 100644 --- a/enterprise/aibridgeproxyd/reload.go +++ b/enterprise/aibridgeproxyd/reload.go @@ -3,37 +3,94 @@ package aibridgeproxyd import ( "context" "net/http" - "net/url" "slices" "strings" + "time" "github.com/elazarl/goproxy" "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridged" ) +// ReloadedProvider is the classification of one ai_providers row. +// Host is the routable hostname; it's populated only when the embedded +// outcome's Status == aibridged.ProviderStatusEnabled. +type ReloadedProvider struct { + aibridged.ProviderOutcome + Host string +} + +// ProviderReload is the result of a single refresh pass: every +// configured provider with its classification. +type ProviderReload struct { + Providers []ReloadedProvider +} + +// RefreshProvidersFunc returns the live provider classification used +// by Reload to rebuild the proxy's routing snapshot. +type RefreshProvidersFunc func(ctx context.Context) (ProviderReload, error) + // Reload refreshes proxy routing from the configured provider source. // A refresh failure leaves the previous snapshot in place. func (s *Server) Reload(ctx context.Context) error { if s.refreshProviders == nil { return nil } - providers, err := s.refreshProviders(ctx) + s.recordReloadAttempt() + reload, err := s.refreshProviders(ctx) if err != nil { return xerrors.Errorf("refresh ai providers for proxy routing: %w", err) } - router, err := buildProviderRouter(ctx, s.logger, providers, s.allowedPorts) + router, err := buildProviderRouter(reload, s.allowedPorts) if err != nil { - return xerrors.Errorf("build provider router (provider_count=%d): %w", len(providers), err) + return xerrors.Errorf("build provider router (provider_count=%d): %w", len(reload.Providers), err) } s.providerRouter.Store(router) + for _, p := range reload.Providers { + if p.Status == aibridged.ProviderStatusError { + s.logger.Warn(s.ctx, "provider excluded from routing", + slog.F("provider", p.Name), + slog.Error(p.Err), + ) + } + } + s.recordReloadSuccess(reload) s.logger.Debug(s.ctx, "aibridgeproxyd router reloaded", + slog.F("provider_count", len(reload.Providers)), slog.F("mitm_host_count", len(router.mitmHosts)), + slog.F("mitm_hosts", router.mitmHosts), ) return nil } +// recordReloadAttempt stamps the attempt-time gauge at the start of a +// Reload. A reload that hangs mid-flight is detected by watching the +// gap between this gauge and ProvidersLastReloadSuccessTimestampSeconds. +func (s *Server) recordReloadAttempt() { + if s.metrics == nil { + return + } + s.metrics.ProvidersLastReloadTimestampSeconds.Set(float64(time.Now().Unix())) +} + +// recordReloadSuccess rewrites the provider_info GaugeVec from the +// classified reload and stamps the success-time gauge. Reset clears +// series for providers that have left the configuration so they don't +// linger as stale. +func (s *Server) recordReloadSuccess(reload ProviderReload) { + if s.metrics == nil { + return + } + outcomes := make([]aibridged.ProviderOutcome, len(reload.Providers)) + for i, p := range reload.Providers { + outcomes[i] = p.ProviderOutcome + } + aibridged.WriteProviderInfoSnapshot(s.metrics.ProviderInfo, outcomes) + s.metrics.ProvidersLastReloadSuccessTimestampSeconds.Set(float64(time.Now().Unix())) +} + func (s *Server) loadProviderRouter() *providerRouter { if p := s.providerRouter.Load(); p != nil { return p @@ -42,7 +99,7 @@ func (s *Server) loadProviderRouter() *providerRouter { } // mitmHostsCondition returns a goproxy ReqConditionFunc that reads the -// allowlist from the atomic router on every match. Using a closure +// MITM host set from the atomic router on every match. Using a closure // instead of goproxy.ReqHostIs(...) lets Reload affect every later // CONNECT without re-registering handlers. func (s *Server) mitmHostsCondition() goproxy.ReqConditionFunc { @@ -54,35 +111,24 @@ func (s *Server) mitmHostsCondition() goproxy.ReqConditionFunc { } } -// buildProviderRouter constructs a router snapshot from a refreshed -// provider list. First provider wins on duplicate hostnames. -func buildProviderRouter(ctx context.Context, logger slog.Logger, providers []ProviderRoute, allowedPorts []string) (*providerRouter, error) { - nameByHost := make(map[string]string, len(providers)) - var domains []string - for _, p := range providers { - if p.BaseURL == "" { - logger.Warn(ctx, "skipping ai provider without base url", - slog.F("provider_name", p.Name), - ) +// buildProviderRouter constructs a router snapshot from a classified +// provider reload. Only providers with Status == +// aibridged.ProviderStatusEnabled are included in the active routing +// tables; the refresh function is responsible for classifying disabled +// and errored rows. First entry wins on duplicate hostnames as a +// defense-in-depth measure even though the refresh function should +// mark duplicates as errors. +func buildProviderRouter(reload ProviderReload, allowedPorts []string) (*providerRouter, error) { + nameByHost := make(map[string]string, len(reload.Providers)) + domains := make([]string, 0, len(reload.Providers)) + for _, p := range reload.Providers { + if p.Status != aibridged.ProviderStatusEnabled { continue } - u, err := url.Parse(p.BaseURL) - if err != nil { - logger.Warn(ctx, "skipping ai provider with invalid base url", - slog.F("provider_name", p.Name), - slog.F("base_url", p.BaseURL), - slog.Error(err), - ) + host := strings.ToLower(p.Host) + if host == "" { continue } - if u.Hostname() == "" { - logger.Warn(ctx, "skipping ai provider base url without hostname", - slog.F("provider_name", p.Name), - slog.F("base_url", p.BaseURL), - ) - continue - } - host := strings.ToLower(u.Hostname()) if _, exists := nameByHost[host]; exists { continue } @@ -95,30 +141,3 @@ func buildProviderRouter(ctx context.Context, logger slog.Logger, providers []Pr } return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil } - -// buildBootRouter seeds the providerRouter from the boot-time inputs. -// The lookup function is consulted only for hosts in the allowlist; a -// nil function with an empty allowlist is fine and yields an empty -// router (the proxy fails closed until Reload populates it). -func buildBootRouter(domainAllowlist []string, providerFromHost func(string) string, allowedPorts []string) (*providerRouter, error) { - mitmHosts, err := convertDomainsToHosts(domainAllowlist, allowedPorts) - if err != nil { - return nil, xerrors.Errorf("invalid domain allowlist: %w", err) - } - nameByHost := make(map[string]string, len(domainAllowlist)) - for _, domain := range domainAllowlist { - domain = strings.TrimSpace(strings.ToLower(domain)) - if domain == "" { - continue - } - var name string - if providerFromHost != nil { - name = providerFromHost(domain) - } - if name == "" { - return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain) - } - nameByHost[domain] = name - } - return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil -} diff --git a/enterprise/aibridgeproxyd/reload_internal_test.go b/enterprise/aibridgeproxyd/reload_internal_test.go index 7a8b0f9fae..5ccba37ec7 100644 --- a/enterprise/aibridgeproxyd/reload_internal_test.go +++ b/enterprise/aibridgeproxyd/reload_internal_test.go @@ -9,20 +9,32 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/aibridged" "github.com/coder/coder/v2/testutil" ) +func enabledProvider(name, host string) ReloadedProvider { + return ReloadedProvider{ + ProviderOutcome: aibridged.ProviderOutcome{ + Name: name, + Type: "openai", + Status: aibridged.ProviderStatusEnabled, + }, + Host: host, + } +} + func TestServerReloadSwapsProviderRouter(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) - providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}} + reload := ProviderReload{Providers: []ReloadedProvider{enabledProvider("old", "old.example.com")}} srv := &Server{ ctx: ctx, logger: slogtest.Make(t, nil), allowedPorts: []string{"443"}, - refreshProviders: func(context.Context) ([]ProviderRoute, error) { - return providers, nil + refreshProviders: func(context.Context) (ProviderReload, error) { + return reload, nil }, } srv.providerRouter.Store(emptyProviderRouter) @@ -31,7 +43,7 @@ func TestServerReloadSwapsProviderRouter(t *testing.T) { assert.Equal(t, "old", srv.loadProviderRouter().providerFromHost("old.example.com")) assert.Empty(t, srv.loadProviderRouter().providerFromHost("new.example.com")) - providers = []ProviderRoute{{Name: "new", BaseURL: "https://new.example.com/"}} + reload = ProviderReload{Providers: []ReloadedProvider{enabledProvider("new", "new.example.com")}} require.NoError(t, srv.Reload(ctx)) router := srv.loadProviderRouter() @@ -45,17 +57,17 @@ func TestServerReloadPreservesProviderRouterOnRefreshError(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) refreshErr := xerrors.New("refresh failed") - providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}} + reload := ProviderReload{Providers: []ReloadedProvider{enabledProvider("old", "old.example.com")}} failRefresh := false srv := &Server{ ctx: ctx, logger: slogtest.Make(t, nil), allowedPorts: []string{"443"}, - refreshProviders: func(context.Context) ([]ProviderRoute, error) { + refreshProviders: func(context.Context) (ProviderReload, error) { if failRefresh { - return nil, refreshErr + return ProviderReload{}, refreshErr } - return providers, nil + return reload, nil }, } srv.providerRouter.Store(emptyProviderRouter) @@ -73,75 +85,84 @@ func TestServerReloadPreservesProviderRouterOnRefreshError(t *testing.T) { assert.Equal(t, []string{"old.example.com:443"}, after.mitmHosts) } -// TestBuildProviderRouter covers the host-and-routing derivation that -// Reload feeds into the providerRouter. +// TestBuildProviderRouter covers the host-and-routing derivation from +// the classified provider reload. func TestBuildProviderRouter(t *testing.T) { t.Parallel() - t.Run("ExtractsHostnames", func(t *testing.T) { + t.Run("IncludesEnabledOnly", func(t *testing.T) { t.Parallel() - providers := []ProviderRoute{ - {Name: "openai", BaseURL: "https://api.openai.com/v1/"}, - {Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, - {Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"}, - } + reload := ProviderReload{Providers: []ReloadedProvider{ + enabledProvider("openai", "api.openai.com"), + enabledProvider("anthropic", "api.anthropic.com"), + enabledProvider("custom", "custom-llm.example.com"), + // Host is populated on the non-enabled rows so the Status + // guard, not the empty-host guard, is what excludes them. + {ProviderOutcome: aibridged.ProviderOutcome{Name: "off", Type: "openai", Status: aibridged.ProviderStatusDisabled}, Host: "disabled.example.com"}, + {ProviderOutcome: aibridged.ProviderOutcome{Name: "bad", Type: "openai", Status: aibridged.ProviderStatusError, Err: xerrors.New("nope")}, Host: "errored.example.com"}, + }} - router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + router, err := buildProviderRouter(reload, []string{"443"}) require.NoError(t, err) assert.Equal(t, "openai", router.providerFromHost("api.openai.com")) assert.Equal(t, "anthropic", router.providerFromHost("api.anthropic.com")) assert.Equal(t, "custom", router.providerFromHost("custom-llm.example.com")) assert.Empty(t, router.providerFromHost("unknown.com")) + assert.Empty(t, router.providerFromHost("disabled.example.com"), + "disabled provider must not be routable even with a populated Host") + assert.Empty(t, router.providerFromHost("errored.example.com"), + "errored provider must not be routable even with a populated Host") assert.Contains(t, router.mitmHosts, "api.openai.com:443") assert.Contains(t, router.mitmHosts, "api.anthropic.com:443") - }) - - t.Run("DeduplicatesSameHost", func(t *testing.T) { - t.Parallel() - - providers := []ProviderRoute{ - {Name: "first", BaseURL: "https://api.example.com/v1"}, - {Name: "second", BaseURL: "https://api.example.com/v2"}, - } - - router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) - require.NoError(t, err) - - // First provider wins on duplicate host. - assert.Equal(t, "first", router.providerFromHost("api.example.com")) + assert.Len(t, router.mitmHosts, 3) }) t.Run("CaseInsensitive", func(t *testing.T) { t.Parallel() - providers := []ProviderRoute{ - {Name: "provider", BaseURL: "https://API.Example.COM/v1"}, - } + reload := ProviderReload{Providers: []ReloadedProvider{ + {ProviderOutcome: aibridged.ProviderOutcome{Name: "provider", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "API.Example.COM"}, + }} - router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + router, err := buildProviderRouter(reload, []string{"443"}) require.NoError(t, err) assert.Equal(t, "provider", router.providerFromHost("API.Example.COM")) assert.Equal(t, "provider", router.providerFromHost("api.example.com")) }) - t.Run("SkipsEmptyOrMalformedBaseURL", func(t *testing.T) { + t.Run("DefensiveDeduplicatesSameHost", func(t *testing.T) { t.Parallel() - providers := []ProviderRoute{ - {Name: "no-url"}, - {Name: "scheme-only", BaseURL: "https://"}, - {Name: "good", BaseURL: "https://api.good.example.com/"}, - } + // Refresh function should mark the duplicate as ProviderStatusError; + // buildProviderRouter is defensive and tolerates an enabled duplicate + // by giving the first entry the host (first wins). + reload := ProviderReload{Providers: []ReloadedProvider{ + enabledProvider("first", "api.example.com"), + enabledProvider("second", "api.example.com"), + }} - router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + router, err := buildProviderRouter(reload, []string{"443"}) + require.NoError(t, err) + + assert.Equal(t, "first", router.providerFromHost("api.example.com")) + }) + + t.Run("SkipsRowsWithEmptyHost", func(t *testing.T) { + t.Parallel() + + reload := ProviderReload{Providers: []ReloadedProvider{ + {ProviderOutcome: aibridged.ProviderOutcome{Name: "no-host", Type: "openai", Status: aibridged.ProviderStatusEnabled}}, + enabledProvider("good", "api.good.example.com"), + }} + + router, err := buildProviderRouter(reload, []string{"443"}) require.NoError(t, err) assert.Equal(t, "good", router.providerFromHost("api.good.example.com")) - assert.Empty(t, router.providerFromHost("scheme-only")) assert.Equal(t, []string{"api.good.example.com:443"}, router.mitmHosts) }) } diff --git a/enterprise/aibridgeproxyd/reload_test.go b/enterprise/aibridgeproxyd/reload_test.go index d51aa5bc98..bfc90338d4 100644 --- a/enterprise/aibridgeproxyd/reload_test.go +++ b/enterprise/aibridgeproxyd/reload_test.go @@ -5,15 +5,19 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "slices" "strings" "sync" "testing" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/aibridged" "github.com/coder/coder/v2/enterprise/aibridgeproxyd" "github.com/coder/coder/v2/testutil" ) @@ -27,6 +31,7 @@ type reloadTestHarness struct { client *http.Client bridged *httptest.Server recorder *aibridgedRecorder + metrics *aibridgeproxyd.Metrics } // aibridgedRecorder captures the path of the last request received by @@ -55,16 +60,24 @@ func (r *aibridgedRecorder) reset() { r.path = "" } -// providerStore is a mutable [aibridgeproxyd.RefreshProvidersFunc] -// backing for integration tests. set / setErr mutate the snapshot -// returned by the next Reload, mimicking CRUD against the database. +// rawProvider is a (name, base URL) pair representing what the database +// holds before classification, mirroring the ai_providers row shape +// that the production refresh function classifies. +type rawProvider struct { + name string + baseURL string +} + +// providerStore is a mutable RefreshProvidersFunc backing for +// integration tests. set / setErr mutate the snapshot returned by the +// next Reload, mimicking CRUD against the database. type providerStore struct { mu sync.Mutex - providers []aibridgeproxyd.ProviderRoute + providers []rawProvider err error } -func (s *providerStore) set(providers []aibridgeproxyd.ProviderRoute) { +func (s *providerStore) set(providers []rawProvider) { s.mu.Lock() defer s.mu.Unlock() s.providers = providers @@ -77,20 +90,61 @@ func (s *providerStore) setErr(err error) { s.err = err } -func (s *providerStore) refresh(context.Context) ([]aibridgeproxyd.ProviderRoute, error) { +func (s *providerStore) refresh(context.Context) (aibridgeproxyd.ProviderReload, error) { s.mu.Lock() defer s.mu.Unlock() if s.err != nil { - return nil, s.err + return aibridgeproxyd.ProviderReload{}, s.err } - // Return a copy so callers can't mutate our internal snapshot. - return slices.Clone(s.providers), nil + providers := slices.Clone(s.providers) + reload := aibridgeproxyd.ProviderReload{ + Providers: make([]aibridgeproxyd.ReloadedProvider, 0, len(providers)), + } + seenHost := make(map[string]string, len(providers)) + for _, p := range providers { + reload.Providers = append(reload.Providers, classifyRaw(p, seenHost)) + } + return reload, nil } -// newReloadTestHarness boots a proxy with an empty boot allowlist and a -// store-backed RefreshProviders. Production wiring is identical: the -// daemon constructs the proxy without a static allowlist and lets -// Reload populate the router from the database. +// classifyRaw mirrors the production classifier in enterprise/cli so +// the reload tests exercise the same validation rules end-to-end. +func classifyRaw(p rawProvider, seenHost map[string]string) aibridgeproxyd.ReloadedProvider { + out := aibridgeproxyd.ReloadedProvider{ + ProviderOutcome: aibridged.ProviderOutcome{Name: p.name, Type: "openai"}, + } + if strings.TrimSpace(p.baseURL) == "" { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.New("base url is empty") + return out + } + u, err := url.Parse(p.baseURL) + if err != nil { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.Errorf("invalid base url %q: %w", p.baseURL, err) + return out + } + host := strings.ToLower(u.Hostname()) + if host == "" { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.Errorf("base url %q has no hostname", p.baseURL) + return out + } + if claimedBy, taken := seenHost[host]; taken { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.Errorf("hostname %q already claimed by provider %q", host, claimedBy) + return out + } + seenHost[host] = p.name + out.Host = host + out.Status = aibridged.ProviderStatusEnabled + return out +} + +// newReloadTestHarness boots a proxy with an empty initial router and +// a store-backed RefreshProviders. Production wiring is identical: the +// daemon constructs the proxy without preconfigured provider hosts and +// lets Reload populate the router from the database. func newReloadTestHarness(t *testing.T) *reloadTestHarness { t.Helper() @@ -103,14 +157,12 @@ func newReloadTestHarness(t *testing.T) *reloadTestHarness { t.Cleanup(bridged.Close) store := &providerStore{} + metrics := aibridgeproxyd.NewMetrics(prometheus.NewRegistry()) srv := newTestProxy(t, withCoderAccessURL(bridged.URL), withAllowedPorts("443"), - // Empty boot allowlist: the router must be populated by Reload, - // matching the production daemon's behavior. - withDomainAllowlist(), - withAIBridgeProviderFromHost(nil), withRefreshProviders(store.refresh), + withMetrics(metrics), ) certPool := getProxyCertPool(t) @@ -125,6 +177,7 @@ func newReloadTestHarness(t *testing.T) *reloadTestHarness { return &reloadTestHarness{ srv: srv, store: store, + metrics: metrics, client: client, bridged: bridged, recorder: recorder, @@ -192,16 +245,147 @@ func (h *reloadTestHarness) expectNotRouted(t *testing.T, targetURL string) { "aibridged must not be reached for non-routed host %s", targetURL) } +// expectProviderStatus asserts the provider_info series for (name, +// status) is present with value 1. +func (h *reloadTestHarness) expectProviderStatus(t *testing.T, name, status string) { + t.Helper() + assert.Equal(t, 1.0, promtest.ToFloat64(h.metrics.ProviderInfo.WithLabelValues(name, "openai", status)), + "expected provider_info{provider_name=%q, status=%q} == 1", name, status) +} + +// expectProviderAbsent asserts no series exists for the provider name +// in any status. This verifies the GaugeVec.Reset on each reload +// clears stale entries. +func (h *reloadTestHarness) expectProviderAbsent(t *testing.T, name string) { + t.Helper() + for _, status := range []string{"enabled", "disabled", "error"} { + assert.Equal(t, 0.0, promtest.ToFloat64(h.metrics.ProviderInfo.WithLabelValues(name, "openai", status)), + "expected no provider_info series for %q, found status %q", name, status) + } +} + +// TestProxy_StaleTunnelStopsRoutingAfterProviderChange is the +// regression test for a bug where a long-lived CONNECT tunnel that was +// established while a provider was enabled kept routing decrypted +// requests to aibridged after the provider was disabled or renamed. The +// fix re-validates the CONNECT-time provider against the live router on +// every decrypted request and covers both shapes of stale mapping: +// +// - ProviderDisabled: liveProvider == "" (host no longer MITM'd). +// - ProviderRenamed: liveProvider != reqCtx.Provider (host MITM'd, but +// under a new provider name). +func TestProxy_StaleTunnelStopsRoutingAfterProviderChange(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + // applyChange mutates the store to simulate the provider change + // after the initial routed request succeeds. + applyChange func(*providerStore) + // changeDescription is appended to the second-request assertion + // message so a failure points at the exercised branch. + changeDescription string + }{ + { + name: "ProviderDisabled", + applyChange: func(s *providerStore) { s.set(nil) }, + changeDescription: "after alpha was disabled", + }, + { + name: "ProviderRenamed", + applyChange: func(s *providerStore) { + // Same host, new provider name: the live router still + // MITMs alpha.invalid, but as "alpha-v2". The stale + // CONNECT-time name "alpha" no longer matches. + s.set([]rawProvider{ + {name: "alpha-v2", baseURL: "https://alpha.invalid/v1"}, + }) + }, + changeDescription: "after alpha was renamed to alpha-v2", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + recorder := &aibridgedRecorder{} + bridged := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recorder.record(r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("aibridged")) + })) + t.Cleanup(bridged.Close) + + store := &providerStore{} + store.set([]rawProvider{ + {name: "alpha", baseURL: "https://alpha.invalid/v1"}, + }) + + // newTestProxy seeds the router from the store via the + // initial Reload, so the first CONNECT is MITM'd as alpha. + srv := newTestProxy(t, + withCoderAccessURL(bridged.URL), + withAllowedPorts("443"), + withRefreshProviders(store.refresh), + ) + + certPool := getProxyCertPool(t) + client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false) + // Keep-alives are required: the regression exists only when a + // subsequent request reuses the original CONNECT tunnel. A fresh + // CONNECT would correctly observe the post-reload router. + transport := client.Transport.(*http.Transport) + transport.DisableKeepAlives = false + transport.MaxConnsPerHost = 1 + transport.MaxIdleConnsPerHost = 1 + + sendThroughTunnel := func(path string) (status int, err error) { + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitShort) + defer cancel() + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, "https://alpha.invalid"+path, strings.NewReader(`{}`)) + require.NoError(t, reqErr) + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return 0, err + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + return resp.StatusCode, nil + } + + // First request: alpha is enabled, the proxy MITMs and routes to + // aibridged under the alpha namespace. + recorder.reset() + status, err := sendThroughTunnel("/v1/messages") + require.NoError(t, err) + require.Equal(t, http.StatusOK, status) + require.Equal(t, "/api/v2/aibridge/alpha/v1/messages", recorder.load(), + "first request must be routed to aibridged while alpha is enabled") + + // Apply the provider change and reload. The atomic router swap + // takes effect immediately, but the client's connection (and + // the proxy's hijacked tunnel) remain open. + tc.applyChange(store) + require.NoError(t, srv.Reload(t.Context())) + + // Second request on the same tunnel: aibridged must NOT see it. + // The connection is hijacked so the request reaches the proxy's + // handleRequest with the stale CONNECT-time provider; the fix + // re-validates against the live router and passes through to + // the original upstream (alpha.invalid, which fails DNS). + recorder.reset() + _, _ = sendThroughTunnel("/v1/should-not-route") + require.Empty(t, recorder.load(), + "%s, aibridged must not receive the request even on a reused tunnel", tc.changeDescription) + }) + } +} + // TestProxy_HotReloadRoutingCRUD drives the proxy through a CRUD-style // sequence of provider changes and asserts on routing after each -// Reload via real HTTPS requests. Each sub-test mutates the store and -// validates that: -// - newly created providers are MITM'd to aibridged with the right -// /api/v2/aibridge// -// - renamed providers route under the new name -// - providers whose BaseURL host changes route the new host and stop -// MITM'ing the old host -// - deleted providers stop being MITM'd; aibridged sees nothing +// Reload via real HTTPS requests. // // Hostnames are .invalid (RFC 2606) so a request that escapes the MITM // path fails fast via DNS rather than reaching a real upstream. @@ -210,53 +394,62 @@ func TestProxy_HotReloadRoutingCRUD(t *testing.T) { h := newReloadTestHarness(t) - // InitialEmptyRouter: no Reload has been called and the boot - // allowlist is empty, so any host falls through to the tunneled + // InitialEmptyRouter: no Reload has been called and no provider + // hosts are configured, so any host falls through to the tunneled // middleware. h.expectNotRouted(t, "https://alpha.invalid/v1/messages") // CreateProvider. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "alpha", BaseURL: "https://alpha.invalid/v1"}, + h.store.set([]rawProvider{ + {name: "alpha", baseURL: "https://alpha.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") + h.expectProviderStatus(t, "alpha", "enabled") // UpdateProviderName: the same BaseURL with a new name must route - // under the new name on the next Reload. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "alpha-v2", BaseURL: "https://alpha.invalid/v1"}, + // under the new name on the next Reload. The renamed provider must + // not leave a stale alpha series behind. + h.store.set([]rawProvider{ + {name: "alpha-v2", baseURL: "https://alpha.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages") + h.expectProviderStatus(t, "alpha-v2", "enabled") + h.expectProviderAbsent(t, "alpha") // UpdateProviderBaseURLHost: moving the provider to a new host must // start MITM'ing the new host and stop MITM'ing the old one. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"}, + h.store.set([]rawProvider{ + {name: "alpha-v2", baseURL: "https://alpha-new.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages") h.expectNotRouted(t, "https://alpha.invalid/v1/messages") + h.expectProviderStatus(t, "alpha-v2", "enabled") // AddSecondProvider: a second provider added in the same Reload must // route independently from the first. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"}, - {Name: "beta", BaseURL: "https://beta.invalid/v1"}, + h.store.set([]rawProvider{ + {name: "alpha-v2", baseURL: "https://alpha-new.invalid/v1"}, + {name: "beta", baseURL: "https://beta.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages") h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions") + h.expectProviderStatus(t, "alpha-v2", "enabled") + h.expectProviderStatus(t, "beta", "enabled") // DeleteOneProvider: removing alpha must keep beta routed and stop - // routing alpha. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "beta", BaseURL: "https://beta.invalid/v1"}, + // routing alpha. The deleted name disappears from provider_info. + h.store.set([]rawProvider{ + {name: "beta", baseURL: "https://beta.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions") h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages") + h.expectProviderStatus(t, "beta", "enabled") + h.expectProviderAbsent(t, "alpha-v2") // DeleteAllProviders: an empty Reload must collapse the router to // the fail-closed state with no host MITM'd. @@ -264,15 +457,21 @@ func TestProxy_HotReloadRoutingCRUD(t *testing.T) { require.NoError(t, h.srv.Reload(t.Context())) h.expectNotRouted(t, "https://beta.invalid/v1/chat/completions") h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages") + h.expectProviderAbsent(t, "beta") // RecreateAfterDelete: reintroducing a previously-deleted provider // must route again without restart, confirming the swap is // symmetric. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "alpha", BaseURL: "https://alpha.invalid/v1"}, + h.store.set([]rawProvider{ + {name: "alpha", baseURL: "https://alpha.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") + h.expectProviderStatus(t, "alpha", "enabled") + + // Both timestamp gauges must have advanced through this sequence. + assert.Positive(t, promtest.ToFloat64(h.metrics.ProvidersLastReloadTimestampSeconds)) + assert.Positive(t, promtest.ToFloat64(h.metrics.ProvidersLastReloadSuccessTimestampSeconds)) } // TestProxy_HotReloadRoutingInvalidProviders covers the resilience @@ -288,15 +487,17 @@ func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) { h := newReloadTestHarness(t) // One valid provider and one with an empty BaseURL. The empty - // entry must be silently dropped; the valid one must still - // route. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "no-url"}, - {Name: "valid", BaseURL: "https://valid.invalid/v1"}, + // entry must be classified as error and excluded from routing; + // the valid one must still route. + h.store.set([]rawProvider{ + {name: "no-url"}, + {name: "valid", baseURL: "https://valid.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages") + h.expectProviderStatus(t, "no-url", "error") + h.expectProviderStatus(t, "valid", "enabled") }) t.Run("MalformedBaseURLSkipped", func(t *testing.T) { @@ -304,31 +505,36 @@ func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) { h := newReloadTestHarness(t) // A BaseURL that fails url.Parse and one whose Hostname() is - // empty must both be dropped. Mixed with a valid entry, only - // the valid one routes. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "malformed", BaseURL: "://not-a-url"}, - {Name: "no-host", BaseURL: "https://"}, - {Name: "valid", BaseURL: "https://valid.invalid/v1"}, + // empty must both be classified as error. Mixed with a valid + // entry, only the valid one routes. + h.store.set([]rawProvider{ + {name: "malformed", baseURL: "://not-a-url"}, + {name: "no-host", baseURL: "https://"}, + {name: "valid", baseURL: "https://valid.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages") + h.expectProviderStatus(t, "malformed", "error") + h.expectProviderStatus(t, "no-host", "error") + h.expectProviderStatus(t, "valid", "enabled") }) t.Run("DuplicateHostFirstWins", func(t *testing.T) { t.Parallel() h := newReloadTestHarness(t) - // Two providers with the same BaseURL host: the first one wins, - // matching buildProviderRouter's documented contract. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "first", BaseURL: "https://shared.invalid/v1"}, - {Name: "second", BaseURL: "https://shared.invalid/v2"}, + // Two providers with the same BaseURL host: the second is + // classified as error and excluded; the first routes. + h.store.set([]rawProvider{ + {name: "first", baseURL: "https://shared.invalid/v1"}, + {name: "second", baseURL: "https://shared.invalid/v2"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://shared.invalid/v1/messages", "/api/v2/aibridge/first/v1/messages") + h.expectProviderStatus(t, "first", "enabled") + h.expectProviderStatus(t, "second", "error") }) t.Run("AllInvalidYieldsEmptyRouter", func(t *testing.T) { @@ -337,10 +543,10 @@ func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) { h := newReloadTestHarness(t) // When every provider is invalid, the router contains no // entries and the proxy fails closed: no host is MITM'd. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "no-url"}, - {Name: "malformed", BaseURL: "://not-a-url"}, - {Name: "no-host", BaseURL: "https://"}, + h.store.set([]rawProvider{ + {name: "no-url"}, + {name: "malformed", baseURL: "://not-a-url"}, + {name: "no-host", baseURL: "https://"}, }) require.NoError(t, h.srv.Reload(t.Context())) @@ -352,15 +558,15 @@ func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) { h := newReloadTestHarness(t) // Seed a valid snapshot so we have something to preserve. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "alpha", BaseURL: "https://alpha.invalid/v1"}, + h.store.set([]rawProvider{ + {name: "alpha", baseURL: "https://alpha.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") // A refresh error must NOT clear the router: dropping the - // allowlist on every transient DB hiccup would amplify the - // fault into a denial of service. + // provider host set on every transient DB hiccup would + // amplify the fault into a denial of service. h.store.setErr(xerrors.New("simulated db failure")) err := h.srv.Reload(t.Context()) require.Error(t, err) @@ -369,8 +575,8 @@ func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) { // Recovery: once the store returns providers again, the next // Reload applies the new snapshot. - h.store.set([]aibridgeproxyd.ProviderRoute{ - {Name: "beta", BaseURL: "https://beta.invalid/v1"}, + h.store.set([]rawProvider{ + {name: "beta", baseURL: "https://beta.invalid/v1"}, }) require.NoError(t, h.srv.Reload(t.Context())) h.expectRoutedTo(t, "https://beta.invalid/v1/messages", "/api/v2/aibridge/beta/v1/messages") diff --git a/enterprise/cli/aibridgeproxyd.go b/enterprise/cli/aibridgeproxyd.go index 0f7ba976a5..08641f5769 100644 --- a/enterprise/cli/aibridgeproxyd.go +++ b/enterprise/cli/aibridgeproxyd.go @@ -5,7 +5,9 @@ package cli import ( "context" "io" + "net/url" "path/filepath" + "strings" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" @@ -86,19 +88,69 @@ func newAIBridgeProxyDaemon(coderAPI *coderd.API) (io.Closer, error) { }, nil } +// refreshProxyProviders classifies every ai_providers row as enabled, +// disabled, or error so the proxy router and any observers see the full +// configured set. Disabled rows are excluded from routing; errored rows +// are excluded from routing and surface their failure reason for +// metrics and logs. func refreshProxyProviders(db database.Store) aibridgeproxyd.RefreshProvidersFunc { - return func(ctx context.Context) ([]aibridgeproxyd.ProviderRoute, error) { + return func(ctx context.Context) (aibridgeproxyd.ProviderReload, error) { //nolint:gocritic // AsAIProviderMetadataReader is the correct subject for routing-only access. rows, err := db.GetAIProviders(dbauthz.AsAIProviderMetadataReader(ctx), database.GetAIProvidersParams{ - IncludeDisabled: false, + IncludeDisabled: true, }) if err != nil { - return nil, xerrors.Errorf("load ai providers: %w", err) + return aibridgeproxyd.ProviderReload{}, xerrors.Errorf("load ai providers: %w", err) } - out := make([]aibridgeproxyd.ProviderRoute, 0, len(rows)) + reload := aibridgeproxyd.ProviderReload{ + Providers: make([]aibridgeproxyd.ReloadedProvider, 0, len(rows)), + } + seenHost := make(map[string]string, len(rows)) for _, row := range rows { - out = append(out, aibridgeproxyd.ProviderRoute{Name: row.Name, BaseURL: row.BaseUrl}) + reload.Providers = append(reload.Providers, classifyProviderRow(row, seenHost)) } - return out, nil + return reload, nil } } + +// classifyProviderRow evaluates a single ai_providers row for routing. +// seenHost is mutated to track the first provider that claimed each +// hostname so later duplicates can be flagged as errors. +func classifyProviderRow(row database.AIProvider, seenHost map[string]string) aibridgeproxyd.ReloadedProvider { + out := aibridgeproxyd.ReloadedProvider{ + ProviderOutcome: aibridged.ProviderOutcome{ + Name: row.Name, + Type: string(row.Type), + }, + } + if !row.Enabled { + out.Status = aibridged.ProviderStatusDisabled + return out + } + if strings.TrimSpace(row.BaseUrl) == "" { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.New("base url is empty") + return out + } + u, err := url.Parse(row.BaseUrl) + if err != nil { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.Errorf("invalid base url %q: %w", row.BaseUrl, err) + return out + } + host := strings.ToLower(u.Hostname()) + if host == "" { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.Errorf("base url %q has no hostname", row.BaseUrl) + return out + } + if claimedBy, taken := seenHost[host]; taken { + out.Status = aibridged.ProviderStatusError + out.Err = xerrors.Errorf("hostname %q already claimed by provider %q", host, claimedBy) + return out + } + seenHost[host] = row.Name + out.Host = host + out.Status = aibridged.ProviderStatusEnabled + return out +} diff --git a/enterprise/cli/aibridgeproxyd_internal_test.go b/enterprise/cli/aibridgeproxyd_internal_test.go new file mode 100644 index 0000000000..2c8520878b --- /dev/null +++ b/enterprise/cli/aibridgeproxyd_internal_test.go @@ -0,0 +1,105 @@ +//go:build !slim + +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/v2/coderd/aibridged" + "github.com/coder/coder/v2/coderd/database" +) + +// TestClassifyProviderRow covers every branch of the classifier so the +// disabled, error, and enabled paths are exercised through the +// production code instead of relying on classifyRaw, the test mirror in +// reload_test.go. +func TestClassifyProviderRow(t *testing.T) { + t.Parallel() + + enabledRow := func(name, baseURL string) database.AIProvider { + return database.AIProvider{ + Name: name, + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: baseURL, + } + } + + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + got := classifyProviderRow(enabledRow("openai", "https://api.openai.com/v1"), seen) + assert.Equal(t, "openai", got.Name) + assert.Equal(t, string(database.AiProviderTypeOpenai), got.Type) + assert.Equal(t, aibridged.ProviderStatusEnabled, got.Status) + assert.Equal(t, "api.openai.com", got.Host) + assert.NoError(t, got.Err) + assert.Equal(t, "openai", seen["api.openai.com"]) + }) + + t.Run("DisabledRow", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + row := enabledRow("off", "https://api.off.example.com/v1") + row.Enabled = false + got := classifyProviderRow(row, seen) + assert.Equal(t, aibridged.ProviderStatusDisabled, got.Status) + assert.Empty(t, got.Host, "disabled provider must not claim a host") + assert.NoError(t, got.Err) + assert.Empty(t, seen, "disabled provider must not occupy a host slot") + }) + + t.Run("EmptyBaseURL", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + got := classifyProviderRow(enabledRow("no-url", " "), seen) + assert.Equal(t, aibridged.ProviderStatusError, got.Status) + assert.Empty(t, got.Host) + assert.ErrorContains(t, got.Err, "base url is empty") + }) + + t.Run("MalformedBaseURL", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + got := classifyProviderRow(enabledRow("bad", "://not-a-url"), seen) + assert.Equal(t, aibridged.ProviderStatusError, got.Status) + assert.ErrorContains(t, got.Err, "invalid base url") + }) + + t.Run("BaseURLWithoutHostname", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + got := classifyProviderRow(enabledRow("no-host", "https://"), seen) + assert.Equal(t, aibridged.ProviderStatusError, got.Status) + assert.ErrorContains(t, got.Err, "no hostname") + }) + + t.Run("DuplicateHostnameFirstWins", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + first := classifyProviderRow(enabledRow("first", "https://shared.example.com/v1"), seen) + assert.Equal(t, aibridged.ProviderStatusEnabled, first.Status) + + second := classifyProviderRow(enabledRow("second", "https://shared.example.com/v2"), seen) + assert.Equal(t, aibridged.ProviderStatusError, second.Status) + assert.ErrorContains(t, second.Err, "already claimed by provider \"first\"") + assert.Equal(t, "first", seen["shared.example.com"], "first wins must not be overwritten") + }) + + t.Run("HostnameLowercased", func(t *testing.T) { + t.Parallel() + + seen := map[string]string{} + got := classifyProviderRow(enabledRow("mixed", "https://API.Example.COM/v1"), seen) + assert.Equal(t, aibridged.ProviderStatusEnabled, got.Status) + assert.Equal(t, "api.example.com", got.Host) + }) +} diff --git a/enterprise/cli/exp_scaletest_agentfake.go b/enterprise/cli/exp_scaletest_agentfake.go index a6c2e88649..cbfca70897 100644 --- a/enterprise/cli/exp_scaletest_agentfake.go +++ b/enterprise/cli/exp_scaletest_agentfake.go @@ -68,7 +68,7 @@ func (r *RootCmd) scaletestAgentFake() *serpent.Command { } logger := inv.Logger - mgr := agentfake.NewManager(client, logger, agentfake.ManagerOptions{ + mgr := agentfake.NewManager(client.URL, client, logger, agentfake.ManagerOptions{ Template: template, Owner: owner, }) diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index c77a03a0b4..37febd028b 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -95,6 +95,7 @@ func (r *RootCmd) Server(_ func()) *serpent.Command { ConnectionLogging: true, BrowserOnly: options.DeploymentValues.BrowserOnly.Value(), SCIMAPIKey: []byte(options.DeploymentValues.SCIMAPIKey.Value()), + UseLegacySCIM: options.DeploymentValues.UseLegacySCIM.Value(), RBAC: true, DERPServerRelayAddress: options.DeploymentValues.DERP.Server.RelayURL.String(), DERPServerRegionID: int(options.DeploymentValues.DERP.Server.RegionID.Value()), diff --git a/enterprise/coderd/aibridge.go b/enterprise/coderd/aibridge.go index 9773dba352..8a220760de 100644 --- a/enterprise/coderd/aibridge.go +++ b/enterprise/coderd/aibridge.go @@ -43,6 +43,11 @@ const ( // reference a valid resource in the expected scope. var errInvalidCursor = xerrors.New("invalid pagination cursor") +// This name is raised by a trigger function with USING CONSTRAINT. +// It is not a table CHECK constraint, so dbgen does not emit it in +// check_constraint.go. +const userAIBudgetOverridesMustBeGroupMemberConstraint database.CheckConstraint = "user_ai_budget_overrides_must_be_group_member" + // aibridgeHandler handles all aibridged-related endpoints. func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) func(r chi.Router) { // Build the overload protection middleware chain for the aibridged handler. @@ -821,3 +826,116 @@ func (api *API) deleteGroupAIBudget(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusNoContent) } + +// @Summary Get user AI budget override +// @ID get-user-ai-budget-override +// @Security CoderSessionToken +// @Produce json +// @Tags Enterprise +// @Param user path string true "User ID, username, or me" +// @Success 200 {object} codersdk.UserAIBudgetOverride +// @Router /api/v2/users/{user}/ai/budget [get] +func (api *API) userAIBudgetOverride(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + user := httpmw.UserParam(r) + + override, err := api.Database.GetUserAIBudgetOverride(ctx, user.ID) + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + if err != nil { + api.Logger.Error(ctx, "get user AI budget override", slog.Error(err)) + httpapi.InternalServerError(rw, err) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserAIBudgetOverride(override)) +} + +// @Summary Upsert user AI budget override +// @ID upsert-user-ai-budget-override +// @Security CoderSessionToken +// @Accept json +// @Produce json +// @Tags Enterprise +// @Param user path string true "User ID, username, or me" +// @Param request body codersdk.UpsertUserAIBudgetOverrideRequest true "Upsert user AI budget override request" +// @Success 200 {object} codersdk.UserAIBudgetOverride +// @Router /api/v2/users/{user}/ai/budget [put] +func (api *API) upsertUserAIBudgetOverride(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + user := httpmw.UserParam(r) + + var req codersdk.UpsertUserAIBudgetOverrideRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + // Look up the group first so a missing or forbidden group_id returns + // 404, distinct from the 400 "not a member" case handled below. + if _, err := api.Database.GetGroupByID(ctx, req.GroupID); err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + api.Logger.Error(ctx, "get group for user AI budget override", slog.Error(err)) + httpapi.InternalServerError(rw, err) + return + } + + override, err := api.Database.UpsertUserAIBudgetOverride(ctx, database.UpsertUserAIBudgetOverrideParams{ + UserID: user.ID, + GroupID: req.GroupID, + SpendLimitMicros: req.SpendLimitMicros, + }) + // A trigger enforces that the user must be a member of the attributed + // group; it raises check_violation with this constraint name. Map + // the violation to a structured 400. + if database.IsCheckViolation(err, userAIBudgetOverridesMustBeGroupMemberConstraint) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "User is not a member of the referenced group.", + Validations: []codersdk.ValidationError{{ + Field: "group_id", + Detail: "user must be a member of this group", + }}, + }) + return + } + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + if err != nil { + api.Logger.Error(ctx, "upsert user AI budget override", slog.Error(err)) + httpapi.InternalServerError(rw, err) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserAIBudgetOverride(override)) +} + +// @Summary Delete user AI budget override +// @ID delete-user-ai-budget-override +// @Security CoderSessionToken +// @Tags Enterprise +// @Param user path string true "User ID, username, or me" +// @Success 204 +// @Router /api/v2/users/{user}/ai/budget [delete] +func (api *API) deleteUserAIBudgetOverride(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + user := httpmw.UserParam(r) + + _, err := api.Database.DeleteUserAIBudgetOverride(ctx, user.ID) + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + if err != nil { + api.Logger.Error(ctx, "delete user AI budget override", slog.Error(err)) + httpapi.InternalServerError(rw, err) + return + } + + rw.WriteHeader(http.StatusNoContent) +} diff --git a/enterprise/coderd/aibridge_reload_test.go b/enterprise/coderd/aibridge_reload_test.go index 6aee3afa91..e3370c8f7d 100644 --- a/enterprise/coderd/aibridge_reload_test.go +++ b/enterprise/coderd/aibridge_reload_test.go @@ -9,6 +9,8 @@ import ( "sync/atomic" "testing" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" @@ -54,7 +56,7 @@ func newMockUpstream(t *testing.T, name string) *mockUpstream { // the supplied API and subscribes it to ai_providers change events. // This mirrors what cli/server.go does in production so /api/v2/aibridge // requests dispatch through the real pool and reloader. -func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) { +func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) *aibridged.Metrics { t.Helper() ctx := context.Background() @@ -62,14 +64,15 @@ func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) { cfg := api.DeploymentValues.AI.BridgeConfig tracer := otel.Tracer("aibridge-reload-test") - providers, err := cli.BuildProviders(ctx, api.Database, cfg, logger) + providers, _, err := cli.BuildProviders(ctx, api.Database, cfg, logger) require.NoError(t, err) pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), nil, tracer) require.NoError(t, err) t.Cleanup(func() { _ = pool.Shutdown(context.Background()) }) - reloader := &testPoolReloader{pool: pool, db: api.Database, cfg: cfg, logger: logger.Named("reloader")} + metrics := aibridged.NewMetrics(prometheus.NewRegistry()) + reloader := &testPoolReloader{pool: pool, db: api.Database, cfg: cfg, logger: logger.Named("reloader"), metrics: metrics} unsubscribe, err := aibridged.SubscribeProviderReload(ctx, api.Pubsub, reloader, logger.Named("subscriber")) require.NoError(t, err) t.Cleanup(unsubscribe) @@ -81,21 +84,25 @@ func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) { t.Cleanup(func() { _ = srv.Close() }) api.RegisterInMemoryAIBridgedHTTPHandler(srv) + return metrics } type testPoolReloader struct { - pool *aibridged.CachedBridgePool - db database.Store - cfg codersdk.AIBridgeConfig - logger slog.Logger + pool *aibridged.CachedBridgePool + db database.Store + cfg codersdk.AIBridgeConfig + logger slog.Logger + metrics *aibridged.Metrics } func (r *testPoolReloader) Reload(ctx context.Context) error { - providers, err := cli.BuildProviders(ctx, r.db, r.cfg, r.logger) + defer r.metrics.RecordReloadAttempt() + providers, outcomes, err := cli.BuildProviders(ctx, r.db, r.cfg, r.logger) if err != nil { return err } r.pool.ReplaceProviders(providers) + r.metrics.RecordReloadSuccess(outcomes) return nil } @@ -124,7 +131,34 @@ func TestAIBridgeProviderHotReload(t *testing.T) { }, }) - startTestAIBridgeDaemon(t, api.AGPL) + metrics := startTestAIBridgeDaemon(t, api.AGPL) + + // requireProviderStatus polls until the provider_info series for + // (name, status) settles to value 1. Reloads happen via pubsub, so + // the assertion has to be eventual. + requireProviderStatus := func(t *testing.T, name, status string) { + t.Helper() + require.Eventuallyf(t, func() bool { + return promtest.ToFloat64(metrics.ProviderInfo.WithLabelValues(name, "openai", status)) == 1 + }, testutil.WaitShort, testutil.IntervalFast, + "expected provider_info{provider_name=%q, status=%q} == 1", name, status) + } + + // requireProviderAbsent polls until no series exists for the + // provider name in any status. After a delete the Reset on the + // next reload must clear all previous status labels for the name. + requireProviderAbsent := func(t *testing.T, name string) { + t.Helper() + require.Eventuallyf(t, func() bool { + for _, status := range []string{"enabled", "disabled", "error"} { + if promtest.ToFloat64(metrics.ProviderInfo.WithLabelValues(name, "openai", status)) != 0 { + return false + } + } + return true + }, testutil.WaitShort, testutil.IntervalFast, + "expected provider_info series for %q to be cleared after delete", name) + } ctx := testutil.Context(t, testutil.WaitLong) @@ -177,6 +211,18 @@ func TestAIBridgeProviderHotReload(t *testing.T) { "expected provider %q to stop routing", providerName) } + // requireDisabledSentinel polls until the provider name yields a + // 503 with the provider_disabled body, indicating the disabled + // handler is wired up for the row. + requireDisabledSentinel := func(t *testing.T, providerName string) { + t.Helper() + require.Eventuallyf(t, func() bool { + status, _ := sendRequest(providerName) + return status == http.StatusServiceUnavailable + }, testutil.WaitShort, testutil.IntervalFast, + "expected provider %q to serve the disabled sentinel", providerName) + } + // 1. Create: provider points at upstream A. created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ Type: codersdk.AIProviderTypeOpenAI, @@ -188,6 +234,7 @@ func TestAIBridgeProviderHotReload(t *testing.T) { require.NoError(t, err) require.Equal(t, "primary", created.Name) requireRoutesTo(t, "primary", upstreamA) + requireProviderStatus(t, "primary", "enabled") // 2. Update BaseURL: same name, now points at upstream B. newBaseURL := upstreamB.server.URL @@ -196,15 +243,17 @@ func TestAIBridgeProviderHotReload(t *testing.T) { }) require.NoError(t, err) requireRoutesTo(t, "primary", upstreamB) + requireProviderStatus(t, "primary", "enabled") - // 3. Disable: the provider drops out of the snapshot, requests - // stop reaching any upstream. + // 3. Disable: requests stop reaching upstream and the bridge + // answers with the 503 sentinel. The metric flips to "disabled". disabled := false _, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{ Enabled: &disabled, }) require.NoError(t, err) - requireRoutingGone(t, "primary") + requireDisabledSentinel(t, "primary") + requireProviderStatus(t, "primary", "disabled") // 4. Re-enable: routing comes back at the most recent BaseURL. enabled := true @@ -213,6 +262,7 @@ func TestAIBridgeProviderHotReload(t *testing.T) { }) require.NoError(t, err) requireRoutesTo(t, "primary", upstreamB) + requireProviderStatus(t, "primary", "enabled") // 5. Add a second provider; both names must route independently. _, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ @@ -225,9 +275,19 @@ func TestAIBridgeProviderHotReload(t *testing.T) { require.NoError(t, err) requireRoutesTo(t, "primary", upstreamB) requireRoutesTo(t, "secondary", upstreamA) + requireProviderStatus(t, "primary", "enabled") + requireProviderStatus(t, "secondary", "enabled") - // 6. Delete primary: only secondary remains routable. + // 6. Delete primary: only secondary remains routable. The + // provider_info series for primary disappears entirely on the + // next reload's Reset. require.NoError(t, client.DeleteAIProvider(ctx, "primary")) requireRoutingGone(t, "primary") requireRoutesTo(t, "secondary", upstreamA) + requireProviderAbsent(t, "primary") + requireProviderStatus(t, "secondary", "enabled") + + // Both timestamp gauges must have advanced during this test. + assert.Positive(t, promtest.ToFloat64(metrics.ProvidersLastReloadTimestampSeconds)) + assert.Positive(t, promtest.ToFloat64(metrics.ProvidersLastReloadSuccessTimestampSeconds)) } diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 158f682842..1faadd1f53 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -2871,6 +2871,447 @@ func TestGroupAIBudget(t *testing.T) { }) } +func TestUserAIBudgetOverride(t *testing.T) { + t.Parallel() + + t.Run("Upsert/CreatesAndUpdates", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, group := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // First upsert creates the override. + newOverride, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err) + require.Equal(t, targetUser.ID, newOverride.UserID) + require.Equal(t, group.ID, newOverride.GroupID) + require.EqualValues(t, 500_000_000, newOverride.SpendLimitMicros) + + // Second upsert updates the existing override. + updatedOverride, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: 1_000_000_000, + }) + require.NoError(t, err) + require.EqualValues(t, 1_000_000_000, updatedOverride.SpendLimitMicros) + + // GET returns the latest value. + currentOverride, err := adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err) + require.EqualValues(t, 1_000_000_000, currentOverride.SpendLimitMicros) + }) + + t.Run("Upsert/ReassignsGroup", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, groupA := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // First upsert: attribute spend to groupA. + _, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: groupA.ID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err) + + // Create groupB in the same org and add the target user. + groupB, err := adminClient.CreateGroup(ctx, targetUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ + Name: "reassign-test-group-b", + }) + require.NoError(t, err) + _, err = adminClient.PatchGroup(ctx, groupB.ID, codersdk.PatchGroupRequest{ + AddUsers: []string{targetUser.ID.String()}, + }) + require.NoError(t, err) + + // Reassign the override's attribution to groupB. + updated, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: groupB.ID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err) + require.Equal(t, groupB.ID, updated.GroupID, "upsert should change attributed group") + + // GET reflects the new group. + got, err := adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err) + require.Equal(t, groupB.ID, got.GroupID, "GET should reflect new group") + }) + + t.Run("Upsert/EveryoneGroup", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, _ := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // The Everyone group has id == organization_id, and the target user + // is implicitly a member via organization_members rather than + // group_members. The membership trigger queries + // group_members_expanded (a UNION of both tables), so this case + // exercises the organization_members branch. + everyoneGroupID := targetUser.OrganizationIDs[0] + + override, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: everyoneGroupID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err, "should be able to attribute override to Everyone group") + require.Equal(t, targetUser.ID, override.UserID) + require.Equal(t, everyoneGroupID, override.GroupID) + require.EqualValues(t, 500_000_000, override.SpendLimitMicros) + }) + + t.Run("Upsert/AcceptsZeroSpendLimit", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, group := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // 0 is a valid value: it blocks all spend for the user. + override, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: 0, + }) + require.NoError(t, err) + require.EqualValues(t, 0, override.SpendLimitMicros) + }) + + t.Run("Upsert/RejectsNegativeSpend", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, group := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: -1, + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("Upsert/RejectsUnknownGroup", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, _ := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // A group_id that doesn't exist (or that the caller can't see) + // is rejected by the visibility check before the membership check. + _, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: uuid.New(), + SpendLimitMicros: 500_000_000, + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) + + t.Run("Upsert/RejectsNonMemberGroup", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, _ := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // Create a second group the target is NOT a member of. + outsiderGroup, err := adminClient.CreateGroup(ctx, targetUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ + Name: "outsider-group", + }) + require.NoError(t, err) + + _, err = adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: outsiderGroup.ID, + SpendLimitMicros: 500_000_000, + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("Get/AbsentReturns404", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, _ := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) + + t.Run("Get/UnknownUserReturns404", func(t *testing.T) { + t.Parallel() + + adminClient, _, _ := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := adminClient.UserAIBudgetOverride(ctx, uuid.New()) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) + + t.Run("Delete/RoundTrip", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, group := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err) + + require.NoError(t, adminClient.DeleteUserAIBudgetOverride(ctx, targetUser.ID)) + + _, err = adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) + + t.Run("Delete/AbsentReturns404", func(t *testing.T) { + t.Parallel() + + adminClient, targetUser, _ := setupUserAIBudgetOverrideTest(t) + ctx := testutil.Context(t, testutil.WaitLong) + + err := adminClient.DeleteUserAIBudgetOverride(ctx, targetUser.ID) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) +} + +// TestUserAIBudgetOverrideRoleAccess verifies the authz matrix for the roles +// expected to interact with user budget overrides: +// +// - Owner / UserAdmin: full CRUD. +// - OrgAdmin / OrgUserAdmin: read-only. Writes require ActionUpdate on the +// User resource (site-scoped), which neither role has. +// +//nolint:tparallel // Subtests run sequentially: they share the same deployment and group, and parallel PatchGroup calls on the same group race. +func TestUserAIBudgetOverrideRoleAccess(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + dv.AI.BridgeConfig.Enabled = serpent.Bool(true) + ownerClient, owner := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{DeploymentValues: dv}, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + codersdk.FeatureAIBridge: 1, + }, + }, + }) + userAdminClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleUserAdmin()) + orgAdminClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.ScopedRoleOrgAdmin(owner.OrganizationID)) + orgUserAdminClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.ScopedRoleOrgUserAdmin(owner.OrganizationID)) + + setupCtx := testutil.Context(t, testutil.WaitLong) + group, err := userAdminClient.CreateGroup(setupCtx, owner.OrganizationID, codersdk.CreateGroupRequest{ + Name: "role-access-group", + }) + require.NoError(t, err) + + cases := []struct { + Name string + Client *codersdk.Client + CanWrite bool + }{ + {Name: "Owner", Client: ownerClient, CanWrite: true}, + {Name: "UserAdmin", Client: userAdminClient, CanWrite: true}, + {Name: "OrgAdmin", Client: orgAdminClient, CanWrite: false}, + {Name: "OrgUserAdmin", Client: orgUserAdminClient, CanWrite: false}, + } + + //nolint:paralleltest // Subtests run sequentially: they share the same deployment and group, and parallel PatchGroup calls on the same group race. + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + // Each case gets a fresh target user. + _, targetUser := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID) + _, err := userAdminClient.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{ + AddUsers: []string{targetUser.ID.String()}, + }) + require.NoError(t, err) + + upsertReq := codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: 500_000_000, + } + + if tc.CanWrite { + // Full CRUD lifecycle. + override, err := tc.Client.UpsertUserAIBudgetOverride(ctx, targetUser.ID, upsertReq) + require.NoError(t, err, "PUT") + require.Equal(t, group.ID, override.GroupID) + + got, err := tc.Client.UserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err, "GET") + require.EqualValues(t, 500_000_000, got.SpendLimitMicros) + + err = tc.Client.DeleteUserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err, "DELETE") + } else { + // PUT rejected. + _, err := tc.Client.UpsertUserAIBudgetOverride(ctx, targetUser.ID, upsertReq) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode(), "PUT") + + // Seed a row via UserAdmin so we can verify read access still works. + _, err = userAdminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, upsertReq) + require.NoError(t, err) + + // GET still works (all roles have ActionRead on User). + got, err := tc.Client.UserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err, "GET") + require.EqualValues(t, 500_000_000, got.SpendLimitMicros) + + // DELETE rejected. + err = tc.Client.DeleteUserAIBudgetOverride(ctx, targetUser.ID) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode(), "DELETE") + } + }) + } +} + +// TestUserAIBudgetOverrideDeletedOnMembershipRemoval verifies that a per-user +// override is deleted automatically when the user loses membership in the +// attributed group. Two paths are exercised: +// +// - RegularGroup: membership stored in group_members; removed via +// PatchGroup with RemoveUsers. +// - EveryoneGroup: membership stored in organization_members; removed +// via DeleteOrganizationMember. +func TestUserAIBudgetOverrideDeletedOnMembershipRemoval(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + dv.AI.BridgeConfig.Enabled = serpent.Bool(true) + ownerClient, owner := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{DeploymentValues: dv}, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + codersdk.FeatureAIBridge: 1, + }, + }, + }) + adminClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleUserAdmin()) + + // "Regular group" means any group except "Everyone". + t.Run("RegularGroup", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + + _, targetUser := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID) + + group, err := adminClient.CreateGroup(ctx, owner.OrganizationID, codersdk.CreateGroupRequest{ + Name: "cascade-regular-group", + }) + require.NoError(t, err) + + _, err = adminClient.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{ + AddUsers: []string{targetUser.ID.String()}, + }) + require.NoError(t, err) + + _, err = adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: group.ID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err, "set override") + + // Sanity-check the override exists. + _, err = adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err, "override should exist before removal") + + _, err = adminClient.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{ + RemoveUsers: []string{targetUser.ID.String()}, + }) + require.NoError(t, err, "remove user from group") + + _, err = adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode(), + "override should be deleted after user is removed from the attributed group") + }) + + t.Run("EveryoneGroup", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + + _, targetUser := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID) + + // The Everyone group has id == organization_id. + everyoneGroupID := owner.OrganizationID + + _, err := adminClient.UpsertUserAIBudgetOverride(ctx, targetUser.ID, codersdk.UpsertUserAIBudgetOverrideRequest{ + GroupID: everyoneGroupID, + SpendLimitMicros: 500_000_000, + }) + require.NoError(t, err, "set override") + + // Sanity-check the override exists. + _, err = adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + require.NoError(t, err, "override should exist before removal") + + err = adminClient.DeleteOrganizationMember(ctx, owner.OrganizationID, targetUser.ID.String()) + require.NoError(t, err, "remove user from organization") + + _, err = adminClient.UserAIBudgetOverride(ctx, targetUser.ID) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode(), + "override should be deleted after user is removed from the organization") + }) +} + +// setupUserAIBudgetOverrideTest returns an Admin client, a target user, and a +// group the target user is a member of. +func setupUserAIBudgetOverrideTest(t *testing.T) (adminClient *codersdk.Client, targetUser codersdk.User, group codersdk.Group) { + t.Helper() + + dv := coderdtest.DeploymentValues(t) + dv.AI.BridgeConfig.Enabled = serpent.Bool(true) + ownerClient, owner := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{DeploymentValues: dv}, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + codersdk.FeatureAIBridge: 1, + }, + }, + }) + adminClient, _ = coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleUserAdmin()) + _, targetUser = coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID) + + ctx := testutil.Context(t, testutil.WaitLong) + g, err := adminClient.CreateGroup(ctx, owner.OrganizationID, codersdk.CreateGroupRequest{ + Name: "override-test-group", + }) + require.NoError(t, err) + g, err = adminClient.PatchGroup(ctx, g.ID, codersdk.PatchGroupRequest{ + AddUsers: []string{targetUser.ID.String()}, + }) + require.NoError(t, err) + return adminClient, targetUser, g +} + // setupGroupAIBudgetTest returns an Admin client along with a newly created group inside it. func setupGroupAIBudgetTest(t *testing.T) (adminClient *codersdk.Client, group codersdk.Group) { t.Helper() diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 82fdd4d27f..33732eea3d 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -596,6 +596,17 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { r.Get("/", api.userQuietHoursSchedule) r.Put("/", api.putUserQuietHoursSchedule) }) + r.Route("/users/{user}/ai/budget", func(r chi.Router) { + // AI cost controls are a paid feature (AI Governance add-on). + r.Use( + api.RequireFeatureMW(codersdk.FeatureAIBridge), + apiKeyMiddleware, + httpmw.ExtractUserParam(options.Database), + ) + r.Get("/", api.userAIBudgetOverride) + r.Put("/", api.upsertUserAIBudgetOverride) + r.Delete("/", api.deleteUserAIBudgetOverride) + }) r.Route("/prebuilds", func(r chi.Router) { r.Use( apiKeyMiddleware, @@ -622,40 +633,12 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { }) }) - if len(options.SCIMAPIKey) != 0 { - api.AGPL.RootHandler.Route("/scim/v2", func(r chi.Router) { - r.Use( - api.RequireFeatureMW(codersdk.FeatureSCIM), - ) - r.Get("/ServiceProviderConfig", api.scimServiceProviderConfig) - r.Post("/Users", api.scimPostUser) - r.Route("/Users", func(r chi.Router) { - r.Get("/", api.scimGetUsers) - r.Post("/", api.scimPostUser) - r.Get("/{id}", api.scimGetUser) - r.Patch("/{id}", api.scimPatchUser) - r.Put("/{id}", api.scimPutUser) - }) - r.NotFound(func(w http.ResponseWriter, r *http.Request) { - u := r.URL.String() - httpapi.Write(r.Context(), w, http.StatusNotFound, codersdk.Response{ - Message: fmt.Sprintf("SCIM endpoint %s not found", u), - Detail: "This endpoint is not implemented. If it is correct and required, please contact support.", - }) - }) - }) - } else { - // Show a helpful 404 error. Because this is not under the /api/v2 routes, - // the frontend is the fallback. A html page is not a helpful error for - // a SCIM provider. This JSON has a call to action that __may__ resolve - // the issue. - // Using Mount to cover all subroute possibilities. - api.AGPL.RootHandler.Mount("/scim/v2", http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - httpapi.Write(r.Context(), w, http.StatusNotFound, codersdk.Response{ - Message: "SCIM is disabled, please contact your administrator if you believe this is an error", - Detail: "SCIM endpoints are disabled if no SCIM is configured. Configure 'CODER_SCIM_AUTH_HEADER' to enable.", - }) - }))) + var mountScimError error + api.AGPL.RootHandler.Route("/scim", func(r chi.Router) { + mountScimError = api.mountScimRoute(options, r) + }) + if mountScimError != nil { + return nil, xerrors.Errorf("mount scim routes: %w", mountScimError) } // We always want to run the replica manager even if we don't have DERP @@ -754,6 +737,11 @@ type Options struct { // Whether to block non-browser connections. BrowserOnly bool SCIMAPIKey []byte + // UseLegacySCIM opts into the legacy SCIM handler implementation + // (imulab/go-scim based). This is provided for backward compatibility + // during the transition to the new elimity-com/scim implementation. + // It will be removed in a future release. + UseLegacySCIM bool ExternalTokenEncryption []dbcrypt.Cipher diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index a7efd1b302..1115ba1211 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -67,6 +67,7 @@ type Options struct { BrowserOnly bool EntitlementsUpdateInterval time.Duration SCIMAPIKey []byte + UseLegacySCIM bool UserWorkspaceQuota int ProxyHealthInterval time.Duration LicenseOptions *LicenseOptions @@ -108,6 +109,7 @@ func NewWithAPI(t *testing.T, options *Options) ( AuditLogging: options.AuditLogging, BrowserOnly: options.BrowserOnly, SCIMAPIKey: options.SCIMAPIKey, + UseLegacySCIM: options.UseLegacySCIM, DERPServerRelayAddress: serverURL.String(), DERPServerRegionID: int(oop.DeploymentValues.DERP.Server.RegionID.Value()), ReplicaSyncUpdateInterval: options.ReplicaSyncUpdateInterval, diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/legacyscim/legacyscim.go similarity index 65% rename from enterprise/coderd/scim.go rename to enterprise/coderd/legacyscim/legacyscim.go index 5d0b248abd..942a78dd83 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/legacyscim/legacyscim.go @@ -1,4 +1,14 @@ -package coderd +// Package legacyscim preserves the old imulab/go-scim based SCIM handler. +// It was added in May 2026 to keep an opt-out path available during the +// rollout of the new SCIM 2.0 implementation in +// enterprise/coderd/scim. Once that implementation has run in production +// for a while and the CODER_SCIM_USE_LEGACY default is flipped, remove +// this package in its entirety. +// +// Enabled via the UseLegacySCIM option. +// +// Deprecated: Use the enterprise/coderd/scim package instead. +package legacyscim import ( "bytes" @@ -6,6 +16,8 @@ import ( "database/sql" "encoding/json" "net/http" + "net/url" + "sync/atomic" "time" "github.com/go-chi/chi/v5" @@ -16,17 +28,64 @@ import ( "github.com/imulab/go-scim/pkg/v2/spec" "golang.org/x/xerrors" + "cdr.dev/slog/v3" agpl "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/coderd/scim" ) -func (api *API) scimVerifyAuthHeader(r *http.Request) bool { +// LegacyServer is the old SCIM handler implementation, kept for backward +// compatibility. It uses the imulab/go-scim library and custom JSON handling. +type LegacyServer struct { + Logger slog.Logger + Database database.Store + IDPSync idpsync.IDPSync + AGPL *agpl.API + AccessURL *url.URL + SCIMAPIKey []byte + Auditor *atomic.Pointer[audit.Auditor] +} + +// Handler returns an http.Handler that serves the legacy SCIM endpoints. +// It should be mounted at /scim/v2. +func (s *LegacyServer) Handler() http.Handler { + r := chi.NewRouter() + r.Get("/ServiceProviderConfig", s.scimServiceProviderConfig) + r.Post("/Users", s.scimPostUser) + r.Route("/Users", func(r chi.Router) { + r.Get("/", s.scimGetUsers) + r.Post("/", s.scimPostUser) + r.Get("/{id}", s.scimGetUser) + r.Patch("/{id}", s.scimPatchUser) + r.Put("/{id}", s.scimPutUser) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + u := r.URL.String() + httpapi.Write(r.Context(), w, http.StatusNotFound, codersdk.Response{ + Message: "SCIM endpoint not found: " + u, + Detail: "This endpoint is not implemented. If it is correct and required, please contact support.", + }) + }) + return r +} + +// AuthMiddleware verifies the SCIM Bearer token. +func (s *LegacyServer) AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if !s.scimVerifyAuthHeader(r) { + scimUnauthorized(rw) + return + } + next.ServeHTTP(rw, r) + }) +} + +func (s *LegacyServer) scimVerifyAuthHeader(r *http.Request) bool { bearer := []byte("bearer ") hdr := []byte(r.Header.Get("Authorization")) @@ -35,11 +94,11 @@ func (api *API) scimVerifyAuthHeader(r *http.Request) bool { hdr = hdr[len(bearer):] } - return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1 + return len(s.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, s.SCIMAPIKey) == 1 } func scimUnauthorized(rw http.ResponseWriter) { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization"))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization"))) } // scimServiceProviderConfig returns a static SCIM service provider configuration. @@ -50,7 +109,7 @@ func scimUnauthorized(rw http.ResponseWriter) { // @Tags Enterprise // @Success 200 // @Router /scim/v2/ServiceProviderConfig [get] -func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Request) { +func (s *LegacyServer) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Request) { // No auth needed to query this endpoint. rw.Header().Set("Content-Type", spec.ApplicationScimJson) @@ -60,35 +119,35 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques // Increment this time if you make any changes to the provider config. providerUpdated := time.Date(2024, 10, 25, 17, 0, 0, 0, time.UTC) var location string - locURL, err := api.AccessURL.Parse("/scim/v2/ServiceProviderConfig") + locURL, err := s.AccessURL.Parse("/scim/v2/ServiceProviderConfig") if err == nil { location = locURL.String() } enc := json.NewEncoder(rw) enc.SetEscapeHTML(true) - _ = enc.Encode(scim.ServiceProviderConfig{ + _ = enc.Encode(ServiceProviderConfig{ Schemas: []string{"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"}, DocURI: "https://coder.com/docs/admin/users/oidc-auth#scim", - Patch: scim.Supported{ + Patch: Supported{ Supported: true, }, - Bulk: scim.BulkSupported{ + Bulk: BulkSupported{ Supported: false, }, - Filter: scim.FilterSupported{ + Filter: FilterSupported{ Supported: false, }, - ChangePassword: scim.Supported{ + ChangePassword: Supported{ Supported: false, }, - Sort: scim.Supported{ + Sort: Supported{ Supported: false, }, - ETag: scim.Supported{ + ETag: Supported{ Supported: false, }, - AuthSchemes: []scim.AuthenticationScheme{ + AuthSchemes: []AuthenticationScheme{ { Type: "oauthbearertoken", Name: "HTTP Header Authentication", @@ -96,7 +155,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques DocURI: "https://coder.com/docs/admin/users/oidc-auth#scim", }, }, - Meta: scim.ServiceProviderMeta{ + Meta: ServiceProviderMeta{ Created: providerUpdated, LastModified: providerUpdated, Location: location, @@ -118,8 +177,8 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques // @Router /scim/v2/Users [get] // //nolint:revive -func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { - if !api.scimVerifyAuthHeader(r) { +func (s *LegacyServer) scimGetUsers(rw http.ResponseWriter, r *http.Request) { + if !s.scimVerifyAuthHeader(r) { scimUnauthorized(rw) return } @@ -146,13 +205,13 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { // @Router /scim/v2/Users/{id} [get] // //nolint:revive -func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) { - if !api.scimVerifyAuthHeader(r) { +func (s *LegacyServer) scimGetUser(rw http.ResponseWriter, r *http.Request) { + if !s.scimVerifyAuthHeader(r) { scimUnauthorized(rw) return } - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404"))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404"))) } // We currently use our own struct instead of using the SCIM package. This was @@ -193,20 +252,20 @@ var SCIMAuditAdditionalFields = map[string]string{ // @Security Authorization // @Produce json // @Tags Enterprise -// @Param request body coderd.SCIMUser true "New user" -// @Success 200 {object} coderd.SCIMUser +// @Param request body legacyscim.SCIMUser true "New user" +// @Success 200 {object} legacyscim.SCIMUser // @Router /scim/v2/Users [post] -func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { +func (s *LegacyServer) scimPostUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - if !api.scimVerifyAuthHeader(r) { + if !s.scimVerifyAuthHeader(r) { scimUnauthorized(rw) return } - auditor := *api.AGPL.Auditor.Load() + auditor := *s.Auditor.Load() aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, - Log: api.Logger, + Log: s.Logger, Request: r, Action: database.AuditActionCreate, AdditionalFields: SCIMAuditAdditionalFields, @@ -216,12 +275,12 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } if sUser.Active == nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", xerrors.New("active field is required"))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidRequest", xerrors.New("active field is required"))) return } @@ -234,12 +293,12 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } if email == "" { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided"))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided"))) return } //nolint:gocritic - dbUser, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{ + dbUser, err := s.Database.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{ Email: email, Username: sUser.UserName, }) @@ -253,7 +312,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { if *sUser.Active && dbUser.Status == database.UserStatusSuspended { //nolint:gocritic - newUser, err := api.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ + newUser, err := s.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ ID: dbUser.ID, // The user will get transitioned to Active after logging in. Status: database.UserStatusDormant, @@ -295,23 +354,23 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { // This is to preserve single org deployment behavior. organizations := []uuid.UUID{} //nolint:gocritic // SCIM operations are a system user - orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database) + orgSync, err := s.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), s.Database) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err))) return } if orgSync.AssignDefault { //nolint:gocritic // SCIM operations are a system user - defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) + defaultOrganization, err := s.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err))) return } organizations = append(organizations, defaultOrganization.ID) } //nolint:gocritic // needed for SCIM - dbUser, err = api.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), api.Database, agpl.CreateUserRequest{ + dbUser, err = s.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), s.Database, agpl.CreateUserRequest{ CreateUserRequestWithOrgs: codersdk.CreateUserRequestWithOrgs{ Username: sUser.UserName, Email: email, @@ -322,7 +381,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { SkipNotifications: true, }) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err))) return } aReq.New = dbUser @@ -342,20 +401,20 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { // @Produce application/scim+json // @Tags Enterprise // @Param id path string true "User ID" format(uuid) -// @Param request body coderd.SCIMUser true "Update user request" +// @Param request body legacyscim.SCIMUser true "Update user request" // @Success 200 {object} codersdk.User // @Router /scim/v2/Users/{id} [patch] -func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { +func (s *LegacyServer) scimPatchUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - if !api.scimVerifyAuthHeader(r) { + if !s.scimVerifyAuthHeader(r) { scimUnauthorized(rw) return } - auditor := *api.AGPL.Auditor.Load() + auditor := *s.Auditor.Load() aReq, commitAudit := audit.InitRequestWithCancel[database.User](rw, &audit.RequestParams{ Audit: auditor, - Log: api.Logger, + Log: s.Logger, Request: r, Action: database.AuditActionWrite, }) @@ -367,19 +426,19 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } sUser.ID = id uid, err := uuid.Parse(id) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) return } //nolint:gocritic // needed for SCIM - dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) + dbUser, err := s.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) if err != nil { _ = handlerutil.WriteError(rw, err) // internal error return @@ -388,14 +447,14 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { aReq.UserID = dbUser.ID if sUser.Active == nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", xerrors.New("active field is required"))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidRequest", xerrors.New("active field is required"))) return } newStatus := scimUserStatus(dbUser, *sUser.Active) if dbUser.Status != newStatus { //nolint:gocritic // needed for SCIM - userNew, err := api.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ + userNew, err := s.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ ID: dbUser.ID, Status: newStatus, UpdatedAt: dbtime.Now(), @@ -426,20 +485,20 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { // @Produce application/scim+json // @Tags Enterprise // @Param id path string true "User ID" format(uuid) -// @Param request body coderd.SCIMUser true "Replace user request" +// @Param request body legacyscim.SCIMUser true "Replace user request" // @Success 200 {object} codersdk.User // @Router /scim/v2/Users/{id} [put] -func (api *API) scimPutUser(rw http.ResponseWriter, r *http.Request) { +func (s *LegacyServer) scimPutUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - if !api.scimVerifyAuthHeader(r) { + if !s.scimVerifyAuthHeader(r) { scimUnauthorized(rw) return } - auditor := *api.AGPL.Auditor.Load() + auditor := *s.Auditor.Load() aReq, commitAudit := audit.InitRequestWithCancel[database.User](rw, &audit.RequestParams{ Audit: auditor, - Log: api.Logger, + Log: s.Logger, Request: r, Action: database.AuditActionWrite, }) @@ -451,23 +510,23 @@ func (api *API) scimPutUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } sUser.ID = id if sUser.Active == nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", xerrors.New("active field is required"))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidRequest", xerrors.New("active field is required"))) return } uid, err := uuid.Parse(id) if err != nil { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) return } //nolint:gocritic // needed for SCIM - dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) + dbUser, err := s.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) if err != nil { _ = handlerutil.WriteError(rw, err) // internal error return @@ -484,14 +543,14 @@ func (api *API) scimPutUser(rw http.ResponseWriter, r *http.Request) { // TODO: Currently ignoring a lot of the SCIM fields. Coder's SCIM implementation // is very basic and only supports active status changes. if immutabilityViolation(dbUser.Username, sUser.UserName) { - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "mutability", xerrors.Errorf("username is currently an immutable field, and cannot be changed. Current: %s, New: %s", dbUser.Username, sUser.UserName))) + _ = handlerutil.WriteError(rw, NewHTTPError(http.StatusBadRequest, "mutability", xerrors.Errorf("username is currently an immutable field, and cannot be changed. Current: %s, New: %s", dbUser.Username, sUser.UserName))) return } newStatus := scimUserStatus(dbUser, *sUser.Active) if dbUser.Status != newStatus { //nolint:gocritic // needed for SCIM - userNew, err := api.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ + userNew, err := s.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ ID: dbUser.ID, Status: newStatus, UpdatedAt: dbtime.Now(), diff --git a/enterprise/coderd/scim/scimtypes.go b/enterprise/coderd/legacyscim/scimtypes.go similarity index 99% rename from enterprise/coderd/scim/scimtypes.go rename to enterprise/coderd/legacyscim/scimtypes.go index 39e022aa24..c96044befb 100644 --- a/enterprise/coderd/scim/scimtypes.go +++ b/enterprise/coderd/legacyscim/scimtypes.go @@ -1,4 +1,4 @@ -package scim +package legacyscim import ( "encoding/json" diff --git a/enterprise/coderd/scim/expression.go b/enterprise/coderd/scim/expression.go new file mode 100644 index 0000000000..516f6d325f --- /dev/null +++ b/enterprise/coderd/scim/expression.go @@ -0,0 +1,39 @@ +package scim + +import ( + "github.com/scim2/filter-parser/v2" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +// userQuery only supports queries of a singular attribute expression. +// Everything else is rejected. Okta just uses username. +// Eg: username eq "alice" +func userQuery(expr filter.Expression) (database.GetUsersParams, error) { + if expr == nil { + return database.GetUsersParams{}, nil + } + + attrExpr, ok := expr.(*filter.AttributeExpression) + if !ok { + return database.GetUsersParams{}, xerrors.Errorf("expected attribute expression") + } + + attrValue, ok := attrExpr.CompareValue.(string) + if !ok { + return database.GetUsersParams{}, xerrors.Errorf("expected string compare value") + } + + var getUsers database.GetUsersParams + switch attrExpr.AttributePath.AttributeName { + case "userName": + getUsers.ExactUsername = attrValue + case "email": + getUsers.ExactEmail = attrValue + default: + return database.GetUsersParams{}, xerrors.Errorf("unsupported filter attribute: %s", attrExpr.AttributePath.AttributeName) + } + + return getUsers, nil +} diff --git a/enterprise/coderd/scim/scim.go b/enterprise/coderd/scim/scim.go new file mode 100644 index 0000000000..2ef19c1b19 --- /dev/null +++ b/enterprise/coderd/scim/scim.go @@ -0,0 +1,138 @@ +package scim + +import ( + "bytes" + "crypto/subtle" + "encoding/json" + "net/http" + "sync/atomic" + + "github.com/elimity-com/scim" + scimErrors "github.com/elimity-com/scim/errors" + "github.com/elimity-com/scim/optional" + "github.com/elimity-com/scim/schema" + + "cdr.dev/slog/v3" + agpl "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/idpsync" +) + +// Handler wraps the elimity-com/scim library's Server to implement +// SCIM 2.0 endpoints. The library auto-serves /Schemas, /ResourceTypes, +// and /ServiceProviderConfig from schema definitions. +type Handler struct { + opts *Options + srv *scim.Server +} + +// Options holds all the dependencies needed by SCIM resource handlers. +type Options struct { + DB database.Store + Auditor *atomic.Pointer[audit.Auditor] + IDPSync idpsync.IDPSync + Logger slog.Logger + + // AGPL is needed for CreateUser. + AGPL *agpl.API + + // SCIMAPIKey is the bearer token used to authenticate SCIM requests. + SCIMAPIKey []byte +} + +func New(opts *Options) (*Handler, error) { + userHandler := &ResourceUser{ + store: opts.DB, + opts: opts, + } + + args := &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{ + DocumentationURI: optional.NewString("https://coder.com/docs/admin/users/oidc-auth#scim"), + AuthenticationSchemes: []scim.AuthenticationScheme{ + { + Type: scim.AuthenticationTypeOauthBearerToken, + Name: "HTTP Header Authentication", + Description: "Authentication scheme using the Authorization header with the shared token", + // TODO: Add documentation links for these specific docs once they exist. + SpecURI: optional.String{}, + DocumentationURI: optional.String{}, + Primary: true, + }, + }, + MaxResults: 0, + // SupportFiltering is set to false, as all filtering operations are not + // supported. A minimal filtering syntax is supported because Okta seems to + // ignore this field and attempt to filter anyway. + SupportFiltering: false, + SupportPatch: true, + }, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Description: optional.NewString("User Account"), + Endpoint: "/Users", + Schema: schema.CoreUserSchema(), + Handler: userHandler, + SchemaExtensions: nil, + }, + }, + } + + srv, err := scim.NewServer(args) + if err != nil { + return nil, err + } + + return &Handler{ + opts: opts, + srv: &srv, + }, nil +} + +func (s *Handler) authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if !s.verifyAuthHeader(r) { + scimUnauthorized(rw) + return + } + + // All authenticated requests are treated as coming from the SCIM provisioner + //nolint:gocritic // auth header authenticates as this identity + ctx := dbauthz.AsSCIMProvisioner(r.Context()) + r = r.WithContext(ctx) + + next.ServeHTTP(rw, r) + }) +} + +func (s *Handler) Handler() http.Handler { + return s.authMiddleware(s.srv) +} + +func (s *Handler) verifyAuthHeader(r *http.Request) bool { + bearer := []byte("bearer ") + hdr := []byte(r.Header.Get("Authorization")) + + // Case-insensitive comparison of the "Bearer " prefix. + if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 { + hdr = hdr[len(bearer):] + } + + return len(s.opts.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, s.opts.SCIMAPIKey) == 1 +} + +func scimUnauthorized(rw http.ResponseWriter) { + rw.Header().Set("Content-Type", "application/scim+json") + rw.WriteHeader(http.StatusUnauthorized) + // scim error spec: + // https://datatracker.ietf.org/doc/html/rfc7644#section-3.12 + _ = json.NewEncoder(rw).Encode(scimErrors.ScimError{ + ScimType: "", // No scimType exists for unauthorized errors. + Detail: "invalid authorization", + Status: http.StatusUnauthorized, + }) +} diff --git a/enterprise/coderd/scim/users.go b/enterprise/coderd/scim/users.go new file mode 100644 index 0000000000..57d7436b71 --- /dev/null +++ b/enterprise/coderd/scim/users.go @@ -0,0 +1,588 @@ +package scim + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/elimity-com/scim" + scimErrors "github.com/elimity-com/scim/errors" + "github.com/elimity-com/scim/optional" + "github.com/google/uuid" + "golang.org/x/xerrors" + + agpl "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/codersdk" +) + +var _ scim.ResourceHandler = (*ResourceUser)(nil) + +// auditUser emits an audit log for a SCIM operation. This uses +// BackgroundAudit instead of InitRequest because the elimity-com/scim +// library owns the http.ResponseWriter and does not expose it to +// resource handlers. +func (ru *ResourceUser) auditUser(ctx context.Context, r *http.Request, action database.AuditAction, old, changed database.User) { + raw, _ := json.Marshal(map[string]string{ + "automatic_actor": "coder", + "automatic_subsystem": "scim", + }) + auditor := *ru.opts.Auditor.Load() + + // This is a best effort + // TODO: Check X-Forwarded-For and others for proxied requests + ip := r.RemoteAddr + + audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.User]{ + Audit: auditor, + Log: ru.opts.Logger, + UserID: uuid.Nil, // SCIM provisioner, not a real user + Action: action, + Old: old, + New: changed, + IP: ip, + UserAgent: r.UserAgent(), + AdditionalFields: raw, + Status: http.StatusOK, + }) +} + +type ResourceUser struct { + store database.Store + opts *Options +} + +// Create implements scim.ResourceHandler. Creates a new Coder user from +// SCIM attributes, or returns the existing user if a duplicate is found. +func (ru *ResourceUser) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) { + ctx := r.Context() + + // Extract fields from the SCIM attributes. + // Do our best to match what the OIDC signup flow also does. + username, _ := attributeAsString(attributes, "userName") + email := primaryEmail(attributes) + if email == "" { + // email is required + return scim.Resource{}, scimErrors.ScimErrorBadRequest("no primary email provided") + } + + // This comes from userOIDC + // TODO: Ideally this code would be shared between the two places. + usernameValidErr := codersdk.NameValid(username) + if usernameValidErr != nil { + if username == "" { + username = email + } + username = codersdk.UsernameFrom(username) + } + + // TODO: OIDC has optional configuration like `EmailDomain` to reject emails outside a specific domain. + // We should consider whether we want to support that for SCIM as well, and if so, apply that validation here. + + active := true + if a, ok := attribute(attributes, "active"); ok { + v, err := booleanValue(a) + if err != nil { + return scim.Resource{}, scimErrors.ScimErrorBadRequest( + fmt.Sprintf("invalid boolean value for 'active' field: %v", a)) + } + active = v + } + + // Check for existing user by email or username. + dbUser, err := ru.store.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + Email: email, + Username: username, + }) + if err == nil { + // SCIM spec says to return a StatusConflict if the user already exists. + // However, Coder never deletes a user. So suspended **is** deleted. + // If the user is not suspended, we return a conflict. + if dbUser.Status != database.UserStatusSuspended { + return scim.Resource{}, scimErrors.ScimError{ + ScimType: scimErrors.ScimTypeUniqueness, + Detail: fmt.Sprintf("user already exists with email %q or username %q", email, username), + Status: http.StatusConflict, + } + } + + // If the user is suspended, then they might be deleted on the SCIM side. + // We can just update their status and return the user as they exist. + status := scimUserStatus(dbUser, &active) + dbUser, err = ru.updateUserStatus(ctx, r, dbUser, status) + if err != nil { + return scim.Resource{}, err + } + return userResource(dbUser), nil + } + + if !xerrors.Is(err, sql.ErrNoRows) { + // Internal DB errors should be returned. + // ErrNoRows is expected if the user does not exist. + return scim.Resource{}, err + } + + // OIDC login runs org, group, and role sync. SCIM does not have (or not yet) these + // claims. We only need to sync the default organization if that is enabled. + // + // When the user eventually logs in via OIDC, the regular sync will run. + // However, since org sync can be disabled. We need to assign the default org if + // that is how we are configured. + organizations := []uuid.UUID{} + orgSync, err := ru.opts.IDPSync.OrganizationSyncSettings(ctx, ru.store) + if err != nil { + return scim.Resource{}, xerrors.Errorf("get organization sync settings: %w", err) + } + if orgSync.AssignDefault { + // Technically, we could just always assign this. When they eventually log in, + // the org would be removed if necessary. But to avoid confusion of the user + // being in the org before they log in, we apply some intelligence to this guess + // of "Do they belong in the default org". + defaultOrganization, err := ru.store.GetDefaultOrganization(ctx) + if err != nil { + return scim.Resource{}, xerrors.Errorf("get default organization: %w", err) + } + organizations = append(organizations, defaultOrganization.ID) + } + + // CreateUser does InsertOrganizationMember internally, and InsertUser + // implicitly assigns the member role at site scope. The SCIM provisioner + // role cannot assign either, so escalate to a system context for this + // specific call, matching the legacy SCIM handler. + //nolint:gocritic // SCIM bearer token authenticates as the SCIM provisioner; user creation needs broader rights to assign default roles. + dbUser, err = ru.opts.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), ru.store, agpl.CreateUserRequest{ + CreateUserRequestWithOrgs: codersdk.CreateUserRequestWithOrgs{ + Username: username, + Email: email, + OrganizationIDs: organizations, + }, + LoginType: database.LoginTypeOIDC, + // Do not send notifications to user admins; SCIM may call this + // sequentially for many users. + // TODO: Maybe we should spam them anyway? + SkipNotifications: true, + }) + if err != nil { + return scim.Resource{}, xerrors.Errorf("create user: %w", err) + } + + ru.auditUser(ctx, r, database.AuditActionCreate, database.User{}, dbUser) + return userResource(dbUser), nil +} + +// Get implements scim.ResourceHandler. Returns a single user by ID. +func (ru *ResourceUser) Get(r *http.Request, idStr string) (scim.Resource, error) { + ctx := r.Context() + usr, err := ru.user(ctx, idStr) + if err != nil { + return scim.Resource{}, err + } + + return userResource(usr), nil +} + +// GetAll implements scim.ResourceHandler. Returns a paginated list of users. +func (ru *ResourceUser) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) { + ctx := r.Context() + + var qry database.GetUsersParams + if params.FilterValidator != nil { + var err error + qry, err = userQuery(params.FilterValidator.GetFilter()) + if err != nil { + return scim.Page{}, scimErrors.ScimErrorBadRequest(fmt.Sprintf("invalid filter: %v", err)) + } + } + + qry.LimitOpt = int32(params.Count) //nolint:gosec + qry.OffsetOpt = int32(params.StartIndex - 1) //nolint:gosec + + if qry.LimitOpt < 0 { + qry.LimitOpt = 100 + } + + users, err := ru.store.GetUsers(ctx, qry) + if err != nil { + return scim.Page{}, err + } + + totalCount := int64(len(users)) + if len(users) == int(qry.LimitOpt) { + // If the limit is not reached, that is the count + // TODO: If there is a query and the limit is reached, this is inaccurate. + totalCount, err = ru.store.GetUserCount(ctx, false) + if err != nil { + return scim.Page{}, err + } + } + + resources := make([]scim.Resource, 0, len(users)) + for _, u := range users { + resources = append(resources, userResourceFromGetUsersRow(u)) + } + + return scim.Page{ + TotalResults: int(totalCount), + Resources: resources, + }, nil +} + +// Replace implements scim.ResourceHandler (PUT). Replaces user attributes. +// Currently only supports changing the active status per existing behavior. +func (ru *ResourceUser) Replace(r *http.Request, idStr string, attributes scim.ResourceAttributes) (scim.Resource, error) { + ctx := r.Context() + + dbUser, err := ru.user(ctx, idStr) + if err != nil { + return scim.Resource{}, err + } + + // All of our fields except for active are immutable. + if !attributeEqual(dbUser.Username, attributes, "userName") { + return scim.Resource{}, scimErrors.ScimErrorBadRequest(fmt.Sprintf("changing the 'userName' field is not supported (current value: %q)", dbUser.Username)) + } + + // TODO: Check if the primary email has changed. If it has, should we do something? + + activeInterface, ok := attribute(attributes, "active") + if !ok { + return scim.Resource{}, scimErrors.ScimErrorBadRequest("missing required 'active' field") + } + + active, err := booleanValue(activeInterface) + if err != nil { + return scim.Resource{}, scimErrors.ScimErrorBadRequest(fmt.Sprintf("invalid boolean value for 'active' field: %v", activeInterface)) + } + + newStatus := scimUserStatus(dbUser, &active) + dbUser, err = ru.updateUserStatus(ctx, r, dbUser, newStatus) + if err != nil { + return scim.Resource{}, err + } + + return userResource(dbUser), nil +} + +// Delete implements scim.ResourceHandler. Suspends the user (Coder does +// not hard-delete users). +func (ru *ResourceUser) Delete(r *http.Request, idStr string) error { + ctx := r.Context() + + dbUser, err := ru.user(ctx, idStr) + if err != nil { + return err + } + + _, err = ru.updateUserStatus(ctx, r, dbUser, database.UserStatusSuspended) + if err != nil { + return err + } + + return nil +} + +// Patch implements scim.ResourceHandler. Updates user attributes based on +// SCIM PatchOp operations. Currently, supports changing the active status. +func (ru *ResourceUser) Patch(r *http.Request, idStr string, operations []scim.PatchOperation) (scim.Resource, error) { + ctx := r.Context() + + uid, err := uuid.Parse(idStr) + if err != nil { + return scim.Resource{}, badUUID(idStr, err) + } + + dbUser, err := ru.store.GetUserByID(ctx, uid) + if err != nil { + if xerrors.Is(err, sql.ErrNoRows) { + return scim.Resource{}, scimErrors.ScimErrorResourceNotFound(idStr) + } + return scim.Resource{}, err + } + + // Process operations. Currently, we only handle the "active" attribute. + var activeSet *bool + for _, op := range operations { + switch op.Op { + case "add": + // TODO: Currently we do not support the adding of attributes. + case "remove": + // TODO: If the path is unspecified, we should fail with the status code 400. + // Today, we only accept the 'active' field and silently drop the rest. + if op.Path != nil && strings.EqualFold(op.Path.String(), "active") { + activeSet = ptr.Ref(false) + } + case "replace": + // TODO: Honor mutability rules of fields like `userName` and `email`. + // Should scim be able to change those fields? + + // SCIM PATCH replace can come in two forms: + // 1. Path set: {"op":"replace","path":"active","value":false} + // 2. No path, value is a map: {"op":"replace","value":{"active":false}} + if op.Path != nil && strings.EqualFold(op.Path.String(), "active") { + v, err := booleanValue(op.Value) + if err != nil { + return scim.Resource{}, scimErrors.ScimErrorBadRequest(fmt.Sprintf("invalid boolean value for 'active' field: %v", op.Value)) + } + activeSet = &v + } else if m, ok := op.Value.(map[string]interface{}); ok { + if actV, ok := attribute(m, "active"); ok { + v, err := booleanValue(actV) + if err != nil { + return scim.Resource{}, scimErrors.ScimErrorBadRequest(fmt.Sprintf("invalid boolean value for 'active' field: %v", actV)) + } + activeSet = &v + } + } + default: + } + } + + newStatus := scimUserStatus(dbUser, activeSet) + dbUser, err = ru.updateUserStatus(ctx, r, dbUser, newStatus) + if err != nil { + return scim.Resource{}, err + } + + return userResource(dbUser), nil +} + +func (ru *ResourceUser) user(ctx context.Context, idStr string) (database.User, error) { + id, err := uuid.Parse(idStr) + if err != nil { + return database.User{}, badUUID(idStr, err) + } + + usr, err := ru.store.GetUserByID(ctx, id) + if err != nil { + if xerrors.Is(err, sql.ErrNoRows) { + return database.User{}, scimErrors.ScimErrorResourceNotFound(idStr) + } + return database.User{}, err + } + + return usr, nil +} + +// updateUserStatus is a no-op if the status did not change. +func (ru *ResourceUser) updateUserStatus(ctx context.Context, r *http.Request, u database.User, status database.UserStatus) (database.User, error) { + if u.Status == status { + return u, nil + } + newUser, err := ru.store.UpdateUserStatus(ctx, database.UpdateUserStatusParams{ + ID: u.ID, Status: status, UpdatedAt: dbtime.Now(), UserIsSeen: false, + }) + if err != nil { + return database.User{}, err + } + ru.auditUser(ctx, r, database.AuditActionWrite, u, newUser) + return newUser, nil +} + +// scimUserStatus maps the SCIM "active" boolean to Coder's internal user status. +// It preserves the active/dormant distinction: active users stay active, +// dormant or suspended users become dormant when re-activated (they become +// active after their next login). +// +//nolint:revive // active is not a control flag +func scimUserStatus(user database.User, active *bool) database.UserStatus { + if active == nil { + return user.Status + } + + if !(*active) { + // SCIM "active: false" means the user should be suspended + return database.UserStatusSuspended + } + + switch user.Status { + case database.UserStatusActive: + // Active users stay active + return database.UserStatusActive + case database.UserStatusDormant, database.UserStatusSuspended: + // Dormant or suspended users become dormant when re-activated + // The user can then become active by doing something in the product. + return database.UserStatusDormant + default: + return database.UserStatusDormant + } +} + +// userResource converts a database.User into a SCIM Resource. +func userResource(u database.User) scim.Resource { + return scim.Resource{ + ID: u.ID.String(), + ExternalID: optional.String{}, + Attributes: scim.ResourceAttributes{ + "userName": u.Username, + "name": map[string]interface{}{ + "formatted": u.Name, + }, + "emails": []map[string]interface{}{ + { + "primary": true, + "value": u.Email, + }, + }, + "active": u.Status == database.UserStatusActive || + u.Status == database.UserStatusDormant, + }, + Meta: scim.Meta{ + Created: &u.CreatedAt, + LastModified: &u.UpdatedAt, + }, + } +} + +// userResourceFromGetUsersRow converts a database.GetUsersRow into a SCIM Resource. +func userResourceFromGetUsersRow(u database.GetUsersRow) scim.Resource { + return scim.Resource{ + ID: u.ID.String(), + ExternalID: optional.String{}, + Attributes: scim.ResourceAttributes{ + "userName": u.Username, + "name": map[string]interface{}{ + "formatted": u.Name, + }, + "emails": []map[string]interface{}{ + { + "primary": true, + "value": u.Email, + }, + }, + "active": u.Status == database.UserStatusActive || + u.Status == database.UserStatusDormant, + }, + Meta: scim.Meta{ + Created: &u.CreatedAt, + LastModified: &u.UpdatedAt, + }, + } +} + +func attributeAsBool(attrs scim.ResourceAttributes, key string) (value bool, exists bool) { + val, ok := attribute(attrs, key) + if !ok { + return false, false + } + + switch v := val.(type) { + case string: + pv, err := strconv.ParseBool(v) + return pv, err == nil + case bool: + return v, true + default: + return false, false + } +} + +func attributeAsString(attrs scim.ResourceAttributes, key string) (string, bool) { + val, ok := attribute(attrs, key) + if !ok { + return "", false + } + + switch v := val.(type) { + case string: + return v, true + case bool: + return strconv.FormatBool(v), true + default: + return "", false + } +} + +func attribute(attrs scim.ResourceAttributes, key string) (interface{}, bool) { + // attribute names are case-insensitive per SCIM spec + val, ok := attrs[key] + if ok { + return val, true + } + + // This is terrible, but we need to iterate the map to find the key in a case-insensitive way. + // The scim Spec says attribute names are case-insensitive. + for k, v := range attrs { + if k == key { + return v, true + } + if len(k) == len(key) && strings.EqualFold(k, key) { + return v, true + } + } + + return nil, false +} + +// badUUID returns a 404 not-found error for non-UUID identifiers. +// SCIM clients may send arbitrary strings as IDs; returning 404 +// (rather than 400) signals that no resource matches. +func badUUID(idStr string, _ error) scimErrors.ScimError { + return scimErrors.ScimError{ + Detail: fmt.Sprintf("%q is not a valid uuid; resource not found", idStr), + Status: http.StatusNotFound, + } +} + +func booleanValue(v interface{}) (bool, error) { + switch b := v.(type) { + case bool: + return b, nil + case string: + return strconv.ParseBool(b) + default: + return false, xerrors.Errorf("expected boolean or string value, got %T", v) + } +} + +func attributeEqual[T comparable](existing T, attrs scim.ResourceAttributes, key string) bool { + found, ok := attribute(attrs, key) + if !ok { + return true // No change if the attribute is not present in the request + } + + sameType, ok := found.(T) + if !ok { + return false // Type mismatch, consider it a change + } + + return existing == sameType +} + +// primaryEmail extracts the primary email from SCIM resource attributes. +func primaryEmail(attributes scim.ResourceAttributes) string { + emailsRaw, ok := attribute(attributes, "emails") + if !ok { + return "" + } + + emails, ok := emailsRaw.([]interface{}) + if !ok { + return "" + } + + var fallback string + for _, e := range emails { + emailMap, ok := e.(map[string]interface{}) + if !ok { + continue + } + val, ok := attributeAsString(emailMap, "value") + if !ok { + continue + } + if primary, _ := attributeAsBool(emailMap, "primary"); primary { + return val + } + fallback = val + } + + return fallback +} diff --git a/enterprise/coderd/scim/users_internal_test.go b/enterprise/coderd/scim/users_internal_test.go new file mode 100644 index 0000000000..b95e0a361f --- /dev/null +++ b/enterprise/coderd/scim/users_internal_test.go @@ -0,0 +1,760 @@ +package scim + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/elimity-com/scim" + scimErrors "github.com/elimity-com/scim/errors" + "github.com/google/uuid" + filter "github.com/scim2/filter-parser/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtestutil" +) + +// setupSCIM creates a ResourceUser backed by a real database for testing. +// The returned mock auditor can be inspected for emitted audit logs. +func setupSCIM(t *testing.T) (*ResourceUser, database.Store, *audit.MockAuditor) { + t.Helper() + + db, _ := dbtestutil.NewDB(t) + mockAudit := audit.NewMock() + auditorPtr := atomic.Pointer[audit.Auditor]{} + var a audit.Auditor = mockAudit + auditorPtr.Store(&a) + + ru := &ResourceUser{ + store: db, + opts: &Options{ + DB: db, + Auditor: &auditorPtr, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug), + }, + } + return ru, db, mockAudit +} + +// scimRequest builds an *http.Request with scim provisioner context, +// simulating the auth context that the SCIM middleware normally sets. +func scimRequest(t *testing.T) *http.Request { + t.Helper() + ctx := dbauthz.AsSCIMProvisioner(context.Background()) + return httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) +} + +// seedUser creates a user in the database for testing. +func seedUser(t *testing.T, db database.Store, opts database.User) database.User { + t.Helper() + return dbgen.User(t, db, opts) +} + +// setupSCIMMock creates a ResourceUser backed by a gomock store for tests +// that only need to verify call patterns (e.g. audit emission) without +// real SQL. +func setupSCIMMock(t *testing.T) (*ResourceUser, *dbmock.MockStore, *audit.MockAuditor) { + t.Helper() + + ctrl := gomock.NewController(t) + mockStore := dbmock.NewMockStore(ctrl) + mockAudit := audit.NewMock() + auditorPtr := atomic.Pointer[audit.Auditor]{} + var a audit.Auditor = mockAudit + auditorPtr.Store(&a) + + ru := &ResourceUser{ + store: mockStore, + opts: &Options{ + DB: mockStore, + Auditor: &auditorPtr, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug), + }, + } + return ru, mockStore, mockAudit +} + +// --- Pure function tests (no DB) --- + +func TestScimUserStatus(t *testing.T) { + t.Parallel() + + boolPtr := func(b bool) *bool { return &b } + + tests := []struct { + name string + status database.UserStatus + active *bool + expected database.UserStatus + }{ + {"active+true=active", database.UserStatusActive, boolPtr(true), database.UserStatusActive}, + {"active+false=suspended", database.UserStatusActive, boolPtr(false), database.UserStatusSuspended}, + {"suspended+true=dormant", database.UserStatusSuspended, boolPtr(true), database.UserStatusDormant}, + {"suspended+false=suspended", database.UserStatusSuspended, boolPtr(false), database.UserStatusSuspended}, + {"dormant+true=dormant", database.UserStatusDormant, boolPtr(true), database.UserStatusDormant}, + {"dormant+false=suspended", database.UserStatusDormant, boolPtr(false), database.UserStatusSuspended}, + {"active+nil=active", database.UserStatusActive, nil, database.UserStatusActive}, + {"suspended+nil=suspended", database.UserStatusSuspended, nil, database.UserStatusSuspended}, + {"dormant+nil=dormant", database.UserStatusDormant, nil, database.UserStatusDormant}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + user := database.User{Status: tt.status} + got := scimUserStatus(user, tt.active) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestPrimaryEmail(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + attrs scim.ResourceAttributes + expected string + }{ + { + name: "primary email", + attrs: scim.ResourceAttributes{ + "emails": []interface{}{ + map[string]interface{}{"value": "a@b.com", "primary": true}, + }, + }, + expected: "a@b.com", + }, + { + name: "fallback to first when no primary", + attrs: scim.ResourceAttributes{ + "emails": []interface{}{ + map[string]interface{}{"value": "first@b.com"}, + }, + }, + expected: "first@b.com", + }, + { + name: "picks primary over first", + attrs: scim.ResourceAttributes{ + "emails": []interface{}{ + map[string]interface{}{"value": "first@b.com"}, + map[string]interface{}{"value": "primary@b.com", "primary": true}, + }, + }, + expected: "primary@b.com", + }, + { + name: "polluted", + attrs: scim.ResourceAttributes{ + "emails": []interface{}{ + // Try and cause a panic + "not-a-map", + true, + 7, + map[int]interface{}{ + 1: "bad", + }, + map[string]interface{}{ + "value": 123, // value is not a string + }, + map[string]interface{}{}, + map[string]interface{}{"value": "first@b.com"}, + map[string]interface{}{"value": "primary@b.com", "primary": true}, + }, + }, + expected: "primary@b.com", + }, + { + name: "no emails key", + attrs: scim.ResourceAttributes{}, + expected: "", + }, + { + name: "empty emails", + attrs: scim.ResourceAttributes{"emails": []interface{}{}}, + expected: "", + }, + { + name: "wrong type", + attrs: scim.ResourceAttributes{"emails": "not-a-list"}, + expected: "", + }, + { + name: "case-insensitive top-level key", + attrs: scim.ResourceAttributes{ + "Emails": []interface{}{ + map[string]interface{}{"value": "a@b.com", "primary": true}, + }, + }, + expected: "a@b.com", + }, + { + name: "case-insensitive inner keys", + attrs: scim.ResourceAttributes{ + "emails": []interface{}{ + map[string]interface{}{"Value": "a@b.com", "Primary": true}, + }, + }, + expected: "a@b.com", + }, + { + name: "all caps keys", + attrs: scim.ResourceAttributes{ + "EMAILS": []interface{}{ + map[string]interface{}{"VALUE": "a@b.com", "PRIMARY": true}, + }, + }, + expected: "a@b.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := primaryEmail(tt.attrs) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestBooleanValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input interface{} + want bool + wantErr bool + }{ + {"bool true", true, true, false}, + {"bool false", false, false, false}, + {"string true", "true", true, false}, + {"string false", "false", false, false}, + {"string True", "True", true, false}, + {"int", 42, false, true}, + {"nil", nil, false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := booleanValue(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestAttribute(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + attrs scim.ResourceAttributes + key string + wantVal interface{} + wantOK bool + }{ + {"exact match", scim.ResourceAttributes{"active": true}, "active", true, true}, + {"capital first", scim.ResourceAttributes{"active": true}, "Active", true, true}, + {"all caps", scim.ResourceAttributes{"active": true}, "ACTIVE", true, true}, + {"camelCase key", scim.ResourceAttributes{"userName": "alice"}, "username", "alice", true}, + {"camelCase swapped", scim.ResourceAttributes{"username": "alice"}, "userName", "alice", true}, + {"missing key", scim.ResourceAttributes{"active": true}, "missing", nil, false}, + {"empty map", scim.ResourceAttributes{}, "active", nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + val, ok := attribute(tt.attrs, tt.key) + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.wantVal, val) + }) + } +} + +func TestAttributeAsBool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + attrs scim.ResourceAttributes + key string + want bool + wantOK bool + }{ + {"exact key bool", scim.ResourceAttributes{"active": true}, "active", true, true}, + {"mixed case bool", scim.ResourceAttributes{"active": false}, "Active", false, true}, + {"all caps bool", scim.ResourceAttributes{"active": true}, "ACTIVE", true, true}, + {"mixed case string true", scim.ResourceAttributes{"active": "true"}, "Active", true, true}, + {"mixed case string false", scim.ResourceAttributes{"active": "false"}, "ACTIVE", false, true}, + {"missing key", scim.ResourceAttributes{}, "active", false, false}, + {"non-convertible", scim.ResourceAttributes{"active": 42}, "active", false, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := attributeAsBool(tt.attrs, tt.key) + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAttributeAsString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + attrs scim.ResourceAttributes + key string + want string + wantOK bool + }{ + {"exact key string", scim.ResourceAttributes{"userName": "alice"}, "userName", "alice", true}, + {"mixed case string", scim.ResourceAttributes{"userName": "alice"}, "UserName", "alice", true}, + {"lower case lookup", scim.ResourceAttributes{"userName": "alice"}, "username", "alice", true}, + {"bool to string", scim.ResourceAttributes{"active": true}, "active", "true", true}, + {"mixed case bool to string", scim.ResourceAttributes{"active": false}, "Active", "false", true}, + {"missing key", scim.ResourceAttributes{}, "userName", "", false}, + {"non-convertible", scim.ResourceAttributes{"count": 42}, "count", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := attributeAsString(tt.attrs, tt.key) + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAttributeEqual(t *testing.T) { + t.Parallel() + + t.Run("exact match same value", func(t *testing.T) { + t.Parallel() + attrs := scim.ResourceAttributes{"userName": "alice"} + assert.True(t, attributeEqual("alice", attrs, "userName")) + }) + + t.Run("mixed case same value", func(t *testing.T) { + t.Parallel() + attrs := scim.ResourceAttributes{"userName": "alice"} + assert.True(t, attributeEqual("alice", attrs, "UserName")) + }) + + t.Run("mixed case different value", func(t *testing.T) { + t.Parallel() + attrs := scim.ResourceAttributes{"userName": "bob"} + assert.False(t, attributeEqual("alice", attrs, "USERNAME")) + }) + + t.Run("missing key means no change", func(t *testing.T) { + t.Parallel() + attrs := scim.ResourceAttributes{} + assert.True(t, attributeEqual("alice", attrs, "userName")) + }) + + t.Run("type mismatch", func(t *testing.T) { + t.Parallel() + attrs := scim.ResourceAttributes{"userName": 42} + assert.False(t, attributeEqual("alice", attrs, "userName")) + }) +} + +// --- Handler tests (with DB) --- + +func TestResourceUser_CaseInsensitive(t *testing.T) { + t.Parallel() + + ru, db, _ := setupSCIM(t) + + // Seed an active user. + user := seedUser(t, db, database.User{ + Status: database.UserStatusActive, + LoginType: database.LoginTypeOIDC, + }) + + r := scimRequest(t) + + // Replace with "Active" (capital A) instead of "active". + res, err := ru.Replace(r, user.ID.String(), scim.ResourceAttributes{ + "userName": user.Username, + "Active": false, + }) + require.NoError(t, err) + assert.Equal(t, false, res.Attributes["active"]) + + // Confirm suspended via Get. + res, err = ru.Get(r, user.ID.String()) + require.NoError(t, err) + assert.Equal(t, false, res.Attributes["active"]) + + // Patch back with map-style replace using "Active" key. + res, err = ru.Patch(r, user.ID.String(), []scim.PatchOperation{ + {Op: "replace", Value: map[string]interface{}{"Active": true}}, + }) + require.NoError(t, err) + assert.Equal(t, true, res.Attributes["active"]) + + // Confirm reactivated via Get. + res, err = ru.Get(r, user.ID.String()) + require.NoError(t, err) + assert.Equal(t, true, res.Attributes["active"]) +} + +func TestResourceUser_Create(t *testing.T) { + t.Parallel() + + // Coder does not hard-delete users. A SCIM Delete suspends the user, so + // when an IdP later re-creates the same user, the handler should match + // them by email/username and reactivate the existing row instead of + // returning 409 Conflict. See commit b3e6e0aa06. + + t.Run("duplicate-active-conflict", func(t *testing.T) { + t.Parallel() + ru, db, _ := setupSCIM(t) + + existing := seedUser(t, db, database.User{ + Status: database.UserStatusActive, + LoginType: database.LoginTypeOIDC, + }) + + _, err := ru.Create(scimRequest(t), scim.ResourceAttributes{ + "userName": existing.Username, + "emails": []interface{}{ + map[string]interface{}{"value": existing.Email, "primary": true}, + }, + "active": true, + }) + require.Error(t, err) + var scimErr scimErrors.ScimError + require.ErrorAs(t, err, &scimErr) + assert.Equal(t, http.StatusConflict, scimErr.Status) + }) + + t.Run("suspended-user-reactivates", func(t *testing.T) { + t.Parallel() + ru, db, mockAudit := setupSCIM(t) + + existing := seedUser(t, db, database.User{ + Status: database.UserStatusSuspended, + LoginType: database.LoginTypeOIDC, + }) + + res, err := ru.Create(scimRequest(t), scim.ResourceAttributes{ + "userName": existing.Username, + "emails": []interface{}{ + map[string]interface{}{"value": existing.Email, "primary": true}, + }, + "active": true, + }) + require.NoError(t, err) + assert.Equal(t, existing.ID.String(), res.ID, "response should reference the existing user, not a new one") + + // The SCIM response must reflect the post-update state so the IdP + // sees active=true after the recreate. + assert.Equal(t, true, res.Attributes["active"], "response should report the reactivated state") + + // Suspended + active=true reactivates to Dormant (not Active) per scimUserStatus. + got, err := db.GetUserByID(dbauthz.AsSCIMProvisioner(context.Background()), existing.ID) + require.NoError(t, err) + assert.Equal(t, database.UserStatusDormant, got.Status, "suspended user should be marked dormant on recreate") + + // Reactivation should emit one audit log for the status change. + assert.Len(t, mockAudit.AuditLogs(), 1) + }) + + t.Run("suspended-user-stays-suspended-when-active-false", func(t *testing.T) { + t.Parallel() + ru, db, mockAudit := setupSCIM(t) + + existing := seedUser(t, db, database.User{ + Status: database.UserStatusSuspended, + LoginType: database.LoginTypeOIDC, + }) + + res, err := ru.Create(scimRequest(t), scim.ResourceAttributes{ + "userName": existing.Username, + "emails": []interface{}{ + map[string]interface{}{"value": existing.Email, "primary": true}, + }, + "active": false, + }) + require.NoError(t, err) + assert.Equal(t, existing.ID.String(), res.ID) + assert.Equal(t, false, res.Attributes["active"]) + + got, err := db.GetUserByID(dbauthz.AsSCIMProvisioner(context.Background()), existing.ID) + require.NoError(t, err) + assert.Equal(t, database.UserStatusSuspended, got.Status) + + // No status change → no audit log. + assert.Empty(t, mockAudit.AuditLogs()) + }) +} + +func TestResourceUser_Lifecycle(t *testing.T) { + t.Parallel() + + ru, db, _ := setupSCIM(t) + + // Seed an active user. + user := seedUser(t, db, database.User{ + Status: database.UserStatusActive, + LoginType: database.LoginTypeOIDC, + }) + + r := scimRequest(t) + + // Step 1: Get the user. Verify fields match. + res, err := ru.Get(r, user.ID.String()) + require.NoError(t, err) + assert.Equal(t, user.ID.String(), res.ID) + assert.Equal(t, user.Username, res.Attributes["userName"]) + assert.Equal(t, true, res.Attributes["active"]) + + // Step 2: Replace with active=false → suspended. + res, err = ru.Replace(r, user.ID.String(), scim.ResourceAttributes{ + "userName": user.Username, + "active": false, + }) + require.NoError(t, err) + assert.Equal(t, false, res.Attributes["active"]) + + // Step 3: Get → confirm inactive. + res, err = ru.Get(r, user.ID.String()) + require.NoError(t, err) + assert.Equal(t, false, res.Attributes["active"]) + + // Step 4: Patch active=true → dormant (shown as active in SCIM). + res, err = ru.Patch(r, user.ID.String(), []scim.PatchOperation{ + {Op: "replace", Path: mustPath("active"), Value: true}, + }) + require.NoError(t, err) + assert.Equal(t, true, res.Attributes["active"]) + + // Step 5: Get → confirm active again. + res, err = ru.Get(r, user.ID.String()) + require.NoError(t, err) + assert.Equal(t, true, res.Attributes["active"]) + + // Step 6: Delete → suspended. + err = ru.Delete(r, user.ID.String()) + require.NoError(t, err) + + // Step 7: Get → confirm inactive after delete. + res, err = ru.Get(r, user.ID.String()) + require.NoError(t, err) + assert.Equal(t, false, res.Attributes["active"]) +} + +func TestResourceUser_GetAll(t *testing.T) { + t.Parallel() + + ru, db, _ := setupSCIM(t) + + // Seed 3 users. + for i := 0; i < 3; i++ { + seedUser(t, db, database.User{ + LoginType: database.LoginTypeOIDC, + }) + } + + r := scimRequest(t) + + // Get all with large count. + page, err := ru.GetAll(r, scim.ListRequestParams{Count: 100, StartIndex: 1}) + require.NoError(t, err) + assert.GreaterOrEqual(t, page.TotalResults, 3) + assert.GreaterOrEqual(t, len(page.Resources), 3) + + // Paginate: startIndex=2, count=1. + page, err = ru.GetAll(r, scim.ListRequestParams{Count: 1, StartIndex: 2}) + require.NoError(t, err) + assert.Len(t, page.Resources, 1) + assert.GreaterOrEqual(t, page.TotalResults, 3) +} + +func TestResourceUser_Errors(t *testing.T) { + t.Parallel() + + ru, _, _ := setupSCIM(t) + r := scimRequest(t) + missingUUID := uuid.New().String() + + tests := []struct { + name string + run func() error + wantStatus int + }{ + { + name: "Get/non-UUID", + run: func() error { _, err := ru.Get(r, "not-a-uuid"); return err }, + wantStatus: http.StatusNotFound, + }, + { + name: "Get/missing", + run: func() error { _, err := ru.Get(r, missingUUID); return err }, + wantStatus: http.StatusNotFound, + }, + { + name: "Replace/non-UUID", + run: func() error { _, err := ru.Replace(r, "bad", scim.ResourceAttributes{}); return err }, + wantStatus: http.StatusNotFound, + }, + { + name: "Replace/missing", + run: func() error { _, err := ru.Replace(r, missingUUID, scim.ResourceAttributes{}); return err }, + wantStatus: http.StatusNotFound, + }, + { + name: "Replace/immutable-userName", + run: func() error { + // Need a real user for this test. + user := seedUser(t, ru.store, database.User{LoginType: database.LoginTypeOIDC}) + _, err := ru.Replace(r, user.ID.String(), scim.ResourceAttributes{ + "userName": "different-name", + }) + return err + }, + wantStatus: http.StatusBadRequest, + }, + { + name: "Patch/non-UUID", + run: func() error { _, err := ru.Patch(r, "bad", nil); return err }, + wantStatus: http.StatusNotFound, + }, + { + name: "Patch/missing", + run: func() error { _, err := ru.Patch(r, missingUUID, nil); return err }, + wantStatus: http.StatusNotFound, + }, + { + name: "Delete/non-UUID", + run: func() error { return ru.Delete(r, "bad") }, + wantStatus: http.StatusNotFound, + }, + { + name: "Delete/missing", + run: func() error { return ru.Delete(r, missingUUID) }, + wantStatus: http.StatusNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.run() + require.Error(t, err) + var scimErr scimErrors.ScimError + require.ErrorAs(t, err, &scimErr) + assert.Equal(t, tt.wantStatus, scimErr.Status) + }) + } +} + +func TestResourceUser_AuditLogs(t *testing.T) { + t.Parallel() + + // These tests use dbmock instead of a real database because they only + // verify audit emission logic (does an audit log fire when status + // changes?), not SQL correctness. The handlers call just GetUserByID + // and UpdateUserStatus, both trivially mockable. + + makeUser := func(status database.UserStatus) (database.User, database.User) { + id := uuid.New() + user := database.User{ + ID: id, + Username: "testuser", + Status: status, + LoginType: database.LoginTypeOIDC, + } + suspended := user + suspended.Status = database.UserStatusSuspended + return user, suspended + } + + t.Run("Replace/status-change-emits-audit", func(t *testing.T) { + t.Parallel() + ru, mockStore, mockAudit := setupSCIMMock(t) + activeUser, suspendedUser := makeUser(database.UserStatusActive) + + mockStore.EXPECT().GetUserByID(gomock.Any(), activeUser.ID).Return(activeUser, nil) + mockStore.EXPECT().UpdateUserStatus(gomock.Any(), gomock.Any()).Return(suspendedUser, nil) + + _, err := ru.Replace(scimRequest(t), activeUser.ID.String(), scim.ResourceAttributes{ + "userName": activeUser.Username, + "active": false, + }) + require.NoError(t, err) + assert.Len(t, mockAudit.AuditLogs(), 1) + }) + + t.Run("Replace/no-change-skips-audit", func(t *testing.T) { + t.Parallel() + ru, mockStore, mockAudit := setupSCIMMock(t) + activeUser, _ := makeUser(database.UserStatusActive) + + mockStore.EXPECT().GetUserByID(gomock.Any(), activeUser.ID).Return(activeUser, nil) + // No UpdateUserStatus expected: active=true on an already active user is a no-op. + + _, err := ru.Replace(scimRequest(t), activeUser.ID.String(), scim.ResourceAttributes{ + "userName": activeUser.Username, + "active": true, + }) + require.NoError(t, err) + assert.Empty(t, mockAudit.AuditLogs()) + }) + + t.Run("Delete/active-user-emits-audit", func(t *testing.T) { + t.Parallel() + ru, mockStore, mockAudit := setupSCIMMock(t) + activeUser, suspendedUser := makeUser(database.UserStatusActive) + + mockStore.EXPECT().GetUserByID(gomock.Any(), activeUser.ID).Return(activeUser, nil) + mockStore.EXPECT().UpdateUserStatus(gomock.Any(), gomock.Any()).Return(suspendedUser, nil) + + err := ru.Delete(scimRequest(t), activeUser.ID.String()) + require.NoError(t, err) + assert.Len(t, mockAudit.AuditLogs(), 1) + }) + + t.Run("Delete/suspended-user-skips-audit", func(t *testing.T) { + t.Parallel() + ru, mockStore, mockAudit := setupSCIMMock(t) + _, suspendedUser := makeUser(database.UserStatusSuspended) + + mockStore.EXPECT().GetUserByID(gomock.Any(), suspendedUser.ID).Return(suspendedUser, nil) + // No UpdateUserStatus expected: already suspended. + + err := ru.Delete(scimRequest(t), suspendedUser.ID.String()) + require.NoError(t, err) + assert.Empty(t, mockAudit.AuditLogs()) + }) +} + +// mustPath parses a SCIM attribute path string into a *filter.Path +// for use in PatchOperation test data. +func mustPath(attr string) *filter.Path { + p, err := filter.ParsePath([]byte(attr)) + if err != nil { + panic(fmt.Sprintf("mustPath(%q): %v", attr, err)) + } + return &p +} diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index e33c49e2a4..0aeb61d8e0 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -4,13 +4,10 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" "testing" - "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" "github.com/imulab/go-scim/pkg/v2/handlerutil" "github.com/imulab/go-scim/pkg/v2/spec" "github.com/stretchr/testify/assert" @@ -19,25 +16,22 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/coderdtest/oidctest" - "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/notifications/notificationstest" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/cryptorand" - "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/legacyscim" "github.com/coder/coder/v2/enterprise/coderd/license" - "github.com/coder/coder/v2/enterprise/coderd/scim" "github.com/coder/coder/v2/testutil" ) //nolint:revive -func makeScimUser(t testing.TB) coderd.SCIMUser { +func makeScimUser(t testing.TB) legacyscim.SCIMUser { rstr, err := cryptorand.String(10) require.NoError(t, err) - return coderd.SCIMUser{ + return legacyscim.SCIMUser{ UserName: rstr, Name: struct { GivenName string `json:"givenName"` @@ -64,807 +58,651 @@ func setScimAuth(key []byte) func(*http.Request) { } } -func setScimAuthBearer(key []byte) func(*http.Request) { - return func(r *http.Request) { - // Do strange casing to ensure it's case-insensitive - r.Header.Set("Authorization", "beAreR "+string(key)) - } -} - +// TestLegacyScim tests the legacy SCIM handler (imulab/go-scim based). +// This is a reduced set of integration tests verifying HTTP routing, auth, +// and core CRUD. Detailed handler logic is covered by the unit tests in +// enterprise/coderd/scim/scimusers_test.go. +// //nolint:gocritic // SCIM authenticates via a special header and bypasses internal RBAC. -func TestScim(t *testing.T) { +func TestLegacyScim(t *testing.T) { t.Parallel() + t.Run("disabled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: []byte("hi"), + UseLegacySCIM: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{codersdk.FeatureSCIM: 0}, + }, + }) + + res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + }) + + t.Run("noAuth", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: []byte("hi"), + UseLegacySCIM: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{codersdk.FeatureSCIM: 1}, + }, + }) + + res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + t.Run("postUser", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() - t.Run("disabled", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: []byte("hi"), - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 0, - }, + scimAPIKey := []byte("hi") + mockAudit := audit.NewMock() + notifyEnq := ¬ificationstest.FakeEnqueuer{} + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Auditor: mockAudit, + NotificationsEnqueuer: notifyEnq, + }, + SCIMAPIKey: scimAPIKey, + UseLegacySCIM: true, + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureAuditLog: 1, + codersdk.FeatureMultipleOrganizations: 1, }, - }) - - res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusForbidden, res.StatusCode) + }, }) - t.Run("noAuth", func(t *testing.T) { - t.Parallel() + sUser := makeScimUser(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() + var createdUser legacyscim.SCIMUser + err = json.NewDecoder(res.Body).Decode(&createdUser) + require.NoError(t, err) + assert.NotEmpty(t, createdUser.ID) + assert.Equal(t, sUser.UserName, createdUser.UserName) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: []byte("hi"), - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - }, + // Verify user exists. + userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: createdUser.UserName}) + require.NoError(t, err) + require.Len(t, userRes.Users, 1) + assert.Equal(t, codersdk.LoginTypeOIDC, userRes.Users[0].LoginType) + + // Verify audit log. + require.True(t, len(mockAudit.AuditLogs()) > 0) + + // Verify no user admin notification (SCIM skips notifications). + require.Empty(t, notifyEnq.Sent()) + }) + + t.Run("Duplicate", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + UseLegacySCIM: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureMultipleOrganizations: 1, }, - }) - - res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }, }) - t.Run("OK", func(t *testing.T) { - t.Parallel() + sUser := makeScimUser(t) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - // given - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - notifyEnq := ¬ificationstest.FakeEnqueuer{} - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - Auditor: mockAudit, - NotificationsEnqueuer: notifyEnq, - }, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - // verify scim is enabled - res, err := client.Request(ctx, http.MethodGet, "/scim/v2/ServiceProviderConfig", nil) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - - // when - sUser := makeScimUser(t) - res, err = client.Request(ctx, http.MethodPost, "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - - // then - // Expect audit logs - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - af := map[string]string{} - err = json.Unmarshal([]byte(aLogs[0].AdditionalFields), &af) - require.NoError(t, err) - assert.Equal(t, coderd.SCIMAuditAdditionalFields, af) - assert.Equal(t, database.AuditActionCreate, aLogs[0].Action) - - // Expect users exposed over API - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) - assert.Equal(t, sUser.UserName, userRes.Users[0].Username) - assert.Len(t, userRes.Users[0].OrganizationIDs, 1) - - // Expect zero notifications (SkipNotifications = true) - require.Empty(t, notifyEnq.Sent()) - }) - - t.Run("OK_Bearer", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - // given - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - notifyEnq := ¬ificationstest.FakeEnqueuer{} - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - Auditor: mockAudit, - NotificationsEnqueuer: notifyEnq, - }, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - // when - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuthBearer(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - - // then - // Expect audit logs - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - af := map[string]string{} - err = json.Unmarshal([]byte(aLogs[0].AdditionalFields), &af) - require.NoError(t, err) - assert.Equal(t, coderd.SCIMAuditAdditionalFields, af) - assert.Equal(t, database.AuditActionCreate, aLogs[0].Action) - - // Expect users exposed over API - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) - assert.Equal(t, sUser.UserName, userRes.Users[0].Username) - assert.Len(t, userRes.Users[0].OrganizationIDs, 1) - - // Expect zero notifications (SkipNotifications = true) - require.Empty(t, notifyEnq.Sent()) - }) - - t.Run("OKNoDefault", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - // given - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - notifyEnq := ¬ificationstest.FakeEnqueuer{} - dv := coderdtest.DeploymentValues(t) - dv.OIDC.OrganizationAssignDefault = false - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - Auditor: mockAudit, - NotificationsEnqueuer: notifyEnq, - DeploymentValues: dv, - }, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - // when - sUser := makeScimUser(t) + // Create same user 3 times. + for i := 0; i < 3; i++ { res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - - // then - // Expect audit logs - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - af := map[string]string{} - err = json.Unmarshal([]byte(aLogs[0].AdditionalFields), &af) - require.NoError(t, err) - assert.Equal(t, coderd.SCIMAuditAdditionalFields, af) - assert.Equal(t, database.AuditActionCreate, aLogs[0].Action) - - // Expect users exposed over API - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) - assert.Equal(t, sUser.UserName, userRes.Users[0].Username) - assert.Len(t, userRes.Users[0].OrganizationIDs, 0) - - // Expect zero notifications (SkipNotifications = true) - require.Empty(t, notifyEnq.Sent()) - }) - - t.Run("Duplicate", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: scimAPIKey, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - }, - }, - }) - - sUser := makeScimUser(t) - for i := 0; i < 3; i++ { - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - } - - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - - assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) - assert.Equal(t, sUser.UserName, userRes.Users[0].Username) - }) - - t.Run("Unsuspend", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: scimAPIKey, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - }, - }, - }) - - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - sUser.Active = ptr.Ref(false) - res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) + } - sUser.Active = ptr.Ref(true) - res, err = client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - - assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) - assert.Equal(t, sUser.UserName, userRes.Users[0].Username) - assert.Equal(t, codersdk.UserStatusDormant, userRes.Users[0].Status) - }) - - t.Run("DomainStrips", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: scimAPIKey, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - }, - }, - }) - - sUser := makeScimUser(t) - sUser.UserName = sUser.UserName + "@coder.com" - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - - assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) - // Username should be the same as the given name. They all use the - // same string before we modified it above. - assert.Equal(t, sUser.Name.GivenName, userRes.Users[0].Username) - }) + // Only 1 user should exist. + userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.UserName}) + require.NoError(t, err) + require.Len(t, userRes.Users, 1) }) t.Run("patchUser", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() - t.Run("disabled", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: []byte("hi"), - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 0, - }, + scimAPIKey := []byte("hi") + mockAudit := audit.NewMock() + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{Auditor: mockAudit}, + SCIMAPIKey: scimAPIKey, + UseLegacySCIM: true, + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureAuditLog: 1, + codersdk.FeatureMultipleOrganizations: 1, }, - }) - - res, err := client.Request(ctx, "PATCH", "/scim/v2/Users/bob", struct{}{}) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusForbidden, res.StatusCode) + }, }) - t.Run("noAuth", func(t *testing.T) { - t.Parallel() + // Create user first. + sUser := makeScimUser(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() + var createdUser legacyscim.SCIMUser + err = json.NewDecoder(res.Body).Decode(&createdUser) + require.NoError(t, err) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: []byte("hi"), - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - }, - }, - }) + // Suspend via PATCH. + mockAudit.ResetLogs() + sUser.Active = ptr.Ref(false) + res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+createdUser.ID, sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) - res, err := client.Request(ctx, "PATCH", "/scim/v2/Users/bob", struct{}{}) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) - - t.Run("OK", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{Auditor: mockAudit}, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - mockAudit.ResetLogs() - - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - sUser.Active = ptr.Ref(false) - - res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - assert.Equal(t, database.AuditActionWrite, aLogs[0].Action) - - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - assert.Equal(t, codersdk.UserStatusSuspended, userRes.Users[0].Status) - }) - - // Create a user via SCIM, which starts as dormant. - // Log in as the user, making them active. - // Then patch the user again and the user should still be active. - t.Run("ActiveIsActive", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - - mockAudit := audit.NewMock() - fake := oidctest.NewFakeIDP(t, oidctest.WithServing()) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - Auditor: mockAudit, - OIDCConfig: fake.OIDCConfig(t, []string{}), - }, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - // User is dormant on create - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - // Check the audit log - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - assert.Equal(t, database.AuditActionCreate, aLogs[0].Action) - - // Verify the user is dormant - scimUser, err := client.User(ctx, sUser.UserName) - require.NoError(t, err) - require.Equal(t, codersdk.UserStatusDormant, scimUser.Status, "user starts as dormant") - - // Log in as the user, making them active - //nolint:bodyclose - scimUserClient, _ := fake.Login(t, client, jwt.MapClaims{ - "email": sUser.Emails[0].Value, - "sub": uuid.NewString(), - }) - scimUser, err = scimUserClient.User(ctx, codersdk.Me) - require.NoError(t, err) - require.Equal(t, codersdk.UserStatusActive, scimUser.Status, "user should now be active") - - // Patch the user - mockAudit.ResetLogs() - res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - // Should be no audit logs since there is no diff - aLogs = mockAudit.AuditLogs() - require.Len(t, aLogs, 0) - - // Verify the user is still active. - scimUser, err = client.User(ctx, sUser.UserName) - require.NoError(t, err) - require.Equal(t, codersdk.UserStatusActive, scimUser.Status, "user is still active") - }) + // Verify suspended. + userRes, err := client.User(ctx, createdUser.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusSuspended, userRes.Status) }) t.Run("putUser", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() - t.Run("disabled", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: []byte("hi"), - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 0, - }, + scimAPIKey := []byte("hi") + mockAudit := audit.NewMock() + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{Auditor: mockAudit}, + SCIMAPIKey: scimAPIKey, + UseLegacySCIM: true, + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureAuditLog: 1, + codersdk.FeatureMultipleOrganizations: 1, }, - }) - - res, err := client.Request(ctx, http.MethodPut, "/scim/v2/Users/bob", struct{}{}) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusForbidden, res.StatusCode) + }, }) - t.Run("noAuth", func(t *testing.T) { - t.Parallel() + // Create user first. + sUser := makeScimUser(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() + var createdUser legacyscim.SCIMUser + err = json.NewDecoder(res.Body).Decode(&createdUser) + require.NoError(t, err) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - SCIMAPIKey: []byte("hi"), - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - }, - }, - }) + // Suspend via PUT. + mockAudit.ResetLogs() + sUser.Active = ptr.Ref(false) + res, err = client.Request(ctx, "PUT", "/scim/v2/Users/"+createdUser.ID, sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) - res, err := client.Request(ctx, http.MethodPut, "/scim/v2/Users/bob", struct{}{}) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) - - t.Run("MissingActiveField", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{Auditor: mockAudit}, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - mockAudit.ResetLogs() - - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - sUser.Active = nil - - res, err = client.Request(ctx, http.MethodPut, "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusBadRequest, res.StatusCode) - - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Contains(t, string(data), "active field is required") - mockAudit.ResetLogs() - }) - - t.Run("ImmutabilityViolation", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{Auditor: mockAudit}, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - mockAudit.ResetLogs() - - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - sUser.UserName += "changed" - - res, err = client.Request(ctx, http.MethodPut, "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusBadRequest, res.StatusCode) - mockAudit.ResetLogs() - - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Contains(t, string(data), "mutability") - require.NoError(t, err) - }) - - t.Run("OK", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - mockAudit := audit.NewMock() - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{Auditor: mockAudit}, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - sUser := makeScimUser(t) - res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - mockAudit.ResetLogs() - - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - sUser.Active = ptr.Ref(false) - - res, err = client.Request(ctx, http.MethodPatch, "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - assert.Equal(t, database.AuditActionWrite, aLogs[0].Action) - - userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) - require.NoError(t, err) - require.Len(t, userRes.Users, 1) - assert.Equal(t, codersdk.UserStatusSuspended, userRes.Users[0].Status) - }) - - // Create a user via SCIM, which starts as dormant. - // Log in as the user, making them active. - // Then patch the user again and the user should still be active. - t.Run("ActiveIsActive", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - scimAPIKey := []byte("hi") - - mockAudit := audit.NewMock() - fake := oidctest.NewFakeIDP(t, oidctest.WithServing()) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - Auditor: mockAudit, - OIDCConfig: fake.OIDCConfig(t, []string{}), - }, - SCIMAPIKey: scimAPIKey, - AuditLogging: true, - LicenseOptions: &coderdenttest.LicenseOptions{ - AccountID: "coolin", - Features: license.Features{ - codersdk.FeatureSCIM: 1, - codersdk.FeatureAuditLog: 1, - }, - }, - }) - mockAudit.ResetLogs() - - // User is dormant on create - sUser := makeScimUser(t) - res, err := client.Request(ctx, http.MethodPost, "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - err = json.NewDecoder(res.Body).Decode(&sUser) - require.NoError(t, err) - - // Check the audit log - aLogs := mockAudit.AuditLogs() - require.Len(t, aLogs, 1) - assert.Equal(t, database.AuditActionCreate, aLogs[0].Action) - - // Verify the user is dormant - scimUser, err := client.User(ctx, sUser.UserName) - require.NoError(t, err) - require.Equal(t, codersdk.UserStatusDormant, scimUser.Status, "user starts as dormant") - - // Log in as the user, making them active - //nolint:bodyclose - scimUserClient, _ := fake.Login(t, client, jwt.MapClaims{ - "email": sUser.Emails[0].Value, - "sub": uuid.NewString(), - }) - scimUser, err = scimUserClient.User(ctx, codersdk.Me) - require.NoError(t, err) - require.Equal(t, codersdk.UserStatusActive, scimUser.Status, "user should now be active") - - // Patch the user - mockAudit.ResetLogs() - res, err = client.Request(ctx, http.MethodPut, "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) - require.NoError(t, err) - _, _ = io.Copy(io.Discard, res.Body) - _ = res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - - // Should be no audit logs since there is no diff - aLogs = mockAudit.AuditLogs() - require.Len(t, aLogs, 0) - - // Verify the user is still active. - scimUser, err = client.User(ctx, sUser.UserName) - require.NoError(t, err) - require.Equal(t, codersdk.UserStatusActive, scimUser.Status, "user is still active") - }) + // Verify suspended. + userRes, err := client.User(ctx, createdUser.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusSuspended, userRes.Status) }) } -func TestScimError(t *testing.T) { +// scim2User is a minimal struct for decoding SCIM 2.0 user responses +// returned by the elimity-com/scim library. +type scim2User struct { + ID string `json:"id"` + UserName string `json:"userName"` + Active bool `json:"active"` +} + +// scim2UserBody is the request body for SCIM 2.0 POST/PUT calls. +// Unlike the legacy handler, the elimity-com/scim library validates the +// "schemas" attribute against the core User schema URI and rejects bodies +// that omit it. +type scim2UserBody struct { + Schemas []string `json:"schemas"` + UserName string `json:"userName"` + Name struct { + GivenName string `json:"givenName"` + FamilyName string `json:"familyName"` + } `json:"name"` + Emails []struct { + Primary bool `json:"primary"` + Value string `json:"value"` + } `json:"emails"` + Active *bool `json:"active,omitempty"` +} + +func makeScim2User(t testing.TB) scim2UserBody { + rstr, err := cryptorand.String(10) + require.NoError(t, err) + + b := scim2UserBody{ + Schemas: []string{"urn:ietf:params:scim:schemas:core:2.0:User"}, + UserName: rstr, + Active: ptr.Ref(true), + } + b.Name.GivenName = rstr + b.Name.FamilyName = rstr + b.Emails = []struct { + Primary bool `json:"primary"` + Value string `json:"value"` + }{{Primary: true, Value: fmt.Sprintf("%s@coder.com", rstr)}} + return b +} + +// TestScim exercises the SCIM 2.0 handler through real HTTP routes. It +// mirrors TestLegacyScim's structure (disabled/noAuth/post/patch/put) and +// adds coverage for behavior unique to the v2 implementation: discovery +// endpoints, 409 Conflict on duplicate active users, suspended-user +// reactivation, GET by id, and DELETE. +// +//nolint:gocritic // SCIM authenticates via a special header and bypasses internal RBAC. +func TestScim(t *testing.T) { + t.Parallel() + + t.Run("disabled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: []byte("hi"), + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{codersdk.FeatureSCIM: 0}, + }, + }) + + res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + }) + + t.Run("noAuth", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: []byte("hi"), + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{codersdk.FeatureSCIM: 1}, + }, + }) + + res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("discovery", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{codersdk.FeatureSCIM: 1}, + }, + }) + + for _, path := range []string{ + "/scim/v2/ServiceProviderConfig", + "/scim/v2/ResourceTypes", + "/scim/v2/Schemas", + } { + res, err := client.Request(ctx, "GET", path, nil, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode, "discovery endpoint %s", path) + } + }) + + t.Run("postUser", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + mockAudit := audit.NewMock() + notifyEnq := ¬ificationstest.FakeEnqueuer{} + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Auditor: mockAudit, + NotificationsEnqueuer: notifyEnq, + }, + SCIMAPIKey: scimAPIKey, + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureAuditLog: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + + var created scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&created)) + assert.NotEmpty(t, created.ID) + assert.Equal(t, sUser.UserName, created.UserName) + assert.True(t, created.Active) + + // Verify user exists. + userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: created.UserName}) + require.NoError(t, err) + require.Len(t, userRes.Users, 1) + assert.Equal(t, codersdk.LoginTypeOIDC, userRes.Users[0].LoginType) + + // Verify audit log. + require.True(t, len(mockAudit.AuditLogs()) > 0) + + // Verify no user admin notification (SCIM skips notifications). + require.Empty(t, notifyEnq.Sent()) + }) + + t.Run("postUserConflict", func(t *testing.T) { + // SCIM 2.0 returns 409 Conflict on duplicate active user, unlike the + // legacy handler which returned 200 with the existing user. + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + + res, err = client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusConflict, res.StatusCode) + + userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.UserName}) + require.NoError(t, err) + require.Len(t, userRes.Users, 1) + }) + + t.Run("postUserReactivatesSuspended", func(t *testing.T) { + // When the SCIM client deletes a user (which only suspends in Coder), + // posting the same user again should reactivate the existing row. + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + var created scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&created)) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + require.NotEmpty(t, created.ID) + + // Delete (suspends) the user. + res, err = client.Request(ctx, "DELETE", "/scim/v2/Users/"+created.ID, nil, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusNoContent, res.StatusCode) + + userRes, err := client.User(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusSuspended, userRes.Status) + + // Re-create. The handler should reactivate the existing row. + res, err = client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + var recreated scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&recreated)) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + assert.Equal(t, created.ID, recreated.ID, "recreate should reactivate the existing row, not create a new one") + assert.True(t, recreated.Active, "recreated user should be active in the SCIM response") + + // The DB user moves from suspended → dormant on reactivate; the SCIM + // response reports both Active and Dormant as active=true. + userRes, err = client.User(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusDormant, userRes.Status) + }) + + t.Run("getUser", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + var created scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&created)) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + + res, err = client.Request(ctx, "GET", "/scim/v2/Users/"+created.ID, nil, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + var got scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&got)) + assert.Equal(t, created.ID, got.ID) + assert.Equal(t, sUser.UserName, got.UserName) + }) + + t.Run("patchUser", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + mockAudit := audit.NewMock() + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{Auditor: mockAudit}, + SCIMAPIKey: scimAPIKey, + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureAuditLog: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + var created scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&created)) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + + // PATCH with replace op setting active=false. + mockAudit.ResetLogs() + patchBody := map[string]interface{}{ + "schemas": []string{"urn:ietf:params:scim:api:messages:2.0:PatchOp"}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "active", "value": false}, + }, + } + res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+created.ID, patchBody, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + + userRes, err := client.User(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusSuspended, userRes.Status) + }) + + t.Run("putUser", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + var created scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&created)) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + + // PUT with active=false. + sUser.Active = ptr.Ref(false) + res, err = client.Request(ctx, "PUT", "/scim/v2/Users/"+created.ID, sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + + userRes, err := client.User(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusSuspended, userRes.Status) + }) + + t.Run("deleteUser", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + + sUser := makeScim2User(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + var created scim2User + require.NoError(t, json.NewDecoder(res.Body).Decode(&created)) + _ = res.Body.Close() + require.Equal(t, http.StatusCreated, res.StatusCode) + + res, err = client.Request(ctx, "DELETE", "/scim/v2/Users/"+created.ID, nil, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusNoContent, res.StatusCode) + + // Coder does not hard-delete users. The user should remain but be suspended. + userRes, err := client.User(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, codersdk.UserStatusSuspended, userRes.Status) + }) +} + +func TestLegacyScimError(t *testing.T) { t.Parallel() // Demonstrates that we cannot use the standard errors @@ -876,7 +714,7 @@ func TestScimError(t *testing.T) { // Our error wrapper works rw = httptest.NewRecorder() - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found"))) + _ = handlerutil.WriteError(rw, legacyscim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found"))) resp = rw.Result() defer resp.Body.Close() require.Equal(t, http.StatusNotFound, resp.StatusCode) diff --git a/enterprise/coderd/scimroutes.go b/enterprise/coderd/scimroutes.go new file mode 100644 index 0000000000..891b760e2f --- /dev/null +++ b/enterprise/coderd/scimroutes.go @@ -0,0 +1,74 @@ +package coderd + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/legacyscim" + "github.com/coder/coder/v2/enterprise/coderd/scim" +) + +func (api *API) mountScimRoute(opt *Options, r chi.Router) error { + if len(opt.SCIMAPIKey) == 0 { + // Show a helpful 404 error. Because this is not under the /api/v2 routes, + // the frontend is the fallback. A html page is not a helpful error for + // a SCIM provider. This JSON has a call to action that __may__ resolve + // the issue. + // + // Using mount to cover all subroute possibilities + r.Mount("/", http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpapi.Write(r.Context(), w, http.StatusNotFound, codersdk.Response{ + Message: "SCIM is disabled, please contact your administrator if you believe this is an error", + Detail: "SCIM endpoints are disabled if no SCIM is configured. Configure 'CODER_SCIM_AUTH_HEADER' to enable.", + }) + }))) + return nil + } + + if opt.UseLegacySCIM { + // Legacy SCIM handler (imulab/go-scim based). Opt-in for + // backward compatibility during the transition period. + legacySrv := &legacyscim.LegacyServer{ + Logger: opt.Logger, + Database: opt.Database, + IDPSync: opt.IDPSync, + AGPL: api.AGPL, + AccessURL: api.AccessURL, + SCIMAPIKey: opt.SCIMAPIKey, + Auditor: &api.AGPL.Auditor, + } + r.Mount("/v2", chi.Chain( + api.RequireFeatureMW(codersdk.FeatureSCIM), + legacySrv.AuthMiddleware, + ).Handler(legacySrv.Handler())) + return nil + } + + // SCIM 2.0 handler (elimity-com/scim based). + scimSrv, err := scim.New(&scim.Options{ + DB: opt.Database, + Auditor: &api.AGPL.Auditor, + IDPSync: opt.IDPSync, + Logger: opt.Logger, + AGPL: api.AGPL, + SCIMAPIKey: opt.SCIMAPIKey, + }) + if err != nil { + return xerrors.Errorf("create scim server: %w", err) + } + + // The elimity-com/scim library reads r.URL.Path and strips "/v2" + // internally. Chi's Route/Mount modifies its own routing context + // but not r.URL.Path, so we use http.StripPrefix to ensure the + // library sees paths like "/v2/Users" instead of "/scim/v2/Users". + r.Mount("/", chi.Chain( + api.RequireFeatureMW(codersdk.FeatureSCIM), + middleware.StripPrefix("/scim"), + ).Handler(scimSrv.Handler())) + return nil +} diff --git a/enterprise/scaletest/agentfake/agent.go b/enterprise/scaletest/agentfake/agent.go index b03ebde8bd..c18d8e0310 100644 --- a/enterprise/scaletest/agentfake/agent.go +++ b/enterprise/scaletest/agentfake/agent.go @@ -2,34 +2,97 @@ package agentfake import ( "context" + "encoding/base64" "net/url" + "strings" "time" + "github.com/google/uuid" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/timestamppb" "cdr.dev/slog/v3" "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/codersdk/agentsdk" + tailnetproto "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/quartz" ) -const reconnectBackoff = 1 * time.Second +// rpcDialer is the subset of agentsdk.Client agentfake uses. Defined +// locally so tests can plug in *agent/agenttest.Client (or any other +// test double) without depending on the rest of the agentsdk.Client +// surface. +type rpcDialer interface { + ConnectRPC29WithRole(ctx context.Context, role string) ( + proto.DRPCAgentClient29, tailnetproto.DRPCTailnetClient28, error, + ) +} + +const ( + reconnectBackoff = 1 * time.Second + + // metadataTickInterval is the scheduler pulse for the per-agent metadata + // goroutine. Per-description cadence is enforced by tracking next-due + // timestamps; the ticker just wakes us up often enough to honor the + // shortest interval we expect (1s). + metadataTickInterval = 1 * time.Second + + // metadataValueBytes matches the payload size produced by the real + // scaletest template's metadata script (`dd if=/dev/urandom bs=3072 + // count=1 | base64`), so the synthetic load shape on the wire mirrors + // what a real agent emits. + metadataValueBytes = 3072 + + // metadataMinInterval is a floor applied to manifest-declared intervals + // to guard against a malformed manifest pinning the goroutine. + metadataMinInterval = 1 * time.Second +) // Agent is a single fake agent. It owns one workspace-agent auth token and one dRPC connection to coderd. type Agent struct { coderURL *url.URL token string logger slog.Logger + clock quartz.Clock + dialer rpcDialer // nil → built from coderURL+token in Run cancel context.CancelFunc } -func NewAgent(coderURL *url.URL, token string, logger slog.Logger) *Agent { - return &Agent{ +// Option configures an Agent. +type Option func(*Agent) + +// WithClock injects a clock for time-based operations. Defaults to +// quartz.NewReal(). Tests pass a *quartz.Mock to drive the metadata +// loop deterministically. The clock is per-agent so a future caller +// can give different agents slightly different cadences. +func WithClock(c quartz.Clock) Option { + return func(a *Agent) { + a.clock = c + } +} + +// WithDialer injects a custom RPC dialer. Defaults to a real +// agentsdk.Client built from coderURL + token. Tests use this to +// substitute *agent/agenttest.Client and avoid standing up a real +// coderd. +func WithDialer(d rpcDialer) Option { + return func(a *Agent) { + a.dialer = d + } +} + +func NewAgent(coderURL *url.URL, token string, logger slog.Logger, opts ...Option) *Agent { + a := &Agent{ coderURL: coderURL, token: token, logger: logger, + clock: quartz.NewReal(), } + for _, opt := range opts { + opt(a) + } + return a } // Run opens a dRPC websocket to coderd as the "agent" role and keeps it open until ctx is canceled or Close is called. @@ -42,7 +105,10 @@ func (a *Agent) Run(ctx context.Context) error { a.cancel = cancel defer a.cancel() - client := agentsdk.New(a.coderURL, agentsdk.WithFixedToken(a.token)) + client := a.dialer + if client == nil { + client = agentsdk.New(a.coderURL, agentsdk.WithFixedToken(a.token)) + } for { if err := runCtx.Err(); err != nil { return nil @@ -52,18 +118,20 @@ func (a *Agent) Run(ctx context.Context) error { a.logger.Warn(runCtx, "fake agent dRPC stream ended; reconnecting", slog.Error(err)) } + timer := a.clock.NewTimer(reconnectBackoff, "agentfake", "reconnect") select { case <-runCtx.Done(): + timer.Stop() return nil - case <-time.After(reconnectBackoff): + case <-timer.C: } } } // connectAndServe opens one dRPC websocket, announces lifecycle = READY, then blocks until ctx is canceled or the // connection is closed by either side. Returns the underlying error, if any. -func (a *Agent) connectAndServe(ctx context.Context, client *agentsdk.Client) error { - rpc, _, err := client.ConnectRPC28WithRole(ctx, "agent") +func (a *Agent) connectAndServe(ctx context.Context, client rpcDialer) error { + rpc, _, err := client.ConnectRPC29WithRole(ctx, "agent") if err != nil { return xerrors.Errorf("connect dRPC: %w", err) } @@ -87,6 +155,30 @@ func (a *Agent) connectAndServe(ctx context.Context, client *agentsdk.Client) er slog.Error(err)) } + // Fetch the agent manifest so we know which metadata descriptions the + // template declared. We synthesize values for each declared key at the + // declared interval. Failure here is non-fatal: a manifest fetch + // hiccup shouldn't tear the connection down, we just skip metadata + // for this session and let the next reconnect retry. + manifest, err := rpc.GetManifest(ctx, &proto.GetManifestRequest{}) + if err != nil { + if ctx.Err() == nil { + a.logger.Warn(ctx, "get manifest for metadata", slog.Error(err)) + } + } else if descs := manifest.GetMetadata(); len(descs) > 0 { + // Parse the workspace ID out of the manifest so we can embed it + // in the synthetic metadata payload below. If the manifest bytes + // are malformed (shouldn't happen in practice), fall back to + // uuid.Nil; the payload is still valid, just less identifiable. + workspaceID, idErr := uuid.FromBytes(manifest.GetWorkspaceId()) + if idErr != nil && ctx.Err() == nil { + a.logger.Warn(ctx, "parse workspace id from manifest; metadata payload will use uuid.Nil", + slog.Error(idErr)) + workspaceID = uuid.Nil + } + go a.runMetadata(ctx, rpc, workspaceID, descs) + } + select { case <-ctx.Done(): return nil @@ -95,6 +187,99 @@ func (a *Agent) connectAndServe(ctx context.Context, client *agentsdk.Client) er } } +// runMetadata sends synthetic values for every metadata description in the +// agent manifest, batching per-tick into a single BatchUpdateMetadata call. +// +// One goroutine per agent (not per description): a 1s ticker pulses and we +// track per-description next-due timestamps so each key reports at its own +// declared interval. The goroutine is scoped to the connection's ctx; on +// disconnect or shutdown it exits cleanly. +// +// The payload is a single fixed value, computed once: the workspace ID +// prepended to a constant padding so each metadata row in scaletest logs +// and the database is traceable back to the agent that emitted it. We +// intentionally do not vary the value per key or per tick; if a future +// scenario requires per-key/per-tick variation we can extend this then. +// +// Errors from BatchUpdateMetadata are logged and ignored. Tearing the +// connection down over a metadata RPC blip would be wasteful; real agents +// behave the same way (see agent.reportMetadata). +func (a *Agent) runMetadata(ctx context.Context, rpc proto.DRPCAgentClient29, workspaceID uuid.UUID, descs []*proto.WorkspaceAgentMetadata_Description) { + // Resolve declared intervals once, applying a floor so a malformed + // manifest can't spin us. Initialize all keys as immediately due so + // the first tick fires every description. + intervals := make([]time.Duration, len(descs)) + nextDue := make([]time.Time, len(descs)) + now := a.clock.Now() + for i, d := range descs { + // The Interval field on the proto is a durationpb.Duration but + // carries the raw int64 seconds value cast through time.Duration + // (see coderd/agentapi/manifest.go and agent/agent.go). Mirror the + // same recovery the real agent does so manifest-declared intervals + // of e.g. 10s are honored as 10s, not 10ns. + intervalSeconds := int64(d.GetInterval().AsDuration()) + interval := time.Duration(intervalSeconds) * time.Second + if interval < metadataMinInterval { + interval = metadataMinInterval + } + intervals[i] = interval + nextDue[i] = now + } + + // Build the metadata payload once: prepend the workspace ID so + // scaletest log lines and DB rows are traceable back to the + // emitting agent, then pad out to metadataValueBytes so the wire + // shape (base64-encoded ~4096 chars) mirrors the real scaletest + // template's `dd if=/dev/urandom bs=3072 count=1 | base64` output. + // coderd truncates the stored value to 2048 chars (see + // coderd/agentapi/metadata.go maxValueLen), and the workspace ID + // lives in the first ~50 chars of the base64 output, so it + // survives truncation. + const tag = "fake-agent-metadata workspace=" + prefix := tag + workspaceID.String() + " " + padLen := metadataValueBytes - len(prefix) + if padLen < 0 { + padLen = 0 + } + value := base64.StdEncoding.EncodeToString([]byte(prefix + strings.Repeat("a", padLen))) + + // TickerFunc spawns its own goroutine that ticks until ctx is + // done and then stops the underlying ticker. We Wait on the + // returned Waiter so that runMetadata (itself running in the + // goroutine spawned by connectAndServe) stays alive for the + // connection's lifetime, matching the pre-refactor for/select + // shape. The Wait error is discarded: ticker exits are expected + // (ctx cancellation), and our tick func never returns a non-nil + // error of its own. + _ = a.clock.TickerFunc(ctx, metadataTickInterval, func() error { + now := a.clock.Now() + var batch []*proto.Metadata + for i, d := range descs { + if now.Before(nextDue[i]) { + continue + } + batch = append(batch, &proto.Metadata{ + Key: d.GetKey(), + Result: &proto.WorkspaceAgentMetadata_Result{ + CollectedAt: timestamppb.New(now), + Value: value, + }, + }) + nextDue[i] = now.Add(intervals[i]) + } + if len(batch) == 0 { + return nil + } + if _, err := rpc.BatchUpdateMetadata(ctx, &proto.BatchUpdateMetadataRequest{ + Metadata: batch, + }); err != nil && ctx.Err() == nil { + a.logger.Debug(ctx, "batch update metadata failed", + slog.Error(err)) + } + return nil + }, "agentfake", "runMetadata").Wait() +} + // Close stops the agent. Safe to call multiple times. func (a *Agent) Close() { if a.cancel != nil { diff --git a/enterprise/scaletest/agentfake/agent_test.go b/enterprise/scaletest/agentfake/agent_test.go index d01776f66d..5997ef7f33 100644 --- a/enterprise/scaletest/agentfake/agent_test.go +++ b/enterprise/scaletest/agentfake/agent_test.go @@ -2,64 +2,62 @@ package agentfake_test import ( "context" + "encoding/base64" "testing" + "time" + "github.com/google/uuid" "github.com/stretchr/testify/require" "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/agent/agenttest" + agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/enterprise/scaletest/agentfake" + "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) // Assert that our fake agent routine establishes the drpc connection and sets its lifecycle status to Ready. func TestAgent_ConnectsAndReachesReady(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - - client, db := coderdtest.NewWithDatabase(t, nil) - user := coderdtest.CreateFirstUser(t, client) - - r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent().Do() + ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - a := agentfake.NewAgent(client.URL, r.AgentToken, logger) - t.Cleanup(func() { a.Close() }) + agentID := uuid.New() + manifest := agentsdk.Manifest{ + AgentID: agentID, + WorkspaceID: uuid.New(), + } + statsCh := make(chan *agentproto.Stats, 1) + coord := tailnet.NewCoordinator(logger) + t.Cleanup(func() { _ = coord.Close() }) + dialer := agenttest.NewClient(t, logger, agentID, manifest, statsCh, coord) + t.Cleanup(dialer.Close) + + a := agentfake.NewAgent(nil, "", logger, agentfake.WithDialer(dialer)) + t.Cleanup(a.Close) runCtx, cancel := context.WithCancel(ctx) t.Cleanup(cancel) runErr := make(chan error, 1) - go func() { - runErr <- a.Run(runCtx) - }() - - coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID). - WithContext(ctx). - Wait() + go func() { runErr <- a.Run(runCtx) }() + // The fake agent sends UpdateLifecycle(READY) once per dRPC + // connect; agenttest records every lifecycle update. require.Eventually(t, func() bool { - ws, err := client.Workspace(ctx, r.Workspace.ID) - if err != nil { - return false - } - for _, res := range ws.LatestBuild.Resources { - for _, agent := range res.Agents { - if agent.LifecycleState != codersdk.WorkspaceAgentLifecycleReady { - return false - } + for _, state := range dialer.GetLifecycleStates() { + if state == codersdk.WorkspaceAgentLifecycleReady { + return true } } - return true - }, testutil.WaitLong, testutil.IntervalFast, - "agent never reached Lifecycle=ready in workspace %s", r.Workspace.ID) + return false + }, testutil.WaitShort, testutil.IntervalFast, + "agent never reported Lifecycle=ready") // Cancel Run and confirm a clean exit (nil error, not ctx error). cancel() @@ -74,3 +72,84 @@ func TestAgent_ConnectsAndReachesReady(t *testing.T) { a.Close() a.Close() } + +// Assert that, when the workspace agent manifest declares metadata +// descriptions, the fake agent sends synthetic values for each key via +// BatchUpdateMetadata. The test drives the agent against +// agent/agenttest.Client (an in-process fake of the agent-side coderd +// API) rather than a real coderd, so the only quartz mock involved is +// the agentfake clock that drives the metadata ticker. +func TestAgent_SendsMetadata(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + agentID := uuid.New() + manifest := agentsdk.Manifest{ + AgentID: agentID, + WorkspaceID: uuid.New(), + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + {Key: "01_meta", DisplayName: "Meta 01", Script: "noop", Interval: 1, Timeout: 10}, + {Key: "02_meta", DisplayName: "Meta 02", Script: "noop", Interval: 1, Timeout: 10}, + }, + } + + // statsCh and coord are required by agenttest.NewClient but + // unused by agentfake. The dialer is the standin for the real + // agentsdk.Client; it records every RPC the agent makes so we + // can assert against the metadata batch directly. + statsCh := make(chan *agentproto.Stats, 1) + coord := tailnet.NewCoordinator(logger) + t.Cleanup(func() { _ = coord.Close() }) + dialer := agenttest.NewClient(t, logger, agentID, manifest, statsCh, coord) + t.Cleanup(dialer.Close) + + a := agentfake.NewAgent(nil, "", logger, + agentfake.WithDialer(dialer), + agentfake.WithClock(mClock), + ) + t.Cleanup(a.Close) + + // Trap the agent's runMetadata TickerFunc registration so we know + // the goroutine is parked on the mock clock before we Advance. + // Otherwise Advance could race the goroutine startup and the + // first tick would be missed. + tickerTrap := mClock.Trap().TickerFunc("agentfake", "runMetadata") + defer tickerTrap.Close() + + runCtx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + runErr := make(chan error, 1) + go func() { runErr <- a.Run(runCtx) }() + + tickerTrap.MustWait(ctx).Release(ctx) + + // One tick fires runMetadata's tick func, which calls + // BatchUpdateMetadata against agenttest.Client. The fake records + // it synchronously in-process; no pubsub, batcher, or SSE involved. + mClock.Advance(time.Second).MustWait(ctx) + + require.Eventually(t, func() bool { + md := dialer.GetMetadata() + for _, key := range []string{"01_meta", "02_meta"} { + m, ok := md[key] + if !ok || m.Value == "" { + return false + } + if _, err := base64.StdEncoding.DecodeString(m.Value); err != nil { + return false + } + } + return true + }, testutil.WaitShort, testutil.IntervalFast) + + cancel() + select { + case err := <-runErr: + require.NoError(t, err, "Agent.Run returned unexpected error") + case <-ctx.Done(): + t.Fatalf("timed out waiting for Agent.Run to return: %v", ctx.Err()) + } +} diff --git a/enterprise/scaletest/agentfake/manager.go b/enterprise/scaletest/agentfake/manager.go index d03e48307b..69315e99c6 100644 --- a/enterprise/scaletest/agentfake/manager.go +++ b/enterprise/scaletest/agentfake/manager.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/url" "strconv" "sync" "time" @@ -17,6 +18,16 @@ import ( "github.com/coder/coder/v2/codersdk" ) +// ExternalAgentClient is the subset of *codersdk.Client the Manager +// uses to enumerate external-agent workspaces under a template and +// fetch each agent's auth token. *codersdk.Client satisfies this +// interface, so production callers pass their client directly; tests +// substitute a fake without standing up a real coderd. +type ExternalAgentClient interface { + Workspaces(ctx context.Context, filter codersdk.WorkspaceFilter) (codersdk.WorkspacesResponse, error) + WorkspaceExternalAgentCredentials(ctx context.Context, workspaceID uuid.UUID, agentName string) (codersdk.ExternalAgentCredentials, error) +} + const ( enumeratePageSize = 100 maxEnumerateRetries = 5 @@ -48,9 +59,10 @@ type ManagerOptions struct { // (via coder_external_agent tokens on workspaces matching opts.Template), then opens a dRPC stream per agent and keeps // them connected until ctx is canceled. type Manager struct { - client *codersdk.Client - logger slog.Logger - opts ManagerOptions + coderURL *url.URL + client ExternalAgentClient + logger slog.Logger + opts ManagerOptions mu sync.Mutex agents []*Agent @@ -58,12 +70,14 @@ type Manager struct { // NewManager returns an Agent Manager. The provided client must already be authenticated with sufficient privilege // to list workspaces by template and to call the enterprise-only WorkspaceExternalAgentCredentials endpoint -// (template-admin or higher; FeatureWorkspaceExternalAgent must be enabled). -func NewManager(client *codersdk.Client, logger slog.Logger, opts ManagerOptions) *Manager { +// (template-admin or higher; FeatureWorkspaceExternalAgent must be enabled). coderURL is the URL the spawned +// fake agents will dial. +func NewManager(coderURL *url.URL, client ExternalAgentClient, logger slog.Logger, opts ManagerOptions) *Manager { return &Manager{ - client: client, - logger: logger, - opts: opts, + coderURL: coderURL, + client: client, + logger: logger, + opts: opts, } } @@ -84,7 +98,7 @@ func (m *Manager) Run(ctx context.Context) error { agents := make([]*Agent, 0, len(tokens)) for i, ti := range tokens { - agents = append(agents, NewAgent(m.client.URL, ti.Token, + agents = append(agents, NewAgent(m.coderURL, ti.Token, m.logger.Named("agent-"+strconv.Itoa(i)))) } m.mu.Lock() diff --git a/enterprise/scaletest/agentfake/manager_test.go b/enterprise/scaletest/agentfake/manager_test.go index 598729909f..769a773b1f 100644 --- a/enterprise/scaletest/agentfake/manager_test.go +++ b/enterprise/scaletest/agentfake/manager_test.go @@ -2,76 +2,131 @@ package agentfake_test import ( "context" - "database/sql" + "net/http" + "net/url" "sort" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" - "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/scaletest/agentfake" - sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/testutil" ) -// Asserts the TokenInfo shape (workspace IDs, agent names, tokens) returned by the enumeration loop. +// fakeExternalAgentClient is an in-package fake for the +// ExternalAgentClient interface used by +// Manager.EnumerateExternalAgents. Tests populate workspaces / +// credentials / workspacesErr before calling the Manager. +type fakeExternalAgentClient struct { + // workspaces, in the order Workspaces() should return them. Each + // call returns up to filter.Limit entries starting at filter.Offset + // to model pagination, matching real coderd behavior. + workspaces []codersdk.Workspace + // credentials, keyed by "{workspaceID}/{agentName}". A nil entry + // causes WorkspaceExternalAgentCredentials to error with notFoundErr. + credentials map[string]codersdk.ExternalAgentCredentials + + // workspacesErr, if non-nil, is returned from every Workspaces call. + workspacesErr error +} + +func (f *fakeExternalAgentClient) Workspaces(_ context.Context, filter codersdk.WorkspaceFilter) (codersdk.WorkspacesResponse, error) { + if f.workspacesErr != nil { + return codersdk.WorkspacesResponse{}, f.workspacesErr + } + start := filter.Offset + if start > len(f.workspaces) { + start = len(f.workspaces) + } + end := start + filter.Limit + if end > len(f.workspaces) { + end = len(f.workspaces) + } + page := f.workspaces[start:end] + return codersdk.WorkspacesResponse{ + Workspaces: page, + Count: len(f.workspaces), + }, nil +} + +func (f *fakeExternalAgentClient) WorkspaceExternalAgentCredentials(_ context.Context, wsID uuid.UUID, agentName string) (codersdk.ExternalAgentCredentials, error) { + key := wsID.String() + "/" + agentName + creds, ok := f.credentials[key] + if !ok { + return codersdk.ExternalAgentCredentials{}, xerrors.Errorf("no credentials for %s", key) + } + return creds, nil +} + +// externalAgentWorkspace returns a codersdk.Workspace whose latest +// build has HasExternalAgent=true and one agent with the given name. +func externalAgentWorkspace(t *testing.T, name, agentName string) (codersdk.Workspace, uuid.UUID) { + t.Helper() + wsID := uuid.New() + agentID := uuid.New() + hasExternal := true + return codersdk.Workspace{ + ID: wsID, + Name: name, + LatestBuild: codersdk.WorkspaceBuild{ + HasExternalAgent: &hasExternal, + Resources: []codersdk.WorkspaceResource{{ + Name: "external", + Type: "coder_external_agent", + Agents: []codersdk.WorkspaceAgent{{ + ID: agentID, + Name: agentName, + }}, + }}, + }, + }, agentID +} + +// Asserts the TokenInfo shape (workspace IDs, agent names, tokens) +// returned by the enumeration loop given a fake client. func Test_Manager_EnumerateExternalAgents_returnsAllTokens(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - - client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureWorkspaceExternalAgent: 1, - }, - }, - }) + ctx := testutil.Context(t, testutil.WaitShort) const numWorkspaces = 3 - first := buildExternalAgentWorkspace(t, db, user, uuid.Nil) - templateID := first.Workspace.TemplateID - want := []agentfake.TokenInfo{{ - WorkspaceID: first.Workspace.ID, - WorkspaceName: first.Workspace.Name, - AgentID: first.Agents[0].ID, - AgentName: first.Agents[0].Name, - Token: first.AgentToken, - }} - for i := 1; i < numWorkspaces; i++ { - r := buildExternalAgentWorkspace(t, db, user, templateID) + workspaces := make([]codersdk.Workspace, 0, numWorkspaces) + credentials := map[string]codersdk.ExternalAgentCredentials{} + want := make([]agentfake.TokenInfo, 0, numWorkspaces) + for i := 0; i < numWorkspaces; i++ { + agentName := "external" + ws, agentID := externalAgentWorkspace(t, "ws-"+uuid.NewString(), agentName) + workspaces = append(workspaces, ws) + token := uuid.NewString() + credentials[ws.ID.String()+"/"+agentName] = codersdk.ExternalAgentCredentials{ + AgentToken: token, + } want = append(want, agentfake.TokenInfo{ - WorkspaceID: r.Workspace.ID, - WorkspaceName: r.Workspace.Name, - AgentID: r.Agents[0].ID, - AgentName: r.Agents[0].Name, - Token: r.AgentToken, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentID: agentID, + AgentName: agentName, + Token: token, }) } - tmpl, err := client.Template(ctx, templateID) - require.NoError(t, err) - + client := &fakeExternalAgentClient{workspaces: workspaces, credentials: credentials} + coderURL, _ := url.Parse("http://fake") logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - m := agentfake.NewManager(client, logger, agentfake.ManagerOptions{Template: tmpl.Name}) + m := agentfake.NewManager(coderURL, client, logger, agentfake.ManagerOptions{Template: "tmpl"}) got, err := m.EnumerateExternalAgents(ctx) require.NoError(t, err) - // Order returned by coderd isn't guaranteed; sort both sides by WorkspaceID before comparing. sortTokenInfosByWorkspaceID(want) sortTokenInfosByWorkspaceID(got) - require.Equal(t, len(want), len(got), - "expected one TokenInfo per external-agent workspace under the template") + require.Equal(t, len(want), len(got), "expected one TokenInfo per external-agent workspace") for i := range want { assert.Equal(t, want[i].WorkspaceID, got[i].WorkspaceID, "WorkspaceID for entry %d", i) assert.Equal(t, want[i].AgentName, got[i].AgentName, "AgentName for entry %d", i) @@ -80,109 +135,25 @@ func Test_Manager_EnumerateExternalAgents_returnsAllTokens(t *testing.T) { } } -// Heavier-weight integration test for the agentfake harness: builds 5 external agents, sets up the client/Manager, -// and asserts that each of the agents the Manager sees via its enumeration function is properly connected and Ready. -func TestManager_FiveAgentsHeartbeat(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - - client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureWorkspaceExternalAgent: 1, - }, - }, - }) - - const numAgents = 5 - first := buildExternalAgentWorkspace(t, db, user, uuid.Nil) - templateID := first.Workspace.TemplateID - workspaceIDs := []uuid.UUID{first.Workspace.ID} - for i := 1; i < numAgents; i++ { - r := buildExternalAgentWorkspace(t, db, user, templateID) - workspaceIDs = append(workspaceIDs, r.Workspace.ID) - } - - tmpl, err := client.Template(ctx, templateID) - require.NoError(t, err) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - manager := agentfake.NewManager(client, logger, agentfake.ManagerOptions{ - Template: tmpl.Name, - }) - t.Cleanup(func() { manager.Close() }) - - managerCtx, cancelManager := context.WithCancel(ctx) - t.Cleanup(cancelManager) - - managerErr := make(chan error, 1) - go func() { - managerErr <- manager.Run(managerCtx) - }() - - // Each workspace's agent must reach Connected. Share the outer test ctx (testutil.WaitLong) across all five waiters - // so the total wait is bounded. - for _, wsID := range workspaceIDs { - coderdtest.NewWorkspaceAgentWaiter(t, client, wsID).WithContext(ctx).Wait() - } - - // Each workspace's agent must also reach Lifecycle=ready. The fake sends UpdateLifecycle(READY) once per dRPC - // connect; coderd persists that and exposes it on the agent. - for _, wsID := range workspaceIDs { - require.Eventually(t, func() bool { - ws, err := client.Workspace(ctx, wsID) - if err != nil { - return false - } - for _, res := range ws.LatestBuild.Resources { - for _, agent := range res.Agents { - if agent.LifecycleState != codersdk.WorkspaceAgentLifecycleReady { - return false - } - } - } - return true - }, testutil.WaitLong, testutil.IntervalFast, - "agent never reached Lifecycle=ready in workspace %s", wsID) - } - - // Cleanly stop the Manager and confirm it exits without a non-context error. - cancelManager() - select { - case err := <-managerErr: - if err != nil { - t.Fatalf("Manager.Run returned unexpected error: %v", err) - } - case <-ctx.Done(): - t.Fatalf("timed out waiting for Manager.Run to return: %v", ctx.Err()) - } -} - -// Asserts that an authentication failure during enumeration produces a fatal error, so the retry loop in -// enumerateWithRetry surfaces it immediately rather than hammering endpoints with credentials that will never work. +// Asserts that an authentication failure during enumeration produces a +// fatal error, so the retry loop in enumerateWithRetry surfaces it +// immediately rather than hammering endpoints with credentials that +// will never work. func Test_Manager_EnumerateExternalAgents_invalidTokenIsFatal(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - - client, db := coderdtest.NewWithDatabase(t, nil) - user := coderdtest.CreateFirstUser(t, client) - - r := buildExternalAgentWorkspace(t, db, user, uuid.Nil) - tmpl, err := client.Template(ctx, r.Workspace.TemplateID) - require.NoError(t, err) - - // Replace the client's session token with garbage to provoke a 401 from coderd's workspace-list endpoint. - // The Manager should surface that as a fatal error. - client.SetSessionToken("not-a-valid-session-token") + ctx := testutil.Context(t, testutil.WaitShort) + client := &fakeExternalAgentClient{ + workspacesErr: codersdk.NewError(http.StatusUnauthorized, codersdk.Response{Message: "unauthorized"}), + } + coderURL, _ := url.Parse("http://fake") logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - m := agentfake.NewManager(client, logger, agentfake.ManagerOptions{Template: tmpl.Name}) + m := agentfake.NewManager(coderURL, client, logger, agentfake.ManagerOptions{Template: "tmpl"}) - _, err = m.EnumerateExternalAgents(ctx) + _, err := m.EnumerateExternalAgents(ctx) require.Error(t, err, "expected enumeration to fail with an invalid session token") require.True(t, agentfake.IsFatalEnumerationError(err), - "expected error to be classified as fatal so the harness exits and Kubernetes can restart it; got: %v", err) + "expected error to be classified as fatal; got: %v", err) } func sortTokenInfosByWorkspaceID(s []agentfake.TokenInfo) { @@ -190,33 +161,3 @@ func sortTokenInfosByWorkspaceID(s []agentfake.TokenInfo) { return s[i].WorkspaceID.String() < s[j].WorkspaceID.String() }) } - -// buildExternalAgentWorkspace creates one workspace with a coder_external_agent resource, an agent, and -// HasExternalAgent=true on the latest build. If templateID is uuid.Nil, dbfake mints a fresh template (and the caller -// can pass the returned Workspace.TemplateID into subsequent calls to share the template). -func buildExternalAgentWorkspace( - t *testing.T, - db database.Store, - user codersdk.CreateFirstUserResponse, - templateID uuid.UUID, -) dbfake.WorkspaceResponse { - t.Helper() - - ws := database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - } - if templateID != uuid.Nil { - ws.TemplateID = templateID - } - return dbfake.WorkspaceBuild(t, db, ws). - Seed(database.WorkspaceBuild{ - HasExternalAgent: sql.NullBool{Bool: true, Valid: true}, - }). - Resource(&sdkproto.Resource{ - Name: "external", - Type: "coder_external_agent", - }). - WithAgent(). - Do() -} diff --git a/flake.nix b/flake.nix index 2465f94fac..e47b078777 100644 --- a/flake.nix +++ b/flake.nix @@ -111,13 +111,13 @@ # Keep Terraform aligned with provisioner/terraform/testdata/version.txt # so `make gen` remains deterministic in Nix shells. - terraform_1_15_2 = + terraform_1_15_5 = if pkgs.stdenv.isLinux && pkgs.stdenv.hostPlatform.isx86_64 then - pkgs.runCommand "terraform-1.15.2" { + pkgs.runCommand "terraform-1.15.5" { nativeBuildInputs = [ pkgs.unzip ]; src = pkgs.fetchurl { - url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_amd64.zip"; - hash = "sha256-xW/yvH5s6bOHmlA5KwPC6gdLR2iL9QP/lmyH+wGyqrg="; + url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_amd64.zip"; + hash = "sha256-cCshNq9nKMj/A3+EPdLbzit62IeGtzgdHXKu+iUPYBw="; }; } '' mkdir -p "$out/bin" @@ -208,7 +208,7 @@ # sqlc sqlc-custom syft - terraform_1_15_2 + terraform_1_15_5 typos which # Needed for many LD system libs! @@ -295,14 +295,6 @@ lib.optionalDrvAttr stdenv.isLinux "${glibcLocales}/lib/locale/locale-archive"; NODE_OPTIONS = "--max-old-space-size=8192"; - BIOME_BINARY = - if pkgs.stdenv.isLinux then - if pkgs.stdenv.hostPlatform.isAarch64 then - "@biomejs/cli-linux-arm64-musl/biome" - else - "@biomejs/cli-linux-x64-musl/biome" - else - ""; GOPRIVATE = "coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder"; }; }; diff --git a/go.mod b/go.mod index c84156c8c0..e13cebfcf3 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ replace github.com/tcnksm/go-httpstat => github.com/coder/go-httpstat v0.0.0-202 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20260519043957-6f014ff9434f +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20260529105257-b7c5fc6e6399 // This is replaced to include // 1. a fix for a data race: c.f. https://github.com/tailscale/wireguard-go/pull/25 @@ -192,7 +192,7 @@ require ( github.com/mocktools/go-smtp-mock/v2 v2.5.0 github.com/muesli/termenv v0.16.0 github.com/natefinch/atomic v1.0.1 - github.com/open-policy-agent/opa v1.11.0 + github.com/open-policy-agent/opa v1.17.0 github.com/ory/dockertest/v3 v3.12.0 github.com/pion/udp v0.1.4 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c @@ -220,8 +220,8 @@ require ( github.com/zclconf/go-cty-yaml v1.2.0 go.nhat.io/otelsql v0.16.0 go.opentelemetry.io/otel v1.43.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 go.uber.org/atomic v1.11.0 @@ -327,7 +327,6 @@ require ( github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.12 github.com/go-chi/hostrouter v0.3.0 // indirect - github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.22.4 // indirect @@ -417,7 +416,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect - github.com/prometheus/procfs v0.19.2 // indirect + github.com/prometheus/procfs v0.20.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20250401214520-65e299d6c5c9 // indirect github.com/riandyrn/otelchi v0.5.1 // indirect github.com/richardartoul/molecule v1.0.1-0.20240531184615-7ca0df43c0b3 // indirect @@ -468,7 +467,7 @@ require ( go.opentelemetry.io/contrib v1.19.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 go.opentelemetry.io/otel/metric v1.43.0 // indirect - go.opentelemetry.io/proto/otlp v1.9.0 // indirect + go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.1 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect @@ -478,9 +477,9 @@ require ( golang.zx2c4.com/wireguard/windows v0.5.3 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60 // indirect - gopkg.in/ini.v1 v1.67.1 // indirect + gopkg.in/ini.v1 v1.67.2 // indirect howett.net/plist v1.0.1 // indirect kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect sigs.k8s.io/yaml v1.6.0 // indirect @@ -512,11 +511,15 @@ require ( github.com/danieljoos/wincred v1.2.3 github.com/dgraph-io/ristretto/v2 v2.4.0 github.com/elazarl/goproxy v1.8.0 + github.com/elimity-com/scim v0.0.0-20260506142751-830e1caafcc3 github.com/fsnotify/fsnotify v1.10.1 github.com/go-git/go-git/v5 v5.19.1 github.com/invopop/jsonschema v0.14.0 github.com/mark3labs/mcp-go v0.38.0 + github.com/nats-io/nats-server/v2 v2.12.8 + github.com/nats-io/nats.go v1.51.0 github.com/openai/openai-go/v3 v3.28.0 + github.com/scim2/filter-parser/v2 v2.2.0 github.com/shopspring/decimal v1.4.0 github.com/smallstep/pkcs7 v0.2.1 github.com/sony/gobreaker/v2 v2.4.0 @@ -546,6 +549,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/alecthomas/chroma v0.10.0 // indirect + github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op // indirect github.com/aquasecurity/go-version v0.0.1 // indirect github.com/aquasecurity/iamgo v0.0.10 // indirect github.com/aquasecurity/jfather v0.0.8 // indirect @@ -572,7 +576,9 @@ require ( github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/daixiang0/gci v0.13.7 // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.1 // indirect + github.com/di-wu/parser v0.2.2 // indirect + github.com/di-wu/xsd-datetime v1.0.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect @@ -587,9 +593,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect - github.com/goccy/go-json v0.10.5 // indirect + github.com/goccy/go-json v0.10.6 // indirect github.com/goccy/go-yaml v1.19.2 // indirect github.com/google/go-containerregistry v0.20.7 // indirect + github.com/google/go-tpm v0.9.8 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72 // indirect github.com/hashicorp/go-getter v1.8.6 // indirect @@ -605,17 +612,20 @@ require ( github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c // indirect github.com/lestrrat-go/blackmagic v1.0.4 // indirect - github.com/lestrrat-go/dsig v1.0.0 // indirect + github.com/lestrrat-go/dsig v1.2.1 // indirect github.com/lestrrat-go/dsig-secp256k1 v1.0.0 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect - github.com/lestrrat-go/httprc/v3 v3.0.1 // indirect - github.com/lestrrat-go/jwx/v3 v3.0.12 // indirect - github.com/lestrrat-go/option v1.0.1 // indirect + github.com/lestrrat-go/httprc/v3 v3.0.5 // indirect + github.com/lestrrat-go/jwx/v3 v3.1.1 // indirect github.com/lestrrat-go/option/v2 v2.0.0 // indirect github.com/mattn/go-shellwords v1.0.12 // indirect + github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 // indirect github.com/moby/moby/api v1.54.0 // indirect github.com/moby/moby/client v0.3.0 // indirect github.com/moby/sys/user v0.4.0 // indirect + github.com/nats-io/jwt/v2 v2.8.1 // indirect + github.com/nats-io/nkeys v0.4.15 // indirect + github.com/nats-io/nuid v1.0.1 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect github.com/openai/openai-go v1.12.0 // indirect github.com/package-url/packageurl-go v0.1.3 // indirect @@ -634,8 +644,8 @@ require ( github.com/tmaxmax/go-sse v0.11.0 // indirect github.com/ulikunitz/xz v0.5.15 // indirect github.com/urfave/cli/v2 v2.27.5 // indirect - github.com/valyala/fastjson v1.6.4 // indirect - github.com/vektah/gqlparser/v2 v2.5.31 // indirect + github.com/valyala/fastjson v1.6.10 // indirect + github.com/vektah/gqlparser/v2 v2.5.33 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect @@ -643,7 +653,7 @@ require ( go.opentelemetry.io/contrib/detectors/gcp v1.42.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect - go.yaml.in/yaml/v2 v2.4.3 // indirect + go.yaml.in/yaml/v2 v2.4.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect golang.org/x/telemetry v0.0.0-20260508192327-42602be52be6 // indirect diff --git a/go.sum b/go.sum index 314cc3fa1f..e639f168f3 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,8 @@ github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eT github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op h1:kpBdlEPbRvff0mDD1gk7o9BhI16b9p5yYAXRlidpqJE= +github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E= github.com/apparentlymart/go-cidr v1.1.0 h1:2mAhrMoF+nhXqxTzSZMUzDHkLjmIHC+Zzn4tdgBZjnU= github.com/apparentlymart/go-cidr v1.1.0/go.mod h1:EBcsNrHc3zQeuaeCeCtQruQm+n9/YjEn/vI25Lg7Gwc= github.com/apparentlymart/go-textseg/v12 v12.0.0/go.mod h1:S/4uRK2UtaQttw1GenVJEynmyUenKwP++x/+DdGV/Ec= @@ -255,8 +257,8 @@ github.com/brianvoe/gofakeit/v7 v7.15.0 h1:kGLYAWN8tnmxq2PelKVK6zwpM7kMxdz9SGPH3 github.com/brianvoe/gofakeit/v7 v7.15.0/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk= github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= -github.com/bytecodealliance/wasmtime-go/v39 v39.0.1 h1:RibaT47yiyCRxMOj/l2cvL8cWiWBSqDXHyqsa9sGcCE= -github.com/bytecodealliance/wasmtime-go/v39 v39.0.1/go.mod h1:miR4NYIEBXeDNamZIzpskhJ0z/p8al+lwMWylQ/ZJb4= +github.com/bytecodealliance/wasmtime-go/v44 v44.0.0 h1:WRZXnLPIer/TWs5aYPaMlmVcOlzmR6Ur6wjLRIQOhTQ= +github.com/bytecodealliance/wasmtime-go/v44 v44.0.0/go.mod h1:GP93piU+39CoFVCQ5xfHrPOUtL0APlMnkbblJ2d3YY0= github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5 h1:BjkPE3785EwPhhyuFkbINB+2a1xATwk8SNDWnJiD41g= github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5/go.mod h1:jtAfVaU/2cu1+wdSRPWE2c1N2qeAA3K4RH9pYgqwets= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= @@ -348,8 +350,8 @@ github.com/coder/serpent v0.15.0 h1:jobR7DnPsxzEMD0cRiailwlY+4v6HAPS/8emIgBpaIU= github.com/coder/serpent v0.15.0/go.mod h1:7OIvFBYMd+OqarMy5einBl8AtRr8LliopVU7pyrwucY= github.com/coder/ssh v0.0.0-20231128192721-70855dedb788 h1:YoUSJ19E8AtuUFVYBpXuOD6a/zVP3rcxezNsoDseTUw= github.com/coder/ssh v0.0.0-20231128192721-70855dedb788/go.mod h1:aGQbuCLyhRLMzZF067xc84Lh7JDs1FKwCmF1Crl9dxQ= -github.com/coder/tailscale v1.1.1-0.20260519043957-6f014ff9434f h1:gYivllu5CHhvRr4SM93zSQDj9cG2V+Pc0URTFy3fF/Y= -github.com/coder/tailscale v1.1.1-0.20260519043957-6f014ff9434f/go.mod h1:WTWP5ZNODDXHwWlQ1Jc2MFhqxu93pUs7lIy28Fd5a5E= +github.com/coder/tailscale v1.1.1-0.20260529105257-b7c5fc6e6399 h1:4IhFSmu0DSfWrvmHCb8aXDjWqSEYoIDA1L7Ar82Dm84= +github.com/coder/tailscale v1.1.1-0.20260529105257-b7c5fc6e6399/go.mod h1:IatCC3hlq/ncu6DjZ+GJ/hNjSf5TmO+Xtc6B20k0q/c= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e h1:JNLPDi2P73laR1oAclY6jWzAbucf70ASAvf5mh2cME0= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI= github.com/coder/terraform-provider-coder/v2 v2.18.0 h1:b60ixwf7pVPuiL0GkHZf+1mVj94/HZhCNpsfjAK34mI= @@ -405,10 +407,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dblohm7/wingoes v0.0.0-20240820181039-f2b84150679e h1:L+XrFvD0vBIBm+Wf9sFN6aU395t7JROoai0qXZraA4U= github.com/dblohm7/wingoes v0.0.0-20240820181039-f2b84150679e/go.mod h1:SUxUaAK/0UG5lYyZR1L1nC4AaYYvSSYTWQSH3FPcxKU= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= -github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs= -github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.1 h1:5RVFMOWjMyRy8cARdy79nAmgYw3hK/4HUq48LQ6Wwqo= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.1/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/dgraph-io/badger/v4 v4.9.1 h1:DocZXZkg5JJHJPtUErA0ibyHxOVUDVoXLSCV6t8NC8w= +github.com/dgraph-io/badger/v4 v4.9.1/go.mod h1:5/MEx97uzdPUHR4KtkNt8asfI2T4JiEiQlV7kWUo8c0= github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= @@ -418,6 +420,10 @@ github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7c github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4= github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU= +github.com/di-wu/parser v0.2.2 h1:I9oHJ8spBXOeL7Wps0ffkFFFiXJf/pk7NX9lcAMqRMU= +github.com/di-wu/parser v0.2.2/go.mod h1:SLp58pW6WamdmznrVRrw2NTyn4wAvT9rrEFynKX7nYo= +github.com/di-wu/xsd-datetime v1.0.0 h1:vZoGNkbzpBNoc+JyfVLEbutNDNydYV8XwHeV7eUJoxI= +github.com/di-wu/xsd-datetime v1.0.0/go.mod h1:i3iEhrP3WchwseOBeIdW/zxeoleXTOzx1WyDXgdmOww= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= @@ -448,6 +454,8 @@ github.com/elastic/go-windows v1.0.0 h1:qLURgZFkkrYyTTkvYpsZIgf83AUsdIHfvlJaqaZ7 github.com/elastic/go-windows v1.0.0/go.mod h1:TsU0Nrp7/y3+VwE82FoZF8gC/XFg/Elz6CcloAxnPgU= github.com/elazarl/goproxy v1.8.0 h1:dt561rX7UAYMeFRLtzFx6uQGl2TpL1dr6uCG23nFQSY= github.com/elazarl/goproxy v1.8.0/go.mod h1:b5xm6W48AUHNpRTCvlnd0YVh+JafCCtsLsJZvvNTz+E= +github.com/elimity-com/scim v0.0.0-20260506142751-830e1caafcc3 h1:P+JJLBS2QNe5aWBpNoDWqmGwNv/DKP+WZpU/mPIS+28= +github.com/elimity-com/scim v0.0.0-20260506142751-830e1caafcc3/go.mod h1:JkjcmqbLW+khwt2fmBPJFBhx2zGZ8XobRZ+O0VhlwWo= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= github.com/emersion/go-smtp v0.21.2 h1:OLDgvZKuofk4em9fT5tFG5j4jE1/hXnX75UMvcrL4AA= @@ -483,8 +491,8 @@ github.com/fergusstrange/embedded-postgres v1.34.0 h1:c6RKhPKFsLVU+Tdxsx8q0UxCHs github.com/fergusstrange/embedded-postgres v1.34.0/go.mod h1:w0YvnCgf19o6tskInrOOACtnqfVlOvluz3hlNLY7tRk= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= -github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= -github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= +github.com/foxcpp/go-mockdns v1.2.0 h1:omK3OrHRD1IWJz1FuFBCFquhXslXoF17OvBS6JPzZF0= +github.com/foxcpp/go-mockdns v1.2.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= @@ -516,8 +524,6 @@ github.com/go-git/go-billy/v5 v5.9.0 h1:jItGXszUDRtR/AlferWPTMN4j38BQ88XnXKbilmm github.com/go-git/go-billy/v5 v5.9.0/go.mod h1:jCnQMLj9eUgGU7+ludSTYoZL/GGmii14RxKFj7ROgHw= github.com/go-git/go-git/v5 v5.19.1 h1:nX27AnaU43/K5bKktKwgBmR9lawoYVe1Ckg0rgzzN00= github.com/go-git/go-git/v5 v5.19.1/go.mod h1:Pb1v0c7/g8aGQJwx9Us09W85yGoyvSwuhEGMH7zjDKQ= -github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= -github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU= @@ -589,8 +595,8 @@ github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU= +github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -655,6 +661,8 @@ github.com/google/go-github/v61 v61.0.0 h1:VwQCBwhyE9JclCI+22/7mLB1PuU9eowCXKY5p github.com/google/go-github/v61 v61.0.0/go.mod h1:0WR+KmsWX75G2EbpyGsGmradjo3IiciuI4BmdVCobQY= github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0= github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU= +github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo= +github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -832,18 +840,16 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lestrrat-go/blackmagic v1.0.4 h1:IwQibdnf8l2KoO+qC3uT4OaTWsW7tuRQXy9TRN9QanA= github.com/lestrrat-go/blackmagic v1.0.4/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= -github.com/lestrrat-go/dsig v1.0.0 h1:OE09s2r9Z81kxzJYRn07TFM9XA4akrUdoMwr0L8xj38= -github.com/lestrrat-go/dsig v1.0.0/go.mod h1:dEgoOYYEJvW6XGbLasr8TFcAxoWrKlbQvmJgCR0qkDo= +github.com/lestrrat-go/dsig v1.2.1 h1:MwxzZhE4+4fguHi+uDALKVlC3Cn+O1QU1Q/F8D7hVIc= +github.com/lestrrat-go/dsig v1.2.1/go.mod h1:RD2eOaidyPvpc7IJQoO3Qq52RWdy8ZcJs8lrOnoa1Kc= github.com/lestrrat-go/dsig-secp256k1 v1.0.0 h1:JpDe4Aybfl0soBvoVwjqDbp+9S1Y2OM7gcrVVMFPOzY= github.com/lestrrat-go/dsig-secp256k1 v1.0.0/go.mod h1:CxUgAhssb8FToqbL8NjSPoGQlnO4w3LG1P0qPWQm/NU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= -github.com/lestrrat-go/httprc/v3 v3.0.1 h1:3n7Es68YYGZb2Jf+k//llA4FTZMl3yCwIjFIk4ubevI= -github.com/lestrrat-go/httprc/v3 v3.0.1/go.mod h1:2uAvmbXE4Xq8kAUjVrZOq1tZVYYYs5iP62Cmtru00xk= -github.com/lestrrat-go/jwx/v3 v3.0.12 h1:p25r68Y4KrbBdYjIsQweYxq794CtGCzcrc5dGzJIRjg= -github.com/lestrrat-go/jwx/v3 v3.0.12/go.mod h1:HiUSaNmMLXgZ08OmGBaPVvoZQgJVOQphSrGr5zMamS8= -github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= -github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lestrrat-go/httprc/v3 v3.0.5 h1:S+Mb4L2I+bM6JGTibLmxExhyTOqnXjqx+zi9MoXw/TM= +github.com/lestrrat-go/httprc/v3 v3.0.5/go.mod h1:mSMtkZW92Z98M5YoNNztbRGxbXHql7tSitCvaxvo9l0= +github.com/lestrrat-go/jwx/v3 v3.1.1 h1:yd9AdPmZ4INnQ7k42IrzXYpnEG803+SrQ6hdMvzHJzw= +github.com/lestrrat-go/jwx/v3 v3.1.1/go.mod h1:uw/MN2M/Xiu4FhwcIwH11Zsh9JWx9SWzgALl7/uIEkU= github.com/lestrrat-go/option/v2 v2.0.0 h1:XxrcaJESE1fokHy3FpaQ/cXW8ZsIdWcdFzzLOcID3Ss= github.com/lestrrat-go/option/v2 v2.0.0/go.mod h1:oSySsmzMoR0iRzCDCaUfsCzxQHUEuhOViQObyy7S6Vg= github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4= @@ -891,6 +897,8 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= @@ -949,6 +957,16 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/nats-io/jwt/v2 v2.8.1 h1:V0xpGuD/N8Mi+fQNDynXohVvp7ZztevW5io8CUWlPmU= +github.com/nats-io/jwt/v2 v2.8.1/go.mod h1:nWnOEEiVMiKHQpnAy4eXlizVEtSfzacZ1Q43LIRavZg= +github.com/nats-io/nats-server/v2 v2.12.8 h1:R6CyZl6cWXTkS9lwMnDxjJsUezoW+hAD+SkdcSOf4DI= +github.com/nats-io/nats-server/v2 v2.12.8/go.mod h1:VmV5LcQmqUq8g1TX9VyEKqnxTR/05F6skTALlL8AsvQ= +github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI= +github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno= +github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= +github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/niklasfasching/go-org v1.9.1 h1:/3s4uTPOF06pImGa2Yvlp24yKXZoTYM+nsIlMzfpg/0= @@ -967,8 +985,8 @@ github.com/olekukonko/ll v0.1.6 h1:lGVTHO+Qc4Qm+fce/2h2m5y9LvqaW+DCN7xW9hsU3uA= github.com/olekukonko/ll v0.1.6/go.mod h1:NVUmjBb/aCtUpjKk75BhWrOlARz3dqsM+OtszpY4o88= github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I= github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY= -github.com/open-policy-agent/opa v1.11.0 h1:eOd/jJrbavakiX477yT4LrXZfUWViAot/AsKsjsfe7o= -github.com/open-policy-agent/opa v1.11.0/go.mod h1:QimuJO4T3KYxWzrmAymqlFvsIanCjKrGjmmC8GgAdgE= +github.com/open-policy-agent/opa v1.17.0 h1:TMm6bCyb3CEL4wjXsXn1d/kBSBbjF+5sEIyzQvbJiEw= +github.com/open-policy-agent/opa v1.17.0/go.mod h1:lcuZYSlqQpXFzsA6EJCELmfR5+nNOpZYX+eo7xaIIlk= github.com/open-telemetry/opentelemetry-collector-contrib/pkg/sampling v0.120.1 h1:lK/3zr73guK9apbXTcnDnYrC0YCQ25V3CIULYz3k2xU= github.com/open-telemetry/opentelemetry-collector-contrib/pkg/sampling v0.120.1/go.mod h1:01TvyaK8x640crO2iFwW/6CFCZgNsOvOGH3B5J239m0= github.com/open-telemetry/opentelemetry-collector-contrib/processor/probabilisticsamplerprocessor v0.120.1 h1:TCyOus9tym82PD1VYtthLKMVMlVyRwtDI4ck4SR2+Ok= @@ -1037,8 +1055,8 @@ github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNw github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= -github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= -github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= +github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc= +github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/quasilyte/go-ruleguard/dsl v0.3.23 h1:lxjt5B6ZCiBeeNO8/oQsegE6fLeCzuMRoVWSkXC4uvY= @@ -1067,6 +1085,8 @@ github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEV github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b h1:gQZ0qzfKHQIybLANtM3mBXNUtOfsCFXeTsnBqCsx1KM= github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/scim2/filter-parser/v2 v2.2.0 h1:QGadEcsmypxg8gYChRSM2j1edLyE/2j72j+hdmI4BJM= +github.com/scim2/filter-parser/v2 v2.2.0/go.mod h1:jWnkDToqX/Y0ugz0P5VvpVEUKcWcyHHj+X+je9ce5JA= github.com/secure-systems-lab/go-securesystemslib v0.10.0 h1:l+H5ErcW0PAehBNrBxoGv1jjNpGYdZ9RcheFkB2WI14= github.com/secure-systems-lab/go-securesystemslib v0.10.0/go.mod h1:MRKONWmRoFzPNQ9USRF9i1mc7MvAVvF1LlW8X5VWDvk= github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= @@ -1198,12 +1218,12 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.71.0 h1:tepR7H+Guh9VUqxxcPggYi8R3lGUu2Rsdh+z7/FCY3k= github.com/valyala/fasthttp v1.71.0/go.mod h1:z1sDUvOShhXq/C9mwH/fSm1Vb71tUJwmQdgkBrBNwnA= -github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ= -github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADTh4= +github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE= github.com/vbatts/tar-split v0.12.2 h1:w/Y6tjxpeiFMR47yzZPlPj/FcPLpXbTUi/9H7d3CPa4= github.com/vbatts/tar-split v0.12.2/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= -github.com/vektah/gqlparser/v2 v2.5.31 h1:YhWGA1mfTjID7qJhd1+Vxhpk5HTgydrGU9IgkWBTJ7k= -github.com/vektah/gqlparser/v2 v2.5.31/go.mod h1:c1I28gSOVNzlfc4WuDlqU7voQnsqI6OG2amkBAFmgts= +github.com/vektah/gqlparser/v2 v2.5.33 h1:lRp8aIeNUNbimf/axZd7ETg24q06hBtPaas+TcvI/7E= +github.com/vektah/gqlparser/v2 v2.5.33/go.mod h1:c1I28gSOVNzlfc4WuDlqU7voQnsqI6OG2amkBAFmgts= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= @@ -1322,12 +1342,10 @@ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1: go.opentelemetry.io/otel v1.3.0/go.mod h1:PWIKzi6JCp7sM0k9yZ43VX+T345uNbAkDKwHVjb2PTs= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 h1:3iZJKlCZufyRzPzlQhUIWVmfltrXuGyfjREgGP3UUjc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0/go.mod h1:/G+nUPfhq2e+qiXMGxMwumDrP5jtzU+mWN7/sjT2rak= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0 h1:SNhVp/9q4Go/XHBkQ1/d5u9P/U+L1yaGPoi0x+mStaI= @@ -1342,8 +1360,8 @@ go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHS go.opentelemetry.io/otel/trace v1.3.0/go.mod h1:c/VDhno8888bvQYmbYLqe41/Ldmr/KKunbvWM4/fEjk= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= -go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= -go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -1355,8 +1373,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= -go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ= +go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v4 v4.0.0-rc.3 h1:3h1fjsh1CTAPjW7q/EMe+C8shx5d8ctzZTrLcs/j8Go= @@ -1462,6 +1480,7 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= @@ -1534,8 +1553,8 @@ google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg= google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= -google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI= -google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60 h1:seT2EwLWM78plQ7wcDfuWBc/4FAEAXDDiaSol4ku4qo= google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ= @@ -1553,8 +1572,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/dnaeon/go-vcr.v4 v4.0.6 h1:PiJkrakkmzc5s7EfBnZOnyiLwi7o7A9fwPzN0X2uwe0= gopkg.in/dnaeon/go-vcr.v4 v4.0.6/go.mod h1:sbq5oMEcM4PXngbcNbHhzfCP9OdZodLhrbRYoyg09HY= -gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k= -gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss= +gopkg.in/ini.v1 v1.67.2 h1:JtOSMb9OuaCZKr7h5D/h6iii14sK0hLbplTc6frx4Ss= +gopkg.in/ini.v1 v1.67.2/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= diff --git a/install.sh b/install.sh index cbb74248fc..daf4f59836 100755 --- a/install.sh +++ b/install.sh @@ -276,7 +276,7 @@ EOF main() { MAINLINE=1 STABLE=0 - TERRAFORM_VERSION="1.15.2" + TERRAFORM_VERSION="1.15.5" if [ "${TRACE-}" ]; then set -x diff --git a/mise.lock b/mise.lock index 7f96bd9e3c..babc55e498 100644 --- a/mise.lock +++ b/mise.lock @@ -896,49 +896,49 @@ checksum = "sha256:b8bfdedb261de2a69768097422a73bc72273ee92136ff676a20c3161e6588 url = "https://github.com/anchore/syft/releases/download/v1.20.0/syft_1.20.0_windows_amd64.zip" [[tools.terraform]] -version = "1.15.2" +version = "1.15.5" backend = "aqua:hashicorp/terraform" [tools.terraform."platforms.linux-arm64"] -checksum = "sha256:cf27657e96bbdc6116f4c16a0c801d36ae6410d7210183a520ac6b2198fb723e" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_arm64.zip" +checksum = "sha256:06e7b48de826146c6d9331ba35b13da12332d8392be30d1dd6b789ba4713fff0" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_arm64.zip" [tools.terraform."platforms.linux-arm64-musl"] -checksum = "sha256:cf27657e96bbdc6116f4c16a0c801d36ae6410d7210183a520ac6b2198fb723e" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_arm64.zip" +checksum = "sha256:06e7b48de826146c6d9331ba35b13da12332d8392be30d1dd6b789ba4713fff0" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_arm64.zip" [tools.terraform."platforms.linux-x64"] -checksum = "sha256:c56ff2bc7e6ce9b3879a50392b03c2ea074b47688bf503ff966c87fb01b2aab8" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_amd64.zip" +checksum = "sha256:702b2136af6728c8ff037f843dd2dbce2b7ad88786b7381d1d72aefa250f601c" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_amd64.zip" [tools.terraform."platforms.linux-x64-baseline"] -checksum = "sha256:c56ff2bc7e6ce9b3879a50392b03c2ea074b47688bf503ff966c87fb01b2aab8" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_amd64.zip" +checksum = "sha256:702b2136af6728c8ff037f843dd2dbce2b7ad88786b7381d1d72aefa250f601c" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_amd64.zip" [tools.terraform."platforms.linux-x64-musl"] -checksum = "sha256:c56ff2bc7e6ce9b3879a50392b03c2ea074b47688bf503ff966c87fb01b2aab8" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_amd64.zip" +checksum = "sha256:702b2136af6728c8ff037f843dd2dbce2b7ad88786b7381d1d72aefa250f601c" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_amd64.zip" [tools.terraform."platforms.linux-x64-musl-baseline"] -checksum = "sha256:c56ff2bc7e6ce9b3879a50392b03c2ea074b47688bf503ff966c87fb01b2aab8" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_amd64.zip" +checksum = "sha256:702b2136af6728c8ff037f843dd2dbce2b7ad88786b7381d1d72aefa250f601c" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_amd64.zip" [tools.terraform."platforms.macos-arm64"] -checksum = "sha256:4204bc3450418a7ce423e58451b053e5daed625ad6c6a15de98bc09345269f99" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_darwin_arm64.zip" +checksum = "sha256:01137660510005b918bba82154866fbeac4393163d8277c2abe861dfb5842c3c" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_darwin_arm64.zip" [tools.terraform."platforms.macos-x64"] -checksum = "sha256:2bb701bc2db93ed39613df4f4e033ec4c2de9eba1c036d9a2f62cffc988af066" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_darwin_amd64.zip" +checksum = "sha256:3687d07c034b3e7deed5b072cd8ae2b34835bcb139baec3fc4f5fd534dabf5ed" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_darwin_amd64.zip" [tools.terraform."platforms.macos-x64-baseline"] -checksum = "sha256:2bb701bc2db93ed39613df4f4e033ec4c2de9eba1c036d9a2f62cffc988af066" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_darwin_amd64.zip" +checksum = "sha256:3687d07c034b3e7deed5b072cd8ae2b34835bcb139baec3fc4f5fd534dabf5ed" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_darwin_amd64.zip" [tools.terraform."platforms.windows-x64"] -checksum = "sha256:a7e25570dd85f363581e96cac0b468257c45945ca8875d951413b6606c9b86d4" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_windows_amd64.zip" +checksum = "sha256:2f652dd854af7b7fbb51301afc55b5ef1d3f6e287be7889d4cc3818df891cd38" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_windows_amd64.zip" [tools.terraform."platforms.windows-x64-baseline"] -checksum = "sha256:a7e25570dd85f363581e96cac0b468257c45945ca8875d951413b6606c9b86d4" -url = "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_windows_amd64.zip" +checksum = "sha256:2f652dd854af7b7fbb51301afc55b5ef1d3f6e287be7889d4cc3818df891cd38" +url = "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_windows_amd64.zip" diff --git a/mise.toml b/mise.toml index c7366deecd..b148fe41c6 100644 --- a/mise.toml +++ b/mise.toml @@ -40,7 +40,7 @@ golangci-lint = "1.64.8" helm = "3.21.0" kubectx = "0.9.4" syft = "1.20.0" -terraform = "1.15.2" +terraform = "1.15.5" # Developer-environment niceties for the dogfood image. Non-dogfood # users who run `mise install` here will pull these too; they are diff --git a/provisioner/terraform/provision.go b/provisioner/terraform/provision.go index 592bc3c9cc..90c96403bc 100644 --- a/provisioner/terraform/provision.go +++ b/provisioner/terraform/provision.go @@ -381,7 +381,7 @@ func provisionEnv( "CODER_WORKSPACE_BUILD_ID="+metadata.GetWorkspaceBuildId(), "CODER_TASK_ID="+metadata.GetTaskId(), "CODER_TASK_PROMPT="+metadata.GetTaskPrompt(), - "AWS_SDK_UA_APP_ID=APN_1.1/pc_cdfmjwn8i6u8l9fwz8h82e4w3$", + awsSDKUserAgentEnv(safeEnvironValue(env, awsSDKUserAgentEnvKey)), ) if metadata.GetPrebuiltWorkspaceBuildStage().IsPrebuild() { env = append(env, provider.IsPrebuildEnvironmentVariable()+"=true") diff --git a/provisioner/terraform/safeenv.go b/provisioner/terraform/safeenv.go index 4da2fc32cd..a42a899bc8 100644 --- a/provisioner/terraform/safeenv.go +++ b/provisioner/terraform/safeenv.go @@ -53,3 +53,39 @@ func safeEnviron() []string { } return strippedEnv } + +// safeEnvironValue returns the value of the named variable in the given +// `KEY=VALUE` environment slice, or an empty string if it is not present. +func safeEnvironValue(env []string, name string) string { + prefix := name + "=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return strings.TrimPrefix(e, prefix) + } + } + return "" +} + +const ( + awsSDKUserAgentEnvKey = "AWS_SDK_UA_APP_ID" + // awsSDKUserAgentCoder is Coder's AWS Partner Revenue Measurement + // User-Agent string. The `APN_1.1/pc_$` format and the + // space-delimited append behavior below follow AWS's guidance: + // https://docs.aws.amazon.com/PRM/latest/aws-prm-onboarding-guide/automated-user-agent.html + awsSDKUserAgentCoder = "APN_1.1/pc_cdfmjwn8i6u8l9fwz8h82e4w3$" +) + +// awsSDKUserAgentEnv returns the AWS_SDK_UA_APP_ID value to pass to the +// Terraform subprocess. If the caller's environment already configures an +// Application ID (e.g. an operator who is also an AWS Partner and wants +// their own revenue attribution), Coder's value is appended with a space +// delimiter so both attributions are preserved. Otherwise Coder's value is +// used on its own. +// +// See: https://docs.aws.amazon.com/PRM/latest/aws-prm-onboarding-guide/automated-user-agent.html +func awsSDKUserAgentEnv(existing string) string { + if existing == "" { + return awsSDKUserAgentEnvKey + "=" + awsSDKUserAgentCoder + } + return awsSDKUserAgentEnvKey + "=" + existing + " " + awsSDKUserAgentCoder +} diff --git a/provisioner/terraform/safeenv_internal_test.go b/provisioner/terraform/safeenv_internal_test.go new file mode 100644 index 0000000000..1863f8fee1 --- /dev/null +++ b/provisioner/terraform/safeenv_internal_test.go @@ -0,0 +1,44 @@ +package terraform + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafeEnvironValue(t *testing.T) { + t.Parallel() + + env := []string{ + "FOO=bar", + "AWS_SDK_UA_APP_ID=my-existing-id", + "BAZ=qux", + } + require.Equal(t, "my-existing-id", safeEnvironValue(env, "AWS_SDK_UA_APP_ID")) + require.Equal(t, "bar", safeEnvironValue(env, "FOO")) + require.Equal(t, "", safeEnvironValue(env, "MISSING")) +} + +func TestAWSSDKUserAgentEnv(t *testing.T) { + t.Parallel() + + t.Run("NoExisting", func(t *testing.T) { + t.Parallel() + require.Equal(t, + "AWS_SDK_UA_APP_ID=APN_1.1/pc_cdfmjwn8i6u8l9fwz8h82e4w3$", + awsSDKUserAgentEnv(""), + ) + }) + + t.Run("AppendToExisting", func(t *testing.T) { + t.Parallel() + // When the operator is themselves an AWS Partner and has set their own + // Application ID, we append Coder's with a space delimiter so both + // attributions are preserved. See: + // https://docs.aws.amazon.com/PRM/latest/aws-prm-onboarding-guide/automated-user-agent.html + require.Equal(t, + "AWS_SDK_UA_APP_ID=EXISTING_APP_ID APN_1.1/pc_cdfmjwn8i6u8l9fwz8h82e4w3$", + awsSDKUserAgentEnv("EXISTING_APP_ID"), + ) + }) +} diff --git a/provisioner/terraform/testdata/resources/ai-tasks-disabled/ai-tasks-disabled.tfplan.json b/provisioner/terraform/testdata/resources/ai-tasks-disabled/ai-tasks-disabled.tfplan.json index 455f32871c..a3ce227430 100644 --- a/provisioner/terraform/testdata/resources/ai-tasks-disabled/ai-tasks-disabled.tfplan.json +++ b/provisioner/terraform/testdata/resources/ai-tasks-disabled/ai-tasks-disabled.tfplan.json @@ -1,12 +1,12 @@ { "format_version": "1.2", - "terraform_version": "1.15.2", + "terraform_version": "1.15.5", "planned_values": { "root_module": {} }, "prior_state": { "format_version": "1.0", - "terraform_version": "1.15.2", + "terraform_version": "1.15.5", "values": { "root_module": { "resources": [ diff --git a/provisioner/terraform/testdata/version.txt b/provisioner/terraform/testdata/version.txt index 42cf0675c5..d32434904b 100644 --- a/provisioner/terraform/testdata/version.txt +++ b/provisioner/terraform/testdata/version.txt @@ -1 +1 @@ -1.15.2 +1.15.5 diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 769bdb8446..2cbdb6eabd 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -533,7 +533,10 @@ func (p *Server) UploadModuleFiles(ctx context.Context, moduleFiles []byte) erro } defer stream.Close() - dataUp, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleFiles) + dataUp, chunks, err := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleFiles) + if err != nil { + return nil, xerrors.Errorf("prepare module files upload: %w", err) + } err = stream.Send(&sdkproto.FileUpload{Type: &sdkproto.FileUpload_DataUpload{DataUpload: dataUp}}) if err != nil { diff --git a/provisionerd/runner/init.go b/provisionerd/runner/init.go index 45c762b7fa..13a8c5066a 100644 --- a/provisionerd/runner/init.go +++ b/provisionerd/runner/init.go @@ -19,14 +19,17 @@ func (r *Runner) init(ctx context.Context, omitModules bool, templateArchive []b // If `moduleTar` is populated, `init` will send it over in multiple parts. This // It must be called before the initial request to populate the correct hash if // there is data to send. This is safe to call on nil or empty slices. - data, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleTar) + data, chunks, err := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleTar) + if err != nil { + return nil, r.failedJobf("prepare module files upload: %v", err) + } hash := []byte{} if len(moduleTar) > 0 { hash = data.DataHash } - err := r.session.Send(&sdkproto.Request{Type: &sdkproto.Request_Init{Init: &sdkproto.InitRequest{ + err = r.session.Send(&sdkproto.Request{Type: &sdkproto.Request_Init{Init: &sdkproto.InitRequest{ TemplateSourceArchive: templateArchive, OmitModuleFiles: omitModules, InitialModuleTarHash: hash, diff --git a/provisionersdk/proto/dataupload.go b/provisionersdk/proto/dataupload.go index e9b6d9ddfb..f8832d3616 100644 --- a/provisionersdk/proto/dataupload.go +++ b/provisionersdk/proto/dataupload.go @@ -9,7 +9,8 @@ import ( ) const ( - ChunkSize = 2 << 20 // 2 MiB + ChunkSize = 2 << 20 // 2 MiB + MaxFileSize = 10 * (10 << 20) // 100 MiB, matches coderd HTTPFileMaxBytes ) type DataBuilder struct { @@ -29,6 +30,21 @@ func NewDataBuilder(req *DataUpload) (*DataBuilder, error) { return nil, xerrors.Errorf("data hash must be 32 bytes, got %d bytes", len(req.DataHash)) } + if req.FileSize < 0 { + return nil, xerrors.Errorf("file size must not be negative, got %d", req.FileSize) + } + if req.FileSize > MaxFileSize { + return nil, xerrors.Errorf("file size %d exceeds maximum allowed %d", req.FileSize, MaxFileSize) + } + if req.Chunks < 0 { + return nil, xerrors.Errorf("chunk count must not be negative, got %d", req.Chunks) + } + //nolint:gosec // FileSize is validated to be <= MaxFileSize, well within int32 range + maxChunks := int32((req.FileSize + ChunkSize - 1) / ChunkSize) + if req.Chunks > maxChunks { + return nil, xerrors.Errorf("chunk count %d exceeds maximum %d for file size %d", req.Chunks, maxChunks, req.FileSize) + } + return &DataBuilder{ Type: req.UploadType, Hash: req.DataHash, @@ -60,7 +76,7 @@ func (b *DataBuilder) Add(chunk *ChunkPiece) (bool, error) { expectedSize := len(b.data) + len(chunk.Data) if expectedSize > int(b.Size) { return b.done(), xerrors.Errorf("data exceeds expected size, data is now %d bytes, %d bytes over the limit of %d", - expectedSize, b.Size-int64(expectedSize), b.Size) + expectedSize, int64(expectedSize)-b.Size, b.Size) } b.data = append(b.data, chunk.Data...) @@ -103,7 +119,11 @@ func (b *DataBuilder) done() bool { return b.chunkIndex >= b.ChunkCount } -func BytesToDataUpload(dataType DataUploadType, data []byte) (*DataUpload, []*ChunkPiece) { +func BytesToDataUpload(dataType DataUploadType, data []byte) (*DataUpload, []*ChunkPiece, error) { + if int64(len(data)) > MaxFileSize { + return nil, nil, xerrors.Errorf("data size %d exceeds maximum allowed %d", len(data), MaxFileSize) + } + fullHash := sha256.Sum256(data) //nolint:gosec // not going over int32 size := int32(len(data)) @@ -135,5 +155,5 @@ func BytesToDataUpload(dataType DataUploadType, data []byte) (*DataUpload, []*Ch chunks = append(chunks, chunk) } - return req, chunks + return req, chunks, nil } diff --git a/provisionersdk/proto/dataupload_test.go b/provisionersdk/proto/dataupload_test.go index 496a7956c9..d8876240b0 100644 --- a/provisionersdk/proto/dataupload_test.go +++ b/provisionersdk/proto/dataupload_test.go @@ -2,6 +2,7 @@ package proto_test import ( crand "crypto/rand" + "crypto/sha256" "math/rand" "testing" @@ -10,6 +11,101 @@ import ( "github.com/coder/coder/v2/provisionersdk/proto" ) +func TestNewDataBuilderValidation(t *testing.T) { + t.Parallel() + + validHash := sha256.Sum256([]byte{}) + + t.Run("ExactMaxFileSize", func(t *testing.T) { + t.Parallel() + builder, err := proto.NewDataBuilder(&proto.DataUpload{ + DataHash: validHash[:], + FileSize: proto.MaxFileSize, + Chunks: int32((proto.MaxFileSize + proto.ChunkSize - 1) / proto.ChunkSize), + UploadType: proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + require.NoError(t, err) + require.NotNil(t, builder) + }) + + t.Run("OversizedFileSize", func(t *testing.T) { + t.Parallel() + _, err := proto.NewDataBuilder(&proto.DataUpload{ + DataHash: validHash[:], + FileSize: proto.MaxFileSize + 1, + Chunks: 1, + UploadType: proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + require.ErrorContains(t, err, "exceeds maximum allowed") + }) + + t.Run("NegativeFileSize", func(t *testing.T) { + t.Parallel() + _, err := proto.NewDataBuilder(&proto.DataUpload{ + DataHash: validHash[:], + FileSize: -1, + Chunks: 1, + UploadType: proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + require.ErrorContains(t, err, "must not be negative") + }) + + t.Run("NegativeChunks", func(t *testing.T) { + t.Parallel() + _, err := proto.NewDataBuilder(&proto.DataUpload{ + DataHash: validHash[:], + FileSize: 100, + Chunks: -1, + UploadType: proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + require.ErrorContains(t, err, "chunk count must not be negative") + }) + + t.Run("ExcessiveChunkCount", func(t *testing.T) { + t.Parallel() + _, err := proto.NewDataBuilder(&proto.DataUpload{ + DataHash: validHash[:], + FileSize: 100, + Chunks: 1000, + UploadType: proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + require.ErrorContains(t, err, "chunk count 1000 exceeds maximum") + }) + + t.Run("ZeroFileSize", func(t *testing.T) { + t.Parallel() + builder, err := proto.NewDataBuilder(&proto.DataUpload{ + DataHash: validHash[:], + FileSize: 0, + Chunks: 0, + UploadType: proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + require.NoError(t, err) + require.True(t, builder.IsDone(), "zero-chunk upload should be immediately done") + }) + + t.Run("ValidRoundTrip", func(t *testing.T) { + t.Parallel() + data := make([]byte, 256) + _, _ = crand.Read(data) + + first, chunks, err := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, data) + require.NoError(t, err) + + builder, err := proto.NewDataBuilder(first) + require.NoError(t, err) + + for _, chunk := range chunks { + _, err = builder.Add(chunk) + require.NoError(t, err) + } + + got, err := builder.Complete() + require.NoError(t, err) + require.Equal(t, data, got) + }) +} + // Fuzz must be run manually with the `-fuzz` flag to generate random test cases. // By default, it only runs the added seed corpus cases. // go test -fuzz=FuzzBytesToDataUpload @@ -25,7 +121,11 @@ func FuzzBytesToDataUpload(f *testing.F) { } f.Fuzz(func(t *testing.T, data []byte) { - first, chunks := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, data) + first, chunks, err := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, data) + if err != nil { + // Data exceeds MaxFileSize, which is expected for large fuzz inputs. + return + } builder, err := proto.NewDataBuilder(first) require.NoError(t, err) @@ -62,7 +162,9 @@ func TestBytesToDataUpload(t *testing.T) { _, err := crand.Read(data) require.NoError(t, err) - first, chunks := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, data) + first, chunks, err := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, data) + require.NoError(t, err) + builder, err := proto.NewDataBuilder(first) require.NoError(t, err) diff --git a/provisionersdk/session.go b/provisionersdk/session.go index 094fe38aba..543fdd3a51 100644 --- a/provisionersdk/session.go +++ b/provisionersdk/session.go @@ -246,24 +246,28 @@ func (s *Session) handleInitRequest(init *proto.InitRequest, requests <-chan *pr s.Logger.Info(s.Context(), "plan response too large, sending modules as stream", slog.F("size_bytes", len(complete.ModuleFiles)), ) - dataUp, chunks := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, complete.ModuleFiles) - - complete.ModuleFiles = nil // sent over the stream - complete.ModuleFilesHash = dataUp.DataHash - - err := s.stream.Send(&proto.Response{Type: &proto.Response_DataUpload{DataUpload: dataUp}}) + dataUp, chunks, err := proto.BytesToDataUpload(proto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, complete.ModuleFiles) if err != nil { - complete.Error = fmt.Sprintf("send data upload: %s", err.Error()) + complete.Error = fmt.Sprintf("prepare module files upload: %s", err.Error()) } else { - for i, chunk := range chunks { - err := s.stream.Send(&proto.Response{Type: &proto.Response_ChunkPiece{ChunkPiece: chunk}}) - if err != nil { - complete.Error = fmt.Sprintf("send data piece upload %d/%d: %s", i, dataUp.Chunks, err.Error()) - break + complete.ModuleFiles = nil // sent over the stream + complete.ModuleFilesHash = dataUp.DataHash + + err := s.stream.Send(&proto.Response{Type: &proto.Response_DataUpload{DataUpload: dataUp}}) + if err != nil { + complete.Error = fmt.Sprintf("send data upload: %s", err.Error()) + } else { + for i, chunk := range chunks { + err := s.stream.Send(&proto.Response{Type: &proto.Response_ChunkPiece{ChunkPiece: chunk}}) + if err != nil { + complete.Error = fmt.Sprintf("send data piece upload %d/%d: %s", i, dataUp.Chunks, err.Error()) + break + } } } } } + s.initialized = true return complete, nil diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 7aaac5b2dc..191f4cf622 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -1,27 +1,14 @@ package ptytest import ( - "bufio" - "bytes" - "context" - "fmt" - "io" - "regexp" "runtime" - "slices" - "strings" "sync" "testing" - "time" - "unicode/utf8" - "github.com/acarl005/stripansi" "github.com/stretchr/testify/require" - "go.uber.org/atomic" - "golang.org/x/xerrors" "github.com/coder/coder/v2/pty" - "github.com/coder/coder/v2/testutil" + "github.com/coder/coder/v2/testutil/expecter" "github.com/coder/serpent" ) @@ -31,10 +18,11 @@ func New(t *testing.T, opts ...pty.Option) *PTY { ptty, err := newTestPTY(opts...) require.NoError(t, err) - e := newExpecter(t, ptty.Output(), "cmd") + e := expecter.New(t, ptty.Output(), "cmd") r := &PTY{ - outExpecter: e, - PTY: ptty, + t: t, + Expecter: *e, + PTY: ptty, } // Ensure pty is cleaned up at the end of test. t.Cleanup(func() { @@ -54,11 +42,12 @@ func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Pr _ = ps.Kill() _ = ps.Wait() }) - ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0]) + ex := expecter.New(t, ptty.OutputReader(), cmd.Args[0]) r := &PTYCmd{ - outExpecter: ex, - PTYCmd: ptty, + Expecter: *ex, + PTYCmd: ptty, + t: t, } t.Cleanup(func() { _ = r.Close() @@ -66,322 +55,12 @@ func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Pr return r, ps } -func newExpecter(t *testing.T, r io.Reader, name string) outExpecter { - // Use pipe for logging. - logDone := make(chan struct{}) - logr, logw := io.Pipe() - - // Write to log and output buffer. - copyDone := make(chan struct{}) - out := newStdbuf() - w := io.MultiWriter(logw, out) - - ex := outExpecter{ - t: t, - out: out, - name: atomic.NewString(name), - - runeReader: bufio.NewReaderSize(out, utf8.UTFMax), - } - - logClose := func(name string, c io.Closer) { - ex.logf("closing %s", name) - err := c.Close() - ex.logf("closed %s: %v", name, err) - } - // Set the actual close function for the outExpecter. - ex.close = func(reason string) error { - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - ex.logf("closing expecter: %s", reason) - - // Caller needs to have closed the PTY so that copying can complete - select { - case <-ctx.Done(): - ex.fatalf("close", "copy did not close in time") - case <-copyDone: - } - - logClose("logw", logw) - logClose("logr", logr) - select { - case <-ctx.Done(): - ex.fatalf("close", "log pipe did not close in time") - case <-logDone: - } - - ex.logf("closed expecter") - - return nil - } - - go func() { - defer close(copyDone) - _, err := io.Copy(w, r) - ex.logf("copy done: %v", err) - ex.logf("closing out") - err = out.closeErr(err) - ex.logf("closed out: %v", err) - }() - - // Log all output as part of test for easier debugging on errors. - go func() { - defer close(logDone) - s := bufio.NewScanner(logr) - for s.Scan() { - ex.logf("%q", stripansi.Strip(s.Text())) - } - // Surface non-EOF scanner errors; otherwise they're invisible. - if err := s.Err(); err != nil { - ex.logf("log scanner stopped: %v", err) - } - }() - - return ex -} - -type outExpecter struct { - t *testing.T - close func(reason string) error - out *stdbuf - name *atomic.String - - runeReader *bufio.Reader -} - -// Deprecated: use ExpectMatchContext instead. -// This uses a background context, so will not respect the test's context. -func (e *outExpecter) ExpectMatch(str string) string { - return e.expectMatchContextFunc(str, e.ExpectMatchContext) -} - -func (e *outExpecter) ExpectRegexMatch(str string) string { - return e.expectMatchContextFunc(str, e.ExpectRegexMatchContext) -} - -func (e *outExpecter) expectMatchContextFunc(str string, fn func(ctx context.Context, str string) string) string { - e.t.Helper() - - timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - - return fn(timeout, str) -} - -// TODO(mafredri): Rename this to ExpectMatch when refactoring. -func (e *outExpecter) ExpectMatchContext(ctx context.Context, str string) string { - return e.expectMatcherFunc(ctx, str, strings.Contains) -} - -func (e *outExpecter) ExpectRegexMatchContext(ctx context.Context, str string) string { - return e.expectMatcherFunc(ctx, str, func(src, pattern string) bool { - return regexp.MustCompile(pattern).MatchString(src) - }) -} - -func (e *outExpecter) expectMatcherFunc(ctx context.Context, str string, fn func(src, pattern string) bool) string { - e.t.Helper() - - var buffer bytes.Buffer - err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func(rd *bufio.Reader) error { - for { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - if fn(buffer.String(), str) { - return nil - } - } - }) - if err != nil { - e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) - return "" - } - e.logf("matched %q = %q", str, buffer.String()) - return buffer.String() -} - -// ExpectNoMatchBefore validates that `match` does not occur before `before`. -func (e *outExpecter) ExpectNoMatchBefore(ctx context.Context, match, before string) string { - e.t.Helper() - - var buffer bytes.Buffer - err := e.doMatchWithDeadline(ctx, "ExpectNoMatchBefore", func(rd *bufio.Reader) error { - for { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - - if strings.Contains(buffer.String(), match) { - return xerrors.Errorf("found %q before %q", match, before) - } - - if strings.Contains(buffer.String(), before) { - return nil - } - } - }) - if err != nil { - e.fatalf("read error", "%v (wanted no %q before %q; got %q)", err, match, before, buffer.String()) - return "" - } - e.logf("matched %q = %q", before, stripansi.Strip(buffer.String())) - return buffer.String() -} - -func (e *outExpecter) Peek(ctx context.Context, n int) []byte { - e.t.Helper() - - var out []byte - err := e.doMatchWithDeadline(ctx, "Peek", func(rd *bufio.Reader) error { - var err error - out, err = rd.Peek(n) - return err - }) - if err != nil { - e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) - return nil - } - e.logf("peeked %d/%d bytes = %q", len(out), n, out) - return slices.Clone(out) -} - //nolint:govet // We don't care about conforming to ReadRune() (rune, int, error). -func (e *outExpecter) ReadRune(ctx context.Context) rune { - e.t.Helper() - - var r rune - err := e.doMatchWithDeadline(ctx, "ReadRune", func(rd *bufio.Reader) error { - var err error - r, _, err = rd.ReadRune() - return err - }) - if err != nil { - e.fatalf("read error", "%v (wanted rune; got %q)", err, r) - return 0 - } - e.logf("matched rune = %q", r) - return r -} - -func (e *outExpecter) ReadLine(ctx context.Context) string { - e.t.Helper() - - var buffer bytes.Buffer - err := e.doMatchWithDeadline(ctx, "ReadLine", func(rd *bufio.Reader) error { - for { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - if r == '\n' { - return nil - } - if r == '\r' { - // Peek the next rune to see if it's an LF and then consume - // it. - - // Unicode code points can be up to 4 bytes, but the - // ones we're looking for are only 1 byte. - b, _ := rd.Peek(1) - if len(b) == 0 { - return nil - } - - r, _ = utf8.DecodeRune(b) - if r == '\n' { - _, _, err = rd.ReadRune() - if err != nil { - return err - } - } - - return nil - } - - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - } - }) - if err != nil { - e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) - return "" - } - e.logf("matched newline = %q", buffer.String()) - return buffer.String() -} - -func (e *outExpecter) ReadAll() []byte { - e.t.Helper() - return e.out.ReadAll() -} - -func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func(*bufio.Reader) error) error { - e.t.Helper() - - // A timeout is mandatory, caller can decide by passing a context - // that times out. - if _, ok := ctx.Deadline(); !ok { - timeout := testutil.WaitMedium - e.logf("%s ctx has no deadline, using %s", name, timeout) - var cancel context.CancelFunc - //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - - match := make(chan error, 1) - go func() { - defer close(match) - match <- fn(e.runeReader) - }() - select { - case err := <-match: - return err - case <-ctx.Done(): - // Ensure goroutine is cleaned up before test exit, do not call - // (*outExpecter).close here to let the caller decide. - _ = e.out.Close() - <-match - - return xerrors.Errorf("match deadline exceeded: %w", ctx.Err()) - } -} - -func (e *outExpecter) logf(format string, args ...interface{}) { - e.t.Helper() - - // Match regular logger timestamp format, we seem to be logging in - // UTC in other places as well, so match here. - e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name.Load(), fmt.Sprintf(format, args...)) -} - -func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) { - e.t.Helper() - - // Ensure the message is part of the normal log stream before - // failing the test. - e.logf("%s: %s", reason, fmt.Sprintf(format, args...)) - - require.FailNowf(e.t, reason, format, args...) -} type PTY struct { - outExpecter + expecter.Expecter pty.PTY + t *testing.T closeOnce sync.Once closeErr error } @@ -391,17 +70,12 @@ func (p *PTY) Close() error { p.closeOnce.Do(func() { pErr := p.PTY.Close() if pErr != nil { - p.logf("PTY: Close failed: %v", pErr) - } - eErr := p.outExpecter.close("PTY close") - if eErr != nil { - p.logf("PTY: close expecter failed: %v", eErr) + p.Logf("PTY: Close failed: %v", pErr) } + p.Expecter.Close("PTY close") if pErr != nil { p.closeErr = pErr - return } - p.closeErr = eErr }) return p.closeErr } @@ -418,7 +92,7 @@ func (p *PTY) Attach(inv *serpent.Invocation) *PTY { func (p *PTY) Write(r rune) { p.t.Helper() - p.logf("stdin: %q", r) + p.Logf("stdin: %q", r) _, err := p.Input().Write([]byte{byte(r)}) require.NoError(p.t, err, "write failed") } @@ -430,7 +104,7 @@ func (p *PTY) WriteLine(str string) { if runtime.GOOS == "windows" { newline = append(newline, '\n') } - p.logf("stdin: %q", str+string(newline)) + p.Logf("stdin: %q", str+string(newline)) _, err := p.Input().Write(append([]byte(str), newline...)) require.NoError(p.t, err, "write line failed") } @@ -440,137 +114,22 @@ func (p *PTY) WriteLine(str string) { // // p := New(t).Named("myCmd") func (p *PTY) Named(name string) *PTY { - p.name.Store(name) + p.Rename(name) return p } type PTYCmd struct { - outExpecter + expecter.Expecter pty.PTYCmd + t *testing.T } func (p *PTYCmd) Close() error { p.t.Helper() pErr := p.PTYCmd.Close() if pErr != nil { - p.logf("PTYCmd: Close failed: %v", pErr) + p.Logf("PTYCmd: Close failed: %v", pErr) } - eErr := p.outExpecter.close("PTYCmd close") - if eErr != nil { - p.logf("PTYCmd: close expecter failed: %v", eErr) - } - if pErr != nil { - return pErr - } - return eErr -} - -// stdbuf is like a buffered stdout, it buffers writes until read. -type stdbuf struct { - r io.Reader - - mu sync.Mutex // Protects following. - b []byte - more chan struct{} - err error -} - -func newStdbuf() *stdbuf { - return &stdbuf{more: make(chan struct{}, 1)} -} - -func (b *stdbuf) ReadAll() []byte { - b.mu.Lock() - defer b.mu.Unlock() - - if b.err != nil { - return nil - } - p := append([]byte(nil), b.b...) - b.b = b.b[len(b.b):] - return p -} - -func (b *stdbuf) Read(p []byte) (int, error) { - if b.r == nil { - return b.readOrWaitForMore(p) - } - - n, err := b.r.Read(p) - if xerrors.Is(err, io.EOF) { - b.r = nil - err = nil - if n == 0 { - return b.readOrWaitForMore(p) - } - } - return n, err -} - -func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - - // Deplete channel so that more check - // is for future input into buffer. - select { - case <-b.more: - default: - } - - if len(b.b) == 0 { - if b.err != nil { - return 0, b.err - } - - b.mu.Unlock() - <-b.more - b.mu.Lock() - } - - b.r = bytes.NewReader(b.b) - b.b = b.b[len(b.b):] - - return b.r.Read(p) -} - -func (b *stdbuf) Write(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - - b.mu.Lock() - defer b.mu.Unlock() - - if b.err != nil { - return 0, b.err - } - - b.b = append(b.b, p...) - - select { - case b.more <- struct{}{}: - default: - } - - return len(p), nil -} - -func (b *stdbuf) Close() error { - return b.closeErr(nil) -} - -func (b *stdbuf) closeErr(err error) error { - b.mu.Lock() - defer b.mu.Unlock() - if b.err != nil { - return err - } - if err == nil { - b.err = io.EOF - } else { - b.err = err - } - close(b.more) - return err + p.Expecter.Close("PTYCmd close") + return pErr } diff --git a/scripts/Dockerfile.base b/scripts/Dockerfile.base index 337ee5a84f..315c099d78 100644 --- a/scripts/Dockerfile.base +++ b/scripts/Dockerfile.base @@ -27,7 +27,7 @@ RUN apk add --no-cache \ # Terraform was disabled in the edge repo due to a build issue. # https://gitlab.alpinelinux.org/alpine/aports/-/commit/f3e263d94cfac02d594bef83790c280e045eba35 # Using wget for now. Note that busybox unzip doesn't support streaming. -RUN ARCH="$(arch)"; if [ "${ARCH}" == "x86_64" ]; then ARCH="amd64"; elif [ "${ARCH}" == "aarch64" ]; then ARCH="arm64"; elif [ "${ARCH}" == "armv7l" ]; then ARCH="arm"; fi; wget -O /tmp/terraform.zip "https://releases.hashicorp.com/terraform/1.15.2/terraform_1.15.2_linux_${ARCH}.zip" && \ +RUN ARCH="$(arch)"; if [ "${ARCH}" == "x86_64" ]; then ARCH="amd64"; elif [ "${ARCH}" == "aarch64" ]; then ARCH="arm64"; elif [ "${ARCH}" == "armv7l" ]; then ARCH="arm"; fi; wget -O /tmp/terraform.zip "https://releases.hashicorp.com/terraform/1.15.5/terraform_1.15.5_linux_${ARCH}.zip" && \ busybox unzip /tmp/terraform.zip -d /usr/local/bin && \ rm -f /tmp/terraform.zip && \ chmod +x /usr/local/bin/terraform && \ diff --git a/scripts/check_emdash.sh b/scripts/check_emdash.sh index ffd4e092ff..2b95fd4584 100755 --- a/scripts/check_emdash.sh +++ b/scripts/check_emdash.sh @@ -39,70 +39,147 @@ scan_all_files() { fi } +# resolve_merge_base finds the merge-base between HEAD and the given ref. +# In shallow CI clones the merge-base is not directly reachable, so we +# query the PR commit count via `gh`, deepen HEAD by count+1, and +# resolve HEAD~N which is the parent of the first PR commit. +resolve_merge_base() { + local base_ref="$1" + + # Fast path: merge-base already reachable (full clone or sufficient depth). + local mb + mb=$(git merge-base HEAD "$base_ref" 2>/dev/null || true) + if [[ -n "$mb" ]]; then + echo "$mb" + return + fi + + if ! command -v gh >/dev/null 2>&1; then + echo "gh CLI not found, cannot determine PR commit count." >&2 + return + fi + + # Use the PR commit count to deepen HEAD past the PR commits. + # HEAD~N is the parent of the oldest PR commit, i.e. the merge-base. + local count + count=$(gh pr view --json commits --jq '.commits | length' 2>/dev/null || true) + if [[ -z "$count" || "$count" -le 0 ]]; then + echo "Could not determine PR commit count from gh." >&2 + return + fi + + echo "Deepening HEAD by $((count + 1)) to reach PR base..." >&2 + git fetch --deepen="$((count + 1))" 2>/dev/null || true + + # Retry merge-base now that we have more history. + mb=$(git merge-base HEAD "$base_ref" 2>/dev/null || true) + if [[ -n "$mb" ]]; then + echo "$mb" + return + fi + + # Last resort: walk first-parent history. This is correct for + # linear PRs but may traverse the wrong branch for merge-commit + # checkouts. + git rev-parse --verify "HEAD~${count}" 2>/dev/null || true +} + +# fetch_base_ref ensures origin/$GITHUB_BASE_REF is available locally. +# CI shallow clones (fetch-depth: 1) typically omit the base branch. +fetch_base_ref() { + local base_ref="$1" + + if git rev-parse --verify "$base_ref" >/dev/null 2>&1; then + return 0 + fi + + local ref="${base_ref#origin/}" + echo "Base ref $base_ref not found locally, fetching $ref..." >&2 + git fetch origin "$ref" --depth=1 2>/dev/null || true + + if ! git rev-parse --verify "$base_ref" >/dev/null 2>&1; then + echo "ERROR: could not fetch base ref $base_ref." >&2 + return 1 + fi +} + +# resolve_diff_base determines the base ref to diff against. +resolve_diff_base() { + # CI pull requests: use merge-base against the target branch. + if [[ -n "${GITHUB_BASE_REF:-}" ]]; then + local base_ref="origin/${GITHUB_BASE_REF}" + fetch_base_ref "$base_ref" || return 1 + + local base + base=$(resolve_merge_base "$base_ref") + if [[ -n "$base" ]]; then + echo "$base" + return + fi + + # Could not determine merge-base; fall back to branch tip. + echo "WARNING: could not find merge-base with $base_ref, using branch tip (diff may include non-PR changes)." >&2 + echo "$base_ref" + return + fi + + # Local dev: use merge-base with origin/main. + if git rev-parse --verify origin/main >/dev/null 2>&1; then + git merge-base HEAD origin/main 2>/dev/null || echo "origin/main" + return + fi +} + +# scan_diff checks only added lines in the diff for emdash/endash. +scan_diff() { + local base="$1" + + local diff_output + if ! diff_output=$(git diff "$base" -U0 -- . "${exclude_pathspecs[@]}" 2>&1); then + echo "ERROR: git diff against $base failed:" >&2 + echo "$diff_output" >&2 + exit 1 + fi + + if [[ -z "$diff_output" ]]; then + echo "OK: no changes to check." + exit 0 + fi + + local current_file="" current_line=0 + while IFS= read -r diff_line; do + if [[ "$diff_line" =~ ^\+\+\+\ b/(.*) ]]; then + current_file="${BASH_REMATCH[1]}" + fi + # Anchored to hunk header structure to avoid matching + # digits from trailing function context. + if [[ "$diff_line" =~ ^@@\ -[0-9,]+\ \+([0-9]+) ]]; then + current_line=${BASH_REMATCH[1]} + continue + fi + if [[ "$diff_line" =~ ^\+ ]] && [[ ! "$diff_line" =~ ^\+\+\+\ [ab/] ]]; then + if echo "$diff_line" | grep -Eq "$pattern"; then + echo "${current_file}:${current_line}:${diff_line:1}" + found=1 + fi + ((current_line++)) || true + fi + done <<<"$diff_output" +} + if [[ "$mode" == "all" ]]; then scan_all_files else - base="" - if [[ -n "${GITHUB_BASE_REF:-}" ]]; then - base="origin/${GITHUB_BASE_REF}" - elif git rev-parse --verify origin/main >/dev/null 2>&1; then - base=$(git merge-base HEAD origin/main 2>/dev/null || echo "origin/main") - fi - + base=$(resolve_diff_base) || { + echo "ERROR: could not determine base ref." >&2 + exit 1 + } if [[ -z "$base" ]]; then - echo "WARNING: no base ref found, scanning all tracked files." + echo "WARNING: no base ref found, scanning all tracked files." >&2 scan_all_files else - # Ensure the base ref is fetchable. CI shallow clones - # (fetch-depth: 1) may not have the base branch available. - if ! git rev-parse --verify "$base" >/dev/null 2>&1; then - ref="${base#origin/}" - echo "Base ref $base not found locally, fetching $ref..." - git fetch origin "$ref" --depth=1 2>/dev/null || true - if ! git rev-parse --verify "$base" >/dev/null 2>&1; then - if git rev-parse --verify origin/main >/dev/null 2>&1; then - echo "WARNING: could not fetch base ref $base, falling back to origin/main merge base." - base=$(git merge-base HEAD origin/main 2>/dev/null || echo "origin/main") - else - echo "ERROR: could not fetch base ref $base." - exit 1 - fi - fi - fi - found=0 - if ! diff_output=$(git diff "$base" -U0 -- . "${exclude_pathspecs[@]}" 2>&1); then - echo "ERROR: git diff against $base failed:" - echo "$diff_output" - exit 1 - fi - - if [[ -z "$diff_output" ]]; then - echo "OK: no changes to check." - exit 0 - fi - - # Parse the diff to check only added lines for emdash/endash. - current_file="" - current_line=0 - while IFS= read -r diff_line; do - if [[ "$diff_line" =~ ^\+\+\+\ b/(.*) ]]; then - current_file="${BASH_REMATCH[1]}" - fi - # Anchored to hunk header structure to avoid matching - # digits from trailing function context. - if [[ "$diff_line" =~ ^@@\ -[0-9,]+\ \+([0-9]+) ]]; then - current_line=${BASH_REMATCH[1]} - continue - fi - if [[ "$diff_line" =~ ^\+ ]] && [[ ! "$diff_line" =~ ^\+\+\+\ [ab/] ]]; then - if echo "$diff_line" | grep -Eq "$pattern"; then - echo "${current_file}:${current_line}:${diff_line:1}" - found=1 - fi - ((current_line++)) || true - fi - done <<<"$diff_output" + scan_diff "$base" fi fi diff --git a/scripts/metricsdocgen/metrics b/scripts/metricsdocgen/metrics index 653de99241..036ac496a1 100644 --- a/scripts/metricsdocgen/metrics +++ b/scripts/metricsdocgen/metrics @@ -208,3 +208,21 @@ coder_aibridgeproxyd_mitm_requests_total{provider=""} 0 # HELP coder_aibridgeproxyd_mitm_responses_total Total number of MITM responses by HTTP status code class. # TYPE coder_aibridgeproxyd_mitm_responses_total counter coder_aibridgeproxyd_mitm_responses_total{code="",provider=""} 0 +# HELP coder_aibridged_provider_info One series per configured AI provider. Value is always 1; the status label (enabled, disabled, error) carries the alertable signal. +# TYPE coder_aibridged_provider_info gauge +coder_aibridged_provider_info{provider_name="",provider_type="",status=""} 0 +# HELP coder_aibridged_providers_last_reload_timestamp_seconds Unix timestamp of the last provider reload attempt, success or failure. +# TYPE coder_aibridged_providers_last_reload_timestamp_seconds gauge +coder_aibridged_providers_last_reload_timestamp_seconds 0 +# HELP coder_aibridged_providers_last_reload_success_timestamp_seconds Unix timestamp of the last provider reload that successfully refreshed the pool. A gap against coder_aibridged_providers_last_reload_timestamp_seconds means the loop is firing but the refresh function is failing. +# TYPE coder_aibridged_providers_last_reload_success_timestamp_seconds gauge +coder_aibridged_providers_last_reload_success_timestamp_seconds 0 +# HELP coder_aibridgeproxyd_provider_info One series per configured AI provider. Value is always 1; the status label (enabled, disabled, error) carries the alertable signal. +# TYPE coder_aibridgeproxyd_provider_info gauge +coder_aibridgeproxyd_provider_info{provider_name="",provider_type="",status=""} 0 +# HELP coder_aibridgeproxyd_providers_last_reload_timestamp_seconds Unix timestamp of the last provider reload attempt, success or failure. +# TYPE coder_aibridgeproxyd_providers_last_reload_timestamp_seconds gauge +coder_aibridgeproxyd_providers_last_reload_timestamp_seconds 0 +# HELP coder_aibridgeproxyd_providers_last_reload_success_timestamp_seconds Unix timestamp of the last provider reload that successfully refreshed the router. A gap against coder_aibridgeproxyd_providers_last_reload_timestamp_seconds means the loop is firing but the refresh function is failing. +# TYPE coder_aibridgeproxyd_providers_last_reload_success_timestamp_seconds gauge +coder_aibridgeproxyd_providers_last_reload_success_timestamp_seconds 0 diff --git a/scripts/metricsdocgen/scanner/scanner.go b/scripts/metricsdocgen/scanner/scanner.go index eee4166e49..f7ab57f9d4 100644 --- a/scripts/metricsdocgen/scanner/scanner.go +++ b/scripts/metricsdocgen/scanner/scanner.go @@ -40,6 +40,7 @@ var scanDirs = []string{ // // eliminate the need for this skip list. var skipPaths = []string{ + "coderd/aibridged/metrics.go", "enterprise/aibridgeproxyd/metrics.go", } diff --git a/scripts/should_deploy.sh b/scripts/should_deploy.sh index 6259f9e109..003828b411 100755 --- a/scripts/should_deploy.sh +++ b/scripts/should_deploy.sh @@ -17,6 +17,20 @@ deploy_branch=main # branch names. branch_name=$(git branch --show-current) +# Short circuit: we no longer deploy release branches to dogfood, and instead +# test them on the stable deployment. +# TODO: once we're happy with the new deployment process, we can remove this +# script and the related github workflow stuff. +if [[ "$branch_name" == "main" ]]; then + log "VERDICT: DEPLOY" + echo "DEPLOY" # stdout + exit 0 +else + log "VERDICT: NOOP" + echo "NOOP" # stdout + exit 0 +fi + if [[ "$branch_name" != "main" && ! "$branch_name" =~ ^release/[0-9]+\.[0-9]+$ ]]; then error "Current branch '$branch_name' is not a supported branch name for dogfood, must be 'main' or 'release/x.y'" fi diff --git a/site/package.json b/site/package.json index d2cf5e5512..0df6778347 100644 --- a/site/package.json +++ b/site/package.json @@ -69,7 +69,7 @@ "@xterm/addon-webgl": "0.19.0", "@xterm/xterm": "5.5.0", "ansi-to-html": "0.7.2", - "axios": "1.15.2", + "axios": "1.16.0", "chroma-js": "2.6.0", "class-variance-authority": "0.7.1", "clsx": "2.1.1", diff --git a/site/pnpm-lock.yaml b/site/pnpm-lock.yaml index be78e8c75d..f99dc40d62 100644 --- a/site/pnpm-lock.yaml +++ b/site/pnpm-lock.yaml @@ -123,8 +123,8 @@ importers: specifier: 0.7.2 version: 0.7.2 axios: - specifier: 1.15.2 - version: 1.15.2 + specifier: 1.16.0 + version: 1.16.0 chroma-js: specifier: 2.6.0 version: 2.6.0 @@ -495,6 +495,10 @@ packages: resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==, tarball: https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz} engines: {node: '>=6.9.0'} + '@babel/code-frame@7.29.7': + resolution: {integrity: sha512-Aup7aUOfpbAUg2ROOJN6Iw5f9DMBlzu0mIkm/malLQFN/YQgO48wCj0Kxa3sEHJvPVFg7siR+qRInwXd2qhQKw==, tarball: https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.7.tgz} + engines: {node: '>=6.9.0'} + '@babel/compat-data@7.29.0': resolution: {integrity: sha512-T1NCJqT/j9+cn8fvkt7jtwbLBfLC/1y1c7NtCeXFRgzGTsafi68MRv8yzkYSapBnFA6L3U2VSc02ciDzoAJhJg==, tarball: https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.29.0.tgz} engines: {node: '>=6.9.0'} @@ -545,6 +549,10 @@ packages: resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==, tarball: https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz} engines: {node: '>=6.9.0'} + '@babel/helper-validator-identifier@7.29.7': + resolution: {integrity: sha512-qehxGkRj55h/ff8EMaJ+cYhyaKlHIxqYDn682wQD7RNp9UujOQsHog2uS0r2vzr4pW+sXf90NeeayjcNaX3fFg==, tarball: https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.29.7.tgz} + engines: {node: '>=6.9.0'} + '@babel/helper-validator-option@7.27.1': resolution: {integrity: sha512-YvjJow9FxbhFFKDSuFnVCe2WxXk1zWc22fFePVNEaWJEu8IrZVlda6N0uHwzZrUM1il7NC9Mlp4MaJYbYd9JSg==, tarball: https://registry.npmjs.org/@babel/helper-validator-option/-/helper-validator-option-7.27.1.tgz} engines: {node: '>=6.9.0'} @@ -3030,8 +3038,8 @@ packages: resolution: {integrity: sha512-BASOg+YwO2C+346x3LZOeoovTIoTrRqEsqMa6fmfAV0P+U9mFr9NsyOEpiYvFjbc64NMrSswhV50WdXzdb/Z5A==, tarball: https://registry.npmjs.org/axe-core/-/axe-core-4.11.1.tgz} engines: {node: '>=4'} - axios@1.15.2: - resolution: {integrity: sha512-wLrXxPtcrPTsNlJmKjkPnNPK2Ihe0hn0wGSaTEiHRPxwjvJwT3hKmXF4dpqxmPO9SoNb2FsYXj/xEo0gHN+D5A==, tarball: https://registry.npmjs.org/axios/-/axios-1.15.2.tgz} + axios@1.16.0: + resolution: {integrity: sha512-6hp5CwvTPlN2A31g5dxnwAX0orzM7pmCRDLnZSX772mv8WDqICwFjowHuPs04Mc8deIld1+ejhtaMn5vp6b+1w==, tarball: https://registry.npmjs.org/axios/-/axios-1.16.0.tgz} babel-plugin-macros@3.1.0: resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==, tarball: https://registry.npmjs.org/babel-plugin-macros/-/babel-plugin-macros-3.1.0.tgz} @@ -3746,8 +3754,8 @@ packages: es-module-lexer@2.1.0: resolution: {integrity: sha512-n27zTYMjYu1aj4MjCWzSP7G9r75utsaoc8m61weK+W8JMBGGQybd43GstCXZ3WNmSFtGT9wi59qQTW6mhTR5LQ==, tarball: https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.1.0.tgz} - es-object-atoms@1.1.1: - resolution: {integrity: sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==, tarball: https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz} + es-object-atoms@1.1.2: + resolution: {integrity: sha512-HWcBoN6NileqtSydK2FqHbS/LoDd2pqrnQHLyJzBj4kOp/ky2MWMN694xOfkK8/SnUsW2DH7EfyVlydKCsm1Zw==, tarball: https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.2.tgz} engines: {node: '>= 0.4'} es-set-tostringtag@2.1.0: @@ -4015,8 +4023,8 @@ packages: resolution: {integrity: sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==, tarball: https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz} engines: {node: '>= 0.4'} - hasown@2.0.3: - resolution: {integrity: sha512-ej4AhfhfL2Q2zpMmLo7U1Uv9+PyhIZpgQLGT1F9miIGmiCJIoCgSmczFdrc97mWT4kVY72KA+WnnhJ5pghSvSg==, tarball: https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz} + hasown@2.0.4: + resolution: {integrity: sha512-T2UbfbBEF32wiepXIsMlTW9+dDYC6wMh/t/vYA4tuOMKqWz/n3vr1NFSxQiyP+zk2mXsoMA/i/7qV6LKut1t1A==, tarball: https://registry.npmjs.org/hasown/-/hasown-2.0.4.tgz} engines: {node: '>= 0.4'} hast-util-from-parse5@8.0.3: @@ -5343,6 +5351,7 @@ packages: recharts@2.15.4: resolution: {integrity: sha512-UT/q6fwS3c1dHbXv2uFgYJ9BMFHu3fwnd7AYZaEQhXuYQ4hgsxLvsUXzGdKeZrW5xopzDCvuA2N41WJ88I7zIw==, tarball: https://registry.npmjs.org/recharts/-/recharts-2.15.4.tgz} engines: {node: '>=14'} + deprecated: 1.x and 2.x branches are no longer active. Bump to Recharts v3 to receive latest features and bugfixes. See https://github.com/recharts/recharts/wiki/3.0-migration-guide peerDependencies: react: ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 @@ -6401,6 +6410,12 @@ snapshots: js-tokens: 4.0.0 picocolors: 1.1.1 + '@babel/code-frame@7.29.7': + dependencies: + '@babel/helper-validator-identifier': 7.29.7 + js-tokens: 4.0.0 + picocolors: 1.1.1 + '@babel/compat-data@7.29.0': {} '@babel/core@7.29.0': @@ -6478,6 +6493,8 @@ snapshots: '@babel/helper-validator-identifier@7.28.5': {} + '@babel/helper-validator-identifier@7.29.7': {} + '@babel/helper-validator-option@7.27.1': {} '@babel/helpers@7.26.10': @@ -8442,7 +8459,7 @@ snapshots: '@testing-library/dom@10.4.0': dependencies: - '@babel/code-frame': 7.29.0 + '@babel/code-frame': 7.29.7 '@babel/runtime': 7.26.10 '@types/aria-query': 5.0.4 aria-query: 5.3.0 @@ -9079,7 +9096,7 @@ snapshots: axe-core@4.11.1: {} - axios@1.15.2: + axios@1.16.0: dependencies: follow-redirects: 1.16.0 form-data: 4.0.4 @@ -9799,7 +9816,7 @@ snapshots: es-module-lexer@2.1.0: {} - es-object-atoms@1.1.1: + es-object-atoms@1.1.2: dependencies: es-errors: 1.3.0 @@ -9808,7 +9825,7 @@ snapshots: es-errors: 1.3.0 get-intrinsic: 1.3.0 has-tostringtag: 1.0.2 - hasown: 2.0.3 + hasown: 2.0.4 esbuild@0.25.12: optionalDependencies: @@ -9972,7 +9989,7 @@ snapshots: asynckit: 0.4.0 combined-stream: 1.0.8 es-set-tostringtag: 2.1.0 - hasown: 2.0.3 + hasown: 2.0.4 mime-types: 2.1.35 format@0.2.2: {} @@ -10042,12 +10059,12 @@ snapshots: call-bind-apply-helpers: 1.0.2 es-define-property: 1.0.1 es-errors: 1.3.0 - es-object-atoms: 1.1.1 + es-object-atoms: 1.1.2 function-bind: 1.1.2 get-proto: 1.0.1 gopd: 1.2.0 has-symbols: 1.1.0 - hasown: 2.0.3 + hasown: 2.0.4 math-intrinsics: 1.1.0 get-nonce@1.0.1: {} @@ -10055,7 +10072,7 @@ snapshots: get-proto@1.0.1: dependencies: dunder-proto: 1.0.1 - es-object-atoms: 1.1.1 + es-object-atoms: 1.1.2 glob-parent@5.1.2: dependencies: @@ -10100,7 +10117,7 @@ snapshots: dependencies: has-symbols: 1.1.0 - hasown@2.0.3: + hasown@2.0.4: dependencies: function-bind: 1.1.2 @@ -10275,7 +10292,7 @@ snapshots: internal-slot@1.0.6: dependencies: get-intrinsic: 1.3.0 - hasown: 2.0.3 + hasown: 2.0.4 side-channel: 1.1.0 internmap@1.0.1: {} @@ -10328,7 +10345,7 @@ snapshots: is-core-module@2.16.1: dependencies: - hasown: 2.0.3 + hasown: 2.0.4 is-date-object@1.0.5: dependencies: diff --git a/site/src/api/queries/workspaces.ts b/site/src/api/queries/workspaces.ts index 5782c32d18..ea6ec316ad 100644 --- a/site/src/api/queries/workspaces.ts +++ b/site/src/api/queries/workspaces.ts @@ -148,6 +148,7 @@ type AutoCreateWorkspaceOptions = { match: string | null; templateVersionId?: string; buildParameters?: WorkspaceBuildParameter[]; + templateVersionPresetId?: string; }; export const autoCreateWorkspace = (queryClient: QueryClient) => { @@ -158,6 +159,7 @@ export const autoCreateWorkspace = (queryClient: QueryClient) => { workspaceName, templateVersionId, buildParameters, + templateVersionPresetId, match, }: AutoCreateWorkspaceOptions) => { if (match) { @@ -185,6 +187,7 @@ export const autoCreateWorkspace = (queryClient: QueryClient) => { ...templateVersionParameters, name: workspaceName, rich_parameter_values: buildParameters, + template_version_preset_id: templateVersionPresetId, }); }, onSuccess: async () => { diff --git a/site/src/api/rbacresourcesGenerated.ts b/site/src/api/rbacresourcesGenerated.ts index 23bd95350c..2ac260c98a 100644 --- a/site/src/api/rbacresourcesGenerated.ts +++ b/site/src/api/rbacresourcesGenerated.ts @@ -50,6 +50,11 @@ export const RBACResourceActions: Partial< create: "create new audit log entries", read: "read audit logs", }, + boundary_log: { + create: "create boundary log records", + delete: "delete boundary logs", + read: "read boundary logs and session metadata", + }, boundary_usage: { delete: "delete boundary usage statistics", read: "read boundary usage statistics", diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 56fbbf6000..c4fae7a3d0 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -378,7 +378,9 @@ export const AIProviderBedrockSettingsVersion = 1; */ export interface AIProviderConfig { /** - * Type is the provider type: "openai", "anthropic", or "copilot". + * Type is the provider type. Valid values are: "openai", + * "anthropic", "azure", "bedrock", "google", "openai-compat", + * "openrouter", "vercel", "copilot". */ readonly type: string; /** @@ -552,6 +554,10 @@ export type APIKeyScope = | "audit_log:*" | "audit_log:create" | "audit_log:read" + | "boundary_log:*" + | "boundary_log:create" + | "boundary_log:delete" + | "boundary_log:read" | "boundary_usage:*" | "boundary_usage:delete" | "boundary_usage:read" @@ -778,6 +784,10 @@ export const APIKeyScopes: APIKeyScope[] = [ "audit_log:*", "audit_log:create", "audit_log:read", + "boundary_log:*", + "boundary_log:create", + "boundary_log:delete", + "boundary_log:read", "boundary_usage:*", "boundary_usage:delete", "boundary_usage:read", @@ -1957,7 +1967,9 @@ export type ChatErrorKind = | "auth" | "config" | "generic" + | "missing_key" | "overloaded" + | "provider_disabled" | "rate_limit" | "startup_timeout" | "timeout" @@ -1967,7 +1979,9 @@ export const ChatErrorKinds: ChatErrorKind[] = [ "auth", "config", "generic", + "missing_key", "overloaded", + "provider_disabled", "rate_limit", "startup_timeout", "timeout", @@ -4086,6 +4100,7 @@ export interface DeploymentValues { readonly agent_fallback_troubleshooting_url?: string; readonly browser_only?: boolean; readonly scim_api_key?: string; + readonly scim_use_legacy?: boolean; readonly external_token_encryption_keys?: string; readonly provisioner?: ProvisionerConfig; readonly rate_limit?: RateLimitConfig; @@ -6865,6 +6880,7 @@ export type RBACResource = | "assign_org_role" | "assign_role" | "audit_log" + | "boundary_log" | "boundary_usage" | "chat" | "connection_log" @@ -6915,6 +6931,7 @@ export const RBACResources: RBACResource[] = [ "assign_org_role", "assign_role", "audit_log", + "boundary_log", "boundary_usage", "chat", "connection_log", @@ -9143,6 +9160,16 @@ export interface UpsertGroupAIBudgetRequest { readonly spend_limit_micros: number; } +// From codersdk/aibridge.go +export interface UpsertUserAIBudgetOverrideRequest { + /** + * GroupID is the group the user's spend is attributed to. The user must + * be a member of this group. + */ + readonly group_id: string; + readonly spend_limit_micros: number; +} + // From codersdk/workspaceagentportshare.go export interface UpsertWorkspaceAgentPortShareRequest { readonly agent_name: string; @@ -9187,6 +9214,15 @@ export interface User extends ReducedUser { readonly has_ai_seat: boolean; } +// From codersdk/aibridge.go +export interface UserAIBudgetOverride { + readonly user_id: string; + readonly group_id: string; + readonly spend_limit_micros: number; + readonly created_at: string; + readonly updated_at: string; +} + // From codersdk/chats.go /** * UserAIProviderKeyConfig is a provider summary from the current user's diff --git a/site/src/modules/resources/AgentLogs/AgentLogLine.test.tsx b/site/src/modules/resources/AgentLogs/AgentLogLine.test.tsx new file mode 100644 index 0000000000..80b1d900b0 --- /dev/null +++ b/site/src/modules/resources/AgentLogs/AgentLogLine.test.tsx @@ -0,0 +1,23 @@ +import { screen } from "@testing-library/react"; +import type { Line } from "#/components/Logs/LogLine"; +import { renderComponent } from "#/testHelpers/renderHelpers"; +import { AgentLogLine } from "./AgentLogLine"; + +const line: Line = { + id: 1, + level: "info", + output: 'safe xss', + sourceId: "source-id", + time: "2024-03-14T11:31:04.090715Z", +}; + +describe("AgentLogLine", () => { + it("renders log HTML as escaped text", () => { + renderComponent(); + + expect(screen.queryByTestId("agent-log-xss")).not.toBeInTheDocument(); + expect( + screen.getByText(/safe xss<\/span>/), + ).toBeInTheDocument(); + }); +}); diff --git a/site/src/modules/resources/AgentLogs/AgentLogLine.tsx b/site/src/modules/resources/AgentLogs/AgentLogLine.tsx index 2fc68a63c2..d7b3c50dbf 100644 --- a/site/src/modules/resources/AgentLogs/AgentLogLine.tsx +++ b/site/src/modules/resources/AgentLogs/AgentLogLine.tsx @@ -5,7 +5,7 @@ import { type Line, LogLine, LogLinePrefix } from "#/components/Logs/LogLine"; // Approximate height of a log line. Used to control virtualized list height. export const AGENT_LOG_LINE_HEIGHT = 20; -const convert = new AnsiToHTML(); +const convert = new AnsiToHTML({ escapeXML: true }); interface AgentLogLineProps { line: Line; diff --git a/site/src/modules/resources/DownloadAgentLogsButton.stories.tsx b/site/src/modules/resources/DownloadAgentLogsButton.stories.tsx index 08f3f60e5d..b2c9114a46 100644 --- a/site/src/modules/resources/DownloadAgentLogsButton.stories.tsx +++ b/site/src/modules/resources/DownloadAgentLogsButton.stories.tsx @@ -34,7 +34,7 @@ export const ClickOnDownload: Story = { play: async ({ canvasElement, args }) => { const canvas = within(canvasElement); await userEvent.click( - canvas.getByRole("button", { name: "Download logs" }), + canvas.getByRole("button", { name: "Download agent logs" }), ); await waitFor(() => expect(args.download).toHaveBeenCalledWith( diff --git a/site/src/modules/resources/DownloadAgentLogsButton.tsx b/site/src/modules/resources/DownloadAgentLogsButton.tsx index 6850578897..fbc676d9d9 100644 --- a/site/src/modules/resources/DownloadAgentLogsButton.tsx +++ b/site/src/modules/resources/DownloadAgentLogsButton.tsx @@ -58,7 +58,7 @@ export const DownloadAgentLogsButton: FC = ({ }} > - {isDownloading ? "Downloading..." : "Download logs"} + {isDownloading ? "Downloading..." : "Download agent logs"} ); }; diff --git a/site/src/modules/resources/DownloadSelectedAgentLogsButton.tsx b/site/src/modules/resources/DownloadSelectedAgentLogsButton.tsx index 7aab6fd56c..16509cb568 100644 --- a/site/src/modules/resources/DownloadSelectedAgentLogsButton.tsx +++ b/site/src/modules/resources/DownloadSelectedAgentLogsButton.tsx @@ -63,7 +63,7 @@ export const DownloadSelectedAgentLogsButton: FC< > - {isDownloading ? "Downloading..." : "Download logs"} + {isDownloading ? "Downloading..." : "Download agent logs"} diff --git a/site/src/modules/workspaces/WorkspaceSharingForm/UserOrGroupAutocomplete.tsx b/site/src/modules/workspaces/WorkspaceSharingForm/UserOrGroupAutocomplete.tsx index ff7e5dba59..4a09998718 100644 --- a/site/src/modules/workspaces/WorkspaceSharingForm/UserOrGroupAutocomplete.tsx +++ b/site/src/modules/workspaces/WorkspaceSharingForm/UserOrGroupAutocomplete.tsx @@ -22,6 +22,7 @@ type UserOrGroupAutocompleteProps = { onChange: (value: UserOrGroupAutocompleteValue) => void; organizationId: string; exclude: ExcludableOption[]; + className?: string; }; const normalizeMember = ( @@ -36,6 +37,7 @@ export const UserOrGroupAutocomplete: FC = ({ onChange, organizationId, exclude, + className = "w-80", }) => { const [inputValue, setInputValue] = useState(""); const [open, setOpen] = useState(false); @@ -132,7 +134,7 @@ export const UserOrGroupAutocomplete: FC = ({ loading={membersQuery.isFetching || groupsQuery.isFetching} placeholder="Search for user or group" noOptionsText="No users or groups found" - className="w-80" + className={className} id="workspace-user-or-group-autocomplete" /> ); diff --git a/site/src/pages/AISettingsPage/ProvidersPage/ProvidersPageView.tsx b/site/src/pages/AISettingsPage/ProvidersPage/ProvidersPageView.tsx index d507d8fcd2..ec82741894 100644 --- a/site/src/pages/AISettingsPage/ProvidersPage/ProvidersPageView.tsx +++ b/site/src/pages/AISettingsPage/ProvidersPage/ProvidersPageView.tsx @@ -1,5 +1,5 @@ import { ChevronDownIcon, PlusIcon } from "lucide-react"; -import { Link, useNavigate } from "react-router"; +import { useNavigate } from "react-router"; import type { AIProvider } from "#/api/typesGenerated"; import { ErrorAlert } from "#/components/Alert/ErrorAlert"; import { Button } from "#/components/Button/Button"; @@ -9,6 +9,7 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from "#/components/DropdownMenu/DropdownMenu"; +import { Link } from "#/components/Link/Link"; import { SettingsHeader, SettingsHeaderDescription, @@ -87,12 +88,8 @@ const ProvidersPageView: React.FC = ({ Bedrock. Providers configured here power Coder Agents, AI Gateway, and other capabilities such as APIs, CLI or IDEs that use LLMs. By default, users can supply their own keys for any provider.{" "} - - Manage deployment-wide BYOK + + View docs diff --git a/site/src/pages/AgentsPage/AgentChatPage.stories.tsx b/site/src/pages/AgentsPage/AgentChatPage.stories.tsx index aeac0227b1..0945ca1fc9 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.stories.tsx +++ b/site/src/pages/AgentsPage/AgentChatPage.stories.tsx @@ -190,15 +190,13 @@ const extractPromptsFromMessages = ( return prompts; }; type ChatAuthorizationFixture = { - action: "share" | "update"; + action: "share"; allowed: boolean; }; const buildChatAuthorizationQuery = ( chat: Pick, - checks: Partial< - Record<"canShareChat" | "canUpdateChat", ChatAuthorizationFixture> - >, + checks: Partial>, ) => { const authorizationChecks: TypesGen.AuthorizationRequest["checks"] = {}; const authorizationResponse: TypesGen.AuthorizationResponse = {}; @@ -1180,9 +1178,6 @@ export const WithMessageHistory: Story = { await canvas.findByText("Markdown rendering showcase"), ).toBeVisible(); await waitFor(() => { - expect( - canvas.queryByText(/^This is not your chat/), - ).not.toBeInTheDocument(); expect( canvas.queryByText(/^This chat is owned by/), ).not.toBeInTheDocument(); @@ -1220,7 +1215,7 @@ export const RootChatShareActionAvailable: Story = { await userEvent.click(canvas.getByLabelText("Share chat")); const body = within(document.body); await waitFor(() => { - expect(body.getByText("Chat Sharing")).toBeVisible(); + expect(body.getByText("Chat sharing")).toBeVisible(); }); await waitFor(() => { expect(body.getByText("No shared members or groups yet")).toBeVisible(); @@ -1246,87 +1241,106 @@ export const Loading: Story = { }, }; -export const AdminViewingOtherUserChat: Story = { +export const OtherUserChatReadOnly: Story = { parameters: { - queries: [ - ...buildQueries( - { - id: CHAT_ID, - ...baseChatFields, - owner_id: "other-user-id", - owner_username: "OtherUser", - owner_name: "Other User", - title: "Other user's chat", - status: "completed", - }, - { messages: [], queued_messages: [], has_more: false }, - { diffUrl: undefined }, - ), - buildChatAuthorizationQuery( - { - owner_id: "other-user-id", - organization_id: baseChatFields.organization_id, - }, - { - canUpdateChat: { action: "update", allowed: true }, - canShareChat: { action: "share", allowed: false }, - }, - ), - ], + queries: buildQueries( + { + id: CHAT_ID, + ...baseChatFields, + owner_id: "other-user-id", + owner_username: "OtherUser", + owner_name: "Other User", + title: "Other user's chat", + status: "completed", + }, + { messages: [], queued_messages: [], has_more: false }, + { diffUrl: undefined }, + ), }, play: async ({ canvasElement }) => { const canvas = within(canvasElement); const banner = await canvas.findByText( - "This is not your chat. Prompting here will use Other User's identity.", + "This chat is owned by Other User. It is read-only.", ); expect(banner).toBeVisible(); expect(banner).toHaveAttribute("role", "status"); expect(canvas.getByRole("textbox")).toHaveAttribute( "aria-disabled", - "false", + "true", ); }, }; -export const SharedReadOnlyChat: Story = { +export const OtherUserChatWithMessages: Story = { parameters: { - queries: [ - ...buildQueries( - { - id: CHAT_ID, - ...baseChatFields, - owner_id: "other-user-id", - owner_username: "OtherUser", - owner_name: "Other User", - title: "Shared read-only chat", - status: "completed", - }, - { messages: [], queued_messages: [], has_more: false }, - { diffUrl: undefined }, - ), - buildChatAuthorizationQuery( - { - owner_id: "other-user-id", - organization_id: baseChatFields.organization_id, - }, - { - canUpdateChat: { action: "update", allowed: false }, - canShareChat: { action: "share", allowed: false }, - }, - ), - ], + queries: buildQueries( + { + id: CHAT_ID, + ...baseChatFields, + owner_id: "other-user-id", + owner_username: "OtherUser", + owner_name: "Other User", + title: "Other user's chat with messages", + status: "completed", + }, + { + messages: [ + { + id: 1, + chat_id: CHAT_ID, + created_at: "2026-02-18T00:00:01.000Z", + role: "user", + content: [{ type: "text", text: "Please review this plan." }], + }, + { + id: 2, + chat_id: CHAT_ID, + created_at: "2026-02-18T00:00:02.000Z", + role: "assistant", + content: [ + { type: "text", text: "I prepared a plan." }, + { + type: "tool-call", + tool_call_id: "other-user-plan", + tool_name: "propose_plan", + args: { path: "/home/coder/PLAN.md" }, + }, + { + type: "tool-result", + tool_call_id: "other-user-plan", + tool_name: "propose_plan", + result: { + file_id: "other-user-plan-file", + content: "# Plan\n\n1. Keep this chat read-only.", + }, + }, + ], + }, + ] as TypesGen.ChatMessage[], + queued_messages: [], + has_more: false, + }, + { diffUrl: undefined }, + ), }, play: async ({ canvasElement }) => { const canvas = within(canvasElement); expect( await canvas.findByText( - "This chat is owned by Other User. You have read-only access.", + "This chat is owned by Other User. It is read-only.", ), ).toBeVisible(); + expect(await canvas.findByText("Please review this plan.")).toBeVisible(); expect(canvas.getByRole("textbox")).toHaveAttribute( "aria-disabled", "true", ); + expect( + canvas.queryByRole("button", { name: "Edit message" }), + ).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Implement plan" }), + ).not.toBeInTheDocument(); }, }; @@ -1352,9 +1366,6 @@ export const ArchivedOtherUserChat: Story = { expect( await canvas.findByText("This agent has been archived and is read-only."), ).toBeVisible(); - expect( - canvas.queryByText(/^This is not your chat/), - ).not.toBeInTheDocument(); expect( canvas.queryByText(/^This chat is owned by/), ).not.toBeInTheDocument(); diff --git a/site/src/pages/AgentsPage/AgentChatPage.tsx b/site/src/pages/AgentsPage/AgentChatPage.tsx index 9a73a122cf..a6d968d8af 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.tsx +++ b/site/src/pages/AgentsPage/AgentChatPage.tsx @@ -845,8 +845,6 @@ const AgentChatPage: FC = () => { chatRecord !== undefined && currentUser.id !== chatRecord.owner_id; const isRootChat = chatRecord !== undefined && getParentChatID(chatRecord) === undefined; - const shouldCheckCanUpdateOtherUserChat = - isViewerNotOwner && !isArchived && chatRecord !== undefined; const shouldCheckCanShareChat = isRootChat; const chatAuthorizationObject = chatRecord !== undefined @@ -857,15 +855,6 @@ const AgentChatPage: FC = () => { } : undefined; const chatAuthorizationChecks: TypesGen.AuthorizationRequest["checks"] = {}; - if ( - chatAuthorizationObject !== undefined && - shouldCheckCanUpdateOtherUserChat - ) { - chatAuthorizationChecks.canUpdateChat = { - object: chatAuthorizationObject, - action: "update", - }; - } if (chatAuthorizationObject !== undefined && shouldCheckCanShareChat) { chatAuthorizationChecks.canShareChat = { object: chatAuthorizationObject, @@ -876,20 +865,16 @@ const AgentChatPage: FC = () => { ...checkAuthorization({ checks: chatAuthorizationChecks }), enabled: Object.keys(chatAuthorizationChecks).length > 0, }); - const canUpdateOtherUserChat = Boolean( - chatAuthorizationQuery.data?.canUpdateChat, - ); - const canUpdateOtherUserChatLoading = - shouldCheckCanUpdateOtherUserChat && chatAuthorizationQuery.isLoading; const canShareChat = isRootChat && Boolean(chatAuthorizationQuery.data?.canShareChat); - const chatOwner = - isViewerNotOwner && chatRecord?.owner_username - ? { - username: chatRecord.owner_username, - ...(chatRecord.owner_name ? { name: chatRecord.owner_name } : {}), - } - : undefined; + const chatOwner = isViewerNotOwner + ? { + ...(chatRecord?.owner_username + ? { username: chatRecord.owner_username } + : {}), + ...(chatRecord?.owner_name ? { name: chatRecord.owner_name } : {}), + } + : undefined; const planModeEnabled = chatRecord?.plan_mode === "plan"; // Initialize MCP selection from chat record or defaults. @@ -1144,11 +1129,7 @@ const AgentChatPage: FC = () => { const isChatSettingsPending = isUpdateChatPlanModePending || isUpdateChatWorkspacePending; const isInputDisabled = - !hasModelOptions || - isArchived || - isChatSettingsPending || - (isViewerNotOwner && - (canUpdateOtherUserChatLoading || !canUpdateOtherUserChat)); + !hasModelOptions || isArchived || isChatSettingsPending || isViewerNotOwner; const selectedWorkspaceId = chatQuery.data?.workspace_id ?? null; const isWorkspaceLoading = @@ -1182,9 +1163,11 @@ const AgentChatPage: FC = () => { store.setStreamError(reason); setChatErrorReason(agentId, reason); } else if (isApiError(error)) { + const detail = error.response?.data?.detail?.trim() || undefined; const reason: ChatDetailError = { kind: "generic", - message: error.message || "An unexpected error occurred.", + message: getErrorMessage(error, "An unexpected error occurred."), + ...(detail ? { detail } : {}), }; store.setStreamError(reason); setChatErrorReason(agentId, reason); @@ -1598,8 +1581,6 @@ const AgentChatPage: FC = () => { persistedError={persistedError} isArchived={isArchived} chatOwner={chatOwner} - canUpdateOtherUserChat={canUpdateOtherUserChat} - canUpdateOtherUserChatLoading={canUpdateOtherUserChatLoading} canShareChat={canShareChat} workspace={workspace} workspaceAgent={workspaceAgent} diff --git a/site/src/pages/AgentsPage/AgentChatPageView.stories.tsx b/site/src/pages/AgentsPage/AgentChatPageView.stories.tsx index b14563d320..baed2e4d8e 100644 --- a/site/src/pages/AgentsPage/AgentChatPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentChatPageView.stories.tsx @@ -153,8 +153,6 @@ const StoryAgentChatPageView: FC = ({ editing, ...overrides }) => { modelSelectorPlaceholder: "Select a model", hasModelOptions: true, compressionThreshold: undefined as number | undefined, - canUpdateOtherUserChat: false, - canUpdateOtherUserChatLoading: false, isInputDisabled: false, isSubmissionPending: false, isInterruptPending: false, @@ -231,7 +229,7 @@ export const Default: Story = { play: async ({ canvasElement }) => { const canvas = within(canvasElement); expect( - canvas.queryByText(/^This is not your chat/), + canvas.queryByText(/^This chat is owned by/), ).not.toBeInTheDocument(); }, }; @@ -241,38 +239,20 @@ export const Archived: Story = { render: () => , }; -export const AdminViewingOtherUserChat: Story = { +export const OtherUserChatReadOnly: Story = { render: () => ( - ), - play: async ({ canvasElement }) => { - const canvas = within(canvasElement); - const banner = canvas.getByText( - "This is not your chat. Prompting here will use Other User's identity.", - ); - expect(banner).toBeVisible(); - expect(banner).toHaveAttribute("role", "status"); - }, -}; - -export const OtherUserChatOwnerPending: Story = { - render: () => ( - ), play: async ({ canvasElement }) => { const canvas = within(canvasElement); - expect( - canvas.queryByText(/^This is not your chat/), - ).not.toBeInTheDocument(); - expect(canvas.queryByText(/other-user-id/)).not.toBeInTheDocument(); + const banner = canvas.getByText( + "This chat is owned by Other User. It is read-only.", + ); + expect(banner).toBeVisible(); + expect(banner).toHaveAttribute("role", "status"); expect(canvas.getByLabelText("Chat message")).toHaveAttribute( "aria-disabled", "true", @@ -280,32 +260,20 @@ export const OtherUserChatOwnerPending: Story = { }, }; -export const ReadOnlyOtherUserChatOwner: Story = { - render: () => ( - - ), - play: async ({ canvasElement }) => { - const canvas = within(canvasElement); - const banner = canvas.getByText( - "This chat is owned by @OtherUser. You have read-only access.", - ); - expect(banner).toBeVisible(); - expect(banner).toHaveAttribute("role", "status"); - }, -}; - -export const ReadOnlyOtherUserChatOwnerPending: Story = { +export const OtherUserChatUsernameFallback: Story = { render: () => ( ), play: async ({ canvasElement }) => { const canvas = within(canvasElement); - expect(canvas.queryByText(/^This chat is owned/)).not.toBeInTheDocument(); - expect(canvas.queryByText(/other-user-id/)).not.toBeInTheDocument(); + const banner = canvas.getByText( + "This chat is owned by @OtherUser. It is read-only.", + ); + expect(banner).toBeVisible(); + expect(banner).toHaveAttribute("role", "status"); expect(canvas.getByLabelText("Chat message")).toHaveAttribute( "aria-disabled", "true", @@ -313,7 +281,23 @@ export const ReadOnlyOtherUserChatOwnerPending: Story = { }, }; -/** Archived chats stay read-only without the identity warning banner. */ +export const OtherUserChatOwnerFallback: Story = { + render: () => , + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const banner = canvas.getByText( + "This chat is owned by another user. It is read-only.", + ); + expect(banner).toBeVisible(); + expect(banner).toHaveAttribute("role", "status"); + expect(canvas.getByLabelText("Chat message")).toHaveAttribute( + "aria-disabled", + "true", + ); + }, +}; + +/** Archived chats stay read-only without the owner banner. */ export const ArchivedOtherUserChat: Story = { render: () => ( { const canvas = within(canvasElement); expect( - canvas.queryByText(/^This is not your chat/), + canvas.queryByText(/^This chat is owned by/), ).not.toBeInTheDocument(); expect( canvas.getByText("This agent has been archived and is read-only."), @@ -775,18 +759,25 @@ export const LoadingSidebarCollapsed: Story = { // Helpers for seeding stores with messages // --------------------------------------------------------------------------- -const buildMessage = ( +const buildMessageWithContent = ( id: number, role: TypesGen.ChatMessageRole, - text: string, + content: TypesGen.ChatMessagePart[], ): TypesGen.ChatMessage => ({ id, chat_id: AGENT_ID, created_at: new Date(Date.now() - (10 - id) * 60_000).toISOString(), role, - content: [{ type: "text", text }], + content, }); +const buildMessage = ( + id: number, + role: TypesGen.ChatMessageRole, + text: string, +): TypesGen.ChatMessage => + buildMessageWithContent(id, role, [{ type: "text", text }]); + const buildStoreWithMessages = ( msgs: TypesGen.ChatMessage[], status: TypesGen.ChatStatus = "completed", @@ -797,6 +788,52 @@ const buildStoreWithMessages = ( return store; }; +const otherUserActionMessages: TypesGen.ChatMessage[] = [ + buildMessage(1, "user", "Please review this plan."), + buildMessageWithContent(2, "assistant", [ + { type: "text", text: "I prepared a plan." }, + { + type: "tool-call", + tool_call_id: "other-user-plan", + tool_name: "propose_plan", + args: { path: "/home/coder/PLAN.md" }, + }, + { + type: "tool-result", + tool_call_id: "other-user-plan", + tool_name: "propose_plan", + result: { + file_id: "other-user-plan-file", + content: "# Plan\n\n1. Keep this chat read-only.", + }, + }, + ]), +]; + +export const OtherUserChatHidesInlineActions: Story = { + render: () => ( + + ), + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + canvas.getByText("This chat is owned by Other User. It is read-only."), + ).toBeVisible(); + expect(await canvas.findByText("Please review this plan.")).toBeVisible(); + expect( + canvas.queryByRole("button", { name: "Edit message" }), + ).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Implement plan" }), + ).not.toBeInTheDocument(); + }, +}; + // --------------------------------------------------------------------------- // Editing flow stories // --------------------------------------------------------------------------- @@ -1511,7 +1548,7 @@ export const ArchivedWithSharing: Story = { await userEvent.click(canvas.getByLabelText("Share chat")); const body = within(document.body); await waitFor(() => { - expect(body.getByText("Chat Sharing")).toBeVisible(); + expect(body.getByText("Chat sharing")).toBeVisible(); }); await waitFor(() => { expect(body.getByText("No shared members or groups yet")).toBeVisible(); @@ -1543,7 +1580,7 @@ export const ShareChatPopoverFromTopBar: Story = { await userEvent.click(canvas.getByLabelText("Share chat")); const body = within(document.body); await waitFor(() => { - expect(body.getByText("Chat Sharing")).toBeVisible(); + expect(body.getByText("Chat sharing")).toBeVisible(); }); await waitFor(() => { expect(body.getByText("No shared members or groups yet")).toBeVisible(); diff --git a/site/src/pages/AgentsPage/AgentChatPageView.tsx b/site/src/pages/AgentsPage/AgentChatPageView.tsx index a99e3ece9e..b71f4bf0c4 100644 --- a/site/src/pages/AgentsPage/AgentChatPageView.tsx +++ b/site/src/pages/AgentsPage/AgentChatPageView.tsx @@ -51,6 +51,11 @@ import type { ChatDetailError } from "./utils/usageLimitMessage"; type ChatStoreHandle = ReturnType["store"]; +type ChatOwnerInfo = { + name?: string; + username?: string; +}; + // Re-use the inner presentational components directly. They are interface EditingState { @@ -93,9 +98,7 @@ interface AgentChatPageViewProps { parentChat: TypesGen.Chat | undefined; persistedError: ChatDetailError | undefined; isArchived: boolean; - chatOwner: Pick | undefined; - canUpdateOtherUserChat: boolean; - canUpdateOtherUserChatLoading: boolean; + chatOwner: ChatOwnerInfo | undefined; canShareChat: boolean; workspaceAgent?: TypesGen.WorkspaceAgent; workspace?: TypesGen.Workspace; @@ -201,8 +204,6 @@ export const AgentChatPageView: FC = ({ persistedError, isArchived, chatOwner, - canUpdateOtherUserChat, - canUpdateOtherUserChatLoading, canShareChat, workspaceAgent, workspace, @@ -422,16 +423,14 @@ export const AgentChatPageView: FC = ({ editing.editingMessageId !== null || editing.editingQueuedMessageID !== null; - const chatOwnerUsername = chatOwner?.username.trim(); + const chatOwnerUsername = chatOwner?.username?.trim(); const chatOwnerLabel = chatOwner?.name?.trim() || - (chatOwnerUsername ? `@${chatOwnerUsername}` : undefined); - const chatOwnerWarning = - isArchived || canUpdateOtherUserChatLoading || chatOwnerLabel === undefined - ? undefined - : canUpdateOtherUserChat - ? `This is not your chat. Prompting here will use ${chatOwnerLabel}'s identity.` - : `This chat is owned by ${chatOwnerLabel}. You have read-only access.`; + (chatOwnerUsername ? `@${chatOwnerUsername}` : "another user"); + const isOtherUserReadOnly = !isArchived && chatOwner !== undefined; + const chatOwnerWarning = isOtherUserReadOnly + ? `This chat is owned by ${chatOwnerLabel}. It is read-only.` + : undefined; const titleElement = ( @@ -535,12 +534,22 @@ export const AgentChatPageView: FC<AgentChatPageViewProps> = ({ chatID={agentId} store={store} persistedError={persistedError} - onEditUserMessage={editing.handleEditUserMessage} + onEditUserMessage={ + isOtherUserReadOnly + ? undefined + : editing.handleEditUserMessage + } editingMessageId={editing.editingMessageId} urlTransform={urlTransform} mcpServers={mcpServers} - onImplementPlan={onImplementPlan} - onSendAskUserQuestionResponse={canSendAskUserQuestionResponse} + onImplementPlan={ + isOtherUserReadOnly ? undefined : onImplementPlan + } + onSendAskUserQuestionResponse={ + isOtherUserReadOnly + ? undefined + : canSendAskUserQuestionResponse + } /> </div> </ChatScrollContainer> diff --git a/site/src/pages/AgentsPage/AgentSettingsModelsPage.tsx b/site/src/pages/AgentsPage/AgentSettingsModelsPage.tsx index b7a8b36ba7..4b8c50fd96 100644 --- a/site/src/pages/AgentsPage/AgentSettingsModelsPage.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsModelsPage.tsx @@ -47,7 +47,7 @@ const AgentSettingsModelsPage: FC = () => { <ChatModelAdminPanel section="models" sectionLabel="Models" - sectionDescription="Choose which models from your configured providers are available for users to select. You can set a default and adjust context limits." + sectionDescription="Choose which models from your configured providers are available for Coder Agents. Set a default and adjust context limits." providerConfigsData={providerConfigsQuery.data} modelConfigsData={modelConfigsQuery.data} modelCatalogData={modelCatalogQuery.data} diff --git a/site/src/pages/AgentsPage/DesktopPopoutPage.stories.tsx b/site/src/pages/AgentsPage/DesktopPopoutPage.stories.tsx new file mode 100644 index 0000000000..583e89e034 --- /dev/null +++ b/site/src/pages/AgentsPage/DesktopPopoutPage.stories.tsx @@ -0,0 +1,69 @@ +import type { Meta, StoryObj } from "@storybook/react-vite"; +import { expect, fn, within } from "storybook/test"; +import { DesktopPopoutPageView } from "./DesktopPopoutPage"; + +const meta = { + title: "pages/AgentsPage/DesktopPopoutPage", + component: DesktopPopoutPageView, + parameters: { + layout: "fullscreen", + }, +} satisfies Meta<typeof DesktopPopoutPageView>; + +export default meta; +type Story = StoryObj<typeof meta>; + +export const Connecting: Story = { + args: { + status: "connecting", + reconnect: fn(), + attach: fn(), + scaleMode: "fit", + onScaleModeChange: fn(), + isControlling: false, + onTakeControl: fn(), + onReleaseControl: fn(), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + await expect( + canvas.getByText("Connecting to desktop..."), + ).toBeInTheDocument(); + }, +}; + +export const Connected: Story = { + args: { + ...Connecting.args, + status: "connected", + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + await expect(canvas.getByText("Take control")).toBeInTheDocument(); + await expect(canvas.getByText("Zoom to 100%")).toBeInTheDocument(); + }, +}; + +export const ErrorState: Story = { + args: { + ...Connecting.args, + status: "error", + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + await expect(canvas.getByText("Reconnect")).toBeInTheDocument(); + }, +}; + +export const Disconnected: Story = { + args: { + ...Connecting.args, + status: "disconnected", + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + await expect( + canvas.getByText("Desktop disconnected. Reconnecting..."), + ).toBeInTheDocument(); + }, +}; diff --git a/site/src/pages/AgentsPage/DesktopPopoutPage.tsx b/site/src/pages/AgentsPage/DesktopPopoutPage.tsx new file mode 100644 index 0000000000..2377e0e147 --- /dev/null +++ b/site/src/pages/AgentsPage/DesktopPopoutPage.tsx @@ -0,0 +1,161 @@ +import type { FC } from "react"; +import { useEffect, useState } from "react"; +import { useParams } from "react-router"; +import { Button } from "#/components/Button/Button"; +import { Spinner } from "#/components/Spinner/Spinner"; +import { + DesktopToolbar, + type ScaleMode, +} from "./components/RightPanel/DesktopToolbar"; +import { + type DesktopConnectionStatus, + useDesktopConnection, +} from "./hooks/useDesktopConnection"; +import { useZoomShortcuts } from "./hooks/useZoomShortcuts"; + +export default function DesktopPopoutPage() { + const { agentId } = useParams() as { agentId: string }; + const [scaleMode, setScaleMode] = useState<ScaleMode>("fit"); + const [isControlling, setIsControlling] = useState(false); + + const { status, reconnect, attach } = useDesktopConnection({ + chatId: agentId, + activated: true, + scaleViewport: scaleMode === "fit", + }); + + // BroadcastChannel for parent window communication. + useEffect(() => { + const channel = new BroadcastChannel(`coder-desktop-${agentId}`); + + channel.postMessage({ type: "popout-opened" }); + + // Retry in case the parent's listener registered after this message. + const retryTimer = setTimeout(() => { + channel.postMessage({ type: "popout-opened" }); + }, 300); + + channel.addEventListener("message", (event) => { + if (event.data?.type === "bring-back") { + close(); + } + }); + + const handleBeforeUnload = () => { + channel.postMessage({ type: "popout-closed" }); + }; + addEventListener("beforeunload", handleBeforeUnload); + + return () => { + clearTimeout(retryTimer); + handleBeforeUnload(); + removeEventListener("beforeunload", handleBeforeUnload); + channel.close(); + }; + }, [agentId]); + + useZoomShortcuts(setScaleMode); + + return ( + <DesktopPopoutPageView + status={status} + reconnect={reconnect} + attach={attach} + scaleMode={scaleMode} + onScaleModeChange={setScaleMode} + isControlling={isControlling} + onTakeControl={() => setIsControlling(true)} + onReleaseControl={() => setIsControlling(false)} + /> + ); +} + +export interface DesktopPopoutPageViewProps { + status: DesktopConnectionStatus; + reconnect: () => void; + attach: (container: HTMLElement) => void; + scaleMode: ScaleMode; + onScaleModeChange: (mode: ScaleMode) => void; + isControlling: boolean; + onTakeControl: () => void; + onReleaseControl: () => void; +} + +export const DesktopPopoutPageView: FC<DesktopPopoutPageViewProps> = ({ + status, + reconnect, + attach, + scaleMode, + onScaleModeChange, + isControlling, + onTakeControl, + onReleaseControl, +}) => { + if (status === "idle" || status === "connecting") { + return ( + <div className="flex h-screen w-screen items-center justify-center bg-surface-primary"> + <div className="flex flex-col items-center gap-2 text-content-secondary"> + <Spinner loading className="h-6 w-6" /> + <span className="text-sm"> + {status === "idle" + ? "Initializing desktop..." + : "Connecting to desktop..."} + </span> + </div> + </div> + ); + } + + if (status === "error") { + return ( + <div className="flex h-screen w-screen items-center justify-center bg-surface-primary"> + <div className="flex flex-col items-center gap-3 text-content-secondary"> + <span className="text-center text-sm"> + Failed to connect to the desktop session. The agent may not be + connected or the desktop environment may not be available. + </span> + <Button variant="outline" size="sm" onClick={reconnect}> + Reconnect + </Button> + </div> + </div> + ); + } + + if (status === "disconnected") { + return ( + <div className="flex h-screen w-screen items-center justify-center bg-surface-primary"> + <div className="flex flex-col items-center gap-2 text-content-secondary"> + <Spinner loading className="h-6 w-6" /> + <span className="text-sm">Desktop disconnected. Reconnecting...</span> + </div> + </div> + ); + } + + return ( + <div className="flex h-screen w-screen flex-col overflow-hidden bg-surface-secondary"> + <DesktopToolbar + scaleMode={scaleMode} + onScaleModeChange={onScaleModeChange} + isControlling={isControlling} + onTakeControl={onTakeControl} + onReleaseControl={onReleaseControl} + isPoppedOut + /> + <div + ref={(el) => { + if (el) attach(el); + }} + className="min-h-0 flex-1 overflow-hidden bg-surface-secondary" + inert={!isControlling ? true : undefined} + role="application" + aria-label={ + isControlling + ? "Remote desktop (interactive)" + : "Remote desktop (view only, take control to interact)" + } + /> + </div> + ); +}; diff --git a/site/src/pages/AgentsPage/components/ChatConversation/ChatStatusCallout.tsx b/site/src/pages/AgentsPage/components/ChatConversation/ChatStatusCallout.tsx index c030cba85f..857d15a5eb 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/ChatStatusCallout.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/ChatStatusCallout.tsx @@ -1,7 +1,9 @@ import { type FC, useEffect, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "#/components/Alert/Alert"; import { Link } from "#/components/Link/Link"; -import { Response, Shimmer } from "../ChatElements"; +import { Shimmer } from "../ChatElements"; +import { TranscriptRow } from "../ChatElements/TranscriptRow"; +import { ToolIcon } from "../ChatElements/tools/ToolIcon"; import { getProviderStatusURL } from "./chatStatusHelpers"; import type { LiveStatusModel } from "./liveStatusModel"; @@ -18,25 +20,21 @@ type ReconnectingStatus = Extract<LiveStatusModel, { phase: "reconnecting" }>; const StatusPlaceholder: FC<{ text: string; shimmer?: boolean; -}> = ({ text, shimmer = false }) => { + showThinkingIcon?: boolean; +}> = ({ text, shimmer = false, showThinkingIcon = false }) => { return ( - <div className="relative min-h-6"> - {/* Reserve the final response height without exposing a selectable copy. */} - <Response aria-hidden className="invisible select-none"> - {text} - </Response> - <div className="pointer-events-none absolute inset-0 flex items-baseline gap-2"> - {shimmer ? ( - <Shimmer as="div" className="text-[13px] leading-relaxed"> - {text} - </Shimmer> - ) : ( - <span className="text-[13px] leading-relaxed text-content-secondary"> - {text} - </span> - )} - </div> - </div> + <TranscriptRow className="gap-2 text-content-secondary"> + {showThinkingIcon && <ToolIcon name="thinking" isError={false} />} + {shimmer ? ( + <Shimmer as="span" className="text-[13px] leading-6"> + {text} + </Shimmer> + ) : ( + <span className="text-[13px] leading-6 text-content-secondary"> + {text} + </span> + )} + </TranscriptRow> ); }; @@ -54,6 +52,7 @@ const StartingPlaceholder: FC = () => { <StatusPlaceholder text={isDelayed ? DELAYED_STARTUP_TEXT : THINKING_TEXT} shimmer={!isDelayed} + showThinkingIcon={!isDelayed} /> ); }; @@ -158,11 +157,17 @@ const StatusAlert: FC<{ status: RetryOrFailedStatus }> = ({ status }) => { </Link> )} </span> - {status.phase === "failed" && status.detail && ( - <span className="mt-1 block text-content-secondary"> - {status.detail} - </span> - )} + {status.phase === "failed" && + status.detail && + (status.kind === "generic" ? ( + <code className="mt-1 block whitespace-pre-wrap text-xs text-content-secondary font-mono bg-surface-secondary rounded-md"> + {status.detail} + </code> + ) : ( + <span className="mt-1 block text-content-secondary"> + {status.detail} + </span> + ))} </AlertDescription> </Alert> ); diff --git a/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx b/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx index daf1b4c007..83024550f3 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx @@ -39,6 +39,7 @@ import { } from "../ChatElements/tools/ReadFileTool"; import type { SubagentVariant } from "../ChatElements/tools/subagentDescriptor"; import { ToolCollapsible } from "../ChatElements/tools/ToolCollapsible"; +import { ToolIcon } from "../ChatElements/tools/ToolIcon"; import { ImageLightbox } from "../ImageLightbox"; import { TextPreviewDialog } from "../TextPreviewDialog"; import { @@ -160,13 +161,16 @@ const ReasoningDisclosure = memo<{ expanded={expanded} onExpandedChange={(open) => setManualToggle(open)} header={ - isStreaming ? ( - <Shimmer as="span" className="text-[13px]"> - {title} - </Shimmer> - ) : ( - <span className="text-[13px]">{title}</span> - ) + <> + <ToolIcon name="thinking" isError={false} /> + {isStreaming ? ( + <Shimmer as="span" className="text-[13px] leading-6"> + {title} + </Shimmer> + ) : ( + <span className="text-[13px] leading-6">{title}</span> + )} + </> } > {hasText && ( @@ -293,7 +297,7 @@ export const BlockList: FC<{ const thinkingDisplayMode: ThinkingDisplayMode = prefQuery.data?.thinking_display_mode || "auto"; const shellToolDisplayMode: TypesGen.AgentDisplayMode = - prefQuery.data?.shell_tool_display_mode || "auto"; + prefQuery.data?.shell_tool_display_mode || "always_collapsed"; const codeDiffDisplayMode: TypesGen.AgentDisplayMode = prefQuery.data?.code_diff_display_mode || "auto"; diff --git a/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.stories.tsx b/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.stories.tsx index a4ddf3bf6f..32484ed15c 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.stories.tsx @@ -193,6 +193,41 @@ export const TerminalTimeoutErrorUnknownProvider: Story = { }, }; +/** Missing API key shows the "Chat interrupted" terminal error. */ +export const TerminalMissingKeyError: Story = { + args: { + ...defaultArgs, + liveStatus: buildLiveStatus({ + streamError: { + kind: "missing_key", + message: + "This conversation was started with an API key that is no longer available. Send your message again to continue.", + retryable: false, + detail: + "If this error persists after resending, please report it as a bug.", + }, + }), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + canvas.getByRole("heading", { name: /chat interrupted/i }), + ).toBeVisible(); + expect( + canvas.getByText( + /this conversation was started with an api key that is no longer available/i, + ), + ).toBeVisible(); + expect( + canvas.getByText(/if this error persists after resending/i), + ).toBeVisible(); + // Guard against the generic fallback. + expect( + canvas.queryByText(/the chat request failed unexpectedly/i), + ).not.toBeInTheDocument(); + }, +}; + /** Retrying a transport timeout shows attempt + countdown. */ export const RetryingTimeoutAnthropic: Story = { args: { @@ -253,6 +288,40 @@ export const TerminalStartupTimeoutError: Story = { }, }; +/** Disabled provider errors render an admin-oriented message without retry. */ +export const TerminalProviderDisabledError: Story = { + args: { + ...defaultArgs, + liveStatus: buildLiveStatus({ + streamError: { + kind: "provider_disabled", + message: + "The OpenAI provider has been disabled. Contact your Coder administrator.", + provider: "openai", + retryable: false, + statusCode: 503, + }, + }), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + canvas.getByRole("heading", { name: /provider disabled/i }), + ).toBeVisible(); + expect( + canvas.getByText( + /the openai provider has been disabled.*contact your coder administrator/i, + ), + ).toBeVisible(); + expect(canvas.getByText(/^HTTP 503$/)).toBeVisible(); + // No retry or status link for administrative disablement. + expect(canvas.queryByText(/retrying/i)).not.toBeInTheDocument(); + expect( + canvas.queryByRole("link", { name: /status/i }), + ).not.toBeInTheDocument(); + }, +}; + /** Generic failures do not show usage or provider CTAs. */ export const GenericErrorDoesNotShowUsageAction: Story = { args: { @@ -282,7 +351,7 @@ export const GenericErrorDoesNotShowUsageAction: Story = { }, }; -/** Provider detail renders as a muted secondary line under the main error. */ +/** Provider detail renders in a monospace block for generic errors. */ export const GenericErrorShowsProviderDetail: Story = { args: { ...defaultArgs, diff --git a/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.tsx b/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.tsx index 93b2c5ce69..9d3d5da412 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/LiveStreamTail.tsx @@ -71,7 +71,13 @@ export const LiveStreamTailContent = ({ } return ( - <div className="flex flex-col gap-2"> + <div + className={ + isTranscriptEmpty + ? "flex flex-col gap-2" + : "mt-2 flex flex-col gap-2 empty:mt-0" + } + > {shouldRenderEmptyState && ( <div className="py-12 text-center text-content-secondary"> <p className="text-sm">Start a conversation with your agent.</p> diff --git a/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.stories.tsx b/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.stories.tsx index e78762dd6c..11546ed4db 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.stories.tsx @@ -282,7 +282,7 @@ export const ThinkingDuringStreamingWithToolCalls: Story = { expect(canvas.getAllByText("Thinking").length).toBeGreaterThanOrEqual(1); const executeButton = canvas.getByRole("button", { - name: /collapse command/i, + name: /expand command/i, }); const readFileLabel = canvas.getByText(/reading README\.md/i); const thinkingText = canvas.getAllByText("Thinking").at(-1); diff --git a/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.tsx b/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.tsx index d3ce729db2..af8534d960 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/StreamingOutput.tsx @@ -9,6 +9,7 @@ import { } from "../ChatElements"; import { TranscriptRow } from "../ChatElements/TranscriptRow"; import type { SubagentVariant } from "../ChatElements/tools/subagentDescriptor"; +import { ToolIcon } from "../ChatElements/tools/ToolIcon"; import { ChatStatusCallout } from "./ChatStatusCallout"; import { BlockList } from "./ConversationTimeline"; import type { LiveStatusModel } from "./liveStatusModel"; @@ -36,7 +37,8 @@ const hasTextOrReasoningBlock = (blocks: readonly RenderBlock[]): boolean => const StreamingThinkingPlaceholder: FC = () => ( <div data-transcript-row="" className="text-content-secondary"> <TranscriptRow className="w-full gap-2"> - <Shimmer as="span" className="text-[13px] leading-relaxed"> + <ToolIcon name="thinking" isError={false} /> + <Shimmer as="span" className="text-[13px] leading-6"> Thinking </Shimmer> </TranscriptRow> diff --git a/site/src/pages/AgentsPage/components/ChatConversation/chatStatusHelpers.ts b/site/src/pages/AgentsPage/components/ChatConversation/chatStatusHelpers.ts index c718e2a007..d9ea6f6e59 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/chatStatusHelpers.ts +++ b/site/src/pages/AgentsPage/components/ChatConversation/chatStatusHelpers.ts @@ -42,6 +42,10 @@ export const getErrorTitle = ( return "Configuration error"; case "usage_limit": return "Usage limit reached"; + case "missing_key": + return "Chat interrupted"; + case "provider_disabled": + return "Provider disabled"; default: return mode === "retry" ? "Retrying request" : "Request failed"; } diff --git a/site/src/pages/AgentsPage/components/ChatElements/Conversation.stories.tsx b/site/src/pages/AgentsPage/components/ChatElements/Conversation.stories.tsx index e0038879eb..17ab70c005 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/Conversation.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/Conversation.stories.tsx @@ -2,7 +2,6 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; import { Conversation, ConversationItem } from "./Conversation"; import { Message, MessageContent } from "./Message"; import { Shimmer } from "./Shimmer"; -import { Thinking } from "./Thinking"; const meta: Meta<typeof Conversation> = { title: "pages/AgentsPage/ChatElements/Conversation", @@ -37,10 +36,6 @@ export const ConversationWithMessages: Story = { <Message className="w-full"> <MessageContent className="whitespace-normal"> <div className="space-y-3"> - <Thinking> - Inspecting auth state and recent command output before - suggesting a fix. - </Thinking> <div className="text-sm text-content-primary"> The remote command failed because external auth needs to be refreshed. diff --git a/site/src/pages/AgentsPage/components/ChatElements/Thinking.tsx b/site/src/pages/AgentsPage/components/ChatElements/Thinking.tsx deleted file mode 100644 index f049c02183..0000000000 --- a/site/src/pages/AgentsPage/components/ChatElements/Thinking.tsx +++ /dev/null @@ -1,17 +0,0 @@ -import type { ComponentPropsWithRef } from "react"; -import { cn } from "#/utils/cn"; - -type ThinkingProps = ComponentPropsWithRef<"div">; - -export const Thinking = ({ className, ref, ...props }: ThinkingProps) => { - return ( - <div - ref={ref} - className={cn( - "rounded-lg border border-border bg-surface-primary px-3 py-2 text-xs text-content-secondary", - className, - )} - {...props} - /> - ); -}; diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/AdvisorTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/AdvisorTool.tsx index 1dac5ba647..eaaca11788 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/AdvisorTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/AdvisorTool.tsx @@ -53,55 +53,55 @@ export const AdvisorTool: React.FC<AdvisorToolProps> = ({ defaultExpanded headerClassName="items-start" header={(expanded) => ( - <> - <div className="flex min-w-0 flex-1 flex-col gap-0.5"> - <div className="flex min-w-0 items-center gap-2 leading-4"> - <ToolIcon - name="advisor" - isError={showError} - isRunning={isRunning} - /> - <ToolLabel - name="advisor" - args={{ question: questionText }} - result={resultType ? { type: resultType } : undefined} - /> - {isRunning && ( - <span className="shrink-0 rounded-full border border-solid border-border-default px-2 text-[13px] leading-4 text-content-secondary"> - {RUNNING_MESSAGE} - </span> - )} - {advisorModelText && ( - <span className="min-w-0 truncate rounded-full border border-solid border-border-default px-2 text-[13px] leading-4 text-content-secondary"> - {advisorModelText} - </span> - )} - {remainingUses !== undefined && ( - <span className="shrink-0 rounded-full border border-solid border-border-default px-2 text-[13px] leading-4 text-content-secondary"> - {remainingUses.toLocaleString("en-US")} uses left - </span> - )} - </div> - <span - className={cn( - "ml-6 block whitespace-normal break-words text-[13px]", - "font-normal leading-5 text-content-primary", - "[overflow-wrap:anywhere]", - !expanded && "line-clamp-2", - )} - > - {questionText} - </span> + <div className="flex min-w-0 flex-1 flex-col gap-0.5"> + <div className="flex min-w-0 items-center gap-2 leading-4"> + <ToolIcon + name="advisor" + isError={showError} + isRunning={isRunning} + /> + <ToolLabel + name="advisor" + args={{ question: questionText }} + result={resultType ? { type: resultType } : undefined} + /> + {isRunning && ( + <span className="shrink-0 rounded-full border border-solid border-border-default px-2 text-[13px] leading-4 text-content-secondary"> + {RUNNING_MESSAGE} + </span> + )} + {advisorModelText && ( + <span className="min-w-0 truncate rounded-full border border-solid border-border-default px-2 text-[13px] leading-4 text-content-secondary"> + {advisorModelText} + </span> + )} + {remainingUses !== undefined && ( + <span className="shrink-0 rounded-full border border-solid border-border-default px-2 text-[13px] leading-4 text-content-secondary"> + {remainingUses.toLocaleString("en-US")} uses left + </span> + )} </div> - {showLimitReached ? ( - <TriangleAlertIcon className="mt-0.5 size-3.5 shrink-0 text-content-warning" /> - ) : showError ? ( - <CircleAlertIcon className="mt-0.5 size-3.5 shrink-0 text-content-destructive" /> - ) : isRunning ? ( - <LoaderIcon className="mt-0.5 size-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" /> - ) : null} - </> + <span + className={cn( + "ml-6 block whitespace-normal break-words text-[13px]", + "font-normal leading-5 text-content-primary", + "[overflow-wrap:anywhere]", + !expanded && "line-clamp-2", + )} + > + {questionText} + </span> + </div> )} + headerStatus={ + showLimitReached ? ( + <TriangleAlertIcon className="mt-0.5 size-3.5 shrink-0 text-content-warning" /> + ) : showError ? ( + <CircleAlertIcon className="mt-0.5 size-3.5 shrink-0 text-content-destructive" /> + ) : isRunning ? ( + <LoaderIcon className="mt-0.5 size-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" /> + ) : null + } > <ScrollArea className="mt-1.5 rounded-md border border-solid border-border-default bg-surface-primary" diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx index 6bb5b1a07c..0a0386b962 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx @@ -10,6 +10,7 @@ import { Input } from "#/components/Input/Input"; import { RadioGroup, RadioGroupItem } from "#/components/RadioGroup/RadioGroup"; import { cn } from "#/utils/cn"; import { TranscriptRow } from "../TranscriptRow"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; export type AskUserQuestion = { @@ -539,8 +540,9 @@ export const AskUserQuestionTool: FC<AskUserQuestionToolProps> = ({ <div className="w-full"> <TranscriptRow role="alert" - className="gap-1.5 text-[13px] text-content-secondary" + className="gap-2 text-[13px] text-content-secondary" > + <ToolIcon name="ask_user_question" isError={isError} /> <TriangleAlertIcon aria-label="Error" className="size-3.5 shrink-0 text-content-secondary" @@ -555,7 +557,16 @@ export const AskUserQuestionTool: FC<AskUserQuestionToolProps> = ({ return ( <div className="w-full"> {isRunning ? ( - <TranscriptRow role="status" aria-live="polite" className="gap-1.5"> + <TranscriptRow + role="status" + aria-live="polite" + className="gap-2 text-content-secondary" + > + <ToolIcon + name="ask_user_question" + isError={false} + isRunning={isRunning} + /> <span className="text-[13px] text-content-secondary"> Asking for clarification... </span> @@ -677,7 +688,16 @@ export const AskUserQuestionTool: FC<AskUserQuestionToolProps> = ({ return ( <div className="w-full"> {isRunning && ( - <TranscriptRow role="status" aria-live="polite" className="gap-1.5"> + <TranscriptRow + role="status" + aria-live="polite" + className="gap-2 text-content-secondary" + > + <ToolIcon + name="ask_user_question" + isError={false} + isRunning={isRunning} + /> <span className="text-[13px] text-content-secondary"> Asking for clarification... </span> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ChatSummarizedTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ChatSummarizedTool.tsx index 3319ddfd46..61a041d890 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ChatSummarizedTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ChatSummarizedTool.tsx @@ -8,6 +8,7 @@ import { } from "#/components/Tooltip/Tooltip"; import { Response } from "../Response"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; /** @@ -29,9 +30,18 @@ export const ChatSummarizedTool: React.FC<{ hasContent={hasSummary} header={ <> - <span className="text-[13px]"> + <ToolIcon + name="chat_summarized" + isError={isError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6"> {isRunning ? "Summarizing…" : "Summarized"} </span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ComputerTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ComputerTool.tsx index ae70532609..513ee35175 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ComputerTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ComputerTool.tsx @@ -8,6 +8,7 @@ import { } from "#/components/Tooltip/Tooltip"; import { ImageLightbox } from "../../ImageLightbox"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; /** @@ -39,9 +40,14 @@ export const ComputerTool: React.FC<{ defaultExpanded={hasImage} header={ <> - <span className="text-[13px]"> + <ToolIcon name="computer" isError={isError} isRunning={isRunning} /> + <span className="text-[13px] leading-6"> {isRunning ? "Taking screenshot…" : "Screenshot"} </span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/CreateWorkspaceTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/CreateWorkspaceTool.tsx index 85a762705d..256717ef0d 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/CreateWorkspaceTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/CreateWorkspaceTool.tsx @@ -1,9 +1,4 @@ -import { - ExternalLinkIcon, - LoaderIcon, - MonitorIcon, - TriangleAlertIcon, -} from "lucide-react"; +import { ExternalLinkIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react"; import type React from "react"; import { Link } from "react-router"; import { @@ -12,6 +7,7 @@ import { TooltipTrigger, } from "#/components/Tooltip/Tooltip"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import { asRecord, asString, type ToolStatus } from "./utils"; import { WorkspaceBuildLogSection } from "./WorkspaceBuildLogSection"; @@ -72,8 +68,26 @@ export const CreateWorkspaceTool: React.FC<{ const header = ( <> - <MonitorIcon className="size-4 shrink-0 text-current" /> - <span className="text-[13px]">{label}</span> + <ToolIcon + name="create_workspace" + isError={isError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6">{label}</span> + {workspaceLink && !isRunning && ( + <Link + to={workspaceLink} + onClick={(e) => e.stopPropagation()} + className="ml-1 inline-flex align-middle text-content-secondary opacity-50 transition-opacity hover:opacity-100" + aria-label="View workspace" + > + <ExternalLinkIcon className="size-3" /> + </Link> + )} + </> + ); + const headerStatus = ( + <> {isError && ( <Tooltip> <TooltipTrigger asChild> @@ -87,16 +101,6 @@ export const CreateWorkspaceTool: React.FC<{ {isRunning && ( <LoaderIcon className="size-3.5 shrink-0 animate-spin motion-reduce:animate-none text-current" /> )} - {workspaceLink && !isRunning && ( - <Link - to={workspaceLink} - onClick={(e) => e.stopPropagation()} - className="ml-1 inline-flex align-middle text-content-secondary opacity-50 transition-opacity hover:opacity-100" - aria-label="View workspace" - > - <ExternalLinkIcon className="size-3" /> - </Link> - )} </> ); @@ -104,6 +108,7 @@ export const CreateWorkspaceTool: React.FC<{ <div className="w-full"> <ToolCollapsible header={header} + headerStatus={headerStatus} hasContent={hasBuildLogs} defaultExpanded={isRunning} > diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/EditFilesTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/EditFilesTool.tsx index 4f9f69cced..b7f9cf6a64 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/EditFilesTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/EditFilesTool.tsx @@ -16,6 +16,7 @@ import { resolveAgentDisplayState, } from "./displayMode"; import { AgentDisplayModeToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import { DIFFS_FONT_STYLE, type EditFilesFileEntry, @@ -69,7 +70,12 @@ export const EditFilesTool: React.FC<{ autoDisplayState={EDIT_FILES_AUTO_DISPLAY_STATE} header={ <> - <span className="text-[13px]">{label}</span> + <ToolIcon name="edit_files" isError={isError} isRunning={isRunning} /> + <span className="text-[13px] leading-6">{label}</span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ExecuteTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ExecuteTool.tsx index 7a38ce4be7..6f0a3b18f0 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ExecuteTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ExecuteTool.tsx @@ -26,6 +26,7 @@ import { isAgentDisplayOpen, resolveAgentDisplayState, } from "./displayMode"; +import { ToolIcon } from "./ToolIcon"; import { formatShellDurationMs, sanitizeExecuteModelIntent, @@ -188,22 +189,20 @@ const ShellCommandLine: React.FC<{ ? summarizeParsedCommands(parsedCommands) : ""; const commandDisplay = summary || command; + const commandLabel = intentLabel + ? `${intentLabel} using ${commandDisplay}` + : `Ran ${commandDisplay}`; + const durationSuffix = durationLabel ? ` for ${durationLabel}` : ""; + return ( <> - <span className="block min-w-0 truncate text-[13px] font-normal text-current"> - {intentLabel ? ( - <> - {intentLabel} using {commandDisplay} - </> - ) : ( - <>Ran {commandDisplay}</> + <ToolIcon name="execute" isError={false} /> + <span className="min-w-0 truncate text-[13px] font-normal leading-6 text-current"> + {commandLabel} + {durationSuffix && ( + <span className="text-content-secondary">{durationSuffix}</span> )} </span> - {durationLabel && ( - <span className="shrink-0 text-[13px] font-normal text-content-secondary"> - {intentLabel ? ` for ${durationLabel}` : durationLabel} - </span> - )} {expanded !== undefined && ( <ChevronDownIcon className={cn( diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ListTemplatesTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ListTemplatesTool.tsx index 5c40aadad5..b8ec34a15f 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ListTemplatesTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ListTemplatesTool.tsx @@ -7,6 +7,7 @@ import { TooltipTrigger, } from "#/components/Tooltip/Tooltip"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import { asRecord, asString, type ToolStatus } from "./utils"; /** @@ -36,7 +37,16 @@ export const ListTemplatesTool: React.FC<{ hasContent={hasContent} header={ <> - <span className="text-[13px]">{label}</span> + <ToolIcon + name="list_templates" + isError={isError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6">{label}</span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ProcessOutputTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ProcessOutputTool.tsx index ecf7eb15cc..2ab190b020 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ProcessOutputTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ProcessOutputTool.tsx @@ -16,6 +16,7 @@ import { resolveAgentDisplayState, } from "./displayMode"; import { AgentDisplayModeToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import { COLLAPSED_OUTPUT_HEIGHT, signalTooltipLabel } from "./utils"; type ProcessOutputToolProps = { @@ -77,8 +78,7 @@ const ProcessOutputToolInner: React.FC<ProcessOutputToolInnerProps> = ({ const toggleOutputExpansion = () => { setOutputFullyExpanded((expanded) => !expanded); }; - const hasHeaderActions = - isRunning || Boolean(killedBySignal) || showExitCode || hasOutput; + const hasHeaderActions = Boolean(killedBySignal) || showExitCode || hasOutput; return ( <AgentDisplayModeToolCollapsible @@ -89,13 +89,24 @@ const ProcessOutputToolInner: React.FC<ProcessOutputToolInnerProps> = ({ ariaLabel={(expanded) => expanded ? "Collapse process output" : "Expand process output" } - header={<span className="text-[13px]">Process output</span>} + header={ + <> + <ToolIcon + name="process_output" + isError={isError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6">Process output</span> + </> + } + headerStatus={ + isRunning ? ( + <LoaderIcon className="size-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" /> + ) : undefined + } headerActions={ hasHeaderActions ? ( <> - {isRunning && ( - <LoaderIcon className="size-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" /> - )} {killedBySignal && !isRunning && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx index d4a92d4f46..a19a0426b9 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx @@ -11,6 +11,7 @@ import { } from "#/components/Tooltip/Tooltip"; import { Response } from "../Response"; import { TranscriptRow } from "../TranscriptRow"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; export const ProposePlanTool: React.FC<{ @@ -73,8 +74,13 @@ export const ProposePlanTool: React.FC<{ return ( <div className="w-full"> - <TranscriptRow className="gap-1.5 text-content-secondary"> - <span className="text-[13px]"> + <TranscriptRow className="gap-2 text-content-secondary"> + <ToolIcon + name="propose_plan" + isError={effectiveError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6"> {isRunning ? `Proposing ${filename}…` : `Proposed ${filename}`} </span> {effectiveError && ( @@ -138,7 +144,7 @@ export const ProposePlanTool: React.FC<{ ) )} {fetchLoading && ( - <TranscriptRow className="gap-1.5 text-[13px] text-content-secondary"> + <TranscriptRow className="gap-2 text-[13px] text-content-secondary"> <LoaderIcon className="size-3.5 animate-spin motion-reduce:animate-none" /> Loading plan… </TranscriptRow> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFileTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFileTool.tsx index 2459a846b3..7e4391fba2 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFileTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFileTool.tsx @@ -10,6 +10,7 @@ import { } from "#/components/Tooltip/Tooltip"; import { asRecord, asString } from "../runtimeTypeUtils"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import { DIFFS_FONT_STYLE, getFileViewerOptionsMinimal, @@ -96,7 +97,12 @@ export const ReadFileTool: React.FC<{ onExpandedChange={onExpandedChange} header={ <> - <span className="text-[13px]">{label}</span> + <ToolIcon name="read_file" isError={isError} isRunning={isRunning} /> + <span className="text-[13px] leading-6">{label}</span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFilesTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFilesTool.tsx index 93473bdb73..d157e9592d 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFilesTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadFilesTool.tsx @@ -8,6 +8,7 @@ import { import type { MergedTool } from "../../ChatConversation/types"; import { getReadFileToolData, ReadFileTool } from "./ReadFileTool"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; type ReadFileItem = { id: string; @@ -50,7 +51,12 @@ export const ReadFilesTool: FC<{ onExpandedChange={onExpandedChange} header={ <> - <span className="text-[13px]">{label}</span> + <ToolIcon + name="read_file" + isError={isError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6">{label}</span> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadSkillTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadSkillTool.tsx index 59df11ab67..33f210d7e2 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadSkillTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadSkillTool.tsx @@ -1,4 +1,4 @@ -import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react"; +import { LoaderIcon, TriangleAlertIcon } from "lucide-react"; import type React from "react"; import { ScrollArea } from "#/components/ScrollArea/ScrollArea"; import { @@ -8,6 +8,7 @@ import { } from "#/components/Tooltip/Tooltip"; import { Response } from "../Response"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; export const ReadSkillTool: React.FC<{ @@ -26,10 +27,14 @@ export const ReadSkillTool: React.FC<{ hasContent={hasContent} header={ <> - <BookOpenIcon className="size-4 shrink-0 text-current" /> - <span className="text-[13px]"> + <ToolIcon name="read_skill" isError={isError} isRunning={isRunning} /> + <span className="text-[13px] leading-6"> {isRunning ? `Reading ${label}…` : `Read ${label}`} </span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadTemplateTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadTemplateTool.tsx index f8c24bd996..70b428cc41 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ReadTemplateTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ReadTemplateTool.tsx @@ -6,6 +6,7 @@ import { TooltipTrigger, } from "#/components/Tooltip/Tooltip"; import { TranscriptRow } from "../TranscriptRow"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; /** @@ -27,8 +28,9 @@ export const ReadTemplateTool: React.FC<{ : "Read template"; return ( - <TranscriptRow className="gap-1.5 text-content-secondary"> - <span className="text-[13px]">{label}</span> + <TranscriptRow className="gap-2 text-content-secondary"> + <ToolIcon name="read_template" isError={isError} isRunning={isRunning} /> + <span className="text-[13px] leading-6">{label}</span> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/StartWorkspaceTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/StartWorkspaceTool.tsx index 7cb14e41e1..98b1262efc 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/StartWorkspaceTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/StartWorkspaceTool.tsx @@ -1,4 +1,4 @@ -import { LoaderIcon, MonitorPlayIcon, TriangleAlertIcon } from "lucide-react"; +import { LoaderIcon, TriangleAlertIcon } from "lucide-react"; import type { FC } from "react"; import { Tooltip, @@ -6,6 +6,7 @@ import { TooltipTrigger, } from "#/components/Tooltip/Tooltip"; import { ToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import type { ToolStatus } from "./utils"; import { WorkspaceBuildLogSection } from "./WorkspaceBuildLogSection"; @@ -42,8 +43,16 @@ export const StartWorkspaceTool: FC<StartWorkspaceToolProps> = ({ const header = ( <> - <MonitorPlayIcon className="size-4 shrink-0 text-current" /> - <span className="text-[13px]">{label}</span> + <ToolIcon + name="start_workspace" + isError={isError} + isRunning={isRunning} + /> + <span className="text-[13px] leading-6">{label}</span> + </> + ); + const headerStatus = ( + <> {isError && ( <Tooltip> <TooltipTrigger asChild> @@ -67,6 +76,7 @@ export const StartWorkspaceTool: FC<StartWorkspaceToolProps> = ({ <div className="w-full"> <ToolCollapsible header={header} + headerStatus={headerStatus} hasContent={hasBuildLogs} defaultExpanded={isRunning} > diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/SubagentTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/SubagentTool.tsx index a6a558a60f..b3f34bbd8e 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/SubagentTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/SubagentTool.tsx @@ -69,7 +69,7 @@ function getSubagentLabel( ): React.ReactNode { if (showDesktopPreview && toolStatus === "running") { return ( - <Shimmer as="span" className="text-[13px]"> + <Shimmer as="span" className="text-[13px] leading-6"> Using the computer... </Shimmer> ); @@ -125,20 +125,22 @@ const SubagentStatusIcon: React.FC<{ const subagentCompleted = isSubagentSuccessStatus(subagentStatus); const DefaultIcon = iconKind === "monitor" ? MonitorIcon : BotIcon; if (isTimeout && !subagentCompleted) { - return <ClockIcon className="size-4 shrink-0 text-current" />; + return <ClockIcon className="size-4 shrink-0 stroke-[1.5] text-current" />; } if ((isError && !subagentCompleted) || toolStatus === "error") { return <CircleXIcon className="size-4 shrink-0 text-current" />; } if (toolStatus === "running") { if (showDesktopPreview) { - return <MonitorIcon className="size-4 shrink-0 text-current" />; + return ( + <MonitorIcon className="size-4 shrink-0 stroke-[1.5] text-current" /> + ); } return ( <LoaderIcon className="size-4 shrink-0 animate-spin motion-reduce:animate-none text-content-link" /> ); } - return <DefaultIcon className="size-4 shrink-0 text-current" />; + return <DefaultIcon className="size-4 shrink-0 stroke-[1.5] text-current" />; }; /** diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx index 218be52f2a..dbd49bc456 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx @@ -10,6 +10,7 @@ import { } from "storybook/test"; import { reactRouterParameters } from "storybook-addon-remix-react-router"; import { ChatWorkspaceContext } from "../../../context/ChatWorkspaceContext"; +import { BlockList } from "../../ChatConversation/ConversationTimeline"; import { DesktopPanelContext } from "./DesktopPanelContext"; import { Tool } from "./Tool"; @@ -18,6 +19,10 @@ const executeIntentCommand = "npm test"; const longExecuteCommand = "docker build --no-cache --build-arg NODE_ENV=production --build-arg API_URL=https://coder.example.com/api --build-arg SENTRY_DSN=https://example.com/sentry --build-arg FEATURE_FLAGS=agents,shell-tools --tag coder-agent:latest ."; +// 1x1 solid coral (#FF6B6B) PNG encoded as base64. +const TEST_PNG_B64 = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4n539HwAHFwLVF8kc1wAAAABJRU5ErkJggg=="; + const getDiffsText = (element: HTMLElement) => Array.from(element.querySelectorAll("diffs-container")) .map((container) => container.shadowRoot?.textContent ?? "") @@ -48,6 +53,212 @@ const meta: Meta<typeof Tool> = { export default meta; type Story = StoryObj<typeof Tool>; +type ToolShowcaseItem = { + name: string; + status?: React.ComponentProps<typeof Tool>["status"]; + args?: unknown; + result?: unknown; + isError?: boolean; + killedBySignal?: "kill" | "terminate"; + modelIntent?: string; + parsedCommands?: readonly string[][]; + subagentVariants?: Map<string, "general" | "explore" | "computer_use">; +}; + +const allToolShowcaseItems: ToolShowcaseItem[] = [ + { + name: "execute", + args: { command: "pnpm check", model_intent: "Checking frontend" }, + modelIntent: "Checking frontend", + parsedCommands: [["pnpm", "check"]], + result: { + output: "Checked 1799 files.", + wall_duration_ms: 2400, + exit_code: 0, + }, + }, + { + name: "process_output", + args: { process_id: "storybook-process" }, + result: { output: "dev server ready on :6006" }, + }, + { + name: "process_list", + args: {}, + result: { + processes: [ + { + id: "storybook-process", + command: "pnpm storybook", + status: "running", + }, + ], + }, + }, + { + name: "process_signal", + args: { process_id: "storybook-process", signal: "terminate" }, + result: { success: true }, + }, + { + name: "wait_for_external_auth", + args: { provider: "github" }, + result: { provider_display_name: "GitHub", authenticated: true }, + }, + { + name: "read_file", + args: { path: "site/src/pages/AgentsPage/AgentChatPage.tsx" }, + result: { content: "export const AgentChatPage = () => null;" }, + }, + { + name: "write_file", + args: { path: "docs/example.md", content: "# Example\n" }, + result: { path: "docs/example.md" }, + }, + { + name: "edit_files", + args: { + files: [ + { + path: "site/src/example.ts", + edits: [{ old_text: "foo", new_text: "bar" }], + }, + ], + }, + result: { files: [{ path: "site/src/example.ts", status: "edited" }] }, + }, + { + name: "list_templates", + result: { + templates: [ + { + id: "template-1", + name: "go-template", + display_name: "Go Development", + }, + ], + count: 1, + }, + }, + { + name: "read_template", + args: { template_id: "template-1" }, + result: { + template: { name: "go-template", display_name: "Go Development" }, + }, + }, + { + name: "create_workspace", + result: { + created: true, + workspace_name: "agent-icons", + build_id: "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + }, + }, + { + name: "start_workspace", + result: { + started: true, + workspace_name: "agent-icons", + agent_status: "ready", + build_id: "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + }, + }, + { + name: "chat_summarized", + result: { summary: "Earlier transcript content was compacted." }, + }, + { + name: "propose_plan", + args: { path: "/home/coder/.coder/plans/PLAN-example.md" }, + result: { path: "/home/coder/.coder/plans/PLAN-example.md" }, + }, + { + name: "ask_user_question", + args: { questions: [] }, + status: "running", + }, + { + name: "advisor", + args: { question: "Which icon family should represent transcript tools?" }, + result: { answer: "Use category-level icons for better scanning." }, + }, + { + name: "computer", + args: { action: "screenshot" }, + result: { output: { type: "image", data: TEST_PNG_B64 } }, + }, + { + name: "read_skill", + args: { name: "deep-review" }, + result: { + name: "deep-review", + content: "# Deep Review\nReview code carefully.", + }, + }, + { + name: "read_skill_file", + args: { name: "deep-review", path: "roles/security-reviewer.md" }, + result: { content: "# Security Reviewer Role\nCheck auth boundaries." }, + }, + { + name: "spawn_agent", + args: { title: "Repository review", prompt: "Review the code." }, + result: { + chat_id: "bot-child", + title: "Repository review", + status: "completed", + }, + }, + { + name: "wait_agent", + args: { chat_id: "bot-child" }, + result: { + chat_id: "bot-child", + title: "Repository review", + status: "completed", + report: "No issues found.", + }, + }, + { + name: "message_agent", + args: { chat_id: "bot-child", message: "Check icon consistency." }, + result: { chat_id: "bot-child", status: "completed" }, + }, + { + name: "close_agent", + args: { chat_id: "bot-child" }, + result: { chat_id: "bot-child", status: "completed" }, + }, + { + name: "spawn_computer_use_agent", + args: { prompt: "Inspect the UI." }, + result: { chat_id: "desktop-child", status: "completed" }, + subagentVariants: new Map([["desktop-child", "computer_use"]]), + }, + { + name: "read_file", + args: { path: "site/src/pages/AgentsPage/Missing.tsx" }, + status: "error", + isError: true, + result: { error: "File not found" }, + }, + { + name: "create_workspace", + status: "running", + args: { workspace_name: "agent-icons" }, + result: { + workspace_name: "agent-icons", + build_id: "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + }, + }, + { + name: "unknown_tool", + args: { example: true }, + result: { ok: true }, + }, +]; + // --------------------------------------------------------------------------- // Execute stories // --------------------------------------------------------------------------- @@ -88,6 +299,7 @@ export const ExecuteModelIntent: Story = { export const ExecuteModelIntentRunning: Story = { args: { + shellToolDisplayMode: "always_expanded", status: "running", args: { command: executeCommand, @@ -129,7 +341,7 @@ export const ExecuteModelIntentLeadingUsing: Story = { const commandButton = canvas.getByRole("button", { name: "Expand command", }); - expect(commandButton).toHaveTextContent(`Ran ${executeCommand}2.3s`); + expect(commandButton).toHaveTextContent(`Ran ${executeCommand} for 2.3s`); expect(commandButton).not.toHaveTextContent("using git fetch origin using"); }, }; @@ -152,7 +364,7 @@ export const ExecuteSuccess: Story = { expect( canvas.queryByRole("img", { name: "Running in background" }), ).not.toBeInTheDocument(); - expect(canvas.getByText("47.2s")).toBeVisible(); + expect(canvas.getByText(/for 47\.2s/)).toBeVisible(); expect(canvas.queryByText("2 lines")).not.toBeInTheDocument(); }, }; @@ -263,7 +475,7 @@ export const ExecuteLongCommandCollapsed: Story = { expect(commandButton).toHaveTextContent(`Ran ${longExecuteCommand}`); expect(commandButton).toHaveAttribute("aria-expanded", "false"); expect(canvas.queryByText("exit 0")).not.toBeInTheDocument(); - expect(canvas.getByText("47.2s")).toBeVisible(); + expect(canvas.getByText(/for 47\.2s/)).toBeVisible(); expect(canvas.queryByText("61 lines")).not.toBeInTheDocument(); }, }; @@ -2578,3 +2790,65 @@ export const CreateWorkspaceBuildFailed: Story = { expect(canvas.getByText("Failed to create workspace")).toBeInTheDocument(); }, }; + +export const AllToolIconsTranscript: Story = { + render: () => ( + <ChatWorkspaceContext value={{ workspaceId: "test-workspace-id" }}> + <DesktopPanelContext.Provider + value={{ desktopChatId: "desktop-child", onOpenDesktop: fn() }} + > + <div className="flex flex-col gap-2"> + <BlockList + blocks={[ + { + type: "thinking", + text: "Thinking\nReviewing the available tools and grouping them by category.", + }, + ]} + tools={[]} + keyPrefix="all-tool-icons-thinking" + /> + {allToolShowcaseItems.map((tool, index) => ( + <Tool + key={`${tool.name}-${index}`} + name={tool.name} + status={tool.status ?? "completed"} + args={tool.args} + result={tool.result} + isError={tool.isError} + killedBySignal={tool.killedBySignal} + modelIntent={tool.modelIntent} + parsedCommands={tool.parsedCommands} + subagentVariants={tool.subagentVariants} + shellToolDisplayMode="always_collapsed" + codeDiffDisplayMode="always_collapsed" + showDesktopPreviews={false} + /> + ))} + </div> + </DesktopPanelContext.Provider> + </ChatWorkspaceContext> + ), + parameters: { + queries: [ + { + key: ["workspace", "test-workspace-id"], + data: { + id: "test-workspace-id", + latest_build: { + id: "test-build-id", + status: "running", + }, + }, + }, + { + key: [ + "workspaceBuilds", + "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "logs", + ], + data: [], + }, + ], + }, +}; diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx index 10ea071eda..82c18f6fe7 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx @@ -938,6 +938,10 @@ const GenericToolRenderer: FC<ToolRendererProps> = ({ mcpSlug={mcpServer?.slug} /> )} + </> + ); + const toolHeaderStatus = ( + <> {isError && ( <Tooltip> <TooltipTrigger asChild> @@ -963,7 +967,11 @@ const GenericToolRenderer: FC<ToolRendererProps> = ({ ); return ( - <ToolCollapsible hasContent={hasContent} header={toolHeader}> + <ToolCollapsible + hasContent={hasContent} + header={toolHeader} + headerStatus={toolHeaderStatus} + > {toolContent} </ToolCollapsible> ); diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ToolCollapsible.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ToolCollapsible.tsx index 8b8f236341..d894f33dec 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ToolCollapsible.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ToolCollapsible.tsx @@ -17,6 +17,7 @@ interface ToolCollapsibleProps { children: ReactNode; header: ToolCollapsibleHeader; headerActions?: ReactNode; + headerStatus?: ReactNode; hasContent?: boolean; defaultExpanded?: boolean; expanded?: boolean; @@ -50,6 +51,7 @@ export const ToolCollapsible: FC<ToolCollapsibleProps> = ({ children, header, headerActions, + headerStatus, hasContent = true, defaultExpanded = false, expanded: expandedProp, @@ -88,6 +90,7 @@ export const ToolCollapsible: FC<ToolCollapsibleProps> = ({ onClick={toggleExpanded} > {renderedHeader} + {headerStatus} <ChevronDownIcon className={cn( "size-3 shrink-0 text-current transition-transform", @@ -105,6 +108,7 @@ export const ToolCollapsible: FC<ToolCollapsibleProps> = ({ )} > {renderedHeader} + {headerStatus} </TranscriptRow> ); diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ToolIcon.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ToolIcon.tsx index ea3ce5e009..c2ff74debb 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ToolIcon.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ToolIcon.tsx @@ -1,13 +1,15 @@ import { - BookOpenIcon, + BadgeQuestionMarkIcon, BotIcon, - ClipboardListIcon, - FileIcon, - FilePenIcon, + CompassIcon, + FilePenLineIcon, + FileTextIcon, LightbulbIcon, + LogInIcon, MonitorIcon, - PlayIcon, - PlusCircleIcon, + PowerIcon, + RouteIcon, + ServerIcon, TerminalIcon, WrenchIcon, } from "lucide-react"; @@ -35,7 +37,12 @@ export const ToolIcon: React.FC<{ }> = ({ name, iconUrl, isRunning, serverName, subagentIconKind }) => { const [imgError, setImgError] = useState(false); const color = "text-current"; - const base = cn("size-4 shrink-0", color, isRunning && "grayscale"); + const base = cn( + "size-4 shrink-0", + color, + "stroke-[1.5]", + isRunning && "grayscale", + ); // If an MCP icon URL is provided and hasn't failed, render it. // Strip colour so external icons match the monochrome lucide @@ -93,28 +100,33 @@ export const ToolIcon: React.FC<{ case "process_list": case "process_signal": return <TerminalIcon className={base} />; + case "wait_for_external_auth": + return <LogInIcon className={base} />; case "read_file": - case "list_templates": - case "read_template": - return <FileIcon className={base} />; - case "write_file": - case "edit_files": - return <FilePenIcon className={base} />; - case "create_workspace": - return <PlusCircleIcon className={base} />; - case "chat_summarized": - return <BotIcon className={base} />; - case "propose_plan": - return <ClipboardListIcon className={base} />; - case "advisor": - return <LightbulbIcon className={base} />; - case "computer": - return <MonitorIcon className={base} />; case "read_skill": case "read_skill_file": - return <BookOpenIcon className={base} />; + return <FileTextIcon className={base} />; + case "write_file": + case "edit_files": + return <FilePenLineIcon className={base} />; + case "list_templates": + case "read_template": + case "create_workspace": + return <ServerIcon className={base} />; case "start_workspace": - return <PlayIcon className={base} />; + return <PowerIcon className={base} />; + case "chat_summarized": + return <BotIcon className={base} />; + case "thinking": + return <LightbulbIcon className={base} />; + case "propose_plan": + return <RouteIcon className={base} />; + case "ask_user_question": + return <BadgeQuestionMarkIcon className={base} />; + case "advisor": + return <CompassIcon className={base} />; + case "computer": + return <MonitorIcon className={base} />; default: return <WrenchIcon className={base} />; diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/WebSearchSources.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/WebSearchSources.tsx index b6822b2ef0..0ede427fd2 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/WebSearchSources.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/WebSearchSources.tsx @@ -36,8 +36,8 @@ const WebSearchSources: FC<WebSearchSourcesProps> = ({ sources }) => { hasContent={unique.length > 0} header={ <> - <GlobeIcon className="size-4 shrink-0 text-current" /> - <span className="text-[13px]"> + <GlobeIcon className="size-4 shrink-0 stroke-[1.5] text-current" /> + <span className="text-[13px] leading-6"> Searched <span className="text-content-secondary/60">{detail}</span> </span> </> diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/WriteFileTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/WriteFileTool.tsx index 11751bf4d7..e70cf930bf 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/WriteFileTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/WriteFileTool.tsx @@ -16,6 +16,7 @@ import { resolveAgentDisplayState, } from "./displayMode"; import { AgentDisplayModeToolCollapsible } from "./ToolCollapsible"; +import { ToolIcon } from "./ToolIcon"; import { DIFFS_FONT_STYLE, getDiffViewerOptions, @@ -53,7 +54,12 @@ export const WriteFileTool: React.FC<{ autoDisplayState={WRITE_FILE_AUTO_DISPLAY_STATE} header={ <> - <span className="text-[13px]">{label}</span> + <ToolIcon name="write_file" isError={isError} isRunning={isRunning} /> + <span className="text-[13px] leading-6">{label}</span> + </> + } + headerStatus={ + <> {isError && ( <Tooltip> <TooltipTrigger asChild> diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx index ec422bca5e..09a5386369 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx @@ -1374,9 +1374,9 @@ export const AnthropicKnownModelHappyPath: Story = { await openKnownModelPopover(body); const options = await body.findAllByRole("option"); - await userEvent.click(findOptionByText(options, "claude-opus-4-7")); + await userEvent.click(findOptionByText(options, "claude-opus-4-8")); - await expectModelIdentifierValue(body, "claude-opus-4-7"); + await expectModelIdentifierValue(body, "claude-opus-4-8"); await expect(body.getByLabelText(/Context limit/i)).toHaveValue("1000000"); await expandSection(body, "Advanced"); @@ -1858,7 +1858,7 @@ export const KnownModelAutoHidePopoverWhenNoMatches: Story = { name: /Model Identifier/i, }); await userEvent.click(input); - await expect(await body.findByText("Claude Opus 4.7")).toBeInTheDocument(); + await expect(await body.findByText("Claude Opus 4.8")).toBeInTheDocument(); await userEvent.clear(input); await userEvent.type(input, "claude-opus-4-5"); diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.test.ts b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.test.ts index 0ad417439c..2e1762d758 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.test.ts +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.test.ts @@ -22,6 +22,7 @@ describe("anthropicKnownModels", () => { (knownModel) => knownModel.modelIdentifier, ), ).toEqual([ + "claude-opus-4-8", "claude-opus-4-7", "claude-opus-4-6", "claude-sonnet-4-6", @@ -31,7 +32,11 @@ describe("anthropicKnownModels", () => { }); it("declares Anthropic reasoning defaults by API support", () => { - for (const modelIdentifier of ["claude-opus-4-7", "claude-opus-4-6"]) { + for (const modelIdentifier of [ + "claude-opus-4-8", + "claude-opus-4-7", + "claude-opus-4-6", + ]) { const knownModel = requireAnthropicKnownModel(modelIdentifier); expect(knownModel.reasoningEffort).toBe("high"); @@ -54,6 +59,7 @@ describe("anthropicKnownModels", () => { expect( anthropicKnownModels.map((knownModel) => knownModel.modelIdentifier), ).toEqual([ + "claude-opus-4-8", "claude-opus-4-7", "claude-opus-4-6", "claude-sonnet-4-6", @@ -64,8 +70,16 @@ describe("anthropicKnownModels", () => { for (const knownModel of anthropicKnownModels) { expect(knownModel.provider).toBe("anthropic"); expect(knownModel.sourceMetadata.sourceName).toBe("models.dev"); - expect(knownModel.sourceMetadata.sourceRetrievedAt).toBe("2026-04-30"); + expect(knownModel.sourceMetadata.sourceRetrievedAt).not.toBe(""); expect(knownModel.sourceMetadata.lastUpdated).not.toBe(""); } + + expect( + requireAnthropicKnownModel("claude-opus-4-8").sourceMetadata, + ).toEqual({ + sourceName: "models.dev", + sourceRetrievedAt: "2026-05-29", + lastUpdated: "2026-05-28", + }); }); }); diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.ts b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.ts index 297ddabe2c..985fb49f7e 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.ts +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/anthropic.ts @@ -10,14 +10,33 @@ import type { KnownModel } from "./types"; // catalog and should be reviewed when the catalog is refreshed. // // Reasoning configuration is split per model based on Anthropic API support: -// models that support adaptive thinking (Opus 4.7, Opus 4.6, Sonnet 4.6) -// carry `reasoningEffort`, which Coder maps to `thinking.type: "adaptive"` -// with the `effort` parameter. Models that do not (Haiku 4.5, Sonnet 4.5) +// models that support adaptive thinking (Opus 4.8, Opus 4.7, Opus 4.6, +// Sonnet 4.6) carry `reasoningEffort`, which Coder maps to +// `thinking.type: "adaptive"` with the `effort` parameter. Models that do not +// (Haiku 4.5, Sonnet 4.5) // carry `thinkingBudgetTokens` instead, which Coder maps to the legacy // `thinking.type: "enabled"` path with `budget_tokens`. Setting `effort` on // the legacy path produces an "adaptive thinking is not supported on this // model" HTTP 400 from Anthropic. export const anthropicKnownModels = [ + { + provider: "anthropic", + modelIdentifier: "claude-opus-4-8", + displayName: "Claude Opus 4.8", + aliases: [], + contextLimit: 1_000_000, + maxOutputTokens: 128_000, + reasoningEffort: "high", + inputCost: 5, + outputCost: 25, + cacheReadCost: 0.5, + cacheWriteCost: 6.25, + sourceMetadata: { + sourceName: "models.dev", + sourceRetrievedAt: "2026-05-29", + lastUpdated: "2026-05-28", + }, + }, { provider: "anthropic", modelIdentifier: "claude-opus-4-7", diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/applyKnownModelDefaults.test.ts b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/applyKnownModelDefaults.test.ts index d325c53385..d537aa726c 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/applyKnownModelDefaults.test.ts +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/knownModels/applyKnownModelDefaults.test.ts @@ -294,7 +294,7 @@ describe("applyKnownModelDefaults", () => { values: buildInitialModelFormValues(), initialValues: buildInitialModelFormValues(), provider: "anthropic", - knownModel: requireKnownModel("anthropic", "claude-opus-4-7"), + knownModel: requireKnownModel("anthropic", "claude-opus-4-8"), }); expect(getPath(result.values, "config.anthropic.effort")).toBe("high"); @@ -327,7 +327,7 @@ describe("applyKnownModelDefaults", () => { values: buildInitialModelFormValues(), initialValues: buildInitialModelFormValues(), provider: "anthropic", - knownModel: requireKnownModel("anthropic", "claude-opus-4-7"), + knownModel: requireKnownModel("anthropic", "claude-opus-4-8"), }); expect(getPath(result.values, "config.anthropic.sendReasoning")).toBe(""); diff --git a/site/src/pages/AgentsPage/components/ChatPageContent.tsx b/site/src/pages/AgentsPage/components/ChatPageContent.tsx index 098d767070..68d2644603 100644 --- a/site/src/pages/AgentsPage/components/ChatPageContent.tsx +++ b/site/src/pages/AgentsPage/components/ChatPageContent.tsx @@ -106,7 +106,7 @@ export const ChatPageTimeline: FC<ChatPageTimelineProps> = ({ <div data-testid="chat-timeline-wrapper" className={cn( - "mx-auto flex w-full flex-col gap-2 py-6", + "mx-auto flex w-full flex-col py-6", chatWidthClass(chatFullWidth), )} > diff --git a/site/src/pages/AgentsPage/components/ChatSharingPopover.stories.tsx b/site/src/pages/AgentsPage/components/ChatSharingPopover.stories.tsx index df4b142b1d..b35f0b3347 100644 --- a/site/src/pages/AgentsPage/components/ChatSharingPopover.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatSharingPopover.stories.tsx @@ -119,7 +119,7 @@ const openChatSharing = async (canvasElement: HTMLElement) => { const canvas = within(canvasElement); await userEvent.click(canvas.getByRole("button", { name: "Share" })); const body = within(canvasElement.ownerDocument.body); - await body.findByText("Chat Sharing"); + await body.findByText("Chat sharing"); return body; }; @@ -128,7 +128,7 @@ const closeChatSharing = async (canvasElement: HTMLElement) => { const body = within(canvasElement.ownerDocument.body); await userEvent.click(canvas.getByRole("button", { name: "Share" })); await waitFor(() => { - expect(body.queryByText("Chat Sharing")).not.toBeInTheDocument(); + expect(body.queryByText("Chat sharing")).not.toBeInTheDocument(); }); }; @@ -153,6 +153,12 @@ const addAutocompleteOption = async ( await userEvent.click(body.getByRole("button", { name: "Add member" })); }; +const MobileFrame = (Story: React.FC) => ( + <div className="w-[390px] max-w-full"> + <Story /> + </div> +); + const meta: Meta<typeof ChatShareButton> = { title: "pages/AgentsPage/ChatSharingPopover", component: ChatShareButton, @@ -174,10 +180,13 @@ export const EmptyACL: Story = { play: async ({ canvasElement }) => { const body = await openChatSharing(canvasElement); await waitFor(() => { - expect(body.getByText("No shared members or groups yet")).toBeVisible(); expect( - body.getByText("Add a member or group using the controls above."), - ).toBeVisible(); + body.getAllByText("No shared members or groups yet").length, + ).toBeGreaterThan(0); + expect( + body.getAllByText("Add a member or group using the controls above.") + .length, + ).toBeGreaterThan(0); }); }, }; @@ -187,11 +196,32 @@ export const PopulatedACL: Story = { play: async ({ canvasElement }) => { const body = await openChatSharing(canvasElement); await waitFor(() => { - expect(body.getByText(chatUser.username)).toBeInTheDocument(); - expect(body.getByText(chatGroup.name)).toBeInTheDocument(); + expect(body.getAllByText(chatUser.username).length).toBeGreaterThan(0); + expect(body.getAllByText(chatGroup.name).length).toBeGreaterThan(0); expect(body.getAllByText("Read").length).toBeGreaterThan(0); }); - expect(body.getAllByRole("button", { name: "Open menu" })).toHaveLength(2); + expect( + body.getAllByRole("button", { name: "Open menu" }).length, + ).toBeGreaterThan(0); + }, +}; + +export const MobilePopulatedACL: Story = { + decorators: [MobileFrame], + parameters: { + chromatic: { viewports: [390] }, + }, + beforeEach: () => mockDialogRequests({ acl: populatedACL }), + play: async ({ canvasElement }) => { + const body = await openChatSharing(canvasElement); + await waitFor(() => { + expect(body.getAllByText(chatUser.username).length).toBeGreaterThan(0); + expect(body.getAllByText(chatGroup.name).length).toBeGreaterThan(0); + expect(body.getAllByText("Read").length).toBeGreaterThan(0); + }); + expect( + body.getAllByRole("button", { name: "Open menu" }).length, + ).toBeGreaterThan(0); }, }; @@ -209,7 +239,7 @@ export const CurrentUserHidden: Story = { expect( body.queryByText(currentChatUser.username), ).not.toBeInTheDocument(); - expect(body.getByText(chatUser.username)).toBeVisible(); + expect(body.getAllByText(chatUser.username).length).toBeGreaterThan(0); }); }, }; @@ -307,7 +337,7 @@ export const RemoveUser: Story = { play: async ({ canvasElement }) => { const body = await openChatSharing(canvasElement); await waitFor(() => { - expect(body.getByText(chatUser.username)).toBeInTheDocument(); + expect(body.getAllByText(chatUser.username).length).toBeGreaterThan(0); }); // Groups render before users, so the user row menu is the second one. const menuButtons = await body.findAllByRole("button", { @@ -334,7 +364,7 @@ export const RemoveGroup: Story = { play: async ({ canvasElement }) => { const body = await openChatSharing(canvasElement); await waitFor(() => { - expect(body.getByText(chatGroup.name)).toBeInTheDocument(); + expect(body.getAllByText(chatGroup.name).length).toBeGreaterThan(0); }); // Groups render before users, so the group row menu is the first one. const menuButtons = await body.findAllByRole("button", { diff --git a/site/src/pages/AgentsPage/components/ChatSharingPopover.tsx b/site/src/pages/AgentsPage/components/ChatSharingPopover.tsx index 0a7720feb5..7fca07da7f 100644 --- a/site/src/pages/AgentsPage/components/ChatSharingPopover.tsx +++ b/site/src/pages/AgentsPage/components/ChatSharingPopover.tsx @@ -1,5 +1,5 @@ -import { EllipsisVerticalIcon, Share2Icon } from "lucide-react"; -import { type FC, useState } from "react"; +import { EllipsisVerticalIcon, Share2Icon, UserPlusIcon } from "lucide-react"; +import { type FC, type ReactNode, useState } from "react"; import { useMutation, useQuery, useQueryClient } from "react-query"; import { toast } from "sonner"; import { @@ -18,7 +18,6 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from "#/components/DropdownMenu/DropdownMenu"; -import { EmptyState } from "#/components/EmptyState/EmptyState"; import { TopbarButton } from "#/components/FullPageLayout/Topbar"; import { Popover, @@ -40,7 +39,6 @@ import { UserOrGroupAutocomplete, type UserOrGroupAutocompleteValue, } from "#/modules/workspaces/WorkspaceSharingForm/UserOrGroupAutocomplete"; -import { AddWorkspaceMemberForm } from "#/modules/workspaces/WorkspaceSharingForm/WorkspaceSharingForm"; type ChatShareButtonProps = { chatId: string; @@ -57,7 +55,7 @@ type MemberRowMenuProps = { }; const ReadRoleBadge: FC = () => ( - <span className="bg-surface-secondary rounded-md px-3 py-0.5 inline-block"> + <span className="inline-block shrink-0 rounded-md bg-surface-secondary px-2 py-0.5 text-xs leading-5"> Read </span> ); @@ -85,6 +83,89 @@ const MemberRowMenu: FC<MemberRowMenuProps> = ({ disabled, onRemove }) => ( </DropdownMenu> ); +type AddChatMemberFormProps = { + isLoading: boolean; + onSubmit: () => void; + disabled: boolean; + children: ReactNode; +}; + +const AddChatMemberForm: FC<AddChatMemberFormProps> = ({ + isLoading, + onSubmit, + disabled, + children, +}) => ( + <form action={onSubmit}> + <div className="flex flex-col gap-2 sm:flex-row sm:items-center"> + <div className="min-w-0 flex-1">{children}</div> + <Button + disabled={disabled || isLoading} + type="submit" + className="w-full sm:w-auto" + > + <Spinner loading={isLoading}> + <UserPlusIcon className="size-icon-sm" /> + </Spinner> + Add member + </Button> + </div> + </form> +); + +type MemberIdentityProps = + | { kind: "group"; group: TypesGen.ChatGroup } + | { kind: "user"; user: TypesGen.ChatUser }; + +const MemberIdentity: FC<MemberIdentityProps> = (props) => { + if (props.kind === "group") { + const { group } = props; + return ( + <AvatarData + title={group.display_name || group.name} + subtitle={getGroupSubtitle(group)} + src={group.avatar_url} + avatar={ + <Avatar + src={group.avatar_url} + fallback={group.display_name || group.name} + variant="icon" + /> + } + /> + ); + } + + const { user } = props; + return ( + <AvatarData + title={user.username} + subtitle={user.name} + src={user.avatar_url} + /> + ); +}; + +type MobileMemberRowProps = { + children: ReactNode; + disabled: boolean; + onRemove: () => void; +}; + +const MobileMemberRow: FC<MobileMemberRowProps> = ({ + children, + disabled, + onRemove, +}) => ( + <div className="flex items-center justify-between gap-3 border-0 border-b border-solid border-border last:border-b-0 px-1 py-3"> + <div className="min-w-0 flex-1">{children}</div> + <div className="flex shrink-0 items-center gap-2"> + <ReadRoleBadge /> + <MemberRowMenu disabled={disabled} onRemove={onRemove} /> + </div> + </div> +); + export const ChatSharingPopoverContent: FC<ChatSharingPopoverContentProps> = ({ chatId, organizationId, @@ -196,9 +277,12 @@ export const ChatSharingPopoverContent: FC<ChatSharingPopoverContentProps> = ({ const isEmpty = groups.length === 0 && users.length === 0; return ( - <PopoverContent align="end" className="w-[580px] p-4"> + <PopoverContent + align="end" + className="w-[calc(100vw-2rem)] p-3 sm:w-[580px] sm:p-4" + > <div className="flex items-center gap-2 mb-4"> - <h3 className="text-lg font-semibold m-0">Chat Sharing</h3> + <h3 className="text-lg font-semibold m-0">Chat sharing</h3> </div> <div className="flex flex-col gap-4"> @@ -212,7 +296,7 @@ export const ChatSharingPopoverContent: FC<ChatSharingPopoverContentProps> = ({ </div> ) : acl ? ( <> - <AddWorkspaceMemberForm + <AddChatMemberForm isLoading={isMutating} disabled={!selectedOption} onSubmit={handleAddMember} @@ -222,52 +306,64 @@ export const ChatSharingPopoverContent: FC<ChatSharingPopoverContentProps> = ({ onChange={setSelectedOption} organizationId={organizationId} exclude={excludeFromAutocomplete} + className="w-full sm:w-80" /> - </AddWorkspaceMemberForm> + </AddChatMemberForm> - <Table - aria-label="Shared chat members and groups" - wrapperClassName="max-h-60 overflow-y-auto" - > - <TableHeader> - <TableRow> - <TableHead className="sticky top-0 z-10 w-[50%] bg-surface-primary py-2"> - Member - </TableHead> - <TableHead className="sticky top-0 z-10 w-[40%] bg-surface-primary py-2"> - Role - </TableHead> - <TableHead className="sticky top-0 z-10 w-[10%] bg-surface-primary py-2" /> - </TableRow> - </TableHeader> - <TableBody> - {isEmpty ? ( - <TableRow> - <TableCell colSpan={3}> - <EmptyState - message="No shared members or groups yet" - description="Add a member or group using the controls above." - isCompact - /> - </TableCell> - </TableRow> - ) : ( - <> + {isEmpty ? ( + <div className="flex min-h-44 flex-col items-center justify-center px-6 py-6 text-center"> + <h4 className="m-0 text-sm font-medium text-content-secondary"> + No shared members or groups yet + </h4> + <p className="m-0 mt-2 text-sm text-content-secondary"> + Add a member or group using the controls above. + </p> + </div> + ) : ( + <div className="max-h-[min(60vh,24rem)] overflow-y-auto rounded-md border border-solid border-border sm:hidden"> + {groups.map((group) => ( + <MobileMemberRow + key={group.id} + disabled={isMutating} + onRemove={() => handleRemoveGroup(group)} + > + <MemberIdentity kind="group" group={group} /> + </MobileMemberRow> + ))} + {users.map((user) => ( + <MobileMemberRow + key={user.id} + disabled={isMutating} + onRemove={() => handleRemoveUser(user)} + > + <MemberIdentity kind="user" user={user} /> + </MobileMemberRow> + ))} + </div> + )} + + {!isEmpty && ( + <div className="hidden sm:block"> + <Table + aria-label="Shared chat members and groups" + wrapperClassName="max-h-60 overflow-y-auto" + > + <TableHeader> + <TableRow> + <TableHead className="sticky top-0 z-10 w-[50%] bg-surface-primary py-2"> + Member + </TableHead> + <TableHead className="sticky top-0 z-10 w-[40%] bg-surface-primary py-2"> + Role + </TableHead> + <TableHead className="sticky top-0 z-10 w-[10%] bg-surface-primary py-2" /> + </TableRow> + </TableHeader> + <TableBody> {groups.map((group) => ( <TableRow key={group.id}> <TableCell className="py-2 w-[50%]"> - <AvatarData - title={group.display_name || group.name} - subtitle={getGroupSubtitle(group)} - src={group.avatar_url} - avatar={ - <Avatar - src={group.avatar_url} - fallback={group.display_name || group.name} - variant="icon" - /> - } - /> + <MemberIdentity kind="group" group={group} /> </TableCell> <TableCell className="py-2 w-[40%]"> <ReadRoleBadge /> @@ -283,11 +379,7 @@ export const ChatSharingPopoverContent: FC<ChatSharingPopoverContentProps> = ({ {users.map((user) => ( <TableRow key={user.id}> <TableCell className="py-2 w-[50%]"> - <AvatarData - title={user.username} - subtitle={user.name} - src={user.avatar_url} - /> + <MemberIdentity kind="user" user={user} /> </TableCell> <TableCell className="py-2 w-[40%]"> <ReadRoleBadge /> @@ -300,10 +392,10 @@ export const ChatSharingPopoverContent: FC<ChatSharingPopoverContentProps> = ({ </TableCell> </TableRow> ))} - </> - )} - </TableBody> - </Table> + </TableBody> + </Table> + </div> + )} </> ) : null} </div> diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.stories.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.stories.tsx index 5294ba61c0..a15c7cee8c 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.stories.tsx @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; import type { ComponentProps } from "react"; import { useEffect, useState } from "react"; +import { useLocation } from "react-router"; import { expect, fn, userEvent, waitFor, within } from "storybook/test"; import { reactRouterParameters } from "storybook-addon-remix-react-router"; import { userChatProviderConfigsKey } from "#/api/queries/chats"; @@ -12,10 +13,25 @@ import { withDashboardProvider, } from "#/testHelpers/storybook"; import { useAgentsPageKeybindings } from "../../hooks/useAgentsPageKeybindings"; -import type { AgentSidebarFilters } from "../../utils/agentSidebarFilters"; +import { DEFAULT_AGENT_SIDEBAR_FILTERS as defaultSidebarFilters } from "../../utils/agentSidebarFilters"; import type { ModelSelectorOption } from "../ChatElements"; import { ChatsSidebar } from "./ChatsSidebar"; +// Probe element used by the archived-filter preservation story to surface the +// search string of whatever child route the sidebar's NavLink ends up at. +const ChildSearchProbe = () => { + const location = useLocation(); + return <div data-testid="child-search">{location.search}</div>; +}; + +// Probe element used by the settings-link preservation story to surface the +// state.from value passed when navigating to settings. +const SettingsStateProbe = () => { + const location = useLocation(); + const from = (location.state as { from?: string })?.from ?? ""; + return <div data-testid="settings-state-from">{from}</div>; +}; + const defaultModelOptions: ModelSelectorOption[] = [ { id: "openai:gpt-4o", @@ -25,13 +41,6 @@ const defaultModelOptions: ModelSelectorOption[] = [ }, ]; -const defaultSidebarFilters: AgentSidebarFilters = { - archiveStatus: "active", - groupBy: "date", - prStatuses: [], - chatStatuses: ["unread", "read"], -}; - const defaultModelConfigs: TypesGen.ChatModelConfig[] = [ { id: "config-openai-gpt-4o", @@ -106,8 +115,8 @@ const meta: Meta<typeof ChatsSidebar> = { isCreating: false, regeneratingTitleChatIds: [], sidebarFilters: defaultSidebarFilters, - onSidebarFiltersChange: fn(), isPersonalModelOverridesEnabled: true, + onSidebarFiltersChange: fn(), }, parameters: { layout: "fullscreen", @@ -724,6 +733,68 @@ export const SectionHeadersCollapse: Story = { }, }; +export const MobileHeaderActions: Story = { + render: ChatsSidebarWithKeybindings, + args: { + chats: sectionHeaderChats, + }, + parameters: { + viewport: { defaultViewport: "mobile1" }, + reactRouter: reactRouterParameters({ + location: { path: "/agents" }, + routing: agentsRouting, + }), + }, + decorators: [ + (Story) => ( + <div style={{ height: 500, width: 360 }}> + <Story /> + </div> + ), + ], + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const searchButton = canvas.getByRole("button", { name: "Search chats" }); + const filterButton = canvas.getByRole("button", { name: "Filter agents" }); + const searchRect = searchButton.getBoundingClientRect(); + const filterRect = filterButton.getBoundingClientRect(); + + await expect(searchButton).not.toHaveTextContent("Search"); + expect(Math.round(searchRect.width)).toBeGreaterThanOrEqual(28); + expect(Math.round(filterRect.width)).toBeGreaterThanOrEqual(28); + expect(searchRect.right).toBeLessThan(filterRect.left); + }, +}; + +export const SidebarFilterMenu: Story = { + args: { + chats: sectionHeaderChats, + }, + parameters: { + reactRouter: reactRouterParameters({ + location: { path: "/agents" }, + routing: agentsRouting, + }), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const body = within(document.body); + + await userEvent.click( + canvas.getByRole("button", { name: "Filter agents" }), + ); + await expect( + await body.findByRole("radio", { name: /Archived/i }), + ).toBeInTheDocument(); + await userEvent.keyboard("{Escape}"); + await waitFor(() => { + expect( + body.queryByRole("radio", { name: /Archived/i }), + ).not.toBeInTheDocument(); + }); + }, +}; + export const SearchDialogKeyboardShortcut: Story = { render: ChatsSidebarWithKeybindings, args: { @@ -740,11 +811,23 @@ export const SearchDialogKeyboardShortcut: Story = { const body = within(document.body); const searchButton = canvas.getByRole("button", { name: "Search chats" }); - await userEvent.hover(searchButton); - const tooltip = await body.findByRole("tooltip"); - await expect(tooltip).toHaveTextContent("Search chats"); - await expect(tooltip).toHaveTextContent("Ctrl"); - await expect(tooltip).toHaveTextContent("K"); + await expect(searchButton).toHaveTextContent("Search"); + await expect(searchButton).toHaveTextContent("Ctrl"); + await expect(searchButton).toHaveTextContent("K"); + + await userEvent.click(searchButton); + const clickedSearchInput = await body.findByRole("combobox", { + name: "Search chats", + }); + await waitFor(() => { + expect(clickedSearchInput).toHaveFocus(); + }); + await userEvent.keyboard("{Escape}"); + await waitFor(() => { + expect( + body.queryByRole("combobox", { name: "Search chats" }), + ).not.toBeInTheDocument(); + }); await userEvent.keyboard("{Control>}k{/Control}"); @@ -1346,10 +1429,7 @@ export const ArchivedFilterShowsArchivedAgents: Story = { updated_at: recentTimestamp, }), ], - sidebarFilters: { - ...defaultSidebarFilters, - archiveStatus: "archived", - }, + sidebarFilters: { ...defaultSidebarFilters, archiveStatus: "archived" }, }, parameters: { reactRouter: reactRouterParameters({ @@ -1367,6 +1447,44 @@ export const ArchivedFilterShowsArchivedAgents: Story = { }, }; +export const PreservesArchivedFilterOnChatNavigation: Story = { + args: { + chats: [ + buildChat({ + id: "archived-nav-1", + title: "Archived nav target", + archived: true, + updated_at: recentTimestamp, + }), + ], + sidebarFilters: { ...defaultSidebarFilters, archiveStatus: "archived" }, + }, + parameters: { + reactRouter: reactRouterParameters({ + location: { + path: "/agents", + searchParams: { archived: "archived" }, + }, + routing: [ + { path: "/agents", useStoryElement: true }, + { path: "/agents/:agentId", element: <ChildSearchProbe /> }, + ], + }), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const link = await canvas.findByRole("link", { + name: /Archived nav target/, + }); + await userEvent.click(link); + await waitFor(() => { + expect(canvas.getByTestId("child-search")).toHaveTextContent( + "archived=archived", + ); + }); + }, +}; + export const NoArchivedSection: Story = { args: { chats: [ @@ -1757,10 +1875,7 @@ export const ArchivedAgentUnarchiveOption: Story = { updated_at: recentTimestamp, }), ], - sidebarFilters: { - ...defaultSidebarFilters, - archiveStatus: "archived", - }, + sidebarFilters: { ...defaultSidebarFilters, archiveStatus: "archived" }, }, parameters: { reactRouter: reactRouterParameters({ @@ -2100,3 +2215,43 @@ export const SettingsAdminAgentsEntryPreserved: Story = { expect(canvas.getByText("Manage Agents")).toBeInTheDocument(); }, }; + +export const PreservesArchivedFilterOnSettingsNavigation: Story = { + args: { + chats: [ + buildChat({ + id: "archived-settings-1", + title: "Archived settings target", + archived: true, + updated_at: recentTimestamp, + }), + ], + sidebarFilters: { ...defaultSidebarFilters, archiveStatus: "archived" }, + }, + parameters: { + reactRouter: reactRouterParameters({ + location: { + path: "/agents", + searchParams: { archived: "archived" }, + }, + routing: [ + { + path: "/agents/settings", + element: <SettingsStateProbe />, + }, + ...agentsRouting, + ], + }), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const settingsLink = await canvas.findByRole("link", { name: "Settings" }); + await userEvent.click(settingsLink); + await waitFor(() => { + const fromValue = + canvas.getByTestId("settings-state-from").textContent ?? ""; + expect(fromValue).toContain("/agents"); + expect(fromValue).toContain("archived=archived"); + }); + }, +}; diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.tsx index 9bc3e69bdf..042c1417c0 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/ChatsSidebar.tsx @@ -151,6 +151,7 @@ export const ChatsSidebar: FC<ChatsSidebarProps> = (props) => { open={isSearchDialogOpen} onOpenChange={onSearchDialogOpenChange} location={location} + recentChats={chats} /> {onRenameTitle && ( <RenameChatDialog diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/chats/ChatsPanel.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/chats/ChatsPanel.tsx index d3190229ec..bd7482767d 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/chats/ChatsPanel.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/chats/ChatsPanel.tsx @@ -30,11 +30,6 @@ import { ProductLogo } from "#/components/Icons/ProductLogo"; import { Kbd, KbdGroup } from "#/components/Kbd/Kbd"; import { ScrollArea } from "#/components/ScrollArea/ScrollArea"; import { Skeleton } from "#/components/Skeleton/Skeleton"; -import { - Tooltip, - TooltipContent, - TooltipTrigger, -} from "#/components/Tooltip/Tooltip"; import { cn } from "#/utils/cn"; import { getOSKey } from "#/utils/platform"; import { @@ -350,8 +345,11 @@ export const ChatsPanel: FC<ChatsPanelProps> = ({ aria-hidden={isSettingsPanel} inert={isSettingsPanel ? true : undefined} > - <div className="hidden border-b border-border-default px-2 py-1.5 sm:block"> - <div className="flex items-center justify-between mb-2.5"> + <nav + aria-label="Sidebar" + className="hidden border-b border-border-default px-2 py-1.5 sm:flex sm:flex-col sm:gap-0.5" + > + <div className="flex items-center justify-between mb-2.5 ml-2.5"> <div className="flex items-center gap-2"> <NavLink to="/workspaces" className="inline-flex"> <ProductLogo className="size-6" /> @@ -397,10 +395,50 @@ export const ChatsPanel: FC<ChatsPanelProps> = ({ onClick={onBeforeNewAgent} disabled={isCreating} /> - </div> - <div className="relative min-h-0 flex-1"> + {onOpenSearchDialog && ( + <SettingsNavItem + icon={SearchIcon} + label="Search" + active={false} + ariaLabel="Search chats" + onClick={onOpenSearchDialog} + className="group focus-visible:bg-surface-tertiary/50 focus-visible:text-content-primary" + trailing={ + <KbdGroup className="opacity-0 transition-opacity group-hover:opacity-100 group-focus-visible:opacity-100"> + <Kbd>{getOSKey()}</Kbd> + <Kbd>K</Kbd> + </KbdGroup> + } + /> + )} + </nav> + <div className="relative min-h-0 flex-1 flex flex-col"> + <div className="mx-2 pt-6 mb-1.5"> + <div className="ml-2.5 mr-2 flex h-7 items-center justify-between"> + <h2 className="m-0 text-sm font-normal leading-6 text-content-secondary"> + Chats + </h2> + <div className="flex items-center gap-1"> + {onOpenSearchDialog && ( + <Button + variant="subtle" + size="icon" + aria-label="Search chats" + onClick={onOpenSearchDialog} + className="h-7 w-7 sm:hidden" + > + <SearchIcon /> + </Button> + )} + <FilterPopover + filters={sidebarFilters} + onFiltersChange={onSidebarFiltersChange} + /> + </div> + </div> + </div> <ScrollArea - className="h-full [&_[data-radix-scroll-area-viewport]>div]:!block" + className="min-h-0 flex-1 [&_[data-radix-scroll-area-viewport]>div]:!block" scrollBarClassName="w-1.5" viewportClassName={cn( "[mask-image:linear-gradient(to_bottom,transparent_0,black_20px,black_calc(100%-20px),transparent_100%)]", @@ -408,40 +446,7 @@ export const ChatsPanel: FC<ChatsPanelProps> = ({ "sm:[mask-image:none] sm:[-webkit-mask-image:none]", )} > - <div className="flex flex-col gap-2 px-2 pb-3 pt-6"> - <div className="ml-2.5 mr-2 flex h-7 items-center justify-between"> - <h2 className="m-0 text-sm font-normal leading-6 text-content-primary"> - Chats - </h2> - <div className="flex flex-row -space-x-1"> - <Tooltip delayDuration={500}> - <TooltipTrigger asChild> - <Button - variant="subtle" - size="icon" - aria-label="Search chats" - onClick={onOpenSearchDialog} - className="size-7 justify-end px-0" - > - <SearchIcon /> - </Button> - </TooltipTrigger> - <TooltipContent side="bottom" align="end"> - <span className="flex items-center gap-1"> - <KbdGroup> - <Kbd>{getOSKey()}</Kbd> - <Kbd>K</Kbd> - </KbdGroup> - <span>Search chats</span> - </span> - </TooltipContent> - </Tooltip> - <FilterPopover - filters={sidebarFilters} - onFiltersChange={onSidebarFiltersChange} - /> - </div> - </div> + <div className="flex flex-col gap-2 px-2 pb-3"> {loadError ? ( <div className="space-y-3 px-1"> <ErrorAlert error={loadError} /> diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.stories.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.stories.tsx index 523c0f6266..24ffe62b86 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.stories.tsx @@ -91,6 +91,7 @@ const meta: Meta<typeof ChatSearchDialog> = { args: { open: true, onOpenChange: fn(), + recentChats: mockChats, location: { pathname: "/agents", search: "", @@ -106,6 +107,8 @@ const meta: Meta<typeof ChatSearchDialog> = { { path: "/agents", useStoryElement: true }, { path: "/agents/:agentId", useStoryElement: true }, { path: "/agents/settings", useStoryElement: true }, + { path: "/agents/settings/personal-skills", useStoryElement: true }, + { path: "/agents/analytics", useStoryElement: true }, ], }), }, @@ -188,6 +191,13 @@ export const RefreshingResults: Story = { // spinner were always visible. expect(body.queryByLabelText("Searching chats")).not.toBeInTheDocument(); + // Ensure the first debounced API call has been registered before + // clearing, so the clear+retype cycle triggers a distinct second call + // rather than coalescing within a single debounce window. + await waitFor(() => { + expect(API.experimental.getChats).toHaveBeenCalledTimes(1); + }); + await userEvent.clear(searchInput); await userEvent.type(searchInput, "review"); @@ -338,3 +348,197 @@ export const ErrorState: Story = { await expect(await body.findByRole("alert")).toBeInTheDocument(); }, }; + +export const ErrorStateWithStackTrace: Story = { + beforeEach: () => { + const err = new Error( + "NetworkError: Failed to fetch chats from the server API endpoint /api/v2/chats", + ); + err.stack = [ + "Error: NetworkError: Failed to fetch chats from the server API endpoint /api/v2/chats", + " at fetchChats (http://localhost:6006/src/api/queries/chats.ts:42:11)", + " at async queryFn (http://localhost:6006/src/api/queries/chats.ts:58:14)", + " at async Object.fetchQuery (http://localhost:6006/node_modules/@tanstack/react-query/src/queryClient.ts:198:16)", + " at async ChatSearchDialogContent (http://localhost:6006/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.tsx:180:20)", + " at async renderWithHooks (http://localhost:6006/node_modules/react-dom/cjs/react-dom.development.js:14985:18)", + " at async mountIndeterminateComponent (http://localhost:6006/node_modules/react-dom/cjs/react-dom.development.js:17811:13)", + " at async beginWork (http://localhost:6006/node_modules/react-dom/cjs/react-dom.development.js:19049:16)", + ].join("\n"); + spyOn(API.experimental, "getChats").mockRejectedValue(err); + }, + play: async () => { + const body = within(document.body); + await userEvent.type( + body.getByRole("combobox", { name: "Search chats" }), + "title:", + ); + const alert = await body.findByRole("alert"); + await expect(alert).toBeInTheDocument(); + + // Open the stack trace details and verify it stays contained. + const details = body.getByText("Stack Trace"); + await userEvent.click(details); + await expect(body.getByText(/fetchChats/)).toBeInTheDocument(); + }, +}; + +// --------------------------------------------------------------------------- +// Interaction states: default view, filter pills, dropdown. +// --------------------------------------------------------------------------- + +export const DefaultViewWithRecentChats: Story = { + play: async () => { + const body = within(document.body); + await expect(await body.findByText("Recent chats")).toBeInTheDocument(); + await expect( + body.getByText("Fix race condition in auth middleware"), + ).toBeInTheDocument(); + }, +}; + +export const FilterDropdownOnFocus: Story = { + play: async () => { + const body = within(document.body); + const toggleButton = body.getByRole("button", { name: "Toggle filters" }); + + await userEvent.click(toggleButton); + await expect(await body.findByText("Filter by")).toBeInTheDocument(); + await expect(body.getByText("Unread")).toBeInTheDocument(); + await expect(body.getByText("Archived")).toBeInTheDocument(); + await expect(body.getByText("PR status")).toBeInTheDocument(); + await expect(body.getByText("Diff URL")).toBeInTheDocument(); + }, +}; + +export const BooleanFilterPill: Story = { + play: async () => { + const body = within(document.body); + const toggleButton = body.getByRole("button", { name: "Toggle filters" }); + + await userEvent.click(toggleButton); + await userEvent.click(await body.findByText("Unread")); + + await expect(await body.findByText("has_unread:true")).toBeInTheDocument(); + await expect( + body.getByRole("button", { name: "Remove has_unread filter" }), + ).toBeInTheDocument(); + + await waitFor(() => { + expect(API.experimental.getChats).toHaveBeenCalledWith({ + limit: CHAT_SEARCH_LIMIT, + q: "has_unread:true", + }); + }); + }, +}; + +export const ParameterizedFilterPill: Story = { + beforeEach: () => { + spyOn(API.experimental, "getChats").mockResolvedValue(mockChats); + }, + play: async () => { + const body = within(document.body); + const searchInput = body.getByRole("combobox", { name: "Search chats" }); + const toggleButton = body.getByRole("button", { name: "Toggle filters" }); + + await userEvent.click(toggleButton); + await userEvent.click(await body.findByText("PR status")); + + await expect(await body.findByText("pr_status:")).toBeInTheDocument(); + + await userEvent.click(searchInput); + await userEvent.type(searchInput, "open "); + + await expect(await body.findByText("pr_status:open")).toBeInTheDocument(); + + await waitFor(() => { + expect(API.experimental.getChats).toHaveBeenCalledWith({ + limit: CHAT_SEARCH_LIMIT, + q: "pr_status:open", + }); + }); + }, +}; + +export const ParameterizedFilterPillEnterCommit: Story = { + beforeEach: () => { + spyOn(API.experimental, "getChats").mockResolvedValue(mockChats); + }, + play: async () => { + const body = within(document.body); + const searchInput = body.getByRole("combobox", { name: "Search chats" }); + const toggleButton = body.getByRole("button", { name: "Toggle filters" }); + + await userEvent.click(toggleButton); + await userEvent.click(await body.findByText("PR status")); + + await expect(await body.findByText("pr_status:")).toBeInTheDocument(); + + await userEvent.click(searchInput); + await userEvent.type(searchInput, "closed"); + await userEvent.keyboard("{Enter}"); + + await expect(await body.findByText("pr_status:closed")).toBeInTheDocument(); + + await waitFor(() => { + expect(API.experimental.getChats).toHaveBeenCalledWith({ + limit: CHAT_SEARCH_LIMIT, + q: "pr_status:closed", + }); + }); + }, +}; + +export const BackspaceRemovesFilter: Story = { + play: async () => { + const body = within(document.body); + const searchInput = body.getByRole("combobox", { name: "Search chats" }); + const toggleButton = body.getByRole("button", { name: "Toggle filters" }); + + await userEvent.click(toggleButton); + await userEvent.click(await body.findByText("Unread")); + await expect(await body.findByText("has_unread:true")).toBeInTheDocument(); + + await userEvent.click(searchInput); + await userEvent.keyboard("{Backspace}"); + await waitFor(() => { + expect(body.queryByText("has_unread:true")).not.toBeInTheDocument(); + }); + }, +}; + +export const TypedFilterAutoDetection: Story = { + play: async () => { + const body = within(document.body); + const searchInput = body.getByRole("combobox", { name: "Search chats" }); + + await userEvent.type(searchInput, "has_unread:true "); + + await expect(await body.findByText("has_unread:true")).toBeInTheDocument(); + await expect( + body.getByRole("button", { name: "Remove has_unread filter" }), + ).toBeInTheDocument(); + }, +}; + +export const CombinedFilterAndText: Story = { + play: async () => { + const body = within(document.body); + const searchInput = body.getByRole("combobox", { name: "Search chats" }); + const toggleButton = body.getByRole("button", { name: "Toggle filters" }); + + await userEvent.click(toggleButton); + await userEvent.click(await body.findByText("Unread")); + await expect(await body.findByText("has_unread:true")).toBeInTheDocument(); + + await userEvent.click(searchInput); + await userEvent.type(searchInput, "Fix"); + + await waitFor(() => { + expect(API.experimental.getChats).toHaveBeenCalledWith({ + limit: CHAT_SEARCH_LIMIT, + q: 'has_unread:true title:"Fix"', + }); + }); + }, +}; diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.tsx index 912fe40271..25aa41dca7 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchDialog.tsx @@ -1,18 +1,72 @@ +import { + ArchiveIcon, + CircleDotIcon, + FileTextIcon, + LinkIcon, +} from "lucide-react"; import type { FC, RefObject } from "react"; -import { type KeyboardEventHandler, useId, useRef, useState } from "react"; +import { + type KeyboardEventHandler, + useId, + useMemo, + useRef, + useState, +} from "react"; import { keepPreviousData, useQuery } from "react-query"; import { type Location, useNavigate } from "react-router"; import { chatSearch } from "#/api/queries/chats"; +import type { Chat } from "#/api/typesGenerated"; +import { Button } from "#/components/Button/Button"; import { Dialog, DialogContent, DialogTitle } from "#/components/Dialog/Dialog"; import { useDebouncedValue } from "#/hooks/debounce"; -import { ChatSearchInput } from "./ChatSearchInput"; +import { ChatSearchInput, type SearchFilter } from "./ChatSearchInput"; import { ChatSearchResults } from "./ChatSearchResults"; import { normalizeChatSearchInput } from "./searchQuery"; +// Filter definitions. Filters with a defaultValue are inserted as complete +// pills (e.g. has_unread:true). Filters without one are inserted as +// incomplete pills so the user can type the value. +type FilterDefinition = { + readonly key: string; + readonly label: string; + readonly icon: FC<{ className?: string }>; + readonly defaultValue: string | null; +}; + +const FILTER_DEFINITIONS: readonly FilterDefinition[] = [ + { + key: "has_unread", + label: "Unread", + icon: CircleDotIcon, + defaultValue: "true", + }, + { + key: "archived", + label: "Archived", + icon: ArchiveIcon, + defaultValue: "true", + }, + { + key: "pr_status", + label: "PR status", + icon: FileTextIcon, + defaultValue: null, + }, + { key: "diff_url", label: "Diff URL", icon: LinkIcon, defaultValue: null }, +]; + +// Set of recognized filter keys for detecting typed filter patterns +// (e.g. "has_unread:true" typed directly into the input). Derived from +// FILTER_DEFINITIONS; the backend equivalent lives in searchQuery.ts as +// passthroughChatSearchFilterKeys. +const KNOWN_FILTER_KEYS = new Set(FILTER_DEFINITIONS.map((def) => def.key)); + type ChatSearchDialogProps = { readonly open: boolean; readonly onOpenChange: (open: boolean) => void; + readonly focusInputOnOpen?: boolean; readonly location: Location; + readonly recentChats?: readonly Chat[]; }; const SEARCH_DEBOUNCE_MS = 500; @@ -20,13 +74,17 @@ const SEARCH_DEBOUNCE_MS = 500; export const ChatSearchDialog: FC<ChatSearchDialogProps> = ({ open, onOpenChange, + focusInputOnOpen = true, location, + recentChats = [], }) => { + const contentRef = useRef<HTMLDivElement | null>(null); const inputRef = useRef<HTMLInputElement | null>(null); return ( <Dialog open={open} onOpenChange={onOpenChange}> <DialogContent + ref={contentRef} // `top` is pinned (rather than the default `top-1/2 -translate-y-1/2`) // so the dialog doesn't re-center when its content height changes // between the empty hint, loading skeleton, and results states. @@ -41,11 +99,17 @@ export const ChatSearchDialog: FC<ChatSearchDialogProps> = ({ // dialog resizing visibly as results stream in. style={{ animation: "none", transition: "none" }} aria-describedby={undefined} + tabIndex={-1} + // When opened from the mobile sidebar button, skip autofocusing + // the input so the virtual keyboard doesn't push the dialog + // off-screen. Focus the dialog container instead to keep the + // element in the accessibility tree. onOpenAutoFocus={(event) => { + if (focusInputOnOpen) { + return; + } event.preventDefault(); - requestAnimationFrame(() => { - inputRef.current?.focus(); - }); + contentRef.current?.focus({ preventScroll: true }); }} > <ChatSearchDialogContent @@ -53,31 +117,90 @@ export const ChatSearchDialog: FC<ChatSearchDialogProps> = ({ onOpenChange={onOpenChange} location={location} inputRef={inputRef} + recentChats={recentChats} /> </DialogContent> </Dialog> ); }; -type ChatSearchDialogContentProps = ChatSearchDialogProps & { +type ChatSearchDialogContentProps = Omit< + ChatSearchDialogProps, + "focusInputOnOpen" +> & { readonly inputRef: RefObject<HTMLInputElement | null>; }; +// Build a raw query string from structured filters + freeform text, then +// normalize it through the existing parser that the backend expects. +const buildQuery = ( + filters: readonly SearchFilter[], + freeText: string, +): string | undefined => { + const parts: string[] = []; + for (const f of filters) { + if (f.value !== null && f.value !== "") { + // Strip internal quotes before wrapping so the resulting + // key:"value" token stays well-formed for the backend. + const stripped = f.value.replaceAll('"', ""); + const v = stripped.includes(" ") ? `"${stripped}"` : stripped; + parts.push(`${f.key}:${v}`); + } + } + if (freeText.trim()) { + parts.push(freeText.trim()); + } + const raw = parts.join(" "); + return normalizeChatSearchInput(raw); +}; + const ChatSearchDialogContent: FC<ChatSearchDialogContentProps> = ({ open, onOpenChange, location, inputRef, + recentChats = [], }) => { const navigate = useNavigate(); - const [inputValue, setInputValue] = useState(""); + const [filters, setFilters] = useState<SearchFilter[]>([]); + const [freeText, setFreeText] = useState(""); + // Tracks the key of a parameterized filter being typed (e.g. "pr_status"). + // While set, freeText holds the in-progress value and the pill shows as + // incomplete (dashed border). Space or Enter commits the value. + const [incompleteFilterKey, setIncompleteFilterKey] = useState<string | null>( + null, + ); + const [isDropdownOpen, setIsDropdownOpen] = useState(false); const [selectedChatIndex, setSelectedChatIndex] = useState< number | undefined >(undefined); const listboxId = useId(); - const debouncedInput = useDebouncedValue(inputValue, SEARCH_DEBOUNCE_MS); - const normalizedQuery = normalizeChatSearchInput(debouncedInput); - const hasQuery = inputValue.trim() !== "" && normalizedQuery !== undefined; + + // Build the full filter list for query building. When an incomplete filter + // has text, include it so debounced search can run against partial values. + const effectiveFilters = useMemo( + () => + incompleteFilterKey && freeText.trim() + ? [...filters, { key: incompleteFilterKey, value: freeText.trim() }] + : filters, + [filters, incompleteFilterKey, freeText], + ); + const hasActiveSearch = effectiveFilters.length > 0 || freeText.trim() !== ""; + + const debouncedFreeText = useDebouncedValue(freeText, SEARCH_DEBOUNCE_MS); + const debouncedFilters = useDebouncedValue( + effectiveFilters, + SEARCH_DEBOUNCE_MS, + ); + // When typing into an incomplete filter, only send the filter (not + // freeText as bare title search). + // When freeText is cleared (e.g. after committing a filter), zero + // queryFreeText immediately instead of waiting for the debounce to + // flush. Otherwise the stale debouncedFreeText leaks into the query. + const queryFreeText = + incompleteFilterKey || !freeText.trim() ? "" : debouncedFreeText; + const normalizedQuery = buildQuery(debouncedFilters, queryFreeText); + const hasQuery = hasActiveSearch && normalizedQuery !== undefined; const searchQuery = useQuery({ ...chatSearch(normalizedQuery ?? ""), @@ -85,14 +208,21 @@ const ChatSearchDialogContent: FC<ChatSearchDialogContentProps> = ({ placeholderData: keepPreviousData, }); - const resultCount = searchQuery.data?.length ?? 0; + // Use search results count when a query is active, otherwise count + // recent chats so keyboard navigation works in the default view too. + const recentChatsSlice = (recentChats ?? []).slice(0, 10); + const resultCount = hasQuery + ? (searchQuery.data?.length ?? 0) + : recentChatsSlice.length; const safeSelectedChatIndex = selectedChatIndex !== undefined && selectedChatIndex < resultCount ? selectedChatIndex : undefined; const selectedChat = safeSelectedChatIndex !== undefined - ? searchQuery.data?.[safeSelectedChatIndex] + ? hasQuery + ? searchQuery.data?.[safeSelectedChatIndex] + : recentChatsSlice[safeSelectedChatIndex] : undefined; const activeResultId = safeSelectedChatIndex !== undefined @@ -109,9 +239,130 @@ const ChatSearchDialogContent: FC<ChatSearchDialogContentProps> = ({ searchQuery.isFetching && searchQuery.isPlaceholderData && !showResultsLoading; + + const commitIncompleteFilter = () => { + if (incompleteFilterKey && freeText.trim()) { + setFilters((prev) => [ + ...prev, + { key: incompleteFilterKey, value: freeText.trim() }, + ]); + setFreeText(""); + setIncompleteFilterKey(null); + } + }; + + const addFilter = (def: FilterDefinition) => { + if ( + filters.some((f) => f.key === def.key) || + incompleteFilterKey === def.key + ) { + return; + } + commitIncompleteFilter(); + + if (def.defaultValue !== null) { + setFilters((prev) => [ + ...prev, + { key: def.key, value: def.defaultValue }, + ]); + } else { + setIncompleteFilterKey(def.key); + setFreeText(""); + } + setIsDropdownOpen(false); + setSelectedChatIndex(undefined); + requestAnimationFrame(() => inputRef.current?.focus()); + }; + + const removeFilter = (key: string) => { + if (incompleteFilterKey === key) { + setIncompleteFilterKey(null); + setFreeText(""); + } else { + setFilters((prev) => prev.filter((f) => f.key !== key)); + } + setSelectedChatIndex(undefined); + requestAnimationFrame(() => inputRef.current?.focus()); + }; + + const handleInputChange = (value: string) => { + setFreeText(value); + setSelectedChatIndex(undefined); + }; + + // Build the display filters for ChatSearchInput: completed filters plus + // the incomplete one (shown with dashed border). + const displayFilters: SearchFilter[] = incompleteFilterKey + ? [...filters, { key: incompleteFilterKey, value: null }] + : filters; + const handleInputKeyDown: KeyboardEventHandler<HTMLInputElement> = ( event, ) => { + if ( + (event.key === " " || event.key === "Enter") && + incompleteFilterKey && + freeText.trim() + ) { + event.preventDefault(); + commitIncompleteFilter(); + return; + } + + if ( + (event.key === " " || event.key === "Enter") && + !incompleteFilterKey && + freeText.trim() + ) { + const activeKeys = new Set(filters.map((f) => f.key)); + const tokens = freeText.trim().split(/\s+/); + const newFilters: SearchFilter[] = []; + const remaining: string[] = []; + + for (const token of tokens) { + const colonIndex = token.indexOf(":"); + if (colonIndex > 0 && colonIndex < token.length - 1) { + const key = token.slice(0, colonIndex); + const val = token.slice(colonIndex + 1); + if (KNOWN_FILTER_KEYS.has(key)) { + // Drop duplicate filter keys silently instead of + // letting them fall through to freeform text. + if (!activeKeys.has(key)) { + newFilters.push({ key, value: val }); + activeKeys.add(key); + } + continue; + } + } + remaining.push(token); + } + + if (newFilters.length > 0) { + event.preventDefault(); + setFilters((prev) => [...prev, ...newFilters]); + setFreeText(remaining.join(" ")); + return; + } + } + + if (event.key === "Backspace" && freeText === "") { + if (incompleteFilterKey) { + setIncompleteFilterKey(null); + return; + } + if (filters.length > 0) { + const lastFilter = filters[filters.length - 1]; + removeFilter(lastFilter.key); + return; + } + } + + if (event.key === "Escape" && isDropdownOpen) { + setIsDropdownOpen(false); + event.stopPropagation(); + return; + } + if (event.key === "ArrowDown" || event.key === "ArrowUp") { if (resultCount === 0) { return; @@ -145,21 +396,38 @@ const ChatSearchDialogContent: FC<ChatSearchDialogContentProps> = ({ return ( <> <DialogTitle className="sr-only">Search chats</DialogTitle> - <ChatSearchInput - activeResultId={activeResultId} - hasResults={resultCount > 0} - inputRef={inputRef} - listboxId={listboxId} - value={inputValue} - onChange={(event) => { - setInputValue(event.target.value); - setSelectedChatIndex(undefined); + {/* Wrap input + dropdown so onBlur on the container closes + the dropdown, but clicks within the dropdown (which is + inside the same container) don't trigger blur. */} + <div + className="relative" + onBlur={(e) => { + if (!e.currentTarget.contains(e.relatedTarget)) { + setIsDropdownOpen(false); + } }} - onKeyDown={handleInputKeyDown} - /> + > + <ChatSearchInput + activeResultId={activeResultId} + hasResults={resultCount > 0} + inputRef={inputRef} + listboxId={listboxId} + filters={displayFilters} + value={freeText} + onChange={(event) => handleInputChange(event.target.value)} + onKeyDown={handleInputKeyDown} + onRemoveFilter={removeFilter} + isDropdownOpen={isDropdownOpen} + onToggleDropdown={() => setIsDropdownOpen((prev) => !prev)} + /> + {isDropdownOpen && ( + <FilterDropdown filters={displayFilters} onSelectFilter={addFilter} /> + )} + </div> <ChatSearchResults chats={searchQuery.data} + recentChats={recentChats} error={searchQuery.error} hasQuery={hasQuery} location={location} @@ -167,8 +435,45 @@ const ChatSearchDialogContent: FC<ChatSearchDialogContentProps> = ({ selectedChatIndex={safeSelectedChatIndex} showLoading={showResultsLoading} isRefreshing={isRefreshing} - onSelectChat={closeDialog} + onDismiss={closeDialog} /> </> ); }; + +// --------------------------------------------------------------------------- +// Filter dropdown: appears on focus, shows clickable filter chips. +// --------------------------------------------------------------------------- + +const FilterDropdown: FC<{ + readonly filters: readonly SearchFilter[]; + readonly onSelectFilter: (def: FilterDefinition) => void; +}> = ({ filters, onSelectFilter }) => { + const activeKeys = new Set(filters.map((f) => f.key)); + + return ( + <div className="absolute left-0 right-0 top-full z-10 mt-1 rounded-md border border-solid border-border bg-surface-primary p-3 shadow-md"> + <h3 className="m-0 mb-2 text-xs font-medium text-content-secondary"> + Filter by + </h3> + <div className="flex flex-wrap gap-2"> + {FILTER_DEFINITIONS.map((def) => { + const Icon = def.icon; + const isActive = activeKeys.has(def.key); + return ( + <Button + key={def.key} + variant="outline" + size="sm" + disabled={isActive} + onClick={() => onSelectFilter(def)} + > + <Icon className="size-4" /> + {def.label} + </Button> + ); + })} + </div> + </div> + ); +}; diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchInput.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchInput.tsx index 1d5cc66dc0..b134f8854b 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchInput.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchInput.tsx @@ -1,20 +1,29 @@ -import { SearchIcon } from "lucide-react"; +import { ListFilterIcon, SearchIcon, XIcon } from "lucide-react"; import type { ChangeEventHandler, FC, KeyboardEventHandler, RefObject, } from "react"; -import { Input } from "#/components/Input/Input"; +import { cn } from "#/utils/cn"; + +export type SearchFilter = { + readonly key: string; + readonly value: string | null; +}; type ChatSearchInputProps = { readonly activeResultId: string | undefined; readonly hasResults: boolean; readonly inputRef: RefObject<HTMLInputElement | null>; readonly listboxId: string; + readonly filters: readonly SearchFilter[]; readonly value: string; readonly onChange: ChangeEventHandler<HTMLInputElement>; readonly onKeyDown: KeyboardEventHandler<HTMLInputElement>; + readonly onRemoveFilter: (key: string) => void; + readonly isDropdownOpen: boolean; + readonly onToggleDropdown: () => void; }; export const ChatSearchInput: FC<ChatSearchInputProps> = ({ @@ -22,20 +31,58 @@ export const ChatSearchInput: FC<ChatSearchInputProps> = ({ hasResults, inputRef, listboxId, + filters, value, onChange, onKeyDown, + onRemoveFilter, + isDropdownOpen, + onToggleDropdown, }) => { + const completedFilters = filters.filter((f) => f.value !== null); + const incompleteFilter = filters.find((f) => f.value === null); + return ( - <div className="relative min-w-0"> - <SearchIcon className="pointer-events-none absolute left-3 top-1/2 size-4 -translate-y-1/2 text-content-secondary" /> - <Input + <div + className={cn( + "flex min-h-10 w-full items-center gap-1.5 rounded-md border border-solid border-border-default bg-surface-primary px-3", + "focus-within:ring-2 focus-within:ring-content-link", + )} + > + <SearchIcon className="size-4 shrink-0 text-content-secondary" /> + {completedFilters.map((f) => ( + <span + key={f.key} + className="inline-flex shrink-0 items-center gap-1 rounded-md border border-solid border-border bg-surface-secondary px-2 py-0.5 text-xs text-content-secondary" + > + <span> + {f.key}:{f.value} + </span> + <button + type="button" + onClick={(e) => { + e.stopPropagation(); + onRemoveFilter(f.key); + }} + className="inline-flex cursor-pointer items-center border-none bg-transparent p-0 text-content-secondary hover:text-content-primary" + aria-label={`Remove ${f.key} filter`} + > + <XIcon className="size-3" /> + </button> + </span> + ))} + {incompleteFilter && ( + <span className="inline-flex shrink-0 items-center rounded-md border border-dashed border-border bg-surface-secondary px-2 py-0.5 text-xs text-content-secondary"> + {incompleteFilter.key}: + </span> + )} + <input ref={inputRef} value={value} onChange={onChange} onKeyDown={onKeyDown} - placeholder="Search chats..." - className="h-10 border-border-default bg-surface-primary pl-9 pr-3 placeholder:text-content-disabled" + placeholder={filters.length > 0 ? "" : "Search chats..."} + className="min-w-[60px] flex-1 border-none bg-transparent py-2 text-sm text-content-primary outline-none placeholder:text-content-disabled" aria-label="Search chats" role="combobox" aria-controls={hasResults ? listboxId : undefined} @@ -43,6 +90,18 @@ export const ChatSearchInput: FC<ChatSearchInputProps> = ({ aria-haspopup="listbox" aria-activedescendant={activeResultId} /> + <button + type="button" + onClick={onToggleDropdown} + className={cn( + "inline-flex shrink-0 cursor-pointer items-center border-none bg-transparent p-0 text-content-secondary hover:text-content-primary", + isDropdownOpen && "text-content-primary", + )} + aria-label="Toggle filters" + aria-expanded={isDropdownOpen} + > + <ListFilterIcon className="size-4" /> + </button> </div> ); }; diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchResults.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchResults.tsx index e4a2cb170a..c49c900332 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchResults.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/dialogs/ChatSearchResults.tsx @@ -12,6 +12,7 @@ import { getChatDisplayConfig } from "../tree/statusConfig"; type ChatSearchResultsProps = { readonly chats: readonly Chat[] | undefined; + readonly recentChats: readonly Chat[]; readonly error: unknown; readonly hasQuery: boolean; readonly location: Location; @@ -19,11 +20,23 @@ type ChatSearchResultsProps = { readonly selectedChatIndex: number | undefined; readonly showLoading: boolean; readonly isRefreshing: boolean; - readonly onSelectChat: () => void; + readonly onDismiss: () => void; +}; + +const RECENT_CHATS_COUNT = 10; + +// !block overrides Radix ScrollArea viewport's display:table so truncated text can shrink. +const SCROLL_AREA_PROPS = { + className: + "h-[300px] w-full [&_[data-radix-scroll-area-viewport]>div]:!block", + scrollBarClassName: "w-[0.375rem]", + viewportClassName: "pr-3", + viewportTabIndex: -1, }; export const ChatSearchResults: FC<ChatSearchResultsProps> = ({ chats, + recentChats, error, hasQuery, location, @@ -31,25 +44,28 @@ export const ChatSearchResults: FC<ChatSearchResultsProps> = ({ selectedChatIndex, showLoading, isRefreshing, - onSelectChat, + onDismiss, }) => { if (error) { return ( <div className="min-h-[260px]"> - <ErrorAlert error={error} /> + <ErrorAlert + error={error} + className="max-h-[340px] overflow-y-auto [&_pre]:whitespace-pre-wrap [&_pre]:break-all [&_pre]:w-auto [&_pre]:min-w-0" + /> </div> ); } if (!hasQuery) { return ( - <div className="min-h-[260px]"> - <div className="pt-2 text-sm text-content-secondary"> - Type to search by title, or use filters like{" "} - <code>has_unread:true</code>, <code>archived:true</code>,{" "} - <code>pr_status:open</code>, or <code>diff_url:"..."</code>. - </div> - </div> + <DefaultView + recentChats={recentChats} + location={location} + listboxId={listboxId} + selectedChatIndex={selectedChatIndex} + onDismiss={onDismiss} + /> ); } @@ -71,33 +87,25 @@ export const ChatSearchResults: FC<ChatSearchResultsProps> = ({ return ( <div className="min-h-[260px]"> <div className="space-y-3"> - <p className="inline-flex items-center gap-1.5 text-sm text-content-secondary"> + <p className="m-0 text-sm text-content-secondary"> <span>{resultSummary}</span> {isRefreshing && ( <Spinner loading size="sm" - className="text-content-secondary" + className="ml-1.5 inline-block align-text-bottom text-content-secondary" aria-label="Searching chats" /> )} </p> - <ScrollArea - // `!block` overrides the Radix ScrollArea viewport wrapper's inline - // `display: table` so that descendants using `truncate` can shrink - // inside the scroll container instead of forcing the table to grow. - className="h-[300px] w-full [&_[data-radix-scroll-area-viewport]>div]:!block" - scrollBarClassName="w-[0.375rem]" - viewportClassName="pr-3" - viewportTabIndex={-1} - > + <ScrollArea {...SCROLL_AREA_PROPS}> <ChatSearchResultsList chats={chats} location={location} listboxId={listboxId} selectedChatIndex={selectedChatIndex} showLoading={showLoading} - onSelectChat={onSelectChat} + onDismiss={onDismiss} /> </ScrollArea> </div> @@ -105,13 +113,72 @@ export const ChatSearchResults: FC<ChatSearchResultsProps> = ({ ); }; +// --------------------------------------------------------------------------- +// Default view: recent chats (shown when no query is active). +// --------------------------------------------------------------------------- + +type DefaultViewProps = { + readonly recentChats: readonly Chat[]; + readonly location: Location; + readonly listboxId: string; + readonly selectedChatIndex: number | undefined; + readonly onDismiss: () => void; +}; + +const DefaultView: FC<DefaultViewProps> = ({ + recentChats, + location, + listboxId, + selectedChatIndex, + onDismiss, +}) => { + const visibleRecentChats = recentChats.slice(0, RECENT_CHATS_COUNT); + + return ( + <div className="min-h-[260px]"> + <div className="space-y-3"> + {visibleRecentChats.length > 0 && ( + <div> + <h3 className="m-0 mb-3 text-sm font-medium text-content-secondary"> + Recent chats + </h3> + <ScrollArea {...SCROLL_AREA_PROPS}> + <div + id={listboxId} + role="listbox" + aria-label="Recent chats" + className="space-y-1" + > + {visibleRecentChats.map((chat, index) => ( + <ChatSearchResultRow + key={chat.id} + chat={chat} + id={`${listboxId}-option-${index}`} + isSelected={selectedChatIndex === index} + location={location} + onSelect={onDismiss} + /> + ))} + </div> + </ScrollArea> + </div> + )} + </div> + </div> + ); +}; + +// --------------------------------------------------------------------------- +// Results list and row components. +// --------------------------------------------------------------------------- + type ChatSearchResultsListProps = { readonly chats: readonly Chat[] | undefined; readonly location: Location; readonly listboxId: string; readonly selectedChatIndex: number | undefined; readonly showLoading: boolean; - readonly onSelectChat: () => void; + readonly onDismiss: () => void; }; const ChatSearchResultsList: FC<ChatSearchResultsListProps> = ({ @@ -120,7 +187,7 @@ const ChatSearchResultsList: FC<ChatSearchResultsListProps> = ({ listboxId, selectedChatIndex, showLoading, - onSelectChat, + onDismiss, }) => { if (showLoading) { return <ChatSearchResultsSkeleton />; @@ -128,9 +195,9 @@ const ChatSearchResultsList: FC<ChatSearchResultsListProps> = ({ if ((chats?.length ?? 0) === 0) { return ( - <p className="px-1.5 py-2 text-sm text-content-secondary"> - No matching chats - </p> + <div className="flex h-[300px] items-center justify-center"> + <p className="text-sm text-content-secondary">No matching chats</p> + </div> ); } @@ -148,7 +215,7 @@ const ChatSearchResultsList: FC<ChatSearchResultsListProps> = ({ id={`${listboxId}-option-${index}`} isSelected={selectedChatIndex === index} location={location} - onSelect={onSelectChat} + onSelect={onDismiss} /> ))} </div> diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/filters/FilterPopover.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/filters/FilterPopover.tsx index 4241209507..9c017a0e74 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/filters/FilterPopover.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/filters/FilterPopover.tsx @@ -226,7 +226,7 @@ export const FilterPopover: FC<FilterPopoverProps> = ({ size="icon" aria-label="Filter agents" className={cn( - "h-7 w-7 min-w-0 justify-end rounded-none px-0 text-content-secondary hover:text-content-primary", + "h-7 w-7 min-w-0 -mr-0.5 justify-end px-0 text-content-secondary hover:text-content-primary", hasActiveFilters(filters) && "text-content-primary", )} > diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/settings/SettingsNavItem.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/settings/SettingsNavItem.tsx index 83566ff389..97a6fbc856 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/settings/SettingsNavItem.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/settings/SettingsNavItem.tsx @@ -1,6 +1,6 @@ import { ShieldIcon } from "lucide-react"; -import type { FC } from "react"; -import { Link, type To } from "react-router"; +import type { ComponentProps, FC, ReactNode } from "react"; +import { Link } from "react-router"; import { Tooltip, TooltipContent, @@ -13,32 +13,52 @@ type SettingsNavItemProps = { label: string; active: boolean; adminOnly?: boolean; + ariaLabel?: string; + className?: string; disabled?: boolean; + trailing?: ReactNode; trailingIcon?: FC<{ className?: string }>; } & ( - | { to: To; replace?: boolean; state?: unknown; onClick?: () => void } + | { + to: ComponentProps<typeof Link>["to"]; + replace?: boolean; + state?: unknown; + onClick?: () => void; + } | { to?: never; replace?: never; state?: never; onClick: () => void } ); -const navItemClassName = (active: boolean, disabled: boolean | undefined) => +const navItemClassName = ( + active: boolean, + disabled: boolean | undefined, + className: string | undefined, +) => cn( - "flex w-full items-center gap-2.5 rounded-md border-0 px-2.5 py-2 text-left text-sm cursor-pointer transition-colors no-underline", + "flex w-full items-center gap-2.5 rounded-md border-0 px-2.5 py-1.5 text-left text-sm cursor-pointer transition-colors no-underline", active ? "bg-surface-quaternary/25 text-content-primary font-medium" : "bg-transparent text-content-secondary hover:bg-surface-tertiary/50 hover:text-content-primary", disabled && "opacity-50 pointer-events-none", + className, ); const NavItemContent: FC<{ icon: FC<{ className?: string }>; label: string; adminOnly?: boolean; + trailing?: ReactNode; trailingIcon?: FC<{ className?: string }>; -}> = ({ icon: Icon, label, adminOnly, trailingIcon: TrailingIcon }) => ( +}> = ({ + icon: Icon, + label, + adminOnly, + trailing, + trailingIcon: TrailingIcon, +}) => ( <> <Icon className="size-4 shrink-0" /> <span className="min-w-0 flex-1">{label}</span> - {(adminOnly || TrailingIcon) && ( + {(adminOnly || trailing || TrailingIcon) && ( <span className="ml-auto flex items-center gap-2"> {adminOnly && ( <Tooltip> @@ -51,6 +71,7 @@ const NavItemContent: FC<{ </Tooltip> )} {TrailingIcon && <TrailingIcon className="size-4 shrink-0" />} + {trailing} </span> )} </> @@ -61,7 +82,10 @@ export const SettingsNavItem: FC<SettingsNavItemProps> = ({ label, active, adminOnly, + ariaLabel, + className, disabled, + trailing, trailingIcon, ...rest }) => { @@ -72,14 +96,16 @@ export const SettingsNavItem: FC<SettingsNavItemProps> = ({ replace={rest.replace} state={rest.state} onClick={rest.onClick} - className={navItemClassName(active, disabled)} + className={navItemClassName(active, disabled, className)} aria-current={active ? "page" : undefined} + aria-label={ariaLabel} tabIndex={disabled ? -1 : undefined} > <NavItemContent icon={icon} label={label} adminOnly={adminOnly} + trailing={trailing} trailingIcon={trailingIcon} /> </Link> @@ -91,13 +117,15 @@ export const SettingsNavItem: FC<SettingsNavItemProps> = ({ type="button" onClick={rest.onClick} disabled={disabled} - className={navItemClassName(active, disabled)} + className={navItemClassName(active, disabled, className)} aria-current={active ? "page" : undefined} + aria-label={ariaLabel} > <NavItemContent icon={icon} label={label} adminOnly={adminOnly} + trailing={trailing} trailingIcon={trailingIcon} /> </button> diff --git a/site/src/pages/AgentsPage/components/ChatsSidebar/tree/ChatTreeNode.tsx b/site/src/pages/AgentsPage/components/ChatsSidebar/tree/ChatTreeNode.tsx index 155058f4c0..aa0a693b40 100644 --- a/site/src/pages/AgentsPage/components/ChatsSidebar/tree/ChatTreeNode.tsx +++ b/site/src/pages/AgentsPage/components/ChatsSidebar/tree/ChatTreeNode.tsx @@ -208,7 +208,7 @@ export const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => { ); return ( - <div className="flex min-w-0 flex-col"> + <div className="flex min-w-0 flex-col gap-0.5"> <ContextMenu> <ContextMenuTrigger asChild> <div @@ -329,7 +329,7 @@ export const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => { <span className="flex items-center justify-end text-xs text-content-secondary/50 tabular-nums [@media(hover:hover)]:group-hover:hidden group-has-[[data-state=open]]:hidden"> {chat.has_unread && !isActiveChat ? ( <span - className="size-2 shrink-0 rounded-full bg-content-link" + className="size-2 shrink-0 rounded-full bg-content-link pr-1" data-testid={`unread-indicator-${chat.id}`} aria-hidden="true" /> @@ -374,7 +374,7 @@ export const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => { </ContextMenu> {hasChildren && isExpanded && ( - <div className="relative ml-4 border-l border-border-default/60 pl-2.5"> + <div className="relative ml-4 flex flex-col border-l border-border-default/60 pl-2.5"> {childIDs.map((childID) => { const childChat = chatById.get(childID); if (!childChat) return null; diff --git a/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.stories.tsx b/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.stories.tsx index 02228c8678..b48c93ff17 100644 --- a/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.stories.tsx @@ -7,9 +7,12 @@ const defaults: DesktopPanelViewProps = { status: "idle", reconnect: fn(), attach: fn(), + scaleMode: "native", + onScaleModeChange: fn(), isControlling: false, onTakeControl: fn(), onReleaseControl: fn(), + onPopOut: fn(), }; const meta: Meta<typeof DesktopPanelView> = { diff --git a/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.tsx b/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.tsx index ef20dcd86d..7e8f74b9f9 100644 --- a/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.tsx +++ b/site/src/pages/AgentsPage/components/RightPanel/DesktopPanel.tsx @@ -1,17 +1,15 @@ -import { HandIcon, MousePointer2Icon } from "lucide-react"; +import { ExternalLinkIcon } from "lucide-react"; import type { FC } from "react"; -import { useState } from "react"; +import { useEffect, useState } from "react"; + import { Button } from "#/components/Button/Button"; import { Spinner } from "#/components/Spinner/Spinner"; -import { cn } from "#/utils/cn"; -import { useDesktopConnection } from "../../hooks/useDesktopConnection"; - -type DesktopConnectionStatus = - | "idle" - | "connecting" - | "connected" - | "disconnected" - | "error"; +import { + type DesktopConnectionStatus, + useDesktopConnection, +} from "../../hooks/useDesktopConnection"; +import { useZoomShortcuts } from "../../hooks/useZoomShortcuts"; +import { DesktopToolbar, type ScaleMode } from "./DesktopToolbar"; interface DesktopPanelProps { chatId: string; @@ -19,19 +17,10 @@ interface DesktopPanelProps { isVisible?: boolean; } -export interface DesktopPanelViewProps { - status: DesktopConnectionStatus; - reconnect: () => void; - attach: (container: HTMLElement) => void; - isControlling: boolean; - onTakeControl: () => void; - onReleaseControl: () => void; -} - export const DesktopPanel: FC<DesktopPanelProps> = ({ chatId, isVisible }) => { // Delay the VNC connection until the desktop tab is first selected. // Once activated, the connection stays alive even when the tab is - // switched away — mirrors the terminal panel pattern from PR #23231. + // switched away. const [activated, setActivated] = useState(false); if (isVisible && !activated) { setActivated(true); @@ -42,29 +31,104 @@ export const DesktopPanel: FC<DesktopPanelProps> = ({ chatId, isVisible }) => { setIsControlling(false); } + const [scaleMode, setScaleMode] = useState<ScaleMode>("fit"); + const [isPoppedOut, setIsPoppedOut] = useState(false); + const { status, reconnect, attach } = useDesktopConnection({ - chatId, - activated, + chatId: isPoppedOut ? undefined : chatId, + activated: activated && !isPoppedOut, + scaleViewport: scaleMode === "fit", }); + + useZoomShortcuts(setScaleMode, isVisible); + + // Listen for BroadcastChannel messages from the pop-out window. + useEffect(() => { + const channel = new BroadcastChannel(`coder-desktop-${chatId}`); + + channel.addEventListener("message", (event) => { + if (event.data?.type === "popout-opened") { + setIsPoppedOut(true); + setIsControlling(false); + } else if (event.data?.type === "popout-closed") { + setIsPoppedOut(false); + } + }); + + return () => channel.close(); + }, [chatId]); + + const handlePopOut = () => { + const width = Math.round(screen.availWidth * 0.5); + const height = Math.round(screen.availHeight * 0.5); + const left = Math.round((screen.availWidth - width) / 2); + const top = Math.round((screen.availHeight - height) / 2); + open( + `/agents/${chatId}/desktop`, + `coder-desktop-${chatId}`, + `popup,width=${width},height=${height},left=${left},top=${top}`, + ); + }; + + const handleBringBack = () => { + const channel = new BroadcastChannel(`coder-desktop-${chatId}`); + channel.postMessage({ type: "bring-back" }); + channel.close(); + setIsPoppedOut(false); + }; + + if (isPoppedOut) { + return ( + <div + className="flex h-full flex-col items-center justify-center gap-3 text-content-secondary" + role="status" + > + <ExternalLinkIcon className="h-8 w-8" /> + <span className="text-sm">Desktop is open in a separate window.</span> + <Button variant="outline" size="sm" onClick={handleBringBack}> + Bring back + </Button> + </div> + ); + } + return ( <DesktopPanelView status={status} reconnect={reconnect} attach={attach} + scaleMode={scaleMode} + onScaleModeChange={setScaleMode} isControlling={isControlling} onTakeControl={() => setIsControlling(true)} onReleaseControl={() => setIsControlling(false)} + onPopOut={handlePopOut} /> ); }; +export interface DesktopPanelViewProps { + status: DesktopConnectionStatus; + reconnect: () => void; + attach: (container: HTMLElement) => void; + scaleMode: ScaleMode; + onScaleModeChange: (mode: ScaleMode) => void; + isControlling: boolean; + onTakeControl: () => void; + onReleaseControl: () => void; + onPopOut?: () => void; +} + export const DesktopPanelView: FC<DesktopPanelViewProps> = ({ status, reconnect, attach, + scaleMode, + onScaleModeChange, isControlling, onTakeControl, onReleaseControl, + onPopOut, }) => { if (status === "connecting") { return ( @@ -109,43 +173,31 @@ export const DesktopPanelView: FC<DesktopPanelViewProps> = ({ // status === "connected" return ( - <div className="relative size-full"> - {/* "Release Control" button — top-right, only when controlling */} - {isControlling && ( - <Button - variant="default" - size="sm" - onClick={onReleaseControl} - className="absolute top-2 right-2 z-20 shadow-xl drop-shadow-lg" - > - <HandIcon className="size-4" /> - Release control - </Button> - )} - {/* VNC container — pointer-events toggled */} - <div - ref={(el) => { - if (el) attach(el); - }} - className={cn("size-full", !isControlling && "pointer-events-none")} + <div className="flex h-full w-full flex-col"> + <DesktopToolbar + scaleMode={scaleMode} + onScaleModeChange={onScaleModeChange} + isControlling={isControlling} + onTakeControl={onTakeControl} + onReleaseControl={onReleaseControl} + onPopOut={onPopOut} /> - {/* "Take Control" hover overlay — only when NOT controlling */} - {!isControlling && ( - <div className="group/desktop absolute inset-0 z-10 flex items-center justify-center bg-black/0 transition-all duration-200 ease-in-out group-hover/desktop:bg-black/40"> - <span className="opacity-0 transition-opacity duration-200 ease-in-out group-hover/desktop:opacity-100"> - <Button - variant="default" - size="sm" - onClick={onTakeControl} - aria-label="Take control of desktop" - className="shadow-xl drop-shadow-lg" - > - <MousePointer2Icon className="size-4" /> - Take control - </Button> - </span> - </div> - )} + + <div className="min-h-0 flex-1 overflow-hidden bg-surface-secondary"> + <div + ref={(el) => { + if (el) attach(el); + }} + className="h-full w-full" + inert={!isControlling ? true : undefined} + role="application" + aria-label={ + isControlling + ? "Remote desktop (interactive)" + : "Remote desktop (view only, take control to interact)" + } + /> + </div> </div> ); }; diff --git a/site/src/pages/AgentsPage/components/RightPanel/DesktopToolbar.stories.tsx b/site/src/pages/AgentsPage/components/RightPanel/DesktopToolbar.stories.tsx new file mode 100644 index 0000000000..36be9a9475 --- /dev/null +++ b/site/src/pages/AgentsPage/components/RightPanel/DesktopToolbar.stories.tsx @@ -0,0 +1,75 @@ +import type { Meta, StoryObj } from "@storybook/react-vite"; +import { expect, fn, userEvent, within } from "storybook/test"; +import { DesktopToolbar } from "./DesktopToolbar"; + +const meta = { + title: "pages/AgentsPage/DesktopToolbar", + component: DesktopToolbar, +} satisfies Meta<typeof DesktopToolbar>; + +export default meta; +type Story = StoryObj<typeof meta>; + +export const ViewOnly: Story = { + args: { + scaleMode: "fit", + onScaleModeChange: fn(), + isControlling: false, + onTakeControl: fn(), + onReleaseControl: fn(), + onPopOut: fn(), + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const takeControl = canvas.getByText("Take control"); + await userEvent.click(takeControl); + await expect(args.onTakeControl).toHaveBeenCalled(); + + const zoom = canvas.getByText("Zoom to 100%"); + await userEvent.click(zoom); + await expect(args.onScaleModeChange).toHaveBeenCalledWith("native"); + + const detach = canvas.getByText("Detach"); + await userEvent.click(detach); + await expect(args.onPopOut).toHaveBeenCalled(); + }, +}; + +export const Controlling: Story = { + args: { + ...ViewOnly.args, + isControlling: true, + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const release = canvas.getByText("Release control"); + await userEvent.click(release); + await expect(args.onReleaseControl).toHaveBeenCalled(); + }, +}; + +export const NativeZoom: Story = { + args: { + ...ViewOnly.args, + scaleMode: "native", + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const zoom = canvas.getByText("Zoom to fit"); + await userEvent.click(zoom); + await expect(args.onScaleModeChange).toHaveBeenCalledWith("fit"); + }, +}; + +export const PoppedOut: Story = { + args: { + ...ViewOnly.args, + isPoppedOut: true, + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + // Detach button should not render when popped out. + const detach = canvas.queryByText("Detach"); + await expect(detach).toBeNull(); + }, +}; diff --git a/site/src/pages/AgentsPage/components/RightPanel/DesktopToolbar.tsx b/site/src/pages/AgentsPage/components/RightPanel/DesktopToolbar.tsx new file mode 100644 index 0000000000..a7c733f1fc --- /dev/null +++ b/site/src/pages/AgentsPage/components/RightPanel/DesktopToolbar.tsx @@ -0,0 +1,100 @@ +import { + ExternalLinkIcon, + HandIcon, + MaximizeIcon, + MousePointer2Icon, + ScalingIcon, +} from "lucide-react"; +import type { FC } from "react"; +import { Button } from "#/components/Button/Button"; +export type ScaleMode = "native" | "fit"; + +interface DesktopToolbarProps { + scaleMode: ScaleMode; + onScaleModeChange: (mode: ScaleMode) => void; + isControlling: boolean; + onTakeControl: () => void; + onReleaseControl: () => void; + onPopOut?: () => void; + isPoppedOut?: boolean; +} + +export const DesktopToolbar: FC<DesktopToolbarProps> = ({ + scaleMode, + onScaleModeChange, + isControlling, + onTakeControl, + onReleaseControl, + onPopOut, + isPoppedOut, +}) => { + return ( + <div + className="flex h-8 shrink-0 items-center justify-end gap-1 border-0 border-b border-solid border-border-default bg-surface-primary px-1.5" + role="group" + aria-label="Desktop controls" + > + {/* Take/Release control */} + <Button + variant="subtle" + size="sm" + onClick={isControlling ? onReleaseControl : onTakeControl} + aria-pressed={isControlling} + className="h-6 gap-1.5 px-2 text-xs" + > + {isControlling ? ( + <> + <HandIcon className="size-3.5" /> + Release control + </> + ) : ( + <> + <MousePointer2Icon className="size-3.5" /> + Take control + </> + )} + </Button> + + {/* Zoom toggle */} + <Button + variant="subtle" + size="sm" + onClick={() => + onScaleModeChange(scaleMode === "native" ? "fit" : "native") + } + aria-label={ + scaleMode === "native" + ? "Zoom to fit (Ctrl+0)" + : "Zoom to 100% (Ctrl+1)" + } + className="h-6 gap-1.5 px-2 text-xs" + > + {scaleMode === "native" ? ( + <> + <ScalingIcon className="size-3.5" /> + Zoom to fit + </> + ) : ( + <> + <MaximizeIcon className="size-3.5" /> + Zoom to 100% + </> + )} + </Button> + + {/* Detach button */} + {onPopOut && !isPoppedOut && ( + <Button + variant="subtle" + size="sm" + onClick={onPopOut} + aria-label="Detach desktop to new window" + className="h-6 gap-1.5 px-2 text-xs" + > + <ExternalLinkIcon className="size-3.5" /> + Detach + </Button> + )} + </div> + ); +}; diff --git a/site/src/pages/AgentsPage/hooks/useDesktopConnection.test.ts b/site/src/pages/AgentsPage/hooks/useDesktopConnection.test.ts index 34cfd36eaa..410ae38758 100644 --- a/site/src/pages/AgentsPage/hooks/useDesktopConnection.test.ts +++ b/site/src/pages/AgentsPage/hooks/useDesktopConnection.test.ts @@ -186,7 +186,11 @@ describe("useDesktopConnection", () => { it("sets scaleViewport and resizeSession on the RFB instance", () => { renderHook(() => - useDesktopConnection({ chatId: "chat-1", activated: true }), + useDesktopConnection({ + chatId: "chat-1", + activated: true, + scaleViewport: true, + }), ); const rfb = getLastRFBInstance(); @@ -194,6 +198,26 @@ describe("useDesktopConnection", () => { expect(rfb.resizeSession).toBe(false); }); + it("syncs scaleViewport changes to the RFB instance", () => { + const { rerender } = renderHook( + ({ scaleViewport }) => + useDesktopConnection({ + chatId: "chat-1", + activated: true, + scaleViewport, + }), + { initialProps: { scaleViewport: true } }, + ); + const rfb = getLastRFBInstance(); + expect(rfb.scaleViewport).toBe(true); + + rerender({ scaleViewport: false }); + expect(rfb.scaleViewport).toBe(false); + + rerender({ scaleViewport: true }); + expect(rfb.scaleViewport).toBe(true); + }); + it("transitions to error on securityfailure", () => { const { result } = renderHook(() => useDesktopConnection({ chatId: "chat-1", activated: true }), @@ -906,7 +930,11 @@ describe("useDesktopConnection", () => { it("forces scaleViewport on hidden→visible transition", () => { renderHook(() => - useDesktopConnection({ chatId: "chat-1", activated: true }), + useDesktopConnection({ + chatId: "chat-1", + activated: true, + scaleViewport: true, + }), ); const rfb = getLastRFBInstance(); act(() => rfb.simulateEvent("connect")); @@ -920,12 +948,12 @@ describe("useDesktopConnection", () => { act(() => observer.simulateResize(0, 0)); // Reset so we can detect re-assignment. - rfb.scaleViewport = false; + const spy = vi.spyOn(rfb, "scaleViewport", "set"); // Container visible again — should force rescale. act(() => observer.simulateResize(800, 600)); - expect(rfb.scaleViewport).toBe(true); + expect(spy).toHaveBeenCalled(); }); it("does not force scaleViewport on normal nonzero→nonzero resize", () => { diff --git a/site/src/pages/AgentsPage/hooks/useDesktopConnection.ts b/site/src/pages/AgentsPage/hooks/useDesktopConnection.ts index 8c9c7fa61c..71fb86216b 100644 --- a/site/src/pages/AgentsPage/hooks/useDesktopConnection.ts +++ b/site/src/pages/AgentsPage/hooks/useDesktopConnection.ts @@ -3,20 +3,21 @@ import { useEffect, useRef, useState } from "react"; import { toast } from "sonner"; import { watchChatDesktop } from "#/api/api"; import { useClipboard } from "#/hooks/useClipboard"; - -interface UseDesktopConnectionOptions { - chatId: string | undefined; - /** When false the hook stays dormant — no WebSocket, no RFB. */ - activated: boolean; -} - -type DesktopConnectionStatus = +export type DesktopConnectionStatus = | "idle" | "connecting" | "connected" | "disconnected" | "error"; +interface UseDesktopConnectionOptions { + chatId: string | undefined; + /** When false the hook stays dormant, no WebSocket, no RFB. */ + activated: boolean; + /** When true the viewport is scaled to fit the container. Default: false (native 100%). */ + scaleViewport?: boolean; +} + export interface UseDesktopConnectionResult { /** Current connection status. */ status: DesktopConnectionStatus; @@ -82,8 +83,10 @@ const isMacCutShortcut = (event: KeyboardEvent): boolean => { export function useDesktopConnection({ chatId, activated, + scaleViewport = false, }: UseDesktopConnectionOptions): UseDesktopConnectionResult { const [status, setStatus] = useState<DesktopConnectionStatus>("idle"); + const [hasConnected, setHasConnected] = useState(false); const [remoteClipboardText, setRemoteClipboardText] = useState<string | null>( null, @@ -228,6 +231,7 @@ export function useDesktopConnection({ offscreenContainerRef.current.style.width = "100%"; offscreenContainerRef.current.style.height = "100%"; offscreenContainerRef.current.style.position = "relative"; + offscreenContainerRef.current.style.overflow = "hidden"; const socket = watchChatDesktop(chatId); @@ -236,10 +240,20 @@ export function useDesktopConnection({ shared: true, }); - rfb.scaleViewport = true; + rfb.scaleViewport = false; rfb.resizeSession = false; rfb.focusOnClick = true; + // Override the noVNC default background (rgb(40,40,40)) + // so the letterbox margins match the app surface color + // in both light and dark themes. + const surfaceHsl = getComputedStyle(document.documentElement) + .getPropertyValue("--surface-secondary") + .trim(); + if (surfaceHsl) { + rfb.background = `hsl(${surfaceHsl})`; + } + // Per-session flags scoped to this RFB instance. // NOT refs — each doConnect() gets fresh copies so // state from a previous session cannot leak. @@ -458,7 +472,7 @@ export function useDesktopConnection({ // shrinks to 0×0. When the container becomes visible // again, noVNC may skip rescaling because it believes // the viewport size hasn't changed. Re-assigning - // scaleViewport = true forces a fresh scale pass + // scaleViewport forces a fresh scale pass // regardless. let prevContainerW = 0; let prevContainerH = 0; @@ -472,7 +486,12 @@ export function useDesktopConnection({ prevContainerW = width; prevContainerH = height; if (wasHidden && isVisible && rfbRef.current) { - rfbRef.current.scaleViewport = true; + // Re-assign the current value to force noVNC to + // recalculate the viewport. The setter triggers an + // internal rescale regardless of whether the value + // actually changed. + const current = rfbRef.current.scaleViewport; + rfbRef.current.scaleViewport = current; } }); visibilityObserver.observe(offscreenContainerRef.current); @@ -500,6 +519,12 @@ export function useDesktopConnection({ }; }, [activated, chatId, syncRemoteClipboardToLocal]); + useEffect(() => { + if (rfbInstance && rfbRef.current) { + rfbRef.current.scaleViewport = scaleViewport; + } + }, [rfbInstance, scaleViewport]); + return { status, hasConnected, diff --git a/site/src/pages/AgentsPage/hooks/useZoomShortcuts.ts b/site/src/pages/AgentsPage/hooks/useZoomShortcuts.ts new file mode 100644 index 0000000000..46d05eea2c --- /dev/null +++ b/site/src/pages/AgentsPage/hooks/useZoomShortcuts.ts @@ -0,0 +1,23 @@ +import { useEffect } from "react"; +import type { ScaleMode } from "../components/RightPanel/DesktopToolbar"; + +export function useZoomShortcuts( + setScaleMode: (mode: ScaleMode) => void, + enabled = true, +) { + useEffect(() => { + if (!enabled) return; + const handleKeyDown = (e: KeyboardEvent) => { + const mod = e.ctrlKey || e.metaKey; + if (mod && e.key === "0") { + e.preventDefault(); + setScaleMode("fit"); + } else if (mod && e.key === "1") { + e.preventDefault(); + setScaleMode("native"); + } + }; + addEventListener("keydown", handleKeyDown); + return () => removeEventListener("keydown", handleKeyDown); + }, [enabled, setScaleMode]); +} diff --git a/site/src/pages/AgentsPage/utils/usageLimitMessage.ts b/site/src/pages/AgentsPage/utils/usageLimitMessage.ts index 1986da0227..1c7f10f00f 100644 --- a/site/src/pages/AgentsPage/utils/usageLimitMessage.ts +++ b/site/src/pages/AgentsPage/utils/usageLimitMessage.ts @@ -11,9 +11,8 @@ type UsageLimitData = Partial< /** * Typed classification for errors surfaced in the agent detail view. * - "usage_limit": the user hit a spending cap (409 + valid usage data). - * - other kinds come from normalized stream/provider failures such as - * "generic", "overloaded", "rate_limit", "timeout", - * "startup_timeout", "auth", and "config". + * - other kinds come from normalized stream/provider failures. + * See ChatErrorKind for the full set. */ export type ChatDetailError = { message: string; diff --git a/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/AuditLogDiff.tsx b/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/AuditLogDiff.tsx index 7565a144e2..d3e035c33d 100644 --- a/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/AuditLogDiff.tsx +++ b/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/AuditLogDiff.tsx @@ -1,43 +1,6 @@ import type { FC } from "react"; import type { AuditDiff } from "#/api/typesGenerated"; - -const getDiffValue = (value: unknown): string => { - if (typeof value === "string") { - return `"${value}"`; - } - - if (isTimeObject(value)) { - if (!value.Valid) { - return "null"; - } - - return new Date(value.Time).toLocaleString(); - } - - if (Array.isArray(value)) { - const values = value.map((v) => getDiffValue(v)); - return `[${values.join(", ")}]`; - } - - if (value === null || value === undefined) { - return "null"; - } - - return String(value); -}; - -const isTimeObject = ( - value: unknown, -): value is { Time: string; Valid: boolean } => { - return ( - value !== null && - typeof value === "object" && - "Time" in value && - typeof value.Time === "string" && - "Valid" in value && - typeof value.Valid === "boolean" - ); -}; +import { formatAuditDiffValue } from "./auditUtils"; interface AuditLogDiffProps { diff: AuditDiff; @@ -58,7 +21,9 @@ export const AuditLogDiff: FC<AuditLogDiffProps> = ({ diff }) => { <div> {attrName}:{" "} <span className="rounded p-px bg-red-800"> - {valueDiff.secret ? "••••••••" : getDiffValue(valueDiff.old)} + {valueDiff.secret + ? "••••••••" + : formatAuditDiffValue(valueDiff.old)} </span> </div> </div> @@ -74,7 +39,9 @@ export const AuditLogDiff: FC<AuditLogDiffProps> = ({ diff }) => { <div> {attrName}:{" "} <span className="rounded p-px bg-green-800"> - {valueDiff.secret ? "••••••••" : getDiffValue(valueDiff.new)} + {valueDiff.secret + ? "••••••••" + : formatAuditDiffValue(valueDiff.new)} </span> </div> </div> diff --git a/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.test.ts b/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.test.ts index 12c6cabb41..74cca5a2c6 100644 --- a/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.test.ts +++ b/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.test.ts @@ -1,4 +1,4 @@ -import { determineGroupDiff } from "./auditUtils"; +import { determineGroupDiff, formatAuditDiffValue } from "./auditUtils"; const auditDiffForNewGroup = { id: { @@ -120,3 +120,70 @@ describe("determineAuditDiff", () => { expect(determineGroupDiff(AuditDiffForDeletedGroup)).toEqual(result); }); }); + +describe("formatAuditDiffValue", () => { + it.each([ + { name: "string", value: "hello", expected: '"hello"' }, + { + name: "string containing double quotes", + value: 'he said "hello"', + expected: '"he said \\"hello\\""', + }, + { + name: "array of primitives", + value: ["admin", "auditor"], + expected: '["admin", "auditor"]', + }, + { name: "boolean true", value: true, expected: "true" }, + { name: "boolean false", value: false, expected: "false" }, + { name: "number", value: 42, expected: "42" }, + { name: "null", value: null, expected: "null" }, + { name: "undefined", value: undefined, expected: "null" }, + { + name: "invalid SQL time", + value: { Time: "0001-01-01T00:00:00Z", Valid: false }, + expected: "null", + }, + ])("preserves current behavior for $name", ({ value, expected }) => { + expect(formatAuditDiffValue(value)).toBe(expected); + }); + + it("preserves current behavior for valid SQL time objects", () => { + const value = { Time: "2024-10-22T09:03:23.961702Z", Valid: true }; + + expect(formatAuditDiffValue(value)).toBe( + new Date(value.Time).toLocaleString(), + ); + }); + + it("formats plain objects as deterministic compact JSON", () => { + expect( + formatAuditDiffValue({ + z: ["read"], + a: { permissions: ["read"] }, + }), + ).toBe('{"a":{"permissions":["read"]},"z":["read"]}'); + }); + + it("formats chat ACL objects as deterministic compact JSON", () => { + expect( + formatAuditDiffValue({ + "user-2": { permissions: ["read"] }, + "user-1": { permissions: ["read"] }, + }), + ).toBe( + '{"user-1":{"permissions":["read"]},"user-2":{"permissions":["read"]}}', + ); + }); + + it("formats arrays containing objects without object string coercion", () => { + expect( + formatAuditDiffValue([ + { user_id: "user-2", permissions: ["read"] }, + { permissions: ["read"], user_id: "user-1" }, + ]), + ).toBe( + '[{"permissions":["read"],"user_id":"user-2"}, {"permissions":["read"],"user_id":"user-1"}]', + ); + }); +}); diff --git a/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.ts b/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.ts index 7e9033841c..881463d46e 100644 --- a/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.ts +++ b/site/src/pages/AuditPage/AuditLogRow/AuditLogDiff/auditUtils.ts @@ -25,27 +25,68 @@ export const determineGroupDiff = (auditLogDiff: AuditDiff): AuditDiff => { }; /** - * - * @param auditLogDiff - * @returns a diff with the 'mappings' as a JSON string. Otherwise, it is [Object object] + * Formats an audit diff value for display. Strings are quoted, nullish values + * become "null", SQL time objects are localized, arrays are recursed, and plain + * objects are serialized as compact JSON with sorted keys. */ -export const determineIdPSyncMappingDiff = ( - auditLogDiff: AuditDiff, -): AuditDiff => { - const old = auditLogDiff.mapping?.old as Record<string, string[]> | undefined; - const new_ = auditLogDiff.mapping?.new as - | Record<string, string[]> - | undefined; - if (!old || !new_) { - return auditLogDiff; +export const formatAuditDiffValue = (value: unknown): string => { + if (typeof value === "string") { + return JSON.stringify(value); } - return { - ...auditLogDiff, - mapping: { - old: JSON.stringify(old), - new: JSON.stringify(new_), - secret: auditLogDiff.mapping?.secret, - }, - }; + if (isTimeObject(value)) { + if (!value.Valid) { + return "null"; + } + + return new Date(value.Time).toLocaleString(); + } + + if (Array.isArray(value)) { + const values = value.map((v) => formatAuditDiffValue(v)); + return `[${values.join(", ")}]`; + } + + if (value === null || value === undefined) { + return "null"; + } + + if (isPlainObject(value)) { + return JSON.stringify(sortObjectKeys(value)); + } + + return String(value); +}; + +const isTimeObject = ( + value: unknown, +): value is { Time: string; Valid: boolean } => { + return ( + value !== null && + typeof value === "object" && + "Time" in value && + typeof value.Time === "string" && + "Valid" in value && + typeof value.Valid === "boolean" + ); +}; + +const isPlainObject = (value: unknown): value is Record<string, unknown> => { + return Object.prototype.toString.call(value) === "[object Object]"; +}; + +const sortObjectKeys = (value: unknown): unknown => { + if (Array.isArray(value)) { + return value.map(sortObjectKeys); + } + + if (!isPlainObject(value)) { + return value; + } + + const sorted: Record<string, unknown> = {}; + for (const key of Object.keys(value).sort()) { + sorted[key] = sortObjectKeys(value[key]); + } + return sorted; }; diff --git a/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.stories.tsx b/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.stories.tsx index 2d4fd6e28e..d84ef16697 100644 --- a/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.stories.tsx +++ b/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.stories.tsx @@ -1,4 +1,5 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; +import type { AuditLog } from "#/api/typesGenerated"; import { Table, TableBody } from "#/components/Table/Table"; import { chromatic } from "#/testHelpers/chromatic"; import { @@ -187,3 +188,87 @@ export const WithConnectionType: Story = { }, }, }; + +const MockChatAuditLog: AuditLog = { + ...MockAuditLog, + resource_type: "chat", + resource_id: "c542b43f-4375-421a-a7e0-b39187e35131", + resource_target: "c542b43f", + resource_icon: "", + resource_link: "/agents/c542b43f-4375-421a-a7e0-b39187e35131", + description: "{user} updated chat {target}", + additional_fields: {}, +}; + +export const WithChatACLDiff: Story = { + parameters: { chromatic }, + args: { + auditLog: { + ...MockChatAuditLog, + id: "1d718c45-5dfb-4f24-9546-4f61fa8e3402", + action: "write", + description: "{user} updated sharing for chat {target}", + diff: { + user_acl: { + old: {}, + new: { + "9a68e35d-bf3a-43bd-8e68-130df721cc71": { + permissions: ["read"], + }, + }, + secret: false, + }, + group_acl: { + old: {}, + new: { + "6d130d81-017e-44ff-8fca-3a38623dcb14": { + permissions: ["read"], + }, + }, + secret: false, + }, + }, + }, + defaultIsDiffOpen: true, + }, +}; + +export const WithArchivedChatDescription: Story = { + args: { + auditLog: { + ...MockChatAuditLog, + id: "57329396-084a-4074-9930-385a7eed858a", + action: "write", + description: "{user} archived chat {target}", + diff: { + archived: { + old: false, + new: true, + secret: false, + }, + }, + }, + }, +}; + +export const WithUpdatedChatSharingDescription: Story = { + args: { + auditLog: { + ...MockChatAuditLog, + id: "8f26cabf-8867-4d2f-942d-77e759a16c1c", + action: "write", + description: "{user} updated sharing for chat {target}", + diff: { + user_acl: { + old: {}, + new: { + "9a68e35d-bf3a-43bd-8e68-130df721cc71": { + permissions: ["read"], + }, + }, + secret: false, + }, + }, + }, + }, +}; diff --git a/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.tsx b/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.tsx index ea0ccc37a8..448aa393af 100644 --- a/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.tsx +++ b/site/src/pages/AuditPage/AuditLogRow/AuditLogRow.tsx @@ -22,10 +22,7 @@ import { cn } from "#/utils/cn"; import { buildReasonLabels } from "#/utils/workspace"; import { AuditLogDescription } from "./AuditLogDescription/AuditLogDescription"; import { AuditLogDiff } from "./AuditLogDiff/AuditLogDiff"; -import { - determineGroupDiff, - determineIdPSyncMappingDiff, -} from "./AuditLogDiff/auditUtils"; +import { determineGroupDiff } from "./AuditLogDiff/auditUtils"; interface AuditLogRowProps { auditLog: AuditLog; @@ -53,14 +50,6 @@ export const AuditLogRow: FC<AuditLogRowProps> = ({ auditDiff = determineGroupDiff(auditLog.diff); } - if ( - auditLog.resource_type === "idp_sync_settings_organization" || - auditLog.resource_type === "idp_sync_settings_group" || - auditLog.resource_type === "idp_sync_settings_role" - ) { - auditDiff = determineIdPSyncMappingDiff(auditLog.diff); - } - const toggle = () => { if (shouldDisplayDiff) { setIsDiffOpen((v) => !v); diff --git a/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.stories.tsx b/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.stories.tsx index 842d25aee0..8bd0dd78bc 100644 --- a/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.stories.tsx +++ b/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.stories.tsx @@ -1,5 +1,5 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; -import { fn } from "storybook/test"; +import { expect, fn, screen } from "storybook/test"; import { AutoCreateConsentDialog } from "./AutoCreateConsentDialog"; const meta: Meta<typeof AutoCreateConsentDialog> = { @@ -77,6 +77,19 @@ export const WithLongValues: Story = { }, }; +export const WithPreset: Story = { + args: { + presetName: "gpu-large", + autofillParameters: [ + { name: "instance_type", value: "g6.4xlarge", source: "url" }, + ], + }, + play: async () => { + expect(screen.getAllByText("Preset:").length).toBeGreaterThan(0); + expect(screen.getAllByText("gpu-large").length).toBeGreaterThan(0); + }, +}; + export const NoParameters: Story = { args: { autofillParameters: [], diff --git a/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.tsx b/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.tsx index 58c3e552fc..8dfc900449 100644 --- a/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.tsx +++ b/site/src/pages/CreateWorkspacePage/AutoCreateConsentDialog.tsx @@ -14,6 +14,7 @@ import type { AutofillBuildParameter } from "#/utils/richParameters"; interface AutoCreateConsentDialogProps { open: boolean; autofillParameters: AutofillBuildParameter[]; + presetName?: string; onConfirm: () => void; onDeny: () => void; } @@ -21,6 +22,7 @@ interface AutoCreateConsentDialogProps { export const AutoCreateConsentDialog: FC<AutoCreateConsentDialogProps> = ({ open, autofillParameters, + presetName, onConfirm, onDeny, }) => { @@ -43,6 +45,17 @@ export const AutoCreateConsentDialog: FC<AutoCreateConsentDialogProps> = ({ </DialogDescription> </DialogHeader> + {presetName && ( + <div className="flex min-w-0 flex-col gap-2"> + <span className="text-sm font-semibold text-content-primary"> + Preset: + </span> + <code className="block whitespace-pre overflow-x-auto"> + {presetName} + </code> + </div> + )} + {autofillParameters.length > 0 && ( <div className="flex min-w-0 flex-col gap-2"> <span className="text-sm font-semibold text-content-primary"> diff --git a/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.test.tsx b/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.test.tsx index 97547f9566..1b7963cf2a 100644 --- a/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.test.tsx +++ b/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.test.tsx @@ -2,7 +2,7 @@ import { screen, waitFor, within } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { act } from "react"; import { API } from "#/api/api"; -import type { DynamicParametersResponse } from "#/api/typesGenerated"; +import type { DynamicParametersResponse, Preset } from "#/api/typesGenerated"; import { MockDropdownParameter, MockDynamicParametersResponse, @@ -11,6 +11,7 @@ import { MockPreviewParameter, MockSliderParameter, MockTemplate, + MockTemplateVersion, MockTemplateVersionExternalAuthGithub, MockTemplateVersionExternalAuthGithubAuthenticated, MockUserOwner, @@ -40,10 +41,24 @@ describe("CreateWorkspacePage", () => { }); }; + const mockGpuPreset: Preset = { + ID: "preset-gpu", + Name: "gpu-large", + Parameters: [ + { Name: "instance_type", Value: "t3.medium" }, + { Name: "cpu_count", Value: "4" }, + ], + Default: false, + DesiredPrebuildInstances: null, + Description: "GPU Large preset", + Icon: "", + }; + beforeEach(() => { vi.clearAllMocks(); vi.spyOn(API, "getTemplate").mockResolvedValue(MockTemplate); + vi.spyOn(API, "getTemplateVersion").mockResolvedValue(MockTemplateVersion); vi.spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([]); vi.spyOn(API, "getTemplateVersionPresets").mockResolvedValue([]); vi.spyOn(API, "createWorkspace").mockResolvedValue(MockWorkspace); @@ -446,7 +461,7 @@ describe("CreateWorkspacePage", () => { `/templates/${MockTemplate.name}/workspace?mode=auto`, ); - // Consent dialog appears for mode=auto — confirm to proceed. + // Consent dialog appears for mode=auto. Confirm to proceed. const confirmButton = await screen.findByRole("button", { name: /confirm and create/i, }); @@ -550,6 +565,158 @@ describe("CreateWorkspacePage", () => { }); }); + describe("URL Presets", () => { + it("resolves a preset from the URL and selects it in the form", async () => { + vi.spyOn(API, "getTemplateVersionPresets").mockResolvedValue([ + mockGpuPreset, + ]); + + renderCreateWorkspacePage( + `/templates/${MockTemplate.name}/workspace?preset=gpu-large`, + ); + await waitForLoaderToBeRemoved(); + + expect( + screen.getByRole("button", { name: /gpu-large/i }), + ).toBeInTheDocument(); + }); + + it("resolves a preset against the pinned template version", async () => { + const getTemplateVersionPresetsSpy = vi + .spyOn(API, "getTemplateVersionPresets") + .mockResolvedValue([mockGpuPreset]); + + renderCreateWorkspacePage( + `/templates/${MockTemplate.name}/workspace?version=custom-version&preset=gpu-large`, + ); + + await waitFor(() => { + expect(getTemplateVersionPresetsSpy).toHaveBeenCalledWith( + "custom-version", + ); + }); + }); + + it("falls back to form mode when auto-create cannot resolve the preset", async () => { + vi.spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([ + MockTemplateVersionExternalAuthGithubAuthenticated, + ]); + vi.spyOn(API, "getTemplateVersionPresets").mockResolvedValue([ + mockGpuPreset, + ]); + + renderCreateWorkspacePage( + `/templates/${MockTemplate.name}/workspace?mode=auto&preset=missing`, + ); + await waitForLoaderToBeRemoved(); + + expect( + screen.queryByRole("button", { name: /confirm and create/i }), + ).not.toBeInTheDocument(); + expect( + screen.getByText(/auto-creation has been disabled/i), + ).toBeInTheDocument(); + expect( + screen.getByText( + /preset "missing" not found on template version "test-version"/i, + ), + ).toBeInTheDocument(); + expect(API.createWorkspace).not.toHaveBeenCalled(); + }); + + it("falls back to form mode when presets fail to load", async () => { + vi.spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([ + MockTemplateVersionExternalAuthGithubAuthenticated, + ]); + vi.spyOn(API, "getTemplateVersionPresets").mockRejectedValue( + new Error("presets unavailable"), + ); + + renderCreateWorkspacePage( + `/templates/${MockTemplate.name}/workspace?mode=auto&preset=gpu-large`, + ); + await waitForLoaderToBeRemoved(); + + expect( + screen.queryByRole("button", { name: /confirm and create/i }), + ).not.toBeInTheDocument(); + expect( + screen.getByText(/auto-creation has been disabled/i), + ).toBeInTheDocument(); + expect( + screen.getByText(/failed to load presets: presets unavailable/i), + ).toBeInTheDocument(); + expect(API.createWorkspace).not.toHaveBeenCalled(); + }); + + it("uses preset parameters instead of param values", async () => { + vi.spyOn(API, "getTemplateVersionPresets").mockResolvedValue([ + mockGpuPreset, + ]); + + renderCreateWorkspacePage( + `/templates/${MockTemplate.name}/workspace?preset=gpu-large¶m.instance_type=t3.small¶m.cpu_count=99`, + ); + await waitForLoaderToBeRemoved(); + + expect(screen.getAllByText(/param\.\*/i).length).toBeGreaterThan(0); + + const nameInput = screen.getByRole("textbox", { + name: /workspace name/i, + }); + await userEvent.type(nameInput, "preset-workspace"); + + await userEvent.click( + screen.getByRole("button", { name: /create workspace/i }), + ); + + await waitFor(() => { + expect(API.createWorkspace).toHaveBeenCalledWith( + "test-user", + expect.objectContaining({ + template_version_preset_id: mockGpuPreset.ID, + rich_parameter_values: expect.arrayContaining([ + expect.objectContaining({ + name: "instance_type", + value: "t3.medium", + }), + expect.objectContaining({ name: "cpu_count", value: "4" }), + ]), + }), + ); + }); + }); + + it("auto-creates with the preset ID after the preset resolves", async () => { + vi.spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([ + MockTemplateVersionExternalAuthGithubAuthenticated, + ]); + vi.spyOn(API, "getTemplateVersionPresets").mockResolvedValue([ + mockGpuPreset, + ]); + + renderCreateWorkspacePage( + `/templates/${MockTemplate.name}/workspace?mode=auto&preset=gpu-large&name=preset-workspace`, + ); + + const confirmButton = await screen.findByRole("button", { + name: /confirm and create/i, + }); + await userEvent.click(confirmButton); + + await waitFor(() => { + expect(API.createWorkspace).toHaveBeenCalledWith( + "me", + expect.objectContaining({ + name: "preset-workspace", + template_version_preset_id: mockGpuPreset.ID, + rich_parameter_values: [], + }), + ); + }); + }); + }); + describe("Navigation", () => { it("navigates to workspace after successful creation", async () => { const { router } = renderCreateWorkspacePage(); diff --git a/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.tsx b/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.tsx index 5dc9ab5f70..74a8356bcd 100644 --- a/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.tsx +++ b/site/src/pages/CreateWorkspacePage/CreateWorkspacePage.tsx @@ -60,6 +60,7 @@ const CreateWorkspacePage: FC = () => { const customVersionId = searchParams.get("version") ?? undefined; const defaultName = searchParams.get("name"); const disabledParams = searchParams.get("disable_params")?.split(","); + const presetName = searchParams.get("preset") || undefined; const [mode, setMode] = useState(() => getWorkspaceMode(searchParams)); const [autoCreateConsented, setAutoCreateConsented] = useState(false); const [autoCreateError, setAutoCreateError] = @@ -76,9 +77,12 @@ const CreateWorkspacePage: FC = () => { const templateQuery = useQuery( templateByName(organizationName, templateName), ); + const realizedVersionId = + customVersionId ?? templateQuery.data?.active_version_id; + const templateVersionPresetsQuery = useQuery({ - ...templateVersionPresets(templateQuery.data?.active_version_id ?? ""), - enabled: Boolean(templateQuery.data), + ...templateVersionPresets(realizedVersionId ?? ""), + enabled: realizedVersionId !== undefined, }); const permissionsQuery = useQuery({ ...checkAuthorization({ @@ -89,15 +93,68 @@ const CreateWorkspacePage: FC = () => { }), enabled: Boolean(templateQuery.data), }); - const realizedVersionId = - customVersionId ?? templateQuery.data?.active_version_id; const templateVersionQuery = useQuery({ ...templateVersion(realizedVersionId ?? ""), enabled: realizedVersionId !== undefined, }); - const autofillParameters = getAutofillParameters(searchParams); + const effectivePresetName = mode === "duplicate" ? undefined : presetName; + + const presets = templateVersionPresetsQuery.data ?? []; + + const urlPresetResult = useMemo(() => { + if (!effectivePresetName) return { preset: undefined, error: undefined }; + + if (templateVersionPresetsQuery.isError) { + return { + preset: undefined, + error: `Failed to load presets: ${templateVersionPresetsQuery.error?.message ?? "unknown error"}. Please try refreshing the page.`, + }; + } + + if (!templateVersionPresetsQuery.isSuccess) { + return { preset: undefined, error: undefined }; // Still loading + } + + const found = presets.find((p) => p.Name === effectivePresetName); + if (!found) { + const versionLabel = templateVersionQuery.data?.name ?? realizedVersionId; + return { + preset: undefined, + error: `Preset "${effectivePresetName}" not found on template version "${versionLabel}". Check that the preset name matches exactly (names are case-sensitive).`, + }; + } + return { preset: found, error: undefined }; + }, [ + effectivePresetName, + presets, + templateVersionPresetsQuery.isSuccess, + templateVersionPresetsQuery.isError, + templateVersionPresetsQuery.error, + realizedVersionId, + templateVersionQuery.data?.name, + ]); + + const urlAutofillParameters = useMemo( + () => getAutofillParameters(searchParams), + [searchParams], + ); + const autofillParameters = useMemo(() => { + if (!urlPresetResult.preset) return urlAutofillParameters; + + const presetParams: AutofillBuildParameter[] = + urlPresetResult.preset.Parameters.map((p) => ({ + name: p.Name, + value: p.Value, + source: "url" as const, + })); + + return presetParams; + }, [urlPresetResult.preset, urlAutofillParameters]); + + const hasIgnoredUrlParams = + urlAutofillParameters.length > 0 && urlPresetResult.preset !== undefined; const sendMessage = ( formValues: Record<string, string>, @@ -227,10 +284,11 @@ const CreateWorkspacePage: FC = () => { const newWorkspace = await autoCreateWorkspaceMutation.mutateAsync({ organizationId, templateName, - buildParameters: autofillParameters, + buildParameters: urlPresetResult.preset ? [] : autofillParameters, workspaceName: defaultName ?? generateWorkspaceName(), templateVersionId: realizedVersionId, match: searchParams.get("match"), + templateVersionPresetId: urlPresetResult.preset?.ID, }); onCreateWorkspace(newWorkspace); @@ -244,11 +302,22 @@ const CreateWorkspacePage: FC = () => { externalAuth?.every((auth) => auth.optional || auth.authenticated), ); + const presetResolved = + !effectivePresetName || + (templateVersionPresetsQuery.isSuccess && + urlPresetResult.preset !== undefined); + let autoCreateReady = - mode === "auto" && hasAllRequiredExternalAuth && autoCreateConsented; + mode === "auto" && + hasAllRequiredExternalAuth && + autoCreateConsented && + presetResolved; const showAutoCreateConsent = - mode === "auto" && !autoCreateConsented && !autoCreateError; + mode === "auto" && + !autoCreateConsented && + !autoCreateError && + presetResolved; // `mode=auto` was set, but a prerequisite has failed, and so auto-mode should be abandoned. if ( @@ -275,6 +344,23 @@ const CreateWorkspacePage: FC = () => { }); } + if ( + mode === "auto" && + hasAllRequiredExternalAuth && + effectivePresetName && + ((templateVersionPresetsQuery.isSuccess && !urlPresetResult.preset) || + templateVersionPresetsQuery.isError) + ) { + setMode("form"); + autoCreateReady = false; + setAutoCreateError({ + message: "Auto-creation has been disabled.", + detail: + urlPresetResult.error ?? + "The requested preset could not be resolved. Please check the preset value before continuing.", + }); + } + useEffect(() => { if (autoCreateReady) { void automateWorkspaceCreation(); @@ -293,7 +379,10 @@ const CreateWorkspacePage: FC = () => { isLoadingFormData || isLoadingExternalAuth || autoCreateReady || - (!latestResponse && !wsError); + (!latestResponse && !wsError) || + (effectivePresetName && + !templateVersionPresetsQuery.isSuccess && + !templateVersionPresetsQuery.isError); return ( <> @@ -301,6 +390,7 @@ const CreateWorkspacePage: FC = () => { <AutoCreateConsentDialog open={showAutoCreateConsent} + presetName={effectivePresetName} autofillParameters={autofillParameters} onConfirm={() => setAutoCreateConsented(true)} onDeny={() => setMode("form")} @@ -336,7 +426,14 @@ const CreateWorkspacePage: FC = () => { hasAllRequiredExternalAuth={hasAllRequiredExternalAuth} permissions={permissionsQuery.data as CreateWorkspacePermissions} parameters={sortedParams} - presets={templateVersionPresetsQuery.data ?? []} + presets={presets} + urlPreset={urlPresetResult.preset} + urlPresetError={ + autoCreateError?.detail === urlPresetResult.error + ? undefined + : urlPresetResult.error + } + hasIgnoredUrlParams={hasIgnoredUrlParams} creatingWorkspace={createWorkspaceMutation.isPending} sendMessage={sendMessage} onCancel={() => { diff --git a/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.stories.tsx b/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.stories.tsx index b689d72bd0..ea10e83b2c 100644 --- a/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.stories.tsx +++ b/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.stories.tsx @@ -1,7 +1,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; import { expect, screen, within } from "storybook/test"; import { DetailedError } from "#/api/errors"; -import type { PreviewParameter } from "#/api/typesGenerated"; +import type { Preset, PreviewParameter } from "#/api/typesGenerated"; import { chromatic } from "#/testHelpers/chromatic"; import { MockTemplate, MockUserOwner } from "#/testHelpers/entities"; import { CreateWorkspacePageView } from "./CreateWorkspacePageView"; @@ -277,6 +277,39 @@ const parameterTextarea: PreviewParameter = { ephemeral: false, }; +const gpuLargePreset: Preset = { + ID: "preset-1", + Name: "GPU Large", + Description: "GPU Large preset", + Parameters: [ + { Name: "instance_type", Value: "t3.large" }, + { Name: "enable_gpu", Value: "true" }, + ], + Default: false, + DesiredPrebuildInstances: null, + Icon: "/emojis/1f4bb.png", +}; + +const cpuSmallPreset: Preset = { + ID: "preset-2", + Name: "CPU Small", + Description: "CPU Small preset", + Parameters: [{ Name: "instance_type", Value: "t3.micro" }], + Default: false, + DesiredPrebuildInstances: null, + Icon: "/emojis/1f4bc.png", +}; + +const urlPreset: Preset = { + ID: "preset-url", + Name: "URL Preset", + Description: "The URL-specified preset", + Parameters: [{ Name: "instance_type", Value: "t3.large" }], + Default: false, + DesiredPrebuildInstances: null, + Icon: "/emojis/1f534.png", +}; + const parameterCheckbox: PreviewParameter = { name: "auto_stop", display_name: "Auto-stop", @@ -356,3 +389,70 @@ export const WithPresets: Story = { parameters: [parameterInput, parameterDropdown], }, }; + +export const WithUrlPreset: Story = { + args: { + presets: [gpuLargePreset, cpuSmallPreset], + urlPreset: gpuLargePreset, + parameters: [parameterDropdown, parameterSwitch], + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + canvas.getByRole("button", { name: /GPU Large/i }), + ).toBeInTheDocument(); + }, +}; + +export const WithUrlPresetNotFound: Story = { + args: { + presets: [gpuLargePreset], + urlPresetError: + 'Preset "gpu-large" not found on template version "test-version". Check that the preset name matches exactly (names are case-sensitive).', + parameters: [parameterDropdown], + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + canvas.getByText(/Preset "gpu-large" not found on template version/i), + ).toBeVisible(); + }, +}; + +export const WithUrlPresetAndIgnoredParams: Story = { + args: { + presets: [gpuLargePreset], + urlPreset: gpuLargePreset, + hasIgnoredUrlParams: true, + parameters: [parameterDropdown], + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect(canvas.getAllByText(/param\.\*/i).length).toBeGreaterThan(0); + }, +}; + +export const WithUrlPresetOverridesDefault: Story = { + args: { + presets: [ + { + ID: "preset-default", + Name: "Default Preset", + Description: "The default preset", + Parameters: [{ Name: "instance_type", Value: "t3.micro" }], + Default: true, + DesiredPrebuildInstances: null, + Icon: "/emojis/1f7e2.png", + }, + urlPreset, + ], + urlPreset, + parameters: [parameterDropdown], + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + canvas.getByRole("button", { name: /URL Preset/i }), + ).toBeInTheDocument(); + }, +}; diff --git a/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.tsx b/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.tsx index 167b93838d..fb12e3fc9a 100644 --- a/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.tsx +++ b/site/src/pages/CreateWorkspacePage/CreateWorkspacePageView.tsx @@ -68,11 +68,14 @@ interface CreateWorkspacePageViewProps { externalAuth: TypesGen.TemplateVersionExternalAuth[]; externalAuthPollingState: ExternalAuthPollingState; hasAllRequiredExternalAuth: boolean; + hasIgnoredUrlParams?: boolean; mode: CreateWorkspaceMode; parameters: PreviewParameter[]; permissions: CreateWorkspacePermissions; presets: TypesGen.Preset[]; template: TypesGen.Template; + urlPreset?: TypesGen.Preset; + urlPresetError?: string; versionId?: string; versionName?: string; onCancel: () => void; @@ -99,11 +102,14 @@ export const CreateWorkspacePageView: FC<CreateWorkspacePageViewProps> = ({ externalAuth, externalAuthPollingState, hasAllRequiredExternalAuth, + hasIgnoredUrlParams, mode, parameters, permissions, presets = [], template, + urlPreset, + urlPresetError, versionId, versionName, onSubmit, @@ -202,6 +208,15 @@ export const CreateWorkspacePageView: FC<CreateWorkspacePageViewProps> = ({ })), ]; setPresetOptions(options); + + // URL preset takes precedence over default preset. + if (urlPreset) { + const idx = presets.findIndex((p) => p.ID === urlPreset.ID) + 1; + setSelectedPresetIndex(idx); + form.setFieldValue("template_version_preset_id", urlPreset.ID); + return; + } + const defaultPreset = presets.find((p) => p.Default); if (defaultPreset) { const idx = presets.indexOf(defaultPreset) + 1; // +1 for "None" @@ -211,7 +226,7 @@ export const CreateWorkspacePageView: FC<CreateWorkspacePageViewProps> = ({ setSelectedPresetIndex(0); // Explicitly set to "None" form.setFieldValue("template_version_preset_id", undefined); } - }, [presets, form.setFieldValue]); + }, [presets, form.setFieldValue, urlPreset]); const [presetParameterNames, setPresetParameterNames] = useState<string[]>( [], @@ -451,6 +466,20 @@ export const CreateWorkspacePageView: FC<CreateWorkspacePageViewProps> = ({ > {Boolean(error) && <ErrorAlert error={error} />} + {urlPresetError && ( + <Alert severity="warning" dismissible> + {urlPresetError} + </Alert> + )} + + {hasIgnoredUrlParams && urlPreset && ( + <Alert severity="info" dismissible> + Preset selected. <code>param.*</code> URL parameters have been + ignored. Use either <code>preset</code> or <code>param.*</code>, + not both. + </Alert> + )} + {mode === "duplicate" && ( <Alert severity="info" @@ -713,7 +742,10 @@ export const CreateWorkspacePageView: FC<CreateWorkspacePageViewProps> = ({ } disabled={isDisabled} isPreset={isPresetParameter} - autofill={autofillByName[parameter.name] !== undefined} + autofill={ + !isPresetParameter && + autofillByName[parameter.name] !== undefined + } value={formValue} /> ); diff --git a/site/src/router.tsx b/site/src/router.tsx index 96871176c6..6ac39cec75 100644 --- a/site/src/router.tsx +++ b/site/src/router.tsx @@ -351,6 +351,9 @@ const ProvisionerJobsPage = lazy( const AgentsPage = lazy(() => import("./pages/AgentsPage/AgentsPage")); const AgentChatPage = lazy(() => import("./pages/AgentsPage/AgentChatPage")); const AgentEmbedPage = lazy(() => import("./pages/AgentsPage/AgentEmbedPage")); +const DesktopPopoutPage = lazy( + () => import("./pages/AgentsPage/DesktopPopoutPage"), +); const AgentCreatePage = lazy( () => import("./pages/AgentsPage/AgentCreatePage"), ); @@ -810,6 +813,18 @@ export const router = createBrowserRouter( } /> </Route> + <Route + path="/agents/:agentId/desktop" + element={ + <Suspense + fallback={ + <div className="flex h-screen w-screen items-center justify-center" /> + } + > + <DesktopPopoutPage /> + </Suspense> + } + /> </Route> <Route diff --git a/testutil/expecter/expecter.go b/testutil/expecter/expecter.go new file mode 100644 index 0000000000..333e9a18ab --- /dev/null +++ b/testutil/expecter/expecter.go @@ -0,0 +1,363 @@ +package expecter + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "regexp" + "slices" + "strings" + "testing" + "time" + "unicode/utf8" + + "github.com/acarl005/stripansi" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/testutil" + "github.com/coder/serpent" +) + +func New(t *testing.T, r io.Reader, name string) *Expecter { + // Use pipe for logging. + logDone := make(chan struct{}) + logr, logw := io.Pipe() + + // Write to log and output buffer. + copyDone := make(chan struct{}) + out := newStdbuf() + w := io.MultiWriter(logw, out) + + ex := &Expecter{ + t: t, + out: out, + name: atomic.NewString(name), + + runeReader: bufio.NewReaderSize(out, utf8.UTFMax), + logDone: logDone, + copyDone: copyDone, + logr: logr, + logw: logw, + } + + go func() { + defer close(copyDone) + _, err := io.Copy(w, r) + ex.Logf("copy done: %v", err) + ex.Logf("closing out") + err = out.closeErr(err) + ex.Logf("closed out: %v", err) + }() + + // Log all output as part of test for easier debugging on errors. + go func() { + defer close(logDone) + s := bufio.NewScanner(logr) + for s.Scan() { + ex.Logf("%q", stripansi.Strip(s.Text())) + } + // Surface non-EOF scanner errors; otherwise they're invisible. + if err := s.Err(); err != nil { + ex.Logf("log scanner stopped: %v", err) + } + }() + + return ex +} + +func NewAttachedToInvocation(t *testing.T, invocation *serpent.Invocation) *Expecter { + r, w := io.Pipe() + invocation.Stdout = w + invocation.Stderr = w + e := New(t, r, "cmd") + + t.Cleanup(func() { + // Serpent doesn't handle closing stdout after running the Invocation; normally the OS does that automatically when + // the process exits. Close it here at the end of the test to ensure we don't leak goroutines reading from the + // stdout/stderr. + _ = w.Close() + e.Close("test end") + }) + return e +} + +type Expecter struct { + t *testing.T + out *stdbuf + name *atomic.String + + runeReader *bufio.Reader + copyDone, logDone chan struct{} + logr, logw io.Closer +} + +// Rename the expecter. Make sure you set this before anything starts writing to the +// stream, or it may not be named consistently. +func (e *Expecter) Rename(name string) { + e.name.Store(name) +} + +func (e *Expecter) Close(reason string) { + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + e.Logf("closing expecter: %s", reason) + + // Caller needs to have closed the stream so that copying can complete + select { + case <-ctx.Done(): + e.fatalf("close", "copy did not close in time") + return + case <-e.copyDone: + } + + e.logClose("logw", e.logw) + e.logClose("logr", e.logr) + select { + case <-ctx.Done(): + e.fatalf("close", "log pipe did not close in time") + return + case <-e.logDone: + } + + e.Logf("closed expecter") +} + +func (e *Expecter) logClose(name string, c io.Closer) { + e.Logf("closing %s", name) + err := c.Close() + e.Logf("closed %s: %v", name, err) +} + +// Deprecated: use ExpectMatchContext instead. +// This uses a background context, so will not respect the test's context. +func (e *Expecter) ExpectMatch(str string) string { + return e.expectMatchContextFunc(str, e.ExpectMatchContext) +} + +func (e *Expecter) ExpectRegexMatch(str string) string { + return e.expectMatchContextFunc(str, e.ExpectRegexMatchContext) +} + +func (e *Expecter) expectMatchContextFunc(str string, fn func(ctx context.Context, str string) string) string { + e.t.Helper() + + timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + return fn(timeout, str) +} + +// TODO(mafredri): Rename this to ExpectMatch when refactoring. +func (e *Expecter) ExpectMatchContext(ctx context.Context, str string) string { + return e.expectMatcherFunc(ctx, str, strings.Contains) +} + +func (e *Expecter) ExpectRegexMatchContext(ctx context.Context, str string) string { + return e.expectMatcherFunc(ctx, str, func(src, pattern string) bool { + return regexp.MustCompile(pattern).MatchString(src) + }) +} + +func (e *Expecter) expectMatcherFunc(ctx context.Context, str string, fn func(src, pattern string) bool) string { + e.t.Helper() + + var buffer bytes.Buffer + err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func(rd *bufio.Reader) error { + for { + r, _, err := rd.ReadRune() + if err != nil { + return err + } + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + if fn(buffer.String(), str) { + return nil + } + } + }) + if err != nil { + e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) + return "" + } + e.Logf("matched %q = %q", str, buffer.String()) + return buffer.String() +} + +// ExpectNoMatchBefore validates that `match` does not occur before `before`. +func (e *Expecter) ExpectNoMatchBefore(ctx context.Context, match, before string) string { + e.t.Helper() + + var buffer bytes.Buffer + err := e.doMatchWithDeadline(ctx, "ExpectNoMatchBefore", func(rd *bufio.Reader) error { + for { + r, _, err := rd.ReadRune() + if err != nil { + return err + } + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + + if strings.Contains(buffer.String(), match) { + return xerrors.Errorf("found %q before %q", match, before) + } + + if strings.Contains(buffer.String(), before) { + return nil + } + } + }) + if err != nil { + e.fatalf("read error", "%v (wanted no %q before %q; got %q)", err, match, before, buffer.String()) + return "" + } + e.Logf("matched %q = %q", before, stripansi.Strip(buffer.String())) + return buffer.String() +} + +func (e *Expecter) Peek(ctx context.Context, n int) []byte { + e.t.Helper() + + var out []byte + err := e.doMatchWithDeadline(ctx, "Peek", func(rd *bufio.Reader) error { + var err error + out, err = rd.Peek(n) + return err + }) + if err != nil { + e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) + return nil + } + e.Logf("peeked %d/%d bytes = %q", len(out), n, out) + return slices.Clone(out) +} + +//nolint:govet // We don't care about conforming to ReadRune() (rune, int, error). +func (e *Expecter) ReadRune(ctx context.Context) rune { + e.t.Helper() + + var r rune + err := e.doMatchWithDeadline(ctx, "ReadRune", func(rd *bufio.Reader) error { + var err error + r, _, err = rd.ReadRune() + return err + }) + if err != nil { + e.fatalf("read error", "%v (wanted rune; got %q)", err, r) + return 0 + } + e.Logf("matched rune = %q", r) + return r +} + +func (e *Expecter) ReadLine(ctx context.Context) string { + e.t.Helper() + + var buffer bytes.Buffer + err := e.doMatchWithDeadline(ctx, "ReadLine", func(rd *bufio.Reader) error { + for { + r, _, err := rd.ReadRune() + if err != nil { + return err + } + if r == '\n' { + return nil + } + if r == '\r' { + // Peek the next rune to see if it's an LF and then consume + // it. + + // Unicode code points can be up to 4 bytes, but the + // ones we're looking for are only 1 byte. + b, _ := rd.Peek(1) + if len(b) == 0 { + return nil + } + + r, _ = utf8.DecodeRune(b) + if r == '\n' { + _, _, err = rd.ReadRune() + if err != nil { + return err + } + } + + return nil + } + + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + } + }) + if err != nil { + e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) + return "" + } + e.Logf("matched newline = %q", buffer.String()) + return buffer.String() +} + +func (e *Expecter) ReadAll() []byte { + e.t.Helper() + return e.out.ReadAll() +} + +func (e *Expecter) doMatchWithDeadline(ctx context.Context, name string, fn func(*bufio.Reader) error) error { + e.t.Helper() + + // A timeout is mandatory, caller can decide by passing a context + // that times out. + if _, ok := ctx.Deadline(); !ok { + timeout := testutil.WaitMedium + e.Logf("%s ctx has no deadline, using %s", name, timeout) + var cancel context.CancelFunc + //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + match := make(chan error, 1) + go func() { + defer close(match) + match <- fn(e.runeReader) + }() + select { + case err := <-match: + return err + case <-ctx.Done(): + // Ensure goroutine is cleaned up before test exit, do not call + // (*outExpecter).close here to let the caller decide. + _ = e.out.Close() + <-match + + return xerrors.Errorf("match deadline exceeded: %w", ctx.Err()) + } +} + +func (e *Expecter) Logf(format string, args ...interface{}) { + e.t.Helper() + + // Match regular logger timestamp format, we seem to be logging in + // UTC in other places as well, so match here. + e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name.Load(), fmt.Sprintf(format, args...)) +} + +func (e *Expecter) fatalf(reason string, format string, args ...interface{}) { + e.t.Helper() + + // Ensure the message is part of the normal log stream before + // failing the test. + e.Logf("%s: %s", reason, fmt.Sprintf(format, args...)) + + require.FailNowf(e.t, reason, format, args...) +} diff --git a/testutil/expecter/stdbuf.go b/testutil/expecter/stdbuf.go new file mode 100644 index 0000000000..092f401d1e --- /dev/null +++ b/testutil/expecter/stdbuf.go @@ -0,0 +1,119 @@ +package expecter + +import ( + "bytes" + "io" + "sync" + + "golang.org/x/xerrors" +) + +// stdbuf is like a buffered stdout, it buffers writes until read. +type stdbuf struct { + r io.Reader + + mu sync.Mutex // Protects following. + b []byte + more chan struct{} + err error +} + +func newStdbuf() *stdbuf { + return &stdbuf{more: make(chan struct{}, 1)} +} + +func (b *stdbuf) ReadAll() []byte { + b.mu.Lock() + defer b.mu.Unlock() + + if b.err != nil { + return nil + } + p := append([]byte(nil), b.b...) + b.b = b.b[len(b.b):] + return p +} + +func (b *stdbuf) Read(p []byte) (int, error) { + if b.r == nil { + return b.readOrWaitForMore(p) + } + + n, err := b.r.Read(p) + if xerrors.Is(err, io.EOF) { + b.r = nil + err = nil + if n == 0 { + return b.readOrWaitForMore(p) + } + } + return n, err +} + +func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + + // Deplete channel so that more check + // is for future input into buffer. + select { + case <-b.more: + default: + } + + if len(b.b) == 0 { + if b.err != nil { + return 0, b.err + } + + b.mu.Unlock() + <-b.more + b.mu.Lock() + } + + b.r = bytes.NewReader(b.b) + b.b = b.b[len(b.b):] + + return b.r.Read(p) +} + +func (b *stdbuf) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.err != nil { + return 0, b.err + } + + b.b = append(b.b, p...) + + select { + case b.more <- struct{}{}: + default: + } + + return len(p), nil +} + +func (b *stdbuf) Close() error { + return b.closeErr(nil) +} + +func (b *stdbuf) closeErr(err error) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.err != nil { + return err + } + if err == nil { + b.err = io.EOF + } else { + b.err = err + } + close(b.more) + return err +} diff --git a/pty/ptytest/ptytest_internal_test.go b/testutil/expecter/stdbuf_internal_test.go similarity index 97% rename from pty/ptytest/ptytest_internal_test.go rename to testutil/expecter/stdbuf_internal_test.go index 2915417863..02365a8ff6 100644 --- a/pty/ptytest/ptytest_internal_test.go +++ b/testutil/expecter/stdbuf_internal_test.go @@ -1,4 +1,4 @@ -package ptytest +package expecter import ( "bytes" diff --git a/testutil/writer.go b/testutil/writer.go new file mode 100644 index 0000000000..4def987e62 --- /dev/null +++ b/testutil/writer.go @@ -0,0 +1,55 @@ +package testutil + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "gvisor.dev/gvisor/pkg/context" + + "cdr.dev/slog/v3" + "github.com/coder/serpent" +) + +// Writer wraps an underlying io.Writer and provides friendlier methods to write to it, including logging. +type Writer struct { + t *testing.T + w io.Writer + l slog.Logger +} + +func NewWriterAttachedToInvocation(t *testing.T, logger slog.Logger, invocation *serpent.Invocation) *Writer { + r, w := io.Pipe() + invocation.Stdin = r + // Close the pipe at the end of the test to ensure any goroutine in the Invocation that reads from stdin won't leak. + t.Cleanup(func() { + _ = w.Close() + }) + return &Writer{ + t: t, + w: w, + l: logger, + } +} + +func (w *Writer) Write(r rune) { + w.t.Helper() + _, err := w.w.Write([]byte{byte(r)}) + if assert.NoError(w.t, err, "write failed") { + w.l.Debug(context.Background(), "wrote rune", slog.F("rune", r)) + } +} + +func (w *Writer) WriteLine(str string) { + w.t.Helper() + + // Always write Windows style endings since our CLI prompt readers trim both out. Note this is *different* than what + // PTY-based tests do. On Unix-like operating systems we write a single carriage-return (\r) to delimit a line + // and the PTY translates it to a line feed (\n) for the CLI command to read. Here there is no translation. + newline := []byte{'\r', '\n'} + + _, err := w.w.Write(append([]byte(str), newline...)) + if assert.NoError(w.t, err, "write line failed") { + w.l.Debug(context.Background(), "wrote line", slog.F("line", str+string(newline))) + } +}