1use crate::bridge::{Point3D, PointCloud};
7use crate::perception::clustering;
8use crate::perception::config::ObstacleConfig;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ObstacleClass {
17 Static,
19 Dynamic,
21 Unknown,
23}
24
25#[derive(Debug, Clone)]
27pub struct DetectedObstacle {
28 pub center: [f64; 3],
30 pub extent: [f64; 3],
32 pub point_count: usize,
34 pub min_distance: f64,
36}
37
38#[derive(Debug, Clone)]
40pub struct ClassifiedObstacle {
41 pub obstacle: DetectedObstacle,
42 pub class: ObstacleClass,
43 pub confidence: f32,
44}
45
46#[derive(Debug, Clone)]
52pub struct ObstacleDetector {
53 config: ObstacleConfig,
54}
55
56impl ObstacleDetector {
57 pub fn new(config: ObstacleConfig) -> Self {
59 Self { config }
60 }
61
62 pub fn detect(
73 &self,
74 cloud: &PointCloud,
75 robot_pos: &[f64; 3],
76 ) -> Vec<DetectedObstacle> {
77 if cloud.is_empty() {
78 return Vec::new();
79 }
80
81 let cell_size = (self.config.safety_margin * 5.0).max(0.5);
82 let clusters = clustering::cluster_point_cloud(cloud, cell_size);
83
84 let mut obstacles: Vec<DetectedObstacle> = clusters
85 .into_iter()
86 .filter(|pts| pts.len() >= self.config.min_obstacle_size)
87 .filter_map(|pts| self.cluster_to_obstacle(&pts, robot_pos))
88 .filter(|o| o.min_distance <= self.config.max_detection_range)
89 .collect();
90
91 obstacles.sort_by(|a, b| {
92 a.min_distance
93 .partial_cmp(&b.min_distance)
94 .unwrap_or(std::cmp::Ordering::Equal)
95 });
96
97 obstacles
98 }
99
100 pub fn classify_obstacles(
108 &self,
109 obstacles: &[DetectedObstacle],
110 ) -> Vec<ClassifiedObstacle> {
111 obstacles
112 .iter()
113 .map(|o| {
114 let (class, confidence) = self.classify_single(o);
115 ClassifiedObstacle {
116 obstacle: o.clone(),
117 class,
118 confidence,
119 }
120 })
121 .collect()
122 }
123
124 fn cluster_to_obstacle(
127 &self,
128 points: &[Point3D],
129 robot_pos: &[f64; 3],
130 ) -> Option<DetectedObstacle> {
131 if points.is_empty() {
132 return None;
133 }
134
135 let (mut min_x, mut min_y, mut min_z) = (f64::MAX, f64::MAX, f64::MAX);
136 let (mut max_x, mut max_y, mut max_z) = (f64::MIN, f64::MIN, f64::MIN);
137 let (mut sum_x, mut sum_y, mut sum_z) = (0.0_f64, 0.0_f64, 0.0_f64);
138
139 for p in points {
140 let (px, py, pz) = (p.x as f64, p.y as f64, p.z as f64);
141 min_x = min_x.min(px);
142 min_y = min_y.min(py);
143 min_z = min_z.min(pz);
144 max_x = max_x.max(px);
145 max_y = max_y.max(py);
146 max_z = max_z.max(pz);
147 sum_x += px;
148 sum_y += py;
149 sum_z += pz;
150 }
151
152 let n = points.len() as f64;
153 let center = [sum_x / n, sum_y / n, sum_z / n];
154 let extent = [
155 (max_x - min_x) / 2.0 + self.config.safety_margin,
156 (max_y - min_y) / 2.0 + self.config.safety_margin,
157 (max_z - min_z) / 2.0 + self.config.safety_margin,
158 ];
159
160 let dist = ((center[0] - robot_pos[0]).powi(2)
161 + (center[1] - robot_pos[1]).powi(2)
162 + (center[2] - robot_pos[2]).powi(2))
163 .sqrt();
164
165 Some(DetectedObstacle {
166 center,
167 extent,
168 point_count: points.len(),
169 min_distance: dist,
170 })
171 }
172
173 fn classify_single(&self, obstacle: &DetectedObstacle) -> (ObstacleClass, f32) {
174 let exts = &obstacle.extent;
175 let max_ext = exts[0].max(exts[1]).max(exts[2]);
176 let min_ext = exts[0].min(exts[1]).min(exts[2]);
177
178 if min_ext < f64::EPSILON {
179 return (ObstacleClass::Unknown, 0.3);
180 }
181
182 let ratio = max_ext / min_ext;
183
184 if ratio > 3.0 {
185 let confidence = (ratio / 10.0).min(1.0) as f32;
187 (ObstacleClass::Static, confidence.max(0.6))
188 } else if ratio <= 2.0 {
189 let confidence = (1.0 - (ratio - 1.0) / 2.0).max(0.5) as f32;
191 (ObstacleClass::Dynamic, confidence)
192 } else {
193 (ObstacleClass::Unknown, 0.4)
194 }
195 }
196}
197
198#[cfg(test)]
203mod tests {
204 use super::*;
205
206 fn make_cloud(raw: &[[f32; 3]]) -> PointCloud {
207 let points: Vec<Point3D> = raw.iter().map(|p| Point3D::new(p[0], p[1], p[2])).collect();
208 PointCloud::new(points, 0)
209 }
210
211 #[test]
212 fn test_detect_empty_cloud() {
213 let det = ObstacleDetector::new(ObstacleConfig::default());
214 let cloud = PointCloud::default();
215 let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
216 assert!(result.is_empty());
217 }
218
219 #[test]
220 fn test_detect_single_cluster() {
221 let det = ObstacleDetector::new(ObstacleConfig {
222 min_obstacle_size: 3,
223 max_detection_range: 100.0,
224 safety_margin: 0.1,
225 });
226 let cloud = make_cloud(&[
227 [1.0, 1.0, 0.0],
228 [1.1, 1.0, 0.0],
229 [1.0, 1.1, 0.0],
230 [1.1, 1.1, 0.0],
231 ]);
232 let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
233 assert_eq!(result.len(), 1);
234 assert!(result[0].min_distance > 0.0);
235 assert_eq!(result[0].point_count, 4);
236 }
237
238 #[test]
239 fn test_detect_filters_by_range() {
240 let det = ObstacleDetector::new(ObstacleConfig {
241 min_obstacle_size: 3,
242 max_detection_range: 1.0,
243 safety_margin: 0.1,
244 });
245 let cloud = make_cloud(&[
247 [10.0, 0.0, 0.0],
248 [10.1, 0.0, 0.0],
249 [10.0, 0.1, 0.0],
250 ]);
251 let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
252 assert!(result.is_empty());
253 }
254
255 #[test]
256 fn test_detect_filters_small_clusters() {
257 let det = ObstacleDetector::new(ObstacleConfig {
258 min_obstacle_size: 5,
259 max_detection_range: 100.0,
260 safety_margin: 0.1,
261 });
262 let cloud = make_cloud(&[
264 [1.0, 1.0, 0.0],
265 [1.1, 1.0, 0.0],
266 [1.0, 1.1, 0.0],
267 ]);
268 let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
269 assert!(result.is_empty());
270 }
271
272 #[test]
273 fn test_detect_sorted_by_distance() {
274 let det = ObstacleDetector::new(ObstacleConfig {
275 min_obstacle_size: 3,
276 max_detection_range: 100.0,
277 safety_margin: 0.1,
278 });
279 let cloud = make_cloud(&[
280 [10.0, 0.0, 0.0],
282 [10.1, 0.0, 0.0],
283 [10.0, 0.1, 0.0],
284 [1.0, 0.0, 0.0],
286 [1.1, 0.0, 0.0],
287 [1.0, 0.1, 0.0],
288 ]);
289 let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
290 assert!(result.len() >= 1);
291 if result.len() >= 2 {
292 assert!(result[0].min_distance <= result[1].min_distance);
293 }
294 }
295
296 #[test]
297 fn test_classify_static_obstacle() {
298 let det = ObstacleDetector::new(ObstacleConfig::default());
299 let obstacle = DetectedObstacle {
301 center: [5.0, 0.0, 0.0],
302 extent: [10.0, 0.5, 0.5],
303 point_count: 50,
304 min_distance: 5.0,
305 };
306 let classified = det.classify_obstacles(&[obstacle]);
307 assert_eq!(classified.len(), 1);
308 assert_eq!(classified[0].class, ObstacleClass::Static);
309 assert!(classified[0].confidence >= 0.5);
310 }
311
312 #[test]
313 fn test_classify_dynamic_obstacle() {
314 let det = ObstacleDetector::new(ObstacleConfig::default());
315 let obstacle = DetectedObstacle {
317 center: [3.0, 0.0, 0.0],
318 extent: [1.0, 1.0, 1.0],
319 point_count: 20,
320 min_distance: 3.0,
321 };
322 let classified = det.classify_obstacles(&[obstacle]);
323 assert_eq!(classified.len(), 1);
324 assert_eq!(classified[0].class, ObstacleClass::Dynamic);
325 }
326
327 #[test]
328 fn test_classify_unknown_obstacle() {
329 let det = ObstacleDetector::new(ObstacleConfig::default());
330 let obstacle = DetectedObstacle {
332 center: [5.0, 0.0, 0.0],
333 extent: [3.0, 1.1, 1.0],
334 point_count: 15,
335 min_distance: 5.0,
336 };
337 let classified = det.classify_obstacles(&[obstacle]);
338 assert_eq!(classified.len(), 1);
339 assert_eq!(classified[0].class, ObstacleClass::Unknown);
340 }
341
342 #[test]
343 fn test_classify_empty_list() {
344 let det = ObstacleDetector::new(ObstacleConfig::default());
345 let classified = det.classify_obstacles(&[]);
346 assert!(classified.is_empty());
347 }
348
349 #[test]
350 fn test_obstacle_detector_debug() {
351 let det = ObstacleDetector::new(ObstacleConfig::default());
352 let dbg = format!("{:?}", det);
353 assert!(dbg.contains("ObstacleDetector"));
354 }
355
356 #[test]
357 fn test_detect_two_separated_clusters() {
358 let det = ObstacleDetector::new(ObstacleConfig {
359 min_obstacle_size: 3,
360 max_detection_range: 200.0,
361 safety_margin: 0.1,
362 });
363 let cloud = make_cloud(&[
364 [0.0, 0.0, 0.0],
366 [0.1, 0.0, 0.0],
367 [0.0, 0.1, 0.0],
368 [100.0, 100.0, 0.0],
370 [100.1, 100.0, 0.0],
371 [100.0, 100.1, 0.0],
372 ]);
373 let result = det.detect(&cloud, &[50.0, 50.0, 0.0]);
374 assert_eq!(result.len(), 2);
375 }
376}