1use crate::error::GnnError;
16use crate::layer::{LayerNorm, Linear};
17use rand::seq::SliceRandom;
18use rand::Rng;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum LossFn {
23 Sce { gamma: f32 },
26 Mse,
28}
29
30impl Default for LossFn {
31 fn default() -> Self { Self::Sce { gamma: 2.0 } }
32}
33
34#[derive(Debug, Clone)]
36pub struct GraphMAEConfig {
37 pub mask_ratio: f32,
39 pub num_layers: usize,
41 pub hidden_dim: usize,
43 pub num_heads: usize,
45 pub decoder_layers: usize,
47 pub re_mask_ratio: f32,
49 pub loss_fn: LossFn,
51 pub input_dim: usize,
53}
54
55impl Default for GraphMAEConfig {
56 fn default() -> Self {
57 Self {
58 mask_ratio: 0.5, num_layers: 2, hidden_dim: 64, num_heads: 4,
59 decoder_layers: 1, re_mask_ratio: 0.0, loss_fn: LossFn::default(), input_dim: 64,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct GraphData {
67 pub node_features: Vec<Vec<f32>>,
69 pub adjacency: Vec<Vec<usize>>,
71 pub num_nodes: usize,
73}
74
75#[derive(Debug, Clone)]
77pub struct MaskResult {
78 pub masked_features: Vec<Vec<f32>>,
80 pub mask_indices: Vec<usize>,
82}
83
84pub struct FeatureMasking {
86 mask_token: Vec<f32>,
87}
88
89impl FeatureMasking {
90 pub fn new(dim: usize) -> Self {
92 let mut rng = rand::thread_rng();
93 Self { mask_token: (0..dim).map(|_| rng.gen::<f32>() * 0.02 - 0.01).collect() }
94 }
95
96 pub fn mask_nodes(&self, features: &[Vec<f32>], mask_ratio: f32) -> MaskResult {
98 let n = features.len();
99 let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
100 let mut rng = rand::thread_rng();
101 let mut indices: Vec<usize> = (0..n).collect();
102 indices.shuffle(&mut rng);
103 let mask_indices = indices[..num_mask.min(n)].to_vec();
104 let mut masked = features.to_vec();
105 for &i in &mask_indices { masked[i] = self.mask_token.clone(); }
106 MaskResult { masked_features: masked, mask_indices }
107 }
108
109 pub fn mask_by_degree(
111 &self, features: &[Vec<f32>], adjacency: &[Vec<usize>], mask_ratio: f32,
112 ) -> MaskResult {
113 let n = features.len();
114 let num_mask = ((n as f32) * mask_ratio.clamp(0.0, 1.0)).round() as usize;
115 let degrees: Vec<f32> = adjacency.iter().map(|a| a.len() as f32 + 1.0).collect();
116 let total: f32 = degrees.iter().sum();
117 let probs: Vec<f32> = degrees.iter().map(|d| d / total).collect();
118 let mut rng = rand::thread_rng();
119 let mut avail: Vec<usize> = (0..n).collect();
120 let mut mask_indices = Vec::with_capacity(num_mask);
121 for _ in 0..num_mask.min(n) {
122 if avail.is_empty() { break; }
123 let rp: Vec<f32> = avail.iter().map(|&i| probs[i]).collect();
124 let s: f32 = rp.iter().sum();
125 if s <= 0.0 { break; }
126 let thr = rng.gen::<f32>() * s;
127 let mut cum = 0.0;
128 let mut chosen = 0;
129 for (pos, &p) in rp.iter().enumerate() {
130 cum += p;
131 if cum >= thr { chosen = pos; break; }
132 }
133 mask_indices.push(avail[chosen]);
134 avail.swap_remove(chosen);
135 }
136 let mut masked = features.to_vec();
137 for &i in &mask_indices { masked[i] = self.mask_token.clone(); }
138 MaskResult { masked_features: masked, mask_indices }
139 }
140}
141
142struct GATLayer {
144 linear: Linear,
145 attn_src: Vec<f32>,
146 attn_dst: Vec<f32>,
147 norm: LayerNorm,
148 num_heads: usize,
149}
150
151impl GATLayer {
152 fn new(input_dim: usize, output_dim: usize, num_heads: usize) -> Self {
153 let mut rng = rand::thread_rng();
154 let hd = output_dim / num_heads.max(1);
155 Self {
156 linear: Linear::new(input_dim, output_dim),
157 attn_src: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
158 attn_dst: (0..hd).map(|_| rng.gen::<f32>() * 0.1).collect(),
159 norm: LayerNorm::new(output_dim, 1e-5),
160 num_heads,
161 }
162 }
163
164 fn forward(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
165 let proj: Vec<Vec<f32>> = features.iter().map(|f| self.linear.forward(f)).collect();
166 let od = proj.first().map_or(0, |v| v.len());
167 let hd = od / self.num_heads.max(1);
168 let mut output = Vec::with_capacity(features.len());
169 for i in 0..features.len() {
170 if adj[i].is_empty() {
171 output.push(elu_vec(&proj[i]));
172 continue;
173 }
174 let mut agg = vec![0.0f32; od];
175 for h in 0..self.num_heads {
176 let (s, e) = (h * hd, (h + 1) * hd);
177 let ss: f32 = proj[i][s..e].iter().zip(&self.attn_src).map(|(a, b)| a * b).sum();
178 let mut scores: Vec<f32> = adj[i].iter().map(|&j| {
179 let ds: f32 = proj[j][s..e].iter().zip(&self.attn_dst).map(|(a, b)| a * b).sum();
180 let v = ss + ds;
181 if v >= 0.0 { v } else { 0.2 * v } }).collect();
183 let mx = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
184 let exp: Vec<f32> = scores.iter_mut().map(|v| (*v - mx).exp()).collect();
185 let sm = exp.iter().sum::<f32>().max(1e-10);
186 for (k, &j) in adj[i].iter().enumerate() {
187 let w = exp[k] / sm;
188 for d in s..e { agg[d] += w * proj[j][d]; }
189 }
190 }
191 for v in &mut agg { *v /= self.num_heads as f32; }
192 if features[i].len() == od {
193 for (a, &f) in agg.iter_mut().zip(features[i].iter()) { *a += f; }
194 }
195 output.push(elu_vec(&self.norm.forward(&agg)));
196 }
197 output
198 }
199}
200
201pub struct GATEncoder { layers: Vec<GATLayer> }
203
204impl GATEncoder {
205 pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize, num_heads: usize) -> Self {
207 let layers = (0..num_layers).map(|i| {
208 GATLayer::new(if i == 0 { input_dim } else { hidden_dim }, hidden_dim, num_heads)
209 }).collect();
210 Self { layers }
211 }
212
213 pub fn encode(&self, features: &[Vec<f32>], adj: &[Vec<usize>]) -> Vec<Vec<f32>> {
215 self.layers.iter().fold(features.to_vec(), |h, l| l.forward(&h, adj))
216 }
217}
218
219pub struct GraphMAEDecoder { layers: Vec<Linear>, norm: LayerNorm }
221
222impl GraphMAEDecoder {
223 pub fn new(hidden_dim: usize, output_dim: usize, num_layers: usize) -> Self {
225 let n = num_layers.max(1);
226 let layers = (0..n).map(|i| {
227 let out = if i == n - 1 { output_dim } else { hidden_dim };
228 Linear::new(if i == 0 { hidden_dim } else { hidden_dim }, out)
229 }).collect();
230 Self { layers, norm: LayerNorm::new(output_dim, 1e-5) }
231 }
232
233 pub fn decode(&self, latent: &[Vec<f32>], mask_idx: &[usize], re_mask: f32) -> Vec<Vec<f32>> {
235 let mut rng = rand::thread_rng();
236 mask_idx.iter().map(|&idx| {
237 let mut h = latent[idx].clone();
238 if re_mask > 0.0 {
239 let nz = ((h.len() as f32) * re_mask).round() as usize;
240 let mut dims: Vec<usize> = (0..h.len()).collect();
241 dims.shuffle(&mut rng);
242 for &d in dims.iter().take(nz) { h[d] = 0.0; }
243 }
244 for layer in &self.layers { h = elu_vec(&layer.forward(&h)); }
245 self.norm.forward(&h)
246 }).collect()
247 }
248}
249
250pub fn sce_loss(preds: &[Vec<f32>], targets: &[Vec<f32>], gamma: f32) -> f32 {
252 if preds.is_empty() { return 0.0; }
253 preds.iter().zip(targets).map(|(p, t)| {
254 let dot: f32 = p.iter().zip(t).map(|(a, b)| a * b).sum();
255 let np = p.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
256 let nt = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
257 (1.0 - (dot / (np * nt)).clamp(-1.0, 1.0)).powf(gamma)
258 }).sum::<f32>() / preds.len() as f32
259}
260
261pub fn mse_loss(preds: &[Vec<f32>], targets: &[Vec<f32>]) -> f32 {
263 if preds.is_empty() { return 0.0; }
264 let n: usize = preds.iter().map(|v| v.len()).sum();
265 if n == 0 { return 0.0; }
266 preds.iter().zip(targets).flat_map(|(p, t)| {
267 p.iter().zip(t).map(|(a, b)| (a - b).powi(2))
268 }).sum::<f32>() / n as f32
269}
270
271pub struct GraphMAE {
273 config: GraphMAEConfig,
274 masking: FeatureMasking,
275 encoder: GATEncoder,
276 decoder: GraphMAEDecoder,
277}
278
279impl GraphMAE {
280 pub fn new(config: GraphMAEConfig) -> Result<Self, GnnError> {
285 if config.hidden_dim % config.num_heads != 0 {
286 return Err(GnnError::layer_config(format!(
287 "hidden_dim ({}) must be divisible by num_heads ({})",
288 config.hidden_dim, config.num_heads
289 )));
290 }
291 if !(0.0..=1.0).contains(&config.mask_ratio) {
292 return Err(GnnError::layer_config("mask_ratio must be in [0.0, 1.0]"));
293 }
294 let masking = FeatureMasking::new(config.input_dim);
295 let encoder = GATEncoder::new(config.input_dim, config.hidden_dim, config.num_layers, config.num_heads);
296 let decoder = GraphMAEDecoder::new(config.hidden_dim, config.input_dim, config.decoder_layers);
297 Ok(Self { config, masking, encoder, decoder })
298 }
299
300 pub fn train_step(&self, graph: &GraphData) -> f32 {
303 let mr = self.masking.mask_nodes(&graph.node_features, self.config.mask_ratio);
304 let latent = self.encoder.encode(&mr.masked_features, &graph.adjacency);
305 let recon = self.decoder.decode(&latent, &mr.mask_indices, self.config.re_mask_ratio);
306 let targets: Vec<Vec<f32>> = mr.mask_indices.iter().map(|&i| graph.node_features[i].clone()).collect();
307 match self.config.loss_fn {
308 LossFn::Sce { gamma } => sce_loss(&recon, &targets, gamma),
309 LossFn::Mse => mse_loss(&recon, &targets),
310 }
311 }
312
313 pub fn encode(&self, graph: &GraphData) -> Vec<Vec<f32>> {
315 self.encoder.encode(&graph.node_features, &graph.adjacency)
316 }
317
318 pub fn get_embeddings(&self, graph: &GraphData) -> Vec<Vec<f32>> { self.encode(graph) }
320}
321
322fn elu_vec(v: &[f32]) -> Vec<f32> {
323 v.iter().map(|&x| if x >= 0.0 { x } else { x.exp() - 1.0 }).collect()
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 fn graph(n: usize, d: usize) -> GraphData {
331 let feats: Vec<Vec<f32>> = (0..n)
332 .map(|i| (0..d).map(|j| (i * d + j) as f32 * 0.1).collect()).collect();
333 let adj: Vec<Vec<usize>> = (0..n).map(|i| {
334 let mut nb = Vec::new();
335 if i > 0 { nb.push(i - 1); }
336 if i + 1 < n { nb.push(i + 1); }
337 nb
338 }).collect();
339 GraphData { node_features: feats, adjacency: adj, num_nodes: n }
340 }
341
342 fn cfg(dim: usize) -> GraphMAEConfig {
343 GraphMAEConfig {
344 input_dim: dim, hidden_dim: 16, num_heads: 4, num_layers: 2,
345 decoder_layers: 1, mask_ratio: 0.5, re_mask_ratio: 0.0, loss_fn: LossFn::default(),
346 }
347 }
348
349 #[test]
350 fn test_masking_ratio() {
351 let feats: Vec<Vec<f32>> = (0..100).map(|i| vec![i as f32; 8]).collect();
352 let m = FeatureMasking::new(8);
353 let r = m.mask_nodes(&feats, 0.3);
354 assert!((r.mask_indices.len() as i32 - 30).unsigned_abs() <= 1);
355 }
356
357 #[test]
358 fn test_encoder_forward() {
359 let g = graph(5, 16);
360 let enc = GATEncoder::new(16, 16, 2, 4);
361 let out = enc.encode(&g.node_features, &g.adjacency);
362 assert_eq!(out.len(), 5);
363 assert_eq!(out[0].len(), 16);
364 }
365
366 #[test]
367 fn test_decoder_reconstruction_shape() {
368 let dec = GraphMAEDecoder::new(16, 8, 1);
369 let lat: Vec<Vec<f32>> = (0..5).map(|_| vec![0.5; 16]).collect();
370 let r = dec.decode(&lat, &[0, 2, 4], 0.0);
371 assert_eq!(r.len(), 3);
372 assert_eq!(r[0].len(), 8);
373 }
374
375 #[test]
376 fn test_sce_loss_identical() {
377 let loss = sce_loss(&[vec![1.0, 0.0, 0.0]], &[vec![1.0, 0.0, 0.0]], 2.0);
378 assert!(loss < 1e-6, "SCE identical should be ~0, got {loss}");
379 }
380
381 #[test]
382 fn test_sce_loss_orthogonal() {
383 let loss = sce_loss(&[vec![1.0, 0.0]], &[vec![0.0, 1.0]], 2.0);
384 assert!((loss - 1.0).abs() < 1e-5, "SCE orthogonal should be 1.0, got {loss}");
385 }
386
387 #[test]
388 fn test_mse_loss() {
389 assert!(mse_loss(&[vec![1.0, 2.0]], &[vec![1.0, 2.0]]) < 1e-8);
390 assert!((mse_loss(&[vec![0.0, 0.0]], &[vec![1.0, 1.0]]) - 1.0).abs() < 1e-6);
391 }
392
393 #[test]
394 fn test_train_step_returns_finite_loss() {
395 let model = GraphMAE::new(cfg(16)).unwrap();
396 let loss = model.train_step(&graph(10, 16));
397 assert!(loss.is_finite() && loss >= 0.0, "bad loss: {loss}");
398 }
399
400 #[test]
401 fn test_re_masking() {
402 let dec = GraphMAEDecoder::new(16, 8, 1);
403 let lat = vec![vec![1.0; 16]; 3];
404 let a = dec.decode(&lat, &[0, 1, 2], 0.0);
405 let b = dec.decode(&lat, &[0, 1, 2], 0.8);
406 let diff: f32 = a[0].iter().zip(&b[0]).map(|(x, y)| (x - y).abs()).sum();
407 assert!(diff > 1e-6, "re-masking should change output");
408 }
409
410 #[test]
411 fn test_degree_based_masking() {
412 let feats: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 8]).collect();
413 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); 10];
414 for i in 1..10 { adj[0].push(i); adj[i].push(0); }
415 let r = FeatureMasking::new(8).mask_by_degree(&feats, &adj, 0.5);
416 assert_eq!(r.mask_indices.len(), 5);
417 }
418
419 #[test]
420 fn test_single_node_graph() {
421 let g = GraphData { node_features: vec![vec![1.0; 16]], adjacency: vec![vec![]], num_nodes: 1 };
422 assert!(GraphMAE::new(cfg(16)).unwrap().train_step(&g).is_finite());
423 }
424
425 #[test]
426 fn test_encode_for_downstream() {
427 let model = GraphMAE::new(cfg(16)).unwrap();
428 let emb = model.get_embeddings(&graph(8, 16));
429 assert_eq!(emb.len(), 8);
430 assert_eq!(emb[0].len(), 16);
431 for e in &emb { for &v in e { assert!(v.is_finite()); } }
432 }
433
434 #[test]
435 fn test_invalid_config() {
436 assert!(GraphMAE::new(GraphMAEConfig { hidden_dim: 15, num_heads: 4, ..cfg(16) }).is_err());
437 assert!(GraphMAE::new(GraphMAEConfig { mask_ratio: 1.5, ..cfg(16) }).is_err());
438 }
439}