1use scirs2_core::ndarray::{Array1, Array2};
18use scirs2_core::random::{Rng, RngExt};
19
20use crate::error::{GraphError, Result};
21use crate::gnn::gcn::CsrMatrix;
22
23#[derive(Debug, Clone, PartialEq, Eq, Default)]
29pub enum SageAggregation {
30 #[default]
32 Mean,
33 Max,
35 Sum,
37 Lstm,
39}
40
41pub fn sample_neighbors(adj: &CsrMatrix, k: usize) -> Vec<Vec<usize>> {
58 let n = adj.n_rows;
59 let mut rng = scirs2_core::random::rng();
60
61 (0..n)
62 .map(|i| {
63 let start = adj.indptr[i];
64 let end = adj.indptr[i + 1];
65 let neighbors: Vec<usize> = adj.indices[start..end].to_vec();
66 if neighbors.len() <= k {
67 neighbors
68 } else {
69 let mut reservoir: Vec<usize> = neighbors[..k].to_vec();
71 for idx in k..neighbors.len() {
72 let j = (rng.random::<f64>() * (idx + 1) as f64) as usize;
73 if j < k {
74 reservoir[j] = neighbors[idx];
75 }
76 }
77 reservoir
78 }
79 })
80 .collect()
81}
82
83pub fn sage_aggregate(
98 adj: &CsrMatrix,
99 features: &Array2<f64>,
100 aggr_type: &SageAggregation,
101) -> Result<Array2<f64>> {
102 let n = adj.n_rows;
103 let (feat_n, feat_dim) = features.dim();
104
105 if feat_n != n {
106 return Err(GraphError::InvalidParameter {
107 param: "features".to_string(),
108 value: format!("{feat_n} rows"),
109 expected: format!("{n} rows (matching adj.n_rows)"),
110 context: "sage_aggregate".to_string(),
111 });
112 }
113
114 let mut agg = Array2::<f64>::zeros((n, feat_dim));
115
116 match aggr_type {
117 SageAggregation::Mean | SageAggregation::Sum => {
118 let mut counts = vec![0usize; n];
119 for (row, col, _) in adj.iter() {
120 if col < feat_n {
121 counts[row] += 1;
122 for k in 0..feat_dim {
123 agg[[row, k]] += features[[col, k]];
124 }
125 }
126 }
127 if *aggr_type == SageAggregation::Mean {
128 for i in 0..n {
129 if counts[i] > 0 {
130 let inv = 1.0 / counts[i] as f64;
131 for k in 0..feat_dim {
132 agg[[i, k]] *= inv;
133 }
134 }
135 }
136 }
137 }
138
139 SageAggregation::Max => {
140 let mut initialized = vec![false; n];
142 for (row, col, _) in adj.iter() {
143 if col < feat_n {
144 if !initialized[row] {
145 for k in 0..feat_dim {
146 agg[[row, k]] = features[[col, k]];
147 }
148 initialized[row] = true;
149 } else {
150 for k in 0..feat_dim {
151 if features[[col, k]] > agg[[row, k]] {
152 agg[[row, k]] = features[[col, k]];
153 }
154 }
155 }
156 }
157 }
158 }
160
161 SageAggregation::Lstm => {
162 for i in 0..n {
168 let start = adj.indptr[i];
169 let end = adj.indptr[i + 1];
170 let neighbor_indices = &adj.indices[start..end];
171
172 if neighbor_indices.is_empty() {
173 continue;
174 }
175
176 let mut h = vec![0.0f64; feat_dim];
177 for &nb in neighbor_indices {
178 if nb < feat_n {
179 for k in 0..feat_dim {
180 let x = features[[nb, k]];
181 let z = 1.0 / (1.0 + (-(x + h[k])).exp());
183 h[k] = z * h[k] + (1.0 - z) * x;
184 }
185 }
186 }
187 for k in 0..feat_dim {
188 agg[[i, k]] = h[k];
189 }
190 }
191 }
192 }
193
194 Ok(agg)
195}
196
197#[derive(Debug, Clone)]
210pub struct GraphSageLayer {
211 pub weights: Array2<f64>,
213 pub bias: Option<Array1<f64>>,
215 pub in_dim: usize,
217 pub out_dim: usize,
219 pub aggregation: SageAggregation,
221 pub use_relu: bool,
223 pub normalize: bool,
225}
226
227impl GraphSageLayer {
228 pub fn new(in_dim: usize, out_dim: usize) -> Self {
234 let concat_dim = 2 * in_dim;
235 let scale = (6.0_f64 / (concat_dim + out_dim) as f64).sqrt();
236 let mut rng = scirs2_core::random::rng();
237 let weights = Array2::from_shape_fn((concat_dim, out_dim), |_| {
238 rng.random::<f64>() * 2.0 * scale - scale
239 });
240 GraphSageLayer {
241 weights,
242 bias: None,
243 in_dim,
244 out_dim,
245 aggregation: SageAggregation::Mean,
246 use_relu: true,
247 normalize: true,
248 }
249 }
250
251 pub fn with_aggregation(mut self, aggr: SageAggregation) -> Self {
253 self.aggregation = aggr;
254 self
255 }
256
257 pub fn without_normalize(mut self) -> Self {
259 self.normalize = false;
260 self
261 }
262
263 pub fn without_activation(mut self) -> Self {
265 self.use_relu = false;
266 self
267 }
268
269 pub fn forward(&self, adj: &CsrMatrix, features: &Array2<f64>) -> Result<Array2<f64>> {
275 let n = adj.n_rows;
276 let (feat_n, feat_dim) = features.dim();
277
278 if feat_n != n {
279 return Err(GraphError::InvalidParameter {
280 param: "features".to_string(),
281 value: format!("{feat_n}"),
282 expected: format!("{n}"),
283 context: "GraphSageLayer::forward".to_string(),
284 });
285 }
286 if feat_dim != self.in_dim {
287 return Err(GraphError::InvalidParameter {
288 param: "features feat_dim".to_string(),
289 value: format!("{feat_dim}"),
290 expected: format!("{}", self.in_dim),
291 context: "GraphSageLayer::forward".to_string(),
292 });
293 }
294
295 let agg = sage_aggregate(adj, features, &self.aggregation)?;
297
298 let concat_dim = 2 * self.in_dim;
300 let mut concat = Array2::<f64>::zeros((n, concat_dim));
301 for i in 0..n {
302 for k in 0..feat_dim {
303 concat[[i, k]] = features[[i, k]];
304 concat[[i, feat_dim + k]] = agg[[i, k]];
305 }
306 }
307
308 let (_, out_dim) = self.weights.dim();
310 let mut output = Array2::<f64>::zeros((n, out_dim));
311 for i in 0..n {
312 for j in 0..out_dim {
313 let mut sum = 0.0;
314 for k in 0..concat_dim {
315 sum += concat[[i, k]] * self.weights[[k, j]];
316 }
317 output[[i, j]] = sum;
318 }
319 }
320
321 if let Some(ref b) = self.bias {
323 for i in 0..n {
324 for j in 0..out_dim {
325 output[[i, j]] += b[j];
326 }
327 }
328 }
329
330 if self.use_relu {
332 output.mapv_inplace(|x| x.max(0.0));
333 }
334
335 if self.normalize {
337 for i in 0..n {
338 let norm = {
339 let row = output.row(i);
340 row.iter().map(|&x| x * x).sum::<f64>().sqrt()
341 };
342 if norm > 1e-10 {
343 for j in 0..out_dim {
344 output[[i, j]] /= norm;
345 }
346 }
347 }
348 }
349
350 Ok(output)
351 }
352}
353
354pub struct GraphSage {
362 pub layers: Vec<GraphSageLayer>,
364 pub neighbor_samples: Vec<Option<usize>>,
366}
367
368impl GraphSage {
369 pub fn new(dims: &[usize], aggr: SageAggregation) -> Result<Self> {
375 if dims.len() < 2 {
376 return Err(GraphError::InvalidParameter {
377 param: "dims".to_string(),
378 value: format!("len={}", dims.len()),
379 expected: "at least 2 elements".to_string(),
380 context: "GraphSage::new".to_string(),
381 });
382 }
383 let mut layers = Vec::with_capacity(dims.len() - 1);
384 for i in 0..(dims.len() - 1) {
385 let is_last = i == dims.len() - 2;
386 let mut layer =
387 GraphSageLayer::new(dims[i], dims[i + 1]).with_aggregation(aggr.clone());
388 if is_last {
389 layer = layer.without_activation();
390 }
391 layers.push(layer);
392 }
393 let neighbor_samples = vec![None; dims.len() - 1];
394 Ok(GraphSage {
395 layers,
396 neighbor_samples,
397 })
398 }
399
400 pub fn with_neighbor_samples(mut self, sizes: Vec<Option<usize>>) -> Result<Self> {
405 if sizes.len() != self.layers.len() {
406 return Err(GraphError::InvalidParameter {
407 param: "sizes".to_string(),
408 value: format!("len={}", sizes.len()),
409 expected: format!("len={}", self.layers.len()),
410 context: "GraphSage::with_neighbor_samples".to_string(),
411 });
412 }
413 self.neighbor_samples = sizes;
414 Ok(self)
415 }
416
417 pub fn forward(&self, adj: &CsrMatrix, features: &Array2<f64>) -> Result<Array2<f64>> {
423 let mut h = features.clone();
424 for (i, layer) in self.layers.iter().enumerate() {
425 let sampled_adj = if let Some(k) = self.neighbor_samples[i] {
427 let sampled = sample_neighbors(adj, k);
429 let mut coo = Vec::new();
430 for (node_i, nbrs) in sampled.iter().enumerate() {
431 for &nb in nbrs {
432 coo.push((node_i, nb, 1.0f64));
433 }
434 }
435 CsrMatrix::from_coo(adj.n_rows, adj.n_cols, &coo)?
436 } else {
437 adj.clone()
438 };
439 h = layer.forward(&sampled_adj, &h)?;
440 }
441 Ok(h)
442 }
443}
444
445#[cfg(test)]
450mod tests {
451 use super::*;
452
453 fn path_csr(n: usize) -> CsrMatrix {
454 let mut coo = Vec::new();
455 for i in 0..(n - 1) {
456 coo.push((i, i + 1, 1.0));
457 coo.push((i + 1, i, 1.0));
458 }
459 CsrMatrix::from_coo(n, n, &coo).expect("path CSR")
460 }
461
462 fn features(n: usize, d: usize) -> Array2<f64> {
463 Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
464 }
465
466 #[test]
467 fn test_mean_aggregate_shape() {
468 let adj = path_csr(4);
469 let feats = features(4, 6);
470 let agg = sage_aggregate(&adj, &feats, &SageAggregation::Mean).expect("mean agg");
471 assert_eq!(agg.dim(), (4, 6));
472 }
473
474 #[test]
475 fn test_max_aggregate_shape() {
476 let adj = path_csr(4);
477 let feats = features(4, 6);
478 let agg = sage_aggregate(&adj, &feats, &SageAggregation::Max).expect("max agg");
479 assert_eq!(agg.dim(), (4, 6));
480 }
481
482 #[test]
483 fn test_sum_aggregate_shape() {
484 let adj = path_csr(4);
485 let feats = features(4, 6);
486 let agg = sage_aggregate(&adj, &feats, &SageAggregation::Sum).expect("sum agg");
487 assert_eq!(agg.dim(), (4, 6));
488 }
489
490 #[test]
491 fn test_lstm_aggregate_shape() {
492 let adj = path_csr(4);
493 let feats = features(4, 6);
494 let agg = sage_aggregate(&adj, &feats, &SageAggregation::Lstm).expect("lstm agg");
495 assert_eq!(agg.dim(), (4, 6));
496 }
497
498 #[test]
499 fn test_sage_layer_output_shape() {
500 let adj = path_csr(5);
501 let feats = features(5, 4);
502 let layer = GraphSageLayer::new(4, 8);
503 let out = layer.forward(&adj, &feats).expect("sage forward");
504 assert_eq!(out.dim(), (5, 8));
505 }
506
507 #[test]
508 fn test_sage_layer_l2_normalization() {
509 let adj = path_csr(5);
510 let feats = features(5, 4);
511 let layer = GraphSageLayer::new(4, 8);
512 let out = layer.forward(&adj, &feats).expect("sage forward");
513 for i in 0..5 {
515 let norm: f64 = out.row(i).iter().map(|&x| x * x).sum::<f64>().sqrt();
516 assert!(
517 norm < 1e-10 || (norm - 1.0).abs() < 1e-9,
518 "norm={norm} for row {i}"
519 );
520 }
521 }
522
523 #[test]
524 fn test_graphsage_multilayer() {
525 let adj = path_csr(6);
526 let feats = features(6, 8);
527 let model = GraphSage::new(&[8, 16, 4], SageAggregation::Mean).expect("sage model");
528 let out = model.forward(&adj, &feats).expect("forward");
529 assert_eq!(out.dim(), (6, 4));
530 }
531
532 #[test]
533 fn test_neighbor_sampling() {
534 let adj = path_csr(4);
535 let sampled = sample_neighbors(&adj, 1);
536 assert_eq!(sampled.len(), 4);
537 assert!(sampled[1].len() <= 1);
539 assert!(sampled[2].len() <= 1);
540 }
541
542 #[test]
543 fn test_graphsage_with_sampling() {
544 let adj = path_csr(6);
545 let feats = features(6, 4);
546 let model = GraphSage::new(&[4, 8, 4], SageAggregation::Mean)
547 .expect("sage model")
548 .with_neighbor_samples(vec![Some(2), Some(2)])
549 .expect("samples");
550 let out = model.forward(&adj, &feats).expect("forward");
551 assert_eq!(out.dim(), (6, 4));
552 }
553
554 #[test]
555 fn test_sage_aggregation_isolated_node() {
556 let coo = vec![(1, 2, 1.0), (2, 1, 1.0)];
558 let adj = CsrMatrix::from_coo(3, 3, &coo).expect("isolated CSR");
559 let feats = features(3, 4);
560 let agg = sage_aggregate(&adj, &feats, &SageAggregation::Mean).expect("mean agg");
561 for k in 0..4 {
563 assert_eq!(agg[[0, k]], 0.0);
564 }
565 }
566}