feat: add experimental agents support (#22290)

feat: add AI chat system with agent tools and chat UI

Introduce the chatd subsystem and Agents UI for AI-powered chat
within Coder workspaces.

- Add chatd package with chat loop, message compaction, prompt
  management, and LLM provider integration (OpenAI, Anthropic)
- Add agent tools: create workspace, list/read templates, read/write/
  edit files, execute commands
- Add chat API endpoints with streaming, message editing, and
  durable reconnection
- Add database schema and migrations for chats, chat messages, chat
  providers, and chat model configs
- Add RBAC policies and dbauthz enforcement for chat resources
- Add Agents UI pages with conversation timeline, queued messages
  list, diff viewer, and model configuration panel
- Add comprehensive test coverage including coderd integration tests,
  chatd unit tests, and Storybook stories
- Gate feature behind experiments flag

---------

Co-authored-by: Cian Johnston <cian@coder.com>
Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
Co-authored-by: Jeremy Ruppel <jeremy@coder.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Kyle Carberry
2026-02-27 11:50:56 -05:00
committed by GitHub
parent 67da4e8b56
commit edee917d88
201 changed files with 44828 additions and 1859 deletions
+16 -8
View File
@@ -854,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
go run -C ./scripts/apitypings main.go > $@
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
./scripts/biome_format.sh src/api/typesGenerated.ts
touch "$@"
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
@@ -863,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
go run ./scripts/gensite/ -icons "$@"
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
./scripts/biome_format.sh src/theme/icons.json
touch "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
@@ -901,12 +901,12 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/typegen/main.go rbac typescript > "$@"
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
touch "$@"
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
go run scripts/typegen/main.go countries > "$@"
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
./scripts/biome_format.sh src/api/countriesGenerated.ts
touch "$@"
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
@@ -950,11 +950,11 @@ coderd/apidoc/.gen: \
touch "$@"
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
./scripts/biome_format.sh ../docs/manifest.json
touch "$@"
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
touch "$@"
update-golden-files:
@@ -999,11 +999,19 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
touch "$@"
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
fi
touch "$@"
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
fi
touch "$@"
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
+45 -1
View File
@@ -4,6 +4,9 @@ import (
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"time"
"golang.org/x/xerrors"
@@ -16,6 +19,29 @@ import (
"github.com/coder/serpent"
)
// detectGitRef attempts to resolve the current git branch and remote
// origin URL from the given working directory. These are sent to the
// control plane so it can look up PR/diff status via the GitHub API
// without SSHing into the workspace. Failures are silently ignored
// since this is best-effort.
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
run := func(args ...string) string {
//nolint:gosec
cmd := exec.Command(args[0], args[1:]...)
if workingDirectory != "" {
cmd.Dir = workingDirectory
}
out, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
return branch, remoteOrigin
}
// gitAskpass is used by the Coder agent to automatically authenticate
// with Git providers based on a hostname.
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
@@ -38,8 +64,20 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("create agent client: %w", err)
}
workingDirectory, err := os.Getwd()
if err != nil {
workingDirectory = ""
}
// Detect the current git branch and remote origin so
// the control plane can resolve diffs without needing
// to SSH back into the workspace.
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
Match: host,
Match: host,
GitBranch: gitBranch,
GitRemoteOrigin: gitRemoteOrigin,
})
if err != nil {
var apiError *codersdk.Error
@@ -58,6 +96,12 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("get git token: %w", err)
}
if token.URL != "" {
// This is to help the agent authenticate with Git.
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
return cliui.ErrCanceled
}
if err := openURL(inv, token.URL); err == nil {
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
} else {
+111 -43
View File
@@ -617,28 +617,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
promRegistry := prometheus.NewRegistry()
oauthInstrument := promoauth.NewFactory(promRegistry)
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
externalAuthConfigs, err := externalauth.ConvertConfig(
oauthInstrument,
vals.ExternalAuthConfigs.Value,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range externalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
if err != nil {
@@ -669,7 +649,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
Pubsub: nil,
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
ExternalAuthConfigs: externalAuthConfigs,
ExternalAuthConfigs: nil,
RealIPConfig: realIPConfig,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
@@ -829,6 +809,40 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("set deployment id: %w", err)
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
ctx,
options.Logger,
options.Database,
vals,
mergedExternalAuthProviders,
)
if err != nil {
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
}
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
oauthInstrument,
mergedExternalAuthProviders,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range options.ExternalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
// Manage push notifications.
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
if experiments.Enabled(codersdk.ExperimentWebPush) {
@@ -1926,6 +1940,79 @@ type githubOAuth2ConfigParams struct {
enterpriseBaseURL string
}
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
// We want to enable the default provider only for new deployments, and avoid
// enabling it if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return false, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return false, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
return defaultEligible, nil
}
func maybeAppendDefaultGithubExternalAuthProvider(
ctx context.Context,
logger slog.Logger,
db database.Store,
vals *codersdk.DeploymentValues,
mergedExplicitProviders []codersdk.ExternalAuthConfig,
) ([]codersdk.ExternalAuthConfig, error) {
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "disabled by configuration"),
slog.F("flag", "external-auth-github-default-provider-enable"),
)
return mergedExplicitProviders, nil
}
if len(mergedExplicitProviders) > 0 {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "explicit external auth providers configured"),
slog.F("provider_count", len(mergedExplicitProviders)),
)
return mergedExplicitProviders, nil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "deployment is not eligible"),
)
return mergedExplicitProviders, nil
}
logger.Info(ctx, "injecting default github external auth provider",
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
)
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
ClientID: GithubOAuth2DefaultProviderClientID,
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
}), nil
}
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
params := githubOAuth2ConfigParams{
accessURL: vals.AccessURL.Value(),
@@ -1950,28 +2037,9 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
return nil, nil //nolint:nilnil
}
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
// We want to enable it only for new deployments, and avoid enabling it
// if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return nil, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
+151
View File
@@ -53,6 +53,7 @@ import (
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
@@ -302,6 +303,7 @@ func TestServer(t *testing.T) {
"open install.sh: file does not exist",
"telemetry disabled, unable to notify of security issues",
"installed terraform version newer than expected",
"report generator",
}
countLines := func(fullOutput string) int {
@@ -1805,6 +1807,155 @@ func TestServer(t *testing.T) {
})
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
type testCase struct {
name string
args []string
env map[string]string
createUserPreStart bool
expectedProviders []string
}
run := func(t *testing.T, tc testCase) {
ctx := testutil.Context(t, testutil.WaitLong)
unsetPrefixedEnv := func(prefix string) {
t.Helper()
for _, envVar := range os.Environ() {
envKey, _, found := strings.Cut(envVar, "=")
if !found || !strings.HasPrefix(envKey, prefix) {
continue
}
value, had := os.LookupEnv(envKey)
require.True(t, had)
require.NoError(t, os.Unsetenv(envKey))
keyCopy := envKey
valueCopy := value
t.Cleanup(func() {
// This is for setting/unsetting a number of prefixed env vars.
// t.Setenv doesn't cover this use case.
// nolint:usetesting
_ = os.Setenv(keyCopy, valueCopy)
})
}
}
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
unsetPrefixedEnv("CODER_GITAUTH_")
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
const (
existingUserEmail = "existing-user@coder.com"
existingUserUsername = "existing-user"
existingUserPassword = "SomeSecurePassword!"
)
if tc.createUserPreStart {
hashedPassword, err := userpassword.Hash(existingUserPassword)
require.NoError(t, err)
_ = dbgen.User(t, db, database.User{
Email: existingUserEmail,
Username: existingUserUsername,
HashedPassword: []byte(hashedPassword),
})
}
args := []string{
"server",
"--postgres-url", dbURL,
"--http-address", ":0",
"--access-url", "https://example.com",
}
args = append(args, tc.args...)
inv, cfg := clitest.New(t, args...)
for envKey, value := range tc.env {
t.Setenv(envKey, value)
}
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
if tc.createUserPreStart {
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: existingUserEmail,
Password: existingUserPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
} else {
_ = coderdtest.CreateFirstUser(t, client)
}
externalAuthResp, err := client.ListExternalAuths(ctx)
require.NoError(t, err)
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
for _, provider := range externalAuthResp.Providers {
gotProviders[provider.ID] = provider
}
require.Len(t, gotProviders, len(tc.expectedProviders))
for _, providerID := range tc.expectedProviders {
provider, ok := gotProviders[providerID]
require.Truef(t, ok, "expected provider %q to be configured", providerID)
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
require.True(t, provider.Device)
}
}
}
for _, tc := range []testCase{
{
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
},
{
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
createUserPreStart: true,
expectedProviders: nil,
},
{
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
args: []string{
"--external-auth-github-default-provider-enable=false",
},
expectedProviders: nil,
},
{
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
args: []string{
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
} {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_Logging_NoParallel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+3
View File
@@ -62,6 +62,9 @@ OPTIONS:
Separate multiple experiments with commas, or enter '*' to opt-in to
all available experiments.
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
Enable the default GitHub external auth provider managed by Coder.
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
Type of auth to use when connecting to postgres. For AWS RDS, using
IAM authentication (awsiamrds) is recommended.
+3
View File
@@ -564,6 +564,9 @@ supportLinks: []
# External Authentication providers.
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
externalAuthProviders: []
# Enable the default GitHub external auth provider managed by Coder.
# (default: true, type: bool)
externalAuthGithubDefaultProviderEnable: true
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
# "tunnel.example.com".
+2 -1
View File
@@ -192,7 +192,8 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
})
defer commitAuditWS()
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, &createWorkspaceOptions{
remoteAddr: r.RemoteAddr,
// Before creating the workspace, ensure that this task can be created.
preCreateInTX: func(ctx context.Context, tx database.Store) error {
// Create task record in the database before creating the workspace so that
+19
View File
@@ -12783,6 +12783,11 @@ const docTemplate = `{
"boundary_usage:delete",
"boundary_usage:read",
"boundary_usage:update",
"chat:*",
"chat:create",
"chat:delete",
"chat:read",
"chat:update",
"coder:all",
"coder:apikeys.manage_self",
"coder:application_connect",
@@ -12987,6 +12992,11 @@ const docTemplate = `{
"APIKeyScopeBoundaryUsageDelete",
"APIKeyScopeBoundaryUsageRead",
"APIKeyScopeBoundaryUsageUpdate",
"APIKeyScopeChatAll",
"APIKeyScopeChatCreate",
"APIKeyScopeChatDelete",
"APIKeyScopeChatRead",
"APIKeyScopeChatUpdate",
"APIKeyScopeCoderAll",
"APIKeyScopeCoderApikeysManageSelf",
"APIKeyScopeCoderApplicationConnect",
@@ -14848,6 +14858,9 @@ const docTemplate = `{
"external_auth": {
"$ref": "#/definitions/serpent.Struct-array_codersdk_ExternalAuthConfig"
},
"external_auth_github_default_provider_enable": {
"type": "boolean"
},
"external_token_encryption_keys": {
"type": "array",
"items": {
@@ -15132,9 +15145,11 @@ const docTemplate = `{
"workspace-usage",
"web-push",
"oauth2",
"agents",
"mcp-server-http"
],
"x-enum-comments": {
"ExperimentAgents": "Enables agent-powered chat functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -15150,6 +15165,7 @@ const docTemplate = `{
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables agent-powered chat functionality.",
"Enables the MCP HTTP server functionality."
],
"x-enum-varnames": [
@@ -15159,6 +15175,7 @@ const docTemplate = `{
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentAgents",
"ExperimentMCPServerHTTP"
]
},
@@ -18099,6 +18116,7 @@ const docTemplate = `{
"assign_role",
"audit_log",
"boundary_usage",
"chat",
"connection_log",
"crypto_key",
"debug_info",
@@ -18144,6 +18162,7 @@ const docTemplate = `{
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceBoundaryUsage",
"ResourceChat",
"ResourceConnectionLog",
"ResourceCryptoKey",
"ResourceDebugInfo",
+19
View File
@@ -11387,6 +11387,11 @@
"boundary_usage:delete",
"boundary_usage:read",
"boundary_usage:update",
"chat:*",
"chat:create",
"chat:delete",
"chat:read",
"chat:update",
"coder:all",
"coder:apikeys.manage_self",
"coder:application_connect",
@@ -11591,6 +11596,11 @@
"APIKeyScopeBoundaryUsageDelete",
"APIKeyScopeBoundaryUsageRead",
"APIKeyScopeBoundaryUsageUpdate",
"APIKeyScopeChatAll",
"APIKeyScopeChatCreate",
"APIKeyScopeChatDelete",
"APIKeyScopeChatRead",
"APIKeyScopeChatUpdate",
"APIKeyScopeCoderAll",
"APIKeyScopeCoderApikeysManageSelf",
"APIKeyScopeCoderApplicationConnect",
@@ -13378,6 +13388,9 @@
"external_auth": {
"$ref": "#/definitions/serpent.Struct-array_codersdk_ExternalAuthConfig"
},
"external_auth_github_default_provider_enable": {
"type": "boolean"
},
"external_token_encryption_keys": {
"type": "array",
"items": {
@@ -13655,9 +13668,11 @@
"workspace-usage",
"web-push",
"oauth2",
"agents",
"mcp-server-http"
],
"x-enum-comments": {
"ExperimentAgents": "Enables agent-powered chat functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -13673,6 +13688,7 @@
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables agent-powered chat functionality.",
"Enables the MCP HTTP server functionality."
],
"x-enum-varnames": [
@@ -13682,6 +13698,7 @@
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentAgents",
"ExperimentMCPServerHTTP"
]
},
@@ -16507,6 +16524,7 @@
"assign_role",
"audit_log",
"boundary_usage",
"chat",
"connection_log",
"crypto_key",
"debug_info",
@@ -16552,6 +16570,7 @@
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceBoundaryUsage",
"ResourceChat",
"ResourceConnectionLog",
"ResourceCryptoKey",
"ResourceDebugInfo",
+48
View File
@@ -1,6 +1,7 @@
package coderd
import (
"context"
"fmt"
"net/http"
@@ -8,6 +9,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
@@ -91,6 +93,36 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object
return true
}
// AuthorizeContext checks whether the RBAC subject on the context
// is authorized to perform the given action. The subject must have
// been set via dbauthz.As or the ExtractAPIKey middleware. Returns
// false if the subject is missing or unauthorized.
func (h *HTTPAuthorizer) AuthorizeContext(ctx context.Context, action policy.Action, object rbac.Objecter) bool {
roles, ok := dbauthz.ActorFromContext(ctx)
if !ok {
h.Logger.Error(ctx, "no authorization actor in context")
return false
}
err := h.Authorizer.Authorize(ctx, roles, action, object.RBACObject())
if err != nil {
internalError := new(rbac.UnauthorizedError)
logger := h.Logger
if xerrors.As(err, internalError) {
logger = h.Logger.With(slog.F("internal_error", internalError.Internal()))
}
logger.Warn(ctx, "requester is not authorized to access the object",
slog.F("roles", roles.SafeRoleNames()),
slog.F("actor_id", roles.ID),
slog.F("actor_name", roles),
slog.F("scope", roles.SafeScopeName()),
slog.F("action", action),
slog.F("object", object),
)
return false
}
return true
}
// AuthorizeSQLFilter returns an authorization filter that can used in a
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
// from postgres are already authorized, and the caller does not need to
@@ -106,6 +138,22 @@ func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Actio
return prepared, nil
}
// AuthorizeSQLFilterContext is like AuthorizeSQLFilter but reads the
// RBAC subject from the context directly rather than from an
// *http.Request. The subject must have been set via dbauthz.As.
func (h *HTTPAuthorizer) AuthorizeSQLFilterContext(ctx context.Context, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) {
roles, ok := dbauthz.ActorFromContext(ctx)
if !ok {
return nil, xerrors.New("no authorization actor in context")
}
prepared, err := h.Authorizer.Prepare(ctx, roles, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
}
return prepared, nil
}
// checkAuthorization returns if the current API key can use the given
// permissions, factoring in the current user's roles and the API key scopes.
//
File diff suppressed because it is too large Load Diff
+461
View File
@@ -0,0 +1,461 @@
package chatd_test
import (
"context"
"database/sql"
"encoding/json"
"errors"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replicaA := newTestServer(t, db, ps, uuid.New())
replicaB := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "interrupt-me",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
runningWorker := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil)
require.True(t, ok)
t.Cleanup(cancel)
updated := replicaA.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
require.False(t, updated.WorkerID.Valid)
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
return false
}
return event.Status.Status == codersdk.ChatStatusWaiting
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "db-transition",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
updated := replica.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
require.False(t, updated.WorkerID.Valid)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
}
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "heartbeat-ownership",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
workerID := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
})
require.NoError(t, err)
require.Equal(t, int64(0), rows)
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: workerID,
})
require.NoError(t, err)
require.Equal(t, int64(1), rows)
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "queue-when-busy",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
workerID := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
require.True(t, result.Chat.WorkerID.Valid)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 1)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, messages, 1)
}
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "interrupt-when-busy",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
require.False(t, result.Queued)
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
require.False(t, result.Chat.WorkerID.Valid)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusPending, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 0)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, messages, 2)
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
}
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "edit-message",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "original"}},
})
require.NoError(t, err)
initialMessages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, initialMessages, 1)
editedMessageID := initialMessages[0].ID
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "follow-up"}},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "another"}},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
ChatID: chat.ID,
Content: json.RawMessage(`"queued"`),
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: editedMessageID,
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
})
require.NoError(t, err)
require.Equal(t, editedMessageID, editResult.Message.ID)
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
require.False(t, editResult.Chat.WorkerID.Valid)
editedSDK := db2sdk.ChatMessage(editResult.Message)
require.Len(t, editedSDK.Content, 1)
require.Equal(t, "edited", editedSDK.Content[0].Text)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, messages, 1)
require.Equal(t, editedMessageID, messages[0].ID)
onlyMessage := db2sdk.ChatMessage(messages[0])
require.Len(t, onlyMessage.Content, 1)
require.Equal(t, "edited", onlyMessage.Content[0].Text)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 0)
chatFromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusPending, chatFromDB.Status)
require.False(t, chatFromDB.WorkerID.Valid)
}
func TestEditMessageRejectsMissingMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "missing-edited-message",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: 999999,
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
})
require.Error(t, err)
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
}
func TestEditMessageRejectsNonUserMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "non-user-edited-message",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
assistantMessage, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: "assistant",
Content: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`"assistant"`),
Valid: true,
},
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
})
require.NoError(t, err)
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: assistantMessage.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
})
require.Error(t, err)
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
}
func newTestServer(
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
replicaID uuid.UUID,
) *chatd.Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: replicaID,
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
func seedChatDependencies(
ctx context.Context,
t *testing.T,
db database.Store,
) (database.User, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, model
}
+676
View File
@@ -0,0 +1,676 @@
package chatloop
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"strings"
"sync"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
)
const (
interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result"
)
var ErrInterrupted = xerrors.New("chat interrupted")
// PersistedStep contains the full content of a completed or
// interrupted agent step. Content includes both assistant blocks
// (text, reasoning, tool calls) and tool result blocks, mirroring
// what fantasy provides in StepResult.Content. The persistence
// layer is responsible for splitting these into separate database
// messages by role.
type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
}
// RunOptions configures a single streaming chat loop run.
type RunOptions struct {
Model fantasy.LanguageModel
Messages []fantasy.Message
Tools []fantasy.AgentTool
StreamCall fantasy.AgentStreamCall
MaxSteps int
ActiveTools []string
ContextLimitFallback int64
PersistStep func(context.Context, PersistedStep) error
PublishMessagePart func(
role fantasy.MessageRole,
part codersdk.ChatMessagePart,
)
Compaction *CompactionOptions
OnInterruptedPersistError func(error)
}
// Run executes the chat step-stream loop and delegates persistence/publishing to callbacks.
func Run(ctx context.Context, opts RunOptions) (*fantasy.AgentResult, error) {
if opts.Model == nil {
return nil, xerrors.New("chat model is required")
}
if opts.PersistStep == nil {
return nil, xerrors.New("persist step callback is required")
}
if opts.MaxSteps <= 0 {
opts.MaxSteps = 1
}
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
if opts.PublishMessagePart == nil {
return
}
opts.PublishMessagePart(role, part)
}
var (
stepStateMu sync.Mutex
streamToolNames map[string]string
streamReasoningTitles map[string]string
streamReasoningText map[string]string
// stepToolResultContents tracks tool results received during
// streaming. These are needed for the interrupted-step path
// where OnStepFinish never fires.
stepToolResultContents []fantasy.ToolResultContent
stepAssistantDraft []fantasy.Content
stepToolCallIndexByID map[string]int
)
resetStepState := func() {
stepStateMu.Lock()
streamToolNames = make(map[string]string)
streamReasoningTitles = make(map[string]string)
streamReasoningText = make(map[string]string)
stepToolResultContents = nil
stepAssistantDraft = nil
stepToolCallIndexByID = make(map[string]int)
stepStateMu.Unlock()
}
setReasoningTitleFromText := func(id string, text string) {
if id == "" || strings.TrimSpace(text) == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if streamReasoningTitles[id] != "" {
return
}
streamReasoningText[id] += text
if !strings.ContainsAny(streamReasoningText[id], "\r\n") {
return
}
title := chatprompt.ReasoningTitleFromFirstLine(streamReasoningText[id])
if title == "" {
return
}
streamReasoningTitles[id] = title
}
appendDraftText := func(text string) {
if text == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if len(stepAssistantDraft) > 0 {
lastIndex := len(stepAssistantDraft) - 1
switch last := stepAssistantDraft[lastIndex].(type) {
case fantasy.TextContent:
last.Text += text
stepAssistantDraft[lastIndex] = last
return
case *fantasy.TextContent:
last.Text += text
stepAssistantDraft[lastIndex] = fantasy.TextContent{Text: last.Text}
return
}
}
stepAssistantDraft = append(stepAssistantDraft, fantasy.TextContent{Text: text})
}
appendDraftReasoning := func(text string) {
if text == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if len(stepAssistantDraft) > 0 {
lastIndex := len(stepAssistantDraft) - 1
switch last := stepAssistantDraft[lastIndex].(type) {
case fantasy.ReasoningContent:
last.Text += text
stepAssistantDraft[lastIndex] = last
return
case *fantasy.ReasoningContent:
last.Text += text
stepAssistantDraft[lastIndex] = fantasy.ReasoningContent{Text: last.Text}
return
}
}
stepAssistantDraft = append(stepAssistantDraft, fantasy.ReasoningContent{Text: text})
}
upsertDraftToolCall := func(toolCallID, toolName, input string, appendInput bool) {
if toolCallID == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if strings.TrimSpace(toolName) != "" {
streamToolNames[toolCallID] = toolName
}
index, exists := stepToolCallIndexByID[toolCallID]
if !exists {
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: input,
})
return
}
if index < 0 || index >= len(stepAssistantDraft) {
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: input,
})
return
}
existingCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](stepAssistantDraft[index])
if !ok {
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](stepAssistantDraft[index]); ptrOK && ptrCall != nil {
existingCall = *ptrCall
ok = true
}
}
if !ok {
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: input,
})
return
}
if strings.TrimSpace(toolName) != "" {
existingCall.ToolName = toolName
}
if appendInput {
existingCall.Input += input
} else if input != "" || existingCall.Input == "" {
existingCall.Input = input
}
stepAssistantDraft[index] = existingCall
}
appendDraftSource := func(source fantasy.SourceContent) {
stepStateMu.Lock()
stepAssistantDraft = append(stepAssistantDraft, source)
stepStateMu.Unlock()
}
persistInterruptedStep := func() error {
stepStateMu.Lock()
draft := append([]fantasy.Content(nil), stepAssistantDraft...)
toolResults := append([]fantasy.ToolResultContent(nil), stepToolResultContents...)
toolNameByCallID := make(map[string]string, len(streamToolNames))
for id, name := range streamToolNames {
toolNameByCallID[id] = name
}
stepStateMu.Unlock()
if len(draft) == 0 && len(toolResults) == 0 {
return nil
}
// Track which tool calls already have results.
answeredToolCalls := make(map[string]struct{}, len(toolResults))
for _, tr := range toolResults {
if tr.ToolCallID != "" {
answeredToolCalls[tr.ToolCallID] = struct{}{}
}
}
// Build the combined content: draft + received tool results
// + synthetic interrupted results for unanswered tool calls.
content := make([]fantasy.Content, 0, len(draft)+len(toolResults))
content = append(content, draft...)
for _, tr := range toolResults {
content = append(content, tr)
}
for _, block := range draft {
toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block)
if !ok {
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](block); ptrOK && ptrCall != nil {
toolCall = *ptrCall
ok = true
}
}
if !ok || toolCall.ToolCallID == "" {
continue
}
if _, exists := answeredToolCalls[toolCall.ToolCallID]; exists {
continue
}
toolName := strings.TrimSpace(toolCall.ToolName)
if toolName == "" {
toolName = strings.TrimSpace(toolNameByCallID[toolCall.ToolCallID])
}
content = append(content, fantasy.ToolResultContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.New(interruptedToolResultErrorMessage),
},
})
answeredToolCalls[toolCall.ToolCallID] = struct{}{}
}
persistCtx := context.WithoutCancel(ctx)
return opts.PersistStep(persistCtx, PersistedStep{
Content: content,
})
}
resetStepState()
agent := fantasy.NewAgent(
opts.Model,
fantasy.WithTools(opts.Tools...),
fantasy.WithStopConditions(fantasy.StepCountIs(opts.MaxSteps)),
)
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
// Fantasy's AgentStreamCall currently requires a non-empty Prompt and always
// appends it as a user message. chatd already supplies the full history in
// Messages, so we pass and then strip a sentinel user message in PrepareStep.
sentinelPrompt := "__chatd_agent_prompt_sentinel_" + uuid.NewString()
streamCall := opts.StreamCall
streamCall.Prompt = sentinelPrompt
streamCall.Messages = opts.Messages
streamCall.PrepareStep = func(
stepCtx context.Context,
options fantasy.PrepareStepFunctionOptions,
) (context.Context, fantasy.PrepareStepResult, error) {
return stepCtx, prepareStepResult(
options.Messages,
sentinelPrompt,
opts.ActiveTools,
applyAnthropicCaching,
), nil
}
streamCall.OnStepStart = func(_ int) error {
resetStepState()
return nil
}
streamCall.OnTextDelta = func(_ string, text string) error {
appendDraftText(text)
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: text,
})
return nil
}
streamCall.OnReasoningDelta = func(id string, text string) error {
appendDraftReasoning(text)
setReasoningTitleFromText(id, text)
stepStateMu.Lock()
title := streamReasoningTitles[id]
stepStateMu.Unlock()
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: text,
Title: title,
})
return nil
}
streamCall.OnReasoningEnd = func(id string, _ fantasy.ReasoningContent) error {
stepStateMu.Lock()
if streamReasoningTitles[id] == "" {
// At the end of reasoning we have the full text, so we can
// safely evaluate first-line title format even if no newline
// ever arrived in deltas.
streamReasoningTitles[id] = chatprompt.ReasoningTitleFromFirstLine(
streamReasoningText[id],
)
}
title := streamReasoningTitles[id]
stepStateMu.Unlock()
if title != "" {
// Publish a title-only reasoning part so clients can update the
// reasoning header when metadata arrives at the end of streaming.
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Title: title,
})
}
return nil
}
streamCall.OnToolInputStart = func(id, toolName string) error {
upsertDraftToolCall(id, toolName, "", false)
return nil
}
streamCall.OnToolInputDelta = func(id, delta string) error {
stepStateMu.Lock()
toolName := streamToolNames[id]
stepStateMu.Unlock()
upsertDraftToolCall(id, toolName, delta, true)
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: id,
ToolName: toolName,
ArgsDelta: delta,
})
return nil
}
streamCall.OnToolCall = func(toolCall fantasy.ToolCallContent) error {
upsertDraftToolCall(toolCall.ToolCallID, toolCall.ToolName, toolCall.Input, false)
publishMessagePart(
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(toolCall),
)
return nil
}
streamCall.OnSource = func(source fantasy.SourceContent) error {
appendDraftSource(source)
publishMessagePart(
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(source),
)
return nil
}
streamCall.OnToolResult = func(result fantasy.ToolResultContent) error {
publishMessagePart(
fantasy.MessageRoleTool,
chatprompt.PartFromContent(result),
)
stepStateMu.Lock()
if result.ToolCallID != "" && strings.TrimSpace(result.ToolName) != "" {
streamToolNames[result.ToolCallID] = result.ToolName
}
stepToolResultContents = append(stepToolResultContents, result)
stepStateMu.Unlock()
return nil
}
streamCall.OnStepFinish = func(stepResult fantasy.StepResult) error {
contextLimit := extractContextLimit(stepResult.ProviderMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
contextLimit = sql.NullInt64{
Int64: opts.ContextLimitFallback,
Valid: true,
}
}
return opts.PersistStep(ctx, PersistedStep{
Content: stepResult.Content,
Usage: stepResult.Usage,
ContextLimit: contextLimit,
})
}
result, err := agent.Stream(ctx, streamCall)
if err != nil {
if errors.Is(err, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
if persistErr := persistInterruptedStep(); persistErr != nil {
if opts.OnInterruptedPersistError != nil {
opts.OnInterruptedPersistError(persistErr)
}
}
return nil, ErrInterrupted
}
return nil, xerrors.Errorf("stream response: %w", err)
}
if opts.Compaction != nil {
if err := maybeCompact(ctx, opts, result); err != nil {
if opts.Compaction.OnError != nil {
opts.Compaction.OnError(err)
}
}
}
return result, nil
}
//nolint:revive // Boolean controls Anthropic-specific caching behavior.
func prepareStepResult(
messages []fantasy.Message,
sentinel string,
activeTools []string,
anthropicCaching bool,
) fantasy.PrepareStepResult {
filtered := make([]fantasy.Message, 0, len(messages))
removed := false
for _, message := range messages {
if !removed &&
message.Role == fantasy.MessageRoleUser &&
len(message.Content) == 1 {
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if ok && textPart.Text == sentinel {
removed = true
continue
}
}
filtered = append(filtered, message)
}
result := fantasy.PrepareStepResult{
Messages: filtered,
}
if anthropicCaching {
result.Messages = addAnthropicPromptCaching(result.Messages)
}
if len(activeTools) > 0 {
result.ActiveTools = append([]string(nil), activeTools...)
}
return result
}
func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool {
if model == nil {
return false
}
return model.Provider() == fantasyanthropic.Name
}
func addAnthropicPromptCaching(messages []fantasy.Message) []fantasy.Message {
for i := range messages {
messages[i].ProviderOptions = nil
}
providerOption := fantasy.ProviderOptions{
fantasyanthropic.Name: &fantasyanthropic.ProviderCacheControlOptions{
CacheControl: fantasyanthropic.CacheControl{Type: "ephemeral"},
},
}
lastSystemRoleIdx := -1
systemMessageUpdated := false
for i, msg := range messages {
if msg.Role == fantasy.MessageRoleSystem {
lastSystemRoleIdx = i
} else if !systemMessageUpdated && lastSystemRoleIdx >= 0 {
messages[lastSystemRoleIdx].ProviderOptions = providerOption
systemMessageUpdated = true
}
if i > len(messages)-3 {
messages[i].ProviderOptions = providerOption
}
}
return messages
}
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
if len(metadata) == 0 {
return sql.NullInt64{}
}
encoded, err := json.Marshal(metadata)
if err != nil || len(encoded) == 0 {
return sql.NullInt64{}
}
var payload any
if err := json.Unmarshal(encoded, &payload); err != nil {
return sql.NullInt64{}
}
limit, ok := findContextLimitValue(payload)
if !ok {
return sql.NullInt64{}
}
return sql.NullInt64{
Int64: limit,
Valid: true,
}
}
func findContextLimitValue(value any) (int64, bool) {
var (
limit int64
found bool
)
collectContextLimitValues(value, func(candidate int64) {
if !found || candidate > limit {
limit = candidate
found = true
}
})
return limit, found
}
func collectContextLimitValues(value any, onValue func(int64)) {
switch typed := value.(type) {
case map[string]any:
for key, child := range typed {
if isContextLimitKey(key) {
if numeric, ok := numericContextLimitValue(child); ok {
onValue(numeric)
}
}
collectContextLimitValues(child, onValue)
}
case []any:
for _, child := range typed {
collectContextLimitValues(child, onValue)
}
}
}
func isContextLimitKey(key string) bool {
normalized := normalizeMetadataKey(key)
if normalized == "" {
return false
}
switch normalized {
case
"contextlimit",
"contextwindow",
"contextlength",
"maxcontext",
"maxcontexttokens",
"maxinputtokens",
"maxinputtoken",
"inputtokenlimit":
return true
}
return strings.Contains(normalized, "context") &&
(strings.Contains(normalized, "limit") ||
strings.Contains(normalized, "window") ||
strings.Contains(normalized, "length") ||
strings.HasPrefix(normalized, "max"))
}
func normalizeMetadataKey(key string) string {
var b strings.Builder
b.Grow(len(key))
for _, r := range key {
switch {
case r >= 'a' && r <= 'z':
_, _ = b.WriteRune(r)
case r >= 'A' && r <= 'Z':
_, _ = b.WriteRune(r + ('a' - 'A'))
case r >= '0' && r <= '9':
_, _ = b.WriteRune(r)
}
}
return b.String()
}
func numericContextLimitValue(value any) (int64, bool) {
switch typed := value.(type) {
case int64:
return positiveInt64(typed)
case int32:
return positiveInt64(int64(typed))
case int:
return positiveInt64(int64(typed))
case float64:
casted := int64(typed)
if typed > 0 && float64(casted) == typed {
return casted, true
}
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(typed), 10, 64)
if err == nil {
return positiveInt64(parsed)
}
case json.Number:
parsed, err := typed.Int64()
if err == nil {
return positiveInt64(parsed)
}
}
return 0, false
}
func positiveInt64(value int64) (int64, bool) {
if value <= 0 {
return 0, false
}
return value, true
}
+289
View File
@@ -0,0 +1,289 @@
package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"iter"
"strings"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
const activeToolName = "read_file"
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
t.Parallel()
var capturedCall fantasy.Call
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
capturedCall = call
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
persistStepCalls := 0
var persistedStep PersistedStep
_, err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "sys-1"),
textMessage(fantasy.MessageRoleSystem, "sys-2"),
textMessage(fantasy.MessageRoleUser, "hello"),
textMessage(fantasy.MessageRoleAssistant, "working"),
textMessage(fantasy.MessageRoleUser, "continue"),
},
Tools: []fantasy.AgentTool{
newNoopTool(activeToolName),
newNoopTool("write_file"),
},
MaxSteps: 3,
ActiveTools: []string{activeToolName},
ContextLimitFallback: 4096,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistStepCalls++
persistedStep = step
return nil
},
})
require.NoError(t, err)
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
require.Len(t, capturedCall.Tools, 1)
require.Equal(t, activeToolName, capturedCall.Tools[0].GetName())
require.Len(t, capturedCall.Prompt, 5)
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1]))
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
}
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
t.Parallel()
started := make(chan struct{})
model := &loopTestModel{
provider: "fake",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
parts := []fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeToolInputStart,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
},
{
Type: fantasy.StreamPartTypeToolInputDelta,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
Delta: `{"path":"main.go"`,
},
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"},
}
for _, part := range parts {
if !yield(part) {
return
}
}
select {
case <-started:
default:
close(started)
}
<-ctx.Done()
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: ctx.Err(),
})
}), nil
},
}
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
go func() {
<-started
cancel(ErrInterrupted)
}()
persistedAssistantCtxErr := xerrors.New("unset")
var persistedContent []fantasy.Content
_, err := Run(ctx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 3,
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedAssistantCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
require.NoError(t, persistedAssistantCtxErr)
require.NotEmpty(t, persistedContent)
var (
foundText bool
foundToolCall bool
foundToolResult bool
)
for _, block := range persistedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "partial assistant output") {
foundText = true
}
continue
}
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
if toolCall.ToolCallID == "interrupt-tool-1" &&
toolCall.ToolName == "read_file" &&
strings.Contains(toolCall.Input, `"path":"main.go"`) {
foundToolCall = true
}
continue
}
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
if toolResult.ToolCallID == "interrupt-tool-1" &&
toolResult.ToolName == "read_file" {
_, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError)
require.True(t, isErr, "interrupted tool result should be an error")
foundToolResult = true
}
}
}
require.True(t, foundText)
require.True(t, foundToolCall)
require.True(t, foundToolResult)
}
type loopTestModel struct {
provider string
model string
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
}
func (m *loopTestModel) Provider() string {
if m.provider != "" {
return m.provider
}
return "fake"
}
func (m *loopTestModel) Model() string {
if m.model != "" {
return m.model
}
return "fake"
}
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
if m.generateFn != nil {
return m.generateFn(ctx, call)
}
return &fantasy.Response{}, nil
}
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
if m.streamFn != nil {
return m.streamFn(ctx, call)
}
return streamFromParts([]fantasy.StreamPart{{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
}}), nil
}
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, xerrors.New("not implemented")
}
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("not implemented")
}
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
})
}
func newNoopTool(name string) fantasy.AgentTool {
return fantasy.NewAgentTool(
name,
"test noop tool",
func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.ToolResponse{}, nil
},
)
}
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
return fantasy.Message{
Role: role,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: text},
},
}
}
func containsPromptSentinel(prompt []fantasy.Message) bool {
for _, message := range prompt {
if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 {
continue
}
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if !ok {
continue
}
if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") {
return true
}
}
return false
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
}
options, ok := message.ProviderOptions[fantasyanthropic.Name]
if !ok {
return false
}
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
return ok && cacheOptions.CacheControl.Type == "ephemeral"
}
+209
View File
@@ -0,0 +1,209 @@
package chatloop
import (
"context"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
)
const (
defaultCompactionThresholdPercent = int32(70)
minCompactionThresholdPercent = int32(0)
maxCompactionThresholdPercent = int32(100)
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
"new assistant can continue seamlessly. Include the user's goals, " +
"decisions made, concrete technical details (files, commands, APIs), " +
"errors encountered and fixes, and open questions. Be dense and factual. " +
"Omit pleasantries and next-step suggestions."
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
defaultCompactionTimeout = 90 * time.Second
)
type CompactionOptions struct {
ThresholdPercent int32
ContextLimit int64
SummaryPrompt string
SystemSummaryPrefix string
Timeout time.Duration
Persist func(context.Context, CompactionResult) error
OnError func(error)
}
type CompactionResult struct {
SystemSummary string
SummaryReport string
ThresholdPercent int32
UsagePercent float64
ContextTokens int64
ContextLimit int64
}
func maybeCompact(
ctx context.Context,
runOpts RunOptions,
runResult *fantasy.AgentResult,
) error {
if runResult == nil || runOpts.Compaction == nil {
return nil
}
config := *runOpts.Compaction
if config.Persist == nil {
return xerrors.New("compaction persist callback is required")
}
if strings.TrimSpace(config.SummaryPrompt) == "" {
config.SummaryPrompt = defaultCompactionSummaryPrompt
}
if strings.TrimSpace(config.SystemSummaryPrefix) == "" {
config.SystemSummaryPrefix = defaultCompactionSystemSummaryPrefix
}
if config.Timeout <= 0 {
config.Timeout = defaultCompactionTimeout
}
if config.ThresholdPercent < minCompactionThresholdPercent ||
config.ThresholdPercent > maxCompactionThresholdPercent {
config.ThresholdPercent = defaultCompactionThresholdPercent
}
if config.ThresholdPercent >= maxCompactionThresholdPercent {
return nil
}
if runOpts.MaxSteps > 0 && len(runResult.Steps) >= runOpts.MaxSteps {
lastStep := runResult.Steps[len(runResult.Steps)-1]
if lastStep.FinishReason == fantasy.FinishReasonToolCalls &&
len(lastStep.Content.ToolCalls()) > 0 {
return nil
}
}
contextTokens := int64(0)
contextLimitFromMetadata := int64(0)
for i := len(runResult.Steps) - 1; i >= 0; i-- {
usage := runResult.Steps[i].Usage
total := int64(0)
hasContextTokens := false
if usage.InputTokens > 0 {
total += usage.InputTokens
hasContextTokens = true
}
if usage.CacheReadTokens > 0 {
total += usage.CacheReadTokens
hasContextTokens = true
}
if usage.CacheCreationTokens > 0 {
total += usage.CacheCreationTokens
hasContextTokens = true
}
if !hasContextTokens && usage.TotalTokens > 0 {
total = usage.TotalTokens
hasContextTokens = true
}
if !hasContextTokens || total <= 0 {
continue
}
contextTokens = total
metadataLimit := extractContextLimit(runResult.Steps[i].ProviderMetadata)
if metadataLimit.Valid && metadataLimit.Int64 > 0 {
contextLimitFromMetadata = metadataLimit.Int64
}
break
}
if contextTokens <= 0 {
return nil
}
contextLimit := contextLimitFromMetadata
if contextLimit <= 0 && config.ContextLimit > 0 {
contextLimit = config.ContextLimit
}
if contextLimit <= 0 && runOpts.ContextLimitFallback > 0 {
contextLimit = runOpts.ContextLimitFallback
}
if contextLimit <= 0 {
return nil
}
usagePercent := (float64(contextTokens) / float64(contextLimit)) * 100
if usagePercent < float64(config.ThresholdPercent) {
return nil
}
summary, err := generateCompactionSummary(
ctx,
runOpts.Model,
runOpts.Messages,
runResult.Steps,
config,
)
if err != nil {
return err
}
if summary == "" {
return nil
}
systemSummary := strings.TrimSpace(
config.SystemSummaryPrefix + "\n\n" + summary,
)
return config.Persist(ctx, CompactionResult{
SystemSummary: systemSummary,
SummaryReport: summary,
ThresholdPercent: config.ThresholdPercent,
UsagePercent: usagePercent,
ContextTokens: contextTokens,
ContextLimit: contextLimit,
})
}
func generateCompactionSummary(
ctx context.Context,
model fantasy.LanguageModel,
messages []fantasy.Message,
steps []fantasy.StepResult,
options CompactionOptions,
) (string, error) {
summaryPrompt := make([]fantasy.Message, 0, len(messages)+len(steps)+1)
summaryPrompt = append(summaryPrompt, messages...)
for _, step := range steps {
summaryPrompt = append(summaryPrompt, step.Messages...)
}
summaryPrompt = append(summaryPrompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: options.SummaryPrompt},
},
})
toolChoice := fantasy.ToolChoiceNone
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
defer cancel()
response, err := model.Generate(summaryCtx, fantasy.Call{
Prompt: summaryPrompt,
ToolChoice: &toolChoice,
})
if err != nil {
return "", xerrors.Errorf("generate summary text: %w", err)
}
parts := make([]string, 0, len(response.Content))
for _, block := range response.Content {
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
if !ok {
continue
}
text := strings.TrimSpace(textBlock.Text)
if text == "" {
continue
}
parts = append(parts, text)
}
return strings.TrimSpace(strings.Join(parts, " ")), nil
}
+132
View File
@@ -0,0 +1,132 @@
package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func TestRun_Compaction(t *testing.T) {
t.Parallel()
t.Run("PersistsWhenThresholdReached", func(t *testing.T) {
t.Parallel()
persistCompactionCalls := 0
var persistedCompaction CompactionResult
const summaryText = "summary text for compaction"
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
},
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
require.NotEmpty(t, call.Prompt)
lastPrompt := call.Prompt[len(call.Prompt)-1]
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
require.Len(t, lastPrompt.Content, 1)
instruction, ok := fantasy.AsMessagePart[fantasy.TextPart](lastPrompt.Content[0])
require.True(t, ok)
require.Equal(t, "summarize now", instruction.Text)
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
_, err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, result CompactionResult) error {
persistCompactionCalls++
persistedCompaction = result
return nil
},
},
})
require.NoError(t, err)
require.Equal(t, 1, persistCompactionCalls)
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
require.Equal(t, int64(100), persistedCompaction.ContextLimit)
require.InDelta(t, 80.0, persistedCompaction.UsagePercent, 0.0001)
})
t.Run("ErrorsAreReported", func(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
},
},
}), nil
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return nil, xerrors.New("generate failed")
},
}
compactionErr := xerrors.New("unset")
_, err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
Persist: func(_ context.Context, _ CompactionResult) error {
return nil
},
OnError: func(err error) {
compactionErr = err
},
},
})
require.NoError(t, err)
require.Error(t, compactionErr)
require.ErrorContains(t, compactionErr, "generate summary text")
})
}
+982
View File
@@ -0,0 +1,982 @@
package chatprompt
import (
"encoding/json"
"regexp"
"strings"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
func ConvertMessages(
messages []database.ChatMessage,
) ([]fantasy.Message, error) {
prompt := make([]fantasy.Message, 0, len(messages))
toolNameByCallID := make(map[string]string)
for _, message := range messages {
visibility := message.Visibility
if visibility == "" {
visibility = database.ChatMessageVisibilityBoth
}
if visibility != database.ChatMessageVisibilityModel &&
visibility != database.ChatMessageVisibilityBoth {
continue
}
switch message.Role {
case string(fantasy.MessageRoleSystem):
content, err := parseSystemContent(message.Content)
if err != nil {
return nil, err
}
if strings.TrimSpace(content) == "" {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: content},
},
})
case string(fantasy.MessageRoleUser):
content, err := ParseContent(string(fantasy.MessageRoleUser), message.Content)
if err != nil {
return nil, err
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: ToMessageParts(content),
})
case string(fantasy.MessageRoleAssistant):
content, err := ParseContent(string(fantasy.MessageRoleAssistant), message.Content)
if err != nil {
return nil, err
}
parts := normalizeAssistantToolCallInputs(ToMessageParts(content))
for _, toolCall := range ExtractToolCalls(parts) {
if toolCall.ToolCallID == "" || strings.TrimSpace(toolCall.ToolName) == "" {
continue
}
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: parts,
})
case string(fantasy.MessageRoleTool):
rows, err := parseToolResultRows(message.Content)
if err != nil {
return nil, err
}
parts := make([]fantasy.MessagePart, 0, len(rows))
for _, row := range rows {
if row.ToolCallID != "" && row.ToolName != "" {
toolNameByCallID[sanitizeToolCallID(row.ToolCallID)] = row.ToolName
}
parts = append(parts, row.toToolResultPart())
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: parts,
})
default:
return nil, xerrors.Errorf("unsupported chat message role %q", message.Role)
}
}
prompt = injectMissingToolResults(prompt)
prompt = injectMissingToolUses(
prompt,
toolNameByCallID,
)
return prompt, nil
}
// PrependSystem prepends a system message unless an existing system
// message already mentions create_workspace guidance.
func PrependSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
for _, message := range prompt {
if message.Role != fantasy.MessageRoleSystem {
continue
}
for _, part := range message.Content {
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
if !ok {
continue
}
if strings.Contains(strings.ToLower(textPart.Text), "create_workspace") {
return prompt
}
}
}
out := make([]fantasy.Message, 0, len(prompt)+1)
out = append(out, fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
})
out = append(out, prompt...)
return out
}
// InsertSystem inserts a system message after the existing system
// block and before the first non-system message.
func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
systemMessage := fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
}
out := make([]fantasy.Message, 0, len(prompt)+1)
inserted := false
for _, message := range prompt {
if !inserted && message.Role != fantasy.MessageRoleSystem {
out = append(out, systemMessage)
inserted = true
}
out = append(out, message)
}
if !inserted {
out = append(out, systemMessage)
}
return out
}
// AppendUser appends an instruction as a user message at the end of
// the prompt.
func AppendUser(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
out := make([]fantasy.Message, 0, len(prompt)+1)
out = append(out, prompt...)
out = append(out, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
})
return out
}
// ParseContent decodes persisted chat message content blocks.
func ParseContent(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var text string
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
return []fantasy.Content{fantasy.TextContent{Text: text}}, nil
}
var rawBlocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
return nil, xerrors.Errorf("parse %s content: %w", role, err)
}
content := make([]fantasy.Content, 0, len(rawBlocks))
for i, rawBlock := range rawBlocks {
block, err := fantasy.UnmarshalContent(rawBlock)
if err != nil {
return nil, xerrors.Errorf("parse %s content block %d: %w", role, i, err)
}
content = append(content, block)
}
return content, nil
}
// toolResultRaw is an untyped representation of a persisted tool
// result row. We intentionally avoid a strict Go struct so that
// historical shapes are never rejected.
type toolResultRaw struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
}
// parseToolResultRows decodes persisted tool result rows.
func parseToolResultRows(raw pqtype.NullRawMessage) ([]toolResultRaw, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var rows []toolResultRaw
if err := json.Unmarshal(raw.RawMessage, &rows); err != nil {
return nil, xerrors.Errorf("parse tool content: %w", err)
}
return rows, nil
}
func (r toolResultRaw) toToolResultPart() fantasy.ToolResultPart {
toolCallID := sanitizeToolCallID(r.ToolCallID)
resultText := string(r.Result)
if resultText == "" || resultText == "null" {
resultText = "{}"
}
if r.IsError {
message := strings.TrimSpace(resultText)
if extracted := extractErrorString(r.Result); extracted != "" {
message = extracted
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
Output: fantasy.ToolResultOutputContentError{
Error: xerrors.New(message),
},
}
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
Output: fantasy.ToolResultOutputContentText{
Text: resultText,
},
}
}
// extractErrorString pulls the "error" field from a JSON object if
// present, returning it as a string. Returns "" if the field is
// missing or the input is not an object.
func extractErrorString(raw json.RawMessage) string {
var fields map[string]json.RawMessage
if err := json.Unmarshal(raw, &fields); err != nil {
return ""
}
errField, ok := fields["error"]
if !ok {
return ""
}
var s string
if err := json.Unmarshal(errField, &s); err != nil {
return ""
}
return strings.TrimSpace(s)
}
// ToMessageParts converts fantasy content blocks into message parts.
func ToMessageParts(content []fantasy.Content) []fantasy.MessagePart {
parts := make([]fantasy.MessagePart, 0, len(content))
for _, block := range content {
switch value := block.(type) {
case fantasy.TextContent:
parts = append(parts, fantasy.TextPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.TextContent:
parts = append(parts, fantasy.TextPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ReasoningContent:
parts = append(parts, fantasy.ReasoningPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ReasoningContent:
parts = append(parts, fantasy.ReasoningPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ToolCallContent:
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
ToolName: value.ToolName,
Input: value.Input,
ProviderExecuted: value.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ToolCallContent:
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
ToolName: value.ToolName,
Input: value.Input,
ProviderExecuted: value.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.FileContent:
parts = append(parts, fantasy.FilePart{
Data: value.Data,
MediaType: value.MediaType,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.FileContent:
parts = append(parts, fantasy.FilePart{
Data: value.Data,
MediaType: value.MediaType,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ToolResultContent:
parts = append(parts, fantasy.ToolResultPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
Output: value.Result,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ToolResultContent:
parts = append(parts, fantasy.ToolResultPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
Output: value.Result,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
}
}
return parts
}
func normalizeAssistantToolCallInputs(
parts []fantasy.MessagePart,
) []fantasy.MessagePart {
normalized := make([]fantasy.MessagePart, 0, len(parts))
for _, part := range parts {
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
if !ok {
normalized = append(normalized, part)
continue
}
toolCall.Input = normalizeToolCallInput(toolCall.Input)
normalized = append(normalized, toolCall)
}
return normalized
}
// normalizeToolCallInput guarantees tool call input is a JSON object string.
// Anthropic drops assistant tool calls with malformed input, which can leave
// following tool results orphaned.
func normalizeToolCallInput(input string) string {
input = strings.TrimSpace(input)
if input == "" {
return "{}"
}
var object map[string]any
if err := json.Unmarshal([]byte(input), &object); err != nil || object == nil {
return "{}"
}
return input
}
// ExtractToolCalls returns all tool call parts as content blocks.
func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
toolCalls := make([]fantasy.ToolCallContent, 0, len(parts))
for _, part := range parts {
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
if !ok {
continue
}
toolCalls = append(toolCalls, fantasy.ToolCallContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolCall.ToolName,
Input: toolCall.Input,
ProviderExecuted: toolCall.ProviderExecuted,
})
}
return toolCalls
}
// MarshalContent encodes message content blocks for persistence.
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
if len(blocks) == 0 {
return pqtype.NullRawMessage{}, nil
}
encodedBlocks := make([]json.RawMessage, 0, len(blocks))
for i, block := range blocks {
encoded, err := marshalContentBlock(block)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf(
"encode content block %d: %w",
i,
err,
)
}
encodedBlocks = append(encodedBlocks, encoded)
}
data, err := json.Marshal(encodedBlocks)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode content blocks: %w", err)
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// MarshalToolResult encodes a single tool result for persistence as
// an opaque JSON blob. The stored shape is
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) (pqtype.NullRawMessage, error) {
row := toolResultRaw{
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
}
data, err := json.Marshal([]toolResultRaw{row})
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode tool result: %w", err)
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// MarshalToolResultContent encodes a fantasy tool result content
// block for persistence. It extracts the raw fields and delegates
// to MarshalToolResult.
func MarshalToolResultContent(content fantasy.ToolResultContent) (pqtype.NullRawMessage, error) {
var result json.RawMessage
var isError bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
isError = true
if output.Error != nil {
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
} else {
result = []byte(`{"error":""}`)
}
case fantasy.ToolResultOutputContentText:
result = json.RawMessage(output.Text)
if !json.Valid(result) {
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
})
default:
result = []byte(`{}`)
}
return MarshalToolResult(content.ToolCallID, content.ToolName, result, isError)
}
// PartFromContent converts fantasy content into a SDK chat message part.
func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
switch value := block.(type) {
case fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case *fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
Title: reasoningSummaryTitle(value.ProviderMetadata),
}
case *fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
Title: reasoningSummaryTitle(value.ProviderMetadata),
}
case fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case *fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case *fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case *fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case fantasy.ToolResultContent:
return toolResultContentToPart(value)
case *fantasy.ToolResultContent:
return toolResultContentToPart(*value)
default:
return codersdk.ChatMessagePart{}
}
}
// ToolResultToPart converts a tool call ID, raw result, and error
// flag into a ChatMessagePart. This is the minimal conversion used
// both during streaming and when reading from the database.
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
}
}
// toolResultContentToPart converts a fantasy ToolResultContent
// directly into a ChatMessagePart without an intermediate struct.
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
var result json.RawMessage
var isError bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
isError = true
if output.Error != nil {
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
} else {
result = []byte(`{"error":""}`)
}
case fantasy.ToolResultOutputContentText:
result = json.RawMessage(output.Text)
// Ensure valid JSON; wrap in an object if not.
if !json.Valid(result) {
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
})
default:
result = []byte(`{}`)
}
return ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
}
// ReasoningTitleFromFirstLine extracts a compact markdown title.
func ReasoningTitleFromFirstLine(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
firstLine := text
if idx := strings.IndexAny(firstLine, "\r\n"); idx >= 0 {
firstLine = firstLine[:idx]
}
firstLine = strings.TrimSpace(firstLine)
if firstLine == "" || !strings.HasPrefix(firstLine, "**") {
return ""
}
rest := firstLine[2:]
end := strings.Index(rest, "**")
if end < 0 {
return ""
}
title := strings.TrimSpace(rest[:end])
if title == "" {
return ""
}
// Require the first line to be exactly "**title**" (ignoring
// surrounding whitespace) so providers without this format don't
// accidentally emit a title.
if strings.TrimSpace(rest[end+2:]) != "" {
return ""
}
return compactReasoningSummaryTitle(title)
}
func injectMissingToolResults(prompt []fantasy.Message) []fantasy.Message {
result := make([]fantasy.Message, 0, len(prompt))
for i := 0; i < len(prompt); i++ {
msg := prompt[i]
result = append(result, msg)
if msg.Role != fantasy.MessageRoleAssistant {
continue
}
toolCalls := ExtractToolCalls(msg.Content)
if len(toolCalls) == 0 {
continue
}
// Collect the tool call IDs that have results in the
// following tool message(s).
answered := make(map[string]struct{})
j := i + 1
for ; j < len(prompt); j++ {
if prompt[j].Role != fantasy.MessageRoleTool {
break
}
for _, part := range prompt[j].Content {
tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
if !ok {
continue
}
answered[tr.ToolCallID] = struct{}{}
}
}
if i+1 < j {
// Preserve persisted tool result ordering and inject any
// synthetic results after the existing contiguous tool messages.
result = append(result, prompt[i+1:j]...)
i = j - 1
}
// Build synthetic results for any unanswered tool calls.
var missing []fantasy.MessagePart
for _, tc := range toolCalls {
if _, ok := answered[tc.ToolCallID]; !ok {
missing = append(missing, fantasy.ToolResultPart{
ToolCallID: tc.ToolCallID,
Output: fantasy.ToolResultOutputContentError{
Error: xerrors.New("tool call was interrupted and did not receive a result"),
},
})
}
}
if len(missing) > 0 {
result = append(result, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: missing,
})
}
}
return result
}
func injectMissingToolUses(
prompt []fantasy.Message,
toolNameByCallID map[string]string,
) []fantasy.Message {
result := make([]fantasy.Message, 0, len(prompt))
for _, msg := range prompt {
if msg.Role != fantasy.MessageRoleTool {
result = append(result, msg)
continue
}
toolResults := make([]fantasy.ToolResultPart, 0, len(msg.Content))
for _, part := range msg.Content {
toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
if !ok {
continue
}
toolResults = append(toolResults, toolResult)
}
if len(toolResults) == 0 {
result = append(result, msg)
continue
}
// Walk backwards through the result to find the nearest
// preceding assistant message (skipping over other tool
// messages that belong to the same batch of results).
answeredByPrevious := make(map[string]struct{})
for k := len(result) - 1; k >= 0; k-- {
if result[k].Role == fantasy.MessageRoleAssistant {
for _, toolCall := range ExtractToolCalls(result[k].Content) {
toolCallID := sanitizeToolCallID(toolCall.ToolCallID)
if toolCallID == "" {
continue
}
answeredByPrevious[toolCallID] = struct{}{}
}
break
}
if result[k].Role != fantasy.MessageRoleTool {
break
}
}
matchingResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
orphanResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
for _, toolResult := range toolResults {
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
if _, ok := answeredByPrevious[toolCallID]; ok {
matchingResults = append(matchingResults, toolResult)
continue
}
orphanResults = append(orphanResults, toolResult)
}
if len(orphanResults) == 0 {
result = append(result, msg)
continue
}
syntheticToolUse := syntheticToolUseMessage(
orphanResults,
toolNameByCallID,
)
if len(syntheticToolUse.Content) == 0 {
result = append(result, msg)
continue
}
if len(matchingResults) > 0 {
result = append(result, toolMessageFromToolResultParts(matchingResults))
}
result = append(result, syntheticToolUse)
result = append(result, toolMessageFromToolResultParts(orphanResults))
}
return result
}
func toolMessageFromToolResultParts(results []fantasy.ToolResultPart) fantasy.Message {
parts := make([]fantasy.MessagePart, 0, len(results))
for _, result := range results {
parts = append(parts, result)
}
return fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: parts,
}
}
func syntheticToolUseMessage(
toolResults []fantasy.ToolResultPart,
toolNameByCallID map[string]string,
) fantasy.Message {
parts := make([]fantasy.MessagePart, 0, len(toolResults))
seen := make(map[string]struct{}, len(toolResults))
for _, toolResult := range toolResults {
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
if toolCallID == "" {
continue
}
if _, ok := seen[toolCallID]; ok {
continue
}
toolName := strings.TrimSpace(toolNameByCallID[toolCallID])
if toolName == "" {
continue
}
seen[toolCallID] = struct{}{}
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: toolCallID,
ToolName: toolName,
Input: "{}",
})
}
return fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: parts,
}
}
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return "", nil
}
var content string
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
return "", xerrors.Errorf("parse system message content: %w", err)
}
return content, nil
}
func sanitizeToolCallID(id string) string {
if id == "" {
return ""
}
return toolCallIDSanitizer.ReplaceAllString(id, "_")
}
func marshalContentBlock(block fantasy.Content) (json.RawMessage, error) {
encoded, err := json.Marshal(block)
if err != nil {
return nil, err
}
title, ok := reasoningTitleFromContent(block)
if !ok || title == "" {
return encoded, nil
}
var envelope struct {
Type string `json:"type"`
Data map[string]any `json:"data"`
}
if err := json.Unmarshal(encoded, &envelope); err != nil {
return nil, err
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
return encoded, nil
}
if envelope.Data == nil {
envelope.Data = map[string]any{}
}
envelope.Data["title"] = title
encodedWithTitle, err := json.Marshal(envelope)
if err != nil {
return nil, err
}
return encodedWithTitle, nil
}
func reasoningTitleFromContent(block fantasy.Content) (string, bool) {
switch value := block.(type) {
case fantasy.ReasoningContent:
return ReasoningTitleFromFirstLine(value.Text), true
case *fantasy.ReasoningContent:
if value == nil {
return "", false
}
return ReasoningTitleFromFirstLine(value.Text), true
default:
return "", false
}
}
func reasoningSummaryTitle(metadata fantasy.ProviderMetadata) string {
if len(metadata) == 0 {
return ""
}
reasoningMetadata := fantasyopenai.GetReasoningMetadata(
fantasy.ProviderOptions(metadata),
)
if reasoningMetadata == nil {
return ""
}
for _, summary := range reasoningMetadata.Summary {
if title := compactReasoningSummaryTitle(summary); title != "" {
return title
}
}
return ""
}
func compactReasoningSummaryTitle(summary string) string {
const maxWords = 8
const maxRunes = 80
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
summary = strings.Trim(summary, "\"'`")
summary = reasoningSummaryHeadline(summary)
words := strings.Fields(summary)
if len(words) == 0 {
return ""
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "…"
}
return truncateRunes(title, maxRunes)
}
func reasoningSummaryHeadline(summary string) string {
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
// OpenAI summary_text may be markdown like:
// "**Title**\n\nLonger explanation ...".
// Keep only the heading segment for UI titles.
if idx := strings.Index(summary, "\n\n"); idx >= 0 {
summary = summary[:idx]
}
if idx := strings.IndexAny(summary, "\r\n"); idx >= 0 {
summary = summary[:idx]
}
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
if strings.HasPrefix(summary, "**") {
rest := summary[2:]
if end := strings.Index(rest, "**"); end >= 0 {
bold := strings.TrimSpace(rest[:end])
if bold != "" {
summary = bold
}
}
}
return strings.TrimSpace(strings.Trim(summary, "\"'`"))
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= maxLen {
return value
}
return string(runes[:maxLen])
}
@@ -0,0 +1,91 @@
package chatprompt_test
import (
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
)
func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "empty input",
input: "",
expected: "{}",
},
{
name: "invalid json",
input: "{\"command\":",
expected: "{}",
},
{
name: "non-object json",
input: "[]",
expected: "{}",
},
{
name: "valid object json",
input: "{\"command\":\"ls\"}",
expected: "{\"command\":\"ls\"}",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assistantContent, err := chatprompt.MarshalContent([]fantasy.Content{
fantasy.ToolCallContent{
ToolCallID: "toolu_01C4PqN6F2493pi7Ebag8Vg7",
ToolName: "execute",
Input: tc.input,
},
})
require.NoError(t, err)
toolContent, err := chatprompt.MarshalToolResult(
"toolu_01C4PqN6F2493pi7Ebag8Vg7",
"execute",
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
true,
)
require.NoError(t, err)
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
{
Role: string(fantasy.MessageRoleAssistant),
Visibility: database.ChatMessageVisibilityBoth,
Content: assistantContent,
},
{
Role: string(fantasy.MessageRoleTool),
Visibility: database.ChatMessageVisibilityBoth,
Content: toolContent,
},
})
require.NoError(t, err)
require.Len(t, prompt, 2)
require.Equal(t, fantasy.MessageRoleAssistant, prompt[0].Role)
toolCalls := chatprompt.ExtractToolCalls(prompt[0].Content)
require.Len(t, toolCalls, 1)
require.Equal(t, tc.expected, toolCalls[0].Input)
require.Equal(t, "execute", toolCalls[0].ToolName)
require.Equal(t, "toolu_01C4PqN6F2493pi7Ebag8Vg7", toolCalls[0].ToolCallID)
require.Equal(t, fantasy.MessageRoleTool, prompt[1].Role)
})
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,191 @@
package chatprovider_test
import (
"testing"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
fantasyopenai "charm.land/fantasy/providers/openai"
fantasyopenrouter "charm.land/fantasy/providers/openrouter"
fantasyvercel "charm.land/fantasy/providers/vercel"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)
func TestReasoningEffortFromChat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
provider string
input *string
want *string
}{
{
name: "OpenAICaseInsensitive",
provider: "openai",
input: stringPtr(" HIGH "),
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
},
{
name: "AnthropicEffort",
provider: "anthropic",
input: stringPtr("max"),
want: stringPtr(string(fantasyanthropic.EffortMax)),
},
{
name: "OpenRouterEffort",
provider: "openrouter",
input: stringPtr("medium"),
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
},
{
name: "VercelEffort",
provider: "vercel",
input: stringPtr("xhigh"),
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
},
{
name: "InvalidEffortReturnsNil",
provider: "openai",
input: stringPtr("unknown"),
want: nil,
},
{
name: "UnsupportedProviderReturnsNil",
provider: "bedrock",
input: stringPtr("high"),
want: nil,
},
{
name: "NilInputReturnsNil",
provider: "openai",
input: nil,
want: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatprovider.ReasoningEffortFromChat(tt.provider, tt.input)
require.Equal(t, tt.want, got)
})
}
}
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
t.Parallel()
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"openai"},
},
},
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
Effort: stringPtr("high"),
},
IncludeUsage: boolPtr(true),
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"anthropic"},
AllowFallbacks: boolPtr(true),
RequireParameters: boolPtr(false),
DataCollection: stringPtr("allow"),
Only: []string{"openai"},
Ignore: []string{"foo"},
Quantizations: []string{"int8"},
Sort: stringPtr("latency"),
},
},
}
chatprovider.MergeMissingProviderOptions(&options, defaults)
require.NotNil(t, options)
require.NotNil(t, options.OpenRouter)
require.NotNil(t, options.OpenRouter.Reasoning)
require.True(t, *options.OpenRouter.Reasoning.Enabled)
require.Equal(t, true, *options.OpenRouter.Reasoning.Exclude)
require.EqualValues(t, 123, *options.OpenRouter.Reasoning.MaxTokens)
require.Equal(t, "high", *options.OpenRouter.Reasoning.Effort)
require.NotNil(t, options.OpenRouter.IncludeUsage)
require.True(t, *options.OpenRouter.IncludeUsage)
require.NotNil(t, options.OpenRouter.Provider)
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Order)
require.NotNil(t, options.OpenRouter.Provider.AllowFallbacks)
require.True(t, *options.OpenRouter.Provider.AllowFallbacks)
require.NotNil(t, options.OpenRouter.Provider.RequireParameters)
require.False(t, *options.OpenRouter.Provider.RequireParameters)
require.Equal(t, "allow", *options.OpenRouter.Provider.DataCollection)
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Only)
require.Equal(t, []string{"foo"}, options.OpenRouter.Provider.Ignore)
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
t.Parallel()
dst := codersdk.ChatModelCallConfig{
Temperature: float64Ptr(0.2),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("alice"),
},
},
}
defaults := codersdk.ChatModelCallConfig{
MaxOutputTokens: int64Ptr(512),
Temperature: float64Ptr(0.9),
TopP: float64Ptr(0.8),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("bob"),
ReasoningEffort: stringPtr("medium"),
},
},
}
chatprovider.MergeMissingCallConfig(&dst, defaults)
require.NotNil(t, dst.MaxOutputTokens)
require.EqualValues(t, 512, *dst.MaxOutputTokens)
require.NotNil(t, dst.Temperature)
require.Equal(t, 0.2, *dst.Temperature)
require.NotNil(t, dst.TopP)
require.Equal(t, 0.8, *dst.TopP)
require.NotNil(t, dst.ProviderOptions)
require.NotNil(t, dst.ProviderOptions.OpenAI)
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
}
func stringPtr(value string) *string {
return &value
}
func boolPtr(value bool) *bool {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
+403
View File
@@ -0,0 +1,403 @@
package chattest
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/google/uuid"
)
// AnthropicHandler handles Anthropic API requests and returns a response.
type AnthropicHandler func(req *AnthropicRequest) AnthropicResponse
// AnthropicResponse represents a response to an Anthropic request.
// Either StreamingChunks or Response should be set, not both.
type AnthropicResponse struct {
StreamingChunks <-chan AnthropicChunk
Response *AnthropicMessage
}
// AnthropicRequest represents an Anthropic messages request.
type AnthropicRequest struct {
*http.Request // Embed http.Request
Model string `json:"model"`
Messages []AnthropicRequestMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
Options map[string]interface{} `json:",inline"` //nolint:revive
}
// AnthropicRequestMessage represents a message in an Anthropic request.
// Content may be either a string or a structured content array.
type AnthropicRequestMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
// AnthropicMessage represents a message in an Anthropic response.
type AnthropicMessage struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicUsage represents usage information in an Anthropic response.
type AnthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// AnthropicChunk represents a streaming chunk from Anthropic.
type AnthropicChunk struct {
Type string `json:"type"`
Index int `json:"index,omitempty"`
Message AnthropicChunkMessage `json:"message,omitempty"`
ContentBlock AnthropicContentBlock `json:"content_block,omitempty"`
Delta AnthropicDeltaBlock `json:"delta,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicChunkMessage represents message metadata in a chunk.
type AnthropicChunkMessage struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
}
// AnthropicContentBlock represents a content block in a chunk.
type AnthropicContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
}
// AnthropicDeltaBlock represents a delta block in a chunk.
type AnthropicDeltaBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
}
// anthropicServer is a test server that mocks the Anthropic API.
type anthropicServer struct {
mu sync.Mutex
server *httptest.Server
handler AnthropicHandler
request *AnthropicRequest
}
// NewAnthropic creates a new Anthropic test server with a handler function.
// The handler is called for each request and should return either a streaming
// response (via channel) or a non-streaming response.
// Returns the base URL of the server.
func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
t.Helper()
s := &anthropicServer{
handler: handler,
}
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/messages", s.handleMessages)
s.server = httptest.NewServer(mux)
t.Cleanup(func() {
s.server.Close()
})
return s.server.URL
}
func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request) {
var req AnthropicRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Return a more detailed error for debugging
http.Error(w, fmt.Sprintf("decode request: %v", err), http.StatusBadRequest)
return
}
req.Request = r // Embed the original http.Request
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeResponse(w, &req, resp)
}
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeStreamingResponse(w, resp.StreamingChunks)
default:
s.writeNonStreamingResponse(w, resp.Response)
}
}
func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <-chan AnthropicChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("anthropic-version", "2023-06-01")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
for chunk := range chunks {
chunkData := make(map[string]interface{})
chunkData["type"] = chunk.Type
switch chunk.Type {
case "message_start":
chunkData["message"] = chunk.Message
case "content_block_start":
chunkData["index"] = chunk.Index
chunkData["content_block"] = chunk.ContentBlock
case "content_block_delta":
chunkData["index"] = chunk.Index
chunkData["delta"] = chunk.Delta
case "content_block_stop":
chunkData["index"] = chunk.Index
case "message_delta":
chunkData["delta"] = map[string]interface{}{
"stop_reason": chunk.StopReason,
"stop_sequence": chunk.StopSequence,
}
chunkData["usage"] = chunk.Usage
case "message_stop":
// No additional fields
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
// Send both event and data lines to match Anthropic API format
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", chunk.Type, chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
_ = s // receiver unused but kept for consistency
response := map[string]interface{}{
"id": resp.ID,
"type": resp.Type,
"role": resp.Role,
"model": resp.Model,
"content": []map[string]interface{}{
{
"type": "text",
"text": resp.Content,
},
},
"stop_reason": resp.StopReason,
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("anthropic-version", "2023-06-01")
_ = json.NewEncoder(w).Encode(response)
}
// AnthropicStreamingResponse creates a streaming response from chunks.
func AnthropicStreamingResponse(chunks ...AnthropicChunk) AnthropicResponse {
ch := make(chan AnthropicChunk, len(chunks))
go func() {
for _, chunk := range chunks {
ch <- chunk
}
close(ch)
}()
return AnthropicResponse{StreamingChunks: ch}
}
// AnthropicNonStreamingResponse creates a non-streaming response with the given text.
func AnthropicNonStreamingResponse(text string) AnthropicResponse {
return AnthropicResponse{
Response: &AnthropicMessage{
ID: fmt.Sprintf("msg-%s", uuid.New().String()[:8]),
Type: "message",
Role: "assistant",
Content: text,
Model: "claude-3-opus-20240229",
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
}
}
// AnthropicTextChunks creates a complete streaming response with text deltas.
// Takes text deltas and creates all required chunks (message_start,
// content_block_start, content_block_delta for each delta,
// content_block_stop, message_delta, message_stop).
func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
if len(deltas) == 0 {
return nil
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "text",
Text: "", // According to Anthropic API spec, text should be empty in content_block_start
},
},
}
// Add a delta chunk for each delta
for _, delta := range deltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "text_delta",
Text: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
// Input JSON can be split across multiple deltas, matching Anthropic's
// input_json_delta streaming behavior.
func AnthropicToolCallChunks(toolName string, inputJSONDeltas ...string) []AnthropicChunk {
if len(inputJSONDeltas) == 0 {
return nil
}
if toolName == "" {
toolName = "tool"
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
toolCallID := fmt.Sprintf("toolu_%s", uuid.New().String()[:8])
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "tool_use",
ID: toolCallID,
Name: toolName,
Input: json.RawMessage("{}"),
},
},
}
for _, delta := range inputJSONDeltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "input_json_delta",
PartialJSON: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "tool_use",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
+221
View File
@@ -0,0 +1,221 @@
package chattest_test
import (
"context"
"sync/atomic"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestAnthropic_Streaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("Hello", " world", "!")...,
)
})
// Create fantasy client pointing to our test server
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
expectedDeltas := []string{"Hello", " world", "!"}
deltaIndex := 0
var allParts []fantasy.StreamPart
for part := range stream {
allParts = append(allParts, part)
if part.Type == fantasy.StreamPartTypeTextDelta {
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
deltaIndex++
}
}
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
}
func TestAnthropic_ToolCalls(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
switch requestCount.Add(1) {
case 1:
return chattest.AnthropicStreamingResponse(
chattest.AnthropicToolCallChunks("get_weather", `{"location":"San Francisco"}`)...,
)
default:
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("The weather in San Francisco is 72F.")...,
)
}
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
type weatherInput struct {
Location string `json:"location"`
}
var toolCallCount atomic.Int32
weatherTool := fantasy.NewAgentTool(
"get_weather",
"Get weather for a location.",
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolCallCount.Add(1)
require.Equal(t, "San Francisco", input.Location)
return fantasy.NewTextResponse("72F"), nil
},
)
agent := fantasy.NewAgent(
model,
fantasy.WithSystemPrompt("You are a helpful assistant."),
fantasy.WithTools(weatherTool),
)
result, err := agent.Stream(context.Background(), fantasy.AgentStreamCall{
Prompt: "What's the weather in San Francisco?",
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
}
func TestAnthropic_NonStreaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicNonStreamingResponse("Response text")
})
// Create fantasy client pointing to our test server
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestAnthropic_Streaming_MismatchReturnsErrorPart(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicNonStreamingResponse("wrong response type")
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var streamErr error
for part := range stream {
if part.Type == fantasy.StreamPartTypeError {
streamErr = part.Error
break
}
}
require.Error(t, streamErr)
require.Contains(t, streamErr.Error(), "500 Internal Server Error")
}
func TestAnthropic_NonStreaming_MismatchReturnsError(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("wrong", " response")...,
)
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "500 Internal Server Error")
}
+458
View File
@@ -0,0 +1,458 @@
package chattest
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/google/uuid"
)
// OpenAIHandler handles OpenAI API requests and returns a response.
type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
// OpenAIResponse represents a response to an OpenAI request.
// Either StreamingChunks or Response should be set, not both.
type OpenAIResponse struct {
StreamingChunks <-chan OpenAIChunk
Response *OpenAICompletion
}
// OpenAIRequest represents an OpenAI chat completion request.
type OpenAIRequest struct {
*http.Request
Model string `json:"model"`
Messages []OpenAIMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // For responses API
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
Options map[string]interface{} `json:",inline"` //nolint:revive
}
// OpenAIMessage represents a message in an OpenAI request.
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// OpenAIToolCallFunction represents the function details in a tool call.
type OpenAIToolCallFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
// OpenAIToolCall represents a tool call in a streaming chunk or completion.
type OpenAIToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function OpenAIToolCallFunction `json:"function,omitempty"`
Index int `json:"index,omitempty"` // For streaming deltas
}
// OpenAIChunkChoice represents a choice in a streaming chunk.
type OpenAIChunkChoice struct {
Index int `json:"index"`
Delta string `json:"delta,omitempty"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
// OpenAIChunk represents a streaming chunk from OpenAI.
type OpenAIChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChunkChoice `json:"choices"`
}
// OpenAICompletionChoice represents a choice in a completion response.
type OpenAICompletionChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
}
// OpenAICompletionUsage represents usage information in a completion response.
type OpenAICompletionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// OpenAICompletion represents a non-streaming OpenAI completion response.
type OpenAICompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAICompletionChoice `json:"choices"`
Usage OpenAICompletionUsage `json:"usage"`
}
// openAIServer is a test server that mocks the OpenAI API.
type openAIServer struct {
mu sync.Mutex
server *httptest.Server
handler OpenAIHandler
request *OpenAIRequest
}
// NewOpenAI creates a new OpenAI test server with a handler function.
// The handler is called for each request and should return either a streaming
// response (via channel) or a non-streaming response.
// Returns the base URL of the server.
func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
t.Helper()
s := &openAIServer{
handler: handler,
}
mux := http.NewServeMux()
mux.HandleFunc("POST /chat/completions", s.handleChatCompletions)
mux.HandleFunc("POST /responses", s.handleResponses)
s.server = httptest.NewServer(mux)
t.Cleanup(func() {
s.server.Close()
})
return s.server.URL
}
func (s *openAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
var req OpenAIRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Request = r
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeChatCompletionsResponse(w, &req, resp)
}
func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
var req OpenAIRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Request = r
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeResponsesAPIResponse(w, &req, resp)
}
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeChatCompletionsStreaming(w, resp.StreamingChunks)
default:
s.writeChatCompletionsNonStreaming(w, resp.Response)
}
}
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeResponsesAPIStreaming(w, resp.StreamingChunks)
default:
s.writeResponsesAPINonStreaming(w, resp.Response)
}
}
func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
for chunk := range chunks {
choicesData := make([]map[string]interface{}, len(chunk.Choices))
for i, choice := range chunk.Choices {
choiceData := map[string]interface{}{
"index": choice.Index,
}
if choice.Delta != "" {
choiceData["delta"] = map[string]interface{}{
"content": choice.Delta,
}
}
if len(choice.ToolCalls) > 0 {
// Tool calls come in the delta
if choiceData["delta"] == nil {
choiceData["delta"] = make(map[string]interface{})
}
delta, ok := choiceData["delta"].(map[string]interface{})
if !ok {
delta = make(map[string]interface{})
choiceData["delta"] = delta
}
delta["tool_calls"] = choice.ToolCalls
}
if choice.FinishReason != "" {
choiceData["finish_reason"] = choice.FinishReason
}
choicesData[i] = choiceData
}
chunkData := map[string]interface{}{
"id": chunk.ID,
"object": chunk.Object,
"created": chunk.Created,
"model": chunk.Model,
"choices": choicesData,
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
return
}
flusher.Flush()
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
itemIDs := make(map[int]string)
for chunk := range chunks {
// Responses API sends one event per choice
for outputIndex, choice := range chunk.Choices {
if choice.Index != 0 {
outputIndex = choice.Index
}
itemID, found := itemIDs[outputIndex]
if !found {
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
itemIDs[outputIndex] = itemID
}
chunkData := map[string]interface{}{
"type": "response.output_text.delta",
"item_id": itemID,
"output_index": outputIndex,
"created": chunk.Created,
"model": chunk.Model,
"content_index": 0,
"delta": choice.Delta,
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
// Convert all choices to output format
outputs := make([]map[string]interface{}, len(resp.Choices))
for i, choice := range resp.Choices {
outputs[i] = map[string]interface{}{
"id": uuid.New().String(),
"type": "message",
"role": "assistant",
"content": []map[string]interface{}{
{
"type": "output_text",
"text": choice.Message.Content,
},
},
}
}
response := map[string]interface{}{
"id": resp.ID,
"object": "response",
"created": resp.Created,
"model": resp.Model,
"output": outputs,
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(response)
}
// OpenAIStreamingResponse creates a streaming response from chunks.
func OpenAIStreamingResponse(chunks ...OpenAIChunk) OpenAIResponse {
ch := make(chan OpenAIChunk, len(chunks))
go func() {
for _, chunk := range chunks {
ch <- chunk
}
close(ch)
}()
return OpenAIResponse{StreamingChunks: ch}
}
// OpenAINonStreamingResponse creates a non-streaming response with the given text.
func OpenAINonStreamingResponse(text string) OpenAIResponse {
return OpenAIResponse{
Response: &OpenAICompletion{
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: "gpt-4",
Choices: []OpenAICompletionChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
Content: text,
},
FinishReason: "stop",
},
},
Usage: OpenAICompletionUsage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
},
}
}
// OpenAITextChunks creates streaming chunks with text deltas.
// Each delta string becomes a separate chunk with a single choice.
// Returns a slice of chunks, one per delta, with each choice having its index (0, 1, 2, ...).
func OpenAITextChunks(deltas ...string) []OpenAIChunk {
if len(deltas) == 0 {
return nil
}
chunkID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8])
now := time.Now().Unix()
chunks := make([]OpenAIChunk, len(deltas))
for i, delta := range deltas {
chunks[i] = OpenAIChunk{
ID: chunkID,
Object: "chat.completion.chunk",
Created: now,
Model: "gpt-4",
Choices: []OpenAIChunkChoice{
{
Index: i,
Delta: delta,
},
},
}
}
return chunks
}
// OpenAIToolCallChunk creates a streaming chunk with a tool call.
// Takes the tool name and arguments JSON string, creates a tool call for choice index 0.
func OpenAIToolCallChunk(toolName, arguments string) OpenAIChunk {
return OpenAIChunk{
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "gpt-4",
Choices: []OpenAIChunkChoice{
{
Index: 0,
ToolCalls: []OpenAIToolCall{
{
Index: 0,
ID: fmt.Sprintf("call_%s", uuid.New().String()[:8]),
Type: "function",
Function: OpenAIToolCallFunction{
Name: toolName,
Arguments: arguments,
},
},
},
},
},
}
}
+367
View File
@@ -0,0 +1,367 @@
package chattest_test
import (
"context"
"sync/atomic"
"testing"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestOpenAI_Streaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(
append(
append(
chattest.OpenAITextChunks("Hello", "Hi"),
chattest.OpenAITextChunks(" world", " there")...,
),
chattest.OpenAITextChunks("!", "!")...,
)...,
)
})
// Create fantasy client pointing to our test server
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
// We expect chunks in order: one choice per chunk
// So we get: "Hello" (choice 0), "Hi" (choice 1), " world" (choice 0), " there" (choice 1), "!" (choice 0), "!" (choice 1)
expectedDeltas := []string{"Hello", "Hi", " world", " there", "!", "!"}
deltaIndex := 0
for part := range stream {
if part.Type == fantasy.StreamPartTypeTextDelta {
// Verify we're getting deltas in the expected order
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
deltaIndex++
}
}
// Verify we received all expected deltas
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d", len(expectedDeltas), deltaIndex)
}
func TestOpenAI_Streaming_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(
append(
append(
chattest.OpenAITextChunks("First", "Second"),
chattest.OpenAITextChunks(" output", " output")...,
),
chattest.OpenAITextChunks("!", "!")...,
)...,
)
})
// Create fantasy client pointing to our test server (responses API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
var parts []fantasy.StreamPart
for part := range stream {
parts = append(parts, part)
}
// Verify we received the chunks in order
require.Greater(t, len(parts), 0)
// Extract text deltas from parts and verify they match expected chunks in order
// We expect: "First", " output", "!" for choice 0, and "Second", " output", "!" for choice 1
var allDeltas []string
for _, part := range parts {
if part.Type == fantasy.StreamPartTypeTextDelta {
allDeltas = append(allDeltas, part.Delta)
}
}
// Verify we received deltas (responses API may handle multiple choices differently)
// If we got text deltas, verify the content
if len(allDeltas) > 0 {
allText := ""
for _, delta := range allDeltas {
allText += delta
}
require.Contains(t, allText, "First")
require.Contains(t, allText, "Second")
require.Contains(t, allText, "output")
require.Contains(t, allText, "!")
} else {
// If no text deltas, at least verify we got some parts (may be different format)
require.Greater(t, len(parts), 0, "Expected at least one stream part")
}
}
func TestOpenAI_NonStreaming_CompletionsAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("First response")
})
// Create fantasy client pointing to our test server (completions API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestOpenAI_ToolCalls(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
switch requestCount.Add(1) {
case 1:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("get_weather", `{"location":"San Francisco"}`),
)
default:
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The weather in San Francisco is 72F.")...,
)
}
})
// Create fantasy client pointing to our test server
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
type weatherInput struct {
Location string `json:"location"`
}
var toolCallCount atomic.Int32
weatherTool := fantasy.NewAgentTool(
"get_weather",
"Get weather for a location.",
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolCallCount.Add(1)
require.Equal(t, "San Francisco", input.Location)
return fantasy.NewTextResponse("72F"), nil
},
)
agent := fantasy.NewAgent(
model,
fantasy.WithSystemPrompt("You are a helpful assistant."),
fantasy.WithTools(weatherTool),
)
result, err := agent.Stream(ctx, fantasy.AgentStreamCall{
Prompt: "What's the weather in San Francisco?",
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
}
func TestOpenAI_NonStreaming_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("First output")
})
// Create fantasy client pointing to our test server (responses API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestOpenAI_Streaming_MismatchReturnsErrorPart(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("wrong response type")
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var streamErr error
for part := range stream {
if part.Type == fantasy.StreamPartTypeError {
streamErr = part.Error
break
}
}
require.Error(t, streamErr)
require.Contains(t, streamErr.Error(), "non-streaming response for streaming request")
}
func TestOpenAI_NonStreaming_MismatchReturnsError_CompletionsAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "streaming response for non-streaming request")
}
func TestOpenAI_NonStreaming_MismatchReturnsError_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "streaming response for non-streaming request")
}
+33
View File
@@ -0,0 +1,33 @@
package chattool
import (
"encoding/json"
"unicode/utf8"
"charm.land/fantasy"
)
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
// result payload.
func toolResponse(result map[string]any) fantasy.ToolResponse {
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextResponse("{}")
}
return fantasy.NewTextResponse(string(data))
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 || value == "" {
return ""
}
if utf8.RuneCountInString(value) <= maxLen {
return value
}
runes := []rune(value)
if maxLen > len(runes) {
maxLen = len(runes)
}
return string(runes[:maxLen])
}
+426
View File
@@ -0,0 +1,426 @@
package chattool
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/namesgenerator"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// buildPollInterval is how often we check if the workspace
// build has completed.
buildPollInterval = 2 * time.Second
// buildTimeout is the maximum time to wait for a workspace
// build to complete before giving up.
buildTimeout = 10 * time.Minute
// agentConnectTimeout is the maximum time to wait for the
// workspace agent to become reachable after a successful build.
agentConnectTimeout = 2 * time.Minute
// agentRetryInterval is how often we retry connecting to the
// workspace agent.
agentRetryInterval = 2 * time.Second
// agentAttemptTimeout is the timeout for a single connection
// attempt to the workspace agent during the retry loop.
agentAttemptTimeout = 5 * time.Second
// agentPingTimeout is the timeout for a single agent ping
// when checking whether an existing workspace is alive.
agentPingTimeout = 5 * time.Second
)
// CreateWorkspaceFn creates a workspace for the given owner.
type CreateWorkspaceFn func(
ctx context.Context,
ownerID uuid.UUID,
req codersdk.CreateWorkspaceRequest,
) (codersdk.Workspace, error)
// AgentConnFunc provides access to workspace agent connections.
type AgentConnFunc func(
ctx context.Context,
agentID uuid.UUID,
) (workspacesdk.AgentConn, func(), error)
// CreateWorkspaceOptions configures the create_workspace tool.
type CreateWorkspaceOptions struct {
DB database.Store
OwnerID uuid.UUID
ChatID uuid.UUID
CreateFn CreateWorkspaceFn
AgentConnFn AgentConnFunc
WorkspaceMu *sync.Mutex
}
type createWorkspaceArgs struct {
TemplateID string `json:"template_id"`
Name string `json:"name,omitempty"`
Parameters map[string]string `json:"parameters,omitempty"`
}
// CreateWorkspace returns a tool that creates a new workspace from a
// template. The tool is idempotent: if the chat already has a
// workspace that is building or running, it returns the existing
// workspace instead of creating a new one. A mutex prevents parallel
// calls from creating duplicate workspaces.
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"create_workspace",
"Create a new workspace from a template. Requires a "+
"template_id (from list_templates). Optionally provide "+
"a name and parameter values (from read_template). "+
"If no name is given, one will be generated. "+
"This tool is idempotent — if the chat already has a "+
"workspace that is building or running, the existing "+
"workspace is returned.",
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.CreateFn == nil {
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
}
templateIDStr := strings.TrimSpace(args.TemplateID)
if templateIDStr == "" {
return fantasy.NewTextErrorResponse("template_id is required; use list_templates to find one"), nil
}
templateID, err := uuid.Parse(templateIDStr)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("invalid template_id: %w", err).Error(),
), nil
}
// Serialize workspace creation to prevent parallel
// tool calls from creating duplicate workspaces.
if options.WorkspaceMu != nil {
options.WorkspaceMu.Lock()
defer options.WorkspaceMu.Unlock()
}
// Check for an existing workspace on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
existing, done, existErr := checkExistingWorkspace(
ctx, options.DB, options.ChatID,
options.AgentConnFn,
)
if existErr != nil {
return fantasy.NewTextErrorResponse(existErr.Error()), nil
}
if done {
return toolResponse(existing), nil
}
}
ownerID := options.OwnerID
// Set up dbauthz context for DB lookups.
if options.DB != nil {
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
if ownerErr != nil {
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
}
ctx = ownerCtx
}
createReq := codersdk.CreateWorkspaceRequest{
TemplateID: templateID,
}
// Resolve workspace name.
name := strings.TrimSpace(args.Name)
if name == "" {
seed := "workspace"
if options.DB != nil {
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
seed = t.Name
}
}
name = generatedWorkspaceName(seed)
} else if err := codersdk.NameValid(name); err != nil {
name = generatedWorkspaceName(name)
}
createReq.Name = name
// Map parameters.
for k, v := range args.Parameters {
createReq.RichParameterValues = append(
createReq.RichParameterValues,
codersdk.WorkspaceBuildParameter{Name: k, Value: v},
)
}
workspace, err := options.CreateFn(ctx, ownerID, createReq)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// Wait for the build to complete and the agent to
// come online so subsequent tools can use the
// workspace immediately.
if options.DB != nil {
if err := waitForBuild(ctx, options.DB, workspace.ID); err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("workspace build failed: %w", err).Error(),
), nil
}
}
// Look up the first agent so we can link it to the chat.
workspaceAgentID := uuid.Nil
if options.DB != nil {
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil && len(agents) > 0 {
workspaceAgentID = agents[0].ID
}
}
// Persist workspace + agent association on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
WorkspaceAgentID: uuid.NullUUID{
UUID: workspaceAgentID,
Valid: workspaceAgentID != uuid.Nil,
},
})
}
// Wait for the agent to come online.
if workspaceAgentID != uuid.Nil && options.AgentConnFn != nil {
if err := waitForAgent(ctx, options.AgentConnFn, workspaceAgentID); err != nil {
// Non-fatal: the workspace was created
// successfully, the agent just isn't ready
// yet. The model can retry.
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
"agent_status": "not_ready",
"agent_error": err.Error(),
}), nil
}
}
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}), nil
},
)
}
// checkExistingWorkspace checks whether the chat already has a usable
// workspace. Returns the result map and true if the caller should
// return early (workspace exists and is alive or building). Returns
// false if the caller should proceed with creation (workspace is dead
// or missing).
func checkExistingWorkspace(
ctx context.Context,
db database.Store,
chatID uuid.UUID,
agentConnFn AgentConnFunc,
) (map[string]any, bool, error) {
chat, err := db.GetChatByID(ctx, chatID)
if err != nil {
return nil, false, xerrors.Errorf("load chat: %w", err)
}
if !chat.WorkspaceID.Valid {
return nil, false, nil
}
// Check if workspace still exists.
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
// Workspace was deleted — allow creation.
return nil, false, nil
}
return nil, false, xerrors.Errorf("load workspace: %w", err)
}
// Check the latest build status.
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
if err != nil {
// Can't determine status — allow creation.
return nil, false, nil
}
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
if err != nil {
return nil, false, nil
}
switch job.JobStatus {
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning:
// Build is in progress — wait for it instead of
// creating a new workspace.
if err := waitForBuild(ctx, db, ws.ID); err != nil {
return nil, false, xerrors.Errorf(
"existing workspace build failed: %w", err,
)
}
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace was already being built and is now ready",
}, true, nil
case database.ProvisionerJobStatusSucceeded:
// Build succeeded — check if agent is reachable.
if chat.WorkspaceAgentID.Valid && agentConnFn != nil {
pingCtx, cancel := context.WithTimeout(
ctx, agentPingTimeout,
)
defer cancel()
conn, release, connErr := agentConnFn(
pingCtx, chat.WorkspaceAgentID.UUID,
)
if connErr == nil {
release()
_ = conn
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace is already running and reachable",
}, true, nil
}
// Agent unreachable — workspace is dead, allow
// creation.
}
// No agent ID or no conn func — allow creation.
return nil, false, nil
default:
// Failed, canceled, etc — allow creation.
return nil, false, nil
}
}
// waitForBuild polls the workspace's latest build until it
// completes or the context expires.
func waitForBuild(
ctx context.Context,
db database.Store,
workspaceID uuid.UUID,
) error {
buildCtx, cancel := context.WithTimeout(ctx, buildTimeout)
defer cancel()
ticker := time.NewTicker(buildPollInterval)
defer ticker.Stop()
for {
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
buildCtx, workspaceID,
)
if err != nil {
return xerrors.Errorf("get latest build: %w", err)
}
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
switch job.JobStatus {
case database.ProvisionerJobStatusSucceeded:
return nil
case database.ProvisionerJobStatusFailed:
errMsg := "build failed"
if job.Error.Valid {
errMsg = job.Error.String
}
return xerrors.New(errMsg)
case database.ProvisionerJobStatusCanceled:
return xerrors.New("build was canceled")
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning,
database.ProvisionerJobStatusCanceling:
// Still in progress — keep waiting.
default:
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
}
select {
case <-buildCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace build: %w",
buildCtx.Err(),
)
case <-ticker.C:
}
}
}
// waitForAgent retries connecting to the workspace agent until it
// succeeds or the timeout expires.
func waitForAgent(
ctx context.Context,
agentConnFn AgentConnFunc,
agentID uuid.UUID,
) error {
agentCtx, cancel := context.WithTimeout(ctx, agentConnectTimeout)
defer cancel()
ticker := time.NewTicker(agentRetryInterval)
defer ticker.Stop()
var lastErr error
for {
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, agentAttemptTimeout)
conn, release, err := agentConnFn(attemptCtx, agentID)
attemptCancel()
if err == nil {
release()
_ = conn
return nil
}
lastErr = err
select {
case <-agentCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace agent: %w",
lastErr,
)
case <-ticker.C:
}
}
}
func generatedWorkspaceName(seed string) string {
base := codersdk.UsernameFrom(strings.TrimSpace(strings.ToLower(seed)))
if strings.TrimSpace(base) == "" {
base = "workspace"
}
suffix := strings.ReplaceAll(uuid.NewString(), "-", "")[:4]
if len(base) > 27 {
base = strings.Trim(base[:27], "-")
}
if base == "" {
base = "workspace"
}
name := fmt.Sprintf("%s-%s", base, suffix)
if err := codersdk.NameValid(name); err == nil {
return name
}
return namesgenerator.NameDigitWith("-")
}
+50
View File
@@ -0,0 +1,50 @@
package chattool
import (
"context"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type EditFilesOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type EditFilesArgs struct {
Files []workspacesdk.FileEdits `json:"files"`
}
func EditFiles(options EditFilesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"edit_files",
"Perform search-and-replace edits on one or more files in the workspace."+
" Each file can have multiple edits applied atomically.",
func(ctx context.Context, args EditFilesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeEditFilesTool(ctx, conn, args)
},
)
}
func executeEditFilesTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args EditFilesArgs,
) (fantasy.ToolResponse, error) {
if len(args.Files) == 0 {
return fantasy.NewTextErrorResponse("files is required"), nil
}
if err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: args.Files}); err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{"ok": true}), nil
}
+133
View File
@@ -0,0 +1,133 @@
package chattool
import (
"context"
"time"
"charm.land/fantasy"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
defaultExecuteTimeout = 60 * time.Second
chatAgentEnvVar = "CODER_CHAT_AGENT"
gitAuthRequiredPrefix = "CODER_GITAUTH_REQUIRED:"
authRequiredResultReason = "authentication_required"
)
type ExecuteOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
DefaultTimeout time.Duration
}
type ExecuteArgs struct {
Command string `json:"command"`
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
}
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
},
)
}
func executeTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args ExecuteArgs,
defaultTimeout time.Duration,
) fantasy.ToolResponse {
if args.Command == "" {
return fantasy.NewTextErrorResponse("command is required")
}
timeout := defaultTimeout
if timeout <= 0 {
timeout = defaultExecuteTimeout
}
if args.TimeoutSeconds != nil {
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
}
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
output, exitCode, err := runCommand(cmdCtx, conn, args.Command)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error())
}
return toolResponse(map[string]any{
"output": output,
"exit_code": exitCode,
})
}
func runCommand(
ctx context.Context,
conn workspacesdk.AgentConn,
command string,
) (string, int, error) {
sshClient, err := conn.SSHClient(ctx)
if err != nil {
return "", 0, err
}
defer sshClient.Close()
session, err := sshClient.NewSession()
if err != nil {
return "", 0, err
}
defer session.Close()
if err := session.Setenv(chatAgentEnvVar, "true"); err != nil {
return "", 0, xerrors.Errorf("set %s: %w", chatAgentEnvVar, err)
}
resultCh := make(chan struct {
output string
exitCode int
err error
}, 1)
go func() {
output, err := session.CombinedOutput(command)
exitCode := 0
if err != nil {
var exitErr *ssh.ExitError
if xerrors.As(err, &exitErr) {
exitCode = exitErr.ExitStatus()
} else {
exitCode = 1
}
}
resultCh <- struct {
output string
exitCode int
err error
}{
output: string(output),
exitCode: exitCode,
err: err,
}
}()
select {
case <-ctx.Done():
_ = session.Close()
return "", 0, ctx.Err()
case result := <-resultCh:
return result.output, result.exitCode, result.err
}
}
+94
View File
@@ -0,0 +1,94 @@
package chattool
import (
"context"
"database/sql"
"strings"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
)
// ListTemplatesOptions configures the list_templates tool.
type ListTemplatesOptions struct {
DB database.Store
OwnerID uuid.UUID
}
type listTemplatesArgs struct {
Query string `json:"query,omitempty"`
}
// ListTemplates returns a tool that lists available workspace templates.
// The agent uses this to discover templates before creating a workspace.
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"list_templates",
"List available workspace templates. Optionally filter by a "+
"search query matching template name or description. "+
"Use this to find a template before creating a workspace.",
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
filterParams := database.GetTemplatesWithFilterParams{
Deleted: false,
Deprecated: sql.NullBool{
Bool: false,
Valid: true,
},
}
query := strings.TrimSpace(args.Query)
if query != "" {
filterParams.FuzzyName = query
}
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
items := make([]map[string]any, 0, len(templates))
for _, t := range templates {
item := map[string]any{
"id": t.ID.String(),
"name": t.Name,
}
if display := strings.TrimSpace(t.DisplayName); display != "" {
item["display_name"] = display
}
if desc := strings.TrimSpace(t.Description); desc != "" {
item["description"] = truncateRunes(desc, 200)
}
items = append(items, item)
}
return toolResponse(map[string]any{
"templates": items,
"count": len(items),
}), nil
},
)
}
// asOwner sets up a dbauthz context for the given owner so that
// subsequent database calls are scoped to what that user can access.
func asOwner(ctx context.Context, db database.Store, ownerID uuid.UUID) (context.Context, error) {
actor, _, err := httpmw.UserRBACSubject(ctx, db, ownerID, rbac.ScopeAll)
if err != nil {
return ctx, xerrors.Errorf("load user authorization: %w", err)
}
return dbauthz.As(ctx, actor), nil
}
+72
View File
@@ -0,0 +1,72 @@
package chattool
import (
"context"
"io"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type ReadFileOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type ReadFileArgs struct {
Path string `json:"path"`
Offset *int64 `json:"offset,omitempty"`
Limit *int64 `json:"limit,omitempty"`
}
func ReadFile(options ReadFileOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"read_file",
"Read a file from the workspace.",
func(ctx context.Context, args ReadFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeReadFileTool(ctx, conn, args)
},
)
}
func executeReadFileTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args ReadFileArgs,
) (fantasy.ToolResponse, error) {
if args.Path == "" {
return fantasy.NewTextErrorResponse("path is required"), nil
}
offset := int64(0)
limit := int64(0)
if args.Offset != nil {
offset = *args.Offset
}
if args.Limit != nil {
limit = *args.Limit
}
reader, mimeType, err := conn.ReadFile(ctx, args.Path, offset, limit)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{
"content": string(data),
"mime_type": mimeType,
}), nil
}
+130
View File
@@ -0,0 +1,130 @@
package chattool
import (
"context"
"encoding/json"
"strings"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
// ReadTemplateOptions configures the read_template tool.
type ReadTemplateOptions struct {
DB database.Store
OwnerID uuid.UUID
}
type readTemplateArgs struct {
TemplateID string `json:"template_id"`
}
// ReadTemplate returns a tool that retrieves details about a specific
// template, including its configurable rich parameters. The agent
// uses this after list_templates and before create_workspace.
func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"read_template",
"Get details about a workspace template, including its "+
"configurable parameters. Use this after finding a "+
"template with list_templates and before creating a "+
"workspace with create_workspace.",
func(ctx context.Context, args readTemplateArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
templateIDStr := strings.TrimSpace(args.TemplateID)
if templateIDStr == "" {
return fantasy.NewTextErrorResponse("template_id is required"), nil
}
templateID, err := uuid.Parse(templateIDStr)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("invalid template_id: %w", err).Error(),
), nil
}
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
template, err := options.DB.GetTemplateByID(ctx, templateID)
if err != nil {
return fantasy.NewTextErrorResponse("template not found"), nil
}
params, err := options.DB.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("failed to get template parameters: %w", err).Error(),
), nil
}
templateInfo := map[string]any{
"id": template.ID.String(),
"name": template.Name,
"active_version_id": template.ActiveVersionID.String(),
}
if display := strings.TrimSpace(template.DisplayName); display != "" {
templateInfo["display_name"] = display
}
if desc := strings.TrimSpace(template.Description); desc != "" {
templateInfo["description"] = desc
}
paramList := make([]map[string]any, 0, len(params))
for _, p := range params {
param := map[string]any{
"name": p.Name,
"type": p.Type,
"required": p.Required,
}
if display := strings.TrimSpace(p.DisplayName); display != "" {
param["display_name"] = display
}
if desc := strings.TrimSpace(p.Description); desc != "" {
param["description"] = truncateRunes(desc, 300)
}
if p.DefaultValue != "" {
param["default"] = p.DefaultValue
}
if p.Mutable {
param["mutable"] = true
}
if p.Ephemeral {
param["ephemeral"] = true
}
if p.FormType != "" {
param["form_type"] = string(p.FormType)
}
if len(p.Options) > 0 && string(p.Options) != "null" && string(p.Options) != "[]" {
var opts []map[string]any
if err := json.Unmarshal(p.Options, &opts); err == nil && len(opts) > 0 {
param["options"] = opts
}
}
if p.ValidationRegex != "" {
param["validation_regex"] = p.ValidationRegex
}
if p.ValidationMin.Valid {
param["validation_min"] = p.ValidationMin.Int32
}
if p.ValidationMax.Valid {
param["validation_max"] = p.ValidationMax.Int32
}
paramList = append(paramList, param)
}
return toolResponse(map[string]any{
"template": templateInfo,
"parameters": paramList,
}), nil
},
)
}
+51
View File
@@ -0,0 +1,51 @@
package chattool
import (
"context"
"strings"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type WriteFileOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type WriteFileArgs struct {
Path string `json:"path"`
Content string `json:"content"`
}
func WriteFile(options WriteFileOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"write_file",
"Write a file to the workspace.",
func(ctx context.Context, args WriteFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeWriteFileTool(ctx, conn, args)
},
)
}
func executeWriteFileTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args WriteFileArgs,
) (fantasy.ToolResponse, error) {
if args.Path == "" {
return fantasy.NewTextErrorResponse("path is required"), nil
}
if err := conn.WriteFile(ctx, args.Path, strings.NewReader(args.Content)); err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{"ok": true}), nil
}
+126
View File
@@ -0,0 +1,126 @@
package chatd
import (
"context"
"io"
"net/http"
"regexp"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
coderHomeInstructionDir = ".coder"
coderHomeInstructionFile = "AGENTS.md"
maxInstructionFileBytes = 64 * 1024
)
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
func readHomeInstructionFile(
ctx context.Context,
conn workspacesdk.AgentConn,
) (content string, sourcePath string, truncated bool, err error) {
if conn == nil {
return "", "", false, nil
}
coderDir, err := conn.LS(ctx, "", workspacesdk.LSRequest{
Path: []string{coderHomeInstructionDir},
Relativity: workspacesdk.LSRelativityHome,
})
if err != nil {
if isCodersdkStatusCode(err, http.StatusNotFound) {
return "", "", false, nil
}
return "", "", false, xerrors.Errorf("list home instruction directory: %w", err)
}
var filePath string
for _, entry := range coderDir.Contents {
if entry.IsDir {
continue
}
if strings.EqualFold(strings.TrimSpace(entry.Name), coderHomeInstructionFile) {
filePath = strings.TrimSpace(entry.AbsolutePathString)
break
}
}
if filePath == "" {
return "", "", false, nil
}
reader, _, err := conn.ReadFile(
ctx,
filePath,
0,
maxInstructionFileBytes+1,
)
if err != nil {
if isCodersdkStatusCode(err, http.StatusNotFound) {
return "", "", false, nil
}
return "", "", false, xerrors.Errorf("read home instruction file: %w", err)
}
defer reader.Close()
raw, err := io.ReadAll(reader)
if err != nil {
return "", "", false, xerrors.Errorf("read home instruction bytes: %w", err)
}
truncated = int64(len(raw)) > maxInstructionFileBytes
if truncated {
raw = raw[:maxInstructionFileBytes]
}
content = sanitizeInstructionMarkdown(string(raw))
if content == "" {
return "", "", truncated, nil
}
return content, filePath, truncated, nil
}
func sanitizeInstructionMarkdown(content string) string {
content = strings.ReplaceAll(content, "\r\n", "\n")
content = strings.ReplaceAll(content, "\r", "\n")
content = markdownCommentPattern.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
//nolint:revive // Boolean indicates content was truncated.
func formatHomeInstruction(content string, sourcePath string, truncated bool) string {
content = strings.TrimSpace(content)
if content == "" {
return ""
}
sourcePath = strings.TrimSpace(sourcePath)
if sourcePath == "" {
sourcePath = "~/.coder/AGENTS.md"
}
var b strings.Builder
_, _ = b.WriteString("<coder-home-instructions>\n")
_, _ = b.WriteString("Source: ")
_, _ = b.WriteString(sourcePath)
if truncated {
_, _ = b.WriteString(" (truncated to 64KiB)")
}
_, _ = b.WriteString("\n\n")
_, _ = b.WriteString(content)
_, _ = b.WriteString("\n</coder-home-instructions>")
return b.String()
}
func isCodersdkStatusCode(err error, statusCode int) bool {
var sdkErr *codersdk.Error
if !xerrors.As(err, &sdkErr) {
return false
}
return sdkErr.StatusCode() == statusCode
}
+134
View File
@@ -0,0 +1,134 @@
package chatd //nolint:testpackage // Uses internal symbols.
import (
"context"
"io"
"strings"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
)
func TestSanitizeInstructionMarkdown(t *testing.T) {
t.Parallel()
input := "line 1\r\n<!-- hidden -->\r\nline 2\r\n"
require.Equal(t, "line 1\n\nline 2", sanitizeInstructionMarkdown(input))
}
func TestReadHomeInstructionFileNotFound(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
return workspacesdk.LSResponse{}, codersdk.NewTestError(404, "POST", "/api/v0/list-directory")
},
)
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
require.NoError(t, err)
require.Empty(t, content)
require.Empty(t, sourcePath)
require.False(t, truncated)
}
func TestReadHomeInstructionFileSuccess(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
return workspacesdk.LSResponse{
Contents: []workspacesdk.LSFile{{
Name: "AGENTS.md",
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
}},
}, nil
},
)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/.coder/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
io.NopCloser(strings.NewReader("base\n<!-- hidden -->\nlocal")),
"text/markdown",
nil,
)
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
require.NoError(t, err)
require.Equal(t, "base\n\nlocal", content)
require.Equal(t, "/home/coder/.coder/AGENTS.md", sourcePath)
require.False(t, truncated)
}
func TestReadHomeInstructionFileTruncates(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
conn := agentconnmock.NewMockAgentConn(ctrl)
content := strings.Repeat("a", maxInstructionFileBytes+8)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{
Contents: []workspacesdk.LSFile{{
Name: "AGENTS.md",
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
}},
},
nil,
)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/.coder/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(io.NopCloser(strings.NewReader(content)), "text/markdown", nil)
got, _, truncated, err := readHomeInstructionFile(context.Background(), conn)
require.NoError(t, err)
require.True(t, truncated)
require.Len(t, got, maxInstructionFileBytes)
}
func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
t.Parallel()
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "base"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "hello"},
},
},
}
got := chatprompt.InsertSystem(prompt, "project rules")
require.Len(t, got, 3)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleSystem, got[1].Role)
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "project rules", part.Text)
}
+73
View File
@@ -0,0 +1,73 @@
package chatd
// DefaultSystemPrompt is used for new chats when no deployment override is
// configured.
const DefaultSystemPrompt = `You are the Coder agent an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
Use the instructions below and the tools available to you to assist User.
IMPORTANT obey every rule in this prompt before anything else.
Do EXACTLY what the User asked, never more, never less.
<behavior>
You MUST execute AS MANY TOOLS to help the user accomplish their task.
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
DO NOT ask the user for clarification - just use your tools.
</behavior>
<personality>
Analytical You break problems into measurable steps, relying on tool output and data rather than intuition.
Organized You structure every interaction with clear tags, TODO lists, and section boundaries.
Precision-Oriented You insist on exact formatting, package-manager choice, and rule adherence.
Efficiency-Focused You minimize chatter, run tasks in parallel, and favor small, complete answers.
Clarity-Seeking You ask for missing details instead of guessing, avoiding any ambiguity.
</personality>
<communication>
Be concise, direct, and to the point.
NO emojis unless the User explicitly asks for them.
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
Default to the project's existing package manager / tooling; never substitute without confirmation.
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
Mimic the style of the User's messages.
Do not remind the User you are happy to help.
Do not inherently assume the User is correct; they may be making assumptions.
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
Do not act with sycophantic flattery or over-the-top enthusiasm.
Here are examples to demonstrate appropriate communication style and level of verbosity:
<example>
user: find me a good issue to work on
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
</example>
<example>
user: work on this issue <url>
...assistant does work...
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
</example>
<example>
user: what is 2+2?
assistant: 4
</example>
<example>
user: how does X work in <popular-repository-name>?
assistant: Let me take a look at the code...
[tool calls to investigate the repository]
</example>
</communication>
<collaboration>
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
- What specific aspect they want to focus on
- Their goals and vision for the changes
- Their preferences for approach or style
- What problems they're trying to solve
Don't assume what needs to be done - collaborate to define the scope together.
</collaboration>`
+512
View File
@@ -0,0 +1,512 @@
package chatd
import (
"context"
"encoding/json"
"sort"
"strings"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
)
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
const (
subagentAwaitPollInterval = 200 * time.Millisecond
defaultSubagentWaitTimeout = 5 * time.Minute
)
type spawnAgentArgs struct {
Prompt string `json:"prompt"`
Title string `json:"title,omitempty"`
}
type waitAgentArgs struct {
ChatID string `json:"chat_id"`
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
}
type messageAgentArgs struct {
ChatID string `json:"chat_id"`
Message string `json:"message"`
Interrupt bool `json:"interrupt,omitempty"`
}
type closeAgentArgs struct {
ChatID string `json:"chat_id"`
}
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
return []fantasy.AgentTool{
fantasy.NewAgentTool(
"spawn_agent",
"Spawn a delegated child agent chat from the root chat.",
func(ctx context.Context, args spawnAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
parent := currentChat()
if parent.ParentChatID.Valid {
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
}
parent, err := p.db.GetChatByID(ctx, parent.ID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
childChat, err := p.createChildSubagentChat(
ctx,
parent,
args.Prompt,
args.Title,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": childChat.ID.String(),
"title": childChat.Title,
"status": string(childChat.Status),
}), nil
},
),
fantasy.NewAgentTool(
"wait_agent",
"Wait until a delegated descendant agent reaches a non-streaming status.",
func(ctx context.Context, args waitAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
targetChatID, err := parseSubagentToolChatID(args.ChatID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
timeout := defaultSubagentWaitTimeout
if args.TimeoutSeconds != nil {
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
}
parent := currentChat()
targetChat, report, err := p.awaitSubagentCompletion(
ctx,
parent.ID,
targetChatID,
timeout,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": targetChatID.String(),
"title": targetChat.Title,
"report": report,
"status": string(targetChat.Status),
}), nil
},
),
fantasy.NewAgentTool(
"message_agent",
"Send a message to a delegated descendant agent. Use wait_agent to collect a response.",
func(ctx context.Context, args messageAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
targetChatID, err := parseSubagentToolChatID(args.ChatID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
parent := currentChat()
busyBehavior := SendMessageBusyBehaviorQueue
if args.Interrupt {
busyBehavior = SendMessageBusyBehaviorInterrupt
}
targetChat, err := p.sendSubagentMessage(
ctx,
parent.ID,
targetChatID,
args.Message,
busyBehavior,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": targetChatID.String(),
"title": targetChat.Title,
"status": string(targetChat.Status),
"interrupted": args.Interrupt,
}), nil
},
),
fantasy.NewAgentTool(
"close_agent",
"Interrupt a delegated descendant agent immediately.",
func(ctx context.Context, args closeAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
targetChatID, err := parseSubagentToolChatID(args.ChatID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
parent := currentChat()
targetChat, err := p.closeSubagent(
ctx,
parent.ID,
targetChatID,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": targetChatID.String(),
"title": targetChat.Title,
"terminated": true,
"status": string(targetChat.Status),
}), nil
},
),
}
}
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
chatID, err := uuid.Parse(strings.TrimSpace(raw))
if err != nil {
return uuid.Nil, xerrors.New("chat_id must be a valid UUID")
}
return chatID, nil
}
func (p *Server) createChildSubagentChat(
ctx context.Context,
parent database.Chat,
prompt string,
title string,
) (database.Chat, error) {
if parent.ParentChatID.Valid {
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
}
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return database.Chat{}, xerrors.New("prompt is required")
}
title = strings.TrimSpace(title)
if title == "" {
title = subagentFallbackChatTitle(prompt)
}
rootChatID := parent.ID
if parent.RootChatID.Valid {
rootChatID = parent.RootChatID.UUID
}
if parent.LastModelConfigID == uuid.Nil {
return database.Chat{}, xerrors.New("parent chat model config id is required")
}
child, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
WorkspaceAgentID: parent.WorkspaceAgentID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: rootChatID,
Valid: true,
},
ModelConfigID: parent.LastModelConfigID,
Title: title,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: prompt}},
})
if err != nil {
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
}
return child, nil
}
func (p *Server) sendSubagentMessage(
ctx context.Context,
parentChatID uuid.UUID,
targetChatID uuid.UUID,
message string,
busyBehavior SendMessageBusyBehavior,
) (database.Chat, error) {
message = strings.TrimSpace(message)
if message == "" {
return database.Chat{}, xerrors.New("message is required")
}
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
if err != nil {
return database.Chat{}, err
}
if !isDescendant {
return database.Chat{}, ErrSubagentNotDescendant
}
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
ChatID: targetChatID,
Content: []fantasy.Content{fantasy.TextContent{Text: message}},
BusyBehavior: busyBehavior,
})
if err != nil {
return database.Chat{}, err
}
return sendResult.Chat, nil
}
func (p *Server) awaitSubagentCompletion(
ctx context.Context,
parentChatID uuid.UUID,
targetChatID uuid.UUID,
timeout time.Duration,
) (database.Chat, string, error) {
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
if err != nil {
return database.Chat{}, "", err
}
if !isDescendant {
return database.Chat{}, "", ErrSubagentNotDescendant
}
if timeout <= 0 {
timeout = defaultSubagentWaitTimeout
}
timer := time.NewTimer(timeout)
defer timer.Stop()
ticker := time.NewTicker(subagentAwaitPollInterval)
defer ticker.Stop()
for {
targetChat, report, done, checkErr := p.checkSubagentCompletion(ctx, targetChatID)
if checkErr != nil {
return database.Chat{}, "", checkErr
}
if done {
if targetChat.Status == database.ChatStatusError {
reason := strings.TrimSpace(report)
if reason == "" {
reason = "agent reached error status"
}
return database.Chat{}, "", xerrors.New(reason)
}
return targetChat, report, nil
}
select {
case <-ticker.C:
case <-timer.C:
return database.Chat{}, "", xerrors.New("timed out waiting for delegated subagent completion")
case <-ctx.Done():
return database.Chat{}, "", ctx.Err()
}
}
}
func (p *Server) closeSubagent(
ctx context.Context,
parentChatID uuid.UUID,
targetChatID uuid.UUID,
) (database.Chat, error) {
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
if err != nil {
return database.Chat{}, err
}
if !isDescendant {
return database.Chat{}, ErrSubagentNotDescendant
}
targetChat, err := p.db.GetChatByID(ctx, targetChatID)
if err != nil {
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
}
if targetChat.Status == database.ChatStatusWaiting {
return targetChat, nil
}
updatedChat := p.InterruptChat(ctx, targetChat)
if updatedChat.Status != database.ChatStatusWaiting {
return database.Chat{}, xerrors.New("set target chat waiting")
}
return updatedChat, nil
}
func (p *Server) checkSubagentCompletion(
ctx context.Context,
chatID uuid.UUID,
) (database.Chat, string, bool, error) {
chat, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
return database.Chat{}, "", false, xerrors.Errorf("get chat: %w", err)
}
if chat.Status == database.ChatStatusPending || chat.Status == database.ChatStatusRunning {
return database.Chat{}, "", false, nil
}
report, err := latestSubagentAssistantMessage(ctx, p.db, chatID)
if err != nil {
return database.Chat{}, "", false, err
}
return chat, report, true, nil
}
func latestSubagentAssistantMessage(
ctx context.Context,
store database.Store,
chatID uuid.UUID,
) (string, error) {
messages, err := store.GetChatMessagesByChatID(ctx, chatID)
if err != nil {
return "", xerrors.Errorf("get chat messages: %w", err)
}
sort.Slice(messages, func(i, j int) bool {
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
return messages[i].ID < messages[j].ID
}
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
})
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Role != string(fantasy.MessageRoleAssistant) ||
message.Visibility == database.ChatMessageVisibilityModel {
continue
}
content, parseErr := chatprompt.ParseContent(message.Role, message.Content)
if parseErr != nil {
continue
}
text := strings.TrimSpace(contentBlocksToText(content))
if text == "" {
continue
}
return text, nil
}
return "", nil
}
func isSubagentDescendant(
ctx context.Context,
store database.Store,
ancestorChatID uuid.UUID,
targetChatID uuid.UUID,
) (bool, error) {
if ancestorChatID == targetChatID {
return false, nil
}
descendants, err := listSubagentDescendants(ctx, store, ancestorChatID)
if err != nil {
return false, err
}
for _, descendant := range descendants {
if descendant.ID == targetChatID {
return true, nil
}
}
return false, nil
}
func listSubagentDescendants(
ctx context.Context,
store database.Store,
chatID uuid.UUID,
) ([]database.Chat, error) {
queue := []uuid.UUID{chatID}
visited := map[uuid.UUID]struct{}{chatID: {}}
out := make([]database.Chat, 0)
for len(queue) > 0 {
parentChatID := queue[0]
queue = queue[1:]
children, err := store.ListChildChatsByParentID(ctx, parentChatID)
if err != nil {
return nil, xerrors.Errorf("list child chats for %s: %w", parentChatID, err)
}
for _, child := range children {
if _, ok := visited[child.ID]; ok {
continue
}
visited[child.ID] = struct{}{}
out = append(out, child)
queue = append(queue, child.ID)
}
}
return out, nil
}
func subagentFallbackChatTitle(message string) string {
const maxWords = 6
const maxRunes = 80
words := strings.Fields(message)
if len(words) == 0 {
return "New Chat"
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "..."
}
return subagentTruncateRunes(title, maxRunes)
}
func subagentTruncateRunes(value string, maxRunes int) string {
if maxRunes <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= maxRunes {
return value
}
return string(runes[:maxRunes])
}
func toolJSONResponse(result map[string]any) fantasy.ToolResponse {
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextResponse("{}")
}
return fantasy.NewTextResponse(string(data))
}
+216
View File
@@ -0,0 +1,216 @@
package chatd
import (
"context"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
)
const titleGenerationPrompt = "Generate a concise title (max 8 words, under 128 characters) for " +
"the user's first message. Return plain text only — no quotes, no emoji, " +
"no markdown, no special characters."
// maybeGenerateChatTitle generates an AI title for the chat when
// appropriate (first user message, no assistant reply yet, and the
// current title is either empty or still the fallback truncation).
// It is a best-effort operation that logs and swallows errors.
func (p *Server) maybeGenerateChatTitle(
ctx context.Context,
chat database.Chat,
messages []database.ChatMessage,
model fantasy.LanguageModel,
logger slog.Logger,
) {
input, ok := titleInput(chat, messages)
if !ok {
return
}
titleCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
title, err := generateTitle(titleCtx, model, input)
if err != nil {
logger.Debug(ctx, "failed to generate chat title",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
if title == "" || title == chat.Title {
return
}
_, err = p.db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
ID: chat.ID,
Title: title,
})
if err != nil {
logger.Warn(ctx, "failed to update generated chat title",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
chat.Title = title
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
}
// generateTitle calls the model with a title-generation system prompt
// and returns the normalized result.
func generateTitle(
ctx context.Context,
model fantasy.LanguageModel,
input string,
) (string, error) {
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: titleGenerationPrompt},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: input},
},
},
}
toolChoice := fantasy.ToolChoiceNone
response, err := model.Generate(ctx, fantasy.Call{
Prompt: prompt,
ToolChoice: &toolChoice,
})
if err != nil {
return "", xerrors.Errorf("generate title text: %w", err)
}
title := normalizeTitleOutput(contentBlocksToText(response.Content))
if title == "" {
return "", xerrors.New("generated title was empty")
}
return title, nil
}
// titleInput returns the first user message text and whether title
// generation should proceed. It returns false when the chat already
// has assistant/tool replies, has more than one visible user message,
// or the current title doesn't look like a candidate for replacement.
func titleInput(
chat database.Chat,
messages []database.ChatMessage,
) (string, bool) {
userCount := 0
firstUserText := ""
for _, message := range messages {
if message.Visibility == database.ChatMessageVisibilityModel {
continue
}
switch message.Role {
case string(fantasy.MessageRoleAssistant), string(fantasy.MessageRoleTool):
return "", false
case string(fantasy.MessageRoleUser):
userCount++
if firstUserText == "" {
parsed, err := chatprompt.ParseContent(
string(fantasy.MessageRoleUser), message.Content,
)
if err != nil {
return "", false
}
firstUserText = strings.TrimSpace(
contentBlocksToText(parsed),
)
}
}
}
if userCount != 1 || firstUserText == "" {
return "", false
}
currentTitle := strings.TrimSpace(chat.Title)
if currentTitle == "" {
return firstUserText, true
}
if currentTitle != fallbackChatTitle(firstUserText) {
return "", false
}
return firstUserText, true
}
func normalizeTitleOutput(title string) string {
title = strings.TrimSpace(title)
if title == "" {
return ""
}
title = strings.Trim(title, "\"'`")
title = strings.Join(strings.Fields(title), " ")
return truncateRunes(title, 80)
}
func fallbackChatTitle(message string) string {
const maxWords = 6
const maxRunes = 80
words := strings.Fields(message)
if len(words) == 0 {
return "New Chat"
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "…"
}
return truncateRunes(title, maxRunes)
}
// contentBlocksToText concatenates the text parts of content blocks
// into a single space-separated string.
func contentBlocksToText(content []fantasy.Content) string {
parts := make([]string, 0, len(content))
for _, block := range content {
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
if !ok {
continue
}
text := strings.TrimSpace(textBlock.Text)
if text == "" {
continue
}
parts = append(parts, text)
}
return strings.Join(parts, " ")
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= maxLen {
return value
}
return string(runes[:maxLen])
}
+3093
View File
File diff suppressed because it is too large Load Diff
+2067
View File
File diff suppressed because it is too large Load Diff
+62 -2
View File
@@ -49,6 +49,7 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/boundaryusage"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
@@ -238,6 +239,9 @@ type Options struct {
SSHConfig codersdk.SSHConfigResponse
HTTPClient *http.Client
// ChatRemotePartsProvider provides cross-replica message_part streaming.
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
ChatRemotePartsProvider chatd.RemotePartsProvider
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
StatsBatcher workspacestats.Batcher
@@ -588,7 +592,6 @@ func New(options *Options) *API {
var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
buildUsageChecker.Store(&noopUsageChecker)
api := &API{
ctx: ctx,
cancel: cancel,
@@ -754,6 +757,17 @@ func New(options *Options) *API {
panic("failed to setup server tailnet: " + err.Error())
}
api.agentProvider = stn
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chats"),
Database: options.Database,
ReplicaID: api.ID,
RemotePartsProvider: options.ChatRemotePartsProvider,
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
Pubsub: options.Pubsub,
})
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(stn)
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
@@ -1085,6 +1099,48 @@ func New(options *Options) *API {
})
})
})
r.Route("/chats", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents),
)
r.Get("/", api.listChats)
r.Post("/", api.postChats)
r.Get("/models", api.listChatModels)
r.Get("/watch", api.watchChats)
r.Route("/providers", func(r chi.Router) {
r.Get("/", api.listChatProviders)
r.Post("/", api.createChatProvider)
r.Route("/{providerConfig}", func(r chi.Router) {
r.Patch("/", api.updateChatProvider)
r.Delete("/", api.deleteChatProvider)
})
})
r.Route("/model-configs", func(r chi.Router) {
r.Get("/", api.listChatModelConfigs)
r.Post("/", api.createChatModelConfig)
r.Route("/{modelConfig}", func(r chi.Router) {
r.Patch("/", api.updateChatModelConfig)
r.Delete("/", api.deleteChatModelConfig)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Delete("/", api.deleteChat)
r.Post("/messages", api.postChatMessages)
r.Patch("/messages/{message}", api.patchChatMessage)
r.Get("/stream", api.streamChat)
r.Post("/interrupt", api.interruptChat)
r.Get("/diff-status", api.getChatDiffStatus)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
r.Post("/promote", api.promoteChatQueuedMessage)
})
})
})
r.Route("/mcp", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
@@ -1902,6 +1958,8 @@ type API struct {
// dbRolluper rolls up template usage stats from raw agent and app
// stats. This is used to provide insights in the WebUI.
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
}
// Close waits for all WebSocket connections to drain before returning.
@@ -1930,8 +1988,10 @@ func (api *API) Close() error {
case <-timer.C:
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
}
api.dbRolluper.Close()
if err := api.chatDaemon.Close(); err != nil {
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
}
api.metricsCache.Close()
if api.updateChecker != nil {
api.updateChecker.Close()
+16 -13
View File
@@ -6,17 +6,20 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
)
+345
View File
@@ -2,6 +2,7 @@
package db2sdk
import (
"database/sql"
"encoding/json"
"fmt"
"net/url"
@@ -11,6 +12,7 @@ import (
"strings"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/hashicorp/hcl/v2"
"github.com/sqlc-dev/pqtype"
@@ -18,6 +20,7 @@ import (
"tailscale.com/tailcfg"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
@@ -1050,3 +1053,345 @@ func jsonOrEmptyMap(rawMessage pqtype.NullRawMessage) map[string]any {
}
return m
}
func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
modelConfigID := &m.ModelConfigID.UUID
if !m.ModelConfigID.Valid {
modelConfigID = nil
}
msg := codersdk.ChatMessage{
ID: m.ID,
ChatID: m.ChatID,
ModelConfigID: modelConfigID,
CreatedAt: m.CreatedAt,
Role: m.Role,
}
if m.Content.Valid {
parts, err := chatMessageParts(m.Role, m.Content)
if err == nil {
msg.Content = parts
}
}
usage := chatMessageUsage(m)
if usage != nil {
msg.Usage = usage
}
return msg
}
// chatMessageUsage builds a ChatMessageUsage from the database row,
// returning nil when no token fields are populated.
func chatMessageUsage(m database.ChatMessage) *codersdk.ChatMessageUsage {
inputTokens := nullInt64Ptr(m.InputTokens)
outputTokens := nullInt64Ptr(m.OutputTokens)
totalTokens := nullInt64Ptr(m.TotalTokens)
reasoningTokens := nullInt64Ptr(m.ReasoningTokens)
cacheCreationTokens := nullInt64Ptr(m.CacheCreationTokens)
cacheReadTokens := nullInt64Ptr(m.CacheReadTokens)
contextLimit := nullInt64Ptr(m.ContextLimit)
if inputTokens == nil && outputTokens == nil && totalTokens == nil &&
reasoningTokens == nil && cacheCreationTokens == nil &&
cacheReadTokens == nil && contextLimit == nil {
return nil
}
return &codersdk.ChatMessageUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalTokens: totalTokens,
ReasoningTokens: reasoningTokens,
CacheCreationTokens: cacheCreationTokens,
CacheReadTokens: cacheReadTokens,
ContextLimit: contextLimit,
}
}
// ChatQueuedMessage converts a queued message to its SDK representation.
func ChatQueuedMessage(message database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
parts, err := chatMessageParts(string(fantasy.MessageRoleUser), pqtype.NullRawMessage{
RawMessage: message.Content,
Valid: len(message.Content) > 0,
})
if err != nil {
parts = nil
}
return codersdk.ChatQueuedMessage{
ID: message.ID,
ChatID: message.ChatID,
Content: parts,
CreatedAt: message.CreatedAt,
}
}
// ChatQueuedMessages converts a slice of database queued messages
// to their SDK representation.
func ChatQueuedMessages(messages []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
out := make([]codersdk.ChatQueuedMessage, 0, len(messages))
for _, message := range messages {
out = append(out, ChatQueuedMessage(message))
}
return out
}
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
switch role {
case string(fantasy.MessageRoleSystem):
content, err := parseSystemContent(raw)
if err != nil {
return nil, err
}
if strings.TrimSpace(content) == "" {
return nil, nil
}
return []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: content,
}}, nil
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
content, err := parseContentBlocks(role, raw)
if err != nil {
return nil, err
}
var rawBlocks []json.RawMessage
if role == string(fantasy.MessageRoleAssistant) {
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
}
parts := make([]codersdk.ChatMessagePart, 0, len(content))
for i, block := range content {
part := contentBlockToPart(block)
if part.Type == "" {
continue
}
if part.Type == codersdk.ChatMessagePartTypeReasoning {
part.Title = ""
if i < len(rawBlocks) {
part.Title = reasoningStoredTitle(rawBlocks[i])
}
}
parts = append(parts, part)
}
return parts, nil
case string(fantasy.MessageRoleTool):
results, err := parseToolResults(raw)
if err != nil {
return nil, err
}
parts := make([]codersdk.ChatMessagePart, 0, len(results))
for _, result := range results {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: result.ToolCallID,
ToolName: result.ToolName,
Result: result.Result,
IsError: result.IsError,
})
}
return parts, nil
default:
return nil, nil
}
}
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return "", nil
}
var content string
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
return "", xerrors.Errorf("parse system content: %w", err)
}
return content, nil
}
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
if role == string(fantasy.MessageRoleUser) {
var text string
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
return []fantasy.Content{
fantasy.TextContent{Text: text},
}, nil
}
}
var blocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
return nil, xerrors.Errorf("parse content blocks: %w", err)
}
content := make([]fantasy.Content, 0, len(blocks))
for _, block := range blocks {
decoded, err := fantasy.UnmarshalContent(block)
if err != nil {
return nil, xerrors.Errorf("parse content block: %w", err)
}
content = append(content, decoded)
}
return content, nil
}
// toolResultRow is used only for extracting top-level fields from
// persisted tool result JSON. The result payload is kept as raw JSON.
type toolResultRow struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
}
func parseToolResults(raw pqtype.NullRawMessage) ([]toolResultRow, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var results []toolResultRow
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
return nil, xerrors.Errorf("parse tool results: %w", err)
}
return results, nil
}
func reasoningStoredTitle(raw json.RawMessage) string {
var envelope struct {
Type string `json:"type"`
Data struct {
Title string `json:"title"`
} `json:"data"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
return ""
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
return ""
}
return strings.TrimSpace(envelope.Data.Title)
}
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
switch value := block.(type) {
case fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case *fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
}
case *fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
}
case fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case *fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case *fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case *fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case fantasy.ToolResultContent:
return chatprompt.ToolResultToPart(
value.ToolCallID,
value.ToolName,
toolResultOutputToRawJSON(value.Result),
toolResultOutputIsError(value.Result),
)
case *fantasy.ToolResultContent:
return chatprompt.ToolResultToPart(
value.ToolCallID,
value.ToolName,
toolResultOutputToRawJSON(value.Result),
toolResultOutputIsError(value.Result),
)
default:
return codersdk.ChatMessagePart{}
}
}
func toolResultOutputToRawJSON(output fantasy.ToolResultOutputContent) json.RawMessage {
switch v := output.(type) {
case fantasy.ToolResultOutputContentError:
if v.Error != nil {
data, _ := json.Marshal(map[string]any{"error": v.Error.Error()})
return data
}
return json.RawMessage(`{"error":""}`)
case fantasy.ToolResultOutputContentText:
raw := json.RawMessage(v.Text)
if json.Valid(raw) {
return raw
}
data, _ := json.Marshal(map[string]any{"output": v.Text})
return data
case fantasy.ToolResultOutputContentMedia:
data, _ := json.Marshal(map[string]any{
"data": v.Data,
"mime_type": v.MediaType,
"text": v.Text,
})
return data
default:
return json.RawMessage(`{}`)
}
}
func toolResultOutputIsError(output fantasy.ToolResultOutputContent) bool {
_, ok := output.(fantasy.ToolResultOutputContentError)
return ok
}
func nullInt64Ptr(v sql.NullInt64) *int64 {
if !v.Valid {
return nil
}
value := v.Int64
return &value
}
+131
View File
@@ -8,6 +8,8 @@ import (
"testing"
"time"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
@@ -435,3 +437,132 @@ func TestAIBridgeInterception(t *testing.T) {
})
}
}
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
t.Parallel()
assistantContent, err := json.Marshal([]fantasy.Content{
fantasy.ReasoningContent{
Text: "Plan migration",
ProviderMetadata: fantasy.ProviderMetadata{
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
ItemID: "reasoning-1",
Summary: []string{"Plan migration"},
},
},
},
})
require.NoError(t, err)
message := db2sdk.ChatMessage(database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
CreatedAt: time.Now(),
Role: string(fantasy.MessageRoleAssistant),
Content: pqtype.NullRawMessage{
RawMessage: assistantContent,
Valid: true,
},
})
require.Len(t, message.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
require.Equal(t, "Plan migration", message.Content[0].Text)
require.Empty(t, message.Content[0].Title)
}
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
t.Parallel()
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
Text: "Verify schema updates, then apply changes in order.",
ProviderMetadata: fantasy.ProviderMetadata{
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
ItemID: "reasoning-1",
Summary: []string{
"**Metadata-derived title**\n\nLonger explanation.",
},
},
},
})
require.NoError(t, err)
var envelope map[string]any
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
dataValue, ok := envelope["data"].(map[string]any)
require.True(t, ok)
dataValue["title"] = "Persisted stream title"
encodedReasoning, err := json.Marshal(envelope)
require.NoError(t, err)
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
require.NoError(t, err)
message := db2sdk.ChatMessage(database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
CreatedAt: time.Now(),
Role: string(fantasy.MessageRoleAssistant),
Content: pqtype.NullRawMessage{
RawMessage: assistantContent,
Valid: true,
},
})
require.Len(t, message.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
require.Equal(t, "Persisted stream title", message.Content[0].Title)
}
func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
t.Parallel()
rawContent, err := json.Marshal([]fantasy.Content{
fantasy.TextContent{Text: "queued text"},
})
require.NoError(t, err)
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
ID: 1,
ChatID: uuid.New(),
Content: rawContent,
CreatedAt: time.Now(),
})
require.Len(t, queued.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
require.Equal(t, "queued text", queued.Content[0].Text)
}
func TestChatQueuedMessage_FallsBackToTextForLegacyContent(t *testing.T) {
t.Parallel()
t.Run("legacy_string", func(t *testing.T) {
t.Parallel()
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
ID: 1,
ChatID: uuid.New(),
Content: json.RawMessage(`"legacy queued text"`),
CreatedAt: time.Now(),
})
require.Len(t, queued.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
require.Equal(t, "legacy queued text", queued.Content[0].Text)
})
t.Run("malformed_payload", func(t *testing.T) {
t.Parallel()
raw := json.RawMessage(`{"unexpected":"shape"}`)
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
ID: 1,
ChatID: uuid.New(),
Content: raw,
CreatedAt: time.Now(),
})
require.Empty(t, queued.Content)
})
}
+415 -10
View File
@@ -453,6 +453,7 @@ var (
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
@@ -1484,6 +1485,15 @@ func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.Prov
return nil
}
func (q *querier) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
// AcquireChat is a system-level operation used by the chat processor.
// Authorization is done at the system level, not per-user.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return database.Chat{}, err
}
return q.db.AcquireChat(ctx, arg)
}
func (q *querier) AcquireLock(ctx context.Context, id int64) error {
return q.db.AcquireLock(ctx, id)
}
@@ -1712,6 +1722,17 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
return q.db.DeleteAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteAllChatQueuedMessages(ctx, chatID)
}
func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return err
@@ -1736,6 +1757,66 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
return err
}
return q.db.DeleteChatByID(ctx, id)
}
func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
// Authorize delete on the parent chat.
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
return err
}
return q.db.DeleteChatMessagesByChatID(ctx, chatID)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatModelConfigByID(ctx, id)
}
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatProviderByID(ctx, id)
}
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteChatQueuedMessage(ctx, arg)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -2304,6 +2385,131 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
return fetch(q.log, q.auth, q.db.GetChatByID)(ctx, id)
}
func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return database.ChatDiffStatus{}, err
}
return q.db.GetChatDiffStatusByChatID(ctx, chatID)
}
func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
if len(chatIDs) == 0 {
return []database.ChatDiffStatus{}, nil
}
actor, ok := ActorFromContext(ctx)
if ok && actor.Type == rbac.SubjectTypeSystemRestricted {
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
}
for _, chatID := range chatIDs {
// Authorize read on each parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
}
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
}
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
// ChatMessages are authorized through their parent Chat.
// We need to fetch the message first to get its chat_id.
msg, err := q.db.GetChatMessageByID(ctx, id)
if err != nil {
return database.ChatMessage{}, err
}
// Authorize read on the parent chat.
_, err = q.GetChatByID(ctx, msg.ChatID)
if err != nil {
return database.ChatMessage{}, err
}
return msg, nil
}
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatID(ctx, chatID)
}
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesForPromptByChatID(ctx, chatID)
}
func (q *querier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetChatModelConfigByID(ctx, id)
}
func (q *querier) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetChatModelConfigByProviderAndModel(ctx, arg)
}
func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetChatModelConfigs(ctx)
}
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByID(ctx, id)
}
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByProvider(ctx, provider)
}
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetChatProviders(ctx)
}
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
return q.db.GetChatQueuedMessages(ctx, chatID)
}
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -2361,6 +2567,13 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
return q.db.GetDERPMeshKey(ctx)
}
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetDefaultChatModelConfig(ctx)
}
func (q *querier) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
return fetch(q.log, q.auth, func(ctx context.Context, _ any) (database.Organization, error) {
return q.db.GetDefaultOrganization(ctx)
@@ -2401,6 +2614,20 @@ func (q *querier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.C
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetEligibleProvisionerDaemonsByProvisionerJobIDs)(ctx, provisionerJobIDs)
}
func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledChatModelConfigs(ctx)
}
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledChatProviders(ctx)
}
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
}
@@ -3100,6 +3327,14 @@ func (q *querier) GetRuntimeConfig(ctx context.Context, key string) (string, err
return q.db.GetRuntimeConfig(ctx, key)
}
func (q *querier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
// GetStaleChats is a system-level operation used by the chat processor for recovery.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.GetStaleChats(ctx, staleThreshold)
}
func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
@@ -4223,6 +4458,47 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
}
func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
}
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
// Authorize create on the parent chat (using update permission).
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatMessage{}, err
}
return q.db.InsertChatMessage(ctx, arg)
}
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.InsertChatModelConfig(ctx, arg)
}
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.InsertChatProvider(ctx, arg)
}
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatQueuedMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatQueuedMessage{}, err
}
return q.db.InsertChatQueuedMessage(ctx, arg)
}
func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -4829,6 +5105,14 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChatsByRootID)(ctx, rootChatID)
}
func (q *querier) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChildChatsByParentID)(ctx, parentChatID)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
}
@@ -4909,6 +5193,17 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
return q.db.PaginatedOrganizationMembers(ctx, arg)
}
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return database.ChatQueuedMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatQueuedMessage{}, err
}
return q.db.PopNextQueuedMessage(ctx, chatID)
}
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
template, err := q.db.GetTemplateByID(ctx, templateID)
if err != nil {
@@ -4987,6 +5282,13 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
}
func (q *querier) UnsetDefaultChatModelConfigs(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UnsetDefaultChatModelConfigs(ctx)
}
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
return database.AIBridgeInterception{}, err
@@ -5001,6 +5303,91 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
}
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
// Authorize update on the parent chat of the edited message.
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
if err != nil {
return database.ChatMessage{}, err
}
chat, err := q.db.GetChatByID(ctx, msg.ChatID)
if err != nil {
return database.ChatMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatMessage{}, err
}
return q.db.UpdateChatMessageByID(ctx, arg)
}
func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.UpdateChatModelConfig(ctx, arg)
}
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.UpdateChatProvider(ctx, arg)
}
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
// UpdateChatStatus is used by the chat processor to change chat status.
// It should be called with system context.
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatStatus(ctx, arg)
}
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
// UpdateChatWorkspace is manually implemented for chat tables and may not be
// present on every wrapped store interface yet.
chatWorkspaceUpdater, ok := q.db.(interface {
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
})
if !ok {
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
}
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
}
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -6056,6 +6443,30 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDiffStatus{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDiffStatus{}, err
}
return q.db.UpsertChatDiffStatus(ctx, arg)
}
func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDiffStatus{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDiffStatus{}, err
}
return q.db.UpsertChatDiffStatusReference(ctx, arg)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
@@ -6347,18 +6758,12 @@ func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg databas
return q.CountConnectionLogs(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
// This cannot be deleted for now because it's included in the
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeInterceptions(ctx, arg)
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
}
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) {
// TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier.
// This cannot be deleted for now because it's included in the
// database.Store interface, so dbauthz needs to implement it.
return q.CountAIBridgeInterceptions(ctx, arg)
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
}
func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, _ rbac.PreparedAuthorized) ([]string, error) {
+373 -1
View File
@@ -170,6 +170,7 @@ func TestDBAuthzRecursive(t *testing.T) {
Groups: []string{},
Scope: rbac.ScopeAll,
}
preparedAuthorizedType := reflect.TypeOf((*rbac.PreparedAuthorized)(nil)).Elem()
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
var ins []reflect.Value
ctx := dbauthz.As(context.Background(), actor)
@@ -177,7 +178,13 @@ func TestDBAuthzRecursive(t *testing.T) {
ins = append(ins, reflect.ValueOf(ctx))
method := reflect.TypeOf(q).Method(i)
for i := 2; i < method.Type.NumIn(); i++ {
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
inType := method.Type.In(i)
if inType.Implements(preparedAuthorizedType) {
ins = append(ins, reflect.ValueOf(emptyPreparedAuthorized{}))
continue
}
ins = append(ins, reflect.New(inType).Elem())
}
if method.Name == "InTx" ||
method.Name == "Ping" ||
@@ -364,6 +371,371 @@ func (s *MethodTestSuite) TestConnectionLogs() {
}))
}
func (s *MethodTestSuite) TestChats() {
s.Run("AcquireChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.AcquireChatParams{
StartedAt: dbtime.Now(),
WorkerID: uuid.New(),
}
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().AcquireChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(chat)
}))
s.Run("DeleteAllChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteAllChatQueuedMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
}))
s.Run("DeleteChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatMessagesByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
}))
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.DeleteChatMessagesAfterIDParams{
ChatID: chat.ID,
AfterID: 123,
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
args := database.DeleteChatQueuedMessageParams{ID: 123, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
}))
s.Run("GetChatByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatDiffStatusByChatID(gomock.Any(), chat.ID).Return(diffStatus, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(diffStatus)
}))
s.Run("GetChatDiffStatusesByChatIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chatA := testutil.Fake(s.T(), faker, database.Chat{})
chatB := testutil.Fake(s.T(), faker, database.Chat{})
ids := []uuid.UUID{chatA.ID, chatB.ID}
diffStatusA := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatA.ID})
diffStatusB := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatB.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chatA.ID).Return(chatA, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chatB.ID).Return(chatB, nil).AnyTimes()
dbm.EXPECT().GetChatDiffStatusesByChatIDs(gomock.Any(), ids).Return([]database.ChatDiffStatus{diffStatusA, diffStatusB}, nil).AnyTimes()
check.Args(ids).
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
}))
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(msg.ID).Asserts(chat, policy.ActionRead).Returns(msg)
}))
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetChatModelConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetDefaultChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatModelConfigByProviderAndModel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
args := database.GetChatModelConfigByProviderAndModelParams{
Provider: config.Provider,
Model: config.Model,
}
dbm.EXPECT().GetChatModelConfigByProviderAndModel(gomock.Any(), args).Return(config, nil).AnyTimes()
check.Args(args).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
}))
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
}))
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerName := "test-provider"
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
}))
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
c1 := testutil.Fake(s.T(), faker, database.Chat{})
c2 := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), c1.OwnerID).Return([]database.Chat{c1, c2}, nil).AnyTimes()
check.Args(c1.OwnerID).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
}))
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
qms := []database.ChatQueuedMessage{testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatQueuedMessages(gomock.Any(), chat.ID).Return(qms, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(qms)
}))
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
}))
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("ListChatsByRootID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
rootChatID := uuid.New()
chatA := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
chatB := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
dbm.EXPECT().ListChatsByRootID(gomock.Any(), rootChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(rootChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("ListChildChatsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
parentChatID := uuid.New()
chatA := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
chatB := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
dbm.EXPECT().ListChildChatsByParentID(gomock.Any(), parentChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(parentChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
threshold := dbtime.Now()
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
dbm.EXPECT().GetStaleChats(gomock.Any(), threshold).Return(chats, nil).AnyTimes()
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
}))
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{})
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
}))
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
}))
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatQueuedMessageParams{ChatID: chat.ID})
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatQueuedMessage(gomock.Any(), arg).Return(qm, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(qm)
}))
s.Run("InsertChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertChatModelConfigParams{
Provider: "test-provider",
Model: "test-model",
DisplayName: "Test Model",
Enabled: true,
}
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{Provider: arg.Provider, Model: arg.Model, DisplayName: arg.DisplayName, Enabled: arg.Enabled})
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertChatProviderParams{
Provider: "test-provider",
DisplayName: "Test Provider",
APIKey: "test-api-key",
Enabled: true,
}
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().PopNextQueuedMessage(gomock.Any(), chat.ID).Return(qm, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(qm)
}))
s.Run("UpdateChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatByIDParams{
ID: chat.ID,
Title: "Updated title",
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
arg := database.UpdateChatMessageByIDParams{
ID: msg.ID,
ModelConfigID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
Content: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"blocks":[{"type":"text","text":"updated"}]}`),
Valid: true,
},
}
updated := testutil.Fake(s.T(), faker, database.ChatMessage{ID: msg.ID, ChatID: chat.ID})
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatMessageByID(gomock.Any(), arg).Return(updated, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updated)
}))
s.Run("UpdateChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
arg := database.UpdateChatModelConfigParams{
ID: config.ID,
Provider: "updated-provider",
Model: "updated-model",
DisplayName: "Updated Model",
Enabled: true,
}
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
arg := database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: "Updated Provider",
APIKey: "updated-api-key",
Enabled: true,
}
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("UpdateChatStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatWorkspaceParams{
ID: chat.ID,
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
WorkspaceAgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
}
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
}))
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UnsetDefaultChatModelConfigs(gomock.Any()).Return(nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("UpsertChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
now := dbtime.Now()
arg := database.UpsertChatDiffStatusParams{
ChatID: chat.ID,
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
PullRequestState: sql.NullString{String: "open", Valid: true},
ChangesRequested: false,
Additions: 10,
Deletions: 5,
ChangedFiles: 2,
RefreshedAt: now,
StaleAt: now.Add(time.Hour),
}
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpsertChatDiffStatus(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
}))
s.Run("UpsertChatDiffStatusReference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
GitBranch: "feature/test",
GitRemoteOrigin: "origin",
StaleAt: dbtime.Now().Add(time.Hour),
}
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
}))
}
func (s *MethodTestSuite) TestFile() {
s.Run("GetFileByHashAndCreator", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
f := testutil.Fake(s.T(), faker, database.File{})
+361
View File
@@ -104,6 +104,14 @@ func (m queryMetricsStore) DeleteOrganization(ctx context.Context, id uuid.UUID)
return r0
}
func (m queryMetricsStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.AcquireChat(ctx, arg)
m.queryLatencies.WithLabelValues("AcquireChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChat").Inc()
return r0, r1
}
func (m queryMetricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
start := time.Now()
r0 := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
@@ -156,6 +164,7 @@ func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context
start := time.Now()
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentMetadata").Inc()
return r0
}
@@ -311,6 +320,14 @@ func (m queryMetricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uui
return r0
}
func (m queryMetricsStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteAllChatQueuedMessages(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteAllChatQueuedMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteAllChatQueuedMessages").Inc()
return r0
}
func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
start := time.Now()
r0 := m.s.DeleteAllTailnetTunnels(ctx, arg)
@@ -335,6 +352,54 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
return r0
}
func (m queryMetricsStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
start := time.Now()
r0 := m.s.DeleteChatMessagesAfterID(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesAfterID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatMessagesByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteChatMessagesByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesByChatID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatModelConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatProviderByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
start := time.Now()
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteChatQueuedMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatQueuedMessage").Inc()
return r0
}
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
@@ -902,6 +967,126 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID
return r0, r1
}
func (m queryMetricsStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatByIDForUpdate(ctx, id)
m.queryLatencies.WithLabelValues("GetChatByIDForUpdate").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByIDForUpdate").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatDiffStatusByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
m.queryLatencies.WithLabelValues("GetChatDiffStatusesByChatIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusesByChatIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessageByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatMessageByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatMessagesForPromptByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesForPromptByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatModelConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigByProviderAndModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatModelConfigByProviderAndModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByProviderAndModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigs(ctx)
m.queryLatencies.WithLabelValues("GetChatModelConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviders(ctx)
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatQueuedMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatQueuedMessages").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
@@ -958,6 +1143,14 @@ func (m queryMetricsStore) GetDERPMeshKey(ctx context.Context) (string, error) {
return r0, r1
}
func (m queryMetricsStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetDefaultChatModelConfig(ctx)
m.queryLatencies.WithLabelValues("GetDefaultChatModelConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetDefaultChatModelConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
start := time.Now()
r0, r1 := m.s.GetDefaultOrganization(ctx)
@@ -1022,6 +1215,22 @@ func (m queryMetricsStore) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx
return r0, r1
}
func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledChatModelConfigs(ctx)
m.queryLatencies.WithLabelValues("GetEnabledChatModelConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatModelConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledChatProviders(ctx)
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
return r0, r1
}
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
start := time.Now()
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
@@ -1718,6 +1927,14 @@ func (m queryMetricsStore) GetRuntimeConfig(ctx context.Context, key string) (st
return r0, r1
}
func (m queryMetricsStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetStaleChats(ctx, staleThreshold)
m.queryLatencies.WithLabelValues("GetStaleChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetStaleChats").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetPeers(ctx, id)
@@ -2710,6 +2927,46 @@ func (m queryMetricsStore) InsertAuditLog(ctx context.Context, arg database.Inse
return r0, r1
}
func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.InsertChat(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChat").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatMessage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.InsertChatModelConfig(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatModelConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatModelConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.InsertChatProvider(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatQueuedMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatQueuedMessage").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.InsertCryptoKey(ctx, arg)
@@ -3246,6 +3503,22 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
return r0, r1
}
func (m queryMetricsStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.ListChatsByRootID(ctx, rootChatID)
m.queryLatencies.WithLabelValues("ListChatsByRootID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatsByRootID").Inc()
return r0, r1
}
func (m queryMetricsStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.ListChildChatsByParentID(ctx, parentChatID)
m.queryLatencies.WithLabelValues("ListChildChatsByParentID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChildChatsByParentID").Inc()
return r0, r1
}
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
start := time.Now()
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
@@ -3326,6 +3599,14 @@ func (m queryMetricsStore) PaginatedOrganizationMembers(ctx context.Context, arg
return r0, r1
}
func (m queryMetricsStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.PopNextQueuedMessage(ctx, chatID)
m.queryLatencies.WithLabelValues("PopNextQueuedMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "PopNextQueuedMessage").Inc()
return r0, r1
}
func (m queryMetricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
start := time.Now()
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
@@ -3398,6 +3679,14 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID
return r0
}
func (m queryMetricsStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
start := time.Now()
r0 := m.s.UnsetDefaultChatModelConfigs(ctx)
m.queryLatencies.WithLabelValues("UnsetDefaultChatModelConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnsetDefaultChatModelConfigs").Inc()
return r0
}
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, arg)
@@ -3414,6 +3703,62 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
return r0
}
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatByID").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatMessageByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatMessageByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMessageByID").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatModelConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatModelConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatModelConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatus").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.UpdateCryptoKeyDeletesAt(ctx, arg)
@@ -4125,6 +4470,22 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatDiffStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatus").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatusReference(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatDiffStatusReference").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatusReference").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
+667
View File
@@ -44,6 +44,21 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// AcquireChat mocks base method.
func (m *MockStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcquireChat", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcquireChat indicates an expected call of AcquireChat.
func (mr *MockStoreMockRecorder) AcquireChat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireChat", reflect.TypeOf((*MockStore)(nil).AcquireChat), ctx, arg)
}
// AcquireLock mocks base method.
func (m *MockStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
m.ctrl.T.Helper()
@@ -469,6 +484,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(ctx, userID any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), ctx, userID)
}
// DeleteAllChatQueuedMessages mocks base method.
func (m *MockStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAllChatQueuedMessages", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAllChatQueuedMessages indicates an expected call of DeleteAllChatQueuedMessages.
func (mr *MockStoreMockRecorder) DeleteAllChatQueuedMessages(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).DeleteAllChatQueuedMessages), ctx, chatID)
}
// DeleteAllTailnetTunnels mocks base method.
func (m *MockStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
m.ctrl.T.Helper()
@@ -511,6 +540,90 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
}
// DeleteChatByID mocks base method.
func (m *MockStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatByID indicates an expected call of DeleteChatByID.
func (mr *MockStoreMockRecorder) DeleteChatByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatByID", reflect.TypeOf((*MockStore)(nil).DeleteChatByID), ctx, id)
}
// DeleteChatMessagesAfterID mocks base method.
func (m *MockStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatMessagesAfterID", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatMessagesAfterID indicates an expected call of DeleteChatMessagesAfterID.
func (mr *MockStoreMockRecorder) DeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesAfterID), ctx, arg)
}
// DeleteChatMessagesByChatID mocks base method.
func (m *MockStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatMessagesByChatID", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatMessagesByChatID indicates an expected call of DeleteChatMessagesByChatID.
func (mr *MockStoreMockRecorder) DeleteChatMessagesByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesByChatID), ctx, chatID)
}
// DeleteChatModelConfigByID mocks base method.
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatModelConfigByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatModelConfigByID indicates an expected call of DeleteChatModelConfigByID.
func (mr *MockStoreMockRecorder) DeleteChatModelConfigByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigByID), ctx, id)
}
// DeleteChatProviderByID mocks base method.
func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID.
func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id)
}
// DeleteChatQueuedMessage mocks base method.
func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatQueuedMessage", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatQueuedMessage indicates an expected call of DeleteChatQueuedMessage.
func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
}
// DeleteCryptoKey mocks base method.
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -1649,6 +1762,231 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspacesAndAgentsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspacesAndAgentsByOwnerID), ctx, ownerID, prepared)
}
// GetChatByID mocks base method.
func (m *MockStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatByID", ctx, id)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatByID indicates an expected call of GetChatByID.
func (mr *MockStoreMockRecorder) GetChatByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByID", reflect.TypeOf((*MockStore)(nil).GetChatByID), ctx, id)
}
// GetChatByIDForUpdate mocks base method.
func (m *MockStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatByIDForUpdate", ctx, id)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatByIDForUpdate indicates an expected call of GetChatByIDForUpdate.
func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
}
// GetChatDiffStatusByChatID mocks base method.
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDiffStatusByChatID", ctx, chatID)
ret0, _ := ret[0].(database.ChatDiffStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDiffStatusByChatID indicates an expected call of GetChatDiffStatusByChatID.
func (mr *MockStoreMockRecorder) GetChatDiffStatusByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusByChatID), ctx, chatID)
}
// GetChatDiffStatusesByChatIDs mocks base method.
func (m *MockStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDiffStatusesByChatIDs", ctx, chatIds)
ret0, _ := ret[0].([]database.ChatDiffStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDiffStatusesByChatIDs indicates an expected call of GetChatDiffStatusesByChatIDs.
func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds)
}
// GetChatMessageByID mocks base method.
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessageByID", ctx, id)
ret0, _ := ret[0].(database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessageByID indicates an expected call of GetChatMessageByID.
func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
}
// GetChatMessagesByChatID mocks base method.
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, chatID)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessagesByChatID indicates an expected call of GetChatMessagesByChatID.
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, chatID)
}
// GetChatMessagesForPromptByChatID mocks base method.
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessagesForPromptByChatID", ctx, chatID)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessagesForPromptByChatID indicates an expected call of GetChatMessagesForPromptByChatID.
func (mr *MockStoreMockRecorder) GetChatMessagesForPromptByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesForPromptByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesForPromptByChatID), ctx, chatID)
}
// GetChatModelConfigByID mocks base method.
func (m *MockStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatModelConfigByID", ctx, id)
ret0, _ := ret[0].(database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatModelConfigByID indicates an expected call of GetChatModelConfigByID.
func (mr *MockStoreMockRecorder) GetChatModelConfigByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByID), ctx, id)
}
// GetChatModelConfigByProviderAndModel mocks base method.
func (m *MockStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatModelConfigByProviderAndModel", ctx, arg)
ret0, _ := ret[0].(database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatModelConfigByProviderAndModel indicates an expected call of GetChatModelConfigByProviderAndModel.
func (mr *MockStoreMockRecorder) GetChatModelConfigByProviderAndModel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByProviderAndModel", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByProviderAndModel), ctx, arg)
}
// GetChatModelConfigs mocks base method.
func (m *MockStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatModelConfigs", ctx)
ret0, _ := ret[0].([]database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatModelConfigs indicates an expected call of GetChatModelConfigs.
func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
}
// GetChatProviderByID mocks base method.
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviderByID indicates an expected call of GetChatProviderByID.
func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id)
}
// GetChatProviderByProvider mocks base method.
func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider.
func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider)
}
// GetChatProviders mocks base method.
func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviders", ctx)
ret0, _ := ret[0].([]database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviders indicates an expected call of GetChatProviders.
func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx)
}
// GetChatQueuedMessages mocks base method.
func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatQueuedMessages", ctx, chatID)
ret0, _ := ret[0].([]database.ChatQueuedMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatQueuedMessages indicates an expected call of GetChatQueuedMessages.
func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
}
// GetChatsByOwnerID mocks base method.
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, ownerID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, ownerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, ownerID)
}
// GetConnectionLogsOffset mocks base method.
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
@@ -1754,6 +2092,21 @@ func (mr *MockStoreMockRecorder) GetDERPMeshKey(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDERPMeshKey", reflect.TypeOf((*MockStore)(nil).GetDERPMeshKey), ctx)
}
// GetDefaultChatModelConfig mocks base method.
func (m *MockStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDefaultChatModelConfig", ctx)
ret0, _ := ret[0].(database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetDefaultChatModelConfig indicates an expected call of GetDefaultChatModelConfig.
func (mr *MockStoreMockRecorder) GetDefaultChatModelConfig(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultChatModelConfig", reflect.TypeOf((*MockStore)(nil).GetDefaultChatModelConfig), ctx)
}
// GetDefaultOrganization mocks base method.
func (m *MockStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
m.ctrl.T.Helper()
@@ -1874,6 +2227,36 @@ func (mr *MockStoreMockRecorder) GetEligibleProvisionerDaemonsByProvisionerJobID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEligibleProvisionerDaemonsByProvisionerJobIDs", reflect.TypeOf((*MockStore)(nil).GetEligibleProvisionerDaemonsByProvisionerJobIDs), ctx, provisionerJobIds)
}
// GetEnabledChatModelConfigs mocks base method.
func (m *MockStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEnabledChatModelConfigs", ctx)
ret0, _ := ret[0].([]database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEnabledChatModelConfigs indicates an expected call of GetEnabledChatModelConfigs.
func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx)
}
// GetEnabledChatProviders mocks base method.
func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx)
ret0, _ := ret[0].([]database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders.
func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
}
// GetExternalAuthLink mocks base method.
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
m.ctrl.T.Helper()
@@ -3179,6 +3562,21 @@ func (mr *MockStoreMockRecorder) GetRuntimeConfig(ctx, key any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuntimeConfig", reflect.TypeOf((*MockStore)(nil).GetRuntimeConfig), ctx, key)
}
// GetStaleChats mocks base method.
func (m *MockStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStaleChats", ctx, staleThreshold)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetStaleChats indicates an expected call of GetStaleChats.
func (mr *MockStoreMockRecorder) GetStaleChats(ctx, staleThreshold any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaleChats", reflect.TypeOf((*MockStore)(nil).GetStaleChats), ctx, staleThreshold)
}
// GetTailnetPeers mocks base method.
func (m *MockStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
m.ctrl.T.Helper()
@@ -5083,6 +5481,81 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), ctx, arg)
}
// InsertChat mocks base method.
func (m *MockStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChat", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChat indicates an expected call of InsertChat.
func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
}
// InsertChatMessage mocks base method.
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
ret0, _ := ret[0].(database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatMessage indicates an expected call of InsertChatMessage.
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
}
// InsertChatModelConfig mocks base method.
func (m *MockStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatModelConfig", ctx, arg)
ret0, _ := ret[0].(database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatModelConfig indicates an expected call of InsertChatModelConfig.
func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg)
}
// InsertChatProvider mocks base method.
func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatProvider indicates an expected call of InsertChatProvider.
func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg)
}
// InsertChatQueuedMessage mocks base method.
func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatQueuedMessage", ctx, arg)
ret0, _ := ret[0].(database.ChatQueuedMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatQueuedMessage indicates an expected call of InsertChatQueuedMessage.
func (mr *MockStoreMockRecorder) InsertChatQueuedMessage(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).InsertChatQueuedMessage), ctx, arg)
}
// InsertCryptoKey mocks base method.
func (m *MockStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -6102,6 +6575,36 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
}
// ListChatsByRootID mocks base method.
func (m *MockStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatsByRootID", ctx, rootChatID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatsByRootID indicates an expected call of ListChatsByRootID.
func (mr *MockStoreMockRecorder) ListChatsByRootID(ctx, rootChatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatsByRootID", reflect.TypeOf((*MockStore)(nil).ListChatsByRootID), ctx, rootChatID)
}
// ListChildChatsByParentID mocks base method.
func (m *MockStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChildChatsByParentID", ctx, parentChatID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChildChatsByParentID indicates an expected call of ListChildChatsByParentID.
func (mr *MockStoreMockRecorder) ListChildChatsByParentID(ctx, parentChatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChildChatsByParentID", reflect.TypeOf((*MockStore)(nil).ListChildChatsByParentID), ctx, parentChatID)
}
// ListProvisionerKeysByOrganization mocks base method.
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
m.ctrl.T.Helper()
@@ -6281,6 +6784,21 @@ func (mr *MockStoreMockRecorder) Ping(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockStore)(nil).Ping), ctx)
}
// PopNextQueuedMessage mocks base method.
func (m *MockStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PopNextQueuedMessage", ctx, chatID)
ret0, _ := ret[0].(database.ChatQueuedMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PopNextQueuedMessage indicates an expected call of PopNextQueuedMessage.
func (mr *MockStoreMockRecorder) PopNextQueuedMessage(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopNextQueuedMessage", reflect.TypeOf((*MockStore)(nil).PopNextQueuedMessage), ctx, chatID)
}
// ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate mocks base method.
func (m *MockStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
m.ctrl.T.Helper()
@@ -6411,6 +6929,20 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
}
// UnsetDefaultChatModelConfigs mocks base method.
func (m *MockStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnsetDefaultChatModelConfigs", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// UnsetDefaultChatModelConfigs indicates an expected call of UnsetDefaultChatModelConfigs.
func (mr *MockStoreMockRecorder) UnsetDefaultChatModelConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsetDefaultChatModelConfigs", reflect.TypeOf((*MockStore)(nil).UnsetDefaultChatModelConfigs), ctx)
}
// UpdateAIBridgeInterceptionEnded mocks base method.
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
@@ -6440,6 +6972,111 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
}
// UpdateChatByID mocks base method.
func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatByID", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatByID indicates an expected call of UpdateChatByID.
func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatHeartbeat mocks base method.
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
}
// UpdateChatMessageByID mocks base method.
func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatMessageByID", ctx, arg)
ret0, _ := ret[0].(database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatMessageByID indicates an expected call of UpdateChatMessageByID.
func (mr *MockStoreMockRecorder) UpdateChatMessageByID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMessageByID", reflect.TypeOf((*MockStore)(nil).UpdateChatMessageByID), ctx, arg)
}
// UpdateChatModelConfig mocks base method.
func (m *MockStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatModelConfig", ctx, arg)
ret0, _ := ret[0].(database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatModelConfig indicates an expected call of UpdateChatModelConfig.
func (mr *MockStoreMockRecorder) UpdateChatModelConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatModelConfig", reflect.TypeOf((*MockStore)(nil).UpdateChatModelConfig), ctx, arg)
}
// UpdateChatProvider mocks base method.
func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatProvider indicates an expected call of UpdateChatProvider.
func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg)
}
// UpdateChatStatus mocks base method.
func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatStatus", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatStatus indicates an expected call of UpdateChatStatus.
func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
}
// UpdateChatWorkspace mocks base method.
func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace.
func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg)
}
// UpdateCryptoKeyDeletesAt mocks base method.
func (m *MockStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -7722,6 +8359,36 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
}
// UpsertChatDiffStatus mocks base method.
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDiffStatus", ctx, arg)
ret0, _ := ret[0].(database.ChatDiffStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatDiffStatus indicates an expected call of UpsertChatDiffStatus.
func (mr *MockStoreMockRecorder) UpsertChatDiffStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatus", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatus), ctx, arg)
}
// UpsertChatDiffStatusReference mocks base method.
func (m *MockStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDiffStatusReference", ctx, arg)
ret0, _ := ret[0].(database.ChatDiffStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatDiffStatusReference indicates an expected call of UpsertChatDiffStatusReference.
func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
+235 -1
View File
@@ -210,7 +210,12 @@ CREATE TYPE api_key_scope AS ENUM (
'boundary_usage:read',
'boundary_usage:update',
'workspace:update_agent',
'workspace_dormant:update_agent'
'workspace_dormant:update_agent',
'chat:create',
'chat:read',
'chat:update',
'chat:delete',
'chat:*'
);
CREATE TYPE app_sharing_level AS ENUM (
@@ -260,6 +265,21 @@ CREATE TYPE build_reason AS ENUM (
'task_resume'
);
CREATE TYPE chat_message_visibility AS ENUM (
'user',
'model',
'both'
);
CREATE TYPE chat_status AS ENUM (
'waiting',
'pending',
'running',
'paused',
'completed',
'error'
);
CREATE TYPE connection_status AS ENUM (
'connected',
'disconnected'
@@ -1144,6 +1164,118 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
CREATE TABLE chat_diff_statuses (
chat_id uuid NOT NULL,
url text,
pull_request_state text,
changes_requested boolean DEFAULT false NOT NULL,
additions integer DEFAULT 0 NOT NULL,
deletions integer DEFAULT 0 NOT NULL,
changed_files integer DEFAULT 0 NOT NULL,
refreshed_at timestamp with time zone,
stale_at timestamp with time zone DEFAULT now() NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
git_branch text DEFAULT ''::text NOT NULL,
git_remote_origin text DEFAULT ''::text NOT NULL
);
CREATE TABLE chat_messages (
id bigint NOT NULL,
chat_id uuid NOT NULL,
model_config_id uuid,
created_at timestamp with time zone DEFAULT now() NOT NULL,
role text NOT NULL,
content jsonb,
visibility chat_message_visibility DEFAULT 'both'::chat_message_visibility NOT NULL,
input_tokens bigint,
output_tokens bigint,
total_tokens bigint,
reasoning_tokens bigint,
cache_creation_tokens bigint,
cache_read_tokens bigint,
context_limit bigint,
compressed boolean DEFAULT false NOT NULL
);
CREATE SEQUENCE chat_messages_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1;
ALTER SEQUENCE chat_messages_id_seq OWNED BY chat_messages.id;
CREATE TABLE chat_model_configs (
id uuid DEFAULT gen_random_uuid() NOT NULL,
provider text NOT NULL,
model text NOT NULL,
display_name text DEFAULT ''::text NOT NULL,
created_by uuid,
updated_by uuid,
enabled boolean DEFAULT true NOT NULL,
is_default boolean DEFAULT false NOT NULL,
deleted boolean DEFAULT false NOT NULL,
deleted_at timestamp with time zone,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
context_limit bigint NOT NULL,
compression_threshold integer NOT NULL,
options jsonb DEFAULT '{}'::jsonb NOT NULL,
CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))),
CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0))
);
CREATE TABLE chat_providers (
id uuid DEFAULT gen_random_uuid() NOT NULL,
provider text NOT NULL,
display_name text DEFAULT ''::text NOT NULL,
api_key text DEFAULT ''::text NOT NULL,
api_key_key_id text,
created_by uuid,
enabled boolean DEFAULT true NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
base_url text DEFAULT ''::text NOT NULL,
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
);
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
CREATE TABLE chat_queued_messages (
id bigint NOT NULL,
chat_id uuid NOT NULL,
content jsonb NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE SEQUENCE chat_queued_messages_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1;
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
CREATE TABLE chats (
id uuid DEFAULT gen_random_uuid() NOT NULL,
owner_id uuid NOT NULL,
workspace_id uuid,
workspace_agent_id uuid,
title text DEFAULT 'New Chat'::text NOT NULL,
status chat_status DEFAULT 'waiting'::chat_status NOT NULL,
worker_id uuid,
started_at timestamp with time zone,
heartbeat_at timestamp with time zone,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
parent_chat_id uuid,
root_chat_id uuid,
last_model_config_id uuid NOT NULL
);
CREATE TABLE connection_logs (
id uuid NOT NULL,
connect_time timestamp with time zone NOT NULL,
@@ -2951,6 +3083,10 @@ CREATE VIEW workspaces_expanded AS
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_messages_id_seq'::regclass);
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
@@ -2987,6 +3123,27 @@ ALTER TABLE ONLY audit_logs
ALTER TABLE ONLY boundary_usage_stats
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
@@ -3314,6 +3471,38 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::text) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING btree (provider, model);
CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
CREATE INDEX idx_chats_root_chat_id ON chats USING btree (root_chat_id);
CREATE INDEX idx_chats_workspace ON chats USING btree (workspace_id);
CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC);
CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
@@ -3560,6 +3749,51 @@ ALTER TABLE ONLY aibridge_interceptions
ALTER TABLE ONLY api_keys
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_provider_fkey FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_root_chat_id_fkey FOREIGN KEY (root_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL;
ALTER TABLE ONLY connection_logs
ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
+15
View File
@@ -8,6 +8,21 @@ type ForeignKeyConstraint string
const (
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ForeignKeyChatModelConfigsProvider ForeignKeyConstraint = "chat_model_configs_provider_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_provider_fkey FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
ForeignKeyChatsRootChatID ForeignKeyConstraint = "chats_root_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_root_chat_id_fkey FOREIGN KEY (root_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
ForeignKeyChatsWorkspaceAgentID ForeignKeyConstraint = "chats_workspace_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ForeignKeyChatsWorkspaceID ForeignKeyConstraint = "chats_workspace_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL;
ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
@@ -0,0 +1,8 @@
DROP TABLE IF EXISTS chat_queued_messages;
DROP TABLE IF EXISTS chat_diff_statuses;
DROP TABLE IF EXISTS chat_messages;
DROP TABLE IF EXISTS chats;
DROP TABLE IF EXISTS chat_model_configs;
DROP TABLE IF EXISTS chat_providers;
DROP TYPE IF EXISTS chat_message_visibility;
DROP TYPE IF EXISTS chat_status;
@@ -0,0 +1,167 @@
CREATE TYPE chat_status AS ENUM (
'waiting',
'pending',
'running',
'paused',
'completed',
'error'
);
CREATE TYPE chat_message_visibility AS ENUM (
'user',
'model',
'both'
);
CREATE TABLE chats (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL,
workspace_agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL,
title TEXT NOT NULL DEFAULT 'New Chat',
status chat_status NOT NULL DEFAULT 'waiting',
worker_id UUID,
started_at TIMESTAMPTZ,
heartbeat_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
parent_chat_id UUID REFERENCES chats(id) ON DELETE SET NULL,
root_chat_id UUID REFERENCES chats(id) ON DELETE SET NULL,
last_model_config_id UUID NOT NULL
);
CREATE INDEX idx_chats_owner ON chats(owner_id);
CREATE INDEX idx_chats_workspace ON chats(workspace_id);
CREATE INDEX idx_chats_pending ON chats(status) WHERE status = 'pending';
CREATE INDEX idx_chats_parent_chat_id ON chats(parent_chat_id);
CREATE INDEX idx_chats_root_chat_id ON chats(root_chat_id);
CREATE INDEX idx_chats_last_model_config_id ON chats(last_model_config_id);
CREATE TABLE chat_messages (
id BIGSERIAL PRIMARY KEY,
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
model_config_id UUID,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
role TEXT NOT NULL,
content JSONB,
visibility chat_message_visibility NOT NULL DEFAULT 'both',
input_tokens BIGINT,
output_tokens BIGINT,
total_tokens BIGINT,
reasoning_tokens BIGINT,
cache_creation_tokens BIGINT,
cache_read_tokens BIGINT,
context_limit BIGINT,
compressed BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX idx_chat_messages_chat ON chat_messages(chat_id);
CREATE INDEX idx_chat_messages_chat_created ON chat_messages(chat_id, created_at);
CREATE INDEX idx_chat_messages_compressed_summary_boundary
ON chat_messages(chat_id, created_at DESC, id DESC)
WHERE compressed = TRUE
AND role = 'system'
AND visibility IN ('model', 'both');
CREATE TABLE chat_diff_statuses (
chat_id UUID PRIMARY KEY REFERENCES chats(id) ON DELETE CASCADE,
url TEXT,
pull_request_state TEXT,
changes_requested BOOLEAN NOT NULL DEFAULT FALSE,
additions INTEGER NOT NULL DEFAULT 0,
deletions INTEGER NOT NULL DEFAULT 0,
changed_files INTEGER NOT NULL DEFAULT 0,
refreshed_at TIMESTAMPTZ,
stale_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
git_branch TEXT NOT NULL DEFAULT '',
git_remote_origin TEXT NOT NULL DEFAULT ''
);
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses(stale_at);
CREATE TABLE chat_providers (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL DEFAULT '',
api_key TEXT NOT NULL DEFAULT '',
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
created_by UUID REFERENCES users(id),
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
base_url TEXT NOT NULL DEFAULT '',
CONSTRAINT chat_providers_provider_check CHECK (
provider = ANY (
ARRAY[
'anthropic'::text,
'azure'::text,
'bedrock'::text,
'google'::text,
'openai'::text,
'openai-compat'::text,
'openrouter'::text,
'vercel'::text
]
)
)
);
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
CREATE INDEX idx_chat_providers_enabled ON chat_providers(enabled);
CREATE TABLE chat_model_configs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider TEXT NOT NULL REFERENCES chat_providers(provider) ON DELETE CASCADE,
model TEXT NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
created_by UUID REFERENCES users(id),
updated_by UUID REFERENCES users(id),
enabled BOOLEAN NOT NULL DEFAULT TRUE,
is_default BOOLEAN NOT NULL DEFAULT FALSE,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
deleted_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
context_limit BIGINT NOT NULL,
compression_threshold INTEGER NOT NULL,
options JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT chat_model_configs_context_limit_check
CHECK (context_limit > 0),
CONSTRAINT chat_model_configs_compression_threshold_check
CHECK (compression_threshold >= 0 AND compression_threshold <= 100)
);
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs(enabled);
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs(provider);
CREATE INDEX idx_chat_model_configs_provider_model
ON chat_model_configs(provider, model);
CREATE UNIQUE INDEX idx_chat_model_configs_single_default
ON chat_model_configs ((1))
WHERE is_default = TRUE
AND deleted = FALSE;
ALTER TABLE chat_messages
ADD CONSTRAINT chat_messages_model_config_id_fkey
FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
ALTER TABLE chats
ADD CONSTRAINT chats_last_model_config_id_fkey
FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
CREATE TABLE chat_queued_messages (
id BIGSERIAL PRIMARY KEY,
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
content JSONB NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages(chat_id);
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:create';
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:read';
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:update';
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:delete';
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:*';
@@ -0,0 +1,114 @@
INSERT INTO chat_providers (
id,
provider,
display_name,
api_key,
api_key_key_id,
enabled,
created_at,
updated_at
) VALUES (
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
'openai',
'OpenAI',
'',
NULL,
TRUE,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00'
);
INSERT INTO chat_model_configs (
id,
provider,
model,
display_name,
enabled,
context_limit,
compression_threshold,
created_at,
updated_at
) VALUES (
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
'openai',
'gpt-5.2',
'GPT 5.2',
TRUE,
200000,
70,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00'
);
INSERT INTO chats (
id,
owner_id,
last_model_config_id,
title,
status,
created_at,
updated_at
)
SELECT
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
id,
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
'Fixture Chat',
'completed',
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00'
FROM users
ORDER BY created_at, id
LIMIT 1;
INSERT INTO chat_messages (
chat_id,
created_at,
role,
content
) VALUES (
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
'2024-01-01 00:00:00+00',
'assistant',
'{"type":"text","text":"fixture"}'::jsonb
);
INSERT INTO chat_diff_statuses (
chat_id,
url,
pull_request_state,
changes_requested,
additions,
deletions,
changed_files,
refreshed_at,
stale_at,
created_at,
updated_at,
git_branch,
git_remote_origin
) VALUES (
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
'https://example.com/pr/1',
'open',
FALSE,
1,
0,
1,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00',
'main',
'origin'
);
INSERT INTO chat_queued_messages (
chat_id,
content,
created_at
) VALUES (
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
'{"type":"text","text":"queued fixture"}'::jsonb,
'2024-01-01 00:00:00+00'
);
+4
View File
@@ -165,6 +165,10 @@ func (t TaskTable) RBACObject() rbac.Object {
InOrg(t.OrganizationID)
}
func (c Chat) RBACObject() rbac.Object {
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
}
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
switch s {
case ApiKeyScopeCoderAll:
+237 -1
View File
@@ -219,6 +219,11 @@ const (
ApiKeyScopeBoundaryUsageUpdate APIKeyScope = "boundary_usage:update"
ApiKeyScopeWorkspaceUpdateAgent APIKeyScope = "workspace:update_agent"
ApiKeyScopeWorkspaceDormantUpdateAgent APIKeyScope = "workspace_dormant:update_agent"
ApiKeyScopeChatCreate APIKeyScope = "chat:create"
ApiKeyScopeChatRead APIKeyScope = "chat:read"
ApiKeyScopeChatUpdate APIKeyScope = "chat:update"
ApiKeyScopeChatDelete APIKeyScope = "chat:delete"
ApiKeyScopeChat APIKeyScope = "chat:*"
)
func (e *APIKeyScope) Scan(src interface{}) error {
@@ -457,7 +462,12 @@ func (e APIKeyScope) Valid() bool {
ApiKeyScopeBoundaryUsageRead,
ApiKeyScopeBoundaryUsageUpdate,
ApiKeyScopeWorkspaceUpdateAgent,
ApiKeyScopeWorkspaceDormantUpdateAgent:
ApiKeyScopeWorkspaceDormantUpdateAgent,
ApiKeyScopeChatCreate,
ApiKeyScopeChatRead,
ApiKeyScopeChatUpdate,
ApiKeyScopeChatDelete,
ApiKeyScopeChat:
return true
}
return false
@@ -665,6 +675,11 @@ func AllAPIKeyScopeValues() []APIKeyScope {
ApiKeyScopeBoundaryUsageUpdate,
ApiKeyScopeWorkspaceUpdateAgent,
ApiKeyScopeWorkspaceDormantUpdateAgent,
ApiKeyScopeChatCreate,
ApiKeyScopeChatRead,
ApiKeyScopeChatUpdate,
ApiKeyScopeChatDelete,
ApiKeyScopeChat,
}
}
@@ -1034,6 +1049,137 @@ func AllBuildReasonValues() []BuildReason {
}
}
type ChatMessageVisibility string
const (
ChatMessageVisibilityUser ChatMessageVisibility = "user"
ChatMessageVisibilityModel ChatMessageVisibility = "model"
ChatMessageVisibilityBoth ChatMessageVisibility = "both"
)
func (e *ChatMessageVisibility) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ChatMessageVisibility(s)
case string:
*e = ChatMessageVisibility(s)
default:
return fmt.Errorf("unsupported scan type for ChatMessageVisibility: %T", src)
}
return nil
}
type NullChatMessageVisibility struct {
ChatMessageVisibility ChatMessageVisibility `json:"chat_message_visibility"`
Valid bool `json:"valid"` // Valid is true if ChatMessageVisibility is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullChatMessageVisibility) Scan(value interface{}) error {
if value == nil {
ns.ChatMessageVisibility, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.ChatMessageVisibility.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullChatMessageVisibility) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.ChatMessageVisibility), nil
}
func (e ChatMessageVisibility) Valid() bool {
switch e {
case ChatMessageVisibilityUser,
ChatMessageVisibilityModel,
ChatMessageVisibilityBoth:
return true
}
return false
}
func AllChatMessageVisibilityValues() []ChatMessageVisibility {
return []ChatMessageVisibility{
ChatMessageVisibilityUser,
ChatMessageVisibilityModel,
ChatMessageVisibilityBoth,
}
}
type ChatStatus string
const (
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
)
func (e *ChatStatus) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ChatStatus(s)
case string:
*e = ChatStatus(s)
default:
return fmt.Errorf("unsupported scan type for ChatStatus: %T", src)
}
return nil
}
type NullChatStatus struct {
ChatStatus ChatStatus `json:"chat_status"`
Valid bool `json:"valid"` // Valid is true if ChatStatus is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullChatStatus) Scan(value interface{}) error {
if value == nil {
ns.ChatStatus, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.ChatStatus.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullChatStatus) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.ChatStatus), nil
}
func (e ChatStatus) Valid() bool {
switch e {
case ChatStatusWaiting,
ChatStatusPending,
ChatStatusRunning,
ChatStatusPaused,
ChatStatusCompleted,
ChatStatusError:
return true
}
return false
}
func AllChatStatusValues() []ChatStatus {
return []ChatStatus{
ChatStatusWaiting,
ChatStatusPending,
ChatStatusRunning,
ChatStatusPaused,
ChatStatusCompleted,
ChatStatusError,
}
}
type ConnectionStatus string
const (
@@ -3739,6 +3885,96 @@ type BoundaryUsageStat struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type Chat struct {
ID uuid.UUID `db:"id" json:"id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
Title string `db:"title" json:"title"`
Status ChatStatus `db:"status" json:"status"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
}
type ChatDiffStatus struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Url sql.NullString `db:"url" json:"url"`
PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"`
ChangesRequested bool `db:"changes_requested" json:"changes_requested"`
Additions int32 `db:"additions" json:"additions"`
Deletions int32 `db:"deletions" json:"deletions"`
ChangedFiles int32 `db:"changed_files" json:"changed_files"`
RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"`
StaleAt time.Time `db:"stale_at" json:"stale_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
GitBranch string `db:"git_branch" json:"git_branch"`
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
}
type ChatMessage struct {
ID int64 `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
Role string `db:"role" json:"role"`
Content pqtype.NullRawMessage `db:"content" json:"content"`
Visibility ChatMessageVisibility `db:"visibility" json:"visibility"`
InputTokens sql.NullInt64 `db:"input_tokens" json:"input_tokens"`
OutputTokens sql.NullInt64 `db:"output_tokens" json:"output_tokens"`
TotalTokens sql.NullInt64 `db:"total_tokens" json:"total_tokens"`
ReasoningTokens sql.NullInt64 `db:"reasoning_tokens" json:"reasoning_tokens"`
CacheCreationTokens sql.NullInt64 `db:"cache_creation_tokens" json:"cache_creation_tokens"`
CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"`
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
Compressed bool `db:"compressed" json:"compressed"`
}
type ChatModelConfig struct {
ID uuid.UUID `db:"id" json:"id"`
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
DisplayName string `db:"display_name" json:"display_name"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"`
Enabled bool `db:"enabled" json:"enabled"`
IsDefault bool `db:"is_default" json:"is_default"`
Deleted bool `db:"deleted" json:"deleted"`
DeletedAt sql.NullTime `db:"deleted_at" json:"deleted_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ContextLimit int64 `db:"context_limit" json:"context_limit"`
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
Options json.RawMessage `db:"options" json:"options"`
}
type ChatProvider struct {
ID uuid.UUID `db:"id" json:"id"`
Provider string `db:"provider" json:"provider"`
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
BaseUrl string `db:"base_url" json:"base_url"`
}
type ChatQueuedMessage struct {
ID int64 `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Content json.RawMessage `db:"content" json:"content"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
type ConnectionLog struct {
ID uuid.UUID `db:"id" json:"id"`
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
+51
View File
@@ -12,6 +12,9 @@ import (
)
type sqlcQuerier interface {
// Acquires a pending chat for processing. Uses SKIP LOCKED to prevent
// multiple replicas from acquiring the same chat.
AcquireChat(ctx context.Context, arg AcquireChatParams) (Chat, error)
// Blocks until the lock is acquired.
//
// This must be called from within a transaction. The lock will be automatically
@@ -81,6 +84,7 @@ type sqlcQuerier interface {
CustomRoles(ctx context.Context, arg CustomRolesParams) ([]CustomRole, error)
DeleteAPIKeyByID(ctx context.Context, id string) error
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error
DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error
// Deletes all existing webpush subscriptions.
// This should be called when the VAPID keypair is regenerated, as the old
@@ -88,6 +92,12 @@ type sqlcQuerier interface {
// be recreated.
DeleteAllWebpushSubscriptions(ctx context.Context) error
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteChatByID(ctx context.Context, id uuid.UUID) error
DeleteChatMessagesAfterID(ctx context.Context, arg DeleteChatMessagesAfterIDParams) error
DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
@@ -199,6 +209,21 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error)
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
GetChatModelConfigByProviderAndModel(ctx context.Context, arg GetChatModelConfigByProviderAndModelParams) (ChatModelConfig, error)
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]Chat, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
@@ -206,6 +231,7 @@ type sqlcQuerier interface {
GetCryptoKeysByFeature(ctx context.Context, feature CryptoKeyFeature) ([]CryptoKey, error)
GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error)
GetDERPMeshKey(ctx context.Context) (string, error)
GetDefaultChatModelConfig(ctx context.Context) (ChatModelConfig, error)
GetDefaultOrganization(ctx context.Context) (Organization, error)
GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error)
GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error)
@@ -214,6 +240,8 @@ type sqlcQuerier interface {
GetDeploymentWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) (GetDeploymentWorkspaceAgentUsageStatsRow, error)
GetDeploymentWorkspaceStats(ctx context.Context) (GetDeploymentWorkspaceStatsRow, error)
GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.Context, provisionerJobIds []uuid.UUID) ([]GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error)
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg GetFailedWorkspaceBuildsByTemplateIDParams) ([]GetFailedWorkspaceBuildsByTemplateIDRow, error)
@@ -352,6 +380,9 @@ type sqlcQuerier interface {
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
GetRuntimeConfig(ctx context.Context, key string) (string, error)
// Find chats that appear stuck (running but heartbeat has expired).
// Used for recovery after coderd crashes or long hangs.
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
@@ -574,6 +605,11 @@ type sqlcQuerier interface {
// every member of the org.
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error)
InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error)
InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error)
InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error
@@ -657,6 +693,8 @@ type sqlcQuerier interface {
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]Chat, error)
ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]Chat, error)
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
@@ -673,6 +711,7 @@ type sqlcQuerier interface {
// - Use both to get a specific org member row
OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error)
PaginatedOrganizationMembers(ctx context.Context, arg PaginatedOrganizationMembersParams) ([]PaginatedOrganizationMembersRow, error)
PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error)
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
@@ -691,8 +730,18 @@ type sqlcQuerier interface {
// This will always work regardless of the current state of the template version.
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
UnsetDefaultChatModelConfigs(ctx context.Context) error
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error)
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
@@ -790,6 +839,8 @@ type sqlcQuerier interface {
// cumulative values for unique counts (accurate period totals). Request counts
// are always deltas, accumulated in DB. Returns true if insert, false if update.
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error
// The default proxy is implied and not actually stored in the database.
+41
View File
@@ -7548,6 +7548,47 @@ func TestGetTaskByWorkspaceID(t *testing.T) {
}
}
func TestDeleteTaskDeletesTaskSnapshot(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
template := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
OrganizationID: org.ID,
CreatedBy: user.ID,
})
task := dbgen.Task(t, db, database.TaskTable{
OrganizationID: org.ID,
OwnerID: user.ID,
TemplateVersionID: templateVersion.ID,
Prompt: "Test prompt",
})
err := db.UpsertTaskSnapshot(ctx, database.UpsertTaskSnapshotParams{
TaskID: task.ID,
LogSnapshot: json.RawMessage(`{"messages":[]}`),
LogSnapshotCreatedAt: dbtime.Now(),
})
require.NoError(t, err)
_, err = db.DeleteTask(ctx, database.DeleteTaskParams{
ID: task.ID,
DeletedAt: dbtime.Now(),
})
require.NoError(t, err)
_, err = db.GetTaskSnapshot(ctx, task.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
}
func TestTaskNameUniqueness(t *testing.T) {
t.Parallel()
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,129 @@
-- name: GetChatModelConfigByID :one
SELECT
*
FROM
chat_model_configs
WHERE
id = @id::uuid
AND deleted = FALSE;
-- name: GetDefaultChatModelConfig :one
SELECT
*
FROM
chat_model_configs
WHERE
is_default = TRUE
AND deleted = FALSE;
-- name: GetChatModelConfigByProviderAndModel :one
SELECT
*
FROM
chat_model_configs
WHERE
provider = @provider::text
AND model = @model::text
AND deleted = FALSE
ORDER BY
updated_at DESC,
created_at DESC,
id DESC
LIMIT 1;
-- name: GetChatModelConfigs :many
SELECT
*
FROM
chat_model_configs
WHERE
deleted = FALSE
ORDER BY
provider ASC,
model ASC,
updated_at DESC,
id DESC;
-- name: GetEnabledChatModelConfigs :many
SELECT
cmc.*
FROM
chat_model_configs cmc
JOIN
chat_providers cp ON cp.provider = cmc.provider
WHERE
cmc.enabled = TRUE
AND cmc.deleted = FALSE
AND cp.enabled = TRUE
ORDER BY
cmc.provider ASC,
cmc.model ASC,
cmc.updated_at DESC,
cmc.id DESC;
-- name: InsertChatModelConfig :one
INSERT INTO chat_model_configs (
provider,
model,
display_name,
created_by,
updated_by,
enabled,
is_default,
context_limit,
compression_threshold,
options
) VALUES (
@provider::text,
@model::text,
@display_name::text,
sqlc.narg('created_by')::uuid,
sqlc.narg('updated_by')::uuid,
@enabled::boolean,
@is_default::boolean,
@context_limit::bigint,
@compression_threshold::integer,
@options::jsonb
)
RETURNING
*;
-- name: UpdateChatModelConfig :one
UPDATE
chat_model_configs
SET
provider = @provider::text,
model = @model::text,
display_name = @display_name::text,
updated_by = sqlc.narg('updated_by')::uuid,
enabled = @enabled::boolean,
is_default = @is_default::boolean,
context_limit = @context_limit::bigint,
compression_threshold = @compression_threshold::integer,
options = @options::jsonb,
updated_at = NOW()
WHERE
id = @id::uuid
AND deleted = FALSE
RETURNING
*;
-- name: UnsetDefaultChatModelConfigs :exec
UPDATE
chat_model_configs
SET
is_default = FALSE,
updated_at = NOW()
WHERE
is_default = TRUE
AND deleted = FALSE;
-- name: DeleteChatModelConfigByID :exec
UPDATE
chat_model_configs
SET
deleted = TRUE,
deleted_at = NOW(),
updated_at = NOW()
WHERE
id = @id::uuid;
+75
View File
@@ -0,0 +1,75 @@
-- name: GetChatProviderByID :one
SELECT
*
FROM
chat_providers
WHERE
id = @id::uuid;
-- name: GetChatProviderByProvider :one
SELECT
*
FROM
chat_providers
WHERE
provider = @provider::text;
-- name: GetChatProviders :many
SELECT
*
FROM
chat_providers
ORDER BY
provider ASC;
-- name: GetEnabledChatProviders :many
SELECT
*
FROM
chat_providers
WHERE
enabled = TRUE
ORDER BY
provider ASC;
-- name: InsertChatProvider :one
INSERT INTO chat_providers (
provider,
display_name,
api_key,
base_url,
api_key_key_id,
created_by,
enabled
) VALUES (
@provider::text,
@display_name::text,
@api_key::text,
@base_url::text,
sqlc.narg('api_key_key_id')::text,
sqlc.narg('created_by')::uuid,
@enabled::boolean
)
RETURNING
*;
-- name: UpdateChatProvider :one
UPDATE
chat_providers
SET
display_name = @display_name::text,
api_key = @api_key::text,
base_url = @base_url::text,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
enabled = @enabled::boolean,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: DeleteChatProviderByID :exec
DELETE FROM
chat_providers
WHERE
id = @id::uuid;
+409
View File
@@ -0,0 +1,409 @@
-- name: DeleteChatByID :exec
DELETE FROM
chats
WHERE
id = @id::uuid;
-- name: DeleteChatMessagesByChatID :exec
DELETE FROM
chat_messages
WHERE
chat_id = @chat_id::uuid;
-- name: DeleteChatMessagesAfterID :exec
DELETE FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND id > @after_id::bigint;
-- name: GetChatByID :one
SELECT
*
FROM
chats
WHERE
id = @id::uuid;
-- name: GetChatMessageByID :one
SELECT
*
FROM
chat_messages
WHERE
id = @id::bigint;
-- name: GetChatMessagesByChatID :many
SELECT
*
FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND visibility IN ('user', 'both')
ORDER BY
created_at ASC;
-- name: GetChatMessagesForPromptByChatID :many
WITH latest_compressed_summary AS (
SELECT
id
FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND role = 'system'
AND visibility IN ('model', 'both')
AND compressed = TRUE
ORDER BY
created_at DESC,
id DESC
LIMIT
1
)
SELECT
*
FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND visibility IN ('model', 'both')
AND (
(
role = 'system'
AND compressed = FALSE
)
OR (
compressed = FALSE
AND (
NOT EXISTS (
SELECT
1
FROM
latest_compressed_summary
)
OR id > (
SELECT
id
FROM
latest_compressed_summary
)
)
)
OR id = (
SELECT
id
FROM
latest_compressed_summary
)
)
ORDER BY
created_at ASC,
id ASC;
-- name: GetChatsByOwnerID :many
SELECT
*
FROM
chats
WHERE
owner_id = @owner_id::uuid
ORDER BY
updated_at DESC;
-- name: ListChildChatsByParentID :many
SELECT
*
FROM
chats
WHERE
parent_chat_id = @parent_chat_id::uuid
ORDER BY
created_at ASC;
-- name: ListChatsByRootID :many
SELECT
*
FROM
chats
WHERE
root_chat_id = @root_chat_id::uuid
ORDER BY
created_at ASC;
-- name: InsertChat :one
INSERT INTO chats (
owner_id,
workspace_id,
workspace_agent_id,
parent_chat_id,
root_chat_id,
last_model_config_id,
title
) VALUES (
@owner_id::uuid,
sqlc.narg('workspace_id')::uuid,
sqlc.narg('workspace_agent_id')::uuid,
sqlc.narg('parent_chat_id')::uuid,
sqlc.narg('root_chat_id')::uuid,
@last_model_config_id::uuid,
@title::text
)
RETURNING
*;
-- name: InsertChatMessage :one
WITH updated_chat AS (
UPDATE
chats
SET
last_model_config_id = sqlc.narg('model_config_id')::uuid
WHERE
id = @chat_id::uuid
AND sqlc.narg('model_config_id')::uuid IS NOT NULL
)
INSERT INTO chat_messages (
chat_id,
model_config_id,
role,
content,
visibility,
input_tokens,
output_tokens,
total_tokens,
reasoning_tokens,
cache_creation_tokens,
cache_read_tokens,
context_limit,
compressed
) VALUES (
@chat_id::uuid,
sqlc.narg('model_config_id')::uuid,
@role::text,
sqlc.narg('content')::jsonb,
@visibility::chat_message_visibility,
sqlc.narg('input_tokens')::bigint,
sqlc.narg('output_tokens')::bigint,
sqlc.narg('total_tokens')::bigint,
sqlc.narg('reasoning_tokens')::bigint,
sqlc.narg('cache_creation_tokens')::bigint,
sqlc.narg('cache_read_tokens')::bigint,
sqlc.narg('context_limit')::bigint,
COALESCE(sqlc.narg('compressed')::boolean, FALSE)
)
RETURNING
*;
-- name: UpdateChatMessageByID :one
UPDATE
chat_messages
SET
model_config_id = COALESCE(sqlc.narg('model_config_id')::uuid, model_config_id),
content = sqlc.narg('content')::jsonb
WHERE
id = @id::bigint
RETURNING
*;
-- name: UpdateChatByID :one
UPDATE
chats
SET
title = @title::text,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: UpdateChatWorkspace :one
UPDATE
chats
SET
workspace_id = sqlc.narg('workspace_id')::uuid,
workspace_agent_id = sqlc.narg('workspace_agent_id')::uuid,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: AcquireChat :one
-- Acquires a pending chat for processing. Uses SKIP LOCKED to prevent
-- multiple replicas from acquiring the same chat.
UPDATE
chats
SET
status = 'running'::chat_status,
started_at = @started_at::timestamptz,
heartbeat_at = @started_at::timestamptz,
updated_at = @started_at::timestamptz,
worker_id = @worker_id::uuid
WHERE
id = (
SELECT
id
FROM
chats
WHERE
status = 'pending'::chat_status
ORDER BY
updated_at ASC
FOR UPDATE
SKIP LOCKED
LIMIT
1
)
RETURNING
*;
-- name: UpdateChatStatus :one
UPDATE
chats
SET
status = @status::chat_status,
worker_id = sqlc.narg('worker_id')::uuid,
started_at = sqlc.narg('started_at')::timestamptz,
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: GetStaleChats :many
-- Find chats that appear stuck (running but heartbeat has expired).
-- Used for recovery after coderd crashes or long hangs.
SELECT
*
FROM
chats
WHERE
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
-- name: UpdateChatHeartbeat :execrows
-- Bumps the heartbeat timestamp for a running chat so that other
-- replicas know the worker is still alive.
UPDATE
chats
SET
heartbeat_at = NOW()
WHERE
id = @id::uuid
AND worker_id = @worker_id::uuid
AND status = 'running'::chat_status;
-- name: GetChatDiffStatusByChatID :one
SELECT
*
FROM
chat_diff_statuses
WHERE
chat_id = @chat_id::uuid;
-- name: GetChatDiffStatusesByChatIDs :many
SELECT
*
FROM
chat_diff_statuses
WHERE
chat_id = ANY(@chat_ids::uuid[]);
-- name: UpsertChatDiffStatusReference :one
INSERT INTO chat_diff_statuses (
chat_id,
url,
git_branch,
git_remote_origin,
stale_at
) VALUES (
@chat_id::uuid,
sqlc.narg('url')::text,
@git_branch::text,
@git_remote_origin::text,
@stale_at::timestamptz
)
ON CONFLICT (chat_id) DO UPDATE
SET
url = CASE
WHEN EXCLUDED.url IS NOT NULL THEN EXCLUDED.url
ELSE chat_diff_statuses.url
END,
git_branch = CASE
WHEN EXCLUDED.git_branch != '' THEN EXCLUDED.git_branch
ELSE chat_diff_statuses.git_branch
END,
git_remote_origin = CASE
WHEN EXCLUDED.git_remote_origin != '' THEN EXCLUDED.git_remote_origin
ELSE chat_diff_statuses.git_remote_origin
END,
stale_at = EXCLUDED.stale_at,
updated_at = NOW()
RETURNING
*;
-- name: UpsertChatDiffStatus :one
INSERT INTO chat_diff_statuses (
chat_id,
url,
pull_request_state,
changes_requested,
additions,
deletions,
changed_files,
refreshed_at,
stale_at
) VALUES (
@chat_id::uuid,
sqlc.narg('url')::text,
sqlc.narg('pull_request_state')::text,
@changes_requested::boolean,
@additions::integer,
@deletions::integer,
@changed_files::integer,
@refreshed_at::timestamptz,
@stale_at::timestamptz
)
ON CONFLICT (chat_id) DO UPDATE
SET
url = EXCLUDED.url,
pull_request_state = EXCLUDED.pull_request_state,
changes_requested = EXCLUDED.changes_requested,
additions = EXCLUDED.additions,
deletions = EXCLUDED.deletions,
changed_files = EXCLUDED.changed_files,
refreshed_at = EXCLUDED.refreshed_at,
stale_at = EXCLUDED.stale_at,
updated_at = NOW()
RETURNING
*;
-- name: InsertChatQueuedMessage :one
INSERT INTO chat_queued_messages (chat_id, content)
VALUES (@chat_id, @content)
RETURNING *;
-- name: GetChatQueuedMessages :many
SELECT * FROM chat_queued_messages
WHERE chat_id = @chat_id
ORDER BY id ASC;
-- name: DeleteChatQueuedMessage :exec
DELETE FROM chat_queued_messages WHERE id = @id AND chat_id = @chat_id;
-- name: DeleteAllChatQueuedMessages :exec
DELETE FROM chat_queued_messages WHERE chat_id = @chat_id;
-- name: PopNextQueuedMessage :one
DELETE FROM chat_queued_messages
WHERE id = (
SELECT cqm.id FROM chat_queued_messages cqm
WHERE cqm.chat_id = @chat_id
ORDER BY cqm.id ASC
LIMIT 1
)
RETURNING *;
-- name: GetChatByIDForUpdate :one
SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;
+1 -1
View File
@@ -399,7 +399,7 @@ WHERE
filtered_workspaces fw
ORDER BY
-- To ensure that 'favorite' workspaces show up first in the list only for their owner.
CASE WHEN owner_id = @requester_id AND favorite THEN 0 ELSE 1 END ASC,
CASE WHEN favorite AND owner_username = (SELECT users.username FROM users WHERE users.id = @requester_id) THEN 0 ELSE 1 END ASC,
(latest_build_completed_at IS NOT NULL AND
latest_build_canceled_at IS NULL AND
latest_build_error IS NULL AND
+8
View File
@@ -14,6 +14,13 @@ const (
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id);
@@ -110,6 +117,7 @@ const (
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexChatModelConfigsSingleDefault UniqueConstraint = "idx_chat_model_configs_single_default" // CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
UniqueIndexCustomRolesNameLowerOrganizationID UniqueConstraint = "idx_custom_roles_name_lower_organization_id" // CREATE UNIQUE INDEX idx_custom_roles_name_lower_organization_id ON custom_roles USING btree (lower(name), COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid));
UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)) WHERE (deleted = false);
+50
View File
@@ -0,0 +1,50 @@
package httpmw
import (
"context"
"net/http"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)
type chatParamContextKey struct{}
// ChatParam returns the chat from the ExtractChatParam handler.
func ChatParam(r *http.Request) database.Chat {
chat, ok := r.Context().Value(chatParamContextKey{}).(database.Chat)
if !ok {
panic("developer error: chat param middleware not provided")
}
return chat
}
// ExtractChatParam grabs a chat from the "chat" URL parameter.
func ExtractChatParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chatID, parsed := ParseUUIDParam(rw, r, "chat")
if !parsed {
return
}
chat, err := db.GetChatByID(ctx, chatID)
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching chat.",
Detail: err.Error(),
})
return
}
ctx = context.WithValue(ctx, chatParamContextKey{}, chat)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}
+159
View File
@@ -0,0 +1,159 @@
package httpmw_test
import (
"context"
"database/sql"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
)
func TestChatParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.User) {
user := dbgen.User(t, db, database.User{})
_, token := dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
})
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set(codersdk.SessionTokenHeader, token)
ctx := chi.NewRouteContext()
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, user
}
insertChat := func(t *testing.T, db database.Store, ownerID uuid.UUID) database.Chat {
t.Helper()
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
BaseUrl: "https://api.openai.com/v1",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
modelConfig, err := db.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test model",
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: []byte("{}"),
})
require.NoError(t, err)
chat, err := db.InsertChat(context.Background(), database.InsertChatParams{
OwnerID: ownerID,
WorkspaceID: uuid.NullUUID{},
WorkspaceAgentID: uuid.NullUUID{},
ParentChatID: uuid.NullUUID{},
RootChatID: uuid.NullUUID{},
LastModelConfigID: modelConfig.ID,
Title: "Test chat",
})
require.NoError(t, err)
return chat
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractChatParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractChatParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("chat", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("BadUUID", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractChatParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("chat", "not-a-uuid")
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("Found", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
RedirectToLogin: false,
}),
httpmw.ExtractChatParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.ChatParam(r)
rw.WriteHeader(http.StatusOK)
})
r, user := setupAuthentication(db)
chat := insertChat(t, db, user.ID)
chi.RouteContext(r.Context()).URLParams.Add("chat", chat.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}
+12 -1
View File
@@ -15,13 +15,24 @@ type requestIDContextKey struct{}
// RequestID returns the ID of the request.
func RequestID(r *http.Request) uuid.UUID {
rid, ok := r.Context().Value(requestIDContextKey{}).(uuid.UUID)
rid, ok := RequestIDOptional(r)
if !ok {
panic("developer error: request id middleware not provided")
}
return rid
}
// RequestIDOptional returns the request ID when present.
func RequestIDOptional(r *http.Request) (uuid.UUID, bool) {
rid, ok := r.Context().Value(requestIDContextKey{}).(uuid.UUID)
return rid, ok
}
// WithRequestID stores a request ID in the context.
func WithRequestID(ctx context.Context, rid uuid.UUID) context.Context {
return context.WithValue(ctx, requestIDContextKey{}, rid)
}
// AttachRequestID adds a request ID to each HTTP request.
func AttachRequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+15
View File
@@ -1,11 +1,13 @@
package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/httpmw"
@@ -31,3 +33,16 @@ func TestRequestID(t *testing.T) {
require.NotEmpty(t, res.Header.Get("X-Coder-Request-ID"))
require.NotEmpty(t, rw.Body.Bytes())
}
func TestRequestIDHelpers(t *testing.T) {
t.Parallel()
requestID := uuid.New()
ctx := httpmw.WithRequestID(context.Background(), requestID)
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
gotRequestID, ok := httpmw.RequestIDOptional(req)
require.True(t, ok)
require.Equal(t, requestID, gotRequestID)
require.Equal(t, requestID, httpmw.RequestID(req))
}
+46
View File
@@ -0,0 +1,46 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
func ChatEventChannel(ownerID uuid.UUID) string {
return fmt.Sprintf("chat:owner:%s", ownerID)
}
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
return
}
var payload ChatEvent
if err := json.Unmarshal(message, &payload); err != nil {
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event"))
return
}
cb(ctx, payload, err)
}
}
type ChatEvent struct {
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
}
type ChatEventKind string
const (
ChatEventKindStatusChange ChatEventKind = "status_change"
ChatEventKindTitleChange ChatEventKind = "title_change"
ChatEventKindCreated ChatEventKind = "created"
ChatEventKindDeleted ChatEventKind = "deleted"
)
+37
View File
@@ -0,0 +1,37 @@
package pubsub
import (
"fmt"
"github.com/google/uuid"
)
// ChatStreamNotifyChannel returns the pubsub channel for per-chat
// stream notifications. Subscribers receive lightweight notifications
// and read actual content from the database.
func ChatStreamNotifyChannel(chatID uuid.UUID) string {
return fmt.Sprintf("chat:stream:%s", chatID)
}
// ChatStreamNotifyMessage is the payload published on the per-chat
// stream notification channel. The actual message content is read
// from the database by subscribers.
type ChatStreamNotifyMessage struct {
// AfterMessageID tells subscribers to query messages after this
// ID. Set when a new message is persisted.
AfterMessageID int64 `json:"after_message_id,omitempty"`
// Status is set when the chat status changes. Subscribers use
// this to update clients and to manage relay lifecycle.
Status string `json:"status,omitempty"`
// WorkerID identifies which replica is running the chat. Used
// by enterprise relay to know where to connect.
WorkerID string `json:"worker_id,omitempty"`
// Error is set when a processing error occurs.
Error string `json:"error,omitempty"`
// QueueUpdate is set when the queued messages change.
QueueUpdate bool `json:"queue_update,omitempty"`
}
+11
View File
@@ -72,6 +72,16 @@ var (
Type: "boundary_usage",
}
// ResourceChat
// Valid Actions
// - "ActionCreate" :: create a new chat
// - "ActionDelete" :: delete a chat
// - "ActionRead" :: read chat messages and metadata
// - "ActionUpdate" :: update chat title or settings
ResourceChat = Object{
Type: "chat",
}
// ResourceConnectionLog
// Valid Actions
// - "ActionRead" :: read connection logs
@@ -429,6 +439,7 @@ func AllResources() []Objecter {
ResourceAssignRole,
ResourceAuditLog,
ResourceBoundaryUsage,
ResourceChat,
ResourceConnectionLog,
ResourceCryptoKey,
ResourceDebugInfo,
+10
View File
@@ -77,6 +77,13 @@ var taskActions = map[Action]ActionDefinition{
ActionDelete: "delete task",
}
var chatActions = map[Action]ActionDefinition{
ActionCreate: "create a new chat",
ActionRead: "read chat messages and metadata",
ActionUpdate: "update chat title or settings",
ActionDelete: "delete a chat",
}
// RBACPermissions is indexed by the type
var RBACPermissions = map[string]PermissionDefinition{
// Wildcard is every object, and the action "*" provides all actions.
@@ -103,6 +110,9 @@ var RBACPermissions = map[string]PermissionDefinition{
"task": {
Actions: taskActions,
},
"chat": {
Actions: chatActions,
},
// Dormant workspaces have the same perms as workspaces.
"workspace_dormant": {
Actions: workspaceActions,
+14
View File
@@ -1030,6 +1030,20 @@ func TestRolePermissions(t *testing.T) {
false: {owner, setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
},
},
{
Name: "ChatUsage",
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
Resource: rbac.ResourceChat.WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, memberMe},
false: {
orgAdmin, otherOrgAdmin,
orgAuditor, otherOrgAuditor,
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
userAdmin, orgUserAdmin, otherOrgUserAdmin,
},
},
},
}
// We expect every permission to be tested above.
+12
View File
@@ -28,6 +28,10 @@ const (
ScopeBoundaryUsageDelete ScopeName = "boundary_usage:delete"
ScopeBoundaryUsageRead ScopeName = "boundary_usage:read"
ScopeBoundaryUsageUpdate ScopeName = "boundary_usage:update"
ScopeChatCreate ScopeName = "chat:create"
ScopeChatDelete ScopeName = "chat:delete"
ScopeChatRead ScopeName = "chat:read"
ScopeChatUpdate ScopeName = "chat:update"
ScopeConnectionLogRead ScopeName = "connection_log:read"
ScopeConnectionLogUpdate ScopeName = "connection_log:update"
ScopeCryptoKeyCreate ScopeName = "crypto_key:create"
@@ -188,6 +192,10 @@ func (e ScopeName) Valid() bool {
ScopeBoundaryUsageDelete,
ScopeBoundaryUsageRead,
ScopeBoundaryUsageUpdate,
ScopeChatCreate,
ScopeChatDelete,
ScopeChatRead,
ScopeChatUpdate,
ScopeConnectionLogRead,
ScopeConnectionLogUpdate,
ScopeCryptoKeyCreate,
@@ -349,6 +357,10 @@ func AllScopeNameValues() []ScopeName {
ScopeBoundaryUsageDelete,
ScopeBoundaryUsageRead,
ScopeBoundaryUsageUpdate,
ScopeChatCreate,
ScopeChatDelete,
ScopeChatRead,
ScopeChatUpdate,
ScopeConnectionLogRead,
ScopeConnectionLogUpdate,
ScopeCryptoKeyCreate,
+16 -4
View File
@@ -139,7 +139,6 @@ const AgentAPIVersionREST = "1.0"
func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
var req agentsdk.PatchLogs
if !httpapi.Read(ctx, rw, r, &req) {
return
@@ -1832,6 +1831,10 @@ func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []code
func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
gitRef := chatGitRef{
Branch: strings.TrimSpace(query.Get("git_branch")),
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
}
// Either match or configID must be provided!
match := query.Get("match")
if match == "" {
@@ -1854,7 +1857,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
// listen determines if the request will wait for a
// new token to be issued!
listen := r.URL.Query().Has("listen")
listen := query.Has("listen")
var externalAuthConfig *externalauth.Config
for _, extAuth := range api.ExternalAuthConfigs {
@@ -1925,6 +1928,13 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
return
}
// Persist git refs as soon as the agent requests external auth so branch
// context is retained even if the flow requires an out-of-band login.
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
//nolint:gocritic // System context required to persist chat git refs.
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, gitRef)
}
var previousToken *database.ExternalAuthLink
// handleRetrying will attempt to continually check for a new token
// if listen is true. This is useful if an error is encountered in the
@@ -1938,7 +1948,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
return
}
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace)
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
}
// This is the URL that will redirect the user with a state token.
@@ -1996,10 +2006,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
})
return
}
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace) {
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
// Since we're ticking frequently and this sign-in operation is rare,
// we are OK with polling to avoid the complexity of pubsub.
ticker, done := api.NewTicker(time.Second)
@@ -2069,6 +2080,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
})
return
}
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
httpapi.Write(ctx, rw, http.StatusOK, resp)
return
}
+15 -8
View File
@@ -404,7 +404,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
AvatarURL: member.AvatarURL,
}
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, r, nil)
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, &createWorkspaceOptions{
remoteAddr: r.RemoteAddr,
})
if err != nil {
httperror.WriteResponseError(ctx, rw, err)
return
@@ -500,7 +502,9 @@ func (api *API) postUserWorkspaces(rw http.ResponseWriter, r *http.Request) {
defer commitAudit()
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, r, nil)
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, &createWorkspaceOptions{
remoteAddr: r.RemoteAddr,
})
if err != nil {
httperror.WriteResponseError(ctx, rw, err)
return
@@ -522,6 +526,10 @@ type createWorkspaceOptions struct {
// postCreateInTX is a function that is called within the transaction, after
// the workspace is created but before the workspace build is created.
postCreateInTX func(ctx context.Context, tx database.Store, workspace database.Workspace) error
// remoteAddr is the IP address of the request initiator, used for
// audit logging. HTTP handlers should pass r.RemoteAddr;
// programmatic callers may leave it empty.
remoteAddr string
}
func createWorkspace(
@@ -531,7 +539,6 @@ func createWorkspace(
api *API,
owner workspaceOwner,
req codersdk.CreateWorkspaceRequest,
r *http.Request,
opts *createWorkspaceOptions,
) (codersdk.Workspace, error) {
if opts == nil {
@@ -545,7 +552,7 @@ func createWorkspace(
// This is a premature auth check to avoid doing unnecessary work if the user
// doesn't have permission to create a workspace.
if !api.Authorize(r, policy.ActionCreate,
if !api.HTTPAuth.AuthorizeContext(ctx, policy.ActionCreate,
rbac.ResourceWorkspace.InOrg(template.OrganizationID).WithOwner(owner.ID.String())) {
// If this check fails, return a proper unauthorized error to the user to indicate
// what is going on.
@@ -562,14 +569,14 @@ func createWorkspace(
// Do this upfront to save work. If this fails, the rest of the work
// would be wasted.
if !api.Authorize(r, policy.ActionCreate,
if !api.HTTPAuth.AuthorizeContext(ctx, policy.ActionCreate,
rbac.ResourceWorkspace.InOrg(template.OrganizationID).WithOwner(owner.ID.String())) {
return codersdk.Workspace{}, httperror.ErrResourceNotFound
}
// The user also needs permission to use the template. At this point they have
// read perms, but not necessarily "use". This is also checked in `db.InsertWorkspace`.
// Doing this up front can save some work below if the user doesn't have permission.
if !api.Authorize(r, policy.ActionUse, template) {
if !api.HTTPAuth.AuthorizeContext(ctx, policy.ActionUse, template) {
return codersdk.Workspace{}, httperror.NewResponseError(http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Unauthorized access to use the template %q.", template.Name),
Detail: "Although you are able to view the template, you are unable to create a workspace using it. " +
@@ -801,9 +808,9 @@ func createWorkspace(
db,
api.FileCache,
func(action policy.Action, object rbac.Objecter) bool {
return api.Authorize(r, action, object)
return api.HTTPAuth.AuthorizeContext(ctx, action, object)
},
audit.WorkspaceBuildBaggageFromRequest(r),
audit.WorkspaceBuildBaggage{IP: opts.remoteAddr},
)
return err
}, nil)
+14
View File
@@ -638,6 +638,14 @@ type ExternalAuthRequest struct {
ID string
// Match is an arbitrary string matched against the regex of the provider.
Match string
// GitBranch is the current git branch in the working directory.
// Sent by the agent so the control plane can resolve diffs
// without SSHing into the workspace.
GitBranch string
// GitRemoteOrigin is the remote origin URL of the git repository.
// Sent by the agent so the control plane can resolve diffs
// without SSHing into the workspace.
GitRemoteOrigin string
// Listen indicates that the request should be long-lived and listen for
// a new token to be requested.
Listen bool
@@ -653,6 +661,12 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
if req.Listen {
q.Set("listen", "true")
}
if req.GitBranch != "" {
q.Set("git_branch", req.GitBranch)
}
if req.GitRemoteOrigin != "" {
q.Set("git_remote_origin", req.GitRemoteOrigin)
}
reqURL := "/api/v2/workspaceagents/me/external-auth?" + q.Encode()
res, err := c.SDK.Request(ctx, http.MethodGet, reqURL, nil)
if err != nil {
+30
View File
@@ -153,3 +153,33 @@ func TestRewriteDERPMap(t *testing.T) {
require.Equal(t, "coconuts.org", node.HostName)
require.Equal(t, 44558, node.DERPPort)
}
func TestExternalAuthRequestQuery(t *testing.T) {
t.Parallel()
t.Run("IncludesGitRefFieldsAndOmitsWorkdir", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/api/v2/workspaceagents/me/external-auth", r.URL.Path)
require.Equal(t, "true", r.URL.Query().Get("listen"))
require.Equal(t, "main", r.URL.Query().Get("git_branch"))
require.Equal(t, "https://github.com/coder/coder.git", r.URL.Query().Get("git_remote_origin"))
require.False(t, r.URL.Query().Has("workdir"))
_, _ = w.Write([]byte(`{"type":"github","access_token":"token"}`))
}))
defer srv.Close()
parsedURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := agentsdk.New(parsedURL, agentsdk.WithFixedToken("token"))
_, err = client.ExternalAuth(testutil.Context(t, testutil.WaitShort), agentsdk.ExternalAuthRequest{
Match: "github.com",
Listen: true,
GitBranch: "main",
GitRemoteOrigin: "https://github.com/coder/coder.git",
})
require.NoError(t, err)
})
}
+5
View File
@@ -33,6 +33,11 @@ const (
APIKeyScopeBoundaryUsageDelete APIKeyScope = "boundary_usage:delete"
APIKeyScopeBoundaryUsageRead APIKeyScope = "boundary_usage:read"
APIKeyScopeBoundaryUsageUpdate APIKeyScope = "boundary_usage:update"
APIKeyScopeChatAll APIKeyScope = "chat:*"
APIKeyScopeChatCreate APIKeyScope = "chat:create"
APIKeyScopeChatDelete APIKeyScope = "chat:delete"
APIKeyScopeChatRead APIKeyScope = "chat:read"
APIKeyScopeChatUpdate APIKeyScope = "chat:update"
APIKeyScopeCoderAll APIKeyScope = "coder:all"
APIKeyScopeCoderApikeysManageSelf APIKeyScope = "coder:apikeys.manage_self"
APIKeyScopeCoderApplicationConnect APIKeyScope = "coder:application_connect"
+903
View File
@@ -0,0 +1,903 @@
package codersdk
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
// ChatStatus represents the status of a chat.
type ChatStatus string
const (
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
)
// Chat represents a chat session with an AI agent.
type Chat struct {
ID uuid.UUID `json:"id" format:"uuid"`
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
WorkspaceAgentID *uuid.UUID `json:"workspace_agent_id,omitempty" format:"uuid"`
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
Title string `json:"title"`
Status ChatStatus `json:"status"`
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
}
// ChatMessage represents a single message in a chat.
type ChatMessage struct {
ID int64 `json:"id"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
Role string `json:"role"`
Content []ChatMessagePart `json:"content,omitempty"`
Usage *ChatMessageUsage `json:"usage,omitempty"`
}
// ChatMessageUsage contains token usage information for a chat message.
type ChatMessageUsage struct {
InputTokens *int64 `json:"input_tokens,omitempty"`
OutputTokens *int64 `json:"output_tokens,omitempty"`
TotalTokens *int64 `json:"total_tokens,omitempty"`
ReasoningTokens *int64 `json:"reasoning_tokens,omitempty"`
CacheCreationTokens *int64 `json:"cache_creation_tokens,omitempty"`
CacheReadTokens *int64 `json:"cache_read_tokens,omitempty"`
ContextLimit *int64 `json:"context_limit,omitempty"`
}
// ChatMessagePartType represents a structured message part type.
type ChatMessagePartType string
const (
ChatMessagePartTypeText ChatMessagePartType = "text"
ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning"
ChatMessagePartTypeToolCall ChatMessagePartType = "tool-call"
ChatMessagePartTypeToolResult ChatMessagePartType = "tool-result"
ChatMessagePartTypeSource ChatMessagePartType = "source"
ChatMessagePartTypeFile ChatMessagePartType = "file"
)
// ChatMessagePart is a structured chunk of a chat message.
type ChatMessagePart struct {
Type ChatMessagePartType `json:"type"`
Text string `json:"text,omitempty"`
Signature string `json:"signature,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolName string `json:"tool_name,omitempty"`
Args json.RawMessage `json:"args,omitempty"`
ArgsDelta string `json:"args_delta,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
ResultDelta string `json:"result_delta,omitempty"`
IsError bool `json:"is_error,omitempty"`
SourceID string `json:"source_id,omitempty"`
URL string `json:"url,omitempty"`
Title string `json:"title,omitempty"`
MediaType string `json:"media_type,omitempty"`
Data []byte `json:"data,omitempty"`
}
// ChatInputPartType represents an input part type for user chat input.
type ChatInputPartType string
const (
ChatInputPartTypeText ChatInputPartType = "text"
)
// ChatInputPart is a single user input part for creating a chat.
type ChatInputPart struct {
Type ChatInputPartType `json:"type"`
Text string `json:"text,omitempty"`
}
// CreateChatRequest is the request to create a new chat.
type CreateChatRequest struct {
Content []ChatInputPart `json:"content"`
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
}
// UpdateChatRequest is the request to update a chat.
type UpdateChatRequest struct {
Title string `json:"title"`
}
// CreateChatMessageRequest is the request to add a message to a chat.
type CreateChatMessageRequest struct {
Content []ChatInputPart `json:"content"`
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
}
// EditChatMessageRequest is the request to edit a user message in a chat.
type EditChatMessageRequest struct {
Content []ChatInputPart `json:"content"`
}
// CreateChatMessageResponse is the response from adding a message to a chat.
type CreateChatMessageResponse struct {
Message *ChatMessage `json:"message,omitempty"`
QueuedMessage *ChatQueuedMessage `json:"queued_message,omitempty"`
Queued bool `json:"queued"`
}
// ChatWithMessages is a chat along with its messages.
type ChatWithMessages struct {
Chat Chat `json:"chat"`
Messages []ChatMessage `json:"messages"`
QueuedMessages []ChatQueuedMessage `json:"queued_messages"`
}
// ChatModelProviderUnavailableReason explains why a provider cannot be used.
type ChatModelProviderUnavailableReason string
const (
ChatModelProviderUnavailableMissingAPIKey ChatModelProviderUnavailableReason = "missing_api_key"
ChatModelProviderUnavailableFetchFailed ChatModelProviderUnavailableReason = "fetch_failed"
)
// ChatModel represents a model in the chat model catalog.
type ChatModel struct {
ID string `json:"id"`
Provider string `json:"provider"`
Model string `json:"model"`
DisplayName string `json:"display_name"`
}
// ChatModelProvider represents provider availability and model results.
type ChatModelProvider struct {
Provider string `json:"provider"`
Available bool `json:"available"`
UnavailableReason ChatModelProviderUnavailableReason `json:"unavailable_reason,omitempty"`
Models []ChatModel `json:"models"`
}
// ChatModelsResponse is the catalog returned from chat model discovery.
type ChatModelsResponse struct {
Providers []ChatModelProvider `json:"providers"`
}
// ChatProviderConfigSource describes how a provider entry is sourced.
type ChatProviderConfigSource string
const (
ChatProviderConfigSourceDatabase ChatProviderConfigSource = "database"
ChatProviderConfigSourceEnvPreset ChatProviderConfigSource = "env_preset"
ChatProviderConfigSourceSupported ChatProviderConfigSource = "supported"
)
// ChatProviderConfig is an admin-managed provider configuration.
type ChatProviderConfig struct {
ID uuid.UUID `json:"id" format:"uuid"`
Provider string `json:"provider"`
DisplayName string `json:"display_name"`
Enabled bool `json:"enabled"`
HasAPIKey bool `json:"has_api_key"`
BaseURL string `json:"base_url,omitempty"`
Source ChatProviderConfigSource `json:"source"`
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
}
// CreateChatProviderConfigRequest creates a chat provider config.
type CreateChatProviderConfigRequest struct {
Provider string `json:"provider"`
DisplayName string `json:"display_name,omitempty"`
APIKey string `json:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
}
// UpdateChatProviderConfigRequest updates a chat provider config.
type UpdateChatProviderConfigRequest struct {
DisplayName string `json:"display_name,omitempty"`
APIKey *string `json:"api_key,omitempty"`
BaseURL *string `json:"base_url,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
}
// ChatModelConfig is an admin-managed model configuration.
type ChatModelConfig struct {
ID uuid.UUID `json:"id" format:"uuid"`
Provider string `json:"provider"`
Model string `json:"model"`
DisplayName string `json:"display_name"`
Enabled bool `json:"enabled"`
IsDefault bool `json:"is_default"`
ContextLimit int64 `json:"context_limit"`
CompressionThreshold int32 `json:"compression_threshold"`
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
}
// ChatModelProviderOptions contains typed provider-specific options.
//
// Note: Azure models use the `openai` options shape.
// Note: Bedrock models use the `anthropic` options shape.
type ChatModelProviderOptions struct {
OpenAI *ChatModelOpenAIProviderOptions `json:"openai,omitempty"`
Anthropic *ChatModelAnthropicProviderOptions `json:"anthropic,omitempty"`
Google *ChatModelGoogleProviderOptions `json:"google,omitempty"`
OpenAICompat *ChatModelOpenAICompatProviderOptions `json:"openaicompat,omitempty"`
OpenRouter *ChatModelOpenRouterProviderOptions `json:"openrouter,omitempty"`
Vercel *ChatModelVercelProviderOptions `json:"vercel,omitempty"`
}
// ChatModelOpenAIProviderOptions configures OpenAI provider behavior.
type ChatModelOpenAIProviderOptions struct {
Include []string `json:"include,omitempty"`
Instructions *string `json:"instructions,omitempty"`
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
LogProbs *bool `json:"log_probs,omitempty"`
TopLogProbs *int64 `json:"top_log_probs,omitempty"`
MaxToolCalls *int64 `json:"max_tool_calls,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
User *string `json:"user,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
ReasoningSummary *string `json:"reasoning_summary,omitempty"`
MaxCompletionTokens *int64 `json:"max_completion_tokens,omitempty"`
TextVerbosity *string `json:"text_verbosity,omitempty"`
Prediction map[string]any `json:"prediction,omitempty"`
Store *bool `json:"store,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
PromptCacheKey *string `json:"prompt_cache_key,omitempty"`
SafetyIdentifier *string `json:"safety_identifier,omitempty"`
ServiceTier *string `json:"service_tier,omitempty"`
StructuredOutputs *bool `json:"structured_outputs,omitempty"`
StrictJSONSchema *bool `json:"strict_json_schema,omitempty"`
}
// ChatModelAnthropicThinkingOptions configures Anthropic thinking budget.
type ChatModelAnthropicThinkingOptions struct {
BudgetTokens *int64 `json:"budget_tokens,omitempty"`
}
// ChatModelAnthropicProviderOptions configures Anthropic provider behavior.
type ChatModelAnthropicProviderOptions struct {
SendReasoning *bool `json:"send_reasoning,omitempty"`
Thinking *ChatModelAnthropicThinkingOptions `json:"thinking,omitempty"`
Effort *string `json:"effort,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
}
// ChatModelGoogleThinkingConfig configures Google thinking behavior.
type ChatModelGoogleThinkingConfig struct {
ThinkingBudget *int64 `json:"thinking_budget,omitempty"`
IncludeThoughts *bool `json:"include_thoughts,omitempty"`
}
// ChatModelGoogleSafetySetting configures Google safety filtering.
type ChatModelGoogleSafetySetting struct {
Category string `json:"category,omitempty"`
Threshold string `json:"threshold,omitempty"`
}
// ChatModelGoogleProviderOptions configures Google provider behavior.
type ChatModelGoogleProviderOptions struct {
ThinkingConfig *ChatModelGoogleThinkingConfig `json:"thinking_config,omitempty"`
CachedContent string `json:"cached_content,omitempty"`
SafetySettings []ChatModelGoogleSafetySetting `json:"safety_settings,omitempty"`
Threshold string `json:"threshold,omitempty"`
}
// ChatModelOpenAICompatProviderOptions configures OpenAI-compatible behavior.
type ChatModelOpenAICompatProviderOptions struct {
User *string `json:"user,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
}
// ChatModelOpenRouterReasoningOptions configures OpenRouter reasoning behavior.
type ChatModelOpenRouterReasoningOptions struct {
Enabled *bool `json:"enabled,omitempty"`
Exclude *bool `json:"exclude,omitempty"`
MaxTokens *int64 `json:"max_tokens,omitempty"`
Effort *string `json:"effort,omitempty"`
}
// ChatModelOpenRouterProvider configures OpenRouter routing preferences.
type ChatModelOpenRouterProvider struct {
Order []string `json:"order,omitempty"`
AllowFallbacks *bool `json:"allow_fallbacks,omitempty"`
RequireParameters *bool `json:"require_parameters,omitempty"`
DataCollection *string `json:"data_collection,omitempty"`
Only []string `json:"only,omitempty"`
Ignore []string `json:"ignore,omitempty"`
Quantizations []string `json:"quantizations,omitempty"`
Sort *string `json:"sort,omitempty"`
}
// ChatModelOpenRouterProviderOptions configures OpenRouter provider behavior.
type ChatModelOpenRouterProviderOptions struct {
Reasoning *ChatModelOpenRouterReasoningOptions `json:"reasoning,omitempty"`
ExtraBody map[string]any `json:"extra_body,omitempty"`
IncludeUsage *bool `json:"include_usage,omitempty"`
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
LogProbs *bool `json:"log_probs,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
User *string `json:"user,omitempty"`
Provider *ChatModelOpenRouterProvider `json:"provider,omitempty"`
}
// ChatModelVercelReasoningOptions configures Vercel reasoning behavior.
type ChatModelVercelReasoningOptions struct {
Enabled *bool `json:"enabled,omitempty"`
MaxTokens *int64 `json:"max_tokens,omitempty"`
Effort *string `json:"effort,omitempty"`
Exclude *bool `json:"exclude,omitempty"`
}
// ChatModelVercelGatewayProviderOptions configures Vercel routing behavior.
type ChatModelVercelGatewayProviderOptions struct {
Order []string `json:"order,omitempty"`
Models []string `json:"models,omitempty"`
}
// ChatModelVercelProviderOptions configures Vercel provider behavior.
type ChatModelVercelProviderOptions struct {
Reasoning *ChatModelVercelReasoningOptions `json:"reasoning,omitempty"`
ProviderOptions *ChatModelVercelGatewayProviderOptions `json:"providerOptions,omitempty"`
User *string `json:"user,omitempty"`
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
LogProbs *bool `json:"logprobs,omitempty"`
TopLogProbs *int64 `json:"top_logprobs,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ExtraBody map[string]any `json:"extra_body,omitempty"`
}
// ChatModelCallConfig configures per-call model behavior defaults.
type ChatModelCallConfig struct {
MaxOutputTokens *int64 `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int64 `json:"top_k,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
ProviderOptions *ChatModelProviderOptions `json:"provider_options,omitempty"`
}
// CreateChatModelConfigRequest creates a chat model config.
type CreateChatModelConfigRequest struct {
Provider string `json:"provider"`
Model string `json:"model"`
DisplayName string `json:"display_name,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
IsDefault *bool `json:"is_default,omitempty"`
ContextLimit *int64 `json:"context_limit,omitempty"`
CompressionThreshold *int32 `json:"compression_threshold,omitempty"`
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
}
// UpdateChatModelConfigRequest updates a chat model config.
type UpdateChatModelConfigRequest struct {
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
DisplayName string `json:"display_name,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
IsDefault *bool `json:"is_default,omitempty"`
ContextLimit *int64 `json:"context_limit,omitempty"`
CompressionThreshold *int32 `json:"compression_threshold,omitempty"`
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
}
// ChatGitChange represents a git file change detected during a chat session.
type ChatGitChange struct {
ID uuid.UUID `json:"id" format:"uuid"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
FilePath string `json:"file_path"`
ChangeType string `json:"change_type"` // added, modified, deleted, renamed
OldPath *string `json:"old_path,omitempty"`
DiffSummary *string `json:"diff_summary,omitempty"`
DetectedAt time.Time `json:"detected_at" format:"date-time"`
}
// ChatDiffStatus represents cached diff status for a chat. The URL
// may point to a pull request or a branch page depending on whether
// a PR has been opened.
type ChatDiffStatus struct {
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
URL *string `json:"url,omitempty"`
PullRequestState *string `json:"pull_request_state,omitempty"`
ChangesRequested bool `json:"changes_requested"`
Additions int32 `json:"additions"`
Deletions int32 `json:"deletions"`
ChangedFiles int32 `json:"changed_files"`
RefreshedAt *time.Time `json:"refreshed_at,omitempty" format:"date-time"`
StaleAt *time.Time `json:"stale_at,omitempty" format:"date-time"`
}
// ChatDiffContents represents the resolved diff text for a chat.
type ChatDiffContents struct {
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Provider *string `json:"provider,omitempty"`
RemoteOrigin *string `json:"remote_origin,omitempty"`
Branch *string `json:"branch,omitempty"`
PullRequestURL *string `json:"pull_request_url,omitempty"`
Diff string `json:"diff,omitempty"`
}
// ChatStreamEventType represents the kind of chat stream update.
type ChatStreamEventType string
const (
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
ChatStreamEventTypeMessage ChatStreamEventType = "message"
ChatStreamEventTypeStatus ChatStreamEventType = "status"
ChatStreamEventTypeError ChatStreamEventType = "error"
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
)
// ChatQueuedMessage represents a queued message waiting to be processed.
type ChatQueuedMessage struct {
ID int64 `json:"id"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Content []ChatMessagePart `json:"content"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
// ChatStreamMessagePart is a streamed message part update.
type ChatStreamMessagePart struct {
Role string `json:"role,omitempty"`
Part ChatMessagePart `json:"part"`
}
// ChatStreamStatus represents an updated chat status.
type ChatStreamStatus struct {
Status ChatStatus `json:"status"`
}
// ChatStreamError represents an error event in the stream.
type ChatStreamError struct {
Message string `json:"message"`
}
// ChatStreamEvent represents a real-time update for chat streaming.
type ChatStreamEvent struct {
Type ChatStreamEventType `json:"type"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Message *ChatMessage `json:"message,omitempty"`
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
Status *ChatStreamStatus `json:"status,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
}
type chatStreamEnvelope struct {
Type ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// ListChats returns all chats for the authenticated user.
func (c *Client) ListChats(ctx context.Context) ([]Chat, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats", nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var chats []Chat
return chats, json.NewDecoder(res.Body).Decode(&chats)
}
// ListChatModels returns the available chat model catalog.
func (c *Client) ListChatModels(ctx context.Context) (ChatModelsResponse, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/models", nil)
if err != nil {
return ChatModelsResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatModelsResponse{}, ReadBodyAsError(res)
}
var catalog ChatModelsResponse
return catalog, json.NewDecoder(res.Body).Decode(&catalog)
}
// ListChatProviders returns admin-managed chat provider configs.
func (c *Client) ListChatProviders(ctx context.Context) ([]ChatProviderConfig, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/providers", nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var providers []ChatProviderConfig
return providers, json.NewDecoder(res.Body).Decode(&providers)
}
// CreateChatProvider creates an admin-managed chat provider config.
func (c *Client) CreateChatProvider(ctx context.Context, req CreateChatProviderConfigRequest) (ChatProviderConfig, error) {
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats/providers", req)
if err != nil {
return ChatProviderConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return ChatProviderConfig{}, ReadBodyAsError(res)
}
var provider ChatProviderConfig
return provider, json.NewDecoder(res.Body).Decode(&provider)
}
// UpdateChatProvider updates an admin-managed chat provider config.
func (c *Client) UpdateChatProvider(ctx context.Context, providerID uuid.UUID, req UpdateChatProviderConfigRequest) (ChatProviderConfig, error) {
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/chats/providers/%s", providerID), req)
if err != nil {
return ChatProviderConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatProviderConfig{}, ReadBodyAsError(res)
}
var provider ChatProviderConfig
return provider, json.NewDecoder(res.Body).Decode(&provider)
}
// DeleteChatProvider deletes an admin-managed chat provider config.
func (c *Client) DeleteChatProvider(ctx context.Context, providerID uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/providers/%s", providerID), nil)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// ListChatModelConfigs returns admin-managed chat model configs.
func (c *Client) ListChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/model-configs", nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var configs []ChatModelConfig
return configs, json.NewDecoder(res.Body).Decode(&configs)
}
// CreateChatModelConfig creates an admin-managed chat model config.
func (c *Client) CreateChatModelConfig(ctx context.Context, req CreateChatModelConfigRequest) (ChatModelConfig, error) {
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats/model-configs", req)
if err != nil {
return ChatModelConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return ChatModelConfig{}, ReadBodyAsError(res)
}
var config ChatModelConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
// UpdateChatModelConfig updates an admin-managed chat model config.
func (c *Client) UpdateChatModelConfig(ctx context.Context, modelConfigID uuid.UUID, req UpdateChatModelConfigRequest) (ChatModelConfig, error) {
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/chats/model-configs/%s", modelConfigID), req)
if err != nil {
return ChatModelConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatModelConfig{}, ReadBodyAsError(res)
}
var config ChatModelConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
// DeleteChatModelConfig deletes an admin-managed chat model config.
func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/model-configs/%s", modelConfigID), nil)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// CreateChat creates a new chat.
func (c *Client) CreateChat(ctx context.Context, req CreateChatRequest) (Chat, error) {
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats", req)
if err != nil {
return Chat{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return Chat{}, ReadBodyAsError(res)
}
var chat Chat
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
// StreamChat streams chat updates in real time.
//
// The returned channel includes initial snapshot events first, followed by
// live updates. Callers must close the returned io.Closer to release the
// websocket connection when done.
func (c *Client) StreamChat(ctx context.Context, chatID uuid.UUID) (<-chan ChatStreamEvent, io.Closer, error) {
conn, err := c.Dial(
ctx,
fmt.Sprintf("/api/experimental/chats/%s/stream", chatID),
&websocket.DialOptions{CompressionMode: websocket.CompressionDisabled},
)
if err != nil {
return nil, nil, err
}
conn.SetReadLimit(1 << 22) // 4MiB
streamCtx, streamCancel := context.WithCancel(ctx)
events := make(chan ChatStreamEvent, 128)
send := func(event ChatStreamEvent) bool {
if event.ChatID == uuid.Nil {
event.ChatID = chatID
}
select {
case <-streamCtx.Done():
return false
case events <- event:
return true
}
}
go func() {
defer close(events)
defer streamCancel()
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
for {
var envelope chatStreamEnvelope
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
if streamCtx.Err() != nil {
return
}
switch websocket.CloseStatus(err) {
case websocket.StatusNormalClosure, websocket.StatusGoingAway:
return
}
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: fmt.Sprintf("read chat stream: %v", err),
},
})
return
}
switch envelope.Type {
case ServerSentEventTypePing:
continue
case ServerSentEventTypeData:
var batch []ChatStreamEvent
decodeErr := json.Unmarshal(envelope.Data, &batch)
if decodeErr == nil {
for _, streamedEvent := range batch {
if !send(streamedEvent) {
return
}
}
continue
}
{
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: fmt.Sprintf(
"decode chat stream event batch: %v",
decodeErr,
),
},
})
return
}
case ServerSentEventTypeError:
message := "chat stream returned an error"
if len(envelope.Data) > 0 {
var response Response
if err := json.Unmarshal(envelope.Data, &response); err == nil {
message = formatChatStreamResponseError(response)
} else {
trimmed := strings.TrimSpace(string(envelope.Data))
if trimmed != "" {
message = trimmed
}
}
}
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: message,
},
})
return
default:
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type),
},
})
return
}
}
}()
return events, closeFunc(func() error {
streamCancel()
return conn.Close(websocket.StatusNormalClosure, "")
}), nil
}
// GetChat returns a chat by ID, including its messages.
func (c *Client) GetChat(ctx context.Context, chatID uuid.UUID) (ChatWithMessages, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
if err != nil {
return ChatWithMessages{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatWithMessages{}, ReadBodyAsError(res)
}
var chat ChatWithMessages
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
// DeleteChat deletes a chat by ID.
func (c *Client) DeleteChat(ctx context.Context, chatID uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// CreateChatMessage adds a message to a chat.
func (c *Client) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req CreateChatMessageRequest) (CreateChatMessageResponse, error) {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/messages", chatID), req)
if err != nil {
return CreateChatMessageResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return CreateChatMessageResponse{}, ReadBodyAsError(res)
}
var resp CreateChatMessageResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// EditChatMessage edits an existing user message in a chat and re-runs from there.
func (c *Client) EditChatMessage(
ctx context.Context,
chatID uuid.UUID,
messageID int64,
req EditChatMessageRequest,
) (ChatMessage, error) {
res, err := c.Request(
ctx,
http.MethodPatch,
fmt.Sprintf("/api/experimental/chats/%s/messages/%d", chatID, messageID),
req,
)
if err != nil {
return ChatMessage{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatMessage{}, ReadBodyAsError(res)
}
var message ChatMessage
return message, json.NewDecoder(res.Body).Decode(&message)
}
// InterruptChat cancels an in-flight chat run and leaves it waiting.
func (c *Client) InterruptChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/interrupt", chatID), nil)
if err != nil {
return Chat{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return Chat{}, ReadBodyAsError(res)
}
var chat Chat
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
// GetChatGitChanges returns git changes for a chat.
func (c *Client) GetChatGitChanges(ctx context.Context, chatID uuid.UUID) ([]ChatGitChange, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/git-changes", chatID), nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var changes []ChatGitChange
return changes, json.NewDecoder(res.Body).Decode(&changes)
}
// GetChatDiffStatus returns cached GitHub pull request diff status for a chat.
func (c *Client) GetChatDiffStatus(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/diff-status", chatID), nil)
if err != nil {
return ChatDiffStatus{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatDiffStatus{}, ReadBodyAsError(res)
}
var status ChatDiffStatus
return status, json.NewDecoder(res.Body).Decode(&status)
}
// GetChatDiffContents returns resolved diff contents for a chat.
func (c *Client) GetChatDiffContents(ctx context.Context, chatID uuid.UUID) (ChatDiffContents, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/diff", chatID), nil)
if err != nil {
return ChatDiffContents{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatDiffContents{}, ReadBodyAsError(res)
}
var diff ChatDiffContents
return diff, json.NewDecoder(res.Body).Decode(&diff)
}
func formatChatStreamResponseError(response Response) string {
message := strings.TrimSpace(response.Message)
detail := strings.TrimSpace(response.Detail)
switch {
case message == "" && detail == "":
return "chat stream returned an error"
case message == "":
return detail
case detail == "":
return message
default:
return fmt.Sprintf("%s: %s", message, detail)
}
}
+53
View File
@@ -0,0 +1,53 @@
package codersdk_test
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk"
)
func TestChatModelProviderOptions_MarshalJSON_UsesPlainProviderPayload(t *testing.T) {
t.Parallel()
sendReasoning := true
effort := "high"
raw, err := json.Marshal(codersdk.ChatModelProviderOptions{
Anthropic: &codersdk.ChatModelAnthropicProviderOptions{
SendReasoning: &sendReasoning,
Effort: &effort,
},
})
require.NoError(t, err)
require.NotContains(t, string(raw), `"type":"anthropic.options"`)
require.NotContains(t, string(raw), `"data":`)
require.Contains(t, string(raw), `"send_reasoning":true`)
require.Contains(t, string(raw), `"effort":"high"`)
}
func TestChatModelProviderOptions_UnmarshalJSON_ParsesPlainProviderPayloads(t *testing.T) {
t.Parallel()
raw := []byte(`{
"anthropic": {
"send_reasoning": true,
"effort": "high"
}
}`)
var decoded codersdk.ChatModelProviderOptions
err := json.Unmarshal(raw, &decoded)
require.NoError(t, err)
require.NotNil(t, decoded.Anthropic)
require.NotNil(t, decoded.Anthropic.SendReasoning)
require.True(t, *decoded.Anthropic.SendReasoning)
require.NotNil(t, decoded.Anthropic.Effort)
require.Equal(
t,
"high",
*decoded.Anthropic.Effort,
)
}
+77 -63
View File
@@ -579,68 +579,69 @@ type DeploymentValues struct {
DocsURL serpent.URL `json:"docs_url,omitempty"`
RedirectToAccessURL serpent.Bool `json:"redirect_to_access_url,omitempty"`
// HTTPAddress is a string because it may be set to zero to disable.
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
AI AIConfig `json:"ai,omitempty"`
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
ExternalAuthGithubDefaultProviderEnable serpent.Bool `json:"external_auth_github_default_provider_enable,omitempty" typescript:",notnull"`
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
AI AIConfig `json:"ai,omitempty"`
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"`
@@ -3153,6 +3154,15 @@ Write out the current server config as YAML to stdout.`,
Value: &c.ExternalAuthConfigs,
Hidden: true,
},
{
Name: "External Auth GitHub Default Provider Enable",
Description: "Enable the default GitHub external auth provider managed by Coder.",
Flag: "external-auth-github-default-provider-enable",
Env: "CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE",
YAML: "externalAuthGithubDefaultProviderEnable",
Value: &c.ExternalAuthGithubDefaultProviderEnable,
Default: "true",
},
{
Name: "Custom wgtunnel Host",
Description: `Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By default, this will pick the best available wgtunnel server hosted by Coder. e.g. "tunnel.example.com".`,
@@ -3583,7 +3593,6 @@ Write out the current server config as YAML to stdout.`,
Group: &deploymentGroupClient,
YAML: "hideAITasks",
},
// AI Bridge Options
{
Name: "AI Bridge Enabled",
@@ -4264,6 +4273,7 @@ const (
ExperimentWorkspaceUsage Experiment = "workspace-usage" // Enables the new workspace usage tracking.
ExperimentWebPush Experiment = "web-push" // Enables web push notifications through the browser.
ExperimentOAuth2 Experiment = "oauth2" // Enables OAuth2 provider functionality.
ExperimentAgents Experiment = "agents" // Enables agent-powered chat functionality.
ExperimentMCPServerHTTP Experiment = "mcp-server-http" // Enables the MCP HTTP server functionality.
)
@@ -4281,6 +4291,8 @@ func (e Experiment) DisplayName() string {
return "Browser Push Notifications"
case ExperimentOAuth2:
return "OAuth2 Provider Functionality"
case ExperimentAgents:
return "Agents"
case ExperimentMCPServerHTTP:
return "MCP HTTP Server Functionality"
default:
@@ -4299,6 +4311,7 @@ var ExperimentsKnown = Experiments{
ExperimentWorkspaceUsage,
ExperimentWebPush,
ExperimentOAuth2,
ExperimentAgents,
ExperimentMCPServerHTTP,
}
@@ -4306,6 +4319,7 @@ var ExperimentsKnown = Experiments{
// users to opt-in to via --experimental='*'.
// Experiments that are not ready for consumption by all users should
// not be included here and will be essentially hidden.
// TODO: Add ExperimentAgents to ExperimentsSafe once it is safe for general use.
var ExperimentsSafe = Experiments{}
// Experiments is a list of experiments.
+2
View File
@@ -11,6 +11,7 @@ const (
ResourceAssignRole RBACResource = "assign_role"
ResourceAuditLog RBACResource = "audit_log"
ResourceBoundaryUsage RBACResource = "boundary_usage"
ResourceChat RBACResource = "chat"
ResourceConnectionLog RBACResource = "connection_log"
ResourceCryptoKey RBACResource = "crypto_key"
ResourceDebugInfo RBACResource = "debug_info"
@@ -82,6 +83,7 @@ var RBACResourceActions = map[RBACResource][]RBACAction{
ResourceAssignRole: {ActionAssign, ActionRead, ActionUnassign},
ResourceAuditLog: {ActionCreate, ActionRead},
ResourceBoundaryUsage: {ActionDelete, ActionRead, ActionUpdate},
ResourceChat: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
ResourceConnectionLog: {ActionRead, ActionUpdate},
ResourceCryptoKey: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
ResourceDebugInfo: {ActionRead},
+1
View File
@@ -304,6 +304,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
}
]
},
"external_auth_github_default_provider_enable": true,
"external_token_encryption_keys": [
"string"
],
+20 -20
View File
@@ -172,10 +172,10 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| Property | Value(s) |
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -305,10 +305,10 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| Property | Value(s) |
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -438,10 +438,10 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| Property | Value(s) |
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -533,10 +533,10 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| Property | Value(s) |
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -909,9 +909,9 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| Property | Value(s) |
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
+84 -81
View File
@@ -865,9 +865,9 @@
#### Enumerated Values
| Value(s) |
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `aibridge_interception:*`, `aibridge_interception:create`, `aibridge_interception:read`, `aibridge_interception:update`, `all`, `api_key:*`, `api_key:create`, `api_key:delete`, `api_key:read`, `api_key:update`, `application_connect`, `assign_org_role:*`, `assign_org_role:assign`, `assign_org_role:create`, `assign_org_role:delete`, `assign_org_role:read`, `assign_org_role:unassign`, `assign_org_role:update`, `assign_role:*`, `assign_role:assign`, `assign_role:read`, `assign_role:unassign`, `audit_log:*`, `audit_log:create`, `audit_log:read`, `boundary_usage:*`, `boundary_usage:delete`, `boundary_usage:read`, `boundary_usage:update`, `coder:all`, `coder:apikeys.manage_self`, `coder:application_connect`, `coder:templates.author`, `coder:templates.build`, `coder:workspaces.access`, `coder:workspaces.create`, `coder:workspaces.delete`, `coder:workspaces.operate`, `connection_log:*`, `connection_log:read`, `connection_log:update`, `crypto_key:*`, `crypto_key:create`, `crypto_key:delete`, `crypto_key:read`, `crypto_key:update`, `debug_info:*`, `debug_info:read`, `deployment_config:*`, `deployment_config:read`, `deployment_config:update`, `deployment_stats:*`, `deployment_stats:read`, `file:*`, `file:create`, `file:read`, `group:*`, `group:create`, `group:delete`, `group:read`, `group:update`, `group_member:*`, `group_member:read`, `idpsync_settings:*`, `idpsync_settings:read`, `idpsync_settings:update`, `inbox_notification:*`, `inbox_notification:create`, `inbox_notification:read`, `inbox_notification:update`, `license:*`, `license:create`, `license:delete`, `license:read`, `notification_message:*`, `notification_message:create`, `notification_message:delete`, `notification_message:read`, `notification_message:update`, `notification_preference:*`, `notification_preference:read`, `notification_preference:update`, `notification_template:*`, `notification_template:read`, `notification_template:update`, `oauth2_app:*`, `oauth2_app:create`, `oauth2_app:delete`, `oauth2_app:read`, `oauth2_app:update`, `oauth2_app_code_token:*`, `oauth2_app_code_token:create`, `oauth2_app_code_token:delete`, `oauth2_app_code_token:read`, `oauth2_app_secret:*`, `oauth2_app_secret:create`, `oauth2_app_secret:delete`, `oauth2_app_secret:read`, `oauth2_app_secret:update`, `organization:*`, `organization:create`, `organization:delete`, `organization:read`, `organization:update`, `organization_member:*`, `organization_member:create`, `organization_member:delete`, `organization_member:read`, `organization_member:update`, `prebuilt_workspace:*`, `prebuilt_workspace:delete`, `prebuilt_workspace:update`, `provisioner_daemon:*`, `provisioner_daemon:create`, `provisioner_daemon:delete`, `provisioner_daemon:read`, `provisioner_daemon:update`, `provisioner_jobs:*`, `provisioner_jobs:create`, `provisioner_jobs:read`, `provisioner_jobs:update`, `replicas:*`, `replicas:read`, `system:*`, `system:create`, `system:delete`, `system:read`, `system:update`, `tailnet_coordinator:*`, `tailnet_coordinator:create`, `tailnet_coordinator:delete`, `tailnet_coordinator:read`, `tailnet_coordinator:update`, `task:*`, `task:create`, `task:delete`, `task:read`, `task:update`, `template:*`, `template:create`, `template:delete`, `template:read`, `template:update`, `template:use`, `template:view_insights`, `usage_event:*`, `usage_event:create`, `usage_event:read`, `usage_event:update`, `user:*`, `user:create`, `user:delete`, `user:read`, `user:read_personal`, `user:update`, `user:update_personal`, `user_secret:*`, `user_secret:create`, `user_secret:delete`, `user_secret:read`, `user_secret:update`, `webpush_subscription:*`, `webpush_subscription:create`, `webpush_subscription:delete`, `webpush_subscription:read`, `workspace:*`, `workspace:application_connect`, `workspace:create`, `workspace:create_agent`, `workspace:delete`, `workspace:delete_agent`, `workspace:read`, `workspace:share`, `workspace:ssh`, `workspace:start`, `workspace:stop`, `workspace:update`, `workspace:update_agent`, `workspace_agent_devcontainers:*`, `workspace_agent_devcontainers:create`, `workspace_agent_resource_monitor:*`, `workspace_agent_resource_monitor:create`, `workspace_agent_resource_monitor:read`, `workspace_agent_resource_monitor:update`, `workspace_dormant:*`, `workspace_dormant:application_connect`, `workspace_dormant:create`, `workspace_dormant:create_agent`, `workspace_dormant:delete`, `workspace_dormant:delete_agent`, `workspace_dormant:read`, `workspace_dormant:share`, `workspace_dormant:ssh`, `workspace_dormant:start`, `workspace_dormant:stop`, `workspace_dormant:update`, `workspace_dormant:update_agent`, `workspace_proxy:*`, `workspace_proxy:create`, `workspace_proxy:delete`, `workspace_proxy:read`, `workspace_proxy:update` |
| Value(s) |
|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `aibridge_interception:*`, `aibridge_interception:create`, `aibridge_interception:read`, `aibridge_interception:update`, `all`, `api_key:*`, `api_key:create`, `api_key:delete`, `api_key:read`, `api_key:update`, `application_connect`, `assign_org_role:*`, `assign_org_role:assign`, `assign_org_role:create`, `assign_org_role:delete`, `assign_org_role:read`, `assign_org_role:unassign`, `assign_org_role:update`, `assign_role:*`, `assign_role:assign`, `assign_role:read`, `assign_role:unassign`, `audit_log:*`, `audit_log:create`, `audit_log:read`, `boundary_usage:*`, `boundary_usage:delete`, `boundary_usage:read`, `boundary_usage:update`, `chat:*`, `chat:create`, `chat:delete`, `chat:read`, `chat:update`, `coder:all`, `coder:apikeys.manage_self`, `coder:application_connect`, `coder:templates.author`, `coder:templates.build`, `coder:workspaces.access`, `coder:workspaces.create`, `coder:workspaces.delete`, `coder:workspaces.operate`, `connection_log:*`, `connection_log:read`, `connection_log:update`, `crypto_key:*`, `crypto_key:create`, `crypto_key:delete`, `crypto_key:read`, `crypto_key:update`, `debug_info:*`, `debug_info:read`, `deployment_config:*`, `deployment_config:read`, `deployment_config:update`, `deployment_stats:*`, `deployment_stats:read`, `file:*`, `file:create`, `file:read`, `group:*`, `group:create`, `group:delete`, `group:read`, `group:update`, `group_member:*`, `group_member:read`, `idpsync_settings:*`, `idpsync_settings:read`, `idpsync_settings:update`, `inbox_notification:*`, `inbox_notification:create`, `inbox_notification:read`, `inbox_notification:update`, `license:*`, `license:create`, `license:delete`, `license:read`, `notification_message:*`, `notification_message:create`, `notification_message:delete`, `notification_message:read`, `notification_message:update`, `notification_preference:*`, `notification_preference:read`, `notification_preference:update`, `notification_template:*`, `notification_template:read`, `notification_template:update`, `oauth2_app:*`, `oauth2_app:create`, `oauth2_app:delete`, `oauth2_app:read`, `oauth2_app:update`, `oauth2_app_code_token:*`, `oauth2_app_code_token:create`, `oauth2_app_code_token:delete`, `oauth2_app_code_token:read`, `oauth2_app_secret:*`, `oauth2_app_secret:create`, `oauth2_app_secret:delete`, `oauth2_app_secret:read`, `oauth2_app_secret:update`, `organization:*`, `organization:create`, `organization:delete`, `organization:read`, `organization:update`, `organization_member:*`, `organization_member:create`, `organization_member:delete`, `organization_member:read`, `organization_member:update`, `prebuilt_workspace:*`, `prebuilt_workspace:delete`, `prebuilt_workspace:update`, `provisioner_daemon:*`, `provisioner_daemon:create`, `provisioner_daemon:delete`, `provisioner_daemon:read`, `provisioner_daemon:update`, `provisioner_jobs:*`, `provisioner_jobs:create`, `provisioner_jobs:read`, `provisioner_jobs:update`, `replicas:*`, `replicas:read`, `system:*`, `system:create`, `system:delete`, `system:read`, `system:update`, `tailnet_coordinator:*`, `tailnet_coordinator:create`, `tailnet_coordinator:delete`, `tailnet_coordinator:read`, `tailnet_coordinator:update`, `task:*`, `task:create`, `task:delete`, `task:read`, `task:update`, `template:*`, `template:create`, `template:delete`, `template:read`, `template:update`, `template:use`, `template:view_insights`, `usage_event:*`, `usage_event:create`, `usage_event:read`, `usage_event:update`, `user:*`, `user:create`, `user:delete`, `user:read`, `user:read_personal`, `user:update`, `user:update_personal`, `user_secret:*`, `user_secret:create`, `user_secret:delete`, `user_secret:read`, `user_secret:update`, `webpush_subscription:*`, `webpush_subscription:create`, `webpush_subscription:delete`, `webpush_subscription:read`, `workspace:*`, `workspace:application_connect`, `workspace:create`, `workspace:create_agent`, `workspace:delete`, `workspace:delete_agent`, `workspace:read`, `workspace:share`, `workspace:ssh`, `workspace:start`, `workspace:stop`, `workspace:update`, `workspace:update_agent`, `workspace_agent_devcontainers:*`, `workspace_agent_devcontainers:create`, `workspace_agent_resource_monitor:*`, `workspace_agent_resource_monitor:create`, `workspace_agent_resource_monitor:read`, `workspace_agent_resource_monitor:update`, `workspace_dormant:*`, `workspace_dormant:application_connect`, `workspace_dormant:create`, `workspace_dormant:create_agent`, `workspace_dormant:delete`, `workspace_dormant:delete_agent`, `workspace_dormant:read`, `workspace_dormant:share`, `workspace_dormant:ssh`, `workspace_dormant:start`, `workspace_dormant:stop`, `workspace_dormant:update`, `workspace_dormant:update_agent`, `workspace_proxy:*`, `workspace_proxy:create`, `workspace_proxy:delete`, `workspace_proxy:read`, `workspace_proxy:update` |
## codersdk.AddLicenseRequest
@@ -2805,6 +2805,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
}
]
},
"external_auth_github_default_provider_enable": true,
"external_token_encryption_keys": [
"string"
],
@@ -3373,6 +3374,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
}
]
},
"external_auth_github_default_provider_enable": true,
"external_token_encryption_keys": [
"string"
],
@@ -3682,78 +3684,79 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
### Properties
| Name | Type | Required | Restrictions | Description |
|--------------------------------------|------------------------------------------------------------------------------------------------------|----------|--------------|--------------------------------------------------------------------|
| `access_url` | [serpent.URL](#serpenturl) | false | | |
| `additional_csp_policy` | array of string | false | | |
| `address` | [serpent.HostPort](#serpenthostport) | false | | Deprecated: Use HTTPAddress or TLS.Address instead. |
| `agent_fallback_troubleshooting_url` | [serpent.URL](#serpenturl) | false | | |
| `agent_stat_refresh_interval` | integer | false | | |
| `ai` | [codersdk.AIConfig](#codersdkaiconfig) | false | | |
| `allow_workspace_renames` | boolean | false | | |
| `autobuild_poll_interval` | integer | false | | |
| `browser_only` | boolean | false | | |
| `cache_directory` | string | false | | |
| `cli_upgrade_message` | string | false | | |
| `config` | string | false | | |
| `config_ssh` | [codersdk.SSHConfig](#codersdksshconfig) | false | | |
| `dangerous` | [codersdk.DangerousConfig](#codersdkdangerousconfig) | false | | |
| `derp` | [codersdk.DERP](#codersdkderp) | false | | |
| `disable_owner_workspace_exec` | boolean | false | | |
| `disable_password_auth` | boolean | false | | |
| `disable_path_apps` | boolean | false | | |
| `disable_workspace_sharing` | boolean | false | | |
| `docs_url` | [serpent.URL](#serpenturl) | false | | |
| `enable_authz_recording` | boolean | false | | |
| `enable_terraform_debug_mode` | boolean | false | | |
| `ephemeral_deployment` | boolean | false | | |
| `experiments` | array of string | false | | |
| `external_auth` | [serpent.Struct-array_codersdk_ExternalAuthConfig](#serpentstruct-array_codersdk_externalauthconfig) | false | | |
| `external_token_encryption_keys` | array of string | false | | |
| `healthcheck` | [codersdk.HealthcheckConfig](#codersdkhealthcheckconfig) | false | | |
| `hide_ai_tasks` | boolean | false | | |
| `http_address` | string | false | | Http address is a string because it may be set to zero to disable. |
| `http_cookies` | [codersdk.HTTPCookieConfig](#codersdkhttpcookieconfig) | false | | |
| `job_hang_detector_interval` | integer | false | | |
| `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | |
| `metrics_cache_refresh_interval` | integer | false | | |
| `notifications` | [codersdk.NotificationsConfig](#codersdknotificationsconfig) | false | | |
| `oauth2` | [codersdk.OAuth2Config](#codersdkoauth2config) | false | | |
| `oidc` | [codersdk.OIDCConfig](#codersdkoidcconfig) | false | | |
| `pg_auth` | string | false | | |
| `pg_conn_max_idle` | string | false | | |
| `pg_conn_max_open` | integer | false | | |
| `pg_connection_url` | string | false | | |
| `pprof` | [codersdk.PprofConfig](#codersdkpprofconfig) | false | | |
| `prometheus` | [codersdk.PrometheusConfig](#codersdkprometheusconfig) | false | | |
| `provisioner` | [codersdk.ProvisionerConfig](#codersdkprovisionerconfig) | false | | |
| `proxy_health_status_interval` | integer | false | | |
| `proxy_trusted_headers` | array of string | false | | |
| `proxy_trusted_origins` | array of string | false | | |
| `rate_limit` | [codersdk.RateLimitConfig](#codersdkratelimitconfig) | false | | |
| `redirect_to_access_url` | boolean | false | | |
| `retention` | [codersdk.RetentionConfig](#codersdkretentionconfig) | false | | |
| `scim_api_key` | string | false | | |
| `session_lifetime` | [codersdk.SessionLifetime](#codersdksessionlifetime) | false | | |
| `ssh_keygen_algorithm` | string | false | | |
| `stats_collection` | [codersdk.StatsCollectionConfig](#codersdkstatscollectionconfig) | false | | |
| `strict_transport_security` | integer | false | | |
| `strict_transport_security_options` | array of string | false | | |
| `support` | [codersdk.SupportConfig](#codersdksupportconfig) | false | | |
| `swagger` | [codersdk.SwaggerConfig](#codersdkswaggerconfig) | false | | |
| `telemetry` | [codersdk.TelemetryConfig](#codersdktelemetryconfig) | false | | |
| `terms_of_service_url` | string | false | | |
| `tls` | [codersdk.TLSConfig](#codersdktlsconfig) | false | | |
| `trace` | [codersdk.TraceConfig](#codersdktraceconfig) | false | | |
| `update_check` | boolean | false | | |
| `user_quiet_hours_schedule` | [codersdk.UserQuietHoursScheduleConfig](#codersdkuserquiethoursscheduleconfig) | false | | |
| `verbose` | boolean | false | | |
| `web_terminal_renderer` | string | false | | |
| `wgtunnel_host` | string | false | | |
| `wildcard_access_url` | string | false | | |
| `workspace_hostname_suffix` | string | false | | |
| `workspace_prebuilds` | [codersdk.PrebuildsConfig](#codersdkprebuildsconfig) | false | | |
| `write_config` | boolean | false | | |
| Name | Type | Required | Restrictions | Description |
|------------------------------------------------|------------------------------------------------------------------------------------------------------|----------|--------------|--------------------------------------------------------------------|
| `access_url` | [serpent.URL](#serpenturl) | false | | |
| `additional_csp_policy` | array of string | false | | |
| `address` | [serpent.HostPort](#serpenthostport) | false | | Deprecated: Use HTTPAddress or TLS.Address instead. |
| `agent_fallback_troubleshooting_url` | [serpent.URL](#serpenturl) | false | | |
| `agent_stat_refresh_interval` | integer | false | | |
| `ai` | [codersdk.AIConfig](#codersdkaiconfig) | false | | |
| `allow_workspace_renames` | boolean | false | | |
| `autobuild_poll_interval` | integer | false | | |
| `browser_only` | boolean | false | | |
| `cache_directory` | string | false | | |
| `cli_upgrade_message` | string | false | | |
| `config` | string | false | | |
| `config_ssh` | [codersdk.SSHConfig](#codersdksshconfig) | false | | |
| `dangerous` | [codersdk.DangerousConfig](#codersdkdangerousconfig) | false | | |
| `derp` | [codersdk.DERP](#codersdkderp) | false | | |
| `disable_owner_workspace_exec` | boolean | false | | |
| `disable_password_auth` | boolean | false | | |
| `disable_path_apps` | boolean | false | | |
| `disable_workspace_sharing` | boolean | false | | |
| `docs_url` | [serpent.URL](#serpenturl) | false | | |
| `enable_authz_recording` | boolean | false | | |
| `enable_terraform_debug_mode` | boolean | false | | |
| `ephemeral_deployment` | boolean | false | | |
| `experiments` | array of string | false | | |
| `external_auth` | [serpent.Struct-array_codersdk_ExternalAuthConfig](#serpentstruct-array_codersdk_externalauthconfig) | false | | |
| `external_auth_github_default_provider_enable` | boolean | false | | |
| `external_token_encryption_keys` | array of string | false | | |
| `healthcheck` | [codersdk.HealthcheckConfig](#codersdkhealthcheckconfig) | false | | |
| `hide_ai_tasks` | boolean | false | | |
| `http_address` | string | false | | Http address is a string because it may be set to zero to disable. |
| `http_cookies` | [codersdk.HTTPCookieConfig](#codersdkhttpcookieconfig) | false | | |
| `job_hang_detector_interval` | integer | false | | |
| `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | |
| `metrics_cache_refresh_interval` | integer | false | | |
| `notifications` | [codersdk.NotificationsConfig](#codersdknotificationsconfig) | false | | |
| `oauth2` | [codersdk.OAuth2Config](#codersdkoauth2config) | false | | |
| `oidc` | [codersdk.OIDCConfig](#codersdkoidcconfig) | false | | |
| `pg_auth` | string | false | | |
| `pg_conn_max_idle` | string | false | | |
| `pg_conn_max_open` | integer | false | | |
| `pg_connection_url` | string | false | | |
| `pprof` | [codersdk.PprofConfig](#codersdkpprofconfig) | false | | |
| `prometheus` | [codersdk.PrometheusConfig](#codersdkprometheusconfig) | false | | |
| `provisioner` | [codersdk.ProvisionerConfig](#codersdkprovisionerconfig) | false | | |
| `proxy_health_status_interval` | integer | false | | |
| `proxy_trusted_headers` | array of string | false | | |
| `proxy_trusted_origins` | array of string | false | | |
| `rate_limit` | [codersdk.RateLimitConfig](#codersdkratelimitconfig) | false | | |
| `redirect_to_access_url` | boolean | false | | |
| `retention` | [codersdk.RetentionConfig](#codersdkretentionconfig) | false | | |
| `scim_api_key` | string | false | | |
| `session_lifetime` | [codersdk.SessionLifetime](#codersdksessionlifetime) | false | | |
| `ssh_keygen_algorithm` | string | false | | |
| `stats_collection` | [codersdk.StatsCollectionConfig](#codersdkstatscollectionconfig) | false | | |
| `strict_transport_security` | integer | false | | |
| `strict_transport_security_options` | array of string | false | | |
| `support` | [codersdk.SupportConfig](#codersdksupportconfig) | false | | |
| `swagger` | [codersdk.SwaggerConfig](#codersdkswaggerconfig) | false | | |
| `telemetry` | [codersdk.TelemetryConfig](#codersdktelemetryconfig) | false | | |
| `terms_of_service_url` | string | false | | |
| `tls` | [codersdk.TLSConfig](#codersdktlsconfig) | false | | |
| `trace` | [codersdk.TraceConfig](#codersdktraceconfig) | false | | |
| `update_check` | boolean | false | | |
| `user_quiet_hours_schedule` | [codersdk.UserQuietHoursScheduleConfig](#codersdkuserquiethoursscheduleconfig) | false | | |
| `verbose` | boolean | false | | |
| `web_terminal_renderer` | string | false | | |
| `wgtunnel_host` | string | false | | |
| `wildcard_access_url` | string | false | | |
| `workspace_hostname_suffix` | string | false | | |
| `workspace_prebuilds` | [codersdk.PrebuildsConfig](#codersdkprebuildsconfig) | false | | |
| `write_config` | boolean | false | | |
## codersdk.DiagnosticExtra
@@ -3981,9 +3984,9 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
#### Enumerated Values
| Value(s) |
|----------------------------------------------------------------------------------------------------------------|
| `auto-fill-parameters`, `example`, `mcp-server-http`, `notifications`, `oauth2`, `web-push`, `workspace-usage` |
| Value(s) |
|--------------------------------------------------------------------------------------------------------------------------|
| `agents`, `auto-fill-parameters`, `example`, `mcp-server-http`, `notifications`, `oauth2`, `web-push`, `workspace-usage` |
## codersdk.ExternalAPIKeyScopes
@@ -7322,9 +7325,9 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
#### Enumerated Values
| Value(s) |
|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| Value(s) |
|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
## codersdk.RateLimitConfig
+5 -5
View File
@@ -811,11 +811,11 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| `login_type` | `github`, `oidc`, `password`, `token` |
| `scope` | `all`, `application_connect` |
| Property | Value(s) |
|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
| `login_type` | `github`, `oidc`, `password`, `token` |
| `scope` | `all`, `application_connect` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
+11
View File
@@ -1258,6 +1258,17 @@ The upgrade message to display to users when a client/server mismatch is detecte
Support links to display in the top right drop down menu.
### --external-auth-github-default-provider-enable
| | |
|-------------|------------------------------------------------------------------|
| Type | <code>bool</code> |
| Environment | <code>$CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE</code> |
| YAML | <code>externalAuthGithubDefaultProviderEnable</code> |
| Default | <code>true</code> |
Enable the default GitHub external auth provider managed by Coder.
### --proxy-health-interval
| | |
+3
View File
@@ -63,6 +63,9 @@ OPTIONS:
Separate multiple experiments with commas, or enter '*' to opt-in to
all available experiments.
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
Enable the default GitHub external auth provider managed by Coder.
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
Type of auth to use when connecting to postgres. For AWS RDS, using
IAM authentication (awsiamrds) is recommended.
+172
View File
@@ -0,0 +1,172 @@
package coderd
import (
"context"
"net/http"
"net/url"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/websocket"
)
// RelaySourceHeader marks replica-relayed stream requests.
const RelaySourceHeader = "X-Coder-Relay-Source-Replica"
const (
authorizationHeader = "Authorization"
cookieHeader = "Cookie"
)
// newRemotePartsProvider creates a RemotePartsProvider that dials a remote
// replica's stream endpoint to fetch message_part events. It filters to only
// forward message_part events since durable events come via pubsub.
func newRemotePartsProvider(
resolveReplicaAddress func(context.Context, uuid.UUID) (string, bool),
replicaHTTPClient *http.Client,
replicaID uuid.UUID,
) chatd.RemotePartsProvider {
return func(
ctx context.Context,
chatID uuid.UUID,
workerID uuid.UUID,
requestHeader http.Header,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
error,
) {
address, ok := resolveReplicaAddress(ctx, workerID)
if !ok {
return nil, nil, nil, xerrors.New("worker replica not found")
}
baseURL, err := url.Parse(address)
if err != nil {
return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err)
}
relayCtx, relayCancel := context.WithCancel(ctx)
sdkClient := codersdk.New(baseURL)
sdkClient.HTTPClient = replicaHTTPClient
sdkClient.SessionTokenProvider = relayHeaderTokenProvider{
header: relayHeaders(requestHeader, replicaID),
}
sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID)
if err != nil {
relayCancel()
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err)
}
snapshot := make([]codersdk.ChatStreamEvent, 0, 100)
preloaded := make([]codersdk.ChatStreamEvent, 0, 100)
drainInitial:
for len(snapshot) < cap(snapshot) {
select {
case <-relayCtx.Done():
_ = sourceStream.Close()
relayCancel()
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err())
case event, ok := <-sourceEvents:
if !ok {
break drainInitial
}
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
continue
}
snapshot = append(snapshot, event)
preloaded = append(preloaded, event)
default:
break drainInitial
}
}
events := make(chan codersdk.ChatStreamEvent, 128)
go func() {
defer close(events)
defer relayCancel()
defer func() {
_ = sourceStream.Close()
}()
for _, event := range preloaded {
select {
case events <- event:
case <-relayCtx.Done():
return
}
}
for {
select {
case <-relayCtx.Done():
return
case event, ok := <-sourceEvents:
if !ok {
return
}
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
continue
}
select {
case events <- event:
case <-relayCtx.Done():
return
}
}
}
}()
cancel := func() {
relayCancel()
_ = sourceStream.Close()
}
return snapshot, events, cancel, nil
}
}
type relayHeaderTokenProvider struct {
header http.Header
}
func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption {
return func(req *http.Request) {
for key, values := range p.header {
for _, value := range values {
req.Header.Add(key, value)
}
}
}
}
func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) {
if opts.HTTPHeader == nil {
opts.HTTPHeader = make(http.Header)
}
for key, values := range p.header {
for _, value := range values {
opts.HTTPHeader.Add(key, value)
}
}
}
func (p relayHeaderTokenProvider) GetSessionToken() string {
return p.header.Get(codersdk.SessionTokenHeader)
}
func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header {
header := make(http.Header)
if source != nil {
for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} {
for _, value := range source.Values(key) {
header.Add(key, value)
}
}
}
header.Set(RelaySourceHeader, replicaID.String())
return header
}
+355
View File
@@ -0,0 +1,355 @@
package coderd_test
import (
"context"
"net/url"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
)
func TestChatStreamRelay(t *testing.T) {
t.Parallel()
t.Run("RelayMessagePartsAcrossReplicas", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, pubsub := dbtestutil.NewDB(t)
firstClient, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureHighAvailability: 1,
},
},
})
secondClient, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
},
DontAddLicense: true,
DontAddFirstUser: true,
})
secondClient.SetSessionToken(firstClient.SessionToken())
// Verify we have two replicas
replicas, err := secondClient.Replicas(ctx)
require.NoError(t, err)
require.Len(t, replicas, 2)
firstReplicaID := replicaIDForClientURL(t, firstClient.URL, replicas)
secondReplicaID := replicaIDForClientURL(t, secondClient.URL, replicas)
streamingChunks := make(chan chattest.OpenAIChunk, 8)
chatStreamStarted := make(chan struct{}, 1)
openai := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if req.Stream {
select {
case chatStreamStarted <- struct{}{}:
default:
}
return chattest.OpenAIResponse{StreamingChunks: streamingChunks}
}
return chattest.OpenAINonStreamingResponse("ok")
})
//nolint:gocritic // Test uses owner client to configure chat providers.
provider, err := firstClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: openai,
})
require.NoError(t, err)
require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, provider.Source)
model, err := firstClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: &[]int64{1000}[0],
CompressionThreshold: &[]int32{70}[0],
})
require.NoError(t, err)
// Create a chat on the first replica
chat, err := firstClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "Test chat for relay",
}},
ModelConfigID: &model.ID,
})
require.NoError(t, err)
require.Equal(t, codersdk.ChatStatusPending, chat.Status)
var runningChat database.Chat
require.Eventually(t, func() bool {
current, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
if current.Status != database.ChatStatusRunning || !current.WorkerID.Valid {
return false
}
runningChat = current
return true
}, testutil.WaitLong, testutil.IntervalFast)
var localClient *codersdk.Client
var relayClient *codersdk.Client
switch runningChat.WorkerID.UUID {
case firstReplicaID:
localClient = firstClient
relayClient = secondClient
case secondReplicaID:
localClient = secondClient
relayClient = firstClient
default:
require.FailNowf(
t,
"worker replica was not recognized",
"worker %s was not one of %s or %s",
runningChat.WorkerID.UUID,
firstReplicaID,
secondReplicaID,
)
}
firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID)
require.NoError(t, err)
defer firstStream.Close()
select {
case <-chatStreamStarted:
case <-ctx.Done():
require.FailNowf(
t,
"timed out waiting for OpenAI stream request",
"chat stream request did not start before context deadline: %v",
ctx.Err(),
)
}
firstChunkText := "relay-part-one"
streamingChunks <- chattest.OpenAITextChunks(firstChunkText)[0]
firstEvent := waitForStreamTextPart(ctx, t, firstEvents, firstChunkText)
require.Equal(t, "assistant", firstEvent.MessagePart.Role)
secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID)
require.NoError(t, err)
defer secondStream.Close()
secondSnapshotEvent := waitForStreamTextPart(ctx, t, secondEvents, firstChunkText)
require.Equal(t, "assistant", secondSnapshotEvent.MessagePart.Role)
secondChunkText := "relay-part-two"
streamingChunks <- chattest.OpenAITextChunks(secondChunkText)[0]
waitForStreamTextPart(ctx, t, firstEvents, secondChunkText)
waitForStreamTextPart(ctx, t, secondEvents, secondChunkText)
close(streamingChunks)
})
}
func waitForStreamTextPart(
ctx context.Context,
t *testing.T,
events <-chan codersdk.ChatStreamEvent,
expectedText string,
) codersdk.ChatStreamEvent {
t.Helper()
for {
select {
case <-ctx.Done():
require.FailNowf(
t,
"timed out waiting for chat stream event",
"expected text part %q before context deadline: %v",
expectedText,
ctx.Err(),
)
case event, ok := <-events:
require.Truef(t, ok, "chat stream closed while waiting for %q", expectedText)
if event.Type == codersdk.ChatStreamEventTypeError {
errMessage := "unknown chat stream error"
if event.Error != nil && event.Error.Message != "" {
errMessage = event.Error.Message
}
require.FailNowf(
t,
"chat stream returned error event",
"while waiting for %q: %s",
expectedText,
errMessage,
)
}
if event.Type != codersdk.ChatStreamEventTypeMessagePart || event.MessagePart == nil {
continue
}
if event.MessagePart.Part.Type != codersdk.ChatMessagePartTypeText {
continue
}
require.Equal(t, expectedText, event.MessagePart.Part.Text)
return event
}
}
}
func replicaIDForClientURL(
t *testing.T,
clientURL *url.URL,
replicas []codersdk.Replica,
) uuid.UUID {
t.Helper()
for _, replica := range replicas {
relayURL, err := url.Parse(replica.RelayAddress)
require.NoErrorf(
t,
err,
"parse replica relay address %q",
replica.RelayAddress,
)
if relayURL.Host == clientURL.Host {
return replica.ID
}
}
require.FailNowf(
t,
"missing replica for client URL",
"client host %q not present in replica list",
clientURL.Host,
)
return uuid.Nil
}
func TestChatModelConfigDefault(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := coderdenttest.New(t, nil)
//nolint:gocritic // Test uses owner client to configure chat providers.
provider, err := client.CreateChatProvider(
ctx,
codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: "https://example.com",
},
)
require.NoError(t, err)
contextLimit := int64(1000)
compressionThreshold := int32(70)
trueValue := true
falseValue := false
firstModel, err := client.CreateChatModelConfig(
ctx,
codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-5-a",
DisplayName: "GPT 5 A",
IsDefault: &trueValue,
ContextLimit: &contextLimit,
CompressionThreshold: &compressionThreshold,
},
)
require.NoError(t, err)
require.True(t, firstModel.IsDefault)
secondModel, err := client.CreateChatModelConfig(
ctx,
codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-5-b",
DisplayName: "GPT 5 B",
IsDefault: &trueValue,
ContextLimit: &contextLimit,
CompressionThreshold: &compressionThreshold,
},
)
require.NoError(t, err)
require.True(t, secondModel.IsDefault)
modelConfigs, err := client.ListChatModelConfigs(ctx)
require.NoError(t, err)
firstStored := findChatModelConfigByID(t, modelConfigs, firstModel.ID)
secondStored := findChatModelConfigByID(t, modelConfigs, secondModel.ID)
require.False(t, firstStored.IsDefault)
require.True(t, secondStored.IsDefault)
updatedFirst, err := client.UpdateChatModelConfig(
ctx,
firstModel.ID,
codersdk.UpdateChatModelConfigRequest{
IsDefault: &trueValue,
},
)
require.NoError(t, err)
require.True(t, updatedFirst.IsDefault)
modelConfigs, err = client.ListChatModelConfigs(ctx)
require.NoError(t, err)
firstStored = findChatModelConfigByID(t, modelConfigs, firstModel.ID)
secondStored = findChatModelConfigByID(t, modelConfigs, secondModel.ID)
require.True(t, firstStored.IsDefault)
require.False(t, secondStored.IsDefault)
updatedFirst, err = client.UpdateChatModelConfig(
ctx,
firstModel.ID,
codersdk.UpdateChatModelConfigRequest{
IsDefault: &falseValue,
},
)
require.NoError(t, err)
require.False(t, updatedFirst.IsDefault)
modelConfigs, err = client.ListChatModelConfigs(ctx)
require.NoError(t, err)
firstStored = findChatModelConfigByID(t, modelConfigs, firstModel.ID)
secondStored = findChatModelConfigByID(t, modelConfigs, secondModel.ID)
require.False(t, firstStored.IsDefault)
require.True(t, secondStored.IsDefault)
}
func findChatModelConfigByID(
t *testing.T,
modelConfigs []codersdk.ChatModelConfig,
id uuid.UUID,
) codersdk.ChatModelConfig {
t.Helper()
for _, modelConfig := range modelConfigs {
if modelConfig.ID == id {
return modelConfig
}
}
require.FailNowf(t, "missing model config", "model config %s not found", id)
return codersdk.ChatModelConfig{}
}
+95 -4
View File
@@ -3,6 +3,7 @@ package coderd
import (
"context"
"crypto/ed25519"
"crypto/tls"
"fmt"
"math"
"net/http"
@@ -15,6 +16,7 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
@@ -100,6 +102,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
ctx, cancelFunc := context.WithCancel(ctx)
defer func() {
if err != nil {
cancelFunc()
}
}()
if options.ExternalTokenEncryption == nil {
options.ExternalTokenEncryption = make([]dbcrypt.Cipher, 0)
@@ -141,6 +148,33 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
)
}
meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates)
if err != nil {
return nil, xerrors.Errorf("create DERP mesh TLS config: %w", err)
}
var replicaManagerPtr atomic.Pointer[replicasync.Manager]
resolveReplicaAddress := func(
_ context.Context,
replicaID uuid.UUID,
) (string, bool) {
manager := replicaManagerPtr.Load()
if manager == nil {
return "", false
}
for _, replica := range manager.AllPrimary() {
if replica.ID != replicaID {
continue
}
relayAddress := strings.TrimSpace(replica.RelayAddress)
if relayAddress == "" {
return "", false
}
return relayAddress, true
}
return "", false
}
api := &API{
ctx: ctx,
cancel: cancelFunc,
@@ -156,6 +190,44 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
// This must happen before coderd initialization!
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
// Wire up enterprise chat relay for cross-replica message_part streaming.
// Must be set before coderd.New so the chat processor gets it.
replicaHTTPClient := replicaRelayHTTPClient(options.HTTPClient, meshTLSConfig)
if replicaHTTPClient == nil {
replicaHTTPClient = options.Options.HTTPClient
}
if replicaHTTPClient == nil {
replicaHTTPClient = http.DefaultClient
}
// Use a closure that captures api by reference so it can access api.AGPL.ID
// after coderd.New is called. The provider is only invoked when Subscribe
// is called, which happens after initialization, so api.AGPL will be set.
options.Options.ChatRemotePartsProvider = func(
ctx context.Context,
chatID uuid.UUID,
workerID uuid.UUID,
requestHeader http.Header,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
error,
) {
// Get the replica ID from the API (will be set after coderd.New)
replicaID := api.AGPL.ID
if replicaID == uuid.Nil {
// Fallback if somehow called before initialization
replicaID = uuid.New()
}
provider := newRemotePartsProvider(
resolveReplicaAddress,
replicaHTTPClient,
replicaID,
)
return provider(ctx, chatID, workerID, requestHeader)
}
api.AGPL = coderd.New(options.Options)
defer func() {
if err != nil {
@@ -583,10 +655,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
})))
}
meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates)
if err != nil {
return nil, xerrors.Errorf("create DERP mesh TLS config: %w", err)
}
// We always want to run the replica manager even if we don't have DERP
// enabled, since it's used to detect other coder servers for licensing.
api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{
@@ -600,6 +668,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
if err != nil {
return nil, xerrors.Errorf("initialize replica: %w", err)
}
replicaManagerPtr.Store(api.replicaManager)
if api.DERPServer != nil {
api.derpMesh = derpmesh.New(options.Logger.Named("derpmesh"), api.DERPServer, meshTLSConfig)
}
@@ -651,6 +720,28 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
return api, nil
}
func replicaRelayHTTPClient(base *http.Client, tlsConfig *tls.Config) *http.Client {
if base == nil {
base = http.DefaultClient
}
clone := *base
var transport *http.Transport
switch t := base.Transport.(type) {
case *http.Transport:
transport = t.Clone()
default:
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok {
transport = defaultTransport.Clone()
} else {
transport = &http.Transport{}
}
}
transport.TLSClientConfig = tlsConfig
clone.Transport = transport
return &clone
}
type Options struct {
*coderd.Options
+55 -2
View File
@@ -3,6 +3,7 @@ package dbcrypt
import (
"context"
"database/sql"
"strings"
"golang.org/x/xerrors"
@@ -82,6 +83,32 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
providers, err := cryptDB.GetChatProviders(ctx)
if err != nil {
return xerrors.Errorf("get chat providers: %w", err)
}
log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers)))
for idx, provider := range providers {
if strings.TrimSpace(provider.APIKey) == "" {
continue
}
if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
Enabled: provider.Enabled,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
// Revoke old keys
for _, c := range ciphers[1:] {
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
@@ -172,6 +199,28 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
providers, err := cryptDB.GetChatProviders(ctx)
if err != nil {
return xerrors.Errorf("get chat providers: %w", err)
}
log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers)))
for idx, provider := range providers {
if !provider.ApiKeyKeyID.Valid {
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
Enabled: provider.Enabled,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
// Revoke _all_ keys
for _, c := range ciphers {
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
@@ -192,6 +241,10 @@ DELETE FROM user_links
DELETE FROM external_auth_links
WHERE oauth_access_token_key_id IS NOT NULL
OR oauth_refresh_token_key_id IS NOT NULL;
UPDATE chat_providers
SET api_key = '',
api_key_key_id = NULL
WHERE api_key_key_id IS NOT NULL;
COMMIT;
`
@@ -203,9 +256,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
store := database.New(sqlDB)
_, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens)
if err != nil {
return xerrors.Errorf("delete user links: %w", err)
return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err)
}
log.Info(ctx, "deleted encrypted user tokens")
log.Info(ctx, "deleted encrypted user tokens and chat provider API keys")
log.Info(ctx, "revoking all active keys")
keys, err := store.GetDBCryptKeys(ctx)
+87
View File
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/base64"
"strings"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -351,6 +352,92 @@ func (db *dbCrypt) GetCryptoKeysByFeature(ctx context.Context, feature database.
return keys, nil
}
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByID(ctx, id)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetEnabledChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.InsertChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.UpdateChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
// If no cipher is loaded, then we can't encrypt anything!
if db.ciphers == nil || db.primaryCipherDigest == "" {
Generated
+3 -3
View File
@@ -76,11 +76,11 @@
},
"nixpkgs-unstable": {
"locked": {
"lastModified": 1758035966,
"narHash": "sha256-qqIJ3yxPiB0ZQTT9//nFGQYn8X/PBoJbofA7hRKZnmE=",
"lastModified": 1771369470,
"narHash": "sha256-0NBlEBKkN3lufyvFegY4TYv5mCNHbi5OmBDrzihbBMQ=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "8d4ddb19d03c65a36ad8d189d001dc32ffb0306b",
"rev": "0182a361324364ae3f436a63005877674cf45efb",
"type": "github"
},
"original": {
+29 -3
View File
@@ -109,6 +109,24 @@
vendorHash = "sha256-69kg3qkvEWyCAzjaCSr3a73MNonub9sZTYyGaCW+UTI=";
};
# Keep Terraform aligned with provisioner/terraform/testdata/version.txt
# so `make gen` remains deterministic in Nix shells.
terraform_1_14_1 =
if pkgs.stdenv.isLinux && pkgs.stdenv.hostPlatform.isx86_64 then
pkgs.runCommand "terraform-1.14.1" {
nativeBuildInputs = [ pkgs.unzip ];
src = pkgs.fetchurl {
url = "https://releases.hashicorp.com/terraform/1.14.1/terraform_1.14.1_linux_amd64.zip";
hash = "sha256-n1MHDuYm354VeIfB0/mvPYEHobZUNxzZkEBinu1piyc=";
};
} ''
mkdir -p "$out/bin"
unzip -p "$src" terraform > "$out/bin/terraform"
chmod +x "$out/bin/terraform"
''
else
unstablePkgs.terraform;
# Packages required to build the frontend
frontendPackages =
with pkgs;
@@ -156,7 +174,7 @@
gnused
gnugrep
gnutar
unstablePkgs.go_1_25
unstablePkgs.go_1_26
gofumpt
go-migrate
(pinnedPkgs.golangci-lint)
@@ -170,7 +188,7 @@
lazydocker
lazygit
less
mockgen
unstablePkgs.mockgen
moreutils
nfpm
nix-prefetch-git
@@ -191,7 +209,7 @@
# sqlc
sqlc-custom
syft
unstablePkgs.terraform
terraform_1_14_1
typos
which
# Needed for many LD system libs!
@@ -285,6 +303,14 @@
lib.optionalDrvAttr stdenv.isLinux "${glibcLocales}/lib/locale/locale-archive";
NODE_OPTIONS = "--max-old-space-size=8192";
BIOME_BINARY =
if pkgs.stdenv.isLinux then
if pkgs.stdenv.hostPlatform.isAarch64 then
"@biomejs/cli-linux-arm64-musl/biome"
else
"@biomejs/cli-linux-x64-musl/biome"
else
"";
GOPRIVATE = "coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder";
};
};
+73 -53
View File
@@ -72,6 +72,14 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202508072
// https://github.com/spf13/afero/pull/487
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
// Forked for two reasons:
// 1) Adds thinking effort to Anthropic provider
// 2) Downgraded to Go 1.25 due to issue with Windows CI
// https://github.com/kylecarbs/fantasy/compare/main...kylecarbs:fantasy:cj/go1.25
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260225152134-45ae0791c21f
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
require (
cdr.dev/slog/v3 v3.0.0-rc1
cloud.google.com/go/compute/metadata v0.9.0
@@ -83,7 +91,7 @@ require (
github.com/aquasecurity/trivy-iac v0.8.0
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2
github.com/awalterschulze/gographviz v2.0.3+incompatible
github.com/aws/smithy-go v1.24.0
github.com/aws/smithy-go v1.24.1
github.com/bramvdbogaerde/go-scp v1.6.0
github.com/briandowns/spinner v1.23.0
github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5
@@ -138,7 +146,7 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-reap v0.0.0-20170704170343-bf58d8a43e7b
github.com/hashicorp/go-version v1.7.0
github.com/hashicorp/go-version v1.8.0
github.com/hashicorp/hc-install v0.9.2
github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f
github.com/hashicorp/terraform-json v0.27.2
@@ -150,7 +158,7 @@ require (
github.com/justinas/nosurf v1.2.0
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f
github.com/klauspost/compress v1.18.2
github.com/klauspost/compress v1.18.4
github.com/lib/pq v1.10.9
github.com/mattn/go-isatty v0.0.20
github.com/mitchellh/go-wordwrap v1.0.1
@@ -167,7 +175,7 @@ require (
github.com/prometheus-community/pro-bing v0.8.0
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_model v0.6.2
github.com/prometheus/common v0.67.4
github.com/prometheus/common v0.67.5
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/robfig/cron/v3 v3.0.1
github.com/shirou/gopsutil/v4 v4.26.1
@@ -186,17 +194,17 @@ require (
github.com/zclconf/go-cty-yaml v1.2.0
go.mozilla.org/pkcs7 v0.9.0
go.nhat.io/otelsql v0.16.0
go.opentelemetry.io/otel v1.39.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0
go.opentelemetry.io/otel/sdk v1.39.0
go.opentelemetry.io/otel/trace v1.39.0
go.opentelemetry.io/otel v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
go.opentelemetry.io/otel/sdk v1.40.0
go.opentelemetry.io/otel/trace v1.40.0
go.uber.org/atomic v1.11.0
go.uber.org/goleak v1.3.1-0.20240429205332-517bace7cc29
go.uber.org/mock v0.6.0
go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516
golang.org/x/crypto v0.48.0
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
golang.org/x/mod v0.33.0
golang.org/x/net v0.50.0
golang.org/x/oauth2 v0.35.0
@@ -219,7 +227,7 @@ require (
)
require (
cloud.google.com/go/auth v0.18.1 // indirect
cloud.google.com/go/auth v0.18.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
dario.cat/mergo v1.0.1 // indirect
filippo.io/edwards25519 v1.1.1 // indirect
@@ -253,19 +261,19 @@ require (
github.com/armon/go-radix v1.0.1-0.20221118154546-54df44f2176c // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.1
github.com/aws/aws-sdk-go-v2/config v1.32.1
github.com/aws/aws-sdk-go-v2/credentials v1.19.1 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.14 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.9
github.com/aws/aws-sdk-go-v2/credentials v1.19.9 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 // indirect
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.2
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.14 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 // indirect
github.com/aws/aws-sdk-go-v2/service/ssm v1.60.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.4 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.10 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.14 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
@@ -288,7 +296,7 @@ require (
github.com/dop251/goja v0.0.0-20241024094426-79f3a7efcdbd // indirect
github.com/dustin/go-humanize v1.0.1
github.com/eapache/queue/v2 v2.0.0-20230407133247-75960ed334e4 // indirect
github.com/ebitengine/purego v0.9.1 // indirect
github.com/ebitengine/purego v0.10.0-alpha.5 // indirect
github.com/elastic/go-windows v1.0.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@@ -305,7 +313,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
@@ -321,10 +329,10 @@ require (
github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.12 // indirect
github.com/googleapis/gax-go/v2 v2.17.0 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-cty v1.5.0 // indirect
@@ -336,12 +344,11 @@ require (
github.com/hashicorp/hcl/v2 v2.24.0
github.com/hashicorp/logutils v1.0.0 // indirect
github.com/hashicorp/terraform-plugin-go v0.29.0 // indirect
github.com/hashicorp/terraform-plugin-log v0.9.0 // indirect
github.com/hashicorp/terraform-plugin-log v0.10.0 // indirect
github.com/hashicorp/terraform-plugin-sdk/v2 v2.38.1 // indirect
github.com/hdevalence/ed25519consensus v0.1.0 // indirect
github.com/illarion/gonotify v1.0.1 // indirect
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 // indirect
github.com/jmespath/go-jmespath v0.4.1-0.20220621161143-b0104c826a24 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect
@@ -388,14 +395,14 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/riandyrn/otelchi v0.5.1 // indirect
github.com/richardartoul/molecule v1.0.1-0.20240531184615-7ca0df43c0b3 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b // indirect
github.com/secure-systems-lab/go-securesystemslib v0.9.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/swaggo/files/v2 v2.0.0 // indirect
@@ -437,9 +444,9 @@ require (
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
go.opentelemetry.io/contrib v1.19.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0
go.opentelemetry.io/otel/metric v1.40.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
@@ -448,10 +455,10 @@ require (
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d // indirect
gopkg.in/ini.v1 v1.67.1 // indirect
howett.net/plist v1.0.0 // indirect
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect
sigs.k8s.io/yaml v1.5.0 // indirect
@@ -464,12 +471,13 @@ require github.com/SherClockHolmes/webpush-go v1.4.0
require (
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 // indirect
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
)
require (
charm.land/fantasy v0.8.1
github.com/anthropics/anthropic-sdk-go v1.19.0
github.com/brianvoe/gofakeit/v7 v7.14.0
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
@@ -491,17 +499,19 @@ require (
cel.dev/expr v0.25.1 // indirect
cloud.google.com/go v0.123.0 // indirect
cloud.google.com/go/iam v1.5.3 // indirect
cloud.google.com/go/logging v1.13.1 // indirect
cloud.google.com/go/logging v1.13.2 // indirect
cloud.google.com/go/longrunning v0.8.0 // indirect
cloud.google.com/go/monitoring v1.24.3 // indirect
cloud.google.com/go/storage v1.56.0 // indirect
cloud.google.com/go/storage v1.60.0 // indirect
git.sr.ht/~jackmordaunt/go-toast v1.1.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect
github.com/DataDog/datadog-agent/comp/core/tagger/origindetection v0.64.2 // indirect
github.com/DataDog/datadog-agent/pkg/version v0.64.2 // indirect
github.com/DataDog/dd-trace-go/v2 v2.0.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect
github.com/Masterminds/semver/v3 v3.3.1 // indirect
github.com/alecthomas/chroma v0.10.0 // indirect
github.com/aquasecurity/go-version v0.0.1 // indirect
@@ -509,24 +519,29 @@ require (
github.com/aquasecurity/jfather v0.0.8 // indirect
github.com/aquasecurity/trivy v0.61.1-0.20250407075540-f1329c7ea1aa // indirect
github.com/aquasecurity/trivy-checks v1.11.3-0.20250604022615-9a7efa7c9169 // indirect
github.com/aws/aws-sdk-go v1.55.7 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.1 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.17 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.8 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.17 // indirect
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.0 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect
github.com/bits-and-blooms/bitset v1.24.4 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250904123553-b4e2667e5ad5 // indirect
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect
github.com/coder/paralleltestctx v0.0.1 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
github.com/daixiang0/gci v0.13.7 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect
github.com/esiqveland/notify v0.13.3 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
github.com/go-git/go-billy/v5 v5.6.2 // indirect
@@ -534,19 +549,24 @@ require (
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/google/go-containerregistry v0.20.6 // indirect
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
github.com/hashicorp/go-getter v1.7.9 // indirect
github.com/hashicorp/go-safetemp v1.0.0 // indirect
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.70 // indirect
github.com/hashicorp/go-getter v1.8.4 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/jackmordaunt/icns/v3 v3.0.1 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/kaptinlin/go-i18n v0.2.4 // indirect
github.com/kaptinlin/jsonpointer v0.4.10 // indirect
github.com/kaptinlin/jsonschema v0.6.10 // indirect
github.com/kaptinlin/messageformat-go v0.4.10 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c // indirect
github.com/mattn/go-shellwords v1.0.12 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
github.com/openai/openai-go v1.12.0 // indirect
github.com/openai/openai-go/v2 v2.7.1 // indirect
github.com/openai/openai-go/v3 v3.15.0 // indirect
github.com/package-url/packageurl-go v0.1.3 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
@@ -569,13 +589,13 @@ require (
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.39.0 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect
golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect
google.golang.org/genai v1.12.0 // indirect
google.golang.org/genai v1.47.0 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
k8s.io/utils v0.0.0-20241210054802-24370beab758 // indirect
mvdan.cc/gofumpt v0.8.0 // indirect
+141 -1502
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -44,7 +44,7 @@ func Test_absoluteBinaryPath(t *testing.T) {
{
name: "TestMalformedVersion",
terraformVersion: "version",
expectedErr: xerrors.Errorf("Terraform binary get version failed: Malformed version: version"),
expectedErr: xerrors.Errorf("Terraform binary get version failed: malformed version: version"),
},
}
// nolint:paralleltest
+33
View File
@@ -0,0 +1,33 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -ne 1 ]]; then
echo "usage: $0 <path-relative-to-site>" >&2
exit 2
fi
script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
repo_root=$(cd "$script_dir/.." && pwd)
target=$1
output_file=$(mktemp)
trap 'rm -f "$output_file"' EXIT
if (
cd "$repo_root/site"
pnpm exec biome format --write "$target"
) >"$output_file" 2>&1; then
cat "$output_file"
exit 0
fi
status=$?
cat "$output_file" >&2
if [[ $status -eq 127 ]] || grep -q "Could not start dynamically linked executable" "$output_file" || grep -q "NixOS cannot run dynamically linked executables" "$output_file"; then
echo "WARNING: skipping biome format for '$target' because the biome binary is unavailable in this environment." >&2
exit 0
fi
exit $status
+4
View File
@@ -48,6 +48,10 @@ func prepareEnv() {
if err != nil {
panic(err)
}
err = os.Setenv("TMPDIR", "/tmp")
if err != nil {
panic(err)
}
}
func deleteEmptyDirs(dir string) error {
+16 -1
View File
@@ -1042,7 +1042,22 @@ const fillParameters = async (
case "number":
{
const parameterField = parameterLabel.locator("input");
await parameterField.fill(buildParameter.value);
// Dynamic parameters can hydrate after initial render and
// overwrite an early fill. Re-apply until the desired value
// is stable.
for (let attempt = 0; attempt < 3; attempt++) {
await parameterField.fill(buildParameter.value);
try {
await expect(parameterField).toHaveValue(buildParameter.value, {
timeout: 1000,
});
break;
} catch (error) {
if (attempt === 2) {
throw error;
}
}
}
}
break;
default:

Some files were not shown because too many files have changed in this diff Show More