1pub mod equivariant;
19pub mod gat;
20pub mod gcn;
21pub mod hgt;
22pub mod kg_completion;
23pub mod relation_message;
24pub mod rgcn;
25pub mod sage;
26pub mod transformers;
27
28pub use gat::{gat_forward, GraphAttentionLayer};
30pub use gcn::{add_self_loops, gcn_forward, symmetric_normalize, CsrMatrix, Gcn, GcnLayer};
31pub use sage::{sage_aggregate, sample_neighbors, GraphSage, GraphSageLayer, SageAggregation};
32
33use std::collections::HashMap;
38
39use scirs2_core::random::{Rng, RngExt};
40
41use crate::base::{EdgeWeight, Graph, Node};
42use crate::error::{GraphError, Result};
43
44#[derive(Debug, Clone, PartialEq, Default)]
50pub enum MessagePassing {
51 Sum,
53 #[default]
55 Mean,
56 Max,
58 Min,
60 Attention,
62}
63
64pub trait MessagePassingLayer {
73 fn aggregate(
75 &self,
76 node_features: &[Vec<f64>],
77 adjacency: &[(usize, usize, f64)],
78 n_nodes: usize,
79 ) -> Result<Vec<Vec<f64>>>;
80
81 fn update(&self, aggregated: &[Vec<f64>], node_features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>>;
83
84 fn forward(
86 &self,
87 node_features: &[Vec<f64>],
88 adjacency: &[(usize, usize, f64)],
89 ) -> Result<Vec<Vec<f64>>> {
90 let n = node_features.len();
91 let aggregated = self.aggregate(node_features, adjacency, n)?;
92 self.update(&aggregated, node_features)
93 }
94}
95
96fn validate_features(features: &[Vec<f64>]) -> Result<usize> {
101 if features.is_empty() {
102 return Ok(0);
103 }
104 let dim = features[0].len();
105 for (i, row) in features.iter().enumerate() {
106 if row.len() != dim {
107 return Err(GraphError::InvalidParameter {
108 param: "node_features".to_string(),
109 value: format!("row {} has {} dims, expected {}", i, row.len(), dim),
110 expected: format!("all rows must have {} dimensions", dim),
111 context: "GNN feature validation".to_string(),
112 });
113 }
114 }
115 Ok(dim)
116}
117
118fn relu(x: f64) -> f64 {
119 x.max(0.0)
120}
121
122fn dot(a: &[f64], b: &[f64]) -> f64 {
123 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
124}
125
126fn softmax_vec(xs: &[f64]) -> Vec<f64> {
127 if xs.is_empty() {
128 return Vec::new();
129 }
130 let max_val = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
131 let exps: Vec<f64> = xs.iter().map(|x| (x - max_val).exp()).collect();
132 let sum: f64 = exps.iter().sum::<f64>().max(1e-10);
133 exps.iter().map(|e| e / sum).collect()
134}
135
136fn matvec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
137 w.iter().map(|row| dot(row, x)).collect()
138}
139
140pub fn graph_to_adjacency<N, E, Ix>(graph: &Graph<N, E, Ix>) -> (Vec<N>, Vec<(usize, usize, f64)>)
142where
143 N: Node + Clone + std::fmt::Debug,
144 E: EdgeWeight + Clone + Into<f64>,
145 Ix: petgraph::graph::IndexType,
146{
147 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
148 let node_to_idx: HashMap<N, usize> = nodes
149 .iter()
150 .enumerate()
151 .map(|(i, n)| (n.clone(), i))
152 .collect();
153
154 let mut adjacency = Vec::new();
155 for edge in graph.edges() {
156 if let (Some(&si), Some(&ti)) =
157 (node_to_idx.get(&edge.source), node_to_idx.get(&edge.target))
158 {
159 let w: f64 = edge.weight.clone().into();
160 adjacency.push((si, ti, w));
161 adjacency.push((ti, si, w)); }
163 }
164
165 (nodes, adjacency)
166}
167
168#[derive(Debug, Clone)]
179pub struct GCNLayer {
180 pub weights: Vec<Vec<f64>>,
182 pub bias: Vec<f64>,
184 pub out_dim: usize,
186 pub aggregation: MessagePassing,
188 pub use_activation: bool,
190}
191
192impl GCNLayer {
193 pub fn new(in_dim: usize, out_dim: usize) -> Self {
195 let scale = (2.0 / (in_dim + out_dim) as f64).sqrt();
196 let mut weights = vec![vec![0.0f64; in_dim]; out_dim];
197 for (i, row) in weights.iter_mut().enumerate() {
198 for (j, w) in row.iter_mut().enumerate() {
199 *w = if i == j {
200 scale
201 } else {
202 scale * 0.01 * ((i as f64 - j as f64).sin())
203 };
204 }
205 }
206 GCNLayer {
207 weights,
208 bias: vec![0.0; out_dim],
209 out_dim,
210 aggregation: MessagePassing::Mean,
211 use_activation: true,
212 }
213 }
214
215 pub fn with_weights(mut self, weights: Vec<Vec<f64>>) -> Result<Self> {
217 if weights.len() != self.out_dim {
218 return Err(GraphError::InvalidParameter {
219 param: "weights".to_string(),
220 value: format!("rows={}", weights.len()),
221 expected: format!("rows={}", self.out_dim),
222 context: "GCNLayer::with_weights".to_string(),
223 });
224 }
225 self.weights = weights;
226 Ok(self)
227 }
228}
229
230impl MessagePassingLayer for GCNLayer {
231 fn aggregate(
232 &self,
233 node_features: &[Vec<f64>],
234 adjacency: &[(usize, usize, f64)],
235 n_nodes: usize,
236 ) -> Result<Vec<Vec<f64>>> {
237 let in_dim = validate_features(node_features)?;
238 if in_dim == 0 {
239 return Ok(Vec::new());
240 }
241
242 let mut deg = vec![1.0f64; n_nodes];
243 for &(src, dst, _) in adjacency {
244 deg[src] += 1.0;
245 let _ = dst;
246 }
247
248 let mut agg: Vec<Vec<f64>> = (0..n_nodes).map(|_| vec![0.0f64; in_dim]).collect();
249
250 for i in 0..n_nodes {
251 let d_inv = 1.0 / deg[i].sqrt();
252 for k in 0..in_dim {
253 agg[i][k] += d_inv * node_features[i][k] * d_inv;
254 }
255 }
256
257 for &(src, dst, w) in adjacency {
258 if src < n_nodes && dst < n_nodes {
259 let norm = w / (deg[src].sqrt() * deg[dst].sqrt());
260 for k in 0..in_dim {
261 agg[dst][k] += norm * node_features[src][k];
262 }
263 }
264 }
265
266 Ok(agg)
267 }
268
269 fn update(
270 &self,
271 aggregated: &[Vec<f64>],
272 _node_features: &[Vec<f64>],
273 ) -> Result<Vec<Vec<f64>>> {
274 let mut result = Vec::with_capacity(aggregated.len());
275 for agg in aggregated {
276 let mut h = matvec(&self.weights, agg);
277 for (hi, bi) in h.iter_mut().zip(self.bias.iter()) {
278 *hi += bi;
279 if self.use_activation {
280 *hi = relu(*hi);
281 }
282 }
283 result.push(h);
284 }
285 Ok(result)
286 }
287}
288
289#[derive(Debug, Clone)]
295pub struct GraphSAGELayer {
296 pub weights: Vec<Vec<f64>>,
298 pub bias: Vec<f64>,
300 pub out_dim: usize,
302 pub aggregation: MessagePassing,
304 pub use_activation: bool,
306}
307
308impl GraphSAGELayer {
309 pub fn new(in_dim: usize, out_dim: usize) -> Self {
311 let concat_dim = 2 * in_dim;
312 let scale = (2.0 / (concat_dim + out_dim) as f64).sqrt();
313 let mut weights = vec![vec![0.0f64; concat_dim]; out_dim];
314 for (i, row) in weights.iter_mut().enumerate() {
315 for (j, w) in row.iter_mut().enumerate() {
316 *w = if i == j % out_dim {
317 scale
318 } else {
319 scale * 0.01 * ((i as f64 - j as f64).cos())
320 };
321 }
322 }
323 GraphSAGELayer {
324 weights,
325 bias: vec![0.0; out_dim],
326 out_dim,
327 aggregation: MessagePassing::Mean,
328 use_activation: true,
329 }
330 }
331}
332
333impl MessagePassingLayer for GraphSAGELayer {
334 fn aggregate(
335 &self,
336 node_features: &[Vec<f64>],
337 adjacency: &[(usize, usize, f64)],
338 n_nodes: usize,
339 ) -> Result<Vec<Vec<f64>>> {
340 let in_dim = validate_features(node_features)?;
341 if in_dim == 0 {
342 return Ok(Vec::new());
343 }
344
345 let mut neighbor_sums: Vec<Vec<f64>> = (0..n_nodes).map(|_| vec![0.0f64; in_dim]).collect();
346 let mut neighbor_counts: Vec<f64> = vec![0.0; n_nodes];
347 let mut neighbor_max: Vec<Vec<f64>> = (0..n_nodes)
348 .map(|_| vec![f64::NEG_INFINITY; in_dim])
349 .collect();
350 let mut neighbor_min: Vec<Vec<f64>> =
351 (0..n_nodes).map(|_| vec![f64::INFINITY; in_dim]).collect();
352
353 for &(src, dst, _) in adjacency {
354 if src < n_nodes && dst < n_nodes {
355 neighbor_counts[dst] += 1.0;
356 for k in 0..in_dim {
357 neighbor_sums[dst][k] += node_features[src][k];
358 if node_features[src][k] > neighbor_max[dst][k] {
359 neighbor_max[dst][k] = node_features[src][k];
360 }
361 if node_features[src][k] < neighbor_min[dst][k] {
362 neighbor_min[dst][k] = node_features[src][k];
363 }
364 }
365 }
366 }
367
368 let agg_neighbor: Vec<Vec<f64>> = (0..n_nodes)
369 .map(|i| {
370 let count = neighbor_counts[i].max(1.0);
371 match &self.aggregation {
372 MessagePassing::Sum => neighbor_sums[i].clone(),
373 MessagePassing::Mean => neighbor_sums[i].iter().map(|s| s / count).collect(),
374 MessagePassing::Max => neighbor_max[i]
375 .iter()
376 .map(|&v| if v == f64::NEG_INFINITY { 0.0 } else { v })
377 .collect(),
378 MessagePassing::Min => neighbor_min[i]
379 .iter()
380 .map(|&v| if v == f64::INFINITY { 0.0 } else { v })
381 .collect(),
382 MessagePassing::Attention => {
383 neighbor_sums[i].iter().map(|s| s / count).collect()
384 }
385 }
386 })
387 .collect();
388
389 let concat: Vec<Vec<f64>> = node_features
390 .iter()
391 .zip(agg_neighbor.iter())
392 .map(|(self_feat, nbr)| {
393 let mut cat = self_feat.clone();
394 cat.extend_from_slice(nbr);
395 cat
396 })
397 .collect();
398
399 Ok(concat)
400 }
401
402 fn update(
403 &self,
404 aggregated: &[Vec<f64>],
405 _node_features: &[Vec<f64>],
406 ) -> Result<Vec<Vec<f64>>> {
407 let mut result = Vec::with_capacity(aggregated.len());
408 for agg in aggregated {
409 let mut h = matvec(&self.weights, agg);
410 for (hi, bi) in h.iter_mut().zip(self.bias.iter()) {
411 *hi += bi;
412 if self.use_activation {
413 *hi = relu(*hi);
414 }
415 }
416 let norm: f64 = h.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
417 h.iter_mut().for_each(|x| *x /= norm);
418 result.push(h);
419 }
420 Ok(result)
421 }
422}
423
424#[derive(Debug, Clone)]
430pub struct GATLayer {
431 pub weights: Vec<Vec<f64>>,
433 pub attention_weights: Vec<f64>,
435 pub out_dim: usize,
437 pub negative_slope: f64,
439 pub use_activation: bool,
441}
442
443impl GATLayer {
444 pub fn new(in_dim: usize, out_dim: usize) -> Self {
446 let scale = (2.0 / (in_dim + out_dim) as f64).sqrt();
447 let mut weights = vec![vec![0.0f64; in_dim]; out_dim];
448 for (i, row) in weights.iter_mut().enumerate() {
449 for (j, w) in row.iter_mut().enumerate() {
450 *w = if i == j { scale } else { scale * 0.01 };
451 }
452 }
453 let attention_weights: Vec<f64> = (0..2 * out_dim)
454 .map(|i| if i % 2 == 0 { 0.5 } else { -0.5 })
455 .collect();
456 GATLayer {
457 weights,
458 attention_weights,
459 out_dim,
460 negative_slope: 0.2,
461 use_activation: true,
462 }
463 }
464
465 fn leaky_relu(&self, x: f64) -> f64 {
466 if x >= 0.0 {
467 x
468 } else {
469 self.negative_slope * x
470 }
471 }
472}
473
474impl MessagePassingLayer for GATLayer {
475 fn aggregate(
476 &self,
477 node_features: &[Vec<f64>],
478 adjacency: &[(usize, usize, f64)],
479 n_nodes: usize,
480 ) -> Result<Vec<Vec<f64>>> {
481 let _in_dim = validate_features(node_features)?;
482
483 let transformed: Vec<Vec<f64>> = node_features
484 .iter()
485 .map(|h| matvec(&self.weights, h))
486 .collect();
487
488 let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n_nodes];
489 for &(src, dst, _) in adjacency {
490 if src < n_nodes && dst < n_nodes {
491 neighbors[dst].push(src);
492 }
493 }
494 for i in 0..n_nodes {
495 if !neighbors[i].contains(&i) {
496 neighbors[i].push(i);
497 }
498 }
499
500 let mut aggregated: Vec<Vec<f64>> = vec![vec![0.0; self.out_dim]; n_nodes];
501
502 for i in 0..n_nodes {
503 let nbrs = &neighbors[i];
504 if nbrs.is_empty() {
505 continue;
506 }
507
508 let scores: Vec<f64> = nbrs
509 .iter()
510 .map(|&j| {
511 let mut concat = transformed[i].clone();
512 concat.extend_from_slice(&transformed[j]);
513 let e = dot(&self.attention_weights, &concat);
514 self.leaky_relu(e)
515 })
516 .collect();
517
518 let alphas = softmax_vec(&scores);
519
520 for (k, &j) in nbrs.iter().enumerate() {
521 let alpha = alphas[k];
522 for d in 0..self.out_dim {
523 aggregated[i][d] += alpha * transformed[j][d];
524 }
525 }
526 }
527
528 Ok(aggregated)
529 }
530
531 fn update(
532 &self,
533 aggregated: &[Vec<f64>],
534 _node_features: &[Vec<f64>],
535 ) -> Result<Vec<Vec<f64>>> {
536 if !self.use_activation {
537 return Ok(aggregated.to_vec());
538 }
539 let result: Vec<Vec<f64>> = aggregated
540 .iter()
541 .map(|row| {
542 row.iter()
543 .map(|&x| if x >= 0.0 { x } else { x.exp() - 1.0 })
544 .collect()
545 })
546 .collect();
547 Ok(result)
548 }
549}
550
551#[derive(Debug, Clone)]
557pub struct NodeEmbeddings {
558 pub node_names: Vec<String>,
560 pub embeddings: Vec<Vec<f64>>,
562 pub dim: usize,
564}
565
566impl NodeEmbeddings {
567 pub fn new(embeddings: Vec<Vec<f64>>) -> Result<Self> {
569 let dim = validate_features(&embeddings)?;
570 let n = embeddings.len();
571 Ok(NodeEmbeddings {
572 node_names: (0..n).map(|i| i.to_string()).collect(),
573 embeddings,
574 dim,
575 })
576 }
577
578 pub fn random(n_nodes: usize, dim: usize) -> Self {
580 let mut rng = scirs2_core::random::rng();
581 let embeddings: Vec<Vec<f64>> = (0..n_nodes)
582 .map(|_| (0..dim).map(|_| rng.random::<f64>() * 2.0 - 1.0).collect())
583 .collect();
584 NodeEmbeddings {
585 node_names: (0..n_nodes).map(|i| i.to_string()).collect(),
586 embeddings,
587 dim,
588 }
589 }
590
591 pub fn one_hot(n_nodes: usize) -> Self {
593 let embeddings: Vec<Vec<f64>> = (0..n_nodes)
594 .map(|i| {
595 let mut v = vec![0.0f64; n_nodes];
596 v[i] = 1.0;
597 v
598 })
599 .collect();
600 NodeEmbeddings {
601 node_names: (0..n_nodes).map(|i| i.to_string()).collect(),
602 embeddings,
603 dim: n_nodes,
604 }
605 }
606
607 pub fn n_nodes(&self) -> usize {
609 self.embeddings.len()
610 }
611
612 pub fn get(&self, i: usize) -> Option<&Vec<f64>> {
614 self.embeddings.get(i)
615 }
616
617 pub fn apply_layer<L: MessagePassingLayer>(
619 &self,
620 layer: &L,
621 adjacency: &[(usize, usize, f64)],
622 ) -> Result<NodeEmbeddings> {
623 let new_embeddings = layer.forward(&self.embeddings, adjacency)?;
624 let dim = validate_features(&new_embeddings)?;
625 Ok(NodeEmbeddings {
626 node_names: self.node_names.clone(),
627 embeddings: new_embeddings,
628 dim,
629 })
630 }
631}
632
633pub fn run_gnn_pipeline<N, E, Ix, L>(
635 graph: &Graph<N, E, Ix>,
636 initial_features: Option<NodeEmbeddings>,
637 layers: &[L],
638) -> Result<NodeEmbeddings>
639where
640 N: Node + Clone + std::fmt::Debug,
641 E: EdgeWeight + Clone + Into<f64>,
642 Ix: petgraph::graph::IndexType,
643 L: MessagePassingLayer,
644{
645 let (_, adjacency) = graph_to_adjacency(graph);
646 let n = graph.nodes().len();
647
648 let mut embeddings = match initial_features {
649 Some(e) => e,
650 None => NodeEmbeddings::one_hot(n),
651 };
652
653 for layer in layers {
654 embeddings = embeddings.apply_layer(layer, &adjacency)?;
655 }
656
657 Ok(embeddings)
658}
659
660#[cfg(test)]
665mod tests {
666 use super::*;
667 use crate::base::Graph;
668
669 type TriangleGraph = (Graph<usize, f64>, Vec<(usize, usize, f64)>);
670
671 fn make_triangle_graph() -> TriangleGraph {
672 let mut g: Graph<usize, f64> = Graph::new();
673 let _ = g.add_edge(0, 1, 1.0);
674 let _ = g.add_edge(1, 2, 1.0);
675 let _ = g.add_edge(0, 2, 1.0);
676 let (_, adj) = graph_to_adjacency(&g);
677 (g, adj)
678 }
679
680 fn make_features(n: usize, dim: usize) -> Vec<Vec<f64>> {
681 (0..n)
682 .map(|i| (0..dim).map(|j| (i * dim + j) as f64 / 10.0).collect())
683 .collect()
684 }
685
686 #[test]
687 fn test_gcn_layer_output_shape() {
688 let (_, adj) = make_triangle_graph();
689 let features = make_features(3, 4);
690 let layer = GCNLayer::new(4, 8);
691 let out = layer.forward(&features, &adj).expect("GCN forward failed");
692 assert_eq!(out.len(), 3);
693 assert_eq!(out[0].len(), 8);
694 }
695
696 #[test]
697 fn test_graphsage_layer_output_shape() {
698 let (_, adj) = make_triangle_graph();
699 let features = make_features(3, 4);
700 let layer = GraphSAGELayer::new(4, 6);
701 let out = layer.forward(&features, &adj).expect("SAGE forward failed");
702 assert_eq!(out.len(), 3);
703 assert_eq!(out[0].len(), 6);
704 }
705
706 #[test]
707 fn test_gat_layer_output_shape() {
708 let (_, adj) = make_triangle_graph();
709 let features = make_features(3, 4);
710 let layer = GATLayer::new(4, 8);
711 let out = layer.forward(&features, &adj).expect("GAT forward failed");
712 assert_eq!(out.len(), 3);
713 assert_eq!(out[0].len(), 8);
714 }
715
716 #[test]
717 fn test_node_embeddings_one_hot() {
718 let emb = NodeEmbeddings::one_hot(3);
719 assert_eq!(emb.n_nodes(), 3);
720 assert_eq!(emb.dim, 3);
721 let row0 = emb.get(0).expect("No embedding for node 0");
722 assert!((row0[0] - 1.0).abs() < 1e-10);
723 assert!((row0[1]).abs() < 1e-10);
724 }
725
726 #[test]
727 fn test_run_gnn_pipeline() {
728 let mut g: Graph<usize, f64> = Graph::new();
729 let _ = g.add_edge(0, 1, 1.0);
730 let _ = g.add_edge(1, 2, 1.0);
731 let _ = g.add_edge(2, 3, 1.0);
732 let layers = vec![GCNLayer::new(4, 4), GCNLayer::new(4, 4)];
733 let features = NodeEmbeddings::new(make_features(4, 4)).expect("Features");
734 let result = run_gnn_pipeline(&g, Some(features), &layers).expect("Pipeline");
735 assert_eq!(result.n_nodes(), 4);
736 }
737}