whisper_cpp_plus/
state.rs1use crate::context::{ContextPtr, WhisperContext};
2use crate::error::{Result, WhisperError};
3use crate::params::FullParams;
4use std::sync::Arc;
5use whisper_cpp_plus_sys as ffi;
6
7pub struct WhisperState {
8 pub(crate) ptr: *mut ffi::whisper_state,
9 pub(crate) _context: Arc<ContextPtr>,
10}
11
12impl Drop for WhisperState {
13 fn drop(&mut self) {
14 unsafe {
15 if !self.ptr.is_null() {
16 ffi::whisper_free_state(self.ptr);
17 }
18 }
19 }
20}
21
22impl WhisperState {
23 pub fn new(context: &WhisperContext) -> Result<Self> {
24 let ptr = unsafe { ffi::whisper_init_state(context.ptr.0) };
25
26 if ptr.is_null() {
27 return Err(WhisperError::OutOfMemory);
28 }
29
30 Ok(Self {
31 ptr,
32 _context: Arc::clone(&context.ptr),
33 })
34 }
35
36 pub fn full(&mut self, params: FullParams, audio: &[f32]) -> Result<()> {
37 if audio.is_empty() {
38 return Err(WhisperError::InvalidAudioFormat);
39 }
40
41 let ret = unsafe {
42 ffi::whisper_full_with_state(
43 self._context.0,
44 self.ptr,
45 params.as_raw(),
46 audio.as_ptr(),
47 audio.len() as i32,
48 )
49 };
50
51 if ret != 0 {
52 return Err(WhisperError::TranscriptionError(format!(
53 "whisper_full returned {}",
54 ret
55 )));
56 }
57
58 Ok(())
59 }
60
61 pub fn full_parallel(
62 &mut self,
63 params: FullParams,
64 audio: &[f32],
65 n_processors: i32,
66 ) -> Result<()> {
67 if audio.is_empty() {
68 return Err(WhisperError::InvalidAudioFormat);
69 }
70
71 if n_processors < 1 {
72 return Err(WhisperError::InvalidParameter(
73 "n_processors must be at least 1".into(),
74 ));
75 }
76
77 let ret = unsafe {
78 ffi::whisper_full_parallel(
79 self._context.0,
80 params.as_raw(),
81 audio.as_ptr(),
82 audio.len() as i32,
83 n_processors,
84 )
85 };
86
87 if ret != 0 {
88 return Err(WhisperError::TranscriptionError(format!(
89 "whisper_full_parallel returned {}",
90 ret
91 )));
92 }
93
94 Ok(())
95 }
96
97 pub fn full_n_segments(&self) -> i32 {
98 unsafe { ffi::whisper_full_n_segments_from_state(self.ptr) }
99 }
100
101 pub fn full_lang_id(&self) -> i32 {
102 unsafe { ffi::whisper_full_lang_id_from_state(self.ptr) }
103 }
104
105 pub fn full_get_segment_text(&self, i_segment: i32) -> Result<String> {
106 let text_ptr =
107 unsafe { ffi::whisper_full_get_segment_text_from_state(self.ptr, i_segment) };
108
109 if text_ptr.is_null() {
110 return Err(WhisperError::InvalidContext);
111 }
112
113 let c_str = unsafe { std::ffi::CStr::from_ptr(text_ptr) };
114 Ok(c_str.to_string_lossy().into_owned())
115 }
116
117 pub fn full_get_segment_timestamps(&self, i_segment: i32) -> (i64, i64) {
118 unsafe {
119 let t0 = ffi::whisper_full_get_segment_t0_from_state(self.ptr, i_segment);
120 let t1 = ffi::whisper_full_get_segment_t1_from_state(self.ptr, i_segment);
121 (t0, t1)
122 }
123 }
124
125 pub fn full_get_segment_speaker_turn_next(&self, i_segment: i32) -> bool {
126 unsafe { ffi::whisper_full_get_segment_speaker_turn_next_from_state(self.ptr, i_segment) }
127 }
128
129 pub fn full_n_tokens(&self, i_segment: i32) -> i32 {
130 unsafe { ffi::whisper_full_n_tokens_from_state(self.ptr, i_segment) }
131 }
132
133 pub fn full_get_token_text(&self, i_segment: i32, i_token: i32) -> Result<String> {
134 let text_ptr = unsafe {
135 ffi::whisper_full_get_token_text_from_state(
136 self._context.0,
137 self.ptr,
138 i_segment,
139 i_token,
140 )
141 };
142
143 if text_ptr.is_null() {
144 return Err(WhisperError::InvalidContext);
145 }
146
147 let c_str = unsafe { std::ffi::CStr::from_ptr(text_ptr) };
148 Ok(c_str.to_string_lossy().into_owned())
149 }
150
151 pub fn full_get_token_id(&self, i_segment: i32, i_token: i32) -> i32 {
152 unsafe { ffi::whisper_full_get_token_id_from_state(self.ptr, i_segment, i_token) }
153 }
154
155 pub fn full_get_token_data(
156 &self,
157 i_segment: i32,
158 i_token: i32,
159 ) -> Option<ffi::whisper_token_data> {
160 let data =
161 unsafe { ffi::whisper_full_get_token_data_from_state(self.ptr, i_segment, i_token) };
162
163 if data.id == -1 {
164 None
165 } else {
166 Some(data)
167 }
168 }
169
170 pub fn full_get_token_prob(&self, i_segment: i32, i_token: i32) -> f32 {
171 unsafe { ffi::whisper_full_get_token_p_from_state(self.ptr, i_segment, i_token) }
172 }
173}
174
175unsafe impl Send for WhisperState {}
176
177#[derive(Debug, Clone)]
178pub struct TranscriptionResult {
179 pub text: String,
180 pub segments: Vec<Segment>,
181}
182
183#[derive(Debug, Clone)]
184pub struct Segment {
185 pub start_ms: i64,
186 pub end_ms: i64,
187 pub text: String,
188 pub speaker_turn_next: bool,
189}
190
191impl Segment {
192 pub fn start_seconds(&self) -> f64 {
193 self.start_ms as f64 / 1000.0
194 }
195
196 pub fn end_seconds(&self) -> f64 {
197 self.end_ms as f64 / 1000.0
198 }
199}