Skip to main content

text_whisper_cpp/
lib.rs

1//! Native whisper.cpp transcription bindings and model management.
2
3#[cfg(feature = "native")]
4mod ffi;
5
6#[cfg(feature = "native")]
7use std::ffi::CStr;
8#[cfg(any(feature = "native", test))]
9use std::ffi::CString;
10use std::fmt::{Display, Formatter};
11#[cfg(feature = "native")]
12use std::fs::File;
13#[cfg(any(feature = "native", test))]
14use std::fs::{self, OpenOptions};
15#[cfg(any(feature = "native", test))]
16use std::io::Write;
17#[cfg(feature = "native")]
18use std::io::{BufWriter, Read};
19use std::path::{Path, PathBuf};
20#[cfg(any(feature = "native", test))]
21use std::thread;
22#[cfg(any(feature = "native", test))]
23use std::time::{Duration, Instant};
24
25use serde::{Deserialize, Serialize};
26#[cfg(feature = "native")]
27use sha2::{Digest, Sha256};
28
29#[derive(
30    Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Default,
31)]
32/// Variants describing whisper cpp model.
33pub enum WhisperCppModel {
34    #[serde(rename = "tiny.en")]
35    /// The tiny en variant.
36    TinyEn,
37    #[serde(rename = "tiny")]
38    /// The tiny variant.
39    Tiny,
40    #[serde(rename = "base.en")]
41    #[default]
42    /// The base en variant.
43    BaseEn,
44    #[serde(rename = "base")]
45    /// The base variant.
46    Base,
47    #[serde(rename = "small.en")]
48    /// The small en variant.
49    SmallEn,
50    #[serde(rename = "small")]
51    /// The small variant.
52    Small,
53    #[serde(rename = "medium.en")]
54    /// The medium en variant.
55    MediumEn,
56    #[serde(rename = "medium")]
57    /// The medium variant.
58    Medium,
59    #[serde(rename = "large-v1")]
60    /// The large v1 variant.
61    LargeV1,
62    #[serde(rename = "large-v2")]
63    /// The large v2 variant.
64    LargeV2,
65    #[serde(rename = "large-v3")]
66    /// The large v3 variant.
67    LargeV3,
68    #[serde(rename = "large-v3-turbo")]
69    /// The large v3 turbo variant.
70    LargeV3Turbo,
71}
72
73impl WhisperCppModel {
74    /// Constant for all.
75    pub const ALL: [Self; 12] = [
76        Self::TinyEn,
77        Self::Tiny,
78        Self::BaseEn,
79        Self::Base,
80        Self::SmallEn,
81        Self::Small,
82        Self::MediumEn,
83        Self::Medium,
84        Self::LargeV1,
85        Self::LargeV2,
86        Self::LargeV3,
87        Self::LargeV3Turbo,
88    ];
89
90    /// Returns identifier.
91    pub fn id(self) -> &'static str {
92        match self {
93            Self::TinyEn => "tiny.en",
94            Self::Tiny => "tiny",
95            Self::BaseEn => "base.en",
96            Self::Base => "base",
97            Self::SmallEn => "small.en",
98            Self::Small => "small",
99            Self::MediumEn => "medium.en",
100            Self::Medium => "medium",
101            Self::LargeV1 => "large-v1",
102            Self::LargeV2 => "large-v2",
103            Self::LargeV3 => "large-v3",
104            Self::LargeV3Turbo => "large-v3-turbo",
105        }
106    }
107
108    /// Returns file name.
109    pub fn file_name(self) -> String {
110        format!("ggml-{}.bin", self.id())
111    }
112
113    /// Returns download URL.
114    pub fn download_url(self) -> String {
115        format!(
116            "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/{}",
117            self.file_name()
118        )
119    }
120
121    /// Returns checksum sha256.
122    pub fn checksum_sha256(self) -> &'static str {
123        match self {
124            Self::TinyEn => "0d686a2a6a22b02da2ef3101d4c86e68461363a623c58f27f81b1b2d36b42317",
125            Self::Tiny => "518970a29bedb265f23ac48d486ddbc63bedffd90967b10140ae5ac61243acf3",
126            Self::BaseEn => "a03779c86df3323075f5e796cb2ce5029f00ec8869eee3fdfb897afe36c6d002",
127            Self::Base => "2f62d18b50c3f3feafbf990eec23a93d319660b1efbdd3fff55e52b7cde2e374",
128            Self::SmallEn => "0d57184d34ae7d736e5bb2db5bf83debe730bd53dcefa235a0979b9dcfd33fb3",
129            Self::Small => "edd29d67e70b000132af65205b99bb774b77abc13d10103e14f80ce2242913e1",
130            Self::MediumEn => "a163589aa264d5188df3b05ed4eac56bfd97e26910f207809d869f7e99886fd2",
131            Self::Medium => "d3d5696e6a3e0ca2aa08eb31cad208ffa1e87b3cc341f59e628fbdcf8122de9b",
132            Self::LargeV1 => "cbcb187d1e1abe979d33636cdc63381de20738eeda0885c39440b086e184248a",
133            Self::LargeV2 => "c6d6d3dcebc5e0074175386e17eba305fc5cc7d3d5dff3ecfd11e8f2bd4222d7",
134            Self::LargeV3 => "766d11cebbdf5a67c179c5774e2642b609e35e1a30240e7b559d5647c655b0a4",
135            Self::LargeV3Turbo => {
136                "5a4b65b05933d70ce9d5aa6265eb128fa5eba38f6fee40836fdedc4d2fde42ad"
137            }
138        }
139    }
140
141    /// Returns multilingual.
142    pub fn multilingual(self) -> bool {
143        !matches!(
144            self,
145            Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn
146        )
147    }
148}
149
150impl Display for WhisperCppModel {
151    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152        f.write_str(self.id())
153    }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
157/// Data type for whisper cpp config.
158pub struct WhisperCppConfig {
159    #[serde(default)]
160    /// The model value.
161    pub model: WhisperCppModel,
162    /// Language tag for this value.
163    pub language: Option<String>,
164    #[serde(default)]
165    /// The translate value.
166    pub translate: bool,
167    /// The threads value.
168    pub threads: Option<usize>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172/// Data type for whisper cpp segment.
173pub struct WhisperCppSegment {
174    /// The index value.
175    pub index: u64,
176    /// The start seconds value.
177    pub start_seconds: Option<f64>,
178    /// The end seconds value.
179    pub end_seconds: Option<f64>,
180    /// Text content for this value.
181    pub text: String,
182    /// Confidence score for this value.
183    pub confidence: Option<f32>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187/// Data type for whisper cpp transcription.
188pub struct WhisperCppTranscription {
189    /// Text content for this value.
190    pub text: Option<String>,
191    /// Language tag for this value.
192    pub language: Option<String>,
193    /// The segments value.
194    pub segments: Vec<WhisperCppSegment>,
195    /// The source value.
196    pub source: Option<String>,
197}
198
199#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
200#[serde(rename_all = "snake_case")]
201/// Variants describing whisper cpp phase.
202pub enum WhisperCppPhase {
203    /// The preparing variant.
204    Preparing,
205    /// The downloading model variant.
206    DownloadingModel,
207    /// The loading model variant.
208    LoadingModel,
209    /// The transcribing variant.
210    Transcribing,
211}
212
213impl WhisperCppPhase {
214    /// Borrows this value as a str.
215    pub fn as_str(self) -> &'static str {
216        match self {
217            Self::Preparing => "preparing",
218            Self::DownloadingModel => "downloading_model",
219            Self::LoadingModel => "loading_model",
220            Self::Transcribing => "transcribing",
221        }
222    }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
226/// Data type for whisper cpp progress event.
227pub struct WhisperCppProgressEvent {
228    /// The phase value.
229    pub phase: WhisperCppPhase,
230    /// The message value.
231    pub message: String,
232    /// The progress value.
233    pub progress: Option<f32>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
237/// Data type for whisper cpp model status.
238pub struct WhisperCppModelStatus {
239    /// The model value.
240    pub model: WhisperCppModel,
241    /// The cached value.
242    pub cached: bool,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
246/// Data type for whisper cpp catalog.
247pub struct WhisperCppCatalog {
248    /// The default model value.
249    pub default_model: WhisperCppModel,
250    /// The models value.
251    pub models: Vec<WhisperCppModelStatus>,
252}
253
254#[derive(Debug, thiserror::Error)]
255/// Variants describing whisper cpp error.
256pub enum WhisperCppError {
257    #[error("I/O error: {0}")]
258    /// The I/O variant.
259    Io(#[from] std::io::Error),
260    #[error("wave input error: {0}")]
261    /// The wav variant.
262    Wav(#[from] hound::Error),
263    #[error("network error: {0}")]
264    /// The http variant.
265    Http(String),
266    #[error("invalid input: {0}")]
267    /// The invalid input variant.
268    InvalidInput(String),
269    #[error("unsupported language `{0}`")]
270    /// The unsupported language variant.
271    UnsupportedLanguage(String),
272    #[error("downloaded model `{model}` failed checksum verification")]
273    /// The invalid checksum variant.
274    InvalidChecksum {
275        /// Model associated with this variant.
276        model: WhisperCppModel,
277    },
278    #[error("failed to initialize whisper.cpp from `{0}`")]
279    /// The initialization variant.
280    Initialization(String),
281    #[error("whisper.cpp inference failed for `{0}`")]
282    /// The inference variant.
283    Inference(String),
284    #[error("invalid utf-8 returned by whisper.cpp")]
285    /// The invalid utf8 variant.
286    InvalidUtf8,
287}
288
289/// Type alias for result.
290pub type Result<T> = std::result::Result<T, WhisperCppError>;
291
292type OwnedProgressCallback = dyn FnMut(WhisperCppProgressEvent) + 'static;
293
294#[derive(Clone)]
295/// Data type for model store.
296pub struct ModelStore {
297    root: PathBuf,
298}
299
300impl Default for ModelStore {
301    fn default() -> Self {
302        Self {
303            root: cache_root().join("whisper-cpp"),
304        }
305    }
306}
307
308impl ModelStore {
309    /// Creates a new value.
310    pub fn new(root: PathBuf) -> Self {
311        Self { root }
312    }
313
314    /// Returns models dir.
315    pub fn models_dir(&self) -> PathBuf {
316        self.root.join("models")
317    }
318
319    /// Returns model path.
320    pub fn model_path(&self, model: WhisperCppModel) -> PathBuf {
321        self.models_dir().join(model.file_name())
322    }
323
324    /// Returns lock path.
325    pub fn lock_path(&self, model: WhisperCppModel) -> PathBuf {
326        self.models_dir()
327            .join(format!("{}.lock", model.file_name()))
328    }
329
330    /// Returns catalog.
331    pub fn catalog(&self) -> WhisperCppCatalog {
332        WhisperCppCatalog {
333            default_model: WhisperCppModel::default(),
334            models: WhisperCppModel::ALL
335                .into_iter()
336                .map(|model| WhisperCppModelStatus {
337                    model,
338                    cached: self.model_path(model).is_file(),
339                })
340                .collect(),
341        }
342    }
343
344    #[cfg(feature = "native")]
345    fn ensure_model(
346        &self,
347        model: WhisperCppModel,
348        progress: &mut ProgressSink<'_>,
349    ) -> Result<PathBuf> {
350        fs::create_dir_all(self.models_dir())?;
351        let model_path = self.model_path(model);
352        if model_path.is_file() {
353            return Ok(model_path);
354        }
355
356        let _lock = FileLock::acquire(self.lock_path(model))?;
357        if model_path.is_file() {
358            return Ok(model_path);
359        }
360
361        progress.emit(
362            WhisperCppPhase::DownloadingModel,
363            format!("downloading whisper.cpp model `{model}`"),
364            Some(0.0),
365        );
366
367        let temp_path = model_path.with_extension("bin.part");
368        if temp_path.exists() {
369            let _ = fs::remove_file(&temp_path);
370        }
371
372        let response = ureq::get(&model.download_url())
373            .call()
374            .map_err(|error| WhisperCppError::Http(error.to_string()))?;
375        let total_bytes = response
376            .header("Content-Length")
377            .and_then(|value| value.parse::<u64>().ok());
378        let mut reader = response.into_reader();
379        let mut file = BufWriter::new(File::create(&temp_path)?);
380        let mut hasher = Sha256::new();
381        let mut downloaded = 0_u64;
382        let mut buffer = [0_u8; 64 * 1024];
383
384        loop {
385            let read = reader
386                .read(&mut buffer)
387                .map_err(|error| WhisperCppError::Http(error.to_string()))?;
388            if read == 0 {
389                break;
390            }
391            file.write_all(&buffer[..read])?;
392            hasher.update(&buffer[..read]);
393            downloaded += read as u64;
394            let fraction =
395                total_bytes.map(|total| (downloaded as f32 / total as f32).clamp(0.0, 1.0));
396            progress.emit(
397                WhisperCppPhase::DownloadingModel,
398                format!("downloading whisper.cpp model `{model}`"),
399                fraction,
400            );
401        }
402        file.flush()?;
403
404        let checksum = format!("{:x}", hasher.finalize());
405        if checksum != model.checksum_sha256() {
406            let _ = fs::remove_file(&temp_path);
407            return Err(WhisperCppError::InvalidChecksum { model });
408        }
409
410        fs::rename(temp_path, &model_path)?;
411        Ok(model_path)
412    }
413}
414
415/// Data type for whisper cpp transcriber.
416pub struct WhisperCppTranscriber {
417    config: WhisperCppConfig,
418    store: ModelStore,
419    progress: Option<Box<OwnedProgressCallback>>,
420}
421
422impl WhisperCppTranscriber {
423    /// Creates a new value.
424    pub fn new(config: WhisperCppConfig) -> Self {
425        Self {
426            config,
427            store: ModelStore::default(),
428            progress: None,
429        }
430    }
431
432    /// Returns this value with model store.
433    pub fn with_model_store(mut self, store: ModelStore) -> Self {
434        self.store = store;
435        self
436    }
437
438    /// Returns on progress.
439    pub fn on_progress<F>(mut self, callback: F) -> Self
440    where
441        F: FnMut(WhisperCppProgressEvent) + 'static,
442    {
443        self.progress = Some(Box::new(callback));
444        self
445    }
446
447    /// Returns transcribe file.
448    pub fn transcribe_file(&mut self, input: &Path) -> Result<WhisperCppTranscription> {
449        let store = self.store.clone();
450        let config = self.config.clone();
451        let mut progress = ProgressSink::new(self.progress_deref_mut());
452        transcribe_impl(&store, &config, input, &mut progress)
453    }
454
455    /// Returns transcribe file with progress.
456    pub fn transcribe_file_with_progress(
457        &mut self,
458        input: &Path,
459        progress: &mut dyn FnMut(WhisperCppProgressEvent),
460    ) -> Result<WhisperCppTranscription> {
461        let mut progress = ProgressSink::new(Some(progress));
462        transcribe_impl(&self.store, &self.config, input, &mut progress)
463    }
464
465    fn progress_deref_mut(&mut self) -> Option<&mut dyn FnMut(WhisperCppProgressEvent)> {
466        self.progress
467            .as_mut()
468            .map(|callback| callback.as_mut() as &mut dyn FnMut(WhisperCppProgressEvent))
469    }
470}
471
472/// Returns transcription catalog.
473pub fn transcription_catalog() -> WhisperCppCatalog {
474    ModelStore::default().catalog()
475}
476
477/// Returns whisper cpp system info.
478pub fn whisper_cpp_system_info() -> Option<String> {
479    #[cfg(not(feature = "native"))]
480    {
481        None
482    }
483
484    #[cfg(feature = "native")]
485    {
486        let value = unsafe { ffi::whisper_print_system_info() };
487        if value.is_null() {
488            return None;
489        }
490        unsafe { CStr::from_ptr(value) }
491            .to_str()
492            .ok()
493            .map(|value| value.to_string())
494    }
495}
496
497#[cfg(feature = "native")]
498fn transcribe_impl(
499    store: &ModelStore,
500    config: &WhisperCppConfig,
501    input: &Path,
502    progress: &mut ProgressSink<'_>,
503) -> Result<WhisperCppTranscription> {
504    let model = config.model;
505    progress.emit(
506        WhisperCppPhase::Preparing,
507        format!(
508            "preparing native whisper.cpp transcription for {}",
509            input.display()
510        ),
511        None,
512    );
513
514    let model_path = store.ensure_model(model, progress)?;
515    progress.emit(
516        WhisperCppPhase::LoadingModel,
517        format!("loading whisper.cpp model `{model}`"),
518        None,
519    );
520
521    let audio = read_wav_mono_f32(input)?;
522    progress.emit(
523        WhisperCppPhase::Transcribing,
524        format!("transcribing audio with whisper.cpp model `{model}`"),
525        None,
526    );
527
528    let context = WhisperContext::from_model(&model_path)?;
529    let mut params = unsafe {
530        ffi::whisper_full_default_params(ffi::whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY)
531    };
532    params.n_threads = resolve_threads(config.threads);
533    params.translate = config.translate;
534    params.print_progress = false;
535    params.print_realtime = false;
536    params.print_special = false;
537    params.print_timestamps = false;
538    params.no_timestamps = false;
539
540    let language = resolve_language(config)?;
541    if let Some(language) = language.as_ref() {
542        let lang_id = unsafe { ffi::whisper_lang_id(language.as_ptr()) };
543        if lang_id < 0 {
544            return Err(WhisperCppError::UnsupportedLanguage(
545                language.to_string_lossy().into_owned(),
546            ));
547        }
548        params.language = language.as_ptr();
549    } else {
550        params.language = std::ptr::null();
551    }
552    params.detect_language = false;
553
554    let status = unsafe {
555        ffi::whisper_full(
556            context.raw,
557            params,
558            audio.samples.as_ptr(),
559            audio.samples.len() as i32,
560        )
561    };
562    if status != 0 {
563        return Err(WhisperCppError::Inference(model_path.display().to_string()));
564    }
565
566    let segment_count = unsafe { ffi::whisper_full_n_segments(context.raw) };
567    let mut segments = Vec::with_capacity(segment_count.max(0) as usize);
568    for index in 0..segment_count {
569        let text_ptr = unsafe { ffi::whisper_full_get_segment_text(context.raw, index) };
570        let text = c_string(text_ptr)?.trim().to_string();
571        let start = unsafe { ffi::whisper_full_get_segment_t0(context.raw, index) };
572        let end = unsafe { ffi::whisper_full_get_segment_t1(context.raw, index) };
573        let token_count = unsafe { ffi::whisper_full_n_tokens(context.raw, index) };
574        let confidence = if token_count > 0 {
575            let mut total = 0.0_f32;
576            for token_index in 0..token_count {
577                total += unsafe { ffi::whisper_full_get_token_p(context.raw, index, token_index) };
578            }
579            Some(total / token_count as f32)
580        } else {
581            None
582        };
583        segments.push(WhisperCppSegment {
584            index: index as u64,
585            start_seconds: Some(timestamp_to_seconds(start)),
586            end_seconds: Some(timestamp_to_seconds(end)),
587            text,
588            confidence,
589        });
590    }
591
592    let language = unsafe { ffi::whisper_full_lang_id(context.raw) };
593    let language = if language >= 0 {
594        Some(c_string(unsafe { ffi::whisper_lang_str(language) })?)
595    } else {
596        None
597    };
598    let text = join_segments(&segments);
599
600    Ok(WhisperCppTranscription {
601        text,
602        language,
603        segments,
604        source: Some(model_path.to_string_lossy().into_owned()),
605    })
606}
607
608#[cfg(not(feature = "native"))]
609fn transcribe_impl(
610    _store: &ModelStore,
611    _config: &WhisperCppConfig,
612    _input: &Path,
613    _progress: &mut ProgressSink<'_>,
614) -> Result<WhisperCppTranscription> {
615    Err(WhisperCppError::Initialization(
616        "text-whisper-cpp was built without the `native` feature".to_string(),
617    ))
618}
619
620#[cfg(any(feature = "native", test))]
621fn resolve_language(config: &WhisperCppConfig) -> Result<Option<CString>> {
622    match config.language.as_deref().map(str::trim) {
623        Some("") => resolve_default_language(config.model),
624        Some(value) if value.eq_ignore_ascii_case("auto") => resolve_default_language(config.model),
625        Some(value) => CString::new(value)
626            .map(Some)
627            .map_err(|_| WhisperCppError::UnsupportedLanguage(value.to_string())),
628        None => resolve_default_language(config.model),
629    }
630}
631
632#[cfg(any(feature = "native", test))]
633fn resolve_default_language(model: WhisperCppModel) -> Result<Option<CString>> {
634    if model.multilingual() {
635        Ok(None)
636    } else {
637        CString::new("en")
638            .map(Some)
639            .map_err(|_| WhisperCppError::UnsupportedLanguage("en".to_string()))
640    }
641}
642
643#[cfg_attr(not(feature = "native"), allow(dead_code))]
644struct ProgressSink<'a> {
645    callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>,
646}
647
648impl<'a> ProgressSink<'a> {
649    fn new(callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>) -> Self {
650        Self { callback }
651    }
652
653    #[cfg(feature = "native")]
654    fn emit(&mut self, phase: WhisperCppPhase, message: String, progress: Option<f32>) {
655        if let Some(callback) = self.callback.as_mut() {
656            callback(WhisperCppProgressEvent {
657                phase,
658                message,
659                progress,
660            });
661        }
662    }
663}
664
665#[cfg(feature = "native")]
666fn read_wav_mono_f32(path: &Path) -> Result<AudioSamples> {
667    let mut reader = hound::WavReader::open(path)?;
668    let spec = reader.spec();
669    if spec.channels == 0 {
670        return Err(WhisperCppError::InvalidInput(
671            "wav file has no channels".to_string(),
672        ));
673    }
674    if spec.sample_rate != 16_000 {
675        return Err(WhisperCppError::InvalidInput(format!(
676            "expected 16 kHz wav input, got {} Hz",
677            spec.sample_rate
678        )));
679    }
680
681    let interleaved = match spec.sample_format {
682        hound::SampleFormat::Int => read_int_samples(&mut reader, spec.bits_per_sample)?,
683        hound::SampleFormat::Float => reader
684            .samples::<f32>()
685            .collect::<std::result::Result<Vec<_>, _>>()?,
686    };
687
688    let channels = spec.channels as usize;
689    let samples = if channels == 1 {
690        interleaved
691    } else {
692        interleaved
693            .chunks(channels)
694            .map(|frame| frame.iter().copied().sum::<f32>() / frame.len() as f32)
695            .collect()
696    };
697
698    Ok(AudioSamples { samples })
699}
700
701#[cfg(feature = "native")]
702fn read_int_samples(
703    reader: &mut hound::WavReader<std::io::BufReader<File>>,
704    bits_per_sample: u16,
705) -> Result<Vec<f32>> {
706    let scale = ((1_i64 << (bits_per_sample.saturating_sub(1) as u32)) - 1) as f32;
707    if bits_per_sample <= 16 {
708        Ok(reader
709            .samples::<i16>()
710            .map(|sample| sample.map(|sample| sample as f32 / scale))
711            .collect::<std::result::Result<Vec<_>, _>>()?)
712    } else {
713        Ok(reader
714            .samples::<i32>()
715            .map(|sample| sample.map(|sample| sample as f32 / scale))
716            .collect::<std::result::Result<Vec<_>, _>>()?)
717    }
718}
719
720#[cfg(feature = "native")]
721fn resolve_threads(value: Option<usize>) -> i32 {
722    value
723        .or_else(|| thread::available_parallelism().ok().map(usize::from))
724        .unwrap_or(4)
725        .min(i32::MAX as usize) as i32
726}
727
728#[cfg(feature = "native")]
729fn timestamp_to_seconds(value: i64) -> f64 {
730    value as f64 / 100.0
731}
732
733#[cfg(feature = "native")]
734fn join_segments(segments: &[WhisperCppSegment]) -> Option<String> {
735    let text = segments
736        .iter()
737        .map(|segment| segment.text.trim())
738        .filter(|text| !text.is_empty())
739        .collect::<Vec<_>>()
740        .join(" ");
741    (!text.is_empty()).then_some(text)
742}
743
744#[cfg(feature = "native")]
745fn c_string(value: *const std::ffi::c_char) -> Result<String> {
746    if value.is_null() {
747        return Ok(String::new());
748    }
749    unsafe { CStr::from_ptr(value) }
750        .to_str()
751        .map(|value| value.to_string())
752        .map_err(|_| WhisperCppError::InvalidUtf8)
753}
754
755fn cache_root() -> PathBuf {
756    if let Some(dir) = std::env::var_os("VIDEO_ANALYSIS_STUDIO_CACHE_DIR") {
757        return PathBuf::from(dir);
758    }
759    if let Some(dir) = std::env::var_os("XDG_CACHE_HOME") {
760        return PathBuf::from(dir).join("video-analysis-studio");
761    }
762    if cfg!(target_os = "windows") {
763        if let Some(dir) = std::env::var_os("LOCALAPPDATA") {
764            return PathBuf::from(dir).join("video-analysis-studio");
765        }
766    }
767    if let Some(home) = std::env::var_os("HOME") {
768        return PathBuf::from(home)
769            .join(".cache")
770            .join("video-analysis-studio");
771    }
772    PathBuf::from(".cache/video-analysis-studio")
773}
774
775#[cfg(feature = "native")]
776struct AudioSamples {
777    samples: Vec<f32>,
778}
779
780#[cfg(feature = "native")]
781struct WhisperContext {
782    raw: *mut ffi::whisper_context,
783}
784
785#[cfg(feature = "native")]
786impl WhisperContext {
787    fn from_model(path: &Path) -> Result<Self> {
788        let model_path = CString::new(path.to_string_lossy().into_owned())
789            .map_err(|_| WhisperCppError::Initialization(path.display().to_string()))?;
790        let mut params = unsafe { ffi::whisper_context_default_params() };
791        params.use_gpu = cfg!(target_os = "macos");
792        params.flash_attn = false;
793        let raw = unsafe { ffi::whisper_init_from_file_with_params(model_path.as_ptr(), params) };
794        if raw.is_null() {
795            return Err(WhisperCppError::Initialization(path.display().to_string()));
796        }
797        Ok(Self { raw })
798    }
799}
800
801#[cfg(feature = "native")]
802impl Drop for WhisperContext {
803    fn drop(&mut self) {
804        if !self.raw.is_null() {
805            unsafe { ffi::whisper_free(self.raw) };
806        }
807    }
808}
809
810#[cfg(any(feature = "native", test))]
811struct FileLock {
812    path: PathBuf,
813}
814
815#[cfg(any(feature = "native", test))]
816impl FileLock {
817    fn acquire(path: PathBuf) -> Result<Self> {
818        let deadline = Instant::now() + Duration::from_secs(120);
819        loop {
820            match OpenOptions::new().create_new(true).write(true).open(&path) {
821                Ok(mut file) => {
822                    let _ = writeln!(file, "{}", std::process::id());
823                    return Ok(Self { path });
824                }
825                Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => {
826                    if Instant::now() >= deadline {
827                        return Err(WhisperCppError::Io(error));
828                    }
829                    thread::sleep(Duration::from_millis(250));
830                }
831                Err(error) => return Err(WhisperCppError::Io(error)),
832            }
833        }
834    }
835}
836
837#[cfg(any(feature = "native", test))]
838impl Drop for FileLock {
839    fn drop(&mut self) {
840        let _ = fs::remove_file(&self.path);
841    }
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847    use tempfile::tempdir;
848
849    #[test]
850    fn model_metadata_matches_expected_file_names() {
851        assert_eq!(WhisperCppModel::BaseEn.file_name(), "ggml-base.en.bin");
852        assert_eq!(
853            WhisperCppModel::LargeV3Turbo.download_url(),
854            "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin"
855        );
856    }
857
858    #[test]
859    fn catalog_uses_base_en_by_default() {
860        let catalog = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test")).catalog();
861        assert_eq!(catalog.default_model, WhisperCppModel::BaseEn);
862        assert_eq!(catalog.models.len(), WhisperCppModel::ALL.len());
863    }
864
865    #[test]
866    fn cache_paths_are_stable() {
867        let store = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test"));
868        assert_eq!(
869            store.model_path(WhisperCppModel::SmallEn),
870            PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin")
871        );
872        assert_eq!(
873            store.lock_path(WhisperCppModel::SmallEn),
874            PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin.lock")
875        );
876    }
877
878    #[test]
879    fn file_lock_creates_and_releases_lock_path() {
880        let dir = tempdir().unwrap();
881        let path = dir.path().join("model.lock");
882        {
883            let _lock = FileLock::acquire(path.clone()).unwrap();
884            assert!(path.is_file());
885        }
886        assert!(!path.exists());
887    }
888
889    #[test]
890    fn english_only_models_default_to_english() {
891        let config = WhisperCppConfig {
892            model: WhisperCppModel::BaseEn,
893            language: None,
894            translate: false,
895            threads: None,
896        };
897
898        let language = resolve_language(&config).unwrap().unwrap();
899        assert_eq!(language.to_str().unwrap(), "en");
900    }
901
902    #[test]
903    fn multilingual_models_default_to_auto_detection_without_detect_only_mode() {
904        let config = WhisperCppConfig {
905            model: WhisperCppModel::Base,
906            language: None,
907            translate: false,
908            threads: None,
909        };
910
911        assert_eq!(resolve_language(&config).unwrap(), None);
912    }
913
914    #[test]
915    fn auto_language_uses_english_for_english_only_models() {
916        let config = WhisperCppConfig {
917            model: WhisperCppModel::SmallEn,
918            language: Some("auto".to_string()),
919            translate: false,
920            threads: None,
921        };
922
923        let language = resolve_language(&config).unwrap().unwrap();
924        assert_eq!(language.to_str().unwrap(), "en");
925    }
926}