mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add experimental agents support (#22290)
feat: add AI chat system with agent tools and chat UI Introduce the chatd subsystem and Agents UI for AI-powered chat within Coder workspaces. - Add chatd package with chat loop, message compaction, prompt management, and LLM provider integration (OpenAI, Anthropic) - Add agent tools: create workspace, list/read templates, read/write/ edit files, execute commands - Add chat API endpoints with streaming, message editing, and durable reconnection - Add database schema and migrations for chats, chat messages, chat providers, and chat model configs - Add RBAC policies and dbauthz enforcement for chat resources - Add Agents UI pages with conversation timeline, queued messages list, diff viewer, and model configuration panel - Add comprehensive test coverage including coderd integration tests, chatd unit tests, and Storybook stories - Gate feature behind experiments flag --------- Co-authored-by: Cian Johnston <cian@coder.com> Co-authored-by: Danielle Maywood <danielle@themaywoods.com> Co-authored-by: Jeremy Ruppel <jeremy@coder.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -854,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
go run -C ./scripts/apitypings main.go > $@
|
||||
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
|
||||
./scripts/biome_format.sh src/api/typesGenerated.ts
|
||||
touch "$@"
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
@@ -863,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
|
||||
go run ./scripts/gensite/ -icons "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
|
||||
./scripts/biome_format.sh src/theme/icons.json
|
||||
touch "$@"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
|
||||
@@ -901,12 +901,12 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
|
||||
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
go run scripts/typegen/main.go rbac typescript > "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
|
||||
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
|
||||
touch "$@"
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
|
||||
go run scripts/typegen/main.go countries > "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
|
||||
./scripts/biome_format.sh src/api/countriesGenerated.ts
|
||||
touch "$@"
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
|
||||
@@ -950,11 +950,11 @@ coderd/apidoc/.gen: \
|
||||
touch "$@"
|
||||
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
|
||||
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
|
||||
./scripts/biome_format.sh ../docs/manifest.json
|
||||
touch "$@"
|
||||
|
||||
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
|
||||
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
|
||||
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
|
||||
touch "$@"
|
||||
|
||||
update-golden-files:
|
||||
@@ -999,11 +999,19 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
|
||||
touch "$@"
|
||||
|
||||
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
|
||||
fi
|
||||
touch "$@"
|
||||
|
||||
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
|
||||
fi
|
||||
touch "$@"
|
||||
|
||||
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
|
||||
|
||||
+45
-1
@@ -4,6 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -16,6 +19,29 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// detectGitRef attempts to resolve the current git branch and remote
|
||||
// origin URL from the given working directory. These are sent to the
|
||||
// control plane so it can look up PR/diff status via the GitHub API
|
||||
// without SSHing into the workspace. Failures are silently ignored
|
||||
// since this is best-effort.
|
||||
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
|
||||
run := func(args ...string) string {
|
||||
//nolint:gosec
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
if workingDirectory != "" {
|
||||
cmd.Dir = workingDirectory
|
||||
}
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
|
||||
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
|
||||
return branch, remoteOrigin
|
||||
}
|
||||
|
||||
// gitAskpass is used by the Coder agent to automatically authenticate
|
||||
// with Git providers based on a hostname.
|
||||
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
@@ -38,8 +64,20 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
workingDirectory = ""
|
||||
}
|
||||
|
||||
// Detect the current git branch and remote origin so
|
||||
// the control plane can resolve diffs without needing
|
||||
// to SSH back into the workspace.
|
||||
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
|
||||
|
||||
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
Match: host,
|
||||
Match: host,
|
||||
GitBranch: gitBranch,
|
||||
GitRemoteOrigin: gitRemoteOrigin,
|
||||
})
|
||||
if err != nil {
|
||||
var apiError *codersdk.Error
|
||||
@@ -58,6 +96,12 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("get git token: %w", err)
|
||||
}
|
||||
if token.URL != "" {
|
||||
// This is to help the agent authenticate with Git.
|
||||
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
|
||||
return cliui.ErrCanceled
|
||||
}
|
||||
|
||||
if err := openURL(inv, token.URL); err == nil {
|
||||
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
|
||||
} else {
|
||||
|
||||
+111
-43
@@ -617,28 +617,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
}
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
|
||||
promRegistry := prometheus.NewRegistry()
|
||||
oauthInstrument := promoauth.NewFactory(promRegistry)
|
||||
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
|
||||
externalAuthConfigs, err := externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
vals.ExternalAuthConfigs.Value,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range externalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
|
||||
if err != nil {
|
||||
@@ -669,7 +649,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
Pubsub: nil,
|
||||
CacheDir: cacheDir,
|
||||
GoogleTokenValidator: googleTokenValidator,
|
||||
ExternalAuthConfigs: externalAuthConfigs,
|
||||
ExternalAuthConfigs: nil,
|
||||
RealIPConfig: realIPConfig,
|
||||
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
||||
TracerProvider: tracerProvider,
|
||||
@@ -829,6 +809,40 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("set deployment id: %w", err)
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
|
||||
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
|
||||
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
|
||||
|
||||
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx,
|
||||
options.Logger,
|
||||
options.Database,
|
||||
vals,
|
||||
mergedExternalAuthProviders,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
|
||||
}
|
||||
|
||||
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
mergedExternalAuthProviders,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range options.ExternalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
// Manage push notifications.
|
||||
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
|
||||
if experiments.Enabled(codersdk.ExperimentWebPush) {
|
||||
@@ -1926,6 +1940,79 @@ type githubOAuth2ConfigParams struct {
|
||||
enterpriseBaseURL string
|
||||
}
|
||||
|
||||
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
|
||||
// We want to enable the default provider only for new deployments, and avoid
|
||||
// enabling it if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return false, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultEligible, nil
|
||||
}
|
||||
|
||||
func maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
vals *codersdk.DeploymentValues,
|
||||
mergedExplicitProviders []codersdk.ExternalAuthConfig,
|
||||
) ([]codersdk.ExternalAuthConfig, error) {
|
||||
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "disabled by configuration"),
|
||||
slog.F("flag", "external-auth-github-default-provider-enable"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
if len(mergedExplicitProviders) > 0 {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "explicit external auth providers configured"),
|
||||
slog.F("provider_count", len(mergedExplicitProviders)),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !defaultEligible {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "deployment is not eligible"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
logger.Info(ctx, "injecting default github external auth provider",
|
||||
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
|
||||
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
|
||||
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
|
||||
)
|
||||
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
ClientID: GithubOAuth2DefaultProviderClientID,
|
||||
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
|
||||
params := githubOAuth2ConfigParams{
|
||||
accessURL: vals.AccessURL.Value(),
|
||||
@@ -1950,28 +2037,9 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
|
||||
// We want to enable it only for new deployments, and avoid enabling it
|
||||
// if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !defaultEligible {
|
||||
|
||||
@@ -53,6 +53,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
@@ -302,6 +303,7 @@ func TestServer(t *testing.T) {
|
||||
"open install.sh: file does not exist",
|
||||
"telemetry disabled, unable to notify of security issues",
|
||||
"installed terraform version newer than expected",
|
||||
"report generator",
|
||||
}
|
||||
|
||||
countLines := func(fullOutput string) int {
|
||||
@@ -1805,6 +1807,155 @@ func TestServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
args []string
|
||||
env map[string]string
|
||||
createUserPreStart bool
|
||||
expectedProviders []string
|
||||
}
|
||||
|
||||
run := func(t *testing.T, tc testCase) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
unsetPrefixedEnv := func(prefix string) {
|
||||
t.Helper()
|
||||
for _, envVar := range os.Environ() {
|
||||
envKey, _, found := strings.Cut(envVar, "=")
|
||||
if !found || !strings.HasPrefix(envKey, prefix) {
|
||||
continue
|
||||
}
|
||||
value, had := os.LookupEnv(envKey)
|
||||
require.True(t, had)
|
||||
require.NoError(t, os.Unsetenv(envKey))
|
||||
keyCopy := envKey
|
||||
valueCopy := value
|
||||
t.Cleanup(func() {
|
||||
// This is for setting/unsetting a number of prefixed env vars.
|
||||
// t.Setenv doesn't cover this use case.
|
||||
// nolint:usetesting
|
||||
_ = os.Setenv(keyCopy, valueCopy)
|
||||
})
|
||||
}
|
||||
}
|
||||
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
|
||||
unsetPrefixedEnv("CODER_GITAUTH_")
|
||||
|
||||
dbURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
|
||||
|
||||
const (
|
||||
existingUserEmail = "existing-user@coder.com"
|
||||
existingUserUsername = "existing-user"
|
||||
existingUserPassword = "SomeSecurePassword!"
|
||||
)
|
||||
if tc.createUserPreStart {
|
||||
hashedPassword, err := userpassword.Hash(existingUserPassword)
|
||||
require.NoError(t, err)
|
||||
_ = dbgen.User(t, db, database.User{
|
||||
Email: existingUserEmail,
|
||||
Username: existingUserUsername,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
})
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"server",
|
||||
"--postgres-url", dbURL,
|
||||
"--http-address", ":0",
|
||||
"--access-url", "https://example.com",
|
||||
}
|
||||
args = append(args, tc.args...)
|
||||
|
||||
inv, cfg := clitest.New(t, args...)
|
||||
for envKey, value := range tc.env {
|
||||
t.Setenv(envKey, value)
|
||||
}
|
||||
clitest.Start(t, inv)
|
||||
|
||||
accessURL := waitAccessURL(t, cfg)
|
||||
client := codersdk.New(accessURL)
|
||||
|
||||
if tc.createUserPreStart {
|
||||
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
|
||||
Email: existingUserEmail,
|
||||
Password: existingUserPassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
client.SetSessionToken(loginResp.SessionToken)
|
||||
} else {
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
}
|
||||
|
||||
externalAuthResp, err := client.ListExternalAuths(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
|
||||
for _, provider := range externalAuthResp.Providers {
|
||||
gotProviders[provider.ID] = provider
|
||||
}
|
||||
require.Len(t, gotProviders, len(tc.expectedProviders))
|
||||
|
||||
for _, providerID := range tc.expectedProviders {
|
||||
provider, ok := gotProviders[providerID]
|
||||
require.Truef(t, ok, "expected provider %q to be configured", providerID)
|
||||
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
|
||||
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
|
||||
require.True(t, provider.Device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range []testCase{
|
||||
{
|
||||
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
|
||||
},
|
||||
{
|
||||
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
|
||||
createUserPreStart: true,
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
"--external-auth-github-default-provider-enable=false",
|
||||
},
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
run(t, tc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_Logging_NoParallel(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
+3
@@ -62,6 +62,9 @@ OPTIONS:
|
||||
Separate multiple experiments with commas, or enter '*' to opt-in to
|
||||
all available experiments.
|
||||
|
||||
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
|
||||
Enable the default GitHub external auth provider managed by Coder.
|
||||
|
||||
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
|
||||
Type of auth to use when connecting to postgres. For AWS RDS, using
|
||||
IAM authentication (awsiamrds) is recommended.
|
||||
|
||||
+3
@@ -564,6 +564,9 @@ supportLinks: []
|
||||
# External Authentication providers.
|
||||
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
|
||||
externalAuthProviders: []
|
||||
# Enable the default GitHub external auth provider managed by Coder.
|
||||
# (default: true, type: bool)
|
||||
externalAuthGithubDefaultProviderEnable: true
|
||||
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
|
||||
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
|
||||
# "tunnel.example.com".
|
||||
|
||||
+2
-1
@@ -192,7 +192,8 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
defer commitAuditWS()
|
||||
|
||||
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
|
||||
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, &createWorkspaceOptions{
|
||||
remoteAddr: r.RemoteAddr,
|
||||
// Before creating the workspace, ensure that this task can be created.
|
||||
preCreateInTX: func(ctx context.Context, tx database.Store) error {
|
||||
// Create task record in the database before creating the workspace so that
|
||||
|
||||
Generated
+19
@@ -12783,6 +12783,11 @@ const docTemplate = `{
|
||||
"boundary_usage:delete",
|
||||
"boundary_usage:read",
|
||||
"boundary_usage:update",
|
||||
"chat:*",
|
||||
"chat:create",
|
||||
"chat:delete",
|
||||
"chat:read",
|
||||
"chat:update",
|
||||
"coder:all",
|
||||
"coder:apikeys.manage_self",
|
||||
"coder:application_connect",
|
||||
@@ -12987,6 +12992,11 @@ const docTemplate = `{
|
||||
"APIKeyScopeBoundaryUsageDelete",
|
||||
"APIKeyScopeBoundaryUsageRead",
|
||||
"APIKeyScopeBoundaryUsageUpdate",
|
||||
"APIKeyScopeChatAll",
|
||||
"APIKeyScopeChatCreate",
|
||||
"APIKeyScopeChatDelete",
|
||||
"APIKeyScopeChatRead",
|
||||
"APIKeyScopeChatUpdate",
|
||||
"APIKeyScopeCoderAll",
|
||||
"APIKeyScopeCoderApikeysManageSelf",
|
||||
"APIKeyScopeCoderApplicationConnect",
|
||||
@@ -14848,6 +14858,9 @@ const docTemplate = `{
|
||||
"external_auth": {
|
||||
"$ref": "#/definitions/serpent.Struct-array_codersdk_ExternalAuthConfig"
|
||||
},
|
||||
"external_auth_github_default_provider_enable": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"external_token_encryption_keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@@ -15132,9 +15145,11 @@ const docTemplate = `{
|
||||
"workspace-usage",
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"agents",
|
||||
"mcp-server-http"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAgents": "Enables agent-powered chat functionality.",
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
"ExperimentExample": "This isn't used for anything.",
|
||||
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
|
||||
@@ -15150,6 +15165,7 @@ const docTemplate = `{
|
||||
"Enables the new workspace usage tracking.",
|
||||
"Enables web push notifications through the browser.",
|
||||
"Enables OAuth2 provider functionality.",
|
||||
"Enables agent-powered chat functionality.",
|
||||
"Enables the MCP HTTP server functionality."
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
@@ -15159,6 +15175,7 @@ const docTemplate = `{
|
||||
"ExperimentWorkspaceUsage",
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentAgents",
|
||||
"ExperimentMCPServerHTTP"
|
||||
]
|
||||
},
|
||||
@@ -18099,6 +18116,7 @@ const docTemplate = `{
|
||||
"assign_role",
|
||||
"audit_log",
|
||||
"boundary_usage",
|
||||
"chat",
|
||||
"connection_log",
|
||||
"crypto_key",
|
||||
"debug_info",
|
||||
@@ -18144,6 +18162,7 @@ const docTemplate = `{
|
||||
"ResourceAssignRole",
|
||||
"ResourceAuditLog",
|
||||
"ResourceBoundaryUsage",
|
||||
"ResourceChat",
|
||||
"ResourceConnectionLog",
|
||||
"ResourceCryptoKey",
|
||||
"ResourceDebugInfo",
|
||||
|
||||
Generated
+19
@@ -11387,6 +11387,11 @@
|
||||
"boundary_usage:delete",
|
||||
"boundary_usage:read",
|
||||
"boundary_usage:update",
|
||||
"chat:*",
|
||||
"chat:create",
|
||||
"chat:delete",
|
||||
"chat:read",
|
||||
"chat:update",
|
||||
"coder:all",
|
||||
"coder:apikeys.manage_self",
|
||||
"coder:application_connect",
|
||||
@@ -11591,6 +11596,11 @@
|
||||
"APIKeyScopeBoundaryUsageDelete",
|
||||
"APIKeyScopeBoundaryUsageRead",
|
||||
"APIKeyScopeBoundaryUsageUpdate",
|
||||
"APIKeyScopeChatAll",
|
||||
"APIKeyScopeChatCreate",
|
||||
"APIKeyScopeChatDelete",
|
||||
"APIKeyScopeChatRead",
|
||||
"APIKeyScopeChatUpdate",
|
||||
"APIKeyScopeCoderAll",
|
||||
"APIKeyScopeCoderApikeysManageSelf",
|
||||
"APIKeyScopeCoderApplicationConnect",
|
||||
@@ -13378,6 +13388,9 @@
|
||||
"external_auth": {
|
||||
"$ref": "#/definitions/serpent.Struct-array_codersdk_ExternalAuthConfig"
|
||||
},
|
||||
"external_auth_github_default_provider_enable": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"external_token_encryption_keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@@ -13655,9 +13668,11 @@
|
||||
"workspace-usage",
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"agents",
|
||||
"mcp-server-http"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAgents": "Enables agent-powered chat functionality.",
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
"ExperimentExample": "This isn't used for anything.",
|
||||
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
|
||||
@@ -13673,6 +13688,7 @@
|
||||
"Enables the new workspace usage tracking.",
|
||||
"Enables web push notifications through the browser.",
|
||||
"Enables OAuth2 provider functionality.",
|
||||
"Enables agent-powered chat functionality.",
|
||||
"Enables the MCP HTTP server functionality."
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
@@ -13682,6 +13698,7 @@
|
||||
"ExperimentWorkspaceUsage",
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentAgents",
|
||||
"ExperimentMCPServerHTTP"
|
||||
]
|
||||
},
|
||||
@@ -16507,6 +16524,7 @@
|
||||
"assign_role",
|
||||
"audit_log",
|
||||
"boundary_usage",
|
||||
"chat",
|
||||
"connection_log",
|
||||
"crypto_key",
|
||||
"debug_info",
|
||||
@@ -16552,6 +16570,7 @@
|
||||
"ResourceAssignRole",
|
||||
"ResourceAuditLog",
|
||||
"ResourceBoundaryUsage",
|
||||
"ResourceChat",
|
||||
"ResourceConnectionLog",
|
||||
"ResourceCryptoKey",
|
||||
"ResourceDebugInfo",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
@@ -91,6 +93,36 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object
|
||||
return true
|
||||
}
|
||||
|
||||
// AuthorizeContext checks whether the RBAC subject on the context
|
||||
// is authorized to perform the given action. The subject must have
|
||||
// been set via dbauthz.As or the ExtractAPIKey middleware. Returns
|
||||
// false if the subject is missing or unauthorized.
|
||||
func (h *HTTPAuthorizer) AuthorizeContext(ctx context.Context, action policy.Action, object rbac.Objecter) bool {
|
||||
roles, ok := dbauthz.ActorFromContext(ctx)
|
||||
if !ok {
|
||||
h.Logger.Error(ctx, "no authorization actor in context")
|
||||
return false
|
||||
}
|
||||
err := h.Authorizer.Authorize(ctx, roles, action, object.RBACObject())
|
||||
if err != nil {
|
||||
internalError := new(rbac.UnauthorizedError)
|
||||
logger := h.Logger
|
||||
if xerrors.As(err, internalError) {
|
||||
logger = h.Logger.With(slog.F("internal_error", internalError.Internal()))
|
||||
}
|
||||
logger.Warn(ctx, "requester is not authorized to access the object",
|
||||
slog.F("roles", roles.SafeRoleNames()),
|
||||
slog.F("actor_id", roles.ID),
|
||||
slog.F("actor_name", roles),
|
||||
slog.F("scope", roles.SafeScopeName()),
|
||||
slog.F("action", action),
|
||||
slog.F("object", object),
|
||||
)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AuthorizeSQLFilter returns an authorization filter that can used in a
|
||||
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
|
||||
// from postgres are already authorized, and the caller does not need to
|
||||
@@ -106,6 +138,22 @@ func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Actio
|
||||
return prepared, nil
|
||||
}
|
||||
|
||||
// AuthorizeSQLFilterContext is like AuthorizeSQLFilter but reads the
|
||||
// RBAC subject from the context directly rather than from an
|
||||
// *http.Request. The subject must have been set via dbauthz.As.
|
||||
func (h *HTTPAuthorizer) AuthorizeSQLFilterContext(ctx context.Context, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) {
|
||||
roles, ok := dbauthz.ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, xerrors.New("no authorization actor in context")
|
||||
}
|
||||
prepared, err := h.Authorizer.Prepare(ctx, roles, action, objectType)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("prepare filter: %w", err)
|
||||
}
|
||||
|
||||
return prepared, nil
|
||||
}
|
||||
|
||||
// checkAuthorization returns if the current API key can use the given
|
||||
// permissions, factoring in the current user's roles and the API key scopes.
|
||||
//
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,461 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replicaA := newTestServer(t, db, ps, uuid.New())
|
||||
replicaB := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-me",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
runningWorker := uuid.New()
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
updated := replicaA.InterruptChat(ctx, chat)
|
||||
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
||||
require.False(t, updated.WorkerID.Valid)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
return false
|
||||
}
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "db-transition",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated := replica.InterruptChat(ctx, chat)
|
||||
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
||||
require.False(t, updated.WorkerID.Valid)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "heartbeat-ownership",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
workerID := uuid.New()
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: uuid.New(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), rows)
|
||||
|
||||
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: workerID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), rows)
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "queue-when-busy",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
workerID := uuid.New()
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Queued)
|
||||
require.NotNil(t, result.QueuedMessage)
|
||||
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
|
||||
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
|
||||
require.True(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 1)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-when-busy",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Queued)
|
||||
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
||||
require.False(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 0)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
||||
}
|
||||
|
||||
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "edit-message",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "original"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
initialMessages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, initialMessages, 1)
|
||||
editedMessageID := initialMessages[0].ID
|
||||
|
||||
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "follow-up"}},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "another"}},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
||||
ChatID: chat.ID,
|
||||
Content: json.RawMessage(`"queued"`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: editedMessageID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, editedMessageID, editResult.Message.ID)
|
||||
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
|
||||
require.False(t, editResult.Chat.WorkerID.Valid)
|
||||
|
||||
editedSDK := db2sdk.ChatMessage(editResult.Message)
|
||||
require.Len(t, editedSDK.Content, 1)
|
||||
require.Equal(t, "edited", editedSDK.Content[0].Text)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.Equal(t, editedMessageID, messages[0].ID)
|
||||
onlyMessage := db2sdk.ChatMessage(messages[0])
|
||||
require.Len(t, onlyMessage.Content, 1)
|
||||
require.Equal(t, "edited", onlyMessage.Content[0].Text)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 0)
|
||||
|
||||
chatFromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, chatFromDB.Status)
|
||||
require.False(t, chatFromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestEditMessageRejectsMissingMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "missing-edited-message",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: 999999,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
|
||||
}
|
||||
|
||||
func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "non-user-edited-message",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantMessage, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`"assistant"`),
|
||||
Valid: true,
|
||||
},
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: assistantMessage.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
|
||||
}
|
||||
|
||||
func newTestServer(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps dbpubsub.Pubsub,
|
||||
replicaID uuid.UUID,
|
||||
) *chatd.Server {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: replicaID,
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
return server
|
||||
}
|
||||
|
||||
func seedChatDependencies(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
) (database.User, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return user, model
|
||||
}
|
||||
@@ -0,0 +1,676 @@
|
||||
package chatloop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result"
|
||||
)
|
||||
|
||||
var ErrInterrupted = xerrors.New("chat interrupted")
|
||||
|
||||
// PersistedStep contains the full content of a completed or
|
||||
// interrupted agent step. Content includes both assistant blocks
|
||||
// (text, reasoning, tool calls) and tool result blocks, mirroring
|
||||
// what fantasy provides in StepResult.Content. The persistence
|
||||
// layer is responsible for splitting these into separate database
|
||||
// messages by role.
|
||||
type PersistedStep struct {
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
}
|
||||
|
||||
// RunOptions configures a single streaming chat loop run.
|
||||
type RunOptions struct {
|
||||
Model fantasy.LanguageModel
|
||||
Messages []fantasy.Message
|
||||
Tools []fantasy.AgentTool
|
||||
StreamCall fantasy.AgentStreamCall
|
||||
MaxSteps int
|
||||
|
||||
ActiveTools []string
|
||||
ContextLimitFallback int64
|
||||
|
||||
PersistStep func(context.Context, PersistedStep) error
|
||||
PublishMessagePart func(
|
||||
role fantasy.MessageRole,
|
||||
part codersdk.ChatMessagePart,
|
||||
)
|
||||
Compaction *CompactionOptions
|
||||
|
||||
OnInterruptedPersistError func(error)
|
||||
}
|
||||
|
||||
// Run executes the chat step-stream loop and delegates persistence/publishing to callbacks.
|
||||
func Run(ctx context.Context, opts RunOptions) (*fantasy.AgentResult, error) {
|
||||
if opts.Model == nil {
|
||||
return nil, xerrors.New("chat model is required")
|
||||
}
|
||||
if opts.PersistStep == nil {
|
||||
return nil, xerrors.New("persist step callback is required")
|
||||
}
|
||||
if opts.MaxSteps <= 0 {
|
||||
opts.MaxSteps = 1
|
||||
}
|
||||
|
||||
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
||||
if opts.PublishMessagePart == nil {
|
||||
return
|
||||
}
|
||||
opts.PublishMessagePart(role, part)
|
||||
}
|
||||
|
||||
var (
|
||||
stepStateMu sync.Mutex
|
||||
streamToolNames map[string]string
|
||||
streamReasoningTitles map[string]string
|
||||
streamReasoningText map[string]string
|
||||
// stepToolResultContents tracks tool results received during
|
||||
// streaming. These are needed for the interrupted-step path
|
||||
// where OnStepFinish never fires.
|
||||
stepToolResultContents []fantasy.ToolResultContent
|
||||
stepAssistantDraft []fantasy.Content
|
||||
stepToolCallIndexByID map[string]int
|
||||
)
|
||||
|
||||
resetStepState := func() {
|
||||
stepStateMu.Lock()
|
||||
streamToolNames = make(map[string]string)
|
||||
streamReasoningTitles = make(map[string]string)
|
||||
streamReasoningText = make(map[string]string)
|
||||
stepToolResultContents = nil
|
||||
stepAssistantDraft = nil
|
||||
stepToolCallIndexByID = make(map[string]int)
|
||||
stepStateMu.Unlock()
|
||||
}
|
||||
|
||||
setReasoningTitleFromText := func(id string, text string) {
|
||||
if id == "" || strings.TrimSpace(text) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if streamReasoningTitles[id] != "" {
|
||||
return
|
||||
}
|
||||
|
||||
streamReasoningText[id] += text
|
||||
if !strings.ContainsAny(streamReasoningText[id], "\r\n") {
|
||||
return
|
||||
}
|
||||
title := chatprompt.ReasoningTitleFromFirstLine(streamReasoningText[id])
|
||||
if title == "" {
|
||||
return
|
||||
}
|
||||
|
||||
streamReasoningTitles[id] = title
|
||||
}
|
||||
|
||||
appendDraftText := func(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if len(stepAssistantDraft) > 0 {
|
||||
lastIndex := len(stepAssistantDraft) - 1
|
||||
switch last := stepAssistantDraft[lastIndex].(type) {
|
||||
case fantasy.TextContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = last
|
||||
return
|
||||
case *fantasy.TextContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = fantasy.TextContent{Text: last.Text}
|
||||
return
|
||||
}
|
||||
}
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.TextContent{Text: text})
|
||||
}
|
||||
|
||||
appendDraftReasoning := func(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if len(stepAssistantDraft) > 0 {
|
||||
lastIndex := len(stepAssistantDraft) - 1
|
||||
switch last := stepAssistantDraft[lastIndex].(type) {
|
||||
case fantasy.ReasoningContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = last
|
||||
return
|
||||
case *fantasy.ReasoningContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = fantasy.ReasoningContent{Text: last.Text}
|
||||
return
|
||||
}
|
||||
}
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ReasoningContent{Text: text})
|
||||
}
|
||||
|
||||
upsertDraftToolCall := func(toolCallID, toolName, input string, appendInput bool) {
|
||||
if toolCallID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if strings.TrimSpace(toolName) != "" {
|
||||
streamToolNames[toolCallID] = toolName
|
||||
}
|
||||
|
||||
index, exists := stepToolCallIndexByID[toolCallID]
|
||||
if !exists {
|
||||
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if index < 0 || index >= len(stepAssistantDraft) {
|
||||
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
existingCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](stepAssistantDraft[index])
|
||||
if !ok {
|
||||
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](stepAssistantDraft[index]); ptrOK && ptrCall != nil {
|
||||
existingCall = *ptrCall
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(toolName) != "" {
|
||||
existingCall.ToolName = toolName
|
||||
}
|
||||
if appendInput {
|
||||
existingCall.Input += input
|
||||
} else if input != "" || existingCall.Input == "" {
|
||||
existingCall.Input = input
|
||||
}
|
||||
stepAssistantDraft[index] = existingCall
|
||||
}
|
||||
|
||||
appendDraftSource := func(source fantasy.SourceContent) {
|
||||
stepStateMu.Lock()
|
||||
stepAssistantDraft = append(stepAssistantDraft, source)
|
||||
stepStateMu.Unlock()
|
||||
}
|
||||
|
||||
persistInterruptedStep := func() error {
|
||||
stepStateMu.Lock()
|
||||
draft := append([]fantasy.Content(nil), stepAssistantDraft...)
|
||||
toolResults := append([]fantasy.ToolResultContent(nil), stepToolResultContents...)
|
||||
toolNameByCallID := make(map[string]string, len(streamToolNames))
|
||||
for id, name := range streamToolNames {
|
||||
toolNameByCallID[id] = name
|
||||
}
|
||||
stepStateMu.Unlock()
|
||||
|
||||
if len(draft) == 0 && len(toolResults) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Track which tool calls already have results.
|
||||
answeredToolCalls := make(map[string]struct{}, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if tr.ToolCallID != "" {
|
||||
answeredToolCalls[tr.ToolCallID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the combined content: draft + received tool results
|
||||
// + synthetic interrupted results for unanswered tool calls.
|
||||
content := make([]fantasy.Content, 0, len(draft)+len(toolResults))
|
||||
content = append(content, draft...)
|
||||
for _, tr := range toolResults {
|
||||
content = append(content, tr)
|
||||
}
|
||||
|
||||
for _, block := range draft {
|
||||
toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block)
|
||||
if !ok {
|
||||
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](block); ptrOK && ptrCall != nil {
|
||||
toolCall = *ptrCall
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
if !ok || toolCall.ToolCallID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := answeredToolCalls[toolCall.ToolCallID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(toolCall.ToolName)
|
||||
if toolName == "" {
|
||||
toolName = strings.TrimSpace(toolNameByCallID[toolCall.ToolCallID])
|
||||
}
|
||||
|
||||
content = append(content, fantasy.ToolResultContent{
|
||||
ToolCallID: toolCall.ToolCallID,
|
||||
ToolName: toolName,
|
||||
Result: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New(interruptedToolResultErrorMessage),
|
||||
},
|
||||
})
|
||||
answeredToolCalls[toolCall.ToolCallID] = struct{}{}
|
||||
}
|
||||
|
||||
persistCtx := context.WithoutCancel(ctx)
|
||||
return opts.PersistStep(persistCtx, PersistedStep{
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
resetStepState()
|
||||
|
||||
agent := fantasy.NewAgent(
|
||||
opts.Model,
|
||||
fantasy.WithTools(opts.Tools...),
|
||||
fantasy.WithStopConditions(fantasy.StepCountIs(opts.MaxSteps)),
|
||||
)
|
||||
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
|
||||
// Fantasy's AgentStreamCall currently requires a non-empty Prompt and always
|
||||
// appends it as a user message. chatd already supplies the full history in
|
||||
// Messages, so we pass and then strip a sentinel user message in PrepareStep.
|
||||
sentinelPrompt := "__chatd_agent_prompt_sentinel_" + uuid.NewString()
|
||||
|
||||
streamCall := opts.StreamCall
|
||||
streamCall.Prompt = sentinelPrompt
|
||||
streamCall.Messages = opts.Messages
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
options fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
return stepCtx, prepareStepResult(
|
||||
options.Messages,
|
||||
sentinelPrompt,
|
||||
opts.ActiveTools,
|
||||
applyAnthropicCaching,
|
||||
), nil
|
||||
}
|
||||
streamCall.OnStepStart = func(_ int) error {
|
||||
resetStepState()
|
||||
return nil
|
||||
}
|
||||
streamCall.OnTextDelta = func(_ string, text string) error {
|
||||
appendDraftText(text)
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: text,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
streamCall.OnReasoningDelta = func(id string, text string) error {
|
||||
appendDraftReasoning(text)
|
||||
setReasoningTitleFromText(id, text)
|
||||
stepStateMu.Lock()
|
||||
title := streamReasoningTitles[id]
|
||||
stepStateMu.Unlock()
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: text,
|
||||
Title: title,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
streamCall.OnReasoningEnd = func(id string, _ fantasy.ReasoningContent) error {
|
||||
stepStateMu.Lock()
|
||||
if streamReasoningTitles[id] == "" {
|
||||
// At the end of reasoning we have the full text, so we can
|
||||
// safely evaluate first-line title format even if no newline
|
||||
// ever arrived in deltas.
|
||||
streamReasoningTitles[id] = chatprompt.ReasoningTitleFromFirstLine(
|
||||
streamReasoningText[id],
|
||||
)
|
||||
}
|
||||
title := streamReasoningTitles[id]
|
||||
stepStateMu.Unlock()
|
||||
if title != "" {
|
||||
// Publish a title-only reasoning part so clients can update the
|
||||
// reasoning header when metadata arrives at the end of streaming.
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Title: title,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolInputStart = func(id, toolName string) error {
|
||||
upsertDraftToolCall(id, toolName, "", false)
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolInputDelta = func(id, delta string) error {
|
||||
stepStateMu.Lock()
|
||||
toolName := streamToolNames[id]
|
||||
stepStateMu.Unlock()
|
||||
upsertDraftToolCall(id, toolName, delta, true)
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: id,
|
||||
ToolName: toolName,
|
||||
ArgsDelta: delta,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolCall = func(toolCall fantasy.ToolCallContent) error {
|
||||
upsertDraftToolCall(toolCall.ToolCallID, toolCall.ToolName, toolCall.Input, false)
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
chatprompt.PartFromContent(toolCall),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
streamCall.OnSource = func(source fantasy.SourceContent) error {
|
||||
appendDraftSource(source)
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
chatprompt.PartFromContent(source),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolResult = func(result fantasy.ToolResultContent) error {
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleTool,
|
||||
chatprompt.PartFromContent(result),
|
||||
)
|
||||
|
||||
stepStateMu.Lock()
|
||||
if result.ToolCallID != "" && strings.TrimSpace(result.ToolName) != "" {
|
||||
streamToolNames[result.ToolCallID] = result.ToolName
|
||||
}
|
||||
stepToolResultContents = append(stepToolResultContents, result)
|
||||
stepStateMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
streamCall.OnStepFinish = func(stepResult fantasy.StepResult) error {
|
||||
contextLimit := extractContextLimit(stepResult.ProviderMetadata)
|
||||
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
|
||||
contextLimit = sql.NullInt64{
|
||||
Int64: opts.ContextLimitFallback,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
return opts.PersistStep(ctx, PersistedStep{
|
||||
Content: stepResult.Content,
|
||||
Usage: stepResult.Usage,
|
||||
ContextLimit: contextLimit,
|
||||
})
|
||||
}
|
||||
|
||||
result, err := agent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
if persistErr := persistInterruptedStep(); persistErr != nil {
|
||||
if opts.OnInterruptedPersistError != nil {
|
||||
opts.OnInterruptedPersistError(persistErr)
|
||||
}
|
||||
}
|
||||
return nil, ErrInterrupted
|
||||
}
|
||||
return nil, xerrors.Errorf("stream response: %w", err)
|
||||
}
|
||||
if opts.Compaction != nil {
|
||||
if err := maybeCompact(ctx, opts, result); err != nil {
|
||||
if opts.Compaction.OnError != nil {
|
||||
opts.Compaction.OnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
//nolint:revive // Boolean controls Anthropic-specific caching behavior.
|
||||
func prepareStepResult(
|
||||
messages []fantasy.Message,
|
||||
sentinel string,
|
||||
activeTools []string,
|
||||
anthropicCaching bool,
|
||||
) fantasy.PrepareStepResult {
|
||||
filtered := make([]fantasy.Message, 0, len(messages))
|
||||
removed := false
|
||||
for _, message := range messages {
|
||||
if !removed &&
|
||||
message.Role == fantasy.MessageRoleUser &&
|
||||
len(message.Content) == 1 {
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
|
||||
if ok && textPart.Text == sentinel {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, message)
|
||||
}
|
||||
|
||||
result := fantasy.PrepareStepResult{
|
||||
Messages: filtered,
|
||||
}
|
||||
if anthropicCaching {
|
||||
result.Messages = addAnthropicPromptCaching(result.Messages)
|
||||
}
|
||||
if len(activeTools) > 0 {
|
||||
result.ActiveTools = append([]string(nil), activeTools...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool {
|
||||
if model == nil {
|
||||
return false
|
||||
}
|
||||
return model.Provider() == fantasyanthropic.Name
|
||||
}
|
||||
|
||||
func addAnthropicPromptCaching(messages []fantasy.Message) []fantasy.Message {
|
||||
for i := range messages {
|
||||
messages[i].ProviderOptions = nil
|
||||
}
|
||||
|
||||
providerOption := fantasy.ProviderOptions{
|
||||
fantasyanthropic.Name: &fantasyanthropic.ProviderCacheControlOptions{
|
||||
CacheControl: fantasyanthropic.CacheControl{Type: "ephemeral"},
|
||||
},
|
||||
}
|
||||
|
||||
lastSystemRoleIdx := -1
|
||||
systemMessageUpdated := false
|
||||
for i, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleSystem {
|
||||
lastSystemRoleIdx = i
|
||||
} else if !systemMessageUpdated && lastSystemRoleIdx >= 0 {
|
||||
messages[lastSystemRoleIdx].ProviderOptions = providerOption
|
||||
systemMessageUpdated = true
|
||||
}
|
||||
if i > len(messages)-3 {
|
||||
messages[i].ProviderOptions = providerOption
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
|
||||
if len(metadata) == 0 {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(metadata)
|
||||
if err != nil || len(encoded) == 0 {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
var payload any
|
||||
if err := json.Unmarshal(encoded, &payload); err != nil {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
limit, ok := findContextLimitValue(payload)
|
||||
if !ok {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
return sql.NullInt64{
|
||||
Int64: limit,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func findContextLimitValue(value any) (int64, bool) {
|
||||
var (
|
||||
limit int64
|
||||
found bool
|
||||
)
|
||||
|
||||
collectContextLimitValues(value, func(candidate int64) {
|
||||
if !found || candidate > limit {
|
||||
limit = candidate
|
||||
found = true
|
||||
}
|
||||
})
|
||||
|
||||
return limit, found
|
||||
}
|
||||
|
||||
func collectContextLimitValues(value any, onValue func(int64)) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
for key, child := range typed {
|
||||
if isContextLimitKey(key) {
|
||||
if numeric, ok := numericContextLimitValue(child); ok {
|
||||
onValue(numeric)
|
||||
}
|
||||
}
|
||||
collectContextLimitValues(child, onValue)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range typed {
|
||||
collectContextLimitValues(child, onValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isContextLimitKey(key string) bool {
|
||||
normalized := normalizeMetadataKey(key)
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch normalized {
|
||||
case
|
||||
"contextlimit",
|
||||
"contextwindow",
|
||||
"contextlength",
|
||||
"maxcontext",
|
||||
"maxcontexttokens",
|
||||
"maxinputtokens",
|
||||
"maxinputtoken",
|
||||
"inputtokenlimit":
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(normalized, "context") &&
|
||||
(strings.Contains(normalized, "limit") ||
|
||||
strings.Contains(normalized, "window") ||
|
||||
strings.Contains(normalized, "length") ||
|
||||
strings.HasPrefix(normalized, "max"))
|
||||
}
|
||||
|
||||
func normalizeMetadataKey(key string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(key))
|
||||
|
||||
for _, r := range key {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
_, _ = b.WriteRune(r)
|
||||
case r >= 'A' && r <= 'Z':
|
||||
_, _ = b.WriteRune(r + ('a' - 'A'))
|
||||
case r >= '0' && r <= '9':
|
||||
_, _ = b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func numericContextLimitValue(value any) (int64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case int64:
|
||||
return positiveInt64(typed)
|
||||
case int32:
|
||||
return positiveInt64(int64(typed))
|
||||
case int:
|
||||
return positiveInt64(int64(typed))
|
||||
case float64:
|
||||
casted := int64(typed)
|
||||
if typed > 0 && float64(casted) == typed {
|
||||
return casted, true
|
||||
}
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(typed), 10, 64)
|
||||
if err == nil {
|
||||
return positiveInt64(parsed)
|
||||
}
|
||||
case json.Number:
|
||||
parsed, err := typed.Int64()
|
||||
if err == nil {
|
||||
return positiveInt64(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func positiveInt64(value int64) (int64, bool) {
|
||||
if value <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const activeToolName = "read_file"
|
||||
|
||||
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var capturedCall fantasy.Call
|
||||
model := &loopTestModel{
|
||||
provider: fantasyanthropic.Name,
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
capturedCall = call
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
persistStepCalls := 0
|
||||
var persistedStep PersistedStep
|
||||
|
||||
_, err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleSystem, "sys-1"),
|
||||
textMessage(fantasy.MessageRoleSystem, "sys-2"),
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
textMessage(fantasy.MessageRoleAssistant, "working"),
|
||||
textMessage(fantasy.MessageRoleUser, "continue"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{
|
||||
newNoopTool(activeToolName),
|
||||
newNoopTool("write_file"),
|
||||
},
|
||||
MaxSteps: 3,
|
||||
ActiveTools: []string{activeToolName},
|
||||
ContextLimitFallback: 4096,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
persistStepCalls++
|
||||
persistedStep = step
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 1, persistStepCalls)
|
||||
require.True(t, persistedStep.ContextLimit.Valid)
|
||||
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
|
||||
|
||||
require.NotEmpty(t, capturedCall.Prompt)
|
||||
require.False(t, containsPromptSentinel(capturedCall.Prompt))
|
||||
require.Len(t, capturedCall.Tools, 1)
|
||||
require.Equal(t, activeToolName, capturedCall.Tools[0].GetName())
|
||||
|
||||
require.Len(t, capturedCall.Prompt, 5)
|
||||
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0]))
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1]))
|
||||
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2]))
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3]))
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
|
||||
}
|
||||
|
||||
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
started := make(chan struct{})
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
parts := []fantasy.StreamPart{
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolInputStart,
|
||||
ID: "interrupt-tool-1",
|
||||
ToolCallName: "read_file",
|
||||
},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolInputDelta,
|
||||
ID: "interrupt-tool-1",
|
||||
ToolCallName: "read_file",
|
||||
Delta: `{"path":"main.go"`,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"},
|
||||
}
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
default:
|
||||
close(started)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(nil)
|
||||
|
||||
go func() {
|
||||
<-started
|
||||
cancel(ErrInterrupted)
|
||||
}()
|
||||
|
||||
persistedAssistantCtxErr := xerrors.New("unset")
|
||||
var persistedContent []fantasy.Content
|
||||
|
||||
_, err := Run(ctx, RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{
|
||||
newNoopTool("read_file"),
|
||||
},
|
||||
MaxSteps: 3,
|
||||
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
|
||||
persistedAssistantCtxErr = persistCtx.Err()
|
||||
persistedContent = append([]fantasy.Content(nil), step.Content...)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, ErrInterrupted)
|
||||
require.NoError(t, persistedAssistantCtxErr)
|
||||
|
||||
require.NotEmpty(t, persistedContent)
|
||||
var (
|
||||
foundText bool
|
||||
foundToolCall bool
|
||||
foundToolResult bool
|
||||
)
|
||||
for _, block := range persistedContent {
|
||||
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
|
||||
if strings.Contains(text.Text, "partial assistant output") {
|
||||
foundText = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
|
||||
if toolCall.ToolCallID == "interrupt-tool-1" &&
|
||||
toolCall.ToolName == "read_file" &&
|
||||
strings.Contains(toolCall.Input, `"path":"main.go"`) {
|
||||
foundToolCall = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
|
||||
if toolResult.ToolCallID == "interrupt-tool-1" &&
|
||||
toolResult.ToolName == "read_file" {
|
||||
_, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError)
|
||||
require.True(t, isErr, "interrupted tool result should be an error")
|
||||
foundToolResult = true
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundText)
|
||||
require.True(t, foundToolCall)
|
||||
require.True(t, foundToolResult)
|
||||
}
|
||||
|
||||
type loopTestModel struct {
|
||||
provider string
|
||||
model string
|
||||
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
|
||||
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Provider() string {
|
||||
if m.provider != "" {
|
||||
return m.provider
|
||||
}
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Model() string {
|
||||
if m.model != "" {
|
||||
return m.model
|
||||
}
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
if m.generateFn != nil {
|
||||
return m.generateFn(ctx, call)
|
||||
}
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
if m.streamFn != nil {
|
||||
return m.streamFn(ctx, call)
|
||||
}
|
||||
return streamFromParts([]fantasy.StreamPart{{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
}}), nil
|
||||
}
|
||||
|
||||
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newNoopTool(name string) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
name,
|
||||
"test noop tool",
|
||||
func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.ToolResponse{}, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
|
||||
return fantasy.Message{
|
||||
Role: role,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: text},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func containsPromptSentinel(prompt []fantasy.Message) bool {
|
||||
for _, message := range prompt {
|
||||
if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
|
||||
if len(message.ProviderOptions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
options, ok := message.ProviderOptions[fantasyanthropic.Name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
|
||||
return ok && cacheOptions.CacheControl.Type == "ephemeral"
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package chatloop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCompactionThresholdPercent = int32(70)
|
||||
minCompactionThresholdPercent = int32(0)
|
||||
maxCompactionThresholdPercent = int32(100)
|
||||
|
||||
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
|
||||
"new assistant can continue seamlessly. Include the user's goals, " +
|
||||
"decisions made, concrete technical details (files, commands, APIs), " +
|
||||
"errors encountered and fixes, and open questions. Be dense and factual. " +
|
||||
"Omit pleasantries and next-step suggestions."
|
||||
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
|
||||
defaultCompactionTimeout = 90 * time.Second
|
||||
)
|
||||
|
||||
type CompactionOptions struct {
|
||||
ThresholdPercent int32
|
||||
ContextLimit int64
|
||||
SummaryPrompt string
|
||||
SystemSummaryPrefix string
|
||||
Timeout time.Duration
|
||||
Persist func(context.Context, CompactionResult) error
|
||||
OnError func(error)
|
||||
}
|
||||
|
||||
type CompactionResult struct {
|
||||
SystemSummary string
|
||||
SummaryReport string
|
||||
ThresholdPercent int32
|
||||
UsagePercent float64
|
||||
ContextTokens int64
|
||||
ContextLimit int64
|
||||
}
|
||||
|
||||
func maybeCompact(
|
||||
ctx context.Context,
|
||||
runOpts RunOptions,
|
||||
runResult *fantasy.AgentResult,
|
||||
) error {
|
||||
if runResult == nil || runOpts.Compaction == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := *runOpts.Compaction
|
||||
if config.Persist == nil {
|
||||
return xerrors.New("compaction persist callback is required")
|
||||
}
|
||||
if strings.TrimSpace(config.SummaryPrompt) == "" {
|
||||
config.SummaryPrompt = defaultCompactionSummaryPrompt
|
||||
}
|
||||
if strings.TrimSpace(config.SystemSummaryPrefix) == "" {
|
||||
config.SystemSummaryPrefix = defaultCompactionSystemSummaryPrefix
|
||||
}
|
||||
if config.Timeout <= 0 {
|
||||
config.Timeout = defaultCompactionTimeout
|
||||
}
|
||||
if config.ThresholdPercent < minCompactionThresholdPercent ||
|
||||
config.ThresholdPercent > maxCompactionThresholdPercent {
|
||||
config.ThresholdPercent = defaultCompactionThresholdPercent
|
||||
}
|
||||
|
||||
if config.ThresholdPercent >= maxCompactionThresholdPercent {
|
||||
return nil
|
||||
}
|
||||
if runOpts.MaxSteps > 0 && len(runResult.Steps) >= runOpts.MaxSteps {
|
||||
lastStep := runResult.Steps[len(runResult.Steps)-1]
|
||||
if lastStep.FinishReason == fantasy.FinishReasonToolCalls &&
|
||||
len(lastStep.Content.ToolCalls()) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
contextTokens := int64(0)
|
||||
contextLimitFromMetadata := int64(0)
|
||||
for i := len(runResult.Steps) - 1; i >= 0; i-- {
|
||||
usage := runResult.Steps[i].Usage
|
||||
total := int64(0)
|
||||
hasContextTokens := false
|
||||
|
||||
if usage.InputTokens > 0 {
|
||||
total += usage.InputTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if usage.CacheReadTokens > 0 {
|
||||
total += usage.CacheReadTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if usage.CacheCreationTokens > 0 {
|
||||
total += usage.CacheCreationTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if !hasContextTokens && usage.TotalTokens > 0 {
|
||||
total = usage.TotalTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if !hasContextTokens || total <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
contextTokens = total
|
||||
metadataLimit := extractContextLimit(runResult.Steps[i].ProviderMetadata)
|
||||
if metadataLimit.Valid && metadataLimit.Int64 > 0 {
|
||||
contextLimitFromMetadata = metadataLimit.Int64
|
||||
}
|
||||
break
|
||||
}
|
||||
if contextTokens <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
contextLimit := contextLimitFromMetadata
|
||||
if contextLimit <= 0 && config.ContextLimit > 0 {
|
||||
contextLimit = config.ContextLimit
|
||||
}
|
||||
if contextLimit <= 0 && runOpts.ContextLimitFallback > 0 {
|
||||
contextLimit = runOpts.ContextLimitFallback
|
||||
}
|
||||
if contextLimit <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
usagePercent := (float64(contextTokens) / float64(contextLimit)) * 100
|
||||
if usagePercent < float64(config.ThresholdPercent) {
|
||||
return nil
|
||||
}
|
||||
|
||||
summary, err := generateCompactionSummary(
|
||||
ctx,
|
||||
runOpts.Model,
|
||||
runOpts.Messages,
|
||||
runResult.Steps,
|
||||
config,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if summary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
systemSummary := strings.TrimSpace(
|
||||
config.SystemSummaryPrefix + "\n\n" + summary,
|
||||
)
|
||||
|
||||
return config.Persist(ctx, CompactionResult{
|
||||
SystemSummary: systemSummary,
|
||||
SummaryReport: summary,
|
||||
ThresholdPercent: config.ThresholdPercent,
|
||||
UsagePercent: usagePercent,
|
||||
ContextTokens: contextTokens,
|
||||
ContextLimit: contextLimit,
|
||||
})
|
||||
}
|
||||
|
||||
func generateCompactionSummary(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
messages []fantasy.Message,
|
||||
steps []fantasy.StepResult,
|
||||
options CompactionOptions,
|
||||
) (string, error) {
|
||||
summaryPrompt := make([]fantasy.Message, 0, len(messages)+len(steps)+1)
|
||||
summaryPrompt = append(summaryPrompt, messages...)
|
||||
for _, step := range steps {
|
||||
summaryPrompt = append(summaryPrompt, step.Messages...)
|
||||
}
|
||||
summaryPrompt = append(summaryPrompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: options.SummaryPrompt},
|
||||
},
|
||||
})
|
||||
toolChoice := fantasy.ToolChoiceNone
|
||||
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
|
||||
defer cancel()
|
||||
|
||||
response, err := model.Generate(summaryCtx, fantasy.Call{
|
||||
Prompt: summaryPrompt,
|
||||
ToolChoice: &toolChoice,
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate summary text: %w", err)
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(response.Content))
|
||||
for _, block := range response.Content {
|
||||
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(textBlock.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(parts, " ")), nil
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func TestRun_Compaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("PersistsWhenThresholdReached", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persistCompactionCalls := 0
|
||||
var persistedCompaction CompactionResult
|
||||
const summaryText = "summary text for compaction"
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
TotalTokens: 85,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
require.NotEmpty(t, call.Prompt)
|
||||
lastPrompt := call.Prompt[len(call.Prompt)-1]
|
||||
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
|
||||
require.Len(t, lastPrompt.Content, 1)
|
||||
|
||||
instruction, ok := fantasy.AsMessagePart[fantasy.TextPart](lastPrompt.Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "summarize now", instruction.Text)
|
||||
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
SummaryPrompt: "summarize now",
|
||||
Persist: func(_ context.Context, result CompactionResult) error {
|
||||
persistCompactionCalls++
|
||||
persistedCompaction = result
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, persistCompactionCalls)
|
||||
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
|
||||
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
|
||||
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
|
||||
require.Equal(t, int64(100), persistedCompaction.ContextLimit)
|
||||
require.InDelta(t, 80.0, persistedCompaction.UsagePercent, 0.0001)
|
||||
})
|
||||
|
||||
t.Run("ErrorsAreReported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return nil, xerrors.New("generate failed")
|
||||
},
|
||||
}
|
||||
|
||||
compactionErr := xerrors.New("unset")
|
||||
_, err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
return nil
|
||||
},
|
||||
OnError: func(err error) {
|
||||
compactionErr = err
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Error(t, compactionErr)
|
||||
require.ErrorContains(t, compactionErr, "generate summary text")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,982 @@
|
||||
package chatprompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
|
||||
func ConvertMessages(
|
||||
messages []database.ChatMessage,
|
||||
) ([]fantasy.Message, error) {
|
||||
prompt := make([]fantasy.Message, 0, len(messages))
|
||||
toolNameByCallID := make(map[string]string)
|
||||
for _, message := range messages {
|
||||
visibility := message.Visibility
|
||||
if visibility == "" {
|
||||
visibility = database.ChatMessageVisibilityBoth
|
||||
}
|
||||
if visibility != database.ChatMessageVisibilityModel &&
|
||||
visibility != database.ChatMessageVisibilityBoth {
|
||||
continue
|
||||
}
|
||||
|
||||
switch message.Role {
|
||||
case string(fantasy.MessageRoleSystem):
|
||||
content, err := parseSystemContent(message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
continue
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: content},
|
||||
},
|
||||
})
|
||||
case string(fantasy.MessageRoleUser):
|
||||
content, err := ParseContent(string(fantasy.MessageRoleUser), message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: ToMessageParts(content),
|
||||
})
|
||||
case string(fantasy.MessageRoleAssistant):
|
||||
content, err := ParseContent(string(fantasy.MessageRoleAssistant), message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := normalizeAssistantToolCallInputs(ToMessageParts(content))
|
||||
for _, toolCall := range ExtractToolCalls(parts) {
|
||||
if toolCall.ToolCallID == "" || strings.TrimSpace(toolCall.ToolName) == "" {
|
||||
continue
|
||||
}
|
||||
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: parts,
|
||||
})
|
||||
case string(fantasy.MessageRoleTool):
|
||||
rows, err := parseToolResultRows(message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]fantasy.MessagePart, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if row.ToolCallID != "" && row.ToolName != "" {
|
||||
toolNameByCallID[sanitizeToolCallID(row.ToolCallID)] = row.ToolName
|
||||
}
|
||||
parts = append(parts, row.toToolResultPart())
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: parts,
|
||||
})
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported chat message role %q", message.Role)
|
||||
}
|
||||
}
|
||||
prompt = injectMissingToolResults(prompt)
|
||||
prompt = injectMissingToolUses(
|
||||
prompt,
|
||||
toolNameByCallID,
|
||||
)
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
// PrependSystem prepends a system message unless an existing system
|
||||
// message already mentions create_workspace guidance.
|
||||
func PrependSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
|
||||
instruction = strings.TrimSpace(instruction)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
for _, message := range prompt {
|
||||
if message.Role != fantasy.MessageRoleSystem {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(textPart.Text), "create_workspace") {
|
||||
return prompt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]fantasy.Message, 0, len(prompt)+1)
|
||||
out = append(out, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: instruction},
|
||||
},
|
||||
})
|
||||
out = append(out, prompt...)
|
||||
return out
|
||||
}
|
||||
|
||||
// InsertSystem inserts a system message after the existing system
|
||||
// block and before the first non-system message.
|
||||
func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
|
||||
instruction = strings.TrimSpace(instruction)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
|
||||
systemMessage := fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: instruction},
|
||||
},
|
||||
}
|
||||
|
||||
out := make([]fantasy.Message, 0, len(prompt)+1)
|
||||
inserted := false
|
||||
for _, message := range prompt {
|
||||
if !inserted && message.Role != fantasy.MessageRoleSystem {
|
||||
out = append(out, systemMessage)
|
||||
inserted = true
|
||||
}
|
||||
out = append(out, message)
|
||||
}
|
||||
if !inserted {
|
||||
out = append(out, systemMessage)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AppendUser appends an instruction as a user message at the end of
|
||||
// the prompt.
|
||||
func AppendUser(prompt []fantasy.Message, instruction string) []fantasy.Message {
|
||||
instruction = strings.TrimSpace(instruction)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
out := make([]fantasy.Message, 0, len(prompt)+1)
|
||||
out = append(out, prompt...)
|
||||
out = append(out, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: instruction},
|
||||
},
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// ParseContent decodes persisted chat message content blocks.
|
||||
func ParseContent(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
|
||||
return []fantasy.Content{fantasy.TextContent{Text: text}}, nil
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
|
||||
return nil, xerrors.Errorf("parse %s content: %w", role, err)
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(rawBlocks))
|
||||
for i, rawBlock := range rawBlocks {
|
||||
block, err := fantasy.UnmarshalContent(rawBlock)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse %s content block %d: %w", role, i, err)
|
||||
}
|
||||
content = append(content, block)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// toolResultRaw is an untyped representation of a persisted tool
|
||||
// result row. We intentionally avoid a strict Go struct so that
|
||||
// historical shapes are never rejected.
|
||||
type toolResultRaw struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// parseToolResultRows decodes persisted tool result rows.
|
||||
func parseToolResultRows(raw pqtype.NullRawMessage) ([]toolResultRaw, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rows []toolResultRaw
|
||||
if err := json.Unmarshal(raw.RawMessage, &rows); err != nil {
|
||||
return nil, xerrors.Errorf("parse tool content: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (r toolResultRaw) toToolResultPart() fantasy.ToolResultPart {
|
||||
toolCallID := sanitizeToolCallID(r.ToolCallID)
|
||||
resultText := string(r.Result)
|
||||
if resultText == "" || resultText == "null" {
|
||||
resultText = "{}"
|
||||
}
|
||||
|
||||
if r.IsError {
|
||||
message := strings.TrimSpace(resultText)
|
||||
if extracted := extractErrorString(r.Result); extracted != "" {
|
||||
message = extracted
|
||||
}
|
||||
return fantasy.ToolResultPart{
|
||||
ToolCallID: toolCallID,
|
||||
Output: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return fantasy.ToolResultPart{
|
||||
ToolCallID: toolCallID,
|
||||
Output: fantasy.ToolResultOutputContentText{
|
||||
Text: resultText,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// extractErrorString pulls the "error" field from a JSON object if
|
||||
// present, returning it as a string. Returns "" if the field is
|
||||
// missing or the input is not an object.
|
||||
func extractErrorString(raw json.RawMessage) string {
|
||||
var fields map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &fields); err != nil {
|
||||
return ""
|
||||
}
|
||||
errField, ok := fields["error"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(errField, &s); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
// ToMessageParts converts fantasy content blocks into message parts.
|
||||
func ToMessageParts(content []fantasy.Content) []fantasy.MessagePart {
|
||||
parts := make([]fantasy.MessagePart, 0, len(content))
|
||||
for _, block := range content {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
parts = append(parts, fantasy.TextPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.TextContent:
|
||||
parts = append(parts, fantasy.TextPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.ReasoningContent:
|
||||
parts = append(parts, fantasy.ReasoningPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.ReasoningContent:
|
||||
parts = append(parts, fantasy.ReasoningPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.ToolCallContent:
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
ToolName: value.ToolName,
|
||||
Input: value.Input,
|
||||
ProviderExecuted: value.ProviderExecuted,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.ToolCallContent:
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
ToolName: value.ToolName,
|
||||
Input: value.Input,
|
||||
ProviderExecuted: value.ProviderExecuted,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.FileContent:
|
||||
parts = append(parts, fantasy.FilePart{
|
||||
Data: value.Data,
|
||||
MediaType: value.MediaType,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.FileContent:
|
||||
parts = append(parts, fantasy.FilePart{
|
||||
Data: value.Data,
|
||||
MediaType: value.MediaType,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.ToolResultContent:
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
Output: value.Result,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.ToolResultContent:
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
Output: value.Result,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func normalizeAssistantToolCallInputs(
|
||||
parts []fantasy.MessagePart,
|
||||
) []fantasy.MessagePart {
|
||||
normalized := make([]fantasy.MessagePart, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
|
||||
if !ok {
|
||||
normalized = append(normalized, part)
|
||||
continue
|
||||
}
|
||||
|
||||
toolCall.Input = normalizeToolCallInput(toolCall.Input)
|
||||
normalized = append(normalized, toolCall)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// normalizeToolCallInput guarantees tool call input is a JSON object string.
|
||||
// Anthropic drops assistant tool calls with malformed input, which can leave
|
||||
// following tool results orphaned.
|
||||
func normalizeToolCallInput(input string) string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
var object map[string]any
|
||||
if err := json.Unmarshal([]byte(input), &object); err != nil || object == nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// ExtractToolCalls returns all tool call parts as content blocks.
|
||||
func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
|
||||
toolCalls := make([]fantasy.ToolCallContent, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolCalls = append(toolCalls, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCall.ToolCallID,
|
||||
ToolName: toolCall.ToolName,
|
||||
Input: toolCall.Input,
|
||||
ProviderExecuted: toolCall.ProviderExecuted,
|
||||
})
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// MarshalContent encodes message content blocks for persistence.
|
||||
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
|
||||
if len(blocks) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
|
||||
encodedBlocks := make([]json.RawMessage, 0, len(blocks))
|
||||
for i, block := range blocks {
|
||||
encoded, err := marshalContentBlock(block)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf(
|
||||
"encode content block %d: %w",
|
||||
i,
|
||||
err,
|
||||
)
|
||||
}
|
||||
encodedBlocks = append(encodedBlocks, encoded)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(encodedBlocks)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf("encode content blocks: %w", err)
|
||||
}
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
|
||||
}
|
||||
|
||||
// MarshalToolResult encodes a single tool result for persistence as
|
||||
// an opaque JSON blob. The stored shape is
|
||||
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
|
||||
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) (pqtype.NullRawMessage, error) {
|
||||
row := toolResultRaw{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Result: result,
|
||||
IsError: isError,
|
||||
}
|
||||
data, err := json.Marshal([]toolResultRaw{row})
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf("encode tool result: %w", err)
|
||||
}
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
|
||||
}
|
||||
|
||||
// MarshalToolResultContent encodes a fantasy tool result content
|
||||
// block for persistence. It extracts the raw fields and delegates
|
||||
// to MarshalToolResult.
|
||||
func MarshalToolResultContent(content fantasy.ToolResultContent) (pqtype.NullRawMessage, error) {
|
||||
var result json.RawMessage
|
||||
var isError bool
|
||||
|
||||
switch output := content.Result.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
isError = true
|
||||
if output.Error != nil {
|
||||
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
|
||||
} else {
|
||||
result = []byte(`{"error":""}`)
|
||||
}
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
result = json.RawMessage(output.Text)
|
||||
if !json.Valid(result) {
|
||||
result, _ = json.Marshal(map[string]any{"output": output.Text})
|
||||
}
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
result, _ = json.Marshal(map[string]any{
|
||||
"data": output.Data,
|
||||
"mime_type": output.MediaType,
|
||||
"text": output.Text,
|
||||
})
|
||||
default:
|
||||
result = []byte(`{}`)
|
||||
}
|
||||
|
||||
return MarshalToolResult(content.ToolCallID, content.ToolName, result, isError)
|
||||
}
|
||||
|
||||
// PartFromContent converts fantasy content into a SDK chat message part.
|
||||
func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
Title: reasoningSummaryTitle(value.ProviderMetadata),
|
||||
}
|
||||
case *fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
Title: reasoningSummaryTitle(value.ProviderMetadata),
|
||||
}
|
||||
case fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case *fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case *fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case fantasy.ToolResultContent:
|
||||
return toolResultContentToPart(value)
|
||||
case *fantasy.ToolResultContent:
|
||||
return toolResultContentToPart(*value)
|
||||
default:
|
||||
return codersdk.ChatMessagePart{}
|
||||
}
|
||||
}
|
||||
|
||||
// ToolResultToPart converts a tool call ID, raw result, and error
|
||||
// flag into a ChatMessagePart. This is the minimal conversion used
|
||||
// both during streaming and when reading from the database.
|
||||
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Result: result,
|
||||
IsError: isError,
|
||||
}
|
||||
}
|
||||
|
||||
// toolResultContentToPart converts a fantasy ToolResultContent
|
||||
// directly into a ChatMessagePart without an intermediate struct.
|
||||
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
|
||||
var result json.RawMessage
|
||||
var isError bool
|
||||
|
||||
switch output := content.Result.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
isError = true
|
||||
if output.Error != nil {
|
||||
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
|
||||
} else {
|
||||
result = []byte(`{"error":""}`)
|
||||
}
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
result = json.RawMessage(output.Text)
|
||||
// Ensure valid JSON; wrap in an object if not.
|
||||
if !json.Valid(result) {
|
||||
result, _ = json.Marshal(map[string]any{"output": output.Text})
|
||||
}
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
result, _ = json.Marshal(map[string]any{
|
||||
"data": output.Data,
|
||||
"mime_type": output.MediaType,
|
||||
"text": output.Text,
|
||||
})
|
||||
default:
|
||||
result = []byte(`{}`)
|
||||
}
|
||||
|
||||
return ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
|
||||
}
|
||||
|
||||
// ReasoningTitleFromFirstLine extracts a compact markdown title.
|
||||
func ReasoningTitleFromFirstLine(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
firstLine := text
|
||||
if idx := strings.IndexAny(firstLine, "\r\n"); idx >= 0 {
|
||||
firstLine = firstLine[:idx]
|
||||
}
|
||||
firstLine = strings.TrimSpace(firstLine)
|
||||
if firstLine == "" || !strings.HasPrefix(firstLine, "**") {
|
||||
return ""
|
||||
}
|
||||
|
||||
rest := firstLine[2:]
|
||||
end := strings.Index(rest, "**")
|
||||
if end < 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(rest[:end])
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Require the first line to be exactly "**title**" (ignoring
|
||||
// surrounding whitespace) so providers without this format don't
|
||||
// accidentally emit a title.
|
||||
if strings.TrimSpace(rest[end+2:]) != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return compactReasoningSummaryTitle(title)
|
||||
}
|
||||
|
||||
func injectMissingToolResults(prompt []fantasy.Message) []fantasy.Message {
|
||||
result := make([]fantasy.Message, 0, len(prompt))
|
||||
for i := 0; i < len(prompt); i++ {
|
||||
msg := prompt[i]
|
||||
result = append(result, msg)
|
||||
|
||||
if msg.Role != fantasy.MessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
toolCalls := ExtractToolCalls(msg.Content)
|
||||
if len(toolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect the tool call IDs that have results in the
|
||||
// following tool message(s).
|
||||
answered := make(map[string]struct{})
|
||||
j := i + 1
|
||||
for ; j < len(prompt); j++ {
|
||||
if prompt[j].Role != fantasy.MessageRoleTool {
|
||||
break
|
||||
}
|
||||
for _, part := range prompt[j].Content {
|
||||
tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
answered[tr.ToolCallID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if i+1 < j {
|
||||
// Preserve persisted tool result ordering and inject any
|
||||
// synthetic results after the existing contiguous tool messages.
|
||||
result = append(result, prompt[i+1:j]...)
|
||||
i = j - 1
|
||||
}
|
||||
|
||||
// Build synthetic results for any unanswered tool calls.
|
||||
var missing []fantasy.MessagePart
|
||||
for _, tc := range toolCalls {
|
||||
if _, ok := answered[tc.ToolCallID]; !ok {
|
||||
missing = append(missing, fantasy.ToolResultPart{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
Output: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("tool call was interrupted and did not receive a result"),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
result = append(result, fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: missing,
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func injectMissingToolUses(
|
||||
prompt []fantasy.Message,
|
||||
toolNameByCallID map[string]string,
|
||||
) []fantasy.Message {
|
||||
result := make([]fantasy.Message, 0, len(prompt))
|
||||
for _, msg := range prompt {
|
||||
if msg.Role != fantasy.MessageRoleTool {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
toolResults := make([]fantasy.ToolResultPart, 0, len(msg.Content))
|
||||
for _, part := range msg.Content {
|
||||
toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
if len(toolResults) == 0 {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk backwards through the result to find the nearest
|
||||
// preceding assistant message (skipping over other tool
|
||||
// messages that belong to the same batch of results).
|
||||
answeredByPrevious := make(map[string]struct{})
|
||||
for k := len(result) - 1; k >= 0; k-- {
|
||||
if result[k].Role == fantasy.MessageRoleAssistant {
|
||||
for _, toolCall := range ExtractToolCalls(result[k].Content) {
|
||||
toolCallID := sanitizeToolCallID(toolCall.ToolCallID)
|
||||
if toolCallID == "" {
|
||||
continue
|
||||
}
|
||||
answeredByPrevious[toolCallID] = struct{}{}
|
||||
}
|
||||
break
|
||||
}
|
||||
if result[k].Role != fantasy.MessageRoleTool {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
matchingResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
|
||||
orphanResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
|
||||
for _, toolResult := range toolResults {
|
||||
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
|
||||
if _, ok := answeredByPrevious[toolCallID]; ok {
|
||||
matchingResults = append(matchingResults, toolResult)
|
||||
continue
|
||||
}
|
||||
orphanResults = append(orphanResults, toolResult)
|
||||
}
|
||||
|
||||
if len(orphanResults) == 0 {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
syntheticToolUse := syntheticToolUseMessage(
|
||||
orphanResults,
|
||||
toolNameByCallID,
|
||||
)
|
||||
if len(syntheticToolUse.Content) == 0 {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(matchingResults) > 0 {
|
||||
result = append(result, toolMessageFromToolResultParts(matchingResults))
|
||||
}
|
||||
result = append(result, syntheticToolUse)
|
||||
result = append(result, toolMessageFromToolResultParts(orphanResults))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toolMessageFromToolResultParts(results []fantasy.ToolResultPart) fantasy.Message {
|
||||
parts := make([]fantasy.MessagePart, 0, len(results))
|
||||
for _, result := range results {
|
||||
parts = append(parts, result)
|
||||
}
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: parts,
|
||||
}
|
||||
}
|
||||
|
||||
func syntheticToolUseMessage(
|
||||
toolResults []fantasy.ToolResultPart,
|
||||
toolNameByCallID map[string]string,
|
||||
) fantasy.Message {
|
||||
parts := make([]fantasy.MessagePart, 0, len(toolResults))
|
||||
seen := make(map[string]struct{}, len(toolResults))
|
||||
|
||||
for _, toolResult := range toolResults {
|
||||
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
|
||||
if toolCallID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[toolCallID]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(toolNameByCallID[toolCallID])
|
||||
if toolName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[toolCallID] = struct{}{}
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: "{}",
|
||||
})
|
||||
}
|
||||
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: parts,
|
||||
}
|
||||
}
|
||||
|
||||
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var content string
|
||||
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
|
||||
return "", xerrors.Errorf("parse system message content: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func sanitizeToolCallID(id string) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
return toolCallIDSanitizer.ReplaceAllString(id, "_")
|
||||
}
|
||||
|
||||
func marshalContentBlock(block fantasy.Content) (json.RawMessage, error) {
|
||||
encoded, err := json.Marshal(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
title, ok := reasoningTitleFromContent(block)
|
||||
if !ok || title == "" {
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(encoded, &envelope); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
|
||||
return encoded, nil
|
||||
}
|
||||
if envelope.Data == nil {
|
||||
envelope.Data = map[string]any{}
|
||||
}
|
||||
envelope.Data["title"] = title
|
||||
|
||||
encodedWithTitle, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encodedWithTitle, nil
|
||||
}
|
||||
|
||||
func reasoningTitleFromContent(block fantasy.Content) (string, bool) {
|
||||
switch value := block.(type) {
|
||||
case fantasy.ReasoningContent:
|
||||
return ReasoningTitleFromFirstLine(value.Text), true
|
||||
case *fantasy.ReasoningContent:
|
||||
if value == nil {
|
||||
return "", false
|
||||
}
|
||||
return ReasoningTitleFromFirstLine(value.Text), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func reasoningSummaryTitle(metadata fantasy.ProviderMetadata) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
reasoningMetadata := fantasyopenai.GetReasoningMetadata(
|
||||
fantasy.ProviderOptions(metadata),
|
||||
)
|
||||
if reasoningMetadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, summary := range reasoningMetadata.Summary {
|
||||
if title := compactReasoningSummaryTitle(summary); title != "" {
|
||||
return title
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func compactReasoningSummaryTitle(summary string) string {
|
||||
const maxWords = 8
|
||||
const maxRunes = 80
|
||||
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
summary = strings.Trim(summary, "\"'`")
|
||||
summary = reasoningSummaryHeadline(summary)
|
||||
words := strings.Fields(summary)
|
||||
if len(words) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
truncated := false
|
||||
if len(words) > maxWords {
|
||||
words = words[:maxWords]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "…"
|
||||
}
|
||||
return truncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
func reasoningSummaryHeadline(summary string) string {
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// OpenAI summary_text may be markdown like:
|
||||
// "**Title**\n\nLonger explanation ...".
|
||||
// Keep only the heading segment for UI titles.
|
||||
if idx := strings.Index(summary, "\n\n"); idx >= 0 {
|
||||
summary = summary[:idx]
|
||||
}
|
||||
|
||||
if idx := strings.IndexAny(summary, "\r\n"); idx >= 0 {
|
||||
summary = summary[:idx]
|
||||
}
|
||||
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(summary, "**") {
|
||||
rest := summary[2:]
|
||||
if end := strings.Index(rest, "**"); end >= 0 {
|
||||
bold := strings.TrimSpace(rest[:end])
|
||||
if bold != "" {
|
||||
summary = bold
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Trim(summary, "\"'`"))
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
runes := []rune(value)
|
||||
if len(runes) <= maxLen {
|
||||
return value
|
||||
}
|
||||
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package chatprompt_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
input: "{\"command\":",
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "non-object json",
|
||||
input: "[]",
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "valid object json",
|
||||
input: "{\"command\":\"ls\"}",
|
||||
expected: "{\"command\":\"ls\"}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assistantContent, err := chatprompt.MarshalContent([]fantasy.Content{
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "toolu_01C4PqN6F2493pi7Ebag8Vg7",
|
||||
ToolName: "execute",
|
||||
Input: tc.input,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
toolContent, err := chatprompt.MarshalToolResult(
|
||||
"toolu_01C4PqN6F2493pi7Ebag8Vg7",
|
||||
"execute",
|
||||
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
|
||||
{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: assistantContent,
|
||||
},
|
||||
{
|
||||
Role: string(fantasy.MessageRoleTool),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: toolContent,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
require.Equal(t, fantasy.MessageRoleAssistant, prompt[0].Role)
|
||||
toolCalls := chatprompt.ExtractToolCalls(prompt[0].Content)
|
||||
require.Len(t, toolCalls, 1)
|
||||
require.Equal(t, tc.expected, toolCalls[0].Input)
|
||||
require.Equal(t, "execute", toolCalls[0].ToolName)
|
||||
require.Equal(t, "toolu_01C4PqN6F2493pi7Ebag8Vg7", toolCalls[0].ToolCallID)
|
||||
|
||||
require.Equal(t, fantasy.MessageRoleTool, prompt[1].Role)
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,191 @@
|
||||
package chatprovider_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
fantasyopenrouter "charm.land/fantasy/providers/openrouter"
|
||||
fantasyvercel "charm.land/fantasy/providers/vercel"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestReasoningEffortFromChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
provider string
|
||||
input *string
|
||||
want *string
|
||||
}{
|
||||
{
|
||||
name: "OpenAICaseInsensitive",
|
||||
provider: "openai",
|
||||
input: stringPtr(" HIGH "),
|
||||
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
|
||||
},
|
||||
{
|
||||
name: "AnthropicEffort",
|
||||
provider: "anthropic",
|
||||
input: stringPtr("max"),
|
||||
want: stringPtr(string(fantasyanthropic.EffortMax)),
|
||||
},
|
||||
{
|
||||
name: "OpenRouterEffort",
|
||||
provider: "openrouter",
|
||||
input: stringPtr("medium"),
|
||||
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
|
||||
},
|
||||
{
|
||||
name: "VercelEffort",
|
||||
provider: "vercel",
|
||||
input: stringPtr("xhigh"),
|
||||
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
|
||||
},
|
||||
{
|
||||
name: "InvalidEffortReturnsNil",
|
||||
provider: "openai",
|
||||
input: stringPtr("unknown"),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "UnsupportedProviderReturnsNil",
|
||||
provider: "bedrock",
|
||||
input: stringPtr("high"),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "NilInputReturnsNil",
|
||||
provider: "openai",
|
||||
input: nil,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := chatprovider.ReasoningEffortFromChat(tt.provider, tt.input)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
options := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
Order: []string{"openai"},
|
||||
},
|
||||
},
|
||||
}
|
||||
defaults := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(false),
|
||||
Exclude: boolPtr(true),
|
||||
MaxTokens: int64Ptr(123),
|
||||
Effort: stringPtr("high"),
|
||||
},
|
||||
IncludeUsage: boolPtr(true),
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
Order: []string{"anthropic"},
|
||||
AllowFallbacks: boolPtr(true),
|
||||
RequireParameters: boolPtr(false),
|
||||
DataCollection: stringPtr("allow"),
|
||||
Only: []string{"openai"},
|
||||
Ignore: []string{"foo"},
|
||||
Quantizations: []string{"int8"},
|
||||
Sort: stringPtr("latency"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingProviderOptions(&options, defaults)
|
||||
|
||||
require.NotNil(t, options)
|
||||
require.NotNil(t, options.OpenRouter)
|
||||
require.NotNil(t, options.OpenRouter.Reasoning)
|
||||
require.True(t, *options.OpenRouter.Reasoning.Enabled)
|
||||
require.Equal(t, true, *options.OpenRouter.Reasoning.Exclude)
|
||||
require.EqualValues(t, 123, *options.OpenRouter.Reasoning.MaxTokens)
|
||||
require.Equal(t, "high", *options.OpenRouter.Reasoning.Effort)
|
||||
require.NotNil(t, options.OpenRouter.IncludeUsage)
|
||||
require.True(t, *options.OpenRouter.IncludeUsage)
|
||||
|
||||
require.NotNil(t, options.OpenRouter.Provider)
|
||||
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Order)
|
||||
require.NotNil(t, options.OpenRouter.Provider.AllowFallbacks)
|
||||
require.True(t, *options.OpenRouter.Provider.AllowFallbacks)
|
||||
require.NotNil(t, options.OpenRouter.Provider.RequireParameters)
|
||||
require.False(t, *options.OpenRouter.Provider.RequireParameters)
|
||||
require.Equal(t, "allow", *options.OpenRouter.Provider.DataCollection)
|
||||
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Only)
|
||||
require.Equal(t, []string{"foo"}, options.OpenRouter.Provider.Ignore)
|
||||
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
|
||||
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
|
||||
}
|
||||
|
||||
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := codersdk.ChatModelCallConfig{
|
||||
Temperature: float64Ptr(0.2),
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("alice"),
|
||||
},
|
||||
},
|
||||
}
|
||||
defaults := codersdk.ChatModelCallConfig{
|
||||
MaxOutputTokens: int64Ptr(512),
|
||||
Temperature: float64Ptr(0.9),
|
||||
TopP: float64Ptr(0.8),
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("bob"),
|
||||
ReasoningEffort: stringPtr("medium"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingCallConfig(&dst, defaults)
|
||||
|
||||
require.NotNil(t, dst.MaxOutputTokens)
|
||||
require.EqualValues(t, 512, *dst.MaxOutputTokens)
|
||||
require.NotNil(t, dst.Temperature)
|
||||
require.Equal(t, 0.2, *dst.Temperature)
|
||||
require.NotNil(t, dst.TopP)
|
||||
require.Equal(t, 0.8, *dst.TopP)
|
||||
require.NotNil(t, dst.ProviderOptions)
|
||||
require.NotNil(t, dst.ProviderOptions.OpenAI)
|
||||
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
|
||||
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
|
||||
func boolPtr(value bool) *bool {
|
||||
return &value
|
||||
}
|
||||
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func float64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AnthropicHandler handles Anthropic API requests and returns a response.
|
||||
type AnthropicHandler func(req *AnthropicRequest) AnthropicResponse
|
||||
|
||||
// AnthropicResponse represents a response to an Anthropic request.
|
||||
// Either StreamingChunks or Response should be set, not both.
|
||||
type AnthropicResponse struct {
|
||||
StreamingChunks <-chan AnthropicChunk
|
||||
Response *AnthropicMessage
|
||||
}
|
||||
|
||||
// AnthropicRequest represents an Anthropic messages request.
|
||||
type AnthropicRequest struct {
|
||||
*http.Request // Embed http.Request
|
||||
Model string `json:"model"`
|
||||
Messages []AnthropicRequestMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
|
||||
Options map[string]interface{} `json:",inline"` //nolint:revive
|
||||
}
|
||||
|
||||
// AnthropicRequestMessage represents a message in an Anthropic request.
|
||||
// Content may be either a string or a structured content array.
|
||||
type AnthropicRequestMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// AnthropicMessage represents a message in an Anthropic response.
|
||||
type AnthropicMessage struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicUsage represents usage information in an Anthropic response.
|
||||
type AnthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// AnthropicChunk represents a streaming chunk from Anthropic.
|
||||
type AnthropicChunk struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Message AnthropicChunkMessage `json:"message,omitempty"`
|
||||
ContentBlock AnthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta AnthropicDeltaBlock `json:"delta,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicChunkMessage represents message metadata in a chunk.
|
||||
type AnthropicChunkMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// AnthropicContentBlock represents a content block in a chunk.
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicDeltaBlock represents a delta block in a chunk.
|
||||
type AnthropicDeltaBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicServer is a test server that mocks the Anthropic API.
|
||||
type anthropicServer struct {
|
||||
mu sync.Mutex
|
||||
server *httptest.Server
|
||||
handler AnthropicHandler
|
||||
request *AnthropicRequest
|
||||
}
|
||||
|
||||
// NewAnthropic creates a new Anthropic test server with a handler function.
|
||||
// The handler is called for each request and should return either a streaming
|
||||
// response (via channel) or a non-streaming response.
|
||||
// Returns the base URL of the server.
|
||||
func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &anthropicServer{
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /v1/messages", s.handleMessages)
|
||||
|
||||
s.server = httptest.NewServer(mux)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.server.Close()
|
||||
})
|
||||
|
||||
return s.server.URL
|
||||
}
|
||||
|
||||
func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request) {
|
||||
var req AnthropicRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Return a more detailed error for debugging
|
||||
http.Error(w, fmt.Sprintf("decode request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.Request = r // Embed the original http.Request
|
||||
|
||||
s.mu.Lock()
|
||||
s.request = &req
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handler(&req)
|
||||
s.writeResponse(w, &req, resp)
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
switch {
|
||||
case hasStreaming && hasNonStreaming:
|
||||
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
||||
return
|
||||
case !hasStreaming && !hasNonStreaming:
|
||||
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
||||
return
|
||||
case req.Stream && !hasStreaming:
|
||||
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case !req.Stream && !hasNonStreaming:
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
s.writeStreamingResponse(w, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeNonStreamingResponse(w, resp.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <-chan AnthropicChunk) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
chunkData := make(map[string]interface{})
|
||||
chunkData["type"] = chunk.Type
|
||||
|
||||
switch chunk.Type {
|
||||
case "message_start":
|
||||
chunkData["message"] = chunk.Message
|
||||
case "content_block_start":
|
||||
chunkData["index"] = chunk.Index
|
||||
chunkData["content_block"] = chunk.ContentBlock
|
||||
case "content_block_delta":
|
||||
chunkData["index"] = chunk.Index
|
||||
chunkData["delta"] = chunk.Delta
|
||||
case "content_block_stop":
|
||||
chunkData["index"] = chunk.Index
|
||||
case "message_delta":
|
||||
chunkData["delta"] = map[string]interface{}{
|
||||
"stop_reason": chunk.StopReason,
|
||||
"stop_sequence": chunk.StopSequence,
|
||||
}
|
||||
chunkData["usage"] = chunk.Usage
|
||||
case "message_stop":
|
||||
// No additional fields
|
||||
}
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send both event and data lines to match Anthropic API format
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", chunk.Type, chunkBytes); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
response := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"type": resp.Type,
|
||||
"role": resp.Role,
|
||||
"model": resp.Model,
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": resp.Content,
|
||||
},
|
||||
},
|
||||
"stop_reason": resp.StopReason,
|
||||
"usage": resp.Usage,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// AnthropicStreamingResponse creates a streaming response from chunks.
|
||||
func AnthropicStreamingResponse(chunks ...AnthropicChunk) AnthropicResponse {
|
||||
ch := make(chan AnthropicChunk, len(chunks))
|
||||
go func() {
|
||||
for _, chunk := range chunks {
|
||||
ch <- chunk
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
return AnthropicResponse{StreamingChunks: ch}
|
||||
}
|
||||
|
||||
// AnthropicNonStreamingResponse creates a non-streaming response with the given text.
|
||||
func AnthropicNonStreamingResponse(text string) AnthropicResponse {
|
||||
return AnthropicResponse{
|
||||
Response: &AnthropicMessage{
|
||||
ID: fmt.Sprintf("msg-%s", uuid.New().String()[:8]),
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: text,
|
||||
Model: "claude-3-opus-20240229",
|
||||
StopReason: "end_turn",
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicTextChunks creates a complete streaming response with text deltas.
|
||||
// Takes text deltas and creates all required chunks (message_start,
|
||||
// content_block_start, content_block_delta for each delta,
|
||||
// content_block_stop, message_delta, message_stop).
|
||||
func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
|
||||
model := "claude-3-opus-20240229"
|
||||
|
||||
chunks := []AnthropicChunk{
|
||||
{
|
||||
Type: "message_start",
|
||||
Message: AnthropicChunkMessage{
|
||||
ID: messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "content_block_start",
|
||||
Index: 0,
|
||||
ContentBlock: AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: "", // According to Anthropic API spec, text should be empty in content_block_start
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add a delta chunk for each delta
|
||||
for _, delta := range deltas {
|
||||
chunks = append(chunks, AnthropicChunk{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Delta: AnthropicDeltaBlock{
|
||||
Type: "text_delta",
|
||||
Text: delta,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
chunks = append(chunks,
|
||||
AnthropicChunk{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_delta",
|
||||
StopReason: "end_turn",
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_stop",
|
||||
},
|
||||
)
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
|
||||
// Input JSON can be split across multiple deltas, matching Anthropic's
|
||||
// input_json_delta streaming behavior.
|
||||
func AnthropicToolCallChunks(toolName string, inputJSONDeltas ...string) []AnthropicChunk {
|
||||
if len(inputJSONDeltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
if toolName == "" {
|
||||
toolName = "tool"
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
|
||||
model := "claude-3-opus-20240229"
|
||||
toolCallID := fmt.Sprintf("toolu_%s", uuid.New().String()[:8])
|
||||
|
||||
chunks := []AnthropicChunk{
|
||||
{
|
||||
Type: "message_start",
|
||||
Message: AnthropicChunkMessage{
|
||||
ID: messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "content_block_start",
|
||||
Index: 0,
|
||||
ContentBlock: AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
Name: toolName,
|
||||
Input: json.RawMessage("{}"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, delta := range inputJSONDeltas {
|
||||
chunks = append(chunks, AnthropicChunk{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Delta: AnthropicDeltaBlock{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: delta,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
chunks = append(chunks,
|
||||
AnthropicChunk{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_delta",
|
||||
StopReason: "tool_use",
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_stop",
|
||||
},
|
||||
)
|
||||
|
||||
return chunks
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package chattest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestAnthropic_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunks("Hello", " world", "!")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Say hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := model.Stream(ctx, call)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedDeltas := []string{"Hello", " world", "!"}
|
||||
deltaIndex := 0
|
||||
|
||||
var allParts []fantasy.StreamPart
|
||||
for part := range stream {
|
||||
allParts = append(allParts, part)
|
||||
if part.Type == fantasy.StreamPartTypeTextDelta {
|
||||
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
|
||||
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
|
||||
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
|
||||
deltaIndex++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
|
||||
}
|
||||
|
||||
func TestAnthropic_ToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
switch requestCount.Add(1) {
|
||||
case 1:
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicToolCallChunks("get_weather", `{"location":"San Francisco"}`)...,
|
||||
)
|
||||
default:
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunks("The weather in San Francisco is 72F.")...,
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
type weatherInput struct {
|
||||
Location string `json:"location"`
|
||||
}
|
||||
var toolCallCount atomic.Int32
|
||||
weatherTool := fantasy.NewAgentTool(
|
||||
"get_weather",
|
||||
"Get weather for a location.",
|
||||
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolCallCount.Add(1)
|
||||
require.Equal(t, "San Francisco", input.Location)
|
||||
return fantasy.NewTextResponse("72F"), nil
|
||||
},
|
||||
)
|
||||
|
||||
agent := fantasy.NewAgent(
|
||||
model,
|
||||
fantasy.WithSystemPrompt("You are a helpful assistant."),
|
||||
fantasy.WithTools(weatherTool),
|
||||
)
|
||||
|
||||
result, err := agent.Stream(context.Background(), fantasy.AgentStreamCall{
|
||||
Prompt: "What's the weather in San Francisco?",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
|
||||
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
|
||||
}
|
||||
|
||||
func TestAnthropic_NonStreaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicNonStreamingResponse("Response text")
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Test message"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
}
|
||||
|
||||
func TestAnthropic_Streaming_MismatchReturnsErrorPart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicNonStreamingResponse("wrong response type")
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
stream, err := model.Stream(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var streamErr error
|
||||
for part := range stream {
|
||||
if part.Type == fantasy.StreamPartTypeError {
|
||||
streamErr = part.Error
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Error(t, streamErr)
|
||||
require.Contains(t, streamErr.Error(), "500 Internal Server Error")
|
||||
}
|
||||
|
||||
func TestAnthropic_NonStreaming_MismatchReturnsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunks("wrong", " response")...,
|
||||
)
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "500 Internal Server Error")
|
||||
}
|
||||
@@ -0,0 +1,458 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OpenAIHandler handles OpenAI API requests and returns a response.
|
||||
type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
|
||||
|
||||
// OpenAIResponse represents a response to an OpenAI request.
|
||||
// Either StreamingChunks or Response should be set, not both.
|
||||
type OpenAIResponse struct {
|
||||
StreamingChunks <-chan OpenAIChunk
|
||||
Response *OpenAICompletion
|
||||
}
|
||||
|
||||
// OpenAIRequest represents an OpenAI chat completion request.
|
||||
type OpenAIRequest struct {
|
||||
*http.Request
|
||||
Model string `json:"model"`
|
||||
Messages []OpenAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Prompt []interface{} `json:"prompt,omitempty"` // For responses API
|
||||
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
|
||||
Options map[string]interface{} `json:",inline"` //nolint:revive
|
||||
}
|
||||
|
||||
// OpenAIMessage represents a message in an OpenAI request.
|
||||
type OpenAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OpenAIToolCallFunction represents the function details in a tool call.
|
||||
type OpenAIToolCallFunction struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIToolCall represents a tool call in a streaming chunk or completion.
|
||||
type OpenAIToolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function OpenAIToolCallFunction `json:"function,omitempty"`
|
||||
Index int `json:"index,omitempty"` // For streaming deltas
|
||||
}
|
||||
|
||||
// OpenAIChunkChoice represents a choice in a streaming chunk.
|
||||
type OpenAIChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIChunk represents a streaming chunk from OpenAI.
|
||||
type OpenAIChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// OpenAICompletionChoice represents a choice in a completion response.
|
||||
type OpenAICompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message OpenAIMessage `json:"message"`
|
||||
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// OpenAICompletionUsage represents usage information in a completion response.
|
||||
type OpenAICompletionUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// OpenAICompletion represents a non-streaming OpenAI completion response.
|
||||
type OpenAICompletion struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAICompletionChoice `json:"choices"`
|
||||
Usage OpenAICompletionUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// openAIServer is a test server that mocks the OpenAI API.
|
||||
type openAIServer struct {
|
||||
mu sync.Mutex
|
||||
server *httptest.Server
|
||||
handler OpenAIHandler
|
||||
request *OpenAIRequest
|
||||
}
|
||||
|
||||
// NewOpenAI creates a new OpenAI test server with a handler function.
|
||||
// The handler is called for each request and should return either a streaming
|
||||
// response (via channel) or a non-streaming response.
|
||||
// Returns the base URL of the server.
|
||||
func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &openAIServer{
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /chat/completions", s.handleChatCompletions)
|
||||
mux.HandleFunc("POST /responses", s.handleResponses)
|
||||
|
||||
s.server = httptest.NewServer(mux)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.server.Close()
|
||||
})
|
||||
|
||||
return s.server.URL
|
||||
}
|
||||
|
||||
func (s *openAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
var req OpenAIRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.Request = r
|
||||
|
||||
s.mu.Lock()
|
||||
s.request = &req
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handler(&req)
|
||||
s.writeChatCompletionsResponse(w, &req, resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
||||
var req OpenAIRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.Request = r
|
||||
|
||||
s.mu.Lock()
|
||||
s.request = &req
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handler(&req)
|
||||
s.writeResponsesAPIResponse(w, &req, resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
switch {
|
||||
case hasStreaming && hasNonStreaming:
|
||||
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
||||
return
|
||||
case !hasStreaming && !hasNonStreaming:
|
||||
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
||||
return
|
||||
case req.Stream && !hasStreaming:
|
||||
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case !req.Stream && !hasNonStreaming:
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
s.writeChatCompletionsStreaming(w, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeChatCompletionsNonStreaming(w, resp.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
switch {
|
||||
case hasStreaming && hasNonStreaming:
|
||||
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
||||
return
|
||||
case !hasStreaming && !hasNonStreaming:
|
||||
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
||||
return
|
||||
case req.Stream && !hasStreaming:
|
||||
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case !req.Stream && !hasNonStreaming:
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
s.writeResponsesAPIStreaming(w, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeResponsesAPINonStreaming(w, resp.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
choicesData := make([]map[string]interface{}, len(chunk.Choices))
|
||||
for i, choice := range chunk.Choices {
|
||||
choiceData := map[string]interface{}{
|
||||
"index": choice.Index,
|
||||
}
|
||||
if choice.Delta != "" {
|
||||
choiceData["delta"] = map[string]interface{}{
|
||||
"content": choice.Delta,
|
||||
}
|
||||
}
|
||||
if len(choice.ToolCalls) > 0 {
|
||||
// Tool calls come in the delta
|
||||
if choiceData["delta"] == nil {
|
||||
choiceData["delta"] = make(map[string]interface{})
|
||||
}
|
||||
delta, ok := choiceData["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
delta = make(map[string]interface{})
|
||||
choiceData["delta"] = delta
|
||||
}
|
||||
delta["tool_calls"] = choice.ToolCalls
|
||||
}
|
||||
if choice.FinishReason != "" {
|
||||
choiceData["finish_reason"] = choice.FinishReason
|
||||
}
|
||||
choicesData[i] = choiceData
|
||||
}
|
||||
|
||||
chunkData := map[string]interface{}{
|
||||
"id": chunk.ID,
|
||||
"object": chunk.Object,
|
||||
"created": chunk.Created,
|
||||
"model": chunk.Model,
|
||||
"choices": choicesData,
|
||||
}
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
itemIDs := make(map[int]string)
|
||||
|
||||
for chunk := range chunks {
|
||||
// Responses API sends one event per choice
|
||||
for outputIndex, choice := range chunk.Choices {
|
||||
if choice.Index != 0 {
|
||||
outputIndex = choice.Index
|
||||
}
|
||||
itemID, found := itemIDs[outputIndex]
|
||||
if !found {
|
||||
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
|
||||
itemIDs[outputIndex] = itemID
|
||||
}
|
||||
|
||||
chunkData := map[string]interface{}{
|
||||
"type": "response.output_text.delta",
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"created": chunk.Created,
|
||||
"model": chunk.Model,
|
||||
"content_index": 0,
|
||||
"delta": choice.Delta,
|
||||
}
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
// Convert all choices to output format
|
||||
outputs := make([]map[string]interface{}, len(resp.Choices))
|
||||
for i, choice := range resp.Choices {
|
||||
outputs[i] = map[string]interface{}{
|
||||
"id": uuid.New().String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": choice.Message.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"object": "response",
|
||||
"created": resp.Created,
|
||||
"model": resp.Model,
|
||||
"output": outputs,
|
||||
"usage": resp.Usage,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// OpenAIStreamingResponse creates a streaming response from chunks.
|
||||
func OpenAIStreamingResponse(chunks ...OpenAIChunk) OpenAIResponse {
|
||||
ch := make(chan OpenAIChunk, len(chunks))
|
||||
go func() {
|
||||
for _, chunk := range chunks {
|
||||
ch <- chunk
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
return OpenAIResponse{StreamingChunks: ch}
|
||||
}
|
||||
|
||||
// OpenAINonStreamingResponse creates a non-streaming response with the given text.
|
||||
func OpenAINonStreamingResponse(text string) OpenAIResponse {
|
||||
return OpenAIResponse{
|
||||
Response: &OpenAICompletion{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "gpt-4",
|
||||
Choices: []OpenAICompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: OpenAICompletionUsage{
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAITextChunks creates streaming chunks with text deltas.
|
||||
// Each delta string becomes a separate chunk with a single choice.
|
||||
// Returns a slice of chunks, one per delta, with each choice having its index (0, 1, 2, ...).
|
||||
func OpenAITextChunks(deltas ...string) []OpenAIChunk {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
chunkID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8])
|
||||
now := time.Now().Unix()
|
||||
chunks := make([]OpenAIChunk, len(deltas))
|
||||
|
||||
for i, delta := range deltas {
|
||||
chunks[i] = OpenAIChunk{
|
||||
ID: chunkID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: now,
|
||||
Model: "gpt-4",
|
||||
Choices: []OpenAIChunkChoice{
|
||||
{
|
||||
Index: i,
|
||||
Delta: delta,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// OpenAIToolCallChunk creates a streaming chunk with a tool call.
|
||||
// Takes the tool name and arguments JSON string, creates a tool call for choice index 0.
|
||||
func OpenAIToolCallChunk(toolName, arguments string) OpenAIChunk {
|
||||
return OpenAIChunk{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "gpt-4",
|
||||
Choices: []OpenAIChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ToolCalls: []OpenAIToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
ID: fmt.Sprintf("call_%s", uuid.New().String()[:8]),
|
||||
Type: "function",
|
||||
Function: OpenAIToolCallFunction{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
package chattest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestOpenAI_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
append(
|
||||
append(
|
||||
chattest.OpenAITextChunks("Hello", "Hi"),
|
||||
chattest.OpenAITextChunks(" world", " there")...,
|
||||
),
|
||||
chattest.OpenAITextChunks("!", "!")...,
|
||||
)...,
|
||||
)
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Say hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := model.Stream(ctx, call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We expect chunks in order: one choice per chunk
|
||||
// So we get: "Hello" (choice 0), "Hi" (choice 1), " world" (choice 0), " there" (choice 1), "!" (choice 0), "!" (choice 1)
|
||||
expectedDeltas := []string{"Hello", "Hi", " world", " there", "!", "!"}
|
||||
deltaIndex := 0
|
||||
|
||||
for part := range stream {
|
||||
if part.Type == fantasy.StreamPartTypeTextDelta {
|
||||
// Verify we're getting deltas in the expected order
|
||||
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
|
||||
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
|
||||
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
|
||||
deltaIndex++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we received all expected deltas
|
||||
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d", len(expectedDeltas), deltaIndex)
|
||||
}
|
||||
|
||||
func TestOpenAI_Streaming_ResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
append(
|
||||
append(
|
||||
chattest.OpenAITextChunks("First", "Second"),
|
||||
chattest.OpenAITextChunks(" output", " output")...,
|
||||
),
|
||||
chattest.OpenAITextChunks("!", "!")...,
|
||||
)...,
|
||||
)
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server (responses API)
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Say hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := model.Stream(ctx, call)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parts []fantasy.StreamPart
|
||||
for part := range stream {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
|
||||
// Verify we received the chunks in order
|
||||
require.Greater(t, len(parts), 0)
|
||||
|
||||
// Extract text deltas from parts and verify they match expected chunks in order
|
||||
// We expect: "First", " output", "!" for choice 0, and "Second", " output", "!" for choice 1
|
||||
var allDeltas []string
|
||||
for _, part := range parts {
|
||||
if part.Type == fantasy.StreamPartTypeTextDelta {
|
||||
allDeltas = append(allDeltas, part.Delta)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we received deltas (responses API may handle multiple choices differently)
|
||||
// If we got text deltas, verify the content
|
||||
if len(allDeltas) > 0 {
|
||||
allText := ""
|
||||
for _, delta := range allDeltas {
|
||||
allText += delta
|
||||
}
|
||||
require.Contains(t, allText, "First")
|
||||
require.Contains(t, allText, "Second")
|
||||
require.Contains(t, allText, "output")
|
||||
require.Contains(t, allText, "!")
|
||||
} else {
|
||||
// If no text deltas, at least verify we got some parts (may be different format)
|
||||
require.Greater(t, len(parts), 0, "Expected at least one stream part")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_CompletionsAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAINonStreamingResponse("First response")
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server (completions API)
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Test message"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
}
|
||||
|
||||
func TestOpenAI_ToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
switch requestCount.Add(1) {
|
||||
case 1:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("get_weather", `{"location":"San Francisco"}`),
|
||||
)
|
||||
default:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("The weather in San Francisco is 72F.")...,
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
type weatherInput struct {
|
||||
Location string `json:"location"`
|
||||
}
|
||||
var toolCallCount atomic.Int32
|
||||
weatherTool := fantasy.NewAgentTool(
|
||||
"get_weather",
|
||||
"Get weather for a location.",
|
||||
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolCallCount.Add(1)
|
||||
require.Equal(t, "San Francisco", input.Location)
|
||||
return fantasy.NewTextResponse("72F"), nil
|
||||
},
|
||||
)
|
||||
|
||||
agent := fantasy.NewAgent(
|
||||
model,
|
||||
fantasy.WithSystemPrompt("You are a helpful assistant."),
|
||||
fantasy.WithTools(weatherTool),
|
||||
)
|
||||
|
||||
result, err := agent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
Prompt: "What's the weather in San Francisco?",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
|
||||
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_ResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAINonStreamingResponse("First output")
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server (responses API)
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Test message"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
}
|
||||
|
||||
func TestOpenAI_Streaming_MismatchReturnsErrorPart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAINonStreamingResponse("wrong response type")
|
||||
})
|
||||
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
stream, err := model.Stream(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var streamErr error
|
||||
for part := range stream {
|
||||
if part.Type == fantasy.StreamPartTypeError {
|
||||
streamErr = part.Error
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Error(t, streamErr)
|
||||
require.Contains(t, streamErr.Error(), "non-streaming response for streaming request")
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_MismatchReturnsError_CompletionsAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
|
||||
})
|
||||
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "streaming response for non-streaming request")
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_MismatchReturnsError_ResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
|
||||
})
|
||||
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "streaming response for non-streaming request")
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
|
||||
// result payload.
|
||||
func toolResponse(result map[string]any) fantasy.ToolResponse {
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.NewTextResponse("{}")
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 || value == "" {
|
||||
return ""
|
||||
}
|
||||
if utf8.RuneCountInString(value) <= maxLen {
|
||||
return value
|
||||
}
|
||||
|
||||
runes := []rune(value)
|
||||
if maxLen > len(runes) {
|
||||
maxLen = len(runes)
|
||||
}
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
@@ -0,0 +1,426 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/namesgenerator"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// buildPollInterval is how often we check if the workspace
|
||||
// build has completed.
|
||||
buildPollInterval = 2 * time.Second
|
||||
// buildTimeout is the maximum time to wait for a workspace
|
||||
// build to complete before giving up.
|
||||
buildTimeout = 10 * time.Minute
|
||||
// agentConnectTimeout is the maximum time to wait for the
|
||||
// workspace agent to become reachable after a successful build.
|
||||
agentConnectTimeout = 2 * time.Minute
|
||||
// agentRetryInterval is how often we retry connecting to the
|
||||
// workspace agent.
|
||||
agentRetryInterval = 2 * time.Second
|
||||
// agentAttemptTimeout is the timeout for a single connection
|
||||
// attempt to the workspace agent during the retry loop.
|
||||
agentAttemptTimeout = 5 * time.Second
|
||||
// agentPingTimeout is the timeout for a single agent ping
|
||||
// when checking whether an existing workspace is alive.
|
||||
agentPingTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// CreateWorkspaceFn creates a workspace for the given owner.
|
||||
type CreateWorkspaceFn func(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
req codersdk.CreateWorkspaceRequest,
|
||||
) (codersdk.Workspace, error)
|
||||
|
||||
// AgentConnFunc provides access to workspace agent connections.
|
||||
type AgentConnFunc func(
|
||||
ctx context.Context,
|
||||
agentID uuid.UUID,
|
||||
) (workspacesdk.AgentConn, func(), error)
|
||||
|
||||
// CreateWorkspaceOptions configures the create_workspace tool.
|
||||
type CreateWorkspaceOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
CreateFn CreateWorkspaceFn
|
||||
AgentConnFn AgentConnFunc
|
||||
WorkspaceMu *sync.Mutex
|
||||
}
|
||||
|
||||
type createWorkspaceArgs struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Parameters map[string]string `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// CreateWorkspace returns a tool that creates a new workspace from a
|
||||
// template. The tool is idempotent: if the chat already has a
|
||||
// workspace that is building or running, it returns the existing
|
||||
// workspace instead of creating a new one. A mutex prevents parallel
|
||||
// calls from creating duplicate workspaces.
|
||||
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"create_workspace",
|
||||
"Create a new workspace from a template. Requires a "+
|
||||
"template_id (from list_templates). Optionally provide "+
|
||||
"a name and parameter values (from read_template). "+
|
||||
"If no name is given, one will be generated. "+
|
||||
"This tool is idempotent — if the chat already has a "+
|
||||
"workspace that is building or running, the existing "+
|
||||
"workspace is returned.",
|
||||
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.CreateFn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
|
||||
}
|
||||
|
||||
templateIDStr := strings.TrimSpace(args.TemplateID)
|
||||
if templateIDStr == "" {
|
||||
return fantasy.NewTextErrorResponse("template_id is required; use list_templates to find one"), nil
|
||||
}
|
||||
templateID, err := uuid.Parse(templateIDStr)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("invalid template_id: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Serialize workspace creation to prevent parallel
|
||||
// tool calls from creating duplicate workspaces.
|
||||
if options.WorkspaceMu != nil {
|
||||
options.WorkspaceMu.Lock()
|
||||
defer options.WorkspaceMu.Unlock()
|
||||
}
|
||||
|
||||
// Check for an existing workspace on the chat.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
existing, done, existErr := checkExistingWorkspace(
|
||||
ctx, options.DB, options.ChatID,
|
||||
options.AgentConnFn,
|
||||
)
|
||||
if existErr != nil {
|
||||
return fantasy.NewTextErrorResponse(existErr.Error()), nil
|
||||
}
|
||||
if done {
|
||||
return toolResponse(existing), nil
|
||||
}
|
||||
}
|
||||
|
||||
ownerID := options.OwnerID
|
||||
|
||||
// Set up dbauthz context for DB lookups.
|
||||
if options.DB != nil {
|
||||
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
|
||||
if ownerErr != nil {
|
||||
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
|
||||
}
|
||||
ctx = ownerCtx
|
||||
}
|
||||
|
||||
createReq := codersdk.CreateWorkspaceRequest{
|
||||
TemplateID: templateID,
|
||||
}
|
||||
|
||||
// Resolve workspace name.
|
||||
name := strings.TrimSpace(args.Name)
|
||||
if name == "" {
|
||||
seed := "workspace"
|
||||
if options.DB != nil {
|
||||
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
|
||||
seed = t.Name
|
||||
}
|
||||
}
|
||||
name = generatedWorkspaceName(seed)
|
||||
} else if err := codersdk.NameValid(name); err != nil {
|
||||
name = generatedWorkspaceName(name)
|
||||
}
|
||||
createReq.Name = name
|
||||
|
||||
// Map parameters.
|
||||
for k, v := range args.Parameters {
|
||||
createReq.RichParameterValues = append(
|
||||
createReq.RichParameterValues,
|
||||
codersdk.WorkspaceBuildParameter{Name: k, Value: v},
|
||||
)
|
||||
}
|
||||
|
||||
workspace, err := options.CreateFn(ctx, ownerID, createReq)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Wait for the build to complete and the agent to
|
||||
// come online so subsequent tools can use the
|
||||
// workspace immediately.
|
||||
if options.DB != nil {
|
||||
if err := waitForBuild(ctx, options.DB, workspace.ID); err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("workspace build failed: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Look up the first agent so we can link it to the chat.
|
||||
workspaceAgentID := uuid.Nil
|
||||
if options.DB != nil {
|
||||
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if agentErr == nil && len(agents) > 0 {
|
||||
workspaceAgentID = agents[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
// Persist workspace + agent association on the chat.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
|
||||
ID: options.ChatID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
WorkspaceAgentID: uuid.NullUUID{
|
||||
UUID: workspaceAgentID,
|
||||
Valid: workspaceAgentID != uuid.Nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for the agent to come online.
|
||||
if workspaceAgentID != uuid.Nil && options.AgentConnFn != nil {
|
||||
if err := waitForAgent(ctx, options.AgentConnFn, workspaceAgentID); err != nil {
|
||||
// Non-fatal: the workspace was created
|
||||
// successfully, the agent just isn't ready
|
||||
// yet. The model can retry.
|
||||
return toolResponse(map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
"agent_status": "not_ready",
|
||||
"agent_error": err.Error(),
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
}), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// checkExistingWorkspace checks whether the chat already has a usable
|
||||
// workspace. Returns the result map and true if the caller should
|
||||
// return early (workspace exists and is alive or building). Returns
|
||||
// false if the caller should proceed with creation (workspace is dead
|
||||
// or missing).
|
||||
func checkExistingWorkspace(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentConnFn AgentConnFunc,
|
||||
) (map[string]any, bool, error) {
|
||||
chat, err := db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, false, xerrors.Errorf("load chat: %w", err)
|
||||
}
|
||||
if !chat.WorkspaceID.Valid {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// Check if workspace still exists.
|
||||
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
// Workspace was deleted — allow creation.
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, xerrors.Errorf("load workspace: %w", err)
|
||||
}
|
||||
|
||||
// Check the latest build status.
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
|
||||
if err != nil {
|
||||
// Can't determine status — allow creation.
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch job.JobStatus {
|
||||
case database.ProvisionerJobStatusPending,
|
||||
database.ProvisionerJobStatusRunning:
|
||||
// Build is in progress — wait for it instead of
|
||||
// creating a new workspace.
|
||||
if err := waitForBuild(ctx, db, ws.ID); err != nil {
|
||||
return nil, false, xerrors.Errorf(
|
||||
"existing workspace build failed: %w", err,
|
||||
)
|
||||
}
|
||||
return map[string]any{
|
||||
"created": false,
|
||||
"workspace_name": ws.Name,
|
||||
"status": "already_exists",
|
||||
"message": "workspace was already being built and is now ready",
|
||||
}, true, nil
|
||||
|
||||
case database.ProvisionerJobStatusSucceeded:
|
||||
// Build succeeded — check if agent is reachable.
|
||||
if chat.WorkspaceAgentID.Valid && agentConnFn != nil {
|
||||
pingCtx, cancel := context.WithTimeout(
|
||||
ctx, agentPingTimeout,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
conn, release, connErr := agentConnFn(
|
||||
pingCtx, chat.WorkspaceAgentID.UUID,
|
||||
)
|
||||
if connErr == nil {
|
||||
release()
|
||||
_ = conn
|
||||
return map[string]any{
|
||||
"created": false,
|
||||
"workspace_name": ws.Name,
|
||||
"status": "already_exists",
|
||||
"message": "workspace is already running and reachable",
|
||||
}, true, nil
|
||||
}
|
||||
// Agent unreachable — workspace is dead, allow
|
||||
// creation.
|
||||
}
|
||||
// No agent ID or no conn func — allow creation.
|
||||
return nil, false, nil
|
||||
|
||||
default:
|
||||
// Failed, canceled, etc — allow creation.
|
||||
return nil, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// waitForBuild polls the workspace's latest build until it
|
||||
// completes or the context expires.
|
||||
func waitForBuild(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
workspaceID uuid.UUID,
|
||||
) error {
|
||||
buildCtx, cancel := context.WithTimeout(ctx, buildTimeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(buildPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
|
||||
buildCtx, workspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
switch job.JobStatus {
|
||||
case database.ProvisionerJobStatusSucceeded:
|
||||
return nil
|
||||
case database.ProvisionerJobStatusFailed:
|
||||
errMsg := "build failed"
|
||||
if job.Error.Valid {
|
||||
errMsg = job.Error.String
|
||||
}
|
||||
return xerrors.New(errMsg)
|
||||
case database.ProvisionerJobStatusCanceled:
|
||||
return xerrors.New("build was canceled")
|
||||
case database.ProvisionerJobStatusPending,
|
||||
database.ProvisionerJobStatusRunning,
|
||||
database.ProvisionerJobStatusCanceling:
|
||||
// Still in progress — keep waiting.
|
||||
default:
|
||||
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-buildCtx.Done():
|
||||
return xerrors.Errorf(
|
||||
"timed out waiting for workspace build: %w",
|
||||
buildCtx.Err(),
|
||||
)
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForAgent retries connecting to the workspace agent until it
|
||||
// succeeds or the timeout expires.
|
||||
func waitForAgent(
|
||||
ctx context.Context,
|
||||
agentConnFn AgentConnFunc,
|
||||
agentID uuid.UUID,
|
||||
) error {
|
||||
agentCtx, cancel := context.WithTimeout(ctx, agentConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(agentRetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastErr error
|
||||
for {
|
||||
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, agentAttemptTimeout)
|
||||
conn, release, err := agentConnFn(attemptCtx, agentID)
|
||||
attemptCancel()
|
||||
if err == nil {
|
||||
release()
|
||||
_ = conn
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
select {
|
||||
case <-agentCtx.Done():
|
||||
return xerrors.Errorf(
|
||||
"timed out waiting for workspace agent: %w",
|
||||
lastErr,
|
||||
)
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generatedWorkspaceName(seed string) string {
|
||||
base := codersdk.UsernameFrom(strings.TrimSpace(strings.ToLower(seed)))
|
||||
if strings.TrimSpace(base) == "" {
|
||||
base = "workspace"
|
||||
}
|
||||
|
||||
suffix := strings.ReplaceAll(uuid.NewString(), "-", "")[:4]
|
||||
if len(base) > 27 {
|
||||
base = strings.Trim(base[:27], "-")
|
||||
}
|
||||
if base == "" {
|
||||
base = "workspace"
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s-%s", base, suffix)
|
||||
if err := codersdk.NameValid(name); err == nil {
|
||||
return name
|
||||
}
|
||||
return namesgenerator.NameDigitWith("-")
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type EditFilesOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
type EditFilesArgs struct {
|
||||
Files []workspacesdk.FileEdits `json:"files"`
|
||||
}
|
||||
|
||||
func EditFiles(options EditFilesOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"edit_files",
|
||||
"Perform search-and-replace edits on one or more files in the workspace."+
|
||||
" Each file can have multiple edits applied atomically.",
|
||||
func(ctx context.Context, args EditFilesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeEditFilesTool(ctx, conn, args)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeEditFilesTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args EditFilesArgs,
|
||||
) (fantasy.ToolResponse, error) {
|
||||
if len(args.Files) == 0 {
|
||||
return fantasy.NewTextErrorResponse("files is required"), nil
|
||||
}
|
||||
|
||||
if err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: args.Files}); err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return toolResponse(map[string]any{"ok": true}), nil
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultExecuteTimeout = 60 * time.Second
|
||||
chatAgentEnvVar = "CODER_CHAT_AGENT"
|
||||
gitAuthRequiredPrefix = "CODER_GITAUTH_REQUIRED:"
|
||||
authRequiredResultReason = "authentication_required"
|
||||
)
|
||||
|
||||
type ExecuteOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
DefaultTimeout time.Duration
|
||||
}
|
||||
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args ExecuteArgs,
|
||||
defaultTimeout time.Duration,
|
||||
) fantasy.ToolResponse {
|
||||
if args.Command == "" {
|
||||
return fantasy.NewTextErrorResponse("command is required")
|
||||
}
|
||||
|
||||
timeout := defaultTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultExecuteTimeout
|
||||
}
|
||||
if args.TimeoutSeconds != nil {
|
||||
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
|
||||
}
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
output, exitCode, err := runCommand(cmdCtx, conn, args.Command)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error())
|
||||
}
|
||||
return toolResponse(map[string]any{
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
})
|
||||
}
|
||||
|
||||
func runCommand(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
command string,
|
||||
) (string, int, error) {
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
defer sshClient.Close()
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
defer session.Close()
|
||||
if err := session.Setenv(chatAgentEnvVar, "true"); err != nil {
|
||||
return "", 0, xerrors.Errorf("set %s: %w", chatAgentEnvVar, err)
|
||||
}
|
||||
|
||||
resultCh := make(chan struct {
|
||||
output string
|
||||
exitCode int
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
go func() {
|
||||
output, err := session.CombinedOutput(command)
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
var exitErr *ssh.ExitError
|
||||
if xerrors.As(err, &exitErr) {
|
||||
exitCode = exitErr.ExitStatus()
|
||||
} else {
|
||||
exitCode = 1
|
||||
}
|
||||
}
|
||||
resultCh <- struct {
|
||||
output string
|
||||
exitCode int
|
||||
err error
|
||||
}{
|
||||
output: string(output),
|
||||
exitCode: exitCode,
|
||||
err: err,
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Close()
|
||||
return "", 0, ctx.Err()
|
||||
case result := <-resultCh:
|
||||
return result.output, result.exitCode, result.err
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
)
|
||||
|
||||
// ListTemplatesOptions configures the list_templates tool.
|
||||
type ListTemplatesOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
}
|
||||
|
||||
type listTemplatesArgs struct {
|
||||
Query string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
// ListTemplates returns a tool that lists available workspace templates.
|
||||
// The agent uses this to discover templates before creating a workspace.
|
||||
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"list_templates",
|
||||
"List available workspace templates. Optionally filter by a "+
|
||||
"search query matching template name or description. "+
|
||||
"Use this to find a template before creating a workspace.",
|
||||
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.DB == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
|
||||
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
filterParams := database.GetTemplatesWithFilterParams{
|
||||
Deleted: false,
|
||||
Deprecated: sql.NullBool{
|
||||
Bool: false,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
query := strings.TrimSpace(args.Query)
|
||||
if query != "" {
|
||||
filterParams.FuzzyName = query
|
||||
}
|
||||
|
||||
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
items := make([]map[string]any, 0, len(templates))
|
||||
for _, t := range templates {
|
||||
item := map[string]any{
|
||||
"id": t.ID.String(),
|
||||
"name": t.Name,
|
||||
}
|
||||
if display := strings.TrimSpace(t.DisplayName); display != "" {
|
||||
item["display_name"] = display
|
||||
}
|
||||
if desc := strings.TrimSpace(t.Description); desc != "" {
|
||||
item["description"] = truncateRunes(desc, 200)
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"templates": items,
|
||||
"count": len(items),
|
||||
}), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// asOwner sets up a dbauthz context for the given owner so that
|
||||
// subsequent database calls are scoped to what that user can access.
|
||||
func asOwner(ctx context.Context, db database.Store, ownerID uuid.UUID) (context.Context, error) {
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, db, ownerID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return ctx, xerrors.Errorf("load user authorization: %w", err)
|
||||
}
|
||||
return dbauthz.As(ctx, actor), nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type ReadFileOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
type ReadFileArgs struct {
|
||||
Path string `json:"path"`
|
||||
Offset *int64 `json:"offset,omitempty"`
|
||||
Limit *int64 `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
func ReadFile(options ReadFileOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"read_file",
|
||||
"Read a file from the workspace.",
|
||||
func(ctx context.Context, args ReadFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeReadFileTool(ctx, conn, args)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeReadFileTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args ReadFileArgs,
|
||||
) (fantasy.ToolResponse, error) {
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path is required"), nil
|
||||
}
|
||||
|
||||
offset := int64(0)
|
||||
limit := int64(0)
|
||||
if args.Offset != nil {
|
||||
offset = *args.Offset
|
||||
}
|
||||
if args.Limit != nil {
|
||||
limit = *args.Limit
|
||||
}
|
||||
|
||||
reader, mimeType, err := conn.ReadFile(ctx, args.Path, offset, limit)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"content": string(data),
|
||||
"mime_type": mimeType,
|
||||
}), nil
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// ReadTemplateOptions configures the read_template tool.
|
||||
type ReadTemplateOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
}
|
||||
|
||||
type readTemplateArgs struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
}
|
||||
|
||||
// ReadTemplate returns a tool that retrieves details about a specific
|
||||
// template, including its configurable rich parameters. The agent
|
||||
// uses this after list_templates and before create_workspace.
|
||||
func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"read_template",
|
||||
"Get details about a workspace template, including its "+
|
||||
"configurable parameters. Use this after finding a "+
|
||||
"template with list_templates and before creating a "+
|
||||
"workspace with create_workspace.",
|
||||
func(ctx context.Context, args readTemplateArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.DB == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
|
||||
templateIDStr := strings.TrimSpace(args.TemplateID)
|
||||
if templateIDStr == "" {
|
||||
return fantasy.NewTextErrorResponse("template_id is required"), nil
|
||||
}
|
||||
templateID, err := uuid.Parse(templateIDStr)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("invalid template_id: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
template, err := options.DB.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("template not found"), nil
|
||||
}
|
||||
|
||||
params, err := options.DB.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("failed to get template parameters: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
templateInfo := map[string]any{
|
||||
"id": template.ID.String(),
|
||||
"name": template.Name,
|
||||
"active_version_id": template.ActiveVersionID.String(),
|
||||
}
|
||||
if display := strings.TrimSpace(template.DisplayName); display != "" {
|
||||
templateInfo["display_name"] = display
|
||||
}
|
||||
if desc := strings.TrimSpace(template.Description); desc != "" {
|
||||
templateInfo["description"] = desc
|
||||
}
|
||||
|
||||
paramList := make([]map[string]any, 0, len(params))
|
||||
for _, p := range params {
|
||||
param := map[string]any{
|
||||
"name": p.Name,
|
||||
"type": p.Type,
|
||||
"required": p.Required,
|
||||
}
|
||||
if display := strings.TrimSpace(p.DisplayName); display != "" {
|
||||
param["display_name"] = display
|
||||
}
|
||||
if desc := strings.TrimSpace(p.Description); desc != "" {
|
||||
param["description"] = truncateRunes(desc, 300)
|
||||
}
|
||||
if p.DefaultValue != "" {
|
||||
param["default"] = p.DefaultValue
|
||||
}
|
||||
if p.Mutable {
|
||||
param["mutable"] = true
|
||||
}
|
||||
if p.Ephemeral {
|
||||
param["ephemeral"] = true
|
||||
}
|
||||
if p.FormType != "" {
|
||||
param["form_type"] = string(p.FormType)
|
||||
}
|
||||
if len(p.Options) > 0 && string(p.Options) != "null" && string(p.Options) != "[]" {
|
||||
var opts []map[string]any
|
||||
if err := json.Unmarshal(p.Options, &opts); err == nil && len(opts) > 0 {
|
||||
param["options"] = opts
|
||||
}
|
||||
}
|
||||
if p.ValidationRegex != "" {
|
||||
param["validation_regex"] = p.ValidationRegex
|
||||
}
|
||||
if p.ValidationMin.Valid {
|
||||
param["validation_min"] = p.ValidationMin.Int32
|
||||
}
|
||||
if p.ValidationMax.Valid {
|
||||
param["validation_max"] = p.ValidationMax.Int32
|
||||
}
|
||||
|
||||
paramList = append(paramList, param)
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"template": templateInfo,
|
||||
"parameters": paramList,
|
||||
}), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type WriteFileOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
type WriteFileArgs struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func WriteFile(options WriteFileOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"write_file",
|
||||
"Write a file to the workspace.",
|
||||
func(ctx context.Context, args WriteFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeWriteFileTool(ctx, conn, args)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeWriteFileTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args WriteFileArgs,
|
||||
) (fantasy.ToolResponse, error) {
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path is required"), nil
|
||||
}
|
||||
|
||||
if err := conn.WriteFile(ctx, args.Path, strings.NewReader(args.Content)); err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return toolResponse(map[string]any{"ok": true}), nil
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
coderHomeInstructionDir = ".coder"
|
||||
coderHomeInstructionFile = "AGENTS.md"
|
||||
maxInstructionFileBytes = 64 * 1024
|
||||
)
|
||||
|
||||
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
|
||||
|
||||
func readHomeInstructionFile(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
) (content string, sourcePath string, truncated bool, err error) {
|
||||
if conn == nil {
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
coderDir, err := conn.LS(ctx, "", workspacesdk.LSRequest{
|
||||
Path: []string{coderHomeInstructionDir},
|
||||
Relativity: workspacesdk.LSRelativityHome,
|
||||
})
|
||||
if err != nil {
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("list home instruction directory: %w", err)
|
||||
}
|
||||
|
||||
var filePath string
|
||||
for _, entry := range coderDir.Contents {
|
||||
if entry.IsDir {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Name), coderHomeInstructionFile) {
|
||||
filePath = strings.TrimSpace(entry.AbsolutePathString)
|
||||
break
|
||||
}
|
||||
}
|
||||
if filePath == "" {
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
reader, _, err := conn.ReadFile(
|
||||
ctx,
|
||||
filePath,
|
||||
0,
|
||||
maxInstructionFileBytes+1,
|
||||
)
|
||||
if err != nil {
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("read home instruction file: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
raw, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", "", false, xerrors.Errorf("read home instruction bytes: %w", err)
|
||||
}
|
||||
|
||||
truncated = int64(len(raw)) > maxInstructionFileBytes
|
||||
if truncated {
|
||||
raw = raw[:maxInstructionFileBytes]
|
||||
}
|
||||
|
||||
content = sanitizeInstructionMarkdown(string(raw))
|
||||
if content == "" {
|
||||
return "", "", truncated, nil
|
||||
}
|
||||
|
||||
return content, filePath, truncated, nil
|
||||
}
|
||||
|
||||
func sanitizeInstructionMarkdown(content string) string {
|
||||
content = strings.ReplaceAll(content, "\r\n", "\n")
|
||||
content = strings.ReplaceAll(content, "\r", "\n")
|
||||
content = markdownCommentPattern.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
//nolint:revive // Boolean indicates content was truncated.
|
||||
func formatHomeInstruction(content string, sourcePath string, truncated bool) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
sourcePath = strings.TrimSpace(sourcePath)
|
||||
if sourcePath == "" {
|
||||
sourcePath = "~/.coder/AGENTS.md"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
_, _ = b.WriteString("<coder-home-instructions>\n")
|
||||
_, _ = b.WriteString("Source: ")
|
||||
_, _ = b.WriteString(sourcePath)
|
||||
if truncated {
|
||||
_, _ = b.WriteString(" (truncated to 64KiB)")
|
||||
}
|
||||
_, _ = b.WriteString("\n\n")
|
||||
_, _ = b.WriteString(content)
|
||||
_, _ = b.WriteString("\n</coder-home-instructions>")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isCodersdkStatusCode(err error, statusCode int) bool {
|
||||
var sdkErr *codersdk.Error
|
||||
if !xerrors.As(err, &sdkErr) {
|
||||
return false
|
||||
}
|
||||
return sdkErr.StatusCode() == statusCode
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package chatd //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
)
|
||||
|
||||
func TestSanitizeInstructionMarkdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "line 1\r\n<!-- hidden -->\r\nline 2\r\n"
|
||||
require.Equal(t, "line 1\n\nline 2", sanitizeInstructionMarkdown(input))
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
|
||||
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
return workspacesdk.LSResponse{}, codersdk.NewTestError(404, "POST", "/api/v0/list-directory")
|
||||
},
|
||||
)
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, content)
|
||||
require.Empty(t, sourcePath)
|
||||
require.False(t, truncated)
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
|
||||
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
return workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
|
||||
}},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader("base\n<!-- hidden -->\nlocal")),
|
||||
"text/markdown",
|
||||
nil,
|
||||
)
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "base\n\nlocal", content)
|
||||
require.Equal(t, "/home/coder/.coder/AGENTS.md", sourcePath)
|
||||
require.False(t, truncated)
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileTruncates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
content := strings.Repeat("a", maxInstructionFileBytes+8)
|
||||
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
|
||||
}},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(io.NopCloser(strings.NewReader(content)), "text/markdown", nil)
|
||||
|
||||
got, _, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.True(t, truncated)
|
||||
require.Len(t, got, maxInstructionFileBytes)
|
||||
}
|
||||
|
||||
func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "base"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "hello"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := chatprompt.InsertSystem(prompt, "project rules")
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[1].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
|
||||
|
||||
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "project rules", part.Text)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package chatd
|
||||
|
||||
// DefaultSystemPrompt is used for new chats when no deployment override is
|
||||
// configured.
|
||||
const DefaultSystemPrompt = `You are the Coder agent — an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
|
||||
Use the instructions below and the tools available to you to assist User.
|
||||
|
||||
IMPORTANT — obey every rule in this prompt before anything else.
|
||||
Do EXACTLY what the User asked, never more, never less.
|
||||
|
||||
<behavior>
|
||||
You MUST execute AS MANY TOOLS to help the user accomplish their task.
|
||||
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
|
||||
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
|
||||
DO NOT ask the user for clarification - just use your tools.
|
||||
</behavior>
|
||||
|
||||
<personality>
|
||||
Analytical — You break problems into measurable steps, relying on tool output and data rather than intuition.
|
||||
Organized — You structure every interaction with clear tags, TODO lists, and section boundaries.
|
||||
Precision-Oriented — You insist on exact formatting, package-manager choice, and rule adherence.
|
||||
Efficiency-Focused — You minimize chatter, run tasks in parallel, and favor small, complete answers.
|
||||
Clarity-Seeking — You ask for missing details instead of guessing, avoiding any ambiguity.
|
||||
</personality>
|
||||
|
||||
<communication>
|
||||
Be concise, direct, and to the point.
|
||||
NO emojis unless the User explicitly asks for them.
|
||||
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
|
||||
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
|
||||
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
|
||||
Default to the project's existing package manager / tooling; never substitute without confirmation.
|
||||
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
Mimic the style of the User's messages.
|
||||
Do not remind the User you are happy to help.
|
||||
Do not inherently assume the User is correct; they may be making assumptions.
|
||||
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
|
||||
Do not act with sycophantic flattery or over-the-top enthusiasm.
|
||||
|
||||
Here are examples to demonstrate appropriate communication style and level of verbosity:
|
||||
|
||||
<example>
|
||||
user: find me a good issue to work on
|
||||
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: work on this issue <url>
|
||||
...assistant does work...
|
||||
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: how does X work in <popular-repository-name>?
|
||||
assistant: Let me take a look at the code...
|
||||
[tool calls to investigate the repository]
|
||||
</example>
|
||||
</communication>
|
||||
|
||||
<collaboration>
|
||||
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
|
||||
- What specific aspect they want to focus on
|
||||
- Their goals and vision for the changes
|
||||
- Their preferences for approach or style
|
||||
- What problems they're trying to solve
|
||||
|
||||
Don't assume what needs to be done - collaborate to define the scope together.
|
||||
</collaboration>`
|
||||
@@ -0,0 +1,512 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
|
||||
|
||||
const (
|
||||
subagentAwaitPollInterval = 200 * time.Millisecond
|
||||
defaultSubagentWaitTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
type spawnAgentArgs struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type waitAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type messageAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
Message string `json:"message"`
|
||||
Interrupt bool `json:"interrupt,omitempty"`
|
||||
}
|
||||
|
||||
type closeAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
"spawn_agent",
|
||||
"Spawn a delegated child agent chat from the root chat.",
|
||||
func(ctx context.Context, args spawnAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
if parent.ParentChatID.Valid {
|
||||
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
|
||||
}
|
||||
|
||||
parent, err := p.db.GetChatByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
childChat, err := p.createChildSubagentChat(
|
||||
ctx,
|
||||
parent,
|
||||
args.Prompt,
|
||||
args.Title,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": childChat.ID.String(),
|
||||
"title": childChat.Title,
|
||||
"status": string(childChat.Status),
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
fantasy.NewAgentTool(
|
||||
"wait_agent",
|
||||
"Wait until a delegated descendant agent reaches a non-streaming status.",
|
||||
func(ctx context.Context, args waitAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
targetChatID, err := parseSubagentToolChatID(args.ChatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
timeout := defaultSubagentWaitTimeout
|
||||
if args.TimeoutSeconds != nil {
|
||||
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
targetChat, report, err := p.awaitSubagentCompletion(
|
||||
ctx,
|
||||
parent.ID,
|
||||
targetChatID,
|
||||
timeout,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": targetChatID.String(),
|
||||
"title": targetChat.Title,
|
||||
"report": report,
|
||||
"status": string(targetChat.Status),
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
fantasy.NewAgentTool(
|
||||
"message_agent",
|
||||
"Send a message to a delegated descendant agent. Use wait_agent to collect a response.",
|
||||
func(ctx context.Context, args messageAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
targetChatID, err := parseSubagentToolChatID(args.ChatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
busyBehavior := SendMessageBusyBehaviorQueue
|
||||
if args.Interrupt {
|
||||
busyBehavior = SendMessageBusyBehaviorInterrupt
|
||||
}
|
||||
targetChat, err := p.sendSubagentMessage(
|
||||
ctx,
|
||||
parent.ID,
|
||||
targetChatID,
|
||||
args.Message,
|
||||
busyBehavior,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": targetChatID.String(),
|
||||
"title": targetChat.Title,
|
||||
"status": string(targetChat.Status),
|
||||
"interrupted": args.Interrupt,
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
fantasy.NewAgentTool(
|
||||
"close_agent",
|
||||
"Interrupt a delegated descendant agent immediately.",
|
||||
func(ctx context.Context, args closeAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
targetChatID, err := parseSubagentToolChatID(args.ChatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
targetChat, err := p.closeSubagent(
|
||||
ctx,
|
||||
parent.ID,
|
||||
targetChatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": targetChatID.String(),
|
||||
"title": targetChat.Title,
|
||||
"terminated": true,
|
||||
"status": string(targetChat.Status),
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
|
||||
chatID, err := uuid.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.New("chat_id must be a valid UUID")
|
||||
}
|
||||
return chatID, nil
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChat(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
) (database.Chat, error) {
|
||||
if parent.ParentChatID.Valid {
|
||||
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return database.Chat{}, xerrors.New("prompt is required")
|
||||
}
|
||||
|
||||
title = strings.TrimSpace(title)
|
||||
if title == "" {
|
||||
title = subagentFallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
if parent.LastModelConfigID == uuid.Nil {
|
||||
return database.Chat{}, xerrors.New("parent chat model config id is required")
|
||||
}
|
||||
|
||||
child, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
WorkspaceAgentID: parent.WorkspaceAgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
|
||||
}
|
||||
|
||||
return child, nil
|
||||
}
|
||||
|
||||
func (p *Server) sendSubagentMessage(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
message string,
|
||||
busyBehavior SendMessageBusyBehavior,
|
||||
) (database.Chat, error) {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return database.Chat{}, xerrors.New("message is required")
|
||||
}
|
||||
|
||||
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
|
||||
ChatID: targetChatID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: message}},
|
||||
BusyBehavior: busyBehavior,
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
return sendResult.Chat, nil
|
||||
}
|
||||
|
||||
func (p *Server) awaitSubagentCompletion(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
timeout time.Duration,
|
||||
) (database.Chat, string, error) {
|
||||
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, "", ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
if timeout <= 0 {
|
||||
timeout = defaultSubagentWaitTimeout
|
||||
}
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
ticker := time.NewTicker(subagentAwaitPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
targetChat, report, done, checkErr := p.checkSubagentCompletion(ctx, targetChatID)
|
||||
if checkErr != nil {
|
||||
return database.Chat{}, "", checkErr
|
||||
}
|
||||
if done {
|
||||
if targetChat.Status == database.ChatStatusError {
|
||||
reason := strings.TrimSpace(report)
|
||||
if reason == "" {
|
||||
reason = "agent reached error status"
|
||||
}
|
||||
return database.Chat{}, "", xerrors.New(reason)
|
||||
}
|
||||
return targetChat, report, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-timer.C:
|
||||
return database.Chat{}, "", xerrors.New("timed out waiting for delegated subagent completion")
|
||||
case <-ctx.Done():
|
||||
return database.Chat{}, "", ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) closeSubagent(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
targetChat, err := p.db.GetChatByID(ctx, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
|
||||
}
|
||||
|
||||
if targetChat.Status == database.ChatStatusWaiting {
|
||||
return targetChat, nil
|
||||
}
|
||||
|
||||
updatedChat := p.InterruptChat(ctx, targetChat)
|
||||
if updatedChat.Status != database.ChatStatusWaiting {
|
||||
return database.Chat{}, xerrors.New("set target chat waiting")
|
||||
}
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) checkSubagentCompletion(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
) (database.Chat, string, bool, error) {
|
||||
chat, err := p.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", false, xerrors.Errorf("get chat: %w", err)
|
||||
}
|
||||
|
||||
if chat.Status == database.ChatStatusPending || chat.Status == database.ChatStatusRunning {
|
||||
return database.Chat{}, "", false, nil
|
||||
}
|
||||
|
||||
report, err := latestSubagentAssistantMessage(ctx, p.db, chatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", false, err
|
||||
}
|
||||
|
||||
return chat, report, true, nil
|
||||
}
|
||||
|
||||
func latestSubagentAssistantMessage(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
) (string, error) {
|
||||
messages, err := store.GetChatMessagesByChatID(ctx, chatID)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
|
||||
sort.Slice(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != string(fantasy.MessageRoleAssistant) ||
|
||||
message.Visibility == database.ChatMessageVisibilityModel {
|
||||
continue
|
||||
}
|
||||
|
||||
content, parseErr := chatprompt.ParseContent(message.Role, message.Content)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(contentBlocksToText(content))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func isSubagentDescendant(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
ancestorChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
) (bool, error) {
|
||||
if ancestorChatID == targetChatID {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
descendants, err := listSubagentDescendants(ctx, store, ancestorChatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, descendant := range descendants {
|
||||
if descendant.ID == targetChatID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func listSubagentDescendants(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
) ([]database.Chat, error) {
|
||||
queue := []uuid.UUID{chatID}
|
||||
visited := map[uuid.UUID]struct{}{chatID: {}}
|
||||
|
||||
out := make([]database.Chat, 0)
|
||||
for len(queue) > 0 {
|
||||
parentChatID := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
children, err := store.ListChildChatsByParentID(ctx, parentChatID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list child chats for %s: %w", parentChatID, err)
|
||||
}
|
||||
|
||||
for _, child := range children {
|
||||
if _, ok := visited[child.ID]; ok {
|
||||
continue
|
||||
}
|
||||
visited[child.ID] = struct{}{}
|
||||
out = append(out, child)
|
||||
queue = append(queue, child.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func subagentFallbackChatTitle(message string) string {
|
||||
const maxWords = 6
|
||||
const maxRunes = 80
|
||||
|
||||
words := strings.Fields(message)
|
||||
if len(words) == 0 {
|
||||
return "New Chat"
|
||||
}
|
||||
|
||||
truncated := false
|
||||
if len(words) > maxWords {
|
||||
words = words[:maxWords]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "..."
|
||||
}
|
||||
|
||||
return subagentTruncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
func subagentTruncateRunes(value string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
runes := []rune(value)
|
||||
if len(runes) <= maxRunes {
|
||||
return value
|
||||
}
|
||||
|
||||
return string(runes[:maxRunes])
|
||||
}
|
||||
|
||||
func toolJSONResponse(result map[string]any) fantasy.ToolResponse {
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.NewTextResponse("{}")
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
|
||||
const titleGenerationPrompt = "Generate a concise title (max 8 words, under 128 characters) for " +
|
||||
"the user's first message. Return plain text only — no quotes, no emoji, " +
|
||||
"no markdown, no special characters."
|
||||
|
||||
// maybeGenerateChatTitle generates an AI title for the chat when
|
||||
// appropriate (first user message, no assistant reply yet, and the
|
||||
// current title is either empty or still the fallback truncation).
|
||||
// It is a best-effort operation that logs and swallows errors.
|
||||
func (p *Server) maybeGenerateChatTitle(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
messages []database.ChatMessage,
|
||||
model fantasy.LanguageModel,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
title, err := generateTitle(titleCtx, model, input)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to generate chat title",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if title == "" || title == chat.Title {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to update generated chat title",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
chat.Title = title
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
|
||||
}
|
||||
|
||||
// generateTitle calls the model with a title-generation system prompt
|
||||
// and returns the normalized result.
|
||||
func generateTitle(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
input string,
|
||||
) (string, error) {
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: titleGenerationPrompt},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: input},
|
||||
},
|
||||
},
|
||||
}
|
||||
toolChoice := fantasy.ToolChoiceNone
|
||||
response, err := model.Generate(ctx, fantasy.Call{
|
||||
Prompt: prompt,
|
||||
ToolChoice: &toolChoice,
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate title text: %w", err)
|
||||
}
|
||||
|
||||
title := normalizeTitleOutput(contentBlocksToText(response.Content))
|
||||
if title == "" {
|
||||
return "", xerrors.New("generated title was empty")
|
||||
}
|
||||
return title, nil
|
||||
}
|
||||
|
||||
// titleInput returns the first user message text and whether title
|
||||
// generation should proceed. It returns false when the chat already
|
||||
// has assistant/tool replies, has more than one visible user message,
|
||||
// or the current title doesn't look like a candidate for replacement.
|
||||
func titleInput(
|
||||
chat database.Chat,
|
||||
messages []database.ChatMessage,
|
||||
) (string, bool) {
|
||||
userCount := 0
|
||||
firstUserText := ""
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Visibility == database.ChatMessageVisibilityModel {
|
||||
continue
|
||||
}
|
||||
|
||||
switch message.Role {
|
||||
case string(fantasy.MessageRoleAssistant), string(fantasy.MessageRoleTool):
|
||||
return "", false
|
||||
case string(fantasy.MessageRoleUser):
|
||||
userCount++
|
||||
if firstUserText == "" {
|
||||
parsed, err := chatprompt.ParseContent(
|
||||
string(fantasy.MessageRoleUser), message.Content,
|
||||
)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
firstUserText = strings.TrimSpace(
|
||||
contentBlocksToText(parsed),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if userCount != 1 || firstUserText == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
currentTitle := strings.TrimSpace(chat.Title)
|
||||
if currentTitle == "" {
|
||||
return firstUserText, true
|
||||
}
|
||||
|
||||
if currentTitle != fallbackChatTitle(firstUserText) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return firstUserText, true
|
||||
}
|
||||
|
||||
func normalizeTitleOutput(title string) string {
|
||||
title = strings.TrimSpace(title)
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
title = strings.Trim(title, "\"'`")
|
||||
title = strings.Join(strings.Fields(title), " ")
|
||||
return truncateRunes(title, 80)
|
||||
}
|
||||
|
||||
func fallbackChatTitle(message string) string {
|
||||
const maxWords = 6
|
||||
const maxRunes = 80
|
||||
|
||||
words := strings.Fields(message)
|
||||
if len(words) == 0 {
|
||||
return "New Chat"
|
||||
}
|
||||
|
||||
truncated := false
|
||||
if len(words) > maxWords {
|
||||
words = words[:maxWords]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "…"
|
||||
}
|
||||
|
||||
return truncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
// contentBlocksToText concatenates the text parts of content blocks
|
||||
// into a single space-separated string.
|
||||
func contentBlocksToText(content []fantasy.Content) string {
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(textBlock.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(value)
|
||||
if len(runes) <= maxLen {
|
||||
return value
|
||||
}
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
+3093
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
+62
-2
@@ -49,6 +49,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/boundaryusage"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -238,6 +239,9 @@ type Options struct {
|
||||
SSHConfig codersdk.SSHConfigResponse
|
||||
|
||||
HTTPClient *http.Client
|
||||
// ChatRemotePartsProvider provides cross-replica message_part streaming.
|
||||
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
|
||||
ChatRemotePartsProvider chatd.RemotePartsProvider
|
||||
|
||||
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
StatsBatcher workspacestats.Batcher
|
||||
@@ -588,7 +592,6 @@ func New(options *Options) *API {
|
||||
var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
|
||||
var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
|
||||
buildUsageChecker.Store(&noopUsageChecker)
|
||||
|
||||
api := &API{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -754,6 +757,17 @@ func New(options *Options) *API {
|
||||
panic("failed to setup server tailnet: " + err.Error())
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
RemotePartsProvider: options.ChatRemotePartsProvider,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
})
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
|
||||
@@ -1085,6 +1099,48 @@ func New(options *Options) *API {
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/chats", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents),
|
||||
)
|
||||
r.Get("/", api.listChats)
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/providers", func(r chi.Router) {
|
||||
r.Get("/", api.listChatProviders)
|
||||
r.Post("/", api.createChatProvider)
|
||||
r.Route("/{providerConfig}", func(r chi.Router) {
|
||||
r.Patch("/", api.updateChatProvider)
|
||||
r.Delete("/", api.deleteChatProvider)
|
||||
})
|
||||
})
|
||||
r.Route("/model-configs", func(r chi.Router) {
|
||||
r.Get("/", api.listChatModelConfigs)
|
||||
r.Post("/", api.createChatModelConfig)
|
||||
r.Route("/{modelConfig}", func(r chi.Router) {
|
||||
r.Patch("/", api.updateChatModelConfig)
|
||||
r.Delete("/", api.deleteChatModelConfig)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Delete("/", api.deleteChat)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
r.Patch("/messages/{message}", api.patchChatMessage)
|
||||
r.Get("/stream", api.streamChat)
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Get("/diff-status", api.getChatDiffStatus)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
r.Post("/promote", api.promoteChatQueuedMessage)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/mcp", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -1902,6 +1958,8 @@ type API struct {
|
||||
// dbRolluper rolls up template usage stats from raw agent and app
|
||||
// stats. This is used to provide insights in the WebUI.
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -1930,8 +1988,10 @@ func (api *API) Close() error {
|
||||
case <-timer.C:
|
||||
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
|
||||
}
|
||||
|
||||
api.dbRolluper.Close()
|
||||
if err := api.chatDaemon.Close(); err != nil {
|
||||
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
|
||||
}
|
||||
api.metricsCache.Close()
|
||||
if api.updateChecker != nil {
|
||||
api.updateChecker.Close()
|
||||
|
||||
@@ -6,17 +6,20 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package db2sdk
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
@@ -1050,3 +1053,345 @@ func jsonOrEmptyMap(rawMessage pqtype.NullRawMessage) map[string]any {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
||||
modelConfigID := &m.ModelConfigID.UUID
|
||||
if !m.ModelConfigID.Valid {
|
||||
modelConfigID = nil
|
||||
}
|
||||
msg := codersdk.ChatMessage{
|
||||
ID: m.ID,
|
||||
ChatID: m.ChatID,
|
||||
ModelConfigID: modelConfigID,
|
||||
CreatedAt: m.CreatedAt,
|
||||
Role: m.Role,
|
||||
}
|
||||
if m.Content.Valid {
|
||||
parts, err := chatMessageParts(m.Role, m.Content)
|
||||
if err == nil {
|
||||
msg.Content = parts
|
||||
}
|
||||
}
|
||||
usage := chatMessageUsage(m)
|
||||
if usage != nil {
|
||||
msg.Usage = usage
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// chatMessageUsage builds a ChatMessageUsage from the database row,
|
||||
// returning nil when no token fields are populated.
|
||||
func chatMessageUsage(m database.ChatMessage) *codersdk.ChatMessageUsage {
|
||||
inputTokens := nullInt64Ptr(m.InputTokens)
|
||||
outputTokens := nullInt64Ptr(m.OutputTokens)
|
||||
totalTokens := nullInt64Ptr(m.TotalTokens)
|
||||
reasoningTokens := nullInt64Ptr(m.ReasoningTokens)
|
||||
cacheCreationTokens := nullInt64Ptr(m.CacheCreationTokens)
|
||||
cacheReadTokens := nullInt64Ptr(m.CacheReadTokens)
|
||||
contextLimit := nullInt64Ptr(m.ContextLimit)
|
||||
|
||||
if inputTokens == nil && outputTokens == nil && totalTokens == nil &&
|
||||
reasoningTokens == nil && cacheCreationTokens == nil &&
|
||||
cacheReadTokens == nil && contextLimit == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &codersdk.ChatMessageUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
ReasoningTokens: reasoningTokens,
|
||||
CacheCreationTokens: cacheCreationTokens,
|
||||
CacheReadTokens: cacheReadTokens,
|
||||
ContextLimit: contextLimit,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatQueuedMessage converts a queued message to its SDK representation.
|
||||
func ChatQueuedMessage(message database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
|
||||
parts, err := chatMessageParts(string(fantasy.MessageRoleUser), pqtype.NullRawMessage{
|
||||
RawMessage: message.Content,
|
||||
Valid: len(message.Content) > 0,
|
||||
})
|
||||
if err != nil {
|
||||
parts = nil
|
||||
}
|
||||
|
||||
return codersdk.ChatQueuedMessage{
|
||||
ID: message.ID,
|
||||
ChatID: message.ChatID,
|
||||
Content: parts,
|
||||
CreatedAt: message.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatQueuedMessages converts a slice of database queued messages
|
||||
// to their SDK representation.
|
||||
func ChatQueuedMessages(messages []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
|
||||
out := make([]codersdk.ChatQueuedMessage, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
out = append(out, ChatQueuedMessage(message))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
|
||||
switch role {
|
||||
case string(fantasy.MessageRoleSystem):
|
||||
content, err := parseSystemContent(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: content,
|
||||
}}, nil
|
||||
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
|
||||
content, err := parseContentBlocks(role, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
if role == string(fantasy.MessageRoleAssistant) {
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
}
|
||||
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(content))
|
||||
for i, block := range content {
|
||||
part := contentBlockToPart(block)
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
if part.Type == codersdk.ChatMessagePartTypeReasoning {
|
||||
part.Title = ""
|
||||
if i < len(rawBlocks) {
|
||||
part.Title = reasoningStoredTitle(rawBlocks[i])
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
case string(fantasy.MessageRoleTool):
|
||||
results, err := parseToolResults(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(results))
|
||||
for _, result := range results {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolName: result.ToolName,
|
||||
Result: result.Result,
|
||||
IsError: result.IsError,
|
||||
})
|
||||
}
|
||||
return parts, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var content string
|
||||
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
|
||||
return "", xerrors.Errorf("parse system content: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if role == string(fantasy.MessageRoleUser) {
|
||||
var text string
|
||||
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
|
||||
return []fantasy.Content{
|
||||
fantasy.TextContent{Text: text},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
var blocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
|
||||
return nil, xerrors.Errorf("parse content blocks: %w", err)
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
decoded, err := fantasy.UnmarshalContent(block)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse content block: %w", err)
|
||||
}
|
||||
content = append(content, decoded)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// toolResultRow is used only for extracting top-level fields from
|
||||
// persisted tool result JSON. The result payload is kept as raw JSON.
|
||||
type toolResultRow struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
func parseToolResults(raw pqtype.NullRawMessage) ([]toolResultRow, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var results []toolResultRow
|
||||
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
|
||||
return nil, xerrors.Errorf("parse tool results: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func reasoningStoredTitle(raw json.RawMessage) string {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
Title string `json:"title"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return ""
|
||||
}
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(envelope.Data.Title)
|
||||
}
|
||||
|
||||
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case *fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case *fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case fantasy.ToolResultContent:
|
||||
return chatprompt.ToolResultToPart(
|
||||
value.ToolCallID,
|
||||
value.ToolName,
|
||||
toolResultOutputToRawJSON(value.Result),
|
||||
toolResultOutputIsError(value.Result),
|
||||
)
|
||||
case *fantasy.ToolResultContent:
|
||||
return chatprompt.ToolResultToPart(
|
||||
value.ToolCallID,
|
||||
value.ToolName,
|
||||
toolResultOutputToRawJSON(value.Result),
|
||||
toolResultOutputIsError(value.Result),
|
||||
)
|
||||
default:
|
||||
return codersdk.ChatMessagePart{}
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultOutputToRawJSON(output fantasy.ToolResultOutputContent) json.RawMessage {
|
||||
switch v := output.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
if v.Error != nil {
|
||||
data, _ := json.Marshal(map[string]any{"error": v.Error.Error()})
|
||||
return data
|
||||
}
|
||||
return json.RawMessage(`{"error":""}`)
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
raw := json.RawMessage(v.Text)
|
||||
if json.Valid(raw) {
|
||||
return raw
|
||||
}
|
||||
data, _ := json.Marshal(map[string]any{"output": v.Text})
|
||||
return data
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"data": v.Data,
|
||||
"mime_type": v.MediaType,
|
||||
"text": v.Text,
|
||||
})
|
||||
return data
|
||||
default:
|
||||
return json.RawMessage(`{}`)
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultOutputIsError(output fantasy.ToolResultOutputContent) bool {
|
||||
_, ok := output.(fantasy.ToolResultOutputContentError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.Int64
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -435,3 +437,132 @@ func TestAIBridgeInterception(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assistantContent, err := json.Marshal([]fantasy.Content{
|
||||
fantasy.ReasoningContent{
|
||||
Text: "Plan migration",
|
||||
ProviderMetadata: fantasy.ProviderMetadata{
|
||||
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
||||
ItemID: "reasoning-1",
|
||||
Summary: []string{"Plan migration"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
message := db2sdk.ChatMessage(database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
CreatedAt: time.Now(),
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: assistantContent,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, message.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
|
||||
require.Equal(t, "Plan migration", message.Content[0].Text)
|
||||
require.Empty(t, message.Content[0].Title)
|
||||
}
|
||||
|
||||
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
|
||||
Text: "Verify schema updates, then apply changes in order.",
|
||||
ProviderMetadata: fantasy.ProviderMetadata{
|
||||
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
||||
ItemID: "reasoning-1",
|
||||
Summary: []string{
|
||||
"**Metadata-derived title**\n\nLonger explanation.",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var envelope map[string]any
|
||||
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
|
||||
dataValue, ok := envelope["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
dataValue["title"] = "Persisted stream title"
|
||||
|
||||
encodedReasoning, err := json.Marshal(envelope)
|
||||
require.NoError(t, err)
|
||||
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
|
||||
require.NoError(t, err)
|
||||
|
||||
message := db2sdk.ChatMessage(database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
CreatedAt: time.Now(),
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: assistantContent,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, message.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
|
||||
require.Equal(t, "Persisted stream title", message.Content[0].Title)
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rawContent, err := json.Marshal([]fantasy.Content{
|
||||
fantasy.TextContent{Text: "queued text"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Content: rawContent,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
require.Len(t, queued.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
|
||||
require.Equal(t, "queued text", queued.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_FallsBackToTextForLegacyContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("legacy_string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Content: json.RawMessage(`"legacy queued text"`),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
require.Len(t, queued.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
|
||||
require.Equal(t, "legacy queued text", queued.Content[0].Text)
|
||||
})
|
||||
|
||||
t.Run("malformed_payload", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := json.RawMessage(`{"unexpected":"shape"}`)
|
||||
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Content: raw,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
require.Empty(t, queued.Content)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -453,6 +453,7 @@ var (
|
||||
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
|
||||
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
@@ -1484,6 +1485,15 @@ func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.Prov
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *querier) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
// AcquireChat is a system-level operation used by the chat processor.
|
||||
// Authorization is done at the system level, not per-user.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.AcquireChat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AcquireLock(ctx context.Context, id int64) error {
|
||||
return q.db.AcquireLock(ctx, id)
|
||||
}
|
||||
@@ -1712,6 +1722,17 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
|
||||
return q.db.DeleteAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteAllChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
@@ -1736,6 +1757,66 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
|
||||
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
// Authorize delete on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatMessagesByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -2304,6 +2385,131 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
|
||||
return q.db.GetAuthorizationUserRoles(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetChatByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
if len(chatIDs) == 0 {
|
||||
return []database.ChatDiffStatus{}, nil
|
||||
}
|
||||
|
||||
actor, ok := ActorFromContext(ctx)
|
||||
if ok && actor.Type == rbac.SubjectTypeSystemRestricted {
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
for _, chatID := range chatIDs {
|
||||
// Authorize read on each parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
// ChatMessages are authorized through their parent Chat.
|
||||
// We need to fetch the message first to get its chat_id.
|
||||
msg, err := q.db.GetChatMessageByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
// Authorize read on the parent chat.
|
||||
_, err = q.GetChatByID(ctx, msg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetChatModelConfigByProviderAndModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByProvider(ctx, provider)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -2361,6 +2567,13 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
return q.db.GetDERPMeshKey(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetDefaultChatModelConfig(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
|
||||
return fetch(q.log, q.auth, func(ctx context.Context, _ any) (database.Organization, error) {
|
||||
return q.db.GetDefaultOrganization(ctx)
|
||||
@@ -2401,6 +2614,20 @@ func (q *querier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.C
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetEligibleProvisionerDaemonsByProvisionerJobIDs)(ctx, provisionerJobIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
|
||||
}
|
||||
@@ -3100,6 +3327,14 @@ func (q *querier) GetRuntimeConfig(ctx context.Context, key string) (string, err
|
||||
return q.db.GetRuntimeConfig(ctx, key)
|
||||
}
|
||||
|
||||
func (q *querier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
// GetStaleChats is a system-level operation used by the chat processor for recovery.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetStaleChats(ctx, staleThreshold)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
@@ -4223,6 +4458,47 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
|
||||
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.InsertChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.InsertChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -4829,6 +5105,14 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
|
||||
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChatsByRootID)(ctx, rootChatID)
|
||||
}
|
||||
|
||||
func (q *querier) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChildChatsByParentID)(ctx, parentChatID)
|
||||
}
|
||||
|
||||
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
|
||||
}
|
||||
@@ -4909,6 +5193,17 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
|
||||
return q.db.PaginatedOrganizationMembers(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
return q.db.PopNextQueuedMessage(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
template, err := q.db.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
@@ -4987,6 +5282,13 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
|
||||
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UnsetDefaultChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
|
||||
return database.AIBridgeInterception{}, err
|
||||
@@ -5001,6 +5303,91 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
|
||||
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
// Authorize update on the parent chat of the edited message.
|
||||
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
chat, err := q.db.GetChatByID(ctx, msg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.UpdateChatMessageByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.UpdateChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.UpdateChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
// UpdateChatStatus is used by the chat processor to change chat status.
|
||||
// It should be called with system context.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace is manually implemented for chat tables and may not be
|
||||
// present on every wrapped store interface yet.
|
||||
chatWorkspaceUpdater, ok := q.db.(interface {
|
||||
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
|
||||
})
|
||||
if !ok {
|
||||
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
|
||||
}
|
||||
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -6056,6 +6443,30 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.UpsertChatDiffStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.UpsertChatDiffStatusReference(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
@@ -6347,18 +6758,12 @@ func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg databas
|
||||
return q.CountConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeInterceptions(ctx, arg)
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) {
|
||||
// TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.CountAIBridgeInterceptions(ctx, arg)
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, _ rbac.PreparedAuthorized) ([]string, error) {
|
||||
|
||||
@@ -170,6 +170,7 @@ func TestDBAuthzRecursive(t *testing.T) {
|
||||
Groups: []string{},
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
preparedAuthorizedType := reflect.TypeOf((*rbac.PreparedAuthorized)(nil)).Elem()
|
||||
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
|
||||
var ins []reflect.Value
|
||||
ctx := dbauthz.As(context.Background(), actor)
|
||||
@@ -177,7 +178,13 @@ func TestDBAuthzRecursive(t *testing.T) {
|
||||
ins = append(ins, reflect.ValueOf(ctx))
|
||||
method := reflect.TypeOf(q).Method(i)
|
||||
for i := 2; i < method.Type.NumIn(); i++ {
|
||||
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
|
||||
inType := method.Type.In(i)
|
||||
if inType.Implements(preparedAuthorizedType) {
|
||||
ins = append(ins, reflect.ValueOf(emptyPreparedAuthorized{}))
|
||||
continue
|
||||
}
|
||||
|
||||
ins = append(ins, reflect.New(inType).Elem())
|
||||
}
|
||||
if method.Name == "InTx" ||
|
||||
method.Name == "Ping" ||
|
||||
@@ -364,6 +371,371 @@ func (s *MethodTestSuite) TestConnectionLogs() {
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("AcquireChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.AcquireChatParams{
|
||||
StartedAt: dbtime.Now(),
|
||||
WorkerID: uuid.New(),
|
||||
}
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().AcquireChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("DeleteAllChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteAllChatQueuedMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatMessagesByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.DeleteChatMessagesAfterIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 123,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
args := database.DeleteChatQueuedMessageParams{ID: 123, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
|
||||
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDiffStatusByChatID(gomock.Any(), chat.ID).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusesByChatIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
ids := []uuid.UUID{chatA.ID, chatB.ID}
|
||||
diffStatusA := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatA.ID})
|
||||
diffStatusB := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatB.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chatA.ID).Return(chatA, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chatB.ID).Return(chatB, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDiffStatusesByChatIDs(gomock.Any(), ids).Return([]database.ChatDiffStatus{diffStatusA, diffStatusB}, nil).AnyTimes()
|
||||
check.Args(ids).
|
||||
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
|
||||
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(msg.ID).Asserts(chat, policy.ActionRead).Returns(msg)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetChatModelConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
|
||||
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetDefaultChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigByProviderAndModel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
args := database.GetChatModelConfigByProviderAndModelParams{
|
||||
Provider: config.Provider,
|
||||
Model: config.Model,
|
||||
}
|
||||
dbm.EXPECT().GetChatModelConfigByProviderAndModel(gomock.Any(), args).Return(config, nil).AnyTimes()
|
||||
check.Args(args).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerName := "test-provider"
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
|
||||
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
|
||||
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), c1.OwnerID).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(c1.OwnerID).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
}))
|
||||
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
qms := []database.ChatQueuedMessage{testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatQueuedMessages(gomock.Any(), chat.ID).Return(qms, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(qms)
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("ListChatsByRootID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
rootChatID := uuid.New()
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
|
||||
dbm.EXPECT().ListChatsByRootID(gomock.Any(), rootChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(rootChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("ListChildChatsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
parentChatID := uuid.New()
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
|
||||
dbm.EXPECT().ListChildChatsByParentID(gomock.Any(), parentChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(parentChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
threshold := dbtime.Now()
|
||||
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
|
||||
dbm.EXPECT().GetStaleChats(gomock.Any(), threshold).Return(chats, nil).AnyTimes()
|
||||
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
|
||||
}))
|
||||
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{})
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
}))
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
|
||||
}))
|
||||
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatQueuedMessageParams{ChatID: chat.ID})
|
||||
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatQueuedMessage(gomock.Any(), arg).Return(qm, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(qm)
|
||||
}))
|
||||
s.Run("InsertChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatModelConfigParams{
|
||||
Provider: "test-provider",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
Enabled: true,
|
||||
}
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{Provider: arg.Provider, Model: arg.Model, DisplayName: arg.DisplayName, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatProviderParams{
|
||||
Provider: "test-provider",
|
||||
DisplayName: "Test Provider",
|
||||
APIKey: "test-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().PopNextQueuedMessage(gomock.Any(), chat.ID).Return(qm, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(qm)
|
||||
}))
|
||||
s.Run("UpdateChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: "Updated title",
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
arg := database.UpdateChatMessageByIDParams{
|
||||
ID: msg.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"blocks":[{"type":"text","text":"updated"}]}`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
updated := testutil.Fake(s.T(), faker, database.ChatMessage{ID: msg.ID, ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatMessageByID(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updated)
|
||||
}))
|
||||
s.Run("UpdateChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
arg := database.UpdateChatModelConfigParams{
|
||||
ID: config.ID,
|
||||
Provider: "updated-provider",
|
||||
Model: "updated-model",
|
||||
DisplayName: "Updated Model",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
arg := database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: "Updated Provider",
|
||||
APIKey: "updated-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("UpdateChatStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UnsetDefaultChatModelConfigs(gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
now := dbtime.Now()
|
||||
arg := database.UpsertChatDiffStatusParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
|
||||
PullRequestState: sql.NullString{String: "open", Valid: true},
|
||||
ChangesRequested: false,
|
||||
Additions: 10,
|
||||
Deletions: 5,
|
||||
ChangedFiles: 2,
|
||||
RefreshedAt: now,
|
||||
StaleAt: now.Add(time.Hour),
|
||||
}
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertChatDiffStatus(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("UpsertChatDiffStatusReference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
|
||||
GitBranch: "feature/test",
|
||||
GitRemoteOrigin: "origin",
|
||||
StaleAt: dbtime.Now().Add(time.Hour),
|
||||
}
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
s.Run("GetFileByHashAndCreator", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
f := testutil.Fake(s.T(), faker, database.File{})
|
||||
|
||||
@@ -104,6 +104,14 @@ func (m queryMetricsStore) DeleteOrganization(ctx context.Context, id uuid.UUID)
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AcquireChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("AcquireChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
|
||||
@@ -156,6 +164,7 @@ func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentMetadata").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -311,6 +320,14 @@ func (m queryMetricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uui
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllChatQueuedMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteAllChatQueuedMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteAllChatQueuedMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllTailnetTunnels(ctx, arg)
|
||||
@@ -335,6 +352,54 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatMessagesAfterID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesAfterID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatQueuedMessage").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
|
||||
@@ -902,6 +967,126 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatByIDForUpdate(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatByIDForUpdate").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByIDForUpdate").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDiffStatusByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
m.queryLatencies.WithLabelValues("GetChatDiffStatusesByChatIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusesByChatIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesForPromptByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesForPromptByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigByProviderAndModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigByProviderAndModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByProviderAndModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatQueuedMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatQueuedMessages").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -958,6 +1143,14 @@ func (m queryMetricsStore) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetDefaultChatModelConfig(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetDefaultChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetDefaultChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetDefaultOrganization(ctx)
|
||||
@@ -1022,6 +1215,22 @@ func (m queryMetricsStore) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatModelConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
|
||||
@@ -1718,6 +1927,14 @@ func (m queryMetricsStore) GetRuntimeConfig(ctx context.Context, key string) (st
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetStaleChats(ctx, staleThreshold)
|
||||
m.queryLatencies.WithLabelValues("GetStaleChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetStaleChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetPeers(ctx, id)
|
||||
@@ -2710,6 +2927,46 @@ func (m queryMetricsStore) InsertAuditLog(ctx context.Context, arg database.Inse
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatModelConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatQueuedMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertCryptoKey(ctx, arg)
|
||||
@@ -3246,6 +3503,22 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatsByRootID(ctx, rootChatID)
|
||||
m.queryLatencies.WithLabelValues("ListChatsByRootID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatsByRootID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChildChatsByParentID(ctx, parentChatID)
|
||||
m.queryLatencies.WithLabelValues("ListChildChatsByParentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChildChatsByParentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
|
||||
@@ -3326,6 +3599,14 @@ func (m queryMetricsStore) PaginatedOrganizationMembers(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.PopNextQueuedMessage(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("PopNextQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "PopNextQueuedMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
|
||||
@@ -3398,6 +3679,14 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnsetDefaultChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("UnsetDefaultChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnsetDefaultChatModelConfigs").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, arg)
|
||||
@@ -3414,6 +3703,62 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMessageByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatMessageByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMessageByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatModelConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateCryptoKeyDeletesAt(ctx, arg)
|
||||
@@ -4125,6 +4470,22 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDiffStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatusReference(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDiffStatusReference").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatusReference").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
|
||||
@@ -44,6 +44,21 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcquireChat mocks base method.
|
||||
func (m *MockStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcquireChat", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcquireChat indicates an expected call of AcquireChat.
|
||||
func (mr *MockStoreMockRecorder) AcquireChat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireChat", reflect.TypeOf((*MockStore)(nil).AcquireChat), ctx, arg)
|
||||
}
|
||||
|
||||
// AcquireLock mocks base method.
|
||||
func (m *MockStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -469,6 +484,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(ctx, userID any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteAllChatQueuedMessages mocks base method.
|
||||
func (m *MockStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllChatQueuedMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllChatQueuedMessages indicates an expected call of DeleteAllChatQueuedMessages.
|
||||
func (mr *MockStoreMockRecorder) DeleteAllChatQueuedMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).DeleteAllChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// DeleteAllTailnetTunnels mocks base method.
|
||||
func (m *MockStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -511,6 +540,90 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteChatByID mocks base method.
|
||||
func (m *MockStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatByID indicates an expected call of DeleteChatByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatByID", reflect.TypeOf((*MockStore)(nil).DeleteChatByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatMessagesAfterID mocks base method.
|
||||
func (m *MockStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatMessagesAfterID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatMessagesAfterID indicates an expected call of DeleteChatMessagesAfterID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatMessagesByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatMessagesByChatID indicates an expected call of DeleteChatMessagesByChatID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatMessagesByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatModelConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID indicates an expected call of DeleteChatModelConfigByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatModelConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatProviderByID mocks base method.
|
||||
func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatQueuedMessage mocks base method.
|
||||
func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatQueuedMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatQueuedMessage indicates an expected call of DeleteChatQueuedMessage.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteCryptoKey mocks base method.
|
||||
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1649,6 +1762,231 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspacesAndAgentsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspacesAndAgentsByOwnerID), ctx, ownerID, prepared)
|
||||
}
|
||||
|
||||
// GetChatByID mocks base method.
|
||||
func (m *MockStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatByID indicates an expected call of GetChatByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByID", reflect.TypeOf((*MockStore)(nil).GetChatByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatByIDForUpdate mocks base method.
|
||||
func (m *MockStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatByIDForUpdate", ctx, id)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatByIDForUpdate indicates an expected call of GetChatByIDForUpdate.
|
||||
func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDiffStatusByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID indicates an expected call of GetChatDiffStatusByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDiffStatusByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusesByChatIDs mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDiffStatusesByChatIDs", ctx, chatIds)
|
||||
ret0, _ := ret[0].([]database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDiffStatusesByChatIDs indicates an expected call of GetChatDiffStatusesByChatIDs.
|
||||
func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds)
|
||||
}
|
||||
|
||||
// GetChatMessageByID mocks base method.
|
||||
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessageByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessageByID indicates an expected call of GetChatMessageByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID indicates an expected call of GetChatMessagesByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesForPromptByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID indicates an expected call of GetChatMessagesForPromptByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesForPromptByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesForPromptByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesForPromptByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigByID indicates an expected call of GetChatModelConfigByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatModelConfigByProviderAndModel mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigByProviderAndModel", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigByProviderAndModel indicates an expected call of GetChatModelConfigByProviderAndModel.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigByProviderAndModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByProviderAndModel", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByProviderAndModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatModelConfigs mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigs indicates an expected call of GetChatModelConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByID indicates an expected call of GetChatProviderByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatProviderByProvider mocks base method.
|
||||
func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider)
|
||||
}
|
||||
|
||||
// GetChatProviders mocks base method.
|
||||
func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviders", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviders indicates an expected call of GetChatProviders.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetChatQueuedMessages mocks base method.
|
||||
func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatQueuedMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.ChatQueuedMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatQueuedMessages indicates an expected call of GetChatQueuedMessages.
|
||||
func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, ownerID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, ownerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, ownerID)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1754,6 +2092,21 @@ func (mr *MockStoreMockRecorder) GetDERPMeshKey(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDERPMeshKey", reflect.TypeOf((*MockStore)(nil).GetDERPMeshKey), ctx)
|
||||
}
|
||||
|
||||
// GetDefaultChatModelConfig mocks base method.
|
||||
func (m *MockStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDefaultChatModelConfig", ctx)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetDefaultChatModelConfig indicates an expected call of GetDefaultChatModelConfig.
|
||||
func (mr *MockStoreMockRecorder) GetDefaultChatModelConfig(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultChatModelConfig", reflect.TypeOf((*MockStore)(nil).GetDefaultChatModelConfig), ctx)
|
||||
}
|
||||
|
||||
// GetDefaultOrganization mocks base method.
|
||||
func (m *MockStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1874,6 +2227,36 @@ func (mr *MockStoreMockRecorder) GetEligibleProvisionerDaemonsByProvisionerJobID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEligibleProvisionerDaemonsByProvisionerJobIDs", reflect.TypeOf((*MockStore)(nil).GetEligibleProvisionerDaemonsByProvisionerJobIDs), ctx, provisionerJobIds)
|
||||
}
|
||||
|
||||
// GetEnabledChatModelConfigs mocks base method.
|
||||
func (m *MockStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledChatModelConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledChatModelConfigs indicates an expected call of GetEnabledChatModelConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetEnabledChatProviders mocks base method.
|
||||
func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetExternalAuthLink mocks base method.
|
||||
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3179,6 +3562,21 @@ func (mr *MockStoreMockRecorder) GetRuntimeConfig(ctx, key any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuntimeConfig", reflect.TypeOf((*MockStore)(nil).GetRuntimeConfig), ctx, key)
|
||||
}
|
||||
|
||||
// GetStaleChats mocks base method.
|
||||
func (m *MockStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStaleChats", ctx, staleThreshold)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetStaleChats indicates an expected call of GetStaleChats.
|
||||
func (mr *MockStoreMockRecorder) GetStaleChats(ctx, staleThreshold any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaleChats", reflect.TypeOf((*MockStore)(nil).GetStaleChats), ctx, staleThreshold)
|
||||
}
|
||||
|
||||
// GetTailnetPeers mocks base method.
|
||||
func (m *MockStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5083,6 +5481,81 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChat mocks base method.
|
||||
func (m *MockStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChat", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChat indicates an expected call of InsertChat.
|
||||
func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatMessage mocks base method.
|
||||
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatMessage indicates an expected call of InsertChatMessage.
|
||||
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatModelConfig mocks base method.
|
||||
func (m *MockStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatModelConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatModelConfig indicates an expected call of InsertChatModelConfig.
|
||||
func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatProvider mocks base method.
|
||||
func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatProvider indicates an expected call of InsertChatProvider.
|
||||
func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatQueuedMessage mocks base method.
|
||||
func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatQueuedMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatQueuedMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatQueuedMessage indicates an expected call of InsertChatQueuedMessage.
|
||||
func (mr *MockStoreMockRecorder) InsertChatQueuedMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).InsertChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertCryptoKey mocks base method.
|
||||
func (m *MockStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6102,6 +6575,36 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatsByRootID mocks base method.
|
||||
func (m *MockStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatsByRootID", ctx, rootChatID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatsByRootID indicates an expected call of ListChatsByRootID.
|
||||
func (mr *MockStoreMockRecorder) ListChatsByRootID(ctx, rootChatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatsByRootID", reflect.TypeOf((*MockStore)(nil).ListChatsByRootID), ctx, rootChatID)
|
||||
}
|
||||
|
||||
// ListChildChatsByParentID mocks base method.
|
||||
func (m *MockStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChildChatsByParentID", ctx, parentChatID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChildChatsByParentID indicates an expected call of ListChildChatsByParentID.
|
||||
func (mr *MockStoreMockRecorder) ListChildChatsByParentID(ctx, parentChatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChildChatsByParentID", reflect.TypeOf((*MockStore)(nil).ListChildChatsByParentID), ctx, parentChatID)
|
||||
}
|
||||
|
||||
// ListProvisionerKeysByOrganization mocks base method.
|
||||
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6281,6 +6784,21 @@ func (mr *MockStoreMockRecorder) Ping(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockStore)(nil).Ping), ctx)
|
||||
}
|
||||
|
||||
// PopNextQueuedMessage mocks base method.
|
||||
func (m *MockStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PopNextQueuedMessage", ctx, chatID)
|
||||
ret0, _ := ret[0].(database.ChatQueuedMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PopNextQueuedMessage indicates an expected call of PopNextQueuedMessage.
|
||||
func (mr *MockStoreMockRecorder) PopNextQueuedMessage(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopNextQueuedMessage", reflect.TypeOf((*MockStore)(nil).PopNextQueuedMessage), ctx, chatID)
|
||||
}
|
||||
|
||||
// ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate mocks base method.
|
||||
func (m *MockStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6411,6 +6929,20 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
|
||||
}
|
||||
|
||||
// UnsetDefaultChatModelConfigs mocks base method.
|
||||
func (m *MockStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnsetDefaultChatModelConfigs", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnsetDefaultChatModelConfigs indicates an expected call of UnsetDefaultChatModelConfigs.
|
||||
func (mr *MockStoreMockRecorder) UnsetDefaultChatModelConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsetDefaultChatModelConfigs", reflect.TypeOf((*MockStore)(nil).UnsetDefaultChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// UpdateAIBridgeInterceptionEnded mocks base method.
|
||||
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6440,6 +6972,111 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatByID mocks base method.
|
||||
func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatByID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatByID indicates an expected call of UpdateChatByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMessageByID mocks base method.
|
||||
func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatMessageByID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatMessageByID indicates an expected call of UpdateChatMessageByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatMessageByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMessageByID", reflect.TypeOf((*MockStore)(nil).UpdateChatMessageByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatModelConfig mocks base method.
|
||||
func (m *MockStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatModelConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatModelConfig indicates an expected call of UpdateChatModelConfig.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatModelConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatModelConfig", reflect.TypeOf((*MockStore)(nil).UpdateChatModelConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatProvider mocks base method.
|
||||
func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatProvider indicates an expected call of UpdateChatProvider.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatStatus mocks base method.
|
||||
func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatStatus indicates an expected call of UpdateChatStatus.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateCryptoKeyDeletesAt mocks base method.
|
||||
func (m *MockStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7722,6 +8359,36 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDiffStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus indicates an expected call of UpsertChatDiffStatus.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDiffStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatus", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatusReference mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDiffStatusReference", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatusReference indicates an expected call of UpsertChatDiffStatusReference.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+235
-1
@@ -210,7 +210,12 @@ CREATE TYPE api_key_scope AS ENUM (
|
||||
'boundary_usage:read',
|
||||
'boundary_usage:update',
|
||||
'workspace:update_agent',
|
||||
'workspace_dormant:update_agent'
|
||||
'workspace_dormant:update_agent',
|
||||
'chat:create',
|
||||
'chat:read',
|
||||
'chat:update',
|
||||
'chat:delete',
|
||||
'chat:*'
|
||||
);
|
||||
|
||||
CREATE TYPE app_sharing_level AS ENUM (
|
||||
@@ -260,6 +265,21 @@ CREATE TYPE build_reason AS ENUM (
|
||||
'task_resume'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_message_visibility AS ENUM (
|
||||
'user',
|
||||
'model',
|
||||
'both'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
|
||||
CREATE TYPE connection_status AS ENUM (
|
||||
'connected',
|
||||
'disconnected'
|
||||
@@ -1144,6 +1164,118 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
|
||||
|
||||
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
|
||||
|
||||
CREATE TABLE chat_diff_statuses (
|
||||
chat_id uuid NOT NULL,
|
||||
url text,
|
||||
pull_request_state text,
|
||||
changes_requested boolean DEFAULT false NOT NULL,
|
||||
additions integer DEFAULT 0 NOT NULL,
|
||||
deletions integer DEFAULT 0 NOT NULL,
|
||||
changed_files integer DEFAULT 0 NOT NULL,
|
||||
refreshed_at timestamp with time zone,
|
||||
stale_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
git_branch text DEFAULT ''::text NOT NULL,
|
||||
git_remote_origin text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
model_config_id uuid,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
role text NOT NULL,
|
||||
content jsonb,
|
||||
visibility chat_message_visibility DEFAULT 'both'::chat_message_visibility NOT NULL,
|
||||
input_tokens bigint,
|
||||
output_tokens bigint,
|
||||
total_tokens bigint,
|
||||
reasoning_tokens bigint,
|
||||
cache_creation_tokens bigint,
|
||||
cache_read_tokens bigint,
|
||||
context_limit bigint,
|
||||
compressed boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_messages_id_seq OWNED BY chat_messages.id;
|
||||
|
||||
CREATE TABLE chat_model_configs (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
provider text NOT NULL,
|
||||
model text NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
created_by uuid,
|
||||
updated_by uuid,
|
||||
enabled boolean DEFAULT true NOT NULL,
|
||||
is_default boolean DEFAULT false NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
deleted_at timestamp with time zone,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
context_limit bigint NOT NULL,
|
||||
compression_threshold integer NOT NULL,
|
||||
options jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))),
|
||||
CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0))
|
||||
);
|
||||
|
||||
CREATE TABLE chat_providers (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
provider text NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
api_key text DEFAULT ''::text NOT NULL,
|
||||
api_key_key_id text,
|
||||
created_by uuid,
|
||||
enabled boolean DEFAULT true NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
base_url text DEFAULT ''::text NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
|
||||
CREATE TABLE chat_queued_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
content jsonb NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_queued_messages_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
|
||||
|
||||
CREATE TABLE chats (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
workspace_id uuid,
|
||||
workspace_agent_id uuid,
|
||||
title text DEFAULT 'New Chat'::text NOT NULL,
|
||||
status chat_status DEFAULT 'waiting'::chat_status NOT NULL,
|
||||
worker_id uuid,
|
||||
started_at timestamp with time zone,
|
||||
heartbeat_at timestamp with time zone,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
parent_chat_id uuid,
|
||||
root_chat_id uuid,
|
||||
last_model_config_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
id uuid NOT NULL,
|
||||
connect_time timestamp with time zone NOT NULL,
|
||||
@@ -2951,6 +3083,10 @@ CREATE VIEW workspaces_expanded AS
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
|
||||
ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
|
||||
@@ -2987,6 +3123,27 @@ ALTER TABLE ONLY audit_logs
|
||||
ALTER TABLE ONLY boundary_usage_stats
|
||||
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY connection_logs
|
||||
ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3314,6 +3471,38 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
|
||||
|
||||
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::text) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING btree (provider, model);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
|
||||
|
||||
CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
|
||||
CREATE INDEX idx_chats_root_chat_id ON chats USING btree (root_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_workspace ON chats USING btree (workspace_id);
|
||||
|
||||
CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC);
|
||||
|
||||
CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
|
||||
@@ -3560,6 +3749,51 @@ ALTER TABLE ONLY aibridge_interceptions
|
||||
ALTER TABLE ONLY api_keys
|
||||
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
|
||||
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_provider_fkey FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_root_chat_id_fkey FOREIGN KEY (root_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY connection_logs
|
||||
ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -8,6 +8,21 @@ type ForeignKeyConstraint string
|
||||
const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatModelConfigsProvider ForeignKeyConstraint = "chat_model_configs_provider_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_provider_fkey FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
|
||||
ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
|
||||
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsRootChatID ForeignKeyConstraint = "chats_root_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_root_chat_id_fkey FOREIGN KEY (root_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsWorkspaceAgentID ForeignKeyConstraint = "chats_workspace_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsWorkspaceID ForeignKeyConstraint = "chats_workspace_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL;
|
||||
ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP TABLE IF EXISTS chat_queued_messages;
|
||||
DROP TABLE IF EXISTS chat_diff_statuses;
|
||||
DROP TABLE IF EXISTS chat_messages;
|
||||
DROP TABLE IF EXISTS chats;
|
||||
DROP TABLE IF EXISTS chat_model_configs;
|
||||
DROP TABLE IF EXISTS chat_providers;
|
||||
DROP TYPE IF EXISTS chat_message_visibility;
|
||||
DROP TYPE IF EXISTS chat_status;
|
||||
@@ -0,0 +1,167 @@
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_message_visibility AS ENUM (
|
||||
'user',
|
||||
'model',
|
||||
'both'
|
||||
);
|
||||
|
||||
CREATE TABLE chats (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL,
|
||||
workspace_agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL,
|
||||
title TEXT NOT NULL DEFAULT 'New Chat',
|
||||
status chat_status NOT NULL DEFAULT 'waiting',
|
||||
worker_id UUID,
|
||||
started_at TIMESTAMPTZ,
|
||||
heartbeat_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
parent_chat_id UUID REFERENCES chats(id) ON DELETE SET NULL,
|
||||
root_chat_id UUID REFERENCES chats(id) ON DELETE SET NULL,
|
||||
last_model_config_id UUID NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats(owner_id);
|
||||
CREATE INDEX idx_chats_workspace ON chats(workspace_id);
|
||||
CREATE INDEX idx_chats_pending ON chats(status) WHERE status = 'pending';
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats(parent_chat_id);
|
||||
CREATE INDEX idx_chats_root_chat_id ON chats(root_chat_id);
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats(last_model_config_id);
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
model_config_id UUID,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
role TEXT NOT NULL,
|
||||
content JSONB,
|
||||
visibility chat_message_visibility NOT NULL DEFAULT 'both',
|
||||
input_tokens BIGINT,
|
||||
output_tokens BIGINT,
|
||||
total_tokens BIGINT,
|
||||
reasoning_tokens BIGINT,
|
||||
cache_creation_tokens BIGINT,
|
||||
cache_read_tokens BIGINT,
|
||||
context_limit BIGINT,
|
||||
compressed BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages(chat_id);
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages(chat_id, created_at);
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary
|
||||
ON chat_messages(chat_id, created_at DESC, id DESC)
|
||||
WHERE compressed = TRUE
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both');
|
||||
|
||||
CREATE TABLE chat_diff_statuses (
|
||||
chat_id UUID PRIMARY KEY REFERENCES chats(id) ON DELETE CASCADE,
|
||||
url TEXT,
|
||||
pull_request_state TEXT,
|
||||
changes_requested BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
additions INTEGER NOT NULL DEFAULT 0,
|
||||
deletions INTEGER NOT NULL DEFAULT 0,
|
||||
changed_files INTEGER NOT NULL DEFAULT 0,
|
||||
refreshed_at TIMESTAMPTZ,
|
||||
stale_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
git_branch TEXT NOT NULL DEFAULT '',
|
||||
git_remote_origin TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses(stale_at);
|
||||
|
||||
CREATE TABLE chat_providers (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
api_key TEXT NOT NULL DEFAULT '',
|
||||
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
created_by UUID REFERENCES users(id),
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
base_url TEXT NOT NULL DEFAULT '',
|
||||
CONSTRAINT chat_providers_provider_check CHECK (
|
||||
provider = ANY (
|
||||
ARRAY[
|
||||
'anthropic'::text,
|
||||
'azure'::text,
|
||||
'bedrock'::text,
|
||||
'google'::text,
|
||||
'openai'::text,
|
||||
'openai-compat'::text,
|
||||
'openrouter'::text,
|
||||
'vercel'::text
|
||||
]
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
|
||||
CREATE INDEX idx_chat_providers_enabled ON chat_providers(enabled);
|
||||
|
||||
CREATE TABLE chat_model_configs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider TEXT NOT NULL REFERENCES chat_providers(provider) ON DELETE CASCADE,
|
||||
model TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
created_by UUID REFERENCES users(id),
|
||||
updated_by UUID REFERENCES users(id),
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
is_default BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
deleted_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
context_limit BIGINT NOT NULL,
|
||||
compression_threshold INTEGER NOT NULL,
|
||||
options JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
CONSTRAINT chat_model_configs_context_limit_check
|
||||
CHECK (context_limit > 0),
|
||||
CONSTRAINT chat_model_configs_compression_threshold_check
|
||||
CHECK (compression_threshold >= 0 AND compression_threshold <= 100)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs(enabled);
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs(provider);
|
||||
CREATE INDEX idx_chat_model_configs_provider_model
|
||||
ON chat_model_configs(provider, model);
|
||||
CREATE UNIQUE INDEX idx_chat_model_configs_single_default
|
||||
ON chat_model_configs ((1))
|
||||
WHERE is_default = TRUE
|
||||
AND deleted = FALSE;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
ADD CONSTRAINT chat_messages_model_config_id_fkey
|
||||
FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
|
||||
|
||||
ALTER TABLE chats
|
||||
ADD CONSTRAINT chats_last_model_config_id_fkey
|
||||
FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
|
||||
CREATE TABLE chat_queued_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
content JSONB NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages(chat_id);
|
||||
|
||||
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:create';
|
||||
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:read';
|
||||
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:update';
|
||||
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:delete';
|
||||
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'chat:*';
|
||||
+114
@@ -0,0 +1,114 @@
|
||||
INSERT INTO chat_providers (
|
||||
id,
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
api_key_key_id,
|
||||
enabled,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
|
||||
'openai',
|
||||
'OpenAI',
|
||||
'',
|
||||
NULL,
|
||||
TRUE,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
);
|
||||
|
||||
INSERT INTO chat_model_configs (
|
||||
id,
|
||||
provider,
|
||||
model,
|
||||
display_name,
|
||||
enabled,
|
||||
context_limit,
|
||||
compression_threshold,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
|
||||
'openai',
|
||||
'gpt-5.2',
|
||||
'GPT 5.2',
|
||||
TRUE,
|
||||
200000,
|
||||
70,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
);
|
||||
|
||||
INSERT INTO chats (
|
||||
id,
|
||||
owner_id,
|
||||
last_model_config_id,
|
||||
title,
|
||||
status,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
id,
|
||||
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
|
||||
'Fixture Chat',
|
||||
'completed',
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
FROM users
|
||||
ORDER BY created_at, id
|
||||
LIMIT 1;
|
||||
|
||||
INSERT INTO chat_messages (
|
||||
chat_id,
|
||||
created_at,
|
||||
role,
|
||||
content
|
||||
) VALUES (
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'2024-01-01 00:00:00+00',
|
||||
'assistant',
|
||||
'{"type":"text","text":"fixture"}'::jsonb
|
||||
);
|
||||
|
||||
INSERT INTO chat_diff_statuses (
|
||||
chat_id,
|
||||
url,
|
||||
pull_request_state,
|
||||
changes_requested,
|
||||
additions,
|
||||
deletions,
|
||||
changed_files,
|
||||
refreshed_at,
|
||||
stale_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
git_branch,
|
||||
git_remote_origin
|
||||
) VALUES (
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'https://example.com/pr/1',
|
||||
'open',
|
||||
FALSE,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00',
|
||||
'main',
|
||||
'origin'
|
||||
);
|
||||
|
||||
INSERT INTO chat_queued_messages (
|
||||
chat_id,
|
||||
content,
|
||||
created_at
|
||||
) VALUES (
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'{"type":"text","text":"queued fixture"}'::jsonb,
|
||||
'2024-01-01 00:00:00+00'
|
||||
);
|
||||
@@ -165,6 +165,10 @@ func (t TaskTable) RBACObject() rbac.Object {
|
||||
InOrg(t.OrganizationID)
|
||||
}
|
||||
|
||||
func (c Chat) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
|
||||
+237
-1
@@ -219,6 +219,11 @@ const (
|
||||
ApiKeyScopeBoundaryUsageUpdate APIKeyScope = "boundary_usage:update"
|
||||
ApiKeyScopeWorkspaceUpdateAgent APIKeyScope = "workspace:update_agent"
|
||||
ApiKeyScopeWorkspaceDormantUpdateAgent APIKeyScope = "workspace_dormant:update_agent"
|
||||
ApiKeyScopeChatCreate APIKeyScope = "chat:create"
|
||||
ApiKeyScopeChatRead APIKeyScope = "chat:read"
|
||||
ApiKeyScopeChatUpdate APIKeyScope = "chat:update"
|
||||
ApiKeyScopeChatDelete APIKeyScope = "chat:delete"
|
||||
ApiKeyScopeChat APIKeyScope = "chat:*"
|
||||
)
|
||||
|
||||
func (e *APIKeyScope) Scan(src interface{}) error {
|
||||
@@ -457,7 +462,12 @@ func (e APIKeyScope) Valid() bool {
|
||||
ApiKeyScopeBoundaryUsageRead,
|
||||
ApiKeyScopeBoundaryUsageUpdate,
|
||||
ApiKeyScopeWorkspaceUpdateAgent,
|
||||
ApiKeyScopeWorkspaceDormantUpdateAgent:
|
||||
ApiKeyScopeWorkspaceDormantUpdateAgent,
|
||||
ApiKeyScopeChatCreate,
|
||||
ApiKeyScopeChatRead,
|
||||
ApiKeyScopeChatUpdate,
|
||||
ApiKeyScopeChatDelete,
|
||||
ApiKeyScopeChat:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -665,6 +675,11 @@ func AllAPIKeyScopeValues() []APIKeyScope {
|
||||
ApiKeyScopeBoundaryUsageUpdate,
|
||||
ApiKeyScopeWorkspaceUpdateAgent,
|
||||
ApiKeyScopeWorkspaceDormantUpdateAgent,
|
||||
ApiKeyScopeChatCreate,
|
||||
ApiKeyScopeChatRead,
|
||||
ApiKeyScopeChatUpdate,
|
||||
ApiKeyScopeChatDelete,
|
||||
ApiKeyScopeChat,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1034,6 +1049,137 @@ func AllBuildReasonValues() []BuildReason {
|
||||
}
|
||||
}
|
||||
|
||||
type ChatMessageVisibility string
|
||||
|
||||
const (
|
||||
ChatMessageVisibilityUser ChatMessageVisibility = "user"
|
||||
ChatMessageVisibilityModel ChatMessageVisibility = "model"
|
||||
ChatMessageVisibilityBoth ChatMessageVisibility = "both"
|
||||
)
|
||||
|
||||
func (e *ChatMessageVisibility) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ChatMessageVisibility(s)
|
||||
case string:
|
||||
*e = ChatMessageVisibility(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ChatMessageVisibility: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullChatMessageVisibility struct {
|
||||
ChatMessageVisibility ChatMessageVisibility `json:"chat_message_visibility"`
|
||||
Valid bool `json:"valid"` // Valid is true if ChatMessageVisibility is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullChatMessageVisibility) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.ChatMessageVisibility, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.ChatMessageVisibility.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullChatMessageVisibility) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.ChatMessageVisibility), nil
|
||||
}
|
||||
|
||||
func (e ChatMessageVisibility) Valid() bool {
|
||||
switch e {
|
||||
case ChatMessageVisibilityUser,
|
||||
ChatMessageVisibilityModel,
|
||||
ChatMessageVisibilityBoth:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllChatMessageVisibilityValues() []ChatMessageVisibility {
|
||||
return []ChatMessageVisibility{
|
||||
ChatMessageVisibilityUser,
|
||||
ChatMessageVisibilityModel,
|
||||
ChatMessageVisibilityBoth,
|
||||
}
|
||||
}
|
||||
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
)
|
||||
|
||||
func (e *ChatStatus) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ChatStatus(s)
|
||||
case string:
|
||||
*e = ChatStatus(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ChatStatus: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullChatStatus struct {
|
||||
ChatStatus ChatStatus `json:"chat_status"`
|
||||
Valid bool `json:"valid"` // Valid is true if ChatStatus is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullChatStatus) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.ChatStatus, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.ChatStatus.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullChatStatus) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.ChatStatus), nil
|
||||
}
|
||||
|
||||
func (e ChatStatus) Valid() bool {
|
||||
switch e {
|
||||
case ChatStatusWaiting,
|
||||
ChatStatusPending,
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllChatStatusValues() []ChatStatus {
|
||||
return []ChatStatus{
|
||||
ChatStatusWaiting,
|
||||
ChatStatusPending,
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectionStatus string
|
||||
|
||||
const (
|
||||
@@ -3739,6 +3885,96 @@ type BoundaryUsageStat struct {
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Url sql.NullString `db:"url" json:"url"`
|
||||
PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"`
|
||||
ChangesRequested bool `db:"changes_requested" json:"changes_requested"`
|
||||
Additions int32 `db:"additions" json:"additions"`
|
||||
Deletions int32 `db:"deletions" json:"deletions"`
|
||||
ChangedFiles int32 `db:"changed_files" json:"changed_files"`
|
||||
RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"`
|
||||
StaleAt time.Time `db:"stale_at" json:"stale_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
GitBranch string `db:"git_branch" json:"git_branch"`
|
||||
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Role string `db:"role" json:"role"`
|
||||
Content pqtype.NullRawMessage `db:"content" json:"content"`
|
||||
Visibility ChatMessageVisibility `db:"visibility" json:"visibility"`
|
||||
InputTokens sql.NullInt64 `db:"input_tokens" json:"input_tokens"`
|
||||
OutputTokens sql.NullInt64 `db:"output_tokens" json:"output_tokens"`
|
||||
TotalTokens sql.NullInt64 `db:"total_tokens" json:"total_tokens"`
|
||||
ReasoningTokens sql.NullInt64 `db:"reasoning_tokens" json:"reasoning_tokens"`
|
||||
CacheCreationTokens sql.NullInt64 `db:"cache_creation_tokens" json:"cache_creation_tokens"`
|
||||
CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"`
|
||||
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
|
||||
Compressed bool `db:"compressed" json:"compressed"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
IsDefault bool `db:"is_default" json:"is_default"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
DeletedAt sql.NullTime `db:"deleted_at" json:"deleted_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ContextLimit int64 `db:"context_limit" json:"context_limit"`
|
||||
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
|
||||
Options json.RawMessage `db:"options" json:"options"`
|
||||
}
|
||||
|
||||
type ChatProvider struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
}
|
||||
|
||||
type ChatQueuedMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Content json.RawMessage `db:"content" json:"content"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type ConnectionLog struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
)
|
||||
|
||||
type sqlcQuerier interface {
|
||||
// Acquires a pending chat for processing. Uses SKIP LOCKED to prevent
|
||||
// multiple replicas from acquiring the same chat.
|
||||
AcquireChat(ctx context.Context, arg AcquireChatParams) (Chat, error)
|
||||
// Blocks until the lock is acquired.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
@@ -81,6 +84,7 @@ type sqlcQuerier interface {
|
||||
CustomRoles(ctx context.Context, arg CustomRolesParams) ([]CustomRole, error)
|
||||
DeleteAPIKeyByID(ctx context.Context, id string) error
|
||||
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error
|
||||
// Deletes all existing webpush subscriptions.
|
||||
// This should be called when the VAPID keypair is regenerated, as the old
|
||||
@@ -88,6 +92,12 @@ type sqlcQuerier interface {
|
||||
// be recreated.
|
||||
DeleteAllWebpushSubscriptions(ctx context.Context) error
|
||||
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteChatByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatMessagesAfterID(ctx context.Context, arg DeleteChatMessagesAfterIDParams) error
|
||||
DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
|
||||
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
|
||||
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
|
||||
@@ -199,6 +209,21 @@ type sqlcQuerier interface {
|
||||
// This function returns roles for authorization purposes. Implied member roles
|
||||
// are included.
|
||||
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
|
||||
GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigByProviderAndModel(ctx context.Context, arg GetChatModelConfigByProviderAndModelParams) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]Chat, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
@@ -206,6 +231,7 @@ type sqlcQuerier interface {
|
||||
GetCryptoKeysByFeature(ctx context.Context, feature CryptoKeyFeature) ([]CryptoKey, error)
|
||||
GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error)
|
||||
GetDERPMeshKey(ctx context.Context) (string, error)
|
||||
GetDefaultChatModelConfig(ctx context.Context) (ChatModelConfig, error)
|
||||
GetDefaultOrganization(ctx context.Context) (Organization, error)
|
||||
GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error)
|
||||
GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error)
|
||||
@@ -214,6 +240,8 @@ type sqlcQuerier interface {
|
||||
GetDeploymentWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) (GetDeploymentWorkspaceAgentUsageStatsRow, error)
|
||||
GetDeploymentWorkspaceStats(ctx context.Context) (GetDeploymentWorkspaceStatsRow, error)
|
||||
GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.Context, provisionerJobIds []uuid.UUID) ([]GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error)
|
||||
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
|
||||
GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg GetFailedWorkspaceBuildsByTemplateIDParams) ([]GetFailedWorkspaceBuildsByTemplateIDRow, error)
|
||||
@@ -352,6 +380,9 @@ type sqlcQuerier interface {
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
|
||||
GetRuntimeConfig(ctx context.Context, key string) (string, error)
|
||||
// Find chats that appear stuck (running but heartbeat has expired).
|
||||
// Used for recovery after coderd crashes or long hangs.
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
|
||||
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
|
||||
@@ -574,6 +605,11 @@ type sqlcQuerier interface {
|
||||
// every member of the org.
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
|
||||
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
|
||||
InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error)
|
||||
InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error)
|
||||
InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error)
|
||||
InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error
|
||||
@@ -657,6 +693,8 @@ type sqlcQuerier interface {
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]Chat, error)
|
||||
ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]Chat, error)
|
||||
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
@@ -673,6 +711,7 @@ type sqlcQuerier interface {
|
||||
// - Use both to get a specific org member row
|
||||
OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error)
|
||||
PaginatedOrganizationMembers(ctx context.Context, arg PaginatedOrganizationMembersParams) ([]PaginatedOrganizationMembersRow, error)
|
||||
PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error)
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
|
||||
@@ -691,8 +730,18 @@ type sqlcQuerier interface {
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
UnsetDefaultChatModelConfigs(ctx context.Context) error
|
||||
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
|
||||
UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error)
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
@@ -790,6 +839,8 @@ type sqlcQuerier interface {
|
||||
// cumulative values for unique counts (accurate period totals). Request counts
|
||||
// are always deltas, accumulated in DB. Returns true if insert, false if update.
|
||||
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
|
||||
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
|
||||
@@ -7548,6 +7548,47 @@ func TestGetTaskByWorkspaceID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteTaskDeletesTaskSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
template := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
task := dbgen.Task(t, db, database.TaskTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
TemplateVersionID: templateVersion.ID,
|
||||
Prompt: "Test prompt",
|
||||
})
|
||||
|
||||
err := db.UpsertTaskSnapshot(ctx, database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
LogSnapshot: json.RawMessage(`{"messages":[]}`),
|
||||
LogSnapshotCreatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.DeleteTask(ctx, database.DeleteTaskParams{
|
||||
ID: task.ID,
|
||||
DeletedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.GetTaskSnapshot(ctx, task.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
}
|
||||
|
||||
func TestTaskNameUniqueness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+1875
-1
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,129 @@
|
||||
-- name: GetChatModelConfigByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND deleted = FALSE;
|
||||
|
||||
-- name: GetDefaultChatModelConfig :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
is_default = TRUE
|
||||
AND deleted = FALSE;
|
||||
|
||||
-- name: GetChatModelConfigByProviderAndModel :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
provider = @provider::text
|
||||
AND model = @model::text
|
||||
AND deleted = FALSE
|
||||
ORDER BY
|
||||
updated_at DESC,
|
||||
created_at DESC,
|
||||
id DESC
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetChatModelConfigs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
deleted = FALSE
|
||||
ORDER BY
|
||||
provider ASC,
|
||||
model ASC,
|
||||
updated_at DESC,
|
||||
id DESC;
|
||||
|
||||
-- name: GetEnabledChatModelConfigs :many
|
||||
SELECT
|
||||
cmc.*
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
WHERE
|
||||
cmc.enabled = TRUE
|
||||
AND cmc.deleted = FALSE
|
||||
AND cp.enabled = TRUE
|
||||
ORDER BY
|
||||
cmc.provider ASC,
|
||||
cmc.model ASC,
|
||||
cmc.updated_at DESC,
|
||||
cmc.id DESC;
|
||||
|
||||
-- name: InsertChatModelConfig :one
|
||||
INSERT INTO chat_model_configs (
|
||||
provider,
|
||||
model,
|
||||
display_name,
|
||||
created_by,
|
||||
updated_by,
|
||||
enabled,
|
||||
is_default,
|
||||
context_limit,
|
||||
compression_threshold,
|
||||
options
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@model::text,
|
||||
@display_name::text,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
sqlc.narg('updated_by')::uuid,
|
||||
@enabled::boolean,
|
||||
@is_default::boolean,
|
||||
@context_limit::bigint,
|
||||
@compression_threshold::integer,
|
||||
@options::jsonb
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatModelConfig :one
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
SET
|
||||
provider = @provider::text,
|
||||
model = @model::text,
|
||||
display_name = @display_name::text,
|
||||
updated_by = sqlc.narg('updated_by')::uuid,
|
||||
enabled = @enabled::boolean,
|
||||
is_default = @is_default::boolean,
|
||||
context_limit = @context_limit::bigint,
|
||||
compression_threshold = @compression_threshold::integer,
|
||||
options = @options::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND deleted = FALSE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UnsetDefaultChatModelConfigs :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
SET
|
||||
is_default = FALSE,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
is_default = TRUE
|
||||
AND deleted = FALSE;
|
||||
|
||||
-- name: DeleteChatModelConfigByID :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
SET
|
||||
deleted = TRUE,
|
||||
deleted_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
@@ -0,0 +1,75 @@
|
||||
-- name: GetChatProviderByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetChatProviderByProvider :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
provider = @provider::text;
|
||||
|
||||
-- name: GetChatProviders :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
ORDER BY
|
||||
provider ASC;
|
||||
|
||||
-- name: GetEnabledChatProviders :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
enabled = TRUE
|
||||
ORDER BY
|
||||
provider ASC;
|
||||
|
||||
-- name: InsertChatProvider :one
|
||||
INSERT INTO chat_providers (
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@display_name::text,
|
||||
@api_key::text,
|
||||
@base_url::text,
|
||||
sqlc.narg('api_key_key_id')::text,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@enabled::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatProvider :one
|
||||
UPDATE
|
||||
chat_providers
|
||||
SET
|
||||
display_name = @display_name::text,
|
||||
api_key = @api_key::text,
|
||||
base_url = @base_url::text,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
enabled = @enabled::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: DeleteChatProviderByID :exec
|
||||
DELETE FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
@@ -0,0 +1,409 @@
|
||||
-- name: DeleteChatByID :exec
|
||||
DELETE FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: DeleteChatMessagesByChatID :exec
|
||||
DELETE FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: DeleteChatMessagesAfterID :exec
|
||||
DELETE FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND id > @after_id::bigint;
|
||||
|
||||
-- name: GetChatByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetChatMessageByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
id = @id::bigint;
|
||||
|
||||
-- name: GetChatMessagesByChatID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND visibility IN ('user', 'both')
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: GetChatMessagesForPromptByChatID :many
|
||||
WITH latest_compressed_summary AS (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both')
|
||||
AND compressed = TRUE
|
||||
ORDER BY
|
||||
created_at DESC,
|
||||
id DESC
|
||||
LIMIT
|
||||
1
|
||||
)
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND visibility IN ('model', 'both')
|
||||
AND (
|
||||
(
|
||||
role = 'system'
|
||||
AND compressed = FALSE
|
||||
)
|
||||
OR (
|
||||
compressed = FALSE
|
||||
AND (
|
||||
NOT EXISTS (
|
||||
SELECT
|
||||
1
|
||||
FROM
|
||||
latest_compressed_summary
|
||||
)
|
||||
OR id > (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
latest_compressed_summary
|
||||
)
|
||||
)
|
||||
)
|
||||
OR id = (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
latest_compressed_summary
|
||||
)
|
||||
)
|
||||
ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: GetChatsByOwnerID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
owner_id = @owner_id::uuid
|
||||
ORDER BY
|
||||
updated_at DESC;
|
||||
|
||||
-- name: ListChildChatsByParentID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
parent_chat_id = @parent_chat_id::uuid
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: ListChatsByRootID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
root_chat_id = @root_chat_id::uuid
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: InsertChat :one
|
||||
INSERT INTO chats (
|
||||
owner_id,
|
||||
workspace_id,
|
||||
workspace_agent_id,
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
title
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
sqlc.narg('workspace_agent_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
@last_model_config_id::uuid,
|
||||
@title::text
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: InsertChatMessage :one
|
||||
WITH updated_chat AS (
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
last_model_config_id = sqlc.narg('model_config_id')::uuid
|
||||
WHERE
|
||||
id = @chat_id::uuid
|
||||
AND sqlc.narg('model_config_id')::uuid IS NOT NULL
|
||||
)
|
||||
INSERT INTO chat_messages (
|
||||
chat_id,
|
||||
model_config_id,
|
||||
role,
|
||||
content,
|
||||
visibility,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
total_tokens,
|
||||
reasoning_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('model_config_id')::uuid,
|
||||
@role::text,
|
||||
sqlc.narg('content')::jsonb,
|
||||
@visibility::chat_message_visibility,
|
||||
sqlc.narg('input_tokens')::bigint,
|
||||
sqlc.narg('output_tokens')::bigint,
|
||||
sqlc.narg('total_tokens')::bigint,
|
||||
sqlc.narg('reasoning_tokens')::bigint,
|
||||
sqlc.narg('cache_creation_tokens')::bigint,
|
||||
sqlc.narg('cache_read_tokens')::bigint,
|
||||
sqlc.narg('context_limit')::bigint,
|
||||
COALESCE(sqlc.narg('compressed')::boolean, FALSE)
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatMessageByID :one
|
||||
UPDATE
|
||||
chat_messages
|
||||
SET
|
||||
model_config_id = COALESCE(sqlc.narg('model_config_id')::uuid, model_config_id),
|
||||
content = sqlc.narg('content')::jsonb
|
||||
WHERE
|
||||
id = @id::bigint
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
title = @title::text,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatWorkspace :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
workspace_id = sqlc.narg('workspace_id')::uuid,
|
||||
workspace_agent_id = sqlc.narg('workspace_agent_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: AcquireChat :one
|
||||
-- Acquires a pending chat for processing. Uses SKIP LOCKED to prevent
|
||||
-- multiple replicas from acquiring the same chat.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = 'running'::chat_status,
|
||||
started_at = @started_at::timestamptz,
|
||||
heartbeat_at = @started_at::timestamptz,
|
||||
updated_at = @started_at::timestamptz,
|
||||
worker_id = @worker_id::uuid
|
||||
WHERE
|
||||
id = (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'pending'::chat_status
|
||||
ORDER BY
|
||||
updated_at ASC
|
||||
FOR UPDATE
|
||||
SKIP LOCKED
|
||||
LIMIT
|
||||
1
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatStatus :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = @status::chat_status,
|
||||
worker_id = sqlc.narg('worker_id')::uuid,
|
||||
started_at = sqlc.narg('started_at')::timestamptz,
|
||||
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_diff_statuses
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: GetChatDiffStatusesByChatIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_diff_statuses
|
||||
WHERE
|
||||
chat_id = ANY(@chat_ids::uuid[]);
|
||||
|
||||
-- name: UpsertChatDiffStatusReference :one
|
||||
INSERT INTO chat_diff_statuses (
|
||||
chat_id,
|
||||
url,
|
||||
git_branch,
|
||||
git_remote_origin,
|
||||
stale_at
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('url')::text,
|
||||
@git_branch::text,
|
||||
@git_remote_origin::text,
|
||||
@stale_at::timestamptz
|
||||
)
|
||||
ON CONFLICT (chat_id) DO UPDATE
|
||||
SET
|
||||
url = CASE
|
||||
WHEN EXCLUDED.url IS NOT NULL THEN EXCLUDED.url
|
||||
ELSE chat_diff_statuses.url
|
||||
END,
|
||||
git_branch = CASE
|
||||
WHEN EXCLUDED.git_branch != '' THEN EXCLUDED.git_branch
|
||||
ELSE chat_diff_statuses.git_branch
|
||||
END,
|
||||
git_remote_origin = CASE
|
||||
WHEN EXCLUDED.git_remote_origin != '' THEN EXCLUDED.git_remote_origin
|
||||
ELSE chat_diff_statuses.git_remote_origin
|
||||
END,
|
||||
stale_at = EXCLUDED.stale_at,
|
||||
updated_at = NOW()
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpsertChatDiffStatus :one
|
||||
INSERT INTO chat_diff_statuses (
|
||||
chat_id,
|
||||
url,
|
||||
pull_request_state,
|
||||
changes_requested,
|
||||
additions,
|
||||
deletions,
|
||||
changed_files,
|
||||
refreshed_at,
|
||||
stale_at
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('url')::text,
|
||||
sqlc.narg('pull_request_state')::text,
|
||||
@changes_requested::boolean,
|
||||
@additions::integer,
|
||||
@deletions::integer,
|
||||
@changed_files::integer,
|
||||
@refreshed_at::timestamptz,
|
||||
@stale_at::timestamptz
|
||||
)
|
||||
ON CONFLICT (chat_id) DO UPDATE
|
||||
SET
|
||||
url = EXCLUDED.url,
|
||||
pull_request_state = EXCLUDED.pull_request_state,
|
||||
changes_requested = EXCLUDED.changes_requested,
|
||||
additions = EXCLUDED.additions,
|
||||
deletions = EXCLUDED.deletions,
|
||||
changed_files = EXCLUDED.changed_files,
|
||||
refreshed_at = EXCLUDED.refreshed_at,
|
||||
stale_at = EXCLUDED.stale_at,
|
||||
updated_at = NOW()
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: InsertChatQueuedMessage :one
|
||||
INSERT INTO chat_queued_messages (chat_id, content)
|
||||
VALUES (@chat_id, @content)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatQueuedMessages :many
|
||||
SELECT * FROM chat_queued_messages
|
||||
WHERE chat_id = @chat_id
|
||||
ORDER BY id ASC;
|
||||
|
||||
-- name: DeleteChatQueuedMessage :exec
|
||||
DELETE FROM chat_queued_messages WHERE id = @id AND chat_id = @chat_id;
|
||||
|
||||
-- name: DeleteAllChatQueuedMessages :exec
|
||||
DELETE FROM chat_queued_messages WHERE chat_id = @chat_id;
|
||||
|
||||
-- name: PopNextQueuedMessage :one
|
||||
DELETE FROM chat_queued_messages
|
||||
WHERE id = (
|
||||
SELECT cqm.id FROM chat_queued_messages cqm
|
||||
WHERE cqm.chat_id = @chat_id
|
||||
ORDER BY cqm.id ASC
|
||||
LIMIT 1
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatByIDForUpdate :one
|
||||
SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;
|
||||
@@ -399,7 +399,7 @@ WHERE
|
||||
filtered_workspaces fw
|
||||
ORDER BY
|
||||
-- To ensure that 'favorite' workspaces show up first in the list only for their owner.
|
||||
CASE WHEN owner_id = @requester_id AND favorite THEN 0 ELSE 1 END ASC,
|
||||
CASE WHEN favorite AND owner_username = (SELECT users.username FROM users WHERE users.id = @requester_id) THEN 0 ELSE 1 END ASC,
|
||||
(latest_build_completed_at IS NOT NULL AND
|
||||
latest_build_canceled_at IS NULL AND
|
||||
latest_build_error IS NULL AND
|
||||
|
||||
@@ -14,6 +14,13 @@ const (
|
||||
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
|
||||
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
|
||||
UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id);
|
||||
@@ -110,6 +117,7 @@ const (
|
||||
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
|
||||
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
|
||||
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
|
||||
UniqueIndexChatModelConfigsSingleDefault UniqueConstraint = "idx_chat_model_configs_single_default" // CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
|
||||
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
|
||||
UniqueIndexCustomRolesNameLowerOrganizationID UniqueConstraint = "idx_custom_roles_name_lower_organization_id" // CREATE UNIQUE INDEX idx_custom_roles_name_lower_organization_id ON custom_roles USING btree (lower(name), COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid));
|
||||
UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)) WHERE (deleted = false);
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type chatParamContextKey struct{}
|
||||
|
||||
// ChatParam returns the chat from the ExtractChatParam handler.
|
||||
func ChatParam(r *http.Request) database.Chat {
|
||||
chat, ok := r.Context().Value(chatParamContextKey{}).(database.Chat)
|
||||
if !ok {
|
||||
panic("developer error: chat param middleware not provided")
|
||||
}
|
||||
return chat
|
||||
}
|
||||
|
||||
// ExtractChatParam grabs a chat from the "chat" URL parameter.
|
||||
func ExtractChatParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chatID, parsed := ParseUUIDParam(rw, r, "chat")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, chatID)
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, chatParamContextKey{}, chat)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestChatParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.User) {
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, token := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
})
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.Header.Set(codersdk.SessionTokenHeader, token)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, user
|
||||
}
|
||||
|
||||
insertChat := func(t *testing.T, db database.Store, ownerID uuid.UUID) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, err := db.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test model",
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: []byte("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(context.Background(), database.InsertChatParams{
|
||||
OwnerID: ownerID,
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
WorkspaceAgentID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
RootChatID: uuid.NullUUID{},
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "Test chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return chat
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractChatParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractChatParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("chat", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("BadUUID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractChatParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("chat", "not-a-uuid")
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractChatParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.ChatParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, user := setupAuthentication(db)
|
||||
chat := insertChat(t, db, user.ID)
|
||||
|
||||
chi.RouteContext(r.Context()).URLParams.Add("chat", chat.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
||||
@@ -15,13 +15,24 @@ type requestIDContextKey struct{}
|
||||
|
||||
// RequestID returns the ID of the request.
|
||||
func RequestID(r *http.Request) uuid.UUID {
|
||||
rid, ok := r.Context().Value(requestIDContextKey{}).(uuid.UUID)
|
||||
rid, ok := RequestIDOptional(r)
|
||||
if !ok {
|
||||
panic("developer error: request id middleware not provided")
|
||||
}
|
||||
return rid
|
||||
}
|
||||
|
||||
// RequestIDOptional returns the request ID when present.
|
||||
func RequestIDOptional(r *http.Request) (uuid.UUID, bool) {
|
||||
rid, ok := r.Context().Value(requestIDContextKey{}).(uuid.UUID)
|
||||
return rid, ok
|
||||
}
|
||||
|
||||
// WithRequestID stores a request ID in the context.
|
||||
func WithRequestID(ctx context.Context, rid uuid.UUID) context.Context {
|
||||
return context.WithValue(ctx, requestIDContextKey{}, rid)
|
||||
}
|
||||
|
||||
// AttachRequestID adds a request ID to each HTTP request.
|
||||
func AttachRequestID(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
@@ -31,3 +33,16 @@ func TestRequestID(t *testing.T) {
|
||||
require.NotEmpty(t, res.Header.Get("X-Coder-Request-ID"))
|
||||
require.NotEmpty(t, rw.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRequestIDHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestID := uuid.New()
|
||||
ctx := httpmw.WithRequestID(context.Background(), requestID)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
|
||||
gotRequestID, ok := httpmw.RequestIDOptional(req)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, requestID, gotRequestID)
|
||||
require.Equal(t, requestID, httpmw.RequestID(req))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event"))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ChatStreamNotifyChannel returns the pubsub channel for per-chat
|
||||
// stream notifications. Subscribers receive lightweight notifications
|
||||
// and read actual content from the database.
|
||||
func ChatStreamNotifyChannel(chatID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:stream:%s", chatID)
|
||||
}
|
||||
|
||||
// ChatStreamNotifyMessage is the payload published on the per-chat
|
||||
// stream notification channel. The actual message content is read
|
||||
// from the database by subscribers.
|
||||
type ChatStreamNotifyMessage struct {
|
||||
// AfterMessageID tells subscribers to query messages after this
|
||||
// ID. Set when a new message is persisted.
|
||||
AfterMessageID int64 `json:"after_message_id,omitempty"`
|
||||
|
||||
// Status is set when the chat status changes. Subscribers use
|
||||
// this to update clients and to manage relay lifecycle.
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// WorkerID identifies which replica is running the chat. Used
|
||||
// by enterprise relay to know where to connect.
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
|
||||
// Error is set when a processing error occurs.
|
||||
Error string `json:"error,omitempty"`
|
||||
|
||||
// QueueUpdate is set when the queued messages change.
|
||||
QueueUpdate bool `json:"queue_update,omitempty"`
|
||||
}
|
||||
@@ -72,6 +72,16 @@ var (
|
||||
Type: "boundary_usage",
|
||||
}
|
||||
|
||||
// ResourceChat
|
||||
// Valid Actions
|
||||
// - "ActionCreate" :: create a new chat
|
||||
// - "ActionDelete" :: delete a chat
|
||||
// - "ActionRead" :: read chat messages and metadata
|
||||
// - "ActionUpdate" :: update chat title or settings
|
||||
ResourceChat = Object{
|
||||
Type: "chat",
|
||||
}
|
||||
|
||||
// ResourceConnectionLog
|
||||
// Valid Actions
|
||||
// - "ActionRead" :: read connection logs
|
||||
@@ -429,6 +439,7 @@ func AllResources() []Objecter {
|
||||
ResourceAssignRole,
|
||||
ResourceAuditLog,
|
||||
ResourceBoundaryUsage,
|
||||
ResourceChat,
|
||||
ResourceConnectionLog,
|
||||
ResourceCryptoKey,
|
||||
ResourceDebugInfo,
|
||||
|
||||
@@ -77,6 +77,13 @@ var taskActions = map[Action]ActionDefinition{
|
||||
ActionDelete: "delete task",
|
||||
}
|
||||
|
||||
var chatActions = map[Action]ActionDefinition{
|
||||
ActionCreate: "create a new chat",
|
||||
ActionRead: "read chat messages and metadata",
|
||||
ActionUpdate: "update chat title or settings",
|
||||
ActionDelete: "delete a chat",
|
||||
}
|
||||
|
||||
// RBACPermissions is indexed by the type
|
||||
var RBACPermissions = map[string]PermissionDefinition{
|
||||
// Wildcard is every object, and the action "*" provides all actions.
|
||||
@@ -103,6 +110,9 @@ var RBACPermissions = map[string]PermissionDefinition{
|
||||
"task": {
|
||||
Actions: taskActions,
|
||||
},
|
||||
"chat": {
|
||||
Actions: chatActions,
|
||||
},
|
||||
// Dormant workspaces have the same perms as workspaces.
|
||||
"workspace_dormant": {
|
||||
Actions: workspaceActions,
|
||||
|
||||
@@ -1030,6 +1030,20 @@ func TestRolePermissions(t *testing.T) {
|
||||
false: {owner, setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "ChatUsage",
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
Resource: rbac.ResourceChat.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
false: {
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
userAdmin, orgUserAdmin, otherOrgUserAdmin,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// We expect every permission to be tested above.
|
||||
|
||||
@@ -28,6 +28,10 @@ const (
|
||||
ScopeBoundaryUsageDelete ScopeName = "boundary_usage:delete"
|
||||
ScopeBoundaryUsageRead ScopeName = "boundary_usage:read"
|
||||
ScopeBoundaryUsageUpdate ScopeName = "boundary_usage:update"
|
||||
ScopeChatCreate ScopeName = "chat:create"
|
||||
ScopeChatDelete ScopeName = "chat:delete"
|
||||
ScopeChatRead ScopeName = "chat:read"
|
||||
ScopeChatUpdate ScopeName = "chat:update"
|
||||
ScopeConnectionLogRead ScopeName = "connection_log:read"
|
||||
ScopeConnectionLogUpdate ScopeName = "connection_log:update"
|
||||
ScopeCryptoKeyCreate ScopeName = "crypto_key:create"
|
||||
@@ -188,6 +192,10 @@ func (e ScopeName) Valid() bool {
|
||||
ScopeBoundaryUsageDelete,
|
||||
ScopeBoundaryUsageRead,
|
||||
ScopeBoundaryUsageUpdate,
|
||||
ScopeChatCreate,
|
||||
ScopeChatDelete,
|
||||
ScopeChatRead,
|
||||
ScopeChatUpdate,
|
||||
ScopeConnectionLogRead,
|
||||
ScopeConnectionLogUpdate,
|
||||
ScopeCryptoKeyCreate,
|
||||
@@ -349,6 +357,10 @@ func AllScopeNameValues() []ScopeName {
|
||||
ScopeBoundaryUsageDelete,
|
||||
ScopeBoundaryUsageRead,
|
||||
ScopeBoundaryUsageUpdate,
|
||||
ScopeChatCreate,
|
||||
ScopeChatDelete,
|
||||
ScopeChatRead,
|
||||
ScopeChatUpdate,
|
||||
ScopeConnectionLogRead,
|
||||
ScopeConnectionLogUpdate,
|
||||
ScopeCryptoKeyCreate,
|
||||
|
||||
@@ -139,7 +139,6 @@ const AgentAPIVersionREST = "1.0"
|
||||
func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.PatchLogs
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -1832,6 +1831,10 @@ func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []code
|
||||
func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
query := r.URL.Query()
|
||||
gitRef := chatGitRef{
|
||||
Branch: strings.TrimSpace(query.Get("git_branch")),
|
||||
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
|
||||
}
|
||||
// Either match or configID must be provided!
|
||||
match := query.Get("match")
|
||||
if match == "" {
|
||||
@@ -1854,7 +1857,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
|
||||
// listen determines if the request will wait for a
|
||||
// new token to be issued!
|
||||
listen := r.URL.Query().Has("listen")
|
||||
listen := query.Has("listen")
|
||||
|
||||
var externalAuthConfig *externalauth.Config
|
||||
for _, extAuth := range api.ExternalAuthConfigs {
|
||||
@@ -1925,6 +1928,13 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Persist git refs as soon as the agent requests external auth so branch
|
||||
// context is retained even if the flow requires an out-of-band login.
|
||||
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
|
||||
//nolint:gocritic // System context required to persist chat git refs.
|
||||
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, gitRef)
|
||||
}
|
||||
|
||||
var previousToken *database.ExternalAuthLink
|
||||
// handleRetrying will attempt to continually check for a new token
|
||||
// if listen is true. This is useful if an error is encountered in the
|
||||
@@ -1938,7 +1948,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace)
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
|
||||
}
|
||||
|
||||
// This is the URL that will redirect the user with a state token.
|
||||
@@ -1996,10 +2006,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace) {
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
|
||||
// Since we're ticking frequently and this sign-in operation is rare,
|
||||
// we are OK with polling to avoid the complexity of pubsub.
|
||||
ticker, done := api.NewTicker(time.Second)
|
||||
@@ -2069,6 +2080,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
+15
-8
@@ -404,7 +404,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
|
||||
AvatarURL: member.AvatarURL,
|
||||
}
|
||||
|
||||
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, r, nil)
|
||||
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, &createWorkspaceOptions{
|
||||
remoteAddr: r.RemoteAddr,
|
||||
})
|
||||
if err != nil {
|
||||
httperror.WriteResponseError(ctx, rw, err)
|
||||
return
|
||||
@@ -500,7 +502,9 @@ func (api *API) postUserWorkspaces(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
defer commitAudit()
|
||||
|
||||
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, r, nil)
|
||||
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, &createWorkspaceOptions{
|
||||
remoteAddr: r.RemoteAddr,
|
||||
})
|
||||
if err != nil {
|
||||
httperror.WriteResponseError(ctx, rw, err)
|
||||
return
|
||||
@@ -522,6 +526,10 @@ type createWorkspaceOptions struct {
|
||||
// postCreateInTX is a function that is called within the transaction, after
|
||||
// the workspace is created but before the workspace build is created.
|
||||
postCreateInTX func(ctx context.Context, tx database.Store, workspace database.Workspace) error
|
||||
// remoteAddr is the IP address of the request initiator, used for
|
||||
// audit logging. HTTP handlers should pass r.RemoteAddr;
|
||||
// programmatic callers may leave it empty.
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func createWorkspace(
|
||||
@@ -531,7 +539,6 @@ func createWorkspace(
|
||||
api *API,
|
||||
owner workspaceOwner,
|
||||
req codersdk.CreateWorkspaceRequest,
|
||||
r *http.Request,
|
||||
opts *createWorkspaceOptions,
|
||||
) (codersdk.Workspace, error) {
|
||||
if opts == nil {
|
||||
@@ -545,7 +552,7 @@ func createWorkspace(
|
||||
|
||||
// This is a premature auth check to avoid doing unnecessary work if the user
|
||||
// doesn't have permission to create a workspace.
|
||||
if !api.Authorize(r, policy.ActionCreate,
|
||||
if !api.HTTPAuth.AuthorizeContext(ctx, policy.ActionCreate,
|
||||
rbac.ResourceWorkspace.InOrg(template.OrganizationID).WithOwner(owner.ID.String())) {
|
||||
// If this check fails, return a proper unauthorized error to the user to indicate
|
||||
// what is going on.
|
||||
@@ -562,14 +569,14 @@ func createWorkspace(
|
||||
|
||||
// Do this upfront to save work. If this fails, the rest of the work
|
||||
// would be wasted.
|
||||
if !api.Authorize(r, policy.ActionCreate,
|
||||
if !api.HTTPAuth.AuthorizeContext(ctx, policy.ActionCreate,
|
||||
rbac.ResourceWorkspace.InOrg(template.OrganizationID).WithOwner(owner.ID.String())) {
|
||||
return codersdk.Workspace{}, httperror.ErrResourceNotFound
|
||||
}
|
||||
// The user also needs permission to use the template. At this point they have
|
||||
// read perms, but not necessarily "use". This is also checked in `db.InsertWorkspace`.
|
||||
// Doing this up front can save some work below if the user doesn't have permission.
|
||||
if !api.Authorize(r, policy.ActionUse, template) {
|
||||
if !api.HTTPAuth.AuthorizeContext(ctx, policy.ActionUse, template) {
|
||||
return codersdk.Workspace{}, httperror.NewResponseError(http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Unauthorized access to use the template %q.", template.Name),
|
||||
Detail: "Although you are able to view the template, you are unable to create a workspace using it. " +
|
||||
@@ -801,9 +808,9 @@ func createWorkspace(
|
||||
db,
|
||||
api.FileCache,
|
||||
func(action policy.Action, object rbac.Objecter) bool {
|
||||
return api.Authorize(r, action, object)
|
||||
return api.HTTPAuth.AuthorizeContext(ctx, action, object)
|
||||
},
|
||||
audit.WorkspaceBuildBaggageFromRequest(r),
|
||||
audit.WorkspaceBuildBaggage{IP: opts.remoteAddr},
|
||||
)
|
||||
return err
|
||||
}, nil)
|
||||
|
||||
@@ -638,6 +638,14 @@ type ExternalAuthRequest struct {
|
||||
ID string
|
||||
// Match is an arbitrary string matched against the regex of the provider.
|
||||
Match string
|
||||
// GitBranch is the current git branch in the working directory.
|
||||
// Sent by the agent so the control plane can resolve diffs
|
||||
// without SSHing into the workspace.
|
||||
GitBranch string
|
||||
// GitRemoteOrigin is the remote origin URL of the git repository.
|
||||
// Sent by the agent so the control plane can resolve diffs
|
||||
// without SSHing into the workspace.
|
||||
GitRemoteOrigin string
|
||||
// Listen indicates that the request should be long-lived and listen for
|
||||
// a new token to be requested.
|
||||
Listen bool
|
||||
@@ -653,6 +661,12 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
|
||||
if req.Listen {
|
||||
q.Set("listen", "true")
|
||||
}
|
||||
if req.GitBranch != "" {
|
||||
q.Set("git_branch", req.GitBranch)
|
||||
}
|
||||
if req.GitRemoteOrigin != "" {
|
||||
q.Set("git_remote_origin", req.GitRemoteOrigin)
|
||||
}
|
||||
reqURL := "/api/v2/workspaceagents/me/external-auth?" + q.Encode()
|
||||
res, err := c.SDK.Request(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -153,3 +153,33 @@ func TestRewriteDERPMap(t *testing.T) {
|
||||
require.Equal(t, "coconuts.org", node.HostName)
|
||||
require.Equal(t, 44558, node.DERPPort)
|
||||
}
|
||||
|
||||
func TestExternalAuthRequestQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("IncludesGitRefFieldsAndOmitsWorkdir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/api/v2/workspaceagents/me/external-auth", r.URL.Path)
|
||||
require.Equal(t, "true", r.URL.Query().Get("listen"))
|
||||
require.Equal(t, "main", r.URL.Query().Get("git_branch"))
|
||||
require.Equal(t, "https://github.com/coder/coder.git", r.URL.Query().Get("git_remote_origin"))
|
||||
require.False(t, r.URL.Query().Has("workdir"))
|
||||
_, _ = w.Write([]byte(`{"type":"github","access_token":"token"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
parsedURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := agentsdk.New(parsedURL, agentsdk.WithFixedToken("token"))
|
||||
_, err = client.ExternalAuth(testutil.Context(t, testutil.WaitShort), agentsdk.ExternalAuthRequest{
|
||||
Match: "github.com",
|
||||
Listen: true,
|
||||
GitBranch: "main",
|
||||
GitRemoteOrigin: "https://github.com/coder/coder.git",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,6 +33,11 @@ const (
|
||||
APIKeyScopeBoundaryUsageDelete APIKeyScope = "boundary_usage:delete"
|
||||
APIKeyScopeBoundaryUsageRead APIKeyScope = "boundary_usage:read"
|
||||
APIKeyScopeBoundaryUsageUpdate APIKeyScope = "boundary_usage:update"
|
||||
APIKeyScopeChatAll APIKeyScope = "chat:*"
|
||||
APIKeyScopeChatCreate APIKeyScope = "chat:create"
|
||||
APIKeyScopeChatDelete APIKeyScope = "chat:delete"
|
||||
APIKeyScopeChatRead APIKeyScope = "chat:read"
|
||||
APIKeyScopeChatUpdate APIKeyScope = "chat:update"
|
||||
APIKeyScopeCoderAll APIKeyScope = "coder:all"
|
||||
APIKeyScopeCoderApikeysManageSelf APIKeyScope = "coder:apikeys.manage_self"
|
||||
APIKeyScopeCoderApplicationConnect APIKeyScope = "coder:application_connect"
|
||||
|
||||
@@ -0,0 +1,903 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
// ChatStatus represents the status of a chat.
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
)
|
||||
|
||||
// Chat represents a chat session with an AI agent.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
WorkspaceAgentID *uuid.UUID `json:"workspace_agent_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
|
||||
Title string `json:"title"`
|
||||
Status ChatStatus `json:"status"`
|
||||
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in a chat.
|
||||
type ChatMessage struct {
|
||||
ID int64 `json:"id"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
Role string `json:"role"`
|
||||
Content []ChatMessagePart `json:"content,omitempty"`
|
||||
Usage *ChatMessageUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessageUsage contains token usage information for a chat message.
|
||||
type ChatMessageUsage struct {
|
||||
InputTokens *int64 `json:"input_tokens,omitempty"`
|
||||
OutputTokens *int64 `json:"output_tokens,omitempty"`
|
||||
TotalTokens *int64 `json:"total_tokens,omitempty"`
|
||||
ReasoningTokens *int64 `json:"reasoning_tokens,omitempty"`
|
||||
CacheCreationTokens *int64 `json:"cache_creation_tokens,omitempty"`
|
||||
CacheReadTokens *int64 `json:"cache_read_tokens,omitempty"`
|
||||
ContextLimit *int64 `json:"context_limit,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessagePartType represents a structured message part type.
|
||||
type ChatMessagePartType string
|
||||
|
||||
const (
|
||||
ChatMessagePartTypeText ChatMessagePartType = "text"
|
||||
ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning"
|
||||
ChatMessagePartTypeToolCall ChatMessagePartType = "tool-call"
|
||||
ChatMessagePartTypeToolResult ChatMessagePartType = "tool-result"
|
||||
ChatMessagePartTypeSource ChatMessagePartType = "source"
|
||||
ChatMessagePartTypeFile ChatMessagePartType = "file"
|
||||
)
|
||||
|
||||
// ChatMessagePart is a structured chunk of a chat message.
|
||||
type ChatMessagePart struct {
|
||||
Type ChatMessagePartType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
Args json.RawMessage `json:"args,omitempty"`
|
||||
ArgsDelta string `json:"args_delta,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
ResultDelta string `json:"result_delta,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
SourceID string `json:"source_id,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ChatInputPartType represents an input part type for user chat input.
|
||||
type ChatInputPartType string
|
||||
|
||||
const (
|
||||
ChatInputPartTypeText ChatInputPartType = "text"
|
||||
)
|
||||
|
||||
// ChatInputPart is a single user input part for creating a chat.
|
||||
type ChatInputPart struct {
|
||||
Type ChatInputPartType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatRequest is the request to create a new chat.
|
||||
type CreateChatRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
}
|
||||
|
||||
// UpdateChatRequest is the request to update a chat.
|
||||
type UpdateChatRequest struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// CreateChatMessageRequest is the request to add a message to a chat.
|
||||
type CreateChatMessageRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
}
|
||||
|
||||
// EditChatMessageRequest is the request to edit a user message in a chat.
|
||||
type EditChatMessageRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
}
|
||||
|
||||
// CreateChatMessageResponse is the response from adding a message to a chat.
|
||||
type CreateChatMessageResponse struct {
|
||||
Message *ChatMessage `json:"message,omitempty"`
|
||||
QueuedMessage *ChatQueuedMessage `json:"queued_message,omitempty"`
|
||||
Queued bool `json:"queued"`
|
||||
}
|
||||
|
||||
// ChatWithMessages is a chat along with its messages.
|
||||
type ChatWithMessages struct {
|
||||
Chat Chat `json:"chat"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages"`
|
||||
}
|
||||
|
||||
// ChatModelProviderUnavailableReason explains why a provider cannot be used.
|
||||
type ChatModelProviderUnavailableReason string
|
||||
|
||||
const (
|
||||
ChatModelProviderUnavailableMissingAPIKey ChatModelProviderUnavailableReason = "missing_api_key"
|
||||
ChatModelProviderUnavailableFetchFailed ChatModelProviderUnavailableReason = "fetch_failed"
|
||||
)
|
||||
|
||||
// ChatModel represents a model in the chat model catalog.
|
||||
type ChatModel struct {
|
||||
ID string `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// ChatModelProvider represents provider availability and model results.
|
||||
type ChatModelProvider struct {
|
||||
Provider string `json:"provider"`
|
||||
Available bool `json:"available"`
|
||||
UnavailableReason ChatModelProviderUnavailableReason `json:"unavailable_reason,omitempty"`
|
||||
Models []ChatModel `json:"models"`
|
||||
}
|
||||
|
||||
// ChatModelsResponse is the catalog returned from chat model discovery.
|
||||
type ChatModelsResponse struct {
|
||||
Providers []ChatModelProvider `json:"providers"`
|
||||
}
|
||||
|
||||
// ChatProviderConfigSource describes how a provider entry is sourced.
|
||||
type ChatProviderConfigSource string
|
||||
|
||||
const (
|
||||
ChatProviderConfigSourceDatabase ChatProviderConfigSource = "database"
|
||||
ChatProviderConfigSourceEnvPreset ChatProviderConfigSource = "env_preset"
|
||||
ChatProviderConfigSourceSupported ChatProviderConfigSource = "supported"
|
||||
)
|
||||
|
||||
// ChatProviderConfig is an admin-managed provider configuration.
|
||||
type ChatProviderConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HasAPIKey bool `json:"has_api_key"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Source ChatProviderConfigSource `json:"source"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// CreateChatProviderConfigRequest creates a chat provider config.
|
||||
type CreateChatProviderConfigRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatProviderConfigRequest updates a chat provider config.
|
||||
type UpdateChatProviderConfigRequest struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey *string `json:"api_key,omitempty"`
|
||||
BaseURL *string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelConfig is an admin-managed model configuration.
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
ContextLimit int64 `json:"context_limit"`
|
||||
CompressionThreshold int32 `json:"compression_threshold"`
|
||||
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatModelProviderOptions contains typed provider-specific options.
|
||||
//
|
||||
// Note: Azure models use the `openai` options shape.
|
||||
// Note: Bedrock models use the `anthropic` options shape.
|
||||
type ChatModelProviderOptions struct {
|
||||
OpenAI *ChatModelOpenAIProviderOptions `json:"openai,omitempty"`
|
||||
Anthropic *ChatModelAnthropicProviderOptions `json:"anthropic,omitempty"`
|
||||
Google *ChatModelGoogleProviderOptions `json:"google,omitempty"`
|
||||
OpenAICompat *ChatModelOpenAICompatProviderOptions `json:"openaicompat,omitempty"`
|
||||
OpenRouter *ChatModelOpenRouterProviderOptions `json:"openrouter,omitempty"`
|
||||
Vercel *ChatModelVercelProviderOptions `json:"vercel,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenAIProviderOptions configures OpenAI provider behavior.
|
||||
type ChatModelOpenAIProviderOptions struct {
|
||||
Include []string `json:"include,omitempty"`
|
||||
Instructions *string `json:"instructions,omitempty"`
|
||||
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *bool `json:"log_probs,omitempty"`
|
||||
TopLogProbs *int64 `json:"top_log_probs,omitempty"`
|
||||
MaxToolCalls *int64 `json:"max_tool_calls,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
ReasoningSummary *string `json:"reasoning_summary,omitempty"`
|
||||
MaxCompletionTokens *int64 `json:"max_completion_tokens,omitempty"`
|
||||
TextVerbosity *string `json:"text_verbosity,omitempty"`
|
||||
Prediction map[string]any `json:"prediction,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key,omitempty"`
|
||||
SafetyIdentifier *string `json:"safety_identifier,omitempty"`
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
StructuredOutputs *bool `json:"structured_outputs,omitempty"`
|
||||
StrictJSONSchema *bool `json:"strict_json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelAnthropicThinkingOptions configures Anthropic thinking budget.
|
||||
type ChatModelAnthropicThinkingOptions struct {
|
||||
BudgetTokens *int64 `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelAnthropicProviderOptions configures Anthropic provider behavior.
|
||||
type ChatModelAnthropicProviderOptions struct {
|
||||
SendReasoning *bool `json:"send_reasoning,omitempty"`
|
||||
Thinking *ChatModelAnthropicThinkingOptions `json:"thinking,omitempty"`
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleThinkingConfig configures Google thinking behavior.
|
||||
type ChatModelGoogleThinkingConfig struct {
|
||||
ThinkingBudget *int64 `json:"thinking_budget,omitempty"`
|
||||
IncludeThoughts *bool `json:"include_thoughts,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleSafetySetting configures Google safety filtering.
|
||||
type ChatModelGoogleSafetySetting struct {
|
||||
Category string `json:"category,omitempty"`
|
||||
Threshold string `json:"threshold,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleProviderOptions configures Google provider behavior.
|
||||
type ChatModelGoogleProviderOptions struct {
|
||||
ThinkingConfig *ChatModelGoogleThinkingConfig `json:"thinking_config,omitempty"`
|
||||
CachedContent string `json:"cached_content,omitempty"`
|
||||
SafetySettings []ChatModelGoogleSafetySetting `json:"safety_settings,omitempty"`
|
||||
Threshold string `json:"threshold,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenAICompatProviderOptions configures OpenAI-compatible behavior.
|
||||
type ChatModelOpenAICompatProviderOptions struct {
|
||||
User *string `json:"user,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenRouterReasoningOptions configures OpenRouter reasoning behavior.
|
||||
type ChatModelOpenRouterReasoningOptions struct {
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
Exclude *bool `json:"exclude,omitempty"`
|
||||
MaxTokens *int64 `json:"max_tokens,omitempty"`
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenRouterProvider configures OpenRouter routing preferences.
|
||||
type ChatModelOpenRouterProvider struct {
|
||||
Order []string `json:"order,omitempty"`
|
||||
AllowFallbacks *bool `json:"allow_fallbacks,omitempty"`
|
||||
RequireParameters *bool `json:"require_parameters,omitempty"`
|
||||
DataCollection *string `json:"data_collection,omitempty"`
|
||||
Only []string `json:"only,omitempty"`
|
||||
Ignore []string `json:"ignore,omitempty"`
|
||||
Quantizations []string `json:"quantizations,omitempty"`
|
||||
Sort *string `json:"sort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenRouterProviderOptions configures OpenRouter provider behavior.
|
||||
type ChatModelOpenRouterProviderOptions struct {
|
||||
Reasoning *ChatModelOpenRouterReasoningOptions `json:"reasoning,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
IncludeUsage *bool `json:"include_usage,omitempty"`
|
||||
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *bool `json:"log_probs,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
Provider *ChatModelOpenRouterProvider `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelVercelReasoningOptions configures Vercel reasoning behavior.
|
||||
type ChatModelVercelReasoningOptions struct {
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
MaxTokens *int64 `json:"max_tokens,omitempty"`
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Exclude *bool `json:"exclude,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelVercelGatewayProviderOptions configures Vercel routing behavior.
|
||||
type ChatModelVercelGatewayProviderOptions struct {
|
||||
Order []string `json:"order,omitempty"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelVercelProviderOptions configures Vercel provider behavior.
|
||||
type ChatModelVercelProviderOptions struct {
|
||||
Reasoning *ChatModelVercelReasoningOptions `json:"reasoning,omitempty"`
|
||||
ProviderOptions *ChatModelVercelGatewayProviderOptions `json:"providerOptions,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs *int64 `json:"top_logprobs,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelCallConfig configures per-call model behavior defaults.
|
||||
type ChatModelCallConfig struct {
|
||||
MaxOutputTokens *int64 `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int64 `json:"top_k,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
ProviderOptions *ChatModelProviderOptions `json:"provider_options,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatModelConfigRequest creates a chat model config.
|
||||
type CreateChatModelConfigRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
IsDefault *bool `json:"is_default,omitempty"`
|
||||
ContextLimit *int64 `json:"context_limit,omitempty"`
|
||||
CompressionThreshold *int32 `json:"compression_threshold,omitempty"`
|
||||
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatModelConfigRequest updates a chat model config.
|
||||
type UpdateChatModelConfigRequest struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
IsDefault *bool `json:"is_default,omitempty"`
|
||||
ContextLimit *int64 `json:"context_limit,omitempty"`
|
||||
CompressionThreshold *int32 `json:"compression_threshold,omitempty"`
|
||||
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
|
||||
}
|
||||
|
||||
// ChatGitChange represents a git file change detected during a chat session.
|
||||
type ChatGitChange struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
FilePath string `json:"file_path"`
|
||||
ChangeType string `json:"change_type"` // added, modified, deleted, renamed
|
||||
OldPath *string `json:"old_path,omitempty"`
|
||||
DiffSummary *string `json:"diff_summary,omitempty"`
|
||||
DetectedAt time.Time `json:"detected_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatDiffStatus represents cached diff status for a chat. The URL
|
||||
// may point to a pull request or a branch page depending on whether
|
||||
// a PR has been opened.
|
||||
type ChatDiffStatus struct {
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
URL *string `json:"url,omitempty"`
|
||||
PullRequestState *string `json:"pull_request_state,omitempty"`
|
||||
ChangesRequested bool `json:"changes_requested"`
|
||||
Additions int32 `json:"additions"`
|
||||
Deletions int32 `json:"deletions"`
|
||||
ChangedFiles int32 `json:"changed_files"`
|
||||
RefreshedAt *time.Time `json:"refreshed_at,omitempty" format:"date-time"`
|
||||
StaleAt *time.Time `json:"stale_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatDiffContents represents the resolved diff text for a chat.
|
||||
type ChatDiffContents struct {
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
RemoteOrigin *string `json:"remote_origin,omitempty"`
|
||||
Branch *string `json:"branch,omitempty"`
|
||||
PullRequestURL *string `json:"pull_request_url,omitempty"`
|
||||
Diff string `json:"diff,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamEventType represents the kind of chat stream update.
|
||||
type ChatStreamEventType string
|
||||
|
||||
const (
|
||||
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
|
||||
ChatStreamEventTypeMessage ChatStreamEventType = "message"
|
||||
ChatStreamEventTypeStatus ChatStreamEventType = "status"
|
||||
ChatStreamEventTypeError ChatStreamEventType = "error"
|
||||
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
|
||||
)
|
||||
|
||||
// ChatQueuedMessage represents a queued message waiting to be processed.
|
||||
type ChatQueuedMessage struct {
|
||||
ID int64 `json:"id"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Content []ChatMessagePart `json:"content"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatStreamMessagePart is a streamed message part update.
|
||||
type ChatStreamMessagePart struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Part ChatMessagePart `json:"part"`
|
||||
}
|
||||
|
||||
// ChatStreamStatus represents an updated chat status.
|
||||
type ChatStreamStatus struct {
|
||||
Status ChatStatus `json:"status"`
|
||||
}
|
||||
|
||||
// ChatStreamError represents an error event in the stream.
|
||||
type ChatStreamError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ChatStreamEvent represents a real-time update for chat streaming.
|
||||
type ChatStreamEvent struct {
|
||||
Type ChatStreamEventType `json:"type"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Message *ChatMessage `json:"message,omitempty"`
|
||||
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
|
||||
Status *ChatStreamStatus `json:"status,omitempty"`
|
||||
Error *ChatStreamError `json:"error,omitempty"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
|
||||
}
|
||||
|
||||
type chatStreamEnvelope struct {
|
||||
Type ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ListChats returns all chats for the authenticated user.
|
||||
func (c *Client) ListChats(ctx context.Context) ([]Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var chats []Chat
|
||||
return chats, json.NewDecoder(res.Body).Decode(&chats)
|
||||
}
|
||||
|
||||
// ListChatModels returns the available chat model catalog.
|
||||
func (c *Client) ListChatModels(ctx context.Context) (ChatModelsResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/models", nil)
|
||||
if err != nil {
|
||||
return ChatModelsResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatModelsResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var catalog ChatModelsResponse
|
||||
return catalog, json.NewDecoder(res.Body).Decode(&catalog)
|
||||
}
|
||||
|
||||
// ListChatProviders returns admin-managed chat provider configs.
|
||||
func (c *Client) ListChatProviders(ctx context.Context) ([]ChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/providers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var providers []ChatProviderConfig
|
||||
return providers, json.NewDecoder(res.Body).Decode(&providers)
|
||||
}
|
||||
|
||||
// CreateChatProvider creates an admin-managed chat provider config.
|
||||
func (c *Client) CreateChatProvider(ctx context.Context, req CreateChatProviderConfigRequest) (ChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats/providers", req)
|
||||
if err != nil {
|
||||
return ChatProviderConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return ChatProviderConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var provider ChatProviderConfig
|
||||
return provider, json.NewDecoder(res.Body).Decode(&provider)
|
||||
}
|
||||
|
||||
// UpdateChatProvider updates an admin-managed chat provider config.
|
||||
func (c *Client) UpdateChatProvider(ctx context.Context, providerID uuid.UUID, req UpdateChatProviderConfigRequest) (ChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/chats/providers/%s", providerID), req)
|
||||
if err != nil {
|
||||
return ChatProviderConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatProviderConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var provider ChatProviderConfig
|
||||
return provider, json.NewDecoder(res.Body).Decode(&provider)
|
||||
}
|
||||
|
||||
// DeleteChatProvider deletes an admin-managed chat provider config.
|
||||
func (c *Client) DeleteChatProvider(ctx context.Context, providerID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/providers/%s", providerID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListChatModelConfigs returns admin-managed chat model configs.
|
||||
func (c *Client) ListChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/model-configs", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var configs []ChatModelConfig
|
||||
return configs, json.NewDecoder(res.Body).Decode(&configs)
|
||||
}
|
||||
|
||||
// CreateChatModelConfig creates an admin-managed chat model config.
|
||||
func (c *Client) CreateChatModelConfig(ctx context.Context, req CreateChatModelConfigRequest) (ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats/model-configs", req)
|
||||
if err != nil {
|
||||
return ChatModelConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return ChatModelConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var config ChatModelConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
// UpdateChatModelConfig updates an admin-managed chat model config.
|
||||
func (c *Client) UpdateChatModelConfig(ctx context.Context, modelConfigID uuid.UUID, req UpdateChatModelConfigRequest) (ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/chats/model-configs/%s", modelConfigID), req)
|
||||
if err != nil {
|
||||
return ChatModelConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatModelConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var config ChatModelConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfig deletes an admin-managed chat model config.
|
||||
func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/model-configs/%s", modelConfigID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChat creates a new chat.
|
||||
func (c *Client) CreateChat(ctx context.Context, req CreateChatRequest) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats", req)
|
||||
if err != nil {
|
||||
return Chat{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return Chat{}, ReadBodyAsError(res)
|
||||
}
|
||||
var chat Chat
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// StreamChat streams chat updates in real time.
|
||||
//
|
||||
// The returned channel includes initial snapshot events first, followed by
|
||||
// live updates. Callers must close the returned io.Closer to release the
|
||||
// websocket connection when done.
|
||||
func (c *Client) StreamChat(ctx context.Context, chatID uuid.UUID) (<-chan ChatStreamEvent, io.Closer, error) {
|
||||
conn, err := c.Dial(
|
||||
ctx,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/stream", chatID),
|
||||
&websocket.DialOptions{CompressionMode: websocket.CompressionDisabled},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn.SetReadLimit(1 << 22) // 4MiB
|
||||
|
||||
streamCtx, streamCancel := context.WithCancel(ctx)
|
||||
events := make(chan ChatStreamEvent, 128)
|
||||
|
||||
send := func(event ChatStreamEvent) bool {
|
||||
if event.ChatID == uuid.Nil {
|
||||
event.ChatID = chatID
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return false
|
||||
case events <- event:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
defer streamCancel()
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
switch websocket.CloseStatus(err) {
|
||||
case websocket.StatusNormalClosure, websocket.StatusGoingAway:
|
||||
return
|
||||
}
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf("read chat stream: %v", err),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var batch []ChatStreamEvent
|
||||
decodeErr := json.Unmarshal(envelope.Data, &batch)
|
||||
if decodeErr == nil {
|
||||
for _, streamedEvent := range batch {
|
||||
if !send(streamedEvent) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
{
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf(
|
||||
"decode chat stream event batch: %v",
|
||||
decodeErr,
|
||||
),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
message := "chat stream returned an error"
|
||||
if len(envelope.Data) > 0 {
|
||||
var response Response
|
||||
if err := json.Unmarshal(envelope.Data, &response); err == nil {
|
||||
message = formatChatStreamResponseError(response)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(string(envelope.Data))
|
||||
if trimmed != "" {
|
||||
message = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
return
|
||||
default:
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return events, closeFunc(func() error {
|
||||
streamCancel()
|
||||
return conn.Close(websocket.StatusNormalClosure, "")
|
||||
}), nil
|
||||
}
|
||||
|
||||
// GetChat returns a chat by ID, including its messages.
|
||||
func (c *Client) GetChat(ctx context.Context, chatID uuid.UUID) (ChatWithMessages, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
|
||||
if err != nil {
|
||||
return ChatWithMessages{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatWithMessages{}, ReadBodyAsError(res)
|
||||
}
|
||||
var chat ChatWithMessages
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// DeleteChat deletes a chat by ID.
|
||||
func (c *Client) DeleteChat(ctx context.Context, chatID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChatMessage adds a message to a chat.
|
||||
func (c *Client) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req CreateChatMessageRequest) (CreateChatMessageResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/messages", chatID), req)
|
||||
if err != nil {
|
||||
return CreateChatMessageResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return CreateChatMessageResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp CreateChatMessageResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// EditChatMessage edits an existing user message in a chat and re-runs from there.
|
||||
func (c *Client) EditChatMessage(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
messageID int64,
|
||||
req EditChatMessageRequest,
|
||||
) (ChatMessage, error) {
|
||||
res, err := c.Request(
|
||||
ctx,
|
||||
http.MethodPatch,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/messages/%d", chatID, messageID),
|
||||
req,
|
||||
)
|
||||
if err != nil {
|
||||
return ChatMessage{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatMessage{}, ReadBodyAsError(res)
|
||||
}
|
||||
var message ChatMessage
|
||||
return message, json.NewDecoder(res.Body).Decode(&message)
|
||||
}
|
||||
|
||||
// InterruptChat cancels an in-flight chat run and leaves it waiting.
|
||||
func (c *Client) InterruptChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/interrupt", chatID), nil)
|
||||
if err != nil {
|
||||
return Chat{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return Chat{}, ReadBodyAsError(res)
|
||||
}
|
||||
var chat Chat
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// GetChatGitChanges returns git changes for a chat.
|
||||
func (c *Client) GetChatGitChanges(ctx context.Context, chatID uuid.UUID) ([]ChatGitChange, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/git-changes", chatID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var changes []ChatGitChange
|
||||
return changes, json.NewDecoder(res.Body).Decode(&changes)
|
||||
}
|
||||
|
||||
// GetChatDiffStatus returns cached GitHub pull request diff status for a chat.
|
||||
func (c *Client) GetChatDiffStatus(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/diff-status", chatID), nil)
|
||||
if err != nil {
|
||||
return ChatDiffStatus{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatDiffStatus{}, ReadBodyAsError(res)
|
||||
}
|
||||
var status ChatDiffStatus
|
||||
return status, json.NewDecoder(res.Body).Decode(&status)
|
||||
}
|
||||
|
||||
// GetChatDiffContents returns resolved diff contents for a chat.
|
||||
func (c *Client) GetChatDiffContents(ctx context.Context, chatID uuid.UUID) (ChatDiffContents, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/diff", chatID), nil)
|
||||
if err != nil {
|
||||
return ChatDiffContents{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatDiffContents{}, ReadBodyAsError(res)
|
||||
}
|
||||
var diff ChatDiffContents
|
||||
return diff, json.NewDecoder(res.Body).Decode(&diff)
|
||||
}
|
||||
|
||||
func formatChatStreamResponseError(response Response) string {
|
||||
message := strings.TrimSpace(response.Message)
|
||||
detail := strings.TrimSpace(response.Detail)
|
||||
switch {
|
||||
case message == "" && detail == "":
|
||||
return "chat stream returned an error"
|
||||
case message == "":
|
||||
return detail
|
||||
case detail == "":
|
||||
return message
|
||||
default:
|
||||
return fmt.Sprintf("%s: %s", message, detail)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package codersdk_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestChatModelProviderOptions_MarshalJSON_UsesPlainProviderPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sendReasoning := true
|
||||
effort := "high"
|
||||
|
||||
raw, err := json.Marshal(codersdk.ChatModelProviderOptions{
|
||||
Anthropic: &codersdk.ChatModelAnthropicProviderOptions{
|
||||
SendReasoning: &sendReasoning,
|
||||
Effort: &effort,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(raw), `"type":"anthropic.options"`)
|
||||
require.NotContains(t, string(raw), `"data":`)
|
||||
require.Contains(t, string(raw), `"send_reasoning":true`)
|
||||
require.Contains(t, string(raw), `"effort":"high"`)
|
||||
}
|
||||
|
||||
func TestChatModelProviderOptions_UnmarshalJSON_ParsesPlainProviderPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := []byte(`{
|
||||
"anthropic": {
|
||||
"send_reasoning": true,
|
||||
"effort": "high"
|
||||
}
|
||||
}`)
|
||||
|
||||
var decoded codersdk.ChatModelProviderOptions
|
||||
err := json.Unmarshal(raw, &decoded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decoded.Anthropic)
|
||||
require.NotNil(t, decoded.Anthropic.SendReasoning)
|
||||
require.True(t, *decoded.Anthropic.SendReasoning)
|
||||
require.NotNil(t, decoded.Anthropic.Effort)
|
||||
require.Equal(
|
||||
t,
|
||||
"high",
|
||||
*decoded.Anthropic.Effort,
|
||||
)
|
||||
}
|
||||
+77
-63
@@ -579,68 +579,69 @@ type DeploymentValues struct {
|
||||
DocsURL serpent.URL `json:"docs_url,omitempty"`
|
||||
RedirectToAccessURL serpent.Bool `json:"redirect_to_access_url,omitempty"`
|
||||
// HTTPAddress is a string because it may be set to zero to disable.
|
||||
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
|
||||
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
|
||||
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
|
||||
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
|
||||
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
|
||||
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
|
||||
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
|
||||
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
|
||||
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
|
||||
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
|
||||
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
|
||||
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
|
||||
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
|
||||
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
|
||||
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
|
||||
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
|
||||
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
|
||||
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
|
||||
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
|
||||
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
|
||||
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
|
||||
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
|
||||
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
|
||||
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
|
||||
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
|
||||
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
|
||||
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
|
||||
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
|
||||
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
|
||||
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
|
||||
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
|
||||
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
|
||||
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
|
||||
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
|
||||
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
|
||||
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
|
||||
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
|
||||
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
|
||||
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
|
||||
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
|
||||
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
|
||||
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
|
||||
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
|
||||
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
|
||||
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
|
||||
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
|
||||
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
|
||||
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
|
||||
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
|
||||
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
|
||||
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
|
||||
AI AIConfig `json:"ai,omitempty"`
|
||||
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
|
||||
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
|
||||
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
|
||||
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
|
||||
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
|
||||
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
|
||||
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
|
||||
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
|
||||
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
|
||||
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
|
||||
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
|
||||
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
|
||||
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
|
||||
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
|
||||
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
|
||||
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
|
||||
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
|
||||
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
|
||||
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
|
||||
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
|
||||
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
|
||||
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
|
||||
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
|
||||
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
|
||||
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
|
||||
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
|
||||
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
|
||||
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
|
||||
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
|
||||
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
|
||||
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
|
||||
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
|
||||
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
|
||||
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
|
||||
ExternalAuthGithubDefaultProviderEnable serpent.Bool `json:"external_auth_github_default_provider_enable,omitempty" typescript:",notnull"`
|
||||
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
|
||||
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
|
||||
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
|
||||
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
|
||||
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
|
||||
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
|
||||
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
|
||||
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
|
||||
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
|
||||
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
|
||||
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
|
||||
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
|
||||
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
|
||||
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
|
||||
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
|
||||
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
|
||||
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
|
||||
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
|
||||
AI AIConfig `json:"ai,omitempty"`
|
||||
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
|
||||
|
||||
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
|
||||
WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"`
|
||||
@@ -3153,6 +3154,15 @@ Write out the current server config as YAML to stdout.`,
|
||||
Value: &c.ExternalAuthConfigs,
|
||||
Hidden: true,
|
||||
},
|
||||
{
|
||||
Name: "External Auth GitHub Default Provider Enable",
|
||||
Description: "Enable the default GitHub external auth provider managed by Coder.",
|
||||
Flag: "external-auth-github-default-provider-enable",
|
||||
Env: "CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE",
|
||||
YAML: "externalAuthGithubDefaultProviderEnable",
|
||||
Value: &c.ExternalAuthGithubDefaultProviderEnable,
|
||||
Default: "true",
|
||||
},
|
||||
{
|
||||
Name: "Custom wgtunnel Host",
|
||||
Description: `Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By default, this will pick the best available wgtunnel server hosted by Coder. e.g. "tunnel.example.com".`,
|
||||
@@ -3583,7 +3593,6 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupClient,
|
||||
YAML: "hideAITasks",
|
||||
},
|
||||
|
||||
// AI Bridge Options
|
||||
{
|
||||
Name: "AI Bridge Enabled",
|
||||
@@ -4264,6 +4273,7 @@ const (
|
||||
ExperimentWorkspaceUsage Experiment = "workspace-usage" // Enables the new workspace usage tracking.
|
||||
ExperimentWebPush Experiment = "web-push" // Enables web push notifications through the browser.
|
||||
ExperimentOAuth2 Experiment = "oauth2" // Enables OAuth2 provider functionality.
|
||||
ExperimentAgents Experiment = "agents" // Enables agent-powered chat functionality.
|
||||
ExperimentMCPServerHTTP Experiment = "mcp-server-http" // Enables the MCP HTTP server functionality.
|
||||
)
|
||||
|
||||
@@ -4281,6 +4291,8 @@ func (e Experiment) DisplayName() string {
|
||||
return "Browser Push Notifications"
|
||||
case ExperimentOAuth2:
|
||||
return "OAuth2 Provider Functionality"
|
||||
case ExperimentAgents:
|
||||
return "Agents"
|
||||
case ExperimentMCPServerHTTP:
|
||||
return "MCP HTTP Server Functionality"
|
||||
default:
|
||||
@@ -4299,6 +4311,7 @@ var ExperimentsKnown = Experiments{
|
||||
ExperimentWorkspaceUsage,
|
||||
ExperimentWebPush,
|
||||
ExperimentOAuth2,
|
||||
ExperimentAgents,
|
||||
ExperimentMCPServerHTTP,
|
||||
}
|
||||
|
||||
@@ -4306,6 +4319,7 @@ var ExperimentsKnown = Experiments{
|
||||
// users to opt-in to via --experimental='*'.
|
||||
// Experiments that are not ready for consumption by all users should
|
||||
// not be included here and will be essentially hidden.
|
||||
// TODO: Add ExperimentAgents to ExperimentsSafe once it is safe for general use.
|
||||
var ExperimentsSafe = Experiments{}
|
||||
|
||||
// Experiments is a list of experiments.
|
||||
|
||||
@@ -11,6 +11,7 @@ const (
|
||||
ResourceAssignRole RBACResource = "assign_role"
|
||||
ResourceAuditLog RBACResource = "audit_log"
|
||||
ResourceBoundaryUsage RBACResource = "boundary_usage"
|
||||
ResourceChat RBACResource = "chat"
|
||||
ResourceConnectionLog RBACResource = "connection_log"
|
||||
ResourceCryptoKey RBACResource = "crypto_key"
|
||||
ResourceDebugInfo RBACResource = "debug_info"
|
||||
@@ -82,6 +83,7 @@ var RBACResourceActions = map[RBACResource][]RBACAction{
|
||||
ResourceAssignRole: {ActionAssign, ActionRead, ActionUnassign},
|
||||
ResourceAuditLog: {ActionCreate, ActionRead},
|
||||
ResourceBoundaryUsage: {ActionDelete, ActionRead, ActionUpdate},
|
||||
ResourceChat: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
|
||||
ResourceConnectionLog: {ActionRead, ActionUpdate},
|
||||
ResourceCryptoKey: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
|
||||
ResourceDebugInfo: {ActionRead},
|
||||
|
||||
Generated
+1
@@ -304,6 +304,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
|
||||
}
|
||||
]
|
||||
},
|
||||
"external_auth_github_default_provider_enable": true,
|
||||
"external_token_encryption_keys": [
|
||||
"string"
|
||||
],
|
||||
|
||||
Generated
+20
-20
@@ -172,10 +172,10 @@ Status Code **200**
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Property | Value(s) |
|
||||
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| Property | Value(s) |
|
||||
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
@@ -305,10 +305,10 @@ Status Code **200**
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Property | Value(s) |
|
||||
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| Property | Value(s) |
|
||||
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
@@ -438,10 +438,10 @@ Status Code **200**
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Property | Value(s) |
|
||||
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| Property | Value(s) |
|
||||
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
@@ -533,10 +533,10 @@ Status Code **200**
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Property | Value(s) |
|
||||
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| Property | Value(s) |
|
||||
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
@@ -909,9 +909,9 @@ Status Code **200**
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Property | Value(s) |
|
||||
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| Property | Value(s) |
|
||||
|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `action` | `application_connect`, `assign`, `create`, `create_agent`, `delete`, `delete_agent`, `read`, `read_personal`, `share`, `ssh`, `start`, `stop`, `unassign`, `update`, `update_agent`, `update_personal`, `use`, `view_insights` |
|
||||
| `resource_type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
Generated
+84
-81
@@ -865,9 +865,9 @@
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `aibridge_interception:*`, `aibridge_interception:create`, `aibridge_interception:read`, `aibridge_interception:update`, `all`, `api_key:*`, `api_key:create`, `api_key:delete`, `api_key:read`, `api_key:update`, `application_connect`, `assign_org_role:*`, `assign_org_role:assign`, `assign_org_role:create`, `assign_org_role:delete`, `assign_org_role:read`, `assign_org_role:unassign`, `assign_org_role:update`, `assign_role:*`, `assign_role:assign`, `assign_role:read`, `assign_role:unassign`, `audit_log:*`, `audit_log:create`, `audit_log:read`, `boundary_usage:*`, `boundary_usage:delete`, `boundary_usage:read`, `boundary_usage:update`, `coder:all`, `coder:apikeys.manage_self`, `coder:application_connect`, `coder:templates.author`, `coder:templates.build`, `coder:workspaces.access`, `coder:workspaces.create`, `coder:workspaces.delete`, `coder:workspaces.operate`, `connection_log:*`, `connection_log:read`, `connection_log:update`, `crypto_key:*`, `crypto_key:create`, `crypto_key:delete`, `crypto_key:read`, `crypto_key:update`, `debug_info:*`, `debug_info:read`, `deployment_config:*`, `deployment_config:read`, `deployment_config:update`, `deployment_stats:*`, `deployment_stats:read`, `file:*`, `file:create`, `file:read`, `group:*`, `group:create`, `group:delete`, `group:read`, `group:update`, `group_member:*`, `group_member:read`, `idpsync_settings:*`, `idpsync_settings:read`, `idpsync_settings:update`, `inbox_notification:*`, `inbox_notification:create`, `inbox_notification:read`, `inbox_notification:update`, `license:*`, `license:create`, `license:delete`, `license:read`, `notification_message:*`, `notification_message:create`, `notification_message:delete`, `notification_message:read`, `notification_message:update`, `notification_preference:*`, `notification_preference:read`, `notification_preference:update`, `notification_template:*`, `notification_template:read`, `notification_template:update`, `oauth2_app:*`, `oauth2_app:create`, `oauth2_app:delete`, `oauth2_app:read`, `oauth2_app:update`, `oauth2_app_code_token:*`, `oauth2_app_code_token:create`, `oauth2_app_code_token:delete`, `oauth2_app_code_token:read`, `oauth2_app_secret:*`, `oauth2_app_secret:create`, `oauth2_app_secret:delete`, `oauth2_app_secret:read`, `oauth2_app_secret:update`, `organization:*`, `organization:create`, `organization:delete`, `organization:read`, `organization:update`, `organization_member:*`, `organization_member:create`, `organization_member:delete`, `organization_member:read`, `organization_member:update`, `prebuilt_workspace:*`, `prebuilt_workspace:delete`, `prebuilt_workspace:update`, `provisioner_daemon:*`, `provisioner_daemon:create`, `provisioner_daemon:delete`, `provisioner_daemon:read`, `provisioner_daemon:update`, `provisioner_jobs:*`, `provisioner_jobs:create`, `provisioner_jobs:read`, `provisioner_jobs:update`, `replicas:*`, `replicas:read`, `system:*`, `system:create`, `system:delete`, `system:read`, `system:update`, `tailnet_coordinator:*`, `tailnet_coordinator:create`, `tailnet_coordinator:delete`, `tailnet_coordinator:read`, `tailnet_coordinator:update`, `task:*`, `task:create`, `task:delete`, `task:read`, `task:update`, `template:*`, `template:create`, `template:delete`, `template:read`, `template:update`, `template:use`, `template:view_insights`, `usage_event:*`, `usage_event:create`, `usage_event:read`, `usage_event:update`, `user:*`, `user:create`, `user:delete`, `user:read`, `user:read_personal`, `user:update`, `user:update_personal`, `user_secret:*`, `user_secret:create`, `user_secret:delete`, `user_secret:read`, `user_secret:update`, `webpush_subscription:*`, `webpush_subscription:create`, `webpush_subscription:delete`, `webpush_subscription:read`, `workspace:*`, `workspace:application_connect`, `workspace:create`, `workspace:create_agent`, `workspace:delete`, `workspace:delete_agent`, `workspace:read`, `workspace:share`, `workspace:ssh`, `workspace:start`, `workspace:stop`, `workspace:update`, `workspace:update_agent`, `workspace_agent_devcontainers:*`, `workspace_agent_devcontainers:create`, `workspace_agent_resource_monitor:*`, `workspace_agent_resource_monitor:create`, `workspace_agent_resource_monitor:read`, `workspace_agent_resource_monitor:update`, `workspace_dormant:*`, `workspace_dormant:application_connect`, `workspace_dormant:create`, `workspace_dormant:create_agent`, `workspace_dormant:delete`, `workspace_dormant:delete_agent`, `workspace_dormant:read`, `workspace_dormant:share`, `workspace_dormant:ssh`, `workspace_dormant:start`, `workspace_dormant:stop`, `workspace_dormant:update`, `workspace_dormant:update_agent`, `workspace_proxy:*`, `workspace_proxy:create`, `workspace_proxy:delete`, `workspace_proxy:read`, `workspace_proxy:update` |
|
||||
| Value(s) |
|
||||
|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `aibridge_interception:*`, `aibridge_interception:create`, `aibridge_interception:read`, `aibridge_interception:update`, `all`, `api_key:*`, `api_key:create`, `api_key:delete`, `api_key:read`, `api_key:update`, `application_connect`, `assign_org_role:*`, `assign_org_role:assign`, `assign_org_role:create`, `assign_org_role:delete`, `assign_org_role:read`, `assign_org_role:unassign`, `assign_org_role:update`, `assign_role:*`, `assign_role:assign`, `assign_role:read`, `assign_role:unassign`, `audit_log:*`, `audit_log:create`, `audit_log:read`, `boundary_usage:*`, `boundary_usage:delete`, `boundary_usage:read`, `boundary_usage:update`, `chat:*`, `chat:create`, `chat:delete`, `chat:read`, `chat:update`, `coder:all`, `coder:apikeys.manage_self`, `coder:application_connect`, `coder:templates.author`, `coder:templates.build`, `coder:workspaces.access`, `coder:workspaces.create`, `coder:workspaces.delete`, `coder:workspaces.operate`, `connection_log:*`, `connection_log:read`, `connection_log:update`, `crypto_key:*`, `crypto_key:create`, `crypto_key:delete`, `crypto_key:read`, `crypto_key:update`, `debug_info:*`, `debug_info:read`, `deployment_config:*`, `deployment_config:read`, `deployment_config:update`, `deployment_stats:*`, `deployment_stats:read`, `file:*`, `file:create`, `file:read`, `group:*`, `group:create`, `group:delete`, `group:read`, `group:update`, `group_member:*`, `group_member:read`, `idpsync_settings:*`, `idpsync_settings:read`, `idpsync_settings:update`, `inbox_notification:*`, `inbox_notification:create`, `inbox_notification:read`, `inbox_notification:update`, `license:*`, `license:create`, `license:delete`, `license:read`, `notification_message:*`, `notification_message:create`, `notification_message:delete`, `notification_message:read`, `notification_message:update`, `notification_preference:*`, `notification_preference:read`, `notification_preference:update`, `notification_template:*`, `notification_template:read`, `notification_template:update`, `oauth2_app:*`, `oauth2_app:create`, `oauth2_app:delete`, `oauth2_app:read`, `oauth2_app:update`, `oauth2_app_code_token:*`, `oauth2_app_code_token:create`, `oauth2_app_code_token:delete`, `oauth2_app_code_token:read`, `oauth2_app_secret:*`, `oauth2_app_secret:create`, `oauth2_app_secret:delete`, `oauth2_app_secret:read`, `oauth2_app_secret:update`, `organization:*`, `organization:create`, `organization:delete`, `organization:read`, `organization:update`, `organization_member:*`, `organization_member:create`, `organization_member:delete`, `organization_member:read`, `organization_member:update`, `prebuilt_workspace:*`, `prebuilt_workspace:delete`, `prebuilt_workspace:update`, `provisioner_daemon:*`, `provisioner_daemon:create`, `provisioner_daemon:delete`, `provisioner_daemon:read`, `provisioner_daemon:update`, `provisioner_jobs:*`, `provisioner_jobs:create`, `provisioner_jobs:read`, `provisioner_jobs:update`, `replicas:*`, `replicas:read`, `system:*`, `system:create`, `system:delete`, `system:read`, `system:update`, `tailnet_coordinator:*`, `tailnet_coordinator:create`, `tailnet_coordinator:delete`, `tailnet_coordinator:read`, `tailnet_coordinator:update`, `task:*`, `task:create`, `task:delete`, `task:read`, `task:update`, `template:*`, `template:create`, `template:delete`, `template:read`, `template:update`, `template:use`, `template:view_insights`, `usage_event:*`, `usage_event:create`, `usage_event:read`, `usage_event:update`, `user:*`, `user:create`, `user:delete`, `user:read`, `user:read_personal`, `user:update`, `user:update_personal`, `user_secret:*`, `user_secret:create`, `user_secret:delete`, `user_secret:read`, `user_secret:update`, `webpush_subscription:*`, `webpush_subscription:create`, `webpush_subscription:delete`, `webpush_subscription:read`, `workspace:*`, `workspace:application_connect`, `workspace:create`, `workspace:create_agent`, `workspace:delete`, `workspace:delete_agent`, `workspace:read`, `workspace:share`, `workspace:ssh`, `workspace:start`, `workspace:stop`, `workspace:update`, `workspace:update_agent`, `workspace_agent_devcontainers:*`, `workspace_agent_devcontainers:create`, `workspace_agent_resource_monitor:*`, `workspace_agent_resource_monitor:create`, `workspace_agent_resource_monitor:read`, `workspace_agent_resource_monitor:update`, `workspace_dormant:*`, `workspace_dormant:application_connect`, `workspace_dormant:create`, `workspace_dormant:create_agent`, `workspace_dormant:delete`, `workspace_dormant:delete_agent`, `workspace_dormant:read`, `workspace_dormant:share`, `workspace_dormant:ssh`, `workspace_dormant:start`, `workspace_dormant:stop`, `workspace_dormant:update`, `workspace_dormant:update_agent`, `workspace_proxy:*`, `workspace_proxy:create`, `workspace_proxy:delete`, `workspace_proxy:read`, `workspace_proxy:update` |
|
||||
|
||||
## codersdk.AddLicenseRequest
|
||||
|
||||
@@ -2805,6 +2805,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
}
|
||||
]
|
||||
},
|
||||
"external_auth_github_default_provider_enable": true,
|
||||
"external_token_encryption_keys": [
|
||||
"string"
|
||||
],
|
||||
@@ -3373,6 +3374,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
}
|
||||
]
|
||||
},
|
||||
"external_auth_github_default_provider_enable": true,
|
||||
"external_token_encryption_keys": [
|
||||
"string"
|
||||
],
|
||||
@@ -3682,78 +3684,79 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|--------------------------------------|------------------------------------------------------------------------------------------------------|----------|--------------|--------------------------------------------------------------------|
|
||||
| `access_url` | [serpent.URL](#serpenturl) | false | | |
|
||||
| `additional_csp_policy` | array of string | false | | |
|
||||
| `address` | [serpent.HostPort](#serpenthostport) | false | | Deprecated: Use HTTPAddress or TLS.Address instead. |
|
||||
| `agent_fallback_troubleshooting_url` | [serpent.URL](#serpenturl) | false | | |
|
||||
| `agent_stat_refresh_interval` | integer | false | | |
|
||||
| `ai` | [codersdk.AIConfig](#codersdkaiconfig) | false | | |
|
||||
| `allow_workspace_renames` | boolean | false | | |
|
||||
| `autobuild_poll_interval` | integer | false | | |
|
||||
| `browser_only` | boolean | false | | |
|
||||
| `cache_directory` | string | false | | |
|
||||
| `cli_upgrade_message` | string | false | | |
|
||||
| `config` | string | false | | |
|
||||
| `config_ssh` | [codersdk.SSHConfig](#codersdksshconfig) | false | | |
|
||||
| `dangerous` | [codersdk.DangerousConfig](#codersdkdangerousconfig) | false | | |
|
||||
| `derp` | [codersdk.DERP](#codersdkderp) | false | | |
|
||||
| `disable_owner_workspace_exec` | boolean | false | | |
|
||||
| `disable_password_auth` | boolean | false | | |
|
||||
| `disable_path_apps` | boolean | false | | |
|
||||
| `disable_workspace_sharing` | boolean | false | | |
|
||||
| `docs_url` | [serpent.URL](#serpenturl) | false | | |
|
||||
| `enable_authz_recording` | boolean | false | | |
|
||||
| `enable_terraform_debug_mode` | boolean | false | | |
|
||||
| `ephemeral_deployment` | boolean | false | | |
|
||||
| `experiments` | array of string | false | | |
|
||||
| `external_auth` | [serpent.Struct-array_codersdk_ExternalAuthConfig](#serpentstruct-array_codersdk_externalauthconfig) | false | | |
|
||||
| `external_token_encryption_keys` | array of string | false | | |
|
||||
| `healthcheck` | [codersdk.HealthcheckConfig](#codersdkhealthcheckconfig) | false | | |
|
||||
| `hide_ai_tasks` | boolean | false | | |
|
||||
| `http_address` | string | false | | Http address is a string because it may be set to zero to disable. |
|
||||
| `http_cookies` | [codersdk.HTTPCookieConfig](#codersdkhttpcookieconfig) | false | | |
|
||||
| `job_hang_detector_interval` | integer | false | | |
|
||||
| `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | |
|
||||
| `metrics_cache_refresh_interval` | integer | false | | |
|
||||
| `notifications` | [codersdk.NotificationsConfig](#codersdknotificationsconfig) | false | | |
|
||||
| `oauth2` | [codersdk.OAuth2Config](#codersdkoauth2config) | false | | |
|
||||
| `oidc` | [codersdk.OIDCConfig](#codersdkoidcconfig) | false | | |
|
||||
| `pg_auth` | string | false | | |
|
||||
| `pg_conn_max_idle` | string | false | | |
|
||||
| `pg_conn_max_open` | integer | false | | |
|
||||
| `pg_connection_url` | string | false | | |
|
||||
| `pprof` | [codersdk.PprofConfig](#codersdkpprofconfig) | false | | |
|
||||
| `prometheus` | [codersdk.PrometheusConfig](#codersdkprometheusconfig) | false | | |
|
||||
| `provisioner` | [codersdk.ProvisionerConfig](#codersdkprovisionerconfig) | false | | |
|
||||
| `proxy_health_status_interval` | integer | false | | |
|
||||
| `proxy_trusted_headers` | array of string | false | | |
|
||||
| `proxy_trusted_origins` | array of string | false | | |
|
||||
| `rate_limit` | [codersdk.RateLimitConfig](#codersdkratelimitconfig) | false | | |
|
||||
| `redirect_to_access_url` | boolean | false | | |
|
||||
| `retention` | [codersdk.RetentionConfig](#codersdkretentionconfig) | false | | |
|
||||
| `scim_api_key` | string | false | | |
|
||||
| `session_lifetime` | [codersdk.SessionLifetime](#codersdksessionlifetime) | false | | |
|
||||
| `ssh_keygen_algorithm` | string | false | | |
|
||||
| `stats_collection` | [codersdk.StatsCollectionConfig](#codersdkstatscollectionconfig) | false | | |
|
||||
| `strict_transport_security` | integer | false | | |
|
||||
| `strict_transport_security_options` | array of string | false | | |
|
||||
| `support` | [codersdk.SupportConfig](#codersdksupportconfig) | false | | |
|
||||
| `swagger` | [codersdk.SwaggerConfig](#codersdkswaggerconfig) | false | | |
|
||||
| `telemetry` | [codersdk.TelemetryConfig](#codersdktelemetryconfig) | false | | |
|
||||
| `terms_of_service_url` | string | false | | |
|
||||
| `tls` | [codersdk.TLSConfig](#codersdktlsconfig) | false | | |
|
||||
| `trace` | [codersdk.TraceConfig](#codersdktraceconfig) | false | | |
|
||||
| `update_check` | boolean | false | | |
|
||||
| `user_quiet_hours_schedule` | [codersdk.UserQuietHoursScheduleConfig](#codersdkuserquiethoursscheduleconfig) | false | | |
|
||||
| `verbose` | boolean | false | | |
|
||||
| `web_terminal_renderer` | string | false | | |
|
||||
| `wgtunnel_host` | string | false | | |
|
||||
| `wildcard_access_url` | string | false | | |
|
||||
| `workspace_hostname_suffix` | string | false | | |
|
||||
| `workspace_prebuilds` | [codersdk.PrebuildsConfig](#codersdkprebuildsconfig) | false | | |
|
||||
| `write_config` | boolean | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------------------------|------------------------------------------------------------------------------------------------------|----------|--------------|--------------------------------------------------------------------|
|
||||
| `access_url` | [serpent.URL](#serpenturl) | false | | |
|
||||
| `additional_csp_policy` | array of string | false | | |
|
||||
| `address` | [serpent.HostPort](#serpenthostport) | false | | Deprecated: Use HTTPAddress or TLS.Address instead. |
|
||||
| `agent_fallback_troubleshooting_url` | [serpent.URL](#serpenturl) | false | | |
|
||||
| `agent_stat_refresh_interval` | integer | false | | |
|
||||
| `ai` | [codersdk.AIConfig](#codersdkaiconfig) | false | | |
|
||||
| `allow_workspace_renames` | boolean | false | | |
|
||||
| `autobuild_poll_interval` | integer | false | | |
|
||||
| `browser_only` | boolean | false | | |
|
||||
| `cache_directory` | string | false | | |
|
||||
| `cli_upgrade_message` | string | false | | |
|
||||
| `config` | string | false | | |
|
||||
| `config_ssh` | [codersdk.SSHConfig](#codersdksshconfig) | false | | |
|
||||
| `dangerous` | [codersdk.DangerousConfig](#codersdkdangerousconfig) | false | | |
|
||||
| `derp` | [codersdk.DERP](#codersdkderp) | false | | |
|
||||
| `disable_owner_workspace_exec` | boolean | false | | |
|
||||
| `disable_password_auth` | boolean | false | | |
|
||||
| `disable_path_apps` | boolean | false | | |
|
||||
| `disable_workspace_sharing` | boolean | false | | |
|
||||
| `docs_url` | [serpent.URL](#serpenturl) | false | | |
|
||||
| `enable_authz_recording` | boolean | false | | |
|
||||
| `enable_terraform_debug_mode` | boolean | false | | |
|
||||
| `ephemeral_deployment` | boolean | false | | |
|
||||
| `experiments` | array of string | false | | |
|
||||
| `external_auth` | [serpent.Struct-array_codersdk_ExternalAuthConfig](#serpentstruct-array_codersdk_externalauthconfig) | false | | |
|
||||
| `external_auth_github_default_provider_enable` | boolean | false | | |
|
||||
| `external_token_encryption_keys` | array of string | false | | |
|
||||
| `healthcheck` | [codersdk.HealthcheckConfig](#codersdkhealthcheckconfig) | false | | |
|
||||
| `hide_ai_tasks` | boolean | false | | |
|
||||
| `http_address` | string | false | | Http address is a string because it may be set to zero to disable. |
|
||||
| `http_cookies` | [codersdk.HTTPCookieConfig](#codersdkhttpcookieconfig) | false | | |
|
||||
| `job_hang_detector_interval` | integer | false | | |
|
||||
| `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | |
|
||||
| `metrics_cache_refresh_interval` | integer | false | | |
|
||||
| `notifications` | [codersdk.NotificationsConfig](#codersdknotificationsconfig) | false | | |
|
||||
| `oauth2` | [codersdk.OAuth2Config](#codersdkoauth2config) | false | | |
|
||||
| `oidc` | [codersdk.OIDCConfig](#codersdkoidcconfig) | false | | |
|
||||
| `pg_auth` | string | false | | |
|
||||
| `pg_conn_max_idle` | string | false | | |
|
||||
| `pg_conn_max_open` | integer | false | | |
|
||||
| `pg_connection_url` | string | false | | |
|
||||
| `pprof` | [codersdk.PprofConfig](#codersdkpprofconfig) | false | | |
|
||||
| `prometheus` | [codersdk.PrometheusConfig](#codersdkprometheusconfig) | false | | |
|
||||
| `provisioner` | [codersdk.ProvisionerConfig](#codersdkprovisionerconfig) | false | | |
|
||||
| `proxy_health_status_interval` | integer | false | | |
|
||||
| `proxy_trusted_headers` | array of string | false | | |
|
||||
| `proxy_trusted_origins` | array of string | false | | |
|
||||
| `rate_limit` | [codersdk.RateLimitConfig](#codersdkratelimitconfig) | false | | |
|
||||
| `redirect_to_access_url` | boolean | false | | |
|
||||
| `retention` | [codersdk.RetentionConfig](#codersdkretentionconfig) | false | | |
|
||||
| `scim_api_key` | string | false | | |
|
||||
| `session_lifetime` | [codersdk.SessionLifetime](#codersdksessionlifetime) | false | | |
|
||||
| `ssh_keygen_algorithm` | string | false | | |
|
||||
| `stats_collection` | [codersdk.StatsCollectionConfig](#codersdkstatscollectionconfig) | false | | |
|
||||
| `strict_transport_security` | integer | false | | |
|
||||
| `strict_transport_security_options` | array of string | false | | |
|
||||
| `support` | [codersdk.SupportConfig](#codersdksupportconfig) | false | | |
|
||||
| `swagger` | [codersdk.SwaggerConfig](#codersdkswaggerconfig) | false | | |
|
||||
| `telemetry` | [codersdk.TelemetryConfig](#codersdktelemetryconfig) | false | | |
|
||||
| `terms_of_service_url` | string | false | | |
|
||||
| `tls` | [codersdk.TLSConfig](#codersdktlsconfig) | false | | |
|
||||
| `trace` | [codersdk.TraceConfig](#codersdktraceconfig) | false | | |
|
||||
| `update_check` | boolean | false | | |
|
||||
| `user_quiet_hours_schedule` | [codersdk.UserQuietHoursScheduleConfig](#codersdkuserquiethoursscheduleconfig) | false | | |
|
||||
| `verbose` | boolean | false | | |
|
||||
| `web_terminal_renderer` | string | false | | |
|
||||
| `wgtunnel_host` | string | false | | |
|
||||
| `wildcard_access_url` | string | false | | |
|
||||
| `workspace_hostname_suffix` | string | false | | |
|
||||
| `workspace_prebuilds` | [codersdk.PrebuildsConfig](#codersdkprebuildsconfig) | false | | |
|
||||
| `write_config` | boolean | false | | |
|
||||
|
||||
## codersdk.DiagnosticExtra
|
||||
|
||||
@@ -3981,9 +3984,9 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|----------------------------------------------------------------------------------------------------------------|
|
||||
| `auto-fill-parameters`, `example`, `mcp-server-http`, `notifications`, `oauth2`, `web-push`, `workspace-usage` |
|
||||
| Value(s) |
|
||||
|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| `agents`, `auto-fill-parameters`, `example`, `mcp-server-http`, `notifications`, `oauth2`, `web-push`, `workspace-usage` |
|
||||
|
||||
## codersdk.ExternalAPIKeyScopes
|
||||
|
||||
@@ -7322,9 +7325,9 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| Value(s) |
|
||||
|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
|
||||
## codersdk.RateLimitConfig
|
||||
|
||||
|
||||
Generated
+5
-5
@@ -811,11 +811,11 @@ Status Code **200**
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Property | Value(s) |
|
||||
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| `login_type` | `github`, `oidc`, `password`, `token` |
|
||||
| `scope` | `all`, `application_connect` |
|
||||
| Property | Value(s) |
|
||||
|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `type` | `*`, `aibridge_interception`, `api_key`, `assign_org_role`, `assign_role`, `audit_log`, `boundary_usage`, `chat`, `connection_log`, `crypto_key`, `debug_info`, `deployment_config`, `deployment_stats`, `file`, `group`, `group_member`, `idpsync_settings`, `inbox_notification`, `license`, `notification_message`, `notification_preference`, `notification_template`, `oauth2_app`, `oauth2_app_code_token`, `oauth2_app_secret`, `organization`, `organization_member`, `prebuilt_workspace`, `provisioner_daemon`, `provisioner_jobs`, `replicas`, `system`, `tailnet_coordinator`, `task`, `template`, `usage_event`, `user`, `user_secret`, `webpush_subscription`, `workspace`, `workspace_agent_devcontainers`, `workspace_agent_resource_monitor`, `workspace_dormant`, `workspace_proxy` |
|
||||
| `login_type` | `github`, `oidc`, `password`, `token` |
|
||||
| `scope` | `all`, `application_connect` |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
|
||||
Generated
+11
@@ -1258,6 +1258,17 @@ The upgrade message to display to users when a client/server mismatch is detecte
|
||||
|
||||
Support links to display in the top right drop down menu.
|
||||
|
||||
### --external-auth-github-default-provider-enable
|
||||
|
||||
| | |
|
||||
|-------------|------------------------------------------------------------------|
|
||||
| Type | <code>bool</code> |
|
||||
| Environment | <code>$CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE</code> |
|
||||
| YAML | <code>externalAuthGithubDefaultProviderEnable</code> |
|
||||
| Default | <code>true</code> |
|
||||
|
||||
Enable the default GitHub external auth provider managed by Coder.
|
||||
|
||||
### --proxy-health-interval
|
||||
|
||||
| | |
|
||||
|
||||
@@ -63,6 +63,9 @@ OPTIONS:
|
||||
Separate multiple experiments with commas, or enter '*' to opt-in to
|
||||
all available experiments.
|
||||
|
||||
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
|
||||
Enable the default GitHub external auth provider managed by Coder.
|
||||
|
||||
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
|
||||
Type of auth to use when connecting to postgres. For AWS RDS, using
|
||||
IAM authentication (awsiamrds) is recommended.
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// RelaySourceHeader marks replica-relayed stream requests.
|
||||
const RelaySourceHeader = "X-Coder-Relay-Source-Replica"
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
cookieHeader = "Cookie"
|
||||
)
|
||||
|
||||
// newRemotePartsProvider creates a RemotePartsProvider that dials a remote
|
||||
// replica's stream endpoint to fetch message_part events. It filters to only
|
||||
// forward message_part events since durable events come via pubsub.
|
||||
func newRemotePartsProvider(
|
||||
resolveReplicaAddress func(context.Context, uuid.UUID) (string, bool),
|
||||
replicaHTTPClient *http.Client,
|
||||
replicaID uuid.UUID,
|
||||
) chatd.RemotePartsProvider {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
address, ok := resolveReplicaAddress(ctx, workerID)
|
||||
if !ok {
|
||||
return nil, nil, nil, xerrors.New("worker replica not found")
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err)
|
||||
}
|
||||
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||
sdkClient := codersdk.New(baseURL)
|
||||
sdkClient.HTTPClient = replicaHTTPClient
|
||||
sdkClient.SessionTokenProvider = relayHeaderTokenProvider{
|
||||
header: relayHeaders(requestHeader, replicaID),
|
||||
}
|
||||
sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID)
|
||||
if err != nil {
|
||||
relayCancel()
|
||||
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err)
|
||||
}
|
||||
|
||||
snapshot := make([]codersdk.ChatStreamEvent, 0, 100)
|
||||
preloaded := make([]codersdk.ChatStreamEvent, 0, 100)
|
||||
drainInitial:
|
||||
for len(snapshot) < cap(snapshot) {
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
_ = sourceStream.Close()
|
||||
relayCancel()
|
||||
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err())
|
||||
case event, ok := <-sourceEvents:
|
||||
if !ok {
|
||||
break drainInitial
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
||||
continue
|
||||
}
|
||||
snapshot = append(snapshot, event)
|
||||
preloaded = append(preloaded, event)
|
||||
default:
|
||||
break drainInitial
|
||||
}
|
||||
}
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 128)
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
defer relayCancel()
|
||||
defer func() {
|
||||
_ = sourceStream.Close()
|
||||
}()
|
||||
|
||||
for _, event := range preloaded {
|
||||
select {
|
||||
case events <- event:
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
case event, ok := <-sourceEvents:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case events <- event:
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
cancel := func() {
|
||||
relayCancel()
|
||||
_ = sourceStream.Close()
|
||||
}
|
||||
return snapshot, events, cancel, nil
|
||||
}
|
||||
}
|
||||
|
||||
type relayHeaderTokenProvider struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption {
|
||||
return func(req *http.Request) {
|
||||
for key, values := range p.header {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
||||
if opts.HTTPHeader == nil {
|
||||
opts.HTTPHeader = make(http.Header)
|
||||
}
|
||||
for key, values := range p.header {
|
||||
for _, value := range values {
|
||||
opts.HTTPHeader.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) GetSessionToken() string {
|
||||
return p.header.Get(codersdk.SessionTokenHeader)
|
||||
}
|
||||
|
||||
func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header {
|
||||
header := make(http.Header)
|
||||
if source != nil {
|
||||
for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} {
|
||||
for _, value := range source.Values(key) {
|
||||
header.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
header.Set(RelaySourceHeader, replicaID.String())
|
||||
return header
|
||||
}
|
||||
@@ -0,0 +1,355 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestChatStreamRelay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RelayMessagePartsAcrossReplicas", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
firstClient, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureHighAvailability: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
secondClient, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
DontAddLicense: true,
|
||||
DontAddFirstUser: true,
|
||||
})
|
||||
secondClient.SetSessionToken(firstClient.SessionToken())
|
||||
|
||||
// Verify we have two replicas
|
||||
replicas, err := secondClient.Replicas(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, replicas, 2)
|
||||
firstReplicaID := replicaIDForClientURL(t, firstClient.URL, replicas)
|
||||
secondReplicaID := replicaIDForClientURL(t, secondClient.URL, replicas)
|
||||
|
||||
streamingChunks := make(chan chattest.OpenAIChunk, 8)
|
||||
chatStreamStarted := make(chan struct{}, 1)
|
||||
openai := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if req.Stream {
|
||||
select {
|
||||
case chatStreamStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return chattest.OpenAIResponse{StreamingChunks: streamingChunks}
|
||||
}
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := firstClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, provider.Source)
|
||||
|
||||
model, err := firstClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat on the first replica
|
||||
chat, err := firstClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "Test chat for relay",
|
||||
}},
|
||||
ModelConfigID: &model.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ChatStatusPending, chat.Status)
|
||||
|
||||
var runningChat database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
current, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
if current.Status != database.ChatStatusRunning || !current.WorkerID.Valid {
|
||||
return false
|
||||
}
|
||||
runningChat = current
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
var localClient *codersdk.Client
|
||||
var relayClient *codersdk.Client
|
||||
switch runningChat.WorkerID.UUID {
|
||||
case firstReplicaID:
|
||||
localClient = firstClient
|
||||
relayClient = secondClient
|
||||
case secondReplicaID:
|
||||
localClient = secondClient
|
||||
relayClient = firstClient
|
||||
default:
|
||||
require.FailNowf(
|
||||
t,
|
||||
"worker replica was not recognized",
|
||||
"worker %s was not one of %s or %s",
|
||||
runningChat.WorkerID.UUID,
|
||||
firstReplicaID,
|
||||
secondReplicaID,
|
||||
)
|
||||
}
|
||||
|
||||
firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
defer firstStream.Close()
|
||||
|
||||
select {
|
||||
case <-chatStreamStarted:
|
||||
case <-ctx.Done():
|
||||
require.FailNowf(
|
||||
t,
|
||||
"timed out waiting for OpenAI stream request",
|
||||
"chat stream request did not start before context deadline: %v",
|
||||
ctx.Err(),
|
||||
)
|
||||
}
|
||||
|
||||
firstChunkText := "relay-part-one"
|
||||
streamingChunks <- chattest.OpenAITextChunks(firstChunkText)[0]
|
||||
firstEvent := waitForStreamTextPart(ctx, t, firstEvents, firstChunkText)
|
||||
require.Equal(t, "assistant", firstEvent.MessagePart.Role)
|
||||
|
||||
secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
defer secondStream.Close()
|
||||
|
||||
secondSnapshotEvent := waitForStreamTextPart(ctx, t, secondEvents, firstChunkText)
|
||||
require.Equal(t, "assistant", secondSnapshotEvent.MessagePart.Role)
|
||||
|
||||
secondChunkText := "relay-part-two"
|
||||
streamingChunks <- chattest.OpenAITextChunks(secondChunkText)[0]
|
||||
waitForStreamTextPart(ctx, t, firstEvents, secondChunkText)
|
||||
waitForStreamTextPart(ctx, t, secondEvents, secondChunkText)
|
||||
|
||||
close(streamingChunks)
|
||||
})
|
||||
}
|
||||
|
||||
func waitForStreamTextPart(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
events <-chan codersdk.ChatStreamEvent,
|
||||
expectedText string,
|
||||
) codersdk.ChatStreamEvent {
|
||||
t.Helper()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
require.FailNowf(
|
||||
t,
|
||||
"timed out waiting for chat stream event",
|
||||
"expected text part %q before context deadline: %v",
|
||||
expectedText,
|
||||
ctx.Err(),
|
||||
)
|
||||
case event, ok := <-events:
|
||||
require.Truef(t, ok, "chat stream closed while waiting for %q", expectedText)
|
||||
|
||||
if event.Type == codersdk.ChatStreamEventTypeError {
|
||||
errMessage := "unknown chat stream error"
|
||||
if event.Error != nil && event.Error.Message != "" {
|
||||
errMessage = event.Error.Message
|
||||
}
|
||||
require.FailNowf(
|
||||
t,
|
||||
"chat stream returned error event",
|
||||
"while waiting for %q: %s",
|
||||
expectedText,
|
||||
errMessage,
|
||||
)
|
||||
}
|
||||
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart || event.MessagePart == nil {
|
||||
continue
|
||||
}
|
||||
if event.MessagePart.Part.Type != codersdk.ChatMessagePartTypeText {
|
||||
continue
|
||||
}
|
||||
|
||||
require.Equal(t, expectedText, event.MessagePart.Part.Text)
|
||||
return event
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func replicaIDForClientURL(
|
||||
t *testing.T,
|
||||
clientURL *url.URL,
|
||||
replicas []codersdk.Replica,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
for _, replica := range replicas {
|
||||
relayURL, err := url.Parse(replica.RelayAddress)
|
||||
require.NoErrorf(
|
||||
t,
|
||||
err,
|
||||
"parse replica relay address %q",
|
||||
replica.RelayAddress,
|
||||
)
|
||||
if relayURL.Host == clientURL.Host {
|
||||
return replica.ID
|
||||
}
|
||||
}
|
||||
|
||||
require.FailNowf(
|
||||
t,
|
||||
"missing replica for client URL",
|
||||
"client host %q not present in replica list",
|
||||
clientURL.Host,
|
||||
)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
func TestChatModelConfigDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, _ := coderdenttest.New(t, nil)
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := client.CreateChatProvider(
|
||||
ctx,
|
||||
codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: "https://example.com",
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(1000)
|
||||
compressionThreshold := int32(70)
|
||||
trueValue := true
|
||||
falseValue := false
|
||||
|
||||
firstModel, err := client.CreateChatModelConfig(
|
||||
ctx,
|
||||
codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-5-a",
|
||||
DisplayName: "GPT 5 A",
|
||||
IsDefault: &trueValue,
|
||||
ContextLimit: &contextLimit,
|
||||
CompressionThreshold: &compressionThreshold,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, firstModel.IsDefault)
|
||||
|
||||
secondModel, err := client.CreateChatModelConfig(
|
||||
ctx,
|
||||
codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-5-b",
|
||||
DisplayName: "GPT 5 B",
|
||||
IsDefault: &trueValue,
|
||||
ContextLimit: &contextLimit,
|
||||
CompressionThreshold: &compressionThreshold,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, secondModel.IsDefault)
|
||||
|
||||
modelConfigs, err := client.ListChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
firstStored := findChatModelConfigByID(t, modelConfigs, firstModel.ID)
|
||||
secondStored := findChatModelConfigByID(t, modelConfigs, secondModel.ID)
|
||||
require.False(t, firstStored.IsDefault)
|
||||
require.True(t, secondStored.IsDefault)
|
||||
|
||||
updatedFirst, err := client.UpdateChatModelConfig(
|
||||
ctx,
|
||||
firstModel.ID,
|
||||
codersdk.UpdateChatModelConfigRequest{
|
||||
IsDefault: &trueValue,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updatedFirst.IsDefault)
|
||||
|
||||
modelConfigs, err = client.ListChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
firstStored = findChatModelConfigByID(t, modelConfigs, firstModel.ID)
|
||||
secondStored = findChatModelConfigByID(t, modelConfigs, secondModel.ID)
|
||||
require.True(t, firstStored.IsDefault)
|
||||
require.False(t, secondStored.IsDefault)
|
||||
|
||||
updatedFirst, err = client.UpdateChatModelConfig(
|
||||
ctx,
|
||||
firstModel.ID,
|
||||
codersdk.UpdateChatModelConfigRequest{
|
||||
IsDefault: &falseValue,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, updatedFirst.IsDefault)
|
||||
|
||||
modelConfigs, err = client.ListChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
firstStored = findChatModelConfigByID(t, modelConfigs, firstModel.ID)
|
||||
secondStored = findChatModelConfigByID(t, modelConfigs, secondModel.ID)
|
||||
require.False(t, firstStored.IsDefault)
|
||||
require.True(t, secondStored.IsDefault)
|
||||
}
|
||||
|
||||
func findChatModelConfigByID(
|
||||
t *testing.T,
|
||||
modelConfigs []codersdk.ChatModelConfig,
|
||||
id uuid.UUID,
|
||||
) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
for _, modelConfig := range modelConfigs {
|
||||
if modelConfig.ID == id {
|
||||
return modelConfig
|
||||
}
|
||||
}
|
||||
|
||||
require.FailNowf(t, "missing model config", "model config %s not found", id)
|
||||
return codersdk.ChatModelConfig{}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package coderd
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -100,6 +102,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
}
|
||||
}()
|
||||
|
||||
if options.ExternalTokenEncryption == nil {
|
||||
options.ExternalTokenEncryption = make([]dbcrypt.Cipher, 0)
|
||||
@@ -141,6 +148,33 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
)
|
||||
}
|
||||
|
||||
meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create DERP mesh TLS config: %w", err)
|
||||
}
|
||||
|
||||
var replicaManagerPtr atomic.Pointer[replicasync.Manager]
|
||||
resolveReplicaAddress := func(
|
||||
_ context.Context,
|
||||
replicaID uuid.UUID,
|
||||
) (string, bool) {
|
||||
manager := replicaManagerPtr.Load()
|
||||
if manager == nil {
|
||||
return "", false
|
||||
}
|
||||
for _, replica := range manager.AllPrimary() {
|
||||
if replica.ID != replicaID {
|
||||
continue
|
||||
}
|
||||
relayAddress := strings.TrimSpace(replica.RelayAddress)
|
||||
if relayAddress == "" {
|
||||
return "", false
|
||||
}
|
||||
return relayAddress, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
api := &API{
|
||||
ctx: ctx,
|
||||
cancel: cancelFunc,
|
||||
@@ -156,6 +190,44 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
}
|
||||
// This must happen before coderd initialization!
|
||||
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
|
||||
|
||||
// Wire up enterprise chat relay for cross-replica message_part streaming.
|
||||
// Must be set before coderd.New so the chat processor gets it.
|
||||
replicaHTTPClient := replicaRelayHTTPClient(options.HTTPClient, meshTLSConfig)
|
||||
if replicaHTTPClient == nil {
|
||||
replicaHTTPClient = options.Options.HTTPClient
|
||||
}
|
||||
if replicaHTTPClient == nil {
|
||||
replicaHTTPClient = http.DefaultClient
|
||||
}
|
||||
// Use a closure that captures api by reference so it can access api.AGPL.ID
|
||||
// after coderd.New is called. The provider is only invoked when Subscribe
|
||||
// is called, which happens after initialization, so api.AGPL will be set.
|
||||
options.Options.ChatRemotePartsProvider = func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
// Get the replica ID from the API (will be set after coderd.New)
|
||||
replicaID := api.AGPL.ID
|
||||
if replicaID == uuid.Nil {
|
||||
// Fallback if somehow called before initialization
|
||||
replicaID = uuid.New()
|
||||
}
|
||||
provider := newRemotePartsProvider(
|
||||
resolveReplicaAddress,
|
||||
replicaHTTPClient,
|
||||
replicaID,
|
||||
)
|
||||
return provider(ctx, chatID, workerID, requestHeader)
|
||||
}
|
||||
|
||||
api.AGPL = coderd.New(options.Options)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -583,10 +655,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
})))
|
||||
}
|
||||
|
||||
meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create DERP mesh TLS config: %w", err)
|
||||
}
|
||||
// We always want to run the replica manager even if we don't have DERP
|
||||
// enabled, since it's used to detect other coder servers for licensing.
|
||||
api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{
|
||||
@@ -600,6 +668,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("initialize replica: %w", err)
|
||||
}
|
||||
replicaManagerPtr.Store(api.replicaManager)
|
||||
if api.DERPServer != nil {
|
||||
api.derpMesh = derpmesh.New(options.Logger.Named("derpmesh"), api.DERPServer, meshTLSConfig)
|
||||
}
|
||||
@@ -651,6 +720,28 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
return api, nil
|
||||
}
|
||||
|
||||
func replicaRelayHTTPClient(base *http.Client, tlsConfig *tls.Config) *http.Client {
|
||||
if base == nil {
|
||||
base = http.DefaultClient
|
||||
}
|
||||
|
||||
clone := *base
|
||||
var transport *http.Transport
|
||||
switch t := base.Transport.(type) {
|
||||
case *http.Transport:
|
||||
transport = t.Clone()
|
||||
default:
|
||||
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
transport = defaultTransport.Clone()
|
||||
} else {
|
||||
transport = &http.Transport{}
|
||||
}
|
||||
}
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
clone.Transport = transport
|
||||
return &clone
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
*coderd.Options
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package dbcrypt
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -82,6 +83,32 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
providers, err := cryptDB.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat providers: %w", err)
|
||||
}
|
||||
log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers)))
|
||||
for idx, provider := range providers {
|
||||
if strings.TrimSpace(provider.APIKey) == "" {
|
||||
continue
|
||||
}
|
||||
if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
// Revoke old keys
|
||||
for _, c := range ciphers[1:] {
|
||||
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
|
||||
@@ -172,6 +199,28 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
providers, err := cryptDB.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat providers: %w", err)
|
||||
}
|
||||
log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers)))
|
||||
for idx, provider := range providers {
|
||||
if !provider.ApiKeyKeyID.Valid {
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
// Revoke _all_ keys
|
||||
for _, c := range ciphers {
|
||||
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
|
||||
@@ -192,6 +241,10 @@ DELETE FROM user_links
|
||||
DELETE FROM external_auth_links
|
||||
WHERE oauth_access_token_key_id IS NOT NULL
|
||||
OR oauth_refresh_token_key_id IS NOT NULL;
|
||||
UPDATE chat_providers
|
||||
SET api_key = '',
|
||||
api_key_key_id = NULL
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
@@ -203,9 +256,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
|
||||
store := database.New(sqlDB)
|
||||
_, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete user links: %w", err)
|
||||
return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err)
|
||||
}
|
||||
log.Info(ctx, "deleted encrypted user tokens")
|
||||
log.Info(ctx, "deleted encrypted user tokens and chat provider API keys")
|
||||
|
||||
log.Info(ctx, "revoking all active keys")
|
||||
keys, err := store.GetDBCryptKeys(ctx)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -351,6 +352,92 @@ func (db *dbCrypt) GetCryptoKeysByFeature(ctx context.Context, feature database.
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
provider, err := db.Store.GetChatProviderByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
|
||||
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
providers, err := db.Store.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range providers {
|
||||
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
providers, err := db.Store.GetEnabledChatProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range providers {
|
||||
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
|
||||
provider, err := db.Store.InsertChatProvider(ctx, params)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
|
||||
provider, err := db.Store.UpdateChatProvider(ctx, params)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
|
||||
// If no cipher is loaded, then we can't encrypt anything!
|
||||
if db.ciphers == nil || db.primaryCipherDigest == "" {
|
||||
|
||||
Generated
+3
-3
@@ -76,11 +76,11 @@
|
||||
},
|
||||
"nixpkgs-unstable": {
|
||||
"locked": {
|
||||
"lastModified": 1758035966,
|
||||
"narHash": "sha256-qqIJ3yxPiB0ZQTT9//nFGQYn8X/PBoJbofA7hRKZnmE=",
|
||||
"lastModified": 1771369470,
|
||||
"narHash": "sha256-0NBlEBKkN3lufyvFegY4TYv5mCNHbi5OmBDrzihbBMQ=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "8d4ddb19d03c65a36ad8d189d001dc32ffb0306b",
|
||||
"rev": "0182a361324364ae3f436a63005877674cf45efb",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -109,6 +109,24 @@
|
||||
vendorHash = "sha256-69kg3qkvEWyCAzjaCSr3a73MNonub9sZTYyGaCW+UTI=";
|
||||
};
|
||||
|
||||
# Keep Terraform aligned with provisioner/terraform/testdata/version.txt
|
||||
# so `make gen` remains deterministic in Nix shells.
|
||||
terraform_1_14_1 =
|
||||
if pkgs.stdenv.isLinux && pkgs.stdenv.hostPlatform.isx86_64 then
|
||||
pkgs.runCommand "terraform-1.14.1" {
|
||||
nativeBuildInputs = [ pkgs.unzip ];
|
||||
src = pkgs.fetchurl {
|
||||
url = "https://releases.hashicorp.com/terraform/1.14.1/terraform_1.14.1_linux_amd64.zip";
|
||||
hash = "sha256-n1MHDuYm354VeIfB0/mvPYEHobZUNxzZkEBinu1piyc=";
|
||||
};
|
||||
} ''
|
||||
mkdir -p "$out/bin"
|
||||
unzip -p "$src" terraform > "$out/bin/terraform"
|
||||
chmod +x "$out/bin/terraform"
|
||||
''
|
||||
else
|
||||
unstablePkgs.terraform;
|
||||
|
||||
# Packages required to build the frontend
|
||||
frontendPackages =
|
||||
with pkgs;
|
||||
@@ -156,7 +174,7 @@
|
||||
gnused
|
||||
gnugrep
|
||||
gnutar
|
||||
unstablePkgs.go_1_25
|
||||
unstablePkgs.go_1_26
|
||||
gofumpt
|
||||
go-migrate
|
||||
(pinnedPkgs.golangci-lint)
|
||||
@@ -170,7 +188,7 @@
|
||||
lazydocker
|
||||
lazygit
|
||||
less
|
||||
mockgen
|
||||
unstablePkgs.mockgen
|
||||
moreutils
|
||||
nfpm
|
||||
nix-prefetch-git
|
||||
@@ -191,7 +209,7 @@
|
||||
# sqlc
|
||||
sqlc-custom
|
||||
syft
|
||||
unstablePkgs.terraform
|
||||
terraform_1_14_1
|
||||
typos
|
||||
which
|
||||
# Needed for many LD system libs!
|
||||
@@ -285,6 +303,14 @@
|
||||
lib.optionalDrvAttr stdenv.isLinux "${glibcLocales}/lib/locale/locale-archive";
|
||||
|
||||
NODE_OPTIONS = "--max-old-space-size=8192";
|
||||
BIOME_BINARY =
|
||||
if pkgs.stdenv.isLinux then
|
||||
if pkgs.stdenv.hostPlatform.isAarch64 then
|
||||
"@biomejs/cli-linux-arm64-musl/biome"
|
||||
else
|
||||
"@biomejs/cli-linux-x64-musl/biome"
|
||||
else
|
||||
"";
|
||||
GOPRIVATE = "coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder";
|
||||
};
|
||||
};
|
||||
|
||||
@@ -72,6 +72,14 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202508072
|
||||
// https://github.com/spf13/afero/pull/487
|
||||
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
|
||||
|
||||
// Forked for two reasons:
|
||||
// 1) Adds thinking effort to Anthropic provider
|
||||
// 2) Downgraded to Go 1.25 due to issue with Windows CI
|
||||
// https://github.com/kylecarbs/fantasy/compare/main...kylecarbs:fantasy:cj/go1.25
|
||||
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260225152134-45ae0791c21f
|
||||
|
||||
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
|
||||
|
||||
require (
|
||||
cdr.dev/slog/v3 v3.0.0-rc1
|
||||
cloud.google.com/go/compute/metadata v0.9.0
|
||||
@@ -83,7 +91,7 @@ require (
|
||||
github.com/aquasecurity/trivy-iac v0.8.0
|
||||
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2
|
||||
github.com/awalterschulze/gographviz v2.0.3+incompatible
|
||||
github.com/aws/smithy-go v1.24.0
|
||||
github.com/aws/smithy-go v1.24.1
|
||||
github.com/bramvdbogaerde/go-scp v1.6.0
|
||||
github.com/briandowns/spinner v1.23.0
|
||||
github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5
|
||||
@@ -138,7 +146,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-reap v0.0.0-20170704170343-bf58d8a43e7b
|
||||
github.com/hashicorp/go-version v1.7.0
|
||||
github.com/hashicorp/go-version v1.8.0
|
||||
github.com/hashicorp/hc-install v0.9.2
|
||||
github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f
|
||||
github.com/hashicorp/terraform-json v0.27.2
|
||||
@@ -150,7 +158,7 @@ require (
|
||||
github.com/justinas/nosurf v1.2.0
|
||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
|
||||
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/klauspost/compress v1.18.4
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-isatty v0.0.20
|
||||
github.com/mitchellh/go-wordwrap v1.0.1
|
||||
@@ -167,7 +175,7 @@ require (
|
||||
github.com/prometheus-community/pro-bing v0.8.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
github.com/prometheus/common v0.67.4
|
||||
github.com/prometheus/common v0.67.5
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/shirou/gopsutil/v4 v4.26.1
|
||||
@@ -186,17 +194,17 @@ require (
|
||||
github.com/zclconf/go-cty-yaml v1.2.0
|
||||
go.mozilla.org/pkcs7 v0.9.0
|
||||
go.nhat.io/otelsql v0.16.0
|
||||
go.opentelemetry.io/otel v1.39.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0
|
||||
go.opentelemetry.io/otel/sdk v1.39.0
|
||||
go.opentelemetry.io/otel/trace v1.39.0
|
||||
go.opentelemetry.io/otel v1.40.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
||||
go.opentelemetry.io/otel/sdk v1.40.0
|
||||
go.opentelemetry.io/otel/trace v1.40.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
go.uber.org/goleak v1.3.1-0.20240429205332-517bace7cc29
|
||||
go.uber.org/mock v0.6.0
|
||||
go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
|
||||
golang.org/x/mod v0.33.0
|
||||
golang.org/x/net v0.50.0
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
@@ -219,7 +227,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/auth v0.18.1 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
dario.cat/mergo v1.0.1 // indirect
|
||||
filippo.io/edwards25519 v1.1.1 // indirect
|
||||
@@ -253,19 +261,19 @@ require (
|
||||
github.com/armon/go-radix v1.0.1-0.20221118154546-54df44f2176c // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.1
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.1
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.9
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.2
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssm v1.60.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
@@ -288,7 +296,7 @@ require (
|
||||
github.com/dop251/goja v0.0.0-20241024094426-79f3a7efcdbd // indirect
|
||||
github.com/dustin/go-humanize v1.0.1
|
||||
github.com/eapache/queue/v2 v2.0.0-20230407133247-75960ed334e4 // indirect
|
||||
github.com/ebitengine/purego v0.9.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0-alpha.5 // indirect
|
||||
github.com/elastic/go-windows v1.0.0 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
@@ -305,7 +313,7 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/gobwas/glob v0.2.3 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
@@ -321,10 +329,10 @@ require (
|
||||
github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.12 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.17.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/hashicorp/go-cty v1.5.0 // indirect
|
||||
@@ -336,12 +344,11 @@ require (
|
||||
github.com/hashicorp/hcl/v2 v2.24.0
|
||||
github.com/hashicorp/logutils v1.0.0 // indirect
|
||||
github.com/hashicorp/terraform-plugin-go v0.29.0 // indirect
|
||||
github.com/hashicorp/terraform-plugin-log v0.9.0 // indirect
|
||||
github.com/hashicorp/terraform-plugin-log v0.10.0 // indirect
|
||||
github.com/hashicorp/terraform-plugin-sdk/v2 v2.38.1 // indirect
|
||||
github.com/hdevalence/ed25519consensus v0.1.0 // indirect
|
||||
github.com/illarion/gonotify v1.0.1 // indirect
|
||||
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.1-0.20220621161143-b0104c826a24 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.3.5 // indirect
|
||||
@@ -388,14 +395,14 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
|
||||
github.com/riandyrn/otelchi v0.5.1 // indirect
|
||||
github.com/richardartoul/molecule v1.0.1-0.20240531184615-7ca0df43c0b3 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b // indirect
|
||||
github.com/secure-systems-lab/go-securesystemslib v0.9.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sirupsen/logrus v1.9.4 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/swaggo/files/v2 v2.0.0 // indirect
|
||||
@@ -437,9 +444,9 @@ require (
|
||||
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
|
||||
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
|
||||
go.opentelemetry.io/contrib v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0
|
||||
go.opentelemetry.io/otel/metric v1.39.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0
|
||||
go.opentelemetry.io/otel/metric v1.40.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
|
||||
@@ -448,10 +455,10 @@ require (
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d // indirect
|
||||
gopkg.in/ini.v1 v1.67.1 // indirect
|
||||
howett.net/plist v1.0.0 // indirect
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect
|
||||
sigs.k8s.io/yaml v1.5.0 // indirect
|
||||
@@ -464,12 +471,13 @@ require github.com/SherClockHolmes/webpush-go v1.4.0
|
||||
require (
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
charm.land/fantasy v0.8.1
|
||||
github.com/anthropics/anthropic-sdk-go v1.19.0
|
||||
github.com/brianvoe/gofakeit/v7 v7.14.0
|
||||
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
|
||||
@@ -491,17 +499,19 @@ require (
|
||||
cel.dev/expr v0.25.1 // indirect
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/iam v1.5.3 // indirect
|
||||
cloud.google.com/go/logging v1.13.1 // indirect
|
||||
cloud.google.com/go/logging v1.13.2 // indirect
|
||||
cloud.google.com/go/longrunning v0.8.0 // indirect
|
||||
cloud.google.com/go/monitoring v1.24.3 // indirect
|
||||
cloud.google.com/go/storage v1.56.0 // indirect
|
||||
cloud.google.com/go/storage v1.60.0 // indirect
|
||||
git.sr.ht/~jackmordaunt/go-toast v1.1.2 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect
|
||||
github.com/DataDog/datadog-agent/comp/core/tagger/origindetection v0.64.2 // indirect
|
||||
github.com/DataDog/datadog-agent/pkg/version v0.64.2 // indirect
|
||||
github.com/DataDog/dd-trace-go/v2 v2.0.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.1 // indirect
|
||||
github.com/alecthomas/chroma v0.10.0 // indirect
|
||||
github.com/aquasecurity/go-version v0.0.1 // indirect
|
||||
@@ -509,24 +519,29 @@ require (
|
||||
github.com/aquasecurity/jfather v0.0.8 // indirect
|
||||
github.com/aquasecurity/trivy v0.61.1-0.20250407075540-f1329c7ea1aa // indirect
|
||||
github.com/aquasecurity/trivy-checks v1.11.3-0.20250604022615-9a7efa7c9169 // indirect
|
||||
github.com/aws/aws-sdk-go v1.55.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect
|
||||
github.com/bits-and-blooms/bitset v1.24.4 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250904123553-b4e2667e5ad5 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect
|
||||
github.com/coder/paralleltestctx v0.0.1 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
|
||||
github.com/daixiang0/gci v0.13.7 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect
|
||||
github.com/esiqveland/notify v0.13.3 // indirect
|
||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
@@ -534,19 +549,24 @@ require (
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/google/go-containerregistry v0.20.6 // indirect
|
||||
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
|
||||
github.com/hashicorp/go-getter v1.7.9 // indirect
|
||||
github.com/hashicorp/go-safetemp v1.0.0 // indirect
|
||||
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.70 // indirect
|
||||
github.com/hashicorp/go-getter v1.8.4 // indirect
|
||||
github.com/hexops/gotextdiff v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/jackmordaunt/icns/v3 v3.0.1 // indirect
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.4 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.10 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.6.10 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.10 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c // indirect
|
||||
github.com/mattn/go-shellwords v1.0.12 // indirect
|
||||
github.com/moby/sys/user v0.4.0 // indirect
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
||||
github.com/openai/openai-go v1.12.0 // indirect
|
||||
github.com/openai/openai-go/v2 v2.7.1 // indirect
|
||||
github.com/openai/openai-go/v3 v3.15.0 // indirect
|
||||
github.com/package-url/packageurl-go v0.1.3 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
@@ -569,13 +589,13 @@ require (
|
||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.39.0 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect
|
||||
google.golang.org/genai v1.12.0 // indirect
|
||||
google.golang.org/genai v1.47.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
k8s.io/utils v0.0.0-20241210054802-24370beab758 // indirect
|
||||
mvdan.cc/gofumpt v0.8.0 // indirect
|
||||
|
||||
@@ -44,7 +44,7 @@ func Test_absoluteBinaryPath(t *testing.T) {
|
||||
{
|
||||
name: "TestMalformedVersion",
|
||||
terraformVersion: "version",
|
||||
expectedErr: xerrors.Errorf("Terraform binary get version failed: Malformed version: version"),
|
||||
expectedErr: xerrors.Errorf("Terraform binary get version failed: malformed version: version"),
|
||||
},
|
||||
}
|
||||
// nolint:paralleltest
|
||||
|
||||
Executable
+33
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "usage: $0 <path-relative-to-site>" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
repo_root=$(cd "$script_dir/.." && pwd)
|
||||
target=$1
|
||||
|
||||
output_file=$(mktemp)
|
||||
trap 'rm -f "$output_file"' EXIT
|
||||
|
||||
if (
|
||||
cd "$repo_root/site"
|
||||
pnpm exec biome format --write "$target"
|
||||
) >"$output_file" 2>&1; then
|
||||
cat "$output_file"
|
||||
exit 0
|
||||
fi
|
||||
status=$?
|
||||
|
||||
cat "$output_file" >&2
|
||||
|
||||
if [[ $status -eq 127 ]] || grep -q "Could not start dynamically linked executable" "$output_file" || grep -q "NixOS cannot run dynamically linked executables" "$output_file"; then
|
||||
echo "WARNING: skipping biome format for '$target' because the biome binary is unavailable in this environment." >&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
exit $status
|
||||
@@ -48,6 +48,10 @@ func prepareEnv() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = os.Setenv("TMPDIR", "/tmp")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteEmptyDirs(dir string) error {
|
||||
|
||||
+16
-1
@@ -1042,7 +1042,22 @@ const fillParameters = async (
|
||||
case "number":
|
||||
{
|
||||
const parameterField = parameterLabel.locator("input");
|
||||
await parameterField.fill(buildParameter.value);
|
||||
// Dynamic parameters can hydrate after initial render and
|
||||
// overwrite an early fill. Re-apply until the desired value
|
||||
// is stable.
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
await parameterField.fill(buildParameter.value);
|
||||
try {
|
||||
await expect(parameterField).toHaveValue(buildParameter.value, {
|
||||
timeout: 1000,
|
||||
});
|
||||
break;
|
||||
} catch (error) {
|
||||
if (attempt === 2) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user