1use std::collections::HashMap;
40use thiserror::Error;
41
42#[derive(Debug, Clone, Copy, PartialEq, Error)]
44pub enum MessageConfigError {
45 #[error("radius must be positive")]
46 InvalidRadius,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
51pub enum MessagePhaseError {
52 #[error("cannot output messages after finalize")]
54 AlreadyFinalized,
55 #[error("must finalize messages before reading")]
57 NotFinalized,
58}
59
60pub trait Message: Clone + Send + Sync + 'static {}
65
66impl<T: Clone + Send + Sync + 'static> Message for T {}
68
69#[derive(Debug, Clone)]
79pub struct BruteForceMessages<M> {
80 buffer: Vec<M>,
81 finalized: bool,
82}
83
84impl<M: Clone> BruteForceMessages<M> {
85 pub fn new() -> Self {
87 Self {
88 buffer: Vec::new(),
89 finalized: false,
90 }
91 }
92
93 pub fn with_capacity(capacity: usize) -> Self {
95 Self {
96 buffer: Vec::with_capacity(capacity),
97 finalized: false,
98 }
99 }
100
101 pub fn output(&mut self, message: M) {
103 self.try_output(message)
104 .expect("cannot output after finalize");
105 }
106
107 pub fn try_output(&mut self, message: M) -> Result<(), MessagePhaseError> {
109 if self.finalized {
110 return Err(MessagePhaseError::AlreadyFinalized);
111 }
112 self.buffer.push(message);
113 Ok(())
114 }
115
116 pub fn finalize(&mut self) {
120 self.finalized = true;
121 }
122
123 pub fn read_all(&self) -> &[M] {
130 self.try_read_all()
131 .expect("must finalize messages before reading")
132 }
133
134 pub fn try_read_all(&self) -> Result<&[M], MessagePhaseError> {
136 if !self.finalized {
137 return Err(MessagePhaseError::NotFinalized);
138 }
139 Ok(&self.buffer)
140 }
141
142 pub fn is_finalized(&self) -> bool {
144 self.finalized
145 }
146
147 pub fn len(&self) -> usize {
149 self.buffer.len()
150 }
151
152 pub fn is_empty(&self) -> bool {
154 self.buffer.is_empty()
155 }
156
157 pub fn clear(&mut self) {
159 self.buffer.clear();
160 self.finalized = false;
161 }
162}
163
164impl<M: Clone> Default for BruteForceMessages<M> {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[derive(Debug, Clone)]
182pub struct SpatialMessages2D<M> {
183 messages: Vec<M>,
184 positions: Vec<(f32, f32)>,
185 radius: f32,
186 bin_map: HashMap<(i32, i32), (usize, usize)>,
188 sorted_indices: Vec<usize>,
190 finalized: bool,
191}
192
193impl<M: Clone> SpatialMessages2D<M> {
194 pub fn new(radius: f32) -> Result<Self, MessageConfigError> {
198 if radius <= 0.0 {
199 return Err(MessageConfigError::InvalidRadius);
200 }
201 Ok(Self {
202 messages: Vec::new(),
203 positions: Vec::new(),
204 radius,
205 bin_map: HashMap::new(),
206 sorted_indices: Vec::new(),
207 finalized: false,
208 })
209 }
210
211 pub fn output(&mut self, message: M, x: f32, y: f32) {
213 self.try_output(message, x, y)
214 .expect("cannot output after finalize");
215 }
216
217 pub fn try_output(&mut self, message: M, x: f32, y: f32) -> Result<(), MessagePhaseError> {
219 if self.finalized {
220 return Err(MessagePhaseError::AlreadyFinalized);
221 }
222 self.messages.push(message);
223 self.positions.push((x, y));
224 Ok(())
225 }
226
227 pub fn finalize(&mut self) {
232 let n = self.messages.len();
233 let inv_radius = 1.0 / self.radius;
234
235 let mut bin_assignments: Vec<(i32, i32)> = Vec::with_capacity(n);
237 for &(x, y) in &self.positions {
238 let bx = (x * inv_radius).floor() as i32;
239 let by = (y * inv_radius).floor() as i32;
240 bin_assignments.push((bx, by));
241 }
242
243 self.sorted_indices.clear();
245 self.sorted_indices.extend(0..n);
246 self.sorted_indices
247 .sort_unstable_by(|&a, &b| bin_assignments[a].cmp(&bin_assignments[b]));
248
249 self.bin_map.clear();
251 if !self.sorted_indices.is_empty() {
252 let mut current_bin = bin_assignments[self.sorted_indices[0]];
253 let mut start = 0;
254
255 for i in 1..n {
256 let bin = bin_assignments[self.sorted_indices[i]];
257 if bin != current_bin {
258 self.bin_map.insert(current_bin, (start, i - start));
259 current_bin = bin;
260 start = i;
261 }
262 }
263 self.bin_map.insert(current_bin, (start, n - start));
264 }
265
266 self.finalized = true;
267 }
268
269 pub fn read_nearby(&self, x: f32, y: f32, radius: f32) -> SpatialIter2D<'_, M> {
281 self.try_read_nearby(x, y, radius)
282 .expect("must finalize messages before reading")
283 }
284
285 pub fn try_read_nearby(
287 &self,
288 x: f32,
289 y: f32,
290 radius: f32,
291 ) -> Result<SpatialIter2D<'_, M>, MessagePhaseError> {
292 if !self.finalized {
293 return Err(MessagePhaseError::NotFinalized);
294 }
295 let inv = 1.0 / self.radius;
296 let center_bx = (x * inv).floor() as i32;
297 let center_by = (y * inv).floor() as i32;
298 let grid_r = ((radius / self.radius).ceil() as i32).max(1);
299
300 let first_dx = -grid_r;
302 let first_dy = -grid_r;
303 let bx = center_bx + first_dx;
304 let by = center_by + first_dy;
305 let (bin_start, bin_count) = self.bin_map.get(&(bx, by)).copied().unwrap_or((0, 0));
306
307 Ok(SpatialIter2D {
308 messages: &self.messages,
309 positions: &self.positions,
310 sorted_indices: &self.sorted_indices,
311 bin_map: &self.bin_map,
312 query_x: x,
313 query_y: y,
314 radius_sq: radius * radius,
315 center_bx,
316 center_by,
317 grid_r,
318 cur_dx: first_dx,
319 cur_dy: first_dy,
320 bin_start,
321 bin_offset: 0,
322 bin_count,
323 })
324 }
325
326 pub fn is_finalized(&self) -> bool {
328 self.finalized
329 }
330
331 pub fn len(&self) -> usize {
333 self.messages.len()
334 }
335
336 pub fn is_empty(&self) -> bool {
338 self.messages.is_empty()
339 }
340
341 pub fn clear(&mut self) {
343 self.messages.clear();
344 self.positions.clear();
345 self.bin_map.clear();
346 self.sorted_indices.clear();
347 self.finalized = false;
348 }
349}
350
351pub struct SpatialIter2D<'a, M> {
353 messages: &'a [M],
354 positions: &'a [(f32, f32)],
355 sorted_indices: &'a [usize],
356 bin_map: &'a HashMap<(i32, i32), (usize, usize)>,
357 query_x: f32,
358 query_y: f32,
359 radius_sq: f32,
360 center_bx: i32,
361 center_by: i32,
362 grid_r: i32,
363 cur_dx: i32,
364 cur_dy: i32,
365 bin_start: usize,
367 bin_offset: usize,
368 bin_count: usize,
369}
370
371impl<'a, M> Iterator for SpatialIter2D<'a, M> {
372 type Item = (&'a M, f32);
373
374 fn next(&mut self) -> Option<Self::Item> {
375 loop {
376 while self.bin_offset < self.bin_count {
378 let idx = self.sorted_indices[self.bin_start + self.bin_offset];
379 self.bin_offset += 1;
380
381 let (px, py) = self.positions[idx];
382 let dx = self.query_x - px;
383 let dy = self.query_y - py;
384 let dist_sq = dx * dx + dy * dy;
385 if dist_sq <= self.radius_sq {
386 return Some((&self.messages[idx], dist_sq));
387 }
388 }
389
390 loop {
392 self.cur_dx += 1;
393 if self.cur_dx > self.grid_r {
394 self.cur_dx = -self.grid_r;
395 self.cur_dy += 1;
396 if self.cur_dy > self.grid_r {
397 return None; }
399 }
400
401 let bx = self.center_bx + self.cur_dx;
402 let by = self.center_by + self.cur_dy;
403 if let Some(&(start, count)) = self.bin_map.get(&(bx, by)) {
404 self.bin_start = start;
405 self.bin_offset = 0;
406 self.bin_count = count;
407 break; }
409 }
410 }
411 }
412}
413
414#[derive(Debug, Clone)]
420pub struct SpatialMessages3D<M> {
421 messages: Vec<M>,
422 positions: Vec<(f32, f32, f32)>,
423 radius: f32,
424 bin_map: HashMap<(i32, i32, i32), (usize, usize)>,
425 sorted_indices: Vec<usize>,
426 finalized: bool,
427}
428
429impl<M: Clone> SpatialMessages3D<M> {
430 pub fn new(radius: f32) -> Result<Self, MessageConfigError> {
432 if radius <= 0.0 {
433 return Err(MessageConfigError::InvalidRadius);
434 }
435 Ok(Self {
436 messages: Vec::new(),
437 positions: Vec::new(),
438 radius,
439 bin_map: HashMap::new(),
440 sorted_indices: Vec::new(),
441 finalized: false,
442 })
443 }
444
445 pub fn output(&mut self, message: M, x: f32, y: f32, z: f32) {
447 self.try_output(message, x, y, z)
448 .expect("cannot output after finalize");
449 }
450
451 pub fn try_output(
453 &mut self,
454 message: M,
455 x: f32,
456 y: f32,
457 z: f32,
458 ) -> Result<(), MessagePhaseError> {
459 if self.finalized {
460 return Err(MessagePhaseError::AlreadyFinalized);
461 }
462 self.messages.push(message);
463 self.positions.push((x, y, z));
464 Ok(())
465 }
466
467 pub fn finalize(&mut self) {
469 let n = self.messages.len();
470 let inv_radius = 1.0 / self.radius;
471
472 let mut bin_assignments: Vec<(i32, i32, i32)> = Vec::with_capacity(n);
473 for &(x, y, z) in &self.positions {
474 let bx = (x * inv_radius).floor() as i32;
475 let by = (y * inv_radius).floor() as i32;
476 let bz = (z * inv_radius).floor() as i32;
477 bin_assignments.push((bx, by, bz));
478 }
479
480 self.sorted_indices.clear();
481 self.sorted_indices.extend(0..n);
482 self.sorted_indices
483 .sort_unstable_by(|&a, &b| bin_assignments[a].cmp(&bin_assignments[b]));
484
485 self.bin_map.clear();
486 if !self.sorted_indices.is_empty() {
487 let mut current_bin = bin_assignments[self.sorted_indices[0]];
488 let mut start = 0;
489
490 for i in 1..n {
491 let bin = bin_assignments[self.sorted_indices[i]];
492 if bin != current_bin {
493 self.bin_map.insert(current_bin, (start, i - start));
494 current_bin = bin;
495 start = i;
496 }
497 }
498 self.bin_map.insert(current_bin, (start, n - start));
499 }
500
501 self.finalized = true;
502 }
503
504 pub fn read_nearby(&self, x: f32, y: f32, z: f32, radius: f32) -> SpatialIter3D<'_, M> {
513 self.try_read_nearby(x, y, z, radius)
514 .expect("must finalize messages before reading")
515 }
516
517 pub fn try_read_nearby(
519 &self,
520 x: f32,
521 y: f32,
522 z: f32,
523 radius: f32,
524 ) -> Result<SpatialIter3D<'_, M>, MessagePhaseError> {
525 if !self.finalized {
526 return Err(MessagePhaseError::NotFinalized);
527 }
528 let inv = 1.0 / self.radius;
529 let center_bx = (x * inv).floor() as i32;
530 let center_by = (y * inv).floor() as i32;
531 let center_bz = (z * inv).floor() as i32;
532 let grid_r = ((radius / self.radius).ceil() as i32).max(1);
533
534 let first_dx = -grid_r;
536 let first_dy = -grid_r;
537 let first_dz = -grid_r;
538 let bx = center_bx + first_dx;
539 let by = center_by + first_dy;
540 let bz = center_bz + first_dz;
541 let (bin_start, bin_count) = self.bin_map.get(&(bx, by, bz)).copied().unwrap_or((0, 0));
542
543 Ok(SpatialIter3D {
544 messages: &self.messages,
545 positions: &self.positions,
546 sorted_indices: &self.sorted_indices,
547 bin_map: &self.bin_map,
548 query_x: x,
549 query_y: y,
550 query_z: z,
551 radius_sq: radius * radius,
552 center_bx,
553 center_by,
554 center_bz,
555 grid_r,
556 cur_dx: first_dx,
557 cur_dy: first_dy,
558 cur_dz: first_dz,
559 bin_start,
560 bin_offset: 0,
561 bin_count,
562 })
563 }
564
565 pub fn is_finalized(&self) -> bool {
567 self.finalized
568 }
569
570 pub fn len(&self) -> usize {
572 self.messages.len()
573 }
574
575 pub fn is_empty(&self) -> bool {
577 self.messages.is_empty()
578 }
579
580 pub fn clear(&mut self) {
582 self.messages.clear();
583 self.positions.clear();
584 self.bin_map.clear();
585 self.sorted_indices.clear();
586 self.finalized = false;
587 }
588}
589
590pub struct SpatialIter3D<'a, M> {
592 messages: &'a [M],
593 positions: &'a [(f32, f32, f32)],
594 sorted_indices: &'a [usize],
595 bin_map: &'a HashMap<(i32, i32, i32), (usize, usize)>,
596 query_x: f32,
597 query_y: f32,
598 query_z: f32,
599 radius_sq: f32,
600 center_bx: i32,
601 center_by: i32,
602 center_bz: i32,
603 grid_r: i32,
604 cur_dx: i32,
605 cur_dy: i32,
606 cur_dz: i32,
607 bin_start: usize,
608 bin_offset: usize,
609 bin_count: usize,
610}
611
612impl<'a, M> Iterator for SpatialIter3D<'a, M> {
613 type Item = (&'a M, f32);
614
615 fn next(&mut self) -> Option<Self::Item> {
616 loop {
617 while self.bin_offset < self.bin_count {
619 let idx = self.sorted_indices[self.bin_start + self.bin_offset];
620 self.bin_offset += 1;
621
622 let (px, py, pz) = self.positions[idx];
623 let dx = self.query_x - px;
624 let dy = self.query_y - py;
625 let dz = self.query_z - pz;
626 let dist_sq = dx * dx + dy * dy + dz * dz;
627 if dist_sq <= self.radius_sq {
628 return Some((&self.messages[idx], dist_sq));
629 }
630 }
631
632 loop {
634 self.cur_dx += 1;
635 if self.cur_dx > self.grid_r {
636 self.cur_dx = -self.grid_r;
637 self.cur_dy += 1;
638 if self.cur_dy > self.grid_r {
639 self.cur_dy = -self.grid_r;
640 self.cur_dz += 1;
641 if self.cur_dz > self.grid_r {
642 return None; }
644 }
645 }
646
647 let bx = self.center_bx + self.cur_dx;
648 let by = self.center_by + self.cur_dy;
649 let bz = self.center_bz + self.cur_dz;
650 if let Some(&(start, count)) = self.bin_map.get(&(bx, by, bz)) {
651 self.bin_start = start;
652 self.bin_offset = 0;
653 self.bin_count = count;
654 break; }
656 }
657 }
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn brute_force_basic() {
667 let mut msgs = BruteForceMessages::new();
668 msgs.output(42i32);
669 msgs.output(99);
670 msgs.finalize();
671 assert_eq!(msgs.read_all(), &[42, 99]);
672 msgs.clear();
673 assert!(msgs.is_empty());
674 }
675
676 #[test]
677 fn brute_force_phase_errors_are_typed() {
678 let mut msgs = BruteForceMessages::new();
679 assert_eq!(msgs.try_read_all(), Err(MessagePhaseError::NotFinalized));
680 msgs.try_output(1).unwrap();
681 msgs.finalize();
682 assert_eq!(msgs.try_output(2), Err(MessagePhaseError::AlreadyFinalized));
683 assert_eq!(msgs.try_read_all().unwrap(), &[1]);
684 }
685
686 #[test]
687 fn spatial_2d_basic() {
688 let mut msgs = SpatialMessages2D::new(1.0).unwrap();
689 msgs.output("a", 0.0, 0.0);
690 msgs.output("b", 0.5, 0.5);
691 msgs.output("c", 10.0, 10.0);
692 msgs.finalize();
693
694 let nearby: Vec<_> = msgs.read_nearby(0.0, 0.0, 1.0).collect();
696 assert_eq!(nearby.len(), 2);
697 let labels: Vec<&str> = nearby.iter().map(|(&m, _)| m).collect();
698 assert!(labels.contains(&"a"));
699 assert!(labels.contains(&"b"));
700 }
701
702 #[test]
703 fn spatial_2d_phase_errors_are_typed() {
704 let mut msgs = SpatialMessages2D::new(1.0).unwrap();
705 assert_eq!(
706 msgs.try_read_nearby(0.0, 0.0, 1.0).err(),
707 Some(MessagePhaseError::NotFinalized)
708 );
709 msgs.try_output("a", 0.0, 0.0).unwrap();
710 msgs.finalize();
711 assert_eq!(
712 msgs.try_output("b", 1.0, 1.0),
713 Err(MessagePhaseError::AlreadyFinalized)
714 );
715 }
716
717 #[test]
718 fn spatial_3d_basic() {
719 let mut msgs = SpatialMessages3D::new(1.0).unwrap();
720 msgs.output("a", 0.0, 0.0, 0.0);
721 msgs.output("b", 0.5, 0.5, 0.5);
722 msgs.output("c", 10.0, 10.0, 10.0);
723 msgs.finalize();
724
725 let nearby: Vec<_> = msgs.read_nearby(0.0, 0.0, 0.0, 1.0).collect();
726 assert_eq!(nearby.len(), 2);
727 let labels: Vec<&str> = nearby.iter().map(|(&m, _)| m).collect();
728 assert!(labels.contains(&"a"));
729 assert!(labels.contains(&"b"));
730 }
731
732 #[test]
733 fn spatial_3d_phase_errors_are_typed() {
734 let mut msgs = SpatialMessages3D::new(1.0).unwrap();
735 assert_eq!(
736 msgs.try_read_nearby(0.0, 0.0, 0.0, 1.0).err(),
737 Some(MessagePhaseError::NotFinalized)
738 );
739 msgs.try_output("a", 0.0, 0.0, 0.0).unwrap();
740 msgs.finalize();
741 assert_eq!(
742 msgs.try_output("b", 1.0, 1.0, 1.0),
743 Err(MessagePhaseError::AlreadyFinalized)
744 );
745 }
746}