mirror of
https://github.com/coder/coder.git
synced 2026-06-04 05:28:20 +00:00
49b34a716a
Upgrades to slog v3 which includes a small, but backward incompatible API change to the acceptible call arguments when logging. This change allows us to verify via compile time type checking that arguments are correct and won't cause a panic, as was possible in slog v1, which this replaces (v2 was tagged but never used in coder/coder). It also updates dependencies that also use slog and were updated. I've left the `aibridge` dependency as a commit SHA, under the assumption that the team there (cc @pawbana @dannykopping ) will tag and update the dependency soon and on their own schedule. Other dependencies, I pushed new tags.
122 lines
3.3 KiB
Go
122 lines
3.3 KiB
Go
package agentssh
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
"github.com/google/uuid"
|
|
"go.uber.org/atomic"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
|
|
"cdr.dev/slog/v3"
|
|
)
|
|
|
|
// localForwardChannelData is copied from the ssh package.
|
|
type localForwardChannelData struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
// JetbrainsChannelWatcher is used to track JetBrains port forwarded (Gateway)
|
|
// channels. If the port forward is something other than JetBrains, this struct
|
|
// is a noop.
|
|
type JetbrainsChannelWatcher struct {
|
|
gossh.NewChannel
|
|
jetbrainsCounter *atomic.Int64
|
|
logger slog.Logger
|
|
originAddr string
|
|
reportConnection reportConnectionFunc
|
|
}
|
|
|
|
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, reportConnection reportConnectionFunc, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
|
|
d := localForwardChannelData{}
|
|
if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil {
|
|
// If the data fails to unmarshal, do nothing.
|
|
logger.Warn(ctx, "failed to unmarshal port forward data", slog.Error(err))
|
|
return newChannel
|
|
}
|
|
|
|
// If we do get a port, we should be able to get the matching PID and from
|
|
// there look up the invocation.
|
|
cmdline, err := getListeningPortProcessCmdline(d.DestPort)
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to inspect port",
|
|
slog.F("destination_port", d.DestPort),
|
|
slog.Error(err))
|
|
return newChannel
|
|
}
|
|
|
|
// If this is not JetBrains, then we do not need to do anything special. We
|
|
// attempt to match on something that appears unique to JetBrains software.
|
|
if !isJetbrainsProcess(cmdline) {
|
|
return newChannel
|
|
}
|
|
|
|
logger.Debug(ctx, "discovered forwarded JetBrains process",
|
|
slog.F("destination_port", d.DestPort))
|
|
|
|
return &JetbrainsChannelWatcher{
|
|
NewChannel: newChannel,
|
|
jetbrainsCounter: counter,
|
|
logger: logger.With(slog.F("destination_port", d.DestPort)),
|
|
originAddr: d.OriginAddr,
|
|
reportConnection: reportConnection,
|
|
}
|
|
}
|
|
|
|
func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) {
|
|
disconnected := w.reportConnection(uuid.New(), MagicSessionTypeJetBrains, w.originAddr)
|
|
|
|
c, r, err := w.NewChannel.Accept()
|
|
if err != nil {
|
|
disconnected(1, err.Error())
|
|
return c, r, err
|
|
}
|
|
w.jetbrainsCounter.Add(1)
|
|
// nolint: gocritic // JetBrains is a proper noun and should be capitalized
|
|
w.logger.Debug(context.Background(), "JetBrains watcher accepted channel")
|
|
|
|
return &ChannelOnClose{
|
|
Channel: c,
|
|
done: func() {
|
|
w.jetbrainsCounter.Add(-1)
|
|
disconnected(0, "")
|
|
// nolint: gocritic // JetBrains is a proper noun and should be capitalized
|
|
w.logger.Debug(context.Background(), "JetBrains watcher channel closed")
|
|
},
|
|
}, r, err
|
|
}
|
|
|
|
type ChannelOnClose struct {
|
|
gossh.Channel
|
|
// once ensures close only decrements the counter once.
|
|
// Because close can be called multiple times.
|
|
once sync.Once
|
|
done func()
|
|
}
|
|
|
|
func (c *ChannelOnClose) Close() error {
|
|
c.once.Do(c.done)
|
|
return c.Channel.Close()
|
|
}
|
|
|
|
func isJetbrainsProcess(cmdline string) bool {
|
|
opts := []string{
|
|
MagicProcessCmdlineJetBrains,
|
|
MagicProcessCmdlineToolbox,
|
|
MagicProcessCmdlineGateway,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
if strings.Contains(strings.ToLower(cmdline), strings.ToLower(opt)) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|