diff --git a/agent/x/agentdesktop/api.go b/agent/x/agentdesktop/api.go index 33ff0fb7ca..fc7686b072 100644 --- a/agent/x/agentdesktop/api.go +++ b/agent/x/agentdesktop/api.go @@ -183,6 +183,38 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { } resp.Output = "key action performed" + case "key_down": + if action.Text == nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing \"text\" for key_down action.", + }) + return + } + if err := a.desktop.KeyDown(ctx, *action.Text); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Key down failed.", + Detail: err.Error(), + }) + return + } + resp.Output = "key_down action performed" + + case "key_up": + if action.Text == nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing \"text\" for key_up action.", + }) + return + } + if err := a.desktop.KeyUp(ctx, *action.Text); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Key up failed.", + Detail: err.Error(), + }) + return + } + resp.Output = "key_up action performed" + case "type": if action.Text == nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ diff --git a/agent/x/agentdesktop/api_test.go b/agent/x/agentdesktop/api_test.go index 7663d677bc..a8c232d978 100644 --- a/agent/x/agentdesktop/api_test.go +++ b/agent/x/agentdesktop/api_test.go @@ -523,6 +523,59 @@ func TestHandleAction_TypeAction(t *testing.T) { assert.Equal(t, "hello world", fake.lastTyped) } +func TestHandleAction_KeyDownAndUp(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + action string + wantOutput string + }{ + {name: "KeyDown", action: "key_down", wantOutput: "key_down action performed"}, + {name: "KeyUp", action: "key_up", wantOutput: "key_up action performed"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + fake := &fakeDesktop{ + startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080}, + } + api := agentdesktop.NewAPI(logger, fake, nil) + defer api.Close() + + text := "ctrl" + body := agentdesktop.DesktopAction{ + Action: tt.action, + Text: &text, + } + b, err := json.Marshal(body) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + + handler := api.Routes() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + var resp agentdesktop.DesktopActionResponse + err = json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Equal(t, tt.wantOutput, resp.Output) + if tt.action == "key_down" { + assert.Equal(t, "ctrl", fake.lastKeyDown) + } else { + assert.Equal(t, "ctrl", fake.lastKeyUp) + } + }) + } +} + func TestHandleAction_HoldKey(t *testing.T) { t.Parallel() diff --git a/coderd/coderd.go b/coderd/coderd.go index 1ff2a5ed48..3901727a06 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1196,6 +1196,8 @@ func New(options *Options) *API { r.Put("/model-override/{context}", api.putChatModelOverride) r.Get("/desktop-enabled", api.getChatDesktopEnabled) r.Put("/desktop-enabled", api.putChatDesktopEnabled) + r.Get("/computer-use-provider", api.getChatComputerUseProvider) + r.Put("/computer-use-provider", api.putChatComputerUseProvider) r.Get("/debug-logging", api.getChatDebugLogging) r.Put("/debug-logging", api.putChatDebugLogging) r.Get("/user-debug-logging", api.getUserChatDebugLogging) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 0ab0b52123..7901aec54a 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2603,6 +2603,18 @@ func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (datab return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id) } +func (q *querier) GetChatComputerUseProvider(ctx context.Context) (string, error) { + // The computer-use provider is a deployment-wide runtime chat setting + // read by authenticated chat users and chatd. Feature and experiment + // access is enforced at caller and API boundaries where applicable, so + // this matches peer runtime config getters and only requires an explicit + // actor so unauthenticated calls fail closed. + if _, ok := ActorFromContext(ctx); !ok { + return "", ErrNoActor + } + return q.db.GetChatComputerUseProvider(ctx) +} + func (q *querier) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) { // The owner's chats, may cross orgs. AnyOrganization() authorizes // the caller if they hold read permission on chats owned by @@ -7437,6 +7449,13 @@ func (q *querier) UpsertChatAutoArchiveDays(ctx context.Context, autoArchiveDays return q.db.UpsertChatAutoArchiveDays(ctx, autoArchiveDays) } +func (q *querier) UpsertChatComputerUseProvider(ctx context.Context, provider string) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.UpsertChatComputerUseProvider(ctx, provider) +} + func (q *querier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2d6b189c8c..41490dcd5a 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -910,6 +910,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes() check.Args().Asserts() })) + s.Run("GetChatComputerUseProvider", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetChatComputerUseProvider(gomock.Any()).Return("anthropic", nil).AnyTimes() + check.Args().Asserts() + })) s.Run("GetChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetChatGeneralModelOverride(gomock.Any()).Return("", nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) @@ -1233,6 +1237,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) + s.Run("UpsertChatComputerUseProvider", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertChatComputerUseProvider(gomock.Any(), "anthropic").Return(nil).AnyTimes() + check.Args("anthropic").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) s.Run("UpsertChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().UpsertChatGeneralModelOverride(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index f9b7c9651b..a4e24772d1 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1152,6 +1152,14 @@ func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUI return r0, r1 } +func (m queryMetricsStore) GetChatComputerUseProvider(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetChatComputerUseProvider(ctx) + m.queryLatencies.WithLabelValues("GetChatComputerUseProvider").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatComputerUseProvider").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) { start := time.Now() r0, r1 := m.s.GetChatCostPerChat(ctx, arg) @@ -5328,6 +5336,14 @@ func (m queryMetricsStore) UpsertChatAutoArchiveDays(ctx context.Context, autoAr return r0 } +func (m queryMetricsStore) UpsertChatComputerUseProvider(ctx context.Context, provider string) error { + start := time.Now() + r0 := m.s.UpsertChatComputerUseProvider(ctx, provider) + m.queryLatencies.WithLabelValues("UpsertChatComputerUseProvider").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatComputerUseProvider").Inc() + return r0 +} + func (m queryMetricsStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error { start := time.Now() r0 := m.s.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 625c5a53cf..1252d277e0 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2117,6 +2117,21 @@ func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id) } +// GetChatComputerUseProvider mocks base method. +func (m *MockStore) GetChatComputerUseProvider(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatComputerUseProvider", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatComputerUseProvider indicates an expected call of GetChatComputerUseProvider. +func (mr *MockStoreMockRecorder) GetChatComputerUseProvider(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatComputerUseProvider", reflect.TypeOf((*MockStore)(nil).GetChatComputerUseProvider), ctx) +} + // GetChatCostPerChat mocks base method. func (m *MockStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) { m.ctrl.T.Helper() @@ -10011,6 +10026,20 @@ func (mr *MockStoreMockRecorder) UpsertChatAutoArchiveDays(ctx, autoArchiveDays return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatAutoArchiveDays", reflect.TypeOf((*MockStore)(nil).UpsertChatAutoArchiveDays), ctx, autoArchiveDays) } +// UpsertChatComputerUseProvider mocks base method. +func (m *MockStore) UpsertChatComputerUseProvider(ctx context.Context, provider string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChatComputerUseProvider", ctx, provider) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChatComputerUseProvider indicates an expected call of UpsertChatComputerUseProvider. +func (mr *MockStoreMockRecorder) UpsertChatComputerUseProvider(ctx, provider any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatComputerUseProvider", reflect.TypeOf((*MockStore)(nil).UpsertChatComputerUseProvider), ctx, provider) +} + // UpsertChatDebugLoggingAllowUsers mocks base method. func (m *MockStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6a8cedd3ab..c67ce67231 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -282,6 +282,7 @@ type sqlcQuerier interface { GetChatAutoArchiveDays(ctx context.Context, defaultAutoArchiveDays int32) (int32, error) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) + GetChatComputerUseProvider(ctx context.Context) (string, error) // Per-root-chat cost breakdown for a single user within a date range. // Groups by root_chat_id so forked chats roll up under their root. // Only counts assistant-role messages. @@ -1194,6 +1195,7 @@ type sqlcQuerier interface { // to JSON before invoking this query. UpsertChatAdvisorConfig(ctx context.Context, value string) error UpsertChatAutoArchiveDays(ctx context.Context, autoArchiveDays int32) error + UpsertChatComputerUseProvider(ctx context.Context, provider string) error // UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that // allows users to opt into chat debug logging. UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index ee331ee2f6..c4a80c2b46 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -20568,6 +20568,18 @@ func (q *sqlQuerier) GetChatAutoArchiveDays(ctx context.Context, defaultAutoArch return auto_archive_days, err } +const getChatComputerUseProvider = `-- name: GetChatComputerUseProvider :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_computer_use_provider'), '') :: text AS provider +` + +func (q *sqlQuerier) GetChatComputerUseProvider(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getChatComputerUseProvider) + var provider string + err := row.Scan(&provider) + return provider, err +} + const getChatDebugLoggingAllowUsers = `-- name: GetChatDebugLoggingAllowUsers :one SELECT COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users @@ -20967,6 +20979,16 @@ func (q *sqlQuerier) UpsertChatAutoArchiveDays(ctx context.Context, autoArchiveD return err } +const upsertChatComputerUseProvider = `-- name: UpsertChatComputerUseProvider :exec +INSERT INTO site_configs (key, value) VALUES ('agents_computer_use_provider', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_computer_use_provider' +` + +func (q *sqlQuerier) UpsertChatComputerUseProvider(ctx context.Context, provider string) error { + _, err := q.db.ExecContext(ctx, upsertChatComputerUseProvider, provider) + return err +} + const upsertChatDebugLoggingAllowUsers = `-- name: UpsertChatDebugLoggingAllowUsers :exec INSERT INTO site_configs (key, value) VALUES ( diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index 5c6e591023..629d89fc05 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -226,6 +226,14 @@ SELECT INSERT INTO site_configs (key, value) VALUES ('agents_advisor_config', $1) ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_advisor_config'; +-- name: GetChatComputerUseProvider :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_computer_use_provider'), '') :: text AS provider; + +-- name: UpsertChatComputerUseProvider :exec +INSERT INTO site_configs (key, value) VALUES ('agents_computer_use_provider', sqlc.arg(provider)) +ON CONFLICT (key) DO UPDATE SET value = sqlc.arg(provider) WHERE site_configs.key = 'agents_computer_use_provider'; + -- GetChatDebugLoggingAllowUsers returns the runtime admin setting that -- allows users to opt into chat debug logging when the deployment does -- not already force debug logging on globally. diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 9892f36ad4..202bdc1d2f 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -49,6 +49,7 @@ import ( "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/coderd/x/chatfiles" "github.com/coder/coder/v2/coderd/x/gitsync" "github.com/coder/coder/v2/codersdk" @@ -4106,6 +4107,58 @@ func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusNoContent) } +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. +func (api *API) getChatComputerUseProvider(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + provider, err := api.Database.GetChatComputerUseProvider(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching computer use provider.", + Detail: err.Error(), + }) + return + } + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatComputerUseProviderResponse{ + Provider: chattool.DefaultComputerUseProvider(provider), + }) +} + +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +func (api *API) putChatComputerUseProvider(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.Forbidden(rw) + return + } + + var req codersdk.UpdateChatComputerUseProviderRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + if !chattool.IsSupportedComputerUseProvider(req.Provider) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid computer use provider.", + Detail: fmt.Sprintf( + "Expected one of: %s. Got %q.", + strings.Join(chattool.SupportedComputerUseProviders(), ", "), + req.Provider, + ), + }) + return + } + + if err := api.Database.UpsertChatComputerUseProvider(ctx, req.Provider); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error updating computer use provider.", + Detail: err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + func (api *API) deploymentChatDebugLoggingEnabled() bool { return api.DeploymentValues != nil && api.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value() } diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index c93bf71cef..35fffae06c 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -10417,6 +10417,139 @@ func TestChatDesktopEnabled(t *testing.T) { }) } +func TestChatComputerUseProvider(t *testing.T) { + t.Parallel() + + t.Run("ReturnsAnthropicWhenUnset", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + + resp, err := adminClient.GetChatComputerUseProvider(ctx) + require.NoError(t, err) + require.Equal(t, "anthropic", resp.Provider) + }) + + t.Run("AdminCanSetAnthropic", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + + err := adminClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: "anthropic", + }) + require.NoError(t, err) + + resp, err := adminClient.GetChatComputerUseProvider(ctx) + require.NoError(t, err) + require.Equal(t, "anthropic", resp.Provider) + }) + + t.Run("AdminCanSetOpenAI", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + + err := adminClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: "openai", + }) + require.NoError(t, err) + + resp, err := adminClient.GetChatComputerUseProvider(ctx) + require.NoError(t, err) + require.Equal(t, "openai", resp.Provider) + }) + + t.Run("AdminCanSwitchProviders", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + + err := adminClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: "openai", + }) + require.NoError(t, err) + + err = adminClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: "anthropic", + }) + require.NoError(t, err) + + resp, err := adminClient.GetChatComputerUseProvider(ctx) + require.NoError(t, err) + require.Equal(t, "anthropic", resp.Provider) + }) + + t.Run("InvalidProviderRejected", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + + for _, provider := range []string{"", "invalid"} { + err := adminClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: provider, + }) + requireSDKError(t, err, http.StatusBadRequest) + } + }) + + t.Run("NonAdminCanRead", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + err := adminClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: "openai", + }) + require.NoError(t, err) + + resp, err := memberClient.GetChatComputerUseProvider(ctx) + require.NoError(t, err) + require.Equal(t, "openai", resp.Provider) + }) + + t.Run("NonAdminWriteFails", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + err := memberClient.UpdateChatComputerUseProvider(ctx, codersdk.UpdateChatComputerUseProviderRequest{ + Provider: "openai", + }) + requireSDKError(t, err, http.StatusForbidden) + }) + + t.Run("UnauthenticatedReadFails", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + + anonClient := codersdk.NewExperimentalClient(codersdk.New(adminClient.URL)) + _, err := anonClient.GetChatComputerUseProvider(ctx) + requireSDKError(t, err, http.StatusUnauthorized) + }) +} + func TestChatDebugLoggingSettings(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 97127e13b1..e3788f265c 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -6146,6 +6146,22 @@ func (p *Server) runChat( // Detect computer-use subagent via the mode column. isComputerUse := chat.Mode.Valid && chat.Mode.ChatMode == database.ChatModeComputerUse + var ( + computerUseProvider string + computerUseModelProvider string + computerUseModelName string + ) + if isComputerUse { + var err error + computerUseProvider, computerUseModelProvider, computerUseModelName, err = p.computerUseProviderAndModelFromConfig(ctx) + if err != nil { + return result, xerrors.Errorf( + "resolve computer use provider and model: %w", + err, + ) + } + } + // NOTE: Buffering was already started in processChat before // the running status was published, so message_part events // are captured from the moment subscribers can see @@ -6717,24 +6733,16 @@ func (p *Server) runChat( if isComputerUse { // Override model for computer use subagent. - resolvedProvider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint( - chattool.ComputerUseModelName, - chattool.ComputerUseModelProvider, - ) - if resolveErr != nil { - return result, xerrors.Errorf("resolve computer use model metadata: %w", resolveErr) - } - cuModel, cuDebugEnabled, cuErr := p.newDebugAwareModelFromConfig( + cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel( ctx, chat, - chattool.ComputerUseModelProvider, - chattool.ComputerUseModelName, providerKeys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), + computerUseProvider, + computerUseModelProvider, + computerUseModelName, ) if cuErr != nil { - return result, xerrors.Errorf("resolve computer use model: %w", cuErr) + return result, cuErr } model = cuModel debugEnabled = cuDebugEnabled @@ -6900,22 +6908,24 @@ func (p *Server) runChat( } } - if !isPlanModeTurn && !isExploreSubagent && isComputerUse { - desktopGeometry := workspacesdk.DefaultDesktopGeometry() - providerTools = append(providerTools, chatloop.ProviderTool{ - Definition: chattool.ComputerUseProviderTool( - desktopGeometry.DeclaredWidth, - desktopGeometry.DeclaredHeight, - ), - Runner: chattool.NewComputerUseTool( - desktopGeometry.DeclaredWidth, - desktopGeometry.DeclaredHeight, - workspaceCtx.getWorkspaceConn, - storeChatAttachment, - quartz.NewReal(), - p.logger.Named("computer_use"), - ), - }) + providerTools, err = appendComputerUseProviderTool( + providerTools, + computerUseProviderToolOptions{ + provider: computerUseProvider, + isPlanModeTurn: isPlanModeTurn, + isComputerUse: isComputerUse, + getWorkspaceConn: workspaceCtx.getWorkspaceConn, + storeFile: storeChatAttachment, + clock: p.clock, + logger: p.logger.Named("computer_use"), + }, + ) + if err != nil { + return result, xerrors.Errorf( + "register computer use provider tool for provider %q: %w", + computerUseProvider, + err, + ) } providerOptions := chatprovider.ProviderOptionsFromChatModelConfig( diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 896e3e8363..5a351f3371 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -18,11 +18,13 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmock" dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatloop" + openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/coderd/x/chatd/chattool" @@ -75,6 +77,172 @@ func (t *testMCPAgentTool) MCPServerConfigID() uuid.UUID { return t.configID } +func TestComputerUseProviderAndModelFromConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + rawProvider string + wantProvider string + wantErr string + }{ + { + name: "DefaultAnthropic", + rawProvider: "", + wantProvider: chattool.ComputerUseProviderAnthropic, + }, + { + name: "OpenAI", + rawProvider: " openai ", + wantProvider: chattool.ComputerUseProviderOpenAI, + }, + { + name: "Unknown", + rawProvider: "bogus", + wantErr: `unknown computer-use provider "bogus" configured in agents_computer_use_provider`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + server := &Server{db: db} + + db.EXPECT().GetChatComputerUseProvider(gomock.Any()).DoAndReturn( + func(ctx context.Context) (string, error) { + _, ok := dbauthz.ActorFromContext(ctx) + require.True(t, ok, "config reads must have an actor") + return tt.rawProvider, nil + }, + ) + + provider, modelProvider, modelName, err := server.computerUseProviderAndModelFromConfig(context.Background()) + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantProvider, provider) + + wantModelProvider, wantModelName, ok := chattool.DefaultComputerUseModel(tt.wantProvider) + require.True(t, ok) + require.Equal(t, wantModelProvider, modelProvider) + require.Equal(t, wantModelName, modelName) + }) + } +} + +func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) { + t.Parallel() + + server := &Server{} + provider := chattool.ComputerUseProviderOpenAI + modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider) + require.True(t, ok) + + model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel( + context.Background(), + database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, + chatprovider.ProviderAPIKeys{}, + provider, + modelProvider, + modelName, + ) + require.Error(t, err) + require.Nil(t, model) + require.False(t, debugEnabled) + require.Empty(t, resolvedProvider) + require.Empty(t, resolvedModel) + require.Contains(t, err.Error(), `provider "openai" model "gpt-5.5"`) + require.Contains(t, err.Error(), "OPENAI_API_KEY is not set") + require.NotContains(t, err.Error(), "ANTHROPIC_API_KEY") +} + +func TestAppendComputerUseProviderTool(t *testing.T) { + t.Parallel() + + providerTools, err := appendComputerUseProviderTool( + nil, + computerUseProviderToolOptions{ + provider: chattool.ComputerUseProviderOpenAI, + isComputerUse: true, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + }, + ) + require.NoError(t, err) + require.Len(t, providerTools, 1) + require.True(t, openaicomputeruse.IsTool(providerTools[0].Definition)) + require.Equal(t, "computer", providerTools[0].Definition.GetName()) + require.Equal(t, "computer", providerTools[0].Runner.Info().Name) + require.NotNil(t, providerTools[0].ResultProviderMetadata) + + metadata := providerTools[0].ResultProviderMetadata( + fantasy.NewImageResponse([]byte("png"), "image/png"), + ) + require.NotNil(t, metadata) +} + +func TestAppendComputerUseProviderTool_Gates(t *testing.T) { + t.Parallel() + + baseTools := []chatloop.ProviderTool{{ + Definition: fantasy.ProviderDefinedTool{ + ID: "web_search", + Name: "web_search", + }, + }} + + tests := []struct { + name string + isPlanModeTurn bool + isComputerUse bool + }{ + {name: "PlanMode", isPlanModeTurn: true, isComputerUse: true}, + // Non-computer-use includes regular, master, general, and explore chats. + // Mode cannot be both ChatModeComputerUse and another chat mode. + {name: "NonComputerUseModes"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + providerTools, err := appendComputerUseProviderTool( + baseTools, + computerUseProviderToolOptions{ + provider: chattool.ComputerUseProviderOpenAI, + isPlanModeTurn: tt.isPlanModeTurn, + isComputerUse: tt.isComputerUse, + }, + ) + require.NoError(t, err) + require.Len(t, providerTools, 1) + require.Equal(t, "web_search", providerTools[0].Definition.GetName()) + }) + } +} + +func TestAppendComputerUseProviderTool_AnthropicHasNoResultMetadata(t *testing.T) { + t.Parallel() + + providerTools, err := appendComputerUseProviderTool( + nil, + computerUseProviderToolOptions{ + provider: chattool.ComputerUseProviderAnthropic, + isComputerUse: true, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + }, + ) + require.NoError(t, err) + require.Len(t, providerTools, 1) + require.Equal(t, "computer", providerTools[0].Definition.GetName()) + require.Nil(t, providerTools[0].ResultProviderMetadata) +} + func TestFilterExternalMCPConfigsForTurn(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 621b1af9bd..46e99fbf89 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -6485,6 +6485,10 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { db, ps := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitLong) + computerUseModelProvider, computerUseModelName, ok := chattool.DefaultComputerUseModel(chattool.ComputerUseProviderAnthropic) + require.True(t, ok) + require.Equal(t, chattool.ComputerUseProviderAnthropic, computerUseModelProvider) + // Track tools and model from the Anthropic LLM calls (the // computer use child chat). We use a raw HTTP handler because // the chattest AnthropicRequest struct does not capture tools. @@ -6532,7 +6536,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { "id": "msg-test", "type": "message", "role": "assistant", - "model": chattool.ComputerUseModelName, + "model": computerUseModelName, "content": []map[string]any{{"type": "text", "text": "Done."}}, "stop_reason": "end_turn", "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, @@ -6552,7 +6556,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { "id": "msg-test", "type": "message", "role": "assistant", - "model": chattool.ComputerUseModelName, + "model": computerUseModelName, }, }, { @@ -6713,9 +6717,9 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { childTools := calls[0].Tools // 1. Verify the model is the computer use model. - require.Equal(t, chattool.ComputerUseModelName, childModel, + require.Equal(t, computerUseModelName, childModel, "computer use subagent should use %s", - chattool.ComputerUseModelName) + computerUseModelName) // 2. Verify the computer tool is present. require.Contains(t, childTools, "computer", diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index bda2167dca..d82032015b 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -195,6 +195,11 @@ type RunOptions struct { type ProviderTool struct { Definition fantasy.Tool Runner fantasy.AgentTool + // ResultProviderMetadata extracts provider-specific metadata from successful + // local runner responses. The chat loop attaches returned metadata to the tool + // result sent back to the model. OpenAI computer-use uses this to request + // original screenshot detail for image results. + ResultProviderMetadata func(response fantasy.ToolResponse) fantasy.ProviderMetadata } // stepResult holds the accumulated output of a single streaming @@ -1020,13 +1025,22 @@ func executeTools( toolMap[t.Info().Name] = t } providerRunnerNames := make(map[string]struct{}, len(providerTools)) + resultProviderMetadata := make( + map[string]func(fantasy.ToolResponse) fantasy.ProviderMetadata, + len(providerTools), + ) // Include runners from provider tools so locally-executed // provider tools (e.g. computer use) can be dispatched. for _, pt := range providerTools { - if pt.Runner != nil { - name := pt.Runner.Info().Name - toolMap[name] = pt.Runner - providerRunnerNames[name] = struct{}{} + if pt.Runner == nil { + continue + } + + name := pt.Runner.Info().Name + toolMap[name] = pt.Runner + providerRunnerNames[name] = struct{}{} + if pt.ResultProviderMetadata != nil { + resultProviderMetadata[name] = pt.ResultProviderMetadata } } @@ -1052,7 +1066,19 @@ func executeTools( // accurate individual completion times. completedAt[i] = dbtime.Now() }() - results[i] = executeSingleTool(ctx, toolMap, tc, metrics, logger, provider, model, builtinToolNames, activeTools, providerRunnerNames) + results[i] = executeSingleTool( + ctx, + toolMap, + tc, + metrics, + logger, + provider, + model, + builtinToolNames, + activeTools, + providerRunnerNames, + resultProviderMetadata, + ) }() } wg.Wait() @@ -1349,6 +1375,7 @@ func executeSingleTool( builtinToolNames map[string]bool, activeTools []string, providerRunnerNames map[string]struct{}, + resultProviderMetadata map[string]func(fantasy.ToolResponse) fantasy.ProviderMetadata, ) fantasy.ToolResultContent { result := fantasy.ToolResultContent{ ToolCallID: tc.ToolCallID, @@ -1430,6 +1457,18 @@ func executeSingleTool( Text: strings.ToValidUTF8(resp.Content, "\uFFFD"), } } + + if _, isError := result.Result.(fantasy.ToolResultOutputContentError); isError { + return result + } + if len(result.ProviderMetadata) == 0 { + if callback := resultProviderMetadata[tc.ToolName]; callback != nil { + metadata := callback(resp) + if len(metadata) > 0 { + result.ProviderMetadata = metadata + } + } + } return result } diff --git a/coderd/x/chatd/chatloop/chatloop_test.go b/coderd/x/chatd/chatloop/chatloop_test.go index 3b3a4483ed..89f5996e1a 100644 --- a/coderd/x/chatd/chatloop/chatloop_test.go +++ b/coderd/x/chatd/chatloop/chatloop_test.go @@ -339,6 +339,137 @@ func TestRun_ActiveToolsAllowsProviderRunnerExecution(t *testing.T) { "persisted step should include the provider runner result") } +func TestRun_ProviderToolResultProviderMetadata(t *testing.T) { + t.Parallel() + + expectedMetadata := fantasy.ProviderMetadata{ + "openai": &testProviderData{data: map[string]any{ + "detail": "original", + }}, + } + + tests := []struct { + name string + callback func(fantasy.ToolResponse) fantasy.ProviderMetadata + want fantasy.ProviderMetadata + }{ + { + name: "callback returns metadata", + callback: func(fantasy.ToolResponse) fantasy.ProviderMetadata { + return expectedMetadata + }, + want: expectedMetadata, + }, + { + name: "callback nil", + want: nil, + }, + { + name: "callback returns nil", + callback: func(fantasy.ToolResponse) fantasy.ProviderMetadata { + return nil + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + providerRunnerName := "computer" + model := &chattest.FakeModel{ + ProviderName: "fake", + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-provider-runner", ToolCallName: providerRunnerName}, + {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-provider-runner", Delta: `{}`}, + {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-provider-runner"}, + { + Type: fantasy.StreamPartTypeToolCall, + ID: "tc-provider-runner", + ToolCallName: providerRunnerName, + ToolCallInput: `{}`, + }, + {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, + }), nil + }, + } + + runnerTool := fantasy.NewAgentTool( + providerRunnerName, + "provider runner", + func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.ToolResponse{ + Type: "image", + Data: []byte("image bytes"), + MediaType: "image/png", + Content: "screenshot", + }, nil + }, + ) + + var persistedStep PersistedStep + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "use the computer"), + }, + ProviderTools: []ProviderTool{ + { + Definition: fantasy.FunctionTool{ + Name: providerRunnerName, + Description: "provider runner", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + Runner: runnerTool, + ResultProviderMetadata: tt.callback, + }, + }, + MaxSteps: 1, + PersistStep: func(_ context.Context, step PersistedStep) error { + persistedStep = step + return nil + }, + }) + require.NoError(t, err) + + var foundResult fantasy.ToolResultContent + for _, block := range persistedStep.Content { + toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) + if !ok || toolResult.ToolName != providerRunnerName { + continue + } + foundResult = toolResult + break + } + require.NotEmpty(t, foundResult.ToolCallID, + "persisted step should include the provider runner result") + + mediaResult, ok := foundResult.Result.(fantasy.ToolResultOutputContentMedia) + require.True(t, ok, "expected media result") + assert.Equal(t, "image/png", mediaResult.MediaType) + assert.Equal(t, tt.want, foundResult.ProviderMetadata) + + if tt.want == nil { + return + } + + messages := stepResult{content: persistedStep.Content}.toResponseMessages() + require.Len(t, messages, 2) + require.Equal(t, fantasy.MessageRoleTool, messages[1].Role) + require.Len(t, messages[1].Content, 1) + + resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](messages[1].Content[0]) + require.True(t, ok, "expected outbound tool result part") + assert.Equal(t, fantasy.ProviderOptions(tt.want), resultPart.ProviderOptions) + }) + } +} + func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) { t.Parallel() @@ -3917,6 +4048,7 @@ func TestExecuteSingleTool_MediaBase64Encoding(t *testing.T) { map[string]bool{}, []string{"screenshot"}, map[string]struct{}{}, + nil, ) media, ok := result.Result.(fantasy.ToolResultOutputContentMedia) @@ -3963,6 +4095,7 @@ func TestExecuteSingleTool_MediaBase64Encoding(t *testing.T) { map[string]bool{}, []string{"screenshot"}, map[string]struct{}{}, + nil, ) media, ok := result.Result.(fantasy.ToolResultOutputContentMedia) @@ -4004,6 +4137,7 @@ func TestExecuteSingleTool_MediaBase64Encoding(t *testing.T) { map[string]bool{}, []string{"echo"}, map[string]struct{}{}, + nil, ) textOutput, ok := result.Result.(fantasy.ToolResultOutputContentText) diff --git a/coderd/x/chatd/chatopenai/computeruse/computeruse.go b/coderd/x/chatd/chatopenai/computeruse/computeruse.go new file mode 100644 index 0000000000..116b2a78bb --- /dev/null +++ b/coderd/x/chatd/chatopenai/computeruse/computeruse.go @@ -0,0 +1,494 @@ +package computeruse + +import ( + "slices" + "strings" + "unicode" + + "charm.land/fantasy" + fantasyopenai "charm.land/fantasy/providers/openai" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// ComputerUseTool returns the OpenAI provider-defined computer-use tool. +func Tool() fantasy.Tool { + return fantasyopenai.NewComputerUseTool(nil).Definition() +} + +// IsComputerUseTool reports whether tool is the OpenAI provider-defined +// computer-use tool. +func IsTool(tool fantasy.Tool) bool { + return fantasyopenai.IsComputerUseTool(tool) +} + +// ParseInput parses an OpenAI computer-use tool call input. +func ParseInput(input string) (*fantasyopenai.ComputerUseInput, error) { + return fantasyopenai.ParseComputerUseInput(input) +} + +// ComputerUseResultProviderMetadata returns metadata that should accompany an +// OpenAI computer-use screenshot result. +func ResultProviderMetadata(response fantasy.ToolResponse) fantasy.ProviderMetadata { + if response.IsError || response.Type != "image" || len(response.Data) == 0 || + !strings.HasPrefix(response.MediaType, "image/") { + return nil + } + + return fantasy.ProviderMetadata{ + fantasyopenai.Name: &fantasyopenai.ComputerCallOutputOptions{ + Detail: "original", + }, + } +} + +// OpenAI scroll deltas are pixels, but Coder desktop scroll amounts are +// wheel clicks. +const computerUseScrollPixelsPerWheelClick int64 = 100 + +// ComputerUseDesktopAction is a Coder desktop operation requested by an +// OpenAI computer-use tool call. +type DesktopAction struct { + Action workspacesdk.DesktopAction + WaitDurationMillis int64 + ReleaseMouseOnFailure bool + ReleaseKeysOnFailure []string +} + +// ComputerUseDesktopActions converts an OpenAI computer-use tool call into +// Coder desktop actions. A caller should execute the returned actions in order, +// wait for WaitDurationMillis entries, and then return a final screenshot. +func DesktopActions( + parsed *fantasyopenai.ComputerUseInput, + declaredWidth, declaredHeight int, +) ([]DesktopAction, error) { + if parsed == nil { + return nil, xerrors.New("OpenAI computer use input is nil") + } + var err error + actions := make([]DesktopAction, 0, len(parsed.Actions)) + for _, action := range parsed.Actions { + switch action.Type { + case "screenshot": + // OpenAI returns one screenshot per response; individual screenshot + // actions in the batch are fulfilled by the batch-final capture. + continue + case "move": + actions = append(actions, DesktopAction{ + Action: desktopActionWithCoordinate( + "mouse_move", + declaredWidth, + declaredHeight, + action.X, + action.Y, + ), + }) + case "click": + actionSet, err := clickActions( + action.Button, + declaredWidth, + declaredHeight, + action.X, + action.Y, + ) + if err != nil { + return nil, err + } + actions, err = appendWithModifiers(actions, action.Keys, actionSet) + if err != nil { + return nil, err + } + case "double_click": + actionName, ok := DoubleClickAction(action.Button) + if !ok { + return nil, xerrors.Errorf( + "unsupported OpenAI double-click button %q", + action.Button, + ) + } + actionSet := []DesktopAction{{ + Action: desktopActionWithCoordinate( + actionName, + declaredWidth, + declaredHeight, + action.X, + action.Y, + ), + }} + actions, err = appendWithModifiers(actions, action.Keys, actionSet) + if err != nil { + return nil, err + } + case "drag": + if len(action.Path) < 2 { + return nil, xerrors.New("OpenAI drag action requires at least two path points") + } + actionSet := []DesktopAction{ + { + Action: desktopActionWithCoordinate( + "mouse_move", + declaredWidth, + declaredHeight, + action.Path[0].X, + action.Path[0].Y, + ), + }, + { + Action: desktopAction( + "left_mouse_down", + declaredWidth, + declaredHeight, + ), + ReleaseMouseOnFailure: true, + }, + } + for _, point := range action.Path[1:] { + actionSet = append(actionSet, DesktopAction{ + Action: desktopActionWithCoordinate( + "mouse_move", + declaredWidth, + declaredHeight, + point.X, + point.Y, + ), + ReleaseMouseOnFailure: true, + }) + } + actionSet = append(actionSet, DesktopAction{ + Action: desktopAction( + "left_mouse_up", + declaredWidth, + declaredHeight, + ), + ReleaseMouseOnFailure: true, + }) + actions, err = appendWithModifiers(actions, action.Keys, actionSet) + if err != nil { + return nil, err + } + case "keypress": + text, err := NormalizeKeys(action.Keys) + if err != nil { + return nil, err + } + desktopAction := desktopAction("key", declaredWidth, declaredHeight) + desktopAction.Text = &text + actions = append(actions, DesktopAction{Action: desktopAction}) + case "type": + desktopAction := desktopAction("type", declaredWidth, declaredHeight) + desktopAction.Text = &action.Text + actions = append(actions, DesktopAction{Action: desktopAction}) + case "scroll": + actionSet := computerUseScrollActions( + declaredWidth, + declaredHeight, + action.X, + action.Y, + action.ScrollX, + action.ScrollY, + ) + actions, err = appendWithModifiers(actions, action.Keys, actionSet) + if err != nil { + return nil, err + } + case "wait": + actions = append(actions, DesktopAction{WaitDurationMillis: 1000}) + default: + return nil, xerrors.Errorf( + "unsupported OpenAI computer action type %q", + action.Type, + ) + } + } + return actions, nil +} + +func appendWithModifiers( + actions []DesktopAction, + keys []string, + actionSet []DesktopAction, +) ([]DesktopAction, error) { + if len(keys) == 0 { + return append(actions, actionSet...), nil + } + + modifiers := make([]string, 0, len(keys)) + for _, key := range keys { + modifier, err := normalizeComputerUseKey(key) + if err != nil { + return nil, err + } + modifiers = append(modifiers, modifier) + } + + heldKeys := make([]string, 0, len(modifiers)) + for _, modifier := range modifiers { + nextHeldKeys := append(slices.Clone(heldKeys), modifier) + desktopAction := desktopAction("key_down", 0, 0) + desktopAction.Text = &modifier + actions = append(actions, DesktopAction{ + Action: desktopAction, + ReleaseKeysOnFailure: nextHeldKeys, + }) + heldKeys = nextHeldKeys + } + + for _, action := range actionSet { + action.ReleaseKeysOnFailure = slices.Clone(heldKeys) + actions = append(actions, action) + } + + for i := len(heldKeys) - 1; i >= 0; i-- { + key := heldKeys[i] + desktopAction := desktopAction("key_up", 0, 0) + desktopAction.Text = &key + actions = append(actions, DesktopAction{ + Action: desktopAction, + ReleaseKeysOnFailure: slices.Clone(heldKeys[:i+1]), + }) + } + return actions, nil +} + +func computerUseScrollActions( + declaredWidth, declaredHeight int, + x, y, scrollX, scrollY int64, +) []DesktopAction { + coord := coordinateFromInt64(x, y) + moveAction := desktopAction("mouse_move", declaredWidth, declaredHeight) + moveAction.Coordinate = &coord + actions := []DesktopAction{{Action: moveAction}} + + if scrollY != 0 { + direction := "down" + if scrollY < 0 { + direction = "up" + } + scrollAction := desktopAction("scroll", declaredWidth, declaredHeight) + scrollAction.Coordinate = &coord + scrollAction.ScrollDirection = &direction + amount := scrollPixelsToWheelClicks(scrollY) + scrollAction.ScrollAmount = &amount + actions = append(actions, DesktopAction{Action: scrollAction}) + } + + if scrollX != 0 { + direction := "right" + if scrollX < 0 { + direction = "left" + } + scrollAction := desktopAction("scroll", declaredWidth, declaredHeight) + scrollAction.Coordinate = &coord + scrollAction.ScrollDirection = &direction + amount := scrollPixelsToWheelClicks(scrollX) + scrollAction.ScrollAmount = &amount + actions = append(actions, DesktopAction{Action: scrollAction}) + } + return actions +} + +func desktopActionWithCoordinate( + action string, + declaredWidth, declaredHeight int, + x, y int64, +) workspacesdk.DesktopAction { + desktopAction := desktopAction(action, declaredWidth, declaredHeight) + coord := coordinateFromInt64(x, y) + desktopAction.Coordinate = &coord + return desktopAction +} + +func desktopAction( + action string, + declaredWidth, declaredHeight int, +) workspacesdk.DesktopAction { + return workspacesdk.DesktopAction{ + Action: action, + ScaledWidth: &declaredWidth, + ScaledHeight: &declaredHeight, + } +} + +func coordinateFromInt64(x, y int64) [2]int { + return [2]int{int(x), int(y)} +} + +func scrollPixelsToWheelClicks(pixels int64) int { + if pixels < 0 { + pixels = -pixels + } + if pixels == 0 { + return 0 + } + return int((pixels + computerUseScrollPixelsPerWheelClick - 1) / + computerUseScrollPixelsPerWheelClick) +} + +func clickActions( + button string, + declaredWidth, declaredHeight int, + x, y int64, +) ([]DesktopAction, error) { + actionName, ok := ClickAction(button) + if ok { + return []DesktopAction{{ + Action: desktopActionWithCoordinate( + actionName, + declaredWidth, + declaredHeight, + x, + y, + ), + }}, nil + } + + navigationKey := "" + switch button { + case "back": + navigationKey = "alt+Left" + case "forward": + navigationKey = "alt+Right" + default: + return nil, xerrors.Errorf("unsupported OpenAI click button %q", button) + } + + keyAction := desktopAction("key", 0, 0) + keyAction.Text = &navigationKey + return []DesktopAction{ + { + Action: desktopActionWithCoordinate( + "mouse_move", + declaredWidth, + declaredHeight, + x, + y, + ), + }, + {Action: keyAction}, + }, nil +} + +// DoubleClickAction maps an OpenAI computer-use double-click button to a Coder +// desktop action name. The desktop API currently supports only left-button +// double-clicks. +func DoubleClickAction(button string) (string, bool) { + switch button { + case "", "left": + return "double_click", true + default: + return "", false + } +} + +// ComputerUseClickAction maps an OpenAI computer-use click button to a Coder +// desktop action name. +func ClickAction(button string) (string, bool) { + switch button { + case "", "left": + return "left_click", true + case "right": + return "right_click", true + case "middle", "wheel": + return "middle_click", true + default: + return "", false + } +} + +// NormalizeComputerUseKeys maps OpenAI keypress tokens to Coder desktop key +// action tokens. +func NormalizeKeys(keys []string) (string, error) { + if len(keys) == 0 { + return "", xerrors.New("OpenAI keypress action requires at least one key") + } + normalized := make([]string, 0, len(keys)) + for _, key := range keys { + normalizedKey, err := normalizeComputerUseKey(key) + if err != nil { + return "", err + } + normalized = append(normalized, normalizedKey) + } + return strings.Join(normalized, "+"), nil +} + +func normalizeComputerUseKey(key string) (string, error) { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + return "", xerrors.New("OpenAI keypress action contains an empty key") + } + + lower := strings.ToLower(trimmed) + switch lower { + case "ctrl", "control": + return "ctrl", nil + case "cmd", "command", "meta", "super": + return "meta", nil + case "shift": + return "shift", nil + case "alt", "option": + return "alt", nil + case "enter", "return": + return "Return", nil + case "escape", "esc": + return "Escape", nil + case "tab": + return "Tab", nil + case "space": + return "space", nil + case "backspace": + return "BackSpace", nil + case "delete", "del": + return "Delete", nil + case "arrowup", "up": + return "Up", nil + case "arrowdown", "down": + return "Down", nil + case "arrowleft", "left": + return "Left", nil + case "arrowright", "right": + return "Right", nil + } + + if isFunctionKey(lower) { + return "F" + lower[1:], nil + } + + runes := []rune(trimmed) + if len(runes) == 1 { + r := runes[0] + if unicode.IsLetter(r) { + return strings.ToLower(trimmed), nil + } + if unicode.IsDigit(r) { + return trimmed, nil + } + if unicode.IsPunct(r) || unicode.IsSymbol(r) { + return trimmed, nil + } + return "", xerrors.Errorf("unsupported OpenAI keypress %q", trimmed) + } + + return "", xerrors.Errorf("unsupported OpenAI keypress %q", trimmed) +} + +func isFunctionKey(key string) bool { + if len(key) < 2 || key[0] != 'f' { + return false + } + number, ok := strings.CutPrefix(key, "f") + if !ok || number == "" { + return false + } + for _, r := range number { + if r < '0' || r > '9' { + return false + } + } + value := 0 + for _, r := range number { + value = value*10 + int(r-'0') + } + return value >= 1 && value <= 35 +} diff --git a/coderd/x/chatd/chatopenai/computeruse/computeruse_test.go b/coderd/x/chatd/chatopenai/computeruse/computeruse_test.go new file mode 100644 index 0000000000..f75efc1f8b --- /dev/null +++ b/coderd/x/chatd/chatopenai/computeruse/computeruse_test.go @@ -0,0 +1,199 @@ +package computeruse_test + +import ( + "testing" + + "charm.land/fantasy" + fantasyopenai "charm.land/fantasy/providers/openai" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse" +) + +func TestComputerUseTool(t *testing.T) { + t.Parallel() + + tool := computeruse.Tool() + require.True(t, computeruse.IsTool(tool)) + require.Equal(t, "computer", tool.GetName()) +} + +func TestComputerUseResultProviderMetadata(t *testing.T) { + t.Parallel() + + t.Run("SuccessfulImage", func(t *testing.T) { + t.Parallel() + + metadata := computeruse.ResultProviderMetadata( + fantasy.NewImageResponse([]byte("png"), "image/png"), + ) + outputOptions, ok := metadata[fantasyopenai.Name].(*fantasyopenai.ComputerCallOutputOptions) + require.True(t, ok) + require.Equal(t, "original", outputOptions.Detail) + }) + + tests := []struct { + name string + response fantasy.ToolResponse + }{ + {name: "Error", response: fantasy.NewTextErrorResponse("failed")}, + {name: "Text", response: fantasy.NewTextResponse("ok")}, + {name: "EmptyImage", response: fantasy.NewImageResponse(nil, "image/png")}, + { + name: "NonImageMediaType", + response: fantasy.NewImageResponse([]byte("png"), "application/octet-stream"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + metadata := computeruse.ResultProviderMetadata(tt.response) + require.Nil(t, metadata) + }) + } +} + +func TestDesktopActionsWrapsPointerActionsWithModifiers(t *testing.T) { + t.Parallel() + + input, err := computeruse.ParseInput(`{ + "call_id":"call_click_modifier", + "actions":[{"type":"click","button":"left","x":70,"y":80,"keys":["ctrl","shift"]}] + }`) + require.NoError(t, err) + + actions, err := computeruse.DesktopActions(input, 1440, 900) + require.NoError(t, err) + require.Len(t, actions, 5) + + require.Equal(t, "key_down", actions[0].Action.Action) + require.NotNil(t, actions[0].Action.Text) + require.Equal(t, "ctrl", *actions[0].Action.Text) + require.Equal(t, []string{"ctrl"}, actions[0].ReleaseKeysOnFailure) + + require.Equal(t, "key_down", actions[1].Action.Action) + require.NotNil(t, actions[1].Action.Text) + require.Equal(t, "shift", *actions[1].Action.Text) + require.Equal(t, []string{"ctrl", "shift"}, actions[1].ReleaseKeysOnFailure) + + require.Equal(t, "left_click", actions[2].Action.Action) + require.Equal(t, []string{"ctrl", "shift"}, actions[2].ReleaseKeysOnFailure) + + require.Equal(t, "key_up", actions[3].Action.Action) + require.NotNil(t, actions[3].Action.Text) + require.Equal(t, "shift", *actions[3].Action.Text) + require.Equal(t, []string{"ctrl", "shift"}, actions[3].ReleaseKeysOnFailure) + + require.Equal(t, "key_up", actions[4].Action.Action) + require.NotNil(t, actions[4].Action.Text) + require.Equal(t, "ctrl", *actions[4].Action.Text) + require.Equal(t, []string{"ctrl"}, actions[4].ReleaseKeysOnFailure) +} + +func TestDesktopActionsMarksFinalDragReleaseForCleanup(t *testing.T) { + t.Parallel() + + input, err := computeruse.ParseInput(`{ + "call_id":"call_drag", + "actions":[{"type":"drag","path":[{"x":1,"y":2},{"x":3,"y":4}]}] + }`) + require.NoError(t, err) + + actions, err := computeruse.DesktopActions(input, 1440, 900) + require.NoError(t, err) + require.Len(t, actions, 4) + require.Equal(t, "left_mouse_down", actions[1].Action.Action) + require.True(t, actions[1].ReleaseMouseOnFailure) + require.Equal(t, "left_mouse_up", actions[3].Action.Action) + require.True(t, actions[3].ReleaseMouseOnFailure) +} + +func TestDesktopActionsDefaultsEmptyClickButtonToLeft(t *testing.T) { + t.Parallel() + + input, err := computeruse.ParseInput(`{ + "call_id":"call_empty_button", + "actions":[{"type":"click","x":70,"y":80}] + }`) + require.NoError(t, err) + + actions, err := computeruse.DesktopActions(input, 1440, 900) + require.NoError(t, err) + require.Len(t, actions, 1) + require.Equal(t, "left_click", actions[0].Action.Action) +} + +func TestDesktopActionsMapsBackForwardClickButtons(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + button string + wantKey string + }{ + {name: "Back", button: "back", wantKey: "alt+Left"}, + {name: "Forward", button: "forward", wantKey: "alt+Right"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + input, err := computeruse.ParseInput(`{ + "call_id":"call_side_button", + "actions":[{"type":"click","button":"` + tt.button + `","x":70,"y":80}] + }`) + require.NoError(t, err) + + actions, err := computeruse.DesktopActions(input, 1440, 900) + require.NoError(t, err) + require.Len(t, actions, 2) + require.Equal(t, "mouse_move", actions[0].Action.Action) + require.Equal(t, "key", actions[1].Action.Action) + require.NotNil(t, actions[1].Action.Text) + require.Equal(t, tt.wantKey, *actions[1].Action.Text) + }) + } +} + +func TestDesktopActionsRejectsUnsupportedDoubleClickButton(t *testing.T) { + t.Parallel() + + input, err := computeruse.ParseInput(`{ + "call_id":"call_double_click", + "actions":[{"type":"double_click","button":"right","x":70,"y":80}] + }`) + require.NoError(t, err) + + _, err = computeruse.DesktopActions(input, 1440, 900) + require.Error(t, err) + require.Contains(t, err.Error(), `unsupported OpenAI double-click button "right"`) +} + +func TestDesktopActionsConvertsScrollPixelsToWheelClicks(t *testing.T) { + t.Parallel() + + input, err := computeruse.ParseInput(`{ + "call_id":"call_scroll", + "actions":[{"type":"scroll","x":70,"y":80,"scroll_y":401,"scroll_x":-99}] + }`) + require.NoError(t, err) + + actions, err := computeruse.DesktopActions(input, 1440, 900) + require.NoError(t, err) + require.Len(t, actions, 3) + + vertical := actions[1].Action + require.NotNil(t, vertical.ScrollAmount) + require.NotNil(t, vertical.ScrollDirection) + require.Equal(t, "down", *vertical.ScrollDirection) + require.Equal(t, 5, *vertical.ScrollAmount) + + horizontal := actions[2].Action + require.NotNil(t, horizontal.ScrollAmount) + require.NotNil(t, horizontal.ScrollDirection) + require.Equal(t, "left", *horizontal.ScrollDirection) + require.Equal(t, 1, *horizontal.ScrollAmount) +} diff --git a/coderd/x/chatd/chattool/computeruse.go b/coderd/x/chatd/chattool/computeruse.go index 1b7a639148..fcff921b49 100644 --- a/coderd/x/chatd/chattool/computeruse.go +++ b/coderd/x/chatd/chattool/computeruse.go @@ -4,28 +4,84 @@ import ( "context" "encoding/base64" "fmt" + "slices" + "strings" "time" "charm.land/fantasy" fantasyanthropic "charm.land/fantasy/providers/anthropic" + "golang.org/x/xerrors" "cdr.dev/slog/v3" + openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/quartz" ) const ( - // ComputerUseModelProvider is the provider for the computer - // use model. - ComputerUseModelProvider = "anthropic" - // ComputerUseModelName is the model used for computer use - // subagents. - ComputerUseModelName = "claude-opus-4-6" + // ComputerUseProviderAnthropic identifies Anthropic computer use. + ComputerUseProviderAnthropic = "anthropic" + // ComputerUseProviderOpenAI identifies OpenAI computer use. + ComputerUseProviderOpenAI = "openai" + // ComputerUseModelProviderDefault is the default model provider name for + // computer use, equal to ComputerUseProviderAnthropic. + ComputerUseModelProviderDefault = ComputerUseProviderAnthropic + // ComputerUseAnthropicModelName is the default Anthropic model used for + // computer use subagents. + ComputerUseAnthropicModelName = "claude-opus-4-6" + // ComputerUseOpenAIModelName is the default OpenAI model used for computer use. + ComputerUseOpenAIModelName = "gpt-5.5" ) -// computerUseTool implements fantasy.AgentTool and -// chatloop.ToolDefiner for Anthropic computer use. +// SupportedComputerUseProviders returns the providers supported by computer use. +// The returned slice is a fresh copy and safe to mutate. +func SupportedComputerUseProviders() []string { + return []string{ + ComputerUseProviderAnthropic, + ComputerUseProviderOpenAI, + } +} + +// IsSupportedComputerUseProvider reports whether provider supports computer use. +func IsSupportedComputerUseProvider(provider string) bool { + return slices.Contains(SupportedComputerUseProviders(), provider) +} + +// DefaultComputerUseProvider returns the effective computer use provider. +func DefaultComputerUseProvider(provider string) string { + if provider == "" { + return ComputerUseProviderAnthropic + } + return provider +} + +// DefaultComputerUseModel returns the default model for a computer use provider. +func DefaultComputerUseModel(provider string) (modelProvider, modelName string, ok bool) { + switch DefaultComputerUseProvider(provider) { + case ComputerUseProviderAnthropic: + return ComputerUseModelProviderDefault, ComputerUseAnthropicModelName, true + case ComputerUseProviderOpenAI: + // Keep OpenAI isolated here because computer-use models may advance. + return ComputerUseProviderOpenAI, ComputerUseOpenAIModelName, true + default: + return "", "", false + } +} + +// DefaultComputerUseDesktopGeometry returns provider-specific model-facing +// desktop geometry for computer use. +func DefaultComputerUseDesktopGeometry(provider string) workspacesdk.DesktopGeometry { + switch DefaultComputerUseProvider(provider) { + case ComputerUseProviderOpenAI: + return workspacesdk.DefaultOpenAIComputerUseDesktopGeometry() + default: + return workspacesdk.DefaultDesktopGeometry() + } +} + +// computerUseTool implements fantasy.AgentTool and chatloop.ToolDefiner. type computerUseTool struct { + provider string declaredWidth int declaredHeight int getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error) @@ -35,11 +91,12 @@ type computerUseTool struct { logger slog.Logger } -// NewComputerUseTool creates a computer use AgentTool that delegates to the -// agent's desktop endpoints. declaredWidth and declaredHeight are the -// model-facing desktop dimensions advertised to Anthropic and requested for -// screenshots. +// NewComputerUseTool creates a provider-aware computer use AgentTool that +// delegates to the agent's desktop endpoints. declaredWidth and declaredHeight +// are the model-facing desktop dimensions advertised to providers and requested +// for screenshots. func NewComputerUseTool( + provider string, declaredWidth, declaredHeight int, getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error), storeFile StoreFileFunc, @@ -47,6 +104,7 @@ func NewComputerUseTool( logger slog.Logger, ) fantasy.AgentTool { return &computerUseTool{ + provider: DefaultComputerUseProvider(provider), declaredWidth: declaredWidth, declaredHeight: declaredHeight, getWorkspaceConn: getWorkspaceConn, @@ -67,20 +125,31 @@ func (*computerUseTool) Info() fantasy.ToolInfo { } } -// ComputerUseProviderTool creates the provider-defined Anthropic computer-use -// tool definition using the declared model-facing desktop geometry. -func ComputerUseProviderTool(declaredWidth, declaredHeight int) fantasy.Tool { - // The run callback is nil because execution is handled separately - // by the AgentTool runner in the chatloop. We extract just the - // provider-defined tool definition. - return fantasyanthropic.NewComputerUseTool( - fantasyanthropic.ComputerUseToolOptions{ - DisplayWidthPx: int64(declaredWidth), - DisplayHeightPx: int64(declaredHeight), - ToolVersion: fantasyanthropic.ComputerUse20251124, - }, - nil, - ).Definition() +// ComputerUseProviderTool creates the provider-defined computer-use tool +// definition using the declared model-facing desktop geometry. +func ComputerUseProviderTool(provider string, declaredWidth, declaredHeight int) (fantasy.Tool, error) { + switch DefaultComputerUseProvider(provider) { + case ComputerUseProviderAnthropic: + // The run callback is nil because execution is handled separately + // by the AgentTool runner in the chatloop. We extract just the + // provider-defined tool definition. + return fantasyanthropic.NewComputerUseTool( + fantasyanthropic.ComputerUseToolOptions{ + DisplayWidthPx: int64(declaredWidth), + DisplayHeightPx: int64(declaredHeight), + ToolVersion: fantasyanthropic.ComputerUse20251124, + }, + nil, + ).Definition(), nil + case ComputerUseProviderOpenAI: + // OpenAI's GA computer tool schema does not accept display + // dimensions. The declared geometry is applied through screenshot + // sizing and desktop action coordinate scaling. + return openaicomputeruse.Tool(), nil + default: + return nil, xerrors.Errorf("unsupported computer use provider %q, supported providers: %s", provider, + strings.Join(SupportedComputerUseProviders(), ", ")) + } } func (t *computerUseTool) ProviderOptions() fantasy.ProviderOptions { @@ -92,6 +161,24 @@ func (t *computerUseTool) SetProviderOptions(opts fantasy.ProviderOptions) { } func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + switch DefaultComputerUseProvider(t.provider) { + case ComputerUseProviderAnthropic: + return t.runAnthropicComputerUse(ctx, call) + case ComputerUseProviderOpenAI: + return t.runOpenAIComputerUse(ctx, call) + default: + return fantasy.NewTextErrorResponse(fmt.Sprintf( + "unsupported computer use provider %q, supported providers: %s", + t.provider, + strings.Join(SupportedComputerUseProviders(), ", "), + )), nil + } +} + +func (t *computerUseTool) runAnthropicComputerUse( + ctx context.Context, + call fantasy.ToolCall, +) (fantasy.ToolResponse, error) { input, err := fantasyanthropic.ParseComputerUseInput(call.Input) if err != nil { return fantasy.NewTextErrorResponse( @@ -110,16 +197,7 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta // For wait actions, sleep then return a screenshot. if input.Action == fantasyanthropic.ActionWait { - d := input.Duration - if d <= 0 { - d = 1000 - } - timer := t.clock.NewTimer(time.Duration(d)*time.Millisecond, "computeruse", "wait") - defer timer.Stop() - select { - case <-ctx.Done(): - case <-timer.C: - } + t.wait(ctx, input.Duration) return t.captureScreenshot(ctx, conn, declaredWidth, declaredHeight) } @@ -129,17 +207,13 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta } // Build the action request. - action := workspacesdk.DesktopAction{ - Action: string(input.Action), - ScaledWidth: &declaredWidth, - ScaledHeight: &declaredHeight, - } + action := t.desktopAction(string(input.Action), declaredWidth, declaredHeight) if input.Coordinate != ([2]int64{}) { - coord := [2]int{int(input.Coordinate[0]), int(input.Coordinate[1])} + coord := coordinateFromInt64(input.Coordinate[0], input.Coordinate[1]) action.Coordinate = &coord } if input.StartCoordinate != ([2]int64{}) { - coord := [2]int{int(input.StartCoordinate[0]), int(input.StartCoordinate[1])} + coord := coordinateFromInt64(input.StartCoordinate[0], input.StartCoordinate[1]) action.StartCoordinate = &coord } if input.Text != "" { @@ -157,18 +231,124 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta action.ScrollDirection = &input.ScrollDirection } - // Execute the action. - _, err = conn.ExecuteDesktopAction(ctx, action) - if err != nil { - return fantasy.NewTextErrorResponse( - fmt.Sprintf("action %q failed: %v", input.Action, err), - ), nil + if resp, done := t.executeDesktopAction(ctx, conn, action); done { + return resp, nil } // Take a screenshot after every action (Anthropic pattern). return t.captureScreenshot(ctx, conn, declaredWidth, declaredHeight) } +func (t *computerUseTool) runOpenAIComputerUse( + ctx context.Context, + call fantasy.ToolCall, +) (fantasy.ToolResponse, error) { + input, err := openaicomputeruse.ParseInput(call.Input) + if err != nil { + return fantasy.NewTextErrorResponse( + fmt.Sprintf("invalid computer use input: %v", err), + ), nil + } + conn, err := t.getWorkspaceConn(ctx) + if err != nil { + return fantasy.NewTextErrorResponse( + fmt.Sprintf("failed to connect to workspace: %v", err), + ), nil + } + + declaredWidth, declaredHeight := t.declaredActionDimensions() + actions, err := openaicomputeruse.DesktopActions( + input, + declaredWidth, + declaredHeight, + ) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + for _, action := range actions { + if action.WaitDurationMillis > 0 { + t.wait(ctx, action.WaitDurationMillis) + continue + } + if resp, done := t.executeDesktopAction(ctx, conn, action.Action); done { + if action.ReleaseMouseOnFailure { + _, err := conn.ExecuteDesktopAction( + ctx, + t.desktopAction("left_mouse_up", declaredWidth, declaredHeight), + ) + if err != nil { + t.logger.Warn(ctx, "failed to release mouse after OpenAI drag error", + slog.Error(err), + ) + } + } + t.releaseOpenAIModifierKeys(ctx, conn, action.ReleaseKeysOnFailure) + return resp, nil + } + } + return t.captureSharedScreenshot(ctx, conn, declaredWidth, declaredHeight) +} + +func (t *computerUseTool) releaseOpenAIModifierKeys( + ctx context.Context, + conn workspacesdk.AgentConn, + keys []string, +) { + for i := len(keys) - 1; i >= 0; i-- { + key := keys[i] + action := t.desktopAction("key_up", 0, 0) + action.Text = &key + if _, err := conn.ExecuteDesktopAction(ctx, action); err != nil { + t.logger.Warn(ctx, "failed to release OpenAI modifier key", + slog.F("key", key), + slog.Error(err), + ) + } + } +} + +func (*computerUseTool) executeDesktopAction( + ctx context.Context, + conn workspacesdk.AgentConn, + action workspacesdk.DesktopAction, +) (fantasy.ToolResponse, bool) { + _, err := conn.ExecuteDesktopAction(ctx, action) + if err != nil { + return fantasy.NewTextErrorResponse( + fmt.Sprintf("action %q failed: %v", action.Action, err), + ), true + } + return fantasy.ToolResponse{}, false +} + +func (*computerUseTool) desktopAction( + action string, + declaredWidth, declaredHeight int, +) workspacesdk.DesktopAction { + return workspacesdk.DesktopAction{ + Action: action, + ScaledWidth: &declaredWidth, + ScaledHeight: &declaredHeight, + } +} + +func (t *computerUseTool) wait(ctx context.Context, durationMillis int64) { + d := durationMillis + if d <= 0 { + d = 1000 + } + timer := t.clock.NewTimer(time.Duration(d)*time.Millisecond, "computeruse", "wait") + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + } +} + +func coordinateFromInt64(x, y int64) [2]int { + return [2]int{int(x), int(y)} +} + func (t *computerUseTool) captureScreenshot( ctx context.Context, conn workspacesdk.AgentConn, @@ -256,7 +436,7 @@ func executeScreenshotAction( func (t *computerUseTool) declaredActionDimensions() (declaredWidth, declaredHeight int) { if t.declaredWidth <= 0 || t.declaredHeight <= 0 { - geometry := workspacesdk.DefaultDesktopGeometry() + geometry := DefaultComputerUseDesktopGeometry(t.provider) return geometry.DeclaredWidth, geometry.DeclaredHeight } return t.declaredWidth, t.declaredHeight diff --git a/coderd/x/chatd/chattool/computeruse_test.go b/coderd/x/chatd/chattool/computeruse_test.go index fa92a55449..5138003345 100644 --- a/coderd/x/chatd/chattool/computeruse_test.go +++ b/coderd/x/chatd/chattool/computeruse_test.go @@ -5,8 +5,10 @@ import ( "context" "encoding/base64" "testing" + "time" "charm.land/fantasy" + fantasyanthropic "charm.land/fantasy/providers/anthropic" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,23 +16,137 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3/sloggers/slogtest" + openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse" "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) +func TestDefaultComputerUseModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider string + wantModelProvider string + wantModelName string + wantOK bool + }{ + { + name: "empty defaults to Anthropic", + provider: "", + wantModelProvider: chattool.ComputerUseModelProviderDefault, + wantModelName: chattool.ComputerUseAnthropicModelName, + wantOK: true, + }, + { + name: "Anthropic", + provider: chattool.ComputerUseProviderAnthropic, + wantModelProvider: chattool.ComputerUseModelProviderDefault, + wantModelName: chattool.ComputerUseAnthropicModelName, + wantOK: true, + }, + { + name: "OpenAI", + provider: chattool.ComputerUseProviderOpenAI, + wantModelProvider: chattool.ComputerUseProviderOpenAI, + wantModelName: chattool.ComputerUseOpenAIModelName, + wantOK: true, + }, + { + name: "unsupported", + provider: "unsupported", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + modelProvider, modelName, ok := chattool.DefaultComputerUseModel(tt.provider) + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.wantModelProvider, modelProvider) + assert.Equal(t, tt.wantModelName, modelName) + }) + } +} + +func TestDefaultComputerUseDesktopGeometry(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider string + declaredWidth int + declaredHeight int + }{ + { + name: "empty defaults to Anthropic geometry", + provider: "", + declaredWidth: 1280, + declaredHeight: 720, + }, + { + name: "Anthropic", + provider: chattool.ComputerUseProviderAnthropic, + declaredWidth: 1280, + declaredHeight: 720, + }, + { + name: "OpenAI", + provider: chattool.ComputerUseProviderOpenAI, + declaredWidth: 1600, + declaredHeight: 900, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + geometry := chattool.DefaultComputerUseDesktopGeometry(tt.provider) + assert.Equal(t, tt.declaredWidth, geometry.DeclaredWidth) + assert.Equal(t, tt.declaredHeight, geometry.DeclaredHeight) + }) + } +} + func TestComputerUseProviderTool(t *testing.T) { t.Parallel() geometry := workspacesdk.DefaultDesktopGeometry() - def := chattool.ComputerUseProviderTool(geometry.DeclaredWidth, geometry.DeclaredHeight) + def, err := chattool.ComputerUseProviderTool( + chattool.ComputerUseProviderAnthropic, + geometry.DeclaredWidth, + geometry.DeclaredHeight, + ) + require.NoError(t, err) pdt, ok := def.(fantasy.ProviderDefinedTool) require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool") + assert.True(t, fantasyanthropic.IsComputerUseTool(def)) assert.Contains(t, pdt.ID, "computer") assert.Equal(t, "computer", pdt.Name) assert.Equal(t, int64(geometry.DeclaredWidth), pdt.Args["display_width_px"]) assert.Equal(t, int64(geometry.DeclaredHeight), pdt.Args["display_height_px"]) + + openAITool, err := chattool.ComputerUseProviderTool( + chattool.ComputerUseProviderOpenAI, + geometry.DeclaredWidth, + geometry.DeclaredHeight, + ) + require.NoError(t, err) + assert.True(t, openaicomputeruse.IsTool(openAITool)) + + _, err = chattool.ComputerUseProviderTool( + "unsupported", + geometry.DeclaredWidth, + geometry.DeclaredHeight, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported computer use provider") } func TestComputerUseTool_Run_Screenshot(t *testing.T) { @@ -56,7 +172,7 @@ func TestComputerUseTool_Run_Screenshot(t *testing.T) { }, nil }) - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return mockConn, nil }, nil, quartz.NewReal(), slogtest.Make(t, nil)) @@ -100,7 +216,7 @@ func TestComputerUseTool_Run_Screenshot_PersistsAttachment(t *testing.T) { var storedName string var storedType string var storedData []byte - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return mockConn, nil }, func(_ context.Context, name string, detectName string, data []byte) (chattool.AttachmentMetadata, error) { storedName = name @@ -154,7 +270,7 @@ func TestComputerUseTool_Run_Screenshot_StoreErrorFallsBackToImage(t *testing.T) ScreenshotHeight: geometry.DeclaredHeight, }, nil) - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return mockConn, nil }, func(_ context.Context, _ string, _ string, _ []byte) (chattool.AttachmentMetadata, error) { return chattool.AttachmentMetadata{}, xerrors.New("chat already has the maximum of 20 linked files") @@ -190,7 +306,7 @@ func TestComputerUseTool_Run_Screenshot_OversizedAttachmentFallsBackToImage(t *t ScreenshotHeight: geometry.DeclaredHeight, }, nil) - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return mockConn, nil }, func(_ context.Context, _ string, _ string, _ []byte) (chattool.AttachmentMetadata, error) { t.Fatal("storeFile should not be called for oversized screenshots") @@ -250,7 +366,7 @@ func TestComputerUseTool_Run_LeftClick(t *testing.T) { }, nil }) - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return mockConn, nil }, func(_ context.Context, _ string, _ string, _ []byte) (chattool.AttachmentMetadata, error) { t.Fatal("storeFile should not be called for left_click follow-up screenshots") @@ -298,7 +414,7 @@ func TestComputerUseTool_Run_Wait(t *testing.T) { }, nil }) - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return mockConn, nil }, func(_ context.Context, _ string, _ string, _ []byte) (chattool.AttachmentMetadata, error) { t.Fatal("storeFile should not be called for wait screenshots") @@ -345,6 +461,7 @@ func TestComputerUseTool_Run_ScreenshotDataIsDecodedBinary(t *testing.T) { }, nil) tool := chattool.NewComputerUseTool( + chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { @@ -386,7 +503,7 @@ func TestComputerUseTool_Run_ConnError(t *testing.T) { t.Parallel() geometry := workspacesdk.DefaultDesktopGeometry() - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return nil, xerrors.New("workspace not available") }, nil, quartz.NewReal(), slogtest.Make(t, nil)) @@ -406,7 +523,7 @@ func TestComputerUseTool_Run_InvalidInput(t *testing.T) { t.Parallel() geometry := workspacesdk.DefaultDesktopGeometry() - tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { + tool := chattool.NewComputerUseTool(chattool.ComputerUseProviderAnthropic, geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) { return nil, xerrors.New("should not be called") }, nil, quartz.NewReal(), slogtest.Make(t, nil)) @@ -421,3 +538,561 @@ func TestComputerUseTool_Run_InvalidInput(t *testing.T) { assert.True(t, resp.IsError) assert.Contains(t, resp.Content, "invalid computer use input") } + +func TestComputerUseTool_Run_OpenAI_BatchedActions(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + const screenshotPNG = "aW1hZ2UtZGF0YQ==" + actions := recordDesktopActions(t, mockConn, geometry, 16, screenshotPNG) + + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_batch", + "actions":[ + {"type":"screenshot"}, + {"type":"move","x":10,"y":20}, + {"type":"click","button":"left","x":30,"y":40}, + {"type":"click","button":"right","x":31,"y":41}, + {"type":"click","button":"middle","x":32,"y":42}, + {"type":"double_click","x":50,"y":60}, + {"type":"drag","path":[{"x":1,"y":2},{"x":3,"y":4},{"x":5,"y":6}]}, + {"type":"keypress","keys":["ctrl","s"]}, + {"type":"type","text":"hello"}, + {"type":"scroll","x":70,"y":80,"scroll_y":500,"scroll_x":-200} + ] + }`)) + require.NoError(t, err) + assert.Equal(t, "image", resp.Type) + assert.Equal(t, "image/png", resp.MediaType) + assert.False(t, resp.IsError) + expectedImage, err := base64.StdEncoding.DecodeString(screenshotPNG) + require.NoError(t, err) + assert.Equal(t, expectedImage, resp.Data) + attachments, err := chattool.AttachmentsFromMetadata(resp.Metadata) + require.NoError(t, err) + assert.Empty(t, attachments) + + require.Len(t, *actions, 16) + for _, action := range *actions { + assertDesktopActionScaled(t, geometry, action) + } + assertDesktopAction(t, (*actions)[0], "mouse_move", [2]int{10, 20}) + assertDesktopAction(t, (*actions)[1], "left_click", [2]int{30, 40}) + assertDesktopAction(t, (*actions)[2], "right_click", [2]int{31, 41}) + assertDesktopAction(t, (*actions)[3], "middle_click", [2]int{32, 42}) + assertDesktopAction(t, (*actions)[4], "double_click", [2]int{50, 60}) + assertDesktopAction(t, (*actions)[5], "mouse_move", [2]int{1, 2}) + assert.Equal(t, "left_mouse_down", (*actions)[6].Action) + assert.Nil(t, (*actions)[6].Coordinate) + assertDesktopAction(t, (*actions)[7], "mouse_move", [2]int{3, 4}) + assertDesktopAction(t, (*actions)[8], "mouse_move", [2]int{5, 6}) + assert.Equal(t, "left_mouse_up", (*actions)[9].Action) + assert.Nil(t, (*actions)[9].Coordinate) + assertTextAction(t, (*actions)[10], "key", "ctrl+s") + assertTextAction(t, (*actions)[11], "type", "hello") + assertDesktopAction(t, (*actions)[12], "mouse_move", [2]int{70, 80}) + assertScrollAction(t, (*actions)[13], [2]int{70, 80}, "down", 5) + assertScrollAction(t, (*actions)[14], [2]int{70, 80}, "left", 2) + assert.Equal(t, "screenshot", (*actions)[15].Action) + assert.Nil(t, (*actions)[15].Coordinate) +} + +func TestComputerUseTool_Run_OpenAI_EmptyActionsCapturesScreenshotAndStoresAttachment(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + const screenshotPNG = "ZmluYWwtc2NyZWVuc2hvdA==" + actions := recordDesktopActions(t, mockConn, geometry, 1, screenshotPNG) + + var storedName string + var storedData []byte + tool := newOpenAIComputerUseTool(t, geometry, mockConn, func(_ context.Context, name string, detectName string, data []byte) (chattool.AttachmentMetadata, error) { + storedName = name + require.Equal(t, name, detectName) + storedData = append([]byte(nil), data...) + return chattool.AttachmentMetadata{ + FileID: uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"), + MediaType: "image/png", + Name: name, + }, nil + }, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_empty", + "actions":[] + }`)) + require.NoError(t, err) + assert.Equal(t, "image", resp.Type) + require.Len(t, *actions, 1) + assert.Equal(t, "screenshot", (*actions)[0].Action) + assert.Contains(t, storedName, "screenshot-") + expectedData, err := base64.StdEncoding.DecodeString(screenshotPNG) + require.NoError(t, err) + assert.Equal(t, expectedData, storedData) + + attachments, err := chattool.AttachmentsFromMetadata(resp.Metadata) + require.NoError(t, err) + require.Len(t, attachments, 1) + assert.Equal(t, uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"), attachments[0].FileID) + assert.Equal(t, "image/png", attachments[0].MediaType) +} + +func TestComputerUseTool_Run_OpenAI_FinalScreenshotStoreErrorFallsBackToImage(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + const screenshotPNG = "ZmluYWwtc2NyZWVuc2hvdA==" + recordDesktopActions(t, mockConn, geometry, 1, screenshotPNG) + + tool := newOpenAIComputerUseTool(t, geometry, mockConn, func(_ context.Context, _ string, _ string, _ []byte) (chattool.AttachmentMetadata, error) { + return chattool.AttachmentMetadata{}, xerrors.New("chat already has the maximum of 20 linked files") + }, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_store_error", + "actions":[{"type":"screenshot"}] + }`)) + require.NoError(t, err) + assert.Equal(t, "image", resp.Type) + assert.Equal(t, "image/png", resp.MediaType) + assert.False(t, resp.IsError) + attachments, err := chattool.AttachmentsFromMetadata(resp.Metadata) + require.NoError(t, err) + assert.Empty(t, attachments) +} + +func TestComputerUseTool_Run_OpenAI_DragReleaseFailureRetriesMouseUp(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + + gomock.InOrder( + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assertDesktopAction(t, action, "mouse_move", [2]int{1, 2}) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{Output: "mouse_move performed"}, nil + }), + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assert.Equal(t, "left_mouse_down", action.Action) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{Output: "mouse_down performed"}, nil + }), + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assertDesktopAction(t, action, "mouse_move", [2]int{3, 4}) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{Output: "mouse_move performed"}, nil + }), + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assert.Equal(t, "left_mouse_up", action.Action) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{}, xerrors.New("release failed") + }), + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assert.Equal(t, "left_mouse_up", action.Action) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{Output: "mouse_up performed"}, nil + }), + ) + + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_release_failure", + "actions":[{"type":"drag","path":[{"x":1,"y":2},{"x":3,"y":4}]}] + }`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, `action "left_mouse_up" failed`) +} + +func TestComputerUseTool_Run_OpenAI_ActionFailureSkipsFinalScreenshot(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + + gomock.InOrder( + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assertDesktopAction(t, action, "mouse_move", [2]int{10, 20}) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{Output: "mouse_move performed"}, nil + }), + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assertTextAction(t, action, "type", "fail") + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{}, xerrors.New("desktop failed") + }), + ) + + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_failure", + "actions":[ + {"type":"move","x":10,"y":20}, + {"type":"type","text":"fail"} + ] + }`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, `action "type" failed`) +} + +func TestComputerUseTool_Run_OpenAI_UnsupportedClickButtons(t *testing.T) { + t.Parallel() + + for _, button := range []string{"extra"} { + t.Run(button, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_unsupported_button", + "actions":[{"type":"click","button":"`+button+`","x":10,"y":20}] + }`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, "unsupported OpenAI click button") + }) + } +} + +func TestComputerUseTool_Run_OpenAI_WheelClickIsMiddle(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + actions := recordDesktopActions(t, mockConn, geometry, 2, "d2hlZWwtY2xpY2s=") + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_wheel_click", + "actions":[{"type":"click","button":"wheel","x":10,"y":20}] + }`)) + require.NoError(t, err) + assert.False(t, resp.IsError) + require.Len(t, *actions, 2) + assertDesktopAction(t, (*actions)[0], "middle_click", [2]int{10, 20}) + assert.Equal(t, "screenshot", (*actions)[1].Action) +} + +func TestComputerUseTool_Run_OpenAI_UnsupportedActionType(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_unknown_action", + "actions":[{"type":"hover","x":10,"y":20}] + }`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, `unsupported OpenAI computer action type "hover"`) +} + +func TestComputerUseTool_Run_OpenAI_InvalidInput(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{invalid json`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, "invalid") +} + +func TestComputerUseTool_Run_OpenAI_DragRequiresTwoPoints(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_short_drag", + "actions":[{"type":"drag","path":[{"x":10,"y":20}]}] + }`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, "requires at least two path points") +} + +func TestComputerUseTool_Run_OpenAI_KeyNormalization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + keysJSON string + wantText string + }{ + {name: "ctrl s", keysJSON: `["ctrl","s"]`, wantText: "ctrl+s"}, + {name: "modifier aliases", keysJSON: `["control","shift","alt","command","A"]`, wantText: "ctrl+shift+alt+meta+a"}, + {name: "special keys", keysJSON: `["enter","escape","tab","space","backspace","delete"]`, wantText: "Return+Escape+Tab+space+BackSpace+Delete"}, + {name: "arrows", keysJSON: `["ArrowUp","arrowdown","left","Right"]`, wantText: "Up+Down+Left+Right"}, + {name: "function letters digits", keysJSON: `["f1","F12","5","Z"]`, wantText: "F1+F12+5+z"}, + {name: "minus key", keysJSON: `["-"]`, wantText: "-"}, + {name: "equals key", keysJSON: `["="]`, wantText: "="}, + {name: "slash key", keysJSON: `["/"]`, wantText: "/"}, + {name: "period key", keysJSON: `["."]`, wantText: "."}, + {name: "left bracket key", keysJSON: `["["]`, wantText: "["}, + {name: "right bracket key", keysJSON: `["]"]`, wantText: "]"}, + {name: "semicolon key", keysJSON: `[";"]`, wantText: ";"}, + {name: "apostrophe key", keysJSON: `["'"]`, wantText: "'"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + actions := recordDesktopActions(t, mockConn, geometry, 2, "a2V5LWltYWdl") + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_key", + "actions":[{"type":"keypress","keys":`+tt.keysJSON+`}] + }`)) + require.NoError(t, err) + assert.False(t, resp.IsError) + require.Len(t, *actions, 2) + assertTextAction(t, (*actions)[0], "key", tt.wantText) + assert.Equal(t, "screenshot", (*actions)[1].Action) + }) + } +} + +func TestComputerUseTool_Run_OpenAI_KeyNormalizationErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + keysJSON string + want string + }{ + {name: "empty array", keysJSON: `[]`, want: "requires at least one key"}, + {name: "empty token", keysJSON: `["ctrl",""]`, want: "contains an empty key"}, + {name: "unsupported multi-rune", keysJSON: `["ab"]`, want: `unsupported OpenAI keypress "ab"`}, + {name: "unsupported function key", keysJSON: `["f99"]`, want: `unsupported OpenAI keypress "f99"`}, + {name: "unsupported named key", keysJSON: `["PageDown"]`, want: `unsupported OpenAI keypress "PageDown"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, quartz.NewReal()) + + resp, err := tool.Run(context.Background(), openAIComputerUseCall(`{ + "call_id":"call_key_error", + "actions":[{"type":"keypress","keys":`+tt.keysJSON+`}] + }`)) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, tt.want) + }) + } +} + +func TestComputerUseTool_Run_OpenAI_WaitUsesMockClock(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + geometry := workspacesdk.DefaultDesktopGeometry() + mClock := quartz.NewMock(t) + const screenshotPNG = "d2FpdC1zY3JlZW5zaG90" + + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + assert.Equal(t, "screenshot", action.Action) + assertDesktopActionScaled(t, geometry, action) + return workspacesdk.DesktopActionResponse{ + Output: "screenshot", + ScreenshotData: screenshotPNG, + ScreenshotWidth: geometry.DeclaredWidth, + ScreenshotHeight: geometry.DeclaredHeight, + }, nil + }).Times(1) + + trap := mClock.Trap().NewTimer("computeruse", "wait") + tool := newOpenAIComputerUseTool(t, geometry, mockConn, nil, mClock) + + type toolResult struct { + resp fantasy.ToolResponse + err error + } + resultCh := make(chan toolResult, 1) + go func() { + resp, err := tool.Run(ctx, openAIComputerUseCall(`{ + "call_id":"call_wait", + "actions":[{"type":"wait"}] + }`)) + resultCh <- toolResult{resp: resp, err: err} + }() + + trap.MustWait(ctx).MustRelease(ctx) + trap.Close() + mClock.Advance(time.Second).MustWait(ctx) + + result := testutil.RequireReceive(ctx, t, resultCh) + require.NoError(t, result.err) + assert.Equal(t, "image", result.resp.Type) + assert.Equal(t, "image/png", result.resp.MediaType) + assert.False(t, result.resp.IsError) +} + +func newOpenAIComputerUseTool( + t testing.TB, + geometry workspacesdk.DesktopGeometry, + conn workspacesdk.AgentConn, + storeFile chattool.StoreFileFunc, + clock quartz.Clock, +) fantasy.AgentTool { + t.Helper() + return chattool.NewComputerUseTool( + chattool.ComputerUseProviderOpenAI, + geometry.DeclaredWidth, + geometry.DeclaredHeight, + func(_ context.Context) (workspacesdk.AgentConn, error) { + return conn, nil + }, + storeFile, + clock, + slogtest.Make(t, nil), + ) +} + +func openAIComputerUseCall(input string) fantasy.ToolCall { + return fantasy.ToolCall{ + ID: "openai-call", + Name: "computer", + Input: input, + } +} + +func recordDesktopActions( + t testing.TB, + mockConn *agentconnmock.MockAgentConn, + geometry workspacesdk.DesktopGeometry, + times int, + screenshotPNG string, +) *[]workspacesdk.DesktopAction { + t.Helper() + actions := make([]workspacesdk.DesktopAction, 0, times) + mockConn.EXPECT().ExecuteDesktopAction( + gomock.Any(), + gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}), + ).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) { + actions = append(actions, action) + if action.Action == "screenshot" { + return workspacesdk.DesktopActionResponse{ + Output: "screenshot", + ScreenshotData: screenshotPNG, + ScreenshotWidth: geometry.DeclaredWidth, + ScreenshotHeight: geometry.DeclaredHeight, + }, nil + } + return workspacesdk.DesktopActionResponse{Output: action.Action + " performed"}, nil + }).Times(times) + return &actions +} + +func assertDesktopActionScaled( + t testing.TB, + geometry workspacesdk.DesktopGeometry, + action workspacesdk.DesktopAction, +) { + t.Helper() + require.NotNil(t, action.ScaledWidth) + require.NotNil(t, action.ScaledHeight) + assert.Equal(t, geometry.DeclaredWidth, *action.ScaledWidth) + assert.Equal(t, geometry.DeclaredHeight, *action.ScaledHeight) +} + +func assertDesktopAction( + t testing.TB, + action workspacesdk.DesktopAction, + actionName string, + coordinate [2]int, +) { + t.Helper() + assert.Equal(t, actionName, action.Action) + require.NotNil(t, action.Coordinate) + assert.Equal(t, coordinate, *action.Coordinate) +} + +func assertTextAction( + t testing.TB, + action workspacesdk.DesktopAction, + actionName string, + text string, +) { + t.Helper() + assert.Equal(t, actionName, action.Action) + require.NotNil(t, action.Text) + assert.Equal(t, text, *action.Text) +} + +func assertScrollAction( + t testing.TB, + action workspacesdk.DesktopAction, + coordinate [2]int, + direction string, + amount int, +) { + t.Helper() + assertDesktopAction(t, action, "scroll", coordinate) + require.NotNil(t, action.ScrollDirection) + require.NotNil(t, action.ScrollAmount) + assert.Equal(t, direction, *action.ScrollDirection) + assert.Equal(t, amount, *action.ScrollAmount) +} diff --git a/coderd/x/chatd/computer_use.go b/coderd/x/chatd/computer_use.go new file mode 100644 index 0000000000..d41214f558 --- /dev/null +++ b/coderd/x/chatd/computer_use.go @@ -0,0 +1,167 @@ +package chatd + +import ( + "context" + "strings" + + "charm.land/fantasy" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/x/chatd/chatloop" + openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/quartz" +) + +// computerUseConfigContext lets internal and worker callers read +// deployment-wide chat settings when they lack an HTTP-derived actor. HTTP +// handlers always carry an actor, so the AsChatd fallback never elevates user +// contexts and this function is a no-op in that path. The setting it gates is +// global and readable by any authenticated actor, not a back-door. +func computerUseConfigContext(ctx context.Context) context.Context { + if _, ok := dbauthz.ActorFromContext(ctx); ok { + return ctx + } + //nolint:gocritic // Worker contexts may lack an actor. + return dbauthz.AsChatd(ctx) +} + +func (p *Server) computerUseProviderAndModelFromConfig( + ctx context.Context, +) (provider, modelProvider, modelName string, err error) { + rawProvider, err := p.db.GetChatComputerUseProvider( + computerUseConfigContext(ctx), + ) + if err != nil { + return "", "", "", xerrors.Errorf("get computer use provider: %w", err) + } + + provider = strings.TrimSpace(rawProvider) + if provider == "" { + provider = chattool.ComputerUseProviderAnthropic + } + + modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider) + if !ok { + return "", "", "", xerrors.Errorf( + "unknown computer-use provider %q configured in agents_computer_use_provider", + provider, + ) + } + + return provider, modelProvider, modelName, nil +} + +func (p *Server) resolveComputerUseModel( + ctx context.Context, + chat database.Chat, + providerKeys chatprovider.ProviderAPIKeys, + computerUseProvider string, + computerUseModelProvider string, + computerUseModelName string, +) ( + model fantasy.LanguageModel, + debugEnabled bool, + resolvedProvider string, + resolvedModel string, + err error, +) { + resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint( + computerUseModelName, + computerUseModelProvider, + ) + if err != nil { + return nil, false, "", "", xerrors.Errorf( + "resolve computer use model metadata for provider %q model %q: %w", + computerUseProvider, + computerUseModelName, + err, + ) + } + + model, debugEnabled, err = p.newDebugAwareModelFromConfig( + ctx, + chat, + computerUseModelProvider, + computerUseModelName, + providerKeys, + chatprovider.UserAgent(), + chatprovider.CoderHeaders(chat), + ) + if err != nil { + return nil, false, "", "", xerrors.Errorf( + "resolve computer use model for provider %q model %q: %w", + computerUseProvider, + computerUseModelName, + err, + ) + } + + return model, debugEnabled, resolvedProvider, resolvedModel, nil +} + +type computerUseProviderToolOptions struct { + provider string + isPlanModeTurn bool + isComputerUse bool + getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error) + storeFile chattool.StoreFileFunc + clock quartz.Clock + logger slog.Logger +} + +func appendComputerUseProviderTool( + providerTools []chatloop.ProviderTool, + opts computerUseProviderToolOptions, +) ([]chatloop.ProviderTool, error) { + // This helper is called for every chat turn. Only chats created by the + // computer_use subagent definition have ChatModeComputerUse, which filters + // out root, general, and explore chats. Plan mode is separate from Mode, so + // planning turns stay gated even for computer-use chats. + if opts.isPlanModeTurn || !opts.isComputerUse { + return providerTools, nil + } + + desktopGeometry := chattool.DefaultComputerUseDesktopGeometry(opts.provider) + definition, err := chattool.ComputerUseProviderTool( + opts.provider, + desktopGeometry.DeclaredWidth, + desktopGeometry.DeclaredHeight, + ) + if err != nil { + return providerTools, xerrors.Errorf( + "build computer use provider tool for provider %q: %w", + opts.provider, + err, + ) + } + + clock := opts.clock + if clock == nil { + clock = quartz.NewReal() + } + providerTool := chatloop.ProviderTool{ + Definition: definition, + Runner: chattool.NewComputerUseTool( + opts.provider, + desktopGeometry.DeclaredWidth, + desktopGeometry.DeclaredHeight, + opts.getWorkspaceConn, + opts.storeFile, + clock, + opts.logger, + ), + } + if opts.provider == chattool.ComputerUseProviderOpenAI { + // OpenAI computer-use image results need detail metadata so the model receives + // the screenshot at original detail when the chat loop sends the tool result. + providerTool.ResultProviderMetadata = openaicomputeruse.ResultProviderMetadata + } + + return append(providerTools, providerTool), nil +} diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index ccf0d01446..4f1207bda5 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -77,22 +77,28 @@ type closeAgentArgs struct { ChatID string `json:"chat_id"` } -// isAnthropicConfigured reports whether an Anthropic API key is -// available, either from static provider keys or from the database. -func (p *Server) isAnthropicConfigured(ctx context.Context) bool { - if p.providerAPIKeys.APIKey("anthropic") != "" { - return true +// providerConfigured reports whether a provider has an API key from +// static configuration or from the database provider configuration. +func (p *Server) providerConfigured(ctx context.Context, provider string) (bool, error) { + normalizedProvider := chatprovider.NormalizeProvider(provider) + if normalizedProvider == "" { + return false, nil } + if p.providerAPIKeys.APIKey(normalizedProvider) != "" { + return true, nil + } + dbProviders, err := p.configCache.EnabledProviders(ctx) if err != nil { - return false + return false, xerrors.Errorf("list enabled chat providers: %w", err) } for _, prov := range dbProviders { - if chatprovider.NormalizeProvider(prov.Provider) == "anthropic" && strings.TrimSpace(prov.APIKey) != "" { - return true + if chatprovider.NormalizeProvider(prov.Provider) == normalizedProvider && + strings.TrimSpace(prov.APIKey) != "" { + return true, nil } } - return false + return false, nil } func (p *Server) isDesktopEnabled(ctx context.Context) bool { diff --git a/coderd/x/chatd/subagent_catalog.go b/coderd/x/chatd/subagent_catalog.go index eea0004475..895bd8ec16 100644 --- a/coderd/x/chatd/subagent_catalog.go +++ b/coderd/x/chatd/subagent_catalog.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/codersdk" ) @@ -103,12 +104,34 @@ func allSubagentDefinitions() []subagentDefinition { if currentChat.PlanMode.Valid && currentChat.PlanMode.ChatPlanMode == database.ChatPlanModePlan { return `type "computer_use" is unavailable in plan mode` } - if !p.isAnthropicConfigured(ctx) || !p.isDesktopEnabled(ctx) { - return `type "computer_use" is unavailable because computer use is not configured` + if !p.isDesktopEnabled(ctx) { + return `type "computer_use" is unavailable because desktop access is not enabled` + } + _, _, _, err := p.computerUseProviderAndModelFromConfig(ctx) + if err != nil { + p.logger.Warn(ctx, "computer-use provider config is unavailable", + slog.F("chat_id", currentChat.ID), + slog.Error(err), + ) + return `type "computer_use" is unavailable because its provider configuration could not be loaded` } return "" }, - buildOptions: func(_ context.Context, _ *Server, _ database.Chat, _ database.Chat, _ uuid.UUID, prompt string) (childSubagentChatOptions, error) { + buildOptions: func(ctx context.Context, p *Server, _ database.Chat, _ database.Chat, _ uuid.UUID, prompt string) (childSubagentChatOptions, error) { + provider, _, _, err := p.computerUseProviderAndModelFromConfig(ctx) + if err != nil { + return childSubagentChatOptions{}, err + } + configured, err := p.providerConfigured(ctx, provider) + if err != nil { + return childSubagentChatOptions{}, err + } + if !configured { + return childSubagentChatOptions{}, xerrors.Errorf( + `API key for computer-use provider %q is not configured`, + provider, + ) + } return childSubagentChatOptions{ chatMode: database.NullChatMode{ ChatMode: database.ChatModeComputerUse, diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 3827a894e4..5df6069209 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -1325,7 +1325,7 @@ func TestSpawnAgent_DescriptionListsAllAvailableTypes(t *testing.T) { require.Contains(t, description, subagentTypeComputerUse) } -func TestSpawnAgent_DescriptionOmitsComputerUseWhenUnavailable(t *testing.T) { +func TestSpawnAgent_DescriptionIncludesComputerUseWithMissingProviderKey(t *testing.T) { t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -1335,7 +1335,7 @@ func TestSpawnAgent_DescriptionOmitsComputerUseWhenUnavailable(t *testing.T) { ctx := chatdTestContext(t) user, org, model := seedInternalChatDeps(t, db) parentChat := createInternalParentChat( - ctx, t, server, db, org.ID, user.ID, model.ID, "parent-description-unavailable", + ctx, t, server, db, org.ID, user.ID, model.ID, "parent-description-missing-key", ) tools := server.subagentTools(ctx, func() database.Chat { return parentChat }, parentChat.LastModelConfigID) @@ -1344,7 +1344,7 @@ func TestSpawnAgent_DescriptionOmitsComputerUseWhenUnavailable(t *testing.T) { description := tool.Info().Description require.Contains(t, description, subagentTypeGeneral) require.Contains(t, description, subagentTypeExplore) - require.NotContains(t, description, subagentTypeComputerUse) + require.Contains(t, description, subagentTypeComputerUse) } func TestSpawnAgent_PlanModeDescriptionOmitsComputerUse(t *testing.T) { @@ -1429,7 +1429,7 @@ func TestPlanningOverlaySubagentGuidance_UsesPlanModeSafeDescriptions(t *testing require.NotContains(t, guidance, "may inspect or modify workspace files") } -func TestSpawnAgent_InvalidTypeAndUnavailableTypeAreDistinct(t *testing.T) { +func TestSpawnAgent_InvalidTypeAndCredentialErrorAreDistinct(t *testing.T) { t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -1452,9 +1452,9 @@ func TestSpawnAgent_InvalidTypeAndUnavailableTypeAreDistinct(t *testing.T) { spawnAgentArgs{Type: "invalid", Prompt: "delegate work"}, ) require.True(t, invalidResp.IsError) - require.Contains(t, invalidResp.Content, "type must be one of: general, explore") + require.Contains(t, invalidResp.Content, "type must be one of: general, explore, computer_use") - unavailableResp := runSubagentTool( + credentialResp := runSubagentTool( ctx, t, server, @@ -1463,8 +1463,140 @@ func TestSpawnAgent_InvalidTypeAndUnavailableTypeAreDistinct(t *testing.T) { spawnAgentToolName, spawnAgentArgs{Type: subagentTypeComputerUse, Prompt: "open browser"}, ) - require.True(t, unavailableResp.IsError) - require.Contains(t, unavailableResp.Content, `type "computer_use" is unavailable because computer use is not configured`) + require.True(t, credentialResp.IsError) + require.Contains(t, credentialResp.Content, "API key") + require.Contains(t, credentialResp.Content, "computer-use") + require.Contains(t, credentialResp.Content, "anthropic") +} + +func TestSpawnAgent_ComputerUseAvailabilityUsesConfiguredProvider(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + require.NoError(t, db.UpsertChatDesktopEnabled(ctx, true)) + require.NoError(t, db.UpsertChatComputerUseProvider( + ctx, + chattool.ComputerUseProviderOpenAI, + )) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + user, org, model := seedInternalChatDeps(t, db) + parentChat := createInternalParentChat( + ctx, t, server, db, org.ID, user.ID, model.ID, "parent-openai-computer-use", + ) + + ids := availableSubagentTypeIDs(ctx, server, parentChat) + require.Contains(t, ids, subagentTypeComputerUse) +} + +func TestSpawnAgent_ComputerUseRejectsMissingConfiguredProvider(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + require.NoError(t, db.UpsertChatDesktopEnabled(ctx, true)) + require.NoError(t, db.UpsertChatComputerUseProvider( + ctx, + chattool.ComputerUseProviderOpenAI, + )) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + model := insertInternalChatModelConfigForProvider( + t, + db, + chattool.ComputerUseProviderOpenAI, + "gpt-4o-mini", + true, + ) + parentChat := createInternalParentChat( + ctx, t, server, db, org.ID, user.ID, model.ID, "parent-openai-missing", + ) + + ids := availableSubagentTypeIDs(ctx, server, parentChat) + require.Contains(t, ids, subagentTypeComputerUse) + beforeChats, err := db.GetChats(ctx, database.GetChatsParams{ + OwnerID: user.ID, + AfterID: uuid.Nil, + OffsetOpt: 0, + LimitOpt: 100, + }) + require.NoError(t, err) + + resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{ + Type: subagentTypeComputerUse, + Prompt: "open the browser", + }) + require.True(t, resp.IsError) + require.Contains(t, resp.Content, "API key") + require.Contains(t, resp.Content, "computer-use") + require.Contains(t, resp.Content, "openai") + afterChats, err := db.GetChats(ctx, database.GetChatsParams{ + OwnerID: user.ID, + AfterID: uuid.Nil, + OffsetOpt: 0, + LimitOpt: 100, + }) + require.NoError(t, err) + require.Len(t, afterChats, len(beforeChats)) +} + +func TestSpawnAgent_ComputerUseRejectsInvalidConfiguredProviderWithStableReason(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + require.NoError(t, db.UpsertChatDesktopEnabled(ctx, true)) + require.NoError(t, db.UpsertChatComputerUseProvider(ctx, "bogus")) + logSink := &subagentTestLogSink{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink) + server := newInternalTestServerWithLogger(t, db, ps, chatprovider.ProviderAPIKeys{}, logger) + + user, org, model := seedInternalChatDeps(t, db) + parentChat := createInternalParentChat( + ctx, t, server, db, org.ID, user.ID, model.ID, "parent-invalid-computer-use-provider", + ) + + resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{ + Type: subagentTypeComputerUse, + Prompt: "open the browser", + }) + require.True(t, resp.IsError) + require.Contains(t, resp.Content, `type "computer_use" is unavailable because its provider configuration could not be loaded`) + require.NotContains(t, resp.Content, "bogus") + require.NotContains(t, resp.Content, "agents_computer_use_provider") + require.NotEmpty(t, logSink.entriesAtLevelWithMessage( + slog.LevelWarn, + "computer-use provider config is unavailable", + )) +} + +func TestSpawnAgent_ComputerUseRejectsDesktopDisabled(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{ + Anthropic: "test-anthropic-key", + }) + + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(t, db) + parentChat := createInternalParentChat( + ctx, t, server, db, org.ID, user.ID, model.ID, "parent-desktop-disabled", + ) + + resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{ + Type: subagentTypeComputerUse, + Prompt: "open the browser", + }) + require.True(t, resp.IsError) + require.Contains(t, resp.Content, `type "computer_use" is unavailable because desktop access is not enabled`) } func TestSpawnAgent_BlankTypeReturnsValidOptions(t *testing.T) { @@ -1779,10 +1911,12 @@ func TestSpawnAgent_ComputerUseUsesComputerUseModelNotParent(t *testing.T) { require.Equal(t, parentChat.AgentID, childChat.AgentID) require.True(t, childChat.Mode.Valid) assert.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode) - assert.NotEqual(t, model.Provider, chattool.ComputerUseModelProvider, + computerUseModelProvider, computerUseModelName, ok := chattool.DefaultComputerUseModel(chattool.ComputerUseProviderAnthropic) + require.True(t, ok) + assert.NotEqual(t, model.Provider, computerUseModelProvider, "computer use model provider must differ from parent model provider") - assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider) - assert.NotEmpty(t, chattool.ComputerUseModelName) + assert.Equal(t, "anthropic", computerUseModelProvider) + assert.NotEmpty(t, computerUseModelName) } func TestSpawnAgent_ComputerUseInheritsMCPServerIDs(t *testing.T) { diff --git a/codersdk/chats.go b/codersdk/chats.go index 1d8682b733..257fc6f9ab 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -672,6 +672,18 @@ type AdvisorConfig struct { // the request and response shapes are currently identical. type UpdateAdvisorConfigRequest = AdvisorConfig +// ChatComputerUseProviderResponse is the response for getting the computer use +// provider setting. +type ChatComputerUseProviderResponse struct { + Provider string `json:"provider"` +} + +// UpdateChatComputerUseProviderRequest is the request to update the computer use +// provider setting. +type UpdateChatComputerUseProviderRequest struct { + Provider string `json:"provider"` +} + // ChatDebugLoggingAdminSettings describes the runtime admin setting // that allows users to opt into chat debug logging. type ChatDebugLoggingAdminSettings struct { @@ -2206,6 +2218,34 @@ func (c *ExperimentalClient) UpdateChatAdvisorConfig(ctx context.Context, req Up return nil } +// GetChatComputerUseProvider returns the deployment-wide computer use provider. +func (c *ExperimentalClient) GetChatComputerUseProvider(ctx context.Context) (ChatComputerUseProviderResponse, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/computer-use-provider", nil) + if err != nil { + return ChatComputerUseProviderResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ChatComputerUseProviderResponse{}, ReadBodyAsError(res) + } + var resp ChatComputerUseProviderResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// UpdateChatComputerUseProvider updates the deployment-wide computer use +// provider. +func (c *ExperimentalClient) UpdateChatComputerUseProvider(ctx context.Context, req UpdateChatComputerUseProviderRequest) error { + res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/computer-use-provider", req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // GetChatWorkspaceTTL returns the configured chat workspace TTL. func (c *ExperimentalClient) GetChatWorkspaceTTL(ctx context.Context) (ChatWorkspaceTTLResponse, error) { res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/workspace-ttl", nil) diff --git a/codersdk/workspacesdk/display.go b/codersdk/workspacesdk/display.go index a3e4e4d040..7f180b4fee 100644 --- a/codersdk/workspacesdk/display.go +++ b/codersdk/workspacesdk/display.go @@ -12,6 +12,11 @@ const ( desktopDeclaredMaxLongEdge = 1568 desktopDeclaredMaxTotalPixels = 1_150_000 + + // OpenAI recommends 1440x900 or 1600x900 for computer use. + // Use 1600x900 so screenshots keep the native 16:9 aspect ratio. + desktopOpenAIComputerUseDeclaredWidth = 1600 + desktopOpenAIComputerUseDeclaredHeight = 900 ) var preferredDeclaredDesktopWidths = []int{1280, 1024} @@ -31,6 +36,17 @@ func DefaultDesktopGeometry() DesktopGeometry { return NewDesktopGeometry(DesktopNativeWidth, DesktopNativeHeight) } +// DefaultOpenAIComputerUseDesktopGeometry returns the default native desktop +// geometry with OpenAI's recommended computer-use declared dimensions. +func DefaultOpenAIComputerUseDesktopGeometry() DesktopGeometry { + return NewDesktopGeometryWithDeclared( + DesktopNativeWidth, + DesktopNativeHeight, + desktopOpenAIComputerUseDeclaredWidth, + desktopOpenAIComputerUseDeclaredHeight, + ) +} + // NewDesktopGeometry derives a declared model-facing geometry from the native // desktop size. func NewDesktopGeometry(nativeWidth, nativeHeight int) DesktopGeometry { diff --git a/codersdk/workspacesdk/display_test.go b/codersdk/workspacesdk/display_test.go index 6ff5606c86..69dae9f0cb 100644 --- a/codersdk/workspacesdk/display_test.go +++ b/codersdk/workspacesdk/display_test.go @@ -99,6 +99,19 @@ func TestDefaultDesktopGeometry(t *testing.T) { assert.Equal(t, 720, geometry.DeclaredHeight) } +// TestDefaultOpenAIComputerUseDesktopGeometry pins the model-facing coordinate +// system for OpenAI computer use so future geometry changes are intentional. +func TestDefaultOpenAIComputerUseDesktopGeometry(t *testing.T) { + t.Parallel() + + geometry := workspacesdk.DefaultOpenAIComputerUseDesktopGeometry() + + assert.Equal(t, 1920, geometry.NativeWidth) + assert.Equal(t, 1080, geometry.NativeHeight) + assert.Equal(t, 1600, geometry.DeclaredWidth) + assert.Equal(t, 900, geometry.DeclaredHeight) +} + func TestDesktopGeometryDeclaredPointToNative(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index 7dcd74dbf2..9d16338994 100644 --- a/go.mod +++ b/go.mod @@ -87,7 +87,7 @@ replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713 // replay stored reasoning item references, only replay web_search references // when paired with reasoning, and validate function_call output pairing. // See: https://github.com/coder/fantasy/commits/f83367a4a205 -replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260426185602-951a49c681df +replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260427164812-d0e6ce2243af // coder/coder uses a fork of charmbracelet's fork of the Anthropic Go SDK // with performance improvements and Bedrock header cleanup. diff --git a/go.sum b/go.sum index 4037a81e94..d1c51d3d9f 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4= github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w= github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A= -github.com/coder/fantasy v0.0.0-20260426185602-951a49c681df h1:Xog/dBDcnXxr98lGZqRxOeFrCrhVZUBrFldtXH7v0EY= -github.com/coder/fantasy v0.0.0-20260426185602-951a49c681df/go.mod h1:wZ0e3lEPqrM0XiIdAUQLvMKCLYhc3gi96MRX2wjbX44= +github.com/coder/fantasy v0.0.0-20260427164812-d0e6ce2243af h1:5X38dLzIc5FSgVm9EuKkuKgtXt4fNV5iSCraxfgQXns= +github.com/coder/fantasy v0.0.0-20260427164812-d0e6ce2243af/go.mod h1:wZ0e3lEPqrM0XiIdAUQLvMKCLYhc3gi96MRX2wjbX44= github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4= github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ= github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU= diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 58b29eb768..48b1f7910d 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3360,6 +3360,24 @@ class ExperimentalApiMethods { await this.axios.put("/api/experimental/chats/config/advisor", req); }; + getChatComputerUseProvider = + async (): Promise => { + const response = + await this.axios.get( + "/api/experimental/chats/config/computer-use-provider", + ); + return response.data; + }; + + updateChatComputerUseProvider = async ( + req: TypesGen.UpdateChatComputerUseProviderRequest, + ): Promise => { + await this.axios.put( + "/api/experimental/chats/config/computer-use-provider", + req, + ); + }; + getChatWorkspaceTTL = async (): Promise => { const response = await this.axios.get( diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 26b0c1b8f2..e61a15afe8 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -1348,6 +1348,22 @@ export const updateChatAdvisorConfig = (queryClient: QueryClient) => ({ }, }); +const chatComputerUseProviderKey = ["chat-computer-use-provider"] as const; + +export const chatComputerUseProvider = () => ({ + queryKey: chatComputerUseProviderKey, + queryFn: () => API.experimental.getChatComputerUseProvider(), +}); + +export const updateChatComputerUseProvider = (queryClient: QueryClient) => ({ + mutationFn: API.experimental.updateChatComputerUseProvider, + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: chatComputerUseProviderKey, + }); + }, +}); + const chatWorkspaceTTLKey = ["chat-workspace-ttl"] as const; export const chatWorkspaceTTL = () => ({ diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e395c71108..c996ef9188 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1346,6 +1346,15 @@ export const ChatClientTypes: ChatClientType[] = ["api", "ui"]; export const ChatCompactionThresholdKeyPrefix = "chat_compaction_threshold_pct:"; +// From codersdk/chats.go +/** + * ChatComputerUseProviderResponse is the response for getting the computer use + * provider setting. + */ +export interface ChatComputerUseProviderResponse { + readonly provider: string; +} + // From codersdk/deployment.go export interface ChatConfig { readonly acquire_batch_size: number; @@ -7817,6 +7826,15 @@ export interface UpdateChatAutoArchiveDaysRequest { readonly auto_archive_days: number; } +// From codersdk/chats.go +/** + * UpdateChatComputerUseProviderRequest is the request to update the computer use + * provider setting. + */ +export interface UpdateChatComputerUseProviderRequest { + readonly provider: string; +} + // From codersdk/chats.go /** * UpdateChatDebugLoggingAllowUsersRequest is the admin request to diff --git a/site/src/pages/AgentsPage/AgentSettingsExperimentsPage.tsx b/site/src/pages/AgentsPage/AgentSettingsExperimentsPage.tsx index e576783815..c577126afc 100644 --- a/site/src/pages/AgentsPage/AgentSettingsExperimentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsExperimentsPage.tsx @@ -2,10 +2,12 @@ import type { FC } from "react"; import { useMutation, useQuery, useQueryClient } from "react-query"; import { chatAdvisorConfig, + chatComputerUseProvider, chatDebugLogging, chatDesktopEnabled, chatModelConfigs, updateChatAdvisorConfig, + updateChatComputerUseProvider, updateChatDebugLogging, updateChatDesktopEnabled, } from "#/api/queries/chats"; @@ -20,6 +22,10 @@ const AgentSettingsExperimentsPage: FC = () => { ...chatDesktopEnabled(), enabled: permissions.editDeploymentConfig, }); + const computerUseProviderQuery = useQuery({ + ...chatComputerUseProvider(), + enabled: permissions.editDeploymentConfig, + }); const debugLoggingQuery = useQuery({ ...chatDebugLogging(), enabled: permissions.editDeploymentConfig, @@ -35,6 +41,9 @@ const AgentSettingsExperimentsPage: FC = () => { const saveDesktopEnabledMutation = useMutation( updateChatDesktopEnabled(queryClient), ); + const saveComputerUseProviderMutation = useMutation( + updateChatComputerUseProvider(queryClient), + ); const saveDebugLoggingMutation = useMutation( updateChatDebugLogging(queryClient), ); @@ -46,10 +55,17 @@ const AgentSettingsExperimentsPage: FC = () => { ; +const getComputerUseProviderSelect = async (canvasElement: HTMLElement) => { + const canvas = within(canvasElement); + return canvas.findByRole("combobox", { + name: "Computer use provider", + }); +}; + +const selectComputerUseProvider = async ( + canvasElement: HTMLElement, + currentSelectionName: string, + optionName: string, +) => { + const trigger = await getComputerUseProviderSelect(canvasElement); + expect(trigger).toHaveTextContent(currentSelectionName); + + await userEvent.click(trigger); + const body = within(canvasElement.ownerDocument.body); + await userEvent.click(await body.findByRole("option", { name: optionName })); + await waitFor(() => expect(trigger).toHaveTextContent(optionName)); +}; + +function InteractiveComputerUseProviderStory( + args: AgentSettingsExperimentsPageViewProps, +) { + const [computerUseProviderData, setComputerUseProviderData] = useState( + args.computerUseProviderData, + ); + + return ( + { + if (options) { + args.onSaveComputerUseProvider(request, options); + } else { + args.onSaveComputerUseProvider(request); + } + setComputerUseProviderData({ provider: request.provider }); + }} + /> + ); +} + export const AllowUsersOff: Story = { play: async ({ canvasElement }) => { const canvas = within(canvasElement); @@ -110,6 +162,26 @@ export const DesktopSetting: Story = { }, }; +export const VirtualDesktopLoading: Story = { + args: { + desktopEnabledData: undefined, + isLoadingDesktopEnabled: true, + computerUseProviderData: undefined, + isLoadingComputerUseProvider: true, + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + // While loading, the Switch is replaced by a skeleton placeholder. + expect( + canvas.queryByRole("switch", { name: "Enable" }), + ).not.toBeInTheDocument(); + + const providerSelect = await getComputerUseProviderSelect(canvasElement); + expect(providerSelect).toBeDisabled(); + }, +}; + export const TogglesDesktop: Story = { play: async ({ canvasElement, args }) => { const canvas = within(canvasElement); @@ -123,3 +195,84 @@ export const TogglesDesktop: Story = { }); }, }; + +export const ComputerUseProviderAnthropic: Story = { + args: { + desktopEnabledData: { enable_desktop: true }, + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + await canvas.findByText("Computer use provider"); + const providerSelect = await getComputerUseProviderSelect(canvasElement); + + expect(providerSelect).not.toBeDisabled(); + expect(providerSelect).toHaveTextContent("Anthropic"); + }, +}; + +export const ComputerUseProviderDisabledWhenDesktopDisabled: Story = { + play: async ({ canvasElement }) => { + const providerSelect = await getComputerUseProviderSelect(canvasElement); + + expect(providerSelect).toBeDisabled(); + }, +}; + +export const SelectsOpenAIProvider: Story = { + args: { + desktopEnabledData: { enable_desktop: true }, + onSaveComputerUseProvider: fn(), + }, + render: InteractiveComputerUseProviderStory, + play: async ({ canvasElement, args }) => { + await selectComputerUseProvider(canvasElement, "Anthropic", "OpenAI"); + + await waitFor(() => { + expect(args.onSaveComputerUseProvider).toHaveBeenCalledWith({ + provider: "openai", + }); + }); + }, +}; + +export const SelectsAnthropicProvider: Story = { + args: { + desktopEnabledData: { enable_desktop: true }, + computerUseProviderData: { provider: "openai" }, + onSaveComputerUseProvider: fn(), + }, + render: InteractiveComputerUseProviderStory, + play: async ({ canvasElement, args }) => { + await selectComputerUseProvider(canvasElement, "OpenAI", "Anthropic"); + + await waitFor(() => { + expect(args.onSaveComputerUseProvider).toHaveBeenCalledWith({ + provider: "anthropic", + }); + }); + }, +}; + +export const ComputerUseProviderSaveError: Story = { + args: { + computerUseProviderSaveError: new Error("Failed to save."), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect( + await canvas.findByText("Failed to save computer use provider."), + ).toBeInTheDocument(); + }, +}; + +export const ComputerUseProviderSaving: Story = { + args: { + desktopEnabledData: { enable_desktop: true }, + isSavingComputerUseProvider: true, + }, + play: async ({ canvasElement }) => { + const providerSelect = await getComputerUseProviderSelect(canvasElement); + + expect(providerSelect).toBeDisabled(); + }, +}; diff --git a/site/src/pages/AgentsPage/AgentSettingsExperimentsPageView.tsx b/site/src/pages/AgentsPage/AgentSettingsExperimentsPageView.tsx index 5671cbade9..afb36452d1 100644 --- a/site/src/pages/AgentsPage/AgentSettingsExperimentsPageView.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsExperimentsPageView.tsx @@ -13,6 +13,7 @@ interface MutationCallbacks { export interface AgentSettingsExperimentsPageViewProps { desktopEnabledData: TypesGen.ChatDesktopEnabledResponse | undefined; + isLoadingDesktopEnabled: boolean; onSaveDesktopEnabled: UseMutateFunction< void, Error, @@ -21,7 +22,18 @@ export interface AgentSettingsExperimentsPageViewProps { >; isSavingDesktopEnabled: boolean; isSaveDesktopEnabledError: boolean; + computerUseProviderData: TypesGen.ChatComputerUseProviderResponse | undefined; + isLoadingComputerUseProvider: boolean; + onSaveComputerUseProvider: UseMutateFunction< + void, + Error, + TypesGen.UpdateChatComputerUseProviderRequest, + unknown + >; + isSavingComputerUseProvider: boolean; + computerUseProviderSaveError: Error | null; debugLoggingData: TypesGen.ChatDebugLoggingAdminSettings | undefined; + isLoadingDebugLogging: boolean; onSaveDebugLogging: UseMutateFunction< void, Error, @@ -51,10 +63,17 @@ export const AgentSettingsExperimentsPageView: FC< AgentSettingsExperimentsPageViewProps > = ({ desktopEnabledData, + isLoadingDesktopEnabled, onSaveDesktopEnabled, isSavingDesktopEnabled, isSaveDesktopEnabledError, + computerUseProviderData, + isLoadingComputerUseProvider, + onSaveComputerUseProvider, + isSavingComputerUseProvider, + computerUseProviderSaveError, debugLoggingData, + isLoadingDebugLogging, onSaveDebugLogging, isSavingDebugLogging, isSaveDebugLoggingError, @@ -79,9 +98,15 @@ export const AgentSettingsExperimentsPageView: FC< /> = ({ adminSettings, + isLoadingAdminSetting, onSaveAdminSetting, isSavingAdminSetting, isSaveAdminSettingError, @@ -48,14 +51,24 @@ export const AdminChatDebugLoggingSettings: FC<

)} - - onSaveAdminSetting({ allow_users: checked }) - } - aria-label="Allow users to enable chat debug logging" - disabled={forcedByDeployment || isSavingAdminSetting} - /> +
+ {isLoadingAdminSetting ? ( +
{isSaveAdminSettingError && (

diff --git a/site/src/pages/AgentsPage/components/VirtualDesktopSettings.tsx b/site/src/pages/AgentsPage/components/VirtualDesktopSettings.tsx index 999498831f..69d65914b3 100644 --- a/site/src/pages/AgentsPage/components/VirtualDesktopSettings.tsx +++ b/site/src/pages/AgentsPage/components/VirtualDesktopSettings.tsx @@ -3,6 +3,15 @@ import type { FC } from "react"; import type * as TypesGen from "#/api/typesGenerated"; import { Badge } from "#/components/Badge/Badge"; import { Link } from "#/components/Link/Link"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from "#/components/Select/Select"; +import { Skeleton } from "#/components/Skeleton/Skeleton"; import { Switch } from "#/components/Switch/Switch"; interface MutationCallbacks { @@ -12,21 +21,57 @@ interface MutationCallbacks { interface VirtualDesktopSettingsProps { desktopEnabledData: TypesGen.ChatDesktopEnabledResponse | undefined; + isLoadingDesktopEnabled: boolean; onSaveDesktopEnabled: ( req: TypesGen.UpdateChatDesktopEnabledRequest, options?: MutationCallbacks, ) => void; isSavingDesktopEnabled: boolean; isSaveDesktopEnabledError: boolean; + computerUseProviderData: TypesGen.ChatComputerUseProviderResponse | undefined; + isLoadingComputerUseProvider: boolean; + onSaveComputerUseProvider: ( + req: TypesGen.UpdateChatComputerUseProviderRequest, + options?: MutationCallbacks, + ) => void; + isSavingComputerUseProvider: boolean; + computerUseProviderSaveError: Error | null; } +const computerUseProviderOptions = [ + { label: "Anthropic", value: "anthropic" }, + { label: "OpenAI", value: "openai" }, +] as const; + +const getComputerUseProviderLabel = (provider: string) => { + return ( + computerUseProviderOptions.find((option) => option.value === provider) + ?.label ?? provider + ); +}; + export const VirtualDesktopSettings: FC = ({ desktopEnabledData, + isLoadingDesktopEnabled, onSaveDesktopEnabled, isSavingDesktopEnabled, isSaveDesktopEnabledError, + computerUseProviderData, + isLoadingComputerUseProvider, + onSaveComputerUseProvider, + isSavingComputerUseProvider, + computerUseProviderSaveError, }) => { const desktopEnabled = desktopEnabledData?.enable_desktop ?? false; + const computerUseProvider = computerUseProviderData?.provider ?? ""; + const isDesktopSwitchDisabled = + isSavingDesktopEnabled || isLoadingDesktopEnabled; + const isComputerUseProviderSelectDisabled = + !desktopEnabled || + isSavingDesktopEnabled || + isLoadingDesktopEnabled || + isSavingComputerUseProvider || + isLoadingComputerUseProvider; return (

@@ -40,14 +85,20 @@ export const VirtualDesktopSettings: FC = ({ Experimental feature
- - onSaveDesktopEnabled({ enable_desktop: checked }) - } - aria-label="Enable" - disabled={isSavingDesktopEnabled} - /> +
+ {isLoadingDesktopEnabled ? ( +

@@ -60,15 +111,65 @@ export const VirtualDesktopSettings: FC = ({ > portabledesktop module {" "} - to be installed in the workspace and the Anthropic provider to be - configured. + to be installed in the workspace and the selected computer use + provider to be configured.

+
+
+

+ Computer use provider +

+

+ Select the provider agents use for computer-use actions when virtual + desktop is enabled. +

+
+ +
{isSaveDesktopEnabledError && (

Failed to save desktop setting.

)} + {computerUseProviderSaveError && ( +

+ Failed to save computer use provider. +

+ )} ); };