wavekat_turn/audio/pipecat.rs
1//! Pipecat Smart Turn v3 backend.
2//!
3//! Audio-based turn detection using the Smart Turn ONNX model.
4//! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be
5//! upsampled before feeding to this detector.
6//!
7//! # Model
8//!
9//! - Source: <https://huggingface.co/pipecat-ai/smart-turn-v3>
10//! - File: `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB)
11//! - License: BSD 2-Clause
12//!
13//! # Tensor specification
14//!
15//! | Role | Name | Shape | Dtype |
16//! |--------|------------------|----------------|---------|
17//! | Input | `input_features` | `[B, 80, 800]` | float32 |
18//! | Output | `logits` | `[B, 1]` | float32 |
19//!
20//! Despite the name, `logits` is a **sigmoid probability** P(turn complete)
21//! in [0, 1] — the sigmoid is fused into the model before ONNX export.
22//! Threshold: `probability > 0.5` → `TurnState::Finished`.
23//!
24//! # Mel-feature specification
25//!
26//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`:
27//!
28//! | Parameter | Value |
29//! |---------------|--------------------------------|
30//! | Sample rate | 16 000 Hz |
31//! | n_fft | 400 samples (25 ms) |
32//! | hop_length | 160 samples (10 ms) |
33//! | n_mels | 80 |
34//! | Freq range | 0 – 8 000 Hz |
35//! | Mel scale | Slaney (NOT HTK) |
36//! | Window | Hann (periodic, size 400) |
37//! | Pre-emphasis | None |
38//! | Log | log10 with ε = 1e-10 |
39//! | Normalization | clamp(max − 8), (x + 4) / 4 |
40//!
41//! # Audio buffer
42//!
43//! - Exactly **8 seconds = 128 000 samples** at 16 kHz.
44//! - Shorter input: **front-padded** with zeros (audio is at the end).
45//! - Longer input: the **last** 8 s is used (oldest samples discarded).
46
47use std::collections::VecDeque;
48use std::path::Path;
49use std::sync::Arc;
50use std::time::Instant;
51
52use ndarray::{s, Array2, Array3};
53use ort::{inputs, value::Tensor};
54use realfft::num_complex::Complex;
55use realfft::{RealFftPlanner, RealToComplex};
56
57use crate::onnx;
58use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
59
60// ---------------------------------------------------------------------------
61// Constants
62// ---------------------------------------------------------------------------
63
64/// Sample rate the model expects.
65const SAMPLE_RATE: u32 = 16_000;
66/// FFT window size in samples (25 ms at 16 kHz).
67const N_FFT: usize = 400;
68/// STFT hop length in samples (10 ms at 16 kHz).
69const HOP_LENGTH: usize = 160;
70/// Number of mel filterbank bins.
71const N_MELS: usize = 80;
72/// Number of STFT frames the model expects (8 s × 100 fps).
73const N_FRAMES: usize = 800;
74/// FFT frequency bins: N_FFT/2 + 1.
75const N_FREQS: usize = N_FFT / 2 + 1; // 201
76/// Ring buffer capacity: 8 s × 16 kHz.
77const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000
78
79/// Embedded ONNX model bytes, downloaded by build.rs at compile time.
80const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
81
82// ---------------------------------------------------------------------------
83// Mel feature extractor
84// ---------------------------------------------------------------------------
85
86/// Pre-computed Whisper-style log-mel feature extractor.
87///
88/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`].
89/// [`MelExtractor::extract`] is then called per inference.
90struct MelExtractor {
91 /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS].
92 mel_filters: Array2<f32>,
93 /// Periodic Hann window of length N_FFT.
94 hann_window: Vec<f32>,
95 /// Reusable forward real FFT plan.
96 fft: Arc<dyn RealToComplex<f32>>,
97 /// Reusable scratch buffer for the FFT.
98 fft_scratch: Vec<Complex<f32>>,
99 /// Reusable output spectrum buffer (N_FREQS complex values).
100 spectrum_buf: Vec<Complex<f32>>,
101 /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call.
102 /// Enables incremental STFT: only new frames are recomputed.
103 cached_power_spec: Option<Array2<f32>>,
104 /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call.
105 /// Enables incremental mel filterbank: only new columns are recomputed.
106 cached_mel_spec: Option<Array2<f32>>,
107}
108
109impl MelExtractor {
110 fn new() -> Self {
111 let mel_filters = build_mel_filters(
112 SAMPLE_RATE as usize,
113 N_FFT,
114 N_MELS,
115 0.0,
116 SAMPLE_RATE as f32 / 2.0,
117 );
118 let hann_window = periodic_hann(N_FFT);
119
120 let mut planner = RealFftPlanner::<f32>::new();
121 let fft = planner.plan_fft_forward(N_FFT);
122 let fft_scratch = fft.make_scratch_vec();
123 let spectrum_buf = fft.make_output_vec();
124
125 Self {
126 mel_filters,
127 hann_window,
128 fft,
129 fft_scratch,
130 spectrum_buf,
131 cached_power_spec: None,
132 cached_mel_spec: None,
133 }
134 }
135
136 /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly
137 /// `RING_CAPACITY` samples of 16 kHz mono audio.
138 ///
139 /// `shift_frames` is how many STFT frames worth of new audio were added
140 /// since the last call. When a valid cache exists and `shift_frames` is
141 /// in range, only the last `shift_frames` columns of the power spectrogram
142 /// are recomputed; the rest are copied from the shifted cache.
143 fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
144 debug_assert_eq!(audio.len(), RING_CAPACITY);
145
146 // ---- Center-pad: N_FFT/2 reflect samples on each side → 128 400 samples ----
147 // Matches WhisperFeatureExtractor: np.pad(waveform, n_fft//2, mode="reflect").
148 // Reflect (not zero) padding ensures the boundary frames match Python exactly.
149 // Gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
150 let pad = N_FFT / 2; // 200
151 let n = audio.len(); // 128 000
152 let mut padded = vec![0.0f32; pad + n + pad];
153 padded[pad..pad + n].copy_from_slice(audio);
154 // Left reflect: padded[0..pad] = audio[pad..1] reversed (exclude edge)
155 for i in 0..pad {
156 padded[i] = audio[pad - i];
157 }
158 // Right reflect: padded[pad+n..pad+n+pad] = audio[n-2..n-2-pad] reversed
159 for i in 0..pad {
160 padded[pad + n + i] = audio[n - 2 - i];
161 }
162
163 // n_total = (128 400 − 400) / 160 + 1 = 801
164 let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
165
166 // ---- Incremental STFT ----
167 // If we have a cached power spec and shift_frames < n_total_frames,
168 // reuse the unchanged frames by shifting the cache left and only
169 // computing the `shift_frames` new columns at the end.
170 let first_new_frame = match &self.cached_power_spec {
171 Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
172 let kept = n_total_frames - shift_frames;
173 let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
174 power_spec
175 .slice_mut(s![.., ..kept])
176 .assign(&cached.slice(s![.., shift_frames..]));
177 self.cached_power_spec = Some(power_spec);
178 kept // only compute frames [kept..n_total_frames]
179 }
180 _ => {
181 self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
182 0 // cold start: compute all frames
183 }
184 };
185
186 let power_spec = self.cached_power_spec.as_mut().unwrap();
187 let mut frame_buf = vec![0.0f32; N_FFT];
188
189 for frame_idx in first_new_frame..n_total_frames {
190 let start = frame_idx * HOP_LENGTH;
191 // Apply periodic Hann window
192 for (i, (&s, &w)) in padded[start..start + N_FFT]
193 .iter()
194 .zip(self.hann_window.iter())
195 .enumerate()
196 {
197 frame_buf[i] = s * w;
198 }
199
200 self.fft
201 .process_with_scratch(
202 &mut frame_buf,
203 &mut self.spectrum_buf,
204 &mut self.fft_scratch,
205 )
206 .expect("FFT failed: internal buffer size mismatch");
207
208 for (k, c) in self.spectrum_buf.iter().enumerate() {
209 power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
210 }
211 }
212
213 // Take first N_FRAMES columns (drop the trailing frame)
214 let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
215
216 // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ----
217 // Reuse the cached mel columns for the unchanged frames; only multiply
218 // the new power-spectrum columns against the filterbank.
219 let mel_spec = match &self.cached_mel_spec {
220 Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
221 let kept = N_FRAMES - shift_frames;
222 let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
223 // Shift old columns left
224 ms.slice_mut(s![.., ..kept])
225 .assign(&cached.slice(s![.., shift_frames..]));
226 // Apply filterbank only to the new power-spectrum columns
227 let new_power = power_spec_view.slice(s![.., kept..]);
228 ms.slice_mut(s![.., kept..])
229 .assign(&self.mel_filters.dot(&new_power));
230 ms
231 }
232 _ => self.mel_filters.dot(&power_spec_view),
233 };
234 self.cached_mel_spec = Some(mel_spec.clone());
235
236 // ---- Log10 with floor at 1e-10 ----
237 let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
238
239 // ---- Dynamic range compression and normalization ----
240 // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4
241 let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
242 log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
243
244 log_mel
245 }
246
247 /// Invalidate all caches (call on reset).
248 fn invalidate_cache(&mut self) {
249 self.cached_power_spec = None;
250 self.cached_mel_spec = None;
251 }
252}
253
254// ---------------------------------------------------------------------------
255// Mel filterbank construction — Slaney scale, slaney norm
256// ---------------------------------------------------------------------------
257
258/// Convert Hz to mel (Slaney/librosa scale, NOT HTK).
259fn hz_to_mel(hz: f32) -> f32 {
260 const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel)
261 const MIN_LOG_HZ: f32 = 1000.0;
262 const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0
263 // logstep = ln(6.4) / 27 (≈ 0.068752)
264 let logstep = (6.4_f32).ln() / 27.0;
265 if hz >= MIN_LOG_HZ {
266 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
267 } else {
268 hz / F_SP
269 }
270}
271
272/// Convert mel back to Hz (Slaney scale).
273fn mel_to_hz(mel: f32) -> f32 {
274 const F_SP: f32 = 200.0 / 3.0;
275 const MIN_LOG_HZ: f32 = 1000.0;
276 const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
277 let logstep = (6.4_f32).ln() / 27.0;
278 if mel >= MIN_LOG_MEL {
279 MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
280 } else {
281 mel * F_SP
282 }
283}
284
285/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs].
286///
287/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax,
288/// norm="slaney", dtype=float32)` which is what HuggingFace's
289/// `WhisperFeatureExtractor` uses internally.
290fn build_mel_filters(
291 sr: usize,
292 n_fft: usize,
293 n_mels: usize,
294 f_min: f32,
295 f_max: f32,
296) -> Array2<f32> {
297 let n_freqs = n_fft / 2 + 1;
298
299 // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, …
300 let fft_freqs: Vec<f32> = (0..n_freqs)
301 .map(|i| i as f32 * sr as f32 / n_fft as f32)
302 .collect();
303
304 // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge)
305 let mel_min = hz_to_mel(f_min);
306 let mel_max = hz_to_mel(f_max);
307 let mel_pts: Vec<f32> = (0..=(n_mels + 1))
308 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
309 .collect();
310 let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
311
312 // Build triangular filters with Slaney normalisation
313 let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
314 for m in 0..n_mels {
315 let f_left = hz_pts[m];
316 let f_center = hz_pts[m + 1];
317 let f_right = hz_pts[m + 2];
318 // Slaney norm: 2 / (right_hz − left_hz)
319 let enorm = 2.0 / (f_right - f_left);
320
321 for (k, &f) in fft_freqs.iter().enumerate() {
322 let w = if f >= f_left && f <= f_center {
323 (f - f_left) / (f_center - f_left)
324 } else if f > f_center && f <= f_right {
325 (f_right - f) / (f_right - f_center)
326 } else {
327 0.0
328 };
329 filters[[m, k]] = w * enorm;
330 }
331 }
332 filters
333}
334
335// ---------------------------------------------------------------------------
336// Hann window
337// ---------------------------------------------------------------------------
338
339/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`.
340///
341/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n.
342/// This differs from the symmetric variant (which divides by n−1).
343fn periodic_hann(n: usize) -> Vec<f32> {
344 use std::f32::consts::PI;
345 (0..n)
346 .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
347 .collect()
348}
349
350// ---------------------------------------------------------------------------
351// Audio preparation
352// ---------------------------------------------------------------------------
353
354/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples.
355///
356/// - Longer: keep the **last** 8 s (discard oldest).
357/// - Shorter: **front-pad** with zeros so audio is right-aligned.
358fn prepare_audio(samples: &[f32]) -> Vec<f32> {
359 match samples.len().cmp(&RING_CAPACITY) {
360 std::cmp::Ordering::Equal => samples.to_vec(),
361 std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
362 std::cmp::Ordering::Less => {
363 let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
364 out.extend_from_slice(samples);
365 out
366 }
367 }
368}
369
370// ---------------------------------------------------------------------------
371// PipecatSmartTurn
372// ---------------------------------------------------------------------------
373
374/// Pipecat Smart Turn v3 detector.
375///
376/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
377/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
378/// end-of-speech to get a [`TurnPrediction`].
379///
380/// # Usage with VAD
381///
382/// ```no_run
383/// # #[cfg(feature = "pipecat")]
384/// # {
385/// use wavekat_turn::audio::PipecatSmartTurn;
386/// use wavekat_turn::AudioTurnDetector;
387///
388/// let mut detector = PipecatSmartTurn::new().unwrap();
389/// // ... feed frames via push_audio ...
390/// let prediction = detector.predict().unwrap();
391/// println!("{:?} ({:.2})", prediction.state, prediction.confidence);
392/// # }
393/// ```
394///
395/// [`push_audio`]: AudioTurnDetector::push_audio
396/// [`predict`]: AudioTurnDetector::predict
397pub struct PipecatSmartTurn {
398 session: ort::session::Session,
399 ring_buffer: VecDeque<f32>,
400 mel: MelExtractor,
401 /// Counts samples pushed since the last `predict()` call.
402 /// Used to compute `shift_frames` for incremental STFT.
403 samples_since_predict: usize,
404}
405
406// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every
407// method that touches the session takes &mut self, preventing concurrent use.
408unsafe impl Send for PipecatSmartTurn {}
409unsafe impl Sync for PipecatSmartTurn {}
410
411impl PipecatSmartTurn {
412 /// Load the Smart Turn v3.2 model embedded at compile time.
413 pub fn new() -> Result<Self, TurnError> {
414 let session = onnx::session_from_memory(MODEL_BYTES)?;
415 Ok(Self::build(session))
416 }
417
418 /// Load a model from a custom path on disk.
419 ///
420 /// Useful for CI environments that supply the model file separately, or
421 /// for evaluating fine-tuned variants without recompiling.
422 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
423 let session = onnx::session_from_file(path)?;
424 Ok(Self::build(session))
425 }
426
427 fn build(session: ort::session::Session) -> Self {
428 Self {
429 session,
430 ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
431 mel: MelExtractor::new(),
432 samples_since_predict: 0,
433 }
434 }
435}
436
437impl AudioTurnDetector for PipecatSmartTurn {
438 /// Append audio to the internal ring buffer.
439 ///
440 /// Frames with a sample rate other than 16 kHz are silently dropped.
441 /// The ring buffer holds at most 8 s; older samples are evicted.
442 fn push_audio(&mut self, frame: &AudioFrame) {
443 if frame.sample_rate() != SAMPLE_RATE {
444 return;
445 }
446 let samples = frame.samples();
447 // Evict oldest samples to make room
448 let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
449 if overflow > 0 {
450 self.ring_buffer.drain(..overflow);
451 }
452 self.ring_buffer.extend(samples.iter().copied());
453 self.samples_since_predict += samples.len();
454 }
455
456 /// Run inference on the buffered audio.
457 ///
458 /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts
459 /// Whisper log-mel features, and runs ONNX inference.
460 fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
461 let t_start = Instant::now();
462
463 // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples
464 let shift_frames = self.samples_since_predict / HOP_LENGTH;
465 self.samples_since_predict = 0;
466
467 let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
468 let audio = prepare_audio(&buffered);
469 let t_after_audio_prep = Instant::now();
470
471 // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental)
472 let mel_spec = self.mel.extract(&audio, shift_frames);
473 let t_after_mel = Instant::now();
474
475 // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference
476 let (raw, _) = mel_spec.into_raw_vec_and_offset();
477 let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
478 .expect("internal: mel output has wrong element count");
479
480 let input_tensor = Tensor::from_array(input_array)
481 .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
482
483 let outputs = self
484 .session
485 .run(inputs!["input_features" => input_tensor])
486 .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
487 let t_after_onnx = Instant::now();
488
489 // Extract sigmoid probability from the "logits" output
490 let output = outputs
491 .get("logits")
492 .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
493 let (_, data): (_, &[f32]) = output
494 .try_extract_tensor()
495 .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
496 let probability = *data
497 .first()
498 .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
499
500 let latency_ms = t_start.elapsed().as_millis() as u64;
501
502 let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
503 let stage_times = vec![
504 StageTiming {
505 name: "audio_prep",
506 us: us(t_start, t_after_audio_prep),
507 },
508 StageTiming {
509 name: "mel",
510 us: us(t_after_audio_prep, t_after_mel),
511 },
512 StageTiming {
513 name: "onnx",
514 us: us(t_after_mel, t_after_onnx),
515 },
516 ];
517
518 // probability = P(turn complete); > 0.5 means the speaker has finished
519 let (state, confidence) = if probability > 0.5 {
520 (TurnState::Finished, probability)
521 } else {
522 (TurnState::Unfinished, 1.0 - probability)
523 };
524
525 Ok(TurnPrediction {
526 state,
527 confidence,
528 latency_ms,
529 stage_times,
530 })
531 }
532
533 /// Clear the ring buffer. Call at the start of each new speech turn.
534 fn reset(&mut self) {
535 self.ring_buffer.clear();
536 self.samples_since_predict = 0;
537 self.mel.invalidate_cache();
538 }
539}
540
541// ---------------------------------------------------------------------------
542// Mel comparison tests (unit tests — need access to private MelExtractor)
543// ---------------------------------------------------------------------------
544
545#[cfg(test)]
546mod mel_tests {
547 use std::path::{Path, PathBuf};
548
549 use ndarray::Array2;
550 use ndarray_npy::ReadNpyExt;
551
552 use super::{prepare_audio, MelExtractor, RING_CAPACITY, SAMPLE_RATE};
553
554 /// Max allowed element-wise absolute difference between Rust and Python mel.
555 const MEL_TOLERANCE: f32 = 0.05;
556
557 fn fixtures_dir() -> PathBuf {
558 Path::new(env!("CARGO_MANIFEST_DIR"))
559 .parent()
560 .unwrap() // crates/
561 .parent()
562 .unwrap() // repo root
563 .join("tests/fixtures")
564 }
565
566 /// Load 16 kHz mono WAV as f32 in [-1, 1], normalised the same way as
567 /// Python's soundfile (divide by 32768, not i16::MAX).
568 fn load_wav_f32(path: &Path) -> Vec<f32> {
569 let mut reader = hound::WavReader::open(path)
570 .unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
571 let spec = reader.spec();
572 assert_eq!(spec.sample_rate, SAMPLE_RATE, "expected 16 kHz");
573 assert_eq!(spec.channels, 1, "expected mono");
574 match spec.sample_format {
575 hound::SampleFormat::Int => reader
576 .samples::<i16>()
577 .map(|s| s.unwrap() as f32 / 32768.0)
578 .collect(),
579 hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
580 }
581 }
582
583 fn load_python_mel(clip: &str) -> Array2<f32> {
584 let path = fixtures_dir().join(format!("{clip}.mel.npy"));
585 let file = std::fs::File::open(&path).unwrap_or_else(|_| {
586 panic!(
587 "missing {}: run `python scripts/gen_reference.py` first",
588 path.display()
589 )
590 });
591 Array2::<f32>::read_npy(file).expect("failed to parse .npy")
592 }
593
594 struct MelDiff {
595 max_diff: f32,
596 mean_diff: f32,
597 /// (mel_bin, frame) of the single largest diff
598 max_at: (usize, usize),
599 /// fraction of elements with diff > 0.01
600 outlier_frac: f32,
601 }
602
603 fn compare_mel(clip: &str) -> MelDiff {
604 let samples = load_wav_f32(&fixtures_dir().join(clip));
605 let audio = prepare_audio(&samples);
606 assert_eq!(audio.len(), RING_CAPACITY);
607
608 let mut extractor = MelExtractor::new();
609 let rust_mel = extractor.extract(&audio, 0);
610 let python_mel = load_python_mel(clip);
611
612 assert_eq!(
613 rust_mel.shape(),
614 python_mel.shape(),
615 "{clip}: mel shape mismatch"
616 );
617
618 let shape = rust_mel.shape();
619 let (n_mels, n_frames) = (shape[0], shape[1]);
620
621 let mut max_diff = 0.0f32;
622 let mut max_at = (0, 0);
623 let mut sum_diff = 0.0f32;
624 let mut outliers = 0usize;
625
626 for m in 0..n_mels {
627 for t in 0..n_frames {
628 let d = (rust_mel[[m, t]] - python_mel[[m, t]]).abs();
629 sum_diff += d;
630 if d > max_diff {
631 max_diff = d;
632 max_at = (m, t);
633 }
634 if d > 0.01 {
635 outliers += 1;
636 }
637 }
638 }
639
640 let total = (n_mels * n_frames) as f32;
641 MelDiff {
642 max_diff,
643 mean_diff: sum_diff / total,
644 max_at,
645 outlier_frac: outliers as f32 / total,
646 }
647 }
648
649 /// Print a markdown table of mel-level diffs between Rust and Python.
650 /// Run with: `make mel`
651 #[test]
652 #[ignore]
653 fn mel_report() {
654 let clips = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"];
655
656 println!();
657 println!("MEL_TOLERANCE={MEL_TOLERANCE}");
658 println!();
659 println!("| Clip | Max Diff | Mean Diff | Max at (mel,frame) | Outliers >0.01 | Status |");
660 println!("|------|----------|-----------|---------------------|----------------|--------|");
661 for clip in clips {
662 let d = compare_mel(clip);
663 let status = if d.max_diff <= MEL_TOLERANCE {
664 "PASS"
665 } else {
666 "FAIL"
667 };
668 println!(
669 "| `{clip}` | {:.6} | {:.6} | ({},{}) | {:.2}% | {status} |",
670 d.max_diff,
671 d.mean_diff,
672 d.max_at.0,
673 d.max_at.1,
674 d.outlier_frac * 100.0,
675 );
676 }
677 println!();
678 }
679}