Skip to main content

spatial_maker/
model.rs

1use crate::error::{SpatialError, SpatialResult};
2use std::path::{Path, PathBuf};
3use tokio::io::AsyncWriteExt;
4
5pub fn get_checkpoint_dir() -> SpatialResult<PathBuf> {
6	if let Ok(custom_dir) = std::env::var("SPATIAL_MAKER_CHECKPOINTS") {
7		Ok(PathBuf::from(custom_dir))
8	} else {
9		let home = dirs::home_dir().ok_or_else(|| {
10			SpatialError::ConfigError("Could not determine home directory".to_string())
11		})?;
12		Ok(home.join(".spatial-maker").join("checkpoints"))
13	}
14}
15
16#[derive(Clone, Debug)]
17pub struct ModelMetadata {
18	pub name: String,
19	pub filename: String,
20	pub url: String,
21	pub size_mb: u32,
22}
23
24impl ModelMetadata {
25	pub fn coreml(encoder_size: &str) -> SpatialResult<Self> {
26		match encoder_size {
27			"s" | "small" => Ok(ModelMetadata {
28				name: "depth-anything-v2-small".to_string(),
29				filename: "DepthAnythingV2SmallF16.mlpackage".to_string(),
30				url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2SmallF16.mlpackage.tar.gz".to_string(),
31				size_mb: 48,
32			}),
33			"b" | "base" => Ok(ModelMetadata {
34				name: "depth-anything-v2-base".to_string(),
35				filename: "DepthAnythingV2BaseF16.mlpackage".to_string(),
36				url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2BaseF16.mlpackage.tar.gz".to_string(),
37				size_mb: 186,
38			}),
39			"l" | "large" => Ok(ModelMetadata {
40				name: "depth-anything-v2-large".to_string(),
41				filename: "DepthAnythingV2LargeF16.mlpackage".to_string(),
42				url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2LargeF16.mlpackage.tar.gz".to_string(),
43				size_mb: 638,
44			}),
45			other => Err(SpatialError::ConfigError(
46				format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
47			)),
48		}
49	}
50
51	#[cfg(feature = "onnx")]
52	pub fn onnx(encoder_size: &str) -> SpatialResult<Self> {
53		match encoder_size {
54			"s" | "small" => Ok(ModelMetadata {
55				name: "depth-anything-v2-small".to_string(),
56				filename: "depth_anything_v2_small.onnx".to_string(),
57				url: "https://huggingface.co/onnx-community/depth-anything-v2-small/resolve/main/onnx/model.onnx".to_string(),
58				size_mb: 99,
59			}),
60			"b" | "base" => Ok(ModelMetadata {
61				name: "depth-anything-v2-base".to_string(),
62				filename: "depth_anything_v2_base.onnx".to_string(),
63				url: "https://huggingface.co/onnx-community/depth-anything-v2-base/resolve/main/onnx/model.onnx".to_string(),
64				size_mb: 380,
65			}),
66			"l" | "large" => Ok(ModelMetadata {
67				name: "depth-anything-v2-large".to_string(),
68				filename: "depth_anything_v2_large.onnx".to_string(),
69				url: "https://huggingface.co/onnx-community/depth-anything-v2-large/resolve/main/onnx/model.onnx".to_string(),
70				size_mb: 1300,
71			}),
72			other => Err(SpatialError::ConfigError(
73				format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
74			)),
75		}
76	}
77}
78
79pub fn find_model(encoder_size: &str) -> SpatialResult<PathBuf> {
80	let checkpoint_dir = get_checkpoint_dir()?;
81
82	#[cfg(all(target_os = "macos", feature = "coreml"))]
83	{
84		let meta = ModelMetadata::coreml(encoder_size)?;
85		let model_path = checkpoint_dir.join(&meta.filename);
86		if model_path.exists() {
87			return Ok(model_path);
88		}
89	}
90
91	#[cfg(feature = "onnx")]
92	{
93		let meta = ModelMetadata::onnx(encoder_size)?;
94		let model_path = checkpoint_dir.join(&meta.filename);
95		if model_path.exists() {
96			return Ok(model_path);
97		}
98	}
99
100	// Also check development paths
101	let dev_paths = [
102		PathBuf::from("checkpoints"),
103		dirs::home_dir()
104			.unwrap_or_default()
105			.join(".spatial-maker")
106			.join("checkpoints"),
107	];
108
109	for dir in &dev_paths {
110		if dir.exists() {
111			if let Ok(entries) = std::fs::read_dir(dir) {
112				for entry in entries.flatten() {
113					let name = entry.file_name().to_string_lossy().to_string();
114					if name.contains("DepthAnything") || name.contains("depth_anything") {
115						let lower_size = encoder_size.to_lowercase();
116						let name_lower = name.to_lowercase();
117						let matches = match lower_size.as_str() {
118							"s" | "small" => name_lower.contains("small"),
119							"b" | "base" => name_lower.contains("base"),
120							"l" | "large" => name_lower.contains("large"),
121							_ => false,
122						};
123						if matches {
124							return Ok(entry.path());
125						}
126					}
127				}
128			}
129		}
130	}
131
132	Err(SpatialError::ModelError(format!(
133		"Model not found for encoder size '{}'. Run download first.",
134		encoder_size
135	)))
136}
137
138pub fn model_exists(encoder_size: &str) -> bool {
139	find_model(encoder_size).is_ok()
140}
141
142pub async fn ensure_model_exists<F>(
143	encoder_size: &str,
144	progress_fn: Option<F>,
145) -> SpatialResult<PathBuf>
146where
147	F: FnMut(u64, u64),
148{
149	if let Ok(path) = find_model(encoder_size) {
150		return Ok(path);
151	}
152
153	let checkpoint_dir = get_checkpoint_dir()?;
154	tokio::fs::create_dir_all(&checkpoint_dir)
155		.await
156		.map_err(|e| {
157			SpatialError::IoError(format!("Failed to create checkpoint directory: {}", e))
158		})?;
159
160	#[cfg(all(target_os = "macos", feature = "coreml"))]
161	{
162		let meta = ModelMetadata::coreml(encoder_size)?;
163		let model_path = checkpoint_dir.join(&meta.filename);
164		download_model(&meta, &model_path, progress_fn).await?;
165		return Ok(model_path);
166	}
167
168	#[cfg(all(feature = "onnx", not(all(target_os = "macos", feature = "coreml"))))]
169	{
170		let meta = ModelMetadata::onnx(encoder_size)?;
171		let model_path = checkpoint_dir.join(&meta.filename);
172		download_model(&meta, &model_path, progress_fn).await?;
173		return Ok(model_path);
174	}
175
176	#[cfg(not(any(all(target_os = "macos", feature = "coreml"), feature = "onnx")))]
177	{
178		let _ = progress_fn;
179		Err(SpatialError::ConfigError(
180			"No depth backend enabled. Enable 'coreml' (macOS) or 'onnx' feature.".to_string(),
181		))
182	}
183}
184
185async fn download_model<F>(
186	metadata: &ModelMetadata,
187	destination: &Path,
188	mut progress_fn: Option<F>,
189) -> SpatialResult<()>
190where
191	F: FnMut(u64, u64),
192{
193	tracing::info!("Downloading model: {} from {}", metadata.name, metadata.url);
194
195	let response = reqwest::get(&metadata.url)
196		.await
197		.map_err(|e| SpatialError::Other(format!("Failed to download model: {}", e)))?;
198
199	let total_bytes = response
200		.content_length()
201		.unwrap_or(metadata.size_mb as u64 * 1_000_000);
202
203	let is_tar_gz = metadata.url.ends_with(".tar.gz");
204
205	if is_tar_gz {
206		let temp_path = destination.with_extension("tar.gz");
207		let mut file = tokio::fs::File::create(&temp_path)
208			.await
209			.map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
210
211		let mut downloaded = 0u64;
212		let mut stream = response.bytes_stream();
213		use futures_util::StreamExt;
214
215		while let Some(chunk) = stream.next().await {
216			let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
217			file.write_all(&chunk)
218				.await
219				.map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
220			downloaded += chunk.len() as u64;
221			if let Some(ref mut f) = progress_fn {
222				f(downloaded, total_bytes);
223			}
224		}
225		drop(file);
226
227		let parent = destination
228			.parent()
229			.ok_or_else(|| SpatialError::IoError("Invalid destination path".to_string()))?;
230
231		let output = std::process::Command::new("tar")
232			.args(&["xzf"])
233			.arg(&temp_path)
234			.arg("-C")
235			.arg(parent)
236			.output()
237			.map_err(|e| SpatialError::IoError(format!("Failed to extract tar.gz: {}", e)))?;
238
239		if !output.status.success() {
240			let stderr = String::from_utf8_lossy(&output.stderr);
241			return Err(SpatialError::IoError(format!("tar extraction failed: {}", stderr)));
242		}
243
244		let _ = tokio::fs::remove_file(&temp_path).await;
245	} else {
246		let mut file = tokio::fs::File::create(destination)
247			.await
248			.map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
249
250		let mut downloaded = 0u64;
251		let mut stream = response.bytes_stream();
252		use futures_util::StreamExt;
253
254		while let Some(chunk) = stream.next().await {
255			let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
256			file.write_all(&chunk)
257				.await
258				.map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
259			downloaded += chunk.len() as u64;
260			if let Some(ref mut f) = progress_fn {
261				f(downloaded, total_bytes);
262			}
263		}
264	}
265
266	tracing::info!("Model downloaded: {:?}", destination);
267	Ok(())
268}