204 lines
5.7 KiB
Go
204 lines
5.7 KiB
Go
package overlap
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
|
)
|
|
|
|
func TestDetectNoOverlaps(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 2),
|
|
segment("b.json", 0, "Bob", 2, 3),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
if len(got.OverlapGroups) != 0 {
|
|
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
|
}
|
|
assertNoSegmentAnnotations(t, got)
|
|
}
|
|
|
|
func TestDetectSimpleTwoSpeakerOverlap(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 3),
|
|
segment("b.json", 0, "Bob", 2, 4),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
assertGroup(t, got, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
|
assertSegmentGroupIDs(t, got, []int{1, 1})
|
|
}
|
|
|
|
func TestDetectTransitiveOverlap(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 10, 14),
|
|
segment("b.json", 0, "Bob", 12, 13),
|
|
segment("c.json", 0, "Carol", 13.5, 15),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
assertGroup(t, got, 0, 1, 10, 15, []string{"a.json#0", "b.json#0", "c.json#0"}, []string{"Alice", "Bob", "Carol"})
|
|
assertSegmentGroupIDs(t, got, []int{1, 1, 1})
|
|
}
|
|
|
|
func TestDetectBoundaryContactDoesNotOverlap(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 2),
|
|
segment("b.json", 0, "Bob", 2, 3),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
if len(got.OverlapGroups) != 0 {
|
|
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
|
}
|
|
assertNoSegmentAnnotations(t, got)
|
|
}
|
|
|
|
func TestDetectSameSpeakerOnlyOverlapDoesNotCreateGroup(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 3),
|
|
segment("a.json", 1, "Alice", 2, 4),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
if len(got.OverlapGroups) != 0 {
|
|
t.Fatalf("expected no overlap groups, got %d", len(got.OverlapGroups))
|
|
}
|
|
assertNoSegmentAnnotations(t, got)
|
|
}
|
|
|
|
func TestDetectIncludesSameSpeakerSegmentInsideMultiSpeakerOverlap(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 5),
|
|
segment("a.json", 1, "Alice", 2, 3),
|
|
segment("b.json", 0, "Bob", 4, 6),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
assertGroup(t, got, 0, 1, 1, 6, []string{"a.json#0", "a.json#1", "b.json#0"}, []string{"Alice", "Bob"})
|
|
assertSegmentGroupIDs(t, got, []int{1, 1, 1})
|
|
}
|
|
|
|
func TestDetectMultipleGroups(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 3),
|
|
segment("b.json", 0, "Bob", 2, 4),
|
|
segment("c.json", 0, "Carol", 5, 7),
|
|
segment("d.json", 0, "Dan", 6, 8),
|
|
},
|
|
}
|
|
|
|
got := Detect(merged)
|
|
if len(got.OverlapGroups) != 2 {
|
|
t.Fatalf("group count = %d, want 2", len(got.OverlapGroups))
|
|
}
|
|
assertGroup(t, got, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
|
assertGroup(t, got, 1, 2, 5, 8, []string{"c.json#0", "d.json#0"}, []string{"Carol", "Dan"})
|
|
assertSegmentGroupIDs(t, got, []int{1, 1, 2, 2})
|
|
}
|
|
|
|
func TestDetectIsIdempotent(t *testing.T) {
|
|
merged := model.MergedTranscript{
|
|
Segments: []model.Segment{
|
|
segment("a.json", 0, "Alice", 1, 3),
|
|
segment("b.json", 0, "Bob", 2, 4),
|
|
},
|
|
OverlapGroups: []model.OverlapGroup{
|
|
{ID: 99},
|
|
},
|
|
}
|
|
merged.Segments[0].OverlapGroupID = 99
|
|
|
|
once := Detect(merged)
|
|
twice := Detect(once)
|
|
assertGroup(t, twice, 0, 1, 1, 4, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"})
|
|
assertSegmentGroupIDs(t, twice, []int{1, 1})
|
|
}
|
|
|
|
func segment(source string, sourceIndex int, speaker string, start float64, end float64) model.Segment {
|
|
return model.Segment{
|
|
Source: source,
|
|
SourceSegmentIndex: sourceIndex,
|
|
Speaker: speaker,
|
|
Start: start,
|
|
End: end,
|
|
Text: speaker,
|
|
}
|
|
}
|
|
|
|
func assertGroup(t *testing.T, merged model.MergedTranscript, groupIndex int, id int, start float64, end float64, refs []string, speakers []string) {
|
|
t.Helper()
|
|
if len(merged.OverlapGroups) <= groupIndex {
|
|
t.Fatalf("missing overlap group index %d", groupIndex)
|
|
}
|
|
group := merged.OverlapGroups[groupIndex]
|
|
if group.ID != id {
|
|
t.Fatalf("group ID = %d, want %d", group.ID, id)
|
|
}
|
|
if group.Start != start {
|
|
t.Fatalf("group start = %f, want %f", group.Start, start)
|
|
}
|
|
if group.End != end {
|
|
t.Fatalf("group end = %f, want %f", group.End, end)
|
|
}
|
|
if group.Class != "unknown" {
|
|
t.Fatalf("group class = %q, want unknown", group.Class)
|
|
}
|
|
if group.Resolution != "unresolved" {
|
|
t.Fatalf("group resolution = %q, want unresolved", group.Resolution)
|
|
}
|
|
if !equalStrings(group.Segments, refs) {
|
|
t.Fatalf("group refs = %v, want %v", group.Segments, refs)
|
|
}
|
|
if !equalStrings(group.Speakers, speakers) {
|
|
t.Fatalf("group speakers = %v, want %v", group.Speakers, speakers)
|
|
}
|
|
}
|
|
|
|
func assertSegmentGroupIDs(t *testing.T, merged model.MergedTranscript, ids []int) {
|
|
t.Helper()
|
|
if len(merged.Segments) != len(ids) {
|
|
t.Fatalf("segment count = %d, want %d", len(merged.Segments), len(ids))
|
|
}
|
|
for index, id := range ids {
|
|
if merged.Segments[index].OverlapGroupID != id {
|
|
t.Fatalf("segment %d overlap group ID = %d, want %d", index, merged.Segments[index].OverlapGroupID, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func assertNoSegmentAnnotations(t *testing.T, merged model.MergedTranscript) {
|
|
t.Helper()
|
|
for index, segment := range merged.Segments {
|
|
if segment.OverlapGroupID != 0 {
|
|
t.Fatalf("segment %d overlap group ID = %d, want 0", index, segment.OverlapGroupID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func equalStrings(left []string, right []string) bool {
|
|
if len(left) != len(right) {
|
|
return false
|
|
}
|
|
for index := range left {
|
|
if left[index] != right[index] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|