Skip to main content

wavekat_turn/audio/
pipecat.rs

1//! Pipecat Smart Turn v3 backend.
2//!
3//! Audio-based turn detection using the Smart Turn ONNX model.
4//! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be
5//! upsampled before feeding to this detector.
6//!
7//! # Model
8//!
9//! - Source:  <https://huggingface.co/pipecat-ai/smart-turn-v3>
10//! - File:    `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB)
11//! - License: BSD 2-Clause
12//!
13//! # Tensor specification
14//!
15//! | Role   | Name             | Shape          | Dtype   |
16//! |--------|------------------|----------------|---------|
17//! | Input  | `input_features` | `[B, 80, 800]` | float32 |
18//! | Output | `logits`         | `[B, 1]`       | float32 |
19//!
20//! Despite the name, `logits` is a **sigmoid probability** P(turn complete)
21//! in [0, 1] — the sigmoid is fused into the model before ONNX export.
22//! Threshold: `probability > 0.5` → `TurnState::Finished`.
23//!
24//! # Mel-feature specification
25//!
26//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`:
27//!
28//! | Parameter     | Value                          |
29//! |---------------|--------------------------------|
30//! | Sample rate   | 16 000 Hz                      |
31//! | n_fft         | 400 samples (25 ms)            |
32//! | hop_length    | 160 samples (10 ms)            |
33//! | n_mels        | 80                             |
34//! | Freq range    | 0 – 8 000 Hz                   |
35//! | Mel scale     | Slaney (NOT HTK)               |
36//! | Window        | Hann (periodic, size 400)      |
37//! | Pre-emphasis  | None                           |
38//! | Log           | log10 with ε = 1e-10           |
39//! | Normalization | clamp(max − 8), (x + 4) / 4   |
40//!
41//! # Audio buffer
42//!
43//! - Exactly **8 seconds = 128 000 samples** at 16 kHz.
44//! - Shorter input: **front-padded** with zeros (audio is at the end).
45//! - Longer input: the **last** 8 s is used (oldest samples discarded).
46
47use std::collections::VecDeque;
48use std::path::Path;
49use std::sync::Arc;
50use std::time::Instant;
51
52use ndarray::{s, Array2, Array3};
53use ort::{inputs, value::Tensor};
54use realfft::num_complex::Complex;
55use realfft::{RealFftPlanner, RealToComplex};
56
57use crate::onnx;
58use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
59
60// ---------------------------------------------------------------------------
61// Constants
62// ---------------------------------------------------------------------------
63
64/// Sample rate the model expects.
65const SAMPLE_RATE: u32 = 16_000;
66/// FFT window size in samples (25 ms at 16 kHz).
67const N_FFT: usize = 400;
68/// STFT hop length in samples (10 ms at 16 kHz).
69const HOP_LENGTH: usize = 160;
70/// Number of mel filterbank bins.
71const N_MELS: usize = 80;
72/// Number of STFT frames the model expects (8 s × 100 fps).
73const N_FRAMES: usize = 800;
74/// FFT frequency bins: N_FFT/2 + 1.
75const N_FREQS: usize = N_FFT / 2 + 1; // 201
76/// Ring buffer capacity: 8 s × 16 kHz.
77const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000
78
79/// Embedded ONNX model bytes, downloaded by build.rs at compile time.
80const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
81
82// ---------------------------------------------------------------------------
83// Mel feature extractor
84// ---------------------------------------------------------------------------
85
86/// Pre-computed Whisper-style log-mel feature extractor.
87///
88/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`].
89/// [`MelExtractor::extract`] is then called per inference.
90struct MelExtractor {
91    /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS].
92    mel_filters: Array2<f32>,
93    /// Periodic Hann window of length N_FFT.
94    hann_window: Vec<f32>,
95    /// Reusable forward real FFT plan.
96    fft: Arc<dyn RealToComplex<f32>>,
97    /// Reusable scratch buffer for the FFT.
98    fft_scratch: Vec<Complex<f32>>,
99    /// Reusable output spectrum buffer (N_FREQS complex values).
100    spectrum_buf: Vec<Complex<f32>>,
101    /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call.
102    /// Enables incremental STFT: only new frames are recomputed.
103    cached_power_spec: Option<Array2<f32>>,
104    /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call.
105    /// Enables incremental mel filterbank: only new columns are recomputed.
106    cached_mel_spec: Option<Array2<f32>>,
107}
108
109impl MelExtractor {
110    fn new() -> Self {
111        let mel_filters = build_mel_filters(
112            SAMPLE_RATE as usize,
113            N_FFT,
114            N_MELS,
115            0.0,
116            SAMPLE_RATE as f32 / 2.0,
117        );
118        let hann_window = periodic_hann(N_FFT);
119
120        let mut planner = RealFftPlanner::<f32>::new();
121        let fft = planner.plan_fft_forward(N_FFT);
122        let fft_scratch = fft.make_scratch_vec();
123        let spectrum_buf = fft.make_output_vec();
124
125        Self {
126            mel_filters,
127            hann_window,
128            fft,
129            fft_scratch,
130            spectrum_buf,
131            cached_power_spec: None,
132            cached_mel_spec: None,
133        }
134    }
135
136    /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly
137    /// `RING_CAPACITY` samples of 16 kHz mono audio.
138    ///
139    /// `shift_frames` is how many STFT frames worth of new audio were added
140    /// since the last call. When a valid cache exists and `shift_frames` is
141    /// in range, only the last `shift_frames` columns of the power spectrogram
142    /// are recomputed; the rest are copied from the shifted cache.
143    fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
144        debug_assert_eq!(audio.len(), RING_CAPACITY);
145
146        // ---- Center-pad: N_FFT/2 zeros on each side → 128 400 samples ----
147        // This replicates librosa/PyTorch `center=True` STFT behaviour, which
148        // gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
149        let pad = N_FFT / 2;
150        let mut padded = vec![0.0f32; pad + audio.len() + pad];
151        padded[pad..pad + audio.len()].copy_from_slice(audio);
152
153        // n_total = (128 400 − 400) / 160 + 1 = 801
154        let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
155
156        // ---- Incremental STFT ----
157        // If we have a cached power spec and shift_frames < n_total_frames,
158        // reuse the unchanged frames by shifting the cache left and only
159        // computing the `shift_frames` new columns at the end.
160        let first_new_frame = match &self.cached_power_spec {
161            Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
162                let kept = n_total_frames - shift_frames;
163                let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
164                power_spec
165                    .slice_mut(s![.., ..kept])
166                    .assign(&cached.slice(s![.., shift_frames..]));
167                self.cached_power_spec = Some(power_spec);
168                kept // only compute frames [kept..n_total_frames]
169            }
170            _ => {
171                self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
172                0 // cold start: compute all frames
173            }
174        };
175
176        let power_spec = self.cached_power_spec.as_mut().unwrap();
177        let mut frame_buf = vec![0.0f32; N_FFT];
178
179        for frame_idx in first_new_frame..n_total_frames {
180            let start = frame_idx * HOP_LENGTH;
181            // Apply periodic Hann window
182            for (i, (&s, &w)) in padded[start..start + N_FFT]
183                .iter()
184                .zip(self.hann_window.iter())
185                .enumerate()
186            {
187                frame_buf[i] = s * w;
188            }
189
190            self.fft
191                .process_with_scratch(
192                    &mut frame_buf,
193                    &mut self.spectrum_buf,
194                    &mut self.fft_scratch,
195                )
196                .expect("FFT failed: internal buffer size mismatch");
197
198            for (k, c) in self.spectrum_buf.iter().enumerate() {
199                power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
200            }
201        }
202
203        // Take first N_FRAMES columns (drop the trailing frame)
204        let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
205
206        // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ----
207        // Reuse the cached mel columns for the unchanged frames; only multiply
208        // the new power-spectrum columns against the filterbank.
209        let mel_spec = match &self.cached_mel_spec {
210            Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
211                let kept = N_FRAMES - shift_frames;
212                let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
213                // Shift old columns left
214                ms.slice_mut(s![.., ..kept])
215                    .assign(&cached.slice(s![.., shift_frames..]));
216                // Apply filterbank only to the new power-spectrum columns
217                let new_power = power_spec_view.slice(s![.., kept..]);
218                ms.slice_mut(s![.., kept..])
219                    .assign(&self.mel_filters.dot(&new_power));
220                ms
221            }
222            _ => self.mel_filters.dot(&power_spec_view),
223        };
224        self.cached_mel_spec = Some(mel_spec.clone());
225
226        // ---- Log10 with floor at 1e-10 ----
227        let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
228
229        // ---- Dynamic range compression and normalization ----
230        // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4
231        let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
232        log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
233
234        log_mel
235    }
236
237    /// Invalidate all caches (call on reset).
238    fn invalidate_cache(&mut self) {
239        self.cached_power_spec = None;
240        self.cached_mel_spec = None;
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Mel filterbank construction — Slaney scale, slaney norm
246// ---------------------------------------------------------------------------
247
248/// Convert Hz to mel (Slaney/librosa scale, NOT HTK).
249fn hz_to_mel(hz: f32) -> f32 {
250    const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel)
251    const MIN_LOG_HZ: f32 = 1000.0;
252    const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0
253                                                // logstep = ln(6.4) / 27  (≈ 0.068752)
254    let logstep = (6.4_f32).ln() / 27.0;
255    if hz >= MIN_LOG_HZ {
256        MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
257    } else {
258        hz / F_SP
259    }
260}
261
262/// Convert mel back to Hz (Slaney scale).
263fn mel_to_hz(mel: f32) -> f32 {
264    const F_SP: f32 = 200.0 / 3.0;
265    const MIN_LOG_HZ: f32 = 1000.0;
266    const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
267    let logstep = (6.4_f32).ln() / 27.0;
268    if mel >= MIN_LOG_MEL {
269        MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
270    } else {
271        mel * F_SP
272    }
273}
274
275/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs].
276///
277/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax,
278///   norm="slaney", dtype=float32)` which is what HuggingFace's
279/// `WhisperFeatureExtractor` uses internally.
280fn build_mel_filters(
281    sr: usize,
282    n_fft: usize,
283    n_mels: usize,
284    f_min: f32,
285    f_max: f32,
286) -> Array2<f32> {
287    let n_freqs = n_fft / 2 + 1;
288
289    // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, …
290    let fft_freqs: Vec<f32> = (0..n_freqs)
291        .map(|i| i as f32 * sr as f32 / n_fft as f32)
292        .collect();
293
294    // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge)
295    let mel_min = hz_to_mel(f_min);
296    let mel_max = hz_to_mel(f_max);
297    let mel_pts: Vec<f32> = (0..=(n_mels + 1))
298        .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
299        .collect();
300    let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
301
302    // Build triangular filters with Slaney normalisation
303    let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
304    for m in 0..n_mels {
305        let f_left = hz_pts[m];
306        let f_center = hz_pts[m + 1];
307        let f_right = hz_pts[m + 2];
308        // Slaney norm: 2 / (right_hz − left_hz)
309        let enorm = 2.0 / (f_right - f_left);
310
311        for (k, &f) in fft_freqs.iter().enumerate() {
312            let w = if f >= f_left && f <= f_center {
313                (f - f_left) / (f_center - f_left)
314            } else if f > f_center && f <= f_right {
315                (f_right - f) / (f_right - f_center)
316            } else {
317                0.0
318            };
319            filters[[m, k]] = w * enorm;
320        }
321    }
322    filters
323}
324
325// ---------------------------------------------------------------------------
326// Hann window
327// ---------------------------------------------------------------------------
328
329/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`.
330///
331/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n.
332/// This differs from the symmetric variant (which divides by n−1).
333fn periodic_hann(n: usize) -> Vec<f32> {
334    use std::f32::consts::PI;
335    (0..n)
336        .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
337        .collect()
338}
339
340// ---------------------------------------------------------------------------
341// Audio preparation
342// ---------------------------------------------------------------------------
343
344/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples.
345///
346/// - Longer: keep the **last** 8 s (discard oldest).
347/// - Shorter: **front-pad** with zeros so audio is right-aligned.
348fn prepare_audio(samples: &[f32]) -> Vec<f32> {
349    match samples.len().cmp(&RING_CAPACITY) {
350        std::cmp::Ordering::Equal => samples.to_vec(),
351        std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
352        std::cmp::Ordering::Less => {
353            let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
354            out.extend_from_slice(samples);
355            out
356        }
357    }
358}
359
360// ---------------------------------------------------------------------------
361// PipecatSmartTurn
362// ---------------------------------------------------------------------------
363
364/// Pipecat Smart Turn v3 detector.
365///
366/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
367/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
368/// end-of-speech to get a [`TurnPrediction`].
369///
370/// # Usage with VAD
371///
372/// ```no_run
373/// # #[cfg(feature = "pipecat")]
374/// # {
375/// use wavekat_turn::audio::PipecatSmartTurn;
376/// use wavekat_turn::AudioTurnDetector;
377///
378/// let mut detector = PipecatSmartTurn::new().unwrap();
379/// // ... feed frames via push_audio ...
380/// let prediction = detector.predict().unwrap();
381/// println!("{:?} ({:.2})", prediction.state, prediction.confidence);
382/// # }
383/// ```
384///
385/// [`push_audio`]: AudioTurnDetector::push_audio
386/// [`predict`]: AudioTurnDetector::predict
387pub struct PipecatSmartTurn {
388    session: ort::session::Session,
389    ring_buffer: VecDeque<f32>,
390    mel: MelExtractor,
391    /// Counts samples pushed since the last `predict()` call.
392    /// Used to compute `shift_frames` for incremental STFT.
393    samples_since_predict: usize,
394}
395
396// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every
397// method that touches the session takes &mut self, preventing concurrent use.
398unsafe impl Send for PipecatSmartTurn {}
399unsafe impl Sync for PipecatSmartTurn {}
400
401impl PipecatSmartTurn {
402    /// Load the Smart Turn v3.2 model embedded at compile time.
403    pub fn new() -> Result<Self, TurnError> {
404        let session = onnx::session_from_memory(MODEL_BYTES)?;
405        Ok(Self::build(session))
406    }
407
408    /// Load a model from a custom path on disk.
409    ///
410    /// Useful for CI environments that supply the model file separately, or
411    /// for evaluating fine-tuned variants without recompiling.
412    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
413        let session = onnx::session_from_file(path)?;
414        Ok(Self::build(session))
415    }
416
417    fn build(session: ort::session::Session) -> Self {
418        Self {
419            session,
420            ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
421            mel: MelExtractor::new(),
422            samples_since_predict: 0,
423        }
424    }
425}
426
427impl AudioTurnDetector for PipecatSmartTurn {
428    /// Append audio to the internal ring buffer.
429    ///
430    /// Frames with a sample rate other than 16 kHz are silently dropped.
431    /// The ring buffer holds at most 8 s; older samples are evicted.
432    fn push_audio(&mut self, frame: &AudioFrame) {
433        if frame.sample_rate() != SAMPLE_RATE {
434            return;
435        }
436        let samples = frame.samples();
437        // Evict oldest samples to make room
438        let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
439        if overflow > 0 {
440            self.ring_buffer.drain(..overflow);
441        }
442        self.ring_buffer.extend(samples.iter().copied());
443        self.samples_since_predict += samples.len();
444    }
445
446    /// Run inference on the buffered audio.
447    ///
448    /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts
449    /// Whisper log-mel features, and runs ONNX inference.
450    fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
451        let t_start = Instant::now();
452
453        // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples
454        let shift_frames = self.samples_since_predict / HOP_LENGTH;
455        self.samples_since_predict = 0;
456
457        let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
458        let audio = prepare_audio(&buffered);
459        let t_after_audio_prep = Instant::now();
460
461        // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental)
462        let mel_spec = self.mel.extract(&audio, shift_frames);
463        let t_after_mel = Instant::now();
464
465        // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference
466        let (raw, _) = mel_spec.into_raw_vec_and_offset();
467        let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
468            .expect("internal: mel output has wrong element count");
469
470        let input_tensor = Tensor::from_array(input_array)
471            .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
472
473        let outputs = self
474            .session
475            .run(inputs!["input_features" => input_tensor])
476            .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
477        let t_after_onnx = Instant::now();
478
479        // Extract sigmoid probability from the "logits" output
480        let output = outputs
481            .get("logits")
482            .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
483        let (_, data): (_, &[f32]) = output
484            .try_extract_tensor()
485            .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
486        let probability = *data
487            .first()
488            .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
489
490        let latency_ms = t_start.elapsed().as_millis() as u64;
491
492        let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
493        let stage_times = vec![
494            StageTiming {
495                name: "audio_prep",
496                us: us(t_start, t_after_audio_prep),
497            },
498            StageTiming {
499                name: "mel",
500                us: us(t_after_audio_prep, t_after_mel),
501            },
502            StageTiming {
503                name: "onnx",
504                us: us(t_after_mel, t_after_onnx),
505            },
506        ];
507
508        // probability = P(turn complete); > 0.5 means the speaker has finished
509        let (state, confidence) = if probability > 0.5 {
510            (TurnState::Finished, probability)
511        } else {
512            (TurnState::Unfinished, 1.0 - probability)
513        };
514
515        Ok(TurnPrediction {
516            state,
517            confidence,
518            latency_ms,
519            stage_times,
520        })
521    }
522
523    /// Clear the ring buffer. Call at the start of each new speech turn.
524    fn reset(&mut self) {
525        self.ring_buffer.clear();
526        self.samples_since_predict = 0;
527        self.mel.invalidate_cache();
528    }
529}