Skip to main content

ten_vad_rs/
lib.rs

1#![allow(clippy::excessive_precision)]
2
3mod biquad;
4mod buffer;
5mod error;
6mod pitch_est;
7
8// Re-export error types for public API
9pub use crate::buffer::AudioFrameBuffer;
10pub use crate::error::{TenVadError, TenVadResult};
11
12/// Target sample rate for TEN VAD (16kHz)
13pub const TARGET_SAMPLE_RATE: u32 = 16000;
14
15use ndarray::{Array1, Array2, ArrayView1, Axis, aview1};
16use ort::session::builder::{GraphOptimizationLevel, SessionBuilder};
17use ort::session::{SessionInputValue, SessionInputs};
18use ort::{session::Session, value::TensorRef};
19use pitch_est::PitchEstimator;
20use rustfft::{Fft, FftPlanner, num_complex::Complex32};
21use std::sync::Arc;
22
23const FFT_SIZE: usize = 1024;
24const WINDOW_SIZE: usize = 768;
25const MEL_FILTER_BANK_NUM: usize = 40;
26const FEATURE_LEN: usize = MEL_FILTER_BANK_NUM + 1; // 40 mel features + 1 pitch feature
27const CONTEXT_WINDOW_LEN: usize = 3;
28const MODEL_HIDDEN_DIM: usize = 64;
29const MODEL_IO_NUM: usize = 5;
30const EPS: f32 = 1e-20;
31const PRE_EMPHASIS_COEFF: f32 = 0.97;
32
33/// Means of input-mel-filterbank (from coeff.h)
34#[rustfmt::skip]
35const FEATURE_MEANS: [f32; 41] = [
36    -8.198236465454e+00, -6.265716552734e+00, -5.483818531036e+00,
37    -4.758691310883e+00, -4.417088985443e+00, -4.142892837524e+00,
38    -3.912850379944e+00, -3.845927953720e+00, -3.657090425491e+00,
39    -3.723418712616e+00, -3.876134157181e+00, -3.843890905380e+00,
40    -3.690405130386e+00, -3.756065845490e+00, -3.698696136475e+00,
41    -3.650463104248e+00, -3.700468778610e+00, -3.567321300507e+00,
42    -3.498900175095e+00, -3.477807044983e+00, -3.458816051483e+00,
43    -3.444923877716e+00, -3.401328563690e+00, -3.306261301041e+00,
44    -3.278556823730e+00, -3.233250856400e+00, -3.198616027832e+00,
45    -3.204526424408e+00, -3.208798646927e+00, -3.257838010788e+00,
46    -3.381376743317e+00, -3.534021377563e+00, -3.640867948532e+00,
47    -3.726858854294e+00, -3.773730993271e+00, -3.804667234421e+00,
48    -3.832901000977e+00, -3.871120452881e+00, -3.990592956543e+00,
49    -4.480289459229e+00, 9.235690307617e+01
50];
51
52/// Stds of input-mel-filterbank (from coeff.h)
53#[rustfmt::skip]
54const FEATURE_STDS: [f32; 41] = [
55    5.166063785553e+00, 4.977209568024e+00, 4.698895931244e+00,
56    4.630621433258e+00, 4.634347915649e+00, 4.641156196594e+00,
57    4.640676498413e+00, 4.666367053986e+00, 4.650534629822e+00,
58    4.640020847321e+00, 4.637400150299e+00, 4.620099067688e+00,
59    4.596316337585e+00, 4.562654972076e+00, 4.554360389709e+00,
60    4.566910743713e+00, 4.562489986420e+00, 4.562412738800e+00,
61    4.585299491882e+00, 4.600179672241e+00, 4.592845916748e+00,
62    4.585922718048e+00, 4.583496570587e+00, 4.626092910767e+00,
63    4.626957893372e+00, 4.626289367676e+00, 4.637005805969e+00,
64    4.683015823364e+00, 4.726813793182e+00, 4.734289646149e+00,
65    4.753227233887e+00, 4.849722862244e+00, 4.869434833527e+00,
66    4.884482860565e+00, 4.921327114105e+00, 4.959212303162e+00,
67    4.996619224548e+00, 5.044823646545e+00, 5.072216987610e+00,
68    5.096439361572e+00, 1.152136917114e+02
69];
70
71/// Hann window coefficients (from coeff.h)
72#[rustfmt::skip]
73const HANN_WINDOW_768: [f32; WINDOW_SIZE] = [
74    0.0000000e+00, 1.6733041e-05, 6.6931045e-05, 1.5059065e-04,
75    2.6770626e-04, 4.1827004e-04, 6.0227190e-04, 8.1969953e-04,
76    1.0705384e-03, 1.3547717e-03, 1.6723803e-03, 2.0233432e-03,
77    2.4076367e-03, 2.8252351e-03, 3.2761105e-03, 3.7602327e-03,
78    4.2775693e-03, 4.8280857e-03, 5.4117450e-03, 6.0285082e-03,
79    6.6783340e-03, 7.3611788e-03, 8.0769970e-03, 8.8257407e-03,
80    9.6073598e-03, 1.0421802e-02, 1.1269013e-02, 1.2148935e-02,
81    1.3061510e-02, 1.4006678e-02, 1.4984373e-02, 1.5994532e-02,
82    1.7037087e-02, 1.8111967e-02, 1.9219101e-02, 2.0358415e-02,
83    2.1529832e-02, 2.2733274e-02, 2.3968661e-02, 2.5235910e-02,
84    2.6534935e-02, 2.7865651e-02, 2.9227967e-02, 3.0621794e-02,
85    3.2047037e-02, 3.3503601e-02, 3.4991388e-02, 3.6510300e-02,
86    3.8060234e-02, 3.9641086e-02, 4.1252752e-02, 4.2895122e-02,
87    4.4568088e-02, 4.6271536e-02, 4.8005353e-02, 4.9769424e-02,
88    5.1563629e-02, 5.3387849e-02, 5.5241962e-02, 5.7125844e-02,
89    5.9039368e-02, 6.0982406e-02, 6.2954829e-02, 6.4956504e-02,
90    6.6987298e-02, 6.9047074e-02, 7.1135695e-02, 7.3253021e-02,
91    7.5398909e-02, 7.7573217e-02, 7.9775799e-02, 8.2006508e-02,
92    8.4265194e-02, 8.6551706e-02, 8.8865891e-02, 9.1207593e-02,
93    9.3576658e-02, 9.5972925e-02, 9.8396234e-02, 1.0084642e-01,
94    1.0332333e-01, 1.0582679e-01, 1.0835663e-01, 1.1091268e-01,
95    1.1349477e-01, 1.1610274e-01, 1.1873640e-01, 1.2139558e-01,
96    1.2408010e-01, 1.2678978e-01, 1.2952444e-01, 1.3228389e-01,
97    1.3506796e-01, 1.3787646e-01, 1.4070919e-01, 1.4356597e-01,
98    1.4644661e-01, 1.4935091e-01, 1.5227868e-01, 1.5522973e-01,
99    1.5820385e-01, 1.6120085e-01, 1.6422052e-01, 1.6726267e-01,
100    1.7032709e-01, 1.7341358e-01, 1.7652192e-01, 1.7965192e-01,
101    1.8280336e-01, 1.8597603e-01, 1.8916971e-01, 1.9238420e-01,
102    1.9561929e-01, 1.9887474e-01, 2.0215035e-01, 2.0544589e-01,
103    2.0876115e-01, 2.1209590e-01, 2.1544993e-01, 2.1882300e-01,
104    2.2221488e-01, 2.2562536e-01, 2.2905421e-01, 2.3250119e-01,
105    2.3596607e-01, 2.3944863e-01, 2.4294863e-01, 2.4646583e-01,
106    2.5000000e-01, 2.5355090e-01, 2.5711830e-01, 2.6070196e-01,
107    2.6430163e-01, 2.6791708e-01, 2.7154806e-01, 2.7519434e-01,
108    2.7885565e-01, 2.8253178e-01, 2.8622245e-01, 2.8992744e-01,
109    2.9364649e-01, 2.9737934e-01, 3.0112576e-01, 3.0488549e-01,
110    3.0865828e-01, 3.1244388e-01, 3.1624203e-01, 3.2005248e-01,
111    3.2387498e-01, 3.2770926e-01, 3.3155507e-01, 3.3541216e-01,
112    3.3928027e-01, 3.4315913e-01, 3.4704849e-01, 3.5094809e-01,
113    3.5485766e-01, 3.5877695e-01, 3.6270569e-01, 3.6664362e-01,
114    3.7059048e-01, 3.7454600e-01, 3.7850991e-01, 3.8248196e-01,
115    3.8646187e-01, 3.9044938e-01, 3.9444422e-01, 3.9844613e-01,
116    4.0245484e-01, 4.0647007e-01, 4.1049157e-01, 4.1451906e-01,
117    4.1855226e-01, 4.2259092e-01, 4.2663476e-01, 4.3068351e-01,
118    4.3473690e-01, 4.3879466e-01, 4.4285652e-01, 4.4692220e-01,
119    4.5099143e-01, 4.5506394e-01, 4.5913946e-01, 4.6321772e-01,
120    4.6729844e-01, 4.7138134e-01, 4.7546616e-01, 4.7955263e-01,
121    4.8364046e-01, 4.8772939e-01, 4.9181913e-01, 4.9590943e-01,
122    5.0000000e-01, 5.0409057e-01, 5.0818087e-01, 5.1227061e-01,
123    5.1635954e-01, 5.2044737e-01, 5.2453384e-01, 5.2861866e-01,
124    5.3270156e-01, 5.3678228e-01, 5.4086054e-01, 5.4493606e-01,
125    5.4900857e-01, 5.5307780e-01, 5.5714348e-01, 5.6120534e-01,
126    5.6526310e-01, 5.6931649e-01, 5.7336524e-01, 5.7740908e-01,
127    5.8144774e-01, 5.8548094e-01, 5.8950843e-01, 5.9352993e-01,
128    5.9754516e-01, 6.0155387e-01, 6.0555578e-01, 6.0955062e-01,
129    6.1353813e-01, 6.1751804e-01, 6.2149009e-01, 6.2545400e-01,
130    6.2940952e-01, 6.3335638e-01, 6.3729431e-01, 6.4122305e-01,
131    6.4514234e-01, 6.4905191e-01, 6.5295151e-01, 6.5684087e-01,
132    6.6071973e-01, 6.6458784e-01, 6.6844493e-01, 6.7229074e-01,
133    6.7612502e-01, 6.7994752e-01, 6.8375797e-01, 6.8755612e-01,
134    6.9134172e-01, 6.9511451e-01, 6.9887424e-01, 7.0262066e-01,
135    7.0635351e-01, 7.1007256e-01, 7.1377755e-01, 7.1746822e-01,
136    7.2114435e-01, 7.2480566e-01, 7.2845194e-01, 7.3208292e-01,
137    7.3569837e-01, 7.3929804e-01, 7.4288170e-01, 7.4644910e-01,
138    7.5000000e-01, 7.5353417e-01, 7.5705137e-01, 7.6055137e-01,
139    7.6403393e-01, 7.6749881e-01, 7.7094579e-01, 7.7437464e-01,
140    7.7778512e-01, 7.8117700e-01, 7.8455007e-01, 7.8790410e-01,
141    7.9123885e-01, 7.9455411e-01, 7.9784965e-01, 8.0112526e-01,
142    8.0438071e-01, 8.0761580e-01, 8.1083029e-01, 8.1402397e-01,
143    8.1719664e-01, 8.2034808e-01, 8.2347808e-01, 8.2658642e-01,
144    8.2967291e-01, 8.3273733e-01, 8.3577948e-01, 8.3879915e-01,
145    8.4179615e-01, 8.4477027e-01, 8.4772132e-01, 8.5064909e-01,
146    8.5355339e-01, 8.5643403e-01, 8.5929081e-01, 8.6212354e-01,
147    8.6493204e-01, 8.6771611e-01, 8.7047556e-01, 8.7321022e-01,
148    8.7591990e-01, 8.7860442e-01, 8.8126360e-01, 8.8389726e-01,
149    8.8650523e-01, 8.8908732e-01, 8.9164337e-01, 8.9417321e-01,
150    8.9667667e-01, 8.9915358e-01, 9.0160377e-01, 9.0402708e-01,
151    9.0642334e-01, 9.0879241e-01, 9.1113411e-01, 9.1344829e-01,
152    9.1573481e-01, 9.1799349e-01, 9.2022420e-01, 9.2242678e-01,
153    9.2460109e-01, 9.2674698e-01, 9.2886431e-01, 9.3095293e-01,
154    9.3301270e-01, 9.3504350e-01, 9.3704517e-01, 9.3901759e-01,
155    9.4096063e-01, 9.4287416e-01, 9.4475804e-01, 9.4661215e-01,
156    9.4843637e-01, 9.5023058e-01, 9.5199465e-01, 9.5372846e-01,
157    9.5543191e-01, 9.5710488e-01, 9.5874725e-01, 9.6035891e-01,
158    9.6193977e-01, 9.6348970e-01, 9.6500861e-01, 9.6649640e-01,
159    9.6795296e-01, 9.6937821e-01, 9.7077203e-01, 9.7213435e-01,
160    9.7346506e-01, 9.7476409e-01, 9.7603134e-01, 9.7726673e-01,
161    9.7847017e-01, 9.7964159e-01, 9.8078090e-01, 9.8188803e-01,
162    9.8296291e-01, 9.8400547e-01, 9.8501563e-01, 9.8599332e-01,
163    9.8693849e-01, 9.8785107e-01, 9.8873099e-01, 9.8957820e-01,
164    9.9039264e-01, 9.9117426e-01, 9.9192300e-01, 9.9263882e-01,
165    9.9332167e-01, 9.9397149e-01, 9.9458825e-01, 9.9517191e-01,
166    9.9572243e-01, 9.9623977e-01, 9.9672389e-01, 9.9717476e-01,
167    9.9759236e-01, 9.9797666e-01, 9.9832762e-01, 9.9864523e-01,
168    9.9892946e-01, 9.9918030e-01, 9.9939773e-01, 9.9958173e-01,
169    9.9973229e-01, 9.9984941e-01, 9.9993307e-01, 9.9998327e-01,
170    1.0000000e+00, 9.9998327e-01, 9.9993307e-01, 9.9984941e-01,
171    9.9973229e-01, 9.9958173e-01, 9.9939773e-01, 9.9918030e-01,
172    9.9892946e-01, 9.9864523e-01, 9.9832762e-01, 9.9797666e-01,
173    9.9759236e-01, 9.9717476e-01, 9.9672389e-01, 9.9623977e-01,
174    9.9572243e-01, 9.9517191e-01, 9.9458825e-01, 9.9397149e-01,
175    9.9332167e-01, 9.9263882e-01, 9.9192300e-01, 9.9117426e-01,
176    9.9039264e-01, 9.8957820e-01, 9.8873099e-01, 9.8785107e-01,
177    9.8693849e-01, 9.8599332e-01, 9.8501563e-01, 9.8400547e-01,
178    9.8296291e-01, 9.8188803e-01, 9.8078090e-01, 9.7964159e-01,
179    9.7847017e-01, 9.7726673e-01, 9.7603134e-01, 9.7476409e-01,
180    9.7346506e-01, 9.7213435e-01, 9.7077203e-01, 9.6937821e-01,
181    9.6795296e-01, 9.6649640e-01, 9.6500861e-01, 9.6348970e-01,
182    9.6193977e-01, 9.6035891e-01, 9.5874725e-01, 9.5710488e-01,
183    9.5543191e-01, 9.5372846e-01, 9.5199465e-01, 9.5023058e-01,
184    9.4843637e-01, 9.4661215e-01, 9.4475804e-01, 9.4287416e-01,
185    9.4096063e-01, 9.3901759e-01, 9.3704517e-01, 9.3504350e-01,
186    9.3301270e-01, 9.3095293e-01, 9.2886431e-01, 9.2674698e-01,
187    9.2460109e-01, 9.2242678e-01, 9.2022420e-01, 9.1799349e-01,
188    9.1573481e-01, 9.1344829e-01, 9.1113411e-01, 9.0879241e-01,
189    9.0642334e-01, 9.0402708e-01, 9.0160377e-01, 8.9915358e-01,
190    8.9667667e-01, 8.9417321e-01, 8.9164337e-01, 8.8908732e-01,
191    8.8650523e-01, 8.8389726e-01, 8.8126360e-01, 8.7860442e-01,
192    8.7591990e-01, 8.7321022e-01, 8.7047556e-01, 8.6771611e-01,
193    8.6493204e-01, 8.6212354e-01, 8.5929081e-01, 8.5643403e-01,
194    8.5355339e-01, 8.5064909e-01, 8.4772132e-01, 8.4477027e-01,
195    8.4179615e-01, 8.3879915e-01, 8.3577948e-01, 8.3273733e-01,
196    8.2967291e-01, 8.2658642e-01, 8.2347808e-01, 8.2034808e-01,
197    8.1719664e-01, 8.1402397e-01, 8.1083029e-01, 8.0761580e-01,
198    8.0438071e-01, 8.0112526e-01, 7.9784965e-01, 7.9455411e-01,
199    7.9123885e-01, 7.8790410e-01, 7.8455007e-01, 7.8117700e-01,
200    7.7778512e-01, 7.7437464e-01, 7.7094579e-01, 7.6749881e-01,
201    7.6403393e-01, 7.6055137e-01, 7.5705137e-01, 7.5353417e-01,
202    7.5000000e-01, 7.4644910e-01, 7.4288170e-01, 7.3929804e-01,
203    7.3569837e-01, 7.3208292e-01, 7.2845194e-01, 7.2480566e-01,
204    7.2114435e-01, 7.1746822e-01, 7.1377755e-01, 7.1007256e-01,
205    7.0635351e-01, 7.0262066e-01, 6.9887424e-01, 6.9511451e-01,
206    6.9134172e-01, 6.8755612e-01, 6.8375797e-01, 6.7994752e-01,
207    6.7612502e-01, 6.7229074e-01, 6.6844493e-01, 6.6458784e-01,
208    6.6071973e-01, 6.5684087e-01, 6.5295151e-01, 6.4905191e-01,
209    6.4514234e-01, 6.4122305e-01, 6.3729431e-01, 6.3335638e-01,
210    6.2940952e-01, 6.2545400e-01, 6.2149009e-01, 6.1751804e-01,
211    6.1353813e-01, 6.0955062e-01, 6.0555578e-01, 6.0155387e-01,
212    5.9754516e-01, 5.9352993e-01, 5.8950843e-01, 5.8548094e-01,
213    5.8144774e-01, 5.7740908e-01, 5.7336524e-01, 5.6931649e-01,
214    5.6526310e-01, 5.6120534e-01, 5.5714348e-01, 5.5307780e-01,
215    5.4900857e-01, 5.4493606e-01, 5.4086054e-01, 5.3678228e-01,
216    5.3270156e-01, 5.2861866e-01, 5.2453384e-01, 5.2044737e-01,
217    5.1635954e-01, 5.1227061e-01, 5.0818087e-01, 5.0409057e-01,
218    5.0000000e-01, 4.9590943e-01, 4.9181913e-01, 4.8772939e-01,
219    4.8364046e-01, 4.7955263e-01, 4.7546616e-01, 4.7138134e-01,
220    4.6729844e-01, 4.6321772e-01, 4.5913946e-01, 4.5506394e-01,
221    4.5099143e-01, 4.4692220e-01, 4.4285652e-01, 4.3879466e-01,
222    4.3473690e-01, 4.3068351e-01, 4.2663476e-01, 4.2259092e-01,
223    4.1855226e-01, 4.1451906e-01, 4.1049157e-01, 4.0647007e-01,
224    4.0245484e-01, 3.9844613e-01, 3.9444422e-01, 3.9044938e-01,
225    3.8646187e-01, 3.8248196e-01, 3.7850991e-01, 3.7454600e-01,
226    3.7059048e-01, 3.6664362e-01, 3.6270569e-01, 3.5877695e-01,
227    3.5485766e-01, 3.5094809e-01, 3.4704849e-01, 3.4315913e-01,
228    3.3928027e-01, 3.3541216e-01, 3.3155507e-01, 3.2770926e-01,
229    3.2387498e-01, 3.2005248e-01, 3.1624203e-01, 3.1244388e-01,
230    3.0865828e-01, 3.0488549e-01, 3.0112576e-01, 2.9737934e-01,
231    2.9364649e-01, 2.8992744e-01, 2.8622245e-01, 2.8253178e-01,
232    2.7885565e-01, 2.7519434e-01, 2.7154806e-01, 2.6791708e-01,
233    2.6430163e-01, 2.6070196e-01, 2.5711830e-01, 2.5355090e-01,
234    2.5000000e-01, 2.4646583e-01, 2.4294863e-01, 2.3944863e-01,
235    2.3596607e-01, 2.3250119e-01, 2.2905421e-01, 2.2562536e-01,
236    2.2221488e-01, 2.1882300e-01, 2.1544993e-01, 2.1209590e-01,
237    2.0876115e-01, 2.0544589e-01, 2.0215035e-01, 1.9887474e-01,
238    1.9561929e-01, 1.9238420e-01, 1.8916971e-01, 1.8597603e-01,
239    1.8280336e-01, 1.7965192e-01, 1.7652192e-01, 1.7341358e-01,
240    1.7032709e-01, 1.6726267e-01, 1.6422052e-01, 1.6120085e-01,
241    1.5820385e-01, 1.5522973e-01, 1.5227868e-01, 1.4935091e-01,
242    1.4644661e-01, 1.4356597e-01, 1.4070919e-01, 1.3787646e-01,
243    1.3506796e-01, 1.3228389e-01, 1.2952444e-01, 1.2678978e-01,
244    1.2408010e-01, 1.2139558e-01, 1.1873640e-01, 1.1610274e-01,
245    1.1349477e-01, 1.1091268e-01, 1.0835663e-01, 1.0582679e-01,
246    1.0332333e-01, 1.0084642e-01, 9.8396234e-02, 9.5972925e-02,
247    9.3576658e-02, 9.1207593e-02, 8.8865891e-02, 8.6551706e-02,
248    8.4265194e-02, 8.2006508e-02, 7.9775799e-02, 7.7573217e-02,
249    7.5398909e-02, 7.3253021e-02, 7.1135695e-02, 6.9047074e-02,
250    6.6987298e-02, 6.4956504e-02, 6.2954829e-02, 6.0982406e-02,
251    5.9039368e-02, 5.7125844e-02, 5.5241962e-02, 5.3387849e-02,
252    5.1563629e-02, 4.9769424e-02, 4.8005353e-02, 4.6271536e-02,
253    4.4568088e-02, 4.2895122e-02, 4.1252752e-02, 3.9641086e-02,
254    3.8060234e-02, 3.6510300e-02, 3.4991388e-02, 3.3503601e-02,
255    3.2047037e-02, 3.0621794e-02, 2.9227967e-02, 2.7865651e-02,
256    2.6534935e-02, 2.5235910e-02, 2.3968661e-02, 2.2733274e-02,
257    2.1529832e-02, 2.0358415e-02, 1.9219101e-02, 1.8111967e-02,
258    1.7037087e-02, 1.5994532e-02, 1.4984373e-02, 1.4006678e-02,
259    1.3061510e-02, 1.2148935e-02, 1.1269013e-02, 1.0421802e-02,
260    9.6073598e-03, 8.8257407e-03, 8.0769970e-03, 7.3611788e-03,
261    6.6783340e-03, 6.0285082e-03, 5.4117450e-03, 4.8280857e-03,
262    4.2775693e-03, 3.7602327e-03, 3.2761105e-03, 2.8252351e-03,
263    2.4076367e-03, 2.0233432e-03, 1.6723803e-03, 1.3547717e-03,
264    1.0705384e-03, 8.1969953e-04, 6.0227190e-04, 4.1827004e-04,
265    2.6770626e-04, 1.5059065e-04, 6.6931045e-05, 1.6733041e-05
266];
267
268/// TEN VAD ONNX model runner
269pub struct TenVad {
270    session: Session,                 // ONNX session for inference
271    hidden_states: Vec<Array2<f32>>, // Vector of 2D arrays: [MODEL_IO_NUM - 1] each [1, MODEL_HIDDEN_DIM]
272    feature_buffer: Array2<f32>,     // 2D array: [CONTEXT_WINDOW_LEN, FEATURE_LEN]
273    pre_emphasis_prev: f32,          // Previous value for pre-emphasis filtering
274    mel_filters: Array2<f32>,        // 2D array: [MEL_FILTER_BANK_NUM, n_bins]
275    window: ArrayView1<'static, f32>, // Hann window view: [WINDOW_SIZE]
276    fft_instance: Arc<dyn Fft<f32>>, // Cached FFT instance
277    fft_buffer: Vec<Complex32>,      // Reusable FFT buffer
278    stft_input_q: Vec<f32>,          // Sliding STFT input queue (pre-emphasized samples)
279    stft_windowed_buf: Array1<f32>,  // Reusable STFT windowed buffer
280    pitch_estimator: PitchEstimator, // Pitch estimator state
281}
282
283impl TenVad {
284    /// Create a new TenVadOnnx instance with the specified ONNX model path and sample rate.
285    ///
286    /// # Arguments
287    /// * `onnx_model_path` - Path to the ONNX model file.
288    /// * `sample_rate` - Sample rate in Hz. **Must be 16000 (16kHz)**, otherwise returns an error.
289    ///
290    /// # Returns
291    /// * A `TenVadResult` containing the initialized `TenVadOnnx` instance or an error.
292    ///
293    /// # Errors
294    /// Returns `TenVadError::UnsupportedSampleRate` if the sample rate is not 16000 Hz.
295    pub fn new(onnx_model_path: &str, sample_rate: u32) -> TenVadResult<Self> {
296        if sample_rate != TARGET_SAMPLE_RATE {
297            return Err(TenVadError::UnsupportedSampleRate(sample_rate));
298        }
299
300        let mut builder = Self::configure_session_builder()?;
301        let session = builder.commit_from_file(onnx_model_path)?;
302
303        Self::from_session(session)
304    }
305
306    /// Create a new TenVad instance from in-memory model bytes.
307    ///
308    /// This uses `commit_from_memory` from the `ort` crate to build the session directly
309    /// from the provided bytes (avoids writing a tempfile).
310    pub fn new_from_bytes(model_bytes: &[u8], sample_rate: u32) -> TenVadResult<Self> {
311        if sample_rate != TARGET_SAMPLE_RATE {
312            return Err(TenVadError::UnsupportedSampleRate(sample_rate));
313        }
314
315        let mut builder = Self::configure_session_builder()?;
316        let session = builder.commit_from_memory(model_bytes)?;
317
318        Self::from_session(session)
319    }
320
321    /// Shared initialization from an already-built `Session`.
322    fn from_session(session: Session) -> TenVadResult<Self> {
323        // Initialize hidden states: Vector of 2D arrays [MODEL_IO_NUM - 1] each [1, MODEL_HIDDEN_DIM]
324        let mut hidden_states = Vec::new();
325        for _ in 0..MODEL_IO_NUM - 1 {
326            hidden_states.push(Array2::zeros((1, MODEL_HIDDEN_DIM)));
327        }
328
329        // Initialize feature buffer: 2D array [CONTEXT_WINDOW_LEN, FEATURE_LEN]
330        let feature_buffer = Array2::zeros((CONTEXT_WINDOW_LEN, FEATURE_LEN));
331
332        // Initialize pre-emphasis previous value
333        let pre_emphasis_prev = 0.0f32;
334
335        // Generate mel filter bank
336        let mel_filters = Self::generate_mel_filters()?;
337
338        // Generate Hann window
339        let window = Self::generate_hann_window();
340
341        // Create and cache FFT planner and instance
342        let mut fft_planner = FftPlanner::new();
343        let fft_instance = fft_planner.plan_fft_forward(FFT_SIZE);
344        let fft_buffer = vec![Complex32::new(0.0, 0.0); FFT_SIZE];
345        let stft_input_q = vec![0.0f32; WINDOW_SIZE];
346        let stft_windowed_buf = Array1::zeros(WINDOW_SIZE);
347        let pitch_estimator = PitchEstimator::new();
348
349        Ok(Self {
350            session,
351            hidden_states,
352            feature_buffer,
353            pre_emphasis_prev,
354            mel_filters,
355            window,
356            fft_instance,
357            fft_buffer,
358            stft_input_q,
359            stft_windowed_buf,
360            pitch_estimator,
361        })
362    }
363
364    /// Configure a common Session builder with project defaults (optimization level and threads).
365    fn configure_session_builder() -> TenVadResult<SessionBuilder> {
366        Ok(Session::builder()?
367            .with_optimization_level(GraphOptimizationLevel::Level3)?
368            .with_intra_threads(1)?
369            .with_inter_threads(1)?)
370    }
371
372    /// Generate mel filter-bank coefficients(Adapted from aed.cc).
373    ///
374    /// A mel filter bank is a set of filters used in audio processing to mimic the human ear's perception of sound frequencies.
375    /// These filters are spaced according to the mel scale, which is more sensitive to lower frequencies and less sensitive to higher frequencies.
376    fn generate_mel_filters() -> TenVadResult<Array2<f32>> {
377        let n_bins = FFT_SIZE / 2 + 1;
378
379        // Generate mel filter-bank coefficients
380        let low_mel = 2595.0f32 * (1.0f32 + 0.0f32 / 700.0f32).log10();
381        let high_mel = 2595.0f32 * (1.0f32 + 8000.0f32 / 700.0f32).log10();
382
383        // Create mel points
384        let mut mel_points = Vec::new();
385        for i in 0..=MEL_FILTER_BANK_NUM + 1 {
386            let mel = low_mel + (high_mel - low_mel) * i as f32 / (MEL_FILTER_BANK_NUM + 1) as f32;
387            mel_points.push(mel);
388        }
389
390        // Convert to Hz
391        let mut hz_points = Vec::new();
392        for mel in mel_points {
393            let hz = 700.0f32 * (10.0f32.powf(mel / 2595.0f32) - 1.0f32);
394            hz_points.push(hz);
395        }
396
397        // Convert to FFT bin indices
398        let mut bin_points = Vec::new();
399        for hz in hz_points {
400            let bin = ((FFT_SIZE + 1) as f32 * hz / 16000.0f32) as usize;
401            bin_points.push(bin);
402        }
403        for i in 1..bin_points.len() {
404            if bin_points[i] == bin_points[i - 1] {
405                return Err(TenVadError::InvalidConfiguration(
406                    "Duplicate mel bin points are not supported".to_string(),
407                ));
408            }
409        }
410
411        // Build mel filter bank as 2D array
412        let mut mel_filters = Array2::zeros((MEL_FILTER_BANK_NUM, n_bins));
413
414        for i in 0..MEL_FILTER_BANK_NUM {
415            // Left slope
416            for j in bin_points[i]..bin_points[i + 1] {
417                if j < n_bins {
418                    mel_filters[[i, j]] =
419                        (j - bin_points[i]) as f32 / (bin_points[i + 1] - bin_points[i]) as f32;
420                }
421            }
422
423            // Right slope
424            for j in bin_points[i + 1]..bin_points[i + 2] {
425                if j < n_bins {
426                    mel_filters[[i, j]] = (bin_points[i + 2] - j) as f32
427                        / (bin_points[i + 2] - bin_points[i + 1]) as f32;
428                }
429            }
430        }
431
432        Ok(mel_filters)
433    }
434
435    /// Generate Hann window coefficients
436    fn generate_hann_window() -> ArrayView1<'static, f32> {
437        aview1(&HANN_WINDOW_768)
438    }
439
440    /// Pre-emphasis filtering
441    fn pre_emphasis(&mut self, audio_frame: &[f32]) -> Array1<f32> {
442        if audio_frame.is_empty() {
443            return Array1::zeros(0);
444        }
445
446        let mut emphasized = Array1::zeros(audio_frame.len());
447
448        // First sample
449        emphasized[0] = audio_frame[0] - PRE_EMPHASIS_COEFF * self.pre_emphasis_prev;
450
451        // Remaining samples
452        for i in 1..audio_frame.len() {
453            emphasized[i] = audio_frame[i] - PRE_EMPHASIS_COEFF * audio_frame[i - 1];
454        }
455
456        // Update previous value for next call
457        self.pre_emphasis_prev = audio_frame[audio_frame.len() - 1];
458
459        emphasized
460    }
461
462    /// Extract features from audio frame
463    fn extract_features(&mut self, audio_frame: &[f32]) -> Array1<f32> {
464        // Pre-emphasis
465        let emphasized = self.pre_emphasis(audio_frame);
466
467        // Sliding STFT window (hop of 256 samples in the reference pipeline).
468        let hop_size = 256.min(WINDOW_SIZE).min(emphasized.len());
469        if hop_size > 0 {
470            self.stft_input_q.copy_within(hop_size.., 0);
471            let dst_start = WINDOW_SIZE - hop_size;
472            for i in 0..hop_size {
473                self.stft_input_q[dst_start + i] = emphasized[emphasized.len() - hop_size + i];
474            }
475        }
476
477        // Windowing into a reusable buffer.
478        for i in 0..WINDOW_SIZE {
479            self.stft_windowed_buf[i] = self.stft_input_q[i] * self.window[i];
480        }
481
482        // Zero the FFT buffer before use to clear any previous data (using cached FFT instance and reusable buffer)
483        self.fft_buffer.fill(Complex32::new(0.0, 0.0));
484
485        // Prepare input for FFT (real to complex)
486        for i in 0..WINDOW_SIZE.min(FFT_SIZE) {
487            self.fft_buffer[i] = Complex32::new(self.stft_windowed_buf[i], 0.0);
488        }
489
490        // Perform FFT using cached instance
491        self.fft_instance.process(&mut self.fft_buffer);
492
493        // Compute power spectrum (only positive frequencies)
494        let n_bins = FFT_SIZE / 2 + 1;
495        let mut power_spectrum = Array1::zeros(n_bins);
496        for i in 0..n_bins {
497            power_spectrum[i] = self.fft_buffer[i].norm_sqr();
498        }
499
500        // Pitch estimation consumes raw (non-pre-emphasized) signal and unnormalized bin power.
501        let pitch_freq = self
502            .pitch_estimator
503            .process(audio_frame, power_spectrum.as_slice().unwrap_or(&[]));
504
505        // Normalize mel path (corresponding to powerNormal = 32768^2 in C++).
506        let power_normal = 32768.0f32.powi(2);
507        power_spectrum /= power_normal;
508
509        // Mel filter bank features
510        let mel_features = self.mel_filters.dot(&power_spectrum);
511        let mel_features = mel_features.mapv(|x| (x + EPS).ln());
512
513        // Combine features
514        let mut features = Array1::zeros(FEATURE_LEN);
515        features
516            .slice_mut(ndarray::s![..MEL_FILTER_BANK_NUM])
517            .assign(&mel_features);
518        features[MEL_FILTER_BANK_NUM] = pitch_freq;
519
520        // Feature normalization
521        for i in 0..FEATURE_LEN {
522            features[i] = (features[i] - FEATURE_MEANS[i]) / (FEATURE_STDS[i] + EPS);
523        }
524
525        features
526    }
527
528    /// Process a single audio frame and return VAD score and decision
529    /// # Arguments
530    /// * `audio_frame` - A slice of i16 audio samples in 16kHz (e.g., from a microphone)
531    /// # Returns
532    /// * The VAD score (f32)
533    pub fn process_frame(&mut self, audio_frame: &[i16]) -> TenVadResult<f32> {
534        // Check if audio frame is empty
535        if audio_frame.is_empty() {
536            return Err(TenVadError::EmptyAudioData);
537        }
538
539        // Convert i16 to f32
540        let audio_f32: Vec<f32> = audio_frame.iter().map(|&x| x as f32).collect();
541
542        // Extract features
543        let features = self.extract_features(&audio_f32);
544
545        // Update feature buffer (sliding window)
546        // Shift existing features up and add new features at the end
547        if CONTEXT_WINDOW_LEN > 1 {
548            // Use a simple loop to shift rows up
549            for i in 0..CONTEXT_WINDOW_LEN - 1 {
550                // Copy row i+1 to row i
551                let src_row = self.feature_buffer.row(i + 1).to_owned();
552                self.feature_buffer.row_mut(i).assign(&src_row);
553            }
554        }
555        // Set the last row to new features
556        self.feature_buffer
557            .row_mut(CONTEXT_WINDOW_LEN - 1)
558            .assign(&features);
559
560        // Prepare ONNX inference input
561        // Reshape feature buffer, [CONTEXT_WINDOW_LEN, FEATURE_LEN] to [1, CONTEXT_WINDOW_LEN, FEATURE_LEN]
562        let input_features = self.feature_buffer.view().insert_axis(Axis(0)); // shape: (1, CONTEXT_WINDOW_LEN, FEATURE_LEN)
563
564        // Build input array directly
565        let input_tensors: [SessionInputValue; MODEL_IO_NUM] = [
566            // Input features as first input
567            SessionInputValue::from(TensorRef::from_array_view(input_features)?),
568            // Add hidden states as inputs
569            SessionInputValue::from(TensorRef::from_array_view(self.hidden_states[0].view())?),
570            SessionInputValue::from(TensorRef::from_array_view(self.hidden_states[1].view())?),
571            SessionInputValue::from(TensorRef::from_array_view(self.hidden_states[2].view())?),
572            SessionInputValue::from(TensorRef::from_array_view(self.hidden_states[3].view())?),
573        ];
574
575        let session_inputs = SessionInputs::ValueArray(input_tensors);
576
577        // Run inference with all inputs
578        let outputs = self.session.run(session_inputs)?;
579
580        // Get VAD score from first output (outputs[0])
581        let vad_score = outputs[0].try_extract_array::<f32>()?[[0, 0, 0]];
582
583        // Update hidden states with outputs[1], outputs[2], outputs[3], outputs[4]
584        for i in 0..MODEL_IO_NUM - 1 {
585            let output_tensor = outputs[i + 1].try_extract_array::<f32>()?;
586            self.hidden_states[i].assign(&output_tensor);
587        }
588
589        Ok(vad_score)
590    }
591
592    /// Reset the VAD state
593    pub fn reset(&mut self) {
594        // Reset hidden states
595        for hidden_state in &mut self.hidden_states {
596            hidden_state.fill(0.0f32);
597        }
598
599        // Reset feature buffer
600        self.feature_buffer.fill(0.0f32);
601
602        // Reset pre-emphasis previous value
603        self.pre_emphasis_prev = 0.0f32;
604        self.stft_input_q.fill(0.0f32);
605        self.stft_windowed_buf.fill(0.0f32);
606        self.pitch_estimator.reset();
607    }
608}
609
610impl std::fmt::Debug for TenVad {
611    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612        f.debug_struct("TenVad")
613            .field("session", &"Session")
614            .field("hidden_states", &self.hidden_states.len())
615            .field("feature_buffer", &self.feature_buffer.shape())
616            .field("pre_emphasis_prev", &self.pre_emphasis_prev)
617            .field("mel_filters", &self.mel_filters.shape())
618            .field("window", &self.window.len())
619            .field("stft_input_q", &self.stft_input_q.len())
620            .field("stft_windowed_buf", &self.stft_windowed_buf.len())
621            .field("pitch_estimator", &self.pitch_estimator)
622            .finish()
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629    use std::f32::consts::PI;
630
631    // Helper function to create a valid TenVad instance for testing
632    fn create_test_vad() -> TenVad {
633        TenVad::new("onnx/ten-vad.onnx", TARGET_SAMPLE_RATE)
634            .expect("Failed to create TenVad instance for testing")
635    }
636
637    // Helper function to generate test audio with specific properties
638    fn generate_test_audio(length: usize, frequency: f32, sample_rate: f32) -> Vec<f32> {
639        (0..length)
640            .map(|i| (2.0 * PI * frequency * i as f32 / sample_rate).sin() * 0.5)
641            .collect()
642    }
643
644    #[test]
645    fn test_generate_mel_filters() {
646        let mel_filters = TenVad::generate_mel_filters().expect("Failed to generate mel filters");
647
648        // Check dimensions
649        assert_eq!(
650            mel_filters.shape(),
651            &[MEL_FILTER_BANK_NUM, FFT_SIZE / 2 + 1]
652        );
653
654        // Check that filters are non-negative
655        assert!(mel_filters.iter().all(|&x| x >= 0.0));
656
657        // Check that each filter has some non-zero values
658        for i in 0..MEL_FILTER_BANK_NUM {
659            let filter_sum: f32 = mel_filters.row(i).sum();
660            assert!(filter_sum > 0.0, "Filter {i} should have non-zero values");
661        }
662
663        // Check that filters have triangular shape (max value should be around 1.0)
664        for i in 0..MEL_FILTER_BANK_NUM {
665            let max_val = mel_filters.row(i).iter().fold(0.0f32, |a, &b| a.max(b));
666            assert!(
667                max_val <= 1.0 + f32::EPSILON,
668                "Filter {i} max value should not exceed 1.0"
669            );
670        }
671    }
672
673    #[test]
674    fn test_generate_hann_window() {
675        let window = TenVad::generate_hann_window();
676
677        // Check length
678        assert_eq!(window.len(), WINDOW_SIZE);
679
680        // Check range [0, 1]
681        assert!(window.iter().all(|&x| (0.0..=1.0).contains(&x)));
682
683        // Check exact edge values from reference table.
684        assert!((window[0] - 0.0).abs() < 1e-12, "Window[0] mismatch");
685        assert!(
686            (window[WINDOW_SIZE - 1] - 1.6733041e-05).abs() < 1e-10,
687            "Window[last] mismatch"
688        );
689
690        // Check that window peaks near the middle
691        let mid_idx = WINDOW_SIZE / 2;
692        assert!(window[mid_idx] > 0.99, "Window should peak near the middle");
693    }
694
695    #[test]
696    fn test_pre_emphasis_basic() {
697        let mut vad = create_test_vad();
698        let audio_frame = vec![0.0, 1.0, 2.0, 3.0, 4.0];
699        let emphasized = vad.pre_emphasis(&audio_frame);
700
701        assert_eq!(emphasized.len(), audio_frame.len());
702
703        // First sample should be original (no previous sample)
704        assert_eq!(emphasized[0], audio_frame[0]);
705
706        // Check that pre-emphasis is applied correctly
707        for i in 1..audio_frame.len() {
708            let expected = audio_frame[i] - PRE_EMPHASIS_COEFF * audio_frame[i - 1];
709            assert!((emphasized[i] - expected).abs() < f32::EPSILON);
710        }
711    }
712
713    #[test]
714    fn test_pre_emphasis_state_preservation() {
715        let mut vad = create_test_vad();
716
717        // Process first frame
718        let frame1 = vec![1.0, 2.0, 3.0];
719        let _ = vad.pre_emphasis(&frame1);
720
721        // Process second frame - should use last value from frame1 as previous
722        let frame2 = vec![4.0, 5.0, 6.0];
723        let emphasized2 = vad.pre_emphasis(&frame2);
724
725        // First sample of frame2 should use last sample of frame1
726        let expected = frame2[0] - PRE_EMPHASIS_COEFF * frame1[frame1.len() - 1];
727        assert!((emphasized2[0] - expected).abs() < f32::EPSILON);
728    }
729
730    #[test]
731    fn test_pre_emphasis_empty_frame() {
732        let mut vad = create_test_vad();
733        let empty_frame: Vec<f32> = vec![];
734        let emphasized = vad.pre_emphasis(&empty_frame);
735        assert_eq!(emphasized.len(), 0);
736    }
737
738    #[test]
739    fn test_pre_emphasis_single_sample() {
740        let mut vad = create_test_vad();
741        let single_frame = vec![5.0];
742        let emphasized = vad.pre_emphasis(&single_frame);
743
744        assert_eq!(emphasized.len(), 1);
745        // With no previous sample (initial state), should be close to original
746        assert!((emphasized[0] - single_frame[0]).abs() < f32::EPSILON);
747    }
748
749    #[test]
750    fn test_extract_features_basic() {
751        let mut vad = create_test_vad();
752        let audio_frame = vec![0.0; WINDOW_SIZE];
753        let features = vad.extract_features(&audio_frame);
754
755        assert_eq!(features.len(), FEATURE_LEN);
756
757        // All features should be finite numbers
758        assert!(features.iter().all(|&x| x.is_finite()));
759    }
760
761    #[test]
762    fn test_extract_features_sine_wave() {
763        let mut vad = create_test_vad();
764        let audio_frame = generate_test_audio(WINDOW_SIZE, 440.0, 16000.0);
765        let features = vad.extract_features(&audio_frame);
766
767        assert_eq!(features.len(), FEATURE_LEN);
768        assert!(features.iter().all(|&x| x.is_finite()));
769
770        // For a sine wave, features should be different from silence
771        let silence_features = vad.extract_features(&vec![0.0; WINDOW_SIZE]);
772        let features_diff: f32 = features
773            .iter()
774            .zip(silence_features.iter())
775            .map(|(a, b)| (a - b).abs())
776            .sum();
777
778        assert!(
779            features_diff > 0.1,
780            "Sine wave features should be different from silence"
781        );
782    }
783
784    #[test]
785    fn test_extract_features_short_frame() {
786        let mut vad = create_test_vad();
787        let short_frame = vec![1.0; 100]; // Shorter than WINDOW_SIZE
788        let features = vad.extract_features(&short_frame);
789
790        assert_eq!(features.len(), FEATURE_LEN);
791        assert!(features.iter().all(|&x| x.is_finite()));
792    }
793
794    #[test]
795    fn test_extract_features_long_frame() {
796        let mut vad = create_test_vad();
797        let long_frame = vec![1.0; WINDOW_SIZE * 2]; // Longer than WINDOW_SIZE
798        let features = vad.extract_features(&long_frame);
799
800        assert_eq!(features.len(), FEATURE_LEN);
801        assert!(features.iter().all(|&x| x.is_finite()));
802    }
803
804    #[test]
805    fn test_extract_features_normalization() {
806        let mut vad = create_test_vad();
807        let audio_frame = generate_test_audio(WINDOW_SIZE, 1000.0, 16000.0);
808        let features = vad.extract_features(&audio_frame);
809
810        // Features should be normalized - check basic properties
811        assert!(
812            features.iter().all(|&x| x.is_finite()),
813            "All features should be finite"
814        );
815
816        // Check that features are not all identical (indicating processing worked)
817        let first_feature = features[0];
818        let has_variation = features.iter().any(|&x| (x - first_feature).abs() > 0.01);
819        assert!(
820            has_variation,
821            "Features should show variation after processing"
822        );
823
824        // Check that features have reasonable magnitude (normalized features typically in [-5, 5] range)
825        let max_abs = features.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
826        assert!(
827            max_abs < 10.0,
828            "Normalized features should have reasonable magnitude"
829        );
830    }
831
832    #[test]
833    fn test_new_vad_initialization() {
834        // Test that initialization works
835        let vad = TenVad::new("onnx/ten-vad.onnx", TARGET_SAMPLE_RATE);
836        assert!(vad.is_ok(), "TenVad initialization should succeed");
837
838        let vad = vad.unwrap();
839
840        // Check initial states
841        assert_eq!(vad.hidden_states.len(), MODEL_IO_NUM - 1);
842        for (i, hidden_state) in vad.hidden_states.iter().enumerate() {
843            assert_eq!(
844                hidden_state.shape(),
845                &[1, MODEL_HIDDEN_DIM],
846                "Hidden state {i} should have correct shape"
847            );
848            assert!(
849                hidden_state.iter().all(|&x| x == 0.0),
850                "Hidden state {i} should be initialized to zero"
851            );
852        }
853
854        assert_eq!(
855            vad.feature_buffer.shape(),
856            &[CONTEXT_WINDOW_LEN, FEATURE_LEN]
857        );
858        assert!(
859            vad.feature_buffer.iter().all(|&x| x == 0.0),
860            "Feature buffer should be initialized to zero"
861        );
862
863        assert_eq!(vad.pre_emphasis_prev, 0.0);
864    }
865
866    #[test]
867    fn test_new_vad_invalid_path() {
868        let result = TenVad::new("nonexistent/path/model.onnx", TARGET_SAMPLE_RATE);
869        assert!(result.is_err(), "Should fail with invalid model path");
870    }
871
872    #[test]
873    fn test_new_vad_unsupported_sample_rate() {
874        let result = TenVad::new("onnx/ten-vad.onnx", 48000);
875        assert!(result.is_err(), "Should fail with unsupported sample rate");
876
877        match result.unwrap_err() {
878            TenVadError::UnsupportedSampleRate(rate) => {
879                assert_eq!(rate, 48000, "Error should contain the invalid sample rate");
880            }
881            _ => panic!("Expected UnsupportedSampleRate error"),
882        }
883    }
884
885    #[test]
886    fn test_reset_vad_state() {
887        let mut vad = create_test_vad();
888
889        // Process some audio to change internal state
890        let audio_frame = generate_test_audio(256, 440.0, 16000.0);
891        let audio_i16: Vec<i16> = audio_frame.iter().map(|&x| (x * 32767.0) as i16).collect();
892        let _ = vad.process_frame(&audio_i16);
893
894        // Reset the VAD
895        vad.reset();
896
897        // Check that states are reset
898        for hidden_state in &vad.hidden_states {
899            assert!(
900                hidden_state.iter().all(|&x| x == 0.0),
901                "Hidden states should be reset to zero"
902            );
903        }
904
905        assert!(
906            vad.feature_buffer.iter().all(|&x| x == 0.0),
907            "Feature buffer should be reset to zero"
908        );
909
910        assert_eq!(
911            vad.pre_emphasis_prev, 0.0,
912            "Pre-emphasis state should be reset"
913        );
914    }
915
916    #[test]
917    fn test_process_frame_basic() {
918        let mut vad = create_test_vad();
919        let audio_frame = vec![0i16; 256];
920        let result = vad.process_frame(&audio_frame);
921
922        assert!(result.is_ok(), "Processing frame should succeed");
923        let vad_score = result.unwrap();
924        assert!(vad_score.is_finite(), "VAD score should be finite");
925        assert!(
926            (0.0..=1.0).contains(&vad_score),
927            "VAD score should be in [0, 1] range"
928        );
929    }
930
931    #[test]
932    fn test_process_frame_empty() {
933        let mut vad = create_test_vad();
934        let empty_frame: Vec<i16> = vec![];
935        let result = vad.process_frame(&empty_frame);
936
937        assert!(result.is_err(), "Processing empty frame should fail");
938    }
939
940    #[test]
941    fn test_process_frame_different_sizes() {
942        let mut vad = create_test_vad();
943
944        let sizes = vec![64, 128, 256, 512, 1024];
945        for size in sizes {
946            let audio_frame = vec![100i16; size];
947            let result = vad.process_frame(&audio_frame);
948            assert!(
949                result.is_ok(),
950                "Processing frame of size {size} should succeed"
951            );
952        }
953    }
954
955    #[test]
956    fn test_process_frame_extreme_values() {
957        let mut vad = create_test_vad();
958
959        // Test with maximum values
960        let max_frame = vec![i16::MAX; 256];
961        let result = vad.process_frame(&max_frame);
962        assert!(result.is_ok(), "Processing max values should succeed");
963
964        // Test with minimum values
965        let min_frame = vec![i16::MIN; 256];
966        let result = vad.process_frame(&min_frame);
967        assert!(result.is_ok(), "Processing min values should succeed");
968    }
969
970    #[test]
971    fn test_process_frame_sequence() {
972        let mut vad = create_test_vad();
973        let frame_size = 256;
974
975        // Process multiple frames in sequence
976        for i in 0..10 {
977            let audio_frame: Vec<i16> = (0..frame_size)
978                .map(|j| ((i * 100 + j) % 1000) as i16)
979                .collect();
980            let result = vad.process_frame(&audio_frame);
981            assert!(result.is_ok(), "Processing frame {i} should succeed");
982
983            let vad_score = result.unwrap();
984            assert!(vad_score.is_finite(), "VAD score {i} should be finite");
985        }
986    }
987
988    #[test]
989    fn test_process_frame_consistent_results() {
990        let mut vad1 = create_test_vad();
991        let mut vad2 = create_test_vad();
992
993        let audio_frame = generate_test_audio(256, 440.0, 16000.0);
994        let audio_i16: Vec<i16> = audio_frame.iter().map(|&x| (x * 32767.0) as i16).collect();
995
996        let score1 = vad1.process_frame(&audio_i16).unwrap();
997        let score2 = vad2.process_frame(&audio_i16).unwrap();
998
999        assert!(
1000            (score1 - score2).abs() < f32::EPSILON,
1001            "Same input should produce same output"
1002        );
1003    }
1004
1005    #[test]
1006    fn test_feature_buffer_sliding_window() {
1007        let mut vad = create_test_vad();
1008
1009        // Feature buffer should initially be zeros
1010        let initial_sum: f32 = vad.feature_buffer.sum();
1011        assert_eq!(initial_sum, 0.0, "Initial feature buffer should be zeros");
1012
1013        // Process several frames with different signals
1014        for i in 0..CONTEXT_WINDOW_LEN + 2 {
1015            // Create audio with some variation to ensure features are different
1016            let audio_frame = generate_test_audio(WINDOW_SIZE, 200.0 + i as f32 * 100.0, 16000.0);
1017            let _ = vad.extract_features(&audio_frame);
1018        }
1019
1020        // Feature buffer should contain the last CONTEXT_WINDOW_LEN frames
1021        assert_eq!(
1022            vad.feature_buffer.shape(),
1023            &[CONTEXT_WINDOW_LEN, FEATURE_LEN]
1024        );
1025
1026        // The buffer should have been updated from its initial zero state
1027        // Even after normalization, processed audio should produce different features than silence
1028        let silence_features = {
1029            let mut temp_vad = create_test_vad();
1030            temp_vad.extract_features(&vec![0.0; WINDOW_SIZE])
1031        };
1032
1033        // At least one row should be different from silence features
1034        let mut has_difference = false;
1035        for row_idx in 0..CONTEXT_WINDOW_LEN {
1036            let row = vad.feature_buffer.row(row_idx);
1037            let diff: f32 = row
1038                .iter()
1039                .zip(silence_features.iter())
1040                .map(|(a, b)| (a - b).abs())
1041                .sum();
1042            if diff > 0.1 {
1043                // Allow for some tolerance
1044                has_difference = true;
1045                break;
1046            }
1047        }
1048
1049        // If no significant difference found, at least verify the buffer structure is correct
1050        assert!(
1051            has_difference || vad.feature_buffer.shape() == [CONTEXT_WINDOW_LEN, FEATURE_LEN],
1052            "Feature buffer should either show processing changes or maintain correct structure"
1053        );
1054    }
1055
1056    #[test]
1057    fn test_constants_validity() {
1058        // Test that constants are reasonable (these help document expected values)
1059        // The following lines use `const _: () = assert!(...)` for compile-time assertions.
1060        // This idiom causes a compilation error if the assertion fails, ensuring the condition is checked at compile time.
1061        const _: () = assert!(FFT_SIZE > 0, "FFT_SIZE should be positive");
1062        const _: () = assert!(WINDOW_SIZE > 0, "WINDOW_SIZE should be positive");
1063        const _: () = assert!(
1064            MEL_FILTER_BANK_NUM > 0,
1065            "MEL_FILTER_BANK_NUM should be positive"
1066        );
1067        const _: () = assert!(FEATURE_LEN > 0, "FEATURE_LEN should be positive");
1068        const _: () = assert!(
1069            CONTEXT_WINDOW_LEN > 0,
1070            "CONTEXT_WINDOW_LEN should be positive"
1071        );
1072        const _: () = assert!(MODEL_HIDDEN_DIM > 0, "MODEL_HIDDEN_DIM should be positive");
1073        const _: () = assert!(MODEL_IO_NUM > 1, "MODEL_IO_NUM should be greater than 1");
1074
1075        // Test runtime checks
1076        assert!(
1077            FFT_SIZE.is_power_of_two(),
1078            "FFT_SIZE should be a power of 2"
1079        );
1080        assert!(
1081            (0.0..1.0).contains(&PRE_EMPHASIS_COEFF),
1082            "PRE_EMPHASIS_COEFF should be in (0,1)"
1083        );
1084
1085        // Test feature normalization constants
1086        assert_eq!(
1087            FEATURE_MEANS.len(),
1088            FEATURE_LEN,
1089            "FEATURE_MEANS length should match FEATURE_LEN"
1090        );
1091        assert_eq!(
1092            FEATURE_STDS.len(),
1093            FEATURE_LEN,
1094            "FEATURE_STDS length should match FEATURE_LEN"
1095        );
1096
1097        // All standard deviations should be positive
1098        assert!(
1099            FEATURE_STDS.iter().all(|&x| x > 0.0),
1100            "All feature stds should be positive"
1101        );
1102    }
1103
1104    #[test]
1105    fn test_debug_implementation() {
1106        let vad = create_test_vad();
1107        let debug_str = format!("{vad:?}");
1108
1109        // Debug output should contain key information
1110        assert!(debug_str.contains("TenVad"));
1111        assert!(debug_str.contains("hidden_states"));
1112        assert!(debug_str.contains("feature_buffer"));
1113    }
1114
1115    #[test]
1116    fn test_multiple_vad_instances() {
1117        // Test that multiple VAD instances can coexist
1118        let mut vad1 = create_test_vad();
1119        let mut vad2 = create_test_vad();
1120
1121        let frame1 = vec![100i16; 256];
1122        let frame2 = vec![200i16; 256];
1123
1124        let score1 = vad1.process_frame(&frame1).unwrap();
1125        let score2 = vad2.process_frame(&frame2).unwrap();
1126
1127        // Different inputs should potentially produce different outputs
1128        assert!(score1.is_finite() && score2.is_finite());
1129
1130        // Reset both instances so they evaluate the same frame from the same recurrent state.
1131        vad1.reset();
1132        vad2.reset();
1133
1134        // Process same frame with both instances
1135        let same_frame = vec![150i16; 256];
1136        let score1_same = vad1.process_frame(&same_frame).unwrap();
1137        let score2_same = vad2.process_frame(&same_frame).unwrap();
1138
1139        // Should produce same result for same input when their state is aligned.
1140        assert!(
1141            (score1_same - score2_same).abs() < 0.01,
1142            "Different instances should produce similar results for same input"
1143        );
1144    }
1145}