Skip to main content

spatial_maker/
depth.rs

1#[cfg(feature = "onnx")]
2use crate::error::{SpatialError, SpatialResult};
3#[cfg(feature = "onnx")]
4use image::DynamicImage;
5#[cfg(feature = "onnx")]
6use ndarray::Array2;
7#[cfg(feature = "onnx")]
8use ort::session::{builder::GraphOptimizationLevel, Session};
9
10#[cfg(feature = "onnx")]
11const INPUT_SIZE: u32 = 518;
12#[cfg(feature = "onnx")]
13const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
14#[cfg(feature = "onnx")]
15const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
16
17#[cfg(feature = "onnx")]
18pub struct OnnxDepthEstimator {
19	session: Session,
20}
21
22#[cfg(feature = "onnx")]
23impl OnnxDepthEstimator {
24	pub fn new(model_path: &str) -> SpatialResult<Self> {
25		let session = Session::builder()
26			.map_err(|e| SpatialError::ModelError(format!("Failed to create session: {}", e)))?
27			.with_optimization_level(GraphOptimizationLevel::Level3)
28			.map_err(|e| SpatialError::ModelError(format!("Failed to set opt level: {}", e)))?
29			.with_intra_threads(4)
30			.map_err(|e| SpatialError::ModelError(format!("Failed to set threads: {}", e)))?
31			.commit_from_file(model_path)
32			.map_err(|e| SpatialError::ModelError(format!("Failed to load ONNX model: {}", e)))?;
33
34		Ok(Self { session })
35	}
36
37	pub fn estimate(&mut self, image: &DynamicImage) -> SpatialResult<Array2<f32>> {
38		let (orig_width, orig_height) = (image.width(), image.height());
39		let size = INPUT_SIZE as usize;
40
41		let resized = image.resize_exact(
42			INPUT_SIZE,
43			INPUT_SIZE,
44			image::imageops::FilterType::Lanczos3,
45		);
46
47		let rgb = resized.to_rgb8();
48		let mut input_data = vec![0.0f32; 1 * 3 * size * size];
49
50		for (i, pixel) in rgb.pixels().enumerate() {
51			for c in 0..3 {
52				let normalized = (pixel[c] as f32 / 255.0 - IMAGENET_MEAN[c]) / IMAGENET_STD[c];
53				input_data[c * size * size + i] = normalized;
54			}
55		}
56
57		let input_value = ort::value::Value::from_array(([1usize, 3, size, size], input_data))
58			.map_err(|e| SpatialError::TensorError(format!("Failed to create input: {}", e)))?;
59
60		let outputs = self.session.run(ort::inputs![input_value])
61			.map_err(|e| SpatialError::ModelError(format!("Inference failed: {}", e)))?;
62
63		let (shape, data) = outputs[0].try_extract_tensor::<f32>()
64			.map_err(|e| SpatialError::TensorError(format!("Failed to extract output: {}", e)))?;
65
66		let dims: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
67		let h = dims[1];
68		let w = dims[2];
69
70		let depth_data: Vec<f32> = data.to_vec();
71
72		let min_val = depth_data.iter().copied().fold(f32::INFINITY, f32::min);
73		let max_val = depth_data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
74		let range = max_val - min_val;
75
76		let normalized: Vec<f32> = if range > 1e-6 {
77			depth_data.iter().map(|&v| (v - min_val) / range).collect()
78		} else {
79			vec![0.5; depth_data.len()]
80		};
81
82		let depth_image = image::ImageBuffer::from_fn(w as u32, h as u32, |x, y| {
83			image::Luma([normalized[y as usize * w + x as usize]])
84		});
85
86		let resized_depth = image::imageops::resize(
87			&depth_image,
88			orig_width,
89			orig_height,
90			image::imageops::FilterType::Lanczos3,
91		);
92
93		let data: Vec<f32> = resized_depth.pixels().map(|p| p[0]).collect();
94		Array2::from_shape_vec((orig_height as usize, orig_width as usize), data)
95			.map_err(|e| SpatialError::TensorError(format!("Failed to reshape depth: {}", e)))
96	}
97}