From 2840fdcb54502b46b5810b0c73389d2d66af9a5e Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Fri, 21 Nov 2025 13:09:27 +0200 Subject: [PATCH] feat(agent): add agent socket API (#20717) relates to: https://github.com/coder/internal/issues/1094 This is number 2 of 5 pull requests in an effort to add agent script ordering. It adds a drpc API that is exposed via a local socket. This API serves access to a lightweight DAG based dependency manager that was inspired by systemd. In follow-up PRs: * This unit manager will be plumbed into the workspace agent struct. * CLI commands will use this agentsocket api to express dependencies between coder scripts I used an LLM to produce some of these changes, but I have conducted thorough self review and consider this contribution to be ready for an external reviewer. --- Makefile | 10 + agent/agentsocket/proto/agentsocket.pb.go | 968 ++++++++++++++++++ agent/agentsocket/proto/agentsocket.proto | 69 ++ .../agentsocket/proto/agentsocket_drpc.pb.go | 311 ++++++ agent/agentsocket/proto/version.go | 17 + agent/agentsocket/server.go | 185 ++++ agent/agentsocket/server_test.go | 52 + agent/agentsocket/service.go | 142 +++ agent/agentsocket/service_test.go | 470 +++++++++ agent/agentsocket/socket_unix.go | 83 ++ agent/agentsocket/socket_windows.go | 27 + agent/unit/manager.go | 45 +- agent/unit/manager_test.go | 249 ++++- 13 files changed, 2568 insertions(+), 60 deletions(-) create mode 100644 agent/agentsocket/proto/agentsocket.pb.go create mode 100644 agent/agentsocket/proto/agentsocket.proto create mode 100644 agent/agentsocket/proto/agentsocket_drpc.pb.go create mode 100644 agent/agentsocket/proto/version.go create mode 100644 agent/agentsocket/server.go create mode 100644 agent/agentsocket/server_test.go create mode 100644 agent/agentsocket/service.go create mode 100644 agent/agentsocket/service_test.go create mode 100644 agent/agentsocket/socket_unix.go create mode 100644 agent/agentsocket/socket_windows.go diff --git a/Makefile b/Makefile index 7ecb64975e..4997430f9d 100644 --- a/Makefile +++ b/Makefile @@ -642,6 +642,7 @@ AIBRIDGED_MOCKS := \ GEN_FILES := \ tailnet/proto/tailnet.pb.go \ agent/proto/agent.pb.go \ + agent/agentsocket/proto/agentsocket.pb.go \ provisionersdk/proto/provisioner.pb.go \ provisionerd/proto/provisionerd.pb.go \ vpn/vpn.pb.go \ @@ -696,6 +697,7 @@ gen/mark-fresh: agent/proto/agent.pb.go \ provisionersdk/proto/provisioner.pb.go \ provisionerd/proto/provisionerd.pb.go \ + agent/agentsocket/proto/agentsocket.pb.go \ vpn/vpn.pb.go \ enterprise/aibridged/proto/aibridged.pb.go \ coderd/database/dump.sql \ @@ -800,6 +802,14 @@ agent/proto/agent.pb.go: agent/proto/agent.proto --go-drpc_opt=paths=source_relative \ ./agent/proto/agent.proto +agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto + protoc \ + --go_out=. \ + --go_opt=paths=source_relative \ + --go-drpc_out=. \ + --go-drpc_opt=paths=source_relative \ + ./agent/agentsocket/proto/agentsocket.proto + provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto protoc \ --go_out=. \ diff --git a/agent/agentsocket/proto/agentsocket.pb.go b/agent/agentsocket/proto/agentsocket.pb.go new file mode 100644 index 0000000000..b2b1d922a8 --- /dev/null +++ b/agent/agentsocket/proto/agentsocket.pb.go @@ -0,0 +1,968 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v4.23.4 +// source: agent/agentsocket/proto/agentsocket.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type PingRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *PingRequest) Reset() { + *x = PingRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PingRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PingRequest) ProtoMessage() {} + +func (x *PingRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PingRequest.ProtoReflect.Descriptor instead. +func (*PingRequest) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{0} +} + +type PingResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *PingResponse) Reset() { + *x = PingResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PingResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PingResponse) ProtoMessage() {} + +func (x *PingResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PingResponse.ProtoReflect.Descriptor instead. +func (*PingResponse) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{1} +} + +type SyncStartRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"` +} + +func (x *SyncStartRequest) Reset() { + *x = SyncStartRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncStartRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncStartRequest) ProtoMessage() {} + +func (x *SyncStartRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncStartRequest.ProtoReflect.Descriptor instead. +func (*SyncStartRequest) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{2} +} + +func (x *SyncStartRequest) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +type SyncStartResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SyncStartResponse) Reset() { + *x = SyncStartResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncStartResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncStartResponse) ProtoMessage() {} + +func (x *SyncStartResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncStartResponse.ProtoReflect.Descriptor instead. +func (*SyncStartResponse) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{3} +} + +type SyncWantRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"` + DependsOn string `protobuf:"bytes,2,opt,name=depends_on,json=dependsOn,proto3" json:"depends_on,omitempty"` +} + +func (x *SyncWantRequest) Reset() { + *x = SyncWantRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncWantRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncWantRequest) ProtoMessage() {} + +func (x *SyncWantRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncWantRequest.ProtoReflect.Descriptor instead. +func (*SyncWantRequest) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{4} +} + +func (x *SyncWantRequest) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +func (x *SyncWantRequest) GetDependsOn() string { + if x != nil { + return x.DependsOn + } + return "" +} + +type SyncWantResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SyncWantResponse) Reset() { + *x = SyncWantResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncWantResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncWantResponse) ProtoMessage() {} + +func (x *SyncWantResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncWantResponse.ProtoReflect.Descriptor instead. +func (*SyncWantResponse) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{5} +} + +type SyncCompleteRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"` +} + +func (x *SyncCompleteRequest) Reset() { + *x = SyncCompleteRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncCompleteRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncCompleteRequest) ProtoMessage() {} + +func (x *SyncCompleteRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncCompleteRequest.ProtoReflect.Descriptor instead. +func (*SyncCompleteRequest) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{6} +} + +func (x *SyncCompleteRequest) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +type SyncCompleteResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SyncCompleteResponse) Reset() { + *x = SyncCompleteResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncCompleteResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncCompleteResponse) ProtoMessage() {} + +func (x *SyncCompleteResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncCompleteResponse.ProtoReflect.Descriptor instead. +func (*SyncCompleteResponse) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{7} +} + +type SyncReadyRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"` +} + +func (x *SyncReadyRequest) Reset() { + *x = SyncReadyRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncReadyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncReadyRequest) ProtoMessage() {} + +func (x *SyncReadyRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncReadyRequest.ProtoReflect.Descriptor instead. +func (*SyncReadyRequest) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{8} +} + +func (x *SyncReadyRequest) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +type SyncReadyResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Ready bool `protobuf:"varint,1,opt,name=ready,proto3" json:"ready,omitempty"` +} + +func (x *SyncReadyResponse) Reset() { + *x = SyncReadyResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncReadyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncReadyResponse) ProtoMessage() {} + +func (x *SyncReadyResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncReadyResponse.ProtoReflect.Descriptor instead. +func (*SyncReadyResponse) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{9} +} + +func (x *SyncReadyResponse) GetReady() bool { + if x != nil { + return x.Ready + } + return false +} + +type SyncStatusRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"` +} + +func (x *SyncStatusRequest) Reset() { + *x = SyncStatusRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncStatusRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncStatusRequest) ProtoMessage() {} + +func (x *SyncStatusRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncStatusRequest.ProtoReflect.Descriptor instead. +func (*SyncStatusRequest) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{10} +} + +func (x *SyncStatusRequest) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +type DependencyInfo struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Unit string `protobuf:"bytes,1,opt,name=unit,proto3" json:"unit,omitempty"` + DependsOn string `protobuf:"bytes,2,opt,name=depends_on,json=dependsOn,proto3" json:"depends_on,omitempty"` + RequiredStatus string `protobuf:"bytes,3,opt,name=required_status,json=requiredStatus,proto3" json:"required_status,omitempty"` + CurrentStatus string `protobuf:"bytes,4,opt,name=current_status,json=currentStatus,proto3" json:"current_status,omitempty"` + IsSatisfied bool `protobuf:"varint,5,opt,name=is_satisfied,json=isSatisfied,proto3" json:"is_satisfied,omitempty"` +} + +func (x *DependencyInfo) Reset() { + *x = DependencyInfo{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DependencyInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DependencyInfo) ProtoMessage() {} + +func (x *DependencyInfo) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DependencyInfo.ProtoReflect.Descriptor instead. +func (*DependencyInfo) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{11} +} + +func (x *DependencyInfo) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +func (x *DependencyInfo) GetDependsOn() string { + if x != nil { + return x.DependsOn + } + return "" +} + +func (x *DependencyInfo) GetRequiredStatus() string { + if x != nil { + return x.RequiredStatus + } + return "" +} + +func (x *DependencyInfo) GetCurrentStatus() string { + if x != nil { + return x.CurrentStatus + } + return "" +} + +func (x *DependencyInfo) GetIsSatisfied() bool { + if x != nil { + return x.IsSatisfied + } + return false +} + +type SyncStatusResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status string `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` + IsReady bool `protobuf:"varint,2,opt,name=is_ready,json=isReady,proto3" json:"is_ready,omitempty"` + Dependencies []*DependencyInfo `protobuf:"bytes,3,rep,name=dependencies,proto3" json:"dependencies,omitempty"` +} + +func (x *SyncStatusResponse) Reset() { + *x = SyncStatusResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncStatusResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncStatusResponse) ProtoMessage() {} + +func (x *SyncStatusResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncStatusResponse.ProtoReflect.Descriptor instead. +func (*SyncStatusResponse) Descriptor() ([]byte, []int) { + return file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP(), []int{12} +} + +func (x *SyncStatusResponse) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +func (x *SyncStatusResponse) GetIsReady() bool { + if x != nil { + return x.IsReady + } + return false +} + +func (x *SyncStatusResponse) GetDependencies() []*DependencyInfo { + if x != nil { + return x.Dependencies + } + return nil +} + +var File_agent_agentsocket_proto_agentsocket_proto protoreflect.FileDescriptor + +var file_agent_agentsocket_proto_agentsocket_proto_rawDesc = []byte{ + 0x0a, 0x29, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, + 0x6b, 0x65, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, + 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x63, 0x6f, 0x64, + 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, + 0x31, 0x22, 0x0d, 0x0a, 0x0b, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x22, 0x0e, 0x0a, 0x0c, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x26, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x22, 0x13, 0x0a, 0x11, 0x53, 0x79, 0x6e, 0x63, + 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x44, 0x0a, + 0x0f, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x75, 0x6e, 0x69, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x73, 0x5f, + 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, + 0x73, 0x4f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x29, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x43, + 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, + 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e, + 0x69, 0x74, 0x22, 0x16, 0x0a, 0x14, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, + 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x26, 0x0a, 0x10, 0x53, 0x79, + 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, + 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e, + 0x69, 0x74, 0x22, 0x29, 0x0a, 0x11, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x65, 0x61, 0x64, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x72, 0x65, 0x61, 0x64, 0x79, 0x22, 0x27, 0x0a, + 0x11, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x22, 0xb6, 0x01, 0x0a, 0x0e, 0x44, 0x65, 0x70, 0x65, 0x6e, + 0x64, 0x65, 0x6e, 0x63, 0x79, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x6e, 0x69, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x12, 0x1d, 0x0a, + 0x0a, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x73, 0x5f, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x09, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x73, 0x4f, 0x6e, 0x12, 0x27, 0x0a, 0x0f, + 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, + 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, + 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x21, 0x0a, 0x0c, + 0x69, 0x73, 0x5f, 0x73, 0x61, 0x74, 0x69, 0x73, 0x66, 0x69, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x53, 0x61, 0x74, 0x69, 0x73, 0x66, 0x69, 0x65, 0x64, 0x22, + 0x91, 0x01, 0x0a, 0x12, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x19, + 0x0a, 0x08, 0x69, 0x73, 0x5f, 0x72, 0x65, 0x61, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x07, 0x69, 0x73, 0x52, 0x65, 0x61, 0x64, 0x79, 0x12, 0x48, 0x0a, 0x0c, 0x64, 0x65, 0x70, + 0x65, 0x6e, 0x64, 0x65, 0x6e, 0x63, 0x69, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x24, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, + 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x6e, 0x63, + 0x79, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0c, 0x64, 0x65, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x6e, 0x63, + 0x69, 0x65, 0x73, 0x32, 0xbb, 0x04, 0x0a, 0x0b, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x6f, 0x63, + 0x6b, 0x65, 0x74, 0x12, 0x4d, 0x0a, 0x04, 0x50, 0x69, 0x6e, 0x67, 0x12, 0x21, 0x2e, 0x63, 0x6f, + 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, + 0x76, 0x31, 0x2e, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, + 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, + 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x5c, 0x0a, 0x09, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x12, + 0x26, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, + 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, + 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, + 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x59, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x12, 0x25, 0x2e, 0x63, + 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, + 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x57, 0x61, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, + 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x57, + 0x61, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x65, 0x0a, 0x0c, 0x53, + 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x12, 0x29, 0x2e, 0x63, 0x6f, + 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, + 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2a, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, + 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, + 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x5c, 0x0a, 0x09, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x12, + 0x26, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, + 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, + 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, + 0x79, 0x6e, 0x63, 0x52, 0x65, 0x61, 0x64, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x5f, 0x0a, 0x0a, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x27, + 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, + 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2e, + 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, + 0x79, 0x6e, 0x63, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x42, 0x33, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x76, 0x32, 0x2f, 0x61, + 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, + 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_agent_agentsocket_proto_agentsocket_proto_rawDescOnce sync.Once + file_agent_agentsocket_proto_agentsocket_proto_rawDescData = file_agent_agentsocket_proto_agentsocket_proto_rawDesc +) + +func file_agent_agentsocket_proto_agentsocket_proto_rawDescGZIP() []byte { + file_agent_agentsocket_proto_agentsocket_proto_rawDescOnce.Do(func() { + file_agent_agentsocket_proto_agentsocket_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_agentsocket_proto_agentsocket_proto_rawDescData) + }) + return file_agent_agentsocket_proto_agentsocket_proto_rawDescData +} + +var file_agent_agentsocket_proto_agentsocket_proto_msgTypes = make([]protoimpl.MessageInfo, 13) +var file_agent_agentsocket_proto_agentsocket_proto_goTypes = []interface{}{ + (*PingRequest)(nil), // 0: coder.agentsocket.v1.PingRequest + (*PingResponse)(nil), // 1: coder.agentsocket.v1.PingResponse + (*SyncStartRequest)(nil), // 2: coder.agentsocket.v1.SyncStartRequest + (*SyncStartResponse)(nil), // 3: coder.agentsocket.v1.SyncStartResponse + (*SyncWantRequest)(nil), // 4: coder.agentsocket.v1.SyncWantRequest + (*SyncWantResponse)(nil), // 5: coder.agentsocket.v1.SyncWantResponse + (*SyncCompleteRequest)(nil), // 6: coder.agentsocket.v1.SyncCompleteRequest + (*SyncCompleteResponse)(nil), // 7: coder.agentsocket.v1.SyncCompleteResponse + (*SyncReadyRequest)(nil), // 8: coder.agentsocket.v1.SyncReadyRequest + (*SyncReadyResponse)(nil), // 9: coder.agentsocket.v1.SyncReadyResponse + (*SyncStatusRequest)(nil), // 10: coder.agentsocket.v1.SyncStatusRequest + (*DependencyInfo)(nil), // 11: coder.agentsocket.v1.DependencyInfo + (*SyncStatusResponse)(nil), // 12: coder.agentsocket.v1.SyncStatusResponse +} +var file_agent_agentsocket_proto_agentsocket_proto_depIdxs = []int32{ + 11, // 0: coder.agentsocket.v1.SyncStatusResponse.dependencies:type_name -> coder.agentsocket.v1.DependencyInfo + 0, // 1: coder.agentsocket.v1.AgentSocket.Ping:input_type -> coder.agentsocket.v1.PingRequest + 2, // 2: coder.agentsocket.v1.AgentSocket.SyncStart:input_type -> coder.agentsocket.v1.SyncStartRequest + 4, // 3: coder.agentsocket.v1.AgentSocket.SyncWant:input_type -> coder.agentsocket.v1.SyncWantRequest + 6, // 4: coder.agentsocket.v1.AgentSocket.SyncComplete:input_type -> coder.agentsocket.v1.SyncCompleteRequest + 8, // 5: coder.agentsocket.v1.AgentSocket.SyncReady:input_type -> coder.agentsocket.v1.SyncReadyRequest + 10, // 6: coder.agentsocket.v1.AgentSocket.SyncStatus:input_type -> coder.agentsocket.v1.SyncStatusRequest + 1, // 7: coder.agentsocket.v1.AgentSocket.Ping:output_type -> coder.agentsocket.v1.PingResponse + 3, // 8: coder.agentsocket.v1.AgentSocket.SyncStart:output_type -> coder.agentsocket.v1.SyncStartResponse + 5, // 9: coder.agentsocket.v1.AgentSocket.SyncWant:output_type -> coder.agentsocket.v1.SyncWantResponse + 7, // 10: coder.agentsocket.v1.AgentSocket.SyncComplete:output_type -> coder.agentsocket.v1.SyncCompleteResponse + 9, // 11: coder.agentsocket.v1.AgentSocket.SyncReady:output_type -> coder.agentsocket.v1.SyncReadyResponse + 12, // 12: coder.agentsocket.v1.AgentSocket.SyncStatus:output_type -> coder.agentsocket.v1.SyncStatusResponse + 7, // [7:13] is the sub-list for method output_type + 1, // [1:7] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_agent_agentsocket_proto_agentsocket_proto_init() } +func file_agent_agentsocket_proto_agentsocket_proto_init() { + if File_agent_agentsocket_proto_agentsocket_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PingRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PingResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncStartRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncStartResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncWantRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncWantResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncCompleteRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncCompleteResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncReadyRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncReadyResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncStatusRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DependencyInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_agent_agentsocket_proto_agentsocket_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncStatusResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_agent_agentsocket_proto_agentsocket_proto_rawDesc, + NumEnums: 0, + NumMessages: 13, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_agent_agentsocket_proto_agentsocket_proto_goTypes, + DependencyIndexes: file_agent_agentsocket_proto_agentsocket_proto_depIdxs, + MessageInfos: file_agent_agentsocket_proto_agentsocket_proto_msgTypes, + }.Build() + File_agent_agentsocket_proto_agentsocket_proto = out.File + file_agent_agentsocket_proto_agentsocket_proto_rawDesc = nil + file_agent_agentsocket_proto_agentsocket_proto_goTypes = nil + file_agent_agentsocket_proto_agentsocket_proto_depIdxs = nil +} diff --git a/agent/agentsocket/proto/agentsocket.proto b/agent/agentsocket/proto/agentsocket.proto new file mode 100644 index 0000000000..2da2ad7380 --- /dev/null +++ b/agent/agentsocket/proto/agentsocket.proto @@ -0,0 +1,69 @@ +syntax = "proto3"; +option go_package = "github.com/coder/coder/v2/agent/agentsocket/proto"; + +package coder.agentsocket.v1; + +message PingRequest {} + +message PingResponse {} + +message SyncStartRequest { + string unit = 1; +} + +message SyncStartResponse {} + +message SyncWantRequest { + string unit = 1; + string depends_on = 2; +} + +message SyncWantResponse {} + +message SyncCompleteRequest { + string unit = 1; +} + +message SyncCompleteResponse {} + +message SyncReadyRequest { + string unit = 1; +} + +message SyncReadyResponse { + bool ready = 1; +} + +message SyncStatusRequest { + string unit = 1; +} + +message DependencyInfo { + string unit = 1; + string depends_on = 2; + string required_status = 3; + string current_status = 4; + bool is_satisfied = 5; +} + +message SyncStatusResponse { + string status = 1; + bool is_ready = 2; + repeated DependencyInfo dependencies = 3; +} + +// AgentSocket provides direct access to the agent over local IPC. +service AgentSocket { + // Ping the agent to check if it is alive. + rpc Ping(PingRequest) returns (PingResponse); + // Report the start of a unit. + rpc SyncStart(SyncStartRequest) returns (SyncStartResponse); + // Declare a dependency between units. + rpc SyncWant(SyncWantRequest) returns (SyncWantResponse); + // Report the completion of a unit. + rpc SyncComplete(SyncCompleteRequest) returns (SyncCompleteResponse); + // Request whether a unit is ready to be started. That is, all dependencies are satisfied. + rpc SyncReady(SyncReadyRequest) returns (SyncReadyResponse); + // Get the status of a unit and list its dependencies. + rpc SyncStatus(SyncStatusRequest) returns (SyncStatusResponse); +} diff --git a/agent/agentsocket/proto/agentsocket_drpc.pb.go b/agent/agentsocket/proto/agentsocket_drpc.pb.go new file mode 100644 index 0000000000..f9749ee0ff --- /dev/null +++ b/agent/agentsocket/proto/agentsocket_drpc.pb.go @@ -0,0 +1,311 @@ +// Code generated by protoc-gen-go-drpc. DO NOT EDIT. +// protoc-gen-go-drpc version: v0.0.34 +// source: agent/agentsocket/proto/agentsocket.proto + +package proto + +import ( + context "context" + errors "errors" + protojson "google.golang.org/protobuf/encoding/protojson" + proto "google.golang.org/protobuf/proto" + drpc "storj.io/drpc" + drpcerr "storj.io/drpc/drpcerr" +) + +type drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto struct{} + +func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Marshal(msg drpc.Message) ([]byte, error) { + return proto.Marshal(msg.(proto.Message)) +} + +func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) { + return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message)) +} + +func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Unmarshal(buf []byte, msg drpc.Message) error { + return proto.Unmarshal(buf, msg.(proto.Message)) +} + +func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONMarshal(msg drpc.Message) ([]byte, error) { + return protojson.Marshal(msg.(proto.Message)) +} + +func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error { + return protojson.Unmarshal(buf, msg.(proto.Message)) +} + +type DRPCAgentSocketClient interface { + DRPCConn() drpc.Conn + + Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) + SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) + SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) + SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) + SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) + SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) +} + +type drpcAgentSocketClient struct { + cc drpc.Conn +} + +func NewDRPCAgentSocketClient(cc drpc.Conn) DRPCAgentSocketClient { + return &drpcAgentSocketClient{cc} +} + +func (c *drpcAgentSocketClient) DRPCConn() drpc.Conn { return c.cc } + +func (c *drpcAgentSocketClient) Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) { + out := new(PingResponse) + err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *drpcAgentSocketClient) SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) { + out := new(SyncStartResponse) + err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *drpcAgentSocketClient) SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) { + out := new(SyncWantResponse) + err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *drpcAgentSocketClient) SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) { + out := new(SyncCompleteResponse) + err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *drpcAgentSocketClient) SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) { + out := new(SyncReadyResponse) + err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *drpcAgentSocketClient) SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) { + out := new(SyncStatusResponse) + err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +type DRPCAgentSocketServer interface { + Ping(context.Context, *PingRequest) (*PingResponse, error) + SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) + SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) + SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) + SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) + SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) +} + +type DRPCAgentSocketUnimplementedServer struct{} + +func (s *DRPCAgentSocketUnimplementedServer) Ping(context.Context, *PingRequest) (*PingResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +func (s *DRPCAgentSocketUnimplementedServer) SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +func (s *DRPCAgentSocketUnimplementedServer) SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +func (s *DRPCAgentSocketUnimplementedServer) SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +func (s *DRPCAgentSocketUnimplementedServer) SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +func (s *DRPCAgentSocketUnimplementedServer) SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +type DRPCAgentSocketDescription struct{} + +func (DRPCAgentSocketDescription) NumMethods() int { return 6 } + +func (DRPCAgentSocketDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { + switch n { + case 0: + return "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCAgentSocketServer). + Ping( + ctx, + in1.(*PingRequest), + ) + }, DRPCAgentSocketServer.Ping, true + case 1: + return "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCAgentSocketServer). + SyncStart( + ctx, + in1.(*SyncStartRequest), + ) + }, DRPCAgentSocketServer.SyncStart, true + case 2: + return "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCAgentSocketServer). + SyncWant( + ctx, + in1.(*SyncWantRequest), + ) + }, DRPCAgentSocketServer.SyncWant, true + case 3: + return "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCAgentSocketServer). + SyncComplete( + ctx, + in1.(*SyncCompleteRequest), + ) + }, DRPCAgentSocketServer.SyncComplete, true + case 4: + return "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCAgentSocketServer). + SyncReady( + ctx, + in1.(*SyncReadyRequest), + ) + }, DRPCAgentSocketServer.SyncReady, true + case 5: + return "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCAgentSocketServer). + SyncStatus( + ctx, + in1.(*SyncStatusRequest), + ) + }, DRPCAgentSocketServer.SyncStatus, true + default: + return "", nil, nil, nil, false + } +} + +func DRPCRegisterAgentSocket(mux drpc.Mux, impl DRPCAgentSocketServer) error { + return mux.Register(impl, DRPCAgentSocketDescription{}) +} + +type DRPCAgentSocket_PingStream interface { + drpc.Stream + SendAndClose(*PingResponse) error +} + +type drpcAgentSocket_PingStream struct { + drpc.Stream +} + +func (x *drpcAgentSocket_PingStream) SendAndClose(m *PingResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil { + return err + } + return x.CloseSend() +} + +type DRPCAgentSocket_SyncStartStream interface { + drpc.Stream + SendAndClose(*SyncStartResponse) error +} + +type drpcAgentSocket_SyncStartStream struct { + drpc.Stream +} + +func (x *drpcAgentSocket_SyncStartStream) SendAndClose(m *SyncStartResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil { + return err + } + return x.CloseSend() +} + +type DRPCAgentSocket_SyncWantStream interface { + drpc.Stream + SendAndClose(*SyncWantResponse) error +} + +type drpcAgentSocket_SyncWantStream struct { + drpc.Stream +} + +func (x *drpcAgentSocket_SyncWantStream) SendAndClose(m *SyncWantResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil { + return err + } + return x.CloseSend() +} + +type DRPCAgentSocket_SyncCompleteStream interface { + drpc.Stream + SendAndClose(*SyncCompleteResponse) error +} + +type drpcAgentSocket_SyncCompleteStream struct { + drpc.Stream +} + +func (x *drpcAgentSocket_SyncCompleteStream) SendAndClose(m *SyncCompleteResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil { + return err + } + return x.CloseSend() +} + +type DRPCAgentSocket_SyncReadyStream interface { + drpc.Stream + SendAndClose(*SyncReadyResponse) error +} + +type drpcAgentSocket_SyncReadyStream struct { + drpc.Stream +} + +func (x *drpcAgentSocket_SyncReadyStream) SendAndClose(m *SyncReadyResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil { + return err + } + return x.CloseSend() +} + +type DRPCAgentSocket_SyncStatusStream interface { + drpc.Stream + SendAndClose(*SyncStatusResponse) error +} + +type drpcAgentSocket_SyncStatusStream struct { + drpc.Stream +} + +func (x *drpcAgentSocket_SyncStatusStream) SendAndClose(m *SyncStatusResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil { + return err + } + return x.CloseSend() +} diff --git a/agent/agentsocket/proto/version.go b/agent/agentsocket/proto/version.go new file mode 100644 index 0000000000..9c6f2cb2a4 --- /dev/null +++ b/agent/agentsocket/proto/version.go @@ -0,0 +1,17 @@ +package proto + +import "github.com/coder/coder/v2/apiversion" + +// Version history: +// +// API v1.0: +// - Initial release +// - Ping +// - Sync operations: SyncStart, SyncWant, SyncComplete, SyncWait, SyncStatus + +const ( + CurrentMajor = 1 + CurrentMinor = 0 +) + +var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor) diff --git a/agent/agentsocket/server.go b/agent/agentsocket/server.go new file mode 100644 index 0000000000..c9f9a4ca42 --- /dev/null +++ b/agent/agentsocket/server.go @@ -0,0 +1,185 @@ +package agentsocket + +import ( + "context" + "errors" + "net" + "sync" + + "golang.org/x/xerrors" + + "github.com/hashicorp/yamux" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentsocket/proto" + "github.com/coder/coder/v2/agent/unit" + "github.com/coder/coder/v2/codersdk/drpcsdk" +) + +// Server provides access to the DRPCAgentSocketService via a Unix domain socket. +// Do not invoke Server{} directly. Use NewServer() instead. +type Server struct { + logger slog.Logger + path string + drpcServer *drpcserver.Server + service *DRPCAgentSocketService + + mu sync.Mutex + listener net.Listener + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +func NewServer(path string, logger slog.Logger) (*Server, error) { + logger = logger.Named("agentsocket-server") + server := &Server{ + logger: logger, + path: path, + service: &DRPCAgentSocketService{ + logger: logger, + unitManager: unit.NewManager(), + }, + } + + mux := drpcmux.New() + err := proto.DRPCRegisterAgentSocket(mux, server.service) + if err != nil { + return nil, xerrors.Errorf("failed to register drpc service: %w", err) + } + + server.drpcServer = drpcserver.NewWithOptions(mux, drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + Log: func(err error) { + if errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) + + if server.path == "" { + var err error + server.path, err = getDefaultSocketPath() + if err != nil { + return nil, xerrors.Errorf("get default socket path: %w", err) + } + } + + listener, err := createSocket(server.path) + if err != nil { + return nil, xerrors.Errorf("create socket: %w", err) + } + + server.listener = listener + + // This context is canceled by server.Close(). + // canceling it will close all connections. + server.ctx, server.cancel = context.WithCancel(context.Background()) + + server.logger.Info(server.ctx, "agent socket server started", slog.F("path", server.path)) + + server.wg.Add(1) + go func() { + defer server.wg.Done() + server.acceptConnections() + }() + + return server, nil +} + +func (s *Server) Close() error { + s.mu.Lock() + + if s.listener == nil { + s.mu.Unlock() + return nil + } + + s.logger.Info(s.ctx, "stopping agent socket server") + + s.cancel() + + if err := s.listener.Close(); err != nil { + s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err)) + } + + s.listener = nil + + s.mu.Unlock() + + // Wait for all connections to finish + s.wg.Wait() + + if err := cleanupSocket(s.path); err != nil { + s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err)) + } + + s.logger.Info(s.ctx, "agent socket server stopped") + + return nil +} + +func (s *Server) acceptConnections() { + // In an edge case, Close() might race with acceptConnections() and set s.listener to nil. + // Therefore, we grab a copy of the listener under a lock. We might still get a nil listener, + // but then we know close has already run and we can return early. + s.mu.Lock() + listener := s.listener + s.mu.Unlock() + if listener == nil { + return + } + + for { + select { + case <-s.ctx.Done(): + return + default: + } + + conn, err := listener.Accept() + if err != nil { + s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err)) + continue + } + + s.mu.Lock() + if s.listener == nil { + s.mu.Unlock() + _ = conn.Close() + return + } + s.wg.Add(1) + s.mu.Unlock() + + go func() { + defer s.wg.Done() + s.handleConnection(conn) + }() + } +} + +func (s *Server) handleConnection(conn net.Conn) { + defer conn.Close() + + s.logger.Debug(s.ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr())) + + config := yamux.DefaultConfig() + config.LogOutput = nil + config.Logger = slog.Stdlib(s.ctx, s.logger.Named("agentsocket-yamux"), slog.LevelInfo) + session, err := yamux.Server(conn, config) + if err != nil { + s.logger.Warn(s.ctx, "failed to create yamux session", slog.Error(err)) + return + } + defer session.Close() + + err = s.drpcServer.Serve(s.ctx, session) + if err != nil { + s.logger.Debug(s.ctx, "drpc server finished", slog.Error(err)) + } +} diff --git a/agent/agentsocket/server_test.go b/agent/agentsocket/server_test.go new file mode 100644 index 0000000000..cf06aff170 --- /dev/null +++ b/agent/agentsocket/server_test.go @@ -0,0 +1,52 @@ +package agentsocket_test + +import ( + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentsocket" +) + +func TestServer(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("agentsocket is not supported on Windows") + } + + t.Run("StartStop", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "test.sock") + logger := slog.Make().Leveled(slog.LevelDebug) + server, err := agentsocket.NewServer(socketPath, logger) + require.NoError(t, err) + require.NoError(t, server.Close()) + }) + + t.Run("AlreadyStarted", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "test.sock") + logger := slog.Make().Leveled(slog.LevelDebug) + server1, err := agentsocket.NewServer(socketPath, logger) + require.NoError(t, err) + defer server1.Close() + _, err = agentsocket.NewServer(socketPath, logger) + require.ErrorContains(t, err, "create socket") + }) + + t.Run("AutoSocketPath", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "test.sock") + logger := slog.Make().Leveled(slog.LevelDebug) + server, err := agentsocket.NewServer(socketPath, logger) + require.NoError(t, err) + require.NoError(t, server.Close()) + }) +} diff --git a/agent/agentsocket/service.go b/agent/agentsocket/service.go new file mode 100644 index 0000000000..f3384dfd43 --- /dev/null +++ b/agent/agentsocket/service.go @@ -0,0 +1,142 @@ +package agentsocket + +import ( + "context" + "errors" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentsocket/proto" + "github.com/coder/coder/v2/agent/unit" +) + +var _ proto.DRPCAgentSocketServer = (*DRPCAgentSocketService)(nil) + +var ErrUnitManagerNotAvailable = xerrors.New("unit manager not available") + +type DRPCAgentSocketService struct { + unitManager *unit.Manager + logger slog.Logger +} + +func (*DRPCAgentSocketService) Ping(_ context.Context, _ *proto.PingRequest) (*proto.PingResponse, error) { + return &proto.PingResponse{}, nil +} + +func (s *DRPCAgentSocketService) SyncStart(_ context.Context, req *proto.SyncStartRequest) (*proto.SyncStartResponse, error) { + if s.unitManager == nil { + return nil, xerrors.Errorf("SyncStart: %w", ErrUnitManagerNotAvailable) + } + + unitID := unit.ID(req.Unit) + + if err := s.unitManager.Register(unitID); err != nil { + if !errors.Is(err, unit.ErrUnitAlreadyRegistered) { + return nil, xerrors.Errorf("SyncStart: %w", err) + } + } + + isReady, err := s.unitManager.IsReady(unitID) + if err != nil { + return nil, xerrors.Errorf("cannot check readiness: %w", err) + } + if !isReady { + return nil, xerrors.Errorf("cannot start unit %q: unit not ready", req.Unit) + } + + err = s.unitManager.UpdateStatus(unitID, unit.StatusStarted) + if err != nil { + return nil, xerrors.Errorf("cannot start unit %q: %w", req.Unit, err) + } + + return &proto.SyncStartResponse{}, nil +} + +func (s *DRPCAgentSocketService) SyncWant(_ context.Context, req *proto.SyncWantRequest) (*proto.SyncWantResponse, error) { + if s.unitManager == nil { + return nil, xerrors.Errorf("cannot add dependency: %w", ErrUnitManagerNotAvailable) + } + + unitID := unit.ID(req.Unit) + dependsOnID := unit.ID(req.DependsOn) + + if err := s.unitManager.Register(unitID); err != nil && !errors.Is(err, unit.ErrUnitAlreadyRegistered) { + return nil, xerrors.Errorf("cannot add dependency: %w", err) + } + + if err := s.unitManager.AddDependency(unitID, dependsOnID, unit.StatusComplete); err != nil { + return nil, xerrors.Errorf("cannot add dependency: %w", err) + } + + return &proto.SyncWantResponse{}, nil +} + +func (s *DRPCAgentSocketService) SyncComplete(_ context.Context, req *proto.SyncCompleteRequest) (*proto.SyncCompleteResponse, error) { + if s.unitManager == nil { + return nil, xerrors.Errorf("cannot complete unit: %w", ErrUnitManagerNotAvailable) + } + + unitID := unit.ID(req.Unit) + + if err := s.unitManager.UpdateStatus(unitID, unit.StatusComplete); err != nil { + return nil, xerrors.Errorf("cannot complete unit %q: %w", req.Unit, err) + } + + return &proto.SyncCompleteResponse{}, nil +} + +func (s *DRPCAgentSocketService) SyncReady(_ context.Context, req *proto.SyncReadyRequest) (*proto.SyncReadyResponse, error) { + if s.unitManager == nil { + return nil, xerrors.Errorf("cannot check readiness: %w", ErrUnitManagerNotAvailable) + } + + unitID := unit.ID(req.Unit) + isReady, err := s.unitManager.IsReady(unitID) + if err != nil { + return nil, xerrors.Errorf("cannot check readiness: %w", err) + } + + return &proto.SyncReadyResponse{ + Ready: isReady, + }, nil +} + +func (s *DRPCAgentSocketService) SyncStatus(_ context.Context, req *proto.SyncStatusRequest) (*proto.SyncStatusResponse, error) { + if s.unitManager == nil { + return nil, xerrors.Errorf("cannot get status for unit %q: %w", req.Unit, ErrUnitManagerNotAvailable) + } + + unitID := unit.ID(req.Unit) + + isReady, err := s.unitManager.IsReady(unitID) + if err != nil { + return nil, xerrors.Errorf("cannot check readiness: %w", err) + } + + dependencies, err := s.unitManager.GetAllDependencies(unitID) + if err != nil { + return nil, xerrors.Errorf("failed to get dependencies: %w", err) + } + + var depInfos []*proto.DependencyInfo + for _, dep := range dependencies { + depInfos = append(depInfos, &proto.DependencyInfo{ + Unit: string(dep.Unit), + DependsOn: string(dep.DependsOn), + RequiredStatus: string(dep.RequiredStatus), + CurrentStatus: string(dep.CurrentStatus), + IsSatisfied: dep.IsSatisfied, + }) + } + + u, err := s.unitManager.Unit(unitID) + if err != nil { + return nil, xerrors.Errorf("cannot get status for unit %q: %w", req.Unit, err) + } + return &proto.SyncStatusResponse{ + Status: string(u.Status()), + IsReady: isReady, + Dependencies: depInfos, + }, nil +} diff --git a/agent/agentsocket/service_test.go b/agent/agentsocket/service_test.go new file mode 100644 index 0000000000..0d6be345b9 --- /dev/null +++ b/agent/agentsocket/service_test.go @@ -0,0 +1,470 @@ +package agentsocket_test + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentsocket" + "github.com/coder/coder/v2/agent/agentsocket/proto" + "github.com/coder/coder/v2/agent/unit" + "github.com/coder/coder/v2/codersdk/drpcsdk" +) + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets (probably). +// +// During tests on darwin we hit the max path length limit for unix sockets +// pretty easily in the default location, so this function uses /tmp instead to +// get shorter paths. To keep paths short, we use a hash of the test name +// instead of the full test name. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + // Use a short hash of the test name to keep the path under 104 chars + hash := sha256.Sum256([]byte(t.Name())) + hashStr := hex.EncodeToString(hash[:])[:8] // Use first 8 chars of hash + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("c-%s-", hashStr)) + require.NoError(t, err, "create temp dir for unix socket test") + t.Cleanup(func() { + err := os.RemoveAll(dir) + assert.NoError(t, err, "remove temp dir", dir) + }) + return dir + } + return t.TempDir() +} + +// newSocketClient creates a DRPC client connected to the Unix socket at the given path. +func newSocketClient(t *testing.T, socketPath string) proto.DRPCAgentSocketClient { + t.Helper() + + conn, err := net.Dial("unix", socketPath) + require.NoError(t, err) + + config := yamux.DefaultConfig() + config.Logger = nil + session, err := yamux.Client(conn, config) + require.NoError(t, err) + + client := proto.NewDRPCAgentSocketClient(drpcsdk.MultiplexedConn(session)) + + t.Cleanup(func() { + _ = session.Close() + _ = conn.Close() + }) + return client +} + +func TestDRPCAgentSocketService(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("agentsocket is not supported on Windows") + } + + t.Run("Ping", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + _, err = client.Ping(context.Background(), &proto.PingRequest{}) + require.NoError(t, err) + }) + + t.Run("SyncStart", func(t *testing.T) { + t.Parallel() + + t.Run("NewUnit", func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + }) + + t.Run("UnitAlreadyStarted", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // First Start + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + + // Second Start + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.ErrorContains(t, err, unit.ErrSameStatusAlreadySet.Error()) + + status, err = client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + }) + + t.Run("UnitAlreadyCompleted", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // First start + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + + // Complete the unit + _, err = client.SyncComplete(context.Background(), &proto.SyncCompleteRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + status, err = client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "completed", status.Status) + + // Second start + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + status, err = client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + }) + + t.Run("UnitNotReady", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + _, err = client.SyncWant(context.Background(), &proto.SyncWantRequest{ + Unit: "test-unit", + DependsOn: "dependency-unit", + }) + require.NoError(t, err) + + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.ErrorContains(t, err, "unit not ready") + + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, string(unit.StatusPending), status.Status) + require.False(t, status.IsReady) + }) + }) + + t.Run("SyncWant", func(t *testing.T) { + t.Parallel() + + t.Run("NewUnits", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // If dependency units are not registered, they are registered automatically + _, err = client.SyncWant(context.Background(), &proto.SyncWantRequest{ + Unit: "test-unit", + DependsOn: "dependency-unit", + }) + require.NoError(t, err) + + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Len(t, status.Dependencies, 1) + require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn) + require.Equal(t, "completed", status.Dependencies[0].RequiredStatus) + }) + + t.Run("DependencyAlreadyRegistered", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // Start the dependency unit + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "dependency-unit", + }) + require.NoError(t, err) + + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "dependency-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + + // Add the dependency after the dependency unit has already started + _, err = client.SyncWant(context.Background(), &proto.SyncWantRequest{ + Unit: "test-unit", + DependsOn: "dependency-unit", + }) + + // Dependencies can be added even if the dependency unit has already started + require.NoError(t, err) + + // The dependency is now reflected in the test unit's status + status, err = client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn) + require.Equal(t, "completed", status.Dependencies[0].RequiredStatus) + }) + + t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // Start the dependent unit + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + status, err := client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "started", status.Status) + + // Add the dependency after the dependency unit has already started + _, err = client.SyncWant(context.Background(), &proto.SyncWantRequest{ + Unit: "test-unit", + DependsOn: "dependency-unit", + }) + + // Dependencies can be added even if the dependent unit has already started. + // The dependency applies the next time a unit is started. The current status is not updated. + // This is to allow flexible dependency management. It does mean that users of this API should + // take care to add dependencies before they start their dependent units. + require.NoError(t, err) + + // The dependency is now reflected in the test unit's status + status, err = client.SyncStatus(context.Background(), &proto.SyncStatusRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn) + require.Equal(t, "completed", status.Dependencies[0].RequiredStatus) + }) + }) + + t.Run("SyncReady", func(t *testing.T) { + t.Parallel() + + t.Run("UnregisteredUnit", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + response, err := client.SyncReady(context.Background(), &proto.SyncReadyRequest{ + Unit: "unregistered-unit", + }) + require.NoError(t, err) + require.False(t, response.Ready) + }) + + t.Run("UnitNotReady", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // Register a unit with an unsatisfied dependency + _, err = client.SyncWant(context.Background(), &proto.SyncWantRequest{ + Unit: "test-unit", + DependsOn: "dependency-unit", + }) + require.NoError(t, err) + + // Check readiness - should be false because dependency is not satisfied + response, err := client.SyncReady(context.Background(), &proto.SyncReadyRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + require.False(t, response.Ready) + }) + + t.Run("UnitReady", func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(tempDirUnixSocket(t), "test.sock") + + server, err := agentsocket.NewServer( + socketPath, + slog.Make().Leveled(slog.LevelDebug), + ) + require.NoError(t, err) + defer server.Close() + + client := newSocketClient(t, socketPath) + + // Register a unit with no dependencies - should be ready immediately + _, err = client.SyncStart(context.Background(), &proto.SyncStartRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + // Check readiness - should be true + _, err = client.SyncReady(context.Background(), &proto.SyncReadyRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + // Also test a unit with satisfied dependencies + _, err = client.SyncWant(context.Background(), &proto.SyncWantRequest{ + Unit: "dependent-unit", + DependsOn: "test-unit", + }) + require.NoError(t, err) + + // Complete the dependency + _, err = client.SyncComplete(context.Background(), &proto.SyncCompleteRequest{ + Unit: "test-unit", + }) + require.NoError(t, err) + + // Now dependent-unit should be ready + _, err = client.SyncReady(context.Background(), &proto.SyncReadyRequest{ + Unit: "dependent-unit", + }) + require.NoError(t, err) + }) + }) +} diff --git a/agent/agentsocket/socket_unix.go b/agent/agentsocket/socket_unix.go new file mode 100644 index 0000000000..0fa062656a --- /dev/null +++ b/agent/agentsocket/socket_unix.go @@ -0,0 +1,83 @@ +//go:build !windows + +package agentsocket + +import ( + "crypto/rand" + "encoding/hex" + "net" + "os" + "path/filepath" + "time" + + "golang.org/x/xerrors" +) + +// createSocket creates a Unix domain socket listener +func createSocket(path string) (net.Listener, error) { + if !isSocketAvailable(path) { + return nil, xerrors.Errorf("socket path %s is not available", path) + } + + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return nil, xerrors.Errorf("remove existing socket: %w", err) + } + + // Create parent directory if it doesn't exist + parentDir := filepath.Dir(path) + if err := os.MkdirAll(parentDir, 0o700); err != nil { + return nil, xerrors.Errorf("create socket directory: %w", err) + } + + listener, err := net.Listen("unix", path) + if err != nil { + return nil, xerrors.Errorf("listen on unix socket: %w", err) + } + + if err := os.Chmod(path, 0o600); err != nil { + _ = listener.Close() + return nil, xerrors.Errorf("set socket permissions: %w", err) + } + return listener, nil +} + +// getDefaultSocketPath returns the default socket path for Unix-like systems +func getDefaultSocketPath() (string, error) { + randomBytes := make([]byte, 4) + if _, err := rand.Read(randomBytes); err != nil { + return "", xerrors.Errorf("generate random socket name: %w", err) + } + randomSuffix := hex.EncodeToString(randomBytes) + + // Try XDG_RUNTIME_DIR first + if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { + return filepath.Join(runtimeDir, "coder-agent-"+randomSuffix+".sock"), nil + } + + return filepath.Join("/tmp", "coder-agent-"+randomSuffix+".sock"), nil +} + +// CleanupSocket removes the socket file +func cleanupSocket(path string) error { + return os.Remove(path) +} + +// isSocketAvailable checks if a socket path is available for use +func isSocketAvailable(path string) bool { + // Check if file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return true + } + + // Try to connect to see if it's actually listening + dialer := net.Dialer{Timeout: 10 * time.Second} + conn, err := dialer.Dial("unix", path) + if err != nil { + // If we can't connect, the socket is not in use + // Socket is available for use + return true + } + _ = conn.Close() + // Socket is in use + return false +} diff --git a/agent/agentsocket/socket_windows.go b/agent/agentsocket/socket_windows.go new file mode 100644 index 0000000000..69785de43e --- /dev/null +++ b/agent/agentsocket/socket_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package agentsocket + +import ( + "net" + + "golang.org/x/xerrors" +) + +// createSocket returns an error indicating that agentsocket is not supported on Windows. +// This feature is unix-only in its current experimental state. +func createSocket(_ string) (net.Listener, error) { + return nil, xerrors.New("agentsocket is not supported on Windows") +} + +// getDefaultSocketPath returns an error indicating that agentsocket is not supported on Windows. +// This feature is unix-only in its current experimental state. +func getDefaultSocketPath() (string, error) { + return "", xerrors.New("agentsocket is not supported on Windows") +} + +// cleanupSocket is a no-op on Windows since agentsocket is not supported. +func cleanupSocket(_ string) error { + // No-op since agentsocket is not supported on Windows + return nil +} diff --git a/agent/unit/manager.go b/agent/unit/manager.go index 56ec0213c1..14727d43f9 100644 --- a/agent/unit/manager.go +++ b/agent/unit/manager.go @@ -10,6 +10,7 @@ import ( ) var ( + ErrUnitIDRequired = xerrors.New("unit name is required") ErrUnitNotFound = xerrors.New("unit not found") ErrUnitAlreadyRegistered = xerrors.New("unit already registered") ErrCannotUpdateOtherUnit = xerrors.New("cannot update other unit's status") @@ -91,6 +92,10 @@ func (m *Manager) Register(id ID) error { m.mu.Lock() defer m.mu.Unlock() + if id == "" { + return xerrors.Errorf("registering unit %q: %w", id, ErrUnitIDRequired) + } + if m.registered(id) { return xerrors.Errorf("registering unit %q: %w", id, ErrUnitAlreadyRegistered) } @@ -112,22 +117,30 @@ func (m *Manager) registered(id ID) bool { // Unit fetches a unit from the manager. If the unit does not exist, // it returns the Unit zero-value as a placeholder unit, because // units may depend on other units that have not yet been created. -func (m *Manager) Unit(id ID) Unit { +func (m *Manager) Unit(id ID) (Unit, error) { + if id == "" { + return Unit{}, xerrors.Errorf("unit ID cannot be empty: %w", ErrUnitIDRequired) + } + m.mu.RLock() defer m.mu.RUnlock() - return m.units[id] + return m.units[id], nil } -func (m *Manager) IsReady(id ID) bool { +func (m *Manager) IsReady(id ID) (bool, error) { + if id == "" { + return false, xerrors.Errorf("unit ID cannot be empty: %w", ErrUnitIDRequired) + } + m.mu.RLock() defer m.mu.RUnlock() if !m.registered(id) { - return false + return false, nil } - return m.units[id].ready + return m.units[id].ready, nil } // AddDependency adds a dependency relationship between units. @@ -136,8 +149,13 @@ func (m *Manager) AddDependency(unit ID, dependsOn ID, requiredStatus Status) er m.mu.Lock() defer m.mu.Unlock() - if !m.registered(unit) { - return xerrors.Errorf("checking registration for unit %q: %w", unit, ErrUnitNotFound) + switch { + case unit == "": + return xerrors.Errorf("dependent name cannot be empty: %w", ErrUnitIDRequired) + case dependsOn == "": + return xerrors.Errorf("dependency name cannot be empty: %w", ErrUnitIDRequired) + case !m.registered(unit): + return xerrors.Errorf("dependent unit %q must be registered first: %w", unit, ErrUnitNotFound) } // Add the dependency edge to the graph @@ -158,8 +176,11 @@ func (m *Manager) UpdateStatus(unit ID, newStatus Status) error { m.mu.Lock() defer m.mu.Unlock() - if !m.registered(unit) { - return xerrors.Errorf("checking registration for unit %q: %w", unit, ErrUnitNotFound) + switch { + case unit == "": + return xerrors.Errorf("updating status for unit %q: %w", unit, ErrUnitIDRequired) + case !m.registered(unit): + return xerrors.Errorf("unit %q must be registered first: %w", unit, ErrUnitNotFound) } u := m.units[unit] @@ -212,6 +233,10 @@ func (m *Manager) GetAllDependencies(unit ID) ([]Dependency, error) { m.mu.RLock() defer m.mu.RUnlock() + if unit == "" { + return nil, xerrors.Errorf("unit ID cannot be empty: %w", ErrUnitIDRequired) + } + if !m.registered(unit) { return nil, xerrors.Errorf("checking registration for unit %q: %w", unit, ErrUnitNotFound) } @@ -225,7 +250,7 @@ func (m *Manager) GetAllDependencies(unit ID) ([]Dependency, error) { requiredStatus := dependency.Edge allDependencies = append(allDependencies, Dependency{ Unit: unit, - DependsOn: dependsOnUnit.id, + DependsOn: dependency.To, RequiredStatus: requiredStatus, CurrentStatus: dependsOnUnit.status, IsSatisfied: dependsOnUnit.status == requiredStatus, diff --git a/agent/unit/manager_test.go b/agent/unit/manager_test.go index d85b1752a1..0f1eab93ab 100644 --- a/agent/unit/manager_test.go +++ b/agent/unit/manager_test.go @@ -16,6 +16,37 @@ const ( unitD unit.ID = "serviceD" ) +func TestManager_UnitValidation(t *testing.T) { + t.Parallel() + + t.Run("Empty Unit Name", func(t *testing.T) { + t.Parallel() + + manager := unit.NewManager() + + err := manager.Register("") + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + err = manager.AddDependency("", unitA, unit.StatusStarted) + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + err = manager.AddDependency(unitA, "", unit.StatusStarted) + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + dependencies, err := manager.GetAllDependencies("") + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + require.Len(t, dependencies, 0) + unmetDependencies, err := manager.GetUnmetDependencies("") + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + require.Len(t, unmetDependencies, 0) + err = manager.UpdateStatus("", unit.StatusStarted) + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + isReady, err := manager.IsReady("") + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + require.False(t, isReady) + u, err := manager.Unit("") + require.ErrorIs(t, err, unit.ErrUnitIDRequired) + assert.Equal(t, unit.Unit{}, u) + }) +} + func TestManager_Register(t *testing.T) { t.Parallel() @@ -29,10 +60,13 @@ func TestManager_Register(t *testing.T) { require.NoError(t, err) // Then: the unit should be ready (no dependencies) - u := manager.Unit(unitA) + u, err := manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unitA, u.ID()) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) t.Run("RegisterDuplicateUnit", func(t *testing.T) { @@ -56,9 +90,12 @@ func TestManager_Register(t *testing.T) { require.ErrorIs(t, err, unit.ErrUnitAlreadyRegistered) // Then: the unit status should not be overwritten - u := manager.Unit(unitA) + u, err := manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusStarted, u.Status()) - assert.True(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) t.Run("RegisterMultipleUnits", func(t *testing.T) { @@ -75,9 +112,12 @@ func TestManager_Register(t *testing.T) { // Then: all units should be ready initially for _, unitID := range unitIDs { - u := manager.Unit(unitID) + u, err := manager.Unit(unitID) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitID)) + isReady, err := manager.IsReady(unitID) + require.NoError(t, err) + assert.True(t, isReady) } }) } @@ -101,28 +141,38 @@ func TestManager_AddDependency(t *testing.T) { require.NoError(t, err) // Then: Unit A should not be ready (depends on B) - u := manager.Unit(unitA) + u, err := manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // Then: Unit B should still be ready (no dependencies) - u = manager.Unit(unitB) + u, err = manager.Unit(unitB) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitB)) + isReady, err = manager.IsReady(unitB) + require.NoError(t, err) + assert.True(t, isReady) // When: Unit B is started err = manager.UpdateStatus(unitB, unit.StatusStarted) require.NoError(t, err) // Then: Unit A should be ready, because its dependency is now in the desired state. - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) // When: Unit B is stopped err = manager.UpdateStatus(unitB, unit.StatusPending) require.NoError(t, err) // Then: Unit A should no longer be ready, because its dependency is not in the desired state. - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) }) t.Run("AddDependencyByAnUnregisteredDependentUnit", func(t *testing.T) { @@ -156,11 +206,22 @@ func TestManager_AddDependency(t *testing.T) { err = manager.AddDependency(unitA, unitB, unit.StatusStarted) require.NoError(t, err) - u := manager.Unit(unitB) + // Then: The dependency should be visible in Unit A's status + dependencies, err := manager.GetAllDependencies(unitA) + require.NoError(t, err) + require.Len(t, dependencies, 1) + assert.Equal(t, unitB, dependencies[0].DependsOn) + assert.Equal(t, unit.StatusStarted, dependencies[0].RequiredStatus) + assert.False(t, dependencies[0].IsSatisfied) + + u, err := manager.Unit(unitB) + require.NoError(t, err) assert.Equal(t, unit.StatusNotRegistered, u.Status()) // Then: Unit A should not be ready, because it depends on Unit B - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // When: Unit B is registered err = manager.Register(unitB) @@ -168,14 +229,18 @@ func TestManager_AddDependency(t *testing.T) { // Then: Unit A should still not be ready. // Unit B is not registered, but it has not been started as required by the dependency. - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // When: Unit B is started err = manager.UpdateStatus(unitB, unit.StatusStarted) require.NoError(t, err) // Then: Unit A should be ready, because its dependency is now in the desired state. - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) t.Run("AddDependencyCreatesACyclicDependency", func(t *testing.T) { @@ -208,6 +273,32 @@ func TestManager_AddDependency(t *testing.T) { err = manager.AddDependency(unitD, unitA, unit.StatusStarted) require.ErrorIs(t, err, unit.ErrCycleDetected) }) + + t.Run("UpdatingADependency", func(t *testing.T) { + t.Parallel() + + manager := unit.NewManager() + + // Given units A and B are registered + err := manager.Register(unitA) + require.NoError(t, err) + err = manager.Register(unitB) + require.NoError(t, err) + + // Given Unit A depends on Unit B being unit.StatusStarted + err = manager.AddDependency(unitA, unitB, unit.StatusStarted) + require.NoError(t, err) + + // When: The dependency is updated to unit.StatusComplete + err = manager.AddDependency(unitA, unitB, unit.StatusComplete) + require.NoError(t, err) + + // Then: Unit A should only have one dependency, and it should be unit.StatusComplete + dependencies, err := manager.GetAllDependencies(unitA) + require.NoError(t, err) + require.Len(t, dependencies, 1) + assert.Equal(t, unit.StatusComplete, dependencies[0].RequiredStatus) + }) } func TestManager_UpdateStatus(t *testing.T) { @@ -229,18 +320,24 @@ func TestManager_UpdateStatus(t *testing.T) { require.NoError(t, err) // Then: Unit A should not be ready (depends on B) - u := manager.Unit(unitA) + u, err := manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // When: Unit B is started err = manager.UpdateStatus(unitB, unit.StatusStarted) require.NoError(t, err) // Then: Unit A should be ready, because its dependency is now in the desired state. - u = manager.Unit(unitA) + u, err = manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) t.Run("UpdateStatusWithUnregisteredUnit", func(t *testing.T) { @@ -276,43 +373,64 @@ func TestManager_UpdateStatus(t *testing.T) { require.NoError(t, err) // Then: only Unit C should be ready (no dependencies) - u := manager.Unit(unitC) + u, err := manager.Unit(unitC) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitC)) + isReady, err := manager.IsReady(unitC) + require.NoError(t, err) + assert.True(t, isReady) - u = manager.Unit(unitB) + u, err = manager.Unit(unitB) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.False(t, manager.IsReady(unitB)) + isReady, err = manager.IsReady(unitB) + require.NoError(t, err) + assert.False(t, isReady) - u = manager.Unit(unitA) + u, err = manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // When: Unit C is completed err = manager.UpdateStatus(unitC, unit.StatusComplete) require.NoError(t, err) // Then: Unit B should be ready, because its dependency is now in the desired state. - u = manager.Unit(unitB) + u, err = manager.Unit(unitB) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitB)) + isReady, err = manager.IsReady(unitB) + require.NoError(t, err) + assert.True(t, isReady) - u = manager.Unit(unitA) + u, err = manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) - u = manager.Unit(unitB) + u, err = manager.Unit(unitB) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitB)) + isReady, err = manager.IsReady(unitB) + require.NoError(t, err) + assert.True(t, isReady) // When: Unit B is started err = manager.UpdateStatus(unitB, unit.StatusStarted) require.NoError(t, err) // Then: Unit A should be ready, because its dependency is now in the desired state. - u = manager.Unit(unitA) + u, err = manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusPending, u.Status()) - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) } @@ -419,19 +537,25 @@ func TestManager_MultipleDependencies(t *testing.T) { require.NoError(t, err) // A should not be ready (depends on both B and C) - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // Update B to unit.StatusStarted - A should still not be ready (needs C too) err = manager.UpdateStatus(unitB, unit.StatusStarted) require.NoError(t, err) - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // Update C to "started" - A should now be ready err = manager.UpdateStatus(unitC, unit.StatusStarted) require.NoError(t, err) - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) t.Run("ComplexDependencyChain", func(t *testing.T) { @@ -460,30 +584,48 @@ func TestManager_MultipleDependencies(t *testing.T) { require.NoError(t, err) // Initially only D is ready - assert.True(t, manager.IsReady(unitD)) - assert.False(t, manager.IsReady(unitB)) - assert.False(t, manager.IsReady(unitC)) - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitD) + require.NoError(t, err) + assert.True(t, isReady) + isReady, err = manager.IsReady(unitB) + require.NoError(t, err) + assert.False(t, isReady) + isReady, err = manager.IsReady(unitC) + require.NoError(t, err) + assert.False(t, isReady) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // Update D to "completed" - B and C should become ready err = manager.UpdateStatus(unitD, unit.StatusComplete) require.NoError(t, err) - assert.True(t, manager.IsReady(unitB)) - assert.True(t, manager.IsReady(unitC)) - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitB) + require.NoError(t, err) + assert.True(t, isReady) + isReady, err = manager.IsReady(unitC) + require.NoError(t, err) + assert.True(t, isReady) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // Update B to unit.StatusStarted - A should still not be ready (needs C) err = manager.UpdateStatus(unitB, unit.StatusStarted) require.NoError(t, err) - assert.False(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // Update C to "started" - A should now be ready err = manager.UpdateStatus(unitC, unit.StatusStarted) require.NoError(t, err) - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) t.Run("DifferentStatusTypes", func(t *testing.T) { @@ -512,14 +654,18 @@ func TestManager_MultipleDependencies(t *testing.T) { // Then: Unit A should not be ready, because only one of its dependencies is in the desired state. // It still requires Unit C to be completed. - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) // When: Unit C is completed err = manager.UpdateStatus(unitC, unit.StatusComplete) require.NoError(t, err) // Then: Unit A should be ready, because both of its dependencies are in the desired state. - assert.True(t, manager.IsReady(unitA)) + isReady, err = manager.IsReady(unitA) + require.NoError(t, err) + assert.True(t, isReady) }) } @@ -532,10 +678,13 @@ func TestManager_IsReady(t *testing.T) { manager := unit.NewManager() // Given: a unit is not registered - u := manager.Unit(unitA) + u, err := manager.Unit(unitA) + require.NoError(t, err) assert.Equal(t, unit.StatusNotRegistered, u.Status()) // Then: the unit is not ready - assert.False(t, manager.IsReady(unitA)) + isReady, err := manager.IsReady(unitA) + require.NoError(t, err) + assert.False(t, isReady) }) }