mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: Deadlock and race in peer, test improvements (#3086)
* fix: Potential deadlock in peer.Channel dc.OnOpen * fix: Potential send on closed channel * fix: Improve robustness of waitOpened during close * chore: Simplify statements * fix: Improve teardown and timeout of peer tests * fix: Improve robustness of TestConn/Buffering test * Update peer/channel.go Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
62e685669f
commit
e33a74975e
+11
-9
@@ -106,12 +106,15 @@ func (c *Channel) init() {
|
||||
// write operations to block once the threshold is set.
|
||||
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
|
||||
c.dc.OnBufferedAmountLow(func() {
|
||||
// Grab the lock to protect the sendMore channel from being
|
||||
// closed in between the isClosed check and the send.
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.sendMore <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
@@ -122,15 +125,16 @@ func (c *Channel) init() {
|
||||
})
|
||||
c.dc.OnOpen(func() {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
|
||||
var err error
|
||||
c.rwc, err = c.dc.Detach()
|
||||
if err != nil {
|
||||
c.closeMutex.Unlock()
|
||||
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
|
||||
return
|
||||
}
|
||||
c.closeMutex.Unlock()
|
||||
|
||||
// pion/webrtc will return an io.ErrShortBuffer when a read
|
||||
// is triggerred with a buffer size less than the chunks written.
|
||||
//
|
||||
@@ -189,9 +193,6 @@ func (c *Channel) init() {
|
||||
//
|
||||
// This will block until the underlying DataChannel has been opened.
|
||||
func (c *Channel) Read(bytes []byte) (int, error) {
|
||||
if c.isClosed() {
|
||||
return 0, c.closeError
|
||||
}
|
||||
err := c.waitOpened()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -228,9 +229,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) {
|
||||
c.writeMutex.Lock()
|
||||
defer c.writeMutex.Unlock()
|
||||
|
||||
if c.isClosed() {
|
||||
return 0, c.closeWithError(nil)
|
||||
}
|
||||
err = c.waitOpened()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -308,6 +306,10 @@ func (c *Channel) isClosed() bool {
|
||||
func (c *Channel) waitOpened() error {
|
||||
select {
|
||||
case <-c.opened:
|
||||
// Re-check the closed channel to prioritize closure.
|
||||
if c.isClosed() {
|
||||
return c.closeError
|
||||
}
|
||||
return nil
|
||||
case <-c.closed:
|
||||
return c.closeError
|
||||
|
||||
@@ -3,7 +3,6 @@ package peer
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"sync"
|
||||
@@ -256,7 +255,6 @@ func (c *Conn) init() error {
|
||||
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
|
||||
select {
|
||||
case <-c.closed:
|
||||
break
|
||||
case c.localCandidateChannel <- iceCandidate.ToJSON():
|
||||
}
|
||||
}()
|
||||
@@ -265,7 +263,6 @@ func (c *Conn) init() error {
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.dcOpenChannel <- dc:
|
||||
}
|
||||
}()
|
||||
@@ -435,9 +432,6 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
|
||||
data := make([]byte, pingDataLength)
|
||||
bytesRead, err := c.pingEchoChan.Read(data)
|
||||
if err != nil {
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
+76
-27
@@ -91,6 +91,8 @@ func TestConn(t *testing.T) {
|
||||
// Create a channel that closes on disconnect.
|
||||
channel, err := server.CreateChannel(context.Background(), "wow", nil)
|
||||
assert.NoError(t, err)
|
||||
defer channel.Close()
|
||||
|
||||
err = wan.Stop()
|
||||
require.NoError(t, err)
|
||||
// Once the connection is marked as disconnected, this
|
||||
@@ -107,10 +109,13 @@ func TestConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
defer cch.Close()
|
||||
|
||||
sch, err := server.Accept(context.Background())
|
||||
sch, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sch.Close()
|
||||
|
||||
@@ -123,9 +128,12 @@ func TestConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, wan := createPair(t)
|
||||
exchange(t, client, server)
|
||||
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
sch, err := server.Accept(context.Background())
|
||||
defer cch.Close()
|
||||
sch, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sch.Close()
|
||||
|
||||
@@ -140,26 +148,44 @@ func TestConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
sch, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer sch.Close()
|
||||
defer cch.Close()
|
||||
|
||||
readErr := make(chan error, 1)
|
||||
go func() {
|
||||
bytes := make([]byte, 4096)
|
||||
for i := 0; i < 1024; i++ {
|
||||
_, err := cch.Write(bytes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
_ = cch.Close()
|
||||
}()
|
||||
bytes := make([]byte, 4096)
|
||||
for {
|
||||
_, err = sch.Read(bytes)
|
||||
sch, err := server.Accept(ctx)
|
||||
if err != nil {
|
||||
require.ErrorIs(t, err, peer.ErrClosed)
|
||||
break
|
||||
readErr <- err
|
||||
_ = cch.Close()
|
||||
return
|
||||
}
|
||||
defer sch.Close()
|
||||
|
||||
bytes := make([]byte, 4096)
|
||||
for {
|
||||
_, err = sch.Read(bytes)
|
||||
if err != nil {
|
||||
readErr <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
bytes := make([]byte, 4096)
|
||||
for i := 0; i < 1024; i++ {
|
||||
_, err = cch.Write(bytes)
|
||||
require.NoError(t, err, "write i=%d", i)
|
||||
}
|
||||
_ = cch.Close()
|
||||
|
||||
select {
|
||||
case err = <-readErr:
|
||||
require.ErrorIs(t, err, peer.ErrClosed, "read error")
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "timeout waiting for read error")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -170,13 +196,29 @@ func TestConn(t *testing.T) {
|
||||
srv, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
go func() {
|
||||
sch, err := server.Accept(context.Background())
|
||||
assert.NoError(t, err)
|
||||
sch, err := server.Accept(ctx)
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
defer sch.Close()
|
||||
|
||||
nc2 := sch.NetConn()
|
||||
defer nc2.Close()
|
||||
|
||||
nc1, err := net.Dial("tcp", srv.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
defer nc1.Close()
|
||||
|
||||
go func() {
|
||||
defer nc1.Close()
|
||||
defer nc2.Close()
|
||||
_, _ = io.Copy(nc1, nc2)
|
||||
}()
|
||||
_, _ = io.Copy(nc2, nc1)
|
||||
@@ -204,7 +246,7 @@ func TestConn(t *testing.T) {
|
||||
c := http.Client{
|
||||
Transport: defaultTransport,
|
||||
}
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := c.Do(req)
|
||||
require.NoError(t, err)
|
||||
@@ -272,14 +314,21 @@ func TestConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
go func() {
|
||||
channel, err := client.CreateChannel(context.Background(), "test", nil)
|
||||
assert.NoError(t, err)
|
||||
channel, err := client.CreateChannel(ctx, "test", nil)
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
defer channel.Close()
|
||||
_, err = channel.Write([]byte{1, 2})
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
channel, err := server.Accept(context.Background())
|
||||
channel, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer channel.Close()
|
||||
data := make([]byte, 1)
|
||||
_, err = channel.Read(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user