scribble/backends/whisper/
mod.rs1use std::collections::HashMap;
2use std::path::Path;
3
4use anyhow::{Result as AnyResult, anyhow, ensure};
5use whisper_rs::WhisperContext;
6
7use crate::Result;
8use crate::backend::{Backend, BackendStream};
9use crate::decoder::SamplesSink;
10use crate::opts::Opts;
11use crate::segment_encoder::SegmentEncoder;
12
13mod ctx;
14mod incremental;
15mod logging;
16mod segments;
17mod token;
18
19use incremental::BufferedSegmentTranscriber;
20use segments::emit_segments;
21
22pub struct WhisperBackend {
24 first_model_key: String,
25 first_model: WhisperContext,
26 models: HashMap<String, WhisperContext>,
27 vad_model_path: String,
28}
29
30pub struct WhisperStream<'a> {
32 inner: BufferedSegmentTranscriber<'a>,
33}
34
35impl BackendStream for WhisperStream<'_> {
36 fn on_samples(&mut self, samples_16k_mono: &[f32]) -> Result<bool> {
37 self.inner.on_samples(samples_16k_mono).map_err(Into::into)
38 }
39
40 fn finish(&mut self) -> Result<()> {
41 self.inner.finish().map_err(Into::into)
42 }
43}
44
45impl WhisperBackend {
46 pub fn new<I, P>(model_paths: I, vad_model_path: &str) -> Result<Self>
50 where
51 I: IntoIterator<Item = P>,
52 P: AsRef<str>,
53 {
54 Self::new_anyhow(model_paths, vad_model_path).map_err(Into::into)
55 }
56
57 fn new_anyhow<I, P>(model_paths: I, vad_model_path: &str) -> AnyResult<Self>
58 where
59 I: IntoIterator<Item = P>,
60 P: AsRef<str>,
61 {
62 ensure!(
63 !vad_model_path.trim().is_empty(),
64 "VAD model path must be provided"
65 );
66
67 let vad_path = Path::new(vad_model_path);
68 ensure!(
69 vad_path.exists(),
70 "VAD model not found at '{}'",
71 vad_model_path
72 );
73 ensure!(
74 vad_path.is_file(),
75 "VAD model path is not a file: '{}'",
76 vad_model_path
77 );
78
79 let mut first_model_key: Option<String> = None;
80 let mut first_model: Option<WhisperContext> = None;
81 let mut models = HashMap::new();
82
83 for model_path in model_paths {
84 let model_path = model_path.as_ref();
85 ensure!(!model_path.trim().is_empty(), "model path must be provided");
86
87 let model_key = Self::model_key_from_path(model_path)?;
88 ensure!(
89 first_model_key.as_deref() != Some(&model_key) && !models.contains_key(&model_key),
90 "duplicate model key '{model_key}' derived from path '{model_path}'"
91 );
92
93 let ctx = ctx::get_context(model_path)?;
94 if first_model_key.is_none() {
95 first_model_key = Some(model_key);
96 first_model = Some(ctx);
97 } else {
98 models.insert(model_key, ctx);
99 }
100 }
101
102 let first_model_key = first_model_key
103 .ok_or_else(|| anyhow!("at least one whisper model must be provided"))?;
104 let first_model =
105 first_model.ok_or_else(|| anyhow!("missing default whisper model context"))?;
106
107 Ok(Self {
108 first_model_key,
109 first_model,
110 models,
111 vad_model_path: vad_model_path.to_owned(),
112 })
113 }
114
115 pub fn context(&self) -> &WhisperContext {
117 &self.first_model
118 }
119
120 pub fn vad_model_path(&self) -> &str {
122 &self.vad_model_path
123 }
124
125 pub fn default_model_key(&self) -> &str {
127 self.first_model_key.as_str()
128 }
129
130 pub fn model_keys(&self) -> Vec<String> {
132 let mut keys = Vec::with_capacity(self.models.len() + 1);
133 keys.push(self.first_model_key.clone());
134 keys.extend(self.models.keys().cloned());
135 keys.sort_unstable();
136 keys
137 }
138
139 fn model_key_from_path(model_path: &str) -> AnyResult<String> {
140 let path = Path::new(model_path);
141 let Some(file_name) = path.file_name() else {
142 return Err(anyhow!(
143 "model path '{model_path}' does not have a filename"
144 ));
145 };
146 let Some(file_name) = file_name.to_str() else {
147 return Err(anyhow!(
148 "model filename for path '{model_path}' is not valid UTF-8"
149 ));
150 };
151 ensure!(
152 !file_name.trim().is_empty(),
153 "model filename for path '{model_path}' is empty"
154 );
155 Ok(file_name.to_owned())
156 }
157
158 fn selected_model_key<'a>(&'a self, opts: &'a Opts) -> AnyResult<&'a str> {
159 if let Some(key) = opts.model_key.as_deref() {
160 if key == self.first_model_key || self.models.contains_key(key) {
161 return Ok(key);
162 }
163 return Err(anyhow!(
164 "unknown model key '{key}' (available: {})",
165 self.available_model_keys()
166 ));
167 }
168
169 Ok(self.first_model_key.as_str())
170 }
171
172 fn selected_context<'a>(&'a self, opts: &'a Opts) -> AnyResult<&'a WhisperContext> {
173 let key = self.selected_model_key(opts)?;
174 if key == self.first_model_key {
175 return Ok(&self.first_model);
176 }
177 self.models
178 .get(key)
179 .ok_or_else(|| anyhow!("selected model '{key}' was not loaded"))
180 }
181
182 fn available_model_keys(&self) -> String {
183 let mut keys: Vec<&str> = self.models.keys().map(|k| k.as_str()).collect();
184 keys.push(self.first_model_key.as_str());
185 keys.sort_unstable();
186 keys.join(", ")
187 }
188}
189
190impl Backend for WhisperBackend {
191 type Stream<'a>
192 = WhisperStream<'a>
193 where
194 Self: 'a;
195
196 fn transcribe_full(
197 &self,
198 opts: &Opts,
199 encoder: &mut dyn SegmentEncoder,
200 samples: &[f32],
201 ) -> Result<()> {
202 self.transcribe_full_anyhow(opts, encoder, samples)
203 .map_err(Into::into)
204 }
205
206 fn create_stream<'a>(
207 &'a self,
208 opts: &'a Opts,
209 encoder: &'a mut dyn SegmentEncoder,
210 ) -> Result<Self::Stream<'a>> {
211 self.create_stream_anyhow(opts, encoder).map_err(Into::into)
212 }
213}
214
215impl WhisperBackend {
216 fn transcribe_full_anyhow(
217 &self,
218 opts: &Opts,
219 encoder: &mut dyn SegmentEncoder,
220 samples: &[f32],
221 ) -> AnyResult<()> {
222 if samples.is_empty() {
223 return Ok(());
224 }
225
226 let ctx = self.selected_context(opts)?;
227
228 let _ = opts.enable_voice_activity_detection;
230 emit_segments(ctx, opts, samples, &mut |seg| {
231 encoder.write_segment(seg).map_err(Into::into)
232 })
233 }
234
235 fn create_stream_anyhow<'a>(
236 &'a self,
237 opts: &'a Opts,
238 encoder: &'a mut dyn SegmentEncoder,
239 ) -> AnyResult<WhisperStream<'a>> {
240 let ctx = self.selected_context(opts)?;
241
242 let _ = opts.enable_voice_activity_detection;
244 Ok(WhisperStream {
245 inner: BufferedSegmentTranscriber::new(ctx, opts, encoder),
246 })
247 }
248}