simple_whisper/
lib.rs

1use std::{
2    fs::File,
3    io::{self, BufReader},
4    path::{Path, PathBuf},
5    sync::Arc,
6    time::Duration,
7};
8
9use derive_builder::Builder;
10
11mod download;
12mod language;
13mod model;
14mod transcribe;
15
16use download::ProgressType;
17pub use language::Language;
18pub use model::Model;
19use rodio::{Decoder, Source, source::UniformSourceIterator};
20use strum::{Display, EnumIs};
21use thiserror::Error;
22use tokio::{
23    spawn,
24    sync::{Notify, mpsc::unbounded_channel},
25    task::spawn_blocking,
26};
27pub use transcribe::TranscribeBuilderError;
28
29use tokio_stream::{Stream, wrappers::UnboundedReceiverStream};
30use transcribe::TranscribeBuilder;
31use whisper_rs::WhisperError;
32
33type Barrier = Arc<Notify>;
34
35pub const SAMPLE_RATE: u32 = 16000;
36
37/// The Whisper audio transcription model.
38#[derive(Default, Builder, Debug)]
39#[builder(setter(into), build_fn(validate = "Self::validate"))]
40pub struct Whisper {
41    language: Language,
42    model: Model,
43    #[builder(default = "false")]
44    progress_bar: bool,
45    #[builder(default = "false")]
46    force_download: bool,
47    #[builder(default = "false")]
48    force_single_segment: bool,
49}
50
51/// Error conditions
52#[derive(Error, Debug)]
53pub enum Error {
54    /// Error that can occur during model files download from huggingface
55    #[error(transparent)]
56    Download(#[from] hf_hub::api::tokio::ApiError),
57    #[error(transparent)]
58    Io(#[from] io::Error),
59    /// Error that can occur during audio file decoding phase
60    #[error(transparent)]
61    AudioDecoder(#[from] rodio::decoder::DecoderError),
62    /// The library was unable to determine the audio file duration
63    #[error("Unable to find duration")]
64    AudioDuration,
65    #[error(transparent)]
66    /// Missing parameters to instantiate the whisper cpp backend
67    ComputeBuilder(#[from] TranscribeBuilderError),
68    #[error(transparent)]
69    Whisper(#[from] WhisperError),
70}
71
72/// Events generated by the [Whisper::transcribe] method
73#[derive(Clone, Debug, Display, EnumIs)]
74pub enum Event {
75    #[strum(to_string = "Downloading {file}")]
76    DownloadStarted { file: String },
77    #[strum(to_string = "{file} has been downloaded")]
78    DownloadCompleted { file: String },
79    #[strum(
80        to_string = "Downloading {file} --> {percentage} {elapsed_time:#?} | {remaining_time:#?}"
81    )]
82    DownloadProgress {
83        /// The resource to download
84        file: String,
85
86        /// The progress expressed as %
87        percentage: f32,
88
89        /// Time elapsed since the download as being started
90        elapsed_time: Duration,
91
92        /// Estimated time to complete the download
93        remaining_time: Duration,
94    },
95    /// Audio chunk transcript
96    #[strum(to_string = "{transcription}")]
97    Segment {
98        start_offset: Duration,
99        end_offset: Duration,
100        percentage: f32,
101        transcription: String,
102    },
103}
104
105impl WhisperBuilder {
106    fn validate(&self) -> Result<(), WhisperBuilderError> {
107        if self.language.as_ref().is_some_and(|l| !l.is_english())
108            && self.model.as_ref().is_some_and(|m| !m.is_multilingual())
109        {
110            let err = format!(
111                "The requested language {} is not compatible with {} model",
112                self.language.as_ref().unwrap(),
113                self.model.as_ref().unwrap()
114            );
115            return Err(WhisperBuilderError::ValidationError(err));
116        }
117        Ok(())
118    }
119}
120
121impl Whisper {
122    /// Transcribe an audio file into text.
123    pub fn transcribe(self, path: impl AsRef<Path>) -> impl Stream<Item = Result<Event, Error>> {
124        let (tx, rx) = unbounded_channel();
125        let (tx_event, mut rx_event) = unbounded_channel();
126
127        let wait_download = Barrier::default();
128        let download_completed = wait_download.clone();
129
130        let path = path.as_ref().into();
131
132        // Download events forwarder
133        let tx_forwarder = tx.clone();
134        spawn(async move {
135            while let Some(msg) = rx_event.recv().await {
136                let _ = tx_forwarder.send(Ok(msg));
137            }
138            wait_download.notify_one();
139        });
140
141        spawn(async move {
142            // Download model data from Hugging Face
143            let progress = if self.progress_bar {
144                drop(tx_event);
145                ProgressType::ProgressBar
146            } else {
147                ProgressType::Callback(tx_event)
148            };
149            let model = self
150                .model
151                .internal_download_model(self.force_download, progress)
152                .await;
153            download_completed.notified().await;
154
155            spawn_blocking(move || {
156                // Load audio file
157                let audio = Self::load_audio(path);
158
159                match audio.map(|audio| (audio, model)) {
160                    Ok((audio, Ok(model_files))) => {
161                        match TranscribeBuilder::default()
162                            .language(self.language)
163                            .audio(audio)
164                            .single_segment(self.force_single_segment)
165                            .tx(tx.clone())
166                            .model(model_files)
167                            .build()
168                        {
169                            Ok(compute) => compute.transcribe(),
170                            Err(err) => {
171                                let _ = tx.send(Err(err.into()));
172                            }
173                        }
174                    }
175                    Ok((_, Err(err))) => {
176                        let _ = tx.send(Err(err));
177                    }
178                    Err(err) => {
179                        let _ = tx.send(Err(err));
180                    }
181                }
182            });
183        });
184
185        UnboundedReceiverStream::new(rx)
186    }
187
188    fn load_audio(path: PathBuf) -> Result<(Vec<f32>, Duration), Error> {
189        let reader = BufReader::new(File::open(&path)?);
190        let decoder = Decoder::new(reader)?;
191        let resample: UniformSourceIterator<Decoder<BufReader<File>>, f32> =
192            UniformSourceIterator::new(decoder, 1, SAMPLE_RATE);
193        let samples = resample
194            .low_pass(3000)
195            .high_pass(200)
196            .convert_samples()
197            .collect::<Vec<f32>>();
198
199        let duration = Self::get_audio_duration(samples.len());
200
201        Ok((samples, duration))
202    }
203
204    fn get_audio_duration(samples: usize) -> Duration {
205        let secs = samples as f64 / SAMPLE_RATE as f64;
206        Duration::from_secs_f64(secs)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use tokio_stream::StreamExt;
213
214    use super::*;
215
216    macro_rules! test_file {
217        ($file_name:expr) => {
218            concat!(env!("CARGO_MANIFEST_DIR"), "/../assets/", $file_name)
219        };
220    }
221
222    #[test]
223    fn incompatible_lang_model() {
224        let error = WhisperBuilder::default()
225            .language(Language::Italian)
226            .model(Model::BaseEn)
227            .build()
228            .unwrap_err();
229        assert!(matches!(error, WhisperBuilderError::ValidationError(_)));
230    }
231
232    #[test]
233    fn compatible_lang_model() {
234        WhisperBuilder::default()
235            .language(Language::Italian)
236            .model(Model::Base)
237            .build()
238            .unwrap();
239    }
240
241    #[ignore]
242    #[tokio::test]
243    async fn simple_transcribe_ok() {
244        let mut rx = WhisperBuilder::default()
245            .language(Language::English)
246            .model(Model::Tiny)
247            .progress_bar(true)
248            .build()
249            .unwrap()
250            .transcribe(test_file!("samples_jfk.wav"));
251
252        while let Some(msg) = rx.next().await {
253            assert!(msg.is_ok());
254            println!("{msg:?}");
255        }
256    }
257}