Skip to main content

ruvector_gnn/
graphmae.rs

1//! # GraphMAE: Masked Autoencoders for Graphs
2//!
3//! Self-supervised graph learning via masked feature reconstruction. Traditional
4//! supervised graph learning requires expensive node/edge labels that are scarce in
5//! real-world graphs. GraphMAE learns representations by masking and reconstructing
6//! node features, requiring **zero labels**. The learned embeddings transfer well to
7//! downstream tasks (classification, link prediction, clustering) because the model
8//! must capture structural and semantic graph properties to reconstruct masked features
9//! from their neighborhood context.
10//!
11//! Pipeline: Mask -> GAT Encode -> Re-mask latent -> Decode masked only -> SCE loss.
12//!
13//! Reference: Hou et al., "GraphMAE: Self-Supervised Masked Graph Autoencoders", KDD 2022.
14
15use crate::error::GnnError;
16use crate::layer::{LayerNorm, Linear};
17use rand::seq::SliceRandom;
18use rand::Rng;
19
20/// Loss function variant for reconstruction.
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum LossFn {
23    /// Scaled Cosine Error: `(1 - cos_sim)^gamma`. Default for GraphMAE.
24    Sce { /// Scaling exponent (default 2.0).
25        gamma: f32 },
26    /// Standard Mean Squared Error.
27    Mse,
28}
29
30impl Default for LossFn {
31    fn default() -> Self { Self::Sce { gamma: 2.0 } }
32}
33
34/// Configuration for a GraphMAE model.
35#[derive(Debug, Clone)]
36pub struct GraphMAEConfig {
37    /// Fraction of nodes to mask (default 0.5).
38    pub mask_ratio: f32,
39    /// Number of GAT encoder layers.
40    pub num_layers: usize,
41    /// Hidden / latent dimension.
42    pub hidden_dim: usize,
43    /// Number of attention heads per encoder layer.
44    pub num_heads: usize,
45    /// Number of decoder layers.
46    pub decoder_layers: usize,
47    /// Secondary mask ratio applied to latent before decoding (default 0.0).
48    pub re_mask_ratio: f32,
49    /// Reconstruction loss function.
50    pub loss_fn: LossFn,
51    /// Input feature dimension.
52    pub input_dim: usize,
53}
54
55impl Default for GraphMAEConfig {
56    fn default() -> Self {
57        Self {
58            mask_ratio: 0.5, num_layers: 2, hidden_dim: 64, num_heads: 4,
59            decoder_layers: 1, re_mask_ratio: 0.0, loss_fn: LossFn::default(), input_dim: 64,
60        }
61    }
62}
63
64/// Sparse graph representation.
65#[derive(Debug, Clone)]
66pub struct GraphData {
67    /// Node feature matrix: `node_features[i]` is the feature vector for node `i`.
68    pub node_features: Vec<Vec<f32>>,
69    /// Adjacency list: `adjacency[i]` contains neighbor indices of node `i`.
70    pub adjacency: Vec<Vec<usize>>,
71    /// Number of nodes.
72    pub num_nodes: usize,
73}
74
75/// Result of masking node features.
76#[derive(Debug, Clone)]
77pub struct MaskResult {
78    /// Features after masking (mask token substituted).
79    pub masked_features: Vec<Vec<f32>>,
80    /// Indices of masked nodes.
81    pub mask_indices: Vec<usize>,
82}
83
84/// Feature masking strategies for GraphMAE.
85pub struct FeatureMasking {
86    mask_token: Vec<f32>,
87}
88
89impl FeatureMasking {
90    /// Create a masking module with a learnable `[MASK]` token of given dimension.
91    pub fn new(dim: usize) -> Self {
92        let mut rng = rand::thread_rng();
93        Self { mask_token: (0..dim).map(|_| rng.gen::<f32>() * 0.02 - 0.01).collect() }
94    }
95
96    /// Randomly mask `mask_ratio` of nodes, replacing features with `[MASK]` token.
97    pub fn mask_nodes(&self, features: &[Vec<f32>], mask_ratio: f32) -> MaskResult {
98        let n = features.len();
99        let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
100        let mut rng = rand::thread_rng();
101        let mut indices: Vec<usize> = (0..n).collect();
102        indices.shuffle(&mut rng);
103        let mask_indices = indices[..num_mask.min(n)].to_vec();
104        let mut masked = features.to_vec();
105        for &i in &mask_indices { masked[i] = self.mask_token.clone(); }
106        MaskResult { masked_features: masked, mask_indices }
107    }
108
109    /// Degree-centrality masking: higher-degree nodes are masked with higher probability.
110    pub fn mask_by_degree(
111        &self, features: &[Vec<f32>], adjacency: &[Vec<usize>], mask_ratio: f32,
112    ) -> MaskResult {
113        let n = features.len();
114        let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
115        let degrees: Vec<f32> = adjacency.iter().map(|a| a.len() as f32 + 1.0).collect();
116        let total: f32 = degrees.iter().sum();
117        let probs: Vec<f32> = degrees.iter().map(|d| d / total).collect();
118        let mut rng = rand::thread_rng();
119        let mut avail: Vec<usize> = (0..n).collect();
120        let mut mask_indices = Vec::with_capacity(num_mask);
121        for _ in 0..num_mask.min(n) {
122            if avail.is_empty() { break; }
123            let rp: Vec<f32> = avail.iter().map(|&i| probs[i]).collect();
124            let s: f32 = rp.iter().sum();
125            if s <= 0.0 { break; }
126            let thr = rng.gen::<f32>() * s;
127            let mut cum = 0.0;
128            let mut chosen = 0;
129            for (pos, &p) in rp.iter().enumerate() {
130                cum += p;
131                if cum >= thr { chosen = pos; break; }
132            }
133            mask_indices.push(avail[chosen]);
134            avail.swap_remove(chosen);
135        }
136        let mut masked = features.to_vec();
137        for &i in &mask_indices { masked[i] = self.mask_token.clone(); }
138        MaskResult { masked_features: masked, mask_indices }
139    }
140}
141
142/// Single GAT layer with residual connection and layer normalization.
143struct GATLayer {
144    linear: Linear,
145    attn_src: Vec<f32>,
146    attn_dst: Vec<f32>,
147    norm: LayerNorm,
148    num_heads: usize,
149}
150
151impl GATLayer {
152    fn new(input_dim: usize, output_dim: usize, num_heads: usize) -> Self {
153        let mut rng = rand::thread_rng();
154        let hd = output_dim / num_heads.max(1);
155        Self {
156            linear: Linear::new(input_dim, output_dim),
157            attn_src: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
158            attn_dst: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
159            norm: LayerNorm::new(output_dim, 1e-5),
160            num_heads,
161        }
162    }
163
164    fn forward(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
165        let proj: Vec<Vec<f32>> = features.iter().map(|f| self.linear.forward(f)).collect();
166        let od = proj.first().map_or(0, |v| v.len());
167        let hd = od / self.num_heads.max(1);
168        let mut output = Vec::with_capacity(features.len());
169        for i in 0..features.len() {
170            if adj[i].is_empty() {
171                output.push(elu_vec(&proj[i]));
172                continue;
173            }
174            let mut agg = vec![0.0f32; od];
175            for h in 0..self.num_heads {
176                let (s, e) = (h * hd, (h + 1) * hd);
177                let ss: f32 = proj[i][s..e].iter().zip(&self.attn_src).map(|(a, b)| a * b).sum();
178                let mut scores: Vec<f32> = adj[i].iter().map(|&j| {
179                    let ds: f32 = proj[j][s..e].iter().zip(&self.attn_dst).map(|(a, b)| a * b).sum();
180                    let v = ss + ds;
181                    if v >= 0.0 { v } else { 0.2 * v } // leaky relu
182                }).collect();
183                let mx = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
184                let exp: Vec<f32> = scores.iter_mut().map(|v| (*v - mx).exp()).collect();
185                let sm = exp.iter().sum::<f32>().max(1e-10);
186                for (k, &j) in adj[i].iter().enumerate() {
187                    let w = exp[k] / sm;
188                    for d in s..e { agg[d] += w * proj[j][d]; }
189                }
190            }
191            for v in &mut agg { *v /= self.num_heads as f32; }
192            if features[i].len() == od {
193                for (a, &f) in agg.iter_mut().zip(features[i].iter()) { *a += f; }
194            }
195            output.push(elu_vec(&self.norm.forward(&agg)));
196        }
197        output
198    }
199}
200
201/// Multi-layer GAT encoder for GraphMAE.
202pub struct GATEncoder { layers: Vec<GATLayer> }
203
204impl GATEncoder {
205    /// Build an encoder with `num_layers` GAT layers.
206    pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize, num_heads: usize) -> Self {
207        let layers = (0..num_layers).map(|i| {
208            GATLayer::new(if i == 0 { input_dim } else { hidden_dim }, hidden_dim, num_heads)
209        }).collect();
210        Self { layers }
211    }
212
213    /// Encode node features through all GAT layers.
214    pub fn encode(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
215        self.layers.iter().fold(features.to_vec(), |h, l| l.forward(&h, adj))
216    }
217}
218
219/// Decoder that reconstructs only masked node features (key efficiency gain).
220pub struct GraphMAEDecoder { layers: Vec<Linear>, norm: LayerNorm }
221
222impl GraphMAEDecoder {
223    /// Create a decoder mapping `hidden_dim` -> `output_dim`.
224    pub fn new(hidden_dim: usize, output_dim: usize, num_layers: usize) -> Self {
225        let n = num_layers.max(1);
226        let layers = (0..n).map(|i| {
227            let out = if i == n - 1 { output_dim } else { hidden_dim };
228            Linear::new(if i == 0 { hidden_dim } else { hidden_dim }, out)
229        }).collect();
230        Self { layers, norm: LayerNorm::new(output_dim, 1e-5) }
231    }
232
233    /// Decode latent for masked nodes. Applies re-masking (zeroing dims) for regularization.
234    pub fn decode(&self, latent: &[Vec<f32>], mask_idx: &[usize], re_mask: f32) -> Vec<Vec<f32>> {
235        let mut rng = rand::thread_rng();
236        mask_idx.iter().map(|&idx| {
237            let mut h = latent[idx].clone();
238            if re_mask > 0.0 {
239                let nz = ((h.len() as f32) * re_mask).round() as usize;
240                let mut dims: Vec<usize> = (0..h.len()).collect();
241                dims.shuffle(&mut rng);
242                for &d in dims.iter().take(nz) { h[d] = 0.0; }
243            }
244            for layer in &self.layers { h = elu_vec(&layer.forward(&h)); }
245            self.norm.forward(&h)
246        }).collect()
247    }
248}
249
250/// Scaled Cosine Error: `mean((1 - cos_sim(pred, target))^gamma)` over masked nodes.
251pub fn sce_loss(preds: &[Vec<f32>], targets: &[Vec<f32>], gamma: f32) -> f32 {
252    if preds.is_empty() { return 0.0; }
253    preds.iter().zip(targets).map(|(p, t)| {
254        let dot: f32 = p.iter().zip(t).map(|(a, b)| a * b).sum();
255        let np = p.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
256        let nt = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
257        (1.0 - (dot / (np * nt)).clamp(-1.0, 1.0)).powf(gamma)
258    }).sum::<f32>() / preds.len() as f32
259}
260
261/// Mean Squared Error across masked node reconstructions.
262pub fn mse_loss(preds: &[Vec<f32>], targets: &[Vec<f32>]) -> f32 {
263    if preds.is_empty() { return 0.0; }
264    let n: usize = preds.iter().map(|v| v.len()).sum();
265    if n == 0 { return 0.0; }
266    preds.iter().zip(targets).flat_map(|(p, t)| {
267        p.iter().zip(t).map(|(a, b)| (a - b).powi(2))
268    }).sum::<f32>() / n as f32
269}
270
271/// GraphMAE self-supervised model.
272pub struct GraphMAE {
273    config: GraphMAEConfig,
274    masking: FeatureMasking,
275    encoder: GATEncoder,
276    decoder: GraphMAEDecoder,
277}
278
279impl GraphMAE {
280    /// Construct a new GraphMAE model from configuration.
281    ///
282    /// # Errors
283    /// Returns `GnnError::LayerConfig` if dimensions are incompatible.
284    pub fn new(config: GraphMAEConfig) -> Result<Self, GnnError> {
285        if config.hidden_dim % config.num_heads != 0 {
286            return Err(GnnError::layer_config(format!(
287                "hidden_dim ({}) must be divisible by num_heads ({})",
288                config.hidden_dim, config.num_heads
289            )));
290        }
291        if !(0.0..=1.0).contains(&config.mask_ratio) {
292            return Err(GnnError::layer_config("mask_ratio must be in [0.0, 1.0]"));
293        }
294        let masking = FeatureMasking::new(config.input_dim);
295        let encoder = GATEncoder::new(config.input_dim, config.hidden_dim, config.num_layers, config.num_heads);
296        let decoder = GraphMAEDecoder::new(config.hidden_dim, config.input_dim, config.decoder_layers);
297        Ok(Self { config, masking, encoder, decoder })
298    }
299
300    /// Run one training step: mask -> encode -> re-mask -> decode -> loss.
301    /// Returns the reconstruction loss computed only on masked nodes.
302    pub fn train_step(&self, graph: &GraphData) -> f32 {
303        let mr = self.masking.mask_nodes(&graph.node_features, self.config.mask_ratio);
304        let latent = self.encoder.encode(&mr.masked_features, &graph.adjacency);
305        let recon = self.decoder.decode(&latent, &mr.mask_indices, self.config.re_mask_ratio);
306        let targets: Vec<Vec<f32>> = mr.mask_indices.iter().map(|&i| graph.node_features[i].clone()).collect();
307        match self.config.loss_fn {
308            LossFn::Sce { gamma } => sce_loss(&recon, &targets, gamma),
309            LossFn::Mse => mse_loss(&recon, &targets),
310        }
311    }
312
313    /// Encode without masking (inference mode). Returns latent embeddings for all nodes.
314    pub fn encode(&self, graph: &GraphData) -> Vec<Vec<f32>> {
315        self.encoder.encode(&graph.node_features, &graph.adjacency)
316    }
317
318    /// Returns node-level representations for downstream tasks.
319    pub fn get_embeddings(&self, graph: &GraphData) -> Vec<Vec<f32>> { self.encode(graph) }
320}
321
322fn elu_vec(v: &[f32]) -> Vec<f32> {
323    v.iter().map(|&x| if x >= 0.0 { x } else { x.exp() - 1.0 }).collect()
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    fn graph(n: usize, d: usize) -> GraphData {
331        let feats: Vec<Vec<f32>> = (0..n)
332            .map(|i| (0..d).map(|j| (i * d + j) as f32 * 0.1).collect()).collect();
333        let adj: Vec<Vec<usize>> = (0..n).map(|i| {
334            let mut nb = Vec::new();
335            if i > 0 { nb.push(i - 1); }
336            if i + 1 < n { nb.push(i + 1); }
337            nb
338        }).collect();
339        GraphData { node_features: feats, adjacency: adj, num_nodes: n }
340    }
341
342    fn cfg(dim: usize) -> GraphMAEConfig {
343        GraphMAEConfig {
344            input_dim: dim, hidden_dim: 16, num_heads: 4, num_layers: 2,
345            decoder_layers: 1, mask_ratio: 0.5, re_mask_ratio: 0.0, loss_fn: LossFn::default(),
346        }
347    }
348
349    #[test]
350    fn test_masking_ratio() {
351        let feats: Vec<Vec<f32>> = (0..100).map(|i| vec![i as f32; 8]).collect();
352        let m = FeatureMasking::new(8);
353        let r = m.mask_nodes(&feats, 0.3);
354        assert!((r.mask_indices.len() as i32 - 30).unsigned_abs() <= 1);
355    }
356
357    #[test]
358    fn test_encoder_forward() {
359        let g = graph(5, 16);
360        let enc = GATEncoder::new(16, 16, 2, 4);
361        let out = enc.encode(&g.node_features, &g.adjacency);
362        assert_eq!(out.len(), 5);
363        assert_eq!(out[0].len(), 16);
364    }
365
366    #[test]
367    fn test_decoder_reconstruction_shape() {
368        let dec = GraphMAEDecoder::new(16, 8, 1);
369        let lat: Vec<Vec<f32>> = (0..5).map(|_| vec![0.5; 16]).collect();
370        let r = dec.decode(&lat, &[0, 2, 4], 0.0);
371        assert_eq!(r.len(), 3);
372        assert_eq!(r[0].len(), 8);
373    }
374
375    #[test]
376    fn test_sce_loss_identical() {
377        let loss = sce_loss(&[vec![1.0, 0.0, 0.0]], &[vec![1.0, 0.0, 0.0]], 2.0);
378        assert!(loss < 1e-6, "SCE identical should be ~0, got {loss}");
379    }
380
381    #[test]
382    fn test_sce_loss_orthogonal() {
383        let loss = sce_loss(&[vec![1.0, 0.0]], &[vec![0.0, 1.0]], 2.0);
384        assert!((loss - 1.0).abs() < 1e-5, "SCE orthogonal should be 1.0, got {loss}");
385    }
386
387    #[test]
388    fn test_mse_loss() {
389        assert!(mse_loss(&[vec![1.0, 2.0]], &[vec![1.0, 2.0]]) < 1e-8);
390        assert!((mse_loss(&[vec![0.0, 0.0]], &[vec![1.0, 1.0]]) - 1.0).abs() < 1e-6);
391    }
392
393    #[test]
394    fn test_train_step_returns_finite_loss() {
395        let model = GraphMAE::new(cfg(16)).unwrap();
396        let loss = model.train_step(&graph(10, 16));
397        assert!(loss.is_finite() && loss >= 0.0, "bad loss: {loss}");
398    }
399
400    #[test]
401    fn test_re_masking() {
402        let dec = GraphMAEDecoder::new(16, 8, 1);
403        let lat = vec![vec![1.0; 16]; 3];
404        let a = dec.decode(&lat, &[0, 1, 2], 0.0);
405        let b = dec.decode(&lat, &[0, 1, 2], 0.8);
406        let diff: f32 = a[0].iter().zip(&b[0]).map(|(x, y)| (x - y).abs()).sum();
407        assert!(diff > 1e-6, "re-masking should change output");
408    }
409
410    #[test]
411    fn test_degree_based_masking() {
412        let feats: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 8]).collect();
413        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); 10];
414        for i in 1..10 { adj[0].push(i); adj[i].push(0); }
415        let r = FeatureMasking::new(8).mask_by_degree(&feats, &adj, 0.5);
416        assert_eq!(r.mask_indices.len(), 5);
417    }
418
419    #[test]
420    fn test_single_node_graph() {
421        let g = GraphData { node_features: vec![vec![1.0; 16]], adjacency: vec![vec![]], num_nodes: 1 };
422        assert!(GraphMAE::new(cfg(16)).unwrap().train_step(&g).is_finite());
423    }
424
425    #[test]
426    fn test_encode_for_downstream() {
427        let model = GraphMAE::new(cfg(16)).unwrap();
428        let emb = model.get_embeddings(&graph(8, 16));
429        assert_eq!(emb.len(), 8);
430        assert_eq!(emb[0].len(), 16);
431        for e in &emb { for &v in e { assert!(v.is_finite()); } }
432    }
433
434    #[test]
435    fn test_invalid_config() {
436        assert!(GraphMAE::new(GraphMAEConfig { hidden_dim: 15, num_heads: 4, ..cfg(16) }).is_err());
437        assert!(GraphMAE::new(GraphMAEConfig { mask_ratio: 1.5, ..cfg(16) }).is_err());
438    }
439}