sherpa_rs/
lib.rs

1pub mod audio_tag;
2pub mod diarize;
3pub mod dolphin;
4pub mod embedding_manager;
5pub mod keyword_spot;
6pub mod language_id;
7pub mod moonshine;
8pub mod paraformer;
9pub mod punctuate;
10pub mod sense_voice;
11pub mod silero_vad;
12pub mod speaker_id;
13pub mod ten_vad;
14pub mod transducer;
15pub mod whisper;
16pub mod zipformer;
17
18mod utils;
19
20#[cfg(feature = "tts")]
21pub mod tts;
22
23use std::ffi::CStr;
24
25#[cfg(feature = "sys")]
26pub use sherpa_rs_sys;
27
28use eyre::{bail, Result};
29use utils::cstr_to_string;
30
31pub fn get_default_provider() -> String {
32    "cpu".into()
33    // Other providers has many issues with different models!!
34    // if cfg!(feature = "cuda") {
35    //     "cuda"
36    // } else if cfg!(target_os = "macos") {
37    //     "coreml"
38    // } else if cfg!(feature = "directml") {
39    //     "directml"
40    // } else {
41    //     "cpu"
42    // }
43    // .into()
44}
45
46pub fn read_audio_file(path: &str) -> Result<(Vec<f32>, u32)> {
47    let mut reader = hound::WavReader::open(path)?;
48    let sample_rate = reader.spec().sample_rate;
49
50    // Check if the sample rate is 16000
51    if sample_rate != 16000 {
52        bail!("The sample rate must be 16000.");
53    }
54
55    // Collect samples into a Vec<f32>
56    let samples: Vec<f32> = reader
57        .samples::<i16>()
58        .map(|s| (s.unwrap() as f32) / (i16::MAX as f32))
59        .collect();
60
61    Ok((samples, sample_rate))
62}
63
64pub fn write_audio_file(path: &str, samples: &[f32], sample_rate: u32) -> Result<()> {
65    // Create a WAV file writer
66    let spec = hound::WavSpec {
67        channels: 1,
68        sample_rate,
69        bits_per_sample: 16,
70        sample_format: hound::SampleFormat::Int,
71    };
72
73    let mut writer = hound::WavWriter::create(path, spec)?;
74
75    // Convert samples from f32 to i16 and write them to the WAV file
76    for &sample in samples {
77        let scaled_sample =
78            (sample * (i16::MAX as f32)).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
79        writer.write_sample(scaled_sample)?;
80    }
81
82    writer.finalize()?;
83    Ok(())
84}
85
86pub struct OnnxConfig {
87    pub provider: String,
88    pub debug: bool,
89    pub num_threads: i32,
90}
91
92#[derive(Debug, Clone)]
93pub struct OfflineRecognizerResult {
94    pub lang: String,
95    pub text: String,
96    pub timestamps: Vec<f32>,
97    pub tokens: Vec<String>,
98}
99
100impl OfflineRecognizerResult {
101    fn new(result: &sherpa_rs_sys::SherpaOnnxOfflineRecognizerResult) -> Self {
102        let lang = unsafe { cstr_to_string(result.lang) };
103        let text = unsafe { cstr_to_string(result.text) };
104        let count = result.count.try_into().unwrap();
105        let timestamps = if result.timestamps.is_null() {
106            Vec::new()
107        } else {
108            unsafe { std::slice::from_raw_parts(result.timestamps, count).to_vec() }
109        };
110        let mut tokens = Vec::with_capacity(count);
111        let mut next_token = result.tokens;
112
113        for _ in 0..count {
114            let token = unsafe { CStr::from_ptr(next_token) };
115            tokens.push(token.to_string_lossy().into_owned());
116            next_token = next_token
117                .wrapping_byte_offset(token.to_bytes_with_nul().len().try_into().unwrap());
118        }
119
120        Self {
121            lang,
122            text,
123            timestamps,
124            tokens,
125        }
126    }
127}
128
129impl Default for OnnxConfig {
130    fn default() -> Self {
131        Self {
132            provider: get_default_provider(),
133            debug: false,
134            num_threads: 1,
135        }
136    }
137}