whisper_rs_2/
whisper_params.rs1use std::ffi::{c_int, CStr};
2use std::marker::PhantomData;
3use crate::WhisperContext;
4use libffi::high::Closure2;
5
6pub enum SamplingStrategy {
7 Greedy {
8 n_past: c_int,
9 },
10 BeamSearch {
12 n_past: c_int,
13 beam_width: c_int,
14 n_best: c_int,
15 },
16}
17
18pub struct FullParams<'a> {
19 pub(crate) fp: whisper_rs_sys2::whisper_full_params,
20 phantom: PhantomData<&'a str>,
21}
22
23impl<'a> FullParams<'a> {
24 pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> {
26 let mut fp = unsafe {
27 whisper_rs_sys2::whisper_full_default_params(match sampling_strategy {
28 SamplingStrategy::Greedy { .. } => {
29 whisper_rs_sys2::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY
30 }
31 SamplingStrategy::BeamSearch { .. } => {
32 whisper_rs_sys2::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH
33 }
34 } as _)
35 };
36
37 match sampling_strategy {
38 SamplingStrategy::Greedy { n_past } => {
39 fp.greedy.n_past = n_past;
40 }
41 SamplingStrategy::BeamSearch {
42 n_past,
43 beam_width,
44 n_best,
45 } => {
46 fp.beam_search.n_past = n_past;
47 fp.beam_search.beam_width = beam_width;
48 fp.beam_search.n_best = n_best;
49 }
50 }
51
52 Self {
53 fp,
54 phantom: PhantomData,
55 }
56 }
57
58 pub fn set_n_threads(&mut self, n_threads: c_int) {
62 self.fp.n_threads = n_threads;
63 }
64
65 pub fn set_offset_ms(&mut self, offset_ms: c_int) {
69 self.fp.offset_ms = offset_ms;
70 }
71
72 pub fn set_translate(&mut self, translate: bool) {
76 self.fp.translate = translate;
77 }
78
79 pub fn set_no_context(&mut self, no_context: bool) {
83 self.fp.no_context = no_context;
84 }
85
86 pub fn set_print_special_tokens(&mut self, print_special_tokens: bool) {
90 self.fp.print_special_tokens = print_special_tokens;
91 }
92
93 pub fn set_print_progress(&mut self, print_progress: bool) {
97 self.fp.print_progress = print_progress;
98 }
99
100 pub fn set_print_realtime(&mut self, print_realtime: bool) {
104 self.fp.print_realtime = print_realtime;
105 }
106
107 pub fn set_print_timestamps(&mut self, print_timestamps: bool) {
111 self.fp.print_timestamps = print_timestamps;
112 }
113
114 pub fn set_language(&mut self, language: &'a str) {
118 self.fp.language = language.as_ptr() as *const _;
119 }
120
121 pub unsafe fn set_new_segment_callback(
133 &mut self,
134 _new_segment_callback: fn(a: WhisperContext) -> (),
135 ) {
136 let closure: &'static _ = Box::leak(Box::new(move |ctx: *mut whisper_rs_sys2::whisper_context, _user_data: *mut std::ffi::c_void| {
137 let num_segments = whisper_rs_sys2::whisper_full_n_segments(ctx);
138 let last_segment_index = num_segments - 1 ;
139
140 let ret1 = whisper_rs_sys2::whisper_full_get_segment_text(ctx, last_segment_index);
141 let c_str1 = CStr::from_ptr(ret1);
142 let segment = c_str1.to_str().unwrap().to_string();
143
144 let start_timestamp: i64 = whisper_rs_sys2::whisper_full_get_segment_t0(ctx, last_segment_index);
145 let end_timestamp: i64 = whisper_rs_sys2::whisper_full_get_segment_t1(ctx, last_segment_index);
146
147 println!("Voila [{} - {}]: {}", start_timestamp, end_timestamp, segment);
148 }));
149 let callback = Closure2::new(closure);
150 let &code = callback.code_ptr();
151 let ptr: unsafe extern "C" fn(ctx: *mut whisper_rs_sys2::whisper_context, user_data: *mut ::std::os::raw::c_void) = std::mem::transmute(code);
152 std::mem::forget(callback);
153
154 self.fp.new_segment_callback = Some(ptr);
155 }
156
157 pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
164 self.fp.new_segment_callback_user_data = user_data;
165 }
166}
167
168unsafe impl<'a> Send for FullParams<'a> {}
172unsafe impl<'a> Sync for FullParams<'a> {}