Implemented an initial transcript merge stage
This commit is contained in:
@@ -2,7 +2,9 @@ package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/config"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
@@ -22,10 +24,92 @@ func (jsonFilesReader) Read(ctx context.Context, cfg config.Config) ([]model.Raw
|
||||
|
||||
raw := make([]model.RawTranscript, 0, len(cfg.InputFiles))
|
||||
for _, inputFile := range cfg.InputFiles {
|
||||
raw = append(raw, model.RawTranscript{Source: inputFile})
|
||||
transcript, err := readRawTranscript(inputFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
raw = append(raw, transcript)
|
||||
}
|
||||
|
||||
return raw, []report.Event{
|
||||
report.Info("input", "json-files", fmt.Sprintf("accepted %d input file(s)", len(raw))),
|
||||
report.Info("input", "json-files", fmt.Sprintf("decoded %d input file(s)", len(raw))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type rawTranscriptFile struct {
|
||||
Segments json.RawMessage `json:"segments"`
|
||||
}
|
||||
|
||||
type rawSegmentFile struct {
|
||||
Start json.RawMessage `json:"start"`
|
||||
End json.RawMessage `json:"end"`
|
||||
Text json.RawMessage `json:"text"`
|
||||
}
|
||||
|
||||
func readRawTranscript(path string) (model.RawTranscript, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return model.RawTranscript{}, fmt.Errorf("read input file %q: %w", path, err)
|
||||
}
|
||||
|
||||
var parsed rawTranscriptFile
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
return model.RawTranscript{}, fmt.Errorf("parse input file %q: %w", path, err)
|
||||
}
|
||||
if parsed.Segments == nil || isJSONNull(parsed.Segments) {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q must contain top-level segments array", path)
|
||||
}
|
||||
|
||||
var rawSegments []rawSegmentFile
|
||||
if err := json.Unmarshal(parsed.Segments, &rawSegments); err != nil {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q top-level segments must be an array: %w", path, err)
|
||||
}
|
||||
|
||||
segments := make([]model.RawSegment, 0, len(rawSegments))
|
||||
for index, segment := range rawSegments {
|
||||
if segment.Start == nil || isJSONNull(segment.Start) {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d missing numeric start", path, index)
|
||||
}
|
||||
if segment.End == nil || isJSONNull(segment.End) {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d missing numeric end", path, index)
|
||||
}
|
||||
if segment.Text == nil || isJSONNull(segment.Text) {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d missing string text", path, index)
|
||||
}
|
||||
|
||||
var start float64
|
||||
if err := json.Unmarshal(segment.Start, &start); err != nil {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d start must be numeric", path, index)
|
||||
}
|
||||
var end float64
|
||||
if err := json.Unmarshal(segment.End, &end); err != nil {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d end must be numeric", path, index)
|
||||
}
|
||||
var text string
|
||||
if err := json.Unmarshal(segment.Text, &text); err != nil {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d text must be a string", path, index)
|
||||
}
|
||||
|
||||
if start < 0 {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d has negative start", path, index)
|
||||
}
|
||||
if end < start {
|
||||
return model.RawTranscript{}, fmt.Errorf("input file %q segment %d has end before start", path, index)
|
||||
}
|
||||
|
||||
segments = append(segments, model.RawSegment{
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
return model.RawTranscript{
|
||||
Source: path,
|
||||
Segments: segments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isJSONNull(value json.RawMessage) bool {
|
||||
return string(value) == "null"
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@ package builtin
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/config"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/model"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/pipeline"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/report"
|
||||
"gitea.maximumdirect.net/eric/seriatim/internal/speaker"
|
||||
)
|
||||
|
||||
type noopPreprocessor struct {
|
||||
@@ -42,6 +44,39 @@ func (p noopPreprocessor) Process(ctx context.Context, in pipeline.PreprocessSta
|
||||
}, nil
|
||||
}
|
||||
|
||||
type trimText struct{}
|
||||
|
||||
func (trimText) Name() string {
|
||||
return "trim-text"
|
||||
}
|
||||
|
||||
func (trimText) Requires() pipeline.ModelState {
|
||||
return pipeline.StateCanonical
|
||||
}
|
||||
|
||||
func (trimText) Produces() pipeline.ModelState {
|
||||
return pipeline.StateCanonical
|
||||
}
|
||||
|
||||
func (trimText) Process(ctx context.Context, in pipeline.PreprocessState, cfg config.Config) (pipeline.PreprocessState, []report.Event, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pipeline.PreprocessState{}, nil, err
|
||||
}
|
||||
if in.State != pipeline.StateCanonical {
|
||||
return pipeline.PreprocessState{}, nil, fmt.Errorf("preprocessing module %q requires state %q but received %q", "trim-text", pipeline.StateCanonical, in.State)
|
||||
}
|
||||
|
||||
for transcriptIndex := range in.Canonical {
|
||||
for segmentIndex := range in.Canonical[transcriptIndex].Segments {
|
||||
in.Canonical[transcriptIndex].Segments[segmentIndex].Text = strings.TrimSpace(in.Canonical[transcriptIndex].Segments[segmentIndex].Text)
|
||||
}
|
||||
}
|
||||
|
||||
return in, []report.Event{
|
||||
report.Info("preprocessing", "trim-text", "trimmed canonical segment text"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type normalizeSpeakers struct{}
|
||||
|
||||
func (normalizeSpeakers) Name() string {
|
||||
@@ -64,11 +99,33 @@ func (normalizeSpeakers) Process(ctx context.Context, in pipeline.PreprocessStat
|
||||
return pipeline.PreprocessState{}, nil, fmt.Errorf("preprocessing module %q requires state %q but received %q", "normalize-speakers", pipeline.StateRaw, in.State)
|
||||
}
|
||||
|
||||
speakers, err := speaker.LoadMap(cfg.SpeakersFile)
|
||||
if err != nil {
|
||||
return pipeline.PreprocessState{}, nil, err
|
||||
}
|
||||
|
||||
canonical := make([]model.CanonicalTranscript, 0, len(in.Raw))
|
||||
for _, raw := range in.Raw {
|
||||
canonicalSpeaker, err := speakers.SpeakerForSource(raw.Source)
|
||||
if err != nil {
|
||||
return pipeline.PreprocessState{}, nil, err
|
||||
}
|
||||
|
||||
segments := make([]model.Segment, 0, len(raw.Segments))
|
||||
for index, rawSegment := range raw.Segments {
|
||||
segments = append(segments, model.Segment{
|
||||
Source: raw.Source,
|
||||
SourceSegmentIndex: index,
|
||||
Speaker: canonicalSpeaker,
|
||||
Start: rawSegment.Start,
|
||||
End: rawSegment.End,
|
||||
Text: rawSegment.Text,
|
||||
})
|
||||
}
|
||||
|
||||
canonical = append(canonical, model.CanonicalTranscript{
|
||||
Source: raw.Source,
|
||||
Segments: nil,
|
||||
Segments: segments,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -77,6 +134,6 @@ func (normalizeSpeakers) Process(ctx context.Context, in pipeline.PreprocessStat
|
||||
Raw: append([]model.RawTranscript(nil), in.Raw...),
|
||||
Canonical: canonical,
|
||||
}, []report.Event{
|
||||
report.Info("preprocessing", "normalize-speakers", "created placeholder canonical transcript(s)"),
|
||||
report.Info("preprocessing", "normalize-speakers", "created canonical transcript(s) from raw input"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ func NewRegistry() *pipeline.Registry {
|
||||
registry.RegisterInputReader(jsonFilesReader{})
|
||||
registry.RegisterPreprocessor(noopPreprocessor{name: "validate-raw", requires: pipeline.StateRaw, produces: pipeline.StateRaw})
|
||||
registry.RegisterPreprocessor(normalizeSpeakers{})
|
||||
registry.RegisterPreprocessor(noopPreprocessor{name: "trim-text", requires: pipeline.StateCanonical, produces: pipeline.StateCanonical})
|
||||
registry.RegisterPreprocessor(trimText{})
|
||||
registry.RegisterPreprocessor(noopPreprocessor{name: "autocorrect", requires: pipeline.StateCanonical, produces: pipeline.StateCanonical})
|
||||
registry.RegisterMerger(placeholderMerger{})
|
||||
registry.RegisterPostprocessor(noopPostprocessor{name: "detect-overlaps"})
|
||||
|
||||
Reference in New Issue
Block a user