Skip to main content

torsh_graph/
geometric.rs

1//! Geometric Graph Neural Networks
2//!
3//! This module provides geometric deep learning capabilities for graph-structured
4//! data with spatial coordinates. It includes geometric graph construction methods,
5//! spatial convolutions, and geometric transformations inspired by scirs2-spatial.
6//!
7//! # Features:
8//! - Geometric graph construction (k-NN, radius, Delaunay)
9//! - Point cloud to graph conversion
10//! - Spatial graph convolutions with distance-based weighting
11//! - Geometric transformations (rotation, translation, scaling)
12//! - 3D mesh processing
13//! - Geometric pooling operations
14
15use crate::parameter::Parameter;
16use crate::{GraphData, GraphLayer};
17use scirs2_core::random::thread_rng;
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use torsh_tensor::{
21    creation::{from_vec, randn, zeros},
22    Tensor,
23};
24
25/// Point in 3D space
26#[derive(Debug, Clone, Copy)]
27pub struct Point3D {
28    pub x: f32,
29    pub y: f32,
30    pub z: f32,
31}
32
33impl Point3D {
34    pub fn new(x: f32, y: f32, z: f32) -> Self {
35        Self { x, y, z }
36    }
37
38    pub fn distance(&self, other: &Point3D) -> f32 {
39        ((self.x - other.x).powi(2) + (self.y - other.y).powi(2) + (self.z - other.z).powi(2))
40            .sqrt()
41    }
42
43    pub fn dot(&self, other: &Point3D) -> f32 {
44        self.x * other.x + self.y * other.y + self.z * other.z
45    }
46
47    pub fn norm(&self) -> f32 {
48        (self.x.powi(2) + self.y.powi(2) + self.z.powi(2)).sqrt()
49    }
50}
51
52/// Geometric graph construction methods
53pub struct GeometricGraphBuilder;
54
55impl GeometricGraphBuilder {
56    /// Build k-nearest neighbors graph from point cloud
57    pub fn knn_graph(points: &[Point3D], k: usize, features: Option<Tensor>) -> GraphData {
58        let num_points = points.len();
59        let mut edges = Vec::new();
60        let mut edge_weights = Vec::new();
61
62        for i in 0..num_points {
63            // Find k nearest neighbors
64            let mut distances: Vec<(usize, f32)> = (0..num_points)
65                .filter(|&j| j != i)
66                .map(|j| (j, points[i].distance(&points[j])))
67                .collect();
68
69            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
70
71            for (j, dist) in distances.iter().take(k) {
72                edges.push(i as f32);
73                edges.push(*j as f32);
74                edge_weights.push(*dist);
75            }
76        }
77
78        let num_edges = edges.len() / 2;
79        let edge_index = from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
80            .expect("knn edge tensor creation should succeed");
81
82        // Use provided features or create default features
83        let x = features.unwrap_or_else(|| {
84            let coords: Vec<f32> = points.iter().flat_map(|p| vec![p.x, p.y, p.z]).collect();
85            from_vec(
86                coords,
87                &[num_points, 3],
88                torsh_core::device::DeviceType::Cpu,
89            )
90            .expect("coordinate feature tensor creation should succeed")
91        });
92
93        let mut graph = GraphData::new(x, edge_index);
94
95        // Store edge weights as edge attributes
96        let edge_attr = from_vec(
97            edge_weights,
98            &[num_edges, 1],
99            torsh_core::device::DeviceType::Cpu,
100        )
101        .expect("edge attribute tensor creation should succeed");
102        graph.edge_attr = Some(edge_attr);
103
104        graph
105    }
106
107    /// Build radius graph (connect all points within radius)
108    pub fn radius_graph(points: &[Point3D], radius: f32, features: Option<Tensor>) -> GraphData {
109        let num_points = points.len();
110        let mut edges = Vec::new();
111        let mut edge_weights = Vec::new();
112
113        for i in 0..num_points {
114            for j in (i + 1)..num_points {
115                let dist = points[i].distance(&points[j]);
116
117                if dist <= radius {
118                    edges.push(i as f32);
119                    edges.push(j as f32);
120                    edges.push(j as f32);
121                    edges.push(i as f32);
122                    edge_weights.push(dist);
123                    edge_weights.push(dist);
124                }
125            }
126        }
127
128        let num_edges = edges.len() / 2;
129        let edge_index = if num_edges > 0 {
130            from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
131                .expect("radius edge tensor creation should succeed")
132        } else {
133            from_vec(vec![], &[2, 0], torsh_core::device::DeviceType::Cpu)
134                .expect("empty edge tensor creation should succeed")
135        };
136
137        let x = features.unwrap_or_else(|| {
138            let coords: Vec<f32> = points.iter().flat_map(|p| vec![p.x, p.y, p.z]).collect();
139            from_vec(
140                coords,
141                &[num_points, 3],
142                torsh_core::device::DeviceType::Cpu,
143            )
144            .expect("coordinate feature tensor creation should succeed")
145        });
146
147        let mut graph = GraphData::new(x, edge_index);
148
149        if num_edges > 0 {
150            let edge_attr = from_vec(
151                edge_weights,
152                &[num_edges, 1],
153                torsh_core::device::DeviceType::Cpu,
154            )
155            .expect("edge attribute tensor creation should succeed");
156            graph.edge_attr = Some(edge_attr);
157        }
158
159        graph
160    }
161
162    /// Build Delaunay triangulation graph (2D simplified version)
163    pub fn delaunay_graph_2d(points: &[(f32, f32)], features: Option<Tensor>) -> GraphData {
164        let num_points = points.len();
165
166        // Simplified Delaunay: connect points that are close
167        // Full Delaunay would require more complex algorithms
168        let mut edges = Vec::new();
169        let mut visited_pairs: std::collections::HashSet<(usize, usize)> =
170            std::collections::HashSet::new();
171
172        for i in 0..num_points {
173            // Find k nearest neighbors for simplified triangulation
174            let k = 5;
175            let mut distances: Vec<(usize, f32)> = (0..num_points)
176                .filter(|&j| j != i)
177                .map(|j| {
178                    let dx = points[i].0 - points[j].0;
179                    let dy = points[i].1 - points[j].1;
180                    (j, (dx * dx + dy * dy).sqrt())
181                })
182                .collect();
183
184            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
185
186            for (j, _) in distances.iter().take(k) {
187                let pair = if i < *j { (i, *j) } else { (*j, i) };
188
189                if !visited_pairs.contains(&pair) {
190                    visited_pairs.insert(pair);
191                    edges.push(i as f32);
192                    edges.push(*j as f32);
193                    edges.push(*j as f32);
194                    edges.push(i as f32);
195                }
196            }
197        }
198
199        let num_edges = edges.len() / 2;
200        let edge_index = from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
201            .expect("delaunay edge tensor creation should succeed");
202
203        let x = features.unwrap_or_else(|| {
204            let coords: Vec<f32> = points.iter().flat_map(|(x, y)| vec![*x, *y]).collect();
205            from_vec(
206                coords,
207                &[num_points, 2],
208                torsh_core::device::DeviceType::Cpu,
209            )
210            .expect("coordinate feature tensor creation should succeed")
211        });
212
213        GraphData::new(x, edge_index)
214    }
215}
216
217/// Geometric convolution layer with distance-based attention
218#[derive(Debug)]
219pub struct GeometricConv {
220    in_features: usize,
221    out_features: usize,
222    hidden_dim: usize,
223
224    // MLP for message generation
225    message_mlp: Vec<Parameter>,
226
227    // Distance encoding
228    distance_encoder: Parameter,
229
230    // Output projection
231    output_weight: Parameter,
232
233    bias: Option<Parameter>,
234}
235
236impl GeometricConv {
237    /// Create a new geometric convolution layer
238    pub fn new(in_features: usize, out_features: usize, hidden_dim: usize, use_bias: bool) -> Self {
239        // MLP layers for message generation
240        let message_layer1 = Parameter::new(
241            randn(&[in_features * 2 + 1, hidden_dim])
242                .expect("randn should succeed for valid dimensions"),
243        );
244        let message_layer2 = Parameter::new(
245            randn(&[hidden_dim, hidden_dim]).expect("randn should succeed for valid dimensions"),
246        );
247
248        let distance_encoder = Parameter::new(
249            randn(&[1, hidden_dim]).expect("randn should succeed for valid dimensions"),
250        );
251        let output_weight = Parameter::new(
252            randn(&[hidden_dim, out_features]).expect("randn should succeed for valid dimensions"),
253        );
254
255        let bias = if use_bias {
256            Some(Parameter::new(
257                zeros(&[out_features]).expect("zeros should succeed for valid dimensions"),
258            ))
259        } else {
260            None
261        };
262
263        Self {
264            in_features,
265            out_features,
266            hidden_dim,
267            message_mlp: vec![message_layer1, message_layer2],
268            distance_encoder,
269            output_weight,
270            bias,
271        }
272    }
273
274    /// Forward pass through geometric convolution
275    pub fn forward(&self, graph: &GraphData) -> GraphData {
276        let num_nodes = graph.num_nodes;
277        let num_edges = graph.num_edges;
278
279        // Get edge distances if available
280        let edge_distances = if let Some(ref edge_attr) = graph.edge_attr {
281            edge_attr.to_vec().expect("conversion should succeed")
282        } else {
283            vec![1.0; num_edges]
284        };
285
286        // Aggregate messages
287        let edge_data = graph
288            .edge_index
289            .to_vec()
290            .expect("conversion should succeed");
291        let mut aggregated = vec![0.0; num_nodes * self.hidden_dim];
292
293        let node_features = graph.x.to_vec().expect("conversion should succeed");
294
295        for edge_idx in 0..num_edges {
296            let src = edge_data[edge_idx * 2] as usize;
297            let dst = edge_data[edge_idx * 2 + 1] as usize;
298
299            if src >= num_nodes || dst >= num_nodes {
300                continue;
301            }
302
303            // Get source and destination features
304            let src_features = &node_features[src * self.in_features..(src + 1) * self.in_features];
305            let dst_features = &node_features[dst * self.in_features..(dst + 1) * self.in_features];
306
307            // Distance encoding
308            let dist = edge_distances[edge_idx.min(edge_distances.len() - 1)];
309
310            // Concatenate features and distance
311            let mut message_input = Vec::new();
312            message_input.extend_from_slice(src_features);
313            message_input.extend_from_slice(dst_features);
314            message_input.push(dist);
315
316            // Compute message through MLP (simplified)
317            let message = self.compute_message(&message_input);
318
319            // Aggregate to destination node
320            for (i, &val) in message.iter().enumerate() {
321                aggregated[dst * self.hidden_dim + i] += val;
322            }
323        }
324
325        // Apply output projection
326        let mut output_features = vec![0.0; num_nodes * self.out_features];
327
328        for node in 0..num_nodes {
329            let agg_features = &aggregated[node * self.hidden_dim..(node + 1) * self.hidden_dim];
330            let output_proj = self
331                .output_weight
332                .clone_data()
333                .to_vec()
334                .expect("conversion should succeed");
335
336            for out_idx in 0..self.out_features {
337                let mut sum = 0.0;
338                for hid_idx in 0..self.hidden_dim {
339                    sum +=
340                        agg_features[hid_idx] * output_proj[hid_idx * self.out_features + out_idx];
341                }
342
343                if let Some(ref bias) = self.bias {
344                    let bias_data = bias
345                        .clone_data()
346                        .to_vec()
347                        .expect("conversion should succeed");
348                    if out_idx < bias_data.len() {
349                        sum += bias_data[out_idx];
350                    }
351                }
352
353                output_features[node * self.out_features + out_idx] = sum;
354            }
355        }
356
357        let output = from_vec(
358            output_features,
359            &[num_nodes, self.out_features],
360            torsh_core::device::DeviceType::Cpu,
361        )
362        .expect("output tensor creation should succeed");
363
364        let mut output_graph = graph.clone();
365        output_graph.x = output;
366        output_graph
367    }
368
369    /// Compute message from concatenated features and distance
370    fn compute_message(&self, input: &[f32]) -> Vec<f32> {
371        // Layer 1
372        let layer1_weights = self.message_mlp[0]
373            .clone_data()
374            .to_vec()
375            .expect("conversion should succeed");
376        let input_dim = self.in_features * 2 + 1;
377        let mut hidden = vec![0.0; self.hidden_dim];
378
379        for h in 0..self.hidden_dim {
380            let mut sum = 0.0;
381            for i in 0..input_dim.min(input.len()) {
382                sum += input[i] * layer1_weights[i * self.hidden_dim + h];
383            }
384            hidden[h] = sum.max(0.0); // ReLU
385        }
386
387        // Layer 2
388        let layer2_weights = self.message_mlp[1]
389            .clone_data()
390            .to_vec()
391            .expect("conversion should succeed");
392        let mut output = vec![0.0; self.hidden_dim];
393
394        for h in 0..self.hidden_dim {
395            let mut sum = 0.0;
396            for i in 0..self.hidden_dim {
397                sum += hidden[i] * layer2_weights[i * self.hidden_dim + h];
398            }
399            output[h] = sum.max(0.0); // ReLU
400        }
401
402        output
403    }
404}
405
406impl GraphLayer for GeometricConv {
407    fn forward(&self, graph: &GraphData) -> GraphData {
408        self.forward(graph)
409    }
410
411    fn parameters(&self) -> Vec<Tensor> {
412        let mut params = Vec::new();
413
414        for layer in &self.message_mlp {
415            params.push(layer.clone_data());
416        }
417
418        params.push(self.distance_encoder.clone_data());
419        params.push(self.output_weight.clone_data());
420
421        if let Some(ref bias) = self.bias {
422            params.push(bias.clone_data());
423        }
424
425        params
426    }
427}
428
429/// Geometric transformations for point clouds and graphs
430pub struct GeometricTransformer;
431
432impl GeometricTransformer {
433    /// Apply rotation to point cloud
434    pub fn rotate_3d(points: &mut [Point3D], axis: &Point3D, angle: f32) {
435        let cos_theta = angle.cos();
436        let sin_theta = angle.sin();
437
438        // Normalize axis
439        let norm = axis.norm();
440        if norm == 0.0 {
441            return;
442        }
443
444        let ux = axis.x / norm;
445        let uy = axis.y / norm;
446        let uz = axis.z / norm;
447
448        // Rotation matrix (Rodrigues' rotation formula)
449        for point in points.iter_mut() {
450            let x = point.x;
451            let y = point.y;
452            let z = point.z;
453
454            // Dot product with axis
455            let dot = ux * x + uy * y + uz * z;
456
457            // Cross product with axis
458            let cross_x = uy * z - uz * y;
459            let cross_y = uz * x - ux * z;
460            let cross_z = ux * y - uy * x;
461
462            // Apply rotation
463            point.x = x * cos_theta + cross_x * sin_theta + ux * dot * (1.0 - cos_theta);
464            point.y = y * cos_theta + cross_y * sin_theta + uy * dot * (1.0 - cos_theta);
465            point.z = z * cos_theta + cross_z * sin_theta + uz * dot * (1.0 - cos_theta);
466        }
467    }
468
469    /// Apply translation to point cloud
470    pub fn translate_3d(points: &mut [Point3D], offset: &Point3D) {
471        for point in points.iter_mut() {
472            point.x += offset.x;
473            point.y += offset.y;
474            point.z += offset.z;
475        }
476    }
477
478    /// Apply scaling to point cloud
479    pub fn scale_3d(points: &mut [Point3D], scale: f32) {
480        for point in points.iter_mut() {
481            point.x *= scale;
482            point.y *= scale;
483            point.z *= scale;
484        }
485    }
486
487    /// Normalize point cloud to unit sphere
488    pub fn normalize_to_unit_sphere(points: &mut [Point3D]) {
489        if points.is_empty() {
490            return;
491        }
492
493        // Find center
494        let mut center = Point3D::new(0.0, 0.0, 0.0);
495        for point in points.iter() {
496            center.x += point.x;
497            center.y += point.y;
498            center.z += point.z;
499        }
500        center.x /= points.len() as f32;
501        center.y /= points.len() as f32;
502        center.z /= points.len() as f32;
503
504        // Translate to origin
505        Self::translate_3d(points, &Point3D::new(-center.x, -center.y, -center.z));
506
507        // Find max distance
508        let max_dist = points
509            .iter()
510            .map(|p| p.norm())
511            .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
512            .unwrap_or(1.0);
513
514        // Scale to unit sphere
515        if max_dist > 0.0 {
516            Self::scale_3d(points, 1.0 / max_dist);
517        }
518    }
519}
520
521/// Geometric pooling operations
522pub struct GeometricPooling;
523
524impl GeometricPooling {
525    /// Voxel-based pooling (divide space into voxels and pool within each)
526    pub fn voxel_pool(
527        points: &[Point3D],
528        features: &Tensor,
529        voxel_size: f32,
530    ) -> (Vec<Point3D>, Tensor) {
531        let feature_data = features.to_vec().expect("conversion should succeed");
532        let feature_dim = features.shape().dims()[1];
533
534        // Compute voxel indices
535        let mut voxel_map: HashMap<(i32, i32, i32), Vec<usize>> = HashMap::new();
536
537        for (i, point) in points.iter().enumerate() {
538            let vx = (point.x / voxel_size).floor() as i32;
539            let vy = (point.y / voxel_size).floor() as i32;
540            let vz = (point.z / voxel_size).floor() as i32;
541
542            voxel_map
543                .entry((vx, vy, vz))
544                .or_insert_with(Vec::new)
545                .push(i);
546        }
547
548        // Pool points and features within each voxel
549        let mut pooled_points = Vec::new();
550        let mut pooled_features = Vec::new();
551
552        for (_voxel, indices) in voxel_map {
553            if indices.is_empty() {
554                continue;
555            }
556
557            // Average position
558            let mut avg_point = Point3D::new(0.0, 0.0, 0.0);
559            for &idx in &indices {
560                avg_point.x += points[idx].x;
561                avg_point.y += points[idx].y;
562                avg_point.z += points[idx].z;
563            }
564            avg_point.x /= indices.len() as f32;
565            avg_point.y /= indices.len() as f32;
566            avg_point.z /= indices.len() as f32;
567
568            pooled_points.push(avg_point);
569
570            // Average features
571            let mut avg_features = vec![0.0; feature_dim];
572            for &idx in &indices {
573                for d in 0..feature_dim {
574                    avg_features[d] += feature_data[idx * feature_dim + d];
575                }
576            }
577            for val in &mut avg_features {
578                *val /= indices.len() as f32;
579            }
580
581            pooled_features.extend(avg_features);
582        }
583
584        let pooled_tensor = from_vec(
585            pooled_features,
586            &[pooled_points.len(), feature_dim],
587            torsh_core::device::DeviceType::Cpu,
588        )
589        .expect("pooled tensor creation should succeed");
590
591        (pooled_points, pooled_tensor)
592    }
593
594    /// Farthest point sampling
595    pub fn farthest_point_sampling(
596        points: &[Point3D],
597        features: &Tensor,
598        num_samples: usize,
599    ) -> (Vec<Point3D>, Tensor) {
600        let num_points = points.len();
601        let feature_dim = features.shape().dims()[1];
602        let feature_data = features.to_vec().expect("conversion should succeed");
603
604        if num_samples >= num_points {
605            return (points.to_vec(), features.clone());
606        }
607
608        let mut selected = Vec::new();
609        let mut distances = vec![f32::MAX; num_points];
610
611        // Start with random point
612        let mut rng = thread_rng();
613        let first_idx = rng.gen_range(0..num_points);
614        selected.push(first_idx);
615
616        // Update distances
617        for i in 0..num_points {
618            distances[i] = points[i].distance(&points[first_idx]);
619        }
620
621        // Iteratively select farthest point
622        for _ in 1..num_samples {
623            let farthest_idx = distances
624                .iter()
625                .enumerate()
626                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
627                .map(|(idx, _)| idx)
628                .unwrap_or(0);
629
630            selected.push(farthest_idx);
631
632            // Update distances
633            for i in 0..num_points {
634                let dist = points[i].distance(&points[farthest_idx]);
635                distances[i] = distances[i].min(dist);
636            }
637        }
638
639        // Extract selected points and features
640        let sampled_points: Vec<_> = selected.iter().map(|&idx| points[idx]).collect();
641        let sampled_features: Vec<_> = selected
642            .iter()
643            .flat_map(|&idx| {
644                let start = idx * feature_dim;
645                let end = start + feature_dim;
646                &feature_data[start..end]
647            })
648            .copied()
649            .collect();
650
651        let sampled_tensor = from_vec(
652            sampled_features,
653            &[num_samples, feature_dim],
654            torsh_core::device::DeviceType::Cpu,
655        )
656        .expect("sampled tensor creation should succeed");
657
658        (sampled_points, sampled_tensor)
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_point3d_distance() {
668        let p1 = Point3D::new(0.0, 0.0, 0.0);
669        let p2 = Point3D::new(3.0, 4.0, 0.0);
670
671        assert!((p1.distance(&p2) - 5.0).abs() < 1e-5);
672    }
673
674    #[test]
675    fn test_knn_graph() {
676        let points = vec![
677            Point3D::new(0.0, 0.0, 0.0),
678            Point3D::new(1.0, 0.0, 0.0),
679            Point3D::new(0.0, 1.0, 0.0),
680            Point3D::new(1.0, 1.0, 0.0),
681        ];
682
683        let graph = GeometricGraphBuilder::knn_graph(&points, 2, None);
684
685        assert_eq!(graph.num_nodes, 4);
686        assert_eq!(graph.x.shape().dims()[1], 3); // 3D coordinates
687        assert!(graph.edge_attr.is_some());
688    }
689
690    #[test]
691    fn test_radius_graph() {
692        let points = vec![
693            Point3D::new(0.0, 0.0, 0.0),
694            Point3D::new(0.5, 0.0, 0.0),
695            Point3D::new(2.0, 0.0, 0.0),
696        ];
697
698        let graph = GeometricGraphBuilder::radius_graph(&points, 1.0, None);
699
700        assert_eq!(graph.num_nodes, 3);
701        assert!(graph.num_edges >= 2); // At least points 0 and 1 connected
702    }
703
704    #[test]
705    fn test_geometric_conv() {
706        let points = vec![
707            Point3D::new(0.0, 0.0, 0.0),
708            Point3D::new(1.0, 0.0, 0.0),
709            Point3D::new(0.0, 1.0, 0.0),
710        ];
711
712        let graph = GeometricGraphBuilder::knn_graph(&points, 2, None);
713        let conv = GeometricConv::new(3, 6, 8, true);
714
715        let output = conv.forward(&graph);
716
717        assert_eq!(output.num_nodes, 3);
718        assert_eq!(output.x.shape().dims()[1], 6);
719    }
720
721    #[test]
722    fn test_geometric_rotation() {
723        let mut points = vec![Point3D::new(1.0, 0.0, 0.0)];
724
725        let axis = Point3D::new(0.0, 0.0, 1.0);
726        let angle = std::f32::consts::PI / 2.0;
727
728        GeometricTransformer::rotate_3d(&mut points, &axis, angle);
729
730        // After 90 degree rotation around Z-axis, (1,0,0) -> (0,1,0)
731        assert!((points[0].x - 0.0).abs() < 1e-5);
732        assert!((points[0].y - 1.0).abs() < 1e-5);
733    }
734
735    #[test]
736    fn test_normalize_to_unit_sphere() {
737        let mut points = vec![
738            Point3D::new(2.0, 0.0, 0.0),
739            Point3D::new(0.0, 2.0, 0.0),
740            Point3D::new(0.0, 0.0, 2.0),
741        ];
742
743        GeometricTransformer::normalize_to_unit_sphere(&mut points);
744
745        // All points should be within unit sphere
746        for point in &points {
747            assert!(point.norm() <= 1.0 + 1e-5);
748        }
749    }
750
751    #[test]
752    fn test_voxel_pooling() {
753        let points = vec![
754            Point3D::new(0.1, 0.1, 0.1),
755            Point3D::new(0.2, 0.2, 0.2),
756            Point3D::new(1.1, 1.1, 1.1),
757        ];
758
759        let features = randn(&[3, 4]).unwrap();
760
761        let (pooled_points, pooled_features) =
762            GeometricPooling::voxel_pool(&points, &features, 1.0);
763
764        assert!(pooled_points.len() <= 3);
765        assert_eq!(pooled_features.shape().dims()[1], 4);
766    }
767
768    #[test]
769    fn test_farthest_point_sampling() {
770        let points = vec![
771            Point3D::new(0.0, 0.0, 0.0),
772            Point3D::new(1.0, 0.0, 0.0),
773            Point3D::new(0.0, 1.0, 0.0),
774            Point3D::new(0.0, 0.0, 1.0),
775        ];
776
777        let features = randn(&[4, 3]).unwrap();
778
779        let (sampled_points, sampled_features) =
780            GeometricPooling::farthest_point_sampling(&points, &features, 2);
781
782        assert_eq!(sampled_points.len(), 2);
783        assert_eq!(sampled_features.shape().dims(), &[2, 3]);
784    }
785}