392 lines
9.2 KiB
Go
392 lines
9.2 KiB
Go
package scheduler
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"hash/fnv"
|
||
"math/rand"
|
||
"sync"
|
||
"time"
|
||
|
||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||
"gitea.maximumdirect.net/ejr/feedkit/logging"
|
||
"gitea.maximumdirect.net/ejr/feedkit/sources"
|
||
)
|
||
|
||
// Logger is a printf-style logger used throughout scheduler.
|
||
// It is an alias to the shared feedkit logging type so callers can pass
|
||
// one function everywhere without type mismatch friction.
|
||
type Logger = logging.Logf
|
||
|
||
// Job describes one scheduler task.
|
||
//
|
||
// A Job may be backed by either:
|
||
// - a polling source (sources.PollSource): uses Every + jitter and calls Poll()
|
||
// - a stream source (sources.StreamSource): ignores Every and calls Run()
|
||
//
|
||
// Jitter behavior:
|
||
// - For polling sources: Jitter is applied at startup and before each poll tick.
|
||
// - For stream sources: Jitter is applied once at startup only (optional; useful to avoid
|
||
// reconnect storms when many instances start together).
|
||
type Job struct {
|
||
Source sources.Input
|
||
Every time.Duration
|
||
StreamExitPolicy StreamExitPolicy
|
||
StreamBackoff StreamBackoff
|
||
|
||
// Jitter is the maximum additional delay added before each poll.
|
||
// Example: if Every=15m and Jitter=30s, each poll will occur at:
|
||
// tick time + random(0..30s)
|
||
//
|
||
// If Jitter == 0 for polling sources, we compute a default jitter based on Every.
|
||
//
|
||
// For stream sources, Jitter is treated as *startup jitter only*.
|
||
Jitter time.Duration
|
||
}
|
||
|
||
// StreamExitPolicy controls how the scheduler handles non-fatal stream exits.
|
||
type StreamExitPolicy string
|
||
|
||
const (
|
||
StreamExitPolicyRestart StreamExitPolicy = "restart"
|
||
StreamExitPolicyStop StreamExitPolicy = "stop"
|
||
StreamExitPolicyFatal StreamExitPolicy = "fatal"
|
||
)
|
||
|
||
// StreamBackoff controls restart pacing for stream supervision.
|
||
type StreamBackoff struct {
|
||
Initial time.Duration
|
||
Max time.Duration
|
||
Jitter time.Duration
|
||
}
|
||
|
||
type Scheduler struct {
|
||
Jobs []Job
|
||
Out chan<- event.Event
|
||
Logf Logger
|
||
}
|
||
|
||
const (
|
||
defaultStreamBackoffInitial = 1 * time.Second
|
||
defaultStreamBackoffMax = 1 * time.Minute
|
||
defaultStreamBackoffJitter = 250 * time.Millisecond
|
||
streamBackoffResetAfter = 5 * time.Minute
|
||
)
|
||
|
||
var timeNow = time.Now
|
||
|
||
// Run starts one goroutine per job.
|
||
// Poll jobs run on their own interval and emit 0..N events per poll.
|
||
// Stream jobs run continuously and emit events as they arrive.
|
||
func (s *Scheduler) Run(ctx context.Context) error {
|
||
if s.Out == nil {
|
||
return fmt.Errorf("scheduler.Run: Out channel is nil")
|
||
}
|
||
if len(s.Jobs) == 0 {
|
||
return fmt.Errorf("scheduler.Run: no jobs configured")
|
||
}
|
||
|
||
runCtx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
fatalErrCh := make(chan error, 1)
|
||
var wg sync.WaitGroup
|
||
for _, job := range s.Jobs {
|
||
job := job // capture loop variable
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
s.runJob(runCtx, job, fatalErrCh)
|
||
}()
|
||
}
|
||
|
||
done := make(chan struct{})
|
||
go func() {
|
||
wg.Wait()
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case err := <-fatalErrCh:
|
||
cancel()
|
||
<-done
|
||
return err
|
||
case <-runCtx.Done():
|
||
<-done
|
||
return runCtx.Err()
|
||
}
|
||
}
|
||
|
||
func (s *Scheduler) runJob(ctx context.Context, job Job, fatalErrCh chan<- error) {
|
||
if job.Source == nil {
|
||
s.logf("scheduler: job has nil source")
|
||
return
|
||
}
|
||
|
||
// Stream sources: event-driven.
|
||
if ss, ok := job.Source.(sources.StreamSource); ok {
|
||
s.runStream(ctx, job, ss, fatalErrCh)
|
||
return
|
||
}
|
||
|
||
// Poll sources: time-based.
|
||
ps, ok := job.Source.(sources.PollSource)
|
||
if !ok {
|
||
s.logf("scheduler: source %T (%s) implements neither Poll() nor Run()", job.Source, job.Source.Name())
|
||
return
|
||
}
|
||
if job.Every <= 0 {
|
||
s.logf("scheduler: polling job %q missing/invalid interval (sources[].every)", ps.Name())
|
||
return
|
||
}
|
||
|
||
s.runPoller(ctx, job, ps)
|
||
}
|
||
|
||
func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource, fatalErrCh chan<- error) {
|
||
policy := effectiveStreamExitPolicy(job.StreamExitPolicy)
|
||
backoff := effectiveStreamBackoff(job.StreamBackoff)
|
||
rng := seededRNG(src.Name())
|
||
|
||
// Optional startup jitter: helps avoid reconnect storms if many daemons start at once.
|
||
if job.Jitter > 0 {
|
||
if !sleepJitter(ctx, rng, job.Jitter) {
|
||
return
|
||
}
|
||
}
|
||
|
||
nextDelay := backoff.Initial
|
||
for {
|
||
startedAt := timeNow()
|
||
err := src.Run(ctx, s.Out)
|
||
if ctx.Err() != nil {
|
||
return
|
||
}
|
||
|
||
normalizedErr := normalizeStreamExitError(src.Name(), err)
|
||
if sources.IsStreamFatal(normalizedErr) {
|
||
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited fatally: %w", src.Name(), normalizedErr))
|
||
return
|
||
}
|
||
|
||
switch policy {
|
||
case StreamExitPolicyStop:
|
||
s.logf("scheduler: stream source %q stopped after exit: %v", src.Name(), normalizedErr)
|
||
return
|
||
case StreamExitPolicyFatal:
|
||
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited under fatal policy: %w", src.Name(), normalizedErr))
|
||
return
|
||
}
|
||
|
||
if streamRunWasStable(startedAt, timeNow()) {
|
||
nextDelay = backoff.Initial
|
||
}
|
||
|
||
delay := nextDelay + randomDuration(rng, backoff.Jitter)
|
||
s.logf("scheduler: stream source %q exited; restarting in %s: %v", src.Name(), delay, normalizedErr)
|
||
if !sleepDuration(ctx, delay) {
|
||
return
|
||
}
|
||
nextDelay = nextStreamBackoff(nextDelay, backoff.Max)
|
||
}
|
||
}
|
||
|
||
func (s *Scheduler) runPoller(ctx context.Context, job Job, src sources.PollSource) {
|
||
// Compute jitter: either configured per job, or a sensible default.
|
||
jitter := effectiveJitter(job.Every, job.Jitter)
|
||
|
||
// Each worker gets its own RNG (safe + no lock contention).
|
||
rng := seededRNG(src.Name())
|
||
|
||
// Optional startup jitter: avoids all jobs firing at the exact moment the daemon starts.
|
||
if !sleepJitter(ctx, rng, jitter) {
|
||
return
|
||
}
|
||
|
||
// Immediate poll at startup (after startup jitter).
|
||
s.pollOnce(ctx, src)
|
||
|
||
t := time.NewTicker(job.Every)
|
||
defer t.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-t.C:
|
||
// Per-tick jitter: spreads calls out within the interval.
|
||
if !sleepJitter(ctx, rng, jitter) {
|
||
return
|
||
}
|
||
s.pollOnce(ctx, src)
|
||
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *Scheduler) pollOnce(ctx context.Context, src sources.PollSource) {
|
||
events, err := src.Poll(ctx)
|
||
if err != nil {
|
||
s.logf("scheduler: poll failed (%s): %v", src.Name(), err)
|
||
return
|
||
}
|
||
|
||
for _, e := range events {
|
||
select {
|
||
case s.Out <- e:
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *Scheduler) logf(format string, args ...any) {
|
||
if s.Logf == nil {
|
||
return
|
||
}
|
||
s.Logf(format, args...)
|
||
}
|
||
|
||
func (s *Scheduler) reportFatal(ch chan<- error, err error) {
|
||
if err == nil {
|
||
return
|
||
}
|
||
select {
|
||
case ch <- err:
|
||
default:
|
||
}
|
||
}
|
||
|
||
// ---- helpers ----
|
||
|
||
func effectiveStreamExitPolicy(policy StreamExitPolicy) StreamExitPolicy {
|
||
switch policy {
|
||
case StreamExitPolicyStop, StreamExitPolicyFatal:
|
||
return policy
|
||
default:
|
||
return StreamExitPolicyRestart
|
||
}
|
||
}
|
||
|
||
func effectiveStreamBackoff(cfg StreamBackoff) StreamBackoff {
|
||
out := cfg
|
||
if out.Initial <= 0 {
|
||
out.Initial = defaultStreamBackoffInitial
|
||
}
|
||
if out.Max <= 0 {
|
||
out.Max = defaultStreamBackoffMax
|
||
}
|
||
if out.Max < out.Initial {
|
||
out.Max = out.Initial
|
||
}
|
||
if out.Jitter < 0 {
|
||
out.Jitter = 0
|
||
}
|
||
return out
|
||
}
|
||
|
||
func normalizeStreamExitError(sourceName string, err error) error {
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return sources.StreamRetryable(fmt.Errorf("stream source %q exited unexpectedly without error", sourceName))
|
||
}
|
||
|
||
func nextStreamBackoff(current, max time.Duration) time.Duration {
|
||
if current <= 0 {
|
||
current = defaultStreamBackoffInitial
|
||
}
|
||
if max <= 0 {
|
||
max = defaultStreamBackoffMax
|
||
}
|
||
if current >= max {
|
||
return max
|
||
}
|
||
next := current * 2
|
||
if next < current || next > max {
|
||
return max
|
||
}
|
||
return next
|
||
}
|
||
|
||
func streamRunWasStable(startedAt, endedAt time.Time) bool {
|
||
if startedAt.IsZero() || endedAt.IsZero() {
|
||
return false
|
||
}
|
||
return endedAt.Sub(startedAt) >= streamBackoffResetAfter
|
||
}
|
||
|
||
func seededRNG(name string) *rand.Rand {
|
||
seed := timeNow().UnixNano() ^ int64(hashStringFNV32a(name))
|
||
return rand.New(rand.NewSource(seed))
|
||
}
|
||
|
||
// effectiveJitter chooses a jitter value.
|
||
// - If configuredMax > 0, use it (but clamp).
|
||
// - Else default to min(every/10, 30s).
|
||
// - Clamp to at most every/2 (so jitter can’t delay more than half the interval).
|
||
func effectiveJitter(every time.Duration, configuredMax time.Duration) time.Duration {
|
||
if every <= 0 {
|
||
return 0
|
||
}
|
||
|
||
j := configuredMax
|
||
if j <= 0 {
|
||
j = every / 10
|
||
if j > 30*time.Second {
|
||
j = 30 * time.Second
|
||
}
|
||
}
|
||
|
||
// Clamp jitter so it doesn’t dominate the schedule.
|
||
maxAllowed := every / 2
|
||
if j > maxAllowed {
|
||
j = maxAllowed
|
||
}
|
||
if j < 0 {
|
||
j = 0
|
||
}
|
||
return j
|
||
}
|
||
|
||
// sleepJitter sleeps for a random duration in [0, max].
|
||
// Returns false if the context is cancelled while waiting.
|
||
func sleepJitter(ctx context.Context, rng *rand.Rand, max time.Duration) bool {
|
||
if max <= 0 {
|
||
return true
|
||
}
|
||
|
||
return sleepDuration(ctx, randomDuration(rng, max))
|
||
}
|
||
|
||
func randomDuration(rng *rand.Rand, max time.Duration) time.Duration {
|
||
if max <= 0 {
|
||
return 0
|
||
}
|
||
// Int63n requires a positive argument.
|
||
// We add 1 so max itself is attainable.
|
||
n := rng.Int63n(int64(max) + 1)
|
||
return time.Duration(n)
|
||
}
|
||
|
||
func sleepDuration(ctx context.Context, d time.Duration) bool {
|
||
if d <= 0 {
|
||
return true
|
||
}
|
||
timer := time.NewTimer(d)
|
||
defer timer.Stop()
|
||
|
||
select {
|
||
case <-timer.C:
|
||
return true
|
||
case <-ctx.Done():
|
||
return false
|
||
}
|
||
}
|
||
|
||
func hashStringFNV32a(s string) uint32 {
|
||
h := fnv.New32a()
|
||
_, _ = h.Write([]byte(s))
|
||
return h.Sum32()
|
||
}
|