scirs2_graph/embeddings/
types.rs

1//! Core types and configurations for graph embeddings
2
3use crate::base::Node;
4use std::collections::HashMap;
5
6/// Configuration for Node2Vec embedding algorithm
7#[derive(Debug, Clone)]
8pub struct Node2VecConfig {
9    /// Dimensions of the embedding vectors
10    pub dimensions: usize,
11    /// Length of each random walk
12    pub walk_length: usize,
13    /// Number of random walks per node
14    pub num_walks: usize,
15    /// Window size for skip-gram model
16    pub window_size: usize,
17    /// Return parameter p (likelihood of immediate revisiting)
18    pub p: f64,
19    /// In-out parameter q (exploration vs exploitation)
20    pub q: f64,
21    /// Number of training epochs
22    pub epochs: usize,
23    /// Learning rate for gradient descent
24    pub learning_rate: f64,
25    /// Number of negative samples for training
26    pub negative_samples: usize,
27}
28
29impl Default for Node2VecConfig {
30    fn default() -> Self {
31        Node2VecConfig {
32            dimensions: 128,
33            walk_length: 80,
34            num_walks: 10,
35            window_size: 10,
36            p: 1.0,
37            q: 1.0,
38            epochs: 1,
39            learning_rate: 0.025,
40            negative_samples: 5,
41        }
42    }
43}
44
45/// Configuration for DeepWalk embedding algorithm
46#[derive(Debug, Clone)]
47pub struct DeepWalkConfig {
48    /// Dimensions of the embedding vectors
49    pub dimensions: usize,
50    /// Length of each random walk
51    pub walk_length: usize,
52    /// Number of random walks per node
53    pub num_walks: usize,
54    /// Window size for skip-gram model
55    pub window_size: usize,
56    /// Number of training epochs
57    pub epochs: usize,
58    /// Learning rate
59    pub learning_rate: f64,
60    /// Number of negative samples
61    pub negative_samples: usize,
62}
63
64impl Default for DeepWalkConfig {
65    fn default() -> Self {
66        DeepWalkConfig {
67            dimensions: 128,
68            walk_length: 40,
69            num_walks: 80,
70            window_size: 5,
71            epochs: 1,
72            learning_rate: 0.025,
73            negative_samples: 5,
74        }
75    }
76}
77
78/// A random walk on a graph
79#[derive(Debug, Clone)]
80pub struct RandomWalk<N: Node> {
81    /// The sequence of nodes in the walk
82    pub nodes: Vec<N>,
83}
84
85/// Advanced optimization techniques for embeddings
86#[derive(Debug, Clone)]
87pub struct OptimizationConfig {
88    /// Learning rate schedule type
89    pub lr_schedule: LearningRateSchedule,
90    /// Initial learning rate
91    pub initial_lr: f64,
92    /// Final learning rate
93    pub final_lr: f64,
94    /// Use momentum optimization
95    pub use_momentum: bool,
96    /// Momentum factor (0.9 is typical)
97    pub momentum: f64,
98    /// Use Adam optimizer
99    pub use_adam: bool,
100    /// Adam beta1 parameter
101    pub adam_beta1: f64,
102    /// Adam beta2 parameter
103    pub adam_beta2: f64,
104    /// Adam epsilon parameter
105    pub adam_epsilon: f64,
106    /// L2 regularization strength
107    pub l2_regularization: f64,
108    /// Gradient clipping threshold
109    pub gradient_clip: Option<f64>,
110    /// Use hierarchical softmax instead of negative sampling
111    pub use_hierarchical_softmax: bool,
112}
113
114impl Default for OptimizationConfig {
115    fn default() -> Self {
116        OptimizationConfig {
117            lr_schedule: LearningRateSchedule::Linear,
118            initial_lr: 0.025,
119            final_lr: 0.0001,
120            use_momentum: false,
121            momentum: 0.9,
122            use_adam: false,
123            adam_beta1: 0.9,
124            adam_beta2: 0.999,
125            adam_epsilon: 1e-8,
126            l2_regularization: 0.0,
127            gradient_clip: Some(1.0),
128            use_hierarchical_softmax: false,
129        }
130    }
131}
132
133/// Learning rate scheduling strategies
134#[derive(Debug, Clone, Copy, PartialEq)]
135pub enum LearningRateSchedule {
136    /// Constant learning rate
137    Constant,
138    /// Linear decay from initial to final
139    Linear,
140    /// Exponential decay
141    Exponential,
142    /// Cosine annealing
143    Cosine,
144    /// Step decay (reduce by factor at specific epochs)
145    Step,
146}
147
148/// Enhanced training metrics and monitoring
149#[derive(Debug, Clone)]
150pub struct TrainingMetrics {
151    /// Current epoch
152    pub epoch: usize,
153    /// Total training steps
154    pub steps: usize,
155    /// Current learning rate
156    pub learning_rate: f64,
157    /// Training loss (negative log likelihood)
158    pub loss: f64,
159    /// Loss moving average
160    pub loss_avg: f64,
161    /// Gradient norm
162    pub gradient_norm: f64,
163    /// Processing speed (steps per second)
164    pub steps_per_second: f64,
165    /// Memory usage in bytes
166    pub memory_usage: usize,
167    /// Convergence indicator (rate of loss change)
168    pub convergence_rate: f64,
169    /// Training accuracy on positive samples
170    pub positive_accuracy: f64,
171    /// Training accuracy on negative samples
172    pub negative_accuracy: f64,
173}
174
175impl Default for TrainingMetrics {
176    fn default() -> Self {
177        TrainingMetrics {
178            epoch: 0,
179            steps: 0,
180            learning_rate: 0.025,
181            loss: 0.0,
182            loss_avg: 0.0,
183            gradient_norm: 0.0,
184            steps_per_second: 0.0,
185            memory_usage: 0,
186            convergence_rate: 0.0,
187            positive_accuracy: 0.0,
188            negative_accuracy: 0.0,
189        }
190    }
191}
192
193/// Adaptive negative sampling strategies
194#[derive(Debug, Clone)]
195pub enum NegativeSamplingStrategy {
196    /// Uniform random sampling
197    Uniform,
198    /// Frequency-based sampling (more frequent nodes sampled more often)
199    Frequency,
200    /// Degree-based sampling (higher degree nodes sampled more often)
201    Degree,
202    /// Adaptive sampling based on embedding quality
203    Adaptive,
204    /// Hierarchical sampling using word2vec-style tree
205    Hierarchical,
206}
207
208/// Advanced optimizer state for Adam/momentum
209#[derive(Debug, Clone)]
210pub struct OptimizerState {
211    /// Momentum buffers for each parameter
212    pub momentum_buffers: HashMap<String, Vec<f64>>,
213    /// Adam first moment estimates
214    pub adam_m: HashMap<String, Vec<f64>>,
215    /// Adam second moment estimates
216    pub adam_v: HashMap<String, Vec<f64>>,
217    /// Time step for bias correction
218    pub time_step: usize,
219}
220
221impl Default for OptimizerState {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227impl OptimizerState {
228    pub fn new() -> Self {
229        OptimizerState {
230            momentum_buffers: HashMap::new(),
231            adam_m: HashMap::new(),
232            adam_v: HashMap::new(),
233            time_step: 0,
234        }
235    }
236}
237
238/// Skip-gram training context pair
239#[derive(Debug, Clone)]
240pub struct ContextPair<N: Node> {
241    /// Target node
242    pub target: N,
243    /// Context node
244    pub context: N,
245}