Skip to main content

scirs2_graph/ssl/
masked_autoencoder.rs

1//! Graph Masked Autoencoder (GraphMAE, Hou et al. 2022).
2//!
3//! GraphMAE is a generative self-supervised learning method that:
4//!
5//! 1. **Masks** a random subset of node features by replacing them with a
6//!    learnable mask token.
7//! 2. **Encodes** the masked graph with a GNN encoder (here approximated by a
8//!    linear layer + ReLU for simplicity).
9//! 3. **Decodes** the latent representation back to the original feature space
10//!    via a lightweight decoder.
11//! 4. Computes the **scaled cosine error (SCE)** reconstruction loss only over
12//!    the masked nodes:
13//!
14//! ```text
15//! L_SCE = (1/|M|) Σ_{i∈M} (1 - cos_sim(ẑ_i, h_i)^γ)
16//! ```
17//!
18//! where `ẑ_i` is the reconstructed feature and `h_i` is the original feature,
19//! and `γ ≥ 1` controls the sharpness of the penalty.
20//!
21//! ## Reference
22//! Hou, Z., Liu, X., Cen, Y., Dong, Y., Yang, H., Wang, C., & Tang, J. (2022).
23//! *GraphMAE: Self-Supervised Masked Graph Autoencoders.* KDD 2022.
24
25use scirs2_core::ndarray::{Array1, Array2};
26use scirs2_core::random::{Rng, RngExt, SeedableRng};
27
28// ============================================================================
29// Configuration
30// ============================================================================
31
32/// Configuration for the Graph Masked Autoencoder.
33#[derive(Debug, Clone)]
34pub struct GraphMaeConfig {
35    /// Fraction of nodes whose features are masked during training.
36    pub mask_rate: f64,
37    /// Encoder output (latent) dimension.
38    pub encoder_dim: usize,
39    /// Decoder output dimension (must equal the input feature dimension).
40    pub decoder_dim: usize,
41    /// Scale of the random initialisation for the mask replacement token.
42    pub replace_token_scale: f64,
43}
44
45impl Default for GraphMaeConfig {
46    fn default() -> Self {
47        Self {
48            mask_rate: 0.25,
49            encoder_dim: 64,
50            decoder_dim: 64,
51            replace_token_scale: 0.1,
52        }
53    }
54}
55
56// ============================================================================
57// GraphMae
58// ============================================================================
59
60/// Graph Masked Autoencoder.
61///
62/// Maintains:
63/// - A learnable **mask token** of shape `[feature_dim]` used to replace masked
64///   node features.
65/// - An **encoder weight** matrix `[feature_dim × encoder_dim]`.
66/// - A **decoder weight** matrix `[encoder_dim × feature_dim]`.
67pub struct GraphMae {
68    /// Learnable mask replacement token: `[feature_dim]`
69    mask_token: Array1<f64>,
70    /// Encoder weight: `[feature_dim × encoder_dim]`
71    encoder_weight: Array2<f64>,
72    /// Decoder weight: `[encoder_dim × feature_dim]`
73    decoder_weight: Array2<f64>,
74    /// Feature dimension (= decoder output dimension).
75    feature_dim: usize,
76    config: GraphMaeConfig,
77}
78
79impl GraphMae {
80    /// Construct a new GraphMAE.
81    ///
82    /// # Arguments
83    /// * `feature_dim` – dimension of each node's input features
84    /// * `config`      – MAE hyper-parameters
85    /// * `seed`        – RNG seed for reproducible initialisation
86    pub fn new(feature_dim: usize, config: GraphMaeConfig, seed: u64) -> Self {
87        let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
88
89        // Mask token: uniform in [-replace_token_scale, +replace_token_scale]
90        let s = config.replace_token_scale;
91        let mask_token = Array1::from_shape_fn(feature_dim, |_| rng.random::<f64>() * 2.0 * s - s);
92
93        // Encoder: Xavier uniform [feature_dim × encoder_dim]
94        let enc_scale = (6.0 / (feature_dim + config.encoder_dim) as f64).sqrt();
95        let encoder_weight = Array2::from_shape_fn((feature_dim, config.encoder_dim), |_| {
96            rng.random::<f64>() * 2.0 * enc_scale - enc_scale
97        });
98
99        // Decoder: Xavier uniform [encoder_dim × feature_dim]
100        let dec_scale = (6.0 / (config.encoder_dim + feature_dim) as f64).sqrt();
101        let decoder_weight = Array2::from_shape_fn((config.encoder_dim, feature_dim), |_| {
102            rng.random::<f64>() * 2.0 * dec_scale - dec_scale
103        });
104
105        GraphMae {
106            mask_token,
107            encoder_weight,
108            decoder_weight,
109            feature_dim,
110            config,
111        }
112    }
113
114    /// Apply random feature masking.
115    ///
116    /// Selects a random subset of nodes (fraction ≈ `config.mask_rate`) and
117    /// replaces their feature vectors with the learnable mask token.
118    ///
119    /// # Arguments
120    /// * `features` – node feature matrix `[n_nodes × feature_dim]`
121    /// * `seed`     – RNG seed (different from the model seed so each call can
122    ///   produce a different mask)
123    ///
124    /// # Returns
125    /// `(masked_features, mask_indices)` where `mask_indices` contains the
126    /// row indices of the masked nodes (sorted ascending).
127    pub fn mask_features(&self, features: &Array2<f64>, seed: u64) -> (Array2<f64>, Vec<usize>) {
128        let n_nodes = features.dim().0;
129        let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
130
131        let mut masked = features.clone();
132        let mut mask_indices = Vec::new();
133
134        for i in 0..n_nodes {
135            if rng.random::<f64>() < self.config.mask_rate {
136                mask_indices.push(i);
137                for d in 0..self.feature_dim {
138                    masked[[i, d]] = self.mask_token[d];
139                }
140            }
141        }
142
143        mask_indices.sort_unstable();
144        (masked, mask_indices)
145    }
146
147    /// Encode node features: `Z = ReLU(X @ W_enc)`
148    ///
149    /// # Arguments
150    /// * `features` – (possibly masked) node features `[n_nodes × feature_dim]`
151    ///
152    /// # Returns
153    /// Latent representations `[n_nodes × encoder_dim]`
154    pub fn encode(&self, features: &Array2<f64>) -> Array2<f64> {
155        let n_nodes = features.dim().0;
156        let enc_dim = self.config.encoder_dim;
157
158        let mut z = Array2::zeros((n_nodes, enc_dim));
159        for i in 0..n_nodes {
160            for k in 0..enc_dim {
161                let mut val = 0.0;
162                for d in 0..self.feature_dim {
163                    val += features[[i, d]] * self.encoder_weight[[d, k]];
164                }
165                z[[i, k]] = if val > 0.0 { val } else { 0.0 }; // ReLU
166            }
167        }
168        z
169    }
170
171    /// Decode latent representations: `X̂ = Z @ W_dec`
172    ///
173    /// # Arguments
174    /// * `encoded` – latent representations `[n_nodes × encoder_dim]`
175    ///
176    /// # Returns
177    /// Reconstructed features `[n_nodes × feature_dim]`
178    pub fn decode(&self, encoded: &Array2<f64>) -> Array2<f64> {
179        let n_nodes = encoded.dim().0;
180
181        let mut out = Array2::zeros((n_nodes, self.feature_dim));
182        for i in 0..n_nodes {
183            for d in 0..self.feature_dim {
184                let mut val = 0.0;
185                for k in 0..self.config.encoder_dim {
186                    val += encoded[[i, k]] * self.decoder_weight[[k, d]];
187                }
188                out[[i, d]] = val;
189            }
190        }
191        out
192    }
193
194    /// Scaled Cosine Error (SCE) reconstruction loss on masked nodes.
195    ///
196    /// ```text
197    /// L = (1/|M|) Σ_{i∈M} (1 - cosine_sim(reconstructed_i, original_i))^γ
198    /// ```
199    ///
200    /// If `mask_indices` is empty, returns `0.0`.
201    ///
202    /// # Arguments
203    /// * `original`       – original node features `[n_nodes × feature_dim]`
204    /// * `reconstructed`  – decoder output `[n_nodes × feature_dim]`
205    /// * `mask_indices`   – indices of masked nodes
206    /// * `gamma`          – exponent ≥ 1 (typical: 2 or 3)
207    pub fn sce_loss(
208        &self,
209        original: &Array2<f64>,
210        reconstructed: &Array2<f64>,
211        mask_indices: &[usize],
212        gamma: f64,
213    ) -> f64 {
214        if mask_indices.is_empty() {
215            return 0.0;
216        }
217
218        let mut total = 0.0;
219        let d = self.feature_dim;
220
221        for &i in mask_indices {
222            // Dot product and norms
223            let mut dot = 0.0;
224            let mut norm_r = 0.0;
225            let mut norm_o = 0.0;
226            for k in 0..d {
227                let r = reconstructed[[i, k]];
228                let o = original[[i, k]];
229                dot += r * o;
230                norm_r += r * r;
231                norm_o += o * o;
232            }
233            let denom = norm_r.sqrt().max(1e-12) * norm_o.sqrt().max(1e-12);
234            let cos_sim = (dot / denom).clamp(-1.0, 1.0);
235            // SCE: (1 - cos_sim)^gamma
236            let term = (1.0 - cos_sim).powf(gamma);
237            total += term;
238        }
239
240        total / mask_indices.len() as f64
241    }
242
243    /// Full GraphMAE forward pass.
244    ///
245    /// 1. Mask features randomly.
246    /// 2. Encode masked features.
247    /// 3. Decode back to feature space.
248    /// 4. Compute SCE loss over masked nodes (γ = 2).
249    ///
250    /// # Arguments
251    /// * `features` – original node feature matrix `[n_nodes × feature_dim]`
252    /// * `seed`     – RNG seed for the masking step
253    ///
254    /// # Returns
255    /// `(reconstructed_features, sce_loss)`
256    pub fn forward(&self, features: &Array2<f64>, seed: u64) -> (Array2<f64>, f64) {
257        let (masked, mask_indices) = self.mask_features(features, seed);
258        let encoded = self.encode(&masked);
259        let reconstructed = self.decode(&encoded);
260        let loss = self.sce_loss(features, &reconstructed, &mask_indices, 2.0);
261        (reconstructed, loss)
262    }
263
264    /// The learnable mask token vector `[feature_dim]`.
265    pub fn mask_token(&self) -> &Array1<f64> {
266        &self.mask_token
267    }
268
269    /// Input / output feature dimension.
270    pub fn feature_dim(&self) -> usize {
271        self.feature_dim
272    }
273
274    /// Encoder output dimension.
275    pub fn encoder_dim(&self) -> usize {
276        self.config.encoder_dim
277    }
278}
279
280// ============================================================================
281// Tests
282// ============================================================================
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    fn make_mae(feature_dim: usize, mask_rate: f64) -> GraphMae {
289        let cfg = GraphMaeConfig {
290            mask_rate,
291            encoder_dim: 16,
292            decoder_dim: feature_dim,
293            replace_token_scale: 0.1,
294        };
295        GraphMae::new(feature_dim, cfg, 42)
296    }
297
298    #[test]
299    fn test_mask_features_approximate_rate() {
300        let mae = make_mae(8, 0.5);
301        let x = Array2::ones((100, 8));
302        let (_, mask_idx) = mae.mask_features(&x, 0);
303        // With mask_rate=0.5 and 100 nodes, expected ~50 masked;
304        // allow generous tolerance for a stochastic test
305        let frac = mask_idx.len() as f64 / 100.0;
306        assert!(
307            (frac - 0.5).abs() < 0.2,
308            "masking fraction {frac} too far from 0.5"
309        );
310    }
311
312    #[test]
313    fn test_encode_output_shape() {
314        let mae = make_mae(8, 0.25);
315        let x = Array2::ones((10, 8));
316        let z = mae.encode(&x);
317        assert_eq!(z.dim(), (10, 16));
318    }
319
320    #[test]
321    fn test_decode_output_shape_matches_feature_dim() {
322        let mae = make_mae(8, 0.25);
323        let z = Array2::ones((10, 16));
324        let out = mae.decode(&z);
325        assert_eq!(out.dim(), (10, 8));
326    }
327
328    #[test]
329    fn test_sce_loss_identical_is_zero() {
330        let mae = make_mae(4, 0.25);
331        let x = Array2::from_shape_fn((6, 4), |(i, j)| (i + j + 1) as f64);
332        // Compute something that equals x by passing x as reconstructed
333        let loss = mae.sce_loss(&x, &x, &[0, 1, 2, 3, 4, 5], 2.0);
334        assert!(
335            loss.abs() < 1e-9,
336            "SCE loss for identical tensors should be ~0, got {loss}"
337        );
338    }
339
340    #[test]
341    fn test_sce_loss_orthogonal_positive() {
342        let mae = make_mae(4, 0.25);
343        // Two orthogonal vectors: cos_sim = 0 → loss per element = 1.0
344        let original = Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0])
345            .expect("ok");
346        let recon = Array2::from_shape_vec((2, 4), vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
347            .expect("ok");
348        let loss = mae.sce_loss(&original, &recon, &[0, 1], 2.0);
349        // cos_sim = 0 → (1-0)^2 = 1 for each node
350        assert!(
351            (loss - 1.0).abs() < 1e-9,
352            "SCE loss for orthogonal vectors should be 1.0, got {loss}"
353        );
354    }
355
356    #[test]
357    fn test_forward_output_shape_consistency() {
358        let mae = make_mae(8, 0.25);
359        let x = Array2::ones((12, 8));
360        let (recon, _loss) = mae.forward(&x, 0);
361        assert_eq!(recon.dim(), (12, 8));
362    }
363
364    #[test]
365    fn test_mask_rate_zero_nothing_masked() {
366        let mae = make_mae(4, 0.0);
367        let x = Array2::ones((20, 4));
368        let (_, idx) = mae.mask_features(&x, 0);
369        assert!(idx.is_empty(), "mask_rate=0 should mask no nodes");
370        // Loss should be 0 since no nodes are masked
371        let encoded = mae.encode(&x);
372        let recon = mae.decode(&encoded);
373        let loss = mae.sce_loss(&x, &recon, &idx, 2.0);
374        assert_eq!(loss, 0.0);
375    }
376
377    #[test]
378    fn test_mask_rate_one_all_masked() {
379        let mae = make_mae(4, 1.0);
380        let x = Array2::ones((10, 4));
381        let (_, idx) = mae.mask_features(&x, 0);
382        assert_eq!(idx.len(), 10, "mask_rate=1 should mask all nodes");
383    }
384
385    #[test]
386    fn test_forward_loss_is_finite() {
387        let mae = make_mae(8, 0.3);
388        let x = Array2::from_shape_fn((20, 8), |(i, j)| (i as f64 * 0.1) + (j as f64 * 0.01));
389        let (_recon, loss) = mae.forward(&x, 7);
390        assert!(loss.is_finite(), "forward loss must be finite, got {loss}");
391        assert!(loss >= 0.0, "SCE loss must be non-negative, got {loss}");
392    }
393}