1use eyre::{bail, Result};
2use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr};
3use ort::session::Session;
4use std::path::Path;
5
6use crate::{session, vad_result::VadResult};
7
8#[derive(Debug)]
9pub struct Vad {
10 session: Session,
11 h_tensor: ArrayBase<OwnedRepr<f32>, Ix3>,
12 c_tensor: ArrayBase<OwnedRepr<f32>, Ix3>,
13 sample_rate_tensor: ArrayBase<OwnedRepr<i64>, Ix1>,
14}
15
16impl Vad {
17 pub fn new<P: AsRef<Path>>(model_path: P, sample_rate: usize) -> Result<Self> {
18 if ![8000_usize, 16000].contains(&sample_rate) {
19 bail!("Unsupported sample rate, use 8000 or 16000!");
20 }
21 let session = session::create_session(model_path)?;
22 let h_tensor = Array3::<f32>::zeros((2, 1, 64));
23 let c_tensor = Array3::<f32>::zeros((2, 1, 64));
24 let sample_rate_tensor = Array1::from_vec(vec![sample_rate as i64]);
25
26 Ok(Self {
27 session,
28 h_tensor,
29 c_tensor,
30 sample_rate_tensor,
31 })
32 }
33
34 pub fn compute(&mut self, samples: &[f32]) -> Result<VadResult> {
35 let samples_tensor = Array2::from_shape_vec((1, samples.len()), samples.to_vec())?;
36 let result = self.session.run(ort::inputs![
37 "input" => samples_tensor.view(),
38 "sr" => self.sample_rate_tensor.view(),
39 "h" => self.h_tensor.view(),
40 "c" => self.c_tensor.view()
41 ]?)?;
42
43 self.h_tensor = result
45 .get("hn")
46 .unwrap()
47 .try_extract_tensor::<f32>()
48 .unwrap()
49 .to_owned()
50 .into_shape_with_order((2, 1, 64))
51 .expect("Shape mismatch for h_tensor");
52 self.c_tensor = result
53 .get("cn")
54 .unwrap()
55 .try_extract_tensor::<f32>()
56 .unwrap()
57 .to_owned()
58 .into_shape_with_order((2, 1, 64))
59 .expect("Shape mismatch for h_tensor");
60
61 let prob = *result
62 .get("output")
63 .unwrap()
64 .try_extract_tensor::<f32>()
65 .unwrap()
66 .first()
67 .unwrap();
68 Ok(VadResult { prob })
69 }
70
71 pub fn reset(&mut self) {
72 self.h_tensor.fill(0.0);
73 self.c_tensor.fill(0.0);
74 }
75}