Implemented an overlap detection module in the postprocessing chain
This commit is contained in:
118
internal/overlap/detect.go
Normal file
118
internal/overlap/detect.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package overlap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClass = "unknown"
|
||||
defaultResolution = "unresolved"
|
||||
)
|
||||
|
||||
// Detect annotates overlapping segment groups in an already sorted merged transcript.
|
||||
func Detect(in model.MergedTranscript) model.MergedTranscript {
|
||||
clearExisting(&in)
|
||||
if len(in.Segments) < 2 {
|
||||
return in
|
||||
}
|
||||
|
||||
var groupID int
|
||||
var candidate overlapCandidate
|
||||
for index := range in.Segments {
|
||||
segment := in.Segments[index]
|
||||
if !candidate.active {
|
||||
candidate = newCandidate(index, segment)
|
||||
continue
|
||||
}
|
||||
|
||||
if segment.Start < candidate.end {
|
||||
candidate.add(index, segment)
|
||||
continue
|
||||
}
|
||||
|
||||
groupID = finalizeCandidate(&in, candidate, groupID)
|
||||
candidate = newCandidate(index, segment)
|
||||
}
|
||||
|
||||
finalizeCandidate(&in, candidate, groupID)
|
||||
return in
|
||||
}
|
||||
|
||||
type overlapCandidate struct {
|
||||
active bool
|
||||
indices []int
|
||||
start float64
|
||||
end float64
|
||||
}
|
||||
|
||||
func newCandidate(index int, segment model.Segment) overlapCandidate {
|
||||
return overlapCandidate{
|
||||
active: true,
|
||||
indices: []int{index},
|
||||
start: segment.Start,
|
||||
end: segment.End,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *overlapCandidate) add(index int, segment model.Segment) {
|
||||
c.indices = append(c.indices, index)
|
||||
if segment.End > c.end {
|
||||
c.end = segment.End
|
||||
}
|
||||
}
|
||||
|
||||
func finalizeCandidate(in *model.MergedTranscript, candidate overlapCandidate, currentGroupID int) int {
|
||||
if !candidate.active || len(candidate.indices) < 2 {
|
||||
return currentGroupID
|
||||
}
|
||||
|
||||
speakers := distinctSpeakers(in.Segments, candidate.indices)
|
||||
if len(speakers) < 2 {
|
||||
return currentGroupID
|
||||
}
|
||||
|
||||
groupID := currentGroupID + 1
|
||||
refs := make([]string, 0, len(candidate.indices))
|
||||
for _, index := range candidate.indices {
|
||||
in.Segments[index].OverlapGroupID = groupID
|
||||
refs = append(refs, segmentRef(in.Segments[index]))
|
||||
}
|
||||
|
||||
in.OverlapGroups = append(in.OverlapGroups, model.OverlapGroup{
|
||||
ID: groupID,
|
||||
Start: candidate.start,
|
||||
End: candidate.end,
|
||||
Segments: refs,
|
||||
Speakers: speakers,
|
||||
Class: defaultClass,
|
||||
Resolution: defaultResolution,
|
||||
})
|
||||
return groupID
|
||||
}
|
||||
|
||||
func distinctSpeakers(segments []model.Segment, indices []int) []string {
|
||||
seen := make(map[string]struct{}, len(indices))
|
||||
speakers := make([]string, 0, len(indices))
|
||||
for _, index := range indices {
|
||||
speaker := segments[index].Speaker
|
||||
if _, exists := seen[speaker]; exists {
|
||||
continue
|
||||
}
|
||||
seen[speaker] = struct{}{}
|
||||
speakers = append(speakers, speaker)
|
||||
}
|
||||
return speakers
|
||||
}
|
||||
|
||||
func segmentRef(segment model.Segment) string {
|
||||
return fmt.Sprintf("%s#%d", segment.Source, segment.SourceSegmentIndex)
|
||||
}
|
||||
|
||||
func clearExisting(in *model.MergedTranscript) {
|
||||
in.OverlapGroups = make([]model.OverlapGroup, 0)
|
||||
for index := range in.Segments {
|
||||
in.Segments[index].OverlapGroupID = 0
|
||||
}
|
||||
}
|
||||
203
internal/overlap/detect_test.go
Normal file
203
internal/overlap/detect_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package overlap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
)
|
||||
|
||||
func TestDetectNoOverlaps(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 2),
|
||||
segment("b.json", 0, "Bob", 2, 3),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 0 {
|
||||
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
||||
}
|
||||
assertNoSegmentAnnotations(t, got)
|
||||
}
|
||||
|
||||
func TestDetectSimpleTwoSpeakerOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("b.json", 0, "Bob", 2, 4),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
assertGroup(t, got, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1})
|
||||
}
|
||||
|
||||
func TestDetectTransitiveOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 10, 14),
|
||||
segment("b.json", 0, "Bob", 12, 13),
|
||||
segment("c.json", 0, "Carol", 13.5, 15),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
assertGroup(t, got, 0, 1, 10, 15, []string{"a.json#0", "b.json#0", "c.json#0"}, []string{"Alice", "Bob", "Carol"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1, 1})
|
||||
}
|
||||
|
||||
func TestDetectBoundaryContactDoesNotOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 2),
|
||||
segment("b.json", 0, "Bob", 2, 3),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 0 {
|
||||
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
||||
}
|
||||
assertNoSegmentAnnotations(t, got)
|
||||
}
|
||||
|
||||
func TestDetectSameSpeakerOnlyOverlapDoesNotCreateGroup(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("a.json", 1, "Alice", 2, 4),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 0 {
|
||||
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
||||
}
|
||||
assertNoSegmentAnnotations(t, got)
|
||||
}
|
||||
|
||||
func TestDetectIncludesSameSpeakerSegmentInsideMultiSpeakerOverlap(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 5),
|
||||
segment("a.json", 1, "Alice", 2, 3),
|
||||
segment("b.json", 0, "Bob", 4, 6),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
assertGroup(t, got, 0, 1, 1, 6, []string{"a.json#0", "a.json#1", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1, 1})
|
||||
}
|
||||
|
||||
func TestDetectMultipleGroups(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("b.json", 0, "Bob", 2, 4),
|
||||
segment("c.json", 0, "Carol", 5, 7),
|
||||
segment("d.json", 0, "Dan", 6, 8),
|
||||
},
|
||||
}
|
||||
|
||||
got := Detect(merged)
|
||||
if len(got.OverlapGroups) != 2 {
|
||||
t.Fatalf("group count = %d, want 2", len(got.OverlapGroups))
|
||||
}
|
||||
assertGroup(t, got, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertGroup(t, got, 1, 2, 5, 8, []string{"c.json#0", "d.json#0"}, []string{"Carol", "Dan"})
|
||||
assertSegmentGroupIDs(t, got, []int{1, 1, 2, 2})
|
||||
}
|
||||
|
||||
func TestDetectIsIdempotent(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segment("a.json", 0, "Alice", 1, 3),
|
||||
segment("b.json", 0, "Bob", 2, 4),
|
||||
},
|
||||
OverlapGroups: []model.OverlapGroup{
|
||||
{ID: 99},
|
||||
},
|
||||
}
|
||||
merged.Segments[0].OverlapGroupID = 99
|
||||
|
||||
once := Detect(merged)
|
||||
twice := Detect(once)
|
||||
assertGroup(t, twice, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
||||
assertSegmentGroupIDs(t, twice, []int{1, 1})
|
||||
}
|
||||
|
||||
func segment(source string, sourceIndex int, speaker string, start float64, end float64) model.Segment {
|
||||
return model.Segment{
|
||||
Source: source,
|
||||
SourceSegmentIndex: sourceIndex,
|
||||
Speaker: speaker,
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: speaker,
|
||||
}
|
||||
}
|
||||
|
||||
func assertGroup(t *testing.T, merged model.MergedTranscript, groupIndex int, id int, start float64, end float64, refs []string, speakers []string) {
|
||||
t.Helper()
|
||||
if len(merged.OverlapGroups) <= groupIndex {
|
||||
t.Fatalf("missing overlap group index %d", groupIndex)
|
||||
}
|
||||
group := merged.OverlapGroups[groupIndex]
|
||||
if group.ID != id {
|
||||
t.Fatalf("group ID = %d, want %d", group.ID, id)
|
||||
}
|
||||
if group.Start != start {
|
||||
t.Fatalf("group start = %f, want %f", group.Start, start)
|
||||
}
|
||||
if group.End != end {
|
||||
t.Fatalf("group end = %f, want %f", group.End, end)
|
||||
}
|
||||
if group.Class != "unknown" {
|
||||
t.Fatalf("group class = %q, want unknown", group.Class)
|
||||
}
|
||||
if group.Resolution != "unresolved" {
|
||||
t.Fatalf("group resolution = %q, want unresolved", group.Resolution)
|
||||
}
|
||||
if !equalStrings(group.Segments, refs) {
|
||||
t.Fatalf("group refs = %v, want %v", group.Segments, refs)
|
||||
}
|
||||
if !equalStrings(group.Speakers, speakers) {
|
||||
t.Fatalf("group speakers = %v, want %v", group.Speakers, speakers)
|
||||
}
|
||||
}
|
||||
|
||||
func assertSegmentGroupIDs(t *testing.T, merged model.MergedTranscript, ids []int) {
|
||||
t.Helper()
|
||||
if len(merged.Segments) != len(ids) {
|
||||
t.Fatalf("segment count = %d, want %d", len(merged.Segments), len(ids))
|
||||
}
|
||||
for index, id := range ids {
|
||||
if merged.Segments[index].OverlapGroupID != id {
|
||||
t.Fatalf("segment %d overlap group ID = %d, want %d", index, merged.Segments[index].OverlapGroupID, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertNoSegmentAnnotations(t *testing.T, merged model.MergedTranscript) {
|
||||
t.Helper()
|
||||
for index, segment := range merged.Segments {
|
||||
if segment.OverlapGroupID != 0 {
|
||||
t.Fatalf("segment %d overlap group ID = %d, want 0", index, segment.OverlapGroupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(left []string, right []string) bool {
|
||||
if len(left) != len(right) {
|
||||
return false
|
||||
}
|
||||
for index := range left {
|
||||
if left[index] != right[index] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user