mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
Merge branch 'main' into fix/codagt-517-testagent-stats-ssh
This commit is contained in:
@@ -166,23 +166,18 @@ runs:
|
||||
mise_dir: ${{ steps.mise-data-dir.outputs.path }}
|
||||
install_args: ${{ steps.cache-key.outputs.install-args }}
|
||||
cache: "false"
|
||||
# Do not export mise's resolved env (every tool install dir) into
|
||||
# GITHUB_ENV. Tools resolve through the shims dir on GITHUB_PATH, so
|
||||
# the export only bloats PATH. On Windows the mise go shim re-prepends
|
||||
# those dirs at invocation, and the resulting PATH crosses cmd.exe's
|
||||
# ~8191 character limit, which makes cmd.exe drop PATH entirely and
|
||||
# fail to resolve native executables in subprocesses spawned by tests.
|
||||
env: false
|
||||
|
||||
- name: Ensure Git usr/bin is in PATH (Windows)
|
||||
- name: Add Git usr/bin to PATH (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
# jdx/mise-action exports "Path" via GITHUB_ENV which may
|
||||
# collide with bash's "PATH". Ensure Git usr/bin is present
|
||||
# and remove any duplicate Path/PATH entries from GITHUB_ENV
|
||||
# by writing both forms.
|
||||
run: | # zizmor: ignore[github-env]
|
||||
$gitdir = "C:\Program Files\Git\usr\bin"
|
||||
$current = $env:Path
|
||||
if ($current -notlike "*$gitdir*") {
|
||||
$current = "$gitdir;$current"
|
||||
}
|
||||
# Write both Path and PATH to GITHUB_ENV so that both
|
||||
# cmd.exe (uses Path) and bash/Go (uses PATH) see the
|
||||
# same value including Git usr/bin.
|
||||
"Path=$current" | Out-File -Append -FilePath $env:GITHUB_ENV -Encoding utf8
|
||||
"PATH=$current" | Out-File -Append -FilePath $env:GITHUB_ENV -Encoding utf8
|
||||
|
||||
shell: bash
|
||||
# GITHUB_PATH is the casing-safe channel and keeps the entry short.
|
||||
# cmd.exe subprocesses spawned by Go tests need MSYS coreutils such as
|
||||
# printf, which live here.
|
||||
run: echo "C:\Program Files\Git\usr\bin" >> "$GITHUB_PATH"
|
||||
|
||||
+3
-1
@@ -236,6 +236,9 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
|
||||
traceAttrs := interceptor.TraceAttributes(r)
|
||||
span.SetAttributes(traceAttrs...)
|
||||
ctx = tracing.WithInterceptionAttributesInContext(ctx, traceAttrs)
|
||||
// Attach the interception ID to the context so every log line
|
||||
// emitted with this context can be correlated to the interception.
|
||||
ctx = slog.With(ctx, slog.F("interception_id", interceptor.ID()))
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Record usage in the background to not block request flow.
|
||||
@@ -272,7 +275,6 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
|
||||
log := logger.With(
|
||||
slog.F("route", route),
|
||||
slog.F("provider", p.Name()),
|
||||
slog.F("interception_id", interceptor.ID()),
|
||||
slog.F("user_agent", r.UserAgent()),
|
||||
slog.F("streaming", interceptor.Streaming()),
|
||||
slog.F("credential_kind", string(cred.Kind)),
|
||||
|
||||
@@ -40,7 +40,7 @@ func (r *WrappedRecorder) RecordInterception(ctx context.Context, req *Intercept
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn(ctx, "failed to record interception", slog.Error(err), slog.F("interception_id", req.ID))
|
||||
r.logger.Warn(ctx, "failed to record interception", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (r *WrappedRecorder) RecordInterceptionEnded(ctx context.Context, req *Inte
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err), slog.F("interception_id", req.ID))
|
||||
r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (r *WrappedRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsag
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err), slog.F("interception_id", req.InterceptionID))
|
||||
r.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ func (r *WrappedRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageR
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn(ctx, "failed to record token usage", slog.Error(err), slog.F("interception_id", req.InterceptionID))
|
||||
r.logger.Warn(ctx, "failed to record token usage", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (r *WrappedRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRec
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn(ctx, "failed to record tool usage", slog.Error(err), slog.F("interception_id", req.InterceptionID))
|
||||
r.logger.Warn(ctx, "failed to record tool usage", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func (r *WrappedRecorder) RecordModelThought(ctx context.Context, req *ModelThou
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn(ctx, "failed to record model thought", slog.Error(err), slog.F("interception_id", req.InterceptionID))
|
||||
r.logger.Warn(ctx, "failed to record model thought", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+1
-3
@@ -146,10 +146,8 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
}).WithAgent().Do()
|
||||
|
||||
coderURLEnv := "$CODER_URL"
|
||||
headerCmd := "printf X-Process-Testing=very-wow-" + coderURLEnv + "'\\r\\n'X-Process-Testing2=more-wow"
|
||||
if runtime.GOOS == "windows" {
|
||||
coderURLEnv = "%CODER_URL%"
|
||||
headerCmd = "echo X-Process-Testing=very-wow-" + coderURLEnv + "& echo X-Process-Testing2=more-wow"
|
||||
}
|
||||
|
||||
logDir := t.TempDir()
|
||||
@@ -161,7 +159,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
"--log-dir", logDir,
|
||||
"--agent-header", "X-Testing=agent",
|
||||
"--agent-header", "Cool-Header=Ethan was Here!",
|
||||
"--agent-header-command", headerCmd,
|
||||
"--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
|
||||
"--socket-path", testutil.AgentSocketPath(t),
|
||||
)
|
||||
clitest.Start(t, agentInv)
|
||||
|
||||
@@ -229,15 +229,8 @@ func Test_sshConfigMatchExecEscape(t *testing.T) {
|
||||
|
||||
// OpenSSH processes %% escape sequences into %
|
||||
escaped = strings.ReplaceAll(escaped, "%%", "%")
|
||||
c := exec.Command(cmd, arg, escaped) //nolint:gosec
|
||||
if runtime.GOOS == "windows" {
|
||||
// Deduplicate Path/PATH env vars so cmd.exe
|
||||
// subprocesses (like powershell.exe used for
|
||||
// paths with spaces) resolve correctly.
|
||||
c.Env = appendAndDedupEnv(os.Environ())
|
||||
}
|
||||
b, err := c.CombinedOutput()
|
||||
require.NoError(t, err, "command output: %s", string(b))
|
||||
b, err := exec.Command(cmd, arg, escaped).CombinedOutput() //nolint:gosec
|
||||
require.NoError(t, err)
|
||||
got := strings.TrimSpace(string(b))
|
||||
require.Equal(t, "yay", got)
|
||||
})
|
||||
|
||||
@@ -4,8 +4,6 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -52,13 +50,7 @@ func sshConfigMatchExecEscape(path string) (string, error) {
|
||||
|
||||
if strings.ContainsAny(path, " ") {
|
||||
// c.f. function comment for how this works.
|
||||
// Use absolute paths for powershell.exe and cmd.exe
|
||||
// to avoid PATH resolution issues when both Path and
|
||||
// PATH (MSYS-translated) exist in the environment.
|
||||
sysRoot := os.Getenv("SYSTEMROOT")
|
||||
pwsh := filepath.Join(sysRoot, "System32", "WindowsPowerShell", "v1.0", "powershell.exe")
|
||||
cmd := filepath.Join(sysRoot, "System32", "cmd.exe")
|
||||
path = fmt.Sprintf("for /f %%%%a in ('%s -Command [char]34') do @%s /c %%%%a%s%%%%a", pwsh, cmd, path) //nolint:gocritic // We don't want %q here.
|
||||
path = fmt.Sprintf("for /f %%%%a in ('powershell.exe -Command [char]34') do @cmd.exe /c %%%%a%s%%%%a", path) //nolint:gocritic // We don't want %q here.
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
+2
-39
@@ -1701,44 +1701,7 @@ func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r(req)
|
||||
}
|
||||
|
||||
// appendAndDedupEnv appends extra environment variables and
|
||||
// deduplicates entries with the same key (case-insensitive on
|
||||
// Windows). For the PATH variable specifically, it prefers the
|
||||
// value that contains native Windows paths (with backslashes)
|
||||
// over MSYS-translated paths (with forward slashes). For all
|
||||
// other variables, the last value wins.
|
||||
func appendAndDedupEnv(env []string, extra ...string) []string {
|
||||
env = append(env, extra...)
|
||||
if runtime.GOOS != "windows" {
|
||||
return env
|
||||
}
|
||||
seen := make(map[string]int, len(env))
|
||||
result := make([]string, 0, len(env))
|
||||
for _, e := range env {
|
||||
key, val, ok := strings.Cut(e, "=")
|
||||
if !ok {
|
||||
result = append(result, e)
|
||||
continue
|
||||
}
|
||||
upper := strings.ToUpper(key)
|
||||
if idx, exists := seen[upper]; exists {
|
||||
if upper == "PATH" {
|
||||
// Prefer the value with native Windows paths.
|
||||
existingVal := result[idx][len(key)+1:]
|
||||
if strings.Contains(existingVal, "\\") && !strings.Contains(val, "\\") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
result[idx] = e
|
||||
continue
|
||||
}
|
||||
seen[upper] = len(result)
|
||||
result = append(result, e)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// headerTransport creates a new transport that executes `--header-command`
|
||||
// HeaderTransport creates a new transport that executes `--header-command`
|
||||
// if it is set to add headers for all outbound requests.
|
||||
func headerTransport(ctx context.Context, serverURL *url.URL, header []string, headerCommand string) (*codersdk.HeaderTransport, error) {
|
||||
transport := &codersdk.HeaderTransport{
|
||||
@@ -1756,7 +1719,7 @@ func headerTransport(ctx context.Context, serverURL *url.URL, header []string, h
|
||||
var outBuf bytes.Buffer
|
||||
// #nosec
|
||||
cmd := exec.CommandContext(ctx, shell, caller, headerCommand)
|
||||
cmd.Env = appendAndDedupEnv(os.Environ(), "CODER_URL="+serverURL.String())
|
||||
cmd.Env = append(os.Environ(), "CODER_URL="+serverURL.String())
|
||||
cmd.Stdout = &outBuf
|
||||
cmd.Stderr = io.Discard
|
||||
err := cmd.Run()
|
||||
|
||||
+2
-4
@@ -177,17 +177,15 @@ func TestRoot(t *testing.T) {
|
||||
url = srv.URL
|
||||
buf := new(bytes.Buffer)
|
||||
coderURLEnv := "$CODER_URL"
|
||||
headerCmd := "printf X-Process-Testing=very-wow-" + coderURLEnv + "'\\r\\n'X-Process-Testing2=more-wow"
|
||||
if runtime.GOOS == "windows" {
|
||||
coderURLEnv = "%CODER_URL%"
|
||||
headerCmd = "echo X-Process-Testing=very-wow-" + coderURLEnv + "& echo X-Process-Testing2=more-wow"
|
||||
}
|
||||
inv, _ := clitest.New(t,
|
||||
"--no-feature-warning",
|
||||
"--no-version-warning",
|
||||
"--header", "X-Testing=wow",
|
||||
"--header", "Cool-Header=Dean was Here!",
|
||||
"--header-command", headerCmd,
|
||||
"--header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
|
||||
"login", srv.URL,
|
||||
)
|
||||
inv.Stdout = buf
|
||||
@@ -268,7 +266,7 @@ func TestDERPHeaders(t *testing.T) {
|
||||
"--no-version-warning",
|
||||
"ping", workspace.Name,
|
||||
"-n", "1",
|
||||
"--header-command", "echo X-Process-Testing=very-wow",
|
||||
"--header-command", "printf X-Process-Testing=very-wow",
|
||||
}
|
||||
for k, v := range expectedHeaders {
|
||||
if k != "X-Process-Testing" {
|
||||
|
||||
+20
-14
@@ -6468,22 +6468,14 @@ type runChatResult struct {
|
||||
HistoryTipMessageID int64
|
||||
}
|
||||
|
||||
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
|
||||
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
if !ok {
|
||||
return ctx
|
||||
}
|
||||
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
|
||||
}
|
||||
|
||||
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
if message.Visibility != database.ChatMessageVisibilityBoth &&
|
||||
message.Visibility != database.ChatMessageVisibilityUser {
|
||||
if !isUserVisibleChatMessage(message) &&
|
||||
!(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
|
||||
continue
|
||||
}
|
||||
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
|
||||
@@ -6494,6 +6486,11 @@ func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bo
|
||||
return "", false
|
||||
}
|
||||
|
||||
func isUserVisibleChatMessage(message database.ChatMessage) bool {
|
||||
return message.Visibility == database.ChatMessageVisibilityBoth ||
|
||||
message.Visibility == database.ChatMessageVisibilityUser
|
||||
}
|
||||
|
||||
func allToolNames(allTools []fantasy.AgentTool) []string {
|
||||
toolNames := make([]string, 0, len(allTools))
|
||||
for _, tool := range allTools {
|
||||
@@ -7124,7 +7121,9 @@ func (p *Server) runChat(
|
||||
return result, xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
||||
if modelOpts.ActiveAPIKeyID != "" {
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID)
|
||||
}
|
||||
|
||||
// Load MCP server configs and user tokens in parallel with model
|
||||
// resolution. These queries have no dependencies on each other and all
|
||||
@@ -7831,6 +7830,7 @@ func (p *Server) runChat(
|
||||
persistCtx,
|
||||
chat.ID,
|
||||
modelConfig.ID,
|
||||
modelOpts.ActiveAPIKeyID,
|
||||
compactionToolCallID,
|
||||
result,
|
||||
); err != nil {
|
||||
@@ -8460,12 +8460,14 @@ func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.P
|
||||
return tools
|
||||
}
|
||||
|
||||
// persistChatContextSummary persists a chat context summary to the database.
|
||||
// This is invoked via the chat loop's compaction callback.
|
||||
// persistChatContextSummary is called from the chat loop's compaction
|
||||
// callback. activeAPIKeyID is stamped onto the summary user message. When
|
||||
// empty, it falls back to the delegated key in ctx.
|
||||
func (p *Server) persistChatContextSummary(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
activeAPIKeyID string,
|
||||
toolCallID string,
|
||||
result chatloop.CompactionResult,
|
||||
) error {
|
||||
@@ -8514,6 +8516,11 @@ func (p *Server) persistChatContextSummary(
|
||||
return xerrors.Errorf("encode summary tool result: %w", err)
|
||||
}
|
||||
|
||||
summaryAPIKeyID := activeAPIKeyID
|
||||
if summaryAPIKeyID == "" {
|
||||
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
}
|
||||
|
||||
var insertedMessages []database.ChatMessage
|
||||
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
@@ -8522,7 +8529,6 @@ func (p *Server) persistChatContextSummary(
|
||||
}
|
||||
|
||||
// Hidden summary user message (not published to subscribers).
|
||||
summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
summaryUserMsg := newUserChatMessage(
|
||||
summaryAPIKeyID,
|
||||
systemContent,
|
||||
|
||||
@@ -6651,42 +6651,63 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
|
||||
UserID: user.ID,
|
||||
})
|
||||
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
||||
|
||||
server := &Server{db: db}
|
||||
persistAndAssertSummaryKey := func(
|
||||
summaryCtx context.Context,
|
||||
chatID uuid.UUID,
|
||||
activeAPIKeyID string,
|
||||
wantAPIKeyID string,
|
||||
toolCallID string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
err := server.persistChatContextSummary(
|
||||
ctx,
|
||||
chat.ID,
|
||||
modelConfig.ID,
|
||||
"tool-call-id-1",
|
||||
chatloop.CompactionResult{
|
||||
SystemSummary: "summarized context",
|
||||
SummaryReport: "context was summarized",
|
||||
ThresholdPercent: 70,
|
||||
UsagePercent: 85.0,
|
||||
ContextTokens: 8500,
|
||||
ContextLimit: 10000,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err := server.persistChatContextSummary(
|
||||
summaryCtx,
|
||||
chatID,
|
||||
modelConfig.ID,
|
||||
activeAPIKeyID,
|
||||
toolCallID,
|
||||
chatloop.CompactionResult{
|
||||
SystemSummary: "summarized context",
|
||||
SummaryReport: "context was summarized",
|
||||
ThresholdPercent: 70,
|
||||
UsagePercent: 85.0,
|
||||
ContextTokens: 8500,
|
||||
ContextLimit: 10000,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
|
||||
// that selects compressed=true, visibility='model'. Only the user
|
||||
// summary qualifies; the assistant (visibility=user) and tool
|
||||
// result (visibility=both) are excluded by the CTE filter.
|
||||
require.NotEmpty(t, msgs)
|
||||
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
|
||||
// that selects compressed=true, visibility='model'. Only the user
|
||||
// summary qualifies; the assistant (visibility=user) and tool
|
||||
// result (visibility=both) are excluded by the CTE filter.
|
||||
require.NotEmpty(t, msgs)
|
||||
|
||||
var foundUserSummary bool
|
||||
for _, msg := range msgs {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
foundUserSummary = true
|
||||
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
|
||||
require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match")
|
||||
var foundUserSummary bool
|
||||
for _, msg := range msgs {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
foundUserSummary = true
|
||||
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
|
||||
require.Equal(t, wantAPIKeyID, msg.APIKeyID.String, "summary user message APIKeyID must match")
|
||||
}
|
||||
}
|
||||
require.True(t, foundUserSummary, "expected to find compressed user summary message")
|
||||
}
|
||||
require.True(t, foundUserSummary, "expected to find compressed user summary message")
|
||||
|
||||
persistAndAssertSummaryKey(ctx, chat.ID, apiKey.ID, apiKey.ID, "tool-call-id-1")
|
||||
|
||||
fallbackChat := dbgen.Chat(t, db, database.Chat{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
})
|
||||
fallbackKey, _ := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
})
|
||||
fallbackCtx := aibridge.WithDelegatedAPIKeyID(ctx, fallbackKey.ID)
|
||||
persistAndAssertSummaryKey(fallbackCtx, fallbackChat.ID, "", fallbackKey.ID, "tool-call-id-2")
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -9914,7 +9915,7 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
|
||||
MaxUsesPerRun: 3,
|
||||
MaxOutputTokens: 16384,
|
||||
})
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
@@ -9927,13 +9928,7 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscribe before the worker commits any durable messages so we
|
||||
// observe the advisor tool-result deltas live. Buffered parts are
|
||||
// claimed by their committed durable message ID at publishMessage
|
||||
// time and dropped from snapshots of late-connecting subscribers, so
|
||||
// a post-completion Subscribe() would no longer see streaming
|
||||
// deltas. Collecting events from the live channel covers the
|
||||
// streaming UX contract this test exists to verify.
|
||||
// Advisor deltas are transient; a late subscriber misses them.
|
||||
_, liveEvents, cancelLive, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
var (
|
||||
@@ -9969,6 +9964,8 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
server.Start()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
@@ -10023,17 +10020,15 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
|
||||
require.True(t, parentSawAdvisorResult,
|
||||
"parent must see the advisor reply in its continuation call")
|
||||
|
||||
// Stop the live collector and assert it captured the streaming
|
||||
// advisor deltas during processing. Late subscribers no longer
|
||||
// see committed parts because publishMessage claims them out of
|
||||
// new snapshots, so the assertion must use the live collector.
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
livePartsMu.Lock()
|
||||
defer livePartsMu.Unlock()
|
||||
assert.Equal(c, advisorDeltas, liveAdvisorDeltas,
|
||||
"advisor nested text deltas must stream into the parent tool card")
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
cancelLive()
|
||||
<-liveCollectorDone
|
||||
livePartsMu.Lock()
|
||||
collectedAdvisorDeltas := append([]string(nil), liveAdvisorDeltas...)
|
||||
livePartsMu.Unlock()
|
||||
require.Equal(t, advisorDeltas, collectedAdvisorDeltas,
|
||||
"advisor nested text deltas must stream into the parent tool card")
|
||||
|
||||
persisted, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
|
||||
@@ -405,7 +405,7 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SkipsModelOnlyUserMessages",
|
||||
name: "SkipsUncompressedModelOnlyUserMessages",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||
@@ -413,6 +413,54 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
wantKey: oldKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "CompressedSummaryFallback",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "LatestCompressedSummaryWins",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
|
||||
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "VisibleUserWinsOverCompressedSummary",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UncompressedModelOnlyUserIgnored",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CompressedSummaryMissingKeyDoesNotFallBack",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -421,15 +469,11 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
|
||||
require.Equal(t, tt.wantOK, gotOK)
|
||||
require.Equal(t, tt.wantKey, gotKey)
|
||||
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
|
||||
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
require.Equal(t, tt.wantOK, ctxOK)
|
||||
require.Equal(t, tt.wantKey, ctxKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
||||
func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
@@ -477,12 +521,70 @@ func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
||||
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, currentKey.ID, gotKey)
|
||||
}
|
||||
|
||||
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := t.Context()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
||||
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
|
||||
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
APIKeyID: sqlNullString(key.ID),
|
||||
})
|
||||
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
})
|
||||
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityModel,
|
||||
Compressed: true,
|
||||
APIKeyID: sqlNullString(key.ID),
|
||||
})
|
||||
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
})
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
ids := make(map[int64]struct{}, len(messages))
|
||||
for _, message := range messages {
|
||||
ids[message.ID] = struct{}{}
|
||||
}
|
||||
_, hasVisibleUser := ids[visibleUser.ID]
|
||||
require.False(t, hasVisibleUser)
|
||||
_, hasSummary := ids[compressedSummary.ID]
|
||||
require.True(t, hasSummary)
|
||||
_, hasAfterSummary := ids[afterSummary.ID]
|
||||
require.True(t, hasAfterSummary)
|
||||
|
||||
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, key.ID, gotKey)
|
||||
}
|
||||
|
||||
func sqlNullString(value string) sql.NullString {
|
||||
return sql.NullString{String: value, Valid: value != ""}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func Test_ProxyServer_Headers(t *testing.T) {
|
||||
"--access-url", "http://localhost:8080",
|
||||
"--http-address", ":0",
|
||||
"--header", fmt.Sprintf("%s=%s", headerName1, headerVal1),
|
||||
"--header-command", fmt.Sprintf("echo %s=%s", headerName2, headerVal2),
|
||||
"--header-command", fmt.Sprintf("printf %s=%s", headerName2, headerVal2),
|
||||
)
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdout = pty.Output()
|
||||
|
||||
+11
-11
@@ -89,7 +89,7 @@
|
||||
"lodash": "4.18.1",
|
||||
"lucide-react": "0.555.0",
|
||||
"monaco-editor": "0.55.1",
|
||||
"motion": "12.38.0",
|
||||
"motion": "12.40.0",
|
||||
"pretty-bytes": "6.1.1",
|
||||
"radix-ui": "1.4.3",
|
||||
"react": "19.2.6",
|
||||
@@ -101,7 +101,7 @@
|
||||
"react-markdown": "9.1.0",
|
||||
"react-query": "npm:@tanstack/react-query@5.77.0",
|
||||
"react-resizable-panels": "3.0.6",
|
||||
"react-router": "7.12.0",
|
||||
"react-router": "7.15.1",
|
||||
"react-syntax-highlighter": "15.6.6",
|
||||
"react-textarea-autosize": "8.5.9",
|
||||
"react-virtualized-auto-sizer": "1.0.26",
|
||||
@@ -111,7 +111,7 @@
|
||||
"semver": "7.7.3",
|
||||
"sonner": "2.0.7",
|
||||
"streamdown": "2.5.0",
|
||||
"tailwind-merge": "2.6.0",
|
||||
"tailwind-merge": "2.6.1",
|
||||
"tailwindcss-animate": "1.0.7",
|
||||
"tzdata": "1.0.46",
|
||||
"ua-parser-js": "1.0.41",
|
||||
@@ -123,7 +123,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "7.29.7",
|
||||
"@babel/plugin-syntax-typescript": "7.28.6",
|
||||
"@babel/plugin-syntax-typescript": "7.29.7",
|
||||
"@biomejs/biome": "2.4.10",
|
||||
"@chromatic-com/storybook": "5.0.1",
|
||||
"@octokit/types": "12.6.0",
|
||||
@@ -145,8 +145,8 @@
|
||||
"@types/express": "4.17.17",
|
||||
"@types/file-saver": "2.0.7",
|
||||
"@types/humanize-duration": "3.27.4",
|
||||
"@types/lodash": "4.17.21",
|
||||
"@types/node": "20.19.39",
|
||||
"@types/lodash": "4.17.24",
|
||||
"@types/node": "20.19.41",
|
||||
"@types/novnc__novnc": "1.5.0",
|
||||
"@types/react": "19.2.15",
|
||||
"@types/react-color": "3.0.13",
|
||||
@@ -158,8 +158,8 @@
|
||||
"@types/ssh2": "1.15.5",
|
||||
"@types/ua-parser-js": "0.7.36",
|
||||
"@types/uuid": "9.0.2",
|
||||
"@vitejs/plugin-react": "6.0.1",
|
||||
"@vitest/browser-playwright": "4.1.1",
|
||||
"@vitejs/plugin-react": "6.0.2",
|
||||
"@vitest/browser-playwright": "4.1.7",
|
||||
"autoprefixer": "10.5.0",
|
||||
"babel-plugin-react-compiler": "1.0.0",
|
||||
"chromatic": "11.29.0",
|
||||
@@ -170,7 +170,7 @@
|
||||
"jsdom": "27.2.0",
|
||||
"knip": "5.71.0",
|
||||
"msw": "2.4.8",
|
||||
"postcss": "8.5.10",
|
||||
"postcss": "8.5.15",
|
||||
"protobufjs": "7.6.1",
|
||||
"resize-observer-polyfill": "1.5.1",
|
||||
"rollup-plugin-visualizer": "7.0.1",
|
||||
@@ -181,9 +181,9 @@
|
||||
"tailwindcss": "3.4.18",
|
||||
"ts-proto": "1.181.2",
|
||||
"typescript": "6.0.2",
|
||||
"vite": "8.0.10",
|
||||
"vite": "8.0.14",
|
||||
"vite-plugin-checker": "0.13.0",
|
||||
"vitest": "4.1.5"
|
||||
"vitest": "4.1.7"
|
||||
},
|
||||
"browserslist": [
|
||||
"chrome 110",
|
||||
|
||||
Generated
+323
-333
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ import type {
|
||||
|
||||
const aiProvidersListKey = ["ai", "providers"] as const;
|
||||
|
||||
const aiProviderKeyFor = (idOrName: string) =>
|
||||
export const aiProviderKeyFor = (idOrName: string) =>
|
||||
[...aiProvidersListKey, idOrName] as const;
|
||||
|
||||
export const aiProvidersList = () => ({
|
||||
|
||||
+6
@@ -7,6 +7,7 @@ import { toast } from "sonner";
|
||||
import { getErrorMessage } from "#/api/errors";
|
||||
import {
|
||||
aiProvider,
|
||||
aiProviderKeyFor,
|
||||
deleteAIProviderMutation,
|
||||
updateAIProviderMutation,
|
||||
} from "#/api/queries/aiProviders";
|
||||
@@ -171,6 +172,10 @@ const UpdateProviderPageView: React.FC = () => {
|
||||
{ enabled: checked },
|
||||
{
|
||||
onSuccess: (updated) => {
|
||||
queryClient.setQueryData(
|
||||
aiProviderKeyFor(providerId),
|
||||
updated,
|
||||
);
|
||||
toast.success(
|
||||
`Provider "${updated.display_name || updated.name}" ${checked ? "enabled" : "disabled"}.`,
|
||||
);
|
||||
@@ -200,6 +205,7 @@ const UpdateProviderPageView: React.FC = () => {
|
||||
const request = providerFormValuesToUpdate(values, provider);
|
||||
try {
|
||||
const updated = await updateMutation.mutateAsync(request);
|
||||
queryClient.setQueryData(aiProviderKeyFor(providerId), updated);
|
||||
toast.success(
|
||||
`Provider "${updated.display_name || updated.name}" updated.`,
|
||||
);
|
||||
|
||||
@@ -10,6 +10,7 @@ type CredentialFieldProps = {
|
||||
placeholder?: string;
|
||||
description?: React.ReactNode;
|
||||
required?: boolean;
|
||||
onBlur?: () => void;
|
||||
onFocus?: () => void;
|
||||
};
|
||||
|
||||
@@ -20,6 +21,7 @@ export const CredentialField: React.FC<CredentialFieldProps> = ({
|
||||
placeholder,
|
||||
description,
|
||||
required = false,
|
||||
onBlur,
|
||||
onFocus,
|
||||
}) => {
|
||||
const inputId = useId();
|
||||
@@ -62,9 +64,13 @@ export const CredentialField: React.FC<CredentialFieldProps> = ({
|
||||
<Input
|
||||
id={inputId}
|
||||
name={helpers.name}
|
||||
className="font-mono text-[13px]"
|
||||
value={helpers.value}
|
||||
onChange={helpers.onChange}
|
||||
onBlur={helpers.onBlur}
|
||||
onBlur={(event) => {
|
||||
helpers.onBlur(event);
|
||||
onBlur?.();
|
||||
}}
|
||||
onFocus={onFocus}
|
||||
autoComplete={autoComplete}
|
||||
placeholder={placeholder}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { type ComponentProps, useState } from "react";
|
||||
import { expect, fn, screen, userEvent, waitFor, within } from "storybook/test";
|
||||
import { ProviderForm } from "./ProviderForm";
|
||||
import { createDeferred, type Deferred } from "#/testHelpers/deferred";
|
||||
import { ProviderForm, SAVED_CREDENTIAL_MASK } from "./ProviderForm";
|
||||
|
||||
const meta: Meta<typeof ProviderForm> = {
|
||||
title: "pages/AISettingsPage/ProviderForm",
|
||||
@@ -15,6 +17,88 @@ const meta: Meta<typeof ProviderForm> = {
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof ProviderForm>;
|
||||
|
||||
const SuccessfulSubmitProviderForm = ({
|
||||
args,
|
||||
deferred,
|
||||
}: {
|
||||
args: ComponentProps<typeof ProviderForm>;
|
||||
deferred: Deferred<void>;
|
||||
}) => {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
return (
|
||||
<ProviderForm
|
||||
{...args}
|
||||
isLoading={isLoading}
|
||||
onSubmit={async (values) => {
|
||||
args.onSubmit?.(values);
|
||||
setIsLoading(true);
|
||||
await deferred.promise;
|
||||
setIsLoading(false);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const FailedSubmitProviderForm = ({
|
||||
args,
|
||||
deferred,
|
||||
}: {
|
||||
args: ComponentProps<typeof ProviderForm>;
|
||||
deferred: Deferred<void>;
|
||||
}) => {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [submitError, setSubmitError] = useState<unknown>();
|
||||
|
||||
return (
|
||||
<ProviderForm
|
||||
{...args}
|
||||
isLoading={isLoading}
|
||||
submitError={submitError}
|
||||
onSubmit={async (values) => {
|
||||
args.onSubmit?.(values);
|
||||
setIsLoading(true);
|
||||
await deferred.promise;
|
||||
setSubmitError(new Error(errorSubmitMessage));
|
||||
setIsLoading(false);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const ExternalLoadingProviderForm = ({
|
||||
args,
|
||||
deferred,
|
||||
}: {
|
||||
args: ComponentProps<typeof ProviderForm>;
|
||||
deferred: Deferred<void>;
|
||||
}) => {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ProviderForm {...args} isLoading={isLoading} />
|
||||
<button
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
setIsLoading(true);
|
||||
await deferred.promise;
|
||||
setIsLoading(false);
|
||||
}}
|
||||
>
|
||||
Simulate external save
|
||||
</button>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const errorSubmitMessage = "Failed to update provider.";
|
||||
|
||||
let bedrockSubmitDeferred = createDeferred<void>();
|
||||
let apiKeySubmitDeferred = createDeferred<void>();
|
||||
let failedSubmitDeferred = createDeferred<void>();
|
||||
let externalSaveDeferred = createDeferred<void>();
|
||||
|
||||
export const AddAnthropicDefault: Story = {};
|
||||
|
||||
export const AddOpenAI: Story = {
|
||||
@@ -47,6 +131,15 @@ export const AddBedrock: Story = {
|
||||
};
|
||||
|
||||
export const EditBedrockKeepCredentials: Story = {
|
||||
render: (args) => {
|
||||
bedrockSubmitDeferred = createDeferred<void>();
|
||||
return (
|
||||
<SuccessfulSubmitProviderForm
|
||||
args={args}
|
||||
deferred={bedrockSubmitDeferred}
|
||||
/>
|
||||
);
|
||||
},
|
||||
args: {
|
||||
editing: true,
|
||||
bedrockSavedAccessCredentials: true,
|
||||
@@ -62,6 +155,59 @@ export const EditBedrockKeepCredentials: Story = {
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const accessKeyInput = await canvas.findByLabelText(/^access key\s*\*?$/i);
|
||||
const accessKeySecretInput =
|
||||
await canvas.findByLabelText(/access key secret/i);
|
||||
|
||||
expect(accessKeyInput).toHaveProperty("type", "text");
|
||||
expect(accessKeySecretInput).toHaveProperty("type", "text");
|
||||
expect(accessKeyInput).toHaveValue(SAVED_CREDENTIAL_MASK);
|
||||
expect(accessKeySecretInput).toHaveValue(SAVED_CREDENTIAL_MASK);
|
||||
|
||||
await userEvent.click(accessKeyInput);
|
||||
await waitFor(() => expect(accessKeyInput).toHaveValue(""));
|
||||
await userEvent.click(accessKeySecretInput);
|
||||
await waitFor(() =>
|
||||
expect(accessKeyInput).toHaveValue(SAVED_CREDENTIAL_MASK),
|
||||
);
|
||||
|
||||
await userEvent.click(accessKeyInput);
|
||||
await waitFor(() => expect(accessKeyInput).toHaveValue(""));
|
||||
await userEvent.type(accessKeyInput, "AKIAI1lO0EXAMPLE");
|
||||
expect(accessKeyInput).toHaveValue("AKIAI1lO0EXAMPLE");
|
||||
|
||||
await userEvent.click(accessKeySecretInput);
|
||||
await waitFor(() => expect(accessKeySecretInput).toHaveValue(""));
|
||||
await userEvent.type(accessKeySecretInput, "wJalrI1lO0Secret");
|
||||
expect(accessKeySecretInput).toHaveValue("wJalrI1lO0Secret");
|
||||
|
||||
const displayName = canvas.getByLabelText(/display name/i);
|
||||
await userEvent.clear(displayName);
|
||||
await userEvent.type(displayName, "Updated Bedrock");
|
||||
|
||||
const submitButton = canvas.getByRole("button", {
|
||||
name: /update provider/i,
|
||||
});
|
||||
await waitFor(() => expect(submitButton).toBeEnabled());
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
await waitFor(() =>
|
||||
expect(args.onSubmit).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
accessKey: "AKIAI1lO0EXAMPLE",
|
||||
accessKeySecret: "wJalrI1lO0Secret",
|
||||
}),
|
||||
),
|
||||
);
|
||||
await waitFor(() => expect(submitButton).toBeDisabled());
|
||||
bedrockSubmitDeferred.resolve();
|
||||
await waitFor(() => {
|
||||
expect(accessKeyInput).toHaveValue(SAVED_CREDENTIAL_MASK);
|
||||
expect(accessKeySecretInput).toHaveValue(SAVED_CREDENTIAL_MASK);
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const AddCopilot: Story = {
|
||||
@@ -141,6 +287,134 @@ export const Submitting: Story = {
|
||||
};
|
||||
|
||||
export const CredentialFocusClear: Story = {
|
||||
render: (args) => {
|
||||
apiKeySubmitDeferred = createDeferred<void>();
|
||||
return (
|
||||
<SuccessfulSubmitProviderForm
|
||||
args={args}
|
||||
deferred={apiKeySubmitDeferred}
|
||||
/>
|
||||
);
|
||||
},
|
||||
args: {
|
||||
editing: true,
|
||||
openAiAnthropicSavedApiKey: true,
|
||||
openAiAnthropicMaskedApiKey: "sk-ant-***\u2026***ABCD",
|
||||
initialValues: {
|
||||
type: "anthropic",
|
||||
name: "production-anthropic",
|
||||
displayName: "Production Anthropic",
|
||||
baseUrl: "https://api.anthropic.com",
|
||||
apiKey: "",
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const apiKeyInput = await canvas.findByLabelText(/api key/i);
|
||||
|
||||
expect(apiKeyInput).toHaveProperty("type", "text");
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD");
|
||||
|
||||
await userEvent.click(apiKeyInput);
|
||||
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
|
||||
|
||||
const displayName = canvas.getByLabelText(/display name/i);
|
||||
await userEvent.click(displayName);
|
||||
await waitFor(() =>
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD"),
|
||||
);
|
||||
|
||||
await userEvent.click(apiKeyInput);
|
||||
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
|
||||
await userEvent.type(apiKeyInput, "sk-ant-I1lO0-new-secret");
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-I1lO0-new-secret");
|
||||
|
||||
await userEvent.clear(displayName);
|
||||
await userEvent.type(displayName, "Updated Anthropic");
|
||||
|
||||
const submitButton = canvas.getByRole("button", {
|
||||
name: /update provider/i,
|
||||
});
|
||||
await waitFor(() => expect(submitButton).toBeEnabled());
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
await waitFor(() =>
|
||||
expect(args.onSubmit).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: "sk-ant-I1lO0-new-secret",
|
||||
}),
|
||||
),
|
||||
);
|
||||
await waitFor(() => expect(submitButton).toBeDisabled());
|
||||
apiKeySubmitDeferred.resolve();
|
||||
await waitFor(() =>
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD"),
|
||||
);
|
||||
},
|
||||
};
|
||||
export const FailedSubmitKeepsCredential: Story = {
|
||||
render: (args) => {
|
||||
failedSubmitDeferred = createDeferred<void>();
|
||||
return (
|
||||
<FailedSubmitProviderForm args={args} deferred={failedSubmitDeferred} />
|
||||
);
|
||||
},
|
||||
args: {
|
||||
editing: true,
|
||||
openAiAnthropicSavedApiKey: true,
|
||||
openAiAnthropicMaskedApiKey: "sk-ant-***\u2026***ABCD",
|
||||
initialValues: {
|
||||
type: "anthropic",
|
||||
name: "production-anthropic",
|
||||
displayName: "Production Anthropic",
|
||||
baseUrl: "https://api.anthropic.com",
|
||||
apiKey: "",
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const apiKeyInput = await canvas.findByLabelText(/api key/i);
|
||||
|
||||
await userEvent.click(apiKeyInput);
|
||||
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
|
||||
await userEvent.type(apiKeyInput, "sk-ant-I1lO0-new-secret");
|
||||
|
||||
const displayName = canvas.getByLabelText(/display name/i);
|
||||
await userEvent.clear(displayName);
|
||||
await userEvent.type(displayName, "Failed Anthropic");
|
||||
|
||||
const submitButton = canvas.getByRole("button", {
|
||||
name: /update provider/i,
|
||||
});
|
||||
await waitFor(() => expect(submitButton).toBeEnabled());
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
await waitFor(() =>
|
||||
expect(args.onSubmit).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: "sk-ant-I1lO0-new-secret",
|
||||
}),
|
||||
),
|
||||
);
|
||||
await waitFor(() => expect(submitButton).toBeDisabled());
|
||||
failedSubmitDeferred.resolve();
|
||||
await expect(await canvas.findByText(errorSubmitMessage)).toBeVisible();
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-I1lO0-new-secret");
|
||||
},
|
||||
};
|
||||
|
||||
export const ExternalLoadingKeepsCredential: Story = {
|
||||
render: (args) => {
|
||||
externalSaveDeferred = createDeferred<void>();
|
||||
return (
|
||||
<ExternalLoadingProviderForm
|
||||
args={args}
|
||||
deferred={externalSaveDeferred}
|
||||
/>
|
||||
);
|
||||
},
|
||||
args: {
|
||||
editing: true,
|
||||
openAiAnthropicSavedApiKey: true,
|
||||
@@ -157,11 +431,25 @@ export const CredentialFocusClear: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const apiKeyInput = await canvas.findByLabelText(/api key/i);
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD");
|
||||
const submitButton = canvas.getByRole("button", {
|
||||
name: /update provider/i,
|
||||
});
|
||||
|
||||
await userEvent.click(apiKeyInput);
|
||||
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
|
||||
await userEvent.type(apiKeyInput, "sk-ant-I1lO0-new-secret");
|
||||
await waitFor(() => expect(submitButton).toBeEnabled());
|
||||
|
||||
await userEvent.click(
|
||||
canvas.getByRole("button", { name: /simulate external save/i }),
|
||||
);
|
||||
await waitFor(() => expect(submitButton).toBeDisabled());
|
||||
externalSaveDeferred.resolve();
|
||||
await waitFor(() => expect(submitButton).toBeEnabled());
|
||||
expect(apiKeyInput).toHaveValue("sk-ant-I1lO0-new-secret");
|
||||
},
|
||||
};
|
||||
|
||||
export const UnsavedChangesPrompt: Story = {
|
||||
args: {
|
||||
editing: true,
|
||||
|
||||
@@ -259,6 +259,21 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
const typeDefaults =
|
||||
providerDefaults[resolvedType as keyof typeof providerDefaults];
|
||||
|
||||
// Seed Bedrock credentials with the mask when on file; focus clears it,
|
||||
// and a re-submitted "" tells the API mapping to keep the value.
|
||||
const maskedAccessKey = bedrockSavedAccessCredentials
|
||||
? SAVED_CREDENTIAL_MASK
|
||||
: "";
|
||||
const maskedAccessKeySecret = bedrockSavedAccessCredentials
|
||||
? SAVED_CREDENTIAL_MASK
|
||||
: "";
|
||||
// Same pattern for openai/anthropic. Prefer the API-supplied masked
|
||||
// rendering so the user sees the key's identifying suffix.
|
||||
const maskedApiKey = openAiAnthropicSavedApiKey
|
||||
? (openAiAnthropicMaskedApiKey ?? SAVED_CREDENTIAL_MASK)
|
||||
: "";
|
||||
|
||||
const didSubmit = useRef(false);
|
||||
const form = useFormik<ProviderFormValues>({
|
||||
initialValues: {
|
||||
...defaultInitialValues,
|
||||
@@ -266,21 +281,16 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
// Edit overrides prefills with server values; create gets them as-is.
|
||||
...(typeDefaults ?? {}),
|
||||
...initialValues,
|
||||
// Seed Bedrock credentials with the mask when on file; focus clears it,
|
||||
// and a re-submitted "" tells the API mapping to keep the value.
|
||||
accessKey: bedrockSavedAccessCredentials ? SAVED_CREDENTIAL_MASK : "",
|
||||
accessKeySecret: bedrockSavedAccessCredentials
|
||||
? SAVED_CREDENTIAL_MASK
|
||||
: "",
|
||||
// Same pattern for openai/anthropic. Prefer the API-supplied masked
|
||||
// rendering so the user sees the key's identifying suffix.
|
||||
apiKey: openAiAnthropicSavedApiKey
|
||||
? (openAiAnthropicMaskedApiKey ?? SAVED_CREDENTIAL_MASK)
|
||||
: "",
|
||||
accessKey: maskedAccessKey,
|
||||
accessKeySecret: maskedAccessKeySecret,
|
||||
apiKey: maskedApiKey,
|
||||
},
|
||||
validationSchema: getProviderFormSchema(editing),
|
||||
validateOnMount: true,
|
||||
onSubmit: onSubmit ?? (() => {}),
|
||||
onSubmit: (values) => {
|
||||
didSubmit.current = true;
|
||||
return onSubmit?.(values);
|
||||
},
|
||||
});
|
||||
const getFieldHelpers = getFormHelpers(form, submitError);
|
||||
|
||||
@@ -297,17 +307,46 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
// Restores the mask when the user leaves the field without entering
|
||||
// a new value, keeping the saved-credential appearance.
|
||||
const handleCredentialBlur = (
|
||||
field: "apiKey" | "accessKey" | "accessKeySecret",
|
||||
) => {
|
||||
const initial = form.initialValues[field];
|
||||
if (form.values[field] === "" && initial !== "") {
|
||||
void form.setFieldValue(field, initial);
|
||||
}
|
||||
};
|
||||
|
||||
// When the parent's mutation finishes without an error, treat the just-
|
||||
// submitted values as the new baseline so the unsaved-changes prompt does
|
||||
// not fire on subsequent navigations. React Query reports a missing error
|
||||
// as `null`, so a truthy check covers both null and undefined.
|
||||
const previousIsLoading = useRef(isLoading);
|
||||
useEffect(() => {
|
||||
if (previousIsLoading.current && !isLoading && !submitError) {
|
||||
form.resetForm({ values: form.values });
|
||||
if (previousIsLoading.current && !isLoading) {
|
||||
if (didSubmit.current && !submitError) {
|
||||
// Restore credential fields to their initial masked sentinels so
|
||||
// the raw key is never left visible after a successful save.
|
||||
const remaskedValues = {
|
||||
...form.values,
|
||||
apiKey: maskedApiKey,
|
||||
accessKey: maskedAccessKey,
|
||||
accessKeySecret: maskedAccessKeySecret,
|
||||
};
|
||||
form.resetForm({ values: remaskedValues });
|
||||
}
|
||||
didSubmit.current = false;
|
||||
}
|
||||
previousIsLoading.current = isLoading;
|
||||
}, [isLoading, submitError, form]);
|
||||
}, [
|
||||
isLoading,
|
||||
submitError,
|
||||
form,
|
||||
maskedApiKey,
|
||||
maskedAccessKey,
|
||||
maskedAccessKeySecret,
|
||||
]);
|
||||
|
||||
const unsavedChanges = useUnsavedChangesPrompt(
|
||||
form.dirty && !form.isSubmitting,
|
||||
@@ -367,6 +406,7 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
required
|
||||
label="API key"
|
||||
helpers={getFieldHelpers("apiKey")}
|
||||
onBlur={() => handleCredentialBlur("apiKey")}
|
||||
onFocus={() => handleCredentialFocus("apiKey")}
|
||||
autoComplete="new-password"
|
||||
placeholder={apiKeyPlaceholder(form.values.type)}
|
||||
@@ -430,12 +470,15 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
required
|
||||
label="Access key"
|
||||
helpers={getFieldHelpers("accessKey")}
|
||||
onBlur={() => handleCredentialBlur("accessKey")}
|
||||
onFocus={() => handleCredentialFocus("accessKey")}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
<CredentialField
|
||||
required
|
||||
label="Access key secret"
|
||||
helpers={getFieldHelpers("accessKeySecret")}
|
||||
onBlur={() => handleCredentialBlur("accessKeySecret")}
|
||||
onFocus={() => handleCredentialFocus("accessKeySecret")}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
|
||||
@@ -2,6 +2,7 @@ import { act, renderHook } from "@testing-library/react";
|
||||
import { createRef } from "react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { ChatQueuedMessage } from "#/api/typesGenerated";
|
||||
import { createDeferred } from "#/testHelpers/deferred";
|
||||
import { MockUserOwner, MockWorkspace } from "#/testHelpers/entities";
|
||||
import {
|
||||
draftInputStorageKeyPrefix,
|
||||
@@ -79,22 +80,6 @@ const setMobileViewport = (isMobile: boolean) => {
|
||||
});
|
||||
};
|
||||
|
||||
type Deferred<T> = {
|
||||
promise: Promise<T>;
|
||||
resolve: (value: T | PromiseLike<T>) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
};
|
||||
|
||||
const createDeferred = <T>(): Deferred<T> => {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
return { promise, resolve, reject };
|
||||
};
|
||||
|
||||
describe("getWorkspaceOptionsWithLinkedWorkspace", () => {
|
||||
it("includes a missing linked workspace only when the current user owns it", () => {
|
||||
const existingWorkspace = {
|
||||
|
||||
+29
@@ -106,6 +106,35 @@ describe("applyKnownModelDefaults", () => {
|
||||
expect(result.appliedFields).toEqual([]);
|
||||
});
|
||||
|
||||
it("populates display name with the Known Model display name", () => {
|
||||
const result = applyDefaults({
|
||||
values: buildInitialModelFormValues(),
|
||||
initialValues: buildInitialModelFormValues(),
|
||||
provider: "anthropic",
|
||||
knownModel: requireKnownModel("anthropic", "claude-opus-4-8"),
|
||||
});
|
||||
|
||||
expect(result.values.displayName).toBe("Claude Opus 4.8");
|
||||
expect(result.appliedFields).toContain("displayName");
|
||||
});
|
||||
|
||||
it("skips display name when current value differs from initial value", () => {
|
||||
const values = setPath(
|
||||
buildInitialModelFormValues(),
|
||||
"displayName",
|
||||
"My Custom Name",
|
||||
);
|
||||
const result = applyDefaults({
|
||||
values,
|
||||
initialValues: buildInitialModelFormValues(),
|
||||
provider: "anthropic",
|
||||
knownModel: requireKnownModel("anthropic", "claude-opus-4-8"),
|
||||
});
|
||||
|
||||
expect(result.values.displayName).toBe("My Custom Name");
|
||||
expect(result.appliedFields).not.toContain("displayName");
|
||||
});
|
||||
|
||||
it("populates context limit when current value still equals initial value", () => {
|
||||
const result = applyDefaults({
|
||||
values: buildInitialModelFormValues(),
|
||||
|
||||
+9
@@ -83,6 +83,15 @@ export const applyKnownModelDefaults = ({
|
||||
const nextValues = structuredClone(values);
|
||||
const appliedFields: string[] = [];
|
||||
|
||||
maybeApplyDefault({
|
||||
appliedFields,
|
||||
initialValues,
|
||||
nextValues,
|
||||
path: "displayName",
|
||||
value: knownModel.displayName,
|
||||
values,
|
||||
});
|
||||
|
||||
if (knownModel.contextLimit !== undefined) {
|
||||
maybeApplyDefault({
|
||||
appliedFields,
|
||||
|
||||
@@ -1,28 +1,13 @@
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { API } from "#/api/api";
|
||||
import { createDeferred } from "#/testHelpers/deferred";
|
||||
import { chatDraftAttachmentStorageKey } from "../utils/chatDraftAttachmentStorage";
|
||||
import {
|
||||
resetChatDraftAttachmentRegistryForTest,
|
||||
useChatDraftAttachments,
|
||||
} from "./useChatDraftAttachments";
|
||||
|
||||
type Deferred<T> = {
|
||||
promise: Promise<T>;
|
||||
resolve: (value: T | PromiseLike<T>) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
};
|
||||
|
||||
const createDeferred = <T>(): Deferred<T> => {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
return { promise, resolve, reject };
|
||||
};
|
||||
|
||||
const orgID = "org-1";
|
||||
const chatID = "chat-a";
|
||||
const storageKey = chatDraftAttachmentStorageKey(orgID, chatID);
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
export type Deferred<T> = {
|
||||
promise: Promise<T>;
|
||||
resolve: (value: T | PromiseLike<T>) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
};
|
||||
|
||||
export const createDeferred = <T>(): Deferred<T> => {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
return { promise, resolve, reject };
|
||||
};
|
||||
Reference in New Issue
Block a user