mirror of
https://github.com/coder/coder.git
synced 2026-06-05 14:08:20 +00:00
378 lines
12 KiB
Go
378 lines
12 KiB
Go
package messagepartbuffer_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func TestBuffer_CreateEpisodeRejectsDuplicate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.ErrorIs(t, buffer.CreateEpisode(key), messagepartbuffer.ErrEpisodeExists)
|
|
}
|
|
|
|
func TestBuffer_AddPartAndGetParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("hello")))
|
|
|
|
parts, err := buffer.GetParts(key)
|
|
require.NoError(t, err)
|
|
require.Len(t, parts, 1)
|
|
require.Equal(t, int64(1), parts[0].Seq)
|
|
require.Equal(t, codersdk.ChatMessageRoleAssistant, parts[0].Role)
|
|
require.Equal(t, codersdk.ChatMessageText("hello"), parts[0].MessagePart)
|
|
}
|
|
|
|
func TestBuffer_AddPartMissingEpisodeReturnsNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
err := buffer.AddPart(testEpisodeKey(), codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("hello"))
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound)
|
|
}
|
|
|
|
func TestBuffer_GetPartsMissingEpisodeReturnsNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
_, err := buffer.GetParts(testEpisodeKey())
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound)
|
|
}
|
|
|
|
func TestBuffer_AddPartFullEpisodeReturnsFull(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{MaxEpisodeBytes: 1})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
err := buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("hello"))
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeFull)
|
|
parts, getErr := buffer.GetParts(key)
|
|
require.NoError(t, getErr)
|
|
require.Empty(t, parts)
|
|
}
|
|
|
|
func TestBuffer_CloseEpisodeMissingCreatesClosedEpisode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CloseEpisode(key))
|
|
parts, err := buffer.GetParts(key)
|
|
require.NoError(t, err)
|
|
require.Empty(t, parts)
|
|
err = buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("tail"))
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeClosed)
|
|
}
|
|
|
|
func TestBuffer_CloseEpisodeIdempotent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.CloseEpisode(key))
|
|
require.NoError(t, buffer.CloseEpisode(key))
|
|
}
|
|
|
|
func TestBuffer_SubscribeExistingReplaysThenStreamsLiveParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("before")))
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
require.Equal(t, "before", receivePart(t, ch).MessagePart.Text)
|
|
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("after")))
|
|
require.Equal(t, "after", receivePart(t, ch).MessagePart.Text)
|
|
}
|
|
|
|
func TestBuffer_SubscribeClosedEpisodeReplaysThenCloses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("before")))
|
|
require.NoError(t, buffer.CloseEpisode(key))
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
require.Equal(t, "before", receivePart(t, ch).MessagePart.Text)
|
|
assertChannelClosed(t, ch)
|
|
}
|
|
|
|
func TestBuffer_SubscribeBeforeCreateReturnsAndWaitsWithoutNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
|
|
select {
|
|
case part := <-ch:
|
|
t.Fatalf("received part before episode create: %+v", part)
|
|
default:
|
|
}
|
|
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("live")))
|
|
require.Equal(t, "live", receivePart(t, ch).MessagePart.Text)
|
|
}
|
|
|
|
func TestBuffer_AddPartAssignsContiguousSeq(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
for i := range 3 {
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(string(rune('a'+i)))))
|
|
}
|
|
parts, err := buffer.GetParts(key)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []int64{1, 2, 3}, []int64{parts[0].Seq, parts[1].Seq, parts[2].Seq})
|
|
}
|
|
|
|
func TestBuffer_EpisodeByteLimitUsesJSONAccounting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
part := codersdk.ChatMessageText("hello")
|
|
limit := serializedPartBytes(t, messagepartbuffer.Part{Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: part})
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{MaxEpisodeBytes: limit})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, part))
|
|
err := buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("too much"))
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeFull)
|
|
}
|
|
|
|
func TestBuffer_GCClosedEpisodeAfterGraceAndNoSubscribers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().NewTimer("message-part-buffer", "subscriber-send")
|
|
defer trap.Close()
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{
|
|
Clock: clock,
|
|
ClosedEpisodeRetention: time.Minute,
|
|
SubscriberSendTimeout: 10 * time.Minute,
|
|
})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("held")))
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
require.NoError(t, buffer.CloseEpisode(key))
|
|
call := trap.MustWait(ctx)
|
|
call.MustRelease(ctx)
|
|
clock.Advance(time.Minute).MustWait(ctx)
|
|
clock.Advance(time.Second).MustWait(ctx)
|
|
_, err = buffer.GetParts(key)
|
|
require.NoError(t, err)
|
|
|
|
cancel()
|
|
drainUntilClosed(t, ch)
|
|
_, err = buffer.GetParts(key)
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound)
|
|
}
|
|
|
|
func TestBuffer_GCRetainedSubscribedEpisodeDoesNotBlockOtherExpiredEpisodes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().NewTimer("message-part-buffer", "subscriber-send")
|
|
defer trap.Close()
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{
|
|
Clock: clock,
|
|
ClosedEpisodeRetention: time.Minute,
|
|
SubscriberSendTimeout: 10 * time.Minute,
|
|
})
|
|
retainedKey := testEpisodeKey()
|
|
collectedKey := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(retainedKey))
|
|
require.NoError(t, buffer.AddPart(retainedKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("held")))
|
|
require.NoError(t, buffer.CreateEpisode(collectedKey))
|
|
require.NoError(t, buffer.AddPart(collectedKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("collect me")))
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, retainedKey)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
require.NoError(t, buffer.CloseEpisode(retainedKey))
|
|
require.NoError(t, buffer.CloseEpisode(collectedKey))
|
|
call := trap.MustWait(ctx)
|
|
call.MustRelease(ctx)
|
|
clock.Advance(time.Minute).MustWait(ctx)
|
|
clock.Advance(time.Second).MustWait(ctx)
|
|
|
|
_, err = buffer.GetParts(retainedKey)
|
|
require.NoError(t, err)
|
|
_, err = buffer.GetParts(collectedKey)
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound)
|
|
|
|
cancel()
|
|
drainUntilClosed(t, ch)
|
|
_, err = buffer.GetParts(retainedKey)
|
|
require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound)
|
|
}
|
|
|
|
func TestBuffer_SlowSubscriberClosed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().NewTimer("message-part-buffer", "subscriber-send")
|
|
defer trap.Close()
|
|
stopTrap := clock.Trap().TimerStop()
|
|
defer stopTrap.Close()
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{
|
|
Clock: clock,
|
|
SubscriberSendTimeout: time.Second,
|
|
})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("blocked")))
|
|
call := trap.MustWait(ctx)
|
|
call.MustRelease(ctx)
|
|
clock.Advance(time.Second).MustWait(ctx)
|
|
stopCall := stopTrap.MustWait(ctx)
|
|
stopCall.MustRelease(ctx)
|
|
assertChannelClosed(t, ch)
|
|
}
|
|
|
|
func TestBuffer_BurstyOutputDoesNotCloseSubscriberBeforeSendTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{SubscriberChannelSize: 1})
|
|
key := testEpisodeKey()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
|
|
for i := range 8 {
|
|
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(string(rune('a'+i)))))
|
|
}
|
|
for i := range 8 {
|
|
part := receivePart(t, ch)
|
|
require.Equal(t, string(rune('a'+i)), part.MessagePart.Text)
|
|
}
|
|
}
|
|
|
|
func TestBuffer_SubscribeCanceledBeforeCreateCanCreateEpisode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ch, cancelSub, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
cancel()
|
|
drainUntilClosed(t, ch)
|
|
cancelSub()
|
|
require.NoError(t, buffer.CreateEpisode(key))
|
|
}
|
|
|
|
func TestBuffer_CloseClosesPendingSubscriptionAndRejectsOperations(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
|
|
key := testEpisodeKey()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ch, cancel, err := buffer.SubscribeToEpisode(ctx, key)
|
|
require.NoError(t, err)
|
|
defer cancel()
|
|
buffer.Close()
|
|
assertChannelClosed(t, ch)
|
|
require.ErrorIs(t, buffer.CreateEpisode(key), messagepartbuffer.ErrMessagePartBufferClosed)
|
|
}
|
|
|
|
func testEpisodeKey() messagepartbuffer.Key {
|
|
return messagepartbuffer.Key{ChatID: uuid.New(), HistoryVersion: 1, GenerationAttempt: 1}
|
|
}
|
|
|
|
func receivePart(t *testing.T, ch <-chan messagepartbuffer.Part) messagepartbuffer.Part {
|
|
t.Helper()
|
|
select {
|
|
case part, ok := <-ch:
|
|
require.True(t, ok)
|
|
return part
|
|
case <-time.After(testutil.WaitLong):
|
|
t.Fatal("timed out waiting for buffered part")
|
|
return messagepartbuffer.Part{}
|
|
}
|
|
}
|
|
|
|
func assertChannelClosed[T any](t *testing.T, ch <-chan T) {
|
|
t.Helper()
|
|
select {
|
|
case _, ok := <-ch:
|
|
require.False(t, ok)
|
|
case <-time.After(testutil.WaitLong):
|
|
t.Fatal("timed out waiting for channel close")
|
|
}
|
|
}
|
|
|
|
func drainUntilClosed[T any](t *testing.T, ch <-chan T) {
|
|
t.Helper()
|
|
for {
|
|
select {
|
|
case _, ok := <-ch:
|
|
if !ok {
|
|
return
|
|
}
|
|
case <-time.After(testutil.WaitLong):
|
|
t.Fatal("timed out waiting for channel close")
|
|
}
|
|
}
|
|
}
|
|
|
|
func serializedPartBytes(t *testing.T, part messagepartbuffer.Part) int64 {
|
|
t.Helper()
|
|
data, err := json.Marshal(struct {
|
|
Seq int64 `json:"seq"`
|
|
Role codersdk.ChatMessageRole `json:"role"`
|
|
Part codersdk.ChatMessagePart `json:"part"`
|
|
}{
|
|
Seq: part.Seq,
|
|
Role: part.Role,
|
|
Part: part.MessagePart,
|
|
})
|
|
require.NoError(t, err)
|
|
return int64(len(data))
|
|
}
|