mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: extract Expecter into its own package (#25806)
Relates to https://github.com/coder/internal/issues/1400 Extracts the code that matches command output from the code that sets up a PTY, so it can be used independently. Subsequent PRs will actually refactor the tests to use this directly over an inmemory pipe.<!-- If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting. -->
This commit is contained in:
@@ -0,0 +1,346 @@
|
||||
package expecter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/acarl005/stripansi"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func New(t *testing.T, r io.Reader, name string) *Expecter {
|
||||
// Use pipe for logging.
|
||||
logDone := make(chan struct{})
|
||||
logr, logw := io.Pipe()
|
||||
|
||||
// Write to log and output buffer.
|
||||
copyDone := make(chan struct{})
|
||||
out := newStdbuf()
|
||||
w := io.MultiWriter(logw, out)
|
||||
|
||||
ex := &Expecter{
|
||||
t: t,
|
||||
out: out,
|
||||
name: atomic.NewString(name),
|
||||
|
||||
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
|
||||
logDone: logDone,
|
||||
copyDone: copyDone,
|
||||
logr: logr,
|
||||
logw: logw,
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(copyDone)
|
||||
_, err := io.Copy(w, r)
|
||||
ex.Logf("copy done: %v", err)
|
||||
ex.Logf("closing out")
|
||||
err = out.closeErr(err)
|
||||
ex.Logf("closed out: %v", err)
|
||||
}()
|
||||
|
||||
// Log all output as part of test for easier debugging on errors.
|
||||
go func() {
|
||||
defer close(logDone)
|
||||
s := bufio.NewScanner(logr)
|
||||
for s.Scan() {
|
||||
ex.Logf("%q", stripansi.Strip(s.Text()))
|
||||
}
|
||||
// Surface non-EOF scanner errors; otherwise they're invisible.
|
||||
if err := s.Err(); err != nil {
|
||||
ex.Logf("log scanner stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ex
|
||||
}
|
||||
|
||||
type Expecter struct {
|
||||
t *testing.T
|
||||
out *stdbuf
|
||||
name *atomic.String
|
||||
|
||||
runeReader *bufio.Reader
|
||||
copyDone, logDone chan struct{}
|
||||
logr, logw io.Closer
|
||||
}
|
||||
|
||||
// Rename the expecter. Make sure you set this before anything starts writing to the
|
||||
// stream, or it may not be named consistently.
|
||||
func (e *Expecter) Rename(name string) {
|
||||
e.name.Store(name)
|
||||
}
|
||||
|
||||
func (e *Expecter) Close(reason string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
e.Logf("closing expecter: %s", reason)
|
||||
|
||||
// Caller needs to have closed the stream so that copying can complete
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
e.fatalf("close", "copy did not close in time")
|
||||
case <-e.copyDone:
|
||||
}
|
||||
|
||||
e.logClose("logw", e.logw)
|
||||
e.logClose("logr", e.logr)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
e.fatalf("close", "log pipe did not close in time")
|
||||
case <-e.logDone:
|
||||
}
|
||||
|
||||
e.Logf("closed expecter")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Expecter) logClose(name string, c io.Closer) {
|
||||
e.Logf("closing %s", name)
|
||||
err := c.Close()
|
||||
e.Logf("closed %s: %v", name, err)
|
||||
}
|
||||
|
||||
// Deprecated: use ExpectMatchContext instead.
|
||||
// This uses a background context, so will not respect the test's context.
|
||||
func (e *Expecter) ExpectMatch(str string) string {
|
||||
return e.expectMatchContextFunc(str, e.ExpectMatchContext)
|
||||
}
|
||||
|
||||
func (e *Expecter) ExpectRegexMatch(str string) string {
|
||||
return e.expectMatchContextFunc(str, e.ExpectRegexMatchContext)
|
||||
}
|
||||
|
||||
func (e *Expecter) expectMatchContextFunc(str string, fn func(ctx context.Context, str string) string) string {
|
||||
e.t.Helper()
|
||||
|
||||
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
return fn(timeout, str)
|
||||
}
|
||||
|
||||
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
|
||||
func (e *Expecter) ExpectMatchContext(ctx context.Context, str string) string {
|
||||
return e.expectMatcherFunc(ctx, str, strings.Contains)
|
||||
}
|
||||
|
||||
func (e *Expecter) ExpectRegexMatchContext(ctx context.Context, str string) string {
|
||||
return e.expectMatcherFunc(ctx, str, func(src, pattern string) bool {
|
||||
return regexp.MustCompile(pattern).MatchString(src)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Expecter) expectMatcherFunc(ctx context.Context, str string, fn func(src, pattern string) bool) string {
|
||||
e.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func(rd *bufio.Reader) error {
|
||||
for {
|
||||
r, _, err := rd.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fn(buffer.String(), str) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
|
||||
return ""
|
||||
}
|
||||
e.Logf("matched %q = %q", str, buffer.String())
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
// ExpectNoMatchBefore validates that `match` does not occur before `before`.
|
||||
func (e *Expecter) ExpectNoMatchBefore(ctx context.Context, match, before string) string {
|
||||
e.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := e.doMatchWithDeadline(ctx, "ExpectNoMatchBefore", func(rd *bufio.Reader) error {
|
||||
for {
|
||||
r, _, err := rd.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.Contains(buffer.String(), match) {
|
||||
return xerrors.Errorf("found %q before %q", match, before)
|
||||
}
|
||||
|
||||
if strings.Contains(buffer.String(), before) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
e.fatalf("read error", "%v (wanted no %q before %q; got %q)", err, match, before, buffer.String())
|
||||
return ""
|
||||
}
|
||||
e.Logf("matched %q = %q", before, stripansi.Strip(buffer.String()))
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (e *Expecter) Peek(ctx context.Context, n int) []byte {
|
||||
e.t.Helper()
|
||||
|
||||
var out []byte
|
||||
err := e.doMatchWithDeadline(ctx, "Peek", func(rd *bufio.Reader) error {
|
||||
var err error
|
||||
out, err = rd.Peek(n)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
|
||||
return nil
|
||||
}
|
||||
e.Logf("peeked %d/%d bytes = %q", len(out), n, out)
|
||||
return slices.Clone(out)
|
||||
}
|
||||
|
||||
//nolint:govet // We don't care about conforming to ReadRune() (rune, int, error).
|
||||
func (e *Expecter) ReadRune(ctx context.Context) rune {
|
||||
e.t.Helper()
|
||||
|
||||
var r rune
|
||||
err := e.doMatchWithDeadline(ctx, "ReadRune", func(rd *bufio.Reader) error {
|
||||
var err error
|
||||
r, _, err = rd.ReadRune()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
e.fatalf("read error", "%v (wanted rune; got %q)", err, r)
|
||||
return 0
|
||||
}
|
||||
e.Logf("matched rune = %q", r)
|
||||
return r
|
||||
}
|
||||
|
||||
func (e *Expecter) ReadLine(ctx context.Context) string {
|
||||
e.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := e.doMatchWithDeadline(ctx, "ReadLine", func(rd *bufio.Reader) error {
|
||||
for {
|
||||
r, _, err := rd.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r == '\n' {
|
||||
return nil
|
||||
}
|
||||
if r == '\r' {
|
||||
// Peek the next rune to see if it's an LF and then consume
|
||||
// it.
|
||||
|
||||
// Unicode code points can be up to 4 bytes, but the
|
||||
// ones we're looking for are only 1 byte.
|
||||
b, _ := rd.Peek(1)
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, _ = utf8.DecodeRune(b)
|
||||
if r == '\n' {
|
||||
_, _, err = rd.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
|
||||
return ""
|
||||
}
|
||||
e.Logf("matched newline = %q", buffer.String())
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (e *Expecter) ReadAll() []byte {
|
||||
e.t.Helper()
|
||||
return e.out.ReadAll()
|
||||
}
|
||||
|
||||
func (e *Expecter) doMatchWithDeadline(ctx context.Context, name string, fn func(*bufio.Reader) error) error {
|
||||
e.t.Helper()
|
||||
|
||||
// A timeout is mandatory, caller can decide by passing a context
|
||||
// that times out.
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
timeout := testutil.WaitMedium
|
||||
e.Logf("%s ctx has no deadline, using %s", name, timeout)
|
||||
var cancel context.CancelFunc
|
||||
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
match := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(match)
|
||||
match <- fn(e.runeReader)
|
||||
}()
|
||||
select {
|
||||
case err := <-match:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
// Ensure goroutine is cleaned up before test exit, do not call
|
||||
// (*outExpecter).close here to let the caller decide.
|
||||
_ = e.out.Close()
|
||||
<-match
|
||||
|
||||
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Expecter) Logf(format string, args ...interface{}) {
|
||||
e.t.Helper()
|
||||
|
||||
// Match regular logger timestamp format, we seem to be logging in
|
||||
// UTC in other places as well, so match here.
|
||||
e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name.Load(), fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (e *Expecter) fatalf(reason string, format string, args ...interface{}) {
|
||||
e.t.Helper()
|
||||
|
||||
// Ensure the message is part of the normal log stream before
|
||||
// failing the test.
|
||||
e.Logf("%s: %s", reason, fmt.Sprintf(format, args...))
|
||||
|
||||
require.FailNowf(e.t, reason, format, args...)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package expecter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// stdbuf is like a buffered stdout, it buffers writes until read.
|
||||
type stdbuf struct {
|
||||
r io.Reader
|
||||
|
||||
mu sync.Mutex // Protects following.
|
||||
b []byte
|
||||
more chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newStdbuf() *stdbuf {
|
||||
return &stdbuf{more: make(chan struct{}, 1)}
|
||||
}
|
||||
|
||||
func (b *stdbuf) ReadAll() []byte {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.err != nil {
|
||||
return nil
|
||||
}
|
||||
p := append([]byte(nil), b.b...)
|
||||
b.b = b.b[len(b.b):]
|
||||
return p
|
||||
}
|
||||
|
||||
func (b *stdbuf) Read(p []byte) (int, error) {
|
||||
if b.r == nil {
|
||||
return b.readOrWaitForMore(p)
|
||||
}
|
||||
|
||||
n, err := b.r.Read(p)
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
b.r = nil
|
||||
err = nil
|
||||
if n == 0 {
|
||||
return b.readOrWaitForMore(p)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Deplete channel so that more check
|
||||
// is for future input into buffer.
|
||||
select {
|
||||
case <-b.more:
|
||||
default:
|
||||
}
|
||||
|
||||
if len(b.b) == 0 {
|
||||
if b.err != nil {
|
||||
return 0, b.err
|
||||
}
|
||||
|
||||
b.mu.Unlock()
|
||||
<-b.more
|
||||
b.mu.Lock()
|
||||
}
|
||||
|
||||
b.r = bytes.NewReader(b.b)
|
||||
b.b = b.b[len(b.b):]
|
||||
|
||||
return b.r.Read(p)
|
||||
}
|
||||
|
||||
func (b *stdbuf) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.err != nil {
|
||||
return 0, b.err
|
||||
}
|
||||
|
||||
b.b = append(b.b, p...)
|
||||
|
||||
select {
|
||||
case b.more <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (b *stdbuf) Close() error {
|
||||
return b.closeErr(nil)
|
||||
}
|
||||
|
||||
func (b *stdbuf) closeErr(err error) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.err != nil {
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
b.err = io.EOF
|
||||
} else {
|
||||
b.err = err
|
||||
}
|
||||
close(b.more)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package expecter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStdbuf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var got bytes.Buffer
|
||||
|
||||
b := newStdbuf()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, err := io.Copy(&got, b)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
_, err := b.Write([]byte("hello "))
|
||||
require.NoError(t, err)
|
||||
_, err = b.Write([]byte("world\n"))
|
||||
require.NoError(t, err)
|
||||
_, err = b.Write([]byte("bye\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = b.Close()
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
|
||||
assert.Equal(t, "hello world\nbye\n", got.String())
|
||||
}
|
||||
Reference in New Issue
Block a user