simple_whisper/
transcribe.rs1use std::{
2 path::{Path, PathBuf},
3 time::Duration,
4};
5
6use derive_builder::Builder;
7use tokio::sync::mpsc::UnboundedSender;
8use whisper_rs::{
9 FullParams, SamplingStrategy, SegmentCallbackData, WhisperContext, WhisperContextParameters,
10 WhisperError, WhisperState,
11};
12
13use crate::{Error, Event, Language};
14
15#[derive(Builder)]
16#[builder(
17 setter(into),
18 pattern = "owned",
19 build_fn(skip, error = "TranscribeBuilderError")
20)]
21pub struct Transcribe {
22 language: Language,
23 audio: (Vec<f32>, Duration),
24 tx: UnboundedSender<Result<Event, Error>>,
25 #[builder(setter(name = "model"))]
26 _model: PathBuf,
27 #[builder(setter(skip))]
28 state: WhisperState,
29 single_segment: bool,
30}
31
32impl TranscribeBuilder {
33 pub fn build(self) -> Result<Transcribe, TranscribeBuilderError> {
34 if self.language.is_none() {
35 return Err(TranscribeBuilderError::UninitializedFieldError("language"));
36 }
37
38 if self.audio.is_none() {
39 return Err(TranscribeBuilderError::UninitializedFieldError("audio"));
40 }
41
42 if self.tx.is_none() {
43 return Err(TranscribeBuilderError::UninitializedFieldError("tx"));
44 }
45
46 if self._model.is_none() {
47 return Err(TranscribeBuilderError::UninitializedFieldError("model"));
48 }
49
50 let state = state_builder(self._model.as_ref().unwrap())?;
51
52 Ok(Transcribe {
53 language: self.language.unwrap(),
54 audio: self.audio.unwrap(),
55 tx: self.tx.unwrap(),
56 _model: self._model.unwrap(),
57 state,
58 single_segment: self.single_segment.unwrap_or(false),
59 })
60 }
61}
62
63#[derive(Error, Debug)]
65pub enum TranscribeBuilderError {
66 #[error("Field not initialized: {0}")]
67 UninitializedFieldError(&'static str),
68 #[error(transparent)]
69 WhisperCppError(#[from] WhisperError),
70}
71
72fn state_builder(model: &Path) -> Result<WhisperState, WhisperError> {
73 #![allow(unused_mut)]
74 let mut context_param = WhisperContextParameters::default();
75 #[cfg(any(
76 feature = "metal",
77 feature = "vulkan",
78 feature = "cuda",
79 feature = "hipblas"
80 ))]
81 {
82 context_param.use_gpu(true);
83 }
84
85 let ctx = WhisperContext::new_with_params(model.to_str().unwrap(), context_param)?;
86
87 ctx.create_state()
88}
89
90impl Transcribe {
91 pub fn transcribe(mut self) {
92 let tx_callback = self.tx.downgrade();
93
94 let (audio, duration) = &self.audio;
95 let duration = *duration;
96 let lang = self.language.to_string();
97
98 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
99 params.set_single_segment(self.single_segment);
100 params.set_n_threads(num_cpus::get().try_into().unwrap());
101 params.set_language(Some(&lang));
102 params.set_print_special(false);
103 params.set_print_progress(false);
104 params.set_print_timestamps(false);
105
106 params.set_segment_callback_safe(move |seg: SegmentCallbackData| {
107 let start_offset = Duration::from_millis(seg.start_timestamp as u64 * 10);
108 let end_offset = Duration::from_millis(seg.end_timestamp as u64 * 10);
109 let mut percentage = end_offset.as_millis() as f32 / duration.as_millis() as f32;
110 if percentage > 1. {
111 percentage = 1.;
112 }
113 let seg = Event::Segment {
114 start_offset,
115 end_offset,
116 percentage,
117 transcription: seg.text,
118 };
119 let _ = tx_callback.upgrade().unwrap().send(Ok(seg));
120 });
121
122 if let Err(err) = self.state.full(params, audio) {
123 let _ = self.tx.send(Err(Error::Whisper(err)));
124 }
125 }
126}