diff --git a/internal/trim/apply.go b/internal/trim/apply.go index 533fdef..07bd6ef 100644 --- a/internal/trim/apply.go +++ b/internal/trim/apply.go @@ -3,6 +3,8 @@ package trim import ( "fmt" + "gitea.maximumdirect.net/eric/seriatim/internal/model" + "gitea.maximumdirect.net/eric/seriatim/internal/overlap" "gitea.maximumdirect.net/eric/seriatim/schema" ) @@ -28,6 +30,20 @@ type Result struct { RemovedIDs []int } +// IntermediateResult contains trimming output for intermediate schema artifacts. +type IntermediateResult struct { + Transcript schema.IntermediateTranscript + OldToNewID map[int]int + RemovedIDs []int +} + +// MinimalResult contains trimming output for minimal schema artifacts. +type MinimalResult struct { + Transcript schema.MinimalTranscript + OldToNewID map[int]int + RemovedIDs []int +} + // Apply trims a full seriatim output transcript by segment ID. func Apply(input schema.Transcript, opts Options) (Result, error) { if err := validateMode(opts.Mode); err != nil { @@ -39,15 +55,18 @@ func Apply(input schema.Transcript, opts Options) (Result, error) { return Result{}, fmt.Errorf("selector cannot be empty") } - idIndex, err := validateInputIDs(input.Segments) + inputIDs := make([]int, len(input.Segments)) + for index, segment := range input.Segments { + inputIDs[index] = segment.ID + } + + idIndex, err := validateInputIDs(inputIDs) if err != nil { return Result{}, err } - for _, id := range selected { - if _, exists := idIndex[id]; !exists { - return Result{}, fmt.Errorf("selected segment ID %d does not exist in input transcript", id) - } + if err := validateSelectedIDsExist(selected, idIndex); err != nil { + return Result{}, err } kept := make([]schema.Segment, 0, len(input.Segments)) @@ -75,9 +94,11 @@ func Apply(input schema.Transcript, opts Options) (Result, error) { return Result{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this") } + kept, groups := recomputeOverlapGroups(kept) + out := copyTranscript(input) out.Segments = kept - out.OverlapGroups = nil + out.OverlapGroups = groups return Result{ Transcript: out, OldToNewID: oldToNew, @@ -85,6 +106,138 @@ func Apply(input schema.Transcript, opts Options) (Result, error) { }, nil } +// ApplyIntermediate trims an intermediate seriatim output transcript by +// segment ID. +func ApplyIntermediate(input schema.IntermediateTranscript, opts Options) (IntermediateResult, error) { + if err := validateMode(opts.Mode); err != nil { + return IntermediateResult{}, err + } + + selected := opts.Selector.IDs() + if len(selected) == 0 { + return IntermediateResult{}, fmt.Errorf("selector cannot be empty") + } + + inputIDs := make([]int, len(input.Segments)) + for index, segment := range input.Segments { + inputIDs[index] = segment.ID + } + idIndex, err := validateInputIDs(inputIDs) + if err != nil { + return IntermediateResult{}, err + } + if err := validateSelectedIDsExist(selected, idIndex); err != nil { + return IntermediateResult{}, err + } + + kept := make([]schema.IntermediateSegment, 0, len(input.Segments)) + removed := make([]int, 0, len(input.Segments)) + oldToNew := make(map[int]int, len(input.Segments)) + for _, segment := range input.Segments { + keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID) + if opts.Mode == ModeRemove { + keep = !opts.Selector.Contains(segment.ID) + } + if !keep { + removed = append(removed, segment.ID) + continue + } + + rewritten := schema.IntermediateSegment{ + ID: len(kept) + 1, + Start: segment.Start, + End: segment.End, + Speaker: segment.Speaker, + Text: segment.Text, + Categories: append([]string(nil), segment.Categories...), + } + kept = append(kept, rewritten) + oldToNew[segment.ID] = rewritten.ID + } + + if len(kept) == 0 && !opts.AllowEmpty { + return IntermediateResult{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this") + } + + return IntermediateResult{ + Transcript: schema.IntermediateTranscript{ + Metadata: schema.IntermediateMetadata{ + Application: input.Metadata.Application, + Version: input.Metadata.Version, + OutputSchema: input.Metadata.OutputSchema, + }, + Segments: kept, + }, + OldToNewID: oldToNew, + RemovedIDs: removed, + }, nil +} + +// ApplyMinimal trims a minimal seriatim output transcript by segment ID. +func ApplyMinimal(input schema.MinimalTranscript, opts Options) (MinimalResult, error) { + if err := validateMode(opts.Mode); err != nil { + return MinimalResult{}, err + } + + selected := opts.Selector.IDs() + if len(selected) == 0 { + return MinimalResult{}, fmt.Errorf("selector cannot be empty") + } + + inputIDs := make([]int, len(input.Segments)) + for index, segment := range input.Segments { + inputIDs[index] = segment.ID + } + idIndex, err := validateInputIDs(inputIDs) + if err != nil { + return MinimalResult{}, err + } + if err := validateSelectedIDsExist(selected, idIndex); err != nil { + return MinimalResult{}, err + } + + kept := make([]schema.MinimalSegment, 0, len(input.Segments)) + removed := make([]int, 0, len(input.Segments)) + oldToNew := make(map[int]int, len(input.Segments)) + for _, segment := range input.Segments { + keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID) + if opts.Mode == ModeRemove { + keep = !opts.Selector.Contains(segment.ID) + } + if !keep { + removed = append(removed, segment.ID) + continue + } + + rewritten := schema.MinimalSegment{ + ID: len(kept) + 1, + Start: segment.Start, + End: segment.End, + Speaker: segment.Speaker, + Text: segment.Text, + } + kept = append(kept, rewritten) + oldToNew[segment.ID] = rewritten.ID + } + + if len(kept) == 0 && !opts.AllowEmpty { + return MinimalResult{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this") + } + + return MinimalResult{ + Transcript: schema.MinimalTranscript{ + Metadata: schema.MinimalMetadata{ + Application: input.Metadata.Application, + Version: input.Metadata.Version, + OutputSchema: input.Metadata.OutputSchema, + }, + Segments: kept, + }, + OldToNewID: oldToNew, + RemovedIDs: removed, + }, nil +} + func validateMode(mode Mode) error { switch mode { case ModeKeep, ModeRemove: @@ -94,10 +247,9 @@ func validateMode(mode Mode) error { } } -func validateInputIDs(segments []schema.Segment) (map[int]int, error) { - seen := make(map[int]int, len(segments)) - for index, segment := range segments { - id := segment.ID +func validateInputIDs(ids []int) (map[int]int, error) { + seen := make(map[int]int, len(ids)) + for index, id := range ids { if id <= 0 { return nil, fmt.Errorf("input transcript has non-positive segment ID %d at index %d", id, index) } @@ -107,14 +259,70 @@ func validateInputIDs(segments []schema.Segment) (map[int]int, error) { seen[id] = index } - for id := 1; id <= len(segments); id++ { + for id := 1; id <= len(ids); id++ { if _, exists := seen[id]; !exists { - return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(segments), id) + return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(ids), id) } } return seen, nil } +func validateSelectedIDsExist(selected []int, idIndex map[int]int) error { + for _, id := range selected { + if _, exists := idIndex[id]; !exists { + return fmt.Errorf("selected segment ID %d does not exist in input transcript", id) + } + } + return nil +} + +func recomputeOverlapGroups(segments []schema.Segment) ([]schema.Segment, []schema.OverlapGroup) { + if len(segments) == 0 { + return segments, nil + } + + modelSegments := make([]model.Segment, len(segments)) + for index, segment := range segments { + modelSegments[index] = model.Segment{ + ID: segment.ID, + Source: segment.Source, + SourceSegmentIndex: copyIntPtr(segment.SourceSegmentIndex), + SourceRef: segment.SourceRef, + DerivedFrom: append([]string(nil), segment.DerivedFrom...), + Speaker: segment.Speaker, + Start: segment.Start, + End: segment.End, + Text: segment.Text, + Categories: append([]string(nil), segment.Categories...), + OverlapGroupID: segment.OverlapGroupID, + } + } + + detected := overlap.Detect(model.MergedTranscript{ + Segments: modelSegments, + }) + rewrittenSegments := make([]schema.Segment, len(segments)) + for index, segment := range segments { + rewritten := copySegment(segment) + rewritten.OverlapGroupID = detected.Segments[index].OverlapGroupID + rewrittenSegments[index] = rewritten + } + + groups := make([]schema.OverlapGroup, len(detected.OverlapGroups)) + for index, group := range detected.OverlapGroups { + groups[index] = schema.OverlapGroup{ + ID: group.ID, + Start: group.Start, + End: group.End, + Segments: append([]string(nil), group.Segments...), + Speakers: append([]string(nil), group.Speakers...), + Class: group.Class, + Resolution: group.Resolution, + } + } + return rewrittenSegments, groups +} + func copyTranscript(input schema.Transcript) schema.Transcript { return schema.Transcript{ Metadata: schema.Metadata{ diff --git a/internal/trim/apply_test.go b/internal/trim/apply_test.go index 8c0a5ec..cabdfbd 100644 --- a/internal/trim/apply_test.go +++ b/internal/trim/apply_test.go @@ -210,6 +210,212 @@ func TestApplyPreservesRetainedSegmentFieldsAndClearsOverlapIDs(t *testing.T) { } } +func TestApplyFullSchemaRemovesStaleOverlapGroups(t *testing.T) { + input := overlapTranscriptFixture() + selector := mustParseSelector(t, "1,3") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if len(result.Transcript.OverlapGroups) != 0 { + t.Fatalf("overlap_groups count = %d, want 0", len(result.Transcript.OverlapGroups)) + } + for index, segment := range result.Transcript.Segments { + if segment.OverlapGroupID != 0 { + t.Fatalf("segment %d overlap_group_id = %d, want 0", index, segment.OverlapGroupID) + } + } +} + +func TestApplyFullSchemaRecomputesOverlapGroup(t *testing.T) { + input := overlapTranscriptFixture() + selector := mustParseSelector(t, "1,2") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + assertSegmentIDs(t, result.Transcript.Segments, []int{1, 2}) + assertIntSlice(t, []int{ + result.Transcript.Segments[0].OverlapGroupID, + result.Transcript.Segments[1].OverlapGroupID, + }, []int{1, 1}) + if len(result.Transcript.OverlapGroups) != 1 { + t.Fatalf("overlap_groups count = %d, want 1", len(result.Transcript.OverlapGroups)) + } + group := result.Transcript.OverlapGroups[0] + if group.ID != 1 { + t.Fatalf("group ID = %d, want 1", group.ID) + } + if group.Start != 1 || group.End != 4 { + t.Fatalf("group times = %.3f-%.3f, want 1.000-4.000", group.Start, group.End) + } + if !equalStringSlices(group.Segments, []string{"a.json#10", "b.json#20"}) { + t.Fatalf("group segments = %v, want %v", group.Segments, []string{"a.json#10", "b.json#20"}) + } + if !equalStringSlices(group.Speakers, []string{"Alice", "Bob"}) { + t.Fatalf("group speakers = %v, want %v", group.Speakers, []string{"Alice", "Bob"}) + } +} + +func TestApplyFullSchemaDropsGroupWhenFewerThanTwoSpeakersRemain(t *testing.T) { + input := overlapTranscriptFixture() + selector := mustParseSelector(t, "1") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if len(result.Transcript.OverlapGroups) != 0 { + t.Fatalf("overlap_groups count = %d, want 0", len(result.Transcript.OverlapGroups)) + } + if len(result.Transcript.Segments) != 1 { + t.Fatalf("segment count = %d, want 1", len(result.Transcript.Segments)) + } + if result.Transcript.Segments[0].OverlapGroupID != 0 { + t.Fatalf("segment overlap_group_id = %d, want 0", result.Transcript.Segments[0].OverlapGroupID) + } +} + +func TestApplyFullSchemaHandlesTransitiveOverlaps(t *testing.T) { + input := transitiveOverlapFixture() + selector := mustParseSelector(t, "1-3") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if len(result.Transcript.OverlapGroups) != 1 { + t.Fatalf("overlap_groups count = %d, want 1", len(result.Transcript.OverlapGroups)) + } + assertIntSlice(t, []int{ + result.Transcript.Segments[0].OverlapGroupID, + result.Transcript.Segments[1].OverlapGroupID, + result.Transcript.Segments[2].OverlapGroupID, + }, []int{1, 1, 1}) + group := result.Transcript.OverlapGroups[0] + if group.Start != 10 || group.End != 15 { + t.Fatalf("group times = %.3f-%.3f, want 10.000-15.000", group.Start, group.End) + } +} + +func TestApplyFullSchemaBoundaryTouchingNotGrouped(t *testing.T) { + input := boundaryFixture() + selector := mustParseSelector(t, "1-2") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if len(result.Transcript.OverlapGroups) != 0 { + t.Fatalf("overlap_groups count = %d, want 0", len(result.Transcript.OverlapGroups)) + } + assertIntSlice(t, []int{ + result.Transcript.Segments[0].OverlapGroupID, + result.Transcript.Segments[1].OverlapGroupID, + }, []int{0, 0}) +} + +func TestApplyIntermediateDoesNotIncludeOverlapGroups(t *testing.T) { + input := schema.IntermediateTranscript{ + Metadata: schema.IntermediateMetadata{ + Application: "seriatim", + Version: "v-test", + OutputSchema: "seriatim-intermediate", + }, + Segments: []schema.IntermediateSegment{ + {ID: 1, Start: 1, End: 3, Speaker: "Alice", Text: "alpha", Categories: []string{"word-run"}}, + {ID: 2, Start: 2, End: 4, Speaker: "Bob", Text: "beta", Categories: []string{"filler"}}, + }, + } + selector := mustParseSelector(t, "1") + result, err := ApplyIntermediate(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply intermediate failed: %v", err) + } + if len(result.Transcript.Segments) != 1 { + t.Fatalf("segment count = %d, want 1", len(result.Transcript.Segments)) + } + if result.Transcript.Segments[0].ID != 1 { + t.Fatalf("segment id = %d, want 1", result.Transcript.Segments[0].ID) + } + if err := schema.ValidateIntermediateTranscript(result.Transcript); err != nil { + t.Fatalf("intermediate output should remain valid: %v", err) + } +} + +func TestApplyMinimalDoesNotIncludeOverlapGroups(t *testing.T) { + input := schema.MinimalTranscript{ + Metadata: schema.MinimalMetadata{ + Application: "seriatim", + Version: "v-test", + OutputSchema: "seriatim-minimal", + }, + Segments: []schema.MinimalSegment{ + {ID: 1, Start: 1, End: 3, Speaker: "Alice", Text: "alpha"}, + {ID: 2, Start: 2, End: 4, Speaker: "Bob", Text: "beta"}, + }, + } + selector := mustParseSelector(t, "2") + result, err := ApplyMinimal(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply minimal failed: %v", err) + } + if len(result.Transcript.Segments) != 1 { + t.Fatalf("segment count = %d, want 1", len(result.Transcript.Segments)) + } + if result.Transcript.Segments[0].ID != 1 { + t.Fatalf("segment id = %d, want 1", result.Transcript.Segments[0].ID) + } + if err := schema.ValidateMinimalTranscript(result.Transcript); err != nil { + t.Fatalf("minimal output should remain valid: %v", err) + } +} + +func TestApplyOutputInvariantsValidAfterRenumberAndOverlapRecompute(t *testing.T) { + input := overlapTranscriptFixture() + selector := mustParseSelector(t, "2,1") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if err := schema.ValidateTranscript(result.Transcript); err != nil { + t.Fatalf("trim output should remain valid: %v", err) + } +} + func mustParseSelector(t *testing.T, value string) Selector { t.Helper() selector, err := ParseSelector(value) @@ -303,6 +509,104 @@ func fullTranscriptFixture() schema.Transcript { } } +func overlapTranscriptFixture() schema.Transcript { + first := 10 + second := 20 + third := 30 + + return schema.Transcript{ + Metadata: schema.Metadata{ + Application: "seriatim", + Version: "v-test", + InputReader: "json-files", + InputFiles: []string{"a.json", "b.json", "c.json"}, + PreprocessingModules: []string{"validate-raw"}, + PostprocessingModules: []string{"detect-overlaps"}, + OutputModules: []string{"json"}, + }, + Segments: []schema.Segment{ + { + ID: 1, + Source: "a.json", + SourceSegmentIndex: &first, + SourceRef: "a.json#10", + Speaker: "Alice", + Start: 1, + End: 4, + Text: "a", + OverlapGroupID: 99, + }, + { + ID: 2, + Source: "b.json", + SourceSegmentIndex: &second, + SourceRef: "b.json#20", + Speaker: "Bob", + Start: 2, + End: 3, + Text: "b", + OverlapGroupID: 99, + }, + { + ID: 3, + Source: "c.json", + SourceSegmentIndex: &third, + SourceRef: "c.json#30", + Speaker: "Carol", + Start: 10, + End: 11, + Text: "c", + OverlapGroupID: 100, + }, + }, + OverlapGroups: []schema.OverlapGroup{ + { + ID: 99, + Start: 0, + End: 100, + Segments: []string{"stale#1", "stale#2"}, + Speakers: []string{"stale"}, + Class: "unknown", + Resolution: "unresolved", + }, + }, + } +} + +func transitiveOverlapFixture() schema.Transcript { + one := 1 + two := 2 + three := 3 + return schema.Transcript{ + Metadata: schema.Metadata{ + Application: "seriatim", + Version: "v-test", + }, + Segments: []schema.Segment{ + {ID: 1, Source: "a.json", SourceSegmentIndex: &one, Speaker: "Alice", Start: 10, End: 14, Text: "a"}, + {ID: 2, Source: "b.json", SourceSegmentIndex: &two, Speaker: "Bob", Start: 12, End: 13, Text: "b"}, + {ID: 3, Source: "c.json", SourceSegmentIndex: &three, Speaker: "Carol", Start: 13.5, End: 15, Text: "c"}, + }, + OverlapGroups: []schema.OverlapGroup{{ID: 77}}, + } +} + +func boundaryFixture() schema.Transcript { + one := 1 + two := 2 + return schema.Transcript{ + Metadata: schema.Metadata{ + Application: "seriatim", + Version: "v-test", + }, + Segments: []schema.Segment{ + {ID: 1, Source: "a.json", SourceSegmentIndex: &one, Speaker: "Alice", Start: 1, End: 2, Text: "a", OverlapGroupID: 7}, + {ID: 2, Source: "b.json", SourceSegmentIndex: &two, Speaker: "Bob", Start: 2, End: 3, Text: "b", OverlapGroupID: 7}, + }, + OverlapGroups: []schema.OverlapGroup{{ID: 7, Start: 1, End: 3}}, + } +} + func assertSegmentIDs(t *testing.T, segments []schema.Segment, want []int) { t.Helper() got := make([]int, len(segments))