wavify/
lib.rs

1use hound::WavReader;
2use std::ffi::{CStr, CString};
3use std::os::raw::{c_char, c_float};
4
5#[macro_use]
6extern crate log;
7use log::Level;
8
9/// Represents the Speech-to-Text Engine.
10pub struct SttEngine {
11    inner: *mut SttEngineInner,
12}
13
14#[repr(C)]
15struct SttEngineInner {
16    // C ABI does not allow zero-sized structs so we add a dummy field
17    _dummy: c_char,
18}
19
20/// A struct representing an array of floating-point numbers.
21#[repr(C)]
22pub struct FloatArray {
23    /// Pointer to the array data.
24    pub data: *const f32,
25    /// Length of the array.
26    pub len: usize,
27}
28
29/// Represents possible errors that can occur during the speech-to-text process.
30#[derive(Debug)]
31pub enum WavifyError {
32    /// Error that occurs during initialization of the STT engine.
33    InitError,
34    /// Error that occurs during inference.
35    InferenceError,
36}
37
38impl std::fmt::Display for WavifyError {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::InitError => write!(f, "Failed to initialize"),
42            Self::InferenceError => write!(f, "Failed to run inference"),
43        }
44    }
45}
46
47#[derive(Debug)]
48pub enum LogLevel {
49    Trace,
50    Debug,
51    Info,
52    Warn,
53    Error,
54}
55impl LogLevel {
56    fn as_str(&self) -> &str {
57        match self {
58            LogLevel::Trace => "trace",
59            LogLevel::Debug => "debug",
60            LogLevel::Info => "info",
61            LogLevel::Warn => "warn",
62            LogLevel::Error => "error",
63        }
64    }
65}
66
67impl std::error::Error for WavifyError {}
68
69extern "C" {
70    fn create_stt_engine(model_path: *const c_char, api_key: *const c_char) -> *mut SttEngineInner;
71    fn destroy_stt_engine(model: *mut SttEngineInner);
72    fn stt(model: *mut SttEngineInner, data: FloatArray) -> *mut c_char;
73    fn free_result(result: *mut c_char);
74    fn create_wake_word_engine(
75        model_path: *const c_char,
76        api_key: *const c_char,
77    ) -> *mut WakeWordEngineInner;
78    fn destroy_wake_word_engine(model: *mut WakeWordEngineInner);
79    fn detect_wake_word(model: *mut WakeWordEngineInner, data: FloatArray) -> c_float;
80    fn setup_logger(level: *const c_char);
81}
82
83impl SttEngine {
84    /// Creates a new instance of `SttEngine`.
85    ///
86    /// # Arguments
87    ///
88    /// * `model_path` - A string slice that holds the path to the model.
89    /// * `api_key` - A string slice that holds the API key.
90    ///
91    /// # Returns
92    ///
93    /// A result that, if successful, contains a new instance of `SttEngine`. Otherwise, it contains a `WavifyError`.
94    ///
95    /// # Examples
96    ///
97    /// ```
98    /// let engine = SttEngine::new("path/to/model", "api_key");
99    /// ```
100    pub fn new(model_path: &str, api_key: &str) -> Result<SttEngine, WavifyError> {
101        let maybe_model_path_c = CString::new(model_path);
102        let maybe_api_key_c = CString::new(api_key);
103        match (maybe_model_path_c, maybe_api_key_c) {
104            (Ok(model_path_c), Ok(api_key_c)) => unsafe {
105                let inner = create_stt_engine(model_path_c.as_ptr(), api_key_c.as_ptr());
106                Ok(SttEngine { inner })
107            },
108            (_, _) => Err(WavifyError::InitError),
109        }
110    }
111
112    /// Destroys the `SttEngine` instance, freeing any resources.
113    pub fn destroy(self) {
114        unsafe { destroy_stt_engine(self.inner) }
115    }
116
117    /// Performs speech-to-text on the given audio data.
118    ///
119    /// # Arguments
120    ///
121    /// * `data` - A slice of floating-point numbers representing the audio data.
122    ///
123    /// # Returns
124    ///
125    /// A result that, if successful, contains a `String` with the recognized text. Otherwise, it contains a `WavifyError`.
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// let text = engine.stt(&audio_data).unwrap();
131    /// ```
132    pub fn stt(&self, data: &[f32]) -> Result<String, WavifyError> {
133        let float_array = FloatArray {
134            data: data.as_ptr(),
135            len: data.len(),
136        };
137
138        unsafe {
139            let result_ptr = stt(self.inner, float_array);
140            if result_ptr.is_null() {
141                return Err(WavifyError::InferenceError);
142            }
143
144            let result = CStr::from_ptr(result_ptr).to_string_lossy().into_owned();
145            free_result(result_ptr);
146            Ok(result)
147        }
148    }
149
150    /// Performs speech-to-text on an audio file.
151    ///
152    /// # Arguments
153    ///
154    /// * `filename` - A string slice that holds the path to the audio file.
155    ///
156    /// # Returns
157    ///
158    /// A result that, if successful, contains a `String` with the recognized text. Otherwise, it contains a `WavifyError`.
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// let text = engine.stt_from_file("path/to/audio.wav").unwrap();
164    /// ```
165    pub fn stt_from_file(&self, filename: &str) -> Result<String, WavifyError> {
166        let mut reader = WavReader::open(filename).unwrap();
167
168        let mut float_data = Vec::new();
169
170        for sample in reader.samples::<i16>() {
171            let float_sample = sample.unwrap() as f64 / i16::MAX as f64;
172            float_data.push(float_sample);
173        }
174
175        let data: Vec<f32> = float_data.iter().map(|v| *v as f32).collect();
176        log!(
177            Level::Debug,
178            "Audio codec: {:?} with len: {}",
179            &data[..10],
180            data.len()
181        );
182
183        self.stt(&data)
184    }
185}
186
187/// Represents the Speech-to-Text Engine.
188pub struct WakeWordEngine {
189    inner: *mut WakeWordEngineInner,
190}
191
192#[repr(C)]
193struct WakeWordEngineInner {
194    // C ABI does not allow zero-sized structs so we add a dummy field
195    _dummy: c_char,
196}
197
198impl WakeWordEngine {
199    /// Creates a new instance of `WakeWordEngine`.
200    ///
201    /// # Arguments
202    ///
203    /// * `model_path` - A string slice that holds the path to the model.
204    /// * `api_key` - A string slice that holds the API key.
205    ///
206    /// # Returns
207    ///
208    /// A result that, if successful, contains a new instance of `WakeWordEngine`. Otherwise, it contains a `WavifyError`.
209    ///
210    /// # Examples
211    ///
212    /// ```
213    /// let engine = WakeWordEngine::new("path/to/model", "api_key");
214    /// ```
215    pub fn new(model_path: &str, api_key: &str) -> Result<WakeWordEngine, WavifyError> {
216        let maybe_model_path_c = CString::new(model_path);
217        let maybe_api_key_c = CString::new(api_key);
218        match (maybe_model_path_c, maybe_api_key_c) {
219            (Ok(model_path_c), Ok(api_key_c)) => unsafe {
220                let inner = create_wake_word_engine(model_path_c.as_ptr(), api_key_c.as_ptr());
221                Ok(WakeWordEngine { inner })
222            },
223            (_, _) => Err(WavifyError::InitError),
224        }
225    }
226
227    /// Destroys the `WakeWordEngine` instance, freeing any resources.
228    pub fn destroy(self) {
229        unsafe { destroy_wake_word_engine(self.inner) }
230    }
231
232    /// Performs the wake word detection  on the given audio data.
233    ///
234    /// # Arguments
235    ///
236    /// * `data` - A slice of floating-point numbers representing the audio data. The length should be equal to 2 seconds sampled at 16kHz.
237    ///
238    /// # Returns
239    ///
240    /// A result that, if successful, contains the probability of a detected wake word. Otherwise, it contains a `WavifyError`.
241    ///
242    /// # Examples
243    ///
244    /// ```
245    /// let probability = engine.detect(&audio_data).unwrap();
246    /// ```
247    pub fn detect(&self, data: &[f32]) -> Result<f32, WavifyError> {
248        let float_array = FloatArray {
249            data: data.as_ptr(),
250            len: data.len(),
251        };
252
253        unsafe {
254            let result = detect_wake_word(self.inner, float_array);
255            if result.is_nan() {
256                return Err(WavifyError::InferenceError);
257            }
258            Ok(result)
259        }
260    }
261}
262
263/// Sets up the logger using the underlying library.
264///
265/// Available values are: `LogLevel::Trace`, `LogLevel::Debug`, `LogLevel::Info`, `LogLevel::Warn`, `LogLevel::Error`.
266/// If `None` is provided, the log level is set to `LogLevel::Info`.
267///
268/// # Arguments
269///
270/// * `level` - The logging level. This can be `Some(LogLevel)` or `None`.
271///
272/// # Examples
273///
274/// ```
275/// set_log_level(Some(LogLevel::Debug)); // Sets log level to Debug
276/// set_log_level(None); // Sets log level to default (Info)
277/// ```
278///
279/// # Panics
280///
281/// This function does not panic.
282///
283/// # Errors
284///
285/// This function prints an error message if it fails to create a C-compatible string for the log level.
286pub fn set_log_level(level: Option<LogLevel>) {
287    let level_str = level.as_ref().unwrap_or(&LogLevel::Info).as_str();
288    let c_level = match CString::new(level_str) {
289        Ok(lev) => lev,
290        Err(e) => {
291            eprintln!("Failed to create CString for logging: {:?}", e);
292            return;
293        }
294    };
295    unsafe {
296        setup_logger(c_level.as_ptr());
297    }
298}