fix: lock log sink against concurrent write and close (#10668)

fixes #10663
This commit is contained in:
Spike Curtis
2023-11-14 16:38:34 +04:00
committed by GitHub
parent 530be2f96a
commit dc4b1ef406
4 changed files with 101 additions and 5 deletions
+38
View File
@@ -0,0 +1,38 @@
package cliutil
import (
"io"
"sync"
)
type discardAfterClose struct {
sync.Mutex
wc io.WriteCloser
closed bool
}
// DiscardAfterClose is an io.WriteCloser that discards writes after it is closed without errors.
// It is useful as a target for a slog.Sink such that an underlying WriteCloser, like a file, can
// be cleaned up without race conditions from still-active loggers.
func DiscardAfterClose(wc io.WriteCloser) io.WriteCloser {
return &discardAfterClose{wc: wc}
}
func (d *discardAfterClose) Write(p []byte) (n int, err error) {
d.Lock()
defer d.Unlock()
if d.closed {
return len(p), nil
}
return d.wc.Write(p)
}
func (d *discardAfterClose) Close() error {
d.Lock()
defer d.Unlock()
if d.closed {
return nil
}
d.closed = true
return d.wc.Close()
}
+54
View File
@@ -0,0 +1,54 @@
package cliutil_test
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/cliutil"
)
func TestDiscardAfterClose(t *testing.T) {
t.Parallel()
exErr := errors.New("test")
fwc := &fakeWriteCloser{err: exErr}
uut := cliutil.DiscardAfterClose(fwc)
n, err := uut.Write([]byte("one"))
require.Equal(t, 3, n)
require.NoError(t, err)
n, err = uut.Write([]byte("two"))
require.Equal(t, 3, n)
require.NoError(t, err)
err = uut.Close()
require.Equal(t, exErr, err)
n, err = uut.Write([]byte("three"))
require.Equal(t, 5, n)
require.NoError(t, err)
require.Len(t, fwc.writes, 2)
require.EqualValues(t, "one", fwc.writes[0])
require.EqualValues(t, "two", fwc.writes[1])
}
type fakeWriteCloser struct {
writes [][]byte
closed bool
err error
}
func (f *fakeWriteCloser) Write(p []byte) (n int, err error) {
q := make([]byte, len(p))
copy(q, p)
f.writes = append(f.writes, q)
return len(p), nil
}
func (f *fakeWriteCloser) Close() error {
f.closed = true
return f.err
}
+4 -2
View File
@@ -28,6 +28,7 @@ import (
"github.com/coder/coder/v2/cli/clibase" "github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/cli/cliutil"
"github.com/coder/coder/v2/coderd/autobuild/notify" "github.com/coder/coder/v2/coderd/autobuild/notify"
"github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
@@ -114,12 +115,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
if err != nil { if err != nil {
return xerrors.Errorf("error opening %s for logging: %w", logDirPath, err) return xerrors.Errorf("error opening %s for logging: %w", logDirPath, err)
} }
dc := cliutil.DiscardAfterClose(logFile)
go func() { go func() {
wg.Wait() wg.Wait()
_ = logFile.Close() _ = dc.Close()
}() }()
logger = slog.Make(sloghuman.Sink(logFile)) logger = logger.AppendSinks(sloghuman.Sink(dc))
if r.verbose { if r.verbose {
logger = logger.Leveled(slog.LevelDebug) logger = logger.Leveled(slog.LevelDebug)
} }
+5 -3
View File
@@ -21,6 +21,7 @@ import (
"cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/cli/clibase" "github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliutil"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
) )
@@ -137,15 +138,16 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd {
// command via the ProxyCommand SSH option. // command via the ProxyCommand SSH option.
pid := os.Getppid() pid := os.Getppid()
var logger slog.Logger logger := slog.Make()
if logDir != "" { if logDir != "" {
logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid)) logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid))
logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600) logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil { if err != nil {
return xerrors.Errorf("open log file %q: %w", logFilePath, err) return xerrors.Errorf("open log file %q: %w", logFilePath, err)
} }
defer logFile.Close() dc := cliutil.DiscardAfterClose(logFile)
logger = slog.Make(sloghuman.Sink(logFile)).Leveled(slog.LevelDebug) defer dc.Close()
logger = logger.AppendSinks(sloghuman.Sink(dc)).Leveled(slog.LevelDebug)
} }
if r.disableDirect { if r.disableDirect {
logger.Info(ctx, "direct connections disabled") logger.Info(ctx, "direct connections disabled")