Skip to main content

ruvector_gnn/
graphmae.rs

1//! # GraphMAE: Masked Autoencoders for Graphs
2//!
3//! Self-supervised graph learning via masked feature reconstruction. Traditional
4//! supervised graph learning requires expensive node/edge labels that are scarce in
5//! real-world graphs. GraphMAE learns representations by masking and reconstructing
6//! node features, requiring **zero labels**. The learned embeddings transfer well to
7//! downstream tasks (classification, link prediction, clustering) because the model
8//! must capture structural and semantic graph properties to reconstruct masked features
9//! from their neighborhood context.
10//!
11//! Pipeline: Mask -> GAT Encode -> Re-mask latent -> Decode masked only -> SCE loss.
12//!
13//! Reference: Hou et al., "GraphMAE: Self-Supervised Masked Graph Autoencoders", KDD 2022.
14
15use crate::error::GnnError;
16use crate::layer::{LayerNorm, Linear};
17use rand::seq::SliceRandom;
18use rand::Rng;
19
20/// Loss function variant for reconstruction.
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum LossFn {
23    /// Scaled Cosine Error: `(1 - cos_sim)^gamma`. Default for GraphMAE.
24    Sce {
25        /// Scaling exponent (default 2.0).
26        gamma: f32,
27    },
28    /// Standard Mean Squared Error.
29    Mse,
30}
31
32impl Default for LossFn {
33    fn default() -> Self {
34        Self::Sce { gamma: 2.0 }
35    }
36}
37
38/// Configuration for a GraphMAE model.
39#[derive(Debug, Clone)]
40pub struct GraphMAEConfig {
41    /// Fraction of nodes to mask (default 0.5).
42    pub mask_ratio: f32,
43    /// Number of GAT encoder layers.
44    pub num_layers: usize,
45    /// Hidden / latent dimension.
46    pub hidden_dim: usize,
47    /// Number of attention heads per encoder layer.
48    pub num_heads: usize,
49    /// Number of decoder layers.
50    pub decoder_layers: usize,
51    /// Secondary mask ratio applied to latent before decoding (default 0.0).
52    pub re_mask_ratio: f32,
53    /// Reconstruction loss function.
54    pub loss_fn: LossFn,
55    /// Input feature dimension.
56    pub input_dim: usize,
57}
58
59impl Default for GraphMAEConfig {
60    fn default() -> Self {
61        Self {
62            mask_ratio: 0.5,
63            num_layers: 2,
64            hidden_dim: 64,
65            num_heads: 4,
66            decoder_layers: 1,
67            re_mask_ratio: 0.0,
68            loss_fn: LossFn::default(),
69            input_dim: 64,
70        }
71    }
72}
73
74/// Sparse graph representation.
75#[derive(Debug, Clone)]
76pub struct GraphData {
77    /// Node feature matrix: `node_features[i]` is the feature vector for node `i`.
78    pub node_features: Vec<Vec<f32>>,
79    /// Adjacency list: `adjacency[i]` contains neighbor indices of node `i`.
80    pub adjacency: Vec<Vec<usize>>,
81    /// Number of nodes.
82    pub num_nodes: usize,
83}
84
85/// Result of masking node features.
86#[derive(Debug, Clone)]
87pub struct MaskResult {
88    /// Features after masking (mask token substituted).
89    pub masked_features: Vec<Vec<f32>>,
90    /// Indices of masked nodes.
91    pub mask_indices: Vec<usize>,
92}
93
94/// Feature masking strategies for GraphMAE.
95pub struct FeatureMasking {
96    mask_token: Vec<f32>,
97}
98
99impl FeatureMasking {
100    /// Create a masking module with a learnable `[MASK]` token of given dimension.
101    pub fn new(dim: usize) -> Self {
102        let mut rng = rand::thread_rng();
103        Self {
104            mask_token: (0..dim).map(|_| rng.gen::<f32>() * 0.02 - 0.01).collect(),
105        }
106    }
107
108    /// Randomly mask `mask_ratio` of nodes, replacing features with `[MASK]` token.
109    pub fn mask_nodes(&self, features: &[Vec<f32>], mask_ratio: f32) -> MaskResult {
110        let n = features.len();
111        let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
112        let mut rng = rand::thread_rng();
113        let mut indices: Vec<usize> = (0..n).collect();
114        indices.shuffle(&mut rng);
115        let mask_indices = indices[..num_mask.min(n)].to_vec();
116        let mut masked = features.to_vec();
117        for &i in &mask_indices {
118            masked[i] = self.mask_token.clone();
119        }
120        MaskResult {
121            masked_features: masked,
122            mask_indices,
123        }
124    }
125
126    /// Degree-centrality masking: higher-degree nodes are masked with higher probability.
127    pub fn mask_by_degree(
128        &self,
129        features: &[Vec<f32>],
130        adjacency: &[Vec<usize>],
131        mask_ratio: f32,
132    ) -> MaskResult {
133        let n = features.len();
134        let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
135        let degrees: Vec<f32> = adjacency.iter().map(|a| a.len() as f32 + 1.0).collect();
136        let total: f32 = degrees.iter().sum();
137        let probs: Vec<f32> = degrees.iter().map(|d| d / total).collect();
138        let mut rng = rand::thread_rng();
139        let mut avail: Vec<usize> = (0..n).collect();
140        let mut mask_indices = Vec::with_capacity(num_mask);
141        for _ in 0..num_mask.min(n) {
142            if avail.is_empty() {
143                break;
144            }
145            let rp: Vec<f32> = avail.iter().map(|&i| probs[i]).collect();
146            let s: f32 = rp.iter().sum();
147            if s <= 0.0 {
148                break;
149            }
150            let thr = rng.gen::<f32>() * s;
151            let mut cum = 0.0;
152            let mut chosen = 0;
153            for (pos, &p) in rp.iter().enumerate() {
154                cum += p;
155                if cum >= thr {
156                    chosen = pos;
157                    break;
158                }
159            }
160            mask_indices.push(avail[chosen]);
161            avail.swap_remove(chosen);
162        }
163        let mut masked = features.to_vec();
164        for &i in &mask_indices {
165            masked[i] = self.mask_token.clone();
166        }
167        MaskResult {
168            masked_features: masked,
169            mask_indices,
170        }
171    }
172}
173
174/// Single GAT layer with residual connection and layer normalization.
175struct GATLayer {
176    linear: Linear,
177    attn_src: Vec<f32>,
178    attn_dst: Vec<f32>,
179    norm: LayerNorm,
180    num_heads: usize,
181}
182
183impl GATLayer {
184    fn new(input_dim: usize, output_dim: usize, num_heads: usize) -> Self {
185        let mut rng = rand::thread_rng();
186        let hd = output_dim / num_heads.max(1);
187        Self {
188            linear: Linear::new(input_dim, output_dim),
189            attn_src: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
190            attn_dst: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
191            norm: LayerNorm::new(output_dim, 1e-5),
192            num_heads,
193        }
194    }
195
196    fn forward(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
197        let proj: Vec<Vec<f32>> = features.iter().map(|f| self.linear.forward(f)).collect();
198        let od = proj.first().map_or(0, |v| v.len());
199        let hd = od / self.num_heads.max(1);
200        let mut output = Vec::with_capacity(features.len());
201        for i in 0..features.len() {
202            if adj[i].is_empty() {
203                output.push(elu_vec(&proj[i]));
204                continue;
205            }
206            let mut agg = vec![0.0f32; od];
207            for h in 0..self.num_heads {
208                let (s, e) = (h * hd, (h + 1) * hd);
209                let ss: f32 = proj[i][s..e]
210                    .iter()
211                    .zip(&self.attn_src)
212                    .map(|(a, b)| a * b)
213                    .sum();
214                let mut scores: Vec<f32> = adj[i]
215                    .iter()
216                    .map(|&j| {
217                        let ds: f32 = proj[j][s..e]
218                            .iter()
219                            .zip(&self.attn_dst)
220                            .map(|(a, b)| a * b)
221                            .sum();
222                        let v = ss + ds;
223                        if v >= 0.0 {
224                            v
225                        } else {
226                            0.2 * v
227                        } // leaky relu
228                    })
229                    .collect();
230                let mx = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
231                let exp: Vec<f32> = scores.iter_mut().map(|v| (*v - mx).exp()).collect();
232                let sm = exp.iter().sum::<f32>().max(1e-10);
233                for (k, &j) in adj[i].iter().enumerate() {
234                    let w = exp[k] / sm;
235                    for d in s..e {
236                        agg[d] += w * proj[j][d];
237                    }
238                }
239            }
240            for v in &mut agg {
241                *v /= self.num_heads as f32;
242            }
243            if features[i].len() == od {
244                for (a, &f) in agg.iter_mut().zip(features[i].iter()) {
245                    *a += f;
246                }
247            }
248            output.push(elu_vec(&self.norm.forward(&agg)));
249        }
250        output
251    }
252}
253
254/// Multi-layer GAT encoder for GraphMAE.
255pub struct GATEncoder {
256    layers: Vec<GATLayer>,
257}
258
259impl GATEncoder {
260    /// Build an encoder with `num_layers` GAT layers.
261    pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize, num_heads: usize) -> Self {
262        let layers = (0..num_layers)
263            .map(|i| {
264                GATLayer::new(
265                    if i == 0 { input_dim } else { hidden_dim },
266                    hidden_dim,
267                    num_heads,
268                )
269            })
270            .collect();
271        Self { layers }
272    }
273
274    /// Encode node features through all GAT layers.
275    pub fn encode(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
276        self.layers
277            .iter()
278            .fold(features.to_vec(), |h, l| l.forward(&h, adj))
279    }
280}
281
282/// Decoder that reconstructs only masked node features (key efficiency gain).
283pub struct GraphMAEDecoder {
284    layers: Vec<Linear>,
285    norm: LayerNorm,
286}
287
288impl GraphMAEDecoder {
289    /// Create a decoder mapping `hidden_dim` -> `output_dim`.
290    pub fn new(hidden_dim: usize, output_dim: usize, num_layers: usize) -> Self {
291        let n = num_layers.max(1);
292        let layers = (0..n)
293            .map(|i| {
294                let out = if i == n - 1 { output_dim } else { hidden_dim };
295                Linear::new(hidden_dim, out)
296            })
297            .collect();
298        Self {
299            layers,
300            norm: LayerNorm::new(output_dim, 1e-5),
301        }
302    }
303
304    /// Decode latent for masked nodes. Applies re-masking (zeroing dims) for regularization.
305    pub fn decode(&self, latent: &[Vec<f32>], mask_idx: &[usize], re_mask: f32) -> Vec<Vec<f32>> {
306        let mut rng = rand::thread_rng();
307        mask_idx
308            .iter()
309            .map(|&idx| {
310                let mut h = latent[idx].clone();
311                if re_mask > 0.0 {
312                    let nz = ((h.len() as f32) * re_mask).round() as usize;
313                    let mut dims: Vec<usize> = (0..h.len()).collect();
314                    dims.shuffle(&mut rng);
315                    for &d in dims.iter().take(nz) {
316                        h[d] = 0.0;
317                    }
318                }
319                for layer in &self.layers {
320                    h = elu_vec(&layer.forward(&h));
321                }
322                self.norm.forward(&h)
323            })
324            .collect()
325    }
326}
327
328/// Scaled Cosine Error: `mean((1 - cos_sim(pred, target))^gamma)` over masked nodes.
329pub fn sce_loss(preds: &[Vec<f32>], targets: &[Vec<f32>], gamma: f32) -> f32 {
330    if preds.is_empty() {
331        return 0.0;
332    }
333    preds
334        .iter()
335        .zip(targets)
336        .map(|(p, t)| {
337            let dot: f32 = p.iter().zip(t).map(|(a, b)| a * b).sum();
338            let np = p.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
339            let nt = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
340            (1.0 - (dot / (np * nt)).clamp(-1.0, 1.0)).powf(gamma)
341        })
342        .sum::<f32>()
343        / preds.len() as f32
344}
345
346/// Mean Squared Error across masked node reconstructions.
347pub fn mse_loss(preds: &[Vec<f32>], targets: &[Vec<f32>]) -> f32 {
348    if preds.is_empty() {
349        return 0.0;
350    }
351    let n: usize = preds.iter().map(|v| v.len()).sum();
352    if n == 0 {
353        return 0.0;
354    }
355    preds
356        .iter()
357        .zip(targets)
358        .flat_map(|(p, t)| p.iter().zip(t).map(|(a, b)| (a - b).powi(2)))
359        .sum::<f32>()
360        / n as f32
361}
362
363/// GraphMAE self-supervised model.
364pub struct GraphMAE {
365    config: GraphMAEConfig,
366    masking: FeatureMasking,
367    encoder: GATEncoder,
368    decoder: GraphMAEDecoder,
369}
370
371impl GraphMAE {
372    /// Construct a new GraphMAE model from configuration.
373    ///
374    /// # Errors
375    /// Returns `GnnError::LayerConfig` if dimensions are incompatible.
376    pub fn new(config: GraphMAEConfig) -> Result<Self, GnnError> {
377        if config.hidden_dim % config.num_heads != 0 {
378            return Err(GnnError::layer_config(format!(
379                "hidden_dim ({}) must be divisible by num_heads ({})",
380                config.hidden_dim, config.num_heads
381            )));
382        }
383        if !(0.0..=1.0).contains(&config.mask_ratio) {
384            return Err(GnnError::layer_config("mask_ratio must be in [0.0, 1.0]"));
385        }
386        let masking = FeatureMasking::new(config.input_dim);
387        let encoder = GATEncoder::new(
388            config.input_dim,
389            config.hidden_dim,
390            config.num_layers,
391            config.num_heads,
392        );
393        let decoder =
394            GraphMAEDecoder::new(config.hidden_dim, config.input_dim, config.decoder_layers);
395        Ok(Self {
396            config,
397            masking,
398            encoder,
399            decoder,
400        })
401    }
402
403    /// Run one training step: mask -> encode -> re-mask -> decode -> loss.
404    /// Returns the reconstruction loss computed only on masked nodes.
405    pub fn train_step(&self, graph: &GraphData) -> f32 {
406        let mr = self
407            .masking
408            .mask_nodes(&graph.node_features, self.config.mask_ratio);
409        let latent = self.encoder.encode(&mr.masked_features, &graph.adjacency);
410        let recon = self
411            .decoder
412            .decode(&latent, &mr.mask_indices, self.config.re_mask_ratio);
413        let targets: Vec<Vec<f32>> = mr
414            .mask_indices
415            .iter()
416            .map(|&i| graph.node_features[i].clone())
417            .collect();
418        match self.config.loss_fn {
419            LossFn::Sce { gamma } => sce_loss(&recon, &targets, gamma),
420            LossFn::Mse => mse_loss(&recon, &targets),
421        }
422    }
423
424    /// Encode without masking (inference mode). Returns latent embeddings for all nodes.
425    pub fn encode(&self, graph: &GraphData) -> Vec<Vec<f32>> {
426        self.encoder.encode(&graph.node_features, &graph.adjacency)
427    }
428
429    /// Returns node-level representations for downstream tasks.
430    pub fn get_embeddings(&self, graph: &GraphData) -> Vec<Vec<f32>> {
431        self.encode(graph)
432    }
433}
434
435fn elu_vec(v: &[f32]) -> Vec<f32> {
436    v.iter()
437        .map(|&x| if x >= 0.0 { x } else { x.exp() - 1.0 })
438        .collect()
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    fn graph(n: usize, d: usize) -> GraphData {
446        let feats: Vec<Vec<f32>> = (0..n)
447            .map(|i| (0..d).map(|j| (i * d + j) as f32 * 0.1).collect())
448            .collect();
449        let adj: Vec<Vec<usize>> = (0..n)
450            .map(|i| {
451                let mut nb = Vec::new();
452                if i > 0 {
453                    nb.push(i - 1);
454                }
455                if i + 1 < n {
456                    nb.push(i + 1);
457                }
458                nb
459            })
460            .collect();
461        GraphData {
462            node_features: feats,
463            adjacency: adj,
464            num_nodes: n,
465        }
466    }
467
468    fn cfg(dim: usize) -> GraphMAEConfig {
469        GraphMAEConfig {
470            input_dim: dim,
471            hidden_dim: 16,
472            num_heads: 4,
473            num_layers: 2,
474            decoder_layers: 1,
475            mask_ratio: 0.5,
476            re_mask_ratio: 0.0,
477            loss_fn: LossFn::default(),
478        }
479    }
480
481    #[test]
482    fn test_masking_ratio() {
483        let feats: Vec<Vec<f32>> = (0..100).map(|i| vec![i as f32; 8]).collect();
484        let m = FeatureMasking::new(8);
485        let r = m.mask_nodes(&feats, 0.3);
486        assert!((r.mask_indices.len() as i32 - 30).unsigned_abs() <= 1);
487    }
488
489    #[test]
490    fn test_encoder_forward() {
491        let g = graph(5, 16);
492        let enc = GATEncoder::new(16, 16, 2, 4);
493        let out = enc.encode(&g.node_features, &g.adjacency);
494        assert_eq!(out.len(), 5);
495        assert_eq!(out[0].len(), 16);
496    }
497
498    #[test]
499    fn test_decoder_reconstruction_shape() {
500        let dec = GraphMAEDecoder::new(16, 8, 1);
501        let lat: Vec<Vec<f32>> = (0..5).map(|_| vec![0.5; 16]).collect();
502        let r = dec.decode(&lat, &[0, 2, 4], 0.0);
503        assert_eq!(r.len(), 3);
504        assert_eq!(r[0].len(), 8);
505    }
506
507    #[test]
508    fn test_sce_loss_identical() {
509        let loss = sce_loss(&[vec![1.0, 0.0, 0.0]], &[vec![1.0, 0.0, 0.0]], 2.0);
510        assert!(loss < 1e-6, "SCE identical should be ~0, got {loss}");
511    }
512
513    #[test]
514    fn test_sce_loss_orthogonal() {
515        let loss = sce_loss(&[vec![1.0, 0.0]], &[vec![0.0, 1.0]], 2.0);
516        assert!(
517            (loss - 1.0).abs() < 1e-5,
518            "SCE orthogonal should be 1.0, got {loss}"
519        );
520    }
521
522    #[test]
523    fn test_mse_loss() {
524        assert!(mse_loss(&[vec![1.0, 2.0]], &[vec![1.0, 2.0]]) < 1e-8);
525        assert!((mse_loss(&[vec![0.0, 0.0]], &[vec![1.0, 1.0]]) - 1.0).abs() < 1e-6);
526    }
527
528    #[test]
529    fn test_train_step_returns_finite_loss() {
530        let model = GraphMAE::new(cfg(16)).unwrap();
531        let loss = model.train_step(&graph(10, 16));
532        assert!(loss.is_finite() && loss >= 0.0, "bad loss: {loss}");
533    }
534
535    #[test]
536    fn test_re_masking() {
537        let dec = GraphMAEDecoder::new(16, 8, 1);
538        let lat = vec![vec![1.0; 16]; 3];
539        let a = dec.decode(&lat, &[0, 1, 2], 0.0);
540        let b = dec.decode(&lat, &[0, 1, 2], 0.8);
541        let diff: f32 = a[0].iter().zip(&b[0]).map(|(x, y)| (x - y).abs()).sum();
542        assert!(diff > 1e-6, "re-masking should change output");
543    }
544
545    #[test]
546    fn test_degree_based_masking() {
547        let feats: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 8]).collect();
548        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); 10];
549        for i in 1..10 {
550            adj[0].push(i);
551            adj[i].push(0);
552        }
553        let r = FeatureMasking::new(8).mask_by_degree(&feats, &adj, 0.5);
554        assert_eq!(r.mask_indices.len(), 5);
555    }
556
557    #[test]
558    fn test_single_node_graph() {
559        let g = GraphData {
560            node_features: vec![vec![1.0; 16]],
561            adjacency: vec![vec![]],
562            num_nodes: 1,
563        };
564        assert!(GraphMAE::new(cfg(16)).unwrap().train_step(&g).is_finite());
565    }
566
567    #[test]
568    fn test_encode_for_downstream() {
569        let model = GraphMAE::new(cfg(16)).unwrap();
570        let emb = model.get_embeddings(&graph(8, 16));
571        assert_eq!(emb.len(), 8);
572        assert_eq!(emb[0].len(), 16);
573        for e in &emb {
574            for &v in e {
575                assert!(v.is_finite());
576            }
577        }
578    }
579
580    #[test]
581    fn test_invalid_config() {
582        assert!(GraphMAE::new(GraphMAEConfig {
583            hidden_dim: 15,
584            num_heads: 4,
585            ..cfg(16)
586        })
587        .is_err());
588        assert!(GraphMAE::new(GraphMAEConfig {
589            mask_ratio: 1.5,
590            ..cfg(16)
591        })
592        .is_err());
593    }
594}