fix: Deadlock and race in peer, test improvements (#3086)

* fix: Potential deadlock in peer.Channel dc.OnOpen

* fix: Potential send on closed channel

* fix: Improve robustness of waitOpened during close

* chore: Simplify statements

* fix: Improve teardown and timeout of peer tests

* fix: Improve robustness of TestConn/Buffering test

* Update peer/channel.go

Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
This commit is contained in:
Mathias Fredriksson
2022-07-21 18:47:17 +03:00
committed by GitHub
parent 62e685669f
commit e33a74975e
3 changed files with 87 additions and 42 deletions
+11 -9
View File
@@ -106,12 +106,15 @@ func (c *Channel) init() {
// write operations to block once the threshold is set.
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
c.dc.OnBufferedAmountLow(func() {
// Grab the lock to protect the sendMore channel from being
// closed in between the isClosed check and the send.
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return
}
select {
case <-c.closed:
return
case c.sendMore <- struct{}{}:
default:
}
@@ -122,15 +125,16 @@ func (c *Channel) init() {
})
c.dc.OnOpen(func() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
var err error
c.rwc, err = c.dc.Detach()
if err != nil {
c.closeMutex.Unlock()
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
return
}
c.closeMutex.Unlock()
// pion/webrtc will return an io.ErrShortBuffer when a read
// is triggerred with a buffer size less than the chunks written.
//
@@ -189,9 +193,6 @@ func (c *Channel) init() {
//
// This will block until the underlying DataChannel has been opened.
func (c *Channel) Read(bytes []byte) (int, error) {
if c.isClosed() {
return 0, c.closeError
}
err := c.waitOpened()
if err != nil {
return 0, err
@@ -228,9 +229,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
if c.isClosed() {
return 0, c.closeWithError(nil)
}
err = c.waitOpened()
if err != nil {
return 0, err
@@ -308,6 +306,10 @@ func (c *Channel) isClosed() bool {
func (c *Channel) waitOpened() error {
select {
case <-c.opened:
// Re-check the closed channel to prioritize closure.
if c.isClosed() {
return c.closeError
}
return nil
case <-c.closed:
return c.closeError
-6
View File
@@ -3,7 +3,6 @@ package peer
import (
"bytes"
"context"
"crypto/rand"
"io"
"sync"
@@ -256,7 +255,6 @@ func (c *Conn) init() error {
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
select {
case <-c.closed:
break
case c.localCandidateChannel <- iceCandidate.ToJSON():
}
}()
@@ -265,7 +263,6 @@ func (c *Conn) init() error {
go func() {
select {
case <-c.closed:
return
case c.dcOpenChannel <- dc:
}
}()
@@ -435,9 +432,6 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
data := make([]byte, pingDataLength)
bytesRead, err := c.pingEchoChan.Read(data)
if err != nil {
if c.isClosed() {
return
}
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
return
}
+76 -27
View File
@@ -91,6 +91,8 @@ func TestConn(t *testing.T) {
// Create a channel that closes on disconnect.
channel, err := server.CreateChannel(context.Background(), "wow", nil)
assert.NoError(t, err)
defer channel.Close()
err = wan.Stop()
require.NoError(t, err)
// Once the connection is marked as disconnected, this
@@ -107,10 +109,13 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
defer cch.Close()
sch, err := server.Accept(context.Background())
sch, err := server.Accept(ctx)
require.NoError(t, err)
defer sch.Close()
@@ -123,9 +128,12 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
defer cch.Close()
sch, err := server.Accept(ctx)
require.NoError(t, err)
defer sch.Close()
@@ -140,26 +148,44 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
defer sch.Close()
defer cch.Close()
readErr := make(chan error, 1)
go func() {
bytes := make([]byte, 4096)
for i := 0; i < 1024; i++ {
_, err := cch.Write(bytes)
require.NoError(t, err)
}
_ = cch.Close()
}()
bytes := make([]byte, 4096)
for {
_, err = sch.Read(bytes)
sch, err := server.Accept(ctx)
if err != nil {
require.ErrorIs(t, err, peer.ErrClosed)
break
readErr <- err
_ = cch.Close()
return
}
defer sch.Close()
bytes := make([]byte, 4096)
for {
_, err = sch.Read(bytes)
if err != nil {
readErr <- err
return
}
}
}()
bytes := make([]byte, 4096)
for i := 0; i < 1024; i++ {
_, err = cch.Write(bytes)
require.NoError(t, err, "write i=%d", i)
}
_ = cch.Close()
select {
case err = <-readErr:
require.ErrorIs(t, err, peer.ErrClosed, "read error")
case <-ctx.Done():
require.Fail(t, "timeout waiting for read error")
}
})
@@ -170,13 +196,29 @@ func TestConn(t *testing.T) {
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
go func() {
sch, err := server.Accept(context.Background())
assert.NoError(t, err)
sch, err := server.Accept(ctx)
if err != nil {
assert.NoError(t, err)
return
}
defer sch.Close()
nc2 := sch.NetConn()
defer nc2.Close()
nc1, err := net.Dial("tcp", srv.Addr().String())
assert.NoError(t, err)
if err != nil {
assert.NoError(t, err)
return
}
defer nc1.Close()
go func() {
defer nc1.Close()
defer nc2.Close()
_, _ = io.Copy(nc1, nc2)
}()
_, _ = io.Copy(nc2, nc1)
@@ -204,7 +246,7 @@ func TestConn(t *testing.T) {
c := http.Client{
Transport: defaultTransport,
}
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
@@ -272,14 +314,21 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
go func() {
channel, err := client.CreateChannel(context.Background(), "test", nil)
assert.NoError(t, err)
channel, err := client.CreateChannel(ctx, "test", nil)
if err != nil {
assert.NoError(t, err)
return
}
defer channel.Close()
_, err = channel.Write([]byte{1, 2})
assert.NoError(t, err)
}()
channel, err := server.Accept(context.Background())
channel, err := server.Accept(ctx)
require.NoError(t, err)
defer channel.Close()
data := make([]byte, 1)
_, err = channel.Read(data)
require.NoError(t, err)