Skip to main content

whisper_cpp_plus/
context.rs

1use crate::error::{Result, WhisperError};
2use std::path::Path;
3use std::sync::Arc;
4use whisper_cpp_plus_sys as ffi;
5
6pub struct WhisperContext {
7    pub(crate) ptr: Arc<ContextPtr>,
8}
9
10pub(crate) struct ContextPtr(pub(crate) *mut ffi::whisper_context);
11
12unsafe impl Send for ContextPtr {}
13unsafe impl Sync for ContextPtr {}
14
15impl Drop for ContextPtr {
16    fn drop(&mut self) {
17        unsafe {
18            if !self.0.is_null() {
19                ffi::whisper_free(self.0);
20            }
21        }
22    }
23}
24
25impl WhisperContext {
26    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
27        let path_str = model_path
28            .as_ref()
29            .to_str()
30            .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
31
32        let c_path = std::ffi::CString::new(path_str)?;
33
34        let ptr = unsafe {
35            ffi::whisper_init_from_file_with_params(
36                c_path.as_ptr(),
37                ffi::whisper_context_default_params(),
38            )
39        };
40
41        if ptr.is_null() {
42            return Err(WhisperError::ModelLoadError("Failed to load model".into()));
43        }
44
45        Ok(Self {
46            ptr: Arc::new(ContextPtr(ptr)),
47        })
48    }
49
50    pub fn new_from_buffer(buffer: &[u8]) -> Result<Self> {
51        let ptr = unsafe {
52            ffi::whisper_init_from_buffer_with_params(
53                buffer.as_ptr() as *mut std::os::raw::c_void,
54                buffer.len(),
55                ffi::whisper_context_default_params(),
56            )
57        };
58
59        if ptr.is_null() {
60            return Err(WhisperError::ModelLoadError(
61                "Failed to load model from buffer".into(),
62            ));
63        }
64
65        Ok(Self {
66            ptr: Arc::new(ContextPtr(ptr)),
67        })
68    }
69
70    pub fn is_multilingual(&self) -> bool {
71        unsafe { ffi::whisper_is_multilingual(self.ptr.0) != 0 }
72    }
73
74    pub fn n_vocab(&self) -> i32 {
75        unsafe { ffi::whisper_n_vocab(self.ptr.0) }
76    }
77
78    pub fn n_audio_ctx(&self) -> i32 {
79        unsafe { ffi::whisper_n_audio_ctx(self.ptr.0) }
80    }
81
82    pub fn n_text_ctx(&self) -> i32 {
83        unsafe { ffi::whisper_n_text_ctx(self.ptr.0) }
84    }
85
86    pub fn n_len(&self) -> i32 {
87        unsafe { ffi::whisper_n_len(self.ptr.0) }
88    }
89}
90
91impl Clone for WhisperContext {
92    fn clone(&self) -> Self {
93        Self {
94            ptr: Arc::clone(&self.ptr),
95        }
96    }
97}