vad_rs/
vad.rs

1use eyre::{bail, Result};
2use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr};
3use ort::session::Session;
4use std::path::Path;
5
6use crate::{session, vad_result::VadResult};
7
8#[derive(Debug)]
9pub struct Vad {
10    session: Session,
11    h_tensor: ArrayBase<OwnedRepr<f32>, Ix3>,
12    c_tensor: ArrayBase<OwnedRepr<f32>, Ix3>,
13    sample_rate_tensor: ArrayBase<OwnedRepr<i64>, Ix1>,
14}
15
16impl Vad {
17    pub fn new<P: AsRef<Path>>(model_path: P, sample_rate: usize) -> Result<Self> {
18        if ![8000_usize, 16000].contains(&sample_rate) {
19            bail!("Unsupported sample rate, use 8000 or 16000!");
20        }
21        let session = session::create_session(model_path)?;
22        let h_tensor = Array3::<f32>::zeros((2, 1, 64));
23        let c_tensor = Array3::<f32>::zeros((2, 1, 64));
24        let sample_rate_tensor = Array1::from_vec(vec![sample_rate as i64]);
25
26        Ok(Self {
27            session,
28            h_tensor,
29            c_tensor,
30            sample_rate_tensor,
31        })
32    }
33
34    pub fn compute(&mut self, samples: &[f32]) -> Result<VadResult> {
35        let samples_tensor = Array2::from_shape_vec((1, samples.len()), samples.to_vec())?;
36        let result = self.session.run(ort::inputs![
37            "input" => samples_tensor.view(),
38            "sr" => self.sample_rate_tensor.view(),
39            "h" => self.h_tensor.view(),
40            "c" => self.c_tensor.view()
41        ]?)?;
42
43        // Update internal state tensors.
44        self.h_tensor = result
45            .get("hn")
46            .unwrap()
47            .try_extract_tensor::<f32>()
48            .unwrap()
49            .to_owned()
50            .into_shape_with_order((2, 1, 64))
51            .expect("Shape mismatch for h_tensor");
52        self.c_tensor = result
53            .get("cn")
54            .unwrap()
55            .try_extract_tensor::<f32>()
56            .unwrap()
57            .to_owned()
58            .into_shape_with_order((2, 1, 64))
59            .expect("Shape mismatch for h_tensor");
60
61        let prob = *result
62            .get("output")
63            .unwrap()
64            .try_extract_tensor::<f32>()
65            .unwrap()
66            .first()
67            .unwrap();
68        Ok(VadResult { prob })
69    }
70
71    pub fn reset(&mut self) {
72        self.h_tensor.fill(0.0);
73        self.c_tensor.fill(0.0);
74    }
75}