Skip to main content

torsh_graph/
neural_operators.rs

1//! Graph Neural Operators
2//!
3//! Advanced implementation of graph neural operators for learning continuous
4//! functions on graphs. Inspired by Neural Operator Theory and Physics-Informed
5//! Neural Networks (PINNs) for graph-structured data.
6//!
7//! # Features:
8//! - Graph Fourier Neural Operators (GraphFNO)
9//! - Graph DeepONet for operator learning
10//! - Physics-informed graph neural networks
11//! - Multi-scale graph operators
12//! - Spectral graph convolutions with learnable kernels
13//! - Graph wavelet neural operators
14
15// Framework infrastructure - components designed for future use
16#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use torsh_tensor::{
20    creation::{from_vec, randn, zeros},
21    Tensor,
22};
23
24/// Graph Fourier Neural Operator (GraphFNO)
25/// Learns operators in the spectral domain of graphs
26#[derive(Debug)]
27pub struct GraphFNO {
28    in_features: usize,
29    out_features: usize,
30    hidden_features: usize,
31    num_modes: usize,
32    num_layers: usize,
33
34    // Fourier layers
35    fourier_weights: Vec<Parameter>,
36    conv_weights: Vec<Parameter>,
37
38    // Input/output projections
39    input_projection: Parameter,
40    output_projection: Parameter,
41
42    // Bias terms
43    bias: Option<Parameter>,
44}
45
46impl GraphFNO {
47    /// Create a new Graph Fourier Neural Operator
48    pub fn new(
49        in_features: usize,
50        out_features: usize,
51        hidden_features: usize,
52        num_modes: usize,
53        num_layers: usize,
54        bias: bool,
55    ) -> Self {
56        let mut fourier_weights = Vec::new();
57        let mut conv_weights = Vec::new();
58
59        // Initialize Fourier weights for each layer
60        for _ in 0..num_layers {
61            fourier_weights.push(Parameter::new(
62                randn(&[hidden_features, hidden_features, num_modes])
63                    .expect("failed to create fourier_weights tensor"),
64            ));
65            conv_weights.push(Parameter::new(
66                randn(&[hidden_features, hidden_features])
67                    .expect("failed to create conv_weights tensor"),
68            ));
69        }
70
71        let input_projection = Parameter::new(
72            randn(&[in_features, hidden_features])
73                .expect("failed to create input_projection tensor"),
74        );
75        let output_projection = Parameter::new(
76            randn(&[hidden_features, out_features])
77                .expect("failed to create output_projection tensor"),
78        );
79
80        let bias = if bias {
81            Some(Parameter::new(
82                zeros::<f32>(&[out_features]).expect("failed to create bias tensor"),
83            ))
84        } else {
85            None
86        };
87
88        Self {
89            in_features,
90            out_features,
91            hidden_features,
92            num_modes,
93            num_layers,
94            fourier_weights,
95            conv_weights,
96            input_projection,
97            output_projection,
98            bias,
99        }
100    }
101
102    /// Forward pass through GraphFNO
103    pub fn forward(&self, graph: &GraphData) -> GraphData {
104        let _num_nodes = graph.num_nodes;
105
106        // Input projection
107        let mut x = graph
108            .x
109            .matmul(&self.input_projection.clone_data())
110            .expect("operation should succeed");
111
112        // Apply Fourier layers
113        for layer in 0..self.num_layers {
114            x = self.fourier_layer(&x, layer, graph);
115        }
116
117        // Output projection
118        let mut output = x
119            .matmul(&self.output_projection.clone_data())
120            .expect("operation should succeed");
121
122        // Add bias if present
123        if let Some(ref bias) = self.bias {
124            output = output
125                .add(&bias.clone_data())
126                .expect("operation should succeed");
127        }
128
129        // Create output graph
130        let mut output_graph = graph.clone();
131        output_graph.x = output;
132        output_graph
133    }
134
135    /// Apply a single Fourier layer
136    fn fourier_layer(&self, x: &Tensor, layer: usize, graph: &GraphData) -> Tensor {
137        // Step 1: Apply graph Fourier transform (simplified)
138        let fourier_x = self.graph_fourier_transform(x, graph);
139
140        // Step 2: Apply learnable Fourier weights
141        let fourier_weights = &self.fourier_weights[layer];
142        let spectral_conv = self.spectral_convolution(&fourier_x, fourier_weights);
143
144        // Step 3: Inverse Fourier transform
145        let spatial_features = self.inverse_graph_fourier_transform(&spectral_conv, graph);
146
147        // Step 4: Apply spatial convolution
148        let conv_weights = &self.conv_weights[layer];
149        let conv_output = spatial_features
150            .matmul(&conv_weights.clone_data())
151            .expect("operation should succeed");
152
153        // Step 5: Residual connection and activation
154        let residual = x.add(&conv_output).expect("operation should succeed");
155
156        // Apply ReLU activation (simplified)
157        self.relu(&residual)
158    }
159
160    /// Graph Fourier Transform (simplified eigendecomposition)
161    fn graph_fourier_transform(&self, x: &Tensor, graph: &GraphData) -> Tensor {
162        // For simplicity, we'll use a learned transformation matrix
163        // In practice, this would use graph Laplacian eigendecomposition
164        let num_nodes = graph.num_nodes;
165
166        // Create a simple transformation that captures spectral properties
167        let mut transform_data = Vec::new();
168        for i in 0..num_nodes {
169            for j in 0..self.num_modes {
170                let freq = (j as f32 + 1.0) * std::f32::consts::PI / num_nodes as f32;
171                let basis = (freq * i as f32).cos();
172                transform_data.push(basis);
173            }
174        }
175
176        let transform_matrix = from_vec(
177            transform_data,
178            &[num_nodes, self.num_modes],
179            torsh_core::device::DeviceType::Cpu,
180        )
181        .expect("GFT transform matrix creation should succeed");
182
183        // Project to spectral domain
184        transform_matrix
185            .t()
186            .expect("operation should succeed")
187            .matmul(x)
188            .expect("operation should succeed")
189    }
190
191    /// Inverse Graph Fourier Transform
192    fn inverse_graph_fourier_transform(&self, fourier_x: &Tensor, graph: &GraphData) -> Tensor {
193        let num_nodes = graph.num_nodes;
194
195        // Create inverse transformation matrix
196        let mut inv_transform_data = Vec::new();
197        for i in 0..num_nodes {
198            for j in 0..self.num_modes {
199                let freq = (j as f32 + 1.0) * std::f32::consts::PI / num_nodes as f32;
200                let basis = (freq * i as f32).cos();
201                inv_transform_data.push(basis);
202            }
203        }
204
205        let inv_transform_matrix = from_vec(
206            inv_transform_data,
207            &[num_nodes, self.num_modes],
208            torsh_core::device::DeviceType::Cpu,
209        )
210        .expect("inverse GFT transform matrix creation should succeed");
211
212        // Project back to spatial domain
213        inv_transform_matrix
214            .matmul(fourier_x)
215            .expect("operation should succeed")
216    }
217
218    /// Spectral convolution in Fourier domain
219    fn spectral_convolution(&self, fourier_x: &Tensor, weights: &Parameter) -> Tensor {
220        // Apply Fourier weights (simplified)
221        let weight_data = weights.clone_data();
222
223        // For simplicity, use only the first mode slice
224        // In practice, this would involve complex multiplication across all modes
225        let weight_2d = weight_data
226            .slice_tensor(2, 0, 1)
227            .expect("spectral weight slice should succeed")
228            .squeeze_tensor(2)
229            .expect("spectral weight squeeze should succeed");
230
231        fourier_x
232            .matmul(&weight_2d)
233            .expect("operation should succeed")
234    }
235
236    /// ReLU activation function
237    fn relu(&self, x: &Tensor) -> Tensor {
238        // Simplified ReLU - clamp negative values to 0
239        let data = x.to_vec().expect("conversion should succeed");
240        let activated_data: Vec<f32> = data.iter().map(|&val| val.max(0.0)).collect();
241
242        from_vec(
243            activated_data,
244            x.shape().dims(),
245            torsh_core::device::DeviceType::Cpu,
246        )
247        .expect("GraphFNO relu tensor creation should succeed")
248    }
249}
250
251impl GraphLayer for GraphFNO {
252    fn forward(&self, graph: &GraphData) -> GraphData {
253        self.forward(graph)
254    }
255
256    fn parameters(&self) -> Vec<Tensor> {
257        let mut params = vec![
258            self.input_projection.clone_data(),
259            self.output_projection.clone_data(),
260        ];
261
262        for weight in &self.fourier_weights {
263            params.push(weight.clone_data());
264        }
265
266        for weight in &self.conv_weights {
267            params.push(weight.clone_data());
268        }
269
270        if let Some(ref bias) = self.bias {
271            params.push(bias.clone_data());
272        }
273
274        params
275    }
276}
277
278/// Graph DeepONet for operator learning on graphs
279#[derive(Debug)]
280pub struct GraphDeepONet {
281    trunk_net_features: usize,
282    branch_net_features: usize,
283    hidden_features: usize,
284    output_features: usize,
285    num_sensors: usize,
286
287    // Branch network (processes input functions)
288    branch_layers: Vec<Parameter>,
289
290    // Trunk network (processes locations/coordinates)
291    trunk_layers: Vec<Parameter>,
292
293    // Output bias
294    bias: Option<Parameter>,
295}
296
297impl GraphDeepONet {
298    /// Create a new Graph DeepONet
299    pub fn new(
300        trunk_net_features: usize,
301        branch_net_features: usize,
302        hidden_features: usize,
303        output_features: usize,
304        num_sensors: usize,
305        num_layers: usize,
306        bias: bool,
307    ) -> Self {
308        let mut branch_layers = Vec::new();
309        let mut trunk_layers = Vec::new();
310
311        // Initialize branch network layers
312        for i in 0..num_layers {
313            let in_dim = if i == 0 { num_sensors } else { hidden_features };
314            let out_dim = if i == num_layers - 1 {
315                output_features
316            } else {
317                hidden_features
318            };
319            branch_layers.push(Parameter::new(
320                randn(&[in_dim, out_dim]).expect("failed to create branch layer tensor"),
321            ));
322        }
323
324        // Initialize trunk network layers
325        for i in 0..num_layers {
326            let in_dim = if i == 0 {
327                trunk_net_features
328            } else {
329                hidden_features
330            };
331            let out_dim = if i == num_layers - 1 {
332                output_features
333            } else {
334                hidden_features
335            };
336            trunk_layers.push(Parameter::new(
337                randn(&[in_dim, out_dim]).expect("failed to create trunk layer tensor"),
338            ));
339        }
340
341        let bias = if bias {
342            Some(Parameter::new(
343                zeros::<f32>(&[output_features]).expect("failed to create DeepONet bias tensor"),
344            ))
345        } else {
346            None
347        };
348
349        Self {
350            trunk_net_features,
351            branch_net_features,
352            hidden_features,
353            output_features,
354            num_sensors,
355            branch_layers,
356            trunk_layers,
357            bias,
358        }
359    }
360
361    /// Forward pass through Graph DeepONet
362    pub fn forward(
363        &self,
364        graph: &GraphData,
365        sensor_data: &Tensor,
366        locations: &Tensor,
367    ) -> GraphData {
368        // Process sensor data through branch network
369        let branch_output = self.forward_branch_net(sensor_data);
370
371        // Process locations through trunk network
372        let trunk_output = self.forward_trunk_net(locations);
373
374        // Combine branch and trunk outputs (dot product)
375        let combined = self.combine_outputs(&branch_output, &trunk_output);
376
377        // Add bias if present
378        let mut output = combined;
379        if let Some(ref bias) = self.bias {
380            output = output
381                .add(&bias.clone_data())
382                .expect("operation should succeed");
383        }
384
385        // Create output graph
386        let mut output_graph = graph.clone();
387        output_graph.x = output;
388        output_graph
389    }
390
391    /// Forward pass through branch network
392    fn forward_branch_net(&self, sensor_data: &Tensor) -> Tensor {
393        let mut x = sensor_data.clone();
394
395        for (i, layer) in self.branch_layers.iter().enumerate() {
396            x = x
397                .matmul(&layer.clone_data())
398                .expect("operation should succeed");
399
400            // Apply activation function except for last layer
401            if i < self.branch_layers.len() - 1 {
402                x = self.tanh(&x);
403            }
404        }
405
406        x
407    }
408
409    /// Forward pass through trunk network
410    fn forward_trunk_net(&self, locations: &Tensor) -> Tensor {
411        let mut x = locations.clone();
412
413        for (i, layer) in self.trunk_layers.iter().enumerate() {
414            x = x
415                .matmul(&layer.clone_data())
416                .expect("operation should succeed");
417
418            // Apply activation function except for last layer
419            if i < self.trunk_layers.len() - 1 {
420                x = self.tanh(&x);
421            }
422        }
423
424        x
425    }
426
427    /// Combine branch and trunk network outputs
428    fn combine_outputs(&self, branch_output: &Tensor, trunk_output: &Tensor) -> Tensor {
429        // Element-wise multiplication and sum
430        branch_output
431            .mul(trunk_output)
432            .expect("operation should succeed")
433    }
434
435    /// Tanh activation function
436    fn tanh(&self, x: &Tensor) -> Tensor {
437        let data = x.to_vec().expect("conversion should succeed");
438        let activated_data: Vec<f32> = data.iter().map(|&val| val.tanh()).collect();
439
440        from_vec(
441            activated_data,
442            x.shape().dims(),
443            torsh_core::device::DeviceType::Cpu,
444        )
445        .expect("DeepONet tanh tensor creation should succeed")
446    }
447}
448
449impl GraphLayer for GraphDeepONet {
450    fn forward(&self, graph: &GraphData) -> GraphData {
451        // Default forward using graph features as both sensor data and locations
452        let sensor_data = graph
453            .x
454            .slice_tensor(1, 0, self.num_sensors.min(graph.x.shape().dims()[1]))
455            .expect("sensor data slice should succeed");
456        let locations = graph.x.clone();
457
458        self.forward(graph, &sensor_data, &locations)
459    }
460
461    fn parameters(&self) -> Vec<Tensor> {
462        let mut params = Vec::new();
463
464        for layer in &self.branch_layers {
465            params.push(layer.clone_data());
466        }
467
468        for layer in &self.trunk_layers {
469            params.push(layer.clone_data());
470        }
471
472        if let Some(ref bias) = self.bias {
473            params.push(bias.clone_data());
474        }
475
476        params
477    }
478}
479
480/// Physics-Informed Graph Neural Network
481#[derive(Debug)]
482pub struct PhysicsInformedGNN {
483    in_features: usize,
484    out_features: usize,
485    hidden_features: usize,
486
487    // Neural network layers
488    layers: Vec<Parameter>,
489
490    // Physics constraints
491    diffusion_coefficient: f32,
492    reaction_rate: f32,
493
494    // Bias
495    bias: Option<Parameter>,
496}
497
498impl PhysicsInformedGNN {
499    /// Create a new Physics-Informed GNN
500    pub fn new(
501        in_features: usize,
502        out_features: usize,
503        hidden_features: usize,
504        num_layers: usize,
505        diffusion_coefficient: f32,
506        reaction_rate: f32,
507        bias: bool,
508    ) -> Self {
509        let mut layers = Vec::new();
510
511        for i in 0..num_layers {
512            let in_dim = if i == 0 { in_features } else { hidden_features };
513            let out_dim = if i == num_layers - 1 {
514                out_features
515            } else {
516                hidden_features
517            };
518            layers.push(Parameter::new(
519                randn(&[in_dim, out_dim]).expect("failed to create PIGNN layer tensor"),
520            ));
521        }
522
523        let bias = if bias {
524            Some(Parameter::new(
525                zeros::<f32>(&[out_features]).expect("failed to create PIGNN bias tensor"),
526            ))
527        } else {
528            None
529        };
530
531        Self {
532            in_features,
533            out_features,
534            hidden_features,
535            layers,
536            diffusion_coefficient,
537            reaction_rate,
538            bias,
539        }
540    }
541
542    /// Forward pass with physics constraints
543    pub fn forward(&self, graph: &GraphData) -> GraphData {
544        // Neural network forward pass
545        let mut x = graph.x.clone();
546
547        for (i, layer) in self.layers.iter().enumerate() {
548            x = x
549                .matmul(&layer.clone_data())
550                .expect("operation should succeed");
551
552            // Apply activation except for last layer
553            if i < self.layers.len() - 1 {
554                x = self.swish(&x);
555            }
556        }
557
558        // Apply physics constraints
559        let physics_constrained = self.apply_physics_constraints(&x, graph);
560
561        // Add bias if present
562        let mut output = physics_constrained;
563        if let Some(ref bias) = self.bias {
564            output = output
565                .add(&bias.clone_data())
566                .expect("operation should succeed");
567        }
568
569        // Create output graph
570        let mut output_graph = graph.clone();
571        output_graph.x = output;
572        output_graph
573    }
574
575    /// Apply physics constraints (diffusion-reaction equation)
576    fn apply_physics_constraints(&self, prediction: &Tensor, graph: &GraphData) -> Tensor {
577        // Compute graph Laplacian for diffusion term
578        let laplacian = self.compute_graph_laplacian(graph);
579
580        // Diffusion term: D * L * u
581        let diffusion_term = laplacian
582            .matmul(prediction)
583            .expect("operation should succeed")
584            .mul_scalar(self.diffusion_coefficient)
585            .expect("operation should succeed");
586
587        // Reaction term: r * u
588        let reaction_term = prediction
589            .mul_scalar(self.reaction_rate)
590            .expect("operation should succeed");
591
592        // Combine terms (simplified physics equation)
593        prediction
594            .add(&diffusion_term)
595            .expect("operation should succeed")
596            .add(&reaction_term)
597            .expect("operation should succeed")
598    }
599
600    /// Compute graph Laplacian matrix
601    fn compute_graph_laplacian(&self, graph: &GraphData) -> Tensor {
602        let num_nodes = graph.num_nodes;
603        let _num_edges = graph.num_edges;
604
605        // Initialize adjacency matrix
606        let mut adj_data = vec![0.0f32; num_nodes * num_nodes];
607
608        // Fill adjacency matrix from edge_index
609        let edge_data = graph
610            .edge_index
611            .to_vec()
612            .expect("conversion should succeed");
613        for i in (0..edge_data.len()).step_by(2) {
614            if i + 1 < edge_data.len() {
615                let src = edge_data[i] as usize;
616                let dst = edge_data[i + 1] as usize;
617
618                if src < num_nodes && dst < num_nodes {
619                    adj_data[src * num_nodes + dst] = 1.0;
620                    adj_data[dst * num_nodes + src] = 1.0; // Undirected graph
621                }
622            }
623        }
624
625        // Compute degree matrix
626        let mut degree_data = vec![0.0f32; num_nodes * num_nodes];
627        for i in 0..num_nodes {
628            let mut degree = 0.0;
629            for j in 0..num_nodes {
630                degree += adj_data[i * num_nodes + j];
631            }
632            degree_data[i * num_nodes + i] = degree;
633        }
634
635        // Laplacian = Degree - Adjacency
636        let mut laplacian_data = Vec::new();
637        for i in 0..num_nodes * num_nodes {
638            laplacian_data.push(degree_data[i] - adj_data[i]);
639        }
640
641        from_vec(
642            laplacian_data,
643            &[num_nodes, num_nodes],
644            torsh_core::device::DeviceType::Cpu,
645        )
646        .expect("graph Laplacian tensor creation should succeed")
647    }
648
649    /// Swish activation function (x * sigmoid(x))
650    fn swish(&self, x: &Tensor) -> Tensor {
651        let data = x.to_vec().expect("conversion should succeed");
652        let activated_data: Vec<f32> = data
653            .iter()
654            .map(|&val| val * (1.0 / (1.0 + (-val).exp())))
655            .collect();
656
657        from_vec(
658            activated_data,
659            x.shape().dims(),
660            torsh_core::device::DeviceType::Cpu,
661        )
662        .expect("PIGNN swish tensor creation should succeed")
663    }
664}
665
666impl GraphLayer for PhysicsInformedGNN {
667    fn forward(&self, graph: &GraphData) -> GraphData {
668        self.forward(graph)
669    }
670
671    fn parameters(&self) -> Vec<Tensor> {
672        let mut params = Vec::new();
673
674        for layer in &self.layers {
675            params.push(layer.clone_data());
676        }
677
678        if let Some(ref bias) = self.bias {
679            params.push(bias.clone_data());
680        }
681
682        params
683    }
684}
685
686/// Multi-scale Graph Neural Operator
687#[derive(Debug)]
688pub struct MultiScaleGNO {
689    in_features: usize,
690    out_features: usize,
691    num_scales: usize,
692    hidden_features: usize,
693
694    // Scale-specific operators
695    scale_operators: Vec<Parameter>,
696
697    // Cross-scale fusion
698    fusion_weights: Parameter,
699
700    // Output projection
701    output_projection: Parameter,
702
703    bias: Option<Parameter>,
704}
705
706impl MultiScaleGNO {
707    /// Create a new Multi-scale Graph Neural Operator
708    pub fn new(
709        in_features: usize,
710        out_features: usize,
711        num_scales: usize,
712        hidden_features: usize,
713        bias: bool,
714    ) -> Self {
715        let mut scale_operators = Vec::new();
716
717        // Initialize scale-specific operators
718        for _ in 0..num_scales {
719            scale_operators.push(Parameter::new(
720                randn(&[in_features, hidden_features])
721                    .expect("failed to create scale operator tensor"),
722            ));
723        }
724
725        let fusion_weights = Parameter::new(
726            randn(&[num_scales * hidden_features, hidden_features])
727                .expect("failed to create fusion_weights tensor"),
728        );
729
730        let output_projection = Parameter::new(
731            randn(&[hidden_features, out_features])
732                .expect("failed to create MultiScaleGNO output_projection tensor"),
733        );
734
735        let bias = if bias {
736            Some(Parameter::new(
737                zeros::<f32>(&[out_features]).expect("failed to create MultiScaleGNO bias tensor"),
738            ))
739        } else {
740            None
741        };
742
743        Self {
744            in_features,
745            out_features,
746            num_scales,
747            hidden_features,
748            scale_operators,
749            fusion_weights,
750            output_projection,
751            bias,
752        }
753    }
754
755    /// Forward pass through multi-scale operator
756    pub fn forward(&self, graph: &GraphData) -> GraphData {
757        let mut scale_features = Vec::new();
758
759        // Process each scale
760        for scale in 0..self.num_scales {
761            let scale_graph = self.coarsen_graph(graph, scale);
762            let features = self.process_scale(&scale_graph, scale);
763            let upsampled = self.upsample_features(&features, graph.num_nodes);
764            scale_features.push(upsampled);
765        }
766
767        // Fuse multi-scale features
768        let fused_features = self.fuse_scales(&scale_features);
769
770        // Output projection
771        let mut output = fused_features
772            .matmul(&self.output_projection.clone_data())
773            .expect("operation should succeed");
774
775        // Add bias if present
776        if let Some(ref bias) = self.bias {
777            output = output
778                .add(&bias.clone_data())
779                .expect("operation should succeed");
780        }
781
782        // Create output graph
783        let mut output_graph = graph.clone();
784        output_graph.x = output;
785        output_graph
786    }
787
788    /// Coarsen graph for multi-scale processing
789    fn coarsen_graph(&self, graph: &GraphData, scale: usize) -> GraphData {
790        let coarsening_factor = 2_usize.pow(scale as u32);
791        let coarse_nodes = (graph.num_nodes + coarsening_factor - 1) / coarsening_factor;
792
793        // Simple node pooling - average features of neighboring nodes
794        let mut coarse_features = Vec::new();
795
796        for coarse_id in 0..coarse_nodes {
797            let start_node = coarse_id * coarsening_factor;
798            let end_node = ((coarse_id + 1) * coarsening_factor).min(graph.num_nodes);
799
800            // Average features of nodes in this coarse group
801            let mut sum_features = vec![0.0f32; graph.x.shape().dims()[1]];
802            let mut count = 0;
803
804            for node_id in start_node..end_node {
805                let features = graph
806                    .x
807                    .slice_tensor(0, node_id, node_id + 1)
808                    .expect("node feature slice should succeed");
809                let feature_data = features.to_vec().expect("conversion should succeed");
810
811                for (i, &val) in feature_data.iter().enumerate() {
812                    if i < sum_features.len() {
813                        sum_features[i] += val;
814                    }
815                }
816                count += 1;
817            }
818
819            // Normalize
820            if count > 0 {
821                for val in &mut sum_features {
822                    *val /= count as f32;
823                }
824            }
825
826            coarse_features.extend(sum_features);
827        }
828
829        let coarse_x = from_vec(
830            coarse_features,
831            &[coarse_nodes, graph.x.shape().dims()[1]],
832            torsh_core::device::DeviceType::Cpu,
833        )
834        .expect("coarse node features tensor creation should succeed");
835
836        // Simplified edge index (connect sequential nodes)
837        let mut coarse_edges = Vec::new();
838        for i in 0..coarse_nodes.saturating_sub(1) {
839            coarse_edges.push(i as f32);
840            coarse_edges.push((i + 1) as f32);
841        }
842
843        let coarse_edge_index = from_vec(
844            coarse_edges,
845            &[2, coarse_nodes.saturating_sub(1)],
846            torsh_core::device::DeviceType::Cpu,
847        )
848        .expect("coarse edge index tensor creation should succeed");
849
850        GraphData::new(coarse_x, coarse_edge_index)
851    }
852
853    /// Process features at a specific scale
854    fn process_scale(&self, graph: &GraphData, scale: usize) -> Tensor {
855        let operator = &self.scale_operators[scale];
856        graph
857            .x
858            .matmul(&operator.clone_data())
859            .expect("operation should succeed")
860    }
861
862    /// Upsample features to original graph size
863    fn upsample_features(&self, features: &Tensor, target_nodes: usize) -> Tensor {
864        let current_nodes = features.shape().dims()[0];
865        let feature_dim = features.shape().dims()[1];
866
867        if current_nodes >= target_nodes {
868            // Truncate if necessary
869            return features
870                .slice_tensor(0, 0, target_nodes)
871                .expect("feature truncation should succeed");
872        }
873
874        // Simple upsampling by repetition
875        let feature_data = features.to_vec().expect("conversion should succeed");
876        let mut upsampled_data = Vec::new();
877
878        for target_id in 0..target_nodes {
879            let source_id = (target_id * current_nodes) / target_nodes;
880            let start_idx = source_id * feature_dim;
881            let end_idx = start_idx + feature_dim;
882
883            if end_idx <= feature_data.len() {
884                upsampled_data.extend(&feature_data[start_idx..end_idx]);
885            } else {
886                // Pad with zeros if needed
887                upsampled_data.extend(vec![0.0f32; feature_dim]);
888            }
889        }
890
891        from_vec(
892            upsampled_data,
893            &[target_nodes, feature_dim],
894            torsh_core::device::DeviceType::Cpu,
895        )
896        .expect("upsampled features tensor creation should succeed")
897    }
898
899    /// Fuse multi-scale features
900    fn fuse_scales(&self, scale_features: &[Tensor]) -> Tensor {
901        // Concatenate features from all scales
902        let mut concatenated_data = Vec::new();
903        let num_nodes = scale_features[0].shape().dims()[0];
904
905        for node_id in 0..num_nodes {
906            for scale_feature in scale_features {
907                let node_features = scale_feature
908                    .slice_tensor(0, node_id, node_id + 1)
909                    .expect("scale feature slice should succeed");
910                let feature_data = node_features.to_vec().expect("conversion should succeed");
911                concatenated_data.extend(feature_data);
912            }
913        }
914
915        let concatenated = from_vec(
916            concatenated_data,
917            &[num_nodes, self.num_scales * self.hidden_features],
918            torsh_core::device::DeviceType::Cpu,
919        )
920        .expect("concatenated scale features tensor creation should succeed");
921
922        // Apply fusion weights
923        concatenated
924            .matmul(&self.fusion_weights.clone_data())
925            .expect("operation should succeed")
926    }
927}
928
929impl GraphLayer for MultiScaleGNO {
930    fn forward(&self, graph: &GraphData) -> GraphData {
931        self.forward(graph)
932    }
933
934    fn parameters(&self) -> Vec<Tensor> {
935        let mut params = vec![
936            self.fusion_weights.clone_data(),
937            self.output_projection.clone_data(),
938        ];
939
940        for operator in &self.scale_operators {
941            params.push(operator.clone_data());
942        }
943
944        if let Some(ref bias) = self.bias {
945            params.push(bias.clone_data());
946        }
947
948        params
949    }
950}
951
952/// Graph Neural Operator utilities
953pub mod utils {
954    use super::*;
955
956    /// Compute spectral features of a graph
957    pub fn compute_spectral_features(graph: &GraphData, num_eigenvalues: usize) -> Tensor {
958        // Simplified spectral computation
959        let num_nodes = graph.num_nodes;
960        let mut spectral_data = Vec::new();
961
962        for i in 0..num_nodes {
963            for j in 0..num_eigenvalues {
964                let eigenvalue = (j as f32 + 1.0) / num_eigenvalues as f32;
965                let eigenvector_val = (std::f32::consts::PI * (i as f32 + 1.0) * (j as f32 + 1.0)
966                    / num_nodes as f32)
967                    .sin();
968                spectral_data.push(eigenvalue * eigenvector_val);
969            }
970        }
971
972        from_vec(
973            spectral_data,
974            &[num_nodes, num_eigenvalues],
975            torsh_core::device::DeviceType::Cpu,
976        )
977        .expect("spectral features tensor creation should succeed")
978    }
979
980    /// Generate synthetic operator learning data
981    pub fn generate_operator_data(
982        num_graphs: usize,
983        num_nodes: usize,
984        feature_dim: usize,
985    ) -> Vec<(GraphData, GraphData)> {
986        let mut rng = scirs2_core::random::thread_rng();
987        let mut data_pairs = Vec::new();
988
989        for _ in 0..num_graphs {
990            // Generate input graph
991            let input_features = randn(&[num_nodes, feature_dim])
992                .expect("input features tensor creation should succeed");
993            let mut edge_data = Vec::new();
994
995            // Create random edges
996            for _ in 0..(num_nodes * 2) {
997                let src = rng.gen_range(0..num_nodes) as f32;
998                let dst = rng.gen_range(0..num_nodes) as f32;
999                edge_data.push(src);
1000                edge_data.push(dst);
1001            }
1002
1003            let edge_index = from_vec(
1004                edge_data,
1005                &[2, num_nodes * 2],
1006                torsh_core::device::DeviceType::Cpu,
1007            )
1008            .expect("edge index tensor creation should succeed");
1009
1010            let input_graph = GraphData::new(input_features, edge_index);
1011
1012            // Generate corresponding output (apply some transformation)
1013            let output_features = input_graph
1014                .x
1015                .mul_scalar(2.0)
1016                .expect("output transformation should succeed");
1017            let output_graph = GraphData::new(output_features, input_graph.edge_index.clone());
1018
1019            data_pairs.push((input_graph, output_graph));
1020        }
1021
1022        data_pairs
1023    }
1024
1025    /// Evaluate operator approximation error
1026    pub fn compute_operator_error(predicted: &GraphData, target: &GraphData) -> f32 {
1027        let pred_data = predicted.x.to_vec().expect("conversion should succeed");
1028        let target_data = target.x.to_vec().expect("conversion should succeed");
1029
1030        let mut mse = 0.0;
1031        let mut count = 0;
1032
1033        for (pred, target) in pred_data.iter().zip(target_data.iter()) {
1034            mse += (pred - target).powi(2);
1035            count += 1;
1036        }
1037
1038        if count > 0 {
1039            mse / count as f32
1040        } else {
1041            0.0
1042        }
1043    }
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049    use torsh_core::device::DeviceType;
1050
1051    #[test]
1052    fn test_graph_fno_creation() {
1053        let fno = GraphFNO::new(4, 8, 16, 10, 3, true);
1054        assert_eq!(fno.in_features, 4);
1055        assert_eq!(fno.out_features, 8);
1056        assert_eq!(fno.hidden_features, 16);
1057        assert_eq!(fno.num_modes, 10);
1058        assert_eq!(fno.num_layers, 3);
1059    }
1060
1061    #[test]
1062    fn test_graph_fno_forward() {
1063        let features = randn(&[5, 4]).unwrap();
1064        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
1065        let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
1066        let graph = GraphData::new(features, edge_index);
1067
1068        let fno = GraphFNO::new(4, 8, 16, 10, 3, true);
1069        let output = fno.forward(&graph);
1070
1071        assert_eq!(output.x.shape().dims(), &[5, 8]);
1072    }
1073
1074    #[test]
1075    fn test_graph_deeponet_creation() {
1076        let deeponet = GraphDeepONet::new(3, 4, 16, 8, 10, 3, true);
1077        assert_eq!(deeponet.trunk_net_features, 3);
1078        assert_eq!(deeponet.branch_net_features, 4);
1079        assert_eq!(deeponet.output_features, 8);
1080        assert_eq!(deeponet.num_sensors, 10);
1081    }
1082
1083    #[test]
1084    fn test_physics_informed_gnn() {
1085        let features = randn(&[4, 3]).unwrap();
1086        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
1087        let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
1088        let graph = GraphData::new(features, edge_index);
1089
1090        let pignn = PhysicsInformedGNN::new(3, 6, 12, 2, 0.1, 0.05, true);
1091        let output = pignn.forward(&graph);
1092
1093        assert_eq!(output.x.shape().dims(), &[4, 6]);
1094    }
1095
1096    #[test]
1097    fn test_multi_scale_gno() {
1098        let features = randn(&[8, 4]).unwrap();
1099        let edges = vec![
1100            0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0,
1101        ];
1102        let edge_index = from_vec(edges, &[2, 7], DeviceType::Cpu).unwrap();
1103        let graph = GraphData::new(features, edge_index);
1104
1105        let ms_gno = MultiScaleGNO::new(4, 6, 3, 8, true);
1106        let output = ms_gno.forward(&graph);
1107
1108        assert_eq!(output.x.shape().dims(), &[8, 6]);
1109    }
1110
1111    #[test]
1112    fn test_spectral_features() {
1113        let features = randn(&[6, 3]).unwrap();
1114        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0];
1115        let edge_index = from_vec(edges, &[2, 5], DeviceType::Cpu).unwrap();
1116        let graph = GraphData::new(features, edge_index);
1117
1118        let spectral_features = utils::compute_spectral_features(&graph, 4);
1119        assert_eq!(spectral_features.shape().dims(), &[6, 4]);
1120    }
1121
1122    #[test]
1123    fn test_operator_data_generation() {
1124        let data_pairs = utils::generate_operator_data(3, 5, 4);
1125        assert_eq!(data_pairs.len(), 3);
1126
1127        for (input, output) in &data_pairs {
1128            assert_eq!(input.num_nodes, 5);
1129            assert_eq!(output.num_nodes, 5);
1130            assert_eq!(input.x.shape().dims()[1], 4);
1131            assert_eq!(output.x.shape().dims()[1], 4);
1132        }
1133    }
1134
1135    #[test]
1136    fn test_operator_error_computation() {
1137        let features1 = from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], DeviceType::Cpu).unwrap();
1138        let features2 = from_vec(vec![1.1, 2.1, 3.1, 4.1], &[2, 2], DeviceType::Cpu).unwrap();
1139        let edges = vec![0.0, 1.0];
1140        let edge_index = from_vec(edges, &[2, 1], DeviceType::Cpu).unwrap();
1141
1142        let graph1 = GraphData::new(features1, edge_index.clone());
1143        let graph2 = GraphData::new(features2, edge_index);
1144
1145        let error = utils::compute_operator_error(&graph1, &graph2);
1146        assert!(error > 0.0);
1147        assert!(error < 1.0); // Should be small for similar graphs
1148    }
1149}