1use scirs2_core::ndarray::{Array1, Array2, Axis};
23use scirs2_core::random::{Rng, RngExt, SeedableRng};
24
25#[derive(Debug, Clone)]
31pub struct GraphClConfig {
32 pub temperature: f64,
34 pub proj_dim: usize,
36 pub mask_feature_rate: f64,
38 pub drop_edge_rate: f64,
40 pub add_edge_rate: f64,
42}
43
44impl Default for GraphClConfig {
45 fn default() -> Self {
46 Self {
47 temperature: 0.5,
48 proj_dim: 128,
49 mask_feature_rate: 0.1,
50 drop_edge_rate: 0.1,
51 add_edge_rate: 0.0,
52 }
53 }
54}
55
56pub fn augment_features(features: &Array2<f64>, mask_rate: f64, seed: u64) -> Array2<f64> {
71 if mask_rate <= 0.0 {
72 return features.clone();
73 }
74 if mask_rate >= 1.0 {
75 return Array2::zeros(features.dim());
76 }
77
78 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
79 let mut out = features.clone();
80 let (n_nodes, feat_dim) = features.dim();
81
82 for i in 0..n_nodes {
83 for j in 0..feat_dim {
84 if rng.random::<f64>() < mask_rate {
85 out[[i, j]] = 0.0;
86 }
87 }
88 }
89 out
90}
91
92pub fn augment_edges(adj: &Array2<f64>, drop_rate: f64, add_rate: f64, seed: u64) -> Array2<f64> {
105 let n = adj.dim().0;
106 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
107 let mut out = adj.clone();
108
109 for i in 0..n {
110 for j in (i + 1)..n {
111 if adj[[i, j]] > 0.0 {
112 if drop_rate > 0.0 && rng.random::<f64>() < drop_rate {
114 out[[i, j]] = 0.0;
115 out[[j, i]] = 0.0;
116 }
117 } else {
118 if add_rate > 0.0 && rng.random::<f64>() < add_rate {
120 out[[i, j]] = 1.0;
121 out[[j, i]] = 1.0;
122 }
123 }
124 }
125 }
126 out
127}
128
129pub fn nt_xent_loss(z1: &Array2<f64>, z2: &Array2<f64>, temperature: f64) -> f64 {
147 let (n, _d) = z1.dim();
148 assert_eq!(z1.dim(), z2.dim(), "z1 and z2 must have the same shape");
149 assert!(temperature > 0.0, "temperature must be positive");
150
151 let norm_z1 = l2_normalise_rows(z1);
153 let norm_z2 = l2_normalise_rows(z2);
154
155 let mut stacked = Array2::zeros((2 * n, z1.dim().1));
157 for i in 0..n {
158 for d in 0..z1.dim().1 {
159 stacked[[i, d]] = norm_z1[[i, d]];
160 stacked[[i + n, d]] = norm_z2[[i, d]];
161 }
162 }
163
164 let two_n = 2 * n;
166 let mut sim = Array2::zeros((two_n, two_n));
167 for i in 0..two_n {
168 for j in 0..two_n {
169 let mut dot = 0.0;
170 for d in 0..stacked.dim().1 {
171 dot += stacked[[i, d]] * stacked[[j, d]];
172 }
173 sim[[i, j]] = dot / temperature;
174 }
175 }
176
177 for i in 0..two_n {
179 sim[[i, i]] = f64::NEG_INFINITY;
180 }
181
182 let mut loss = 0.0;
186 for i in 0..two_n {
187 let pos_j = if i < n { i + n } else { i - n };
188 let pos_score = sim[[i, pos_j]];
189
190 let row_scores: Vec<f64> = (0..two_n)
192 .filter(|&j| j != i)
193 .map(|j| sim[[i, j]])
194 .collect();
195 let max_s = row_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
196 let log_sum_exp = max_s
197 + row_scores
198 .iter()
199 .map(|&s| (s - max_s).exp())
200 .sum::<f64>()
201 .ln();
202
203 loss += -(pos_score - log_sum_exp);
204 }
205
206 loss / two_n as f64
207}
208
209fn l2_normalise_rows(x: &Array2<f64>) -> Array2<f64> {
211 let norms: Array1<f64> = x.map_axis(Axis(1), |row| {
212 let s: f64 = row.iter().map(|&v| v * v).sum();
213 s.sqrt().max(1e-12)
214 });
215 let mut out = x.clone();
216 let (n, _d) = x.dim();
217 for i in 0..n {
218 for d in 0.._d {
219 out[[i, d]] /= norms[i];
220 }
221 }
222 out
223}
224
225pub struct ProjectionHead {
233 w1: Array2<f64>,
234 b1: Array1<f64>,
235 w2: Array2<f64>,
236 b2: Array1<f64>,
237}
238
239impl ProjectionHead {
240 pub fn new(in_dim: usize, hidden_dim: usize, out_dim: usize, seed: u64) -> Self {
248 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
249
250 let s1 = (6.0 / (in_dim + hidden_dim) as f64).sqrt();
251 let w1 = Array2::from_shape_fn((in_dim, hidden_dim), |_| {
252 rng.random::<f64>() * 2.0 * s1 - s1
253 });
254 let b1 = Array1::zeros(hidden_dim);
255
256 let s2 = (6.0 / (hidden_dim + out_dim) as f64).sqrt();
257 let w2 = Array2::from_shape_fn((hidden_dim, out_dim), |_| {
258 rng.random::<f64>() * 2.0 * s2 - s2
259 });
260 let b2 = Array1::zeros(out_dim);
261
262 ProjectionHead { w1, b1, w2, b2 }
263 }
264
265 pub fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
273 let batch = x.dim().0;
274 let hidden_dim = self.w1.dim().1;
275 let out_dim = self.w2.dim().1;
276
277 let mut h = Array2::zeros((batch, hidden_dim));
279 for i in 0..batch {
280 for j in 0..hidden_dim {
281 let mut val = self.b1[j];
282 for d in 0..x.dim().1 {
283 val += x[[i, d]] * self.w1[[d, j]];
284 }
285 h[[i, j]] = if val > 0.0 { val } else { 0.0 };
286 }
287 }
288
289 let mut out = Array2::zeros((batch, out_dim));
291 for i in 0..batch {
292 for k in 0..out_dim {
293 let mut val = self.b2[k];
294 for j in 0..hidden_dim {
295 val += h[[i, j]] * self.w2[[j, k]];
296 }
297 out[[i, k]] = val;
298 }
299 }
300
301 out
302 }
303
304 pub fn in_dim(&self) -> usize {
306 self.w1.dim().0
307 }
308
309 pub fn out_dim(&self) -> usize {
311 self.w2.dim().1
312 }
313}
314
315pub fn simgrace_perturb(weights: &Array2<f64>, noise_scale: f64, seed: u64) -> Array2<f64> {
333 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
334 weights.mapv(|v| {
335 let u1: f64 = rng.random::<f64>().max(1e-12);
337 let u2: f64 = rng.random::<f64>();
338 let noise = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
339 v + noise_scale * noise
340 })
341}
342
343#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_augment_features_zero_rate_identity() {
353 let x = Array2::from_shape_vec((3, 4), (0..12).map(|v| v as f64).collect()).expect("ok");
354 let out = augment_features(&x, 0.0, 0);
355 for (a, b) in x.iter().zip(out.iter()) {
356 assert_eq!(a, b);
357 }
358 }
359
360 #[test]
361 fn test_augment_features_full_rate_zeros() {
362 let x = Array2::ones((5, 8));
363 let out = augment_features(&x, 1.0, 0);
364 for v in out.iter() {
365 assert_eq!(*v, 0.0);
366 }
367 }
368
369 #[test]
370 fn test_nt_xent_identical_views_low_loss() {
371 let z = Array2::from_shape_fn((8, 16), |(i, j)| if i == j { 1.0 } else { 0.0 });
373 let loss = nt_xent_loss(&z, &z, 0.5);
374 assert!(loss >= 0.0, "loss should be non-negative, got {loss}");
377 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(0);
379 let z_rand = Array2::from_shape_fn((8, 16), |_| rng.random::<f64>() - 0.5);
380 let loss_rand = nt_xent_loss(&z_rand, &z_rand, 0.5);
381 assert!(loss <= loss_rand + 1e-6);
383 }
384
385 #[test]
386 fn test_nt_xent_random_views_positive_loss() {
387 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(42);
388 let z1 = Array2::from_shape_fn((6, 8), |_| rng.random::<f64>() - 0.5);
389 let z2 = Array2::from_shape_fn((6, 8), |_| rng.random::<f64>() - 0.5);
390 let loss = nt_xent_loss(&z1, &z2, 0.5);
391 assert!(
392 loss > 0.0,
393 "loss with random views should be positive, got {loss}"
394 );
395 }
396
397 #[test]
398 fn test_projection_head_output_shape() {
399 let head = ProjectionHead::new(32, 64, 16, 0);
400 let x = Array2::ones((10, 32));
401 let out = head.forward(&x);
402 assert_eq!(out.dim(), (10, 16));
403 }
404
405 #[test]
406 fn test_projection_head_dims() {
407 let head = ProjectionHead::new(32, 64, 16, 0);
408 assert_eq!(head.in_dim(), 32);
409 assert_eq!(head.out_dim(), 16);
410 }
411
412 #[test]
413 fn test_simgrace_perturb_changes_weights() {
414 let w = Array2::ones((8, 8));
415 let perturbed = simgrace_perturb(&w, 0.1, 99);
416 let diff: f64 = w
417 .iter()
418 .zip(perturbed.iter())
419 .map(|(a, b)| (a - b).abs())
420 .sum();
421 assert!(
422 diff > 1e-10,
423 "perturbed weights should differ from original"
424 );
425 }
426
427 #[test]
428 fn test_simgrace_zero_noise_preserves_weights() {
429 let w = Array2::ones((4, 4));
430 let perturbed = simgrace_perturb(&w, 0.0, 0);
431 for (a, b) in w.iter().zip(perturbed.iter()) {
432 assert!((a - b).abs() < 1e-12);
433 }
434 }
435
436 #[test]
437 fn test_augment_edges_symmetry() {
438 let mut adj = Array2::zeros((4, 4));
440 adj[[0, 1]] = 1.0;
441 adj[[1, 0]] = 1.0;
442 adj[[1, 2]] = 1.0;
443 adj[[2, 1]] = 1.0;
444 adj[[2, 3]] = 1.0;
445 adj[[3, 2]] = 1.0;
446
447 let aug = augment_edges(&adj, 0.3, 0.1, 7);
448 let n = 4;
449 for i in 0..n {
450 for j in 0..n {
451 assert_eq!(
452 aug[[i, j]],
453 aug[[j, i]],
454 "augmented adjacency must remain symmetric at ({i},{j})"
455 );
456 }
457 }
458 }
459
460 #[test]
461 fn test_temperature_sensitivity() {
462 let z = Array2::from_shape_fn((4, 8), |(i, j)| if i == j { 1.0 } else { 0.0 });
464 let loss_low_t = nt_xent_loss(&z, &z, 0.1);
465 let loss_high_t = nt_xent_loss(&z, &z, 2.0);
466 assert!(loss_low_t >= 0.0);
468 assert!(loss_high_t >= 0.0);
469 assert!(
470 (loss_low_t - loss_high_t).abs() > 1e-6,
471 "temperature should affect loss magnitude"
472 );
473 }
474}