Implemented an overlap detection module in the postprocessing chain
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/autocorrect"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/config"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/overlap"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/report"
|
||||
)
|
||||
|
||||
@@ -48,6 +49,23 @@ func (assignIDs) Process(ctx context.Context, in model.MergedTranscript, cfg con
|
||||
}, nil
|
||||
}
|
||||
|
||||
type detectOverlaps struct{}
|
||||
|
||||
func (detectOverlaps) Name() string {
|
||||
return "detect-overlaps"
|
||||
}
|
||||
|
||||
func (detectOverlaps) Process(ctx context.Context, in model.MergedTranscript, cfg config.Config) (model.MergedTranscript, []report.Event, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return model.MergedTranscript{}, nil, err
|
||||
}
|
||||
|
||||
in = overlap.Detect(in)
|
||||
return in, []report.Event{
|
||||
report.Info("postprocessing", "detect-overlaps", fmt.Sprintf("detected %d overlap group(s)", len(in.OverlapGroups))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type autocorrectPostprocessor struct{}
|
||||
|
||||
func (autocorrectPostprocessor) Name() string {
|
||||
|
||||
@@ -11,7 +11,7 @@ func NewRegistry() *pipeline.Registry {
|
||||
registry.RegisterPreprocessor(normalizeSpeakers{})
|
||||
registry.RegisterPreprocessor(trimText{})
|
||||
registry.RegisterMerger(placeholderMerger{})
|
||||
registry.RegisterPostprocessor(noopPostprocessor{name: "detect-overlaps"})
|
||||
registry.RegisterPostprocessor(detectOverlaps{})
|
||||
registry.RegisterPostprocessor(noopPostprocessor{name: "resolve-overlaps"})
|
||||
registry.RegisterPostprocessor(assignIDs{})
|
||||
registry.RegisterPostprocessor(noopPostprocessor{name: "validate-output"})
|
||||
|
||||
@@ -149,6 +149,74 @@ func TestMergeTieBreakOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDetectsOverlapGroups(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
inputA := writeJSONFile(t, dir, "a.json", `{
|
||||
"segments": [
|
||||
{"start": 1, "end": 5, "text": "alice long"},
|
||||
{"start": 2, "end": 3, "text": "alice nested"}
|
||||
]
|
||||
}`)
|
||||
inputB := writeJSONFile(t, dir, "b.json", `{
|
||||
"segments": [
|
||||
{"start": 4, "end": 6, "text": "bob overlap"}
|
||||
]
|
||||
}`)
|
||||
speakers := writeYAMLFile(t, dir, "speakers.yml", `match:
|
||||
- speaker: Alice
|
||||
match: ["a.json"]
|
||||
- speaker: Bob
|
||||
match: ["b.json"]
|
||||
`)
|
||||
output := filepath.Join(dir, "merged.json")
|
||||
reportPath := filepath.Join(dir, "report.json")
|
||||
|
||||
err := executeMerge(
|
||||
"--input-file", inputB,
|
||||
"--input-file", inputA,
|
||||
"--speakers", speakers,
|
||||
"--output-file", output,
|
||||
"--report-file", reportPath,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("merge failed: %v", err)
|
||||
}
|
||||
|
||||
var transcript model.FinalTranscript
|
||||
readJSON(t, output, &transcript)
|
||||
if len(transcript.OverlapGroups) != 1 {
|
||||
t.Fatalf("overlap group count = %d, want 1", len(transcript.OverlapGroups))
|
||||
}
|
||||
group := transcript.OverlapGroups[0]
|
||||
if group.ID != 1 {
|
||||
t.Fatalf("group ID = %d, want 1", group.ID)
|
||||
}
|
||||
if group.Start != 1 || group.End != 6 {
|
||||
t.Fatalf("group bounds = %f-%f, want 1-6", group.Start, group.End)
|
||||
}
|
||||
wantRefs := []string{inputA + "#0", inputA + "#1", inputB + "#0"}
|
||||
if !equalStrings(group.Segments, wantRefs) {
|
||||
t.Fatalf("group refs = %v, want %v", group.Segments, wantRefs)
|
||||
}
|
||||
if !equalStrings(group.Speakers, []string{"Alice", "Bob"}) {
|
||||
t.Fatalf("group speakers = %v, want [Alice Bob]", group.Speakers)
|
||||
}
|
||||
if group.Class != "unknown" || group.Resolution != "unresolved" {
|
||||
t.Fatalf("unexpected group class/resolution: %q/%q", group.Class, group.Resolution)
|
||||
}
|
||||
for index, segment := range transcript.Segments {
|
||||
if segment.OverlapGroupID != 1 {
|
||||
t.Fatalf("segment %d overlap group ID = %d, want 1", index, segment.OverlapGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
var rpt report.Report
|
||||
readJSON(t, reportPath, &rpt)
|
||||
if !hasReportEvent(rpt, "postprocessing", "detect-overlaps", "detected 1 overlap group(s)") {
|
||||
t.Fatal("expected detect-overlaps report event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeakerMatchingUsesFirstMatchingRuleCaseInsensitive(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
input := writeJSONFile(t, dir, "2026-04-19-Adam_Rakestraw.json", `{
|
||||
|
||||
@@ -137,6 +137,7 @@ func normalizeInputFiles(paths []string) ([]string, error) {
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(paths))
|
||||
seen := make(map[string]struct{}, len(paths))
|
||||
for _, path := range paths {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
@@ -147,6 +148,10 @@ func normalizeInputFiles(paths []string) ([]string, error) {
|
||||
if err := requireFile(clean, "--input-file"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, exists := seen[clean]; exists {
|
||||
return nil, fmt.Errorf("duplicate --input-file %q", clean)
|
||||
}
|
||||
seen[clean] = struct{}{}
|
||||
normalized = append(normalized, clean)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -24,6 +25,27 @@ func TestOmittingNormalizeSpeakersDoesNotRequireSpeakers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateInputFilesFailValidation(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
input := writeTempFile(t, dir, "input.json")
|
||||
output := filepath.Join(dir, "merged.json")
|
||||
|
||||
_, err := NewMergeConfig(MergeOptions{
|
||||
InputFiles: []string{input, input},
|
||||
OutputFile: output,
|
||||
InputReader: DefaultInputReader,
|
||||
OutputModules: DefaultOutputModules,
|
||||
PreprocessingModules: DefaultPreprocessingModules,
|
||||
PostprocessingModules: DefaultPostprocessingModules,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected duplicate input error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "duplicate --input-file") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeTempFile(t *testing.T, dir string, name string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
118
internal/overlap/detect.go
Normal file
118
internal/overlap/detect.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package overlap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClass = "unknown"
|
||||
defaultResolution = "unresolved"
|
||||
)
|
||||
|
||||
// Detect annotates overlapping segment groups in an already sorted merged transcript.
|
||||
func Detect(in model.MergedTranscript) model.MergedTranscript {
|
||||
clearExisting(&in)
|
||||
if len(in.Segments) < 2 {
|
||||
return in
|
||||
}
|
||||
|
||||
var groupID int
|
||||
var candidate overlapCandidate
|
||||
for index := range in.Segments {
|
||||
segment := in.Segments[index]
|
||||
if !candidate.active {
|
||||
candidate = newCandidate(index, segment)
|
||||
continue
|
||||
}
|
||||
|
||||
if segment.Start < candidate.end {
|
||||
candidate.add(index, segment)
|
||||
continue
|
||||
}
|
||||
|
||||
groupID = finalizeCandidate(&in, candidate, groupID)
|
||||
candidate = newCandidate(index, segment)
|
||||
}
|
||||
|
||||
finalizeCandidate(&in, candidate, groupID)
|
||||
return in
|
||||
}
|
||||
|
||||
type overlapCandidate struct {
|
||||
active bool
|
||||
indices []int
|
||||
start float64
|
||||
end float64
|
||||
}
|
||||
|
||||
func newCandidate(index int, segment model.Segment) overlapCandidate {
|
||||
return overlapCandidate{
|
||||
active: true,
|
||||
indices: []int{index},
|
||||
start: segment.Start,
|
||||
end: segment.End,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *overlapCandidate) add(index int, segment model.Segment) {
|
||||
c.indices = append(c.indices, index)
|
||||
if segment.End > c.end {
|
||||
c.end = segment.End
|
||||
}
|
||||
}
|
||||
|
||||
func finalizeCandidate(in *model.MergedTranscript, candidate overlapCandidate, currentGroupID int) int {
|
||||
if !candidate.active || len(candidate.indices) < 2 {
|
||||
return currentGroupID
|
||||
}
|
||||
|
||||
speakers := distinctSpeakers(in.Segments, candidate.indices)
|
||||
if len(speakers) < 2 {
|
||||
return currentGroupID
|
||||
}
|
||||
|
||||
groupID := currentGroupID + 1
|
||||
refs := make([]string, 0, len(candidate.indices))
|
||||
for _, index := range candidate.indices {
|
||||
in.Segments[index].OverlapGroupID = groupID
|
||||
refs = append(refs, segmentRef(in.Segments[index]))
|
||||
}
|
||||
|
||||
in.OverlapGroups = append(in.OverlapGroups, model.OverlapGroup{
|
||||
ID: groupID,
|
||||
Start: candidate.start,
|
||||
End: candidate.end,
|
||||
Segments: refs,
|
||||
Speakers: speakers,
|
||||
Class: defaultClass,
|
||||
Resolution: defaultResolution,
|
||||
})
|
||||
return groupID
|
||||
}
|
||||
|
||||
func distinctSpeakers(segments []model.Segment, indices []int) []string {
|
||||
seen := make(map[string]struct{}, len(indices))
|
||||
speakers := make([]string, 0, len(indices))
|
||||
for _, index := range indices {
|
||||
speaker := segments[index].Speaker
|
||||
if _, exists := seen[speaker]; exists {
|
||||
continue
|
||||
}
|
||||
seen[speaker] = struct{}{}
|
||||
speakers = append(speakers, speaker)
|
||||
}
|
||||
return speakers
|
||||
}
|
||||
|
||||
func segmentRef(segment model.Segment) string {
|
||||
return fmt.Sprintf("%s#%d", segment.Source, segment.SourceSegmentIndex)
|
||||
}
|
||||
|
||||
func clearExisting(in *model.MergedTranscript) {
|
||||
in.OverlapGroups = make([]model.OverlapGroup, 0)
|
||||
for index := range in.Segments {
|
||||
in.Segments[index].OverlapGroupID = 0
|
||||
}
|
||||
}
|
||||
203
internal/overlap/detect_test.go
Normal file
203
internal/overlap/detect_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package overlap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
)
|
||||
|
||||
func TestDetectNoOverlaps(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 2),
|
||||
segment("b.json", 0, "Bob", 2, 3),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 0 {
|
||||
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
||||
}
|
||||
assertNoSegmentAnnotations(t, got)
|
||||
}
|
||||
|
||||
func TestDetectSimpleTwoSpeakerOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("b.json", 0, "Bob", 2, 4),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
assertGroup(t, got, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1})
|
||||
}
|
||||
|
||||
func TestDetectTransitiveOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 10, 14),
|
||||
segment("b.json", 0, "Bob", 12, 13),
|
||||
segment("c.json", 0, "Carol", 13.5, 15),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
assertGroup(t, got, 0, 1, 10, 15, []string{"a.json#0", "b.json#0", "c.json#0"}, []string{"Alice", "Bob", "Carol"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1, 1})
|
||||
}
|
||||
|
||||
func TestDetectBoundaryContactDoesNotOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 2),
|
||||
segment("b.json", 0, "Bob", 2, 3),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 0 {
|
||||
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
||||
}
|
||||
assertNoSegmentAnnotations(t, got)
|
||||
}
|
||||
|
||||
func TestDetectSameSpeakerOnlyOverlapDoesNotCreateGroup(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("a.json", 1, "Alice", 2, 4),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 0 {
|
||||
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
||||
}
|
||||
assertNoSegmentAnnotations(t, got)
|
||||
}
|
||||
|
||||
func TestDetectIncludesSameSpeakerSegmentInsideMultiSpeakerOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 5),
|
||||
segment("a.json", 1, "Alice", 2, 3),
|
||||
segment("b.json", 0, "Bob", 4, 6),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
assertGroup(t, got, 0, 1, 1, 6, []string{"a.json#0", "a.json#1", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1, 1})
|
||||
}
|
||||
|
||||
func TestDetectMultipleGroups(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("b.json", 0, "Bob", 2, 4),
|
||||
segment("c.json", 0, "Carol", 5, 7),
|
||||
segment("d.json", 0, "Dan", 6, 8),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 2 {
|
||||
t.Fatalf("group count = %d, want 2", len(got.OverlapGroups))
|
||||
}
|
||||
assertGroup(t, got, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertGroup(t, got, 1, 2, 5, 8, []string{"c.json#0", "d.json#0"}, []string{"Carol", "Dan"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1, 2, 2})
|
||||
}
|
||||
|
||||
func TestDetectIsIdempotent(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("b.json", 0, "Bob", 2, 4),
|
||||
},
|
||||
OverlapGroups: []model.OverlapGroup{
|
||||
{ID: 99},
|
||||
},
|
||||
}
|
||||
merged.Segments[0].OverlapGroupID = 99
|
||||
|
||||
once := Detect(merged)
|
||||
twice := Detect(once)
|
||||
assertGroup(t, twice, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertSegmentGroupIDs(t, twice, []int{1, 1})
|
||||
}
|
||||
|
||||
func segment(source string, sourceIndex int, speaker string, start float64, end float64) model.Segment {
|
||||
return model.Segment{
|
||||
Source: source,
|
||||
SourceSegmentIndex: sourceIndex,
|
||||
Speaker: speaker,
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: speaker,
|
||||
}
|
||||
}
|
||||
|
||||
func assertGroup(t *testing.T, merged model.MergedTranscript, groupIndex int, id int, start float64, end float64, refs []string, speakers []string) {
|
||||
t.Helper()
|
||||
if len(merged.OverlapGroups) <= groupIndex {
|
||||
t.Fatalf("missing overlap group index %d", groupIndex)
|
||||
}
|
||||
group := merged.OverlapGroups[groupIndex]
|
||||
if group.ID != id {
|
||||
t.Fatalf("group ID = %d, want %d", group.ID, id)
|
||||
}
|
||||
if group.Start != start {
|
||||
t.Fatalf("group start = %f, want %f", group.Start, start)
|
||||
}
|
||||
if group.End != end {
|
||||
t.Fatalf("group end = %f, want %f", group.End, end)
|
||||
}
|
||||
if group.Class != "unknown" {
|
||||
t.Fatalf("group class = %q, want unknown", group.Class)
|
||||
}
|
||||
if group.Resolution != "unresolved" {
|
||||
t.Fatalf("group resolution = %q, want unresolved", group.Resolution)
|
||||
}
|
||||
if !equalStrings(group.Segments, refs) {
|
||||
t.Fatalf("group refs = %v, want %v", group.Segments, refs)
|
||||
}
|
||||
if !equalStrings(group.Speakers, speakers) {
|
||||
t.Fatalf("group speakers = %v, want %v", group.Speakers, speakers)
|
||||
}
|
||||
}
|
||||
|
||||
func assertSegmentGroupIDs(t *testing.T, merged model.MergedTranscript, ids []int) {
|
||||
t.Helper()
|
||||
if len(merged.Segments) != len(ids) {
|
||||
t.Fatalf("segment count = %d, want %d", len(merged.Segments), len(ids))
|
||||
}
|
||||
for index, id := range ids {
|
||||
if merged.Segments[index].OverlapGroupID != id {
|
||||
t.Fatalf("segment %d overlap group ID = %d, want %d", index, merged.Segments[index].OverlapGroupID, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertNoSegmentAnnotations(t *testing.T, merged model.MergedTranscript) {
|
||||
t.Helper()
|
||||
for index, segment := range merged.Segments {
|
||||
if segment.OverlapGroupID != 0 {
|
||||
t.Fatalf("segment %d overlap group ID = %d, want 0", index, segment.OverlapGroupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(left []string, right []string) bool {
|
||||
if len(left) != len(right) {
|
||||
return false
|
||||
}
|
||||
for index := range left {
|
||||
if left[index] != right[index] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user