1#![allow(dead_code)]
5use crate::{Result, VisionError};
6use scirs2_core::ndarray::{arr1, arr2, Array1, Array2, ArrayView2, Axis};
7use scirs2_spatial::kdtree::KDTree;
8use scirs2_spatial::octree::Octree;
9use scirs2_spatial::quadtree::Quadtree;
10use scirs2_spatial::rtree::RTree;
11use std::collections::HashMap;
12use torsh_tensor::Tensor;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub struct PointId(pub usize);
17
18#[derive(Debug, Clone)]
20pub struct BoundingBox {
21 pub min: Array1<f64>,
22 pub max: Array1<f64>,
23}
24
25impl BoundingBox {
26 pub fn new(min: Array1<f64>, max: Array1<f64>) -> Result<Self> {
28 if min.len() != max.len() {
29 return Err(VisionError::InvalidArgument(
30 "Min and max coordinates must have same dimension".to_string(),
31 ));
32 }
33
34 for (i, (&min_val, &max_val)) in min.iter().zip(max.iter()).enumerate() {
35 if min_val > max_val {
36 return Err(VisionError::InvalidArgument(format!(
37 "Min coordinate {} is greater than max coordinate at dimension {}",
38 min_val, i
39 )));
40 }
41 }
42
43 Ok(Self { min, max })
44 }
45
46 pub fn from_points(points: &Array2<f64>) -> Result<Self> {
48 if points.is_empty() {
49 return Err(VisionError::InvalidArgument(
50 "Cannot create bounding box from empty points".to_string(),
51 ));
52 }
53
54 let dims = points.ncols();
55 let mut min = Array1::from_elem(dims, f64::INFINITY);
56 let mut max = Array1::from_elem(dims, f64::NEG_INFINITY);
57
58 for point in points.outer_iter() {
59 for (i, &coord) in point.iter().enumerate() {
60 min[i] = min[i].min(coord);
61 max[i] = max[i].max(coord);
62 }
63 }
64
65 Self::new(min, max)
66 }
67
68 pub fn contains(&self, point: &ArrayView2<f64>) -> bool {
70 if point.len() != self.min.len() {
71 return false;
72 }
73
74 for (i, &coord) in point.iter().enumerate() {
75 if coord < self.min[i] || coord > self.max[i] {
76 return false;
77 }
78 }
79
80 true
81 }
82
83 pub fn volume(&self) -> f64 {
85 self.max
86 .iter()
87 .zip(self.min.iter())
88 .map(|(&max_val, &min_val)| max_val - min_val)
89 .product()
90 }
91}
92
93pub struct SpatialObjectTracker {
95 spatial_index: Option<RTree<ObjectId>>,
96 object_data: HashMap<ObjectId, ObjectMetadata>,
97 frame_history: Vec<FrameData>,
98 max_history: usize,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub struct ObjectId(pub usize);
103
104#[derive(Debug, Clone)]
105pub struct ObjectMetadata {
106 pub bbox: BoundingBox,
107 pub confidence: f64,
108 pub class_id: usize,
109 pub last_seen_frame: usize,
110}
111
112#[derive(Debug, Clone)]
113pub struct FrameData {
114 pub frame_id: usize,
115 pub detections: Vec<Detection>,
116 pub timestamp: f64,
117}
118
119#[derive(Debug, Clone)]
120pub struct Detection {
121 pub bbox: BoundingBox,
122 pub confidence: f64,
123 pub class_id: usize,
124 pub features: Option<Array1<f64>>,
125}
126
127#[derive(Debug, Clone)]
128pub struct TrackedObject {
129 pub object_id: ObjectId,
130 pub bbox: BoundingBox,
131 pub confidence: f64,
132 pub class_id: usize,
133 pub track_length: usize,
134}
135
136impl SpatialObjectTracker {
137 pub fn new(max_history: usize) -> Self {
139 Self {
140 spatial_index: None,
141 object_data: HashMap::new(),
142 frame_history: Vec::new(),
143 max_history,
144 }
145 }
146
147 pub fn track_objects(
149 &mut self,
150 detections: &[Detection],
151 frame_id: usize,
152 ) -> Result<Vec<TrackedObject>> {
153 let frame_data = FrameData {
155 frame_id,
156 detections: detections.to_vec(),
157 timestamp: std::time::SystemTime::now()
158 .duration_since(std::time::UNIX_EPOCH)
159 .expect("system time should be after UNIX_EPOCH")
160 .as_secs_f64(),
161 };
162
163 self.frame_history.push(frame_data);
164 if self.frame_history.len() > self.max_history {
165 self.frame_history.remove(0);
166 }
167
168 let mut tracked_objects = Vec::new();
170
171 for (i, detection) in detections.iter().enumerate() {
172 let object_id = ObjectId(i);
173
174 let tracked_object = TrackedObject {
176 object_id,
177 bbox: detection.bbox.clone(),
178 confidence: detection.confidence,
179 class_id: detection.class_id,
180 track_length: 1,
181 };
182
183 tracked_objects.push(tracked_object);
184
185 let metadata = ObjectMetadata {
187 bbox: detection.bbox.clone(),
188 confidence: detection.confidence,
189 class_id: detection.class_id,
190 last_seen_frame: frame_id,
191 };
192
193 self.object_data.insert(object_id, metadata);
194 }
195
196 Ok(tracked_objects)
197 }
198
199 pub fn query_region(&self, region: &BoundingBox) -> Result<Vec<ObjectId>> {
201 let mut objects_in_region = Vec::new();
203
204 for (&object_id, metadata) in &self.object_data {
205 if self.bboxes_overlap(&metadata.bbox, region) {
207 objects_in_region.push(object_id);
208 }
209 }
210
211 Ok(objects_in_region)
212 }
213
214 fn bboxes_overlap(&self, bbox1: &BoundingBox, bbox2: &BoundingBox) -> bool {
215 if bbox1.min.len() != bbox2.min.len() {
216 return false;
217 }
218
219 for i in 0..bbox1.min.len() {
220 if bbox1.max[i] < bbox2.min[i] || bbox1.min[i] > bbox2.max[i] {
221 return false;
222 }
223 }
224
225 true
226 }
227
228 pub fn get_trajectory(&self, object_id: ObjectId) -> Vec<BoundingBox> {
230 if let Some(metadata) = self.object_data.get(&object_id) {
232 vec![metadata.bbox.clone()]
233 } else {
234 Vec::new()
235 }
236 }
237}
238
239pub struct PointCloudProcessor {
241 octree: Option<Octree>,
242 points: Array2<f64>,
243 point_metadata: HashMap<PointId, PointMetadata>,
244}
245
246#[derive(Debug, Clone)]
247pub struct PointMetadata {
248 pub color: Option<Array1<f32>>,
249 pub normal: Option<Array1<f64>>,
250 pub intensity: Option<f64>,
251}
252
253impl PointCloudProcessor {
254 pub fn new() -> Self {
256 Self {
257 octree: None,
258 points: Array2::zeros((0, 3)),
259 point_metadata: HashMap::new(),
260 }
261 }
262
263 pub fn build_index(&mut self, points: Array2<f64>) -> Result<()> {
265 if points.ncols() != 3 {
266 return Err(VisionError::InvalidArgument(
267 "Point cloud must have 3D coordinates".to_string(),
268 ));
269 }
270
271 self.points = points;
272
273 let _bbox = BoundingBox::from_points(&self.points)?;
275
276 Ok(())
284 }
285
286 pub fn query_region(&self, region: &BoundingBox) -> Result<Vec<PointId>> {
288 let mut points_in_region = Vec::new();
290
291 for (i, point) in self.points.outer_iter().enumerate() {
292 if region.contains(&point.view().insert_axis(Axis(1))) {
293 points_in_region.push(PointId(i));
294 }
295 }
296
297 Ok(points_in_region)
298 }
299
300 pub fn find_neighbors(&self, query_point: &Array1<f64>, k: usize) -> Result<Vec<PointId>> {
302 if query_point.len() != 3 {
303 return Err(VisionError::InvalidArgument(
304 "Query point must be 3D".to_string(),
305 ));
306 }
307
308 let mut distances: Vec<(PointId, f64)> = Vec::new();
310
311 for (i, point) in self.points.outer_iter().enumerate() {
312 let diff = &point - query_point;
313 let distance = (diff.mapv(|x| x * x).sum()).sqrt();
314 distances.push((PointId(i), distance));
315 }
316
317 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("comparison should succeed"));
318 distances.truncate(k);
319
320 Ok(distances.into_iter().map(|(id, _)| id).collect())
321 }
322
323 pub fn segment_regions(&self, _region_size: f64) -> Result<Vec<Vec<PointId>>> {
325 let mut regions = Vec::new();
327
328 if !self.points.is_empty() {
330 let bbox = BoundingBox::from_points(&self.points)?;
331 let _dims = bbox.max.len();
332
333 let all_points: Vec<PointId> = (0..self.points.nrows()).map(PointId).collect();
335 regions.push(all_points);
336 }
337
338 Ok(regions)
339 }
340}
341
342impl Default for PointCloudProcessor {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 #[test]
354 fn test_bounding_box_creation() {
355 let min = arr1(&[0.0, 0.0]);
356 let max = arr1(&[1.0, 1.0]);
357 let bbox = BoundingBox::new(min, max);
358 assert!(bbox.is_ok());
359 }
360
361 #[test]
362 fn test_bounding_box_invalid() {
363 let min = arr1(&[1.0, 1.0]);
364 let max = arr1(&[0.0, 0.0]);
365 let bbox = BoundingBox::new(min, max);
366 assert!(bbox.is_err());
367 }
368
369 #[test]
370 fn test_bounding_box_from_points() {
371 let points = arr2(&[[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]]);
372 let bbox = BoundingBox::from_points(&points);
373 assert!(bbox.is_ok());
374
375 let bbox = bbox.expect("operation should succeed");
376 assert_eq!(bbox.min[0], 0.0);
377 assert_eq!(bbox.max[0], 1.0);
378 assert_eq!(bbox.volume(), 1.0);
379 }
380
381 #[test]
382 fn test_spatial_object_tracker() {
383 let mut tracker = SpatialObjectTracker::new(10);
384
385 let detection = Detection {
386 bbox: BoundingBox::new(arr1(&[0.0, 0.0]), arr1(&[1.0, 1.0]))
387 .expect("operation should succeed"),
388 confidence: 0.9,
389 class_id: 1,
390 features: None,
391 };
392
393 let result = tracker.track_objects(&[detection], 0);
394 assert!(result.is_ok());
395 assert_eq!(result.expect("operation should succeed").len(), 1);
396 }
397
398 #[test]
399 fn test_point_cloud_processor() {
400 let mut processor = PointCloudProcessor::new();
401 let points = arr2(&[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]);
402
403 let result = processor.build_index(points);
404 assert!(result.is_ok());
405 }
406}