wavify/lib.rs
1use hound::WavReader;
2use std::ffi::{CStr, CString};
3use std::os::raw::{c_char, c_float};
4
5#[macro_use]
6extern crate log;
7use log::Level;
8
9/// Represents the Speech-to-Text Engine.
10pub struct SttEngine {
11 inner: *mut SttEngineInner,
12}
13
14#[repr(C)]
15struct SttEngineInner {
16 // C ABI does not allow zero-sized structs so we add a dummy field
17 _dummy: c_char,
18}
19
20/// A struct representing an array of floating-point numbers.
21#[repr(C)]
22pub struct FloatArray {
23 /// Pointer to the array data.
24 pub data: *const f32,
25 /// Length of the array.
26 pub len: usize,
27}
28
29/// Represents possible errors that can occur during the speech-to-text process.
30#[derive(Debug)]
31pub enum WavifyError {
32 /// Error that occurs during initialization of the STT engine.
33 InitError,
34 /// Error that occurs during inference.
35 InferenceError,
36}
37
38impl std::fmt::Display for WavifyError {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Self::InitError => write!(f, "Failed to initialize"),
42 Self::InferenceError => write!(f, "Failed to run inference"),
43 }
44 }
45}
46
47#[derive(Debug)]
48pub enum LogLevel {
49 Trace,
50 Debug,
51 Info,
52 Warn,
53 Error,
54}
55impl LogLevel {
56 fn as_str(&self) -> &str {
57 match self {
58 LogLevel::Trace => "trace",
59 LogLevel::Debug => "debug",
60 LogLevel::Info => "info",
61 LogLevel::Warn => "warn",
62 LogLevel::Error => "error",
63 }
64 }
65}
66
67impl std::error::Error for WavifyError {}
68
69extern "C" {
70 fn create_stt_engine(model_path: *const c_char, api_key: *const c_char) -> *mut SttEngineInner;
71 fn destroy_stt_engine(model: *mut SttEngineInner);
72 fn stt(model: *mut SttEngineInner, data: FloatArray) -> *mut c_char;
73 fn free_result(result: *mut c_char);
74 fn create_wake_word_engine(
75 model_path: *const c_char,
76 api_key: *const c_char,
77 ) -> *mut WakeWordEngineInner;
78 fn destroy_wake_word_engine(model: *mut WakeWordEngineInner);
79 fn detect_wake_word(model: *mut WakeWordEngineInner, data: FloatArray) -> c_float;
80 fn setup_logger(level: *const c_char);
81}
82
83impl SttEngine {
84 /// Creates a new instance of `SttEngine`.
85 ///
86 /// # Arguments
87 ///
88 /// * `model_path` - A string slice that holds the path to the model.
89 /// * `api_key` - A string slice that holds the API key.
90 ///
91 /// # Returns
92 ///
93 /// A result that, if successful, contains a new instance of `SttEngine`. Otherwise, it contains a `WavifyError`.
94 ///
95 /// # Examples
96 ///
97 /// ```
98 /// let engine = SttEngine::new("path/to/model", "api_key");
99 /// ```
100 pub fn new(model_path: &str, api_key: &str) -> Result<SttEngine, WavifyError> {
101 let maybe_model_path_c = CString::new(model_path);
102 let maybe_api_key_c = CString::new(api_key);
103 match (maybe_model_path_c, maybe_api_key_c) {
104 (Ok(model_path_c), Ok(api_key_c)) => unsafe {
105 let inner = create_stt_engine(model_path_c.as_ptr(), api_key_c.as_ptr());
106 Ok(SttEngine { inner })
107 },
108 (_, _) => Err(WavifyError::InitError),
109 }
110 }
111
112 /// Destroys the `SttEngine` instance, freeing any resources.
113 pub fn destroy(self) {
114 unsafe { destroy_stt_engine(self.inner) }
115 }
116
117 /// Performs speech-to-text on the given audio data.
118 ///
119 /// # Arguments
120 ///
121 /// * `data` - A slice of floating-point numbers representing the audio data.
122 ///
123 /// # Returns
124 ///
125 /// A result that, if successful, contains a `String` with the recognized text. Otherwise, it contains a `WavifyError`.
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// let text = engine.stt(&audio_data).unwrap();
131 /// ```
132 pub fn stt(&self, data: &[f32]) -> Result<String, WavifyError> {
133 let float_array = FloatArray {
134 data: data.as_ptr(),
135 len: data.len(),
136 };
137
138 unsafe {
139 let result_ptr = stt(self.inner, float_array);
140 if result_ptr.is_null() {
141 return Err(WavifyError::InferenceError);
142 }
143
144 let result = CStr::from_ptr(result_ptr).to_string_lossy().into_owned();
145 free_result(result_ptr);
146 Ok(result)
147 }
148 }
149
150 /// Performs speech-to-text on an audio file.
151 ///
152 /// # Arguments
153 ///
154 /// * `filename` - A string slice that holds the path to the audio file.
155 ///
156 /// # Returns
157 ///
158 /// A result that, if successful, contains a `String` with the recognized text. Otherwise, it contains a `WavifyError`.
159 ///
160 /// # Examples
161 ///
162 /// ```
163 /// let text = engine.stt_from_file("path/to/audio.wav").unwrap();
164 /// ```
165 pub fn stt_from_file(&self, filename: &str) -> Result<String, WavifyError> {
166 let mut reader = WavReader::open(filename).unwrap();
167
168 let mut float_data = Vec::new();
169
170 for sample in reader.samples::<i16>() {
171 let float_sample = sample.unwrap() as f64 / i16::MAX as f64;
172 float_data.push(float_sample);
173 }
174
175 let data: Vec<f32> = float_data.iter().map(|v| *v as f32).collect();
176 log!(
177 Level::Debug,
178 "Audio codec: {:?} with len: {}",
179 &data[..10],
180 data.len()
181 );
182
183 self.stt(&data)
184 }
185}
186
187/// Represents the Speech-to-Text Engine.
188pub struct WakeWordEngine {
189 inner: *mut WakeWordEngineInner,
190}
191
192#[repr(C)]
193struct WakeWordEngineInner {
194 // C ABI does not allow zero-sized structs so we add a dummy field
195 _dummy: c_char,
196}
197
198impl WakeWordEngine {
199 /// Creates a new instance of `WakeWordEngine`.
200 ///
201 /// # Arguments
202 ///
203 /// * `model_path` - A string slice that holds the path to the model.
204 /// * `api_key` - A string slice that holds the API key.
205 ///
206 /// # Returns
207 ///
208 /// A result that, if successful, contains a new instance of `WakeWordEngine`. Otherwise, it contains a `WavifyError`.
209 ///
210 /// # Examples
211 ///
212 /// ```
213 /// let engine = WakeWordEngine::new("path/to/model", "api_key");
214 /// ```
215 pub fn new(model_path: &str, api_key: &str) -> Result<WakeWordEngine, WavifyError> {
216 let maybe_model_path_c = CString::new(model_path);
217 let maybe_api_key_c = CString::new(api_key);
218 match (maybe_model_path_c, maybe_api_key_c) {
219 (Ok(model_path_c), Ok(api_key_c)) => unsafe {
220 let inner = create_wake_word_engine(model_path_c.as_ptr(), api_key_c.as_ptr());
221 Ok(WakeWordEngine { inner })
222 },
223 (_, _) => Err(WavifyError::InitError),
224 }
225 }
226
227 /// Destroys the `WakeWordEngine` instance, freeing any resources.
228 pub fn destroy(self) {
229 unsafe { destroy_wake_word_engine(self.inner) }
230 }
231
232 /// Performs the wake word detection on the given audio data.
233 ///
234 /// # Arguments
235 ///
236 /// * `data` - A slice of floating-point numbers representing the audio data. The length should be equal to 2 seconds sampled at 16kHz.
237 ///
238 /// # Returns
239 ///
240 /// A result that, if successful, contains the probability of a detected wake word. Otherwise, it contains a `WavifyError`.
241 ///
242 /// # Examples
243 ///
244 /// ```
245 /// let probability = engine.detect(&audio_data).unwrap();
246 /// ```
247 pub fn detect(&self, data: &[f32]) -> Result<f32, WavifyError> {
248 let float_array = FloatArray {
249 data: data.as_ptr(),
250 len: data.len(),
251 };
252
253 unsafe {
254 let result = detect_wake_word(self.inner, float_array);
255 if result.is_nan() {
256 return Err(WavifyError::InferenceError);
257 }
258 Ok(result)
259 }
260 }
261}
262
263/// Sets up the logger using the underlying library.
264///
265/// Available values are: `LogLevel::Trace`, `LogLevel::Debug`, `LogLevel::Info`, `LogLevel::Warn`, `LogLevel::Error`.
266/// If `None` is provided, the log level is set to `LogLevel::Info`.
267///
268/// # Arguments
269///
270/// * `level` - The logging level. This can be `Some(LogLevel)` or `None`.
271///
272/// # Examples
273///
274/// ```
275/// set_log_level(Some(LogLevel::Debug)); // Sets log level to Debug
276/// set_log_level(None); // Sets log level to default (Info)
277/// ```
278///
279/// # Panics
280///
281/// This function does not panic.
282///
283/// # Errors
284///
285/// This function prints an error message if it fails to create a C-compatible string for the log level.
286pub fn set_log_level(level: Option<LogLevel>) {
287 let level_str = level.as_ref().unwrap_or(&LogLevel::Info).as_str();
288 let c_level = match CString::new(level_str) {
289 Ok(lev) => lev,
290 Err(e) => {
291 eprintln!("Failed to create CString for logging: {:?}", e);
292 return;
293 }
294 };
295 unsafe {
296 setup_logger(c_level.as_ptr());
297 }
298}