mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
6ed12ade54
There are sporadic flakes in some tests on Windows related to the use of `ptytest`. Since it's not clear why they fail, this commit adds some more logging and timeouts to the cleanup methods.
304 lines
5.7 KiB
Go
304 lines
5.7 KiB
Go
package ptytest
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os/exec"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/pty"
|
|
"github.com/coder/coder/testutil"
|
|
)
|
|
|
|
func New(t *testing.T, opts ...pty.Option) *PTY {
|
|
t.Helper()
|
|
|
|
ptty, err := pty.New(opts...)
|
|
require.NoError(t, err)
|
|
|
|
return create(t, ptty, "cmd")
|
|
}
|
|
|
|
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) {
|
|
t.Helper()
|
|
|
|
ptty, ps, err := pty.Start(cmd, opts...)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = ps.Kill()
|
|
_ = ps.Wait()
|
|
})
|
|
return create(t, ptty, cmd.Args[0]), ps
|
|
}
|
|
|
|
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
|
|
// Use pipe for logging.
|
|
logDone := make(chan struct{})
|
|
logr, logw := io.Pipe()
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
_ = logw.Close()
|
|
_ = logr.Close()
|
|
select {
|
|
case <-ctx.Done():
|
|
fatalf(t, name, "cleanup", "log pipe did not close in time")
|
|
case <-logDone: // Guard against logging after test.
|
|
}
|
|
})
|
|
|
|
// Write to log and output buffer.
|
|
copyDone := make(chan struct{})
|
|
out := newStdbuf()
|
|
w := io.MultiWriter(logw, out)
|
|
go func() {
|
|
defer close(copyDone)
|
|
_, err := io.Copy(w, ptty.Output())
|
|
_ = out.closeErr(err)
|
|
}()
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
_ = out.Close()
|
|
_ = ptty.Close()
|
|
select {
|
|
case <-ctx.Done():
|
|
fatalf(t, name, "cleanup", "copy did not close in time")
|
|
case <-copyDone:
|
|
}
|
|
})
|
|
|
|
tpty := &PTY{
|
|
t: t,
|
|
PTY: ptty,
|
|
out: out,
|
|
name: name,
|
|
|
|
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
|
|
}
|
|
|
|
go func() {
|
|
defer close(logDone)
|
|
s := bufio.NewScanner(logr)
|
|
for s.Scan() {
|
|
// Quote output to avoid terminal escape codes, e.g. bell.
|
|
tpty.logf("stdout: %q", s.Text())
|
|
}
|
|
}()
|
|
|
|
return tpty
|
|
}
|
|
|
|
type PTY struct {
|
|
pty.PTY
|
|
t *testing.T
|
|
out *stdbuf
|
|
name string
|
|
|
|
runeReader *bufio.Reader
|
|
}
|
|
|
|
func (p *PTY) ExpectMatch(str string) string {
|
|
p.t.Helper()
|
|
|
|
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
|
defer cancel()
|
|
|
|
var buffer bytes.Buffer
|
|
match := make(chan error, 1)
|
|
go func() {
|
|
defer close(match)
|
|
match <- func() error {
|
|
for {
|
|
r, _, err := p.runeReader.ReadRune()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = buffer.WriteRune(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if strings.Contains(buffer.String(), str) {
|
|
return nil
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
|
|
select {
|
|
case err := <-match:
|
|
if err != nil {
|
|
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
|
|
return ""
|
|
}
|
|
p.logf("matched %q = %q", str, buffer.String())
|
|
return buffer.String()
|
|
case <-timeout.Done():
|
|
// Ensure gorouine is cleaned up before test exit.
|
|
_ = p.out.closeErr(p.Close())
|
|
<-match
|
|
|
|
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (p *PTY) Write(r rune) {
|
|
p.t.Helper()
|
|
|
|
p.logf("stdin: %q", r)
|
|
_, err := p.Input().Write([]byte{byte(r)})
|
|
require.NoError(p.t, err, "write failed")
|
|
}
|
|
|
|
func (p *PTY) WriteLine(str string) {
|
|
p.t.Helper()
|
|
|
|
newline := []byte{'\r'}
|
|
if runtime.GOOS == "windows" {
|
|
newline = append(newline, '\n')
|
|
}
|
|
p.logf("stdin: %q", str+string(newline))
|
|
_, err := p.Input().Write(append([]byte(str), newline...))
|
|
require.NoError(p.t, err, "write line failed")
|
|
}
|
|
|
|
func (p *PTY) logf(format string, args ...interface{}) {
|
|
p.t.Helper()
|
|
logf(p.t, p.name, format, args...)
|
|
}
|
|
|
|
func (p *PTY) fatalf(reason string, format string, args ...interface{}) {
|
|
p.t.Helper()
|
|
fatalf(p.t, p.name, reason, format, args...)
|
|
}
|
|
|
|
func logf(t *testing.T, name, format string, args ...interface{}) {
|
|
t.Helper()
|
|
|
|
// Match regular logger timestamp format, we seem to be logging in
|
|
// UTC in other places as well, so match here.
|
|
t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), name, fmt.Sprintf(format, args...))
|
|
}
|
|
|
|
func fatalf(t *testing.T, name, reason, format string, args ...interface{}) {
|
|
t.Helper()
|
|
|
|
// Ensure the message is part of the normal log stream before
|
|
// failing the test.
|
|
logf(t, name, "%s: %s", reason, fmt.Sprintf(format, args...))
|
|
|
|
require.FailNowf(t, reason, format, args...)
|
|
}
|
|
|
|
// stdbuf is like a buffered stdout, it buffers writes until read.
|
|
type stdbuf struct {
|
|
r io.Reader
|
|
|
|
mu sync.Mutex // Protects following.
|
|
b []byte
|
|
more chan struct{}
|
|
err error
|
|
}
|
|
|
|
func newStdbuf() *stdbuf {
|
|
return &stdbuf{more: make(chan struct{}, 1)}
|
|
}
|
|
|
|
func (b *stdbuf) Read(p []byte) (int, error) {
|
|
if b.r == nil {
|
|
return b.readOrWaitForMore(p)
|
|
}
|
|
|
|
n, err := b.r.Read(p)
|
|
if xerrors.Is(err, io.EOF) {
|
|
b.r = nil
|
|
err = nil
|
|
if n == 0 {
|
|
return b.readOrWaitForMore(p)
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
// Deplete channel so that more check
|
|
// is for future input into buffer.
|
|
select {
|
|
case <-b.more:
|
|
default:
|
|
}
|
|
|
|
if len(b.b) == 0 {
|
|
if b.err != nil {
|
|
return 0, b.err
|
|
}
|
|
|
|
b.mu.Unlock()
|
|
<-b.more
|
|
b.mu.Lock()
|
|
}
|
|
|
|
b.r = bytes.NewReader(b.b)
|
|
b.b = b.b[len(b.b):]
|
|
|
|
return b.r.Read(p)
|
|
}
|
|
|
|
func (b *stdbuf) Write(p []byte) (int, error) {
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
if b.err != nil {
|
|
return 0, b.err
|
|
}
|
|
|
|
b.b = append(b.b, p...)
|
|
|
|
select {
|
|
case b.more <- struct{}{}:
|
|
default:
|
|
}
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
func (b *stdbuf) Close() error {
|
|
return b.closeErr(nil)
|
|
}
|
|
|
|
func (b *stdbuf) closeErr(err error) error {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
if b.err != nil {
|
|
return err
|
|
}
|
|
if err == nil {
|
|
b.err = io.EOF
|
|
} else {
|
|
b.err = err
|
|
}
|
|
close(b.more)
|
|
return err
|
|
}
|