Skip to main content

whisper_cpp_plus/
context.rs

1use crate::error::{Result, WhisperError};
2use std::path::Path;
3use std::sync::Arc;
4use whisper_cpp_plus_sys as ffi;
5
6pub struct WhisperContext {
7    pub(crate) ptr: Arc<ContextPtr>,
8}
9
10pub(crate) struct ContextPtr(pub(crate) *mut ffi::whisper_context);
11
12unsafe impl Send for ContextPtr {}
13unsafe impl Sync for ContextPtr {}
14
15impl Drop for ContextPtr {
16    fn drop(&mut self) {
17        unsafe {
18            if !self.0.is_null() {
19                ffi::whisper_free(self.0);
20            }
21        }
22    }
23}
24
25impl WhisperContext {
26    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
27        let path_str = model_path
28            .as_ref()
29            .to_str()
30            .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
31
32        let c_path = std::ffi::CString::new(path_str)?;
33
34        let ptr = unsafe {
35            ffi::whisper_init_from_file_with_params(
36                c_path.as_ptr(),
37                ffi::whisper_context_default_params(),
38            )
39        };
40
41        if ptr.is_null() {
42            return Err(WhisperError::ModelLoadError(
43                "Failed to load model".into(),
44            ));
45        }
46
47        Ok(Self {
48            ptr: Arc::new(ContextPtr(ptr)),
49        })
50    }
51
52    pub fn new_from_buffer(buffer: &[u8]) -> Result<Self> {
53        let ptr = unsafe {
54            ffi::whisper_init_from_buffer_with_params(
55                buffer.as_ptr() as *mut std::os::raw::c_void,
56                buffer.len(),
57                ffi::whisper_context_default_params(),
58            )
59        };
60
61        if ptr.is_null() {
62            return Err(WhisperError::ModelLoadError(
63                "Failed to load model from buffer".into(),
64            ));
65        }
66
67        Ok(Self {
68            ptr: Arc::new(ContextPtr(ptr)),
69        })
70    }
71
72    pub fn is_multilingual(&self) -> bool {
73        unsafe { ffi::whisper_is_multilingual(self.ptr.0) != 0 }
74    }
75
76    pub fn n_vocab(&self) -> i32 {
77        unsafe { ffi::whisper_n_vocab(self.ptr.0) }
78    }
79
80    pub fn n_audio_ctx(&self) -> i32 {
81        unsafe { ffi::whisper_n_audio_ctx(self.ptr.0) }
82    }
83
84    pub fn n_text_ctx(&self) -> i32 {
85        unsafe { ffi::whisper_n_text_ctx(self.ptr.0) }
86    }
87
88    pub fn n_len(&self) -> i32 {
89        unsafe { ffi::whisper_n_len(self.ptr.0) }
90    }
91
92}
93
94impl Clone for WhisperContext {
95    fn clone(&self) -> Self {
96        Self {
97            ptr: Arc::clone(&self.ptr),
98        }
99    }
100}