simple_whisper/
transcribe.rs

1use std::{
2    path::{Path, PathBuf},
3    time::Duration,
4};
5
6use derive_builder::Builder;
7use tokio::sync::mpsc::UnboundedSender;
8use whisper_rs::{
9    FullParams, SamplingStrategy, SegmentCallbackData, WhisperContext, WhisperContextParameters,
10    WhisperError, WhisperState,
11};
12
13use crate::{Error, Event, Language};
14
15#[derive(Builder)]
16#[builder(
17    setter(into),
18    pattern = "owned",
19    build_fn(skip, error = "TranscribeBuilderError")
20)]
21pub struct Transcribe {
22    language: Language,
23    audio: (Vec<f32>, Duration),
24    tx: UnboundedSender<Result<Event, Error>>,
25    #[builder(setter(name = "model"))]
26    _model: PathBuf,
27    #[builder(setter(skip))]
28    state: WhisperState,
29    single_segment: bool,
30}
31
32impl TranscribeBuilder {
33    pub fn build(self) -> Result<Transcribe, TranscribeBuilderError> {
34        if self.language.is_none() {
35            return Err(TranscribeBuilderError::UninitializedFieldError("language"));
36        }
37
38        if self.audio.is_none() {
39            return Err(TranscribeBuilderError::UninitializedFieldError("audio"));
40        }
41
42        if self.tx.is_none() {
43            return Err(TranscribeBuilderError::UninitializedFieldError("tx"));
44        }
45
46        if self._model.is_none() {
47            return Err(TranscribeBuilderError::UninitializedFieldError("model"));
48        }
49
50        let state = state_builder(self._model.as_ref().unwrap())?;
51
52        Ok(Transcribe {
53            language: self.language.unwrap(),
54            audio: self.audio.unwrap(),
55            tx: self.tx.unwrap(),
56            _model: self._model.unwrap(),
57            state,
58            single_segment: self.single_segment.unwrap_or(false),
59        })
60    }
61}
62
63/// Error type for TrascriveBuilder
64#[derive(Error, Debug)]
65pub enum TranscribeBuilderError {
66    #[error("Field not initialized: {0}")]
67    UninitializedFieldError(&'static str),
68    #[error(transparent)]
69    WhisperCppError(#[from] WhisperError),
70}
71
72fn state_builder(model: &Path) -> Result<WhisperState, WhisperError> {
73    #![allow(unused_mut)]
74    let mut context_param = WhisperContextParameters::default();
75    #[cfg(any(
76        feature = "metal",
77        feature = "vulkan",
78        feature = "cuda",
79        feature = "hipblas"
80    ))]
81    {
82        context_param.use_gpu(true);
83    }
84
85    let ctx = WhisperContext::new_with_params(model.to_str().unwrap(), context_param)?;
86
87    ctx.create_state()
88}
89
90impl Transcribe {
91    pub fn transcribe(mut self) {
92        let tx_callback = self.tx.downgrade();
93
94        let (audio, duration) = &self.audio;
95        let duration = *duration;
96        let lang = self.language.to_string();
97
98        let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
99        params.set_single_segment(self.single_segment);
100        params.set_n_threads(num_cpus::get().try_into().unwrap());
101        params.set_language(Some(&lang));
102        params.set_print_special(false);
103        params.set_print_progress(false);
104        params.set_print_timestamps(false);
105
106        params.set_segment_callback_safe(move |seg: SegmentCallbackData| {
107            let start_offset = Duration::from_millis(seg.start_timestamp as u64 * 10);
108            let end_offset = Duration::from_millis(seg.end_timestamp as u64 * 10);
109            let mut percentage = end_offset.as_millis() as f32 / duration.as_millis() as f32;
110            if percentage > 1. {
111                percentage = 1.;
112            }
113            let seg = Event::Segment {
114                start_offset,
115                end_offset,
116                percentage,
117                transcription: seg.text,
118            };
119            let _ = tx_callback.upgrade().unwrap().send(Ok(seg));
120        });
121
122        if let Err(err) = self.state.full(params, audio) {
123            let _ = self.tx.send(Err(Error::Whisper(err)));
124        }
125    }
126}