Added a module to coalesce adjacent same-speaker segments
This commit is contained in:
@@ -2,6 +2,7 @@ package overlap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
)
|
||||
@@ -18,9 +19,10 @@ func Detect(in model.MergedTranscript) model.MergedTranscript {
|
||||
return in
|
||||
}
|
||||
|
||||
order := sortedSegmentIndices(in.Segments)
|
||||
var groupID int
|
||||
var candidate overlapCandidate
|
||||
for index := range in.Segments {
|
||||
for _, index := range order {
|
||||
segment := in.Segments[index]
|
||||
if !candidate.active {
|
||||
candidate = newCandidate(index, segment)
|
||||
@@ -40,6 +42,17 @@ func Detect(in model.MergedTranscript) model.MergedTranscript {
|
||||
return in
|
||||
}
|
||||
|
||||
func sortedSegmentIndices(segments []model.Segment) []int {
|
||||
indices := make([]int, len(segments))
|
||||
for index := range segments {
|
||||
indices[index] = index
|
||||
}
|
||||
sort.SliceStable(indices, func(i, j int) bool {
|
||||
return model.SegmentLess(segments[indices[i]], segments[indices[j]])
|
||||
})
|
||||
return indices
|
||||
}
|
||||
|
||||
type overlapCandidate struct {
|
||||
active bool
|
||||
indices []int
|
||||
|
||||
@@ -18,7 +18,7 @@ type ResolutionSummary struct {
|
||||
|
||||
// Resolve replaces detected overlap-group segments with word-run segments when
|
||||
// word-level timing is available.
|
||||
func Resolve(in model.MergedTranscript, wordRunGap float64) (model.MergedTranscript, ResolutionSummary, error) {
|
||||
func Resolve(in model.MergedTranscript, wordRunGap float64, wordRunReorderWindow float64) (model.MergedTranscript, ResolutionSummary, error) {
|
||||
summary := ResolutionSummary{
|
||||
GroupsProcessed: len(in.OverlapGroups),
|
||||
}
|
||||
@@ -35,9 +35,10 @@ func Resolve(in model.MergedTranscript, wordRunGap float64) (model.MergedTranscr
|
||||
clearAnnotationRefs := make(map[string]struct{})
|
||||
removeGroupIDs := make(map[int]struct{})
|
||||
replacements := make([]model.Segment, 0)
|
||||
replacementOrder := make(map[string]replacementOrder)
|
||||
|
||||
for _, group := range in.OverlapGroups {
|
||||
resolved, err := resolveGroup(in, group, refToIndex, wordRunGap)
|
||||
resolved, err := resolveGroup(in, group, refToIndex, wordRunGap, wordRunReorderWindow)
|
||||
if err != nil {
|
||||
return model.MergedTranscript{}, ResolutionSummary{}, err
|
||||
}
|
||||
@@ -48,6 +49,9 @@ func Resolve(in model.MergedTranscript, wordRunGap float64) (model.MergedTranscr
|
||||
summary.GroupsChanged++
|
||||
removeGroupIDs[group.ID] = struct{}{}
|
||||
replacements = append(replacements, resolved.replacements...)
|
||||
for sourceRef, order := range resolved.replacementOrder {
|
||||
replacementOrder[sourceRef] = order
|
||||
}
|
||||
|
||||
for _, ref := range group.Segments {
|
||||
clearAnnotationRefs[ref] = struct{}{}
|
||||
@@ -78,7 +82,7 @@ func Resolve(in model.MergedTranscript, wordRunGap float64) (model.MergedTranscr
|
||||
}
|
||||
segments = append(segments, replacements...)
|
||||
sort.SliceStable(segments, func(i, j int) bool {
|
||||
return model.SegmentLess(segments[i], segments[j])
|
||||
return resolvedSegmentLess(segments[i], segments[j], replacementOrder)
|
||||
})
|
||||
|
||||
overlapGroups := make([]model.OverlapGroup, 0, len(in.OverlapGroups)-len(removeGroupIDs))
|
||||
@@ -96,8 +100,15 @@ func Resolve(in model.MergedTranscript, wordRunGap float64) (model.MergedTranscr
|
||||
}
|
||||
|
||||
type resolvedGroup struct {
|
||||
removeRefs []string
|
||||
replacements []model.Segment
|
||||
removeRefs []string
|
||||
replacements []model.Segment
|
||||
replacementOrder map[string]replacementOrder
|
||||
}
|
||||
|
||||
type replacementOrder struct {
|
||||
cluster string
|
||||
rank int
|
||||
anchor float64
|
||||
}
|
||||
|
||||
type resolutionWord struct {
|
||||
@@ -114,7 +125,7 @@ type wordRun struct {
|
||||
end float64
|
||||
}
|
||||
|
||||
func resolveGroup(in model.MergedTranscript, group model.OverlapGroup, refToIndex map[string]int, wordRunGap float64) (resolvedGroup, error) {
|
||||
func resolveGroup(in model.MergedTranscript, group model.OverlapGroup, refToIndex map[string]int, wordRunGap float64, wordRunReorderWindow float64) (resolvedGroup, error) {
|
||||
segmentsBySpeaker := make(map[string][]model.Segment)
|
||||
refsBySpeaker := make(map[string][]string)
|
||||
for _, ref := range group.Segments {
|
||||
@@ -147,9 +158,76 @@ func resolveGroup(in model.MergedTranscript, group model.OverlapGroup, refToInde
|
||||
}
|
||||
}
|
||||
|
||||
resolved.replacements, resolved.replacementOrder = reorderReplacementSegments(group.ID, resolved.replacements, wordRunReorderWindow)
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func reorderReplacementSegments(groupID int, replacements []model.Segment, wordRunReorderWindow float64) ([]model.Segment, map[string]replacementOrder) {
|
||||
if len(replacements) == 0 {
|
||||
return replacements, nil
|
||||
}
|
||||
|
||||
ordered := append([]model.Segment(nil), replacements...)
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
return model.SegmentLess(ordered[i], ordered[j])
|
||||
})
|
||||
|
||||
ranks := make(map[string]replacementOrder, len(ordered))
|
||||
clusterStart := 0
|
||||
clusterIndex := 1
|
||||
for clusterStart < len(ordered) {
|
||||
clusterEnd := clusterStart + 1
|
||||
for clusterEnd < len(ordered) && ordered[clusterEnd].Start-ordered[clusterEnd-1].Start <= wordRunReorderWindow {
|
||||
clusterEnd++
|
||||
}
|
||||
|
||||
cluster := ordered[clusterStart:clusterEnd]
|
||||
anchor := cluster[0].Start
|
||||
sort.SliceStable(cluster, func(i, j int) bool {
|
||||
leftDuration := cluster[i].End - cluster[i].Start
|
||||
rightDuration := cluster[j].End - cluster[j].Start
|
||||
if leftDuration != rightDuration {
|
||||
return leftDuration < rightDuration
|
||||
}
|
||||
return model.SegmentLess(cluster[i], cluster[j])
|
||||
})
|
||||
|
||||
clusterKey := fmt.Sprintf("%d:%d", groupID, clusterIndex)
|
||||
for index := range cluster {
|
||||
ranks[cluster[index].SourceRef] = replacementOrder{
|
||||
cluster: clusterKey,
|
||||
rank: index,
|
||||
anchor: anchor,
|
||||
}
|
||||
}
|
||||
|
||||
clusterStart = clusterEnd
|
||||
clusterIndex++
|
||||
}
|
||||
|
||||
return ordered, ranks
|
||||
}
|
||||
|
||||
func resolvedSegmentLess(left model.Segment, right model.Segment, replacementOrder map[string]replacementOrder) bool {
|
||||
leftOrder, leftHasOrder := replacementOrder[left.SourceRef]
|
||||
rightOrder, rightHasOrder := replacementOrder[right.SourceRef]
|
||||
if leftHasOrder && rightHasOrder && leftOrder.cluster == rightOrder.cluster && leftOrder.rank != rightOrder.rank {
|
||||
return leftOrder.rank < rightOrder.rank
|
||||
}
|
||||
leftStart := left.Start
|
||||
if leftHasOrder {
|
||||
leftStart = leftOrder.anchor
|
||||
}
|
||||
rightStart := right.Start
|
||||
if rightHasOrder {
|
||||
rightStart = rightOrder.anchor
|
||||
}
|
||||
if leftStart != rightStart {
|
||||
return leftStart < rightStart
|
||||
}
|
||||
return model.SegmentLess(left, right)
|
||||
}
|
||||
|
||||
func groupSpeakerOrder(group model.OverlapGroup, segmentsBySpeaker map[string][]model.Segment) []string {
|
||||
seen := make(map[string]struct{}, len(group.Speakers))
|
||||
speakers := make([]string, 0, len(group.Speakers))
|
||||
|
||||
@@ -2,6 +2,7 @@ package overlap
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
@@ -14,7 +15,7 @@ func TestResolveNoOverlapGroupsIsNoOp(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
got, summary, err := Resolve(merged, 0.75)
|
||||
got, summary, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -39,7 +40,7 @@ func TestResolveCreatesChronologicalWordRunSegments(t *testing.T) {
|
||||
merged.Segments[0].OverlapGroupID = 1
|
||||
merged.Segments[1].OverlapGroupID = 1
|
||||
|
||||
got, summary, err := Resolve(merged, 0.75)
|
||||
got, summary, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -93,7 +94,7 @@ func TestResolveIncludesWordsByIntervalIntersection(t *testing.T) {
|
||||
}
|
||||
merged.Segments[0].OverlapGroupID = 1
|
||||
|
||||
got, _, err := Resolve(merged, 10)
|
||||
got, _, err := Resolve(merged, 10, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func TestResolveWordRunGapThreshold(t *testing.T) {
|
||||
}
|
||||
merged.Segments[0].OverlapGroupID = 1
|
||||
|
||||
got, _, err := Resolve(merged, 0.75)
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func TestResolvePartialResolutionKeepsNoWordSpeakerOriginals(t *testing.T) {
|
||||
merged.Segments[0].OverlapGroupID = 1
|
||||
merged.Segments[1].OverlapGroupID = 1
|
||||
|
||||
got, summary, err := Resolve(merged, 0.75)
|
||||
got, summary, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func TestResolveGroupWithNoUsableWordsRemainsUnchanged(t *testing.T) {
|
||||
merged.Segments[0].OverlapGroupID = 1
|
||||
merged.Segments[1].OverlapGroupID = 1
|
||||
|
||||
got, summary, err := Resolve(merged, 0.75)
|
||||
got, summary, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -201,7 +202,7 @@ func TestResolveReplacementProvenanceIsDeterministic(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75)
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -238,7 +239,7 @@ func TestResolveIncludesUntimedWordsInTextWithoutChangingBounds(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75)
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -273,7 +274,7 @@ func TestResolveUntimedWordsDoNotBridgeWordRunGap(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75)
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -299,7 +300,7 @@ func TestResolveSpeakerWithOnlyUntimedWordsIsNotReplaced(t *testing.T) {
|
||||
}
|
||||
merged.Segments[0].OverlapGroupID = 1
|
||||
|
||||
got, summary, err := Resolve(merged, 0.75)
|
||||
got, summary, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
@@ -311,6 +312,93 @@ func TestResolveSpeakerWithOnlyUntimedWordsIsNotReplaced(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveReordersNearStartWordRunsByDuration(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segmentWithWords("a.json", 0, "Alice", 1, 3, word("long", 1, 2)),
|
||||
segmentWithWords("b.json", 0, "Bob", 1, 3, word("short", 1.2, 1.3)),
|
||||
},
|
||||
OverlapGroups: []model.OverlapGroup{
|
||||
group(1, 1, 3, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"}),
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
if gotTexts(got.Segments) != "short,long" {
|
||||
t.Fatalf("segment order = %s, want short,long", gotTexts(got.Segments))
|
||||
}
|
||||
if got.Segments[0].Start != 1.2 || got.Segments[0].End != 1.3 {
|
||||
t.Fatalf("short segment bounds changed: %#v", got.Segments[0])
|
||||
}
|
||||
if got.Segments[1].SourceRef != "word-run:1:1:1" || got.Segments[1].Text != "long" {
|
||||
t.Fatalf("long segment provenance/text changed: %#v", got.Segments[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDoesNotReorderWordRunsOutsideWindow(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segmentWithWords("a.json", 0, "Alice", 1, 3, word("long", 1, 2)),
|
||||
segmentWithWords("b.json", 0, "Bob", 1, 3, word("short", 1.5, 1.6)),
|
||||
},
|
||||
OverlapGroups: []model.OverlapGroup{
|
||||
group(1, 1, 3, []string{"a.json#0", "b.json#0"}, []string{"Alice", "Bob"}),
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
if gotTexts(got.Segments) != "long,short" {
|
||||
t.Fatalf("segment order = %s, want long,short", gotTexts(got.Segments))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveReordersTransitiveNearStartClustersByDuration(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segmentWithWords("a.json", 0, "Alice", 1, 3, word("long", 1, 2)),
|
||||
segmentWithWords("b.json", 0, "Bob", 1, 3, word("medium", 1.3, 1.8)),
|
||||
segmentWithWords("c.json", 0, "Carol", 1, 3, word("short", 1.65, 1.75)),
|
||||
},
|
||||
OverlapGroups: []model.OverlapGroup{
|
||||
group(1, 1, 3, []string{"a.json#0", "b.json#0", "c.json#0"}, []string{"Alice", "Bob", "Carol"}),
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
if gotTexts(got.Segments) != "short,medium,long" {
|
||||
t.Fatalf("segment order = %s, want short,medium,long", gotTexts(got.Segments))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveReorderFallsBackToDeterministicOrderForEqualDurations(t *testing.T) {
|
||||
merged := model.MergedTranscript{
|
||||
Segments: []model.Segment{
|
||||
segmentWithWords("b.json", 0, "Bob", 1, 3, word("bob", 1, 1.5)),
|
||||
segmentWithWords("a.json", 0, "Alice", 1, 3, word("alice", 1.2, 1.7)),
|
||||
},
|
||||
OverlapGroups: []model.OverlapGroup{
|
||||
group(1, 1, 3, []string{"b.json#0", "a.json#0"}, []string{"Bob", "Alice"}),
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := Resolve(merged, 0.75, 0.4)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
if gotTexts(got.Segments) != "bob,alice" {
|
||||
t.Fatalf("segment order = %s, want bob,alice", gotTexts(got.Segments))
|
||||
}
|
||||
}
|
||||
|
||||
func segmentWithWords(source string, sourceIndex int, speaker string, start float64, end float64, words ...model.Word) model.Segment {
|
||||
segment := segment(source, sourceIndex, speaker, start, end)
|
||||
segment.Words = words
|
||||
@@ -326,6 +414,14 @@ func word(text string, start float64, end float64) model.Word {
|
||||
}
|
||||
}
|
||||
|
||||
func gotTexts(segments []model.Segment) string {
|
||||
texts := make([]string, 0, len(segments))
|
||||
for _, segment := range segments {
|
||||
texts = append(texts, segment.Text)
|
||||
}
|
||||
return strings.Join(texts, ",")
|
||||
}
|
||||
|
||||
func untimedWord(text string) model.Word {
|
||||
return model.Word{
|
||||
Text: text,
|
||||
|
||||
Reference in New Issue
Block a user