157 lines
4.2 KiB
Go
157 lines
4.2 KiB
Go
package trim
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"gitea.maximumdirect.net/eric/seriatim/schema"
|
|
)
|
|
|
|
// Mode controls how selector IDs are applied.
|
|
type Mode string
|
|
|
|
const (
|
|
ModeKeep Mode = "keep"
|
|
ModeRemove Mode = "remove"
|
|
)
|
|
|
|
// Options configures transcript trimming.
|
|
type Options struct {
|
|
Mode Mode
|
|
Selector Selector
|
|
AllowEmpty bool
|
|
}
|
|
|
|
// Result contains trimming output and ID mapping metadata.
|
|
type Result struct {
|
|
Transcript schema.Transcript
|
|
OldToNewID map[int]int
|
|
RemovedIDs []int
|
|
}
|
|
|
|
// Apply trims a full seriatim output transcript by segment ID.
|
|
func Apply(input schema.Transcript, opts Options) (Result, error) {
|
|
if err := validateMode(opts.Mode); err != nil {
|
|
return Result{}, err
|
|
}
|
|
|
|
selected := opts.Selector.IDs()
|
|
if len(selected) == 0 {
|
|
return Result{}, fmt.Errorf("selector cannot be empty")
|
|
}
|
|
|
|
idIndex, err := validateInputIDs(input.Segments)
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
|
|
for _, id := range selected {
|
|
if _, exists := idIndex[id]; !exists {
|
|
return Result{}, fmt.Errorf("selected segment ID %d does not exist in input transcript", id)
|
|
}
|
|
}
|
|
|
|
kept := make([]schema.Segment, 0, len(input.Segments))
|
|
removed := make([]int, 0, len(input.Segments))
|
|
oldToNew := make(map[int]int, len(input.Segments))
|
|
for _, segment := range input.Segments {
|
|
keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID)
|
|
if opts.Mode == ModeRemove {
|
|
keep = !opts.Selector.Contains(segment.ID)
|
|
}
|
|
|
|
if !keep {
|
|
removed = append(removed, segment.ID)
|
|
continue
|
|
}
|
|
|
|
rewritten := copySegment(segment)
|
|
rewritten.ID = len(kept) + 1
|
|
rewritten.OverlapGroupID = 0
|
|
kept = append(kept, rewritten)
|
|
oldToNew[segment.ID] = rewritten.ID
|
|
}
|
|
|
|
if len(kept) == 0 && !opts.AllowEmpty {
|
|
return Result{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
|
|
}
|
|
|
|
out := copyTranscript(input)
|
|
out.Segments = kept
|
|
out.OverlapGroups = nil
|
|
return Result{
|
|
Transcript: out,
|
|
OldToNewID: oldToNew,
|
|
RemovedIDs: removed,
|
|
}, nil
|
|
}
|
|
|
|
func validateMode(mode Mode) error {
|
|
switch mode {
|
|
case ModeKeep, ModeRemove:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("invalid trim mode %q", mode)
|
|
}
|
|
}
|
|
|
|
func validateInputIDs(segments []schema.Segment) (map[int]int, error) {
|
|
seen := make(map[int]int, len(segments))
|
|
for index, segment := range segments {
|
|
id := segment.ID
|
|
if id <= 0 {
|
|
return nil, fmt.Errorf("input transcript has non-positive segment ID %d at index %d", id, index)
|
|
}
|
|
if firstIndex, exists := seen[id]; exists {
|
|
return nil, fmt.Errorf("input transcript has duplicate segment ID %d at indexes %d and %d", id, firstIndex, index)
|
|
}
|
|
seen[id] = index
|
|
}
|
|
|
|
for id := 1; id <= len(segments); id++ {
|
|
if _, exists := seen[id]; !exists {
|
|
return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(segments), id)
|
|
}
|
|
}
|
|
return seen, nil
|
|
}
|
|
|
|
func copyTranscript(input schema.Transcript) schema.Transcript {
|
|
return schema.Transcript{
|
|
Metadata: schema.Metadata{
|
|
Application: input.Metadata.Application,
|
|
Version: input.Metadata.Version,
|
|
InputReader: input.Metadata.InputReader,
|
|
InputFiles: append([]string(nil), input.Metadata.InputFiles...),
|
|
PreprocessingModules: append([]string(nil), input.Metadata.PreprocessingModules...),
|
|
PostprocessingModules: append([]string(nil), input.Metadata.PostprocessingModules...),
|
|
OutputModules: append([]string(nil), input.Metadata.OutputModules...),
|
|
},
|
|
Segments: append([]schema.Segment(nil), input.Segments...),
|
|
OverlapGroups: append([]schema.OverlapGroup(nil), input.OverlapGroups...),
|
|
}
|
|
}
|
|
|
|
func copySegment(input schema.Segment) schema.Segment {
|
|
return schema.Segment{
|
|
ID: input.ID,
|
|
Source: input.Source,
|
|
SourceSegmentIndex: copyIntPtr(input.SourceSegmentIndex),
|
|
SourceRef: input.SourceRef,
|
|
DerivedFrom: append([]string(nil), input.DerivedFrom...),
|
|
Speaker: input.Speaker,
|
|
Start: input.Start,
|
|
End: input.End,
|
|
Text: input.Text,
|
|
Categories: append([]string(nil), input.Categories...),
|
|
OverlapGroupID: input.OverlapGroupID,
|
|
}
|
|
}
|
|
|
|
func copyIntPtr(value *int) *int {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
copied := *value
|
|
return &copied
|
|
}
|