mirror of
https://github.com/coder/coder.git
synced 2026-06-06 06:28:20 +00:00
493 lines
12 KiB
Go
493 lines
12 KiB
Go
package messagepartbuffer
|
|
|
|
import (
|
|
"container/heap"
|
|
"context"
|
|
"encoding/json"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
const (
|
|
defaultMaxEpisodeBytes = int64(1024 * 1024)
|
|
defaultClosedEpisodeRetention = 15 * time.Second
|
|
defaultSubscriberSendTimeout = 10 * time.Second
|
|
defaultSubscriberChannelSize = 16
|
|
)
|
|
|
|
var (
|
|
// ErrEpisodeExists means the episode already exists.
|
|
ErrEpisodeExists = xerrors.New("message part episode already exists")
|
|
// ErrEpisodeNotFound means the episode has not been created.
|
|
ErrEpisodeNotFound = xerrors.New("message part episode not found")
|
|
// ErrEpisodeClosed means the episode no longer accepts parts.
|
|
ErrEpisodeClosed = xerrors.New("message part episode closed")
|
|
// ErrEpisodeFull means the episode byte limit would be exceeded.
|
|
ErrEpisodeFull = xerrors.New("message part episode full")
|
|
// ErrMessagePartBufferClosed means the whole buffer is closed.
|
|
ErrMessagePartBufferClosed = xerrors.New("message part buffer closed")
|
|
)
|
|
|
|
// Key identifies a buffered message part episode.
|
|
type Key struct {
|
|
ChatID uuid.UUID
|
|
HistoryVersion int64
|
|
GenerationAttempt int64
|
|
}
|
|
|
|
// Part is a buffered chat message part with its sequence number.
|
|
type Part struct {
|
|
Seq int64
|
|
Role codersdk.ChatMessageRole
|
|
MessagePart codersdk.ChatMessagePart
|
|
}
|
|
|
|
type partJSON struct {
|
|
Seq int64 `json:"seq"`
|
|
Role codersdk.ChatMessageRole `json:"role"`
|
|
Part codersdk.ChatMessagePart `json:"part"`
|
|
}
|
|
|
|
func (p Part) jsonValue() partJSON {
|
|
return partJSON{
|
|
Seq: p.Seq,
|
|
Role: p.Role,
|
|
Part: p.MessagePart,
|
|
}
|
|
}
|
|
|
|
// Options configures a Buffer.
|
|
type Options struct {
|
|
MaxEpisodeBytes int64
|
|
ClosedEpisodeRetention time.Duration
|
|
SubscriberSendTimeout time.Duration
|
|
SubscriberChannelSize int
|
|
Clock quartz.Clock
|
|
}
|
|
|
|
// Buffer stores streamed message parts by episode.
|
|
type Buffer struct {
|
|
mu sync.Mutex
|
|
opts Options
|
|
episodes map[Key]*episodeState
|
|
closedEpisodes closedEpisodeHeap
|
|
closed bool
|
|
done chan struct{}
|
|
}
|
|
|
|
type episodeState struct {
|
|
created bool
|
|
createdCh chan struct{}
|
|
closed bool
|
|
closedAt time.Time
|
|
closedHeapItem *closedEpisodeItem
|
|
parts []Part
|
|
bytes int64
|
|
subscribers map[*episodeSubscriber]struct{}
|
|
}
|
|
|
|
type closedEpisodeItem struct {
|
|
key Key
|
|
closedAt time.Time
|
|
}
|
|
|
|
type closedEpisodeHeap []*closedEpisodeItem
|
|
|
|
func (h closedEpisodeHeap) Len() int {
|
|
return len(h)
|
|
}
|
|
|
|
func (h closedEpisodeHeap) Less(i, j int) bool {
|
|
return h[i].closedAt.Before(h[j].closedAt)
|
|
}
|
|
|
|
func (h closedEpisodeHeap) Swap(i, j int) {
|
|
h[i], h[j] = h[j], h[i]
|
|
}
|
|
|
|
func (h *closedEpisodeHeap) Push(value any) {
|
|
item, ok := value.(*closedEpisodeItem)
|
|
if !ok {
|
|
panic("closed episode heap received invalid item")
|
|
}
|
|
*h = append(*h, item)
|
|
}
|
|
|
|
func (h *closedEpisodeHeap) Pop() any {
|
|
old := *h
|
|
last := old[len(old)-1]
|
|
old[len(old)-1] = nil
|
|
*h = old[:len(old)-1]
|
|
return last
|
|
}
|
|
|
|
type episodeSubscriber struct {
|
|
out chan Part
|
|
notifyCh chan struct{}
|
|
stopCh chan struct{}
|
|
next int
|
|
stopOnce sync.Once
|
|
}
|
|
|
|
// New returns a message part buffer.
|
|
func New(options Options) *Buffer {
|
|
if options.MaxEpisodeBytes <= 0 {
|
|
options.MaxEpisodeBytes = defaultMaxEpisodeBytes
|
|
}
|
|
if options.ClosedEpisodeRetention <= 0 {
|
|
options.ClosedEpisodeRetention = defaultClosedEpisodeRetention
|
|
}
|
|
if options.SubscriberSendTimeout <= 0 {
|
|
options.SubscriberSendTimeout = defaultSubscriberSendTimeout
|
|
}
|
|
if options.SubscriberChannelSize < 0 {
|
|
options.SubscriberChannelSize = 0
|
|
}
|
|
if options.SubscriberChannelSize == 0 {
|
|
options.SubscriberChannelSize = defaultSubscriberChannelSize
|
|
}
|
|
if options.Clock == nil {
|
|
options.Clock = quartz.NewReal()
|
|
}
|
|
buffer := &Buffer{
|
|
opts: options,
|
|
episodes: make(map[Key]*episodeState),
|
|
done: make(chan struct{}),
|
|
}
|
|
buffer.startCleanupLoop()
|
|
return buffer
|
|
}
|
|
|
|
// CreateEpisode creates a new episode.
|
|
func (b *Buffer) CreateEpisode(key Key) error {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
if b.closed {
|
|
return ErrMessagePartBufferClosed
|
|
}
|
|
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "create"))
|
|
episode := b.episodeLocked(key)
|
|
if episode.created {
|
|
return ErrEpisodeExists
|
|
}
|
|
episode.created = true
|
|
close(episode.createdCh)
|
|
return nil
|
|
}
|
|
|
|
// AddPart appends a part to an existing episode.
|
|
func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) error {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
if b.closed {
|
|
return ErrMessagePartBufferClosed
|
|
}
|
|
episode := b.episodes[key]
|
|
if episode == nil || !episode.created {
|
|
return ErrEpisodeNotFound
|
|
}
|
|
if episode.closed {
|
|
return ErrEpisodeClosed
|
|
}
|
|
buffered := Part{
|
|
Seq: int64(len(episode.parts) + 1),
|
|
Role: role,
|
|
MessagePart: part,
|
|
}
|
|
sizeBytes, err := serializedPartBytes(buffered)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if episode.bytes+sizeBytes > b.opts.MaxEpisodeBytes {
|
|
return ErrEpisodeFull
|
|
}
|
|
episode.parts = append(episode.parts, buffered)
|
|
episode.bytes += sizeBytes
|
|
for subscriber := range episode.subscribers {
|
|
notifySubscriber(subscriber)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CloseEpisode marks an episode closed and closes its subscribers.
|
|
func (b *Buffer) CloseEpisode(key Key) error {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
if b.closed {
|
|
return ErrMessagePartBufferClosed
|
|
}
|
|
episode := b.episodeLocked(key)
|
|
if !episode.created {
|
|
episode.created = true
|
|
close(episode.createdCh)
|
|
}
|
|
if episode.closed {
|
|
return nil
|
|
}
|
|
episode.closed = true
|
|
episode.closedAt = b.opts.Clock.Now("message-part-buffer", "close")
|
|
b.queueClosedEpisodeLocked(key, episode)
|
|
for subscriber := range episode.subscribers {
|
|
notifySubscriber(subscriber)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetParts returns a snapshot of buffered parts for an episode.
|
|
func (b *Buffer) GetParts(key Key) ([]Part, error) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
if b.closed {
|
|
return nil, ErrMessagePartBufferClosed
|
|
}
|
|
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "get"))
|
|
episode := b.episodes[key]
|
|
if episode == nil || !episode.created {
|
|
return nil, ErrEpisodeNotFound
|
|
}
|
|
return append([]Part(nil), episode.parts...), nil
|
|
}
|
|
|
|
// SubscribeToEpisode replays existing parts and streams new parts.
|
|
func (b *Buffer) SubscribeToEpisode(ctx context.Context, key Key) (<-chan Part, func(), error) {
|
|
b.mu.Lock()
|
|
if b.closed {
|
|
b.mu.Unlock()
|
|
return nil, nil, ErrMessagePartBufferClosed
|
|
}
|
|
episode := b.episodeLocked(key)
|
|
subscriber := &episodeSubscriber{
|
|
out: make(chan Part),
|
|
notifyCh: make(chan struct{}, 1),
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
if episode.subscribers == nil {
|
|
episode.subscribers = make(map[*episodeSubscriber]struct{})
|
|
}
|
|
episode.subscribers[subscriber] = struct{}{}
|
|
notifySubscriber(subscriber)
|
|
b.mu.Unlock()
|
|
|
|
go b.deliverSubscriber(ctx, key, subscriber)
|
|
cancel := func() {
|
|
b.cancelSubscriber(key, subscriber)
|
|
}
|
|
return subscriber.out, cancel, nil
|
|
}
|
|
|
|
// Close closes the buffer and all pending subscriptions.
|
|
func (b *Buffer) Close() {
|
|
b.mu.Lock()
|
|
if b.closed {
|
|
b.mu.Unlock()
|
|
return
|
|
}
|
|
b.closed = true
|
|
close(b.done)
|
|
for _, episode := range b.episodes {
|
|
for subscriber := range episode.subscribers {
|
|
b.stopSubscriberLocked(episode, subscriber)
|
|
}
|
|
if !episode.created {
|
|
episode.created = true
|
|
close(episode.createdCh)
|
|
}
|
|
}
|
|
b.mu.Unlock()
|
|
}
|
|
|
|
func (b *Buffer) startCleanupLoop() {
|
|
ticker := b.opts.Clock.NewTicker(b.opts.ClosedEpisodeRetention, "message-part-buffer", "cleanup")
|
|
go func() {
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
b.mu.Lock()
|
|
if b.closed {
|
|
b.mu.Unlock()
|
|
return
|
|
}
|
|
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "cleanup"))
|
|
b.mu.Unlock()
|
|
case <-b.done:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (b *Buffer) gcClosedEpisodesLocked(now time.Time) {
|
|
cutoff := now.Add(-b.opts.ClosedEpisodeRetention)
|
|
type retainedEpisode struct {
|
|
key Key
|
|
episode *episodeState
|
|
}
|
|
retained := make([]retainedEpisode, 0)
|
|
for b.closedEpisodes.Len() > 0 {
|
|
item := b.closedEpisodes[0]
|
|
if item.closedAt.After(cutoff) {
|
|
break
|
|
}
|
|
popped, ok := heap.Pop(&b.closedEpisodes).(*closedEpisodeItem)
|
|
if !ok || popped != item {
|
|
continue
|
|
}
|
|
episode := b.episodes[item.key]
|
|
if episode == nil || episode.closedHeapItem != item || !episode.closed {
|
|
continue
|
|
}
|
|
episode.closedHeapItem = nil
|
|
if len(episode.subscribers) > 0 {
|
|
retained = append(retained, retainedEpisode{key: item.key, episode: episode})
|
|
continue
|
|
}
|
|
delete(b.episodes, item.key)
|
|
}
|
|
for _, item := range retained {
|
|
if b.episodes[item.key] != item.episode || !item.episode.closed || item.episode.closedHeapItem != nil {
|
|
continue
|
|
}
|
|
b.queueClosedEpisodeLocked(item.key, item.episode)
|
|
}
|
|
}
|
|
|
|
func (b *Buffer) queueClosedEpisodeLocked(key Key, episode *episodeState) {
|
|
if episode.closedHeapItem != nil {
|
|
return
|
|
}
|
|
item := &closedEpisodeItem{key: key, closedAt: episode.closedAt}
|
|
episode.closedHeapItem = item
|
|
heap.Push(&b.closedEpisodes, item)
|
|
}
|
|
|
|
func (b *Buffer) episodeLocked(key Key) *episodeState {
|
|
episode := b.episodes[key]
|
|
if episode != nil {
|
|
return episode
|
|
}
|
|
episode = &episodeState{createdCh: make(chan struct{})}
|
|
b.episodes[key] = episode
|
|
return episode
|
|
}
|
|
|
|
func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts []Part, closed bool, ok bool) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
if b.closed {
|
|
return nil, false, false
|
|
}
|
|
episode := b.episodes[key]
|
|
if episode == nil {
|
|
return nil, false, false
|
|
}
|
|
if !episode.created {
|
|
return nil, false, true
|
|
}
|
|
if subscriber.next > len(episode.parts) {
|
|
return nil, false, false
|
|
}
|
|
parts = append([]Part(nil), episode.parts[subscriber.next:]...)
|
|
subscriber.next = len(episode.parts)
|
|
return parts, episode.closed && subscriber.next == len(episode.parts), true
|
|
}
|
|
|
|
func (b *Buffer) deliverSubscriber(ctx context.Context, key Key, subscriber *episodeSubscriber) {
|
|
defer close(subscriber.out)
|
|
defer b.removeSubscriber(key, subscriber)
|
|
for {
|
|
parts, closed, ok := b.subscriberParts(key, subscriber)
|
|
if !ok {
|
|
return
|
|
}
|
|
for _, part := range parts {
|
|
if !b.sendSubscriberPart(ctx, subscriber, part) {
|
|
return
|
|
}
|
|
}
|
|
if closed {
|
|
return
|
|
}
|
|
select {
|
|
case <-subscriber.notifyCh:
|
|
case <-subscriber.stopCh:
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
case <-b.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Buffer) sendSubscriberPart(ctx context.Context, subscriber *episodeSubscriber, part Part) bool {
|
|
timer := b.opts.Clock.NewTimer(b.opts.SubscriberSendTimeout, "message-part-buffer", "subscriber-send")
|
|
defer timer.Stop()
|
|
select {
|
|
case subscriber.out <- part:
|
|
return true
|
|
case <-timer.C:
|
|
return false
|
|
case <-subscriber.stopCh:
|
|
return false
|
|
case <-ctx.Done():
|
|
return false
|
|
case <-b.done:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (b *Buffer) cancelSubscriber(key Key, subscriber *episodeSubscriber) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
episode := b.episodes[key]
|
|
if episode != nil {
|
|
b.stopSubscriberLocked(episode, subscriber)
|
|
return
|
|
}
|
|
subscriber.stop()
|
|
}
|
|
|
|
func (b *Buffer) removeSubscriber(key Key, subscriber *episodeSubscriber) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
episode := b.episodes[key]
|
|
if episode == nil {
|
|
return
|
|
}
|
|
delete(episode.subscribers, subscriber)
|
|
if episode.closed && len(episode.subscribers) == 0 {
|
|
b.queueClosedEpisodeLocked(key, episode)
|
|
}
|
|
}
|
|
|
|
func (*Buffer) stopSubscriberLocked(episode *episodeState, subscriber *episodeSubscriber) {
|
|
delete(episode.subscribers, subscriber)
|
|
subscriber.stop()
|
|
}
|
|
|
|
func notifySubscriber(subscriber *episodeSubscriber) {
|
|
select {
|
|
case subscriber.notifyCh <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (s *episodeSubscriber) stop() {
|
|
s.stopOnce.Do(func() { close(s.stopCh) })
|
|
}
|
|
|
|
func serializedPartBytes(part Part) (int64, error) {
|
|
data, err := json.Marshal(part.jsonValue())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int64(len(data)), nil
|
|
}
|