Skip to main content

whisper_cpp_plus/
state.rs

1use crate::context::{ContextPtr, WhisperContext};
2use crate::error::{Result, WhisperError};
3use crate::params::FullParams;
4use std::sync::Arc;
5use whisper_cpp_plus_sys as ffi;
6
7pub struct WhisperState {
8    pub(crate) ptr: *mut ffi::whisper_state,
9    pub(crate) _context: Arc<ContextPtr>,
10}
11
12impl Drop for WhisperState {
13    fn drop(&mut self) {
14        unsafe {
15            if !self.ptr.is_null() {
16                ffi::whisper_free_state(self.ptr);
17            }
18        }
19    }
20}
21
22impl WhisperState {
23    pub fn new(context: &WhisperContext) -> Result<Self> {
24        let ptr = unsafe { ffi::whisper_init_state(context.ptr.0) };
25
26        if ptr.is_null() {
27            return Err(WhisperError::OutOfMemory);
28        }
29
30        Ok(Self {
31            ptr,
32            _context: Arc::clone(&context.ptr),
33        })
34    }
35
36    pub fn full(&mut self, params: FullParams, audio: &[f32]) -> Result<()> {
37        if audio.is_empty() {
38            return Err(WhisperError::InvalidAudioFormat);
39        }
40
41        let ret = unsafe {
42            ffi::whisper_full_with_state(
43                self._context.0,
44                self.ptr,
45                params.as_raw(),
46                audio.as_ptr(),
47                audio.len() as i32,
48            )
49        };
50
51        if ret != 0 {
52            return Err(WhisperError::TranscriptionError(format!(
53                "whisper_full returned {}",
54                ret
55            )));
56        }
57
58        Ok(())
59    }
60
61    pub fn full_parallel(
62        &mut self,
63        params: FullParams,
64        audio: &[f32],
65        n_processors: i32,
66    ) -> Result<()> {
67        if audio.is_empty() {
68            return Err(WhisperError::InvalidAudioFormat);
69        }
70
71        if n_processors < 1 {
72            return Err(WhisperError::InvalidParameter(
73                "n_processors must be at least 1".into(),
74            ));
75        }
76
77        let ret = unsafe {
78            ffi::whisper_full_parallel(
79                self._context.0,
80                params.as_raw(),
81                audio.as_ptr(),
82                audio.len() as i32,
83                n_processors,
84            )
85        };
86
87        if ret != 0 {
88            return Err(WhisperError::TranscriptionError(format!(
89                "whisper_full_parallel returned {}",
90                ret
91            )));
92        }
93
94        Ok(())
95    }
96
97    pub fn full_n_segments(&self) -> i32 {
98        unsafe { ffi::whisper_full_n_segments_from_state(self.ptr) }
99    }
100
101    pub fn full_lang_id(&self) -> i32 {
102        unsafe { ffi::whisper_full_lang_id_from_state(self.ptr) }
103    }
104
105    pub fn full_get_segment_text(&self, i_segment: i32) -> Result<String> {
106        let text_ptr =
107            unsafe { ffi::whisper_full_get_segment_text_from_state(self.ptr, i_segment) };
108
109        if text_ptr.is_null() {
110            return Err(WhisperError::InvalidContext);
111        }
112
113        let c_str = unsafe { std::ffi::CStr::from_ptr(text_ptr) };
114        Ok(c_str.to_string_lossy().into_owned())
115    }
116
117    pub fn full_get_segment_timestamps(&self, i_segment: i32) -> (i64, i64) {
118        unsafe {
119            let t0 = ffi::whisper_full_get_segment_t0_from_state(self.ptr, i_segment);
120            let t1 = ffi::whisper_full_get_segment_t1_from_state(self.ptr, i_segment);
121            (t0, t1)
122        }
123    }
124
125    pub fn full_get_segment_speaker_turn_next(&self, i_segment: i32) -> bool {
126        unsafe { ffi::whisper_full_get_segment_speaker_turn_next_from_state(self.ptr, i_segment) }
127    }
128
129    pub fn full_n_tokens(&self, i_segment: i32) -> i32 {
130        unsafe { ffi::whisper_full_n_tokens_from_state(self.ptr, i_segment) }
131    }
132
133    pub fn full_get_token_text(&self, i_segment: i32, i_token: i32) -> Result<String> {
134        let text_ptr = unsafe {
135            ffi::whisper_full_get_token_text_from_state(
136                self._context.0,
137                self.ptr,
138                i_segment,
139                i_token,
140            )
141        };
142
143        if text_ptr.is_null() {
144            return Err(WhisperError::InvalidContext);
145        }
146
147        let c_str = unsafe { std::ffi::CStr::from_ptr(text_ptr) };
148        Ok(c_str.to_string_lossy().into_owned())
149    }
150
151    pub fn full_get_token_id(&self, i_segment: i32, i_token: i32) -> i32 {
152        unsafe { ffi::whisper_full_get_token_id_from_state(self.ptr, i_segment, i_token) }
153    }
154
155    pub fn full_get_token_data(
156        &self,
157        i_segment: i32,
158        i_token: i32,
159    ) -> Option<ffi::whisper_token_data> {
160        let data =
161            unsafe { ffi::whisper_full_get_token_data_from_state(self.ptr, i_segment, i_token) };
162
163        if data.id == -1 {
164            None
165        } else {
166            Some(data)
167        }
168    }
169
170    pub fn full_get_token_prob(&self, i_segment: i32, i_token: i32) -> f32 {
171        unsafe { ffi::whisper_full_get_token_p_from_state(self.ptr, i_segment, i_token) }
172    }
173}
174
175unsafe impl Send for WhisperState {}
176
177#[derive(Debug, Clone)]
178pub struct TranscriptionResult {
179    pub text: String,
180    pub segments: Vec<Segment>,
181}
182
183#[derive(Debug, Clone)]
184pub struct Segment {
185    pub start_ms: i64,
186    pub end_ms: i64,
187    pub text: String,
188    pub speaker_turn_next: bool,
189}
190
191impl Segment {
192    pub fn start_seconds(&self) -> f64 {
193        self.start_ms as f64 / 1000.0
194    }
195
196    pub fn end_seconds(&self) -> f64 {
197        self.end_ms as f64 / 1000.0
198    }
199}