whisper_cpp_plus/
params.rs1use std::ffi::CString;
2use whisper_cpp_plus_sys as ffi;
3
4#[derive(Clone, Copy, Debug)]
5pub enum SamplingStrategy {
6 Greedy { best_of: i32 },
7 BeamSearch { beam_size: i32 },
8}
9
10#[derive(Clone)]
11pub struct FullParams {
12 pub(crate) inner: ffi::whisper_full_params,
13 language: Option<CString>,
14 initial_prompt: Option<CString>,
15}
16
17unsafe impl Send for FullParams {}
19unsafe impl Sync for FullParams {}
20
21impl FullParams {
22 pub fn new(strategy: SamplingStrategy) -> Self {
23 let inner = unsafe {
24 match strategy {
25 SamplingStrategy::Greedy { best_of } => {
26 let mut params = ffi::whisper_full_default_params(
27 ffi::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY,
28 );
29 params.greedy.best_of = best_of;
30 params
31 }
32 SamplingStrategy::BeamSearch { beam_size } => {
33 let mut params = ffi::whisper_full_default_params(
34 ffi::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH,
35 );
36 params.beam_search.beam_size = beam_size;
37 params
38 }
39 }
40 };
41
42 let mut params = Self {
43 inner,
44 language: None,
45 initial_prompt: None,
46 };
47
48 params.inner.n_threads = (num_cpus::get() / 2).max(1) as i32;
49 params.inner.suppress_blank = true;
50 params.inner.suppress_nst = true;
51 params.inner.temperature = 0.0;
52 params.inner.max_initial_ts = 1.0;
53 params.inner.length_penalty = -1.0;
54
55 params
56 }
57
58 pub(crate) fn as_raw(&self) -> ffi::whisper_full_params {
59 let mut params = self.inner;
60
61 if let Some(ref lang) = self.language {
62 params.language = lang.as_ptr();
63 }
64
65 if let Some(ref prompt) = self.initial_prompt {
66 params.initial_prompt = prompt.as_ptr();
67 }
68
69 params
70 }
71
72 pub fn language(mut self, lang: &str) -> Self {
73 self.language = CString::new(lang).ok();
74 if let Some(ref lang_cstr) = self.language {
75 self.inner.language = lang_cstr.as_ptr();
76 }
77 self
78 }
79
80 pub fn translate(mut self, translate: bool) -> Self {
81 self.inner.translate = translate;
82 self
83 }
84
85 pub fn no_context(mut self, no_context: bool) -> Self {
86 self.inner.no_context = no_context;
87 self
88 }
89
90 pub fn no_timestamps(mut self, no_timestamps: bool) -> Self {
91 self.inner.no_timestamps = no_timestamps;
92 self
93 }
94
95 pub fn single_segment(mut self, single_segment: bool) -> Self {
96 self.inner.single_segment = single_segment;
97 self
98 }
99
100 pub fn print_special(mut self, print_special: bool) -> Self {
101 self.inner.print_special = print_special;
102 self
103 }
104
105 pub fn print_progress(mut self, print_progress: bool) -> Self {
106 self.inner.print_progress = print_progress;
107 self
108 }
109
110 pub fn print_realtime(mut self, print_realtime: bool) -> Self {
111 self.inner.print_realtime = print_realtime;
112 self
113 }
114
115 pub fn print_timestamps(mut self, print_timestamps: bool) -> Self {
116 self.inner.print_timestamps = print_timestamps;
117 self
118 }
119
120 pub fn token_timestamps(mut self, token_timestamps: bool) -> Self {
121 self.inner.token_timestamps = token_timestamps;
122 self
123 }
124
125 pub fn thold_pt(mut self, thold_pt: f32) -> Self {
126 self.inner.thold_pt = thold_pt;
127 self
128 }
129
130 pub fn thold_ptsum(mut self, thold_ptsum: f32) -> Self {
131 self.inner.thold_ptsum = thold_ptsum;
132 self
133 }
134
135 pub fn max_len(mut self, max_len: i32) -> Self {
136 self.inner.max_len = max_len;
137 self
138 }
139
140 pub fn split_on_word(mut self, split_on_word: bool) -> Self {
141 self.inner.split_on_word = split_on_word;
142 self
143 }
144
145 pub fn max_tokens(mut self, max_tokens: i32) -> Self {
146 self.inner.max_tokens = max_tokens;
147 self
148 }
149
150
151 pub fn debug_mode(mut self, debug_mode: bool) -> Self {
152 self.inner.debug_mode = debug_mode;
153 self
154 }
155
156 pub fn audio_ctx(mut self, audio_ctx: i32) -> Self {
157 self.inner.audio_ctx = audio_ctx;
158 self
159 }
160
161 pub fn tdrz_enable(mut self, tdrz_enable: bool) -> Self {
162 self.inner.tdrz_enable = tdrz_enable;
163 self
164 }
165
166 pub fn suppress_regex(mut self, suppress_regex: Option<&str>) -> Self {
167 if let Some(regex) = suppress_regex {
168 if let Ok(c_regex) = CString::new(regex) {
169 self.inner.suppress_regex = c_regex.as_ptr();
170 }
171 } else {
172 self.inner.suppress_regex = std::ptr::null();
173 }
174 self
175 }
176
177 pub fn initial_prompt(mut self, prompt: &str) -> Self {
178 self.initial_prompt = CString::new(prompt).ok();
179 if let Some(ref prompt_cstr) = self.initial_prompt {
180 self.inner.initial_prompt = prompt_cstr.as_ptr();
181 }
182 self
183 }
184
185 pub fn prompt_tokens(mut self, tokens: &[i32]) -> Self {
186 self.inner.prompt_tokens = tokens.as_ptr();
187 self.inner.prompt_n_tokens = tokens.len() as i32;
188 self
189 }
190
191 pub fn temperature(mut self, temperature: f32) -> Self {
192 self.inner.temperature = temperature;
193 self
194 }
195
196 pub fn temperature_inc(mut self, temperature_inc: f32) -> Self {
197 self.inner.temperature_inc = temperature_inc;
198 self
199 }
200
201 pub fn entropy_thold(mut self, entropy_thold: f32) -> Self {
202 self.inner.entropy_thold = entropy_thold;
203 self
204 }
205
206 pub fn logprob_thold(mut self, logprob_thold: f32) -> Self {
207 self.inner.logprob_thold = logprob_thold;
208 self
209 }
210
211 pub fn n_threads(mut self, n_threads: i32) -> Self {
212 self.inner.n_threads = n_threads;
213 self
214 }
215
216 pub fn offset_ms(mut self, offset_ms: i32) -> Self {
217 self.inner.offset_ms = offset_ms;
218 self
219 }
220
221 pub fn duration_ms(mut self, duration_ms: i32) -> Self {
222 self.inner.duration_ms = duration_ms;
223 self
224 }
225}
226
227impl Default for FullParams {
228 fn default() -> Self {
229 Self::new(SamplingStrategy::Greedy { best_of: 1 })
230 }
231}
232
233#[derive(Clone)]
234pub struct TranscriptionParams {
235 params: FullParams,
236}
237
238impl TranscriptionParams {
239 pub fn builder() -> TranscriptionParamsBuilder {
240 TranscriptionParamsBuilder::new()
241 }
242
243 pub(crate) fn into_full_params(self) -> FullParams {
244 self.params
245 }
246}
247
248#[derive(Clone)]
249pub struct TranscriptionParamsBuilder {
250 params: FullParams,
251}
252
253impl TranscriptionParamsBuilder {
254 pub fn new() -> Self {
255 Self {
256 params: FullParams::default(),
257 }
258 }
259
260 pub fn language(mut self, lang: &str) -> Self {
261 self.params = self.params.language(lang);
262 self
263 }
264
265 pub fn translate(mut self, translate: bool) -> Self {
266 self.params = self.params.translate(translate);
267 self
268 }
269
270 pub fn temperature(mut self, temperature: f32) -> Self {
271 self.params = self.params.temperature(temperature);
272 self
273 }
274
275 pub fn enable_timestamps(mut self) -> Self {
276 self.params = self.params.no_timestamps(false);
277 self
278 }
279
280 pub fn disable_timestamps(mut self) -> Self {
281 self.params = self.params.no_timestamps(true);
282 self
283 }
284
285 pub fn single_segment(mut self, single: bool) -> Self {
286 self.params = self.params.single_segment(single);
287 self
288 }
289
290 pub fn max_tokens(mut self, max_tokens: i32) -> Self {
291 self.params = self.params.max_tokens(max_tokens);
292 self
293 }
294
295 pub fn initial_prompt(mut self, prompt: &str) -> Self {
296 self.params = self.params.initial_prompt(prompt);
297 self
298 }
299
300 pub fn n_threads(mut self, n_threads: i32) -> Self {
301 self.params = self.params.n_threads(n_threads);
302 self
303 }
304
305 pub fn build(self) -> TranscriptionParams {
306 TranscriptionParams {
307 params: self.params,
308 }
309 }
310}
311
312impl Default for TranscriptionParamsBuilder {
313 fn default() -> Self {
314 Self::new()
315 }
316}