Recompute overlap groups during trim
This commit is contained in:
@@ -3,6 +3,8 @@ package trim
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/overlap"
|
||||
"gitea.maximumdirect.net/eric/seriatim/schema"
|
||||
)
|
||||
|
||||
@@ -28,6 +30,20 @@ type Result struct {
|
||||
RemovedIDs []int
|
||||
}
|
||||
|
||||
// IntermediateResult contains trimming output for intermediate schema artifacts.
|
||||
type IntermediateResult struct {
|
||||
Transcript schema.IntermediateTranscript
|
||||
OldToNewID map[int]int
|
||||
RemovedIDs []int
|
||||
}
|
||||
|
||||
// MinimalResult contains trimming output for minimal schema artifacts.
|
||||
type MinimalResult struct {
|
||||
Transcript schema.MinimalTranscript
|
||||
OldToNewID map[int]int
|
||||
RemovedIDs []int
|
||||
}
|
||||
|
||||
// Apply trims a full seriatim output transcript by segment ID.
|
||||
func Apply(input schema.Transcript, opts Options) (Result, error) {
|
||||
if err := validateMode(opts.Mode); err != nil {
|
||||
@@ -39,15 +55,18 @@ func Apply(input schema.Transcript, opts Options) (Result, error) {
|
||||
return Result{}, fmt.Errorf("selector cannot be empty")
|
||||
}
|
||||
|
||||
idIndex, err := validateInputIDs(input.Segments)
|
||||
inputIDs := make([]int, len(input.Segments))
|
||||
for index, segment := range input.Segments {
|
||||
inputIDs[index] = segment.ID
|
||||
}
|
||||
|
||||
idIndex, err := validateInputIDs(inputIDs)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
for _, id := range selected {
|
||||
if _, exists := idIndex[id]; !exists {
|
||||
return Result{}, fmt.Errorf("selected segment ID %d does not exist in input transcript", id)
|
||||
}
|
||||
if err := validateSelectedIDsExist(selected, idIndex); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
kept := make([]schema.Segment, 0, len(input.Segments))
|
||||
@@ -75,9 +94,11 @@ func Apply(input schema.Transcript, opts Options) (Result, error) {
|
||||
return Result{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
|
||||
}
|
||||
|
||||
kept, groups := recomputeOverlapGroups(kept)
|
||||
|
||||
out := copyTranscript(input)
|
||||
out.Segments = kept
|
||||
out.OverlapGroups = nil
|
||||
out.OverlapGroups = groups
|
||||
return Result{
|
||||
Transcript: out,
|
||||
OldToNewID: oldToNew,
|
||||
@@ -85,6 +106,138 @@ func Apply(input schema.Transcript, opts Options) (Result, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ApplyIntermediate trims an intermediate seriatim output transcript by
|
||||
// segment ID.
|
||||
func ApplyIntermediate(input schema.IntermediateTranscript, opts Options) (IntermediateResult, error) {
|
||||
if err := validateMode(opts.Mode); err != nil {
|
||||
return IntermediateResult{}, err
|
||||
}
|
||||
|
||||
selected := opts.Selector.IDs()
|
||||
if len(selected) == 0 {
|
||||
return IntermediateResult{}, fmt.Errorf("selector cannot be empty")
|
||||
}
|
||||
|
||||
inputIDs := make([]int, len(input.Segments))
|
||||
for index, segment := range input.Segments {
|
||||
inputIDs[index] = segment.ID
|
||||
}
|
||||
idIndex, err := validateInputIDs(inputIDs)
|
||||
if err != nil {
|
||||
return IntermediateResult{}, err
|
||||
}
|
||||
if err := validateSelectedIDsExist(selected, idIndex); err != nil {
|
||||
return IntermediateResult{}, err
|
||||
}
|
||||
|
||||
kept := make([]schema.IntermediateSegment, 0, len(input.Segments))
|
||||
removed := make([]int, 0, len(input.Segments))
|
||||
oldToNew := make(map[int]int, len(input.Segments))
|
||||
for _, segment := range input.Segments {
|
||||
keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID)
|
||||
if opts.Mode == ModeRemove {
|
||||
keep = !opts.Selector.Contains(segment.ID)
|
||||
}
|
||||
if !keep {
|
||||
removed = append(removed, segment.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
rewritten := schema.IntermediateSegment{
|
||||
ID: len(kept) + 1,
|
||||
Start: segment.Start,
|
||||
End: segment.End,
|
||||
Speaker: segment.Speaker,
|
||||
Text: segment.Text,
|
||||
Categories: append([]string(nil), segment.Categories...),
|
||||
}
|
||||
kept = append(kept, rewritten)
|
||||
oldToNew[segment.ID] = rewritten.ID
|
||||
}
|
||||
|
||||
if len(kept) == 0 && !opts.AllowEmpty {
|
||||
return IntermediateResult{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
|
||||
}
|
||||
|
||||
return IntermediateResult{
|
||||
Transcript: schema.IntermediateTranscript{
|
||||
Metadata: schema.IntermediateMetadata{
|
||||
Application: input.Metadata.Application,
|
||||
Version: input.Metadata.Version,
|
||||
OutputSchema: input.Metadata.OutputSchema,
|
||||
},
|
||||
Segments: kept,
|
||||
},
|
||||
OldToNewID: oldToNew,
|
||||
RemovedIDs: removed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ApplyMinimal trims a minimal seriatim output transcript by segment ID.
|
||||
func ApplyMinimal(input schema.MinimalTranscript, opts Options) (MinimalResult, error) {
|
||||
if err := validateMode(opts.Mode); err != nil {
|
||||
return MinimalResult{}, err
|
||||
}
|
||||
|
||||
selected := opts.Selector.IDs()
|
||||
if len(selected) == 0 {
|
||||
return MinimalResult{}, fmt.Errorf("selector cannot be empty")
|
||||
}
|
||||
|
||||
inputIDs := make([]int, len(input.Segments))
|
||||
for index, segment := range input.Segments {
|
||||
inputIDs[index] = segment.ID
|
||||
}
|
||||
idIndex, err := validateInputIDs(inputIDs)
|
||||
if err != nil {
|
||||
return MinimalResult{}, err
|
||||
}
|
||||
if err := validateSelectedIDsExist(selected, idIndex); err != nil {
|
||||
return MinimalResult{}, err
|
||||
}
|
||||
|
||||
kept := make([]schema.MinimalSegment, 0, len(input.Segments))
|
||||
removed := make([]int, 0, len(input.Segments))
|
||||
oldToNew := make(map[int]int, len(input.Segments))
|
||||
for _, segment := range input.Segments {
|
||||
keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID)
|
||||
if opts.Mode == ModeRemove {
|
||||
keep = !opts.Selector.Contains(segment.ID)
|
||||
}
|
||||
if !keep {
|
||||
removed = append(removed, segment.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
rewritten := schema.MinimalSegment{
|
||||
ID: len(kept) + 1,
|
||||
Start: segment.Start,
|
||||
End: segment.End,
|
||||
Speaker: segment.Speaker,
|
||||
Text: segment.Text,
|
||||
}
|
||||
kept = append(kept, rewritten)
|
||||
oldToNew[segment.ID] = rewritten.ID
|
||||
}
|
||||
|
||||
if len(kept) == 0 && !opts.AllowEmpty {
|
||||
return MinimalResult{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this")
|
||||
}
|
||||
|
||||
return MinimalResult{
|
||||
Transcript: schema.MinimalTranscript{
|
||||
Metadata: schema.MinimalMetadata{
|
||||
Application: input.Metadata.Application,
|
||||
Version: input.Metadata.Version,
|
||||
OutputSchema: input.Metadata.OutputSchema,
|
||||
},
|
||||
Segments: kept,
|
||||
},
|
||||
OldToNewID: oldToNew,
|
||||
RemovedIDs: removed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateMode(mode Mode) error {
|
||||
switch mode {
|
||||
case ModeKeep, ModeRemove:
|
||||
@@ -94,10 +247,9 @@ func validateMode(mode Mode) error {
|
||||
}
|
||||
}
|
||||
|
||||
func validateInputIDs(segments []schema.Segment) (map[int]int, error) {
|
||||
seen := make(map[int]int, len(segments))
|
||||
for index, segment := range segments {
|
||||
id := segment.ID
|
||||
func validateInputIDs(ids []int) (map[int]int, error) {
|
||||
seen := make(map[int]int, len(ids))
|
||||
for index, id := range ids {
|
||||
if id <= 0 {
|
||||
return nil, fmt.Errorf("input transcript has non-positive segment ID %d at index %d", id, index)
|
||||
}
|
||||
@@ -107,14 +259,70 @@ func validateInputIDs(segments []schema.Segment) (map[int]int, error) {
|
||||
seen[id] = index
|
||||
}
|
||||
|
||||
for id := 1; id <= len(segments); id++ {
|
||||
for id := 1; id <= len(ids); id++ {
|
||||
if _, exists := seen[id]; !exists {
|
||||
return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(segments), id)
|
||||
return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(ids), id)
|
||||
}
|
||||
}
|
||||
return seen, nil
|
||||
}
|
||||
|
||||
func validateSelectedIDsExist(selected []int, idIndex map[int]int) error {
|
||||
for _, id := range selected {
|
||||
if _, exists := idIndex[id]; !exists {
|
||||
return fmt.Errorf("selected segment ID %d does not exist in input transcript", id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func recomputeOverlapGroups(segments []schema.Segment) ([]schema.Segment, []schema.OverlapGroup) {
|
||||
if len(segments) == 0 {
|
||||
return segments, nil
|
||||
}
|
||||
|
||||
modelSegments := make([]model.Segment, len(segments))
|
||||
for index, segment := range segments {
|
||||
modelSegments[index] = model.Segment{
|
||||
ID: segment.ID,
|
||||
Source: segment.Source,
|
||||
SourceSegmentIndex: copyIntPtr(segment.SourceSegmentIndex),
|
||||
SourceRef: segment.SourceRef,
|
||||
DerivedFrom: append([]string(nil), segment.DerivedFrom...),
|
||||
Speaker: segment.Speaker,
|
||||
Start: segment.Start,
|
||||
End: segment.End,
|
||||
Text: segment.Text,
|
||||
Categories: append([]string(nil), segment.Categories...),
|
||||
OverlapGroupID: segment.OverlapGroupID,
|
||||
}
|
||||
}
|
||||
|
||||
detected := overlap.Detect(model.MergedTranscript{
|
||||
Segments: modelSegments,
|
||||
})
|
||||
rewrittenSegments := make([]schema.Segment, len(segments))
|
||||
for index, segment := range segments {
|
||||
rewritten := copySegment(segment)
|
||||
rewritten.OverlapGroupID = detected.Segments[index].OverlapGroupID
|
||||
rewrittenSegments[index] = rewritten
|
||||
}
|
||||
|
||||
groups := make([]schema.OverlapGroup, len(detected.OverlapGroups))
|
||||
for index, group := range detected.OverlapGroups {
|
||||
groups[index] = schema.OverlapGroup{
|
||||
ID: group.ID,
|
||||
Start: group.Start,
|
||||
End: group.End,
|
||||
Segments: append([]string(nil), group.Segments...),
|
||||
Speakers: append([]string(nil), group.Speakers...),
|
||||
Class: group.Class,
|
||||
Resolution: group.Resolution,
|
||||
}
|
||||
}
|
||||
return rewrittenSegments, groups
|
||||
}
|
||||
|
||||
func copyTranscript(input schema.Transcript) schema.Transcript {
|
||||
return schema.Transcript{
|
||||
Metadata: schema.Metadata{
|
||||
|
||||
Reference in New Issue
Block a user