Skip to main content

scribble/backends/whisper/
mod.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use anyhow::{Result as AnyResult, anyhow, ensure};
5use whisper_rs::WhisperContext;
6
7use crate::Result;
8use crate::backend::{Backend, BackendStream};
9use crate::decoder::SamplesSink;
10use crate::opts::Opts;
11use crate::segment_encoder::SegmentEncoder;
12
13mod ctx;
14mod incremental;
15mod logging;
16mod segments;
17mod token;
18
19use incremental::BufferedSegmentTranscriber;
20use segments::emit_segments;
21
22/// Built-in backend powered by `whisper-rs` / `whisper.cpp`.
23pub struct WhisperBackend {
24    first_model_key: String,
25    first_model: WhisperContext,
26    models: HashMap<String, WhisperContext>,
27    vad_model_path: String,
28}
29
30/// Streaming state for [`WhisperBackend`].
31pub struct WhisperStream<'a> {
32    inner: BufferedSegmentTranscriber<'a>,
33}
34
35impl BackendStream for WhisperStream<'_> {
36    fn on_samples(&mut self, samples_16k_mono: &[f32]) -> Result<bool> {
37        self.inner.on_samples(samples_16k_mono).map_err(Into::into)
38    }
39
40    fn finish(&mut self) -> Result<()> {
41        self.inner.finish().map_err(Into::into)
42    }
43}
44
45impl WhisperBackend {
46    /// Load whisper.cpp model(s) and initialize a backend.
47    ///
48    /// Model keys are derived from the model filename (not the full path).
49    pub fn new<I, P>(model_paths: I, vad_model_path: &str) -> Result<Self>
50    where
51        I: IntoIterator<Item = P>,
52        P: AsRef<str>,
53    {
54        Self::new_anyhow(model_paths, vad_model_path).map_err(Into::into)
55    }
56
57    fn new_anyhow<I, P>(model_paths: I, vad_model_path: &str) -> AnyResult<Self>
58    where
59        I: IntoIterator<Item = P>,
60        P: AsRef<str>,
61    {
62        ensure!(
63            !vad_model_path.trim().is_empty(),
64            "VAD model path must be provided"
65        );
66
67        let vad_path = Path::new(vad_model_path);
68        ensure!(
69            vad_path.exists(),
70            "VAD model not found at '{}'",
71            vad_model_path
72        );
73        ensure!(
74            vad_path.is_file(),
75            "VAD model path is not a file: '{}'",
76            vad_model_path
77        );
78
79        let mut first_model_key: Option<String> = None;
80        let mut first_model: Option<WhisperContext> = None;
81        let mut models = HashMap::new();
82
83        for model_path in model_paths {
84            let model_path = model_path.as_ref();
85            ensure!(!model_path.trim().is_empty(), "model path must be provided");
86
87            let model_key = Self::model_key_from_path(model_path)?;
88            ensure!(
89                first_model_key.as_deref() != Some(&model_key) && !models.contains_key(&model_key),
90                "duplicate model key '{model_key}' derived from path '{model_path}'"
91            );
92
93            let ctx = ctx::get_context(model_path)?;
94            if first_model_key.is_none() {
95                first_model_key = Some(model_key);
96                first_model = Some(ctx);
97            } else {
98                models.insert(model_key, ctx);
99            }
100        }
101
102        let first_model_key = first_model_key
103            .ok_or_else(|| anyhow!("at least one whisper model must be provided"))?;
104        let first_model =
105            first_model.ok_or_else(|| anyhow!("missing default whisper model context"))?;
106
107        Ok(Self {
108            first_model_key,
109            first_model,
110            models,
111            vad_model_path: vad_model_path.to_owned(),
112        })
113    }
114
115    /// Access the default Whisper context (the first loaded model).
116    pub fn context(&self) -> &WhisperContext {
117        &self.first_model
118    }
119
120    /// Access the configured VAD model path.
121    pub fn vad_model_path(&self) -> &str {
122        &self.vad_model_path
123    }
124
125    /// The model key used when `Opts::model_key` is `None`.
126    pub fn default_model_key(&self) -> &str {
127        self.first_model_key.as_str()
128    }
129
130    /// List available model keys (sorted).
131    pub fn model_keys(&self) -> Vec<String> {
132        let mut keys = Vec::with_capacity(self.models.len() + 1);
133        keys.push(self.first_model_key.clone());
134        keys.extend(self.models.keys().cloned());
135        keys.sort_unstable();
136        keys
137    }
138
139    fn model_key_from_path(model_path: &str) -> AnyResult<String> {
140        let path = Path::new(model_path);
141        let Some(file_name) = path.file_name() else {
142            return Err(anyhow!(
143                "model path '{model_path}' does not have a filename"
144            ));
145        };
146        let Some(file_name) = file_name.to_str() else {
147            return Err(anyhow!(
148                "model filename for path '{model_path}' is not valid UTF-8"
149            ));
150        };
151        ensure!(
152            !file_name.trim().is_empty(),
153            "model filename for path '{model_path}' is empty"
154        );
155        Ok(file_name.to_owned())
156    }
157
158    fn selected_model_key<'a>(&'a self, opts: &'a Opts) -> AnyResult<&'a str> {
159        if let Some(key) = opts.model_key.as_deref() {
160            if key == self.first_model_key || self.models.contains_key(key) {
161                return Ok(key);
162            }
163            return Err(anyhow!(
164                "unknown model key '{key}' (available: {})",
165                self.available_model_keys()
166            ));
167        }
168
169        Ok(self.first_model_key.as_str())
170    }
171
172    fn selected_context<'a>(&'a self, opts: &'a Opts) -> AnyResult<&'a WhisperContext> {
173        let key = self.selected_model_key(opts)?;
174        if key == self.first_model_key {
175            return Ok(&self.first_model);
176        }
177        self.models
178            .get(key)
179            .ok_or_else(|| anyhow!("selected model '{key}' was not loaded"))
180    }
181
182    fn available_model_keys(&self) -> String {
183        let mut keys: Vec<&str> = self.models.keys().map(|k| k.as_str()).collect();
184        keys.push(self.first_model_key.as_str());
185        keys.sort_unstable();
186        keys.join(", ")
187    }
188}
189
190impl Backend for WhisperBackend {
191    type Stream<'a>
192        = WhisperStream<'a>
193    where
194        Self: 'a;
195
196    fn transcribe_full(
197        &self,
198        opts: &Opts,
199        encoder: &mut dyn SegmentEncoder,
200        samples: &[f32],
201    ) -> Result<()> {
202        self.transcribe_full_anyhow(opts, encoder, samples)
203            .map_err(Into::into)
204    }
205
206    fn create_stream<'a>(
207        &'a self,
208        opts: &'a Opts,
209        encoder: &'a mut dyn SegmentEncoder,
210    ) -> Result<Self::Stream<'a>> {
211        self.create_stream_anyhow(opts, encoder).map_err(Into::into)
212    }
213}
214
215impl WhisperBackend {
216    fn transcribe_full_anyhow(
217        &self,
218        opts: &Opts,
219        encoder: &mut dyn SegmentEncoder,
220        samples: &[f32],
221    ) -> AnyResult<()> {
222        if samples.is_empty() {
223            return Ok(());
224        }
225
226        let ctx = self.selected_context(opts)?;
227
228        // VAD workflow is temporarily disabled while the streaming-focused version is reworked.
229        let _ = opts.enable_voice_activity_detection;
230        emit_segments(ctx, opts, samples, &mut |seg| {
231            encoder.write_segment(seg).map_err(Into::into)
232        })
233    }
234
235    fn create_stream_anyhow<'a>(
236        &'a self,
237        opts: &'a Opts,
238        encoder: &'a mut dyn SegmentEncoder,
239    ) -> AnyResult<WhisperStream<'a>> {
240        let ctx = self.selected_context(opts)?;
241
242        // VAD workflow is temporarily disabled while the streaming-focused version is reworked.
243        let _ = opts.enable_voice_activity_detection;
244        Ok(WhisperStream {
245            inner: BufferedSegmentTranscriber::new(ctx, opts, encoder),
246        })
247    }
248}