1#![allow(dead_code)]
16use crate::parameter::Parameter;
17use crate::{GraphData, GraphLayer};
18use torsh_tensor::{
19 creation::{randn, zeros},
20 Tensor,
21};
22
23use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
25use std::collections::HashMap;
26use std::sync::Arc;
27
28#[derive(Debug)]
35pub struct MPNNConv {
36 in_features: usize,
37 out_features: usize,
38 edge_features: usize,
39 message_hidden_dim: usize,
40 update_hidden_dim: usize,
41
42 message_layer1: Parameter,
44 message_layer2: Parameter,
45 message_bias1: Option<Parameter>,
46 message_bias2: Option<Parameter>,
47
48 update_layer1: Parameter,
50 update_layer2: Parameter,
51 update_bias1: Option<Parameter>,
52 update_bias2: Option<Parameter>,
53
54 edge_embedding: Option<Parameter>,
56
57 aggregation_type: AggregationType,
58}
59
60#[derive(Debug, Clone, Copy)]
62pub enum AggregationType {
63 Sum,
64 Mean,
65 Max,
66 Attention,
67}
68
69impl MPNNConv {
70 pub fn new(
72 in_features: usize,
73 out_features: usize,
74 edge_features: usize,
75 message_hidden_dim: usize,
76 update_hidden_dim: usize,
77 aggregation_type: AggregationType,
78 bias: bool,
79 ) -> Self {
80 let message_input_dim = 2 * in_features + edge_features;
82 let message_layer1 = Parameter::new(
83 randn(&[message_input_dim, message_hidden_dim])
84 .expect("failed to create message layer 1 weights"),
85 );
86 let message_layer2 = Parameter::new(
87 randn(&[message_hidden_dim, out_features])
88 .expect("failed to create message layer 2 weights"),
89 );
90
91 let message_bias1 = if bias {
92 Some(Parameter::new(
93 zeros(&[message_hidden_dim]).expect("failed to create message bias 1"),
94 ))
95 } else {
96 None
97 };
98
99 let message_bias2 = if bias {
100 Some(Parameter::new(
101 zeros(&[out_features]).expect("failed to create message bias 2"),
102 ))
103 } else {
104 None
105 };
106
107 let update_input_dim = in_features + out_features;
109 let update_layer1 = Parameter::new(
110 randn(&[update_input_dim, update_hidden_dim])
111 .expect("failed to create update layer 1 weights"),
112 );
113 let update_layer2 = Parameter::new(
114 randn(&[update_hidden_dim, out_features])
115 .expect("failed to create update layer 2 weights"),
116 );
117
118 let update_bias1 = if bias {
119 Some(Parameter::new(
120 zeros(&[update_hidden_dim]).expect("failed to create update bias 1"),
121 ))
122 } else {
123 None
124 };
125
126 let update_bias2 = if bias {
127 Some(Parameter::new(
128 zeros(&[out_features]).expect("failed to create update bias 2"),
129 ))
130 } else {
131 None
132 };
133
134 let edge_embedding = if edge_features > 0 {
136 Some(Parameter::new(
137 randn(&[edge_features, edge_features])
138 .expect("failed to create edge embedding weights"),
139 ))
140 } else {
141 None
142 };
143
144 Self {
145 in_features,
146 out_features,
147 edge_features,
148 message_hidden_dim,
149 update_hidden_dim,
150 message_layer1,
151 message_layer2,
152 message_bias1,
153 message_bias2,
154 update_layer1,
155 update_layer2,
156 update_bias1,
157 update_bias2,
158 edge_embedding,
159 aggregation_type,
160 }
161 }
162
163 pub fn forward(&self, graph: &GraphData) -> GraphData {
165 let num_nodes = graph.num_nodes;
166 let edge_data = crate::utils::tensor_to_vec2::<f32>(&graph.edge_index)
167 .expect("failed to extract edge index data");
168 let _num_edges = edge_data[0].len();
169
170 let messages = self.compute_messages(graph);
172
173 let aggregated = self.aggregate_messages(&messages, &edge_data, num_nodes);
175
176 let updated_features = self.update_nodes(&graph.x, &aggregated);
178
179 GraphData {
180 x: updated_features,
181 edge_index: graph.edge_index.clone(),
182 edge_attr: graph.edge_attr.clone(),
183 batch: graph.batch.clone(),
184 num_nodes: graph.num_nodes,
185 num_edges: graph.num_edges,
186 }
187 }
188
189 fn compute_messages(&self, graph: &GraphData) -> Tensor {
191 let edge_data = crate::utils::tensor_to_vec2::<f32>(&graph.edge_index)
192 .expect("failed to extract edge index data");
193 let num_edges = edge_data[0].len();
194
195 let mut all_messages = Vec::new();
196
197 for edge_idx in 0..num_edges {
198 let src_idx = edge_data[0][edge_idx] as usize;
199 let dst_idx = edge_data[1][edge_idx] as usize;
200
201 let h_i = graph
203 .x
204 .slice_tensor(0, src_idx, src_idx + 1)
205 .expect("failed to slice source node features")
206 .squeeze_tensor(0)
207 .expect("failed to squeeze source node features");
208 let h_j = graph
209 .x
210 .slice_tensor(0, dst_idx, dst_idx + 1)
211 .expect("failed to slice destination node features")
212 .squeeze_tensor(0)
213 .expect("failed to squeeze destination node features");
214
215 let edge_feat = if let Some(ref edge_attr) = graph.edge_attr {
217 if self.edge_features > 0 {
218 let e_ij = edge_attr
219 .slice_tensor(0, edge_idx, edge_idx + 1)
220 .expect("failed to slice edge attributes")
221 .squeeze_tensor(0)
222 .expect("failed to squeeze edge attributes");
223
224 if let Some(ref edge_emb) = self.edge_embedding {
226 let e_ij_2d = e_ij
228 .unsqueeze_tensor(0)
229 .expect("failed to unsqueeze edge features");
230 e_ij_2d
231 .matmul(&edge_emb.clone_data())
232 .expect("failed to apply edge embedding")
233 .squeeze_tensor(0)
234 .expect("failed to squeeze embedded edge features")
235 } else {
236 e_ij
237 }
238 } else {
239 zeros(&[self.edge_features]).expect("failed to create zero edge features")
240 }
241 } else {
242 zeros(&[self.edge_features]).expect("failed to create zero edge features")
243 };
244
245 let message_input = Tensor::cat(&[&h_i, &h_j, &edge_feat], 0)
247 .expect("failed to concatenate message input");
248
249 let message_input_2d = message_input
252 .unsqueeze_tensor(0)
253 .expect("failed to unsqueeze message input");
254 let mut message = message_input_2d
255 .matmul(&self.message_layer1.clone_data())
256 .expect("failed to apply message layer 1")
257 .squeeze_tensor(0)
258 .expect("failed to squeeze message layer 1 output");
259
260 if let Some(ref bias1) = self.message_bias1 {
261 message = message
262 .add(&bias1.clone_data())
263 .expect("operation should succeed");
264 }
265
266 message = message
268 .maximum(
269 &zeros(&message.shape().dims()).expect("failed to create zero tensor for ReLU"),
270 )
271 .expect("failed to apply ReLU activation");
272
273 let message_2d = message
275 .unsqueeze_tensor(0)
276 .expect("failed to unsqueeze message for layer 2");
277 message = message_2d
278 .matmul(&self.message_layer2.clone_data())
279 .expect("failed to apply message layer 2")
280 .squeeze_tensor(0)
281 .expect("failed to squeeze message layer 2 output");
282
283 if let Some(ref bias2) = self.message_bias2 {
284 message = message
285 .add(&bias2.clone_data())
286 .expect("operation should succeed");
287 }
288
289 all_messages.push(message);
290 }
291
292 if all_messages.is_empty() {
294 zeros(&[0, self.out_features]).expect("failed to create empty messages tensor")
295 } else {
296 let mut message_data = Vec::new();
298 for msg in &all_messages {
299 let msg_vec = msg.to_vec().expect("conversion should succeed");
300 message_data.extend(msg_vec);
301 }
302
303 torsh_tensor::creation::from_vec(
304 message_data,
305 &[all_messages.len(), self.out_features],
306 torsh_core::device::DeviceType::Cpu,
307 )
308 .expect("failed to create messages tensor from data")
309 }
310 }
311
312 fn aggregate_messages(
314 &self,
315 messages: &Tensor,
316 edge_data: &[Vec<f32>],
317 num_nodes: usize,
318 ) -> Tensor {
319 let mut aggregated = zeros(&[num_nodes, self.out_features])
320 .expect("failed to create aggregated messages tensor");
321 let num_edges = edge_data[0].len();
322
323 if num_edges == 0 {
324 return aggregated;
325 }
326
327 match self.aggregation_type {
328 AggregationType::Sum | AggregationType::Mean => {
329 let mut node_counts = vec![0; num_nodes];
330
331 for edge_idx in 0..num_edges {
333 let dst_idx = edge_data[1][edge_idx] as usize;
334 if dst_idx < num_nodes {
335 let message = messages
336 .slice_tensor(0, edge_idx, edge_idx + 1)
337 .expect("failed to slice message")
338 .squeeze_tensor(0)
339 .expect("failed to squeeze message");
340
341 let current = aggregated
342 .slice_tensor(0, dst_idx, dst_idx + 1)
343 .expect("failed to slice aggregated tensor")
344 .squeeze_tensor(0)
345 .expect("failed to squeeze aggregated tensor");
346 let updated = current.add(&message).expect("operation should succeed");
347
348 aggregated
349 .slice_tensor(0, dst_idx, dst_idx + 1)
350 .expect("failed to slice aggregated tensor for update")
351 .copy_(
352 &updated
353 .unsqueeze_tensor(0)
354 .expect("failed to unsqueeze updated tensor"),
355 )
356 .expect("failed to copy updated tensor");
357
358 node_counts[dst_idx] += 1;
359 }
360 }
361
362 if matches!(self.aggregation_type, AggregationType::Mean) {
364 for node in 0..num_nodes {
365 if node_counts[node] > 0 {
366 let current = aggregated
367 .slice_tensor(0, node, node + 1)
368 .expect("failed to slice aggregated tensor for mean")
369 .squeeze_tensor(0)
370 .expect("failed to squeeze aggregated tensor for mean");
371 let normalized = current
372 .div_scalar(node_counts[node] as f32)
373 .expect("failed to normalize aggregated tensor");
374
375 aggregated
376 .slice_tensor(0, node, node + 1)
377 .expect("failed to slice aggregated tensor for normalized update")
378 .copy_(
379 &normalized
380 .unsqueeze_tensor(0)
381 .expect("failed to unsqueeze normalized tensor"),
382 )
383 .expect("failed to copy normalized tensor");
384 }
385 }
386 }
387 }
388
389 AggregationType::Max => {
390 aggregated
392 .fill_(-1e9_f32)
393 .expect("failed to fill aggregated tensor with initial values");
394
395 for edge_idx in 0..num_edges {
396 let dst_idx = edge_data[1][edge_idx] as usize;
397 if dst_idx < num_nodes {
398 let message = messages
399 .slice_tensor(0, edge_idx, edge_idx + 1)
400 .expect("failed to slice message for max aggregation")
401 .squeeze_tensor(0)
402 .expect("failed to squeeze message for max aggregation");
403
404 let current = aggregated
405 .slice_tensor(0, dst_idx, dst_idx + 1)
406 .expect("failed to slice aggregated tensor for max")
407 .squeeze_tensor(0)
408 .expect("failed to squeeze aggregated tensor for max");
409 let updated = current
410 .maximum(&message)
411 .expect("failed to compute maximum");
412
413 aggregated
414 .slice_tensor(0, dst_idx, dst_idx + 1)
415 .expect("failed to slice aggregated tensor for max update")
416 .copy_(
417 &updated
418 .unsqueeze_tensor(0)
419 .expect("failed to unsqueeze max updated tensor"),
420 )
421 .expect("failed to copy max updated tensor");
422 }
423 }
424
425 let aggregated_data = aggregated.to_vec().expect("conversion should succeed");
428 let filtered_data: Vec<f32> = aggregated_data
429 .iter()
430 .map(|&x| if x <= -1e8_f32 { 0.0 } else { x })
431 .collect();
432 aggregated = Tensor::from_data(
433 filtered_data,
434 aggregated.shape().dims().to_vec(),
435 aggregated.device(),
436 )
437 .expect("failed to create filtered aggregated tensor");
438 }
439
440 AggregationType::Attention => {
441 return self.aggregate_messages(messages, edge_data, num_nodes);
444 }
445 }
446
447 aggregated
448 }
449
450 fn update_nodes(&self, current_states: &Tensor, aggregated_messages: &Tensor) -> Tensor {
452 let num_nodes = current_states.shape().dims()[0];
453 let mut updated_states =
454 zeros(&[num_nodes, self.out_features]).expect("failed to create updated states tensor");
455
456 for node in 0..num_nodes {
457 let h_i = current_states
459 .slice_tensor(0, node, node + 1)
460 .expect("failed to slice current node state")
461 .squeeze_tensor(0)
462 .expect("failed to squeeze current node state");
463
464 let m_i = aggregated_messages
466 .slice_tensor(0, node, node + 1)
467 .expect("failed to slice aggregated message")
468 .squeeze_tensor(0)
469 .expect("failed to squeeze aggregated message");
470
471 let update_input =
473 Tensor::cat(&[&h_i, &m_i], 0).expect("failed to concatenate update input");
474
475 let update_input_2d = update_input
478 .unsqueeze_tensor(0)
479 .expect("failed to unsqueeze update input");
480 let mut updated = update_input_2d
481 .matmul(&self.update_layer1.clone_data())
482 .expect("failed to apply update layer 1")
483 .squeeze_tensor(0)
484 .expect("failed to squeeze update layer 1 output");
485
486 if let Some(ref bias1) = self.update_bias1 {
487 updated = updated
488 .add(&bias1.clone_data())
489 .expect("operation should succeed");
490 }
491
492 let mut updated_temp = updated;
494 updated_temp
495 .clamp_(0.0, f32::INFINITY)
496 .expect("failed to clamp update values");
497 updated = updated_temp;
498
499 let updated_2d = updated
501 .unsqueeze_tensor(0)
502 .expect("failed to unsqueeze for update layer 2");
503 updated = updated_2d
504 .matmul(&self.update_layer2.clone_data())
505 .expect("failed to apply update layer 2")
506 .squeeze_tensor(0)
507 .expect("failed to squeeze update layer 2 output");
508
509 if let Some(ref bias2) = self.update_bias2 {
510 updated = updated
511 .add(&bias2.clone_data())
512 .expect("operation should succeed");
513 }
514
515 let updated_data = updated.to_vec().expect("conversion should succeed");
517 for (i, &value) in updated_data.iter().enumerate() {
518 updated_states
519 .set_item(&[node, i], value)
520 .expect("failed to set updated state value");
521 }
522 }
523
524 updated_states
525 }
526}
527
528impl GraphLayer for MPNNConv {
529 fn forward(&self, graph: &GraphData) -> GraphData {
530 self.forward(graph)
531 }
532
533 fn parameters(&self) -> Vec<Tensor> {
534 let mut params = vec![
535 self.message_layer1.clone_data(),
536 self.message_layer2.clone_data(),
537 self.update_layer1.clone_data(),
538 self.update_layer2.clone_data(),
539 ];
540
541 if let Some(ref bias1) = self.message_bias1 {
542 params.push(bias1.clone_data());
543 }
544
545 if let Some(ref bias2) = self.message_bias2 {
546 params.push(bias2.clone_data());
547 }
548
549 if let Some(ref bias1) = self.update_bias1 {
550 params.push(bias1.clone_data());
551 }
552
553 if let Some(ref bias2) = self.update_bias2 {
554 params.push(bias2.clone_data());
555 }
556
557 if let Some(ref edge_emb) = self.edge_embedding {
558 params.push(edge_emb.clone_data());
559 }
560
561 params
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use torsh_core::device::DeviceType;
569 use torsh_tensor::creation::from_vec;
570
571 #[test]
572 fn test_mpnn_creation() {
573 let mpnn = MPNNConv::new(8, 16, 4, 32, 32, AggregationType::Sum, true);
574 let params = mpnn.parameters();
575
576 assert!(params.len() >= 4); assert!(params.len() <= 9); }
581
582 #[test]
583 fn test_mpnn_forward() {
584 let mpnn = MPNNConv::new(3, 8, 2, 16, 16, AggregationType::Mean, false);
585
586 let x = from_vec(
588 vec![
589 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ],
593 &[3, 3],
594 DeviceType::Cpu,
595 )
596 .unwrap();
597
598 let edge_index =
599 from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
600
601 let edge_attr =
602 from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], DeviceType::Cpu).unwrap();
603
604 let graph = GraphData::new(x, edge_index).with_edge_attr(edge_attr);
605
606 let output = mpnn.forward(&graph);
607 assert_eq!(output.x.shape().dims(), &[3, 8]);
608 assert_eq!(output.num_nodes, 3);
609 }
610
611 #[test]
612 fn test_mpnn_aggregation_types() {
613 let mpnn_sum = MPNNConv::new(2, 4, 0, 8, 8, AggregationType::Sum, false);
614 let mpnn_mean = MPNNConv::new(2, 4, 0, 8, 8, AggregationType::Mean, false);
615 let mpnn_max = MPNNConv::new(2, 4, 0, 8, 8, AggregationType::Max, false);
616
617 let x = from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], DeviceType::Cpu).unwrap();
619
620 let edge_index = from_vec(vec![0.0, 1.0], &[2, 1], DeviceType::Cpu).unwrap();
621
622 let graph = GraphData::new(x, edge_index);
623
624 let _output_sum = mpnn_sum.forward(&graph);
626 let _output_mean = mpnn_mean.forward(&graph);
627 let _output_max = mpnn_max.forward(&graph);
628 }
629
630 #[test]
631 fn test_mpnn_empty_graph() {
632 let mpnn = MPNNConv::new(3, 8, 0, 16, 16, AggregationType::Sum, false);
633
634 let x = from_vec(vec![1.0, 2.0, 3.0], &[1, 3], DeviceType::Cpu).unwrap();
636
637 let edge_index = zeros(&[2, 0]).unwrap();
638 let graph = GraphData::new(x, edge_index);
639
640 let output = mpnn.forward(&graph);
641 assert_eq!(output.x.shape().dims(), &[1, 8]);
642 assert_eq!(output.num_nodes, 1);
643 }
644}
645
646#[derive(Debug, Clone)]
652pub struct AdvancedSIMDMPNN {
653 in_features: usize,
655 out_features: usize,
656 edge_features: usize,
657
658 simd_chunk_size: usize,
660 memory_efficient: bool,
661 use_attention: bool,
662 num_attention_heads: usize,
663
664 message_weights: Array2<f64>,
666 update_weights: Array2<f64>,
667 attention_weights: Option<Array2<f64>>,
668
669 message_bias: Option<Array1<f64>>,
671 update_bias: Option<Array1<f64>>,
672
673 aggregation_config: AdvancedAggregationConfig,
675
676 performance_cache: PerformanceCache,
678}
679
680#[derive(Debug, Clone)]
682pub struct AdvancedAggregationConfig {
683 primary_aggregation: AggregationType,
685 secondary_aggregation: Option<AggregationType>,
687 hierarchical_levels: usize,
689 attention_temperature: f64,
691 dynamic_routing: bool,
693}
694
695#[derive(Debug, Clone)]
697pub struct PerformanceCache {
698 adjacency_patterns: HashMap<String, Arc<Array2<f64>>>,
700 degree_stats: HashMap<usize, (f64, f64)>, message_cache: HashMap<String, Arc<Array2<f64>>>,
704 simd_speedup_factor: f64,
706}
707
708impl AdvancedSIMDMPNN {
709 pub fn new(
711 in_features: usize,
712 out_features: usize,
713 edge_features: usize,
714 config: AdvancedMPNNConfig,
715 ) -> Self {
716 let message_input_dim = 2 * in_features + edge_features;
717 let hidden_dim = config.hidden_dim;
718
719 let message_weights = Self::initialize_weights_simd(message_input_dim, hidden_dim);
721 let update_weights = Self::initialize_weights_simd(hidden_dim + in_features, out_features);
722
723 let attention_weights = if config.use_attention {
725 Some(Self::initialize_weights_simd(
726 hidden_dim,
727 config.num_attention_heads * hidden_dim,
728 ))
729 } else {
730 None
731 };
732
733 let message_bias = if config.use_bias {
735 Some(Array1::zeros(hidden_dim))
736 } else {
737 None
738 };
739
740 let update_bias = if config.use_bias {
741 Some(Array1::zeros(out_features))
742 } else {
743 None
744 };
745
746 Self {
747 in_features,
748 out_features,
749 edge_features,
750 simd_chunk_size: config.simd_chunk_size,
751 memory_efficient: config.memory_efficient,
752 use_attention: config.use_attention,
753 num_attention_heads: config.num_attention_heads,
754 message_weights,
755 update_weights,
756 attention_weights,
757 message_bias,
758 update_bias,
759 aggregation_config: config.aggregation_config,
760 performance_cache: PerformanceCache::new(),
761 }
762 }
763
764 pub fn forward_simd(&mut self, graph: &GraphData) -> GraphData {
766 let batch_size = graph.num_nodes;
767
768 if batch_size == 0 {
769 return graph.clone();
770 }
771
772 let node_features = self.tensor_to_array2(&graph.x);
774 let edge_indices = self.extract_edge_indices(&graph.edge_index);
775 let edge_attributes = graph
776 .edge_attr
777 .as_ref()
778 .map(|attr| self.tensor_to_array2(attr));
779
780 let messages = if self.memory_efficient && batch_size > self.simd_chunk_size {
782 self.compute_messages_chunked(&node_features, &edge_indices, &edge_attributes)
783 } else {
784 self.compute_messages_vectorized(&node_features, &edge_indices, &edge_attributes)
785 };
786
787 let aggregated_messages =
789 self.aggregate_messages_simd(&messages, &edge_indices, batch_size);
790
791 let updated_features = self.update_nodes_simd(&node_features, &aggregated_messages);
793
794 let output_tensor = self.array2_to_tensor(&updated_features);
796
797 self.update_performance_cache(batch_size, edge_indices.len());
799
800 GraphData::new(output_tensor, graph.edge_index.clone())
801 .with_edge_attr_opt(graph.edge_attr.clone())
802 }
803
804 fn initialize_weights_simd(input_dim: usize, output_dim: usize) -> Array2<f64> {
806 let mut weights = Array2::zeros((input_dim, output_dim));
807 let scale = (2.0 / input_dim as f64).sqrt();
808
809 use std::collections::hash_map::DefaultHasher;
811 use std::hash::{Hash, Hasher};
812
813 for i in 0..input_dim {
814 for j in 0..output_dim {
815 let mut hasher = DefaultHasher::new();
816 (i, j).hash(&mut hasher);
817 let hash_val = hasher.finish();
818 let normalized = (hash_val as f64) / (u64::MAX as f64);
819 weights[[i, j]] = (normalized - 0.5) * 2.0 * scale;
820 }
821 }
822
823 weights
824 }
825
826 fn compute_messages_vectorized(
828 &self,
829 node_features: &Array2<f64>,
830 edge_indices: &[(usize, usize)],
831 edge_attributes: &Option<Array2<f64>>,
832 ) -> Array2<f64> {
833 let num_edges = edge_indices.len();
834 let message_dim = self.message_weights.ncols();
835 let mut messages = Array2::zeros((num_edges, message_dim));
836
837 for (edge_idx, &(src, dst)) in edge_indices.iter().enumerate() {
839 if src < node_features.nrows() && dst < node_features.nrows() {
840 let src_features = node_features.row(src);
842 let dst_features = node_features.row(dst);
843
844 let mut message_input =
845 Vec::with_capacity(self.in_features * 2 + self.edge_features);
846
847 message_input.extend(src_features.iter());
849 message_input.extend(dst_features.iter());
850
851 if let Some(ref edge_attr) = edge_attributes {
853 if edge_idx < edge_attr.nrows() {
854 message_input.extend(edge_attr.row(edge_idx).iter());
855 } else {
856 message_input.resize(message_input.len() + self.edge_features, 0.0);
858 }
859 } else {
860 message_input.resize(message_input.len() + self.edge_features, 0.0);
862 }
863
864 let input_array = Array1::from_vec(message_input);
866 let message = self.compute_message_mlp(&input_array);
867
868 for (i, &val) in message.iter().enumerate() {
870 if i < message_dim {
871 messages[[edge_idx, i]] = val;
872 }
873 }
874 }
875 }
876
877 messages
878 }
879
880 fn compute_messages_chunked(
882 &self,
883 node_features: &Array2<f64>,
884 edge_indices: &[(usize, usize)],
885 edge_attributes: &Option<Array2<f64>>,
886 ) -> Array2<f64> {
887 let num_edges = edge_indices.len();
888 let message_dim = self.message_weights.ncols();
889 let mut messages = Array2::zeros((num_edges, message_dim));
890
891 for chunk_start in (0..num_edges).step_by(self.simd_chunk_size) {
893 let chunk_end = (chunk_start + self.simd_chunk_size).min(num_edges);
894 let chunk_indices = &edge_indices[chunk_start..chunk_end];
895
896 for (local_idx, &(src, dst)) in chunk_indices.iter().enumerate() {
898 let edge_idx = chunk_start + local_idx;
899
900 if src < node_features.nrows() && dst < node_features.nrows() {
901 let message = self.compute_single_message(
902 &node_features.row(src),
903 &node_features.row(dst),
904 edge_attributes.as_ref().and_then(|attr| {
905 if edge_idx < attr.nrows() {
906 Some(attr.row(edge_idx))
907 } else {
908 None
909 }
910 }),
911 );
912
913 for (i, &val) in message.iter().enumerate() {
915 if i < message_dim {
916 messages[[edge_idx, i]] = val;
917 }
918 }
919 }
920 }
921 }
922
923 messages
924 }
925
926 fn compute_message_mlp(&self, input: &Array1<f64>) -> Array1<f64> {
928 let mut hidden = Array1::zeros(self.message_weights.ncols());
930
931 for (i, _row) in self.message_weights.axis_iter(Axis(1)).enumerate() {
933 let dot_product = input
934 .iter()
935 .zip(self.message_weights.axis_iter(Axis(0)))
936 .map(|(&x, weight_col)| x * weight_col[i])
937 .sum::<f64>();
938
939 hidden[i] = dot_product;
940 }
941
942 if let Some(ref bias) = self.message_bias {
944 for i in 0..hidden.len() {
945 if i < bias.len() {
946 hidden[i] += bias[i];
947 }
948 }
949 }
950
951 hidden.mapv_inplace(|x| x.max(0.0));
953
954 hidden
956 }
957
958 fn compute_single_message(
960 &self,
961 src_features: &ArrayView1<f64>,
962 dst_features: &ArrayView1<f64>,
963 edge_features: Option<ArrayView1<f64>>,
964 ) -> Array1<f64> {
965 let mut message_input = Vec::with_capacity(self.in_features * 2 + self.edge_features);
966
967 message_input.extend(src_features.iter());
969 message_input.extend(dst_features.iter());
970
971 if let Some(edge_feat) = edge_features {
972 message_input.extend(edge_feat.iter());
973 } else {
974 message_input.resize(message_input.len() + self.edge_features, 0.0);
975 }
976
977 let input_array = Array1::from_vec(message_input);
978 self.compute_message_mlp(&input_array)
979 }
980
981 fn aggregate_messages_simd(
983 &self,
984 messages: &Array2<f64>,
985 edge_indices: &[(usize, usize)],
986 num_nodes: usize,
987 ) -> Array2<f64> {
988 let message_dim = messages.ncols();
989 let mut aggregated = Array2::zeros((num_nodes, message_dim));
990
991 match self.aggregation_config.primary_aggregation {
992 AggregationType::Sum => {
993 self.aggregate_sum_simd(messages, edge_indices, &mut aggregated)
994 }
995 AggregationType::Mean => {
996 self.aggregate_mean_simd(messages, edge_indices, &mut aggregated)
997 }
998 AggregationType::Max => {
999 self.aggregate_max_simd(messages, edge_indices, &mut aggregated)
1000 }
1001 AggregationType::Attention => {
1002 self.aggregate_attention_simd(messages, edge_indices, &mut aggregated)
1003 }
1004 }
1005
1006 aggregated
1007 }
1008
1009 fn aggregate_sum_simd(
1011 &self,
1012 messages: &Array2<f64>,
1013 edge_indices: &[(usize, usize)],
1014 aggregated: &mut Array2<f64>,
1015 ) {
1016 for (edge_idx, &(_, dst)) in edge_indices.iter().enumerate() {
1017 if dst < aggregated.nrows() && edge_idx < messages.nrows() {
1018 let message = messages.row(edge_idx);
1019 let mut dst_row = aggregated.row_mut(dst);
1020
1021 for (i, &msg_val) in message.iter().enumerate() {
1023 if i < dst_row.len() {
1024 dst_row[i] += msg_val;
1025 }
1026 }
1027 }
1028 }
1029 }
1030
1031 fn aggregate_mean_simd(
1033 &self,
1034 messages: &Array2<f64>,
1035 edge_indices: &[(usize, usize)],
1036 aggregated: &mut Array2<f64>,
1037 ) {
1038 self.aggregate_sum_simd(messages, edge_indices, aggregated);
1040
1041 let mut neighbor_counts = vec![0usize; aggregated.nrows()];
1043 for &(_, dst) in edge_indices {
1044 if dst < neighbor_counts.len() {
1045 neighbor_counts[dst] += 1;
1046 }
1047 }
1048
1049 for (node_idx, count) in neighbor_counts.iter().enumerate() {
1051 if *count > 0 && node_idx < aggregated.nrows() {
1052 let count_f64 = *count as f64;
1053 let mut row = aggregated.row_mut(node_idx);
1054 row.mapv_inplace(|x| x / count_f64);
1055 }
1056 }
1057 }
1058
1059 fn aggregate_max_simd(
1061 &self,
1062 messages: &Array2<f64>,
1063 edge_indices: &[(usize, usize)],
1064 aggregated: &mut Array2<f64>,
1065 ) {
1066 aggregated.fill(f64::NEG_INFINITY);
1068
1069 for (edge_idx, &(_, dst)) in edge_indices.iter().enumerate() {
1070 if dst < aggregated.nrows() && edge_idx < messages.nrows() {
1071 let message = messages.row(edge_idx);
1072 let mut dst_row = aggregated.row_mut(dst);
1073
1074 for (i, &msg_val) in message.iter().enumerate() {
1076 if i < dst_row.len() {
1077 dst_row[i] = dst_row[i].max(msg_val);
1078 }
1079 }
1080 }
1081 }
1082
1083 aggregated.mapv_inplace(|x| if x == f64::NEG_INFINITY { 0.0 } else { x });
1085 }
1086
1087 fn aggregate_attention_simd(
1089 &self,
1090 messages: &Array2<f64>,
1091 edge_indices: &[(usize, usize)],
1092 aggregated: &mut Array2<f64>,
1093 ) {
1094 if let Some(ref attention_weights) = self.attention_weights {
1095 let attention_scores = self.compute_attention_scores_simd(messages, attention_weights);
1097
1098 for (edge_idx, &(_, dst)) in edge_indices.iter().enumerate() {
1100 if dst < aggregated.nrows() && edge_idx < messages.nrows() {
1101 let message = messages.row(edge_idx);
1102 let attention_weight = attention_scores.get(edge_idx).copied().unwrap_or(0.0);
1103 let mut dst_row = aggregated.row_mut(dst);
1104
1105 for (i, &msg_val) in message.iter().enumerate() {
1107 if i < dst_row.len() {
1108 dst_row[i] += msg_val * attention_weight;
1109 }
1110 }
1111 }
1112 }
1113 } else {
1114 self.aggregate_sum_simd(messages, edge_indices, aggregated);
1116 }
1117 }
1118
1119 fn compute_attention_scores_simd(
1121 &self,
1122 messages: &Array2<f64>,
1123 attention_weights: &Array2<f64>,
1124 ) -> Vec<f64> {
1125 let num_messages = messages.nrows();
1126 let mut scores = Vec::with_capacity(num_messages);
1127
1128 for i in 0..num_messages {
1129 let message = messages.row(i);
1130
1131 let score = message
1133 .iter()
1134 .zip(attention_weights.column(0).iter())
1135 .map(|(&m, &w)| m * w)
1136 .sum::<f64>();
1137
1138 scores.push(score);
1139 }
1140
1141 self.softmax_simd(&mut scores);
1143 scores
1144 }
1145
1146 fn softmax_simd(&self, scores: &mut Vec<f64>) {
1148 if scores.is_empty() {
1149 return;
1150 }
1151
1152 let max_score = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1154
1155 for score in scores.iter_mut() {
1157 *score = (*score - max_score).exp();
1158 }
1159
1160 let sum: f64 = scores.iter().sum();
1162 if sum > 1e-15 {
1163 for score in scores.iter_mut() {
1164 *score /= sum;
1165 }
1166 }
1167 }
1168
1169 fn update_nodes_simd(
1171 &self,
1172 node_features: &Array2<f64>,
1173 aggregated_messages: &Array2<f64>,
1174 ) -> Array2<f64> {
1175 let num_nodes = node_features.nrows();
1176 let output_dim = self.out_features;
1177 let mut updated_features = Array2::zeros((num_nodes, output_dim));
1178
1179 for node_idx in 0..num_nodes {
1180 if node_idx < aggregated_messages.nrows() {
1181 let node_feat = node_features.row(node_idx);
1182 let agg_msg = aggregated_messages.row(node_idx);
1183
1184 let mut update_input = Vec::with_capacity(node_feat.len() + agg_msg.len());
1186 update_input.extend(node_feat.iter());
1187 update_input.extend(agg_msg.iter());
1188
1189 let input_array = Array1::from_vec(update_input);
1190 let updated = self.compute_update_mlp(&input_array);
1191
1192 for (i, &val) in updated.iter().enumerate() {
1194 if i < output_dim {
1195 updated_features[[node_idx, i]] = val;
1196 }
1197 }
1198 }
1199 }
1200
1201 updated_features
1202 }
1203
1204 fn compute_update_mlp(&self, input: &Array1<f64>) -> Array1<f64> {
1206 let mut output = Array1::zeros(self.out_features);
1207
1208 for (i, weight_col) in self.update_weights.axis_iter(Axis(1)).enumerate() {
1210 if i < output.len() {
1211 let dot_product = input
1212 .iter()
1213 .zip(weight_col.iter())
1214 .map(|(&x, &w)| x * w)
1215 .sum::<f64>();
1216
1217 output[i] = dot_product;
1218 }
1219 }
1220
1221 if let Some(ref bias) = self.update_bias {
1223 for i in 0..output.len() {
1224 if i < bias.len() {
1225 output[i] += bias[i];
1226 }
1227 }
1228 }
1229
1230 output.mapv_inplace(|x| x.max(0.0));
1232
1233 output
1234 }
1235
1236 fn tensor_to_array2(&self, tensor: &Tensor) -> Array2<f64> {
1238 match tensor.to_vec() {
1239 Ok(vec_data) => {
1240 let shape = tensor.shape();
1241 let dims = shape.dims();
1242 if dims.len() == 2 {
1243 let rows = dims[0];
1244 let cols = dims[1];
1245 let data_f64: Vec<f64> = vec_data.iter().map(|&x| x as f64).collect();
1246 Array2::from_shape_vec((rows, cols), data_f64)
1247 .expect("failed to create Array2 from shape and data")
1248 } else {
1249 Array2::zeros((1, 1))
1250 }
1251 }
1252 Err(_) => Array2::zeros((1, 1)),
1253 }
1254 }
1255
1256 fn array2_to_tensor(&self, array: &Array2<f64>) -> Tensor {
1257 let (rows, cols) = array.dim();
1258 let data_f32: Vec<f32> = array.iter().map(|&x| x as f32).collect();
1259
1260 torsh_tensor::creation::from_vec(
1261 data_f32,
1262 &[rows, cols],
1263 torsh_core::device::DeviceType::Cpu,
1264 )
1265 .expect("failed to create tensor from array data")
1266 }
1267
1268 fn extract_edge_indices(&self, edge_index: &Tensor) -> Vec<(usize, usize)> {
1269 match edge_index.to_vec() {
1270 Ok(vec_data) => {
1271 let shape = edge_index.shape();
1272 let dims = shape.dims();
1273 if dims.len() == 2 && dims[0] == 2 {
1274 let num_edges = dims[1];
1275 let mut edges = Vec::with_capacity(num_edges);
1276 for i in 0..num_edges {
1277 let src = vec_data[i] as usize;
1278 let dst = vec_data[num_edges + i] as usize;
1279 edges.push((src, dst));
1280 }
1281 edges
1282 } else {
1283 Vec::new()
1284 }
1285 }
1286 Err(_) => Vec::new(),
1287 }
1288 }
1289
1290 fn update_performance_cache(&mut self, num_nodes: usize, num_edges: usize) {
1292 let base_speedup = if num_nodes > self.simd_chunk_size {
1294 2.5 } else {
1296 1.5 };
1298
1299 self.performance_cache.simd_speedup_factor =
1300 base_speedup * (1.0 + (num_edges as f64 / num_nodes as f64).ln());
1301 }
1302}
1303
1304#[derive(Debug, Clone)]
1306pub struct AdvancedMPNNConfig {
1307 pub hidden_dim: usize,
1308 pub use_bias: bool,
1309 pub use_attention: bool,
1310 pub num_attention_heads: usize,
1311 pub simd_chunk_size: usize,
1312 pub memory_efficient: bool,
1313 pub aggregation_config: AdvancedAggregationConfig,
1314}
1315
1316impl Default for AdvancedMPNNConfig {
1317 fn default() -> Self {
1318 Self {
1319 hidden_dim: 128,
1320 use_bias: true,
1321 use_attention: true,
1322 num_attention_heads: 4,
1323 simd_chunk_size: 1024,
1324 memory_efficient: true,
1325 aggregation_config: AdvancedAggregationConfig::default(),
1326 }
1327 }
1328}
1329
1330impl Default for AdvancedAggregationConfig {
1331 fn default() -> Self {
1332 Self {
1333 primary_aggregation: AggregationType::Attention,
1334 secondary_aggregation: Some(AggregationType::Mean),
1335 hierarchical_levels: 2,
1336 attention_temperature: 1.0,
1337 dynamic_routing: true,
1338 }
1339 }
1340}
1341
1342impl PerformanceCache {
1343 fn new() -> Self {
1344 Self {
1345 adjacency_patterns: HashMap::new(),
1346 degree_stats: HashMap::new(),
1347 message_cache: HashMap::new(),
1348 simd_speedup_factor: 1.0,
1349 }
1350 }
1351}