1#[cfg(feature = "native")]
4mod ffi;
5
6#[cfg(feature = "native")]
7use std::ffi::CStr;
8#[cfg(any(feature = "native", test))]
9use std::ffi::CString;
10use std::fmt::{Display, Formatter};
11#[cfg(feature = "native")]
12use std::fs::File;
13#[cfg(any(feature = "native", test))]
14use std::fs::{self, OpenOptions};
15#[cfg(any(feature = "native", test))]
16use std::io::Write;
17#[cfg(feature = "native")]
18use std::io::{BufWriter, Read};
19use std::path::{Path, PathBuf};
20#[cfg(any(feature = "native", test))]
21use std::thread;
22#[cfg(any(feature = "native", test))]
23use std::time::{Duration, Instant};
24
25use serde::{Deserialize, Serialize};
26#[cfg(feature = "native")]
27use sha2::{Digest, Sha256};
28
29#[derive(
30 Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Default,
31)]
32pub enum WhisperCppModel {
34 #[serde(rename = "tiny.en")]
35 TinyEn,
37 #[serde(rename = "tiny")]
38 Tiny,
40 #[serde(rename = "base.en")]
41 #[default]
42 BaseEn,
44 #[serde(rename = "base")]
45 Base,
47 #[serde(rename = "small.en")]
48 SmallEn,
50 #[serde(rename = "small")]
51 Small,
53 #[serde(rename = "medium.en")]
54 MediumEn,
56 #[serde(rename = "medium")]
57 Medium,
59 #[serde(rename = "large-v1")]
60 LargeV1,
62 #[serde(rename = "large-v2")]
63 LargeV2,
65 #[serde(rename = "large-v3")]
66 LargeV3,
68 #[serde(rename = "large-v3-turbo")]
69 LargeV3Turbo,
71}
72
73impl WhisperCppModel {
74 pub const ALL: [Self; 12] = [
76 Self::TinyEn,
77 Self::Tiny,
78 Self::BaseEn,
79 Self::Base,
80 Self::SmallEn,
81 Self::Small,
82 Self::MediumEn,
83 Self::Medium,
84 Self::LargeV1,
85 Self::LargeV2,
86 Self::LargeV3,
87 Self::LargeV3Turbo,
88 ];
89
90 pub fn id(self) -> &'static str {
92 match self {
93 Self::TinyEn => "tiny.en",
94 Self::Tiny => "tiny",
95 Self::BaseEn => "base.en",
96 Self::Base => "base",
97 Self::SmallEn => "small.en",
98 Self::Small => "small",
99 Self::MediumEn => "medium.en",
100 Self::Medium => "medium",
101 Self::LargeV1 => "large-v1",
102 Self::LargeV2 => "large-v2",
103 Self::LargeV3 => "large-v3",
104 Self::LargeV3Turbo => "large-v3-turbo",
105 }
106 }
107
108 pub fn file_name(self) -> String {
110 format!("ggml-{}.bin", self.id())
111 }
112
113 pub fn download_url(self) -> String {
115 format!(
116 "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/{}",
117 self.file_name()
118 )
119 }
120
121 pub fn checksum_sha256(self) -> &'static str {
123 match self {
124 Self::TinyEn => "0d686a2a6a22b02da2ef3101d4c86e68461363a623c58f27f81b1b2d36b42317",
125 Self::Tiny => "518970a29bedb265f23ac48d486ddbc63bedffd90967b10140ae5ac61243acf3",
126 Self::BaseEn => "a03779c86df3323075f5e796cb2ce5029f00ec8869eee3fdfb897afe36c6d002",
127 Self::Base => "2f62d18b50c3f3feafbf990eec23a93d319660b1efbdd3fff55e52b7cde2e374",
128 Self::SmallEn => "0d57184d34ae7d736e5bb2db5bf83debe730bd53dcefa235a0979b9dcfd33fb3",
129 Self::Small => "edd29d67e70b000132af65205b99bb774b77abc13d10103e14f80ce2242913e1",
130 Self::MediumEn => "a163589aa264d5188df3b05ed4eac56bfd97e26910f207809d869f7e99886fd2",
131 Self::Medium => "d3d5696e6a3e0ca2aa08eb31cad208ffa1e87b3cc341f59e628fbdcf8122de9b",
132 Self::LargeV1 => "cbcb187d1e1abe979d33636cdc63381de20738eeda0885c39440b086e184248a",
133 Self::LargeV2 => "c6d6d3dcebc5e0074175386e17eba305fc5cc7d3d5dff3ecfd11e8f2bd4222d7",
134 Self::LargeV3 => "766d11cebbdf5a67c179c5774e2642b609e35e1a30240e7b559d5647c655b0a4",
135 Self::LargeV3Turbo => {
136 "5a4b65b05933d70ce9d5aa6265eb128fa5eba38f6fee40836fdedc4d2fde42ad"
137 }
138 }
139 }
140
141 pub fn multilingual(self) -> bool {
143 !matches!(
144 self,
145 Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn
146 )
147 }
148}
149
150impl Display for WhisperCppModel {
151 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152 f.write_str(self.id())
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
157pub struct WhisperCppConfig {
159 #[serde(default)]
160 pub model: WhisperCppModel,
162 pub language: Option<String>,
164 #[serde(default)]
165 pub translate: bool,
167 pub threads: Option<usize>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172pub struct WhisperCppSegment {
174 pub index: u64,
176 pub start_seconds: Option<f64>,
178 pub end_seconds: Option<f64>,
180 pub text: String,
182 pub confidence: Option<f32>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187pub struct WhisperCppTranscription {
189 pub text: Option<String>,
191 pub language: Option<String>,
193 pub segments: Vec<WhisperCppSegment>,
195 pub source: Option<String>,
197}
198
199#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
200#[serde(rename_all = "snake_case")]
201pub enum WhisperCppPhase {
203 Preparing,
205 DownloadingModel,
207 LoadingModel,
209 Transcribing,
211}
212
213impl WhisperCppPhase {
214 pub fn as_str(self) -> &'static str {
216 match self {
217 Self::Preparing => "preparing",
218 Self::DownloadingModel => "downloading_model",
219 Self::LoadingModel => "loading_model",
220 Self::Transcribing => "transcribing",
221 }
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
226pub struct WhisperCppProgressEvent {
228 pub phase: WhisperCppPhase,
230 pub message: String,
232 pub progress: Option<f32>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
237pub struct WhisperCppModelStatus {
239 pub model: WhisperCppModel,
241 pub cached: bool,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
246pub struct WhisperCppCatalog {
248 pub default_model: WhisperCppModel,
250 pub models: Vec<WhisperCppModelStatus>,
252}
253
254#[derive(Debug, thiserror::Error)]
255pub enum WhisperCppError {
257 #[error("I/O error: {0}")]
258 Io(#[from] std::io::Error),
260 #[error("wave input error: {0}")]
261 Wav(#[from] hound::Error),
263 #[error("network error: {0}")]
264 Http(String),
266 #[error("invalid input: {0}")]
267 InvalidInput(String),
269 #[error("unsupported language `{0}`")]
270 UnsupportedLanguage(String),
272 #[error("downloaded model `{model}` failed checksum verification")]
273 InvalidChecksum {
275 model: WhisperCppModel,
277 },
278 #[error("failed to initialize whisper.cpp from `{0}`")]
279 Initialization(String),
281 #[error("whisper.cpp inference failed for `{0}`")]
282 Inference(String),
284 #[error("invalid utf-8 returned by whisper.cpp")]
285 InvalidUtf8,
287}
288
289pub type Result<T> = std::result::Result<T, WhisperCppError>;
291
292type OwnedProgressCallback = dyn FnMut(WhisperCppProgressEvent) + 'static;
293
294#[derive(Clone)]
295pub struct ModelStore {
297 root: PathBuf,
298}
299
300impl Default for ModelStore {
301 fn default() -> Self {
302 Self {
303 root: cache_root().join("whisper-cpp"),
304 }
305 }
306}
307
308impl ModelStore {
309 pub fn new(root: PathBuf) -> Self {
311 Self { root }
312 }
313
314 pub fn models_dir(&self) -> PathBuf {
316 self.root.join("models")
317 }
318
319 pub fn model_path(&self, model: WhisperCppModel) -> PathBuf {
321 self.models_dir().join(model.file_name())
322 }
323
324 pub fn lock_path(&self, model: WhisperCppModel) -> PathBuf {
326 self.models_dir()
327 .join(format!("{}.lock", model.file_name()))
328 }
329
330 pub fn catalog(&self) -> WhisperCppCatalog {
332 WhisperCppCatalog {
333 default_model: WhisperCppModel::default(),
334 models: WhisperCppModel::ALL
335 .into_iter()
336 .map(|model| WhisperCppModelStatus {
337 model,
338 cached: self.model_path(model).is_file(),
339 })
340 .collect(),
341 }
342 }
343
344 #[cfg(feature = "native")]
345 fn ensure_model(
346 &self,
347 model: WhisperCppModel,
348 progress: &mut ProgressSink<'_>,
349 ) -> Result<PathBuf> {
350 fs::create_dir_all(self.models_dir())?;
351 let model_path = self.model_path(model);
352 if model_path.is_file() {
353 return Ok(model_path);
354 }
355
356 let _lock = FileLock::acquire(self.lock_path(model))?;
357 if model_path.is_file() {
358 return Ok(model_path);
359 }
360
361 progress.emit(
362 WhisperCppPhase::DownloadingModel,
363 format!("downloading whisper.cpp model `{model}`"),
364 Some(0.0),
365 );
366
367 let temp_path = model_path.with_extension("bin.part");
368 if temp_path.exists() {
369 let _ = fs::remove_file(&temp_path);
370 }
371
372 let response = ureq::get(&model.download_url())
373 .call()
374 .map_err(|error| WhisperCppError::Http(error.to_string()))?;
375 let total_bytes = response
376 .header("Content-Length")
377 .and_then(|value| value.parse::<u64>().ok());
378 let mut reader = response.into_reader();
379 let mut file = BufWriter::new(File::create(&temp_path)?);
380 let mut hasher = Sha256::new();
381 let mut downloaded = 0_u64;
382 let mut buffer = [0_u8; 64 * 1024];
383
384 loop {
385 let read = reader
386 .read(&mut buffer)
387 .map_err(|error| WhisperCppError::Http(error.to_string()))?;
388 if read == 0 {
389 break;
390 }
391 file.write_all(&buffer[..read])?;
392 hasher.update(&buffer[..read]);
393 downloaded += read as u64;
394 let fraction =
395 total_bytes.map(|total| (downloaded as f32 / total as f32).clamp(0.0, 1.0));
396 progress.emit(
397 WhisperCppPhase::DownloadingModel,
398 format!("downloading whisper.cpp model `{model}`"),
399 fraction,
400 );
401 }
402 file.flush()?;
403
404 let checksum = format!("{:x}", hasher.finalize());
405 if checksum != model.checksum_sha256() {
406 let _ = fs::remove_file(&temp_path);
407 return Err(WhisperCppError::InvalidChecksum { model });
408 }
409
410 fs::rename(temp_path, &model_path)?;
411 Ok(model_path)
412 }
413}
414
415pub struct WhisperCppTranscriber {
417 config: WhisperCppConfig,
418 store: ModelStore,
419 progress: Option<Box<OwnedProgressCallback>>,
420}
421
422impl WhisperCppTranscriber {
423 pub fn new(config: WhisperCppConfig) -> Self {
425 Self {
426 config,
427 store: ModelStore::default(),
428 progress: None,
429 }
430 }
431
432 pub fn with_model_store(mut self, store: ModelStore) -> Self {
434 self.store = store;
435 self
436 }
437
438 pub fn on_progress<F>(mut self, callback: F) -> Self
440 where
441 F: FnMut(WhisperCppProgressEvent) + 'static,
442 {
443 self.progress = Some(Box::new(callback));
444 self
445 }
446
447 pub fn transcribe_file(&mut self, input: &Path) -> Result<WhisperCppTranscription> {
449 let store = self.store.clone();
450 let config = self.config.clone();
451 let mut progress = ProgressSink::new(self.progress_deref_mut());
452 transcribe_impl(&store, &config, input, &mut progress)
453 }
454
455 pub fn transcribe_file_with_progress(
457 &mut self,
458 input: &Path,
459 progress: &mut dyn FnMut(WhisperCppProgressEvent),
460 ) -> Result<WhisperCppTranscription> {
461 let mut progress = ProgressSink::new(Some(progress));
462 transcribe_impl(&self.store, &self.config, input, &mut progress)
463 }
464
465 fn progress_deref_mut(&mut self) -> Option<&mut dyn FnMut(WhisperCppProgressEvent)> {
466 self.progress
467 .as_mut()
468 .map(|callback| callback.as_mut() as &mut dyn FnMut(WhisperCppProgressEvent))
469 }
470}
471
472pub fn transcription_catalog() -> WhisperCppCatalog {
474 ModelStore::default().catalog()
475}
476
477pub fn whisper_cpp_system_info() -> Option<String> {
479 #[cfg(not(feature = "native"))]
480 {
481 None
482 }
483
484 #[cfg(feature = "native")]
485 {
486 let value = unsafe { ffi::whisper_print_system_info() };
487 if value.is_null() {
488 return None;
489 }
490 unsafe { CStr::from_ptr(value) }
491 .to_str()
492 .ok()
493 .map(|value| value.to_string())
494 }
495}
496
497#[cfg(feature = "native")]
498fn transcribe_impl(
499 store: &ModelStore,
500 config: &WhisperCppConfig,
501 input: &Path,
502 progress: &mut ProgressSink<'_>,
503) -> Result<WhisperCppTranscription> {
504 let model = config.model;
505 progress.emit(
506 WhisperCppPhase::Preparing,
507 format!(
508 "preparing native whisper.cpp transcription for {}",
509 input.display()
510 ),
511 None,
512 );
513
514 let model_path = store.ensure_model(model, progress)?;
515 progress.emit(
516 WhisperCppPhase::LoadingModel,
517 format!("loading whisper.cpp model `{model}`"),
518 None,
519 );
520
521 let audio = read_wav_mono_f32(input)?;
522 progress.emit(
523 WhisperCppPhase::Transcribing,
524 format!("transcribing audio with whisper.cpp model `{model}`"),
525 None,
526 );
527
528 let context = WhisperContext::from_model(&model_path)?;
529 let mut params = unsafe {
530 ffi::whisper_full_default_params(ffi::whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY)
531 };
532 params.n_threads = resolve_threads(config.threads);
533 params.translate = config.translate;
534 params.print_progress = false;
535 params.print_realtime = false;
536 params.print_special = false;
537 params.print_timestamps = false;
538 params.no_timestamps = false;
539
540 let language = resolve_language(config)?;
541 if let Some(language) = language.as_ref() {
542 let lang_id = unsafe { ffi::whisper_lang_id(language.as_ptr()) };
543 if lang_id < 0 {
544 return Err(WhisperCppError::UnsupportedLanguage(
545 language.to_string_lossy().into_owned(),
546 ));
547 }
548 params.language = language.as_ptr();
549 } else {
550 params.language = std::ptr::null();
551 }
552 params.detect_language = false;
553
554 let status = unsafe {
555 ffi::whisper_full(
556 context.raw,
557 params,
558 audio.samples.as_ptr(),
559 audio.samples.len() as i32,
560 )
561 };
562 if status != 0 {
563 return Err(WhisperCppError::Inference(model_path.display().to_string()));
564 }
565
566 let segment_count = unsafe { ffi::whisper_full_n_segments(context.raw) };
567 let mut segments = Vec::with_capacity(segment_count.max(0) as usize);
568 for index in 0..segment_count {
569 let text_ptr = unsafe { ffi::whisper_full_get_segment_text(context.raw, index) };
570 let text = c_string(text_ptr)?.trim().to_string();
571 let start = unsafe { ffi::whisper_full_get_segment_t0(context.raw, index) };
572 let end = unsafe { ffi::whisper_full_get_segment_t1(context.raw, index) };
573 let token_count = unsafe { ffi::whisper_full_n_tokens(context.raw, index) };
574 let confidence = if token_count > 0 {
575 let mut total = 0.0_f32;
576 for token_index in 0..token_count {
577 total += unsafe { ffi::whisper_full_get_token_p(context.raw, index, token_index) };
578 }
579 Some(total / token_count as f32)
580 } else {
581 None
582 };
583 segments.push(WhisperCppSegment {
584 index: index as u64,
585 start_seconds: Some(timestamp_to_seconds(start)),
586 end_seconds: Some(timestamp_to_seconds(end)),
587 text,
588 confidence,
589 });
590 }
591
592 let language = unsafe { ffi::whisper_full_lang_id(context.raw) };
593 let language = if language >= 0 {
594 Some(c_string(unsafe { ffi::whisper_lang_str(language) })?)
595 } else {
596 None
597 };
598 let text = join_segments(&segments);
599
600 Ok(WhisperCppTranscription {
601 text,
602 language,
603 segments,
604 source: Some(model_path.to_string_lossy().into_owned()),
605 })
606}
607
608#[cfg(not(feature = "native"))]
609fn transcribe_impl(
610 _store: &ModelStore,
611 _config: &WhisperCppConfig,
612 _input: &Path,
613 _progress: &mut ProgressSink<'_>,
614) -> Result<WhisperCppTranscription> {
615 Err(WhisperCppError::Initialization(
616 "text-whisper-cpp was built without the `native` feature".to_string(),
617 ))
618}
619
620#[cfg(any(feature = "native", test))]
621fn resolve_language(config: &WhisperCppConfig) -> Result<Option<CString>> {
622 match config.language.as_deref().map(str::trim) {
623 Some("") => resolve_default_language(config.model),
624 Some(value) if value.eq_ignore_ascii_case("auto") => resolve_default_language(config.model),
625 Some(value) => CString::new(value)
626 .map(Some)
627 .map_err(|_| WhisperCppError::UnsupportedLanguage(value.to_string())),
628 None => resolve_default_language(config.model),
629 }
630}
631
632#[cfg(any(feature = "native", test))]
633fn resolve_default_language(model: WhisperCppModel) -> Result<Option<CString>> {
634 if model.multilingual() {
635 Ok(None)
636 } else {
637 CString::new("en")
638 .map(Some)
639 .map_err(|_| WhisperCppError::UnsupportedLanguage("en".to_string()))
640 }
641}
642
643#[cfg_attr(not(feature = "native"), allow(dead_code))]
644struct ProgressSink<'a> {
645 callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>,
646}
647
648impl<'a> ProgressSink<'a> {
649 fn new(callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>) -> Self {
650 Self { callback }
651 }
652
653 #[cfg(feature = "native")]
654 fn emit(&mut self, phase: WhisperCppPhase, message: String, progress: Option<f32>) {
655 if let Some(callback) = self.callback.as_mut() {
656 callback(WhisperCppProgressEvent {
657 phase,
658 message,
659 progress,
660 });
661 }
662 }
663}
664
665#[cfg(feature = "native")]
666fn read_wav_mono_f32(path: &Path) -> Result<AudioSamples> {
667 let mut reader = hound::WavReader::open(path)?;
668 let spec = reader.spec();
669 if spec.channels == 0 {
670 return Err(WhisperCppError::InvalidInput(
671 "wav file has no channels".to_string(),
672 ));
673 }
674 if spec.sample_rate != 16_000 {
675 return Err(WhisperCppError::InvalidInput(format!(
676 "expected 16 kHz wav input, got {} Hz",
677 spec.sample_rate
678 )));
679 }
680
681 let interleaved = match spec.sample_format {
682 hound::SampleFormat::Int => read_int_samples(&mut reader, spec.bits_per_sample)?,
683 hound::SampleFormat::Float => reader
684 .samples::<f32>()
685 .collect::<std::result::Result<Vec<_>, _>>()?,
686 };
687
688 let channels = spec.channels as usize;
689 let samples = if channels == 1 {
690 interleaved
691 } else {
692 interleaved
693 .chunks(channels)
694 .map(|frame| frame.iter().copied().sum::<f32>() / frame.len() as f32)
695 .collect()
696 };
697
698 Ok(AudioSamples { samples })
699}
700
701#[cfg(feature = "native")]
702fn read_int_samples(
703 reader: &mut hound::WavReader<std::io::BufReader<File>>,
704 bits_per_sample: u16,
705) -> Result<Vec<f32>> {
706 let scale = ((1_i64 << (bits_per_sample.saturating_sub(1) as u32)) - 1) as f32;
707 if bits_per_sample <= 16 {
708 Ok(reader
709 .samples::<i16>()
710 .map(|sample| sample.map(|sample| sample as f32 / scale))
711 .collect::<std::result::Result<Vec<_>, _>>()?)
712 } else {
713 Ok(reader
714 .samples::<i32>()
715 .map(|sample| sample.map(|sample| sample as f32 / scale))
716 .collect::<std::result::Result<Vec<_>, _>>()?)
717 }
718}
719
720#[cfg(feature = "native")]
721fn resolve_threads(value: Option<usize>) -> i32 {
722 value
723 .or_else(|| thread::available_parallelism().ok().map(usize::from))
724 .unwrap_or(4)
725 .min(i32::MAX as usize) as i32
726}
727
728#[cfg(feature = "native")]
729fn timestamp_to_seconds(value: i64) -> f64 {
730 value as f64 / 100.0
731}
732
733#[cfg(feature = "native")]
734fn join_segments(segments: &[WhisperCppSegment]) -> Option<String> {
735 let text = segments
736 .iter()
737 .map(|segment| segment.text.trim())
738 .filter(|text| !text.is_empty())
739 .collect::<Vec<_>>()
740 .join(" ");
741 (!text.is_empty()).then_some(text)
742}
743
744#[cfg(feature = "native")]
745fn c_string(value: *const std::ffi::c_char) -> Result<String> {
746 if value.is_null() {
747 return Ok(String::new());
748 }
749 unsafe { CStr::from_ptr(value) }
750 .to_str()
751 .map(|value| value.to_string())
752 .map_err(|_| WhisperCppError::InvalidUtf8)
753}
754
755fn cache_root() -> PathBuf {
756 if let Some(dir) = std::env::var_os("VIDEO_ANALYSIS_STUDIO_CACHE_DIR") {
757 return PathBuf::from(dir);
758 }
759 if let Some(dir) = std::env::var_os("XDG_CACHE_HOME") {
760 return PathBuf::from(dir).join("video-analysis-studio");
761 }
762 if cfg!(target_os = "windows") {
763 if let Some(dir) = std::env::var_os("LOCALAPPDATA") {
764 return PathBuf::from(dir).join("video-analysis-studio");
765 }
766 }
767 if let Some(home) = std::env::var_os("HOME") {
768 return PathBuf::from(home)
769 .join(".cache")
770 .join("video-analysis-studio");
771 }
772 PathBuf::from(".cache/video-analysis-studio")
773}
774
775#[cfg(feature = "native")]
776struct AudioSamples {
777 samples: Vec<f32>,
778}
779
780#[cfg(feature = "native")]
781struct WhisperContext {
782 raw: *mut ffi::whisper_context,
783}
784
785#[cfg(feature = "native")]
786impl WhisperContext {
787 fn from_model(path: &Path) -> Result<Self> {
788 let model_path = CString::new(path.to_string_lossy().into_owned())
789 .map_err(|_| WhisperCppError::Initialization(path.display().to_string()))?;
790 let mut params = unsafe { ffi::whisper_context_default_params() };
791 params.use_gpu = cfg!(target_os = "macos");
792 params.flash_attn = false;
793 let raw = unsafe { ffi::whisper_init_from_file_with_params(model_path.as_ptr(), params) };
794 if raw.is_null() {
795 return Err(WhisperCppError::Initialization(path.display().to_string()));
796 }
797 Ok(Self { raw })
798 }
799}
800
801#[cfg(feature = "native")]
802impl Drop for WhisperContext {
803 fn drop(&mut self) {
804 if !self.raw.is_null() {
805 unsafe { ffi::whisper_free(self.raw) };
806 }
807 }
808}
809
810#[cfg(any(feature = "native", test))]
811struct FileLock {
812 path: PathBuf,
813}
814
815#[cfg(any(feature = "native", test))]
816impl FileLock {
817 fn acquire(path: PathBuf) -> Result<Self> {
818 let deadline = Instant::now() + Duration::from_secs(120);
819 loop {
820 match OpenOptions::new().create_new(true).write(true).open(&path) {
821 Ok(mut file) => {
822 let _ = writeln!(file, "{}", std::process::id());
823 return Ok(Self { path });
824 }
825 Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => {
826 if Instant::now() >= deadline {
827 return Err(WhisperCppError::Io(error));
828 }
829 thread::sleep(Duration::from_millis(250));
830 }
831 Err(error) => return Err(WhisperCppError::Io(error)),
832 }
833 }
834 }
835}
836
837#[cfg(any(feature = "native", test))]
838impl Drop for FileLock {
839 fn drop(&mut self) {
840 let _ = fs::remove_file(&self.path);
841 }
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847 use tempfile::tempdir;
848
849 #[test]
850 fn model_metadata_matches_expected_file_names() {
851 assert_eq!(WhisperCppModel::BaseEn.file_name(), "ggml-base.en.bin");
852 assert_eq!(
853 WhisperCppModel::LargeV3Turbo.download_url(),
854 "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin"
855 );
856 }
857
858 #[test]
859 fn catalog_uses_base_en_by_default() {
860 let catalog = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test")).catalog();
861 assert_eq!(catalog.default_model, WhisperCppModel::BaseEn);
862 assert_eq!(catalog.models.len(), WhisperCppModel::ALL.len());
863 }
864
865 #[test]
866 fn cache_paths_are_stable() {
867 let store = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test"));
868 assert_eq!(
869 store.model_path(WhisperCppModel::SmallEn),
870 PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin")
871 );
872 assert_eq!(
873 store.lock_path(WhisperCppModel::SmallEn),
874 PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin.lock")
875 );
876 }
877
878 #[test]
879 fn file_lock_creates_and_releases_lock_path() {
880 let dir = tempdir().unwrap();
881 let path = dir.path().join("model.lock");
882 {
883 let _lock = FileLock::acquire(path.clone()).unwrap();
884 assert!(path.is_file());
885 }
886 assert!(!path.exists());
887 }
888
889 #[test]
890 fn english_only_models_default_to_english() {
891 let config = WhisperCppConfig {
892 model: WhisperCppModel::BaseEn,
893 language: None,
894 translate: false,
895 threads: None,
896 };
897
898 let language = resolve_language(&config).unwrap().unwrap();
899 assert_eq!(language.to_str().unwrap(), "en");
900 }
901
902 #[test]
903 fn multilingual_models_default_to_auto_detection_without_detect_only_mode() {
904 let config = WhisperCppConfig {
905 model: WhisperCppModel::Base,
906 language: None,
907 translate: false,
908 threads: None,
909 };
910
911 assert_eq!(resolve_language(&config).unwrap(), None);
912 }
913
914 #[test]
915 fn auto_language_uses_english_for_english_only_models() {
916 let config = WhisperCppConfig {
917 model: WhisperCppModel::SmallEn,
918 language: Some("auto".to_string()),
919 translate: false,
920 threads: None,
921 };
922
923 let language = resolve_language(&config).unwrap().unwrap();
924 assert_eq!(language.to_str().unwrap(), "en");
925 }
926}