wavekat_turn/audio/pipecat.rs
1//! Pipecat Smart Turn v3 backend.
2//!
3//! Audio-based turn detection using the Smart Turn ONNX model.
4//! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be
5//! upsampled before feeding to this detector.
6//!
7//! # Model
8//!
9//! - Source: <https://huggingface.co/pipecat-ai/smart-turn-v3>
10//! - File: `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB)
11//! - License: BSD 2-Clause
12//!
13//! # Tensor specification
14//!
15//! | Role | Name | Shape | Dtype |
16//! |--------|------------------|----------------|---------|
17//! | Input | `input_features` | `[B, 80, 800]` | float32 |
18//! | Output | `logits` | `[B, 1]` | float32 |
19//!
20//! Despite the name, `logits` is a **sigmoid probability** P(turn complete)
21//! in [0, 1] — the sigmoid is fused into the model before ONNX export.
22//! Threshold: `probability > 0.5` → `TurnState::Finished`.
23//!
24//! # Mel-feature specification
25//!
26//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`:
27//!
28//! | Parameter | Value |
29//! |---------------|--------------------------------|
30//! | Sample rate | 16 000 Hz |
31//! | n_fft | 400 samples (25 ms) |
32//! | hop_length | 160 samples (10 ms) |
33//! | n_mels | 80 |
34//! | Freq range | 0 – 8 000 Hz |
35//! | Mel scale | Slaney (NOT HTK) |
36//! | Window | Hann (periodic, size 400) |
37//! | Pre-emphasis | None |
38//! | Log | log10 with ε = 1e-10 |
39//! | Normalization | clamp(max − 8), (x + 4) / 4 |
40//!
41//! # Audio buffer
42//!
43//! - Exactly **8 seconds = 128 000 samples** at 16 kHz.
44//! - Shorter input: **front-padded** with zeros (audio is at the end).
45//! - Longer input: the **last** 8 s is used (oldest samples discarded).
46
47use std::collections::VecDeque;
48use std::path::Path;
49use std::sync::Arc;
50use std::time::Instant;
51
52use ndarray::{s, Array2, Array3};
53use ort::{inputs, value::Tensor};
54use realfft::num_complex::Complex;
55use realfft::{RealFftPlanner, RealToComplex};
56
57use crate::onnx;
58use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
59
60// ---------------------------------------------------------------------------
61// Constants
62// ---------------------------------------------------------------------------
63
64/// Sample rate the model expects.
65const SAMPLE_RATE: u32 = 16_000;
66/// FFT window size in samples (25 ms at 16 kHz).
67const N_FFT: usize = 400;
68/// STFT hop length in samples (10 ms at 16 kHz).
69const HOP_LENGTH: usize = 160;
70/// Number of mel filterbank bins.
71const N_MELS: usize = 80;
72/// Number of STFT frames the model expects (8 s × 100 fps).
73const N_FRAMES: usize = 800;
74/// FFT frequency bins: N_FFT/2 + 1.
75const N_FREQS: usize = N_FFT / 2 + 1; // 201
76/// Ring buffer capacity: 8 s × 16 kHz.
77const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000
78
79/// Embedded ONNX model bytes, downloaded by build.rs at compile time.
80const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
81
82// ---------------------------------------------------------------------------
83// Mel feature extractor
84// ---------------------------------------------------------------------------
85
86/// Pre-computed Whisper-style log-mel feature extractor.
87///
88/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`].
89/// [`MelExtractor::extract`] is then called per inference.
90struct MelExtractor {
91 /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS].
92 mel_filters: Array2<f32>,
93 /// Periodic Hann window of length N_FFT.
94 hann_window: Vec<f32>,
95 /// Reusable forward real FFT plan.
96 fft: Arc<dyn RealToComplex<f32>>,
97 /// Reusable scratch buffer for the FFT.
98 fft_scratch: Vec<Complex<f32>>,
99 /// Reusable output spectrum buffer (N_FREQS complex values).
100 spectrum_buf: Vec<Complex<f32>>,
101 /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call.
102 /// Enables incremental STFT: only new frames are recomputed.
103 cached_power_spec: Option<Array2<f32>>,
104 /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call.
105 /// Enables incremental mel filterbank: only new columns are recomputed.
106 cached_mel_spec: Option<Array2<f32>>,
107}
108
109impl MelExtractor {
110 fn new() -> Self {
111 let mel_filters = build_mel_filters(
112 SAMPLE_RATE as usize,
113 N_FFT,
114 N_MELS,
115 0.0,
116 SAMPLE_RATE as f32 / 2.0,
117 );
118 let hann_window = periodic_hann(N_FFT);
119
120 let mut planner = RealFftPlanner::<f32>::new();
121 let fft = planner.plan_fft_forward(N_FFT);
122 let fft_scratch = fft.make_scratch_vec();
123 let spectrum_buf = fft.make_output_vec();
124
125 Self {
126 mel_filters,
127 hann_window,
128 fft,
129 fft_scratch,
130 spectrum_buf,
131 cached_power_spec: None,
132 cached_mel_spec: None,
133 }
134 }
135
136 /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly
137 /// `RING_CAPACITY` samples of 16 kHz mono audio.
138 ///
139 /// `shift_frames` is how many STFT frames worth of new audio were added
140 /// since the last call. When a valid cache exists and `shift_frames` is
141 /// in range, only the last `shift_frames` columns of the power spectrogram
142 /// are recomputed; the rest are copied from the shifted cache.
143 fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
144 debug_assert_eq!(audio.len(), RING_CAPACITY);
145
146 // ---- Center-pad: N_FFT/2 zeros on each side → 128 400 samples ----
147 // This replicates librosa/PyTorch `center=True` STFT behaviour, which
148 // gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
149 let pad = N_FFT / 2;
150 let mut padded = vec![0.0f32; pad + audio.len() + pad];
151 padded[pad..pad + audio.len()].copy_from_slice(audio);
152
153 // n_total = (128 400 − 400) / 160 + 1 = 801
154 let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
155
156 // ---- Incremental STFT ----
157 // If we have a cached power spec and shift_frames < n_total_frames,
158 // reuse the unchanged frames by shifting the cache left and only
159 // computing the `shift_frames` new columns at the end.
160 let first_new_frame = match &self.cached_power_spec {
161 Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
162 let kept = n_total_frames - shift_frames;
163 let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
164 power_spec
165 .slice_mut(s![.., ..kept])
166 .assign(&cached.slice(s![.., shift_frames..]));
167 self.cached_power_spec = Some(power_spec);
168 kept // only compute frames [kept..n_total_frames]
169 }
170 _ => {
171 self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
172 0 // cold start: compute all frames
173 }
174 };
175
176 let power_spec = self.cached_power_spec.as_mut().unwrap();
177 let mut frame_buf = vec![0.0f32; N_FFT];
178
179 for frame_idx in first_new_frame..n_total_frames {
180 let start = frame_idx * HOP_LENGTH;
181 // Apply periodic Hann window
182 for (i, (&s, &w)) in padded[start..start + N_FFT]
183 .iter()
184 .zip(self.hann_window.iter())
185 .enumerate()
186 {
187 frame_buf[i] = s * w;
188 }
189
190 self.fft
191 .process_with_scratch(
192 &mut frame_buf,
193 &mut self.spectrum_buf,
194 &mut self.fft_scratch,
195 )
196 .expect("FFT failed: internal buffer size mismatch");
197
198 for (k, c) in self.spectrum_buf.iter().enumerate() {
199 power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
200 }
201 }
202
203 // Take first N_FRAMES columns (drop the trailing frame)
204 let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
205
206 // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ----
207 // Reuse the cached mel columns for the unchanged frames; only multiply
208 // the new power-spectrum columns against the filterbank.
209 let mel_spec = match &self.cached_mel_spec {
210 Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
211 let kept = N_FRAMES - shift_frames;
212 let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
213 // Shift old columns left
214 ms.slice_mut(s![.., ..kept])
215 .assign(&cached.slice(s![.., shift_frames..]));
216 // Apply filterbank only to the new power-spectrum columns
217 let new_power = power_spec_view.slice(s![.., kept..]);
218 ms.slice_mut(s![.., kept..])
219 .assign(&self.mel_filters.dot(&new_power));
220 ms
221 }
222 _ => self.mel_filters.dot(&power_spec_view),
223 };
224 self.cached_mel_spec = Some(mel_spec.clone());
225
226 // ---- Log10 with floor at 1e-10 ----
227 let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
228
229 // ---- Dynamic range compression and normalization ----
230 // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4
231 let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
232 log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
233
234 log_mel
235 }
236
237 /// Invalidate all caches (call on reset).
238 fn invalidate_cache(&mut self) {
239 self.cached_power_spec = None;
240 self.cached_mel_spec = None;
241 }
242}
243
244// ---------------------------------------------------------------------------
245// Mel filterbank construction — Slaney scale, slaney norm
246// ---------------------------------------------------------------------------
247
248/// Convert Hz to mel (Slaney/librosa scale, NOT HTK).
249fn hz_to_mel(hz: f32) -> f32 {
250 const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel)
251 const MIN_LOG_HZ: f32 = 1000.0;
252 const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0
253 // logstep = ln(6.4) / 27 (≈ 0.068752)
254 let logstep = (6.4_f32).ln() / 27.0;
255 if hz >= MIN_LOG_HZ {
256 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
257 } else {
258 hz / F_SP
259 }
260}
261
262/// Convert mel back to Hz (Slaney scale).
263fn mel_to_hz(mel: f32) -> f32 {
264 const F_SP: f32 = 200.0 / 3.0;
265 const MIN_LOG_HZ: f32 = 1000.0;
266 const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
267 let logstep = (6.4_f32).ln() / 27.0;
268 if mel >= MIN_LOG_MEL {
269 MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
270 } else {
271 mel * F_SP
272 }
273}
274
275/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs].
276///
277/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax,
278/// norm="slaney", dtype=float32)` which is what HuggingFace's
279/// `WhisperFeatureExtractor` uses internally.
280fn build_mel_filters(
281 sr: usize,
282 n_fft: usize,
283 n_mels: usize,
284 f_min: f32,
285 f_max: f32,
286) -> Array2<f32> {
287 let n_freqs = n_fft / 2 + 1;
288
289 // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, …
290 let fft_freqs: Vec<f32> = (0..n_freqs)
291 .map(|i| i as f32 * sr as f32 / n_fft as f32)
292 .collect();
293
294 // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge)
295 let mel_min = hz_to_mel(f_min);
296 let mel_max = hz_to_mel(f_max);
297 let mel_pts: Vec<f32> = (0..=(n_mels + 1))
298 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
299 .collect();
300 let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
301
302 // Build triangular filters with Slaney normalisation
303 let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
304 for m in 0..n_mels {
305 let f_left = hz_pts[m];
306 let f_center = hz_pts[m + 1];
307 let f_right = hz_pts[m + 2];
308 // Slaney norm: 2 / (right_hz − left_hz)
309 let enorm = 2.0 / (f_right - f_left);
310
311 for (k, &f) in fft_freqs.iter().enumerate() {
312 let w = if f >= f_left && f <= f_center {
313 (f - f_left) / (f_center - f_left)
314 } else if f > f_center && f <= f_right {
315 (f_right - f) / (f_right - f_center)
316 } else {
317 0.0
318 };
319 filters[[m, k]] = w * enorm;
320 }
321 }
322 filters
323}
324
325// ---------------------------------------------------------------------------
326// Hann window
327// ---------------------------------------------------------------------------
328
329/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`.
330///
331/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n.
332/// This differs from the symmetric variant (which divides by n−1).
333fn periodic_hann(n: usize) -> Vec<f32> {
334 use std::f32::consts::PI;
335 (0..n)
336 .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
337 .collect()
338}
339
340// ---------------------------------------------------------------------------
341// Audio preparation
342// ---------------------------------------------------------------------------
343
344/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples.
345///
346/// - Longer: keep the **last** 8 s (discard oldest).
347/// - Shorter: **front-pad** with zeros so audio is right-aligned.
348fn prepare_audio(samples: &[f32]) -> Vec<f32> {
349 match samples.len().cmp(&RING_CAPACITY) {
350 std::cmp::Ordering::Equal => samples.to_vec(),
351 std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
352 std::cmp::Ordering::Less => {
353 let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
354 out.extend_from_slice(samples);
355 out
356 }
357 }
358}
359
360// ---------------------------------------------------------------------------
361// PipecatSmartTurn
362// ---------------------------------------------------------------------------
363
364/// Pipecat Smart Turn v3 detector.
365///
366/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
367/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
368/// end-of-speech to get a [`TurnPrediction`].
369///
370/// # Usage with VAD
371///
372/// ```no_run
373/// # #[cfg(feature = "pipecat")]
374/// # {
375/// use wavekat_turn::audio::PipecatSmartTurn;
376/// use wavekat_turn::AudioTurnDetector;
377///
378/// let mut detector = PipecatSmartTurn::new().unwrap();
379/// // ... feed frames via push_audio ...
380/// let prediction = detector.predict().unwrap();
381/// println!("{:?} ({:.2})", prediction.state, prediction.confidence);
382/// # }
383/// ```
384///
385/// [`push_audio`]: AudioTurnDetector::push_audio
386/// [`predict`]: AudioTurnDetector::predict
387pub struct PipecatSmartTurn {
388 session: ort::session::Session,
389 ring_buffer: VecDeque<f32>,
390 mel: MelExtractor,
391 /// Counts samples pushed since the last `predict()` call.
392 /// Used to compute `shift_frames` for incremental STFT.
393 samples_since_predict: usize,
394}
395
396// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every
397// method that touches the session takes &mut self, preventing concurrent use.
398unsafe impl Send for PipecatSmartTurn {}
399unsafe impl Sync for PipecatSmartTurn {}
400
401impl PipecatSmartTurn {
402 /// Load the Smart Turn v3.2 model embedded at compile time.
403 pub fn new() -> Result<Self, TurnError> {
404 let session = onnx::session_from_memory(MODEL_BYTES)?;
405 Ok(Self::build(session))
406 }
407
408 /// Load a model from a custom path on disk.
409 ///
410 /// Useful for CI environments that supply the model file separately, or
411 /// for evaluating fine-tuned variants without recompiling.
412 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
413 let session = onnx::session_from_file(path)?;
414 Ok(Self::build(session))
415 }
416
417 fn build(session: ort::session::Session) -> Self {
418 Self {
419 session,
420 ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
421 mel: MelExtractor::new(),
422 samples_since_predict: 0,
423 }
424 }
425}
426
427impl AudioTurnDetector for PipecatSmartTurn {
428 /// Append audio to the internal ring buffer.
429 ///
430 /// Frames with a sample rate other than 16 kHz are silently dropped.
431 /// The ring buffer holds at most 8 s; older samples are evicted.
432 fn push_audio(&mut self, frame: &AudioFrame) {
433 if frame.sample_rate() != SAMPLE_RATE {
434 return;
435 }
436 let samples = frame.samples();
437 // Evict oldest samples to make room
438 let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
439 if overflow > 0 {
440 self.ring_buffer.drain(..overflow);
441 }
442 self.ring_buffer.extend(samples.iter().copied());
443 self.samples_since_predict += samples.len();
444 }
445
446 /// Run inference on the buffered audio.
447 ///
448 /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts
449 /// Whisper log-mel features, and runs ONNX inference.
450 fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
451 let t_start = Instant::now();
452
453 // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples
454 let shift_frames = self.samples_since_predict / HOP_LENGTH;
455 self.samples_since_predict = 0;
456
457 let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
458 let audio = prepare_audio(&buffered);
459 let t_after_audio_prep = Instant::now();
460
461 // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental)
462 let mel_spec = self.mel.extract(&audio, shift_frames);
463 let t_after_mel = Instant::now();
464
465 // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference
466 let (raw, _) = mel_spec.into_raw_vec_and_offset();
467 let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
468 .expect("internal: mel output has wrong element count");
469
470 let input_tensor = Tensor::from_array(input_array)
471 .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
472
473 let outputs = self
474 .session
475 .run(inputs!["input_features" => input_tensor])
476 .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
477 let t_after_onnx = Instant::now();
478
479 // Extract sigmoid probability from the "logits" output
480 let output = outputs
481 .get("logits")
482 .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
483 let (_, data): (_, &[f32]) = output
484 .try_extract_tensor()
485 .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
486 let probability = *data
487 .first()
488 .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
489
490 let latency_ms = t_start.elapsed().as_millis() as u64;
491
492 let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
493 let stage_times = vec![
494 StageTiming {
495 name: "audio_prep",
496 us: us(t_start, t_after_audio_prep),
497 },
498 StageTiming {
499 name: "mel",
500 us: us(t_after_audio_prep, t_after_mel),
501 },
502 StageTiming {
503 name: "onnx",
504 us: us(t_after_mel, t_after_onnx),
505 },
506 ];
507
508 // probability = P(turn complete); > 0.5 means the speaker has finished
509 let (state, confidence) = if probability > 0.5 {
510 (TurnState::Finished, probability)
511 } else {
512 (TurnState::Unfinished, 1.0 - probability)
513 };
514
515 Ok(TurnPrediction {
516 state,
517 confidence,
518 latency_ms,
519 stage_times,
520 })
521 }
522
523 /// Clear the ring buffer. Call at the start of each new speech turn.
524 fn reset(&mut self) {
525 self.ring_buffer.clear();
526 self.samples_since_predict = 0;
527 self.mel.invalidate_cache();
528 }
529}