pub struct GraphMae { /* private fields */ }Expand description
Graph Masked Autoencoder.
Maintains:
- A learnable mask token of shape
[feature_dim]used to replace masked node features. - An encoder weight matrix
[feature_dim × encoder_dim]. - A decoder weight matrix
[encoder_dim × feature_dim].
Implementations§
Source§impl GraphMae
impl GraphMae
Sourcepub fn new(feature_dim: usize, config: GraphMaeConfig, seed: u64) -> Self
pub fn new(feature_dim: usize, config: GraphMaeConfig, seed: u64) -> Self
Construct a new GraphMAE.
§Arguments
feature_dim– dimension of each node’s input featuresconfig– MAE hyper-parametersseed– RNG seed for reproducible initialisation
Sourcepub fn mask_features(
&self,
features: &Array2<f64>,
seed: u64,
) -> (Array2<f64>, Vec<usize>)
pub fn mask_features( &self, features: &Array2<f64>, seed: u64, ) -> (Array2<f64>, Vec<usize>)
Apply random feature masking.
Selects a random subset of nodes (fraction ≈ config.mask_rate) and
replaces their feature vectors with the learnable mask token.
§Arguments
features– node feature matrix[n_nodes × feature_dim]seed– RNG seed (different from the model seed so each call can produce a different mask)
§Returns
(masked_features, mask_indices) where mask_indices contains the
row indices of the masked nodes (sorted ascending).
Sourcepub fn sce_loss(
&self,
original: &Array2<f64>,
reconstructed: &Array2<f64>,
mask_indices: &[usize],
gamma: f64,
) -> f64
pub fn sce_loss( &self, original: &Array2<f64>, reconstructed: &Array2<f64>, mask_indices: &[usize], gamma: f64, ) -> f64
Scaled Cosine Error (SCE) reconstruction loss on masked nodes.
L = (1/|M|) Σ_{i∈M} (1 - cosine_sim(reconstructed_i, original_i))^γIf mask_indices is empty, returns 0.0.
§Arguments
original– original node features[n_nodes × feature_dim]reconstructed– decoder output[n_nodes × feature_dim]mask_indices– indices of masked nodesgamma– exponent ≥ 1 (typical: 2 or 3)
Sourcepub fn forward(&self, features: &Array2<f64>, seed: u64) -> (Array2<f64>, f64)
pub fn forward(&self, features: &Array2<f64>, seed: u64) -> (Array2<f64>, f64)
Full GraphMAE forward pass.
- Mask features randomly.
- Encode masked features.
- Decode back to feature space.
- Compute SCE loss over masked nodes (γ = 2).
§Arguments
features– original node feature matrix[n_nodes × feature_dim]seed– RNG seed for the masking step
§Returns
(reconstructed_features, sce_loss)
Sourcepub fn mask_token(&self) -> &Array1<f64>
pub fn mask_token(&self) -> &Array1<f64>
The learnable mask token vector [feature_dim].
Sourcepub fn feature_dim(&self) -> usize
pub fn feature_dim(&self) -> usize
Input / output feature dimension.
Sourcepub fn encoder_dim(&self) -> usize
pub fn encoder_dim(&self) -> usize
Encoder output dimension.
Auto Trait Implementations§
impl Freeze for GraphMae
impl RefUnwindSafe for GraphMae
impl Send for GraphMae
impl Sync for GraphMae
impl Unpin for GraphMae
impl UnsafeUnpin for GraphMae
impl UnwindSafe for GraphMae
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more