diff --git a/internal/trim/apply.go b/internal/trim/apply.go new file mode 100644 index 0000000..533fdef --- /dev/null +++ b/internal/trim/apply.go @@ -0,0 +1,156 @@ +package trim + +import ( + "fmt" + + "gitea.maximumdirect.net/eric/seriatim/schema" +) + +// Mode controls how selector IDs are applied. +type Mode string + +const ( + ModeKeep Mode = "keep" + ModeRemove Mode = "remove" +) + +// Options configures transcript trimming. +type Options struct { + Mode Mode + Selector Selector + AllowEmpty bool +} + +// Result contains trimming output and ID mapping metadata. +type Result struct { + Transcript schema.Transcript + OldToNewID map[int]int + RemovedIDs []int +} + +// Apply trims a full seriatim output transcript by segment ID. +func Apply(input schema.Transcript, opts Options) (Result, error) { + if err := validateMode(opts.Mode); err != nil { + return Result{}, err + } + + selected := opts.Selector.IDs() + if len(selected) == 0 { + return Result{}, fmt.Errorf("selector cannot be empty") + } + + idIndex, err := validateInputIDs(input.Segments) + if err != nil { + return Result{}, err + } + + for _, id := range selected { + if _, exists := idIndex[id]; !exists { + return Result{}, fmt.Errorf("selected segment ID %d does not exist in input transcript", id) + } + } + + kept := make([]schema.Segment, 0, len(input.Segments)) + removed := make([]int, 0, len(input.Segments)) + oldToNew := make(map[int]int, len(input.Segments)) + for _, segment := range input.Segments { + keep := opts.Mode == ModeKeep && opts.Selector.Contains(segment.ID) + if opts.Mode == ModeRemove { + keep = !opts.Selector.Contains(segment.ID) + } + + if !keep { + removed = append(removed, segment.ID) + continue + } + + rewritten := copySegment(segment) + rewritten.ID = len(kept) + 1 + rewritten.OverlapGroupID = 0 + kept = append(kept, rewritten) + oldToNew[segment.ID] = rewritten.ID + } + + if len(kept) == 0 && !opts.AllowEmpty { + return Result{}, fmt.Errorf("trim operation produced an empty transcript; set AllowEmpty to true to permit this") + } + + out := copyTranscript(input) + out.Segments = kept + out.OverlapGroups = nil + return Result{ + Transcript: out, + OldToNewID: oldToNew, + RemovedIDs: removed, + }, nil +} + +func validateMode(mode Mode) error { + switch mode { + case ModeKeep, ModeRemove: + return nil + default: + return fmt.Errorf("invalid trim mode %q", mode) + } +} + +func validateInputIDs(segments []schema.Segment) (map[int]int, error) { + seen := make(map[int]int, len(segments)) + for index, segment := range segments { + id := segment.ID + if id <= 0 { + return nil, fmt.Errorf("input transcript has non-positive segment ID %d at index %d", id, index) + } + if firstIndex, exists := seen[id]; exists { + return nil, fmt.Errorf("input transcript has duplicate segment ID %d at indexes %d and %d", id, firstIndex, index) + } + seen[id] = index + } + + for id := 1; id <= len(segments); id++ { + if _, exists := seen[id]; !exists { + return nil, fmt.Errorf("input transcript segment IDs must be sequential 1..%d; missing ID %d", len(segments), id) + } + } + return seen, nil +} + +func copyTranscript(input schema.Transcript) schema.Transcript { + return schema.Transcript{ + Metadata: schema.Metadata{ + Application: input.Metadata.Application, + Version: input.Metadata.Version, + InputReader: input.Metadata.InputReader, + InputFiles: append([]string(nil), input.Metadata.InputFiles...), + PreprocessingModules: append([]string(nil), input.Metadata.PreprocessingModules...), + PostprocessingModules: append([]string(nil), input.Metadata.PostprocessingModules...), + OutputModules: append([]string(nil), input.Metadata.OutputModules...), + }, + Segments: append([]schema.Segment(nil), input.Segments...), + OverlapGroups: append([]schema.OverlapGroup(nil), input.OverlapGroups...), + } +} + +func copySegment(input schema.Segment) schema.Segment { + return schema.Segment{ + ID: input.ID, + Source: input.Source, + SourceSegmentIndex: copyIntPtr(input.SourceSegmentIndex), + SourceRef: input.SourceRef, + DerivedFrom: append([]string(nil), input.DerivedFrom...), + Speaker: input.Speaker, + Start: input.Start, + End: input.End, + Text: input.Text, + Categories: append([]string(nil), input.Categories...), + OverlapGroupID: input.OverlapGroupID, + } +} + +func copyIntPtr(value *int) *int { + if value == nil { + return nil + } + copied := *value + return &copied +} diff --git a/internal/trim/apply_test.go b/internal/trim/apply_test.go new file mode 100644 index 0000000..8c0a5ec --- /dev/null +++ b/internal/trim/apply_test.go @@ -0,0 +1,364 @@ +package trim + +import ( + "strings" + "testing" + + "gitea.maximumdirect.net/eric/seriatim/schema" +) + +func TestApplyKeepModeRenumbersFromOne(t *testing.T) { + input := fullTranscriptFixture() + selector := mustParseSelector(t, "2,4") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if len(result.Transcript.Segments) != 2 { + t.Fatalf("segment count = %d, want 2", len(result.Transcript.Segments)) + } + assertSegmentIDs(t, result.Transcript.Segments, []int{1, 2}) + assertSegmentTexts(t, result.Transcript.Segments, []string{"beta", "delta"}) + assertIntMap(t, result.OldToNewID, map[int]int{2: 1, 4: 2}) + assertIntSlice(t, result.RemovedIDs, []int{1, 3}) +} + +func TestApplyRemoveModeRenumbersFromOne(t *testing.T) { + input := fullTranscriptFixture() + selector := mustParseSelector(t, "2,4") + + result, err := Apply(input, Options{ + Mode: ModeRemove, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + assertSegmentIDs(t, result.Transcript.Segments, []int{1, 2}) + assertSegmentTexts(t, result.Transcript.Segments, []string{"alpha", "gamma"}) + assertIntMap(t, result.OldToNewID, map[int]int{1: 1, 3: 2}) + assertIntSlice(t, result.RemovedIDs, []int{2, 4}) +} + +func TestApplySelectorOrderDoesNotChangeTranscriptOrder(t *testing.T) { + input := fullTranscriptFixture() + selector := mustParseSelector(t, "4,1,3") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + assertSegmentIDs(t, result.Transcript.Segments, []int{1, 2, 3}) + assertSegmentTexts(t, result.Transcript.Segments, []string{"alpha", "gamma", "delta"}) +} + +func TestApplyFailsWhenSelectedIDDoesNotExist(t *testing.T) { + input := fullTranscriptFixture() + selector := mustParseSelector(t, "2,99") + + _, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err == nil { + t.Fatal("expected missing selected ID error") + } + if !strings.Contains(err.Error(), "does not exist") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestApplyFailsOnDuplicateInputIDs(t *testing.T) { + input := fullTranscriptFixture() + input.Segments[2].ID = 2 + selector := mustParseSelector(t, "2") + + _, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err == nil { + t.Fatal("expected duplicate input ID error") + } + if !strings.Contains(err.Error(), "duplicate segment ID") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestApplyFailsOnMissingOrNonSequentialInputIDs(t *testing.T) { + input := fullTranscriptFixture() + input.Segments[1].ID = 5 + selector := mustParseSelector(t, "1") + + _, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err == nil { + t.Fatal("expected non-sequential input ID error") + } + if !strings.Contains(err.Error(), "must be sequential") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestApplyFailsOnNonPositiveInputIDs(t *testing.T) { + input := fullTranscriptFixture() + input.Segments[0].ID = 0 + selector := mustParseSelector(t, "1") + + _, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err == nil { + t.Fatal("expected non-positive input ID error") + } + if !strings.Contains(err.Error(), "non-positive") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestApplyEmptyOutputFailsUnlessAllowEmpty(t *testing.T) { + input := fullTranscriptFixture() + selector := mustParseSelector(t, "1-4") + + _, err := Apply(input, Options{ + Mode: ModeRemove, + Selector: selector, + }) + if err == nil { + t.Fatal("expected empty-output error") + } + if !strings.Contains(err.Error(), "empty transcript") { + t.Fatalf("unexpected error: %v", err) + } + + allowed, err := Apply(input, Options{ + Mode: ModeRemove, + Selector: selector, + AllowEmpty: true, + }) + if err != nil { + t.Fatalf("apply with AllowEmpty failed: %v", err) + } + if len(allowed.Transcript.Segments) != 0 { + t.Fatalf("segment count = %d, want 0", len(allowed.Transcript.Segments)) + } + assertIntMap(t, allowed.OldToNewID, map[int]int{}) + assertIntSlice(t, allowed.RemovedIDs, []int{1, 2, 3, 4}) +} + +func TestApplyPreservesRetainedSegmentFieldsAndClearsOverlapIDs(t *testing.T) { + input := fullTranscriptFixture() + selector := mustParseSelector(t, "2") + + result, err := Apply(input, Options{ + Mode: ModeKeep, + Selector: selector, + }) + if err != nil { + t.Fatalf("apply failed: %v", err) + } + + if len(result.Transcript.Segments) != 1 { + t.Fatalf("segment count = %d, want 1", len(result.Transcript.Segments)) + } + segment := result.Transcript.Segments[0] + if segment.ID != 1 { + t.Fatalf("segment ID = %d, want 1", segment.ID) + } + if segment.Source != "b.json" { + t.Fatalf("source = %q, want %q", segment.Source, "b.json") + } + if segment.SourceSegmentIndex == nil || *segment.SourceSegmentIndex != 20 { + t.Fatalf("source_segment_index = %v, want 20", segment.SourceSegmentIndex) + } + if segment.SourceRef != "b.json#20" { + t.Fatalf("source_ref = %q, want %q", segment.SourceRef, "b.json#20") + } + if !equalStringSlices(segment.DerivedFrom, []string{"b.json#19", "b.json#20"}) { + t.Fatalf("derived_from = %v, want %v", segment.DerivedFrom, []string{"b.json#19", "b.json#20"}) + } + if !equalStringSlices(segment.Categories, []string{"filler", "backchannel"}) { + t.Fatalf("categories = %v, want %v", segment.Categories, []string{"filler", "backchannel"}) + } + if segment.Speaker != "Bob" { + t.Fatalf("speaker = %q, want Bob", segment.Speaker) + } + if segment.Start != 2 || segment.End != 3 { + t.Fatalf("times = %.3f-%.3f, want 2.000-3.000", segment.Start, segment.End) + } + if segment.Text != "beta" { + t.Fatalf("text = %q, want beta", segment.Text) + } + if segment.OverlapGroupID != 0 { + t.Fatalf("overlap_group_id = %d, want 0", segment.OverlapGroupID) + } + if len(result.Transcript.OverlapGroups) != 0 { + t.Fatalf("overlap_groups count = %d, want 0", len(result.Transcript.OverlapGroups)) + } +} + +func mustParseSelector(t *testing.T, value string) Selector { + t.Helper() + selector, err := ParseSelector(value) + if err != nil { + t.Fatalf("selector parse failed for %q: %v", value, err) + } + return selector +} + +func fullTranscriptFixture() schema.Transcript { + firstIndex := 10 + secondIndex := 20 + thirdIndex := 30 + fourthIndex := 40 + + return schema.Transcript{ + Metadata: schema.Metadata{ + Application: "seriatim", + Version: "v-test", + InputReader: "json-files", + InputFiles: []string{"a.json", "b.json"}, + PreprocessingModules: []string{"validate-raw"}, + PostprocessingModules: []string{"detect-overlaps"}, + OutputModules: []string{"json"}, + }, + Segments: []schema.Segment{ + { + ID: 1, + Source: "a.json", + SourceSegmentIndex: &firstIndex, + SourceRef: "a.json#10", + DerivedFrom: []string{"a.json#10"}, + Speaker: "Alice", + Start: 1, + End: 2, + Text: "alpha", + Categories: []string{"word-run"}, + OverlapGroupID: 7, + }, + { + ID: 2, + Source: "b.json", + SourceSegmentIndex: &secondIndex, + SourceRef: "b.json#20", + DerivedFrom: []string{"b.json#19", "b.json#20"}, + Speaker: "Bob", + Start: 2, + End: 3, + Text: "beta", + Categories: []string{"filler", "backchannel"}, + OverlapGroupID: 7, + }, + { + ID: 3, + Source: "c.json", + SourceSegmentIndex: &thirdIndex, + SourceRef: "c.json#30", + DerivedFrom: []string{"c.json#30"}, + Speaker: "Carol", + Start: 3, + End: 4, + Text: "gamma", + Categories: []string{"normal"}, + OverlapGroupID: 8, + }, + { + ID: 4, + Source: "d.json", + SourceSegmentIndex: &fourthIndex, + SourceRef: "d.json#40", + DerivedFrom: []string{"d.json#40"}, + Speaker: "Dan", + Start: 4, + End: 5, + Text: "delta", + Categories: []string{"normal"}, + OverlapGroupID: 9, + }, + }, + OverlapGroups: []schema.OverlapGroup{ + { + ID: 7, + Start: 1.5, + End: 3.1, + Segments: []string{"a.json#10", "b.json#20"}, + Speakers: []string{"Alice", "Bob"}, + Class: "unknown", + Resolution: "unresolved", + }, + }, + } +} + +func assertSegmentIDs(t *testing.T, segments []schema.Segment, want []int) { + t.Helper() + got := make([]int, len(segments)) + for index, segment := range segments { + got[index] = segment.ID + } + assertIntSlice(t, got, want) +} + +func assertSegmentTexts(t *testing.T, segments []schema.Segment, want []string) { + t.Helper() + got := make([]string, len(segments)) + for index, segment := range segments { + got[index] = segment.Text + } + if !equalStringSlices(got, want) { + t.Fatalf("segment texts = %v, want %v", got, want) + } +} + +func assertIntSlice(t *testing.T, got []int, want []int) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("slice length = %d, want %d", len(got), len(want)) + } + for index := range got { + if got[index] != want[index] { + t.Fatalf("slice[%d] = %d, want %d (full got=%v, want=%v)", index, got[index], want[index], got, want) + } + } +} + +func assertIntMap(t *testing.T, got map[int]int, want map[int]int) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("map length = %d, want %d", len(got), len(want)) + } + for key, wantValue := range want { + gotValue, exists := got[key] + if !exists { + t.Fatalf("missing map key %d", key) + } + if gotValue != wantValue { + t.Fatalf("map[%d] = %d, want %d", key, gotValue, wantValue) + } + } +} + +func equalStringSlices(got []string, want []string) bool { + if len(got) != len(want) { + return false + } + for index := range got { + if got[index] != want[index] { + return false + } + } + return true +}