Skip to main content

scirs2_graph/ssl/
pretrain.rs

1//! Graph pre-training strategies for self-supervised learning.
2//!
3//! This module provides three complementary pre-training objectives:
4//!
5//! | Struct | Strategy | Reference |
6//! |---|---|---|
7//! | [`NodeMaskingPretrainer`] | BERT-style node attribute masking | Hu et al. 2020 (Strategies for Pre-training Graph Neural Networks) |
8//! | [`GraphContextPretrainer`] | Subgraph-context InfoNCE contrastive | Hu et al. 2020 |
9//! | [`AttributeReconstructionObjective`] | MAE-style attribute reconstruction | Hou et al. 2022 (GraphMAE) |
10
11use crate::error::{GraphError, Result};
12
13// ── Tiny LCG RNG (no external crate) ─────────────────────────────────────────
14
15struct Lcg(u64);
16
17impl Lcg {
18    fn new(seed: u64) -> Self {
19        Self(seed ^ 0xdeadbeefcafe1234)
20    }
21    /// Return next pseudo-random f64 in [0, 1).
22    fn next_f64(&mut self) -> f64 {
23        self.0 = self
24            .0
25            .wrapping_mul(6364136223846793005)
26            .wrapping_add(1442695040888963407);
27        // Use high 53 bits for double mantissa
28        let bits = self.0 >> 11;
29        bits as f64 / (1u64 << 53) as f64
30    }
31    /// Return next usize in 0..bound.
32    fn next_usize(&mut self, bound: usize) -> usize {
33        self.0 = self
34            .0
35            .wrapping_mul(6364136223846793005)
36            .wrapping_add(1442695040888963407);
37        ((self.0 >> 33) as usize) % bound
38    }
39}
40
41// ─────────────────────────────────────────────────────────────────────────────
42// A.  NodeMaskingPretrainer
43// ─────────────────────────────────────────────────────────────────────────────
44
45/// Configuration for BERT-style node attribute masking.
46#[derive(Debug, Clone)]
47pub struct NodeMaskingConfig {
48    /// Fraction of nodes to mask. Default 0.15.
49    pub mask_rate: f64,
50    /// Of the masked nodes, fraction replaced with a random feature vector
51    /// instead of a zero mask vector. Default 0.1.
52    pub replace_rate: f64,
53    /// Number of BFS hops used for context features (informational only).
54    /// Default 2.
55    pub n_neighbors: usize,
56    /// Feature dimensionality. Default 64.
57    pub feature_dim: usize,
58}
59
60impl Default for NodeMaskingConfig {
61    fn default() -> Self {
62        Self {
63            mask_rate: 0.15,
64            replace_rate: 0.1,
65            n_neighbors: 2,
66            feature_dim: 64,
67        }
68    }
69}
70
71/// BERT-style node attribute masking pre-trainer.
72///
73/// # Usage
74/// ```rust,no_run
75/// use scirs2_graph::ssl::pretrain::{NodeMaskingPretrainer, NodeMaskingConfig};
76///
77/// let cfg = NodeMaskingConfig { feature_dim: 4, ..Default::default() };
78/// let pretrainer = NodeMaskingPretrainer::new(cfg);
79/// let features = vec![vec![1.0, 2.0, 3.0, 4.0]; 10];
80/// let (masked, indices) = pretrainer.mask_nodes(&features, 42).unwrap();
81/// ```
82pub struct NodeMaskingPretrainer {
83    config: NodeMaskingConfig,
84}
85
86impl NodeMaskingPretrainer {
87    /// Create a new pretrainer with the given configuration.
88    pub fn new(config: NodeMaskingConfig) -> Self {
89        Self { config }
90    }
91
92    /// Apply attribute masking to the node feature matrix.
93    ///
94    /// # Arguments
95    /// * `features` – `n_nodes × feature_dim` feature matrix.
96    /// * `rng_seed` – seed for reproducibility.
97    ///
98    /// # Returns
99    /// `(masked_features, masked_indices)` where `masked_features` has the
100    /// selected nodes zeroed out (or replaced with random vectors), and
101    /// `masked_indices` is the sorted list of masked node indices.
102    ///
103    /// # Errors
104    /// Returns `GraphError::InvalidParameter` if feature vectors have
105    /// inconsistent lengths or if `mask_rate` is outside (0, 1].
106    pub fn mask_nodes(
107        &self,
108        features: &[Vec<f64>],
109        rng_seed: u64,
110    ) -> Result<(Vec<Vec<f64>>, Vec<usize>)> {
111        let n = features.len();
112        if n == 0 {
113            return Ok((vec![], vec![]));
114        }
115        let dim = features[0].len();
116        if dim == 0 {
117            return Err(GraphError::InvalidParameter {
118                param: "features".to_string(),
119                value: "empty feature vectors".to_string(),
120                expected: "non-empty feature vectors".to_string(),
121                context: "NodeMaskingPretrainer::mask_nodes".to_string(),
122            });
123        }
124        for (i, f) in features.iter().enumerate() {
125            if f.len() != dim {
126                return Err(GraphError::InvalidParameter {
127                    param: format!("features[{i}]"),
128                    value: format!("length {}", f.len()),
129                    expected: format!("length {dim}"),
130                    context: "NodeMaskingPretrainer::mask_nodes".to_string(),
131                });
132            }
133        }
134        if !(0.0 < self.config.mask_rate && self.config.mask_rate <= 1.0) {
135            return Err(GraphError::InvalidParameter {
136                param: "mask_rate".to_string(),
137                value: format!("{}", self.config.mask_rate),
138                expected: "value in (0, 1]".to_string(),
139                context: "NodeMaskingPretrainer::mask_nodes".to_string(),
140            });
141        }
142
143        let k = ((n as f64 * self.config.mask_rate).ceil() as usize).min(n);
144        let mut rng = Lcg::new(rng_seed);
145
146        // Randomly sample k distinct node indices (partial Fisher-Yates)
147        let mut indices: Vec<usize> = (0..n).collect();
148        for i in (n - k..n).rev() {
149            let j = rng.next_usize(i + 1);
150            indices.swap(i, j);
151        }
152        let mut masked_indices: Vec<usize> = indices[n - k..].to_vec();
153        masked_indices.sort_unstable();
154
155        // Build masked feature matrix
156        let mut masked = features.to_vec();
157        let masked_set: std::collections::HashSet<usize> = masked_indices.iter().cloned().collect();
158        for &node in &masked_indices {
159            let replace = rng.next_f64() < self.config.replace_rate;
160            masked[node] = if replace {
161                // Replace with random vector sampled from U(-1, 1)
162                (0..dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect()
163            } else {
164                // Standard BERT masking: zero out
165                vec![0.0; dim]
166            };
167        }
168        // Suppress unused variable warning
169        let _ = masked_set;
170
171        Ok((masked, masked_indices))
172    }
173
174    /// Compute MSE reconstruction loss only on the masked nodes.
175    ///
176    /// # Arguments
177    /// * `predicted`     – reconstructed features (same shape as `original`).
178    /// * `original`      – ground-truth features.
179    /// * `masked_indices`– which node indices to include in the loss.
180    ///
181    /// # Errors
182    /// Returns `GraphError::InvalidParameter` if any index is out of range or
183    /// vectors have mismatched lengths.
184    pub fn reconstruction_loss(
185        &self,
186        predicted: &[Vec<f64>],
187        original: &[Vec<f64>],
188        masked_indices: &[usize],
189    ) -> Result<f64> {
190        if predicted.len() != original.len() {
191            return Err(GraphError::InvalidParameter {
192                param: "predicted / original".to_string(),
193                value: format!("lengths {} vs {}", predicted.len(), original.len()),
194                expected: "equal lengths".to_string(),
195                context: "NodeMaskingPretrainer::reconstruction_loss".to_string(),
196            });
197        }
198        if masked_indices.is_empty() {
199            return Ok(0.0);
200        }
201        let n = predicted.len();
202        let mut total = 0.0_f64;
203        let mut count = 0usize;
204        for &idx in masked_indices {
205            if idx >= n {
206                return Err(GraphError::InvalidParameter {
207                    param: "masked_indices".to_string(),
208                    value: format!("{idx}"),
209                    expected: format!("index < {n}"),
210                    context: "NodeMaskingPretrainer::reconstruction_loss".to_string(),
211                });
212            }
213            let p = &predicted[idx];
214            let o = &original[idx];
215            if p.len() != o.len() {
216                return Err(GraphError::InvalidParameter {
217                    param: format!("predicted[{idx}]"),
218                    value: format!("length {}", p.len()),
219                    expected: format!("length {}", o.len()),
220                    context: "NodeMaskingPretrainer::reconstruction_loss".to_string(),
221                });
222            }
223            for (a, b) in p.iter().zip(o.iter()) {
224                let diff = a - b;
225                total += diff * diff;
226                count += 1;
227            }
228        }
229        Ok(if count > 0 { total / count as f64 } else { 0.0 })
230    }
231}
232
233// ─────────────────────────────────────────────────────────────────────────────
234// B.  GraphContextPretrainer
235// ─────────────────────────────────────────────────────────────────────────────
236
237/// Configuration for subgraph-context contrastive pre-training.
238#[non_exhaustive]
239#[derive(Debug, Clone)]
240pub struct GraphContextConfig {
241    /// Maximum number of nodes to include in the context subgraph (BFS size).
242    /// Default 8.
243    pub context_size: usize,
244    /// Number of negative context samples per positive pair. Default 4.
245    pub negative_samples: usize,
246    /// Feature dimensionality. Default 64.
247    pub feature_dim: usize,
248    /// Temperature for InfoNCE loss. Default 0.07.
249    pub temperature: f64,
250}
251
252impl Default for GraphContextConfig {
253    fn default() -> Self {
254        Self {
255            context_size: 8,
256            negative_samples: 4,
257            feature_dim: 64,
258            temperature: 0.07,
259        }
260    }
261}
262
263/// Graph-context contrastive pre-trainer.
264///
265/// Samples context subgraphs via BFS and maximises InfoNCE between a center
266/// node embedding and its positive context while pushing away negative samples.
267pub struct GraphContextPretrainer {
268    config: GraphContextConfig,
269}
270
271impl GraphContextPretrainer {
272    /// Create a new pretrainer.
273    pub fn new(config: GraphContextConfig) -> Self {
274        Self { config }
275    }
276
277    /// Sample a context subgraph around `center` using BFS, limited to
278    /// `config.context_size` nodes (including the center).
279    ///
280    /// # Arguments
281    /// * `adj`    – undirected edge list as `(src, dst)` pairs.
282    /// * `center` – starting node.
283    /// * `n_nodes`– total number of nodes.
284    /// * `seed`   – RNG seed (used to break BFS frontier ties randomly).
285    ///
286    /// # Returns
287    /// A vector of node indices (sorted) in the context subgraph.
288    ///
289    /// # Errors
290    /// Returns `GraphError::InvalidParameter` if `center >= n_nodes`.
291    pub fn sample_context_subgraph(
292        &self,
293        adj: &[(usize, usize)],
294        center: usize,
295        n_nodes: usize,
296        seed: u64,
297    ) -> Result<Vec<usize>> {
298        if n_nodes == 0 {
299            return Ok(vec![]);
300        }
301        if center >= n_nodes {
302            return Err(GraphError::InvalidParameter {
303                param: "center".to_string(),
304                value: format!("{center}"),
305                expected: format!("index < {n_nodes}"),
306                context: "GraphContextPretrainer::sample_context_subgraph".to_string(),
307            });
308        }
309
310        // Build adjacency lists
311        let mut lists: Vec<Vec<usize>> = vec![Vec::new(); n_nodes];
312        for &(u, v) in adj {
313            if u < n_nodes && v < n_nodes && u != v {
314                lists[u].push(v);
315                lists[v].push(u);
316            }
317        }
318
319        let max_ctx = self.config.context_size.max(1);
320        let mut visited = vec![false; n_nodes];
321        let mut result = Vec::with_capacity(max_ctx);
322        let mut queue = std::collections::VecDeque::new();
323        let mut rng = Lcg::new(seed);
324
325        visited[center] = true;
326        queue.push_back(center);
327        result.push(center);
328
329        while let Some(v) = queue.pop_front() {
330            if result.len() >= max_ctx {
331                break;
332            }
333            // Shuffle neighbors to avoid deterministic bias
334            let mut nbrs = lists[v].clone();
335            for i in (1..nbrs.len()).rev() {
336                let j = rng.next_usize(i + 1);
337                nbrs.swap(i, j);
338            }
339            for w in nbrs {
340                if result.len() >= max_ctx {
341                    break;
342                }
343                if !visited[w] {
344                    visited[w] = true;
345                    result.push(w);
346                    queue.push_back(w);
347                }
348            }
349        }
350
351        result.sort_unstable();
352        Ok(result)
353    }
354
355    /// Compute InfoNCE (noise-contrastive estimation) loss.
356    ///
357    /// ```text
358    /// L = -log [ exp(sim(a, p) / τ) / (exp(sim(a, p) / τ) + Σ_i exp(sim(a, nᵢ) / τ)) ]
359    /// ```
360    ///
361    /// where `sim` is cosine similarity.
362    ///
363    /// # Arguments
364    /// * `anchor`      – anchor embedding vector.
365    /// * `positive`    – positive (context) embedding vector.
366    /// * `negatives`   – list of negative embedding vectors.
367    /// * `temperature` – temperature τ.
368    ///
369    /// # Errors
370    /// Returns `GraphError::InvalidParameter` if vectors have mismatched dims.
371    pub fn contrastive_loss(
372        &self,
373        anchor: &[f64],
374        positive: &[f64],
375        negatives: &[Vec<f64>],
376        temperature: f64,
377    ) -> Result<f64> {
378        infonce_loss(anchor, positive, negatives, temperature)
379    }
380}
381
382/// Cosine similarity between two equal-length vectors. Returns 0 if either
383/// vector has zero norm.
384fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
385    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
386    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
387    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
388    if na == 0.0 || nb == 0.0 {
389        0.0
390    } else {
391        dot / (na * nb)
392    }
393}
394
395/// Standalone InfoNCE loss function.
396///
397/// # Errors
398/// Returns [`GraphError::InvalidParameter`] if dimension mismatches are found
399/// or `temperature <= 0`.
400pub fn infonce_loss(
401    anchor: &[f64],
402    positive: &[f64],
403    negatives: &[Vec<f64>],
404    temperature: f64,
405) -> Result<f64> {
406    let dim = anchor.len();
407    if dim == 0 {
408        return Err(GraphError::InvalidParameter {
409            param: "anchor".to_string(),
410            value: "empty".to_string(),
411            expected: "non-empty embedding vector".to_string(),
412            context: "infonce_loss".to_string(),
413        });
414    }
415    if positive.len() != dim {
416        return Err(GraphError::InvalidParameter {
417            param: "positive".to_string(),
418            value: format!("length {}", positive.len()),
419            expected: format!("length {dim}"),
420            context: "infonce_loss".to_string(),
421        });
422    }
423    if temperature <= 0.0 {
424        return Err(GraphError::InvalidParameter {
425            param: "temperature".to_string(),
426            value: format!("{temperature}"),
427            expected: "positive value".to_string(),
428            context: "infonce_loss".to_string(),
429        });
430    }
431    for (i, neg) in negatives.iter().enumerate() {
432        if neg.len() != dim {
433            return Err(GraphError::InvalidParameter {
434                param: format!("negatives[{i}]"),
435                value: format!("length {}", neg.len()),
436                expected: format!("length {dim}"),
437                context: "infonce_loss".to_string(),
438            });
439        }
440    }
441
442    let sim_pos = cosine_similarity(anchor, positive) / temperature;
443    // Numerically stable: subtract max before exp
444    let mut sims: Vec<f64> = std::iter::once(sim_pos)
445        .chain(
446            negatives
447                .iter()
448                .map(|n| cosine_similarity(anchor, n) / temperature),
449        )
450        .collect();
451    let max_sim = sims.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
452    for s in sims.iter_mut() {
453        *s = (*s - max_sim).exp();
454    }
455    let denom: f64 = sims.iter().sum();
456    let loss = -(sims[0].ln() - denom.ln());
457    Ok(loss)
458}
459
460// ─────────────────────────────────────────────────────────────────────────────
461// C.  AttributeReconstructionObjective
462// ─────────────────────────────────────────────────────────────────────────────
463
464/// Configuration for the attribute reconstruction MLP.
465#[non_exhaustive]
466#[derive(Debug, Clone)]
467pub struct AttrReconConfig {
468    /// Number of MLP layers. Default 2.
469    pub n_layers: usize,
470    /// Hidden layer dimension. Default 128.
471    pub hidden_dim: usize,
472    /// Dropout rate (stored for future training use). Default 0.1.
473    pub dropout: f64,
474}
475
476impl Default for AttrReconConfig {
477    fn default() -> Self {
478        Self {
479            n_layers: 2,
480            hidden_dim: 128,
481            dropout: 0.1,
482        }
483    }
484}
485
486/// A weight matrix and bias vector for a single linear layer.
487#[derive(Debug, Clone)]
488struct LinearLayer {
489    /// Weight matrix: `out_dim × in_dim`.
490    weights: Vec<Vec<f64>>,
491    /// Bias: `out_dim`.
492    bias: Vec<f64>,
493}
494
495impl LinearLayer {
496    /// Initialise with Xavier-uniform-like weights using the given seed.
497    fn new(in_dim: usize, out_dim: usize, seed: u64) -> Self {
498        let mut rng = Lcg::new(seed);
499        let scale = (6.0 / (in_dim + out_dim) as f64).sqrt();
500        let weights = (0..out_dim)
501            .map(|_| {
502                (0..in_dim)
503                    .map(|_| (rng.next_f64() * 2.0 - 1.0) * scale)
504                    .collect()
505            })
506            .collect();
507        let bias = vec![0.0f64; out_dim];
508        Self { weights, bias }
509    }
510
511    /// Apply this layer with tanh activation.  Returns `out_dim` values.
512    fn forward_tanh(&self, x: &[f64]) -> Vec<f64> {
513        self.weights
514            .iter()
515            .zip(self.bias.iter())
516            .map(|(row, b)| {
517                let pre: f64 = row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum::<f64>() + b;
518                pre.tanh()
519            })
520            .collect()
521    }
522
523    /// Apply this layer with no activation (output layer).
524    fn forward_linear(&self, x: &[f64]) -> Vec<f64> {
525        self.weights
526            .iter()
527            .zip(self.bias.iter())
528            .map(|(row, b)| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum::<f64>() + b)
529            .collect()
530    }
531}
532
533/// A simple multi-layer perceptron for attribute reconstruction.
534///
535/// Architecture: `input_dim → hidden_dim → … → input_dim` (tanh between
536/// hidden layers, linear output).  No back-propagation is implemented here;
537/// this struct provides the forward pass for inference / loss computation.
538pub struct AttributeReconstructionObjective {
539    config: AttrReconConfig,
540    layers: Vec<LinearLayer>,
541    input_dim: usize,
542}
543
544impl AttributeReconstructionObjective {
545    /// Build a new MLP for `input_dim`-dimensional features.
546    ///
547    /// Layers are randomly initialised.  Use the same `seed` for
548    /// reproducibility.
549    ///
550    /// # Errors
551    /// Returns `GraphError::InvalidParameter` if `input_dim == 0` or
552    /// `n_layers == 0`.
553    pub fn new(input_dim: usize, config: AttrReconConfig, seed: u64) -> Result<Self> {
554        if input_dim == 0 {
555            return Err(GraphError::InvalidParameter {
556                param: "input_dim".to_string(),
557                value: "0".to_string(),
558                expected: "positive dimension".to_string(),
559                context: "AttributeReconstructionObjective::new".to_string(),
560            });
561        }
562        if config.n_layers == 0 {
563            return Err(GraphError::InvalidParameter {
564                param: "n_layers".to_string(),
565                value: "0".to_string(),
566                expected: "at least 1 layer".to_string(),
567                context: "AttributeReconstructionObjective::new".to_string(),
568            });
569        }
570        let hidden = config.hidden_dim.max(1);
571        let mut layers = Vec::with_capacity(config.n_layers);
572
573        // First layer: input → hidden
574        layers.push(LinearLayer::new(input_dim, hidden, seed));
575
576        // Intermediate hidden layers
577        for i in 1..config.n_layers.saturating_sub(1) {
578            layers.push(LinearLayer::new(
579                hidden,
580                hidden,
581                seed.wrapping_add(i as u64),
582            ));
583        }
584
585        // Final layer: hidden → input (reconstruction)
586        if config.n_layers > 1 {
587            layers.push(LinearLayer::new(
588                hidden,
589                input_dim,
590                seed.wrapping_add(config.n_layers as u64),
591            ));
592        }
593
594        Ok(Self {
595            config,
596            layers,
597            input_dim,
598        })
599    }
600
601    /// Run a forward pass for each node's feature vector.
602    ///
603    /// # Arguments
604    /// * `features` – `n_nodes × input_dim` feature matrix.
605    ///
606    /// # Returns
607    /// Reconstructed feature matrix of the same shape.
608    ///
609    /// # Errors
610    /// Returns `GraphError::InvalidParameter` if any feature vector has
611    /// length ≠ `input_dim`.
612    pub fn forward(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
613        features
614            .iter()
615            .enumerate()
616            .map(|(i, f)| {
617                if f.len() != self.input_dim {
618                    return Err(GraphError::InvalidParameter {
619                        param: format!("features[{i}]"),
620                        value: format!("length {}", f.len()),
621                        expected: format!("length {}", self.input_dim),
622                        context: "AttributeReconstructionObjective::forward".to_string(),
623                    });
624                }
625                let mut h = f.clone();
626                let last = self.layers.len().saturating_sub(1);
627                for (j, layer) in self.layers.iter().enumerate() {
628                    h = if j < last {
629                        layer.forward_tanh(&h)
630                    } else {
631                        layer.forward_linear(&h)
632                    };
633                }
634                Ok(h)
635            })
636            .collect()
637    }
638
639    /// Mean squared error between `predicted` and `target`.
640    ///
641    /// # Errors
642    /// Returns `GraphError::InvalidParameter` if shapes differ.
643    pub fn mse_loss(&self, predicted: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
644        if predicted.len() != target.len() {
645            return Err(GraphError::InvalidParameter {
646                param: "predicted".to_string(),
647                value: format!("length {}", predicted.len()),
648                expected: format!("length {}", target.len()),
649                context: "AttributeReconstructionObjective::mse_loss".to_string(),
650            });
651        }
652        if predicted.is_empty() {
653            return Ok(0.0);
654        }
655        let mut total = 0.0_f64;
656        let mut count = 0usize;
657        for (p_row, t_row) in predicted.iter().zip(target.iter()) {
658            if p_row.len() != t_row.len() {
659                return Err(GraphError::InvalidParameter {
660                    param: "predicted row".to_string(),
661                    value: format!("length {}", p_row.len()),
662                    expected: format!("length {}", t_row.len()),
663                    context: "AttributeReconstructionObjective::mse_loss".to_string(),
664                });
665            }
666            for (a, b) in p_row.iter().zip(t_row.iter()) {
667                let diff = a - b;
668                total += diff * diff;
669                count += 1;
670            }
671        }
672        Ok(if count > 0 { total / count as f64 } else { 0.0 })
673    }
674
675    /// Access the underlying config.
676    pub fn config(&self) -> &AttrReconConfig {
677        &self.config
678    }
679}
680
681// ── Tests ─────────────────────────────────────────────────────────────────────
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686
687    // ── NodeMaskingPretrainer ─────────────────────────────────────────────────
688
689    #[test]
690    fn test_masking_correct_fraction() {
691        let n = 100;
692        let dim = 8;
693        let features: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64; dim]).collect();
694        let cfg = NodeMaskingConfig {
695            mask_rate: 0.15,
696            replace_rate: 0.0,
697            feature_dim: dim,
698            ..Default::default()
699        };
700        let pretrainer = NodeMaskingPretrainer::new(cfg);
701        let (_, indices) = pretrainer.mask_nodes(&features, 7).unwrap();
702        // ceil(100 * 0.15) = 15
703        assert_eq!(indices.len(), 15, "should mask exactly 15 nodes");
704    }
705
706    #[test]
707    fn test_masking_features_differ() {
708        let n = 20;
709        let dim = 4;
710        let features: Vec<Vec<f64>> = (0..n).map(|i| vec![(i + 1) as f64; dim]).collect();
711        let cfg = NodeMaskingConfig {
712            mask_rate: 0.5,
713            replace_rate: 0.0,
714            feature_dim: dim,
715            ..Default::default()
716        };
717        let pretrainer = NodeMaskingPretrainer::new(cfg);
718        let (masked, indices) = pretrainer.mask_nodes(&features, 99).unwrap();
719        // Masked nodes should be all zeros
720        for &idx in &indices {
721            assert_eq!(masked[idx], vec![0.0; dim], "node {idx} should be zeroed");
722        }
723        // Unmasked nodes should be identical to originals
724        for i in 0..n {
725            if !indices.contains(&i) {
726                assert_eq!(masked[i], features[i], "node {i} should be unchanged");
727            }
728        }
729    }
730
731    #[test]
732    fn test_reconstruction_loss_finite_positive() {
733        let n = 10;
734        let dim = 6;
735        let original: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64; dim]).collect();
736        let cfg = NodeMaskingConfig {
737            mask_rate: 0.3,
738            feature_dim: dim,
739            ..Default::default()
740        };
741        let pretrainer = NodeMaskingPretrainer::new(cfg);
742        let (masked, indices) = pretrainer.mask_nodes(&original, 11).unwrap();
743        let loss = pretrainer
744            .reconstruction_loss(&masked, &original, &indices)
745            .unwrap();
746        assert!(loss.is_finite(), "loss should be finite");
747        assert!(loss >= 0.0, "loss should be non-negative");
748    }
749
750    // ── GraphContextPretrainer ────────────────────────────────────────────────
751
752    #[test]
753    fn test_context_subgraph_bounded() {
754        let edges: Vec<(usize, usize)> = (0..9).map(|i| (i, i + 1)).collect();
755        let cfg = GraphContextConfig {
756            context_size: 4,
757            ..Default::default()
758        };
759        let pretrainer = GraphContextPretrainer::new(cfg.clone());
760        let ctx = pretrainer
761            .sample_context_subgraph(&edges, 5, 10, 42)
762            .unwrap();
763        assert!(
764            ctx.len() <= cfg.context_size,
765            "context size {} should be ≤ {}",
766            ctx.len(),
767            cfg.context_size
768        );
769    }
770
771    #[test]
772    fn test_context_subgraph_contains_center() {
773        let edges = vec![(0, 1), (1, 2), (2, 3)];
774        let cfg = GraphContextConfig {
775            context_size: 3,
776            ..Default::default()
777        };
778        let pretrainer = GraphContextPretrainer::new(cfg);
779        let ctx = pretrainer.sample_context_subgraph(&edges, 1, 4, 0).unwrap();
780        assert!(ctx.contains(&1), "context should include center node 1");
781    }
782
783    #[test]
784    fn test_contrastive_loss_pos_closer_lower_loss() {
785        // Anchor is [1,0], positive is also [1,0] (cos sim = 1).
786        // Negatives are [-1, 0] (cos sim = -1). Loss should be very small.
787        let anchor = vec![1.0_f64, 0.0];
788        let positive = vec![1.0_f64, 0.0];
789        let negatives = vec![vec![-1.0_f64, 0.0]; 4];
790        let cfg = GraphContextConfig {
791            temperature: 0.07,
792            ..Default::default()
793        };
794        let pretrainer = GraphContextPretrainer::new(cfg.clone());
795        let loss = pretrainer
796            .contrastive_loss(&anchor, &positive, &negatives, cfg.temperature)
797            .unwrap();
798        // Now make negative similar to anchor
799        let far_negatives = vec![vec![1.0_f64, 0.0]; 4];
800        let high_loss = pretrainer
801            .contrastive_loss(&anchor, &positive, &far_negatives, cfg.temperature)
802            .unwrap();
803        assert!(
804            loss < high_loss,
805            "loss with far negatives ({loss}) should be lower than loss with close negatives ({high_loss})"
806        );
807    }
808
809    #[test]
810    fn test_contrastive_loss_finite() {
811        let anchor = vec![0.5, 0.3, 0.2];
812        let positive = vec![0.4, 0.4, 0.2];
813        let negatives = vec![vec![0.1, 0.1, 0.8], vec![-0.1, 0.5, 0.4]];
814        let loss = infonce_loss(&anchor, &positive, &negatives, 0.1).unwrap();
815        assert!(loss.is_finite(), "InfoNCE loss should be finite");
816        assert!(loss >= 0.0, "InfoNCE loss should be non-negative");
817    }
818
819    // ── AttributeReconstructionObjective ─────────────────────────────────────
820
821    #[test]
822    fn test_attr_recon_forward_shape() {
823        let cfg = AttrReconConfig {
824            n_layers: 2,
825            hidden_dim: 16,
826            dropout: 0.0,
827        };
828        let obj = AttributeReconstructionObjective::new(8, cfg, 123).unwrap();
829        let features: Vec<Vec<f64>> = (0..5).map(|_| vec![1.0; 8]).collect();
830        let out = obj.forward(&features).unwrap();
831        assert_eq!(out.len(), 5, "output should have same number of nodes");
832        for row in &out {
833            assert_eq!(row.len(), 8, "each output vector should have dim 8");
834        }
835    }
836
837    #[test]
838    fn test_config_defaults() {
839        let pr = NodeMaskingConfig::default();
840        assert!((pr.mask_rate - 0.15).abs() < 1e-9);
841        assert!((pr.replace_rate - 0.1).abs() < 1e-9);
842        assert_eq!(pr.n_neighbors, 2);
843        assert_eq!(pr.feature_dim, 64);
844
845        let gc = GraphContextConfig::default();
846        assert_eq!(gc.context_size, 8);
847        assert_eq!(gc.negative_samples, 4);
848        assert!((gc.temperature - 0.07).abs() < 1e-9);
849
850        let ar = AttrReconConfig::default();
851        assert_eq!(ar.n_layers, 2);
852        assert_eq!(ar.hidden_dim, 128);
853        assert!((ar.dropout - 0.1).abs() < 1e-9);
854    }
855
856    #[test]
857    fn test_empty_graph_handling() {
858        // NodeMaskingPretrainer on empty features
859        let cfg = NodeMaskingConfig::default();
860        let pretrainer = NodeMaskingPretrainer::new(cfg);
861        let (masked, indices) = pretrainer.mask_nodes(&[], 0).unwrap();
862        assert!(masked.is_empty());
863        assert!(indices.is_empty());
864
865        // GraphContextPretrainer on empty graph
866        let cfg2 = GraphContextConfig::default();
867        let pretrainer2 = GraphContextPretrainer::new(cfg2);
868        let ctx = pretrainer2.sample_context_subgraph(&[], 0, 0, 0).unwrap();
869        assert!(ctx.is_empty());
870    }
871}