Files
seriatim/internal/trim/apply.go

365 lines
10 KiB
Go

package trim
import (
"fmt"
"gitea.maximumdirect.net/eric/seriatim/internal/model"
"gitea.maximumdirect.net/eric/seriatim/internal/overlap"
"gitea.maximumdirect.net/eric/seriatim/schema"
)
// Mode controls how selector IDs are applied.
type Mode string
const (
ModeKeep Mode = "keep"
ModeRemove Mode = "remove"
)
// Options configures transcript trimming.
type Options struct {
Mode Mode
Selector Selector
AllowEmpty bool
}
// Result contains trimming output and ID mapping metadata.
type Result struct {
Transcript schema.Transcript
OldToNewID map[int]int
RemovedIDs []int
}
// IntermediateResult contains trimming output for intermediate schema artifacts.
type IntermediateResult struct {
Transcript schema.IntermediateTranscript
OldToNewID map[int]int
RemovedIDs []int
}
// MinimalResult contains trimming output for minimal schema artifacts.
type MinimalResult struct {
Transcript schema.MinimalTranscript
OldToNewID map[int]int
RemovedIDs []int
}
// Apply trims a full seriatim output transcript by segment ID.
func Apply(input schema.Transcript, opts Options) (Result, error) {
if err := validateMode(opts.Mode); err != nil {
return Result{}, err
}
selected := opts.Selector.IDs()
if len(selected) == 0 {
return Result{}, fmt.Errorf("selector cannot be empty")
}
inputIDs := make([]int, len(input.Segments))
for index, segment := range input.Segments {
inputIDs[index] = segment.ID
}
idIndex, err := validateInputIDs(inputIDs)
if err != nil {
return Result{}, err
}
if err := validateSelectedIDsExist(selected, idIndex); err != nil {
return Result{}, err
}
kept := make([]schema.Segment, 0, len(input.Segments))
removed := make([]int, 0, len(input.Segments))
oldToNew := make(map[int]int, len(input.Segments))
for _, segment := range input.Segments {
keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID)
if opts.Mode == ModeRemove {
keep = !opts.Selector.Contains(segment.ID)
}
if !keep {
removed = append(removed, segment.ID)
continue
}
rewritten := copySegment(segment)
rewritten.ID = len(kept) + 1
rewritten.OverlapGroupID = 0
kept = append(kept, rewritten)
oldToNew[segment.ID] = rewritten.ID
}
if len(kept) == 0 && !opts.AllowEmpty {
return Result{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
}
kept, groups := recomputeOverlapGroups(kept)
out := copyTranscript(input)
out.Segments = kept
out.OverlapGroups = groups
return Result{
Transcript: out,
OldToNewID: oldToNew,
RemovedIDs: removed,
}, nil
}
// ApplyIntermediate trims an intermediate seriatim output transcript by
// segment ID.
func ApplyIntermediate(input schema.IntermediateTranscript, opts Options) (IntermediateResult, error) {
if err := validateMode(opts.Mode); err != nil {
return IntermediateResult{}, err
}
selected := opts.Selector.IDs()
if len(selected) == 0 {
return IntermediateResult{}, fmt.Errorf("selector cannot be empty")
}
inputIDs := make([]int, len(input.Segments))
for index, segment := range input.Segments {
inputIDs[index] = segment.ID
}
idIndex, err := validateInputIDs(inputIDs)
if err != nil {
return IntermediateResult{}, err
}
if err := validateSelectedIDsExist(selected, idIndex); err != nil {
return IntermediateResult{}, err
}
kept := make([]schema.IntermediateSegment, 0, len(input.Segments))
removed := make([]int, 0, len(input.Segments))
oldToNew := make(map[int]int, len(input.Segments))
for _, segment := range input.Segments {
keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID)
if opts.Mode == ModeRemove {
keep = !opts.Selector.Contains(segment.ID)
}
if !keep {
removed = append(removed, segment.ID)
continue
}
rewritten := schema.IntermediateSegment{
ID: len(kept) + 1,
Start: segment.Start,
End: segment.End,
Speaker: segment.Speaker,
Text: segment.Text,
Categories: append([]string(nil), segment.Categories...),
}
kept = append(kept, rewritten)
oldToNew[segment.ID] = rewritten.ID
}
if len(kept) == 0 && !opts.AllowEmpty {
return IntermediateResult{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
}
return IntermediateResult{
Transcript: schema.IntermediateTranscript{
Metadata: schema.IntermediateMetadata{
Application: input.Metadata.Application,
Version: input.Metadata.Version,
OutputSchema: input.Metadata.OutputSchema,
},
Segments: kept,
},
OldToNewID: oldToNew,
RemovedIDs: removed,
}, nil
}
// ApplyMinimal trims a minimal seriatim output transcript by segment ID.
func ApplyMinimal(input schema.MinimalTranscript, opts Options) (MinimalResult, error) {
if err := validateMode(opts.Mode); err != nil {
return MinimalResult{}, err
}
selected := opts.Selector.IDs()
if len(selected) == 0 {
return MinimalResult{}, fmt.Errorf("selector cannot be empty")
}
inputIDs := make([]int, len(input.Segments))
for index, segment := range input.Segments {
inputIDs[index] = segment.ID
}
idIndex, err := validateInputIDs(inputIDs)
if err != nil {
return MinimalResult{}, err
}
if err := validateSelectedIDsExist(selected, idIndex); err != nil {
return MinimalResult{}, err
}
kept := make([]schema.MinimalSegment, 0, len(input.Segments))
removed := make([]int, 0, len(input.Segments))
oldToNew := make(map[int]int, len(input.Segments))
for _, segment := range input.Segments {
keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID)
if opts.Mode == ModeRemove {
keep = !opts.Selector.Contains(segment.ID)
}
if !keep {
removed = append(removed, segment.ID)
continue
}
rewritten := schema.MinimalSegment{
ID: len(kept) + 1,
Start: segment.Start,
End: segment.End,
Speaker: segment.Speaker,
Text: segment.Text,
}
kept = append(kept, rewritten)
oldToNew[segment.ID] = rewritten.ID
}
if len(kept) == 0 && !opts.AllowEmpty {
return MinimalResult{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
}
return MinimalResult{
Transcript: schema.MinimalTranscript{
Metadata: schema.MinimalMetadata{
Application: input.Metadata.Application,
Version: input.Metadata.Version,
OutputSchema: input.Metadata.OutputSchema,
},
Segments: kept,
},
OldToNewID: oldToNew,
RemovedIDs: removed,
}, nil
}
func validateMode(mode Mode) error {
switch mode {
case ModeKeep, ModeRemove:
return nil
default:
return fmt.Errorf("invalid trim mode %q", mode)
}
}
func validateInputIDs(ids []int) (map[int]int, error) {
seen := make(map[int]int, len(ids))
for index, id := range ids {
if id <= 0 {
return nil, fmt.Errorf("input transcript has non-positive segment ID %d at index %d", id, index)
}
if firstIndex, exists := seen[id]; exists {
return nil, fmt.Errorf("input transcript has duplicate segment ID %d at indexes %d and %d", id, firstIndex, index)
}
seen[id] = index
}
for id := 1; id <= len(ids); id++ {
if _, exists := seen[id]; !exists {
return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(ids), id)
}
}
return seen, nil
}
func validateSelectedIDsExist(selected []int, idIndex map[int]int) error {
for _, id := range selected {
if _, exists := idIndex[id]; !exists {
return fmt.Errorf("selected segment ID %d does not exist in input transcript", id)
}
}
return nil
}
func recomputeOverlapGroups(segments []schema.Segment) ([]schema.Segment, []schema.OverlapGroup) {
if len(segments) == 0 {
return segments, nil
}
modelSegments := make([]model.Segment, len(segments))
for index, segment := range segments {
modelSegments[index] = model.Segment{
ID: segment.ID,
Source: segment.Source,
SourceSegmentIndex: copyIntPtr(segment.SourceSegmentIndex),
SourceRef: segment.SourceRef,
DerivedFrom: append([]string(nil), segment.DerivedFrom...),
Speaker: segment.Speaker,
Start: segment.Start,
End: segment.End,
Text: segment.Text,
Categories: append([]string(nil), segment.Categories...),
OverlapGroupID: segment.OverlapGroupID,
}
}
detected := overlap.Detect(model.MergedTranscript{
Segments: modelSegments,
})
rewrittenSegments := make([]schema.Segment, len(segments))
for index, segment := range segments {
rewritten := copySegment(segment)
rewritten.OverlapGroupID = detected.Segments[index].OverlapGroupID
rewrittenSegments[index] = rewritten
}
groups := make([]schema.OverlapGroup, len(detected.OverlapGroups))
for index, group := range detected.OverlapGroups {
groups[index] = schema.OverlapGroup{
ID: group.ID,
Start: group.Start,
End: group.End,
Segments: append([]string(nil), group.Segments...),
Speakers: append([]string(nil), group.Speakers...),
Class: group.Class,
Resolution: group.Resolution,
}
}
return rewrittenSegments, groups
}
func copyTranscript(input schema.Transcript) schema.Transcript {
return schema.Transcript{
Metadata: schema.Metadata{
Application: input.Metadata.Application,
Version: input.Metadata.Version,
InputReader: input.Metadata.InputReader,
InputFiles: append([]string(nil), input.Metadata.InputFiles...),
PreprocessingModules: append([]string(nil), input.Metadata.PreprocessingModules...),
PostprocessingModules: append([]string(nil), input.Metadata.PostprocessingModules...),
OutputModules: append([]string(nil), input.Metadata.OutputModules...),
},
Segments: append([]schema.Segment(nil), input.Segments...),
OverlapGroups: append([]schema.OverlapGroup(nil), input.OverlapGroups...),
}
}
func copySegment(input schema.Segment) schema.Segment {
return schema.Segment{
ID: input.ID,
Source: input.Source,
SourceSegmentIndex: copyIntPtr(input.SourceSegmentIndex),
SourceRef: input.SourceRef,
DerivedFrom: append([]string(nil), input.DerivedFrom...),
Speaker: input.Speaker,
Start: input.Start,
End: input.End,
Text: input.Text,
Categories: append([]string(nil), input.Categories...),
OverlapGroupID: input.OverlapGroupID,
}
}
func copyIntPtr(value *int) *int {
if value == nil {
return nil
}
copied := *value
return &copied
}