From 3928e0c4a7fdb84ca5ddbb92818ad29296bd697a Mon Sep 17 00:00:00 2001 From: Eric Rakestraw Date: Sun, 26 Apr 2026 19:33:23 -0500 Subject: [PATCH] Implemented an autocorrect module at the postprocessing stage --- README.md | 41 ++++- internal/autocorrect/autocorrect.go | 132 ++++++++++++++++ internal/autocorrect/autocorrect_test.go | 190 +++++++++++++++++++++++ internal/builtin/postprocess.go | 30 ++++ internal/builtin/registry.go | 3 +- internal/cli/merge_test.go | 90 ++++++++++- internal/config/config.go | 2 +- 7 files changed, 482 insertions(+), 6 deletions(-) create mode 100644 internal/autocorrect/autocorrect.go create mode 100644 internal/autocorrect/autocorrect_test.go diff --git a/README.md b/README.md index e3fbf75..3074f58 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Optional flags: - `--output-modules`: comma-separated output modules. Default: `json`. - `--preprocessing-modules`: comma-separated preprocessing modules. Default: `validate-raw,normalize-speakers,trim-text`. - `--postprocessing-modules`: comma-separated postprocessing modules. Default: `detect-overlaps,resolve-overlaps,assign-ids,validate-output`. -- `--autocorrect`: autocorrect rules file. Reserved for the `autocorrect` module; not part of the default pipeline. +- `--autocorrect`: autocorrect rules file. Required when the postprocessing `autocorrect` module is enabled. ## Input JSON Format @@ -163,9 +163,46 @@ Segments are sorted deterministically by: Final segment IDs are assigned after sorting and start at `1`. +## Autocorrect + +Autocorrect is an opt-in postprocessing module. It is not part of the default pipeline. + +Enable it by adding `autocorrect` to `--postprocessing-modules` and passing `--autocorrect`: + +```sh +go run ./cmd/seriatim merge \ + --input-file input.json \ + --speakers speakers.yml \ + --autocorrect autocorrect.yml \ + --postprocessing-modules detect-overlaps,resolve-overlaps,autocorrect,assign-ids,validate-output \ + --output-file merged.json +``` + +`autocorrect.yml` format: + +```yaml +autocorrect: + - target: "Hrank" + match: + - "hrank" + - "Frank" + + - target: "Mike Brown" + match: + - "Mike Pat" +``` + +Matching behavior: + +- Matching is case-sensitive. +- Matches apply only to whole tokens, not substrings inside larger words. +- Punctuation and whitespace can surround a match. +- Multi-word and hyphenated matches are supported. +- Duplicate match strings are invalid, including duplicates across separate rules. + ## Current Limitations - Only JSON input is supported. - Word-level timing data is not preserved yet. - Overlap detection and overlap resolution are currently no-op modules. -- Autocorrect, coalescing, and alternate output formats are not implemented yet. +- Coalescing and alternate output formats are not implemented yet. diff --git a/internal/autocorrect/autocorrect.go b/internal/autocorrect/autocorrect.go new file mode 100644 index 0000000..9331d8f --- /dev/null +++ b/internal/autocorrect/autocorrect.go @@ -0,0 +1,132 @@ +package autocorrect + +import ( + "fmt" + "os" + "strings" + + "gopkg.in/yaml.v3" +) + +// Rules stores ordered autocorrect replacement rules. +type Rules struct { + rules []Rule +} + +// Rule replaces ordered match strings with a canonical target. +type Rule struct { + Target string `yaml:"target"` + Match []string `yaml:"match"` +} + +type fileSchema struct { + Autocorrect []Rule `yaml:"autocorrect"` +} + +// Load parses and validates an autocorrect.yml file. +func Load(path string) (Rules, error) { + data, err := os.ReadFile(path) + if err != nil { + return Rules{}, err + } + + var parsed fileSchema + if err := yaml.Unmarshal(data, &parsed); err != nil { + return Rules{}, fmt.Errorf("parse autocorrect file %q: %w", path, err) + } + if len(parsed.Autocorrect) == 0 { + return Rules{}, fmt.Errorf("autocorrect file %q must contain at least one autocorrect rule", path) + } + + seenMatches := make(map[string]int) + rules := make([]Rule, 0, len(parsed.Autocorrect)) + for ruleIndex, rule := range parsed.Autocorrect { + rule.Target = strings.TrimSpace(rule.Target) + if rule.Target == "" { + return Rules{}, fmt.Errorf("autocorrect rule %d must include target", ruleIndex) + } + if len(rule.Match) == 0 { + return Rules{}, fmt.Errorf("autocorrect rule %d for target %q must include at least one match string", ruleIndex, rule.Target) + } + + localMatches := make(map[string]struct{}, len(rule.Match)) + for matchIndex, match := range rule.Match { + match = strings.TrimSpace(match) + if match == "" { + return Rules{}, fmt.Errorf("autocorrect rule %d for target %q contains empty match string at index %d", ruleIndex, rule.Target, matchIndex) + } + if _, exists := localMatches[match]; exists { + return Rules{}, fmt.Errorf("autocorrect rule %d for target %q contains duplicate match string %q", ruleIndex, rule.Target, match) + } + localMatches[match] = struct{}{} + + if previousRuleIndex, exists := seenMatches[match]; exists { + return Rules{}, fmt.Errorf("autocorrect match string %q appears in both rule %d and rule %d", match, previousRuleIndex, ruleIndex) + } + seenMatches[match] = ruleIndex + rule.Match[matchIndex] = match + } + + rules = append(rules, rule) + } + + return Rules{rules: rules}, nil +} + +// Apply replaces configured whole-token matches and returns the updated text and replacement count. +func (r Rules) Apply(text string) (string, int) { + total := 0 + for _, rule := range r.rules { + for _, match := range rule.Match { + var count int + text, count = replaceWholeToken(text, match, rule.Target) + total += count + } + } + return text, total +} + +func replaceWholeToken(text string, match string, target string) (string, int) { + if text == "" || match == "" { + return text, 0 + } + + var builder strings.Builder + replacements := 0 + searchStart := 0 + writeStart := 0 + for { + index := strings.Index(text[searchStart:], match) + if index == -1 { + break + } + index += searchStart + end := index + len(match) + + if isTokenBoundary(text, index-1) && isTokenBoundary(text, end) { + builder.WriteString(text[writeStart:index]) + builder.WriteString(target) + replacements++ + writeStart = end + searchStart = end + continue + } + + searchStart = index + 1 + } + + if replacements == 0 { + return text, 0 + } + + builder.WriteString(text[writeStart:]) + return builder.String(), replacements +} + +func isTokenBoundary(text string, index int) bool { + if index < 0 || index >= len(text) { + return true + } + char := text[index] + return !((char >= 'A' && char <= 'Z') || (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '_') +} diff --git a/internal/autocorrect/autocorrect_test.go b/internal/autocorrect/autocorrect_test.go new file mode 100644 index 0000000..ba72929 --- /dev/null +++ b/internal/autocorrect/autocorrect_test.go @@ -0,0 +1,190 @@ +package autocorrect + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadValidRules(t *testing.T) { + dir := t.TempDir() + path := writeAutocorrect(t, dir, `autocorrect: + - target: "Hrank" + match: + - "Frank" + - "frank" +`) + + rules, err := Load(path) + if err != nil { + t.Fatalf("load rules: %v", err) + } + + got, count := rules.Apply("Frank and frank") + if got != "Hrank and Hrank" { + t.Fatalf("text = %q, want %q", got, "Hrank and Hrank") + } + if count != 2 { + t.Fatalf("count = %d, want 2", count) + } +} + +func TestLoadValidation(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "missing top-level autocorrect", + content: `other: []`, + want: "must contain at least one autocorrect rule", + }, + { + name: "empty rules list", + content: `autocorrect: []`, + want: "must contain at least one autocorrect rule", + }, + { + name: "empty target", + content: `autocorrect: + - target: "" + match: ["Frank"] +`, + want: "must include target", + }, + { + name: "empty match list", + content: `autocorrect: + - target: "Hrank" + match: [] +`, + want: "must include at least one match string", + }, + { + name: "empty match string", + content: `autocorrect: + - target: "Hrank" + match: [" "] +`, + want: "contains empty match string", + }, + { + name: "duplicate match across rules", + content: `autocorrect: + - target: "Hrank" + match: ["Frank"] + - target: "Other" + match: ["Frank"] +`, + want: `appears in both rule 0 and rule 1`, + }, + { + name: "duplicate match within rule", + content: `autocorrect: + - target: "Hrank" + match: ["Frank", "Frank"] +`, + want: `contains duplicate match string "Frank"`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dir := t.TempDir() + path := writeAutocorrect(t, dir, test.content) + + _, err := Load(path) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), test.want) { + t.Fatalf("expected error to contain %q, got %v", test.want, err) + } + }) + } +} + +func TestApplyReplacementBehavior(t *testing.T) { + rules := Rules{rules: []Rule{ + { + Target: "Hrank", + Match: []string{"Frank"}, + }, + { + Target: "Mike Brown", + Match: []string{"Mike Pat"}, + }, + { + Target: "Godfrey", + Match: []string{"God-free"}, + }, + }} + + tests := []struct { + name string + input string + want string + wantCount int + }{ + { + name: "case sensitive", + input: "Frank and FRANK", + want: "Hrank and FRANK", + wantCount: 1, + }, + { + name: "punctuation boundary", + input: "Frank, are you there?", + want: "Hrank, are you there?", + wantCount: 1, + }, + { + name: "no substring in larger token", + input: "Franklin and xFrank Frank_y Frank2", + want: "Franklin and xFrank Frank_y Frank2", + wantCount: 0, + }, + { + name: "multi word match", + input: "Hello Mike Pat.", + want: "Hello Mike Brown.", + wantCount: 1, + }, + { + name: "hyphenated match", + input: "God-free is here.", + want: "Godfrey is here.", + wantCount: 1, + }, + { + name: "hyphen outside match is boundary", + input: "x-Frank-y", + want: "x-Hrank-y", + wantCount: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, count := rules.Apply(test.input) + if got != test.want { + t.Fatalf("text = %q, want %q", got, test.want) + } + if count != test.wantCount { + t.Fatalf("count = %d, want %d", count, test.wantCount) + } + }) + } +} + +func writeAutocorrect(t *testing.T, dir string, content string) string { + t.Helper() + + path := filepath.Join(dir, "autocorrect.yml") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("write autocorrect file: %v", err) + } + return path +} diff --git a/internal/builtin/postprocess.go b/internal/builtin/postprocess.go index b0678e0..ac1d585 100644 --- a/internal/builtin/postprocess.go +++ b/internal/builtin/postprocess.go @@ -2,7 +2,9 @@ package builtin import ( "context" + "fmt" + "gitea.maximumdirect.net/eric/seriatim/internal/autocorrect" "gitea.maximumdirect.net/eric/seriatim/internal/config" "gitea.maximumdirect.net/eric/seriatim/internal/model" "gitea.maximumdirect.net/eric/seriatim/internal/report" @@ -45,3 +47,31 @@ func (assignIDs) Process(ctx context.Context, in model.MergedTranscript, cfg con report.Info("postprocessing", "assign-ids", "assigned final segment IDs"), }, nil } + +type autocorrectPostprocessor struct{} + +func (autocorrectPostprocessor) Name() string { + return "autocorrect" +} + +func (autocorrectPostprocessor) Process(ctx context.Context, in model.MergedTranscript, cfg config.Config) (model.MergedTranscript, []report.Event, error) { + if err := ctx.Err(); err != nil { + return model.MergedTranscript{}, nil, err + } + + rules, err := autocorrect.Load(cfg.AutocorrectFile) + if err != nil { + return model.MergedTranscript{}, nil, err + } + + replacements := 0 + for index := range in.Segments { + var count int + in.Segments[index].Text, count = rules.Apply(in.Segments[index].Text) + replacements += count + } + + return in, []report.Event{ + report.Info("postprocessing", "autocorrect", fmt.Sprintf("applied %d autocorrect replacement(s)", replacements)), + }, nil +} diff --git a/internal/builtin/registry.go b/internal/builtin/registry.go index dcbb428..91e071b 100644 --- a/internal/builtin/registry.go +++ b/internal/builtin/registry.go @@ -10,13 +10,12 @@ func NewRegistry() *pipeline.Registry { registry.RegisterPreprocessor(noopPreprocessor{name: "validate-raw", requires: pipeline.StateRaw, produces: pipeline.StateRaw}) registry.RegisterPreprocessor(normalizeSpeakers{}) registry.RegisterPreprocessor(trimText{}) - registry.RegisterPreprocessor(noopPreprocessor{name: "autocorrect", requires: pipeline.StateCanonical, produces: pipeline.StateCanonical}) registry.RegisterMerger(placeholderMerger{}) registry.RegisterPostprocessor(noopPostprocessor{name: "detect-overlaps"}) registry.RegisterPostprocessor(noopPostprocessor{name: "resolve-overlaps"}) registry.RegisterPostprocessor(assignIDs{}) registry.RegisterPostprocessor(noopPostprocessor{name: "validate-output"}) - registry.RegisterPostprocessor(noopPostprocessor{name: "autocorrect"}) + registry.RegisterPostprocessor(autocorrectPostprocessor{}) registry.RegisterOutputWriter(jsonOutputWriter{}) return registry diff --git a/internal/cli/merge_test.go b/internal/cli/merge_test.go index 5d1d1c3..64aa351 100644 --- a/internal/cli/merge_test.go +++ b/internal/cli/merge_test.go @@ -304,7 +304,7 @@ func TestAutocorrectRequiresAutocorrectFile(t *testing.T) { "--input-file", input, "--speakers", speakers, "--output-file", output, - "--preprocessing-modules", "validate-raw,normalize-speakers,autocorrect", + "--postprocessing-modules", "detect-overlaps,resolve-overlaps,autocorrect,assign-ids,validate-output", ) if err == nil { t.Fatal("expected error") @@ -314,6 +314,94 @@ func TestAutocorrectRequiresAutocorrectFile(t *testing.T) { } } +func TestPreprocessingAutocorrectIsUnknownModule(t *testing.T) { + dir := t.TempDir() + input := writeJSONFile(t, dir, "input.json", `{"segments":[]}`) + speakers := writeYAMLFile(t, dir, "speakers.yml", `match: + - speaker: Alice + match: ["input.json"] +`) + autocorrect := writeYAMLFile(t, dir, "autocorrect.yml", `autocorrect: + - target: Hrank + match: ["Frank"] +`) + output := filepath.Join(dir, "merged.json") + + err := executeMerge( + "--input-file", input, + "--speakers", speakers, + "--autocorrect", autocorrect, + "--output-file", output, + "--preprocessing-modules", "validate-raw,normalize-speakers,autocorrect", + ) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), `unknown preprocessing module "autocorrect"`) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestPostprocessingAutocorrectUpdatesOutputAndReport(t *testing.T) { + dir := t.TempDir() + input := writeJSONFile(t, dir, "input.json", `{ + "segments": [ + {"start": 1, "end": 2, "text": "Frank met Mike Pat, not Franklin."}, + {"start": 3, "end": 4, "text": "God-free and FRANK stayed."} + ] + }`) + speakers := writeYAMLFile(t, dir, "speakers.yml", `match: + - speaker: Alice + match: ["input.json"] +`) + autocorrect := writeYAMLFile(t, dir, "autocorrect.yml", `autocorrect: + - target: Hrank + match: ["Frank"] + - target: Mike Brown + match: ["Mike Pat"] + - target: Godfrey + match: ["God-free"] +`) + output := filepath.Join(dir, "merged.json") + reportPath := filepath.Join(dir, "report.json") + + err := executeMerge( + "--input-file", input, + "--speakers", speakers, + "--autocorrect", autocorrect, + "--output-file", output, + "--report-file", reportPath, + "--postprocessing-modules", "detect-overlaps,resolve-overlaps,autocorrect,assign-ids,validate-output", + ) + if err != nil { + t.Fatalf("merge failed: %v", err) + } + + var transcript model.FinalTranscript + readJSON(t, output, &transcript) + if got, want := transcript.Segments[0].Text, "Hrank met Mike Brown, not Franklin."; got != want { + t.Fatalf("segment 0 text = %q, want %q", got, want) + } + if got, want := transcript.Segments[1].Text, "Godfrey and FRANK stayed."; got != want { + t.Fatalf("segment 1 text = %q, want %q", got, want) + } + + var rpt report.Report + readJSON(t, reportPath, &rpt) + found := false + for _, event := range rpt.Events { + if event.Stage == "postprocessing" && event.Module == "autocorrect" { + found = true + if !strings.Contains(event.Message, "applied 3 autocorrect replacement(s)") { + t.Fatalf("unexpected autocorrect report message: %q", event.Message) + } + } + } + if !found { + t.Fatal("expected autocorrect report event") + } +} + func TestOutputJSONIsByteStable(t *testing.T) { dir := t.TempDir() inputA := writeJSONFile(t, dir, "a.json", `{"segments":[{"start":2,"end":3,"text":"a"}]}`) diff --git a/internal/config/config.go b/internal/config/config.go index 4651e1a..791cb93 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -111,7 +111,7 @@ func NewMergeConfig(opts MergeOptions) (Config, error) { } } - if contains(cfg.PreprocessingModules, "autocorrect") || contains(cfg.PostprocessingModules, "autocorrect") { + if contains(cfg.PostprocessingModules, "autocorrect") { if cfg.AutocorrectFile == "" { return Config{}, errors.New("--autocorrect is required when autocorrect is enabled") }