1use crate::parameter::Parameter;
16use crate::{GraphData, GraphLayer};
17use scirs2_core::random::thread_rng;
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use torsh_tensor::{
21 creation::{from_vec, randn, zeros},
22 Tensor,
23};
24
25#[derive(Debug, Clone, Copy)]
27pub struct Point3D {
28 pub x: f32,
29 pub y: f32,
30 pub z: f32,
31}
32
33impl Point3D {
34 pub fn new(x: f32, y: f32, z: f32) -> Self {
35 Self { x, y, z }
36 }
37
38 pub fn distance(&self, other: &Point3D) -> f32 {
39 ((self.x - other.x).powi(2) + (self.y - other.y).powi(2) + (self.z - other.z).powi(2))
40 .sqrt()
41 }
42
43 pub fn dot(&self, other: &Point3D) -> f32 {
44 self.x * other.x + self.y * other.y + self.z * other.z
45 }
46
47 pub fn norm(&self) -> f32 {
48 (self.x.powi(2) + self.y.powi(2) + self.z.powi(2)).sqrt()
49 }
50}
51
52pub struct GeometricGraphBuilder;
54
55impl GeometricGraphBuilder {
56 pub fn knn_graph(points: &[Point3D], k: usize, features: Option<Tensor>) -> GraphData {
58 let num_points = points.len();
59 let mut edges = Vec::new();
60 let mut edge_weights = Vec::new();
61
62 for i in 0..num_points {
63 let mut distances: Vec<(usize, f32)> = (0..num_points)
65 .filter(|&j| j != i)
66 .map(|j| (j, points[i].distance(&points[j])))
67 .collect();
68
69 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
70
71 for (j, dist) in distances.iter().take(k) {
72 edges.push(i as f32);
73 edges.push(*j as f32);
74 edge_weights.push(*dist);
75 }
76 }
77
78 let num_edges = edges.len() / 2;
79 let edge_index = from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
80 .expect("knn edge tensor creation should succeed");
81
82 let x = features.unwrap_or_else(|| {
84 let coords: Vec<f32> = points.iter().flat_map(|p| vec![p.x, p.y, p.z]).collect();
85 from_vec(
86 coords,
87 &[num_points, 3],
88 torsh_core::device::DeviceType::Cpu,
89 )
90 .expect("coordinate feature tensor creation should succeed")
91 });
92
93 let mut graph = GraphData::new(x, edge_index);
94
95 let edge_attr = from_vec(
97 edge_weights,
98 &[num_edges, 1],
99 torsh_core::device::DeviceType::Cpu,
100 )
101 .expect("edge attribute tensor creation should succeed");
102 graph.edge_attr = Some(edge_attr);
103
104 graph
105 }
106
107 pub fn radius_graph(points: &[Point3D], radius: f32, features: Option<Tensor>) -> GraphData {
109 let num_points = points.len();
110 let mut edges = Vec::new();
111 let mut edge_weights = Vec::new();
112
113 for i in 0..num_points {
114 for j in (i + 1)..num_points {
115 let dist = points[i].distance(&points[j]);
116
117 if dist <= radius {
118 edges.push(i as f32);
119 edges.push(j as f32);
120 edges.push(j as f32);
121 edges.push(i as f32);
122 edge_weights.push(dist);
123 edge_weights.push(dist);
124 }
125 }
126 }
127
128 let num_edges = edges.len() / 2;
129 let edge_index = if num_edges > 0 {
130 from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
131 .expect("radius edge tensor creation should succeed")
132 } else {
133 from_vec(vec![], &[2, 0], torsh_core::device::DeviceType::Cpu)
134 .expect("empty edge tensor creation should succeed")
135 };
136
137 let x = features.unwrap_or_else(|| {
138 let coords: Vec<f32> = points.iter().flat_map(|p| vec![p.x, p.y, p.z]).collect();
139 from_vec(
140 coords,
141 &[num_points, 3],
142 torsh_core::device::DeviceType::Cpu,
143 )
144 .expect("coordinate feature tensor creation should succeed")
145 });
146
147 let mut graph = GraphData::new(x, edge_index);
148
149 if num_edges > 0 {
150 let edge_attr = from_vec(
151 edge_weights,
152 &[num_edges, 1],
153 torsh_core::device::DeviceType::Cpu,
154 )
155 .expect("edge attribute tensor creation should succeed");
156 graph.edge_attr = Some(edge_attr);
157 }
158
159 graph
160 }
161
162 pub fn delaunay_graph_2d(points: &[(f32, f32)], features: Option<Tensor>) -> GraphData {
164 let num_points = points.len();
165
166 let mut edges = Vec::new();
169 let mut visited_pairs: std::collections::HashSet<(usize, usize)> =
170 std::collections::HashSet::new();
171
172 for i in 0..num_points {
173 let k = 5;
175 let mut distances: Vec<(usize, f32)> = (0..num_points)
176 .filter(|&j| j != i)
177 .map(|j| {
178 let dx = points[i].0 - points[j].0;
179 let dy = points[i].1 - points[j].1;
180 (j, (dx * dx + dy * dy).sqrt())
181 })
182 .collect();
183
184 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
185
186 for (j, _) in distances.iter().take(k) {
187 let pair = if i < *j { (i, *j) } else { (*j, i) };
188
189 if !visited_pairs.contains(&pair) {
190 visited_pairs.insert(pair);
191 edges.push(i as f32);
192 edges.push(*j as f32);
193 edges.push(*j as f32);
194 edges.push(i as f32);
195 }
196 }
197 }
198
199 let num_edges = edges.len() / 2;
200 let edge_index = from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
201 .expect("delaunay edge tensor creation should succeed");
202
203 let x = features.unwrap_or_else(|| {
204 let coords: Vec<f32> = points.iter().flat_map(|(x, y)| vec![*x, *y]).collect();
205 from_vec(
206 coords,
207 &[num_points, 2],
208 torsh_core::device::DeviceType::Cpu,
209 )
210 .expect("coordinate feature tensor creation should succeed")
211 });
212
213 GraphData::new(x, edge_index)
214 }
215}
216
217#[derive(Debug)]
219pub struct GeometricConv {
220 in_features: usize,
221 out_features: usize,
222 hidden_dim: usize,
223
224 message_mlp: Vec<Parameter>,
226
227 distance_encoder: Parameter,
229
230 output_weight: Parameter,
232
233 bias: Option<Parameter>,
234}
235
236impl GeometricConv {
237 pub fn new(in_features: usize, out_features: usize, hidden_dim: usize, use_bias: bool) -> Self {
239 let message_layer1 = Parameter::new(
241 randn(&[in_features * 2 + 1, hidden_dim])
242 .expect("randn should succeed for valid dimensions"),
243 );
244 let message_layer2 = Parameter::new(
245 randn(&[hidden_dim, hidden_dim]).expect("randn should succeed for valid dimensions"),
246 );
247
248 let distance_encoder = Parameter::new(
249 randn(&[1, hidden_dim]).expect("randn should succeed for valid dimensions"),
250 );
251 let output_weight = Parameter::new(
252 randn(&[hidden_dim, out_features]).expect("randn should succeed for valid dimensions"),
253 );
254
255 let bias = if use_bias {
256 Some(Parameter::new(
257 zeros(&[out_features]).expect("zeros should succeed for valid dimensions"),
258 ))
259 } else {
260 None
261 };
262
263 Self {
264 in_features,
265 out_features,
266 hidden_dim,
267 message_mlp: vec![message_layer1, message_layer2],
268 distance_encoder,
269 output_weight,
270 bias,
271 }
272 }
273
274 pub fn forward(&self, graph: &GraphData) -> GraphData {
276 let num_nodes = graph.num_nodes;
277 let num_edges = graph.num_edges;
278
279 let edge_distances = if let Some(ref edge_attr) = graph.edge_attr {
281 edge_attr.to_vec().expect("conversion should succeed")
282 } else {
283 vec![1.0; num_edges]
284 };
285
286 let edge_data = graph
288 .edge_index
289 .to_vec()
290 .expect("conversion should succeed");
291 let mut aggregated = vec![0.0; num_nodes * self.hidden_dim];
292
293 let node_features = graph.x.to_vec().expect("conversion should succeed");
294
295 for edge_idx in 0..num_edges {
296 let src = edge_data[edge_idx * 2] as usize;
297 let dst = edge_data[edge_idx * 2 + 1] as usize;
298
299 if src >= num_nodes || dst >= num_nodes {
300 continue;
301 }
302
303 let src_features = &node_features[src * self.in_features..(src + 1) * self.in_features];
305 let dst_features = &node_features[dst * self.in_features..(dst + 1) * self.in_features];
306
307 let dist = edge_distances[edge_idx.min(edge_distances.len() - 1)];
309
310 let mut message_input = Vec::new();
312 message_input.extend_from_slice(src_features);
313 message_input.extend_from_slice(dst_features);
314 message_input.push(dist);
315
316 let message = self.compute_message(&message_input);
318
319 for (i, &val) in message.iter().enumerate() {
321 aggregated[dst * self.hidden_dim + i] += val;
322 }
323 }
324
325 let mut output_features = vec![0.0; num_nodes * self.out_features];
327
328 for node in 0..num_nodes {
329 let agg_features = &aggregated[node * self.hidden_dim..(node + 1) * self.hidden_dim];
330 let output_proj = self
331 .output_weight
332 .clone_data()
333 .to_vec()
334 .expect("conversion should succeed");
335
336 for out_idx in 0..self.out_features {
337 let mut sum = 0.0;
338 for hid_idx in 0..self.hidden_dim {
339 sum +=
340 agg_features[hid_idx] * output_proj[hid_idx * self.out_features + out_idx];
341 }
342
343 if let Some(ref bias) = self.bias {
344 let bias_data = bias
345 .clone_data()
346 .to_vec()
347 .expect("conversion should succeed");
348 if out_idx < bias_data.len() {
349 sum += bias_data[out_idx];
350 }
351 }
352
353 output_features[node * self.out_features + out_idx] = sum;
354 }
355 }
356
357 let output = from_vec(
358 output_features,
359 &[num_nodes, self.out_features],
360 torsh_core::device::DeviceType::Cpu,
361 )
362 .expect("output tensor creation should succeed");
363
364 let mut output_graph = graph.clone();
365 output_graph.x = output;
366 output_graph
367 }
368
369 fn compute_message(&self, input: &[f32]) -> Vec<f32> {
371 let layer1_weights = self.message_mlp[0]
373 .clone_data()
374 .to_vec()
375 .expect("conversion should succeed");
376 let input_dim = self.in_features * 2 + 1;
377 let mut hidden = vec![0.0; self.hidden_dim];
378
379 for h in 0..self.hidden_dim {
380 let mut sum = 0.0;
381 for i in 0..input_dim.min(input.len()) {
382 sum += input[i] * layer1_weights[i * self.hidden_dim + h];
383 }
384 hidden[h] = sum.max(0.0); }
386
387 let layer2_weights = self.message_mlp[1]
389 .clone_data()
390 .to_vec()
391 .expect("conversion should succeed");
392 let mut output = vec![0.0; self.hidden_dim];
393
394 for h in 0..self.hidden_dim {
395 let mut sum = 0.0;
396 for i in 0..self.hidden_dim {
397 sum += hidden[i] * layer2_weights[i * self.hidden_dim + h];
398 }
399 output[h] = sum.max(0.0); }
401
402 output
403 }
404}
405
406impl GraphLayer for GeometricConv {
407 fn forward(&self, graph: &GraphData) -> GraphData {
408 self.forward(graph)
409 }
410
411 fn parameters(&self) -> Vec<Tensor> {
412 let mut params = Vec::new();
413
414 for layer in &self.message_mlp {
415 params.push(layer.clone_data());
416 }
417
418 params.push(self.distance_encoder.clone_data());
419 params.push(self.output_weight.clone_data());
420
421 if let Some(ref bias) = self.bias {
422 params.push(bias.clone_data());
423 }
424
425 params
426 }
427}
428
429pub struct GeometricTransformer;
431
432impl GeometricTransformer {
433 pub fn rotate_3d(points: &mut [Point3D], axis: &Point3D, angle: f32) {
435 let cos_theta = angle.cos();
436 let sin_theta = angle.sin();
437
438 let norm = axis.norm();
440 if norm == 0.0 {
441 return;
442 }
443
444 let ux = axis.x / norm;
445 let uy = axis.y / norm;
446 let uz = axis.z / norm;
447
448 for point in points.iter_mut() {
450 let x = point.x;
451 let y = point.y;
452 let z = point.z;
453
454 let dot = ux * x + uy * y + uz * z;
456
457 let cross_x = uy * z - uz * y;
459 let cross_y = uz * x - ux * z;
460 let cross_z = ux * y - uy * x;
461
462 point.x = x * cos_theta + cross_x * sin_theta + ux * dot * (1.0 - cos_theta);
464 point.y = y * cos_theta + cross_y * sin_theta + uy * dot * (1.0 - cos_theta);
465 point.z = z * cos_theta + cross_z * sin_theta + uz * dot * (1.0 - cos_theta);
466 }
467 }
468
469 pub fn translate_3d(points: &mut [Point3D], offset: &Point3D) {
471 for point in points.iter_mut() {
472 point.x += offset.x;
473 point.y += offset.y;
474 point.z += offset.z;
475 }
476 }
477
478 pub fn scale_3d(points: &mut [Point3D], scale: f32) {
480 for point in points.iter_mut() {
481 point.x *= scale;
482 point.y *= scale;
483 point.z *= scale;
484 }
485 }
486
487 pub fn normalize_to_unit_sphere(points: &mut [Point3D]) {
489 if points.is_empty() {
490 return;
491 }
492
493 let mut center = Point3D::new(0.0, 0.0, 0.0);
495 for point in points.iter() {
496 center.x += point.x;
497 center.y += point.y;
498 center.z += point.z;
499 }
500 center.x /= points.len() as f32;
501 center.y /= points.len() as f32;
502 center.z /= points.len() as f32;
503
504 Self::translate_3d(points, &Point3D::new(-center.x, -center.y, -center.z));
506
507 let max_dist = points
509 .iter()
510 .map(|p| p.norm())
511 .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
512 .unwrap_or(1.0);
513
514 if max_dist > 0.0 {
516 Self::scale_3d(points, 1.0 / max_dist);
517 }
518 }
519}
520
521pub struct GeometricPooling;
523
524impl GeometricPooling {
525 pub fn voxel_pool(
527 points: &[Point3D],
528 features: &Tensor,
529 voxel_size: f32,
530 ) -> (Vec<Point3D>, Tensor) {
531 let feature_data = features.to_vec().expect("conversion should succeed");
532 let feature_dim = features.shape().dims()[1];
533
534 let mut voxel_map: HashMap<(i32, i32, i32), Vec<usize>> = HashMap::new();
536
537 for (i, point) in points.iter().enumerate() {
538 let vx = (point.x / voxel_size).floor() as i32;
539 let vy = (point.y / voxel_size).floor() as i32;
540 let vz = (point.z / voxel_size).floor() as i32;
541
542 voxel_map
543 .entry((vx, vy, vz))
544 .or_insert_with(Vec::new)
545 .push(i);
546 }
547
548 let mut pooled_points = Vec::new();
550 let mut pooled_features = Vec::new();
551
552 for (_voxel, indices) in voxel_map {
553 if indices.is_empty() {
554 continue;
555 }
556
557 let mut avg_point = Point3D::new(0.0, 0.0, 0.0);
559 for &idx in &indices {
560 avg_point.x += points[idx].x;
561 avg_point.y += points[idx].y;
562 avg_point.z += points[idx].z;
563 }
564 avg_point.x /= indices.len() as f32;
565 avg_point.y /= indices.len() as f32;
566 avg_point.z /= indices.len() as f32;
567
568 pooled_points.push(avg_point);
569
570 let mut avg_features = vec![0.0; feature_dim];
572 for &idx in &indices {
573 for d in 0..feature_dim {
574 avg_features[d] += feature_data[idx * feature_dim + d];
575 }
576 }
577 for val in &mut avg_features {
578 *val /= indices.len() as f32;
579 }
580
581 pooled_features.extend(avg_features);
582 }
583
584 let pooled_tensor = from_vec(
585 pooled_features,
586 &[pooled_points.len(), feature_dim],
587 torsh_core::device::DeviceType::Cpu,
588 )
589 .expect("pooled tensor creation should succeed");
590
591 (pooled_points, pooled_tensor)
592 }
593
594 pub fn farthest_point_sampling(
596 points: &[Point3D],
597 features: &Tensor,
598 num_samples: usize,
599 ) -> (Vec<Point3D>, Tensor) {
600 let num_points = points.len();
601 let feature_dim = features.shape().dims()[1];
602 let feature_data = features.to_vec().expect("conversion should succeed");
603
604 if num_samples >= num_points {
605 return (points.to_vec(), features.clone());
606 }
607
608 let mut selected = Vec::new();
609 let mut distances = vec![f32::MAX; num_points];
610
611 let mut rng = thread_rng();
613 let first_idx = rng.gen_range(0..num_points);
614 selected.push(first_idx);
615
616 for i in 0..num_points {
618 distances[i] = points[i].distance(&points[first_idx]);
619 }
620
621 for _ in 1..num_samples {
623 let farthest_idx = distances
624 .iter()
625 .enumerate()
626 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
627 .map(|(idx, _)| idx)
628 .unwrap_or(0);
629
630 selected.push(farthest_idx);
631
632 for i in 0..num_points {
634 let dist = points[i].distance(&points[farthest_idx]);
635 distances[i] = distances[i].min(dist);
636 }
637 }
638
639 let sampled_points: Vec<_> = selected.iter().map(|&idx| points[idx]).collect();
641 let sampled_features: Vec<_> = selected
642 .iter()
643 .flat_map(|&idx| {
644 let start = idx * feature_dim;
645 let end = start + feature_dim;
646 &feature_data[start..end]
647 })
648 .copied()
649 .collect();
650
651 let sampled_tensor = from_vec(
652 sampled_features,
653 &[num_samples, feature_dim],
654 torsh_core::device::DeviceType::Cpu,
655 )
656 .expect("sampled tensor creation should succeed");
657
658 (sampled_points, sampled_tensor)
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_point3d_distance() {
668 let p1 = Point3D::new(0.0, 0.0, 0.0);
669 let p2 = Point3D::new(3.0, 4.0, 0.0);
670
671 assert!((p1.distance(&p2) - 5.0).abs() < 1e-5);
672 }
673
674 #[test]
675 fn test_knn_graph() {
676 let points = vec![
677 Point3D::new(0.0, 0.0, 0.0),
678 Point3D::new(1.0, 0.0, 0.0),
679 Point3D::new(0.0, 1.0, 0.0),
680 Point3D::new(1.0, 1.0, 0.0),
681 ];
682
683 let graph = GeometricGraphBuilder::knn_graph(&points, 2, None);
684
685 assert_eq!(graph.num_nodes, 4);
686 assert_eq!(graph.x.shape().dims()[1], 3); assert!(graph.edge_attr.is_some());
688 }
689
690 #[test]
691 fn test_radius_graph() {
692 let points = vec![
693 Point3D::new(0.0, 0.0, 0.0),
694 Point3D::new(0.5, 0.0, 0.0),
695 Point3D::new(2.0, 0.0, 0.0),
696 ];
697
698 let graph = GeometricGraphBuilder::radius_graph(&points, 1.0, None);
699
700 assert_eq!(graph.num_nodes, 3);
701 assert!(graph.num_edges >= 2); }
703
704 #[test]
705 fn test_geometric_conv() {
706 let points = vec![
707 Point3D::new(0.0, 0.0, 0.0),
708 Point3D::new(1.0, 0.0, 0.0),
709 Point3D::new(0.0, 1.0, 0.0),
710 ];
711
712 let graph = GeometricGraphBuilder::knn_graph(&points, 2, None);
713 let conv = GeometricConv::new(3, 6, 8, true);
714
715 let output = conv.forward(&graph);
716
717 assert_eq!(output.num_nodes, 3);
718 assert_eq!(output.x.shape().dims()[1], 6);
719 }
720
721 #[test]
722 fn test_geometric_rotation() {
723 let mut points = vec![Point3D::new(1.0, 0.0, 0.0)];
724
725 let axis = Point3D::new(0.0, 0.0, 1.0);
726 let angle = std::f32::consts::PI / 2.0;
727
728 GeometricTransformer::rotate_3d(&mut points, &axis, angle);
729
730 assert!((points[0].x - 0.0).abs() < 1e-5);
732 assert!((points[0].y - 1.0).abs() < 1e-5);
733 }
734
735 #[test]
736 fn test_normalize_to_unit_sphere() {
737 let mut points = vec![
738 Point3D::new(2.0, 0.0, 0.0),
739 Point3D::new(0.0, 2.0, 0.0),
740 Point3D::new(0.0, 0.0, 2.0),
741 ];
742
743 GeometricTransformer::normalize_to_unit_sphere(&mut points);
744
745 for point in &points {
747 assert!(point.norm() <= 1.0 + 1e-5);
748 }
749 }
750
751 #[test]
752 fn test_voxel_pooling() {
753 let points = vec![
754 Point3D::new(0.1, 0.1, 0.1),
755 Point3D::new(0.2, 0.2, 0.2),
756 Point3D::new(1.1, 1.1, 1.1),
757 ];
758
759 let features = randn(&[3, 4]).unwrap();
760
761 let (pooled_points, pooled_features) =
762 GeometricPooling::voxel_pool(&points, &features, 1.0);
763
764 assert!(pooled_points.len() <= 3);
765 assert_eq!(pooled_features.shape().dims()[1], 4);
766 }
767
768 #[test]
769 fn test_farthest_point_sampling() {
770 let points = vec![
771 Point3D::new(0.0, 0.0, 0.0),
772 Point3D::new(1.0, 0.0, 0.0),
773 Point3D::new(0.0, 1.0, 0.0),
774 Point3D::new(0.0, 0.0, 1.0),
775 ];
776
777 let features = randn(&[4, 3]).unwrap();
778
779 let (sampled_points, sampled_features) =
780 GeometricPooling::farthest_point_sampling(&points, &features, 2);
781
782 assert_eq!(sampled_points.len(), 2);
783 assert_eq!(sampled_features.shape().dims(), &[2, 3]);
784 }
785}