1use scirs2_core::ndarray::{Array1, Array2};
26use scirs2_core::random::{Rng, RngExt, SeedableRng};
27
28#[derive(Debug, Clone)]
34pub struct GraphMaeConfig {
35 pub mask_rate: f64,
37 pub encoder_dim: usize,
39 pub decoder_dim: usize,
41 pub replace_token_scale: f64,
43}
44
45impl Default for GraphMaeConfig {
46 fn default() -> Self {
47 Self {
48 mask_rate: 0.25,
49 encoder_dim: 64,
50 decoder_dim: 64,
51 replace_token_scale: 0.1,
52 }
53 }
54}
55
56pub struct GraphMae {
68 mask_token: Array1<f64>,
70 encoder_weight: Array2<f64>,
72 decoder_weight: Array2<f64>,
74 feature_dim: usize,
76 config: GraphMaeConfig,
77}
78
79impl GraphMae {
80 pub fn new(feature_dim: usize, config: GraphMaeConfig, seed: u64) -> Self {
87 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
88
89 let s = config.replace_token_scale;
91 let mask_token = Array1::from_shape_fn(feature_dim, |_| rng.random::<f64>() * 2.0 * s - s);
92
93 let enc_scale = (6.0 / (feature_dim + config.encoder_dim) as f64).sqrt();
95 let encoder_weight = Array2::from_shape_fn((feature_dim, config.encoder_dim), |_| {
96 rng.random::<f64>() * 2.0 * enc_scale - enc_scale
97 });
98
99 let dec_scale = (6.0 / (config.encoder_dim + feature_dim) as f64).sqrt();
101 let decoder_weight = Array2::from_shape_fn((config.encoder_dim, feature_dim), |_| {
102 rng.random::<f64>() * 2.0 * dec_scale - dec_scale
103 });
104
105 GraphMae {
106 mask_token,
107 encoder_weight,
108 decoder_weight,
109 feature_dim,
110 config,
111 }
112 }
113
114 pub fn mask_features(&self, features: &Array2<f64>, seed: u64) -> (Array2<f64>, Vec<usize>) {
128 let n_nodes = features.dim().0;
129 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
130
131 let mut masked = features.clone();
132 let mut mask_indices = Vec::new();
133
134 for i in 0..n_nodes {
135 if rng.random::<f64>() < self.config.mask_rate {
136 mask_indices.push(i);
137 for d in 0..self.feature_dim {
138 masked[[i, d]] = self.mask_token[d];
139 }
140 }
141 }
142
143 mask_indices.sort_unstable();
144 (masked, mask_indices)
145 }
146
147 pub fn encode(&self, features: &Array2<f64>) -> Array2<f64> {
155 let n_nodes = features.dim().0;
156 let enc_dim = self.config.encoder_dim;
157
158 let mut z = Array2::zeros((n_nodes, enc_dim));
159 for i in 0..n_nodes {
160 for k in 0..enc_dim {
161 let mut val = 0.0;
162 for d in 0..self.feature_dim {
163 val += features[[i, d]] * self.encoder_weight[[d, k]];
164 }
165 z[[i, k]] = if val > 0.0 { val } else { 0.0 }; }
167 }
168 z
169 }
170
171 pub fn decode(&self, encoded: &Array2<f64>) -> Array2<f64> {
179 let n_nodes = encoded.dim().0;
180
181 let mut out = Array2::zeros((n_nodes, self.feature_dim));
182 for i in 0..n_nodes {
183 for d in 0..self.feature_dim {
184 let mut val = 0.0;
185 for k in 0..self.config.encoder_dim {
186 val += encoded[[i, k]] * self.decoder_weight[[k, d]];
187 }
188 out[[i, d]] = val;
189 }
190 }
191 out
192 }
193
194 pub fn sce_loss(
208 &self,
209 original: &Array2<f64>,
210 reconstructed: &Array2<f64>,
211 mask_indices: &[usize],
212 gamma: f64,
213 ) -> f64 {
214 if mask_indices.is_empty() {
215 return 0.0;
216 }
217
218 let mut total = 0.0;
219 let d = self.feature_dim;
220
221 for &i in mask_indices {
222 let mut dot = 0.0;
224 let mut norm_r = 0.0;
225 let mut norm_o = 0.0;
226 for k in 0..d {
227 let r = reconstructed[[i, k]];
228 let o = original[[i, k]];
229 dot += r * o;
230 norm_r += r * r;
231 norm_o += o * o;
232 }
233 let denom = norm_r.sqrt().max(1e-12) * norm_o.sqrt().max(1e-12);
234 let cos_sim = (dot / denom).clamp(-1.0, 1.0);
235 let term = (1.0 - cos_sim).powf(gamma);
237 total += term;
238 }
239
240 total / mask_indices.len() as f64
241 }
242
243 pub fn forward(&self, features: &Array2<f64>, seed: u64) -> (Array2<f64>, f64) {
257 let (masked, mask_indices) = self.mask_features(features, seed);
258 let encoded = self.encode(&masked);
259 let reconstructed = self.decode(&encoded);
260 let loss = self.sce_loss(features, &reconstructed, &mask_indices, 2.0);
261 (reconstructed, loss)
262 }
263
264 pub fn mask_token(&self) -> &Array1<f64> {
266 &self.mask_token
267 }
268
269 pub fn feature_dim(&self) -> usize {
271 self.feature_dim
272 }
273
274 pub fn encoder_dim(&self) -> usize {
276 self.config.encoder_dim
277 }
278}
279
280#[cfg(test)]
285mod tests {
286 use super::*;
287
288 fn make_mae(feature_dim: usize, mask_rate: f64) -> GraphMae {
289 let cfg = GraphMaeConfig {
290 mask_rate,
291 encoder_dim: 16,
292 decoder_dim: feature_dim,
293 replace_token_scale: 0.1,
294 };
295 GraphMae::new(feature_dim, cfg, 42)
296 }
297
298 #[test]
299 fn test_mask_features_approximate_rate() {
300 let mae = make_mae(8, 0.5);
301 let x = Array2::ones((100, 8));
302 let (_, mask_idx) = mae.mask_features(&x, 0);
303 let frac = mask_idx.len() as f64 / 100.0;
306 assert!(
307 (frac - 0.5).abs() < 0.2,
308 "masking fraction {frac} too far from 0.5"
309 );
310 }
311
312 #[test]
313 fn test_encode_output_shape() {
314 let mae = make_mae(8, 0.25);
315 let x = Array2::ones((10, 8));
316 let z = mae.encode(&x);
317 assert_eq!(z.dim(), (10, 16));
318 }
319
320 #[test]
321 fn test_decode_output_shape_matches_feature_dim() {
322 let mae = make_mae(8, 0.25);
323 let z = Array2::ones((10, 16));
324 let out = mae.decode(&z);
325 assert_eq!(out.dim(), (10, 8));
326 }
327
328 #[test]
329 fn test_sce_loss_identical_is_zero() {
330 let mae = make_mae(4, 0.25);
331 let x = Array2::from_shape_fn((6, 4), |(i, j)| (i + j + 1) as f64);
332 let loss = mae.sce_loss(&x, &x, &[0, 1, 2, 3, 4, 5], 2.0);
334 assert!(
335 loss.abs() < 1e-9,
336 "SCE loss for identical tensors should be ~0, got {loss}"
337 );
338 }
339
340 #[test]
341 fn test_sce_loss_orthogonal_positive() {
342 let mae = make_mae(4, 0.25);
343 let original = Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0])
345 .expect("ok");
346 let recon = Array2::from_shape_vec((2, 4), vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
347 .expect("ok");
348 let loss = mae.sce_loss(&original, &recon, &[0, 1], 2.0);
349 assert!(
351 (loss - 1.0).abs() < 1e-9,
352 "SCE loss for orthogonal vectors should be 1.0, got {loss}"
353 );
354 }
355
356 #[test]
357 fn test_forward_output_shape_consistency() {
358 let mae = make_mae(8, 0.25);
359 let x = Array2::ones((12, 8));
360 let (recon, _loss) = mae.forward(&x, 0);
361 assert_eq!(recon.dim(), (12, 8));
362 }
363
364 #[test]
365 fn test_mask_rate_zero_nothing_masked() {
366 let mae = make_mae(4, 0.0);
367 let x = Array2::ones((20, 4));
368 let (_, idx) = mae.mask_features(&x, 0);
369 assert!(idx.is_empty(), "mask_rate=0 should mask no nodes");
370 let encoded = mae.encode(&x);
372 let recon = mae.decode(&encoded);
373 let loss = mae.sce_loss(&x, &recon, &idx, 2.0);
374 assert_eq!(loss, 0.0);
375 }
376
377 #[test]
378 fn test_mask_rate_one_all_masked() {
379 let mae = make_mae(4, 1.0);
380 let x = Array2::ones((10, 4));
381 let (_, idx) = mae.mask_features(&x, 0);
382 assert_eq!(idx.len(), 10, "mask_rate=1 should mask all nodes");
383 }
384
385 #[test]
386 fn test_forward_loss_is_finite() {
387 let mae = make_mae(8, 0.3);
388 let x = Array2::from_shape_fn((20, 8), |(i, j)| (i as f64 * 0.1) + (j as f64 * 0.01));
389 let (_recon, loss) = mae.forward(&x, 7);
390 assert!(loss.is_finite(), "forward loss must be finite, got {loss}");
391 assert!(loss >= 0.0, "SCE loss must be non-negative, got {loss}");
392 }
393}