1use crate::error::{SpatialError, SpatialResult};
2use std::path::{Path, PathBuf};
3use tokio::io::AsyncWriteExt;
4
5pub fn get_checkpoint_dir() -> SpatialResult<PathBuf> {
6 if let Ok(custom_dir) = std::env::var("SPATIAL_MAKER_CHECKPOINTS") {
7 Ok(PathBuf::from(custom_dir))
8 } else {
9 let home = dirs::home_dir().ok_or_else(|| {
10 SpatialError::ConfigError("Could not determine home directory".to_string())
11 })?;
12 Ok(home.join(".spatial-maker").join("checkpoints"))
13 }
14}
15
16#[derive(Clone, Debug)]
17pub struct ModelMetadata {
18 pub name: String,
19 pub filename: String,
20 pub url: String,
21 pub size_mb: u32,
22}
23
24impl ModelMetadata {
25 pub fn coreml(encoder_size: &str) -> SpatialResult<Self> {
26 match encoder_size {
27 "s" | "small" => Ok(ModelMetadata {
28 name: "depth-anything-v2-small".to_string(),
29 filename: "DepthAnythingV2SmallF16.mlpackage".to_string(),
30 url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2SmallF16.mlpackage.tar.gz".to_string(),
31 size_mb: 48,
32 }),
33 "b" | "base" => Ok(ModelMetadata {
34 name: "depth-anything-v2-base".to_string(),
35 filename: "DepthAnythingV2BaseF16.mlpackage".to_string(),
36 url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2BaseF16.mlpackage.tar.gz".to_string(),
37 size_mb: 186,
38 }),
39 "l" | "large" => Ok(ModelMetadata {
40 name: "depth-anything-v2-large".to_string(),
41 filename: "DepthAnythingV2LargeF16.mlpackage".to_string(),
42 url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2LargeF16.mlpackage.tar.gz".to_string(),
43 size_mb: 638,
44 }),
45 other => Err(SpatialError::ConfigError(
46 format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
47 )),
48 }
49 }
50
51 #[cfg(feature = "onnx")]
52 pub fn onnx(encoder_size: &str) -> SpatialResult<Self> {
53 match encoder_size {
54 "s" | "small" => Ok(ModelMetadata {
55 name: "depth-anything-v2-small".to_string(),
56 filename: "depth_anything_v2_small.onnx".to_string(),
57 url: "https://huggingface.co/onnx-community/depth-anything-v2-small/resolve/main/onnx/model.onnx".to_string(),
58 size_mb: 99,
59 }),
60 "b" | "base" => Ok(ModelMetadata {
61 name: "depth-anything-v2-base".to_string(),
62 filename: "depth_anything_v2_base.onnx".to_string(),
63 url: "https://huggingface.co/onnx-community/depth-anything-v2-base/resolve/main/onnx/model.onnx".to_string(),
64 size_mb: 380,
65 }),
66 "l" | "large" => Ok(ModelMetadata {
67 name: "depth-anything-v2-large".to_string(),
68 filename: "depth_anything_v2_large.onnx".to_string(),
69 url: "https://huggingface.co/onnx-community/depth-anything-v2-large/resolve/main/onnx/model.onnx".to_string(),
70 size_mb: 1300,
71 }),
72 other => Err(SpatialError::ConfigError(
73 format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
74 )),
75 }
76 }
77}
78
79pub fn find_model(encoder_size: &str) -> SpatialResult<PathBuf> {
80 let checkpoint_dir = get_checkpoint_dir()?;
81
82 #[cfg(all(target_os = "macos", feature = "coreml"))]
83 {
84 let meta = ModelMetadata::coreml(encoder_size)?;
85 let model_path = checkpoint_dir.join(&meta.filename);
86 if model_path.exists() {
87 return Ok(model_path);
88 }
89 }
90
91 #[cfg(feature = "onnx")]
92 {
93 let meta = ModelMetadata::onnx(encoder_size)?;
94 let model_path = checkpoint_dir.join(&meta.filename);
95 if model_path.exists() {
96 return Ok(model_path);
97 }
98 }
99
100 let dev_paths = [
102 PathBuf::from("checkpoints"),
103 dirs::home_dir()
104 .unwrap_or_default()
105 .join(".spatial-maker")
106 .join("checkpoints"),
107 ];
108
109 for dir in &dev_paths {
110 if dir.exists() {
111 if let Ok(entries) = std::fs::read_dir(dir) {
112 for entry in entries.flatten() {
113 let name = entry.file_name().to_string_lossy().to_string();
114 if name.ends_with(".tar.gz") || name.ends_with(".downloading") {
115 continue;
116 }
117 if name.contains("DepthAnything") || name.contains("depth_anything") {
118 let lower_size = encoder_size.to_lowercase();
119 let name_lower = name.to_lowercase();
120 let matches = match lower_size.as_str() {
121 "s" | "small" => name_lower.contains("small"),
122 "b" | "base" => name_lower.contains("base"),
123 "l" | "large" => name_lower.contains("large"),
124 _ => false,
125 };
126 if matches {
127 return Ok(entry.path());
128 }
129 }
130 }
131 }
132 }
133 }
134
135 Err(SpatialError::ModelError(format!(
136 "Model not found for encoder size '{}'. Run download first.",
137 encoder_size
138 )))
139}
140
141pub fn model_exists(encoder_size: &str) -> bool {
142 find_model(encoder_size).is_ok()
143}
144
145pub async fn ensure_model_exists<F>(
146 encoder_size: &str,
147 progress_fn: Option<F>,
148) -> SpatialResult<PathBuf>
149where
150 F: FnMut(u64, u64),
151{
152 if let Ok(path) = find_model(encoder_size) {
153 return Ok(path);
154 }
155
156 let checkpoint_dir = get_checkpoint_dir()?;
157 tokio::fs::create_dir_all(&checkpoint_dir)
158 .await
159 .map_err(|e| {
160 SpatialError::IoError(format!("Failed to create checkpoint directory: {}", e))
161 })?;
162
163 #[cfg(all(target_os = "macos", feature = "coreml"))]
164 {
165 let meta = ModelMetadata::coreml(encoder_size)?;
166 let model_path = checkpoint_dir.join(&meta.filename);
167 download_model(&meta, &model_path, progress_fn).await?;
168 return Ok(model_path);
169 }
170
171 #[cfg(all(feature = "onnx", not(all(target_os = "macos", feature = "coreml"))))]
172 {
173 let meta = ModelMetadata::onnx(encoder_size)?;
174 let model_path = checkpoint_dir.join(&meta.filename);
175 download_model(&meta, &model_path, progress_fn).await?;
176 return Ok(model_path);
177 }
178
179 #[cfg(not(any(all(target_os = "macos", feature = "coreml"), feature = "onnx")))]
180 {
181 let _ = progress_fn;
182 Err(SpatialError::ConfigError(
183 "No depth backend enabled. Enable 'coreml' (macOS) or 'onnx' feature.".to_string(),
184 ))
185 }
186}
187
188async fn download_model<F>(
189 metadata: &ModelMetadata,
190 destination: &Path,
191 mut progress_fn: Option<F>,
192) -> SpatialResult<()>
193where
194 F: FnMut(u64, u64),
195{
196 eprintln!("Downloading model: {} ({} MB)...", metadata.name, metadata.size_mb);
197 tracing::info!("Downloading model: {} from {}", metadata.name, metadata.url);
198
199 let response = reqwest::get(&metadata.url)
200 .await
201 .map_err(|e| SpatialError::Other(format!("Failed to download model: {}", e)))?;
202
203 if !response.status().is_success() {
204 return Err(SpatialError::Other(format!(
205 "Failed to download model: HTTP {} from {}",
206 response.status(),
207 metadata.url
208 )));
209 }
210
211 let total_bytes = response
212 .content_length()
213 .unwrap_or(metadata.size_mb as u64 * 1_000_000);
214
215 let is_tar_gz = metadata.url.ends_with(".tar.gz");
216
217 if is_tar_gz {
218 let temp_path = destination.with_extension("tar.gz");
219 let mut file = tokio::fs::File::create(&temp_path)
220 .await
221 .map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
222
223 let mut downloaded = 0u64;
224 let mut stream = response.bytes_stream();
225 use futures_util::StreamExt;
226
227 let mut last_pct: u64 = 0;
228 while let Some(chunk) = stream.next().await {
229 let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
230 file.write_all(&chunk)
231 .await
232 .map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
233 downloaded += chunk.len() as u64;
234 if let Some(ref mut f) = progress_fn {
235 f(downloaded, total_bytes);
236 }
237 if total_bytes > 0 {
238 let pct = downloaded * 100 / total_bytes;
239 if pct != last_pct {
240 last_pct = pct;
241 eprint!("\rDownloading... {}%", pct);
242 }
243 }
244 }
245 eprintln!();
246 drop(file);
247
248 let parent = destination
249 .parent()
250 .ok_or_else(|| SpatialError::IoError("Invalid destination path".to_string()))?;
251
252 eprintln!("Extracting...");
253 let output = std::process::Command::new("tar")
254 .args(&["xzf"])
255 .arg(&temp_path)
256 .arg("-C")
257 .arg(parent)
258 .output()
259 .map_err(|e| SpatialError::IoError(format!("Failed to extract tar.gz: {}", e)))?;
260
261 if !output.status.success() {
262 let stderr = String::from_utf8_lossy(&output.stderr);
263 return Err(SpatialError::IoError(format!("tar extraction failed: {}", stderr)));
264 }
265
266 let _ = tokio::fs::remove_file(&temp_path).await;
267
268 if !destination.exists() {
269 return Err(SpatialError::ModelError(format!(
270 "Extraction succeeded but model not found at {:?}",
271 destination
272 )));
273 }
274 } else {
275 let mut file = tokio::fs::File::create(destination)
276 .await
277 .map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
278
279 let mut downloaded = 0u64;
280 let mut stream = response.bytes_stream();
281 use futures_util::StreamExt;
282
283 while let Some(chunk) = stream.next().await {
284 let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
285 file.write_all(&chunk)
286 .await
287 .map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
288 downloaded += chunk.len() as u64;
289 if let Some(ref mut f) = progress_fn {
290 f(downloaded, total_bytes);
291 }
292 }
293 }
294
295 tracing::info!("Model downloaded: {:?}", destination);
296 Ok(())
297}