1use crate::error::GnnError;
16use crate::layer::{LayerNorm, Linear};
17use rand::seq::SliceRandom;
18use rand::Rng;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum LossFn {
23 Sce {
25 gamma: f32,
27 },
28 Mse,
30}
31
32impl Default for LossFn {
33 fn default() -> Self {
34 Self::Sce { gamma: 2.0 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct GraphMAEConfig {
41 pub mask_ratio: f32,
43 pub num_layers: usize,
45 pub hidden_dim: usize,
47 pub num_heads: usize,
49 pub decoder_layers: usize,
51 pub re_mask_ratio: f32,
53 pub loss_fn: LossFn,
55 pub input_dim: usize,
57}
58
59impl Default for GraphMAEConfig {
60 fn default() -> Self {
61 Self {
62 mask_ratio: 0.5,
63 num_layers: 2,
64 hidden_dim: 64,
65 num_heads: 4,
66 decoder_layers: 1,
67 re_mask_ratio: 0.0,
68 loss_fn: LossFn::default(),
69 input_dim: 64,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct GraphData {
77 pub node_features: Vec<Vec<f32>>,
79 pub adjacency: Vec<Vec<usize>>,
81 pub num_nodes: usize,
83}
84
85#[derive(Debug, Clone)]
87pub struct MaskResult {
88 pub masked_features: Vec<Vec<f32>>,
90 pub mask_indices: Vec<usize>,
92}
93
94pub struct FeatureMasking {
96 mask_token: Vec<f32>,
97}
98
99impl FeatureMasking {
100 pub fn new(dim: usize) -> Self {
102 let mut rng = rand::thread_rng();
103 Self {
104 mask_token: (0..dim).map(|_| rng.gen::<f32>() * 0.02 - 0.01).collect(),
105 }
106 }
107
108 pub fn mask_nodes(&self, features: &[Vec<f32>], mask_ratio: f32) -> MaskResult {
110 let n = features.len();
111 let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
112 let mut rng = rand::thread_rng();
113 let mut indices: Vec<usize> = (0..n).collect();
114 indices.shuffle(&mut rng);
115 let mask_indices = indices[..num_mask.min(n)].to_vec();
116 let mut masked = features.to_vec();
117 for &i in &mask_indices {
118 masked[i] = self.mask_token.clone();
119 }
120 MaskResult {
121 masked_features: masked,
122 mask_indices,
123 }
124 }
125
126 pub fn mask_by_degree(
128 &self,
129 features: &[Vec<f32>],
130 adjacency: &[Vec<usize>],
131 mask_ratio: f32,
132 ) -> MaskResult {
133 let n = features.len();
134 let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
135 let degrees: Vec<f32> = adjacency.iter().map(|a| a.len() as f32 + 1.0).collect();
136 let total: f32 = degrees.iter().sum();
137 let probs: Vec<f32> = degrees.iter().map(|d| d / total).collect();
138 let mut rng = rand::thread_rng();
139 let mut avail: Vec<usize> = (0..n).collect();
140 let mut mask_indices = Vec::with_capacity(num_mask);
141 for _ in 0..num_mask.min(n) {
142 if avail.is_empty() {
143 break;
144 }
145 let rp: Vec<f32> = avail.iter().map(|&i| probs[i]).collect();
146 let s: f32 = rp.iter().sum();
147 if s <= 0.0 {
148 break;
149 }
150 let thr = rng.gen::<f32>() * s;
151 let mut cum = 0.0;
152 let mut chosen = 0;
153 for (pos, &p) in rp.iter().enumerate() {
154 cum += p;
155 if cum >= thr {
156 chosen = pos;
157 break;
158 }
159 }
160 mask_indices.push(avail[chosen]);
161 avail.swap_remove(chosen);
162 }
163 let mut masked = features.to_vec();
164 for &i in &mask_indices {
165 masked[i] = self.mask_token.clone();
166 }
167 MaskResult {
168 masked_features: masked,
169 mask_indices,
170 }
171 }
172}
173
174struct GATLayer {
176 linear: Linear,
177 attn_src: Vec<f32>,
178 attn_dst: Vec<f32>,
179 norm: LayerNorm,
180 num_heads: usize,
181}
182
183impl GATLayer {
184 fn new(input_dim: usize, output_dim: usize, num_heads: usize) -> Self {
185 let mut rng = rand::thread_rng();
186 let hd = output_dim / num_heads.max(1);
187 Self {
188 linear: Linear::new(input_dim, output_dim),
189 attn_src: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
190 attn_dst: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
191 norm: LayerNorm::new(output_dim, 1e-5),
192 num_heads,
193 }
194 }
195
196 fn forward(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
197 let proj: Vec<Vec<f32>> = features.iter().map(|f| self.linear.forward(f)).collect();
198 let od = proj.first().map_or(0, |v| v.len());
199 let hd = od / self.num_heads.max(1);
200 let mut output = Vec::with_capacity(features.len());
201 for i in 0..features.len() {
202 if adj[i].is_empty() {
203 output.push(elu_vec(&proj[i]));
204 continue;
205 }
206 let mut agg = vec![0.0f32; od];
207 for h in 0..self.num_heads {
208 let (s, e) = (h * hd, (h + 1) * hd);
209 let ss: f32 = proj[i][s..e]
210 .iter()
211 .zip(&self.attn_src)
212 .map(|(a, b)| a * b)
213 .sum();
214 let mut scores: Vec<f32> = adj[i]
215 .iter()
216 .map(|&j| {
217 let ds: f32 = proj[j][s..e]
218 .iter()
219 .zip(&self.attn_dst)
220 .map(|(a, b)| a * b)
221 .sum();
222 let v = ss + ds;
223 if v >= 0.0 {
224 v
225 } else {
226 0.2 * v
227 } })
229 .collect();
230 let mx = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
231 let exp: Vec<f32> = scores.iter_mut().map(|v| (*v - mx).exp()).collect();
232 let sm = exp.iter().sum::<f32>().max(1e-10);
233 for (k, &j) in adj[i].iter().enumerate() {
234 let w = exp[k] / sm;
235 for d in s..e {
236 agg[d] += w * proj[j][d];
237 }
238 }
239 }
240 for v in &mut agg {
241 *v /= self.num_heads as f32;
242 }
243 if features[i].len() == od {
244 for (a, &f) in agg.iter_mut().zip(features[i].iter()) {
245 *a += f;
246 }
247 }
248 output.push(elu_vec(&self.norm.forward(&agg)));
249 }
250 output
251 }
252}
253
254pub struct GATEncoder {
256 layers: Vec<GATLayer>,
257}
258
259impl GATEncoder {
260 pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize, num_heads: usize) -> Self {
262 let layers = (0..num_layers)
263 .map(|i| {
264 GATLayer::new(
265 if i == 0 { input_dim } else { hidden_dim },
266 hidden_dim,
267 num_heads,
268 )
269 })
270 .collect();
271 Self { layers }
272 }
273
274 pub fn encode(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
276 self.layers
277 .iter()
278 .fold(features.to_vec(), |h, l| l.forward(&h, adj))
279 }
280}
281
282pub struct GraphMAEDecoder {
284 layers: Vec<Linear>,
285 norm: LayerNorm,
286}
287
288impl GraphMAEDecoder {
289 pub fn new(hidden_dim: usize, output_dim: usize, num_layers: usize) -> Self {
291 let n = num_layers.max(1);
292 let layers = (0..n)
293 .map(|i| {
294 let out = if i == n - 1 { output_dim } else { hidden_dim };
295 Linear::new(hidden_dim, out)
296 })
297 .collect();
298 Self {
299 layers,
300 norm: LayerNorm::new(output_dim, 1e-5),
301 }
302 }
303
304 pub fn decode(&self, latent: &[Vec<f32>], mask_idx: &[usize], re_mask: f32) -> Vec<Vec<f32>> {
306 let mut rng = rand::thread_rng();
307 mask_idx
308 .iter()
309 .map(|&idx| {
310 let mut h = latent[idx].clone();
311 if re_mask > 0.0 {
312 let nz = ((h.len() as f32) * re_mask).round() as usize;
313 let mut dims: Vec<usize> = (0..h.len()).collect();
314 dims.shuffle(&mut rng);
315 for &d in dims.iter().take(nz) {
316 h[d] = 0.0;
317 }
318 }
319 for layer in &self.layers {
320 h = elu_vec(&layer.forward(&h));
321 }
322 self.norm.forward(&h)
323 })
324 .collect()
325 }
326}
327
328pub fn sce_loss(preds: &[Vec<f32>], targets: &[Vec<f32>], gamma: f32) -> f32 {
330 if preds.is_empty() {
331 return 0.0;
332 }
333 preds
334 .iter()
335 .zip(targets)
336 .map(|(p, t)| {
337 let dot: f32 = p.iter().zip(t).map(|(a, b)| a * b).sum();
338 let np = p.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
339 let nt = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
340 (1.0 - (dot / (np * nt)).clamp(-1.0, 1.0)).powf(gamma)
341 })
342 .sum::<f32>()
343 / preds.len() as f32
344}
345
346pub fn mse_loss(preds: &[Vec<f32>], targets: &[Vec<f32>]) -> f32 {
348 if preds.is_empty() {
349 return 0.0;
350 }
351 let n: usize = preds.iter().map(|v| v.len()).sum();
352 if n == 0 {
353 return 0.0;
354 }
355 preds
356 .iter()
357 .zip(targets)
358 .flat_map(|(p, t)| p.iter().zip(t).map(|(a, b)| (a - b).powi(2)))
359 .sum::<f32>()
360 / n as f32
361}
362
363pub struct GraphMAE {
365 config: GraphMAEConfig,
366 masking: FeatureMasking,
367 encoder: GATEncoder,
368 decoder: GraphMAEDecoder,
369}
370
371impl GraphMAE {
372 pub fn new(config: GraphMAEConfig) -> Result<Self, GnnError> {
377 if config.hidden_dim % config.num_heads != 0 {
378 return Err(GnnError::layer_config(format!(
379 "hidden_dim ({}) must be divisible by num_heads ({})",
380 config.hidden_dim, config.num_heads
381 )));
382 }
383 if !(0.0..=1.0).contains(&config.mask_ratio) {
384 return Err(GnnError::layer_config("mask_ratio must be in [0.0, 1.0]"));
385 }
386 let masking = FeatureMasking::new(config.input_dim);
387 let encoder = GATEncoder::new(
388 config.input_dim,
389 config.hidden_dim,
390 config.num_layers,
391 config.num_heads,
392 );
393 let decoder =
394 GraphMAEDecoder::new(config.hidden_dim, config.input_dim, config.decoder_layers);
395 Ok(Self {
396 config,
397 masking,
398 encoder,
399 decoder,
400 })
401 }
402
403 pub fn train_step(&self, graph: &GraphData) -> f32 {
406 let mr = self
407 .masking
408 .mask_nodes(&graph.node_features, self.config.mask_ratio);
409 let latent = self.encoder.encode(&mr.masked_features, &graph.adjacency);
410 let recon = self
411 .decoder
412 .decode(&latent, &mr.mask_indices, self.config.re_mask_ratio);
413 let targets: Vec<Vec<f32>> = mr
414 .mask_indices
415 .iter()
416 .map(|&i| graph.node_features[i].clone())
417 .collect();
418 match self.config.loss_fn {
419 LossFn::Sce { gamma } => sce_loss(&recon, &targets, gamma),
420 LossFn::Mse => mse_loss(&recon, &targets),
421 }
422 }
423
424 pub fn encode(&self, graph: &GraphData) -> Vec<Vec<f32>> {
426 self.encoder.encode(&graph.node_features, &graph.adjacency)
427 }
428
429 pub fn get_embeddings(&self, graph: &GraphData) -> Vec<Vec<f32>> {
431 self.encode(graph)
432 }
433}
434
435fn elu_vec(v: &[f32]) -> Vec<f32> {
436 v.iter()
437 .map(|&x| if x >= 0.0 { x } else { x.exp() - 1.0 })
438 .collect()
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 fn graph(n: usize, d: usize) -> GraphData {
446 let feats: Vec<Vec<f32>> = (0..n)
447 .map(|i| (0..d).map(|j| (i * d + j) as f32 * 0.1).collect())
448 .collect();
449 let adj: Vec<Vec<usize>> = (0..n)
450 .map(|i| {
451 let mut nb = Vec::new();
452 if i > 0 {
453 nb.push(i - 1);
454 }
455 if i + 1 < n {
456 nb.push(i + 1);
457 }
458 nb
459 })
460 .collect();
461 GraphData {
462 node_features: feats,
463 adjacency: adj,
464 num_nodes: n,
465 }
466 }
467
468 fn cfg(dim: usize) -> GraphMAEConfig {
469 GraphMAEConfig {
470 input_dim: dim,
471 hidden_dim: 16,
472 num_heads: 4,
473 num_layers: 2,
474 decoder_layers: 1,
475 mask_ratio: 0.5,
476 re_mask_ratio: 0.0,
477 loss_fn: LossFn::default(),
478 }
479 }
480
481 #[test]
482 fn test_masking_ratio() {
483 let feats: Vec<Vec<f32>> = (0..100).map(|i| vec![i as f32; 8]).collect();
484 let m = FeatureMasking::new(8);
485 let r = m.mask_nodes(&feats, 0.3);
486 assert!((r.mask_indices.len() as i32 - 30).unsigned_abs() <= 1);
487 }
488
489 #[test]
490 fn test_encoder_forward() {
491 let g = graph(5, 16);
492 let enc = GATEncoder::new(16, 16, 2, 4);
493 let out = enc.encode(&g.node_features, &g.adjacency);
494 assert_eq!(out.len(), 5);
495 assert_eq!(out[0].len(), 16);
496 }
497
498 #[test]
499 fn test_decoder_reconstruction_shape() {
500 let dec = GraphMAEDecoder::new(16, 8, 1);
501 let lat: Vec<Vec<f32>> = (0..5).map(|_| vec![0.5; 16]).collect();
502 let r = dec.decode(&lat, &[0, 2, 4], 0.0);
503 assert_eq!(r.len(), 3);
504 assert_eq!(r[0].len(), 8);
505 }
506
507 #[test]
508 fn test_sce_loss_identical() {
509 let loss = sce_loss(&[vec![1.0, 0.0, 0.0]], &[vec![1.0, 0.0, 0.0]], 2.0);
510 assert!(loss < 1e-6, "SCE identical should be ~0, got {loss}");
511 }
512
513 #[test]
514 fn test_sce_loss_orthogonal() {
515 let loss = sce_loss(&[vec![1.0, 0.0]], &[vec![0.0, 1.0]], 2.0);
516 assert!(
517 (loss - 1.0).abs() < 1e-5,
518 "SCE orthogonal should be 1.0, got {loss}"
519 );
520 }
521
522 #[test]
523 fn test_mse_loss() {
524 assert!(mse_loss(&[vec![1.0, 2.0]], &[vec![1.0, 2.0]]) < 1e-8);
525 assert!((mse_loss(&[vec![0.0, 0.0]], &[vec![1.0, 1.0]]) - 1.0).abs() < 1e-6);
526 }
527
528 #[test]
529 fn test_train_step_returns_finite_loss() {
530 let model = GraphMAE::new(cfg(16)).unwrap();
531 let loss = model.train_step(&graph(10, 16));
532 assert!(loss.is_finite() && loss >= 0.0, "bad loss: {loss}");
533 }
534
535 #[test]
536 fn test_re_masking() {
537 let dec = GraphMAEDecoder::new(16, 8, 1);
538 let lat = vec![vec![1.0; 16]; 3];
539 let a = dec.decode(&lat, &[0, 1, 2], 0.0);
540 let b = dec.decode(&lat, &[0, 1, 2], 0.8);
541 let diff: f32 = a[0].iter().zip(&b[0]).map(|(x, y)| (x - y).abs()).sum();
542 assert!(diff > 1e-6, "re-masking should change output");
543 }
544
545 #[test]
546 fn test_degree_based_masking() {
547 let feats: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 8]).collect();
548 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); 10];
549 for i in 1..10 {
550 adj[0].push(i);
551 adj[i].push(0);
552 }
553 let r = FeatureMasking::new(8).mask_by_degree(&feats, &adj, 0.5);
554 assert_eq!(r.mask_indices.len(), 5);
555 }
556
557 #[test]
558 fn test_single_node_graph() {
559 let g = GraphData {
560 node_features: vec![vec![1.0; 16]],
561 adjacency: vec![vec![]],
562 num_nodes: 1,
563 };
564 assert!(GraphMAE::new(cfg(16)).unwrap().train_step(&g).is_finite());
565 }
566
567 #[test]
568 fn test_encode_for_downstream() {
569 let model = GraphMAE::new(cfg(16)).unwrap();
570 let emb = model.get_embeddings(&graph(8, 16));
571 assert_eq!(emb.len(), 8);
572 assert_eq!(emb[0].len(), 16);
573 for e in &emb {
574 for &v in e {
575 assert!(v.is_finite());
576 }
577 }
578 }
579
580 #[test]
581 fn test_invalid_config() {
582 assert!(GraphMAE::new(GraphMAEConfig {
583 hidden_dim: 15,
584 num_heads: 4,
585 ..cfg(16)
586 })
587 .is_err());
588 assert!(GraphMAE::new(GraphMAEConfig {
589 mask_ratio: 1.5,
590 ..cfg(16)
591 })
592 .is_err());
593 }
594}