1#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use torsh_tensor::{
20 creation::{from_vec, randn, zeros},
21 Tensor,
22};
23
24#[derive(Debug)]
27pub struct GraphFNO {
28 in_features: usize,
29 out_features: usize,
30 hidden_features: usize,
31 num_modes: usize,
32 num_layers: usize,
33
34 fourier_weights: Vec<Parameter>,
36 conv_weights: Vec<Parameter>,
37
38 input_projection: Parameter,
40 output_projection: Parameter,
41
42 bias: Option<Parameter>,
44}
45
46impl GraphFNO {
47 pub fn new(
49 in_features: usize,
50 out_features: usize,
51 hidden_features: usize,
52 num_modes: usize,
53 num_layers: usize,
54 bias: bool,
55 ) -> Self {
56 let mut fourier_weights = Vec::new();
57 let mut conv_weights = Vec::new();
58
59 for _ in 0..num_layers {
61 fourier_weights.push(Parameter::new(
62 randn(&[hidden_features, hidden_features, num_modes])
63 .expect("failed to create fourier_weights tensor"),
64 ));
65 conv_weights.push(Parameter::new(
66 randn(&[hidden_features, hidden_features])
67 .expect("failed to create conv_weights tensor"),
68 ));
69 }
70
71 let input_projection = Parameter::new(
72 randn(&[in_features, hidden_features])
73 .expect("failed to create input_projection tensor"),
74 );
75 let output_projection = Parameter::new(
76 randn(&[hidden_features, out_features])
77 .expect("failed to create output_projection tensor"),
78 );
79
80 let bias = if bias {
81 Some(Parameter::new(
82 zeros::<f32>(&[out_features]).expect("failed to create bias tensor"),
83 ))
84 } else {
85 None
86 };
87
88 Self {
89 in_features,
90 out_features,
91 hidden_features,
92 num_modes,
93 num_layers,
94 fourier_weights,
95 conv_weights,
96 input_projection,
97 output_projection,
98 bias,
99 }
100 }
101
102 pub fn forward(&self, graph: &GraphData) -> GraphData {
104 let _num_nodes = graph.num_nodes;
105
106 let mut x = graph
108 .x
109 .matmul(&self.input_projection.clone_data())
110 .expect("operation should succeed");
111
112 for layer in 0..self.num_layers {
114 x = self.fourier_layer(&x, layer, graph);
115 }
116
117 let mut output = x
119 .matmul(&self.output_projection.clone_data())
120 .expect("operation should succeed");
121
122 if let Some(ref bias) = self.bias {
124 output = output
125 .add(&bias.clone_data())
126 .expect("operation should succeed");
127 }
128
129 let mut output_graph = graph.clone();
131 output_graph.x = output;
132 output_graph
133 }
134
135 fn fourier_layer(&self, x: &Tensor, layer: usize, graph: &GraphData) -> Tensor {
137 let fourier_x = self.graph_fourier_transform(x, graph);
139
140 let fourier_weights = &self.fourier_weights[layer];
142 let spectral_conv = self.spectral_convolution(&fourier_x, fourier_weights);
143
144 let spatial_features = self.inverse_graph_fourier_transform(&spectral_conv, graph);
146
147 let conv_weights = &self.conv_weights[layer];
149 let conv_output = spatial_features
150 .matmul(&conv_weights.clone_data())
151 .expect("operation should succeed");
152
153 let residual = x.add(&conv_output).expect("operation should succeed");
155
156 self.relu(&residual)
158 }
159
160 fn graph_fourier_transform(&self, x: &Tensor, graph: &GraphData) -> Tensor {
162 let num_nodes = graph.num_nodes;
165
166 let mut transform_data = Vec::new();
168 for i in 0..num_nodes {
169 for j in 0..self.num_modes {
170 let freq = (j as f32 + 1.0) * std::f32::consts::PI / num_nodes as f32;
171 let basis = (freq * i as f32).cos();
172 transform_data.push(basis);
173 }
174 }
175
176 let transform_matrix = from_vec(
177 transform_data,
178 &[num_nodes, self.num_modes],
179 torsh_core::device::DeviceType::Cpu,
180 )
181 .expect("GFT transform matrix creation should succeed");
182
183 transform_matrix
185 .t()
186 .expect("operation should succeed")
187 .matmul(x)
188 .expect("operation should succeed")
189 }
190
191 fn inverse_graph_fourier_transform(&self, fourier_x: &Tensor, graph: &GraphData) -> Tensor {
193 let num_nodes = graph.num_nodes;
194
195 let mut inv_transform_data = Vec::new();
197 for i in 0..num_nodes {
198 for j in 0..self.num_modes {
199 let freq = (j as f32 + 1.0) * std::f32::consts::PI / num_nodes as f32;
200 let basis = (freq * i as f32).cos();
201 inv_transform_data.push(basis);
202 }
203 }
204
205 let inv_transform_matrix = from_vec(
206 inv_transform_data,
207 &[num_nodes, self.num_modes],
208 torsh_core::device::DeviceType::Cpu,
209 )
210 .expect("inverse GFT transform matrix creation should succeed");
211
212 inv_transform_matrix
214 .matmul(fourier_x)
215 .expect("operation should succeed")
216 }
217
218 fn spectral_convolution(&self, fourier_x: &Tensor, weights: &Parameter) -> Tensor {
220 let weight_data = weights.clone_data();
222
223 let weight_2d = weight_data
226 .slice_tensor(2, 0, 1)
227 .expect("spectral weight slice should succeed")
228 .squeeze_tensor(2)
229 .expect("spectral weight squeeze should succeed");
230
231 fourier_x
232 .matmul(&weight_2d)
233 .expect("operation should succeed")
234 }
235
236 fn relu(&self, x: &Tensor) -> Tensor {
238 let data = x.to_vec().expect("conversion should succeed");
240 let activated_data: Vec<f32> = data.iter().map(|&val| val.max(0.0)).collect();
241
242 from_vec(
243 activated_data,
244 x.shape().dims(),
245 torsh_core::device::DeviceType::Cpu,
246 )
247 .expect("GraphFNO relu tensor creation should succeed")
248 }
249}
250
251impl GraphLayer for GraphFNO {
252 fn forward(&self, graph: &GraphData) -> GraphData {
253 self.forward(graph)
254 }
255
256 fn parameters(&self) -> Vec<Tensor> {
257 let mut params = vec![
258 self.input_projection.clone_data(),
259 self.output_projection.clone_data(),
260 ];
261
262 for weight in &self.fourier_weights {
263 params.push(weight.clone_data());
264 }
265
266 for weight in &self.conv_weights {
267 params.push(weight.clone_data());
268 }
269
270 if let Some(ref bias) = self.bias {
271 params.push(bias.clone_data());
272 }
273
274 params
275 }
276}
277
278#[derive(Debug)]
280pub struct GraphDeepONet {
281 trunk_net_features: usize,
282 branch_net_features: usize,
283 hidden_features: usize,
284 output_features: usize,
285 num_sensors: usize,
286
287 branch_layers: Vec<Parameter>,
289
290 trunk_layers: Vec<Parameter>,
292
293 bias: Option<Parameter>,
295}
296
297impl GraphDeepONet {
298 pub fn new(
300 trunk_net_features: usize,
301 branch_net_features: usize,
302 hidden_features: usize,
303 output_features: usize,
304 num_sensors: usize,
305 num_layers: usize,
306 bias: bool,
307 ) -> Self {
308 let mut branch_layers = Vec::new();
309 let mut trunk_layers = Vec::new();
310
311 for i in 0..num_layers {
313 let in_dim = if i == 0 { num_sensors } else { hidden_features };
314 let out_dim = if i == num_layers - 1 {
315 output_features
316 } else {
317 hidden_features
318 };
319 branch_layers.push(Parameter::new(
320 randn(&[in_dim, out_dim]).expect("failed to create branch layer tensor"),
321 ));
322 }
323
324 for i in 0..num_layers {
326 let in_dim = if i == 0 {
327 trunk_net_features
328 } else {
329 hidden_features
330 };
331 let out_dim = if i == num_layers - 1 {
332 output_features
333 } else {
334 hidden_features
335 };
336 trunk_layers.push(Parameter::new(
337 randn(&[in_dim, out_dim]).expect("failed to create trunk layer tensor"),
338 ));
339 }
340
341 let bias = if bias {
342 Some(Parameter::new(
343 zeros::<f32>(&[output_features]).expect("failed to create DeepONet bias tensor"),
344 ))
345 } else {
346 None
347 };
348
349 Self {
350 trunk_net_features,
351 branch_net_features,
352 hidden_features,
353 output_features,
354 num_sensors,
355 branch_layers,
356 trunk_layers,
357 bias,
358 }
359 }
360
361 pub fn forward(
363 &self,
364 graph: &GraphData,
365 sensor_data: &Tensor,
366 locations: &Tensor,
367 ) -> GraphData {
368 let branch_output = self.forward_branch_net(sensor_data);
370
371 let trunk_output = self.forward_trunk_net(locations);
373
374 let combined = self.combine_outputs(&branch_output, &trunk_output);
376
377 let mut output = combined;
379 if let Some(ref bias) = self.bias {
380 output = output
381 .add(&bias.clone_data())
382 .expect("operation should succeed");
383 }
384
385 let mut output_graph = graph.clone();
387 output_graph.x = output;
388 output_graph
389 }
390
391 fn forward_branch_net(&self, sensor_data: &Tensor) -> Tensor {
393 let mut x = sensor_data.clone();
394
395 for (i, layer) in self.branch_layers.iter().enumerate() {
396 x = x
397 .matmul(&layer.clone_data())
398 .expect("operation should succeed");
399
400 if i < self.branch_layers.len() - 1 {
402 x = self.tanh(&x);
403 }
404 }
405
406 x
407 }
408
409 fn forward_trunk_net(&self, locations: &Tensor) -> Tensor {
411 let mut x = locations.clone();
412
413 for (i, layer) in self.trunk_layers.iter().enumerate() {
414 x = x
415 .matmul(&layer.clone_data())
416 .expect("operation should succeed");
417
418 if i < self.trunk_layers.len() - 1 {
420 x = self.tanh(&x);
421 }
422 }
423
424 x
425 }
426
427 fn combine_outputs(&self, branch_output: &Tensor, trunk_output: &Tensor) -> Tensor {
429 branch_output
431 .mul(trunk_output)
432 .expect("operation should succeed")
433 }
434
435 fn tanh(&self, x: &Tensor) -> Tensor {
437 let data = x.to_vec().expect("conversion should succeed");
438 let activated_data: Vec<f32> = data.iter().map(|&val| val.tanh()).collect();
439
440 from_vec(
441 activated_data,
442 x.shape().dims(),
443 torsh_core::device::DeviceType::Cpu,
444 )
445 .expect("DeepONet tanh tensor creation should succeed")
446 }
447}
448
449impl GraphLayer for GraphDeepONet {
450 fn forward(&self, graph: &GraphData) -> GraphData {
451 let sensor_data = graph
453 .x
454 .slice_tensor(1, 0, self.num_sensors.min(graph.x.shape().dims()[1]))
455 .expect("sensor data slice should succeed");
456 let locations = graph.x.clone();
457
458 self.forward(graph, &sensor_data, &locations)
459 }
460
461 fn parameters(&self) -> Vec<Tensor> {
462 let mut params = Vec::new();
463
464 for layer in &self.branch_layers {
465 params.push(layer.clone_data());
466 }
467
468 for layer in &self.trunk_layers {
469 params.push(layer.clone_data());
470 }
471
472 if let Some(ref bias) = self.bias {
473 params.push(bias.clone_data());
474 }
475
476 params
477 }
478}
479
480#[derive(Debug)]
482pub struct PhysicsInformedGNN {
483 in_features: usize,
484 out_features: usize,
485 hidden_features: usize,
486
487 layers: Vec<Parameter>,
489
490 diffusion_coefficient: f32,
492 reaction_rate: f32,
493
494 bias: Option<Parameter>,
496}
497
498impl PhysicsInformedGNN {
499 pub fn new(
501 in_features: usize,
502 out_features: usize,
503 hidden_features: usize,
504 num_layers: usize,
505 diffusion_coefficient: f32,
506 reaction_rate: f32,
507 bias: bool,
508 ) -> Self {
509 let mut layers = Vec::new();
510
511 for i in 0..num_layers {
512 let in_dim = if i == 0 { in_features } else { hidden_features };
513 let out_dim = if i == num_layers - 1 {
514 out_features
515 } else {
516 hidden_features
517 };
518 layers.push(Parameter::new(
519 randn(&[in_dim, out_dim]).expect("failed to create PIGNN layer tensor"),
520 ));
521 }
522
523 let bias = if bias {
524 Some(Parameter::new(
525 zeros::<f32>(&[out_features]).expect("failed to create PIGNN bias tensor"),
526 ))
527 } else {
528 None
529 };
530
531 Self {
532 in_features,
533 out_features,
534 hidden_features,
535 layers,
536 diffusion_coefficient,
537 reaction_rate,
538 bias,
539 }
540 }
541
542 pub fn forward(&self, graph: &GraphData) -> GraphData {
544 let mut x = graph.x.clone();
546
547 for (i, layer) in self.layers.iter().enumerate() {
548 x = x
549 .matmul(&layer.clone_data())
550 .expect("operation should succeed");
551
552 if i < self.layers.len() - 1 {
554 x = self.swish(&x);
555 }
556 }
557
558 let physics_constrained = self.apply_physics_constraints(&x, graph);
560
561 let mut output = physics_constrained;
563 if let Some(ref bias) = self.bias {
564 output = output
565 .add(&bias.clone_data())
566 .expect("operation should succeed");
567 }
568
569 let mut output_graph = graph.clone();
571 output_graph.x = output;
572 output_graph
573 }
574
575 fn apply_physics_constraints(&self, prediction: &Tensor, graph: &GraphData) -> Tensor {
577 let laplacian = self.compute_graph_laplacian(graph);
579
580 let diffusion_term = laplacian
582 .matmul(prediction)
583 .expect("operation should succeed")
584 .mul_scalar(self.diffusion_coefficient)
585 .expect("operation should succeed");
586
587 let reaction_term = prediction
589 .mul_scalar(self.reaction_rate)
590 .expect("operation should succeed");
591
592 prediction
594 .add(&diffusion_term)
595 .expect("operation should succeed")
596 .add(&reaction_term)
597 .expect("operation should succeed")
598 }
599
600 fn compute_graph_laplacian(&self, graph: &GraphData) -> Tensor {
602 let num_nodes = graph.num_nodes;
603 let _num_edges = graph.num_edges;
604
605 let mut adj_data = vec![0.0f32; num_nodes * num_nodes];
607
608 let edge_data = graph
610 .edge_index
611 .to_vec()
612 .expect("conversion should succeed");
613 for i in (0..edge_data.len()).step_by(2) {
614 if i + 1 < edge_data.len() {
615 let src = edge_data[i] as usize;
616 let dst = edge_data[i + 1] as usize;
617
618 if src < num_nodes && dst < num_nodes {
619 adj_data[src * num_nodes + dst] = 1.0;
620 adj_data[dst * num_nodes + src] = 1.0; }
622 }
623 }
624
625 let mut degree_data = vec![0.0f32; num_nodes * num_nodes];
627 for i in 0..num_nodes {
628 let mut degree = 0.0;
629 for j in 0..num_nodes {
630 degree += adj_data[i * num_nodes + j];
631 }
632 degree_data[i * num_nodes + i] = degree;
633 }
634
635 let mut laplacian_data = Vec::new();
637 for i in 0..num_nodes * num_nodes {
638 laplacian_data.push(degree_data[i] - adj_data[i]);
639 }
640
641 from_vec(
642 laplacian_data,
643 &[num_nodes, num_nodes],
644 torsh_core::device::DeviceType::Cpu,
645 )
646 .expect("graph Laplacian tensor creation should succeed")
647 }
648
649 fn swish(&self, x: &Tensor) -> Tensor {
651 let data = x.to_vec().expect("conversion should succeed");
652 let activated_data: Vec<f32> = data
653 .iter()
654 .map(|&val| val * (1.0 / (1.0 + (-val).exp())))
655 .collect();
656
657 from_vec(
658 activated_data,
659 x.shape().dims(),
660 torsh_core::device::DeviceType::Cpu,
661 )
662 .expect("PIGNN swish tensor creation should succeed")
663 }
664}
665
666impl GraphLayer for PhysicsInformedGNN {
667 fn forward(&self, graph: &GraphData) -> GraphData {
668 self.forward(graph)
669 }
670
671 fn parameters(&self) -> Vec<Tensor> {
672 let mut params = Vec::new();
673
674 for layer in &self.layers {
675 params.push(layer.clone_data());
676 }
677
678 if let Some(ref bias) = self.bias {
679 params.push(bias.clone_data());
680 }
681
682 params
683 }
684}
685
686#[derive(Debug)]
688pub struct MultiScaleGNO {
689 in_features: usize,
690 out_features: usize,
691 num_scales: usize,
692 hidden_features: usize,
693
694 scale_operators: Vec<Parameter>,
696
697 fusion_weights: Parameter,
699
700 output_projection: Parameter,
702
703 bias: Option<Parameter>,
704}
705
706impl MultiScaleGNO {
707 pub fn new(
709 in_features: usize,
710 out_features: usize,
711 num_scales: usize,
712 hidden_features: usize,
713 bias: bool,
714 ) -> Self {
715 let mut scale_operators = Vec::new();
716
717 for _ in 0..num_scales {
719 scale_operators.push(Parameter::new(
720 randn(&[in_features, hidden_features])
721 .expect("failed to create scale operator tensor"),
722 ));
723 }
724
725 let fusion_weights = Parameter::new(
726 randn(&[num_scales * hidden_features, hidden_features])
727 .expect("failed to create fusion_weights tensor"),
728 );
729
730 let output_projection = Parameter::new(
731 randn(&[hidden_features, out_features])
732 .expect("failed to create MultiScaleGNO output_projection tensor"),
733 );
734
735 let bias = if bias {
736 Some(Parameter::new(
737 zeros::<f32>(&[out_features]).expect("failed to create MultiScaleGNO bias tensor"),
738 ))
739 } else {
740 None
741 };
742
743 Self {
744 in_features,
745 out_features,
746 num_scales,
747 hidden_features,
748 scale_operators,
749 fusion_weights,
750 output_projection,
751 bias,
752 }
753 }
754
755 pub fn forward(&self, graph: &GraphData) -> GraphData {
757 let mut scale_features = Vec::new();
758
759 for scale in 0..self.num_scales {
761 let scale_graph = self.coarsen_graph(graph, scale);
762 let features = self.process_scale(&scale_graph, scale);
763 let upsampled = self.upsample_features(&features, graph.num_nodes);
764 scale_features.push(upsampled);
765 }
766
767 let fused_features = self.fuse_scales(&scale_features);
769
770 let mut output = fused_features
772 .matmul(&self.output_projection.clone_data())
773 .expect("operation should succeed");
774
775 if let Some(ref bias) = self.bias {
777 output = output
778 .add(&bias.clone_data())
779 .expect("operation should succeed");
780 }
781
782 let mut output_graph = graph.clone();
784 output_graph.x = output;
785 output_graph
786 }
787
788 fn coarsen_graph(&self, graph: &GraphData, scale: usize) -> GraphData {
790 let coarsening_factor = 2_usize.pow(scale as u32);
791 let coarse_nodes = (graph.num_nodes + coarsening_factor - 1) / coarsening_factor;
792
793 let mut coarse_features = Vec::new();
795
796 for coarse_id in 0..coarse_nodes {
797 let start_node = coarse_id * coarsening_factor;
798 let end_node = ((coarse_id + 1) * coarsening_factor).min(graph.num_nodes);
799
800 let mut sum_features = vec![0.0f32; graph.x.shape().dims()[1]];
802 let mut count = 0;
803
804 for node_id in start_node..end_node {
805 let features = graph
806 .x
807 .slice_tensor(0, node_id, node_id + 1)
808 .expect("node feature slice should succeed");
809 let feature_data = features.to_vec().expect("conversion should succeed");
810
811 for (i, &val) in feature_data.iter().enumerate() {
812 if i < sum_features.len() {
813 sum_features[i] += val;
814 }
815 }
816 count += 1;
817 }
818
819 if count > 0 {
821 for val in &mut sum_features {
822 *val /= count as f32;
823 }
824 }
825
826 coarse_features.extend(sum_features);
827 }
828
829 let coarse_x = from_vec(
830 coarse_features,
831 &[coarse_nodes, graph.x.shape().dims()[1]],
832 torsh_core::device::DeviceType::Cpu,
833 )
834 .expect("coarse node features tensor creation should succeed");
835
836 let mut coarse_edges = Vec::new();
838 for i in 0..coarse_nodes.saturating_sub(1) {
839 coarse_edges.push(i as f32);
840 coarse_edges.push((i + 1) as f32);
841 }
842
843 let coarse_edge_index = from_vec(
844 coarse_edges,
845 &[2, coarse_nodes.saturating_sub(1)],
846 torsh_core::device::DeviceType::Cpu,
847 )
848 .expect("coarse edge index tensor creation should succeed");
849
850 GraphData::new(coarse_x, coarse_edge_index)
851 }
852
853 fn process_scale(&self, graph: &GraphData, scale: usize) -> Tensor {
855 let operator = &self.scale_operators[scale];
856 graph
857 .x
858 .matmul(&operator.clone_data())
859 .expect("operation should succeed")
860 }
861
862 fn upsample_features(&self, features: &Tensor, target_nodes: usize) -> Tensor {
864 let current_nodes = features.shape().dims()[0];
865 let feature_dim = features.shape().dims()[1];
866
867 if current_nodes >= target_nodes {
868 return features
870 .slice_tensor(0, 0, target_nodes)
871 .expect("feature truncation should succeed");
872 }
873
874 let feature_data = features.to_vec().expect("conversion should succeed");
876 let mut upsampled_data = Vec::new();
877
878 for target_id in 0..target_nodes {
879 let source_id = (target_id * current_nodes) / target_nodes;
880 let start_idx = source_id * feature_dim;
881 let end_idx = start_idx + feature_dim;
882
883 if end_idx <= feature_data.len() {
884 upsampled_data.extend(&feature_data[start_idx..end_idx]);
885 } else {
886 upsampled_data.extend(vec![0.0f32; feature_dim]);
888 }
889 }
890
891 from_vec(
892 upsampled_data,
893 &[target_nodes, feature_dim],
894 torsh_core::device::DeviceType::Cpu,
895 )
896 .expect("upsampled features tensor creation should succeed")
897 }
898
899 fn fuse_scales(&self, scale_features: &[Tensor]) -> Tensor {
901 let mut concatenated_data = Vec::new();
903 let num_nodes = scale_features[0].shape().dims()[0];
904
905 for node_id in 0..num_nodes {
906 for scale_feature in scale_features {
907 let node_features = scale_feature
908 .slice_tensor(0, node_id, node_id + 1)
909 .expect("scale feature slice should succeed");
910 let feature_data = node_features.to_vec().expect("conversion should succeed");
911 concatenated_data.extend(feature_data);
912 }
913 }
914
915 let concatenated = from_vec(
916 concatenated_data,
917 &[num_nodes, self.num_scales * self.hidden_features],
918 torsh_core::device::DeviceType::Cpu,
919 )
920 .expect("concatenated scale features tensor creation should succeed");
921
922 concatenated
924 .matmul(&self.fusion_weights.clone_data())
925 .expect("operation should succeed")
926 }
927}
928
929impl GraphLayer for MultiScaleGNO {
930 fn forward(&self, graph: &GraphData) -> GraphData {
931 self.forward(graph)
932 }
933
934 fn parameters(&self) -> Vec<Tensor> {
935 let mut params = vec![
936 self.fusion_weights.clone_data(),
937 self.output_projection.clone_data(),
938 ];
939
940 for operator in &self.scale_operators {
941 params.push(operator.clone_data());
942 }
943
944 if let Some(ref bias) = self.bias {
945 params.push(bias.clone_data());
946 }
947
948 params
949 }
950}
951
952pub mod utils {
954 use super::*;
955
956 pub fn compute_spectral_features(graph: &GraphData, num_eigenvalues: usize) -> Tensor {
958 let num_nodes = graph.num_nodes;
960 let mut spectral_data = Vec::new();
961
962 for i in 0..num_nodes {
963 for j in 0..num_eigenvalues {
964 let eigenvalue = (j as f32 + 1.0) / num_eigenvalues as f32;
965 let eigenvector_val = (std::f32::consts::PI * (i as f32 + 1.0) * (j as f32 + 1.0)
966 / num_nodes as f32)
967 .sin();
968 spectral_data.push(eigenvalue * eigenvector_val);
969 }
970 }
971
972 from_vec(
973 spectral_data,
974 &[num_nodes, num_eigenvalues],
975 torsh_core::device::DeviceType::Cpu,
976 )
977 .expect("spectral features tensor creation should succeed")
978 }
979
980 pub fn generate_operator_data(
982 num_graphs: usize,
983 num_nodes: usize,
984 feature_dim: usize,
985 ) -> Vec<(GraphData, GraphData)> {
986 let mut rng = scirs2_core::random::thread_rng();
987 let mut data_pairs = Vec::new();
988
989 for _ in 0..num_graphs {
990 let input_features = randn(&[num_nodes, feature_dim])
992 .expect("input features tensor creation should succeed");
993 let mut edge_data = Vec::new();
994
995 for _ in 0..(num_nodes * 2) {
997 let src = rng.gen_range(0..num_nodes) as f32;
998 let dst = rng.gen_range(0..num_nodes) as f32;
999 edge_data.push(src);
1000 edge_data.push(dst);
1001 }
1002
1003 let edge_index = from_vec(
1004 edge_data,
1005 &[2, num_nodes * 2],
1006 torsh_core::device::DeviceType::Cpu,
1007 )
1008 .expect("edge index tensor creation should succeed");
1009
1010 let input_graph = GraphData::new(input_features, edge_index);
1011
1012 let output_features = input_graph
1014 .x
1015 .mul_scalar(2.0)
1016 .expect("output transformation should succeed");
1017 let output_graph = GraphData::new(output_features, input_graph.edge_index.clone());
1018
1019 data_pairs.push((input_graph, output_graph));
1020 }
1021
1022 data_pairs
1023 }
1024
1025 pub fn compute_operator_error(predicted: &GraphData, target: &GraphData) -> f32 {
1027 let pred_data = predicted.x.to_vec().expect("conversion should succeed");
1028 let target_data = target.x.to_vec().expect("conversion should succeed");
1029
1030 let mut mse = 0.0;
1031 let mut count = 0;
1032
1033 for (pred, target) in pred_data.iter().zip(target_data.iter()) {
1034 mse += (pred - target).powi(2);
1035 count += 1;
1036 }
1037
1038 if count > 0 {
1039 mse / count as f32
1040 } else {
1041 0.0
1042 }
1043 }
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use torsh_core::device::DeviceType;
1050
1051 #[test]
1052 fn test_graph_fno_creation() {
1053 let fno = GraphFNO::new(4, 8, 16, 10, 3, true);
1054 assert_eq!(fno.in_features, 4);
1055 assert_eq!(fno.out_features, 8);
1056 assert_eq!(fno.hidden_features, 16);
1057 assert_eq!(fno.num_modes, 10);
1058 assert_eq!(fno.num_layers, 3);
1059 }
1060
1061 #[test]
1062 fn test_graph_fno_forward() {
1063 let features = randn(&[5, 4]).unwrap();
1064 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
1065 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
1066 let graph = GraphData::new(features, edge_index);
1067
1068 let fno = GraphFNO::new(4, 8, 16, 10, 3, true);
1069 let output = fno.forward(&graph);
1070
1071 assert_eq!(output.x.shape().dims(), &[5, 8]);
1072 }
1073
1074 #[test]
1075 fn test_graph_deeponet_creation() {
1076 let deeponet = GraphDeepONet::new(3, 4, 16, 8, 10, 3, true);
1077 assert_eq!(deeponet.trunk_net_features, 3);
1078 assert_eq!(deeponet.branch_net_features, 4);
1079 assert_eq!(deeponet.output_features, 8);
1080 assert_eq!(deeponet.num_sensors, 10);
1081 }
1082
1083 #[test]
1084 fn test_physics_informed_gnn() {
1085 let features = randn(&[4, 3]).unwrap();
1086 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
1087 let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
1088 let graph = GraphData::new(features, edge_index);
1089
1090 let pignn = PhysicsInformedGNN::new(3, 6, 12, 2, 0.1, 0.05, true);
1091 let output = pignn.forward(&graph);
1092
1093 assert_eq!(output.x.shape().dims(), &[4, 6]);
1094 }
1095
1096 #[test]
1097 fn test_multi_scale_gno() {
1098 let features = randn(&[8, 4]).unwrap();
1099 let edges = vec![
1100 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0,
1101 ];
1102 let edge_index = from_vec(edges, &[2, 7], DeviceType::Cpu).unwrap();
1103 let graph = GraphData::new(features, edge_index);
1104
1105 let ms_gno = MultiScaleGNO::new(4, 6, 3, 8, true);
1106 let output = ms_gno.forward(&graph);
1107
1108 assert_eq!(output.x.shape().dims(), &[8, 6]);
1109 }
1110
1111 #[test]
1112 fn test_spectral_features() {
1113 let features = randn(&[6, 3]).unwrap();
1114 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0];
1115 let edge_index = from_vec(edges, &[2, 5], DeviceType::Cpu).unwrap();
1116 let graph = GraphData::new(features, edge_index);
1117
1118 let spectral_features = utils::compute_spectral_features(&graph, 4);
1119 assert_eq!(spectral_features.shape().dims(), &[6, 4]);
1120 }
1121
1122 #[test]
1123 fn test_operator_data_generation() {
1124 let data_pairs = utils::generate_operator_data(3, 5, 4);
1125 assert_eq!(data_pairs.len(), 3);
1126
1127 for (input, output) in &data_pairs {
1128 assert_eq!(input.num_nodes, 5);
1129 assert_eq!(output.num_nodes, 5);
1130 assert_eq!(input.x.shape().dims()[1], 4);
1131 assert_eq!(output.x.shape().dims()[1], 4);
1132 }
1133 }
1134
1135 #[test]
1136 fn test_operator_error_computation() {
1137 let features1 = from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], DeviceType::Cpu).unwrap();
1138 let features2 = from_vec(vec![1.1, 2.1, 3.1, 4.1], &[2, 2], DeviceType::Cpu).unwrap();
1139 let edges = vec![0.0, 1.0];
1140 let edge_index = from_vec(edges, &[2, 1], DeviceType::Cpu).unwrap();
1141
1142 let graph1 = GraphData::new(features1, edge_index.clone());
1143 let graph2 = GraphData::new(features2, edge_index);
1144
1145 let error = utils::compute_operator_error(&graph1, &graph2);
1146 assert!(error > 0.0);
1147 assert!(error < 1.0); }
1149}