Merge branch 'main' into fix/codagt-517-testagent-stats-ssh

This commit is contained in:
Mathias Fredriksson
2026-06-02 14:11:18 +03:00
committed by GitHub
25 changed files with 976 additions and 555 deletions
+13 -18
View File
@@ -166,23 +166,18 @@ runs:
mise_dir: ${{ steps.mise-data-dir.outputs.path }} mise_dir: ${{ steps.mise-data-dir.outputs.path }}
install_args: ${{ steps.cache-key.outputs.install-args }} install_args: ${{ steps.cache-key.outputs.install-args }}
cache: "false" cache: "false"
# Do not export mise's resolved env (every tool install dir) into
# GITHUB_ENV. Tools resolve through the shims dir on GITHUB_PATH, so
# the export only bloats PATH. On Windows the mise go shim re-prepends
# those dirs at invocation, and the resulting PATH crosses cmd.exe's
# ~8191 character limit, which makes cmd.exe drop PATH entirely and
# fail to resolve native executables in subprocesses spawned by tests.
env: false
- name: Ensure Git usr/bin is in PATH (Windows) - name: Add Git usr/bin to PATH (Windows)
if: runner.os == 'Windows' if: runner.os == 'Windows'
shell: pwsh shell: bash
# jdx/mise-action exports "Path" via GITHUB_ENV which may # GITHUB_PATH is the casing-safe channel and keeps the entry short.
# collide with bash's "PATH". Ensure Git usr/bin is present # cmd.exe subprocesses spawned by Go tests need MSYS coreutils such as
# and remove any duplicate Path/PATH entries from GITHUB_ENV # printf, which live here.
# by writing both forms. run: echo "C:\Program Files\Git\usr\bin" >> "$GITHUB_PATH"
run: | # zizmor: ignore[github-env]
$gitdir = "C:\Program Files\Git\usr\bin"
$current = $env:Path
if ($current -notlike "*$gitdir*") {
$current = "$gitdir;$current"
}
# Write both Path and PATH to GITHUB_ENV so that both
# cmd.exe (uses Path) and bash/Go (uses PATH) see the
# same value including Git usr/bin.
"Path=$current" | Out-File -Append -FilePath $env:GITHUB_ENV -Encoding utf8
"PATH=$current" | Out-File -Append -FilePath $env:GITHUB_ENV -Encoding utf8
+3 -1
View File
@@ -236,6 +236,9 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
traceAttrs := interceptor.TraceAttributes(r) traceAttrs := interceptor.TraceAttributes(r)
span.SetAttributes(traceAttrs...) span.SetAttributes(traceAttrs...)
ctx = tracing.WithInterceptionAttributesInContext(ctx, traceAttrs) ctx = tracing.WithInterceptionAttributesInContext(ctx, traceAttrs)
// Attach the interception ID to the context so every log line
// emitted with this context can be correlated to the interception.
ctx = slog.With(ctx, slog.F("interception_id", interceptor.ID()))
r = r.WithContext(ctx) r = r.WithContext(ctx)
// Record usage in the background to not block request flow. // Record usage in the background to not block request flow.
@@ -272,7 +275,6 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
log := logger.With( log := logger.With(
slog.F("route", route), slog.F("route", route),
slog.F("provider", p.Name()), slog.F("provider", p.Name()),
slog.F("interception_id", interceptor.ID()),
slog.F("user_agent", r.UserAgent()), slog.F("user_agent", r.UserAgent()),
slog.F("streaming", interceptor.Streaming()), slog.F("streaming", interceptor.Streaming()),
slog.F("credential_kind", string(cred.Kind)), slog.F("credential_kind", string(cred.Kind)),
+6 -6
View File
@@ -40,7 +40,7 @@ func (r *WrappedRecorder) RecordInterception(ctx context.Context, req *Intercept
return nil return nil
} }
r.logger.Warn(ctx, "failed to record interception", slog.Error(err), slog.F("interception_id", req.ID)) r.logger.Warn(ctx, "failed to record interception", slog.Error(err))
return err return err
} }
@@ -58,7 +58,7 @@ func (r *WrappedRecorder) RecordInterceptionEnded(ctx context.Context, req *Inte
return nil return nil
} }
r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err), slog.F("interception_id", req.ID)) r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err))
return err return err
} }
@@ -76,7 +76,7 @@ func (r *WrappedRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsag
return nil return nil
} }
r.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err), slog.F("interception_id", req.InterceptionID)) r.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err))
return err return err
} }
@@ -94,7 +94,7 @@ func (r *WrappedRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageR
return nil return nil
} }
r.logger.Warn(ctx, "failed to record token usage", slog.Error(err), slog.F("interception_id", req.InterceptionID)) r.logger.Warn(ctx, "failed to record token usage", slog.Error(err))
return err return err
} }
@@ -112,7 +112,7 @@ func (r *WrappedRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRec
return nil return nil
} }
r.logger.Warn(ctx, "failed to record tool usage", slog.Error(err), slog.F("interception_id", req.InterceptionID)) r.logger.Warn(ctx, "failed to record tool usage", slog.Error(err))
return err return err
} }
@@ -130,7 +130,7 @@ func (r *WrappedRecorder) RecordModelThought(ctx context.Context, req *ModelThou
return nil return nil
} }
r.logger.Warn(ctx, "failed to record model thought", slog.Error(err), slog.F("interception_id", req.InterceptionID)) r.logger.Warn(ctx, "failed to record model thought", slog.Error(err))
return err return err
} }
+1 -3
View File
@@ -146,10 +146,8 @@ func TestWorkspaceAgent(t *testing.T) {
}).WithAgent().Do() }).WithAgent().Do()
coderURLEnv := "$CODER_URL" coderURLEnv := "$CODER_URL"
headerCmd := "printf X-Process-Testing=very-wow-" + coderURLEnv + "'\\r\\n'X-Process-Testing2=more-wow"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
coderURLEnv = "%CODER_URL%" coderURLEnv = "%CODER_URL%"
headerCmd = "echo X-Process-Testing=very-wow-" + coderURLEnv + "& echo X-Process-Testing2=more-wow"
} }
logDir := t.TempDir() logDir := t.TempDir()
@@ -161,7 +159,7 @@ func TestWorkspaceAgent(t *testing.T) {
"--log-dir", logDir, "--log-dir", logDir,
"--agent-header", "X-Testing=agent", "--agent-header", "X-Testing=agent",
"--agent-header", "Cool-Header=Ethan was Here!", "--agent-header", "Cool-Header=Ethan was Here!",
"--agent-header-command", headerCmd, "--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
"--socket-path", testutil.AgentSocketPath(t), "--socket-path", testutil.AgentSocketPath(t),
) )
clitest.Start(t, agentInv) clitest.Start(t, agentInv)
+2 -9
View File
@@ -229,15 +229,8 @@ func Test_sshConfigMatchExecEscape(t *testing.T) {
// OpenSSH processes %% escape sequences into % // OpenSSH processes %% escape sequences into %
escaped = strings.ReplaceAll(escaped, "%%", "%") escaped = strings.ReplaceAll(escaped, "%%", "%")
c := exec.Command(cmd, arg, escaped) //nolint:gosec b, err := exec.Command(cmd, arg, escaped).CombinedOutput() //nolint:gosec
if runtime.GOOS == "windows" { require.NoError(t, err)
// Deduplicate Path/PATH env vars so cmd.exe
// subprocesses (like powershell.exe used for
// paths with spaces) resolve correctly.
c.Env = appendAndDedupEnv(os.Environ())
}
b, err := c.CombinedOutput()
require.NoError(t, err, "command output: %s", string(b))
got := strings.TrimSpace(string(b)) got := strings.TrimSpace(string(b))
require.Equal(t, "yay", got) require.Equal(t, "yay", got)
}) })
+1 -9
View File
@@ -4,8 +4,6 @@ package cli
import ( import (
"fmt" "fmt"
"os"
"path/filepath"
"strings" "strings"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@@ -52,13 +50,7 @@ func sshConfigMatchExecEscape(path string) (string, error) {
if strings.ContainsAny(path, " ") { if strings.ContainsAny(path, " ") {
// c.f. function comment for how this works. // c.f. function comment for how this works.
// Use absolute paths for powershell.exe and cmd.exe path = fmt.Sprintf("for /f %%%%a in ('powershell.exe -Command [char]34') do @cmd.exe /c %%%%a%s%%%%a", path) //nolint:gocritic // We don't want %q here.
// to avoid PATH resolution issues when both Path and
// PATH (MSYS-translated) exist in the environment.
sysRoot := os.Getenv("SYSTEMROOT")
pwsh := filepath.Join(sysRoot, "System32", "WindowsPowerShell", "v1.0", "powershell.exe")
cmd := filepath.Join(sysRoot, "System32", "cmd.exe")
path = fmt.Sprintf("for /f %%%%a in ('%s -Command [char]34') do @%s /c %%%%a%s%%%%a", pwsh, cmd, path) //nolint:gocritic // We don't want %q here.
} }
return path, nil return path, nil
} }
+2 -39
View File
@@ -1701,44 +1701,7 @@ func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r(req) return r(req)
} }
// appendAndDedupEnv appends extra environment variables and // HeaderTransport creates a new transport that executes `--header-command`
// deduplicates entries with the same key (case-insensitive on
// Windows). For the PATH variable specifically, it prefers the
// value that contains native Windows paths (with backslashes)
// over MSYS-translated paths (with forward slashes). For all
// other variables, the last value wins.
func appendAndDedupEnv(env []string, extra ...string) []string {
env = append(env, extra...)
if runtime.GOOS != "windows" {
return env
}
seen := make(map[string]int, len(env))
result := make([]string, 0, len(env))
for _, e := range env {
key, val, ok := strings.Cut(e, "=")
if !ok {
result = append(result, e)
continue
}
upper := strings.ToUpper(key)
if idx, exists := seen[upper]; exists {
if upper == "PATH" {
// Prefer the value with native Windows paths.
existingVal := result[idx][len(key)+1:]
if strings.Contains(existingVal, "\\") && !strings.Contains(val, "\\") {
continue
}
}
result[idx] = e
continue
}
seen[upper] = len(result)
result = append(result, e)
}
return result
}
// headerTransport creates a new transport that executes `--header-command`
// if it is set to add headers for all outbound requests. // if it is set to add headers for all outbound requests.
func headerTransport(ctx context.Context, serverURL *url.URL, header []string, headerCommand string) (*codersdk.HeaderTransport, error) { func headerTransport(ctx context.Context, serverURL *url.URL, header []string, headerCommand string) (*codersdk.HeaderTransport, error) {
transport := &codersdk.HeaderTransport{ transport := &codersdk.HeaderTransport{
@@ -1756,7 +1719,7 @@ func headerTransport(ctx context.Context, serverURL *url.URL, header []string, h
var outBuf bytes.Buffer var outBuf bytes.Buffer
// #nosec // #nosec
cmd := exec.CommandContext(ctx, shell, caller, headerCommand) cmd := exec.CommandContext(ctx, shell, caller, headerCommand)
cmd.Env = appendAndDedupEnv(os.Environ(), "CODER_URL="+serverURL.String()) cmd.Env = append(os.Environ(), "CODER_URL="+serverURL.String())
cmd.Stdout = &outBuf cmd.Stdout = &outBuf
cmd.Stderr = io.Discard cmd.Stderr = io.Discard
err := cmd.Run() err := cmd.Run()
+2 -4
View File
@@ -177,17 +177,15 @@ func TestRoot(t *testing.T) {
url = srv.URL url = srv.URL
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
coderURLEnv := "$CODER_URL" coderURLEnv := "$CODER_URL"
headerCmd := "printf X-Process-Testing=very-wow-" + coderURLEnv + "'\\r\\n'X-Process-Testing2=more-wow"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
coderURLEnv = "%CODER_URL%" coderURLEnv = "%CODER_URL%"
headerCmd = "echo X-Process-Testing=very-wow-" + coderURLEnv + "& echo X-Process-Testing2=more-wow"
} }
inv, _ := clitest.New(t, inv, _ := clitest.New(t,
"--no-feature-warning", "--no-feature-warning",
"--no-version-warning", "--no-version-warning",
"--header", "X-Testing=wow", "--header", "X-Testing=wow",
"--header", "Cool-Header=Dean was Here!", "--header", "Cool-Header=Dean was Here!",
"--header-command", headerCmd, "--header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
"login", srv.URL, "login", srv.URL,
) )
inv.Stdout = buf inv.Stdout = buf
@@ -268,7 +266,7 @@ func TestDERPHeaders(t *testing.T) {
"--no-version-warning", "--no-version-warning",
"ping", workspace.Name, "ping", workspace.Name,
"-n", "1", "-n", "1",
"--header-command", "echo X-Process-Testing=very-wow", "--header-command", "printf X-Process-Testing=very-wow",
} }
for k, v := range expectedHeaders { for k, v := range expectedHeaders {
if k != "X-Process-Testing" { if k != "X-Process-Testing" {
+20 -14
View File
@@ -6468,22 +6468,14 @@ type runChatResult struct {
HistoryTipMessageID int64 HistoryTipMessageID int64
} }
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
if !ok {
return ctx
}
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
}
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) { func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
message := messages[i] message := messages[i]
if message.Role != database.ChatMessageRoleUser { if message.Role != database.ChatMessageRoleUser {
continue continue
} }
if message.Visibility != database.ChatMessageVisibilityBoth && if !isUserVisibleChatMessage(message) &&
message.Visibility != database.ChatMessageVisibilityUser { !(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
continue continue
} }
if !message.APIKeyID.Valid || message.APIKeyID.String == "" { if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
@@ -6494,6 +6486,11 @@ func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bo
return "", false return "", false
} }
func isUserVisibleChatMessage(message database.ChatMessage) bool {
return message.Visibility == database.ChatMessageVisibilityBoth ||
message.Visibility == database.ChatMessageVisibilityUser
}
func allToolNames(allTools []fantasy.AgentTool) []string { func allToolNames(allTools []fantasy.AgentTool) []string {
toolNames := make([]string, 0, len(allTools)) toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools { for _, tool := range allTools {
@@ -7124,7 +7121,9 @@ func (p *Server) runChat(
return result, xerrors.Errorf("get chat messages: %w", err) return result, xerrors.Errorf("get chat messages: %w", err)
} }
modelOpts := modelBuildOptionsFromMessages(messages) modelOpts := modelBuildOptionsFromMessages(messages)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages) if modelOpts.ActiveAPIKeyID != "" {
ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID)
}
// Load MCP server configs and user tokens in parallel with model // Load MCP server configs and user tokens in parallel with model
// resolution. These queries have no dependencies on each other and all // resolution. These queries have no dependencies on each other and all
@@ -7831,6 +7830,7 @@ func (p *Server) runChat(
persistCtx, persistCtx,
chat.ID, chat.ID,
modelConfig.ID, modelConfig.ID,
modelOpts.ActiveAPIKeyID,
compactionToolCallID, compactionToolCallID,
result, result,
); err != nil { ); err != nil {
@@ -8460,12 +8460,14 @@ func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.P
return tools return tools
} }
// persistChatContextSummary persists a chat context summary to the database. // persistChatContextSummary is called from the chat loop's compaction
// This is invoked via the chat loop's compaction callback. // callback. activeAPIKeyID is stamped onto the summary user message. When
// empty, it falls back to the delegated key in ctx.
func (p *Server) persistChatContextSummary( func (p *Server) persistChatContextSummary(
ctx context.Context, ctx context.Context,
chatID uuid.UUID, chatID uuid.UUID,
modelConfigID uuid.UUID, modelConfigID uuid.UUID,
activeAPIKeyID string,
toolCallID string, toolCallID string,
result chatloop.CompactionResult, result chatloop.CompactionResult,
) error { ) error {
@@ -8514,6 +8516,11 @@ func (p *Server) persistChatContextSummary(
return xerrors.Errorf("encode summary tool result: %w", err) return xerrors.Errorf("encode summary tool result: %w", err)
} }
summaryAPIKeyID := activeAPIKeyID
if summaryAPIKeyID == "" {
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
}
var insertedMessages []database.ChatMessage var insertedMessages []database.ChatMessage
txErr := p.db.InTx(func(tx database.Store) error { txErr := p.db.InTx(func(tx database.Store) error {
@@ -8522,7 +8529,6 @@ func (p *Server) persistChatContextSummary(
} }
// Hidden summary user message (not published to subscribers). // Hidden summary user message (not published to subscribers).
summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
summaryUserMsg := newUserChatMessage( summaryUserMsg := newUserChatMessage(
summaryAPIKeyID, summaryAPIKeyID,
systemContent, systemContent,
+28 -7
View File
@@ -6651,15 +6651,22 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
UserID: user.ID, UserID: user.ID,
}) })
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
server := &Server{db: db} server := &Server{db: db}
persistAndAssertSummaryKey := func(
summaryCtx context.Context,
chatID uuid.UUID,
activeAPIKeyID string,
wantAPIKeyID string,
toolCallID string,
) {
t.Helper()
err := server.persistChatContextSummary( err := server.persistChatContextSummary(
ctx, summaryCtx,
chat.ID, chatID,
modelConfig.ID, modelConfig.ID,
"tool-call-id-1", activeAPIKeyID,
toolCallID,
chatloop.CompactionResult{ chatloop.CompactionResult{
SystemSummary: "summarized context", SystemSummary: "summarized context",
SummaryReport: "context was summarized", SummaryReport: "context was summarized",
@@ -6671,7 +6678,7 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chatID)
require.NoError(t, err) require.NoError(t, err)
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE // GetChatMessagesForPromptByChatID uses a compaction boundary CTE
@@ -6685,8 +6692,22 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
if msg.Role == database.ChatMessageRoleUser { if msg.Role == database.ChatMessageRoleUser {
foundUserSummary = true foundUserSummary = true
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set") require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match") require.Equal(t, wantAPIKeyID, msg.APIKeyID.String, "summary user message APIKeyID must match")
} }
} }
require.True(t, foundUserSummary, "expected to find compressed user summary message") require.True(t, foundUserSummary, "expected to find compressed user summary message")
} }
persistAndAssertSummaryKey(ctx, chat.ID, apiKey.ID, apiKey.ID, "tool-call-id-1")
fallbackChat := dbgen.Chat(t, db, database.Chat{
OwnerID: user.ID,
OrganizationID: org.ID,
LastModelConfigID: modelConfig.ID,
})
fallbackKey, _ := dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
})
fallbackCtx := aibridge.WithDelegatedAPIKeyID(ctx, fallbackKey.ID)
persistAndAssertSummaryKey(fallbackCtx, fallbackChat.ID, "", fallbackKey.ID, "tool-call-id-2")
}
+12 -17
View File
@@ -26,6 +26,7 @@ import (
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sqlc-dev/pqtype" "github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@@ -9914,7 +9915,7 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
MaxUsesPerRun: 3, MaxUsesPerRun: 3,
MaxOutputTokens: 16384, MaxOutputTokens: 16384,
}) })
server := newActiveTestServer(t, db, ps) server := newTestServer(t, db, ps, uuid.New())
chat, err := server.CreateChat(ctx, chatd.CreateOptions{ chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID, OrganizationID: org.ID,
@@ -9927,13 +9928,7 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
// Subscribe before the worker commits any durable messages so we // Advisor deltas are transient; a late subscriber misses them.
// observe the advisor tool-result deltas live. Buffered parts are
// claimed by their committed durable message ID at publishMessage
// time and dropped from snapshots of late-connecting subscribers, so
// a post-completion Subscribe() would no longer see streaming
// deltas. Collecting events from the live channel covers the
// streaming UX contract this test exists to verify.
_, liveEvents, cancelLive, ok := server.Subscribe(ctx, chat.ID, nil, 0) _, liveEvents, cancelLive, ok := server.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok) require.True(t, ok)
var ( var (
@@ -9969,6 +9964,8 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
} }
}() }()
server.Start()
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID) got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil { if getErr != nil {
@@ -10023,17 +10020,15 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) {
require.True(t, parentSawAdvisorResult, require.True(t, parentSawAdvisorResult,
"parent must see the advisor reply in its continuation call") "parent must see the advisor reply in its continuation call")
// Stop the live collector and assert it captured the streaming require.EventuallyWithT(t, func(c *assert.CollectT) {
// advisor deltas during processing. Late subscribers no longer livePartsMu.Lock()
// see committed parts because publishMessage claims them out of defer livePartsMu.Unlock()
// new snapshots, so the assertion must use the live collector. assert.Equal(c, advisorDeltas, liveAdvisorDeltas,
"advisor nested text deltas must stream into the parent tool card")
}, testutil.WaitLong, testutil.IntervalFast)
cancelLive() cancelLive()
<-liveCollectorDone <-liveCollectorDone
livePartsMu.Lock()
collectedAdvisorDeltas := append([]string(nil), liveAdvisorDeltas...)
livePartsMu.Unlock()
require.Equal(t, advisorDeltas, collectedAdvisorDeltas,
"advisor nested text deltas must stream into the parent tool card")
persisted, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ persisted, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID, ChatID: chat.ID,
+110 -8
View File
@@ -405,7 +405,7 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
}, },
}, },
{ {
name: "SkipsModelOnlyUserMessages", name: "SkipsUncompressedModelOnlyUserMessages",
messages: []database.ChatMessage{ messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)}, {ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)}, {ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
@@ -413,6 +413,54 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
wantKey: oldKeyID, wantKey: oldKeyID,
wantOK: true, wantOK: true,
}, },
{
name: "CompressedSummaryFallback",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "LatestCompressedSummaryWins",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "VisibleUserWinsOverCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
},
},
{
name: "UncompressedModelOnlyUserIgnored",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
},
},
{
name: "CompressedSummaryMissingKeyDoesNotFallBack",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@@ -421,15 +469,11 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages) gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
require.Equal(t, tt.wantOK, gotOK) require.Equal(t, tt.wantOK, gotOK)
require.Equal(t, tt.wantKey, gotKey) require.Equal(t, tt.wantKey, gotKey)
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.Equal(t, tt.wantOK, ctxOK)
require.Equal(t, tt.wantKey, ctxKey)
}) })
} }
} }
func TestActiveTurnContextUsesPromptMessages(t *testing.T) { func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel() t.Parallel()
db, _ := dbtestutil.NewDB(t) db, _ := dbtestutil.NewDB(t)
@@ -477,12 +521,70 @@ func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err) require.NoError(t, err)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages) gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.True(t, ok) require.True(t, ok)
require.Equal(t, currentKey.ID, gotKey) require.Equal(t, currentKey.ID, gotKey)
} }
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := t.Context()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(key.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityModel,
Compressed: true,
APIKeyID: sqlNullString(key.ID),
})
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
ids := make(map[int64]struct{}, len(messages))
for _, message := range messages {
ids[message.ID] = struct{}{}
}
_, hasVisibleUser := ids[visibleUser.ID]
require.False(t, hasVisibleUser)
_, hasSummary := ids[compressedSummary.ID]
require.True(t, hasSummary)
_, hasAfterSummary := ids[afterSummary.ID]
require.True(t, hasAfterSummary)
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
require.True(t, ok)
require.Equal(t, key.ID, gotKey)
}
func sqlNullString(value string) sql.NullString { func sqlNullString(value string) sql.NullString {
return sql.NullString{String: value, Valid: value != ""} return sql.NullString{String: value, Valid: value != ""}
} }
+1 -1
View File
@@ -48,7 +48,7 @@ func Test_ProxyServer_Headers(t *testing.T) {
"--access-url", "http://localhost:8080", "--access-url", "http://localhost:8080",
"--http-address", ":0", "--http-address", ":0",
"--header", fmt.Sprintf("%s=%s", headerName1, headerVal1), "--header", fmt.Sprintf("%s=%s", headerName1, headerVal1),
"--header-command", fmt.Sprintf("echo %s=%s", headerName2, headerVal2), "--header-command", fmt.Sprintf("printf %s=%s", headerName2, headerVal2),
) )
pty := ptytest.New(t) pty := ptytest.New(t)
inv.Stdout = pty.Output() inv.Stdout = pty.Output()
+11 -11
View File
@@ -89,7 +89,7 @@
"lodash": "4.18.1", "lodash": "4.18.1",
"lucide-react": "0.555.0", "lucide-react": "0.555.0",
"monaco-editor": "0.55.1", "monaco-editor": "0.55.1",
"motion": "12.38.0", "motion": "12.40.0",
"pretty-bytes": "6.1.1", "pretty-bytes": "6.1.1",
"radix-ui": "1.4.3", "radix-ui": "1.4.3",
"react": "19.2.6", "react": "19.2.6",
@@ -101,7 +101,7 @@
"react-markdown": "9.1.0", "react-markdown": "9.1.0",
"react-query": "npm:@tanstack/react-query@5.77.0", "react-query": "npm:@tanstack/react-query@5.77.0",
"react-resizable-panels": "3.0.6", "react-resizable-panels": "3.0.6",
"react-router": "7.12.0", "react-router": "7.15.1",
"react-syntax-highlighter": "15.6.6", "react-syntax-highlighter": "15.6.6",
"react-textarea-autosize": "8.5.9", "react-textarea-autosize": "8.5.9",
"react-virtualized-auto-sizer": "1.0.26", "react-virtualized-auto-sizer": "1.0.26",
@@ -111,7 +111,7 @@
"semver": "7.7.3", "semver": "7.7.3",
"sonner": "2.0.7", "sonner": "2.0.7",
"streamdown": "2.5.0", "streamdown": "2.5.0",
"tailwind-merge": "2.6.0", "tailwind-merge": "2.6.1",
"tailwindcss-animate": "1.0.7", "tailwindcss-animate": "1.0.7",
"tzdata": "1.0.46", "tzdata": "1.0.46",
"ua-parser-js": "1.0.41", "ua-parser-js": "1.0.41",
@@ -123,7 +123,7 @@
}, },
"devDependencies": { "devDependencies": {
"@babel/core": "7.29.7", "@babel/core": "7.29.7",
"@babel/plugin-syntax-typescript": "7.28.6", "@babel/plugin-syntax-typescript": "7.29.7",
"@biomejs/biome": "2.4.10", "@biomejs/biome": "2.4.10",
"@chromatic-com/storybook": "5.0.1", "@chromatic-com/storybook": "5.0.1",
"@octokit/types": "12.6.0", "@octokit/types": "12.6.0",
@@ -145,8 +145,8 @@
"@types/express": "4.17.17", "@types/express": "4.17.17",
"@types/file-saver": "2.0.7", "@types/file-saver": "2.0.7",
"@types/humanize-duration": "3.27.4", "@types/humanize-duration": "3.27.4",
"@types/lodash": "4.17.21", "@types/lodash": "4.17.24",
"@types/node": "20.19.39", "@types/node": "20.19.41",
"@types/novnc__novnc": "1.5.0", "@types/novnc__novnc": "1.5.0",
"@types/react": "19.2.15", "@types/react": "19.2.15",
"@types/react-color": "3.0.13", "@types/react-color": "3.0.13",
@@ -158,8 +158,8 @@
"@types/ssh2": "1.15.5", "@types/ssh2": "1.15.5",
"@types/ua-parser-js": "0.7.36", "@types/ua-parser-js": "0.7.36",
"@types/uuid": "9.0.2", "@types/uuid": "9.0.2",
"@vitejs/plugin-react": "6.0.1", "@vitejs/plugin-react": "6.0.2",
"@vitest/browser-playwright": "4.1.1", "@vitest/browser-playwright": "4.1.7",
"autoprefixer": "10.5.0", "autoprefixer": "10.5.0",
"babel-plugin-react-compiler": "1.0.0", "babel-plugin-react-compiler": "1.0.0",
"chromatic": "11.29.0", "chromatic": "11.29.0",
@@ -170,7 +170,7 @@
"jsdom": "27.2.0", "jsdom": "27.2.0",
"knip": "5.71.0", "knip": "5.71.0",
"msw": "2.4.8", "msw": "2.4.8",
"postcss": "8.5.10", "postcss": "8.5.15",
"protobufjs": "7.6.1", "protobufjs": "7.6.1",
"resize-observer-polyfill": "1.5.1", "resize-observer-polyfill": "1.5.1",
"rollup-plugin-visualizer": "7.0.1", "rollup-plugin-visualizer": "7.0.1",
@@ -181,9 +181,9 @@
"tailwindcss": "3.4.18", "tailwindcss": "3.4.18",
"ts-proto": "1.181.2", "ts-proto": "1.181.2",
"typescript": "6.0.2", "typescript": "6.0.2",
"vite": "8.0.10", "vite": "8.0.14",
"vite-plugin-checker": "0.13.0", "vite-plugin-checker": "0.13.0",
"vitest": "4.1.5" "vitest": "4.1.7"
}, },
"browserslist": [ "browserslist": [
"chrome 110", "chrome 110",
+323 -333
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -8,7 +8,7 @@ import type {
const aiProvidersListKey = ["ai", "providers"] as const; const aiProvidersListKey = ["ai", "providers"] as const;
const aiProviderKeyFor = (idOrName: string) => export const aiProviderKeyFor = (idOrName: string) =>
[...aiProvidersListKey, idOrName] as const; [...aiProvidersListKey, idOrName] as const;
export const aiProvidersList = () => ({ export const aiProvidersList = () => ({
@@ -7,6 +7,7 @@ import { toast } from "sonner";
import { getErrorMessage } from "#/api/errors"; import { getErrorMessage } from "#/api/errors";
import { import {
aiProvider, aiProvider,
aiProviderKeyFor,
deleteAIProviderMutation, deleteAIProviderMutation,
updateAIProviderMutation, updateAIProviderMutation,
} from "#/api/queries/aiProviders"; } from "#/api/queries/aiProviders";
@@ -171,6 +172,10 @@ const UpdateProviderPageView: React.FC = () => {
{ enabled: checked }, { enabled: checked },
{ {
onSuccess: (updated) => { onSuccess: (updated) => {
queryClient.setQueryData(
aiProviderKeyFor(providerId),
updated,
);
toast.success( toast.success(
`Provider "${updated.display_name || updated.name}" ${checked ? "enabled" : "disabled"}.`, `Provider "${updated.display_name || updated.name}" ${checked ? "enabled" : "disabled"}.`,
); );
@@ -200,6 +205,7 @@ const UpdateProviderPageView: React.FC = () => {
const request = providerFormValuesToUpdate(values, provider); const request = providerFormValuesToUpdate(values, provider);
try { try {
const updated = await updateMutation.mutateAsync(request); const updated = await updateMutation.mutateAsync(request);
queryClient.setQueryData(aiProviderKeyFor(providerId), updated);
toast.success( toast.success(
`Provider "${updated.display_name || updated.name}" updated.`, `Provider "${updated.display_name || updated.name}" updated.`,
); );
@@ -10,6 +10,7 @@ type CredentialFieldProps = {
placeholder?: string; placeholder?: string;
description?: React.ReactNode; description?: React.ReactNode;
required?: boolean; required?: boolean;
onBlur?: () => void;
onFocus?: () => void; onFocus?: () => void;
}; };
@@ -20,6 +21,7 @@ export const CredentialField: React.FC<CredentialFieldProps> = ({
placeholder, placeholder,
description, description,
required = false, required = false,
onBlur,
onFocus, onFocus,
}) => { }) => {
const inputId = useId(); const inputId = useId();
@@ -62,9 +64,13 @@ export const CredentialField: React.FC<CredentialFieldProps> = ({
<Input <Input
id={inputId} id={inputId}
name={helpers.name} name={helpers.name}
className="font-mono text-[13px]"
value={helpers.value} value={helpers.value}
onChange={helpers.onChange} onChange={helpers.onChange}
onBlur={helpers.onBlur} onBlur={(event) => {
helpers.onBlur(event);
onBlur?.();
}}
onFocus={onFocus} onFocus={onFocus}
autoComplete={autoComplete} autoComplete={autoComplete}
placeholder={placeholder} placeholder={placeholder}
@@ -1,6 +1,8 @@
import type { Meta, StoryObj } from "@storybook/react-vite"; import type { Meta, StoryObj } from "@storybook/react-vite";
import { type ComponentProps, useState } from "react";
import { expect, fn, screen, userEvent, waitFor, within } from "storybook/test"; import { expect, fn, screen, userEvent, waitFor, within } from "storybook/test";
import { ProviderForm } from "./ProviderForm"; import { createDeferred, type Deferred } from "#/testHelpers/deferred";
import { ProviderForm, SAVED_CREDENTIAL_MASK } from "./ProviderForm";
const meta: Meta<typeof ProviderForm> = { const meta: Meta<typeof ProviderForm> = {
title: "pages/AISettingsPage/ProviderForm", title: "pages/AISettingsPage/ProviderForm",
@@ -15,6 +17,88 @@ const meta: Meta<typeof ProviderForm> = {
export default meta; export default meta;
type Story = StoryObj<typeof ProviderForm>; type Story = StoryObj<typeof ProviderForm>;
const SuccessfulSubmitProviderForm = ({
args,
deferred,
}: {
args: ComponentProps<typeof ProviderForm>;
deferred: Deferred<void>;
}) => {
const [isLoading, setIsLoading] = useState(false);
return (
<ProviderForm
{...args}
isLoading={isLoading}
onSubmit={async (values) => {
args.onSubmit?.(values);
setIsLoading(true);
await deferred.promise;
setIsLoading(false);
}}
/>
);
};
const FailedSubmitProviderForm = ({
args,
deferred,
}: {
args: ComponentProps<typeof ProviderForm>;
deferred: Deferred<void>;
}) => {
const [isLoading, setIsLoading] = useState(false);
const [submitError, setSubmitError] = useState<unknown>();
return (
<ProviderForm
{...args}
isLoading={isLoading}
submitError={submitError}
onSubmit={async (values) => {
args.onSubmit?.(values);
setIsLoading(true);
await deferred.promise;
setSubmitError(new Error(errorSubmitMessage));
setIsLoading(false);
}}
/>
);
};
const ExternalLoadingProviderForm = ({
args,
deferred,
}: {
args: ComponentProps<typeof ProviderForm>;
deferred: Deferred<void>;
}) => {
const [isLoading, setIsLoading] = useState(false);
return (
<>
<ProviderForm {...args} isLoading={isLoading} />
<button
type="button"
onClick={async () => {
setIsLoading(true);
await deferred.promise;
setIsLoading(false);
}}
>
Simulate external save
</button>
</>
);
};
const errorSubmitMessage = "Failed to update provider.";
let bedrockSubmitDeferred = createDeferred<void>();
let apiKeySubmitDeferred = createDeferred<void>();
let failedSubmitDeferred = createDeferred<void>();
let externalSaveDeferred = createDeferred<void>();
export const AddAnthropicDefault: Story = {}; export const AddAnthropicDefault: Story = {};
export const AddOpenAI: Story = { export const AddOpenAI: Story = {
@@ -47,6 +131,15 @@ export const AddBedrock: Story = {
}; };
export const EditBedrockKeepCredentials: Story = { export const EditBedrockKeepCredentials: Story = {
render: (args) => {
bedrockSubmitDeferred = createDeferred<void>();
return (
<SuccessfulSubmitProviderForm
args={args}
deferred={bedrockSubmitDeferred}
/>
);
},
args: { args: {
editing: true, editing: true,
bedrockSavedAccessCredentials: true, bedrockSavedAccessCredentials: true,
@@ -62,6 +155,59 @@ export const EditBedrockKeepCredentials: Story = {
enabled: true, enabled: true,
}, },
}, },
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
const accessKeyInput = await canvas.findByLabelText(/^access key\s*\*?$/i);
const accessKeySecretInput =
await canvas.findByLabelText(/access key secret/i);
expect(accessKeyInput).toHaveProperty("type", "text");
expect(accessKeySecretInput).toHaveProperty("type", "text");
expect(accessKeyInput).toHaveValue(SAVED_CREDENTIAL_MASK);
expect(accessKeySecretInput).toHaveValue(SAVED_CREDENTIAL_MASK);
await userEvent.click(accessKeyInput);
await waitFor(() => expect(accessKeyInput).toHaveValue(""));
await userEvent.click(accessKeySecretInput);
await waitFor(() =>
expect(accessKeyInput).toHaveValue(SAVED_CREDENTIAL_MASK),
);
await userEvent.click(accessKeyInput);
await waitFor(() => expect(accessKeyInput).toHaveValue(""));
await userEvent.type(accessKeyInput, "AKIAI1lO0EXAMPLE");
expect(accessKeyInput).toHaveValue("AKIAI1lO0EXAMPLE");
await userEvent.click(accessKeySecretInput);
await waitFor(() => expect(accessKeySecretInput).toHaveValue(""));
await userEvent.type(accessKeySecretInput, "wJalrI1lO0Secret");
expect(accessKeySecretInput).toHaveValue("wJalrI1lO0Secret");
const displayName = canvas.getByLabelText(/display name/i);
await userEvent.clear(displayName);
await userEvent.type(displayName, "Updated Bedrock");
const submitButton = canvas.getByRole("button", {
name: /update provider/i,
});
await waitFor(() => expect(submitButton).toBeEnabled());
await userEvent.click(submitButton);
await waitFor(() =>
expect(args.onSubmit).toHaveBeenCalledWith(
expect.objectContaining({
accessKey: "AKIAI1lO0EXAMPLE",
accessKeySecret: "wJalrI1lO0Secret",
}),
),
);
await waitFor(() => expect(submitButton).toBeDisabled());
bedrockSubmitDeferred.resolve();
await waitFor(() => {
expect(accessKeyInput).toHaveValue(SAVED_CREDENTIAL_MASK);
expect(accessKeySecretInput).toHaveValue(SAVED_CREDENTIAL_MASK);
});
},
}; };
export const AddCopilot: Story = { export const AddCopilot: Story = {
@@ -141,6 +287,134 @@ export const Submitting: Story = {
}; };
export const CredentialFocusClear: Story = { export const CredentialFocusClear: Story = {
render: (args) => {
apiKeySubmitDeferred = createDeferred<void>();
return (
<SuccessfulSubmitProviderForm
args={args}
deferred={apiKeySubmitDeferred}
/>
);
},
args: {
editing: true,
openAiAnthropicSavedApiKey: true,
openAiAnthropicMaskedApiKey: "sk-ant-***\u2026***ABCD",
initialValues: {
type: "anthropic",
name: "production-anthropic",
displayName: "Production Anthropic",
baseUrl: "https://api.anthropic.com",
apiKey: "",
enabled: true,
},
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
const apiKeyInput = await canvas.findByLabelText(/api key/i);
expect(apiKeyInput).toHaveProperty("type", "text");
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD");
await userEvent.click(apiKeyInput);
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
const displayName = canvas.getByLabelText(/display name/i);
await userEvent.click(displayName);
await waitFor(() =>
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD"),
);
await userEvent.click(apiKeyInput);
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
await userEvent.type(apiKeyInput, "sk-ant-I1lO0-new-secret");
expect(apiKeyInput).toHaveValue("sk-ant-I1lO0-new-secret");
await userEvent.clear(displayName);
await userEvent.type(displayName, "Updated Anthropic");
const submitButton = canvas.getByRole("button", {
name: /update provider/i,
});
await waitFor(() => expect(submitButton).toBeEnabled());
await userEvent.click(submitButton);
await waitFor(() =>
expect(args.onSubmit).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: "sk-ant-I1lO0-new-secret",
}),
),
);
await waitFor(() => expect(submitButton).toBeDisabled());
apiKeySubmitDeferred.resolve();
await waitFor(() =>
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD"),
);
},
};
export const FailedSubmitKeepsCredential: Story = {
render: (args) => {
failedSubmitDeferred = createDeferred<void>();
return (
<FailedSubmitProviderForm args={args} deferred={failedSubmitDeferred} />
);
},
args: {
editing: true,
openAiAnthropicSavedApiKey: true,
openAiAnthropicMaskedApiKey: "sk-ant-***\u2026***ABCD",
initialValues: {
type: "anthropic",
name: "production-anthropic",
displayName: "Production Anthropic",
baseUrl: "https://api.anthropic.com",
apiKey: "",
enabled: true,
},
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
const apiKeyInput = await canvas.findByLabelText(/api key/i);
await userEvent.click(apiKeyInput);
await waitFor(() => expect(apiKeyInput).toHaveValue(""));
await userEvent.type(apiKeyInput, "sk-ant-I1lO0-new-secret");
const displayName = canvas.getByLabelText(/display name/i);
await userEvent.clear(displayName);
await userEvent.type(displayName, "Failed Anthropic");
const submitButton = canvas.getByRole("button", {
name: /update provider/i,
});
await waitFor(() => expect(submitButton).toBeEnabled());
await userEvent.click(submitButton);
await waitFor(() =>
expect(args.onSubmit).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: "sk-ant-I1lO0-new-secret",
}),
),
);
await waitFor(() => expect(submitButton).toBeDisabled());
failedSubmitDeferred.resolve();
await expect(await canvas.findByText(errorSubmitMessage)).toBeVisible();
expect(apiKeyInput).toHaveValue("sk-ant-I1lO0-new-secret");
},
};
export const ExternalLoadingKeepsCredential: Story = {
render: (args) => {
externalSaveDeferred = createDeferred<void>();
return (
<ExternalLoadingProviderForm
args={args}
deferred={externalSaveDeferred}
/>
);
},
args: { args: {
editing: true, editing: true,
openAiAnthropicSavedApiKey: true, openAiAnthropicSavedApiKey: true,
@@ -157,11 +431,25 @@ export const CredentialFocusClear: Story = {
play: async ({ canvasElement }) => { play: async ({ canvasElement }) => {
const canvas = within(canvasElement); const canvas = within(canvasElement);
const apiKeyInput = await canvas.findByLabelText(/api key/i); const apiKeyInput = await canvas.findByLabelText(/api key/i);
expect(apiKeyInput).toHaveValue("sk-ant-***\u2026***ABCD"); const submitButton = canvas.getByRole("button", {
name: /update provider/i,
});
await userEvent.click(apiKeyInput); await userEvent.click(apiKeyInput);
await waitFor(() => expect(apiKeyInput).toHaveValue("")); await waitFor(() => expect(apiKeyInput).toHaveValue(""));
await userEvent.type(apiKeyInput, "sk-ant-I1lO0-new-secret");
await waitFor(() => expect(submitButton).toBeEnabled());
await userEvent.click(
canvas.getByRole("button", { name: /simulate external save/i }),
);
await waitFor(() => expect(submitButton).toBeDisabled());
externalSaveDeferred.resolve();
await waitFor(() => expect(submitButton).toBeEnabled());
expect(apiKeyInput).toHaveValue("sk-ant-I1lO0-new-secret");
}, },
}; };
export const UnsavedChangesPrompt: Story = { export const UnsavedChangesPrompt: Story = {
args: { args: {
editing: true, editing: true,
@@ -259,6 +259,21 @@ export const ProviderForm: FC<ProviderFormProps> = ({
const typeDefaults = const typeDefaults =
providerDefaults[resolvedType as keyof typeof providerDefaults]; providerDefaults[resolvedType as keyof typeof providerDefaults];
// Seed Bedrock credentials with the mask when on file; focus clears it,
// and a re-submitted "" tells the API mapping to keep the value.
const maskedAccessKey = bedrockSavedAccessCredentials
? SAVED_CREDENTIAL_MASK
: "";
const maskedAccessKeySecret = bedrockSavedAccessCredentials
? SAVED_CREDENTIAL_MASK
: "";
// Same pattern for openai/anthropic. Prefer the API-supplied masked
// rendering so the user sees the key's identifying suffix.
const maskedApiKey = openAiAnthropicSavedApiKey
? (openAiAnthropicMaskedApiKey ?? SAVED_CREDENTIAL_MASK)
: "";
const didSubmit = useRef(false);
const form = useFormik<ProviderFormValues>({ const form = useFormik<ProviderFormValues>({
initialValues: { initialValues: {
...defaultInitialValues, ...defaultInitialValues,
@@ -266,21 +281,16 @@ export const ProviderForm: FC<ProviderFormProps> = ({
// Edit overrides prefills with server values; create gets them as-is. // Edit overrides prefills with server values; create gets them as-is.
...(typeDefaults ?? {}), ...(typeDefaults ?? {}),
...initialValues, ...initialValues,
// Seed Bedrock credentials with the mask when on file; focus clears it, accessKey: maskedAccessKey,
// and a re-submitted "" tells the API mapping to keep the value. accessKeySecret: maskedAccessKeySecret,
accessKey: bedrockSavedAccessCredentials ? SAVED_CREDENTIAL_MASK : "", apiKey: maskedApiKey,
accessKeySecret: bedrockSavedAccessCredentials
? SAVED_CREDENTIAL_MASK
: "",
// Same pattern for openai/anthropic. Prefer the API-supplied masked
// rendering so the user sees the key's identifying suffix.
apiKey: openAiAnthropicSavedApiKey
? (openAiAnthropicMaskedApiKey ?? SAVED_CREDENTIAL_MASK)
: "",
}, },
validationSchema: getProviderFormSchema(editing), validationSchema: getProviderFormSchema(editing),
validateOnMount: true, validateOnMount: true,
onSubmit: onSubmit ?? (() => {}), onSubmit: (values) => {
didSubmit.current = true;
return onSubmit?.(values);
},
}); });
const getFieldHelpers = getFormHelpers(form, submitError); const getFieldHelpers = getFormHelpers(form, submitError);
@@ -297,17 +307,46 @@ export const ProviderForm: FC<ProviderFormProps> = ({
} }
}; };
// Restores the mask when the user leaves the field without entering
// a new value, keeping the saved-credential appearance.
const handleCredentialBlur = (
field: "apiKey" | "accessKey" | "accessKeySecret",
) => {
const initial = form.initialValues[field];
if (form.values[field] === "" && initial !== "") {
void form.setFieldValue(field, initial);
}
};
// When the parent's mutation finishes without an error, treat the just- // When the parent's mutation finishes without an error, treat the just-
// submitted values as the new baseline so the unsaved-changes prompt does // submitted values as the new baseline so the unsaved-changes prompt does
// not fire on subsequent navigations. React Query reports a missing error // not fire on subsequent navigations. React Query reports a missing error
// as `null`, so a truthy check covers both null and undefined. // as `null`, so a truthy check covers both null and undefined.
const previousIsLoading = useRef(isLoading); const previousIsLoading = useRef(isLoading);
useEffect(() => { useEffect(() => {
if (previousIsLoading.current && !isLoading && !submitError) { if (previousIsLoading.current && !isLoading) {
form.resetForm({ values: form.values }); if (didSubmit.current && !submitError) {
// Restore credential fields to their initial masked sentinels so
// the raw key is never left visible after a successful save.
const remaskedValues = {
...form.values,
apiKey: maskedApiKey,
accessKey: maskedAccessKey,
accessKeySecret: maskedAccessKeySecret,
};
form.resetForm({ values: remaskedValues });
}
didSubmit.current = false;
} }
previousIsLoading.current = isLoading; previousIsLoading.current = isLoading;
}, [isLoading, submitError, form]); }, [
isLoading,
submitError,
form,
maskedApiKey,
maskedAccessKey,
maskedAccessKeySecret,
]);
const unsavedChanges = useUnsavedChangesPrompt( const unsavedChanges = useUnsavedChangesPrompt(
form.dirty && !form.isSubmitting, form.dirty && !form.isSubmitting,
@@ -367,6 +406,7 @@ export const ProviderForm: FC<ProviderFormProps> = ({
required required
label="API key" label="API key"
helpers={getFieldHelpers("apiKey")} helpers={getFieldHelpers("apiKey")}
onBlur={() => handleCredentialBlur("apiKey")}
onFocus={() => handleCredentialFocus("apiKey")} onFocus={() => handleCredentialFocus("apiKey")}
autoComplete="new-password" autoComplete="new-password"
placeholder={apiKeyPlaceholder(form.values.type)} placeholder={apiKeyPlaceholder(form.values.type)}
@@ -430,12 +470,15 @@ export const ProviderForm: FC<ProviderFormProps> = ({
required required
label="Access key" label="Access key"
helpers={getFieldHelpers("accessKey")} helpers={getFieldHelpers("accessKey")}
onBlur={() => handleCredentialBlur("accessKey")}
onFocus={() => handleCredentialFocus("accessKey")} onFocus={() => handleCredentialFocus("accessKey")}
autoComplete="new-password"
/> />
<CredentialField <CredentialField
required required
label="Access key secret" label="Access key secret"
helpers={getFieldHelpers("accessKeySecret")} helpers={getFieldHelpers("accessKeySecret")}
onBlur={() => handleCredentialBlur("accessKeySecret")}
onFocus={() => handleCredentialFocus("accessKeySecret")} onFocus={() => handleCredentialFocus("accessKeySecret")}
autoComplete="new-password" autoComplete="new-password"
/> />
@@ -2,6 +2,7 @@ import { act, renderHook } from "@testing-library/react";
import { createRef } from "react"; import { createRef } from "react";
import { beforeEach, describe, expect, it, vi } from "vitest"; import { beforeEach, describe, expect, it, vi } from "vitest";
import type { ChatQueuedMessage } from "#/api/typesGenerated"; import type { ChatQueuedMessage } from "#/api/typesGenerated";
import { createDeferred } from "#/testHelpers/deferred";
import { MockUserOwner, MockWorkspace } from "#/testHelpers/entities"; import { MockUserOwner, MockWorkspace } from "#/testHelpers/entities";
import { import {
draftInputStorageKeyPrefix, draftInputStorageKeyPrefix,
@@ -79,22 +80,6 @@ const setMobileViewport = (isMobile: boolean) => {
}); });
}; };
type Deferred<T> = {
promise: Promise<T>;
resolve: (value: T | PromiseLike<T>) => void;
reject: (reason?: unknown) => void;
};
const createDeferred = <T>(): Deferred<T> => {
let resolve!: (value: T | PromiseLike<T>) => void;
let reject!: (reason?: unknown) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return { promise, resolve, reject };
};
describe("getWorkspaceOptionsWithLinkedWorkspace", () => { describe("getWorkspaceOptionsWithLinkedWorkspace", () => {
it("includes a missing linked workspace only when the current user owns it", () => { it("includes a missing linked workspace only when the current user owns it", () => {
const existingWorkspace = { const existingWorkspace = {
@@ -106,6 +106,35 @@ describe("applyKnownModelDefaults", () => {
expect(result.appliedFields).toEqual([]); expect(result.appliedFields).toEqual([]);
}); });
it("populates display name with the Known Model display name", () => {
const result = applyDefaults({
values: buildInitialModelFormValues(),
initialValues: buildInitialModelFormValues(),
provider: "anthropic",
knownModel: requireKnownModel("anthropic", "claude-opus-4-8"),
});
expect(result.values.displayName).toBe("Claude Opus 4.8");
expect(result.appliedFields).toContain("displayName");
});
it("skips display name when current value differs from initial value", () => {
const values = setPath(
buildInitialModelFormValues(),
"displayName",
"My Custom Name",
);
const result = applyDefaults({
values,
initialValues: buildInitialModelFormValues(),
provider: "anthropic",
knownModel: requireKnownModel("anthropic", "claude-opus-4-8"),
});
expect(result.values.displayName).toBe("My Custom Name");
expect(result.appliedFields).not.toContain("displayName");
});
it("populates context limit when current value still equals initial value", () => { it("populates context limit when current value still equals initial value", () => {
const result = applyDefaults({ const result = applyDefaults({
values: buildInitialModelFormValues(), values: buildInitialModelFormValues(),
@@ -83,6 +83,15 @@ export const applyKnownModelDefaults = ({
const nextValues = structuredClone(values); const nextValues = structuredClone(values);
const appliedFields: string[] = []; const appliedFields: string[] = [];
maybeApplyDefault({
appliedFields,
initialValues,
nextValues,
path: "displayName",
value: knownModel.displayName,
values,
});
if (knownModel.contextLimit !== undefined) { if (knownModel.contextLimit !== undefined) {
maybeApplyDefault({ maybeApplyDefault({
appliedFields, appliedFields,
@@ -1,28 +1,13 @@
import { act, renderHook } from "@testing-library/react"; import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { API } from "#/api/api"; import { API } from "#/api/api";
import { createDeferred } from "#/testHelpers/deferred";
import { chatDraftAttachmentStorageKey } from "../utils/chatDraftAttachmentStorage"; import { chatDraftAttachmentStorageKey } from "../utils/chatDraftAttachmentStorage";
import { import {
resetChatDraftAttachmentRegistryForTest, resetChatDraftAttachmentRegistryForTest,
useChatDraftAttachments, useChatDraftAttachments,
} from "./useChatDraftAttachments"; } from "./useChatDraftAttachments";
type Deferred<T> = {
promise: Promise<T>;
resolve: (value: T | PromiseLike<T>) => void;
reject: (reason?: unknown) => void;
};
const createDeferred = <T>(): Deferred<T> => {
let resolve!: (value: T | PromiseLike<T>) => void;
let reject!: (reason?: unknown) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return { promise, resolve, reject };
};
const orgID = "org-1"; const orgID = "org-1";
const chatID = "chat-a"; const chatID = "chat-a";
const storageKey = chatDraftAttachmentStorageKey(orgID, chatID); const storageKey = chatDraftAttachmentStorageKey(orgID, chatID);
+15
View File
@@ -0,0 +1,15 @@
export type Deferred<T> = {
promise: Promise<T>;
resolve: (value: T | PromiseLike<T>) => void;
reject: (reason?: unknown) => void;
};
export const createDeferred = <T>(): Deferred<T> => {
let resolve!: (value: T | PromiseLike<T>) => void;
let reject!: (reason?: unknown) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return { promise, resolve, reject };
};