voice_engine/media/vad/
tiny_ten.rs

1use super::{VADOption, VadEngine};
2use crate::media::{AudioFrame, PcmBuf, Samples};
3use anyhow::Result;
4use realfft::{RealFftPlanner, RealToComplex};
5use std::sync::Arc;
6
7// Constants
8const SAMPLE_RATE: u32 = 16000;
9const HOP_SIZE: usize = 256; // 16ms per frame
10const FFT_SIZE: usize = 1024;
11const WINDOW_SIZE: usize = 768;
12const MEL_FILTER_BANK_NUM: usize = 40;
13const FEATURE_LEN: usize = 41; // 40 mel features + 1 pitch feature
14const CONTEXT_WINDOW_LEN: usize = 3;
15const HIDDEN_SIZE: usize = 64;
16const EPS: f32 = 1e-20;
17const PRE_EMPHASIS_COEFF: f32 = 0.97;
18
19// Feature normalization parameters
20const FEATURE_MEANS: [f32; FEATURE_LEN] = [
21    -8.198_236,
22    -6.265_716_6,
23    -5.483_818_5,
24    -4.758_691_3,
25    -4.417_089,
26    -4.142_893,
27    -3.912_850_4,
28    -3.845_928,
29    -3.657_090_4,
30    -3.723_418_7,
31    -3.876_134_2,
32    -3.843_891,
33    -3.690_405_1,
34    -3.756_065_8,
35    -3.698_696_1,
36    -3.650_463,
37    -3.700_468_8,
38    -3.567_321_3,
39    -3.498_900_2,
40    -3.477_807,
41    -3.458_816,
42    -3.444_923_9,
43    -3.401_328_6,
44    -3.306_261_3,
45    -3.278_556_8,
46    -3.233_250_9,
47    -3.198_616,
48    -3.204_526_4,
49    -3.208_798_6,
50    -3.257_838,
51    -3.381_376_7,
52    -3.534_021_4,
53    -3.640_868,
54    -3.726_858_9,
55    -3.773_731,
56    -3.804_667_2,
57    -3.832_901,
58    -3.871_120_5,
59    -3.990_593,
60    -4.480_289_5,
61    9.235_69e1,
62];
63
64const FEATURE_STDS: [f32; FEATURE_LEN] = [
65    5.166_064,
66    4.977_21,
67    4.698_896,
68    4.630_621_4,
69    4.634_348,
70    4.641_156,
71    4.640_676_5,
72    4.666_367,
73    4.650_534_6,
74    4.640_021,
75    4.637_4,
76    4.620_099,
77    4.596_316_3,
78    4.562_655,
79    4.554_36,
80    4.566_910_7,
81    4.562_49,
82    4.562_413,
83    4.585_299_5,
84    4.600_179_7,
85    4.592_846,
86    4.585_923,
87    4.583_496_6,
88    4.626_093,
89    4.626_958,
90    4.626_289_4,
91    4.637_006,
92    4.683_016,
93    4.726_814,
94    4.734_29,
95    4.753_227,
96    4.849_723,
97    4.869_435,
98    4.884_483,
99    4.921_327,
100    4.959_212_3,
101    4.996_619,
102    5.044_823_6,
103    5.072_217,
104    5.096_439_4,
105    1.152_136_9e2,
106];
107
108pub struct TenFeatureExtractor {
109    pre_emphasis_prev: f32,
110    mel_filters: ndarray::Array2<f32>,
111    mel_filter_ranges: Vec<(usize, usize)>,
112    window: Vec<f32>,
113    // FFT related fields
114    rfft: Arc<dyn RealToComplex<f32>>,
115    fft_scratch: Vec<realfft::num_complex::Complex<f32>>,
116    fft_output: Vec<realfft::num_complex::Complex<f32>>,
117    fft_input: Vec<f32>,
118    power_spectrum: Vec<f32>,
119    inv_stds: Vec<f32>,
120}
121
122impl TenFeatureExtractor {
123    pub fn new() -> Self {
124        // Generate mel filter bank
125        let (mel_filters, mel_filter_ranges) = Self::generate_mel_filters();
126
127        // Generate Hann window
128        let window = super::utils::generate_hann_window(WINDOW_SIZE, false);
129
130        // Initialize FFT
131        let mut planner = RealFftPlanner::<f32>::new();
132        let rfft = planner.plan_fft_forward(FFT_SIZE);
133        let fft_scratch = rfft.make_scratch_vec();
134        let fft_output = rfft.make_output_vec();
135        let fft_input = rfft.make_input_vec();
136        let power_spectrum = vec![0.0; FFT_SIZE / 2 + 1];
137
138        // Pre-calculate inverse STDs
139        let inv_stds: Vec<f32> = FEATURE_STDS.iter().map(|&std| 1.0 / (std + EPS)).collect();
140
141        Self {
142            pre_emphasis_prev: 0.0,
143            mel_filters,
144            mel_filter_ranges,
145            window,
146            rfft,
147            fft_scratch,
148            fft_output,
149            fft_input,
150            power_spectrum,
151            inv_stds,
152        }
153    }
154
155    fn generate_mel_filters() -> (ndarray::Array2<f32>, Vec<(usize, usize)>) {
156        let n_bins = FFT_SIZE / 2 + 1;
157
158        // Generate mel frequency points
159        let low_mel = 2595.0_f32 * (1.0_f32 + 0.0_f32 / 700.0_f32).log10();
160        let high_mel = 2595.0_f32 * (1.0_f32 + 8000.0_f32 / 700.0_f32).log10();
161
162        let mut mel_points = Vec::new();
163        for i in 0..=MEL_FILTER_BANK_NUM + 1 {
164            let mel = low_mel + (high_mel - low_mel) * i as f32 / (MEL_FILTER_BANK_NUM + 1) as f32;
165            mel_points.push(mel);
166        }
167
168        // Convert to Hz
169        let mut hz_points = Vec::new();
170        for mel in mel_points {
171            let hz = 700.0_f32 * (10.0_f32.powf(mel / 2595.0_f32) - 1.0_f32);
172            hz_points.push(hz);
173        }
174
175        // Convert to FFT bin indices
176        let mut bin_points = Vec::new();
177        for hz in hz_points {
178            let bin = ((FFT_SIZE + 1) as f32 * hz / SAMPLE_RATE as f32).floor() as usize;
179            bin_points.push(bin);
180        }
181
182        // Build mel filter bank
183        let mut mel_filters = ndarray::Array2::<f32>::zeros((MEL_FILTER_BANK_NUM, n_bins));
184        let mut ranges = Vec::with_capacity(MEL_FILTER_BANK_NUM);
185
186        for i in 0..MEL_FILTER_BANK_NUM {
187            let start = bin_points[i];
188            let end = bin_points[i + 2];
189            ranges.push((start, end));
190
191            // Left slope
192            for j in bin_points[i]..bin_points[i + 1] {
193                if j < n_bins {
194                    mel_filters[[i, j]] =
195                        (j - bin_points[i]) as f32 / (bin_points[i + 1] - bin_points[i]) as f32;
196                }
197            }
198
199            // Right slope
200            for j in bin_points[i + 1]..bin_points[i + 2] {
201                if j < n_bins {
202                    mel_filters[[i, j]] = (bin_points[i + 2] - j) as f32
203                        / (bin_points[i + 2] - bin_points[i + 1]) as f32;
204                }
205            }
206        }
207
208        (mel_filters, ranges)
209    }
210
211    fn pre_emphasis(prev_state: &mut f32, audio_frame: &[i16], output: &mut [f32]) {
212        if !audio_frame.is_empty() {
213            let inv_scale = 1.0 / 32768.0;
214            let first_sample = audio_frame[0] as f32;
215            output[0] = (first_sample - PRE_EMPHASIS_COEFF * *prev_state) * inv_scale;
216
217            // Use windows(2) to iterate over pairs (prev, curr)
218            // This avoids bounds checks and allows better vectorization
219            for (out, samples) in output[1..].iter_mut().zip(audio_frame.windows(2)) {
220                let prev = samples[0] as f32;
221                let curr = samples[1] as f32;
222                *out = (curr - PRE_EMPHASIS_COEFF * prev) * inv_scale;
223            }
224
225            if !audio_frame.is_empty() {
226                // Store unscaled last sample for next frame
227                *prev_state = audio_frame[audio_frame.len() - 1] as f32;
228            }
229        }
230    }
231
232    pub fn extract_features(&mut self, audio_frame: &[i16]) -> ndarray::Array1<f32> {
233        // Prepare FFT input buffer
234        // 1. Clear buffer
235        self.fft_input.fill(0.0);
236
237        // 2. Pre-emphasis directly into fft_input
238        let copy_len = audio_frame.len().min(WINDOW_SIZE);
239        Self::pre_emphasis(
240            &mut self.pre_emphasis_prev,
241            audio_frame,
242            &mut self.fft_input[..copy_len],
243        );
244
245        // 3. Windowing
246        for (i, sample) in self.fft_input.iter_mut().enumerate().take(copy_len) {
247            *sample *= self.window[i];
248        }
249
250        // 4. FFT
251        self.rfft
252            .process_with_scratch(
253                &mut self.fft_input,
254                &mut self.fft_output,
255                &mut self.fft_scratch,
256            )
257            .unwrap();
258
259        // 5. Power spectrum
260        let n_bins = FFT_SIZE / 2 + 1;
261        let scale = 1.0 / (32768.0 * 32768.0);
262
263        // Compute power spectrum once
264        // Use iterators to avoid bounds checks
265        for (pow, complex) in self.power_spectrum.iter_mut().zip(self.fft_output.iter()) {
266            *pow = (complex.re * complex.re + complex.im * complex.im) * scale;
267        }
268
269        // Mel filter bank features
270        let mut mel_features = ndarray::Array1::<f32>::zeros(MEL_FILTER_BANK_NUM);
271
272        for i in 0..MEL_FILTER_BANK_NUM {
273            let (start, end) = self.mel_filter_ranges[i];
274            let valid_end = end.min(n_bins);
275
276            let mut sum = 0.0;
277            if start < valid_end {
278                // Use slices for dot product to enable vectorization
279                let filter_row = self.mel_filters.row(i);
280                // Safety: we know the row is contiguous because we created it that way
281                // and we haven't modified layout.
282                if let Some(filter_slice) = filter_row.as_slice() {
283                    let filter_sub = &filter_slice[start..valid_end];
284                    let power_sub = &self.power_spectrum[start..valid_end];
285
286                    // This dot product should be auto-vectorized
287                    sum = super::simd::dot_product(filter_sub, power_sub);
288                } else {
289                    // Fallback if not contiguous (should not happen)
290                    for j in start..valid_end {
291                        sum += self.mel_filters[[i, j]] * self.power_spectrum[j];
292                    }
293                }
294            }
295            mel_features[i] = (sum + EPS).ln();
296        }
297
298        // Simple pitch estimation (using 0 as in Python code)
299        let pitch_freq = 0.0;
300
301        // Combine features
302        let mut features = ndarray::Array1::<f32>::zeros(FEATURE_LEN);
303        features
304            .slice_mut(ndarray::s![..MEL_FILTER_BANK_NUM])
305            .assign(&mel_features);
306        features[MEL_FILTER_BANK_NUM] = pitch_freq;
307
308        // Feature normalization
309        // Use pre-calculated inverse STDs and iterators
310        for (feat, (&mean, &inv_std)) in features
311            .iter_mut()
312            .zip(FEATURE_MEANS.iter().zip(self.inv_stds.iter()))
313        {
314            *feat = (*feat - mean) * inv_std;
315        }
316
317        features
318    }
319}
320
321// 3D Tensor (H, W, C)
322#[derive(Clone, Debug)]
323struct Tensor3D {
324    data: Vec<f32>,
325    h: usize,
326    w: usize,
327    c: usize,
328}
329
330impl Tensor3D {
331    fn new(h: usize, w: usize, c: usize) -> Self {
332        Self {
333            data: vec![0.0; h * w * c],
334            h,
335            w,
336            c,
337        }
338    }
339
340    fn zeros(&mut self) {
341        self.data.fill(0.0);
342    }
343
344    #[inline(always)]
345    fn get(&self, y: usize, x: usize, ch: usize) -> f32 {
346        // Safety: We assume caller checks bounds or we rely on Vec bounds check
347        self.data[y * self.w * self.c + x * self.c + ch]
348    }
349
350    #[inline(always)]
351    fn set(&mut self, y: usize, x: usize, ch: usize, val: f32) {
352        self.data[y * self.w * self.c + x * self.c + ch] = val;
353    }
354}
355
356// Conv2D Layer
357struct Conv2dLayer {
358    weights: Vec<f32>,      // [out_c, in_c/groups, kh, kw]
359    bias: Option<Vec<f32>>, // [out_c]
360    in_channels: usize,
361    out_channels: usize,
362    kernel_h: usize,
363    kernel_w: usize,
364    stride_h: usize,
365    stride_w: usize,
366    padding: [usize; 4], // [top, left, bottom, right]
367    groups: usize,
368}
369
370impl Conv2dLayer {
371    fn new(
372        in_channels: usize,
373        out_channels: usize,
374        kernel_h: usize,
375        kernel_w: usize,
376        stride_h: usize,
377        stride_w: usize,
378        padding: [usize; 4],
379        groups: usize,
380    ) -> Self {
381        Self {
382            weights: vec![0.0; out_channels * (in_channels / groups) * kernel_h * kernel_w],
383            bias: None,
384            in_channels,
385            out_channels,
386            kernel_h,
387            kernel_w,
388            stride_h,
389            stride_w,
390            padding,
391            groups,
392        }
393    }
394
395    // Optimized forward pass with pre-allocated output buffer
396    fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
397        let out_h = output.h;
398        let out_w = output.w;
399
400        // Optimization for Conv1_DW (3x3, s=1, p=0, in=1, out=1)
401        // Input: [3, 41, 1], Output: [1, 39, 1]
402        if self.in_channels == 1
403            && self.out_channels == 1
404            && self.kernel_h == 3
405            && self.kernel_w == 3
406            && self.stride_h == 1
407            && self.stride_w == 1
408            && self.padding == [0, 0, 0, 0]
409        {
410            let bias = self.bias.as_ref().map(|b| b[0]).unwrap_or(0.0);
411            let w = &self.weights; // 9 elements
412
413            // Hardcoded 3x3 convolution
414            // y is always 0 because out_h=1 (input_h=3, k=3, s=1 -> (3-3)/1 + 1 = 1)
415            for x in 0..out_w {
416                // input x range: x to x+3
417                // input y range: 0 to 3
418                let mut sum = bias;
419
420                // Unroll 3x3 kernel
421                // Row 0
422                sum += input.get(0, x, 0) * w[0];
423                sum += input.get(0, x + 1, 0) * w[1];
424                sum += input.get(0, x + 2, 0) * w[2];
425
426                // Row 1
427                sum += input.get(1, x, 0) * w[3];
428                sum += input.get(1, x + 1, 0) * w[4];
429                sum += input.get(1, x + 2, 0) * w[5];
430
431                // Row 2
432                sum += input.get(2, x, 0) * w[6];
433                sum += input.get(2, x + 1, 0) * w[7];
434                sum += input.get(2, x + 2, 0) * w[8];
435
436                output.set(0, x, 0, sum);
437            }
438            return;
439        }
440
441        // Optimization for Conv1_PW (1x1, s=1, p=0, in=1, out=16)
442        // Input: [1, 39, 1], Output: [1, 39, 16]
443        if self.in_channels == 1
444            && self.out_channels == 16
445            && self.kernel_h == 1
446            && self.kernel_w == 1
447            && self.stride_h == 1
448            && self.stride_w == 1
449        {
450            let w = &self.weights; // 16 elements
451            let b = self.bias.as_ref(); // 16 elements
452
453            for x in 0..out_w {
454                let val = input.get(0, x, 0);
455
456                // Unroll 16 channels
457                for oc in 0..16 {
458                    let bias = if let Some(bias_vec) = b {
459                        bias_vec[oc]
460                    } else {
461                        0.0
462                    };
463                    let res = val * w[oc] + bias;
464                    output.set(0, x, oc, res);
465                }
466            }
467            return;
468        }
469
470        // Optimization for Conv2_DW (1x3, s=2, p=[0,1,0,1], in=16, out=16, groups=16)
471        // Input: [1, 19, 16], Output: [1, 10, 16]
472        if self.groups == 16
473            && self.in_channels == 16
474            && self.out_channels == 16
475            && self.kernel_h == 1
476            && self.kernel_w == 3
477            && self.stride_w == 2
478            && self.padding == [0, 1, 0, 1]
479        {
480            let w = &self.weights; // 16 * 1 * 1 * 3 = 48 elements
481            let b = self.bias.as_ref();
482
483            for c in 0..16 {
484                let w_offset = c * 3;
485                let w0 = w[w_offset];
486                let w1 = w[w_offset + 1];
487                let w2 = w[w_offset + 2];
488                let bias = if let Some(bias_vec) = b {
489                    bias_vec[c]
490                } else {
491                    0.0
492                };
493
494                // x=0: in_x = -1, 0, 1. Valid: 0, 1. (w1, w2)
495                let val0 = input.get(0, 0, c);
496                let val1 = input.get(0, 1, c);
497                let sum0 = val0 * w1 + val1 * w2 + bias;
498                output.set(0, 0, c, sum0);
499
500                // x=1..9: in_x = 1, 3, 5, ... 17.
501                // x=1: in_x_origin = 1. kx=0->1, kx=1->2, kx=2->3.
502                // ...
503                // x=9: in_x_origin = 17. kx=0->17, kx=1->18, kx=2->19(skip).
504
505                // Middle loop x=1..8
506                for x in 1..9 {
507                    let in_x_origin = x * 2 - 1;
508                    let v0 = input.get(0, in_x_origin, c);
509                    let v1 = input.get(0, in_x_origin + 1, c);
510                    let v2 = input.get(0, in_x_origin + 2, c);
511                    let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
512                    output.set(0, x, c, sum);
513                }
514
515                // x=9: in_x_origin = 17. Valid: 17, 18. (w0, w1)
516                let v0 = input.get(0, 17, c);
517                let v1 = input.get(0, 18, c);
518                let sum9 = v0 * w0 + v1 * w1 + bias;
519                output.set(0, 9, c, sum9);
520            }
521            return;
522        }
523
524        // Optimization for Conv2_PW (1x1, s=1, p=0, in=16, out=16)
525        // Input: [1, 10, 16], Output: [1, 10, 16]
526        if self.in_channels == 16
527            && self.out_channels == 16
528            && self.kernel_h == 1
529            && self.kernel_w == 1
530            && self.stride_h == 1
531            && self.stride_w == 1
532            && self.groups == 1
533        {
534            let w = &self.weights; // 16 * 16 = 256 elements
535            let b = self.bias.as_ref();
536
537            for x in 0..out_w {
538                // Pre-load input channel values for this pixel to registers (hopefully)
539                let mut in_vals = [0.0; 16];
540                for ic in 0..16 {
541                    in_vals[ic] = input.get(0, x, ic);
542                }
543
544                for oc in 0..16 {
545                    let mut sum = if let Some(bias_vec) = b {
546                        bias_vec[oc]
547                    } else {
548                        0.0
549                    };
550                    let w_offset = oc * 16;
551
552                    // Unroll dot product
553                    for ic in 0..16 {
554                        sum += in_vals[ic] * w[w_offset + ic];
555                    }
556                    output.set(0, x, oc, sum);
557                }
558            }
559            return;
560        }
561
562        // Optimization for Conv3_DW (1x3, s=2, p=[0,1,0,1], in=16, out=16, groups=16)
563        // Input: [1, 10, 16], Output: [1, 5, 16]
564        if self.groups == 16
565            && self.in_channels == 16
566            && self.out_channels == 16
567            && self.kernel_h == 1
568            && self.kernel_w == 3
569            && self.stride_w == 2
570            && self.padding == [0, 1, 0, 1]
571            && out_w == 5
572        {
573            let w = &self.weights;
574            let b = self.bias.as_ref();
575
576            for c in 0..16 {
577                let w_offset = c * 3;
578                let w0 = w[w_offset];
579                let w1 = w[w_offset + 1];
580                let w2 = w[w_offset + 2];
581                let bias = if let Some(bias_vec) = b {
582                    bias_vec[c]
583                } else {
584                    0.0
585                };
586
587                // x=0: in_x = -1. Valid: 0, 1. (w1, w2)
588                let val0 = input.get(0, 0, c);
589                let val1 = input.get(0, 1, c);
590                let sum0 = val0 * w1 + val1 * w2 + bias;
591                output.set(0, 0, c, sum0);
592
593                // x=1..5: in_x = 1, 3, 5, 7.
594                // Max index accessed: 7 + 2 = 9. Input width is 10 (0..9). Safe.
595                for x in 1..5 {
596                    let in_x_origin = x * 2 - 1;
597                    let v0 = input.get(0, in_x_origin, c);
598                    let v1 = input.get(0, in_x_origin + 1, c);
599                    let v2 = input.get(0, in_x_origin + 2, c);
600                    let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
601                    output.set(0, x, c, sum);
602                }
603            }
604            return;
605        }
606
607        // Optimization for Conv3_PW (1x1, s=1, p=0, in=16, out=32)
608        // Input: [1, 5, 16], Output: [1, 5, 32]
609        if self.in_channels == 16
610            && self.out_channels == 32
611            && self.kernel_h == 1
612            && self.kernel_w == 1
613            && self.stride_h == 1
614            && self.stride_w == 1
615            && self.groups == 1
616        {
617            let w = &self.weights; // 32 * 16 = 512 elements
618            let b = self.bias.as_ref();
619
620            for x in 0..out_w {
621                let mut in_vals = [0.0; 16];
622                for ic in 0..16 {
623                    in_vals[ic] = input.get(0, x, ic);
624                }
625
626                for oc in 0..32 {
627                    let mut sum = if let Some(bias_vec) = b {
628                        bias_vec[oc]
629                    } else {
630                        0.0
631                    };
632                    let w_offset = oc * 16;
633
634                    for ic in 0..16 {
635                        sum += in_vals[ic] * w[w_offset + ic];
636                    }
637                    output.set(0, x, oc, sum);
638                }
639            }
640            return;
641        }
642
643        // Reset output buffer
644        output.zeros();
645
646        let in_c_per_group = self.in_channels / self.groups;
647        let out_c_per_group = self.out_channels / self.groups;
648
649        // Optimization: Check if we can use fast path (no padding, stride 1, etc)
650        // But here we have padding and strides.
651
652        // Optimization: Lift bias addition out of inner loop
653        if let Some(b) = &self.bias {
654            for g in 0..self.groups {
655                for oc in 0..out_c_per_group {
656                    let out_ch_idx = g * out_c_per_group + oc;
657                    let bias_val = b[out_ch_idx];
658                    // Initialize output with bias
659                    for y in 0..out_h {
660                        for x in 0..out_w {
661                            output.set(y, x, out_ch_idx, bias_val);
662                        }
663                    }
664                }
665            }
666        }
667
668        for g in 0..self.groups {
669            for oc in 0..out_c_per_group {
670                let out_ch_idx = g * out_c_per_group + oc;
671
672                // Pre-calculate weight offset for this output channel
673                let w_base = out_ch_idx * (in_c_per_group * self.kernel_h * self.kernel_w);
674
675                for y in 0..out_h {
676                    let in_y_origin = (y * self.stride_h) as isize - self.padding[0] as isize;
677
678                    for x in 0..out_w {
679                        let in_x_origin = (x * self.stride_w) as isize - self.padding[1] as isize;
680
681                        let mut sum = 0.0;
682
683                        for ic in 0..in_c_per_group {
684                            let in_ch_idx = g * in_c_per_group + ic;
685                            let w_ic_base = w_base + ic * (self.kernel_h * self.kernel_w);
686
687                            for ky in 0..self.kernel_h {
688                                let in_y = in_y_origin + ky as isize;
689                                if in_y >= 0 && in_y < input.h as isize {
690                                    let w_ky_base = w_ic_base + ky * self.kernel_w;
691
692                                    for kx in 0..self.kernel_w {
693                                        let in_x = in_x_origin + kx as isize;
694
695                                        if in_x >= 0 && in_x < input.w as isize {
696                                            // Hot path
697                                            let val =
698                                                input.get(in_y as usize, in_x as usize, in_ch_idx);
699                                            let w_idx = w_ky_base + kx;
700                                            // Safety: w_idx is within bounds by construction
701                                            let w = unsafe { *self.weights.get_unchecked(w_idx) };
702                                            sum += val * w;
703                                        }
704                                    }
705                                }
706                            }
707                        }
708
709                        // Accumulate to output (which already has bias)
710                        let current = output.get(y, x, out_ch_idx);
711                        output.set(y, x, out_ch_idx, current + sum);
712                    }
713                }
714            }
715        }
716    }
717}
718
719// MaxPool2D Layer
720struct MaxPool2dLayer {
721    kernel_h: usize,
722    kernel_w: usize,
723    stride_h: usize,
724    stride_w: usize,
725}
726
727impl MaxPool2dLayer {
728    fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
729        let out_h = output.h;
730        let out_w = output.w;
731
732        // Optimization for MaxPool (1x3, s=1x2)
733        if self.kernel_h == 1 && self.kernel_w == 3 && self.stride_h == 1 && self.stride_w == 2 {
734            for c in 0..input.c {
735                // y is always 0
736                for x in 0..out_w {
737                    let in_x = x * 2;
738                    // We assume valid padding so in_x+2 is within bounds
739                    let v0 = input.get(0, in_x, c);
740                    let v1 = input.get(0, in_x + 1, c);
741                    let v2 = input.get(0, in_x + 2, c);
742
743                    let max_v = v0.max(v1).max(v2);
744                    output.set(0, x, c, max_v);
745                }
746            }
747            return;
748        }
749
750        for c in 0..input.c {
751            for y in 0..out_h {
752                for x in 0..out_w {
753                    let mut max_val = f32::NEG_INFINITY;
754
755                    for ky in 0..self.kernel_h {
756                        for kx in 0..self.kernel_w {
757                            let in_y = y * self.stride_h + ky;
758                            let in_x = x * self.stride_w + kx;
759                            // MaxPool usually doesn't have padding in this model (valid padding)
760                            // So we can skip bounds check if we trust output size calculation
761                            let val = input.get(in_y, in_x, c);
762                            if val > max_val {
763                                max_val = val;
764                            }
765                        }
766                    }
767                    output.set(y, x, c, max_val);
768                }
769            }
770        }
771    }
772}
773
774// Simple Linear Layer
775struct LinearLayer {
776    weights: Vec<f32>, // Flattened [out_features, in_features]
777    bias: Vec<f32>,    // [out_features]
778    in_features: usize,
779    out_features: usize,
780}
781
782impl LinearLayer {
783    fn new(in_features: usize, out_features: usize) -> Self {
784        // Initialize with dummy weights (or load from file)
785        // For now, we initialize with zeros/randoms if we were training,
786        // but here we just create the structure.
787        Self {
788            weights: vec![0.0; out_features * in_features],
789            bias: vec![0.0; out_features],
790            in_features,
791            out_features,
792        }
793    }
794
795    fn forward(&self, input: &[f32], output: &mut [f32]) {
796        assert_eq!(input.len(), self.in_features);
797        assert_eq!(output.len(), self.out_features);
798
799        // Matrix-Vector Multiplication: y = Wx + b
800        // Optimized with iterators for auto-vectorization
801        for (i, out_val) in output.iter_mut().enumerate() {
802            let weight_row_start = i * self.in_features;
803            let weight_row = &self.weights[weight_row_start..weight_row_start + self.in_features];
804
805            let dot_product: f32 = weight_row
806                .iter()
807                .zip(input.iter())
808                .map(|(&w, &x)| w * x)
809                .sum();
810
811            *out_val = dot_product + self.bias[i];
812        }
813    }
814}
815
816// LSTM Layer
817struct LstmLayer {
818    input_size: usize,
819    hidden_size: usize,
820    // Weights: 4 * hidden_size rows (i, f, g, o)
821    weight_ih: Vec<f32>, // [4 * hidden_size, input_size]
822    weight_hh: Vec<f32>, // [4 * hidden_size, hidden_size]
823    bias_ih: Vec<f32>,   // [4 * hidden_size]
824    bias_hh: Vec<f32>,   // [4 * hidden_size]
825
826    // Scratch buffers
827    gates_buffer: Vec<f32>, // [4 * hidden_size]
828}
829
830impl LstmLayer {
831    fn new(input_size: usize, hidden_size: usize) -> Self {
832        Self {
833            input_size,
834            hidden_size,
835            weight_ih: vec![0.0; 4 * hidden_size * input_size],
836            weight_hh: vec![0.0; 4 * hidden_size * hidden_size],
837            bias_ih: vec![0.0; 4 * hidden_size],
838            bias_hh: vec![0.0; 4 * hidden_size],
839            gates_buffer: vec![0.0; 4 * hidden_size],
840        }
841    }
842
843    fn forward_optimized(&mut self, input: &[f32], hidden: &mut [f32], cell: &mut [f32]) {
844        let h_size = self.hidden_size;
845
846        // 1. Compute W_ih * x + b_ih for all gates (i, f, g, o)
847        for i in 0..4 * h_size {
848            let w_start = i * self.input_size;
849            let w_row = &self.weight_ih[w_start..w_start + self.input_size];
850            let dot: f32 = w_row.iter().zip(input).map(|(&w, &x)| w * x).sum();
851            self.gates_buffer[i] = dot + self.bias_ih[i];
852        }
853
854        // 2. Compute W_hh * h + b_hh for all gates
855        // We can add directly to gates_buffer
856        for i in 0..4 * h_size {
857            let w_start = i * h_size;
858            let w_row = &self.weight_hh[w_start..w_start + h_size];
859            let dot: f32 = w_row.iter().zip(hidden.iter()).map(|(&w, &h)| w * h).sum();
860            self.gates_buffer[i] += dot + self.bias_hh[i];
861        }
862
863        // 3. Apply activations and update states
864        // ONNX Gates order: i, o, f, g (c)
865        for i in 0..h_size {
866            let i_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i]);
867            let o_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + h_size]);
868            let f_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + 2 * h_size]);
869            let g_gate = crate::media::vad::utils::tanh(self.gates_buffer[i + 3 * h_size]);
870
871            // c_t = f_t * c_{t-1} + i_t * g_t
872            cell[i] = f_gate * cell[i] + i_gate * g_gate;
873
874            // h_t = o_t * tanh(c_t)
875            hidden[i] = o_gate * crate::media::vad::utils::tanh(cell[i]);
876        }
877    }
878}
879
880pub struct TinyTen {
881    config: VADOption,
882    buffer: PcmBuf,
883    last_timestamp: u64,
884    chunk_size: usize,
885
886    feature_extractor: TenFeatureExtractor,
887    feature_buffer: ndarray::Array2<f32>,
888
889    // Model Layers
890    // Block 1
891    conv1_dw: Conv2dLayer,
892    conv1_pw: Conv2dLayer,
893    maxpool: MaxPool2dLayer,
894
895    // Block 2
896    conv2_dw: Conv2dLayer,
897    conv2_pw: Conv2dLayer,
898
899    // Block 3
900    conv3_dw: Conv2dLayer,
901    conv3_pw: Conv2dLayer,
902
903    lstm1: LstmLayer,
904    lstm2: LstmLayer,
905    dense1: LinearLayer,
906    dense2: LinearLayer,
907
908    // Model States
909    h1: Vec<f32>,
910    c1: Vec<f32>,
911    h2: Vec<f32>,
912    c2: Vec<f32>,
913
914    // Scratch Buffers (Pre-allocated)
915    t_input: Tensor3D,
916    t_conv1_dw: Tensor3D,
917    t_conv1_pw: Tensor3D,
918    t_maxpool: Tensor3D,
919    t_conv2_dw: Tensor3D,
920    t_conv2_pw: Tensor3D,
921    t_conv3_dw: Tensor3D,
922    t_conv3_pw: Tensor3D,
923
924    dense_input_buffer: Vec<f32>,
925    dense1_out_buffer: Vec<f32>,
926
927    last_score: Option<f32>,
928}
929
930const WEIGHTS_BYTES: &[u8] = include_bytes!("tiny_tenvad.bin");
931
932impl TinyTen {
933    pub fn new(config: VADOption) -> Result<Self> {
934        if config.samplerate != 16000 {
935            return Err(anyhow::anyhow!("TinyVad only supports 16kHz audio"));
936        }
937
938        let feature_extractor = TenFeatureExtractor::new();
939        let feature_buffer = ndarray::Array2::<f32>::zeros((CONTEXT_WINDOW_LEN, FEATURE_LEN));
940
941        // Initialize layers
942        // Conv1: Input [1, 3, 41, 1]
943        // DW: 3x3, stride 1, pad 0. Out: [1, 1, 39, 1]
944        let conv1_dw = Conv2dLayer::new(1, 1, 3, 3, 1, 1, [0, 0, 0, 0], 1);
945        // PW: 1x1, stride 1, pad 0. Out: [1, 1, 39, 16]
946        let conv1_pw = Conv2dLayer::new(1, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
947
948        // MaxPool: 1x3, stride 1x2. Out: [1, 1, 19, 16]
949        let maxpool = MaxPool2dLayer {
950            kernel_h: 1,
951            kernel_w: 3,
952            stride_h: 1,
953            stride_w: 2,
954        };
955
956        // Conv2: Input [1, 1, 19, 16]
957        // DW: 1x3, stride 2x2, pad [0, 1, 0, 1]. Out: [1, 1, 10, 16]
958        let conv2_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 1, 0, 1], 16);
959        // PW: 1x1, stride 1, pad 0. Out: [1, 1, 10, 16]
960        let conv2_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
961
962        // Conv3: Input [1, 1, 10, 16]
963        // DW: 1x3, stride 2x2, pad [0, 0, 0, 1]. Out: [1, 1, 5, 16]
964        let conv3_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 0, 0, 1], 16);
965        // PW: 1x1, stride 1, pad 0. Out: [1, 1, 5, 16]
966        let conv3_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
967
968        // LSTM Input size: 5 * 16 = 80.
969        let lstm1 = LstmLayer::new(80, HIDDEN_SIZE);
970        let lstm2 = LstmLayer::new(HIDDEN_SIZE, HIDDEN_SIZE);
971
972        let dense1 = LinearLayer::new(HIDDEN_SIZE * 2, 32);
973        let dense2 = LinearLayer::new(32, 1);
974
975        // Pre-allocate scratch buffers
976        let t_input = Tensor3D::new(CONTEXT_WINDOW_LEN, FEATURE_LEN, 1);
977        let t_conv1_dw = Tensor3D::new(1, 39, 1);
978        let t_conv1_pw = Tensor3D::new(1, 39, 16);
979        let t_maxpool = Tensor3D::new(1, 19, 16);
980        let t_conv2_dw = Tensor3D::new(1, 10, 16);
981        let t_conv2_pw = Tensor3D::new(1, 10, 16);
982        let t_conv3_dw = Tensor3D::new(1, 5, 16);
983        let t_conv3_pw = Tensor3D::new(1, 5, 16);
984
985        let dense_input_buffer = vec![0.0; HIDDEN_SIZE * 2];
986        let dense1_out_buffer = vec![0.0; 32];
987
988        let mut vad = Self {
989            config,
990            buffer: Vec::new(),
991            chunk_size: HOP_SIZE,
992            last_timestamp: 0,
993            feature_extractor,
994            feature_buffer,
995            conv1_dw,
996            conv1_pw,
997            maxpool,
998            conv2_dw,
999            conv2_pw,
1000            conv3_dw,
1001            conv3_pw,
1002            lstm1,
1003            lstm2,
1004            dense1,
1005            dense2,
1006            h1: vec![0.0; HIDDEN_SIZE],
1007            c1: vec![0.0; HIDDEN_SIZE],
1008            h2: vec![0.0; HIDDEN_SIZE],
1009            c2: vec![0.0; HIDDEN_SIZE],
1010            t_input,
1011            t_conv1_dw,
1012            t_conv1_pw,
1013            t_maxpool,
1014            t_conv2_dw,
1015            t_conv2_pw,
1016            t_conv3_dw,
1017            t_conv3_pw,
1018            dense_input_buffer,
1019            dense1_out_buffer,
1020            last_score: None,
1021        };
1022
1023        vad.load_weights_from_bytes(WEIGHTS_BYTES)?;
1024        Ok(vad)
1025    }
1026
1027    pub fn predict(&mut self, samples: &[i16]) -> f32 {
1028        // 1. Extract features
1029        let features = self.feature_extractor.extract_features(samples);
1030
1031        // 2. Update context window
1032        for i in 0..CONTEXT_WINDOW_LEN - 1 {
1033            for j in 0..FEATURE_LEN {
1034                self.feature_buffer[[i, j]] = self.feature_buffer[[i + 1, j]];
1035            }
1036        }
1037        for j in 0..FEATURE_LEN {
1038            self.feature_buffer[[CONTEXT_WINDOW_LEN - 1, j]] = features[j];
1039        }
1040
1041        // 3. Prepare Input Tensor [1, 3, 41, 1]
1042        // H=3 (Time), W=41 (Freq), C=1
1043        // Reuse t_input
1044        for i in 0..CONTEXT_WINDOW_LEN {
1045            for j in 0..FEATURE_LEN {
1046                self.t_input.set(i, j, 0, self.feature_buffer[[i, j]]);
1047            }
1048        }
1049
1050        // 4. Forward Pass
1051        // Block 1
1052        self.conv1_dw
1053            .forward_into(&self.t_input, &mut self.t_conv1_dw);
1054        self.conv1_pw
1055            .forward_into(&self.t_conv1_dw, &mut self.t_conv1_pw);
1056
1057        // Apply Relu
1058        for val in self.t_conv1_pw.data.iter_mut() {
1059            *val = val.max(0.0);
1060        }
1061
1062        self.maxpool
1063            .forward_into(&self.t_conv1_pw, &mut self.t_maxpool);
1064
1065        // Block 2
1066        self.conv2_dw
1067            .forward_into(&self.t_maxpool, &mut self.t_conv2_dw);
1068        self.conv2_pw
1069            .forward_into(&self.t_conv2_dw, &mut self.t_conv2_pw);
1070
1071        for val in self.t_conv2_pw.data.iter_mut() {
1072            *val = val.max(0.0);
1073        }
1074
1075        // Block 3
1076        self.conv3_dw
1077            .forward_into(&self.t_conv2_pw, &mut self.t_conv3_dw);
1078        self.conv3_pw
1079            .forward_into(&self.t_conv3_dw, &mut self.t_conv3_pw);
1080
1081        for val in self.t_conv3_pw.data.iter_mut() {
1082            *val = val.max(0.0);
1083        }
1084
1085        // Flatten for LSTM
1086        // x shape should be [1, 5, 16] -> 80 elements
1087        let lstm_input = &self.t_conv3_pw.data;
1088
1089        // LSTM 1
1090        self.lstm1
1091            .forward_optimized(lstm_input, &mut self.h1, &mut self.c1);
1092
1093        // LSTM 2
1094        self.lstm2
1095            .forward_optimized(&self.h1, &mut self.h2, &mut self.c2);
1096
1097        // Concat h2, h1 (Graph says concat_1 inputs: lstm2, lstm1)
1098        // dense_input_buffer is [h2, h1]
1099        let h_size = HIDDEN_SIZE;
1100        self.dense_input_buffer[0..h_size].copy_from_slice(&self.h2);
1101        self.dense_input_buffer[h_size..2 * h_size].copy_from_slice(&self.h1);
1102
1103        // Dense 1
1104        self.dense1
1105            .forward(&self.dense_input_buffer, &mut self.dense1_out_buffer);
1106        // Relu
1107        for val in self.dense1_out_buffer.iter_mut() {
1108            *val = val.max(0.0);
1109        }
1110
1111        // Dense 2
1112        let mut output = [0.0; 1];
1113        self.dense2.forward(&self.dense1_out_buffer, &mut output);
1114
1115        let score = 1.0 / (1.0 + (-output[0]).exp()); // Sigmoid
1116        self.last_score = Some(score);
1117
1118        score
1119    }
1120
1121    fn load_weights_from_bytes(&mut self, bytes: &[u8]) -> Result<()> {
1122        let mut offset = 0;
1123
1124        // Helper to read u32
1125        let read_u32 = |offset: &mut usize, buf: &[u8]| -> u32 {
1126            let val = u32::from_le_bytes(buf[*offset..*offset + 4].try_into().unwrap());
1127            *offset += 4;
1128            val
1129        };
1130
1131        let num_tensors = read_u32(&mut offset, bytes);
1132
1133        let mut weights = std::collections::HashMap::new();
1134
1135        for _ in 0..num_tensors {
1136            let name_len = read_u32(&mut offset, bytes) as usize;
1137            let name_bytes = &bytes[offset..offset + name_len];
1138            let name = std::str::from_utf8(name_bytes)?.to_string();
1139            offset += name_len;
1140
1141            let shape_len = read_u32(&mut offset, bytes) as usize;
1142            let mut shape = Vec::new();
1143            for _ in 0..shape_len {
1144                shape.push(read_u32(&mut offset, bytes));
1145            }
1146
1147            let data_len = read_u32(&mut offset, bytes) as usize;
1148            let data_bytes = &bytes[offset..offset + data_len];
1149            let floats: Vec<f32> = data_bytes
1150                .chunks_exact(4)
1151                .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
1152                .collect();
1153            offset += data_len;
1154
1155            weights.insert(name, (shape, floats));
1156        }
1157
1158        // Assign weights
1159        if let Some(w) = weights.get("conv1_dw_weight") {
1160            self.conv1_dw.weights = w.1.clone();
1161        }
1162        if let Some(w) = weights.get("conv1_pw_weight") {
1163            self.conv1_pw.weights = w.1.clone();
1164        }
1165        if let Some(w) = weights.get("conv1_bias") {
1166            self.conv1_pw.bias = Some(w.1.clone());
1167        }
1168
1169        if let Some(w) = weights.get("conv2_dw_weight") {
1170            self.conv2_dw.weights = w.1.clone();
1171        }
1172        if let Some(w) = weights.get("conv2_pw_weight") {
1173            self.conv2_pw.weights = w.1.clone();
1174        }
1175        if let Some(w) = weights.get("conv2_bias") {
1176            self.conv2_pw.bias = Some(w.1.clone());
1177        }
1178
1179        if let Some(w) = weights.get("conv3_dw_weight") {
1180            self.conv3_dw.weights = w.1.clone();
1181        }
1182        if let Some(w) = weights.get("conv3_pw_weight") {
1183            self.conv3_pw.weights = w.1.clone();
1184        }
1185        if let Some(w) = weights.get("conv3_bias") {
1186            self.conv3_pw.bias = Some(w.1.clone());
1187        }
1188
1189        if let Some(w) = weights.get("lstm1_w_ih") {
1190            self.lstm1.weight_ih = w.1.clone();
1191        }
1192        if let Some(w) = weights.get("lstm1_w_hh") {
1193            self.lstm1.weight_hh = w.1.clone();
1194        }
1195        if let Some(w) = weights.get("lstm1_bias") {
1196            // Split bias into ih and hh if needed, or just use as is.
1197            // ONNX LSTM bias is [8*H]. Our LstmLayer expects bias_ih [4*H] and bias_hh [4*H].
1198            // Usually first half is W_b, second half is R_b.
1199            let b = &w.1;
1200            if b.len() == 8 * HIDDEN_SIZE {
1201                self.lstm1.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1202                self.lstm1.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1203            }
1204        }
1205
1206        if let Some(w) = weights.get("lstm2_w_ih") {
1207            self.lstm2.weight_ih = w.1.clone();
1208        }
1209        if let Some(w) = weights.get("lstm2_w_hh") {
1210            self.lstm2.weight_hh = w.1.clone();
1211        }
1212        if let Some(w) = weights.get("lstm2_bias") {
1213            let b = &w.1;
1214            if b.len() == 8 * HIDDEN_SIZE {
1215                self.lstm2.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1216                self.lstm2.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1217            }
1218        }
1219
1220        if let Some(w) = weights.get("dense1_weight") {
1221            self.dense1.weights = w.1.clone();
1222        }
1223        if let Some(w) = weights.get("dense1_bias") {
1224            self.dense1.bias = w.1.clone();
1225        }
1226
1227        if let Some(w) = weights.get("dense2_weight") {
1228            self.dense2.weights = w.1.clone();
1229        }
1230        if let Some(w) = weights.get("dense2_bias") {
1231            self.dense2.bias = w.1.clone();
1232        }
1233        Ok(())
1234    }
1235}
1236
1237impl VadEngine for TinyTen {
1238    fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)> {
1239        let samples = match &frame.samples {
1240            Samples::PCM { samples } => samples,
1241            _ => return Some((false, frame.timestamp)),
1242        };
1243
1244        self.buffer.extend_from_slice(samples);
1245
1246        if self.buffer.len() >= self.chunk_size {
1247            let chunk: Vec<i16> = self.buffer.drain(..self.chunk_size).collect();
1248            let score = self.predict(&chunk);
1249
1250            let is_voice = score > self.config.voice_threshold;
1251            let chunk_duration_ms = (self.chunk_size as u64 * 1000) / (frame.sample_rate as u64);
1252
1253            if self.last_timestamp == 0 {
1254                self.last_timestamp = frame.timestamp;
1255            }
1256
1257            let chunk_timestamp = self.last_timestamp;
1258            self.last_timestamp += chunk_duration_ms;
1259
1260            return Some((is_voice, chunk_timestamp));
1261        }
1262
1263        None
1264    }
1265}