Files
coder/tailnet/service.go
T
Spike Curtis d6154c4310 chore: remove tailnet v1 API support (#14641)
Drops support for v1 of the tailnet API, which was the original coordination protocol where we only sent node updates, never marked them lost or disconnected.

v2 of the tailnet API went GA for CLI clients in Coder 2.8.0, so clients older than that would stop working.
2024-09-12 07:56:31 +04:00

364 lines
9.7 KiB
Go

package tailnet
import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
var ErrUnsupportedVersion = xerrors.New("unsupported version")
type streamIDContextKey struct{}
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
// on the context, since the information is extracted at the HTTP layer for
// remote clients of the API, or set outside tailnet for local clients (e.g.
// Coderd's single_tailnet)
type StreamID struct {
Name string
ID uuid.UUID
Auth CoordinateeAuth
}
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
return context.WithValue(ctx, streamIDContextKey{}, streamID)
}
type ClientServiceOptions struct {
Logger slog.Logger
CoordPtr *atomic.Pointer[Coordinator]
DERPMapUpdateFrequency time.Duration
DERPMapFn func() *tailcfg.DERPMap
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
ResumeTokenProvider ResumeTokenProvider
}
// ClientService is a tailnet coordination service that accepts a connection and version from a
// tailnet client, and support versions 2.x of the Tailnet API protocol.
type ClientService struct {
Logger slog.Logger
CoordPtr *atomic.Pointer[Coordinator]
drpc *drpcserver.Server
}
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
// loaded on each processed connection.
func NewClientService(options ClientServiceOptions) (
*ClientService, error,
) {
s := &ClientService{Logger: options.Logger, CoordPtr: options.CoordPtr}
mux := drpcmux.New()
drpcService := &DRPCService{
CoordPtr: options.CoordPtr,
Logger: options.Logger,
DerpMapUpdateFrequency: options.DERPMapUpdateFrequency,
DerpMapFn: options.DERPMapFn,
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
ResumeTokenProvider: options.ResumeTokenProvider,
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
if err != nil {
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) {
return
}
options.Logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
s.drpc = server
return s, nil
}
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
major, _, err := apiversion.Parse(version)
if err != nil {
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
return err
}
switch major {
case 2:
auth := ClientCoordinateeAuth{AgentID: agent}
streamID := StreamID{
Name: "client",
ID: id,
Auth: auth,
}
return s.ServeConnV2(ctx, conn, streamID)
default:
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return ErrUnsupportedVersion
}
}
func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error {
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(conn, config)
if err != nil {
return xerrors.Errorf("yamux init failed: %w", err)
}
ctx = WithStreamID(ctx, streamID)
s.Logger.Debug(ctx, "serving dRPC tailnet v2 API session",
slog.F("peer_id", streamID.ID.String()))
return s.drpc.Serve(ctx, session)
}
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
type DRPCService struct {
CoordPtr *atomic.Pointer[Coordinator]
Logger slog.Logger
DerpMapUpdateFrequency time.Duration
DerpMapFn func() *tailcfg.DERPMap
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
ResumeTokenProvider ResumeTokenProvider
}
func (s *DRPCService) PostTelemetry(_ context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
if s.NetworkTelemetryHandler != nil {
s.NetworkTelemetryHandler(req.Events)
}
return &proto.TelemetryResponse{}, nil
}
func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCTailnet_StreamDERPMapsStream) error {
defer stream.Close()
ticker := time.NewTicker(s.DerpMapUpdateFrequency)
defer ticker.Stop()
var lastDERPMap *tailcfg.DERPMap
for {
derpMap := s.DerpMapFn()
if derpMap == nil {
// in testing, we send nil to close the stream.
return io.EOF
}
if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) {
protoDERPMap := DERPMapToProto(derpMap)
err := stream.Send(protoDERPMap)
if err != nil {
return xerrors.Errorf("send derp map: %w", err)
}
lastDERPMap = derpMap
}
ticker.Reset(s.DerpMapUpdateFrequency)
select {
case <-stream.Context().Done():
return nil
case <-ticker.C:
}
}
}
func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
if !ok {
return nil, xerrors.New("no Stream ID")
}
res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID)
if err != nil {
return nil, xerrors.Errorf("generate resume token: %w", err)
}
return res, nil
}
func (s *DRPCService) Coordinate(stream proto.DRPCTailnet_CoordinateStream) error {
ctx := stream.Context()
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
if !ok {
_ = stream.Close()
return xerrors.New("no Stream ID")
}
logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
logger.Debug(ctx, "starting tailnet Coordinate")
coord := *(s.CoordPtr.Load())
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
c := communicator{
logger: logger,
stream: stream,
reqs: reqs,
resps: resps,
}
c.communicate()
return nil
}
type communicator struct {
logger slog.Logger
stream proto.DRPCTailnet_CoordinateStream
reqs chan<- *proto.CoordinateRequest
resps <-chan *proto.CoordinateResponse
}
func (c communicator) communicate() {
go c.loopReq()
c.loopResp()
}
func (c communicator) loopReq() {
ctx := c.stream.Context()
defer close(c.reqs)
for {
req, err := c.stream.Recv()
if err != nil {
c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err))
return
}
err = SendCtx(ctx, c.reqs, req)
if err != nil {
c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err()))
return
}
}
}
func (c communicator) loopResp() {
ctx := c.stream.Context()
defer func() {
err := c.stream.Close()
if err != nil {
c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err))
}
}()
for {
resp, err := RecvCtx(ctx, c.resps)
if err != nil {
c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err))
return
}
err = c.stream.Send(resp)
if err != nil {
c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err))
return
}
}
}
type NetworkTelemetryBatcher struct {
clock quartz.Clock
frequency time.Duration
maxSize int
batchFn func(batch []*proto.TelemetryEvent)
mu sync.Mutex
closed chan struct{}
done chan struct{}
ticker *quartz.Ticker
pending []*proto.TelemetryEvent
}
func NewNetworkTelemetryBatcher(clk quartz.Clock, frequency time.Duration, maxSize int, batchFn func(batch []*proto.TelemetryEvent)) *NetworkTelemetryBatcher {
b := &NetworkTelemetryBatcher{
clock: clk,
frequency: frequency,
maxSize: maxSize,
batchFn: batchFn,
closed: make(chan struct{}),
done: make(chan struct{}),
}
if b.batchFn == nil {
b.batchFn = func(batch []*proto.TelemetryEvent) {}
}
b.start()
return b
}
func (b *NetworkTelemetryBatcher) Close() error {
close(b.closed)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-ctx.Done():
return xerrors.New("timed out waiting for batcher to close")
case <-b.done:
}
return nil
}
func (b *NetworkTelemetryBatcher) sendTelemetryBatch() {
b.mu.Lock()
defer b.mu.Unlock()
events := b.pending
if len(events) == 0 {
return
}
b.pending = []*proto.TelemetryEvent{}
b.batchFn(events)
}
func (b *NetworkTelemetryBatcher) start() {
b.ticker = b.clock.NewTicker(b.frequency)
go func() {
defer func() {
// The lock prevents Handler from racing with Close.
b.mu.Lock()
defer b.mu.Unlock()
close(b.done)
b.ticker.Stop()
}()
for {
select {
case <-b.ticker.C:
b.sendTelemetryBatch()
b.ticker.Reset(b.frequency)
case <-b.closed:
// Send any remaining telemetry events before exiting.
b.sendTelemetryBatch()
return
}
}
}()
}
func (b *NetworkTelemetryBatcher) Handler(events []*proto.TelemetryEvent) {
b.mu.Lock()
defer b.mu.Unlock()
select {
case <-b.closed:
return
default:
}
for _, event := range events {
b.pending = append(b.pending, event)
if len(b.pending) >= b.maxSize {
// This can't call sendTelemetryBatch directly because we already
// hold the lock.
events := b.pending
b.pending = []*proto.TelemetryEvent{}
// Resetting the ticker is best effort. We don't care if the ticker
// has already fired or has a pending message, because the only risk
// is that we send two telemetry events in short succession (which
// is totally fine).
b.ticker.Reset(b.frequency)
// Perform the send in a goroutine to avoid blocking the DRPC call.
go b.batchFn(events)
}
}
}