Skip to main content

svod_arch/vad/
mod.rs

1//! VAD-aware chunker for long-form ASR.
2//!
3//! Operates on `&[f32]` per-frame speech probabilities — the output of any
4//! frame-level VAD — and packs them into bounded-length [`AudioChunk`]s
5//! suitable for feeding to an encoder one chunk at a time. Speech-bearing
6//! regions of the waveform are preserved; pure-silence regions between
7//! chunks are dropped.
8//!
9//! The chunker is purely algorithmic: no Tensor or model dependency, no
10//! coupling to a specific VAD. The output is sample-index ranges that any
11//! downstream decoder can consume.
12//!
13//! # Algorithm
14//!
15//! ```text
16//! 1. threshold + smoothing  → speech runs (prob-grid indices)
17//! 2. split runs ≥ strict_limit at internal prob troughs
18//! 3. greedy-pack runs into chunks of ~[min_duration, max_duration]
19//!    (closing at inter-segment silence rather than mid-speech)
20//! 4. convert prob indices → samples, apply pad, align to align_to
21//! ```
22//!
23//! All knobs live in [`ChunkerOpts`]; nothing inside the algorithm hardcodes
24//! sample rates, prob granularity, or alignment.
25
26pub(crate) mod segment;
27
28#[cfg(feature = "serde")]
29use serde::Deserialize;
30use snafu::Snafu;
31
32use segment::threshold_segments;
33
34// ─── Config ───────────────────────────────────────────────────────────────
35
36/// Configuration for [`chunks_from_probs`].
37///
38/// All `*_duration` fields are wall-clock seconds; the chunker converts to
39/// prob-grid indices via `(sample_rate, samples_per_prob)`.
40#[derive(Clone, Debug)]
41#[cfg_attr(feature = "serde", derive(Deserialize))]
42#[cfg_attr(feature = "serde", serde(default))]
43pub struct ChunkerOpts {
44    /// Sample rate of the source waveform in Hz.
45    pub sample_rate: u32,
46    /// Number of input samples covered by one entry of the `probs` array.
47    /// Match the stride of the upstream frame-level VAD. Required so the
48    /// chunker stays VAD-agnostic.
49    pub samples_per_prob: usize,
50    /// Speech threshold: prob entries `>= threshold` count as speech.
51    pub threshold: f32,
52    /// Soft minimum chunk duration. The chunker won't voluntarily close a
53    /// chunk shorter than this.
54    pub min_duration: f32,
55    /// Soft maximum chunk duration. Past `min_duration`, the chunk closes
56    /// at the next inter-segment silence (or, for a single long run, at a
57    /// local prob trough) instead of extending past max.
58    pub max_duration: f32,
59    /// Hard ceiling. A single VAD segment longer than this is split
60    /// internally at prob-trough argmins so no output chunk exceeds it.
61    /// Also caps chunk length when an under-min chunk would otherwise
62    /// be extended past this.
63    pub strict_limit_duration: f32,
64    /// Pre-segmentation smoothing: a speech run must contain at least this
65    /// many above-threshold probs to be retained.
66    pub min_speech_probs: usize,
67    /// Pre-segmentation smoothing: a silence gap must span at least this
68    /// many below-threshold probs to terminate a speech run.
69    pub min_silence_probs: usize,
70    /// Two speech runs separated by ≤ this many silence probs are merged
71    /// before chunking.
72    pub merge_gap_probs: usize,
73    /// Window radius (in prob-grid units) for the trough-argmin search when
74    /// splitting overlong runs. `None` (default) reuses `min_silence_probs`,
75    /// which is fine when smoothing tightness and trough-search width happen
76    /// to want the same scale; set explicitly to decouple them.
77    pub trough_search_probs: Option<usize>,
78    /// Secondary threshold (typically lower than `threshold`) for
79    /// `split_long_runs`. When `Some(t)`, search the full legal split
80    /// range for the frame closest to the geometric target with prob
81    /// `< t`; fall back to the narrow argmin around the target when no
82    /// frame qualifies. `None` always uses narrow argmin.
83    pub trough_threshold: Option<f32>,
84    /// Symmetric pad in samples added to each chunk's start/end (clamped at
85    /// 0 and the implicit waveform end). Gives the encoder context at chunk
86    /// boundaries.
87    pub pad_samples: usize,
88    /// Snap chunk boundaries to integer multiples of this many samples.
89    /// `1` = sample-precise. Set to the encoder's effective frame stride
90    /// (e.g. `mel_hop * subsample_factor`) so chunks land on encoder-frame
91    /// boundaries. Pathological values (e.g. > min_duration) are the
92    /// caller's responsibility — boundaries can shift by up to
93    /// `align_to - 1` samples.
94    pub align_to: usize,
95}
96
97impl Default for ChunkerOpts {
98    fn default() -> Self {
99        Self {
100            sample_rate: 16_000,
101            samples_per_prob: 512,
102            threshold: 0.5,
103            min_duration: 15.0,
104            max_duration: 22.0,
105            strict_limit_duration: 30.0,
106            min_speech_probs: 8,
107            min_silence_probs: 4,
108            merge_gap_probs: 8,
109            trough_search_probs: None,
110            trough_threshold: None,
111            pad_samples: 0,
112            align_to: 1,
113        }
114    }
115}
116
117// ─── Output ───────────────────────────────────────────────────────────────
118
119/// A speech-bearing region of the source waveform.
120///
121/// Sample indices reference the *original* waveform passed to the VAD.
122/// `start_sample` is `chunk_index * samples_per_prob` (after pad + align);
123/// callers can derive `start_sec = start_sample as f32 / sample_rate as f32`
124/// to offset per-chunk transcripts.
125#[derive(Clone, Copy, Debug, PartialEq, Eq)]
126pub struct AudioChunk {
127    /// Inclusive start sample index in the original waveform.
128    pub start_sample: usize,
129    /// Exclusive end sample index. May exceed the waveform length if the
130    /// last prob entry covered samples past the waveform end; callers
131    /// should clamp at slice time.
132    pub end_sample: usize,
133}
134
135// ─── Errors ───────────────────────────────────────────────────────────────
136
137#[derive(Debug, Snafu)]
138#[snafu(visibility(pub))]
139pub enum Error {
140    #[snafu(display("samples_per_prob must be > 0"))]
141    ZeroSamplesPerProb,
142    #[snafu(display("align_to must be > 0"))]
143    ZeroAlignTo,
144    #[snafu(display("min_duration ({min}) must be ≤ max_duration ({max})"))]
145    MinExceedsMax { min: f32, max: f32 },
146    #[snafu(display("max_duration ({max}) must be ≤ strict_limit_duration ({strict})"))]
147    MaxExceedsStrict { max: f32, strict: f32 },
148}
149
150pub type Result<T> = std::result::Result<T, Error>;
151
152/// Upper bound (in samples) on any chunk [`chunks_from_probs`] can emit:
153/// `strict_limit + 2·trough_radius` (split_long_runs slack) `+ 2·pad +
154/// align_to` (post-process slack at waveform edges + alignment ceil).
155/// Single source of truth for downstream callers that need to size
156/// buffers or assert the contract.
157pub fn strict_chunk_sample_bound(
158    strict_limit_probs: usize,
159    trough_radius: usize,
160    samples_per_prob: usize,
161    pad_samples: usize,
162    align_to: usize,
163) -> usize {
164    (strict_limit_probs + 2 * trough_radius) * samples_per_prob + 2 * pad_samples + align_to
165}
166
167// ─── Public entry point ───────────────────────────────────────────────────
168
169/// Pack VAD speech probabilities into bounded-length chunks.
170///
171/// Output chunks cover only speech-bearing portions of the waveform; silence
172/// between chunks is dropped. Boundaries are padded by `opts.pad_samples` and
173/// snapped to `opts.align_to` multiples (start floored, end ceil'd, so
174/// coverage is preserved). Adjacent chunks that overlap after padding are
175/// merged.
176pub fn chunks_from_probs(probs: &[f32], opts: &ChunkerOpts) -> Result<Vec<AudioChunk>> {
177    validate(opts)?;
178    if probs.is_empty() {
179        return Ok(Vec::new());
180    }
181
182    let probs_per_sec = opts.sample_rate as f32 / opts.samples_per_prob as f32;
183    let strict_limit_probs = (opts.strict_limit_duration * probs_per_sec).ceil() as usize;
184    let min_probs = (opts.min_duration * probs_per_sec).ceil() as usize;
185    let max_probs = (opts.max_duration * probs_per_sec).ceil() as usize;
186
187    let trough_radius = opts.trough_search_probs.unwrap_or(opts.min_silence_probs);
188    let trough_threshold = opts.trough_threshold;
189
190    // Halve the silence-sensitivity knobs and retry if `threshold_segments`
191    // produced any segment exceeding `strict_limit_probs` — gives `split_long_runs`
192    // less work / more silence to cut at. Floor at 2 because a single
193    // sub-threshold prob is reliably a VAD micro-dip mid-word, not silence.
194    let mut adapted = opts.clone();
195    let segments = loop {
196        let segs = threshold_segments(probs, &adapted);
197        let any_over = segs.iter().any(|&(s, e)| e - s > strict_limit_probs);
198        if !any_over || adapted.min_silence_probs <= 2 {
199            break segs;
200        }
201        adapted.min_silence_probs = (adapted.min_silence_probs / 2).max(2);
202        adapted.merge_gap_probs = (adapted.merge_gap_probs / 2).max(1);
203    };
204    let segments = split_long_runs(segments, probs, trough_radius, trough_threshold, strict_limit_probs);
205    let chunks = pack_segments(&segments, min_probs, max_probs, strict_limit_probs);
206
207    Ok(post_process(&chunks, probs.len(), opts))
208}
209
210// ─── Internals ────────────────────────────────────────────────────────────
211
212fn validate(opts: &ChunkerOpts) -> Result<()> {
213    if opts.samples_per_prob == 0 {
214        return ZeroSamplesPerProbSnafu.fail();
215    }
216    if opts.align_to == 0 {
217        return ZeroAlignToSnafu.fail();
218    }
219    if opts.min_duration > opts.max_duration {
220        return MinExceedsMaxSnafu { min: opts.min_duration, max: opts.max_duration }.fail();
221    }
222    if opts.max_duration > opts.strict_limit_duration {
223        return MaxExceedsStrictSnafu { max: opts.max_duration, strict: opts.strict_limit_duration }.fail();
224    }
225    Ok(())
226}
227
228/// Break any speech segment whose length exceeds `strict_limit_probs` into
229/// `ceil(len / strict_limit)` near-equal pieces, choosing each split point
230/// as the prob argmin within ±`search_radius` of the geometric target. Lands
231/// on natural pauses inside long unbroken runs instead of hard-cutting at
232/// fixed time intervals.
233///
234/// Each emitted piece is at least `len / (2 * n)` long. Without that floor
235/// a wide `search_radius` can let the argmin land arbitrarily close to a
236/// split's neighbours and produce 1-prob shards that downstream code has to
237/// special-case. With the floor the worst-case shrinkage is half the
238/// average piece length.
239fn split_long_runs(
240    segments: Vec<(usize, usize)>,
241    probs: &[f32],
242    search_radius: usize,
243    trough_threshold: Option<f32>,
244    strict_limit_probs: usize,
245) -> Vec<(usize, usize)> {
246    if strict_limit_probs == 0 {
247        return segments;
248    }
249    let mut out = Vec::with_capacity(segments.len());
250    for (start, end) in segments {
251        let len = end - start;
252        if len <= strict_limit_probs {
253            out.push((start, end));
254            continue;
255        }
256        let n = len.div_ceil(strict_limit_probs);
257        let min_piece = (len / (2 * n)).max(1);
258        let mut cur = start;
259        for k in 1..n {
260            let target = start + (len * k) / n;
261            let pieces_left = n - k;
262            // Constrain the argmin window so this split is at least
263            // min_piece away from cur and from `end - pieces_left * min_piece`
264            // (i.e. each remaining piece can still hit min_piece).
265            let lo_narrow = target.saturating_sub(search_radius).max(cur + min_piece);
266            let hi_floor = end.saturating_sub(pieces_left * min_piece);
267            let hi_narrow = (target + search_radius).min(hi_floor.saturating_sub(1));
268
269            // With `trough_threshold`: prefer a real silence frame anywhere
270            // in the legal range (closest to target for balance) over the
271            // narrow-radius argmin which may land inside speech.
272            let trough_split = trough_threshold.and_then(|t| {
273                let lo_wide = cur + min_piece;
274                let hi_wide = hi_floor.saturating_sub(1);
275                if hi_wide < lo_wide {
276                    return None;
277                }
278                let slice = &probs[lo_wide..=hi_wide];
279                slice
280                    .iter()
281                    .enumerate()
282                    .filter(|&(_, &p)| p < t)
283                    .min_by_key(|(i, _)| (lo_wide + i).abs_diff(target))
284                    .map(|(i, _)| lo_wide + i)
285            });
286            let split = if let Some(s) = trough_split {
287                s
288            } else if hi_narrow >= lo_narrow {
289                lo_narrow + argmin(&probs[lo_narrow..=hi_narrow])
290            } else {
291                // Constraints incompatible (radius wider than the available
292                // slack). Fall back to the geometric target, clamped so the
293                // remaining pieces are still non-empty.
294                target.clamp(cur + min_piece, hi_floor.saturating_sub(1).max(cur + min_piece))
295            };
296            if split > cur && split < end {
297                out.push((cur, split));
298                cur = split;
299            }
300        }
301        if cur < end {
302            out.push((cur, end));
303        }
304    }
305    out
306}
307
308fn argmin(slice: &[f32]) -> usize {
309    let mut best = 0usize;
310    let mut best_v = slice[0];
311    for (i, &v) in slice.iter().enumerate().skip(1) {
312        if v < best_v {
313            best_v = v;
314            best = i;
315        }
316    }
317    best
318}
319
320/// Greedy-concat speech segments into bounded-length chunks. Closes a chunk
321/// when the next segment would push it past `max_probs` AND either the
322/// current chunk has reached `min_probs` *or* extending would exceed
323/// `strict_limit_probs` (the hard ceiling).
324fn pack_segments(
325    segments: &[(usize, usize)],
326    min_probs: usize,
327    max_probs: usize,
328    strict_limit_probs: usize,
329) -> Vec<(usize, usize)> {
330    let mut chunks = Vec::new();
331    let mut cur: Option<(usize, usize)> = None;
332    for &(s, e) in segments {
333        match cur {
334            None => cur = Some((s, e)),
335            Some((cs, ce)) => {
336                let prospective = e - cs;
337                let cur_len = ce - cs;
338                if prospective > max_probs && (cur_len >= min_probs || prospective > strict_limit_probs) {
339                    chunks.push((cs, ce));
340                    cur = Some((s, e));
341                } else {
342                    cur = Some((cs, e));
343                }
344            }
345        }
346    }
347    if let Some(c) = cur {
348        chunks.push(c);
349    }
350    chunks
351}
352
353/// Convert prob-index ranges to sample ranges. Padding is adaptive: each
354/// side is capped at half the silence gap to the neighbour, so chunks
355/// never overlap their neighbours' speech. Alignment-induced overlap
356/// (floor-start / ceil-end rounding) is clipped to preserve splits.
357fn post_process(chunks: &[(usize, usize)], probs_len: usize, opts: &ChunkerOpts) -> Vec<AudioChunk> {
358    let max_sample = probs_len * opts.samples_per_prob;
359    let pad = opts.pad_samples;
360    let align = opts.align_to;
361
362    let mut out: Vec<AudioChunk> = Vec::with_capacity(chunks.len());
363    for (i, &(s, e)) in chunks.iter().enumerate() {
364        let raw_start = s * opts.samples_per_prob;
365        let raw_end = e * opts.samples_per_prob;
366
367        // Cap each side's padding at half the silence gap to the neighbour
368        // (or the full margin at the waveform edges). Floor division: total
369        // pad consumed by adjacent chunks ≤ gap, so they never overlap.
370        let pad_left = if i == 0 {
371            pad.min(raw_start)
372        } else {
373            let prev_raw_end = chunks[i - 1].1 * opts.samples_per_prob;
374            pad.min(raw_start.saturating_sub(prev_raw_end) / 2)
375        };
376        let pad_right = if i + 1 == chunks.len() {
377            pad.min(max_sample.saturating_sub(raw_end))
378        } else {
379            let next_raw_start = chunks[i + 1].0 * opts.samples_per_prob;
380            pad.min(next_raw_start.saturating_sub(raw_end) / 2)
381        };
382
383        let padded_start = raw_start - pad_left;
384        let padded_end = (raw_end + pad_right).min(max_sample);
385        let aligned_start = (padded_start / align) * align;
386        let mut aligned_end = padded_end.div_ceil(align) * align;
387        if aligned_end > max_sample {
388            aligned_end = max_sample;
389        }
390        if aligned_end <= aligned_start {
391            continue;
392        }
393        if let Some(last) = out.last_mut()
394            && aligned_start < last.end_sample
395        {
396            // Alignment-only overlap (asymmetric floor/ceil rounding put us
397            // inside the previous chunk). Clip our start up to preserve the
398            // split. Drop if it collapses to empty.
399            let bumped_start = last.end_sample;
400            if aligned_end > bumped_start {
401                out.push(AudioChunk { start_sample: bumped_start, end_sample: aligned_end });
402            }
403            continue;
404        }
405        out.push(AudioChunk { start_sample: aligned_start, end_sample: aligned_end });
406    }
407    out
408}