whisper_rs_2/
whisper_params.rs

1use std::ffi::{c_int, CStr};
2use std::marker::PhantomData;
3use crate::WhisperContext;
4use libffi::high::Closure2;
5
6pub enum SamplingStrategy {
7    Greedy {
8        n_past: c_int,
9    },
10    /// not implemented yet, results of using this unknown
11    BeamSearch {
12        n_past: c_int,
13        beam_width: c_int,
14        n_best: c_int,
15    },
16}
17
18pub struct FullParams<'a> {
19    pub(crate) fp: whisper_rs_sys2::whisper_full_params,
20    phantom: PhantomData<&'a str>,
21}
22
23impl<'a> FullParams<'a> {
24    /// Create a new set of parameters for the decoder.
25    pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> {
26        let mut fp = unsafe {
27            whisper_rs_sys2::whisper_full_default_params(match sampling_strategy {
28                SamplingStrategy::Greedy { .. } => {
29                    whisper_rs_sys2::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY
30                }
31                SamplingStrategy::BeamSearch { .. } => {
32                    whisper_rs_sys2::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH
33                }
34            } as _)
35        };
36
37        match sampling_strategy {
38            SamplingStrategy::Greedy { n_past } => {
39                fp.greedy.n_past = n_past;
40            }
41            SamplingStrategy::BeamSearch {
42                n_past,
43                beam_width,
44                n_best,
45            } => {
46                fp.beam_search.n_past = n_past;
47                fp.beam_search.beam_width = beam_width;
48                fp.beam_search.n_best = n_best;
49            }
50        }
51
52        Self {
53            fp,
54            phantom: PhantomData,
55        }
56    }
57
58    /// Set the number of threads to use for decoding.
59    ///
60    /// Defaults to min(4, std::thread::hardware_concurrency()).
61    pub fn set_n_threads(&mut self, n_threads: c_int) {
62        self.fp.n_threads = n_threads;
63    }
64
65    /// Set the offset in milliseconds to use for decoding.
66    ///
67    /// Defaults to 0.
68    pub fn set_offset_ms(&mut self, offset_ms: c_int) {
69        self.fp.offset_ms = offset_ms;
70    }
71
72    /// Set whether to translate the output to the language specified by `language`.
73    ///
74    /// Defaults to false.
75    pub fn set_translate(&mut self, translate: bool) {
76        self.fp.translate = translate;
77    }
78
79    /// Set no_context. Usage unknown.
80    ///
81    /// Defaults to false.
82    pub fn set_no_context(&mut self, no_context: bool) {
83        self.fp.no_context = no_context;
84    }
85
86    /// Set whether to print special tokens.
87    ///
88    /// Defaults to false.
89    pub fn set_print_special_tokens(&mut self, print_special_tokens: bool) {
90        self.fp.print_special_tokens = print_special_tokens;
91    }
92
93    /// Set whether to print progress.
94    ///
95    /// Defaults to true.
96    pub fn set_print_progress(&mut self, print_progress: bool) {
97        self.fp.print_progress = print_progress;
98    }
99
100    /// Set print_realtime. Usage unknown.
101    ///
102    /// Defaults to false.
103    pub fn set_print_realtime(&mut self, print_realtime: bool) {
104        self.fp.print_realtime = print_realtime;
105    }
106
107    /// Set whether to print timestamps.
108    ///
109    /// Defaults to true.
110    pub fn set_print_timestamps(&mut self, print_timestamps: bool) {
111        self.fp.print_timestamps = print_timestamps;
112    }
113
114    /// Set the target language.
115    ///
116    /// Defaults to "en".
117    pub fn set_language(&mut self, language: &'a str) {
118        self.fp.language = language.as_ptr() as *const _;
119    }
120
121    /// Set the callback for new segments.
122    ///
123    /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
124    /// It is still a C callback.
125    ///
126    /// # Safety
127    /// Do not use this function unless you know what you are doing.
128    /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
129    ///   This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
130    ///
131    /// Defaults to None.
132    pub unsafe fn set_new_segment_callback(
133        &mut self,
134        _new_segment_callback: fn(a: WhisperContext) -> (),
135    ) {
136        let closure: &'static _ = Box::leak(Box::new(move |ctx: *mut whisper_rs_sys2::whisper_context, _user_data: *mut std::ffi::c_void| {
137            let num_segments = whisper_rs_sys2::whisper_full_n_segments(ctx);
138            let last_segment_index = num_segments - 1 ;
139
140            let ret1 = whisper_rs_sys2::whisper_full_get_segment_text(ctx, last_segment_index);
141            let c_str1 = CStr::from_ptr(ret1);
142            let segment = c_str1.to_str().unwrap().to_string();
143
144            let start_timestamp: i64 = whisper_rs_sys2::whisper_full_get_segment_t0(ctx, last_segment_index);
145            let end_timestamp: i64 = whisper_rs_sys2::whisper_full_get_segment_t1(ctx, last_segment_index);
146
147            println!("Voila [{} - {}]: {}", start_timestamp, end_timestamp, segment);
148        }));
149        let callback = Closure2::new(closure);
150        let &code = callback.code_ptr();
151        let ptr: unsafe extern "C" fn(ctx: *mut whisper_rs_sys2::whisper_context, user_data: *mut ::std::os::raw::c_void) = std::mem::transmute(code);
152        std::mem::forget(callback);
153
154        self.fp.new_segment_callback = Some(ptr);
155    }
156
157    /// Set the user data to be passed to the new segment callback.
158    ///
159    /// # Safety
160    /// See the safety notes for `set_new_segment_callback`.
161    ///
162    /// Defaults to None.
163    pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
164        self.fp.new_segment_callback_user_data = user_data;
165    }
166}
167
168// following implementations are safe
169// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
170// concurrent usage is prevented by &mut self on methods that modify the struct
171unsafe impl<'a> Send for FullParams<'a> {}
172unsafe impl<'a> Sync for FullParams<'a> {}