1use std::{
2 fs::File,
3 io::{self, BufReader},
4 path::{Path, PathBuf},
5 sync::Arc,
6 time::Duration,
7};
8
9use derive_builder::Builder;
10
11mod download;
12mod language;
13mod model;
14mod transcribe;
15
16use download::ProgressType;
17pub use language::Language;
18pub use model::Model;
19use rodio::{Decoder, Source, source::UniformSourceIterator};
20use strum::{Display, EnumIs};
21use thiserror::Error;
22use tokio::{
23 spawn,
24 sync::{Notify, mpsc::unbounded_channel},
25 task::spawn_blocking,
26};
27pub use transcribe::TranscribeBuilderError;
28
29use tokio_stream::{Stream, wrappers::UnboundedReceiverStream};
30use transcribe::TranscribeBuilder;
31use whisper_rs::WhisperError;
32
33type Barrier = Arc<Notify>;
34
35pub const SAMPLE_RATE: u32 = 16000;
36
37#[derive(Default, Builder, Debug)]
39#[builder(setter(into), build_fn(validate = "Self::validate"))]
40pub struct Whisper {
41 language: Language,
42 model: Model,
43 #[builder(default = "false")]
44 progress_bar: bool,
45 #[builder(default = "false")]
46 force_download: bool,
47 #[builder(default = "false")]
48 force_single_segment: bool,
49}
50
51#[derive(Error, Debug)]
53pub enum Error {
54 #[error(transparent)]
56 Download(#[from] hf_hub::api::tokio::ApiError),
57 #[error(transparent)]
58 Io(#[from] io::Error),
59 #[error(transparent)]
61 AudioDecoder(#[from] rodio::decoder::DecoderError),
62 #[error("Unable to find duration")]
64 AudioDuration,
65 #[error(transparent)]
66 ComputeBuilder(#[from] TranscribeBuilderError),
68 #[error(transparent)]
69 Whisper(#[from] WhisperError),
70}
71
72#[derive(Clone, Debug, Display, EnumIs)]
74pub enum Event {
75 #[strum(to_string = "Downloading {file}")]
76 DownloadStarted { file: String },
77 #[strum(to_string = "{file} has been downloaded")]
78 DownloadCompleted { file: String },
79 #[strum(
80 to_string = "Downloading {file} --> {percentage} {elapsed_time:#?} | {remaining_time:#?}"
81 )]
82 DownloadProgress {
83 file: String,
85
86 percentage: f32,
88
89 elapsed_time: Duration,
91
92 remaining_time: Duration,
94 },
95 #[strum(to_string = "{transcription}")]
97 Segment {
98 start_offset: Duration,
99 end_offset: Duration,
100 percentage: f32,
101 transcription: String,
102 },
103}
104
105impl WhisperBuilder {
106 fn validate(&self) -> Result<(), WhisperBuilderError> {
107 if self.language.as_ref().is_some_and(|l| !l.is_english())
108 && self.model.as_ref().is_some_and(|m| !m.is_multilingual())
109 {
110 let err = format!(
111 "The requested language {} is not compatible with {} model",
112 self.language.as_ref().unwrap(),
113 self.model.as_ref().unwrap()
114 );
115 return Err(WhisperBuilderError::ValidationError(err));
116 }
117 Ok(())
118 }
119}
120
121impl Whisper {
122 pub fn transcribe(self, path: impl AsRef<Path>) -> impl Stream<Item = Result<Event, Error>> {
124 let (tx, rx) = unbounded_channel();
125 let (tx_event, mut rx_event) = unbounded_channel();
126
127 let wait_download = Barrier::default();
128 let download_completed = wait_download.clone();
129
130 let path = path.as_ref().into();
131
132 let tx_forwarder = tx.clone();
134 spawn(async move {
135 while let Some(msg) = rx_event.recv().await {
136 let _ = tx_forwarder.send(Ok(msg));
137 }
138 wait_download.notify_one();
139 });
140
141 spawn(async move {
142 let progress = if self.progress_bar {
144 drop(tx_event);
145 ProgressType::ProgressBar
146 } else {
147 ProgressType::Callback(tx_event)
148 };
149 let model = self
150 .model
151 .internal_download_model(self.force_download, progress)
152 .await;
153 download_completed.notified().await;
154
155 spawn_blocking(move || {
156 let audio = Self::load_audio(path);
158
159 match audio.map(|audio| (audio, model)) {
160 Ok((audio, Ok(model_files))) => {
161 match TranscribeBuilder::default()
162 .language(self.language)
163 .audio(audio)
164 .single_segment(self.force_single_segment)
165 .tx(tx.clone())
166 .model(model_files)
167 .build()
168 {
169 Ok(compute) => compute.transcribe(),
170 Err(err) => {
171 let _ = tx.send(Err(err.into()));
172 }
173 }
174 }
175 Ok((_, Err(err))) => {
176 let _ = tx.send(Err(err));
177 }
178 Err(err) => {
179 let _ = tx.send(Err(err));
180 }
181 }
182 });
183 });
184
185 UnboundedReceiverStream::new(rx)
186 }
187
188 fn load_audio(path: PathBuf) -> Result<(Vec<f32>, Duration), Error> {
189 let reader = BufReader::new(File::open(&path)?);
190 let decoder = Decoder::new(reader)?;
191 let resample: UniformSourceIterator<Decoder<BufReader<File>>, f32> =
192 UniformSourceIterator::new(decoder, 1, SAMPLE_RATE);
193 let samples = resample
194 .low_pass(3000)
195 .high_pass(200)
196 .convert_samples()
197 .collect::<Vec<f32>>();
198
199 let duration = Self::get_audio_duration(samples.len());
200
201 Ok((samples, duration))
202 }
203
204 fn get_audio_duration(samples: usize) -> Duration {
205 let secs = samples as f64 / SAMPLE_RATE as f64;
206 Duration::from_secs_f64(secs)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use tokio_stream::StreamExt;
213
214 use super::*;
215
216 macro_rules! test_file {
217 ($file_name:expr) => {
218 concat!(env!("CARGO_MANIFEST_DIR"), "/../assets/", $file_name)
219 };
220 }
221
222 #[test]
223 fn incompatible_lang_model() {
224 let error = WhisperBuilder::default()
225 .language(Language::Italian)
226 .model(Model::BaseEn)
227 .build()
228 .unwrap_err();
229 assert!(matches!(error, WhisperBuilderError::ValidationError(_)));
230 }
231
232 #[test]
233 fn compatible_lang_model() {
234 WhisperBuilder::default()
235 .language(Language::Italian)
236 .model(Model::Base)
237 .build()
238 .unwrap();
239 }
240
241 #[ignore]
242 #[tokio::test]
243 async fn simple_transcribe_ok() {
244 let mut rx = WhisperBuilder::default()
245 .language(Language::English)
246 .model(Model::Tiny)
247 .progress_bar(true)
248 .build()
249 .unwrap()
250 .transcribe(test_file!("samples_jfk.wav"));
251
252 while let Some(msg) = rx.next().await {
253 assert!(msg.is_ok());
254 println!("{msg:?}");
255 }
256 }
257}