refactor: make ChatMessagePart a discriminated union in TypeScript (#23168)

The flat ChatMessagePart interface had 20+ optional fields, preventing
TypeScript from narrowing types on switch(part.type). Each consumer
needed runtime validation, type assertions, or defensive ?. chains.

Add `variants` struct tags to ChatMessagePart fields declaring which
union variants include each field. A codegen mutation in apitypings
reads these tags via reflect and generates per-variant sub-interfaces
(ChatTextPart, ChatReasoningPart, etc.) plus a union type alias.
A test validates every field has a variants tag or is explicitly
excluded, and every part type is covered.

Remove dead frontend code: normalizeBlockType, alias case branches
("thinking", "toolcall", "toolresult"), legacy field fallbacks
(line_number, typedBlock.name/id/input/output), and result_delta
handling. Add test coverage for args_delta streaming, provider_executed
skip logic, and source part parsing.
This commit is contained in:
Mathias Fredriksson
2026-03-18 11:27:51 +02:00
committed by GitHub
parent 563c00fb2c
commit 66f809388e
11 changed files with 568 additions and 196 deletions
+167
View File
@@ -3,9 +3,12 @@ package main
import (
"fmt"
"log"
"reflect"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/guts"
"github.com/coder/guts/bindings"
"github.com/coder/guts/config"
@@ -74,6 +77,7 @@ func TSMutations(ts *guts.Typescript) {
// of referencing maps that are actually null.
config.NotNullMaps,
FixSerpentStruct,
DiscriminatedChatMessagePart,
// Prefer enums as types
config.EnumAsTypes,
// Enum list generator
@@ -142,6 +146,169 @@ func TypeMappings(gen *guts.GoParser) error {
return nil
}
// DiscriminatedChatMessagePart splits the flat ChatMessagePart
// interface into a discriminated union of per-type sub-interfaces.
// Each sub-interface narrows the `type` field to a string literal
// and includes only the fields relevant to that part type.
//
// Variant membership is declared via `variants` struct tags on
// ChatMessagePart fields in codersdk/chats.go. This function
// reads those tags via reflect and builds the union from them.
func DiscriminatedChatMessagePart(ts *guts.Typescript) {
node, ok := ts.Node("ChatMessagePart")
if !ok {
return
}
iface, ok := node.(*bindings.Interface)
if !ok {
return
}
// Build a lookup from field name to its PropertySignature so
// we can copy type information from the original interface.
fieldMap := make(map[string]*bindings.PropertySignature, len(iface.Fields))
for _, f := range iface.Fields {
fieldMap[f.Name] = f
}
// copyField copies a field from the original interface into a
// sub-interface, setting QuestionToken based on whether the
// field is required for that variant.
copyField := func(name string, required bool) *bindings.PropertySignature {
orig, exists := fieldMap[name]
if !exists {
return nil
}
return &bindings.PropertySignature{
Name: orig.Name,
Modifiers: orig.Modifiers,
QuestionToken: !required,
Type: orig.Type,
SupportComments: orig.SupportComments,
}
}
variants := parseVariantTags()
unionMembers := make([]bindings.ExpressionType, 0, len(variants))
for _, v := range variants {
fields := make([]*bindings.PropertySignature, 0, 1+len(v.required)+len(v.optional))
// Discriminant field: type narrowed to a string literal.
fields = append(fields, &bindings.PropertySignature{
Name: "type",
Type: &bindings.LiteralType{Value: string(v.typeLiteral)},
})
for _, name := range v.required {
if f := copyField(name, true); f != nil {
fields = append(fields, f)
}
}
for _, name := range v.optional {
if f := copyField(name, false); f != nil {
fields = append(fields, f)
}
}
tsName := chatMessagePartTSName(v.typeLiteral)
subIface := &bindings.Interface{
Name: bindings.Identifier{
Name: tsName,
Package: iface.Name.Package,
Prefix: iface.Name.Prefix,
},
Fields: fields,
Source: iface.Source,
}
// Inject the sub-interface as a new top-level type.
if err := ts.SetNode(tsName, subIface); err != nil {
panic(fmt.Sprintf("ChatMessagePart variant %q: %v", v.typeLiteral, err))
}
unionMembers = append(unionMembers, bindings.Reference(bindings.Identifier{
Name: tsName,
Package: iface.Name.Package,
Prefix: iface.Name.Prefix,
}))
}
// Replace the original flat interface with a union alias.
ts.ReplaceNode("ChatMessagePart", &bindings.Alias{
Name: iface.Name,
Modifiers: iface.Modifiers,
Type: bindings.Union(unionMembers...),
SupportComments: iface.SupportComments,
Source: iface.Source,
})
}
// chatPartVariant holds the parsed variant info for one part type.
type chatPartVariant struct {
typeLiteral codersdk.ChatMessagePartType
required []string // JSON field names
optional []string // JSON field names
}
// parseVariantTags reads `variants` struct tags from ChatMessagePart
// and returns the per-type field sets using JSON tag names. Variants
// are returned in AllChatMessagePartTypes order for stable codegen.
func parseVariantTags() []chatPartVariant {
t := reflect.TypeFor[codersdk.ChatMessagePart]()
type fieldSets struct {
required []string
optional []string
}
byType := make(map[codersdk.ChatMessagePartType]*fieldSets)
for i := range t.NumField() {
f := t.Field(i)
varTag := f.Tag.Get("variants")
if varTag == "" {
continue
}
jsonName, _, _ := strings.Cut(f.Tag.Get("json"), ",")
for entry := range strings.SplitSeq(varTag, ",") {
isOptional := strings.HasSuffix(entry, "?")
typeLit := codersdk.ChatMessagePartType(strings.TrimSuffix(entry, "?"))
if byType[typeLit] == nil {
byType[typeLit] = &fieldSets{}
}
if isOptional {
byType[typeLit].optional = append(byType[typeLit].optional, jsonName)
} else {
byType[typeLit].required = append(byType[typeLit].required, jsonName)
}
}
}
result := make([]chatPartVariant, 0, len(byType))
for _, pt := range codersdk.AllChatMessagePartTypes() {
if fs, ok := byType[pt]; ok {
result = append(result, chatPartVariant{
typeLiteral: pt,
required: fs.required,
optional: fs.optional,
})
}
}
return result
}
// chatMessagePartTSName derives a TypeScript interface name from
// a ChatMessagePartType literal. "tool-call" → "ChatToolCallPart".
func chatMessagePartTSName(t codersdk.ChatMessagePartType) string {
words := strings.Split(string(t), "-")
for i, w := range words {
if len(w) > 0 {
words[i] = strings.ToUpper(w[:1]) + w[1:]
}
}
return "Chat" + strings.Join(words, "") + "Part"
}
// FixSerpentStruct fixes 'serpent.Struct'.
// 'serpent.Struct' overrides the json.Marshal to use the underlying type,
// so the typescript type should be the underlying type.