whisper_stream_rs/
model.rs

1use std::path::{PathBuf, Path};
2use std::fs;
3use std::io::{self, Write};
4use crate::error::WhisperStreamError;
5use log::{info};
6use std::fmt;
7use std::str::FromStr;
8
9#[cfg(feature = "coreml")]
10use zip::ZipArchive;
11#[cfg(feature = "coreml")]
12use std::fs::File;
13#[cfg(feature = "coreml")]
14use log::{warn};
15
16/// Supported Whisper models.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum Model {
19    /// The default model: base.en
20    BaseEn,
21    /// The tiny.en model
22    TinyEn,
23    /// The small.en model
24    SmallEn,
25}
26
27impl Model {
28    /// Returns the user-facing name for this model (e.g., "base.en").
29    pub fn name(&self) -> &'static str {
30        match self {
31            Model::BaseEn => "base.en",
32            Model::TinyEn => "tiny.en",
33            Model::SmallEn => "small.en",
34        }
35    }
36    /// Returns the model file name (e.g., "ggml-base.en.bin").
37    pub fn file_name(&self) -> &'static str {
38        match self {
39            Model::BaseEn => "ggml-base.en.bin",
40            Model::TinyEn => "ggml-tiny.en.bin",
41            Model::SmallEn => "ggml-small.en.bin",
42        }
43    }
44    /// Returns the model download URL.
45    pub fn url(&self) -> &'static str {
46        match self {
47            Model::BaseEn => "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin",
48            Model::TinyEn => "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin",
49            Model::SmallEn => "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin",
50        }
51    }
52    /// Returns all supported models.
53    pub fn list() -> Vec<Model> {
54        vec![Model::BaseEn, Model::TinyEn, Model::SmallEn]
55    }
56}
57
58impl fmt::Display for Model {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        write!(f, "{}", self.name())
61    }
62}
63
64impl FromStr for Model {
65    type Err = ();
66    fn from_str(s: &str) -> Result<Self, Self::Err> {
67        match s {
68            "base.en" => Ok(Model::BaseEn),
69            "tiny.en" => Ok(Model::TinyEn),
70            "small.en" => Ok(Model::SmallEn),
71            _ => Err(()),
72        }
73    }
74}
75
76
77#[cfg(feature = "coreml")]
78const COREML_MODEL_URL_TEMPLATE: &str = "https://models.milan.place/whisper-cpp/metal//{}-encoder.mlmodelc.zip";
79#[cfg(feature = "coreml")]
80const BASE_MODEL_NAME_FOR_COREML: &str = "ggml-base.en"; // Corresponds to ggml-base.en.bin
81
82/// Ensures the Whisper model (and CoreML model if 'coreml' feature is enabled) is present, downloading if necessary.
83pub fn ensure_model(model: Model) -> Result<PathBuf, WhisperStreamError> {
84    let cache_dir = dirs::data_local_dir()
85        .ok_or_else(|| WhisperStreamError::Io {
86            source: io::Error::new(io::ErrorKind::NotFound, "Could not find local data dir")
87        })?
88        .join("whisper-stream-rs");
89
90    fs::create_dir_all(&cache_dir).map_err(WhisperStreamError::from)?;
91
92    let model_path = cache_dir.join(model.file_name());
93
94    if !model_path.exists() {
95        info!("Downloading Whisper model to {}...", model_path.display());
96        download_file(model.url(), &model_path)?;
97        info!("Whisper model downloaded.");
98    }
99
100    #[cfg(feature = "coreml")]
101    {
102        ensure_coreml_model_if_enabled(&cache_dir)?;
103    }
104
105    Ok(model_path) // Return path to the main .bin model
106}
107
108#[cfg(feature = "coreml")]
109fn ensure_coreml_model_if_enabled(cache_dir: &Path) -> Result<(), WhisperStreamError> {
110    info!("CoreML feature enabled. Checking for CoreML model...");
111    let coreml_base_name = BASE_MODEL_NAME_FOR_COREML;
112    let coreml_encoder_dir_name = format!("{}-encoder.mlmodelc", coreml_base_name);
113    let coreml_model_dir_path = cache_dir.join(&coreml_encoder_dir_name);
114
115    if !coreml_model_dir_path.exists() {
116        let coreml_model_zip_url = COREML_MODEL_URL_TEMPLATE.replace("{}", coreml_base_name);
117        let coreml_zip_filename = format!("{}-encoder.mlmodelc.zip", coreml_base_name);
118        let coreml_zip_path = cache_dir.join(&coreml_zip_filename);
119
120        info!("Downloading CoreML model from {} to {}...", coreml_model_zip_url, coreml_zip_path.display());
121        download_file(&coreml_model_zip_url, &coreml_zip_path)?;
122        info!("CoreML model ZIP downloaded.");
123
124        info!("Unzipping CoreML model to {}...", cache_dir.display());
125        if let Err(e) = unzip_file(&coreml_zip_path, &cache_dir) {
126            // Attempt to clean up the potentially corrupted zip file or partial extraction
127            if let Err(remove_err) = fs::remove_file(&coreml_zip_path) {
128                warn!("Failed to remove zip file {} during cleanup: {}", coreml_zip_path.display(), remove_err);
129            }
130            if let Err(remove_dir_err) = fs::remove_dir_all(&coreml_model_dir_path) {
131                warn!("Failed to remove directory {} during cleanup: {}", coreml_model_dir_path.display(), remove_dir_err);
132            }
133            // The error is returned from this function, so no need for error! here, caller handles it.
134            return Err(e);
135        }
136        info!("CoreML model unzipped and available at {}.", coreml_model_dir_path.display());
137
138        // Clean up the downloaded zip file after successful extraction
139        if fs::remove_file(&coreml_zip_path).is_err() {
140            warn!("Could not remove CoreML zip file: {}", coreml_zip_path.display());
141        }
142    } else {
143        info!("CoreML model already present at {}.", coreml_model_dir_path.display());
144    }
145    Ok(())
146}
147
148fn download_file(url: &str, path: &Path) -> Result<(), WhisperStreamError> {
149    let mut resp = reqwest::blocking::get(url)
150        .map_err(|e| WhisperStreamError::ModelFetch(format!("Failed to initiate download from {}: {}", url, e)))?;
151
152    if !resp.status().is_success() {
153        return Err(WhisperStreamError::ModelFetch(format!("Failed to download from {}: HTTP Status {}", url, resp.status())));
154    }
155
156    let mut out = fs::File::create(path)
157        .map_err(|e| WhisperStreamError::Io { source: e })?;
158
159    io::copy(&mut resp, &mut out)
160        .map_err(|e| WhisperStreamError::Io { source: e })?;
161
162    out.flush().map_err(|e| WhisperStreamError::Io { source: e })?;
163    Ok(())
164}
165
166#[cfg(feature = "coreml")]
167fn unzip_file(zip_path: &Path, dest_dir: &Path) -> Result<(), WhisperStreamError> {
168    let file = File::open(zip_path).map_err(|e| WhisperStreamError::Io { source: e })?;
169    let mut archive = ZipArchive::new(file).map_err(|e| WhisperStreamError::ModelFetch(format!("Failed to open zip archive '{}': {}", zip_path.display(), e)))?;
170
171    for i in 0..archive.len() {
172        let mut file_in_zip = archive.by_index(i).map_err(|e| WhisperStreamError::ModelFetch(format!("Failed to access file in zip '{}': {}", zip_path.display(), e)))?;
173        let outpath = match file_in_zip.enclosed_name() {
174            Some(path) => dest_dir.join(path),
175            None => continue, // Skip if path is risky (e.g. ../)
176        };
177
178        if file_in_zip.name().ends_with('/') {
179            fs::create_dir_all(&outpath).map_err(|e| WhisperStreamError::Io { source: e })?;
180        } else {
181            if let Some(p) = outpath.parent() {
182                if !p.exists() {
183                    fs::create_dir_all(p).map_err(|e| WhisperStreamError::Io { source: e })?;
184                }
185            }
186            let mut outfile = fs::File::create(&outpath).map_err(|e| WhisperStreamError::Io { source: e })?;
187            io::copy(&mut file_in_zip, &mut outfile).map_err(|e| WhisperStreamError::Io { source: e })?;
188        }
189    }
190    Ok(())
191}