Files
coder/coderd/x/chatd/messagepartbuffer/message_part_buffer.go
T
Hugo Dutka 658a04d28f pr 3
2026-06-04 18:51:22 +00:00

493 lines
12 KiB
Go

package messagepartbuffer
import (
"container/heap"
"context"
"encoding/json"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const (
defaultMaxEpisodeBytes = int64(1024 * 1024)
defaultClosedEpisodeRetention = 15 * time.Second
defaultSubscriberSendTimeout = 10 * time.Second
defaultSubscriberChannelSize = 16
)
var (
// ErrEpisodeExists means the episode already exists.
ErrEpisodeExists = xerrors.New("message part episode already exists")
// ErrEpisodeNotFound means the episode has not been created.
ErrEpisodeNotFound = xerrors.New("message part episode not found")
// ErrEpisodeClosed means the episode no longer accepts parts.
ErrEpisodeClosed = xerrors.New("message part episode closed")
// ErrEpisodeFull means the episode byte limit would be exceeded.
ErrEpisodeFull = xerrors.New("message part episode full")
// ErrMessagePartBufferClosed means the whole buffer is closed.
ErrMessagePartBufferClosed = xerrors.New("message part buffer closed")
)
// Key identifies a buffered message part episode.
type Key struct {
ChatID uuid.UUID
HistoryVersion int64
GenerationAttempt int64
}
// Part is a buffered chat message part with its sequence number.
type Part struct {
Seq int64
Role codersdk.ChatMessageRole
MessagePart codersdk.ChatMessagePart
}
type partJSON struct {
Seq int64 `json:"seq"`
Role codersdk.ChatMessageRole `json:"role"`
Part codersdk.ChatMessagePart `json:"part"`
}
func (p Part) jsonValue() partJSON {
return partJSON{
Seq: p.Seq,
Role: p.Role,
Part: p.MessagePart,
}
}
// Options configures a Buffer.
type Options struct {
MaxEpisodeBytes int64
ClosedEpisodeRetention time.Duration
SubscriberSendTimeout time.Duration
SubscriberChannelSize int
Clock quartz.Clock
}
// Buffer stores streamed message parts by episode.
type Buffer struct {
mu sync.Mutex
opts Options
episodes map[Key]*episodeState
closedEpisodes closedEpisodeHeap
closed bool
done chan struct{}
}
type episodeState struct {
created bool
createdCh chan struct{}
closed bool
closedAt time.Time
closedHeapItem *closedEpisodeItem
parts []Part
bytes int64
subscribers map[*episodeSubscriber]struct{}
}
type closedEpisodeItem struct {
key Key
closedAt time.Time
}
type closedEpisodeHeap []*closedEpisodeItem
func (h closedEpisodeHeap) Len() int {
return len(h)
}
func (h closedEpisodeHeap) Less(i, j int) bool {
return h[i].closedAt.Before(h[j].closedAt)
}
func (h closedEpisodeHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *closedEpisodeHeap) Push(value any) {
item, ok := value.(*closedEpisodeItem)
if !ok {
panic("closed episode heap received invalid item")
}
*h = append(*h, item)
}
func (h *closedEpisodeHeap) Pop() any {
old := *h
last := old[len(old)-1]
old[len(old)-1] = nil
*h = old[:len(old)-1]
return last
}
type episodeSubscriber struct {
out chan Part
notifyCh chan struct{}
stopCh chan struct{}
next int
stopOnce sync.Once
}
// New returns a message part buffer.
func New(options Options) *Buffer {
if options.MaxEpisodeBytes <= 0 {
options.MaxEpisodeBytes = defaultMaxEpisodeBytes
}
if options.ClosedEpisodeRetention <= 0 {
options.ClosedEpisodeRetention = defaultClosedEpisodeRetention
}
if options.SubscriberSendTimeout <= 0 {
options.SubscriberSendTimeout = defaultSubscriberSendTimeout
}
if options.SubscriberChannelSize < 0 {
options.SubscriberChannelSize = 0
}
if options.SubscriberChannelSize == 0 {
options.SubscriberChannelSize = defaultSubscriberChannelSize
}
if options.Clock == nil {
options.Clock = quartz.NewReal()
}
buffer := &Buffer{
opts: options,
episodes: make(map[Key]*episodeState),
done: make(chan struct{}),
}
buffer.startCleanupLoop()
return buffer
}
// CreateEpisode creates a new episode.
func (b *Buffer) CreateEpisode(key Key) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrMessagePartBufferClosed
}
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "create"))
episode := b.episodeLocked(key)
if episode.created {
return ErrEpisodeExists
}
episode.created = true
close(episode.createdCh)
return nil
}
// AddPart appends a part to an existing episode.
func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrMessagePartBufferClosed
}
episode := b.episodes[key]
if episode == nil || !episode.created {
return ErrEpisodeNotFound
}
if episode.closed {
return ErrEpisodeClosed
}
buffered := Part{
Seq: int64(len(episode.parts) + 1),
Role: role,
MessagePart: part,
}
sizeBytes, err := serializedPartBytes(buffered)
if err != nil {
return err
}
if episode.bytes+sizeBytes > b.opts.MaxEpisodeBytes {
return ErrEpisodeFull
}
episode.parts = append(episode.parts, buffered)
episode.bytes += sizeBytes
for subscriber := range episode.subscribers {
notifySubscriber(subscriber)
}
return nil
}
// CloseEpisode marks an episode closed and closes its subscribers.
func (b *Buffer) CloseEpisode(key Key) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrMessagePartBufferClosed
}
episode := b.episodeLocked(key)
if !episode.created {
episode.created = true
close(episode.createdCh)
}
if episode.closed {
return nil
}
episode.closed = true
episode.closedAt = b.opts.Clock.Now("message-part-buffer", "close")
b.queueClosedEpisodeLocked(key, episode)
for subscriber := range episode.subscribers {
notifySubscriber(subscriber)
}
return nil
}
// GetParts returns a snapshot of buffered parts for an episode.
func (b *Buffer) GetParts(key Key) ([]Part, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return nil, ErrMessagePartBufferClosed
}
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "get"))
episode := b.episodes[key]
if episode == nil || !episode.created {
return nil, ErrEpisodeNotFound
}
return append([]Part(nil), episode.parts...), nil
}
// SubscribeToEpisode replays existing parts and streams new parts.
func (b *Buffer) SubscribeToEpisode(ctx context.Context, key Key) (<-chan Part, func(), error) {
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return nil, nil, ErrMessagePartBufferClosed
}
episode := b.episodeLocked(key)
subscriber := &episodeSubscriber{
out: make(chan Part),
notifyCh: make(chan struct{}, 1),
stopCh: make(chan struct{}),
}
if episode.subscribers == nil {
episode.subscribers = make(map[*episodeSubscriber]struct{})
}
episode.subscribers[subscriber] = struct{}{}
notifySubscriber(subscriber)
b.mu.Unlock()
go b.deliverSubscriber(ctx, key, subscriber)
cancel := func() {
b.cancelSubscriber(key, subscriber)
}
return subscriber.out, cancel, nil
}
// Close closes the buffer and all pending subscriptions.
func (b *Buffer) Close() {
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return
}
b.closed = true
close(b.done)
for _, episode := range b.episodes {
for subscriber := range episode.subscribers {
b.stopSubscriberLocked(episode, subscriber)
}
if !episode.created {
episode.created = true
close(episode.createdCh)
}
}
b.mu.Unlock()
}
func (b *Buffer) startCleanupLoop() {
ticker := b.opts.Clock.NewTicker(b.opts.ClosedEpisodeRetention, "message-part-buffer", "cleanup")
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return
}
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "cleanup"))
b.mu.Unlock()
case <-b.done:
return
}
}
}()
}
func (b *Buffer) gcClosedEpisodesLocked(now time.Time) {
cutoff := now.Add(-b.opts.ClosedEpisodeRetention)
type retainedEpisode struct {
key Key
episode *episodeState
}
retained := make([]retainedEpisode, 0)
for b.closedEpisodes.Len() > 0 {
item := b.closedEpisodes[0]
if item.closedAt.After(cutoff) {
break
}
popped, ok := heap.Pop(&b.closedEpisodes).(*closedEpisodeItem)
if !ok || popped != item {
continue
}
episode := b.episodes[item.key]
if episode == nil || episode.closedHeapItem != item || !episode.closed {
continue
}
episode.closedHeapItem = nil
if len(episode.subscribers) > 0 {
retained = append(retained, retainedEpisode{key: item.key, episode: episode})
continue
}
delete(b.episodes, item.key)
}
for _, item := range retained {
if b.episodes[item.key] != item.episode || !item.episode.closed || item.episode.closedHeapItem != nil {
continue
}
b.queueClosedEpisodeLocked(item.key, item.episode)
}
}
func (b *Buffer) queueClosedEpisodeLocked(key Key, episode *episodeState) {
if episode.closedHeapItem != nil {
return
}
item := &closedEpisodeItem{key: key, closedAt: episode.closedAt}
episode.closedHeapItem = item
heap.Push(&b.closedEpisodes, item)
}
func (b *Buffer) episodeLocked(key Key) *episodeState {
episode := b.episodes[key]
if episode != nil {
return episode
}
episode = &episodeState{createdCh: make(chan struct{})}
b.episodes[key] = episode
return episode
}
func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts []Part, closed bool, ok bool) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return nil, false, false
}
episode := b.episodes[key]
if episode == nil {
return nil, false, false
}
if !episode.created {
return nil, false, true
}
if subscriber.next > len(episode.parts) {
return nil, false, false
}
parts = append([]Part(nil), episode.parts[subscriber.next:]...)
subscriber.next = len(episode.parts)
return parts, episode.closed && subscriber.next == len(episode.parts), true
}
func (b *Buffer) deliverSubscriber(ctx context.Context, key Key, subscriber *episodeSubscriber) {
defer close(subscriber.out)
defer b.removeSubscriber(key, subscriber)
for {
parts, closed, ok := b.subscriberParts(key, subscriber)
if !ok {
return
}
for _, part := range parts {
if !b.sendSubscriberPart(ctx, subscriber, part) {
return
}
}
if closed {
return
}
select {
case <-subscriber.notifyCh:
case <-subscriber.stopCh:
return
case <-ctx.Done():
return
case <-b.done:
return
}
}
}
func (b *Buffer) sendSubscriberPart(ctx context.Context, subscriber *episodeSubscriber, part Part) bool {
timer := b.opts.Clock.NewTimer(b.opts.SubscriberSendTimeout, "message-part-buffer", "subscriber-send")
defer timer.Stop()
select {
case subscriber.out <- part:
return true
case <-timer.C:
return false
case <-subscriber.stopCh:
return false
case <-ctx.Done():
return false
case <-b.done:
return false
}
}
func (b *Buffer) cancelSubscriber(key Key, subscriber *episodeSubscriber) {
b.mu.Lock()
defer b.mu.Unlock()
episode := b.episodes[key]
if episode != nil {
b.stopSubscriberLocked(episode, subscriber)
return
}
subscriber.stop()
}
func (b *Buffer) removeSubscriber(key Key, subscriber *episodeSubscriber) {
b.mu.Lock()
defer b.mu.Unlock()
episode := b.episodes[key]
if episode == nil {
return
}
delete(episode.subscribers, subscriber)
if episode.closed && len(episode.subscribers) == 0 {
b.queueClosedEpisodeLocked(key, episode)
}
}
func (*Buffer) stopSubscriberLocked(episode *episodeState, subscriber *episodeSubscriber) {
delete(episode.subscribers, subscriber)
subscriber.stop()
}
func notifySubscriber(subscriber *episodeSubscriber) {
select {
case subscriber.notifyCh <- struct{}{}:
default:
}
}
func (s *episodeSubscriber) stop() {
s.stopOnce.Do(func() { close(s.stopCh) })
}
func serializedPartBytes(part Part) (int64, error) {
data, err := json.Marshal(part.jsonValue())
if err != nil {
return 0, err
}
return int64(len(data)), nil
}