1use super::{VADOption, VadEngine};
2use crate::media::{AudioFrame, PcmBuf, Samples};
3use anyhow::Result;
4use realfft::{RealFftPlanner, RealToComplex};
5use std::sync::Arc;
6
7const SAMPLE_RATE: u32 = 16000;
9const HOP_SIZE: usize = 256; const FFT_SIZE: usize = 1024;
11const WINDOW_SIZE: usize = 768;
12const MEL_FILTER_BANK_NUM: usize = 40;
13const FEATURE_LEN: usize = 41; const CONTEXT_WINDOW_LEN: usize = 3;
15const HIDDEN_SIZE: usize = 64;
16const EPS: f32 = 1e-20;
17const PRE_EMPHASIS_COEFF: f32 = 0.97;
18
19const FEATURE_MEANS: [f32; FEATURE_LEN] = [
21 -8.198_236,
22 -6.265_716_6,
23 -5.483_818_5,
24 -4.758_691_3,
25 -4.417_089,
26 -4.142_893,
27 -3.912_850_4,
28 -3.845_928,
29 -3.657_090_4,
30 -3.723_418_7,
31 -3.876_134_2,
32 -3.843_891,
33 -3.690_405_1,
34 -3.756_065_8,
35 -3.698_696_1,
36 -3.650_463,
37 -3.700_468_8,
38 -3.567_321_3,
39 -3.498_900_2,
40 -3.477_807,
41 -3.458_816,
42 -3.444_923_9,
43 -3.401_328_6,
44 -3.306_261_3,
45 -3.278_556_8,
46 -3.233_250_9,
47 -3.198_616,
48 -3.204_526_4,
49 -3.208_798_6,
50 -3.257_838,
51 -3.381_376_7,
52 -3.534_021_4,
53 -3.640_868,
54 -3.726_858_9,
55 -3.773_731,
56 -3.804_667_2,
57 -3.832_901,
58 -3.871_120_5,
59 -3.990_593,
60 -4.480_289_5,
61 9.235_69e1,
62];
63
64const FEATURE_STDS: [f32; FEATURE_LEN] = [
65 5.166_064,
66 4.977_21,
67 4.698_896,
68 4.630_621_4,
69 4.634_348,
70 4.641_156,
71 4.640_676_5,
72 4.666_367,
73 4.650_534_6,
74 4.640_021,
75 4.637_4,
76 4.620_099,
77 4.596_316_3,
78 4.562_655,
79 4.554_36,
80 4.566_910_7,
81 4.562_49,
82 4.562_413,
83 4.585_299_5,
84 4.600_179_7,
85 4.592_846,
86 4.585_923,
87 4.583_496_6,
88 4.626_093,
89 4.626_958,
90 4.626_289_4,
91 4.637_006,
92 4.683_016,
93 4.726_814,
94 4.734_29,
95 4.753_227,
96 4.849_723,
97 4.869_435,
98 4.884_483,
99 4.921_327,
100 4.959_212_3,
101 4.996_619,
102 5.044_823_6,
103 5.072_217,
104 5.096_439_4,
105 1.152_136_9e2,
106];
107
108pub struct TenFeatureExtractor {
109 pre_emphasis_prev: f32,
110 mel_filters: ndarray::Array2<f32>,
111 mel_filter_ranges: Vec<(usize, usize)>,
112 window: Vec<f32>,
113 rfft: Arc<dyn RealToComplex<f32>>,
115 fft_scratch: Vec<realfft::num_complex::Complex<f32>>,
116 fft_output: Vec<realfft::num_complex::Complex<f32>>,
117 fft_input: Vec<f32>,
118 power_spectrum: Vec<f32>,
119 inv_stds: Vec<f32>,
120}
121
122impl TenFeatureExtractor {
123 pub fn new() -> Self {
124 let (mel_filters, mel_filter_ranges) = Self::generate_mel_filters();
126
127 let window = super::utils::generate_hann_window(WINDOW_SIZE, false);
129
130 let mut planner = RealFftPlanner::<f32>::new();
132 let rfft = planner.plan_fft_forward(FFT_SIZE);
133 let fft_scratch = rfft.make_scratch_vec();
134 let fft_output = rfft.make_output_vec();
135 let fft_input = rfft.make_input_vec();
136 let power_spectrum = vec![0.0; FFT_SIZE / 2 + 1];
137
138 let inv_stds: Vec<f32> = FEATURE_STDS.iter().map(|&std| 1.0 / (std + EPS)).collect();
140
141 Self {
142 pre_emphasis_prev: 0.0,
143 mel_filters,
144 mel_filter_ranges,
145 window,
146 rfft,
147 fft_scratch,
148 fft_output,
149 fft_input,
150 power_spectrum,
151 inv_stds,
152 }
153 }
154
155 fn generate_mel_filters() -> (ndarray::Array2<f32>, Vec<(usize, usize)>) {
156 let n_bins = FFT_SIZE / 2 + 1;
157
158 let low_mel = 2595.0_f32 * (1.0_f32 + 0.0_f32 / 700.0_f32).log10();
160 let high_mel = 2595.0_f32 * (1.0_f32 + 8000.0_f32 / 700.0_f32).log10();
161
162 let mut mel_points = Vec::new();
163 for i in 0..=MEL_FILTER_BANK_NUM + 1 {
164 let mel = low_mel + (high_mel - low_mel) * i as f32 / (MEL_FILTER_BANK_NUM + 1) as f32;
165 mel_points.push(mel);
166 }
167
168 let mut hz_points = Vec::new();
170 for mel in mel_points {
171 let hz = 700.0_f32 * (10.0_f32.powf(mel / 2595.0_f32) - 1.0_f32);
172 hz_points.push(hz);
173 }
174
175 let mut bin_points = Vec::new();
177 for hz in hz_points {
178 let bin = ((FFT_SIZE + 1) as f32 * hz / SAMPLE_RATE as f32).floor() as usize;
179 bin_points.push(bin);
180 }
181
182 let mut mel_filters = ndarray::Array2::<f32>::zeros((MEL_FILTER_BANK_NUM, n_bins));
184 let mut ranges = Vec::with_capacity(MEL_FILTER_BANK_NUM);
185
186 for i in 0..MEL_FILTER_BANK_NUM {
187 let start = bin_points[i];
188 let end = bin_points[i + 2];
189 ranges.push((start, end));
190
191 for j in bin_points[i]..bin_points[i + 1] {
193 if j < n_bins {
194 mel_filters[[i, j]] =
195 (j - bin_points[i]) as f32 / (bin_points[i + 1] - bin_points[i]) as f32;
196 }
197 }
198
199 for j in bin_points[i + 1]..bin_points[i + 2] {
201 if j < n_bins {
202 mel_filters[[i, j]] = (bin_points[i + 2] - j) as f32
203 / (bin_points[i + 2] - bin_points[i + 1]) as f32;
204 }
205 }
206 }
207
208 (mel_filters, ranges)
209 }
210
211 fn pre_emphasis(prev_state: &mut f32, audio_frame: &[i16], output: &mut [f32]) {
212 if !audio_frame.is_empty() {
213 let inv_scale = 1.0 / 32768.0;
214 let first_sample = audio_frame[0] as f32;
215 output[0] = (first_sample - PRE_EMPHASIS_COEFF * *prev_state) * inv_scale;
216
217 for (out, samples) in output[1..].iter_mut().zip(audio_frame.windows(2)) {
220 let prev = samples[0] as f32;
221 let curr = samples[1] as f32;
222 *out = (curr - PRE_EMPHASIS_COEFF * prev) * inv_scale;
223 }
224
225 if !audio_frame.is_empty() {
226 *prev_state = audio_frame[audio_frame.len() - 1] as f32;
228 }
229 }
230 }
231
232 pub fn extract_features(&mut self, audio_frame: &[i16]) -> ndarray::Array1<f32> {
233 self.fft_input.fill(0.0);
236
237 let copy_len = audio_frame.len().min(WINDOW_SIZE);
239 Self::pre_emphasis(
240 &mut self.pre_emphasis_prev,
241 audio_frame,
242 &mut self.fft_input[..copy_len],
243 );
244
245 for (i, sample) in self.fft_input.iter_mut().enumerate().take(copy_len) {
247 *sample *= self.window[i];
248 }
249
250 self.rfft
252 .process_with_scratch(
253 &mut self.fft_input,
254 &mut self.fft_output,
255 &mut self.fft_scratch,
256 )
257 .unwrap();
258
259 let n_bins = FFT_SIZE / 2 + 1;
261 let scale = 1.0 / (32768.0 * 32768.0);
262
263 for (pow, complex) in self.power_spectrum.iter_mut().zip(self.fft_output.iter()) {
266 *pow = (complex.re * complex.re + complex.im * complex.im) * scale;
267 }
268
269 let mut mel_features = ndarray::Array1::<f32>::zeros(MEL_FILTER_BANK_NUM);
271
272 for i in 0..MEL_FILTER_BANK_NUM {
273 let (start, end) = self.mel_filter_ranges[i];
274 let valid_end = end.min(n_bins);
275
276 let mut sum = 0.0;
277 if start < valid_end {
278 let filter_row = self.mel_filters.row(i);
280 if let Some(filter_slice) = filter_row.as_slice() {
283 let filter_sub = &filter_slice[start..valid_end];
284 let power_sub = &self.power_spectrum[start..valid_end];
285
286 sum = super::simd::dot_product(filter_sub, power_sub);
288 } else {
289 for j in start..valid_end {
291 sum += self.mel_filters[[i, j]] * self.power_spectrum[j];
292 }
293 }
294 }
295 mel_features[i] = (sum + EPS).ln();
296 }
297
298 let pitch_freq = 0.0;
300
301 let mut features = ndarray::Array1::<f32>::zeros(FEATURE_LEN);
303 features
304 .slice_mut(ndarray::s![..MEL_FILTER_BANK_NUM])
305 .assign(&mel_features);
306 features[MEL_FILTER_BANK_NUM] = pitch_freq;
307
308 for (feat, (&mean, &inv_std)) in features
311 .iter_mut()
312 .zip(FEATURE_MEANS.iter().zip(self.inv_stds.iter()))
313 {
314 *feat = (*feat - mean) * inv_std;
315 }
316
317 features
318 }
319}
320
321#[derive(Clone, Debug)]
323struct Tensor3D {
324 data: Vec<f32>,
325 h: usize,
326 w: usize,
327 c: usize,
328}
329
330impl Tensor3D {
331 fn new(h: usize, w: usize, c: usize) -> Self {
332 Self {
333 data: vec![0.0; h * w * c],
334 h,
335 w,
336 c,
337 }
338 }
339
340 fn zeros(&mut self) {
341 self.data.fill(0.0);
342 }
343
344 #[inline(always)]
345 fn get(&self, y: usize, x: usize, ch: usize) -> f32 {
346 self.data[y * self.w * self.c + x * self.c + ch]
348 }
349
350 #[inline(always)]
351 fn set(&mut self, y: usize, x: usize, ch: usize, val: f32) {
352 self.data[y * self.w * self.c + x * self.c + ch] = val;
353 }
354}
355
356struct Conv2dLayer {
358 weights: Vec<f32>, bias: Option<Vec<f32>>, in_channels: usize,
361 out_channels: usize,
362 kernel_h: usize,
363 kernel_w: usize,
364 stride_h: usize,
365 stride_w: usize,
366 padding: [usize; 4], groups: usize,
368}
369
370impl Conv2dLayer {
371 fn new(
372 in_channels: usize,
373 out_channels: usize,
374 kernel_h: usize,
375 kernel_w: usize,
376 stride_h: usize,
377 stride_w: usize,
378 padding: [usize; 4],
379 groups: usize,
380 ) -> Self {
381 Self {
382 weights: vec![0.0; out_channels * (in_channels / groups) * kernel_h * kernel_w],
383 bias: None,
384 in_channels,
385 out_channels,
386 kernel_h,
387 kernel_w,
388 stride_h,
389 stride_w,
390 padding,
391 groups,
392 }
393 }
394
395 fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
397 let out_h = output.h;
398 let out_w = output.w;
399
400 if self.in_channels == 1
403 && self.out_channels == 1
404 && self.kernel_h == 3
405 && self.kernel_w == 3
406 && self.stride_h == 1
407 && self.stride_w == 1
408 && self.padding == [0, 0, 0, 0]
409 {
410 let bias = self.bias.as_ref().map(|b| b[0]).unwrap_or(0.0);
411 let w = &self.weights; for x in 0..out_w {
416 let mut sum = bias;
419
420 sum += input.get(0, x, 0) * w[0];
423 sum += input.get(0, x + 1, 0) * w[1];
424 sum += input.get(0, x + 2, 0) * w[2];
425
426 sum += input.get(1, x, 0) * w[3];
428 sum += input.get(1, x + 1, 0) * w[4];
429 sum += input.get(1, x + 2, 0) * w[5];
430
431 sum += input.get(2, x, 0) * w[6];
433 sum += input.get(2, x + 1, 0) * w[7];
434 sum += input.get(2, x + 2, 0) * w[8];
435
436 output.set(0, x, 0, sum);
437 }
438 return;
439 }
440
441 if self.in_channels == 1
444 && self.out_channels == 16
445 && self.kernel_h == 1
446 && self.kernel_w == 1
447 && self.stride_h == 1
448 && self.stride_w == 1
449 {
450 let w = &self.weights; let b = self.bias.as_ref(); for x in 0..out_w {
454 let val = input.get(0, x, 0);
455
456 for oc in 0..16 {
458 let bias = if let Some(bias_vec) = b {
459 bias_vec[oc]
460 } else {
461 0.0
462 };
463 let res = val * w[oc] + bias;
464 output.set(0, x, oc, res);
465 }
466 }
467 return;
468 }
469
470 if self.groups == 16
473 && self.in_channels == 16
474 && self.out_channels == 16
475 && self.kernel_h == 1
476 && self.kernel_w == 3
477 && self.stride_w == 2
478 && self.padding == [0, 1, 0, 1]
479 {
480 let w = &self.weights; let b = self.bias.as_ref();
482
483 for c in 0..16 {
484 let w_offset = c * 3;
485 let w0 = w[w_offset];
486 let w1 = w[w_offset + 1];
487 let w2 = w[w_offset + 2];
488 let bias = if let Some(bias_vec) = b {
489 bias_vec[c]
490 } else {
491 0.0
492 };
493
494 let val0 = input.get(0, 0, c);
496 let val1 = input.get(0, 1, c);
497 let sum0 = val0 * w1 + val1 * w2 + bias;
498 output.set(0, 0, c, sum0);
499
500 for x in 1..9 {
507 let in_x_origin = x * 2 - 1;
508 let v0 = input.get(0, in_x_origin, c);
509 let v1 = input.get(0, in_x_origin + 1, c);
510 let v2 = input.get(0, in_x_origin + 2, c);
511 let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
512 output.set(0, x, c, sum);
513 }
514
515 let v0 = input.get(0, 17, c);
517 let v1 = input.get(0, 18, c);
518 let sum9 = v0 * w0 + v1 * w1 + bias;
519 output.set(0, 9, c, sum9);
520 }
521 return;
522 }
523
524 if self.in_channels == 16
527 && self.out_channels == 16
528 && self.kernel_h == 1
529 && self.kernel_w == 1
530 && self.stride_h == 1
531 && self.stride_w == 1
532 && self.groups == 1
533 {
534 let w = &self.weights; let b = self.bias.as_ref();
536
537 for x in 0..out_w {
538 let mut in_vals = [0.0; 16];
540 for ic in 0..16 {
541 in_vals[ic] = input.get(0, x, ic);
542 }
543
544 for oc in 0..16 {
545 let mut sum = if let Some(bias_vec) = b {
546 bias_vec[oc]
547 } else {
548 0.0
549 };
550 let w_offset = oc * 16;
551
552 for ic in 0..16 {
554 sum += in_vals[ic] * w[w_offset + ic];
555 }
556 output.set(0, x, oc, sum);
557 }
558 }
559 return;
560 }
561
562 if self.groups == 16
565 && self.in_channels == 16
566 && self.out_channels == 16
567 && self.kernel_h == 1
568 && self.kernel_w == 3
569 && self.stride_w == 2
570 && self.padding == [0, 1, 0, 1]
571 && out_w == 5
572 {
573 let w = &self.weights;
574 let b = self.bias.as_ref();
575
576 for c in 0..16 {
577 let w_offset = c * 3;
578 let w0 = w[w_offset];
579 let w1 = w[w_offset + 1];
580 let w2 = w[w_offset + 2];
581 let bias = if let Some(bias_vec) = b {
582 bias_vec[c]
583 } else {
584 0.0
585 };
586
587 let val0 = input.get(0, 0, c);
589 let val1 = input.get(0, 1, c);
590 let sum0 = val0 * w1 + val1 * w2 + bias;
591 output.set(0, 0, c, sum0);
592
593 for x in 1..5 {
596 let in_x_origin = x * 2 - 1;
597 let v0 = input.get(0, in_x_origin, c);
598 let v1 = input.get(0, in_x_origin + 1, c);
599 let v2 = input.get(0, in_x_origin + 2, c);
600 let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
601 output.set(0, x, c, sum);
602 }
603 }
604 return;
605 }
606
607 if self.in_channels == 16
610 && self.out_channels == 32
611 && self.kernel_h == 1
612 && self.kernel_w == 1
613 && self.stride_h == 1
614 && self.stride_w == 1
615 && self.groups == 1
616 {
617 let w = &self.weights; let b = self.bias.as_ref();
619
620 for x in 0..out_w {
621 let mut in_vals = [0.0; 16];
622 for ic in 0..16 {
623 in_vals[ic] = input.get(0, x, ic);
624 }
625
626 for oc in 0..32 {
627 let mut sum = if let Some(bias_vec) = b {
628 bias_vec[oc]
629 } else {
630 0.0
631 };
632 let w_offset = oc * 16;
633
634 for ic in 0..16 {
635 sum += in_vals[ic] * w[w_offset + ic];
636 }
637 output.set(0, x, oc, sum);
638 }
639 }
640 return;
641 }
642
643 output.zeros();
645
646 let in_c_per_group = self.in_channels / self.groups;
647 let out_c_per_group = self.out_channels / self.groups;
648
649 if let Some(b) = &self.bias {
654 for g in 0..self.groups {
655 for oc in 0..out_c_per_group {
656 let out_ch_idx = g * out_c_per_group + oc;
657 let bias_val = b[out_ch_idx];
658 for y in 0..out_h {
660 for x in 0..out_w {
661 output.set(y, x, out_ch_idx, bias_val);
662 }
663 }
664 }
665 }
666 }
667
668 for g in 0..self.groups {
669 for oc in 0..out_c_per_group {
670 let out_ch_idx = g * out_c_per_group + oc;
671
672 let w_base = out_ch_idx * (in_c_per_group * self.kernel_h * self.kernel_w);
674
675 for y in 0..out_h {
676 let in_y_origin = (y * self.stride_h) as isize - self.padding[0] as isize;
677
678 for x in 0..out_w {
679 let in_x_origin = (x * self.stride_w) as isize - self.padding[1] as isize;
680
681 let mut sum = 0.0;
682
683 for ic in 0..in_c_per_group {
684 let in_ch_idx = g * in_c_per_group + ic;
685 let w_ic_base = w_base + ic * (self.kernel_h * self.kernel_w);
686
687 for ky in 0..self.kernel_h {
688 let in_y = in_y_origin + ky as isize;
689 if in_y >= 0 && in_y < input.h as isize {
690 let w_ky_base = w_ic_base + ky * self.kernel_w;
691
692 for kx in 0..self.kernel_w {
693 let in_x = in_x_origin + kx as isize;
694
695 if in_x >= 0 && in_x < input.w as isize {
696 let val =
698 input.get(in_y as usize, in_x as usize, in_ch_idx);
699 let w_idx = w_ky_base + kx;
700 let w = unsafe { *self.weights.get_unchecked(w_idx) };
702 sum += val * w;
703 }
704 }
705 }
706 }
707 }
708
709 let current = output.get(y, x, out_ch_idx);
711 output.set(y, x, out_ch_idx, current + sum);
712 }
713 }
714 }
715 }
716 }
717}
718
719struct MaxPool2dLayer {
721 kernel_h: usize,
722 kernel_w: usize,
723 stride_h: usize,
724 stride_w: usize,
725}
726
727impl MaxPool2dLayer {
728 fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
729 let out_h = output.h;
730 let out_w = output.w;
731
732 if self.kernel_h == 1 && self.kernel_w == 3 && self.stride_h == 1 && self.stride_w == 2 {
734 for c in 0..input.c {
735 for x in 0..out_w {
737 let in_x = x * 2;
738 let v0 = input.get(0, in_x, c);
740 let v1 = input.get(0, in_x + 1, c);
741 let v2 = input.get(0, in_x + 2, c);
742
743 let max_v = v0.max(v1).max(v2);
744 output.set(0, x, c, max_v);
745 }
746 }
747 return;
748 }
749
750 for c in 0..input.c {
751 for y in 0..out_h {
752 for x in 0..out_w {
753 let mut max_val = f32::NEG_INFINITY;
754
755 for ky in 0..self.kernel_h {
756 for kx in 0..self.kernel_w {
757 let in_y = y * self.stride_h + ky;
758 let in_x = x * self.stride_w + kx;
759 let val = input.get(in_y, in_x, c);
762 if val > max_val {
763 max_val = val;
764 }
765 }
766 }
767 output.set(y, x, c, max_val);
768 }
769 }
770 }
771 }
772}
773
774struct LinearLayer {
776 weights: Vec<f32>, bias: Vec<f32>, in_features: usize,
779 out_features: usize,
780}
781
782impl LinearLayer {
783 fn new(in_features: usize, out_features: usize) -> Self {
784 Self {
788 weights: vec![0.0; out_features * in_features],
789 bias: vec![0.0; out_features],
790 in_features,
791 out_features,
792 }
793 }
794
795 fn forward(&self, input: &[f32], output: &mut [f32]) {
796 assert_eq!(input.len(), self.in_features);
797 assert_eq!(output.len(), self.out_features);
798
799 for (i, out_val) in output.iter_mut().enumerate() {
802 let weight_row_start = i * self.in_features;
803 let weight_row = &self.weights[weight_row_start..weight_row_start + self.in_features];
804
805 let dot_product: f32 = weight_row
806 .iter()
807 .zip(input.iter())
808 .map(|(&w, &x)| w * x)
809 .sum();
810
811 *out_val = dot_product + self.bias[i];
812 }
813 }
814}
815
816struct LstmLayer {
818 input_size: usize,
819 hidden_size: usize,
820 weight_ih: Vec<f32>, weight_hh: Vec<f32>, bias_ih: Vec<f32>, bias_hh: Vec<f32>, gates_buffer: Vec<f32>, }
829
830impl LstmLayer {
831 fn new(input_size: usize, hidden_size: usize) -> Self {
832 Self {
833 input_size,
834 hidden_size,
835 weight_ih: vec![0.0; 4 * hidden_size * input_size],
836 weight_hh: vec![0.0; 4 * hidden_size * hidden_size],
837 bias_ih: vec![0.0; 4 * hidden_size],
838 bias_hh: vec![0.0; 4 * hidden_size],
839 gates_buffer: vec![0.0; 4 * hidden_size],
840 }
841 }
842
843 fn forward_optimized(&mut self, input: &[f32], hidden: &mut [f32], cell: &mut [f32]) {
844 let h_size = self.hidden_size;
845
846 for i in 0..4 * h_size {
848 let w_start = i * self.input_size;
849 let w_row = &self.weight_ih[w_start..w_start + self.input_size];
850 let dot: f32 = w_row.iter().zip(input).map(|(&w, &x)| w * x).sum();
851 self.gates_buffer[i] = dot + self.bias_ih[i];
852 }
853
854 for i in 0..4 * h_size {
857 let w_start = i * h_size;
858 let w_row = &self.weight_hh[w_start..w_start + h_size];
859 let dot: f32 = w_row.iter().zip(hidden.iter()).map(|(&w, &h)| w * h).sum();
860 self.gates_buffer[i] += dot + self.bias_hh[i];
861 }
862
863 for i in 0..h_size {
866 let i_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i]);
867 let o_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + h_size]);
868 let f_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + 2 * h_size]);
869 let g_gate = crate::media::vad::utils::tanh(self.gates_buffer[i + 3 * h_size]);
870
871 cell[i] = f_gate * cell[i] + i_gate * g_gate;
873
874 hidden[i] = o_gate * crate::media::vad::utils::tanh(cell[i]);
876 }
877 }
878}
879
880pub struct TinyTen {
881 config: VADOption,
882 buffer: PcmBuf,
883 last_timestamp: u64,
884 chunk_size: usize,
885
886 feature_extractor: TenFeatureExtractor,
887 feature_buffer: ndarray::Array2<f32>,
888
889 conv1_dw: Conv2dLayer,
892 conv1_pw: Conv2dLayer,
893 maxpool: MaxPool2dLayer,
894
895 conv2_dw: Conv2dLayer,
897 conv2_pw: Conv2dLayer,
898
899 conv3_dw: Conv2dLayer,
901 conv3_pw: Conv2dLayer,
902
903 lstm1: LstmLayer,
904 lstm2: LstmLayer,
905 dense1: LinearLayer,
906 dense2: LinearLayer,
907
908 h1: Vec<f32>,
910 c1: Vec<f32>,
911 h2: Vec<f32>,
912 c2: Vec<f32>,
913
914 t_input: Tensor3D,
916 t_conv1_dw: Tensor3D,
917 t_conv1_pw: Tensor3D,
918 t_maxpool: Tensor3D,
919 t_conv2_dw: Tensor3D,
920 t_conv2_pw: Tensor3D,
921 t_conv3_dw: Tensor3D,
922 t_conv3_pw: Tensor3D,
923
924 dense_input_buffer: Vec<f32>,
925 dense1_out_buffer: Vec<f32>,
926
927 last_score: Option<f32>,
928}
929
930const WEIGHTS_BYTES: &[u8] = include_bytes!("tiny_tenvad.bin");
931
932impl TinyTen {
933 pub fn new(config: VADOption) -> Result<Self> {
934 if config.samplerate != 16000 {
935 return Err(anyhow::anyhow!("TinyVad only supports 16kHz audio"));
936 }
937
938 let feature_extractor = TenFeatureExtractor::new();
939 let feature_buffer = ndarray::Array2::<f32>::zeros((CONTEXT_WINDOW_LEN, FEATURE_LEN));
940
941 let conv1_dw = Conv2dLayer::new(1, 1, 3, 3, 1, 1, [0, 0, 0, 0], 1);
945 let conv1_pw = Conv2dLayer::new(1, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
947
948 let maxpool = MaxPool2dLayer {
950 kernel_h: 1,
951 kernel_w: 3,
952 stride_h: 1,
953 stride_w: 2,
954 };
955
956 let conv2_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 1, 0, 1], 16);
959 let conv2_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
961
962 let conv3_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 0, 0, 1], 16);
965 let conv3_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
967
968 let lstm1 = LstmLayer::new(80, HIDDEN_SIZE);
970 let lstm2 = LstmLayer::new(HIDDEN_SIZE, HIDDEN_SIZE);
971
972 let dense1 = LinearLayer::new(HIDDEN_SIZE * 2, 32);
973 let dense2 = LinearLayer::new(32, 1);
974
975 let t_input = Tensor3D::new(CONTEXT_WINDOW_LEN, FEATURE_LEN, 1);
977 let t_conv1_dw = Tensor3D::new(1, 39, 1);
978 let t_conv1_pw = Tensor3D::new(1, 39, 16);
979 let t_maxpool = Tensor3D::new(1, 19, 16);
980 let t_conv2_dw = Tensor3D::new(1, 10, 16);
981 let t_conv2_pw = Tensor3D::new(1, 10, 16);
982 let t_conv3_dw = Tensor3D::new(1, 5, 16);
983 let t_conv3_pw = Tensor3D::new(1, 5, 16);
984
985 let dense_input_buffer = vec![0.0; HIDDEN_SIZE * 2];
986 let dense1_out_buffer = vec![0.0; 32];
987
988 let mut vad = Self {
989 config,
990 buffer: Vec::new(),
991 chunk_size: HOP_SIZE,
992 last_timestamp: 0,
993 feature_extractor,
994 feature_buffer,
995 conv1_dw,
996 conv1_pw,
997 maxpool,
998 conv2_dw,
999 conv2_pw,
1000 conv3_dw,
1001 conv3_pw,
1002 lstm1,
1003 lstm2,
1004 dense1,
1005 dense2,
1006 h1: vec![0.0; HIDDEN_SIZE],
1007 c1: vec![0.0; HIDDEN_SIZE],
1008 h2: vec![0.0; HIDDEN_SIZE],
1009 c2: vec![0.0; HIDDEN_SIZE],
1010 t_input,
1011 t_conv1_dw,
1012 t_conv1_pw,
1013 t_maxpool,
1014 t_conv2_dw,
1015 t_conv2_pw,
1016 t_conv3_dw,
1017 t_conv3_pw,
1018 dense_input_buffer,
1019 dense1_out_buffer,
1020 last_score: None,
1021 };
1022
1023 vad.load_weights_from_bytes(WEIGHTS_BYTES)?;
1024 Ok(vad)
1025 }
1026
1027 pub fn predict(&mut self, samples: &[i16]) -> f32 {
1028 let features = self.feature_extractor.extract_features(samples);
1030
1031 for i in 0..CONTEXT_WINDOW_LEN - 1 {
1033 for j in 0..FEATURE_LEN {
1034 self.feature_buffer[[i, j]] = self.feature_buffer[[i + 1, j]];
1035 }
1036 }
1037 for j in 0..FEATURE_LEN {
1038 self.feature_buffer[[CONTEXT_WINDOW_LEN - 1, j]] = features[j];
1039 }
1040
1041 for i in 0..CONTEXT_WINDOW_LEN {
1045 for j in 0..FEATURE_LEN {
1046 self.t_input.set(i, j, 0, self.feature_buffer[[i, j]]);
1047 }
1048 }
1049
1050 self.conv1_dw
1053 .forward_into(&self.t_input, &mut self.t_conv1_dw);
1054 self.conv1_pw
1055 .forward_into(&self.t_conv1_dw, &mut self.t_conv1_pw);
1056
1057 for val in self.t_conv1_pw.data.iter_mut() {
1059 *val = val.max(0.0);
1060 }
1061
1062 self.maxpool
1063 .forward_into(&self.t_conv1_pw, &mut self.t_maxpool);
1064
1065 self.conv2_dw
1067 .forward_into(&self.t_maxpool, &mut self.t_conv2_dw);
1068 self.conv2_pw
1069 .forward_into(&self.t_conv2_dw, &mut self.t_conv2_pw);
1070
1071 for val in self.t_conv2_pw.data.iter_mut() {
1072 *val = val.max(0.0);
1073 }
1074
1075 self.conv3_dw
1077 .forward_into(&self.t_conv2_pw, &mut self.t_conv3_dw);
1078 self.conv3_pw
1079 .forward_into(&self.t_conv3_dw, &mut self.t_conv3_pw);
1080
1081 for val in self.t_conv3_pw.data.iter_mut() {
1082 *val = val.max(0.0);
1083 }
1084
1085 let lstm_input = &self.t_conv3_pw.data;
1088
1089 self.lstm1
1091 .forward_optimized(lstm_input, &mut self.h1, &mut self.c1);
1092
1093 self.lstm2
1095 .forward_optimized(&self.h1, &mut self.h2, &mut self.c2);
1096
1097 let h_size = HIDDEN_SIZE;
1100 self.dense_input_buffer[0..h_size].copy_from_slice(&self.h2);
1101 self.dense_input_buffer[h_size..2 * h_size].copy_from_slice(&self.h1);
1102
1103 self.dense1
1105 .forward(&self.dense_input_buffer, &mut self.dense1_out_buffer);
1106 for val in self.dense1_out_buffer.iter_mut() {
1108 *val = val.max(0.0);
1109 }
1110
1111 let mut output = [0.0; 1];
1113 self.dense2.forward(&self.dense1_out_buffer, &mut output);
1114
1115 let score = 1.0 / (1.0 + (-output[0]).exp()); self.last_score = Some(score);
1117
1118 score
1119 }
1120
1121 fn load_weights_from_bytes(&mut self, bytes: &[u8]) -> Result<()> {
1122 let mut offset = 0;
1123
1124 let read_u32 = |offset: &mut usize, buf: &[u8]| -> u32 {
1126 let val = u32::from_le_bytes(buf[*offset..*offset + 4].try_into().unwrap());
1127 *offset += 4;
1128 val
1129 };
1130
1131 let num_tensors = read_u32(&mut offset, bytes);
1132
1133 let mut weights = std::collections::HashMap::new();
1134
1135 for _ in 0..num_tensors {
1136 let name_len = read_u32(&mut offset, bytes) as usize;
1137 let name_bytes = &bytes[offset..offset + name_len];
1138 let name = std::str::from_utf8(name_bytes)?.to_string();
1139 offset += name_len;
1140
1141 let shape_len = read_u32(&mut offset, bytes) as usize;
1142 let mut shape = Vec::new();
1143 for _ in 0..shape_len {
1144 shape.push(read_u32(&mut offset, bytes));
1145 }
1146
1147 let data_len = read_u32(&mut offset, bytes) as usize;
1148 let data_bytes = &bytes[offset..offset + data_len];
1149 let floats: Vec<f32> = data_bytes
1150 .chunks_exact(4)
1151 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
1152 .collect();
1153 offset += data_len;
1154
1155 weights.insert(name, (shape, floats));
1156 }
1157
1158 if let Some(w) = weights.get("conv1_dw_weight") {
1160 self.conv1_dw.weights = w.1.clone();
1161 }
1162 if let Some(w) = weights.get("conv1_pw_weight") {
1163 self.conv1_pw.weights = w.1.clone();
1164 }
1165 if let Some(w) = weights.get("conv1_bias") {
1166 self.conv1_pw.bias = Some(w.1.clone());
1167 }
1168
1169 if let Some(w) = weights.get("conv2_dw_weight") {
1170 self.conv2_dw.weights = w.1.clone();
1171 }
1172 if let Some(w) = weights.get("conv2_pw_weight") {
1173 self.conv2_pw.weights = w.1.clone();
1174 }
1175 if let Some(w) = weights.get("conv2_bias") {
1176 self.conv2_pw.bias = Some(w.1.clone());
1177 }
1178
1179 if let Some(w) = weights.get("conv3_dw_weight") {
1180 self.conv3_dw.weights = w.1.clone();
1181 }
1182 if let Some(w) = weights.get("conv3_pw_weight") {
1183 self.conv3_pw.weights = w.1.clone();
1184 }
1185 if let Some(w) = weights.get("conv3_bias") {
1186 self.conv3_pw.bias = Some(w.1.clone());
1187 }
1188
1189 if let Some(w) = weights.get("lstm1_w_ih") {
1190 self.lstm1.weight_ih = w.1.clone();
1191 }
1192 if let Some(w) = weights.get("lstm1_w_hh") {
1193 self.lstm1.weight_hh = w.1.clone();
1194 }
1195 if let Some(w) = weights.get("lstm1_bias") {
1196 let b = &w.1;
1200 if b.len() == 8 * HIDDEN_SIZE {
1201 self.lstm1.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1202 self.lstm1.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1203 }
1204 }
1205
1206 if let Some(w) = weights.get("lstm2_w_ih") {
1207 self.lstm2.weight_ih = w.1.clone();
1208 }
1209 if let Some(w) = weights.get("lstm2_w_hh") {
1210 self.lstm2.weight_hh = w.1.clone();
1211 }
1212 if let Some(w) = weights.get("lstm2_bias") {
1213 let b = &w.1;
1214 if b.len() == 8 * HIDDEN_SIZE {
1215 self.lstm2.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1216 self.lstm2.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1217 }
1218 }
1219
1220 if let Some(w) = weights.get("dense1_weight") {
1221 self.dense1.weights = w.1.clone();
1222 }
1223 if let Some(w) = weights.get("dense1_bias") {
1224 self.dense1.bias = w.1.clone();
1225 }
1226
1227 if let Some(w) = weights.get("dense2_weight") {
1228 self.dense2.weights = w.1.clone();
1229 }
1230 if let Some(w) = weights.get("dense2_bias") {
1231 self.dense2.bias = w.1.clone();
1232 }
1233 Ok(())
1234 }
1235}
1236
1237impl VadEngine for TinyTen {
1238 fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)> {
1239 let samples = match &frame.samples {
1240 Samples::PCM { samples } => samples,
1241 _ => return Some((false, frame.timestamp)),
1242 };
1243
1244 self.buffer.extend_from_slice(samples);
1245
1246 if self.buffer.len() >= self.chunk_size {
1247 let chunk: Vec<i16> = self.buffer.drain(..self.chunk_size).collect();
1248 let score = self.predict(&chunk);
1249
1250 let is_voice = score > self.config.voice_threshold;
1251 let chunk_duration_ms = (self.chunk_size as u64 * 1000) / (frame.sample_rate as u64);
1252
1253 if self.last_timestamp == 0 {
1254 self.last_timestamp = frame.timestamp;
1255 }
1256
1257 let chunk_timestamp = self.last_timestamp;
1258 self.last_timestamp += chunk_duration_ms;
1259
1260 return Some((is_voice, chunk_timestamp));
1261 }
1262
1263 None
1264 }
1265}