diff --git a/Makefile b/Makefile index 9723827011..6d71bb1826 100644 --- a/Makefile +++ b/Makefile @@ -1293,6 +1293,7 @@ coderd/apidoc/.gen: \ $(wildcard enterprise/coderd/*.go) \ $(wildcard codersdk/*.go) \ $(wildcard enterprise/wsproxy/wsproxysdk/*.go) \ + $(wildcard coderd/workspaceconnwatcher/*.go) \ $(DB_GEN_FILES) \ coderd/rbac/object_gen.go \ .swaggo \ diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index a2ecc72bb7..9a3bfc39d7 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -12783,6 +12783,41 @@ const docTemplate = `{ ] } }, + "/api/v2/workspaces/{workspace}/agent-connection-watch": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "Workspaces" + ], + "summary": "Workspace Agent Connection Watch", + "operationId": "workspace-agent-connection-watch", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace ID", + "name": "workspace", + "in": "path", + "required": true + } + ], + "responses": { + "101": { + "description": "Switching Protocols", + "schema": { + "$ref": "#/definitions/workspacesdk.ConnectionWatchEvent" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/api/v2/workspaces/{workspace}/autostart": { "put": { "consumes": [ @@ -27964,6 +27999,93 @@ const docTemplate = `{ } } }, + "workspacesdk.AgentUpdate": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + }, + "lifecycle": { + "$ref": "#/definitions/codersdk.WorkspaceAgentLifecycle" + } + } + }, + "workspacesdk.BuildUpdate": { + "type": "object", + "properties": { + "job_status": { + "$ref": "#/definitions/codersdk.ProvisionerJobStatus" + }, + "transition": { + "$ref": "#/definitions/codersdk.WorkspaceTransition" + } + } + }, + "workspacesdk.ConnectionWatchEvent": { + "type": "object", + "properties": { + "agent_update": { + "$ref": "#/definitions/workspacesdk.AgentUpdate" + }, + "build_update": { + "$ref": "#/definitions/workspacesdk.BuildUpdate" + }, + "error": { + "$ref": "#/definitions/workspacesdk.WatchError" + } + } + }, + "workspacesdk.WatchError": { + "type": "object", + "properties": { + "code": { + "$ref": "#/definitions/workspacesdk.WatchErrorCode" + }, + "details": { + "type": "string" + }, + "message": { + "type": "string" + }, + "retryable": { + "type": "boolean" + } + } + }, + "workspacesdk.WatchErrorCode": { + "type": "integer", + "enum": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ], + "x-enum-comments": { + "_": "Ensure that zero value is not a valid code" + }, + "x-enum-descriptions": [ + "Ensure that zero value is not a valid code", + "", + "", + "", + "", + "", + "" + ], + "x-enum-varnames": [ + "_", + "WatchErrorTooManyAgents", + "WatchErrorNameNotFound", + "WatchErrorNoAgents", + "WatchErrorServerShutdown", + "WatchErrorDatabase", + "WatchErrorInternal" + ] + }, "wsproxysdk.CryptoKeysResponse": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 688778a27a..17aae9c74e 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -11343,6 +11343,37 @@ ] } }, + "/api/v2/workspaces/{workspace}/agent-connection-watch": { + "get": { + "produces": ["application/json"], + "tags": ["Workspaces"], + "summary": "Workspace Agent Connection Watch", + "operationId": "workspace-agent-connection-watch", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace ID", + "name": "workspace", + "in": "path", + "required": true + } + ], + "responses": { + "101": { + "description": "Switching Protocols", + "schema": { + "$ref": "#/definitions/workspacesdk.ConnectionWatchEvent" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/api/v2/workspaces/{workspace}/autostart": { "put": { "consumes": ["application/json"], @@ -25823,6 +25854,85 @@ } } }, + "workspacesdk.AgentUpdate": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + }, + "lifecycle": { + "$ref": "#/definitions/codersdk.WorkspaceAgentLifecycle" + } + } + }, + "workspacesdk.BuildUpdate": { + "type": "object", + "properties": { + "job_status": { + "$ref": "#/definitions/codersdk.ProvisionerJobStatus" + }, + "transition": { + "$ref": "#/definitions/codersdk.WorkspaceTransition" + } + } + }, + "workspacesdk.ConnectionWatchEvent": { + "type": "object", + "properties": { + "agent_update": { + "$ref": "#/definitions/workspacesdk.AgentUpdate" + }, + "build_update": { + "$ref": "#/definitions/workspacesdk.BuildUpdate" + }, + "error": { + "$ref": "#/definitions/workspacesdk.WatchError" + } + } + }, + "workspacesdk.WatchError": { + "type": "object", + "properties": { + "code": { + "$ref": "#/definitions/workspacesdk.WatchErrorCode" + }, + "details": { + "type": "string" + }, + "message": { + "type": "string" + }, + "retryable": { + "type": "boolean" + } + } + }, + "workspacesdk.WatchErrorCode": { + "type": "integer", + "enum": [0, 1, 2, 3, 4, 5, 6], + "x-enum-comments": { + "_": "Ensure that zero value is not a valid code" + }, + "x-enum-descriptions": [ + "Ensure that zero value is not a valid code", + "", + "", + "", + "", + "", + "" + ], + "x-enum-varnames": [ + "_", + "WatchErrorTooManyAgents", + "WatchErrorNameNotFound", + "WatchErrorNoAgents", + "WatchErrorServerShutdown", + "WatchErrorDatabase", + "WatchErrorInternal" + ] + }, "wsproxysdk.CryptoKeysResponse": { "type": "object", "properties": { diff --git a/coderd/coderd.go b/coderd/coderd.go index dedd032d80..9f1d85bca7 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -92,6 +92,7 @@ import ( "github.com/coder/coder/v2/coderd/webpush" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/coderd/workspaceconnwatcher" "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/coderd/x/chatd" @@ -923,6 +924,8 @@ func New(options *Options) *API { APIKeyEncryptionKeycache: options.AppEncryptionKeyCache, }) + api.workspaceAgentConnWatcher = workspaceconnwatcher.New(api.ctx, options.Logger, options.Pubsub, options.Database) + apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: options.Database, ActivateDormantUser: ActivateDormantUser(options.Logger, &api.Auditor, options.Database), @@ -1820,6 +1823,7 @@ func New(options *Options) *API { r.Patch("/", api.patchWorkspaceACL) r.Delete("/", api.deleteWorkspaceACL) }) + r.Get("/agent-connection-watch", api.workspaceAgentConnWatcher.WorkspaceAgentConnectionWatch) }) }) r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) { @@ -2238,6 +2242,8 @@ type API struct { // profile collection (via /debug/profile) can run at a time. The CPU // profiler is process-global, so concurrent collections would fail. ProfileCollecting atomic.Bool + + workspaceAgentConnWatcher *workspaceconnwatcher.Watcher } // Close waits for all WebSocket connections to drain before returning. @@ -2301,6 +2307,7 @@ func (api *API) Close() error { _ = api.AppSigningKeyCache.Close() _ = api.AppEncryptionKeyCache.Close() _ = api.UpdatesProvider.Close() + api.workspaceAgentConnWatcher.Close() if current := api.PrebuildsReconciler.Load(); current != nil { ctx, giveUp := context.WithTimeoutCause(context.Background(), time.Second*30, xerrors.New("gave up waiting for reconciler to stop before shutdown")) diff --git a/coderd/coderdtest/database.go b/coderd/coderdtest/database.go new file mode 100644 index 0000000000..2071e99178 --- /dev/null +++ b/coderd/coderdtest/database.go @@ -0,0 +1,28 @@ +package coderdtest + +import ( + "sync/atomic" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/rbac" +) + +func MockedDatabaseWithAuthz(t testing.TB, logger slog.Logger) (*gomock.Controller, *dbmock.MockStore, database.Store, rbac.Authorizer) { + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) + accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{} + var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{} + accessControlStore.Store(&acs) + // dbauthz will call Wrappers() to check for wrapped databases + mDB.EXPECT().Wrappers().Return([]string{}).AnyTimes() + authDB := dbauthz.New(mDB, auth, logger, accessControlStore) + return ctrl, mDB, authDB, auth +} diff --git a/coderd/coderdtest/subjects.go b/coderd/coderdtest/subjects.go new file mode 100644 index 0000000000..97d61af42b --- /dev/null +++ b/coderd/coderdtest/subjects.go @@ -0,0 +1,31 @@ +package coderdtest + +import ( + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/rolestore" +) + +func MemberSubject(userID, orgID uuid.UUID) rbac.Subject { + memberRole, err := rbac.RoleByName(rbac.RoleMember()) + if err != nil { + panic(err) + } + orgMember, err := rolestore.TestingGetSystemRole( + rbac.RoleOrgMember(), + orgID, + rbac.OrgSettings{ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwnersNone}, + ) + if err != nil { + panic(err) + } + return rbac.Subject{ + FriendlyName: "coderdtest-member", + Email: "member@coderd.test", + Type: rbac.SubjectTypeUser, + ID: userID.String(), + Roles: rbac.Roles{memberRole, orgMember}, + Scope: rbac.ScopeAll, + }.WithCachedASTValue() +} diff --git a/coderd/coderdtest/swagger_test.go b/coderd/coderdtest/swagger_test.go index 5f43eb4872..71db94d44c 100644 --- a/coderd/coderdtest/swagger_test.go +++ b/coderd/coderdtest/swagger_test.go @@ -16,7 +16,7 @@ import ( func TestEndpointsDocumented(t *testing.T) { t.Parallel() - swaggerComments, err := coderdtest.ParseSwaggerComments("..") + swaggerComments, err := coderdtest.ParseSwaggerComments("..", "../workspaceconnwatcher") require.NoError(t, err, "can't parse swagger comments") require.NotEmpty(t, swaggerComments, "swagger comments must be present") diff --git a/coderd/httpmw/workspaceparam.go b/coderd/httpmw/workspaceparam.go index 25b07aa669..cab77d6d92 100644 --- a/coderd/httpmw/workspaceparam.go +++ b/coderd/httpmw/workspaceparam.go @@ -54,3 +54,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler { }) } } + +func WithWorkspaceParam(ctx context.Context, workspace database.Workspace) context.Context { + return context.WithValue(ctx, workspaceParamContextKey{}, workspace) +} diff --git a/coderd/rbac/rolestore/rolestore.go b/coderd/rbac/rolestore/rolestore.go index c246778995..9f95c1870a 100644 --- a/coderd/rbac/rolestore/rolestore.go +++ b/coderd/rbac/rolestore/rolestore.go @@ -170,6 +170,25 @@ var systemRoles = map[string]permissionsFunc{ rbac.RoleOrgServiceAccount(): rbac.OrgServiceAccountPermissions, } +func TestingGetSystemRole(name string, orgID uuid.UUID, settings rbac.OrgSettings) (rbac.Role, error) { + f, ok := systemRoles[name] + if !ok { + return rbac.Role{}, xerrors.Errorf("role %q not found", name) + } + perms := f(settings) + return rbac.Role{ + Identifier: rbac.RoleIdentifier{Name: name, OrganizationID: orgID}, + DisplayName: "", + Site: nil, + ByOrgID: map[string]rbac.OrgPermissions{ + orgID.String(): { + Org: perms.Org, + Member: perms.Member, + }, + }, + }, nil +} + // permissionsFunc produces the desired permissions for a system role // given organization settings. type permissionsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions diff --git a/coderd/workspaceconnwatcher/watcher.go b/coderd/workspaceconnwatcher/watcher.go new file mode 100644 index 0000000000..44145b9fe8 --- /dev/null +++ b/coderd/workspaceconnwatcher/watcher.go @@ -0,0 +1,333 @@ +package workspaceconnwatcher + +import ( + "context" + "database/sql" + "errors" + "net/http" + "sync" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/wspubsub" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/websocket" +) + +type Watcher struct { + logger slog.Logger + sub pubsub.Subscriber + db database.Store + ctx context.Context + cancel context.CancelFunc + + mu sync.Mutex + wg sync.WaitGroup + closed bool +} + +type event struct { + sync bool + wsEvent *wspubsub.WorkspaceEvent +} + +func New(ctx context.Context, logger slog.Logger, sub pubsub.Subscriber, db database.Store) *Watcher { + ctx, cancel := context.WithCancel(ctx) + w := &Watcher{ + logger: logger.Named("wsconnwatcher"), + ctx: ctx, + cancel: cancel, + sub: sub, + db: db, + } + go func() { + <-ctx.Done() + w.Close() + }() + return w +} + +// @Summary Workspace Agent Connection Watch +// @ID workspace-agent-connection-watch +// @Security CoderSessionToken +// @Produce json +// @Tags Workspaces +// @Param workspace path string true "Workspace ID" format(uuid) +// @Success 101 {object} workspacesdk.ConnectionWatchEvent +// @Router /api/v2/workspaces/{workspace}/agent-connection-watch [get] +func (w *Watcher) WorkspaceAgentConnectionWatch(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspace := httpmw.WorkspaceParam(r) + agentName := r.URL.Query().Get("agent_name") + + filteredEvents := make(chan event, 1) + filteredEvents <- event{sync: true} // init sync + cancelWorkspaceSubscribe, err := w.sub.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(ctx context.Context, payload wspubsub.WorkspaceEvent, err error) { + if err != nil { + // subscription error, resync + select { + case filteredEvents <- event{sync: true}: + case <-ctx.Done(): + } + return + } + if payload.WorkspaceID != workspace.ID { + return + } + select { + case filteredEvents <- event{wsEvent: &payload}: + case <-ctx.Done(): + } + })) + if err != nil { + w.logger.Error(ctx, "failed to subscribe to workspace events", + slog.Error(err), slog.F("owner_id", workspace.OwnerID)) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error setting up workspace event subscription", + // Don't include the error in case it leaks infra details about the pubsub + }) + return + } + defer cancelWorkspaceSubscribe() + + closed := false + w.mu.Lock() + closed = w.closed + if !closed { + w.wg.Add(1) + } + w.mu.Unlock() + if closed { + w.logger.Debug(ctx, "server is closed, writing error") + httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{ + Message: "Server instance is shutting down", + }) + return + } + defer w.wg.Done() + + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept WebSocket.", + Detail: err.Error(), + }) + return + } + + // CloseRead starts a goroutine to read and discard messages from the client, + // including Pong messages sent in response to our Ping heartbeats. + _ = conn.CloseRead(ctx) + + ctx, cancel := context.WithCancel(ctx) + go httpapi.HeartbeatClose(ctx, w.logger, cancel, conn) + defer cancel() + + u := &updater{ + db: w.db, + watcherCtx: w.ctx, + connCtx: ctx, + conn: conn, + workspaceID: workspace.ID, + events: filteredEvents, + agentName: agentName, + } + u.run() +} + +func (w *Watcher) Close() { + w.mu.Lock() + w.closed = true + w.mu.Unlock() + + w.cancel() + w.wg.Wait() +} + +type updater struct { + db database.Store + watcherCtx context.Context + connCtx context.Context + conn *websocket.Conn + enc *wsjson.Encoder[workspacesdk.ConnectionWatchEvent] + workspaceID uuid.UUID + events <-chan event + agentName string + + lastBuild database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow +} + +func (u *updater) run() { + u.enc = wsjson.NewEncoder[workspacesdk.ConnectionWatchEvent](u.conn, websocket.MessageText) + defer func() { + // this is a no-op if we have already closed for some other reason. + _ = u.enc.Close(websocket.StatusNormalClosure) + }() + + for { + select { + case <-u.watcherCtx.Done(): + u.errorThenClose(workspacesdk.WatchError{ + Code: workspacesdk.WatchErrorServerShutdown, + Retryable: true, + Message: "server is shutting down", + }) + return + case <-u.connCtx.Done(): + return + case e := <-u.events: + if e.sync { + // zero this out so we'll send a full update + u.lastBuild = database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{} + if !u.buildUpdate() { + return + } + } + if e.wsEvent != nil { + switch e.wsEvent.Kind { + case wspubsub.WorkspaceEventKindStateChange: + if !u.buildUpdate() { + return + } + case wspubsub.WorkspaceEventKindAgentLifecycleUpdate: + if !u.maybeSendAgentUpdate() { + return + } + } + } + } + } +} + +func (u *updater) buildUpdate() bool { + build, err := u.db.GetLatestWorkspaceBuildWithStatusByWorkspaceID(u.connCtx, u.workspaceID) + if err != nil { + retryable := true + details := err.Error() + if errors.Is(err, sql.ErrNoRows) { + // There is no build (unlikely), or the workspace was deleted. In both cases, retrying won't help. + retryable = false + } + if dbauthz.IsNotAuthorizedError(err) { + retryable = false + details = "unauthorized" // security: don't leak internal authz details + } + u.errorThenClose(workspacesdk.WatchError{ + Code: workspacesdk.WatchErrorDatabase, + Retryable: retryable, + Message: "failed to fetch latest workspace build", + Details: details, + }) + return false + } + + if build.BuildNumber != u.lastBuild.BuildNumber || + build.JobStatus != u.lastBuild.JobStatus || + build.Transition != u.lastBuild.Transition { + u.lastBuild = build + err = u.enc.Encode(workspacesdk.ConnectionWatchEvent{BuildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransition(build.Transition), + JobStatus: codersdk.ProvisionerJobStatus(build.JobStatus), + }}) + if err != nil { + // probably this is just that the connection is closed, but in case there is some actual JSON serialization + // error, send a close frame. + _ = u.conn.Close(websocket.StatusInternalError, "failed to encode build update") + return false + } + return u.maybeSendAgentUpdate() + } + return true +} + +func (u *updater) maybeSendAgentUpdate() (ok bool) { + if u.lastBuild.Transition != database.WorkspaceTransitionStart || + u.lastBuild.JobStatus != database.ProvisionerJobStatusSucceeded { + // only send agent updates for successfully started workspaces + return true + } + + agents, err := u.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(u.connCtx, + database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{ + WorkspaceID: u.workspaceID, + BuildNumber: u.lastBuild.BuildNumber, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + details := err.Error() + retryable := true + if dbauthz.IsNotAuthorizedError(err) { + retryable = false + details = "unauthorized" + } + u.errorThenClose(workspacesdk.WatchError{ + Code: workspacesdk.WatchErrorDatabase, + Retryable: retryable, + Message: "failed to fetch workspace agents", + Details: details, + }) + return false + } + if len(agents) == 0 { + u.errorThenClose(workspacesdk.WatchError{ + Code: workspacesdk.WatchErrorNoAgents, + Retryable: false, + Message: "no agents found for workspace", + }) + return false + } + if len(agents) > 1 && u.agentName == "" { + u.errorThenClose(workspacesdk.WatchError{ + Code: workspacesdk.WatchErrorTooManyAgents, + Retryable: false, + Message: "more than one agent on workspace and target not specified", + }) + return false + } + var agent database.WorkspaceAgent + if u.agentName == "" { + agent = agents[0] + } else { + for _, a := range agents { + if a.Name == u.agentName { + agent = a + break + } + } + if agent.ID == uuid.Nil { + u.errorThenClose(workspacesdk.WatchError{ + Code: workspacesdk.WatchErrorNameNotFound, + Retryable: false, + Message: "target agent not found by name", + }) + return false + } + } + + err = u.enc.Encode(workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{ + Lifecycle: codersdk.WorkspaceAgentLifecycle(agent.LifecycleState), + ID: agent.ID, + }}) + if err != nil { + // probably this is just that the connection is closed, but in case there is some actual JSON serialization + // error, send a close frame. + _ = u.conn.Close(websocket.StatusInternalError, "failed to encode agent update") + return false + } + return true +} + +func (u *updater) errorThenClose(err workspacesdk.WatchError) { + _ = u.enc.Encode(workspacesdk.ConnectionWatchEvent{Error: &err}) + // ignore encoding errors above because in any case, we are going to close the connection. + _ = u.conn.Close(websocket.StatusNormalClosure, "error") +} diff --git a/coderd/workspaceconnwatcher/watcher_test.go b/coderd/workspaceconnwatcher/watcher_test.go new file mode 100644 index 0000000000..0fdab4fed1 --- /dev/null +++ b/coderd/workspaceconnwatcher/watcher_test.go @@ -0,0 +1,474 @@ +package workspaceconnwatcher_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/workspaceconnwatcher" + "github.com/coder/coder/v2/coderd/wspubsub" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +var ( + workspaceID = uuid.UUID{1} + userID = uuid.UUID{2} + orgID = uuid.UUID{3} + agentID = uuid.UUID{4} +) + +type harness struct { + db *dbmock.MockStore + watcher *workspaceconnwatcher.Watcher + pub pubsub.Publisher + logger slog.Logger + + // Initialized, but overridable before Dial() + workspace database.Workspace + userID, orgID uuid.UUID +} + +func newHarness(ctx context.Context, t *testing.T, logger slog.Logger) *harness { + h := &harness{ + workspace: database.Workspace{ + ID: workspaceID, + OrganizationID: orgID, + OwnerID: userID, + }, + orgID: orgID, + userID: userID, + logger: logger, + } + ps := pubsub.NewInMemory() + h.pub = ps + + var authzDB database.Store + _, h.db, authzDB, _ = coderdtest.MockedDatabaseWithAuthz(t, logger) + h.watcher = workspaceconnwatcher.New(ctx, logger.Named("watcher"), ps, authzDB) + t.Cleanup(h.watcher.Close) + return h +} + +func (h *harness) Dial(ctx context.Context, url string) (*wsjson.Decoder[workspacesdk.ConnectionWatchEvent], error) { + rt := testutil.InMemWebsocketRoundTripper{ + Handler: http.HandlerFunc(h.watcher.WorkspaceAgentConnectionWatch), + CtxMutator: func(ctx context.Context) context.Context { + ctx = httpmw.WithWorkspaceParam(ctx, h.workspace) + ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID)) + return ctx + }, + Logger: h.logger.Named("roundtripper"), + } + // nolint: bodyclose + clientSock, resp, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPClient: &http.Client{Transport: rt}, + }) + if err != nil { + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, codersdk.ReadBodyAsError(resp) + } + return nil, err + } + + dec := wsjson.NewDecoder[workspacesdk.ConnectionWatchEvent]( + clientSock, websocket.MessageText, h.logger.Named("decoder")) + return dec, nil +} + +func TestWatcher_Agents(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + agents []database.WorkspaceAgent + agentDBError error + url string + expectedAgentUpdate *workspacesdk.AgentUpdate + expectedErrorCode workspacesdk.WatchErrorCode + expectedErrorRetryable bool + }{ + { + name: "noNameSingleAgent", + agents: []database.WorkspaceAgent{ + { + Name: "test", + ID: agentID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + }, + url: "wss://local.test/", + expectedAgentUpdate: &workspacesdk.AgentUpdate{ + Lifecycle: codersdk.WorkspaceAgentLifecycleCreated, + ID: agentID, + }, + }, + { + name: "noNameMultiAgent", + agents: []database.WorkspaceAgent{ + { + Name: "agent0", + ID: agentID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + { + Name: "agent1", + ID: uuid.UUID{77}, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + }, + url: "wss://local.test/", + expectedErrorCode: workspacesdk.WatchErrorTooManyAgents, + expectedErrorRetryable: false, + }, + { + name: "namedAgentMultiAgent", + agents: []database.WorkspaceAgent{ + { + Name: "agent0", + ID: agentID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + { + Name: "agent1", + ID: uuid.UUID{77}, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + }, + }, + url: "wss://local.test/?agent_name=agent0", + expectedAgentUpdate: &workspacesdk.AgentUpdate{ + Lifecycle: codersdk.WorkspaceAgentLifecycleCreated, + ID: agentID, + }, + }, + { + name: "namedAgentNonexistent", + agents: []database.WorkspaceAgent{ + { + Name: "agent0", + ID: agentID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + { + Name: "agent1", + ID: uuid.UUID{77}, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + }, + url: "wss://local.test/?agent_name=agent2", + expectedErrorCode: workspacesdk.WatchErrorNameNotFound, + expectedErrorRetryable: false, + }, + { + name: "dbError", + agentDBError: xerrors.New("a bad thing happened"), + url: "wss://local.test/", + expectedErrorCode: workspacesdk.WatchErrorDatabase, + expectedErrorRetryable: true, + }, + { + name: "unauthorized", + agentDBError: dbauthz.NotAuthorizedError{Err: xerrors.New("not allowed")}, + url: "wss://local.test/", + expectedErrorCode: workspacesdk.WatchErrorDatabase, + expectedErrorRetryable: false, + }, + { + name: "noAgents", + agents: []database.WorkspaceAgent{}, + url: "wss://local.test/", + expectedErrorCode: workspacesdk.WatchErrorNoAgents, + expectedErrorRetryable: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + h := newHarness(ctx, t, logger) + + h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID). + Times(1). + Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{ + Transition: database.WorkspaceTransitionStart, + BuildNumber: 1, + JobStatus: database.ProvisionerJobStatusSucceeded, + WorkspaceTable: database.WorkspaceTable{ + ID: h.workspace.ID, + OwnerID: userID, + OrganizationID: orgID, + }, + }, nil) + // RBAC check for agent query + h.db.EXPECT().GetWorkspaceByID(gomock.Any(), h.workspace.ID). + Times(1). + Return(h.workspace, nil) + h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber( + gomock.Any(), + database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{ + WorkspaceID: h.workspace.ID, + BuildNumber: 1, + }). + Times(1). + Return(tc.agents, tc.agentDBError) + + dec, err := h.Dial(ctx, tc.url) + require.NoError(t, err) + defer dec.Close() + events := dec.Chan() + e0 := testutil.RequireReceive(ctx, t, events) + require.Equal(t, workspacesdk.ConnectionWatchEvent{ + BuildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransitionStart, + JobStatus: codersdk.ProvisionerJobSucceeded, + }, + }, e0) + + e1 := testutil.RequireReceive(ctx, t, events) + if tc.expectedAgentUpdate != nil { + require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: tc.expectedAgentUpdate}, e1) + } else { + require.NotNil(t, e1.Error) + require.Equal(t, tc.expectedErrorRetryable, e1.Error.Retryable) + require.Equal(t, tc.expectedErrorCode, e1.Error.Code) + } + }) + } +} + +func TestWatcher_LostAccess(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + h := newHarness(ctx, t, logger) + + h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID). + Times(1). + Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{ + Transition: database.WorkspaceTransitionStart, + BuildNumber: 1, + JobStatus: database.ProvisionerJobStatusSucceeded, + WorkspaceTable: database.WorkspaceTable{ + ID: h.workspace.ID, + OwnerID: uuid.UUID{99}, // workspace gets a new owner, e.g. + OrganizationID: orgID, + }, + }, nil) + + dec, err := h.Dial(ctx, "wss://local.test/") + require.NoError(t, err) + defer func() { + err := dec.Close() + require.NoError(t, err) + }() + events := dec.Chan() + e0 := testutil.RequireReceive(ctx, t, events) + require.NotNil(t, e0.Error) + require.Equal(t, workspacesdk.WatchErrorDatabase, e0.Error.Code) + require.False(t, e0.Error.Retryable) + require.Equal(t, "unauthorized", e0.Error.Details, "should not leak internal auth details") +} + +func TestWatcher_PublishChanges(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + h := newHarness(ctx, t, logger) + + // Initial build update, job is running. + build0 := h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID). + Times(1). + Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{ + Transition: database.WorkspaceTransitionStart, + BuildNumber: 1, + JobStatus: database.ProvisionerJobStatusRunning, + WorkspaceTable: database.WorkspaceTable{ + ID: h.workspace.ID, + OwnerID: userID, + OrganizationID: orgID, + }, + }, nil) + + dec, err := h.Dial(ctx, "wss://local.test/") + require.NoError(t, err) + defer func() { + err := dec.Close() + require.NoError(t, err) + }() + events := dec.Chan() + + e0 := testutil.RequireReceive(ctx, t, events) + require.Equal(t, workspacesdk.ConnectionWatchEvent{ + BuildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransitionStart, + JobStatus: codersdk.ProvisionerJobRunning, + }, + }, e0) + + // Since job is still running, we don't immediately query for agents. Next we set up the db queries and send in an + // update over the pubsub to kick a new query. + build1 := h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID). + After(build0). + Times(1). + Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{ + Transition: database.WorkspaceTransitionStart, + BuildNumber: 1, + JobStatus: database.ProvisionerJobStatusSucceeded, + WorkspaceTable: database.WorkspaceTable{ + ID: h.workspace.ID, + OwnerID: userID, + OrganizationID: orgID, + }, + }, nil) + // RBAC check for agent query + h.db.EXPECT().GetWorkspaceByID(gomock.Any(), h.workspace.ID). + After(build1). + Times(2). // these queries are identical between the initial and the update below + Return(h.workspace, nil) + agent0 := h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber( + gomock.Any(), + database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{ + WorkspaceID: h.workspace.ID, + BuildNumber: 1, + }). + After(build1). + Times(1). + Return([]database.WorkspaceAgent{ + { + Name: "test", + ID: agentID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }, + }, nil) + changeMsg := wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: h.workspace.ID, + } + changeBytes, err := json.Marshal(changeMsg) + require.NoError(t, err) + err = h.pub.Publish(wspubsub.WorkspaceEventChannel(h.workspace.OwnerID), changeBytes) + require.NoError(t, err) + + e1 := testutil.RequireReceive(ctx, t, events) + require.Equal(t, workspacesdk.ConnectionWatchEvent{ + BuildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransitionStart, + JobStatus: codersdk.ProvisionerJobSucceeded, + }, + }, e1) + e2 := testutil.RequireReceive(ctx, t, events) + require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{ + ID: agentID, + Lifecycle: codersdk.WorkspaceAgentLifecycleCreated, + }}, e2) + + // Finally, send in a change event for the agent. But first, program the mock for the expected query. + h.db.EXPECT().GetWorkspaceAgentsByWorkspaceAndBuildNumber( + gomock.Any(), + database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{ + WorkspaceID: h.workspace.ID, + BuildNumber: 1, + }). + After(agent0). + Times(1). + Return([]database.WorkspaceAgent{ + { + Name: "test", + ID: agentID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + }, + }, nil) + changeMsg = wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentLifecycleUpdate, + WorkspaceID: h.workspace.ID, + AgentID: &agentID, + } + changeBytes, err = json.Marshal(changeMsg) + require.NoError(t, err) + err = h.pub.Publish(wspubsub.WorkspaceEventChannel(h.workspace.OwnerID), changeBytes) + require.NoError(t, err) + + e3 := testutil.RequireReceive(ctx, t, events) + require.Equal(t, workspacesdk.ConnectionWatchEvent{AgentUpdate: &workspacesdk.AgentUpdate{ + ID: agentID, + Lifecycle: codersdk.WorkspaceAgentLifecycleReady, + }}, e3) +} + +func TestWatcher_ClosedBeforeDial(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + h := newHarness(ctx, t, logger) + h.watcher.Close() + _, err := h.Dial(ctx, "wss://local.test/") + var sdkError *codersdk.Error + require.True(t, errors.As(err, &sdkError)) + require.Equal(t, http.StatusServiceUnavailable, sdkError.StatusCode()) +} + +func TestWatcher_ClosedAfterDial(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + h := newHarness(ctx, t, logger) + + h.db.EXPECT().GetLatestWorkspaceBuildWithStatusByWorkspaceID(gomock.Any(), h.workspace.ID). + Times(1). + Return(database.GetLatestWorkspaceBuildWithStatusByWorkspaceIDRow{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 1, + JobStatus: database.ProvisionerJobStatusSucceeded, + WorkspaceTable: database.WorkspaceTable{ + ID: h.workspace.ID, + OwnerID: userID, + OrganizationID: orgID, + }, + }, nil) + + dec, err := h.Dial(ctx, "wss://local.test/") + require.NoError(t, err) + events := dec.Chan() + _ = testutil.RequireReceive(ctx, t, events) + + closed := make(chan struct{}) + go func() { + defer close(closed) + h.watcher.Close() + }() + + e := testutil.RequireReceive(ctx, t, events) + require.NotNil(t, e.Error) + require.Equal(t, workspacesdk.WatchErrorServerShutdown, e.Error.Code) + require.True(t, e.Error.Retryable) + + select { + case <-ctx.Done(): + t.Fatal("context timed out") + case _, ok := <-events: + require.False(t, ok, "socket not closed") + } + testutil.TryReceive(ctx, t, closed) +} diff --git a/codersdk/workspacesdk/tunneler/integration_test.go b/codersdk/workspacesdk/tunneler/integration_test.go new file mode 100644 index 0000000000..3fe2703995 --- /dev/null +++ b/codersdk/workspacesdk/tunneler/integration_test.go @@ -0,0 +1,100 @@ +package tunneler_test + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/tunneler" + "github.com/coder/coder/v2/testutil" +) + +// TestTunneler_Integration is an integration test using coderdtest. It should be removed when we integrate the Tunneler +// into coder ssh and those integration test cover this functionality. +func TestTunneler_Integration(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client, store := coderdtest.NewWithDatabase(t, nil) + logger := testutil.Logger(t) + client.SetLogger(logger.Named("client")) + first := coderdtest.CreateFirstUser(t, client) + userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) { + r.Username = "myuser" + }) + userClient.SetLogger(logger.Named("userclient")) + r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ + Name: "myworkspace", + OrganizationID: first.OrganizationID, + OwnerID: user.ID, + }).WithAgent().Do() + wsSDKClient := workspacesdk.New(userClient) + logs := &bytes.Buffer{} + + app := &sshApplication{ + t: t, + ctx: ctx, + done: make(chan struct{}), + } + + tun := tunneler.NewTunneler(wsSDKClient, tunneler.Config{ + WorkspaceID: r.Workspace.ID, + App: app, + WorkspaceStarter: nil, + AgentName: "", + LogWriter: logs, + DebugLogger: logger.Named("tunneler"), + }) + + testAgent := agenttest.New(t, client.URL, r.AgentToken) + defer testAgent.Close() + + testutil.TryReceive(ctx, t, app.done) + require.Equal(t, app.result, "foo\n") + + err := tun.GracefulShutdown(ctx) + require.NoError(t, err) +} + +type sshApplication struct { + t *testing.T + ctx context.Context + client *ssh.Client + done chan struct{} + result string +} + +func (s *sshApplication) Close() error { + return s.client.Close() +} + +func (s *sshApplication) Start(conn workspacesdk.AgentConn) error { + var err error + s.client, err = conn.SSHClient(s.ctx) + if err != nil { + s.t.Error(err) + return err + } + go func() { + defer close(s.done) + sess, err := s.client.NewSession() + if err != nil { + s.t.Error("failed to create session", err) + } + defer sess.Close() + out, err := sess.Output("echo foo") + if err != nil { + s.t.Error("failed to echo", err) + } + s.result = string(out) + }() + return nil +} diff --git a/codersdk/workspacesdk/tunneler/tunneler.go b/codersdk/workspacesdk/tunneler/tunneler.go index f1e7c1dd62..c2c6ceab90 100644 --- a/codersdk/workspacesdk/tunneler/tunneler.go +++ b/codersdk/workspacesdk/tunneler/tunneler.go @@ -11,6 +11,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" ) type state int @@ -33,6 +34,11 @@ type WorkspaceStarter interface { type Client interface { DialAgent(dialCtx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (workspacesdk.AgentConn, error) + WorkspaceAgentConnectionWatch( + dialCtx context.Context, workspaceID uuid.UUID, agentName string, + ) ( + dec *wsjson.Decoder[workspacesdk.ConnectionWatchEvent], err error, + ) } const ( @@ -127,9 +133,9 @@ type Config struct { // ordering. type tunnelerEvent struct { shutdownSignal *shutdownSignal - buildUpdate *buildUpdate + buildUpdate *workspacesdk.BuildUpdate provisionerJobLog *codersdk.ProvisionerJobLog - agentUpdate *agentUpdate + agentUpdate *workspacesdk.AgentUpdate agentLog *codersdk.WorkspaceAgentLog appUpdate *networkedApplicationUpdate tailnetUpdate *tailnetUpdate @@ -137,16 +143,6 @@ type tunnelerEvent struct { type shutdownSignal struct{} -type buildUpdate struct { - transition codersdk.WorkspaceTransition - jobStatus codersdk.ProvisionerJobStatus -} - -type agentUpdate struct { - lifecycle codersdk.WorkspaceAgentLifecycle - id uuid.UUID -} - type networkedApplicationUpdate struct { // up is true if the application is up. False if it is down. up bool @@ -174,10 +170,64 @@ func NewTunneler(client Client, config Config) *Tunneler { return t } +func (t *Tunneler) GracefulShutdown(ctx context.Context) error { + select { + case t.events <- tunnelerEvent{shutdownSignal: &shutdownSignal{}}: + case <-ctx.Done(): + return ctx.Err() + case <-t.ctx.Done(): + } + done := make(chan struct{}) + go func() { + defer close(done) + t.wg.Wait() + }() + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + func (t *Tunneler) start() { defer t.wg.Done() - // here we would subscribe to updates. - // t.client.AgentConnectionWatch(t.config.WorkspaceID, t.config.AgentName) + d, err := t.client.WorkspaceAgentConnectionWatch(t.ctx, t.config.WorkspaceID, t.config.AgentName) + // TODO: handle retries + if err != nil { + return + } + defer d.Close() + c := d.Chan() + for { + select { + case <-t.ctx.Done(): + return + case event, ok := <-c: + if !ok { + t.config.DebugLogger.Error(t.ctx, "watch closed") + } + if event.Error != nil { + t.config.DebugLogger.Error(t.ctx, "workspace agent connection watch error", slog.Error(event.Error)) + } + if !ok || event.Error != nil { + // TODO: handle retries + select { + case t.events <- tunnelerEvent{shutdownSignal: &shutdownSignal{}}: + case <-t.ctx.Done(): + } + return + } + select { + case <-t.ctx.Done(): + return + case t.events <- tunnelerEvent{ + buildUpdate: event.BuildUpdate, + agentUpdate: event.AgentUpdate, + }: + } + } + } } func (t *Tunneler) eventLoop() { @@ -235,13 +285,13 @@ func (t *Tunneler) handleSignal() { } } -func (t *Tunneler) handleBuildUpdate(update *buildUpdate) { +func (t *Tunneler) handleBuildUpdate(update *workspacesdk.BuildUpdate) { if t.state == shutdownTailnet || t.state == shutdownApplication || t.state == exit { return // no-op } var canMakeProgress, jobUnhealthy bool - switch update.jobStatus { + switch update.JobStatus { case codersdk.ProvisionerJobPending, codersdk.ProvisionerJobRunning: canMakeProgress = true case codersdk.ProvisionerJobSucceeded: @@ -249,21 +299,21 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) { jobUnhealthy = true } - if update.transition == codersdk.WorkspaceTransitionDelete { - t.config.DebugLogger.Info(t.ctx, "workspace is being deleted", slog.F("job_status", update.jobStatus)) + if update.Transition == codersdk.WorkspaceTransitionDelete { + t.config.DebugLogger.Info(t.ctx, "workspace is being deleted", slog.F("job_status", update.JobStatus)) // treat same as signal t.handleSignal() return } if jobUnhealthy { - t.config.DebugLogger.Info(t.ctx, "build job is in unhealthy state", slog.F("job_status", update.jobStatus)) + t.config.DebugLogger.Info(t.ctx, "build job is in unhealthy state", slog.F("job_status", update.JobStatus)) // treat same as signal t.handleSignal() return } - if update.transition == codersdk.WorkspaceTransitionStart && canMakeProgress { - t.config.DebugLogger.Debug(t.ctx, "workspace is starting", slog.F("job_status", update.jobStatus)) + if update.Transition == codersdk.WorkspaceTransitionStart && canMakeProgress { + t.config.DebugLogger.Debug(t.ctx, "workspace is starting", slog.F("job_status", update.JobStatus)) switch t.state { // new build after we have already connected case establishTailnet: // we are starting the tailnet @@ -279,8 +329,8 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) { } return } - if update.transition == codersdk.WorkspaceTransitionStart && update.jobStatus == codersdk.ProvisionerJobSucceeded { - t.config.DebugLogger.Debug(t.ctx, "workspace is started", slog.F("job_status", update.jobStatus)) + if update.Transition == codersdk.WorkspaceTransitionStart && update.JobStatus == codersdk.ProvisionerJobSucceeded { + t.config.DebugLogger.Debug(t.ctx, "workspace is started", slog.F("job_status", update.JobStatus)) switch t.state { case establishTailnet, applicationUp, tailnetUp: // no-op. Later agent updates will tell us whether the tailnet connection is current. @@ -290,7 +340,7 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) { return } - if update.transition == codersdk.WorkspaceTransitionStop { + if update.Transition == codersdk.WorkspaceTransitionStop { // these cases take effect regardless of whether the transition is complete or not switch t.state { // all 3 of these mean a new build after we have already started connecting @@ -312,7 +362,7 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) { t.state = exit return } - if update.jobStatus == codersdk.ProvisionerJobSucceeded { + if update.JobStatus == codersdk.ProvisionerJobSucceeded { switch t.state { case stateInit, waitToStart, waitForAgent: t.wg.Add(1) @@ -335,29 +385,29 @@ func (t *Tunneler) handleBuildUpdate(update *buildUpdate) { } // unhittable t.config.DebugLogger.Critical(t.ctx, "unhandled build update", - slog.F("job_status", update.jobStatus), slog.F("transition", update.transition), slog.F("state", t.state)) + slog.F("job_status", update.JobStatus), slog.F("transition", update.Transition), slog.F("state", t.state)) } func (*Tunneler) handleProvisionerJobLog(*codersdk.ProvisionerJobLog) { } -func (t *Tunneler) handleAgentUpdate(update *agentUpdate) { +func (t *Tunneler) handleAgentUpdate(update *workspacesdk.AgentUpdate) { t.config.DebugLogger.Debug(t.ctx, "handling agent update", slog.F("state", t.state), - slog.F("lifecycle", update.lifecycle), - slog.F("agent_id", update.id)) + slog.F("lifecycle", update.Lifecycle), + slog.F("agent_id", update.ID)) if t.state != waitForAgent { return } doConnect := func() { t.wg.Add(1) t.state = establishTailnet - go t.connectTailnet(update.id) + go t.connectTailnet(update.ID) } // consequence of ignoring updates if we are not waiting for the agent is that we MUST receive // the start build succeeded update BEFORE we get the Agent connected / ready update. We should keep this // in mind when implementing the watch in Coderd. - switch update.lifecycle { + switch update.Lifecycle { case codersdk.WorkspaceAgentLifecycleReady: doConnect() return @@ -376,7 +426,7 @@ func (t *Tunneler) handleAgentUpdate(update *agentUpdate) { default: // unhittable, unless new states are added. We structure this with the switch and all cases covered to ensure // we cover all cases. - t.config.DebugLogger.Critical(t.ctx, "unhandled agent update", slog.F("lifecycle", update.lifecycle)) + t.config.DebugLogger.Critical(t.ctx, "unhandled agent update", slog.F("lifecycle", update.Lifecycle)) } } diff --git a/codersdk/workspacesdk/tunneler/tunneler_internal_test.go b/codersdk/workspacesdk/tunneler/tunneler_internal_test.go index ecd63ff940..0b6f22a4b4 100644 --- a/codersdk/workspacesdk/tunneler/tunneler_internal_test.go +++ b/codersdk/workspacesdk/tunneler/tunneler_internal_test.go @@ -12,6 +12,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/testutil" ) @@ -28,7 +29,7 @@ func TestHandleBuildUpdate_Coverage(t *testing.T) { t.Run(fmt.Sprintf("%d_%s_%s_%t_%t", s, trans, jobStatus, noAutostart, noWaitForScripts), func(t *testing.T) { t.Parallel() coverUpdate(t, workspaceID, noAutostart, noWaitForScripts, s, func(uut *Tunneler) { - uut.handleBuildUpdate(&buildUpdate{transition: trans, jobStatus: jobStatus}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: trans, JobStatus: jobStatus}) }) }) } @@ -105,31 +106,31 @@ func TestBuildUpdatesStoppedWorkspace(t *testing.T) { state: stateInit, } - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobPending}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobPending}) require.Equal(t, waitToStart, uut.state) waitForGoroutines(testCtx, t, uut) require.False(t, fWorkspaceStarter.started) - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobRunning}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobRunning}) require.Equal(t, waitToStart, uut.state) waitForGoroutines(testCtx, t, uut) require.False(t, fWorkspaceStarter.started) // when stop job succeeds, we start the workspace - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobSucceeded}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobSucceeded}) require.Equal(t, waitForWorkspaceStarted, uut.state) waitForGoroutines(testCtx, t, uut) require.True(t, fWorkspaceStarter.started) - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobPending}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobPending}) require.Equal(t, waitForWorkspaceStarted, uut.state) waitForGoroutines(testCtx, t, uut) - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning}) require.Equal(t, waitForWorkspaceStarted, uut.state) waitForGoroutines(testCtx, t, uut) - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobSucceeded}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobSucceeded}) require.Equal(t, waitForAgent, uut.state) waitForGoroutines(testCtx, t, uut) } @@ -157,7 +158,7 @@ func TestBuildUpdatesNewBuildWhileWaiting(t *testing.T) { } // New build comes in while we are waiting for the agent to start. We roll back to waiting for the workspace to start. - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning}) require.Equal(t, waitForWorkspaceStarted, uut.state) waitForGoroutines(testCtx, t, uut) require.False(t, fWorkspaceStarter.started) @@ -193,12 +194,12 @@ func TestBuildUpdatesBadJobs(t *testing.T) { state: stateInit, } - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStart, jobStatus: codersdk.ProvisionerJobRunning}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStart, JobStatus: codersdk.ProvisionerJobRunning}) require.Equal(t, waitForWorkspaceStarted, uut.state) waitForGoroutines(testCtx, t, uut) require.False(t, fWorkspaceStarter.started) - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: jobStatus}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: jobStatus}) require.Equal(t, exit, uut.state) waitForGoroutines(testCtx, t, uut) require.False(t, fWorkspaceStarter.started) @@ -233,7 +234,7 @@ func TestBuildUpdatesNoAutostart(t *testing.T) { } // when stop job succeeds, we exit because autostart is disabled - uut.handleBuildUpdate(&buildUpdate{transition: codersdk.WorkspaceTransitionStop, jobStatus: codersdk.ProvisionerJobSucceeded}) + uut.handleBuildUpdate(&workspacesdk.BuildUpdate{Transition: codersdk.WorkspaceTransitionStop, JobStatus: codersdk.ProvisionerJobSucceeded}) require.Equal(t, exit, uut.state) waitForGoroutines(testCtx, t, uut) require.False(t, fWorkspaceStarter.started) @@ -254,7 +255,7 @@ func TestAgentUpdate_Coverage(t *testing.T) { t.Run(fmt.Sprintf("%d_%s_%t_%t", s, lifecycle, noAutostart, noWaitForScripts), func(t *testing.T) { t.Parallel() coverUpdate(t, workspaceID, noAutostart, noWaitForScripts, s, func(uut *Tunneler) { - uut.handleAgentUpdate(&agentUpdate{lifecycle: lifecycle, id: agentID}) + uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: lifecycle, ID: agentID}) }) }) } @@ -288,7 +289,7 @@ func TestAgentUpdateReady(t *testing.T) { client: fClient, } - uut.handleAgentUpdate(&agentUpdate{lifecycle: codersdk.WorkspaceAgentLifecycleReady, id: agentID}) + uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: codersdk.WorkspaceAgentLifecycleReady, ID: agentID}) require.Equal(t, establishTailnet, uut.state) event := testutil.RequireReceive(testCtx, t, uut.events) require.NotNil(t, event.tailnetUpdate) @@ -323,7 +324,7 @@ func TestAgentUpdateNoWait(t *testing.T) { client: fClient, } - uut.handleAgentUpdate(&agentUpdate{lifecycle: codersdk.WorkspaceAgentLifecycleStarting, id: agentID}) + uut.handleAgentUpdate(&workspacesdk.AgentUpdate{Lifecycle: codersdk.WorkspaceAgentLifecycleStarting, ID: agentID}) require.Equal(t, establishTailnet, uut.state) event := testutil.RequireReceive(testCtx, t, uut.events) require.NotNil(t, event.tailnetUpdate) @@ -526,27 +527,27 @@ func TestTunneler_EventLoop_Signal(t *testing.T) { go uut.eventLoop() testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{ - buildUpdate: &buildUpdate{ - transition: codersdk.WorkspaceTransitionStart, - jobStatus: codersdk.ProvisionerJobPending, + buildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransitionStart, + JobStatus: codersdk.ProvisionerJobPending, }, }) testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{ - buildUpdate: &buildUpdate{ - transition: codersdk.WorkspaceTransitionStart, - jobStatus: codersdk.ProvisionerJobRunning, + buildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransitionStart, + JobStatus: codersdk.ProvisionerJobRunning, }, }) testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{ - buildUpdate: &buildUpdate{ - transition: codersdk.WorkspaceTransitionStart, - jobStatus: codersdk.ProvisionerJobSucceeded, + buildUpdate: &workspacesdk.BuildUpdate{ + Transition: codersdk.WorkspaceTransitionStart, + JobStatus: codersdk.ProvisionerJobSucceeded, }, }) testutil.RequireSend(testCtx, t, uut.events, tunnelerEvent{ - agentUpdate: &agentUpdate{ - lifecycle: codersdk.WorkspaceAgentLifecycleReady, - id: agentID, + agentUpdate: &workspacesdk.AgentUpdate{ + Lifecycle: codersdk.WorkspaceAgentLifecycleReady, + ID: agentID, }, }) @@ -658,6 +659,11 @@ type fakeClient struct { dialed bool } +func (*fakeClient) WorkspaceAgentConnectionWatch(context.Context, uuid.UUID, string) (dec *wsjson.Decoder[workspacesdk.ConnectionWatchEvent], err error) { + // TODO implement me + panic("implement me") +} + func (f *fakeClient) DialAgent( _ context.Context, id uuid.UUID, _ *workspacesdk.DialAgentOptions, ) ( diff --git a/codersdk/workspacesdk/workspaceagentconnwatch.go b/codersdk/workspacesdk/workspaceagentconnwatch.go new file mode 100644 index 0000000000..a862554bc7 --- /dev/null +++ b/codersdk/workspacesdk/workspaceagentconnwatch.go @@ -0,0 +1,86 @@ +package workspacesdk + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/websocket" +) + +type WatchErrorCode int + +const ( + _ WatchErrorCode = iota // Ensure that zero value is not a valid code + WatchErrorTooManyAgents + WatchErrorNameNotFound + WatchErrorNoAgents + WatchErrorServerShutdown + WatchErrorDatabase + WatchErrorInternal +) + +type ConnectionWatchEvent struct { + Error *WatchError `json:"error"` + BuildUpdate *BuildUpdate `json:"build_update,omitempty"` + AgentUpdate *AgentUpdate `json:"agent_update,omitempty"` +} + +type WatchError struct { + Code WatchErrorCode `json:"code"` + Retryable bool `json:"retryable"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +func (e *WatchError) Error() string { + if e.Details != "" { + return fmt.Sprintf("%s: %s", e.Message, e.Details) + } + return e.Message +} + +type BuildUpdate struct { + Transition codersdk.WorkspaceTransition `json:"transition"` + JobStatus codersdk.ProvisionerJobStatus `json:"job_status"` +} + +type AgentUpdate struct { + Lifecycle codersdk.WorkspaceAgentLifecycle `json:"lifecycle"` + ID uuid.UUID `json:"id" format:"uuid"` +} + +func (c *Client) WorkspaceAgentConnectionWatch( + dialCtx context.Context, workspaceID uuid.UUID, agentName string, +) ( + dec *wsjson.Decoder[ConnectionWatchEvent], err error, +) { + wsOptions := &websocket.DialOptions{ + HTTPClient: c.client.HTTPClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + } + c.client.SessionTokenProvider.SetDialOption(wsOptions) + + watchURL, err := c.client.URL.Parse(fmt.Sprintf("/api/v2/workspaces/%s/agent-connection-watch", workspaceID)) + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + if agentName != "" { + q := watchURL.Query() + q.Set("agent_name", agentName) + watchURL.RawQuery = q.Encode() + } + + // nolint:bodyclose + conn, res, err := websocket.Dial(dialCtx, watchURL.String(), wsOptions) + if err != nil { + bodyErr := codersdk.ReadBodyAsError(res) + return nil, bodyErr + } + return wsjson.NewDecoder[ConnectionWatchEvent](conn, websocket.MessageText, c.client.Logger()), nil +} diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index d2968250d3..d428bb5bd2 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -18616,6 +18616,101 @@ None | `disable_direct_connections` | boolean | false | | | | `hostname_suffix` | string | false | | | +## workspacesdk.AgentUpdate + +```json +{ + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "lifecycle": "created" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-------------|----------------------------------------------------------------------|----------|--------------|-------------| +| `id` | string | false | | | +| `lifecycle` | [codersdk.WorkspaceAgentLifecycle](#codersdkworkspaceagentlifecycle) | false | | | + +## workspacesdk.BuildUpdate + +```json +{ + "job_status": "pending", + "transition": "start" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|--------------|----------------------------------------------------------------|----------|--------------|-------------| +| `job_status` | [codersdk.ProvisionerJobStatus](#codersdkprovisionerjobstatus) | false | | | +| `transition` | [codersdk.WorkspaceTransition](#codersdkworkspacetransition) | false | | | + +## workspacesdk.ConnectionWatchEvent + +```json +{ + "agent_update": { + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "lifecycle": "created" + }, + "build_update": { + "job_status": "pending", + "transition": "start" + }, + "error": { + "code": 0, + "details": "string", + "message": "string", + "retryable": true + } +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|----------------|------------------------------------------------------|----------|--------------|-------------| +| `agent_update` | [workspacesdk.AgentUpdate](#workspacesdkagentupdate) | false | | | +| `build_update` | [workspacesdk.BuildUpdate](#workspacesdkbuildupdate) | false | | | +| `error` | [workspacesdk.WatchError](#workspacesdkwatcherror) | false | | | + +## workspacesdk.WatchError + +```json +{ + "code": 0, + "details": "string", + "message": "string", + "retryable": true +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-------------|------------------------------------------------------------|----------|--------------|-------------| +| `code` | [workspacesdk.WatchErrorCode](#workspacesdkwatcherrorcode) | false | | | +| `details` | string | false | | | +| `message` | string | false | | | +| `retryable` | boolean | false | | | + +## workspacesdk.WatchErrorCode + +```json +0 +``` + +### Properties + +#### Enumerated Values + +| Value(s) | +|-----------------------------------| +| `0`, `1`, `2`, `3`, `4`, `5`, `6` | + ## wsproxysdk.CryptoKeysResponse ```json diff --git a/docs/reference/api/workspaces.md b/docs/reference/api/workspaces.md index b1faf32448..82fdea7c39 100644 --- a/docs/reference/api/workspaces.md +++ b/docs/reference/api/workspaces.md @@ -1817,6 +1817,56 @@ curl -X PATCH http://coder-server:8080/api/v2/workspaces/{workspace}/acl \ To perform this operation, you must be authenticated. [Learn more](authentication.md). +## Workspace Agent Connection Watch + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/agent-connection-watch \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /api/v2/workspaces/{workspace}/agent-connection-watch` + +### Parameters + +| Name | In | Type | Required | Description | +|-------------|------|--------------|----------|--------------| +| `workspace` | path | string(uuid) | true | Workspace ID | + +### Example responses + +> 101 Response + +```json +{ + "agent_update": { + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "lifecycle": "created" + }, + "build_update": { + "job_status": "pending", + "transition": "start" + }, + "error": { + "code": 0, + "details": "string", + "message": "string", + "retryable": true + } +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|--------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------| +| 101 | [Switching Protocols](https://tools.ietf.org/html/rfc7231#section-6.2.2) | Switching Protocols | [workspacesdk.ConnectionWatchEvent](schemas.md#workspacesdkconnectionwatchevent) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## Update workspace autostart schedule by ID ### Code samples diff --git a/enterprise/coderd/coderdenttest/swagger_test.go b/enterprise/coderd/coderdenttest/swagger_test.go index f727a68a89..0a48695f38 100644 --- a/enterprise/coderd/coderdenttest/swagger_test.go +++ b/enterprise/coderd/coderdenttest/swagger_test.go @@ -12,7 +12,8 @@ import ( func TestEnterpriseEndpointsDocumented(t *testing.T) { t.Parallel() - swaggerComments, err := coderdtest.ParseSwaggerComments("..", "../../../coderd") + swaggerComments, err := coderdtest.ParseSwaggerComments( + "..", "../../../coderd", "../../../coderd/workspaceconnwatcher") require.NoError(t, err, "can't parse swagger comments") require.NotEmpty(t, swaggerComments, "swagger comments must be present") diff --git a/scripts/apidocgen/swaginit/main.go b/scripts/apidocgen/swaginit/main.go index b6a60bb59e..4774323e81 100644 --- a/scripts/apidocgen/swaginit/main.go +++ b/scripts/apidocgen/swaginit/main.go @@ -22,7 +22,7 @@ func main() { } err := gen.New().Build(&gen.Config{ - SearchDir: "./coderd,./codersdk,./enterprise/coderd,./enterprise/wsproxy/wsproxysdk", + SearchDir: "./coderd,./coderd/workspaceconnwatcher,./codersdk,./enterprise/coderd,./enterprise/wsproxy/wsproxysdk", MainAPIFile: "coderd.go", OutputDir: outputDir, OutputTypes: []string{"go", "json"}, diff --git a/testutil/websocket.go b/testutil/websocket.go new file mode 100644 index 0000000000..026f1c3590 --- /dev/null +++ b/testutil/websocket.go @@ -0,0 +1,101 @@ +package testutil + +import ( + "bufio" + "context" + "io" + "net" + "net/http" + + "cdr.dev/slog/v3" +) + +// InMemWebsocketRoundTripper allows you to "dial" an HTTP handler that sets up a websocket using only in-memory +// primitives. No TCP or OS networking needed. CtxMutator gives you explicit control over the context the handler sees. +// +// Example: +// +// rt := testutil.InMemWebsocketRoundTripper{ +// Handler: MyHandler, +// CtxMutator: func(ctx context.Context) context.Context { +// ctx = httpmw.WithWorkspaceParam(ctx, ws) +// ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID)) +// return ctx +// }, +// Logger: logger.Named("roundtripper"), +// } +// clientSock, _, err := websocket.Dial(ctx, "wss://local.test/", &websocket.DialOptions{ +// HTTPClient: &http.Client{Transport: rt}, +// }) +type InMemWebsocketRoundTripper struct { + Logger slog.Logger + Handler http.Handler + CtxMutator func(ctx context.Context) context.Context +} + +func (i InMemWebsocketRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + i.Logger.Debug(context.Background(), "round trip start") + defer i.Logger.Debug(context.Background(), "round trip end") + newCtx := i.CtxMutator(request.Context()) + request = request.WithContext(newCtx) + serverP, clientP := net.Pipe() + var _ io.ReadWriteCloser = clientP // compile time check that response body is OK for websocket + response := &http.Response{ + Header: make(http.Header), + Body: clientP, + } + rw := newInMemWebsocketResponseWriter(response, serverP) + go func() { + i.Handler.ServeHTTP(rw, request) + if !rw.hijacked { + i.Logger.Debug(context.Background(), "closing connection after handler did not hijack") + // If the handler didn't hijack the connection, we should close it when the handler finishes. + // This prevents a 3s delay in websocket.Dial() reading the non-upgraded response. + _ = serverP.Close() + } + }() + select { + case <-newCtx.Done(): + return nil, newCtx.Err() + case <-rw.gotHeaders: + return response, nil + } +} + +func newInMemWebsocketResponseWriter(resp *http.Response, conn net.Conn) *inMemWebsocketResponseWriter { + r := bufio.NewReader(conn) + w := bufio.NewWriter(conn) + return &inMemWebsocketResponseWriter{ + r: resp, + b: bufio.NewReadWriter(r, w), + gotHeaders: make(chan struct{}), + conn: conn, + } +} + +type inMemWebsocketResponseWriter struct { + r *http.Response + b *bufio.ReadWriter + gotHeaders chan struct{} + hijacked bool + conn net.Conn +} + +func (rw *inMemWebsocketResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw.hijacked = true + return rw.conn, rw.b, nil +} + +func (rw *inMemWebsocketResponseWriter) Header() http.Header { + return rw.r.Header +} + +func (rw *inMemWebsocketResponseWriter) Write([]byte) (int, error) { + n, err := rw.b.Write([]byte{}) + return n, err +} + +func (rw *inMemWebsocketResponseWriter) WriteHeader(statusCode int) { + rw.r.StatusCode = statusCode + close(rw.gotHeaders) +}