whisper_stream_rs/
model.rs1use std::path::{PathBuf, Path};
2use std::fs;
3use std::io::{self, Write};
4use crate::error::WhisperStreamError;
5use log::{info};
6use std::fmt;
7use std::str::FromStr;
8
9#[cfg(feature = "coreml")]
10use zip::ZipArchive;
11#[cfg(feature = "coreml")]
12use std::fs::File;
13#[cfg(feature = "coreml")]
14use log::{warn};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum Model {
19 BaseEn,
21 TinyEn,
23 SmallEn,
25}
26
27impl Model {
28 pub fn name(&self) -> &'static str {
30 match self {
31 Model::BaseEn => "base.en",
32 Model::TinyEn => "tiny.en",
33 Model::SmallEn => "small.en",
34 }
35 }
36 pub fn file_name(&self) -> &'static str {
38 match self {
39 Model::BaseEn => "ggml-base.en.bin",
40 Model::TinyEn => "ggml-tiny.en.bin",
41 Model::SmallEn => "ggml-small.en.bin",
42 }
43 }
44 pub fn url(&self) -> &'static str {
46 match self {
47 Model::BaseEn => "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin",
48 Model::TinyEn => "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin",
49 Model::SmallEn => "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin",
50 }
51 }
52 pub fn list() -> Vec<Model> {
54 vec![Model::BaseEn, Model::TinyEn, Model::SmallEn]
55 }
56}
57
58impl fmt::Display for Model {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 write!(f, "{}", self.name())
61 }
62}
63
64impl FromStr for Model {
65 type Err = ();
66 fn from_str(s: &str) -> Result<Self, Self::Err> {
67 match s {
68 "base.en" => Ok(Model::BaseEn),
69 "tiny.en" => Ok(Model::TinyEn),
70 "small.en" => Ok(Model::SmallEn),
71 _ => Err(()),
72 }
73 }
74}
75
76
77#[cfg(feature = "coreml")]
78const COREML_MODEL_URL_TEMPLATE: &str = "https://models.milan.place/whisper-cpp/metal//{}-encoder.mlmodelc.zip";
79#[cfg(feature = "coreml")]
80const BASE_MODEL_NAME_FOR_COREML: &str = "ggml-base.en"; pub fn ensure_model(model: Model) -> Result<PathBuf, WhisperStreamError> {
84 let cache_dir = dirs::data_local_dir()
85 .ok_or_else(|| WhisperStreamError::Io {
86 source: io::Error::new(io::ErrorKind::NotFound, "Could not find local data dir")
87 })?
88 .join("whisper-stream-rs");
89
90 fs::create_dir_all(&cache_dir).map_err(WhisperStreamError::from)?;
91
92 let model_path = cache_dir.join(model.file_name());
93
94 if !model_path.exists() {
95 info!("Downloading Whisper model to {}...", model_path.display());
96 download_file(model.url(), &model_path)?;
97 info!("Whisper model downloaded.");
98 }
99
100 #[cfg(feature = "coreml")]
101 {
102 ensure_coreml_model_if_enabled(&cache_dir)?;
103 }
104
105 Ok(model_path) }
107
108#[cfg(feature = "coreml")]
109fn ensure_coreml_model_if_enabled(cache_dir: &Path) -> Result<(), WhisperStreamError> {
110 info!("CoreML feature enabled. Checking for CoreML model...");
111 let coreml_base_name = BASE_MODEL_NAME_FOR_COREML;
112 let coreml_encoder_dir_name = format!("{}-encoder.mlmodelc", coreml_base_name);
113 let coreml_model_dir_path = cache_dir.join(&coreml_encoder_dir_name);
114
115 if !coreml_model_dir_path.exists() {
116 let coreml_model_zip_url = COREML_MODEL_URL_TEMPLATE.replace("{}", coreml_base_name);
117 let coreml_zip_filename = format!("{}-encoder.mlmodelc.zip", coreml_base_name);
118 let coreml_zip_path = cache_dir.join(&coreml_zip_filename);
119
120 info!("Downloading CoreML model from {} to {}...", coreml_model_zip_url, coreml_zip_path.display());
121 download_file(&coreml_model_zip_url, &coreml_zip_path)?;
122 info!("CoreML model ZIP downloaded.");
123
124 info!("Unzipping CoreML model to {}...", cache_dir.display());
125 if let Err(e) = unzip_file(&coreml_zip_path, &cache_dir) {
126 if let Err(remove_err) = fs::remove_file(&coreml_zip_path) {
128 warn!("Failed to remove zip file {} during cleanup: {}", coreml_zip_path.display(), remove_err);
129 }
130 if let Err(remove_dir_err) = fs::remove_dir_all(&coreml_model_dir_path) {
131 warn!("Failed to remove directory {} during cleanup: {}", coreml_model_dir_path.display(), remove_dir_err);
132 }
133 return Err(e);
135 }
136 info!("CoreML model unzipped and available at {}.", coreml_model_dir_path.display());
137
138 if fs::remove_file(&coreml_zip_path).is_err() {
140 warn!("Could not remove CoreML zip file: {}", coreml_zip_path.display());
141 }
142 } else {
143 info!("CoreML model already present at {}.", coreml_model_dir_path.display());
144 }
145 Ok(())
146}
147
148fn download_file(url: &str, path: &Path) -> Result<(), WhisperStreamError> {
149 let mut resp = reqwest::blocking::get(url)
150 .map_err(|e| WhisperStreamError::ModelFetch(format!("Failed to initiate download from {}: {}", url, e)))?;
151
152 if !resp.status().is_success() {
153 return Err(WhisperStreamError::ModelFetch(format!("Failed to download from {}: HTTP Status {}", url, resp.status())));
154 }
155
156 let mut out = fs::File::create(path)
157 .map_err(|e| WhisperStreamError::Io { source: e })?;
158
159 io::copy(&mut resp, &mut out)
160 .map_err(|e| WhisperStreamError::Io { source: e })?;
161
162 out.flush().map_err(|e| WhisperStreamError::Io { source: e })?;
163 Ok(())
164}
165
166#[cfg(feature = "coreml")]
167fn unzip_file(zip_path: &Path, dest_dir: &Path) -> Result<(), WhisperStreamError> {
168 let file = File::open(zip_path).map_err(|e| WhisperStreamError::Io { source: e })?;
169 let mut archive = ZipArchive::new(file).map_err(|e| WhisperStreamError::ModelFetch(format!("Failed to open zip archive '{}': {}", zip_path.display(), e)))?;
170
171 for i in 0..archive.len() {
172 let mut file_in_zip = archive.by_index(i).map_err(|e| WhisperStreamError::ModelFetch(format!("Failed to access file in zip '{}': {}", zip_path.display(), e)))?;
173 let outpath = match file_in_zip.enclosed_name() {
174 Some(path) => dest_dir.join(path),
175 None => continue, };
177
178 if file_in_zip.name().ends_with('/') {
179 fs::create_dir_all(&outpath).map_err(|e| WhisperStreamError::Io { source: e })?;
180 } else {
181 if let Some(p) = outpath.parent() {
182 if !p.exists() {
183 fs::create_dir_all(p).map_err(|e| WhisperStreamError::Io { source: e })?;
184 }
185 }
186 let mut outfile = fs::File::create(&outpath).map_err(|e| WhisperStreamError::Io { source: e })?;
187 io::copy(&mut file_in_zip, &mut outfile).map_err(|e| WhisperStreamError::Io { source: e })?;
188 }
189 }
190 Ok(())
191}