ruvector_gnn/
training.rs

1//! Training utilities for GNN models.
2//!
3//! Provides training loop utilities, optimizers, and loss functions.
4
5use crate::error::{GnnError, Result};
6use crate::search::cosine_similarity;
7use ndarray::Array2;
8
9/// Optimizer types
10#[derive(Debug, Clone)]
11pub enum OptimizerType {
12    /// Stochastic Gradient Descent
13    Sgd {
14        /// Learning rate
15        learning_rate: f32,
16        /// Momentum coefficient (0.0 = no momentum, 0.9 = standard)
17        momentum: f32,
18    },
19    /// Adam optimizer
20    Adam {
21        /// Learning rate
22        learning_rate: f32,
23        /// Beta1 parameter (exponential decay rate for first moment)
24        beta1: f32,
25        /// Beta2 parameter (exponential decay rate for second moment)
26        beta2: f32,
27        /// Epsilon for numerical stability
28        epsilon: f32,
29    },
30}
31
32/// Optimizer state storage
33#[derive(Debug)]
34enum OptimizerState {
35    /// SGD with momentum state
36    Sgd {
37        /// Momentum buffer (velocity)
38        velocity: Option<Array2<f32>>,
39    },
40    /// Adam optimizer state
41    Adam {
42        /// First moment estimate (mean of gradients)
43        m: Option<Array2<f32>>,
44        /// Second moment estimate (uncentered variance of gradients)
45        v: Option<Array2<f32>>,
46        /// Timestep counter
47        t: usize,
48    },
49}
50
51/// Optimizer for parameter updates
52pub struct Optimizer {
53    optimizer_type: OptimizerType,
54    state: OptimizerState,
55}
56
57impl Optimizer {
58    /// Create a new optimizer
59    pub fn new(optimizer_type: OptimizerType) -> Self {
60        let state = match &optimizer_type {
61            OptimizerType::Sgd { .. } => OptimizerState::Sgd { velocity: None },
62            OptimizerType::Adam { .. } => OptimizerState::Adam {
63                m: None,
64                v: None,
65                t: 0,
66            },
67        };
68
69        Self {
70            optimizer_type,
71            state,
72        }
73    }
74
75    /// Perform optimization step
76    ///
77    /// Updates parameters in-place based on gradients using the configured optimizer.
78    ///
79    /// # Arguments
80    /// * `params` - Parameters to update (modified in-place)
81    /// * `grads` - Gradients for the parameters
82    ///
83    /// # Returns
84    /// * `Ok(())` on success
85    /// * `Err(GnnError)` if shapes don't match or other errors occur
86    pub fn step(&mut self, params: &mut Array2<f32>, grads: &Array2<f32>) -> Result<()> {
87        // Validate shapes match
88        if params.shape() != grads.shape() {
89            return Err(GnnError::dimension_mismatch(
90                format!("{:?}", params.shape()),
91                format!("{:?}", grads.shape()),
92            ));
93        }
94
95        match (&self.optimizer_type, &mut self.state) {
96            (
97                OptimizerType::Sgd {
98                    learning_rate,
99                    momentum,
100                },
101                OptimizerState::Sgd { velocity },
102            ) => Self::sgd_step_with_momentum(params, grads, *learning_rate, *momentum, velocity),
103            (
104                OptimizerType::Adam {
105                    learning_rate,
106                    beta1,
107                    beta2,
108                    epsilon,
109                },
110                OptimizerState::Adam { m, v, t },
111            ) => Self::adam_step(
112                params,
113                grads,
114                *learning_rate,
115                *beta1,
116                *beta2,
117                *epsilon,
118                m,
119                v,
120                t,
121            ),
122            _ => return Err(GnnError::invalid_input("Optimizer type and state mismatch")),
123        }
124    }
125
126    /// SGD optimization step with momentum
127    ///
128    /// Implements: v_t = momentum * v_{t-1} + learning_rate * grad
129    ///             params = params - v_t
130    fn sgd_step_with_momentum(
131        params: &mut Array2<f32>,
132        grads: &Array2<f32>,
133        learning_rate: f32,
134        momentum: f32,
135        velocity: &mut Option<Array2<f32>>,
136    ) -> Result<()> {
137        if momentum == 0.0 {
138            // Simple SGD without momentum
139            *params -= &(grads * learning_rate);
140        } else {
141            // SGD with momentum
142            if velocity.is_none() {
143                // Initialize velocity buffer
144                *velocity = Some(Array2::zeros(params.dim()));
145            }
146
147            if let Some(v) = velocity {
148                // Update velocity: v = momentum * v + learning_rate * grad
149                let new_velocity = v.mapv(|x| x * momentum) + grads * learning_rate;
150                *v = new_velocity;
151
152                // Update parameters: params = params - v
153                *params -= &*v;
154            }
155        }
156
157        Ok(())
158    }
159
160    /// Adam optimization step
161    ///
162    /// Implements the Adam algorithm:
163    /// 1. m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
164    /// 2. v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
165    /// 3. m_hat = m_t / (1 - beta1^t)
166    /// 4. v_hat = v_t / (1 - beta2^t)
167    /// 5. params = params - lr * m_hat / (sqrt(v_hat) + epsilon)
168    #[allow(clippy::too_many_arguments)]
169    fn adam_step(
170        params: &mut Array2<f32>,
171        grads: &Array2<f32>,
172        learning_rate: f32,
173        beta1: f32,
174        beta2: f32,
175        epsilon: f32,
176        m: &mut Option<Array2<f32>>,
177        v: &mut Option<Array2<f32>>,
178        t: &mut usize,
179    ) -> Result<()> {
180        // Initialize moment buffers if needed
181        if m.is_none() {
182            *m = Some(Array2::zeros(params.dim()));
183        }
184        if v.is_none() {
185            *v = Some(Array2::zeros(params.dim()));
186        }
187
188        // Increment timestep
189        *t += 1;
190        let timestep = *t as f32;
191
192        if let (Some(m_buf), Some(v_buf)) = (m, v) {
193            // Update biased first moment estimate
194            // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
195            let new_m = m_buf.mapv(|x| x * beta1) + grads * (1.0 - beta1);
196            *m_buf = new_m;
197
198            // Update biased second raw moment estimate
199            // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
200            let grads_squared = grads.mapv(|x| x * x);
201            let new_v = v_buf.mapv(|x| x * beta2) + grads_squared * (1.0 - beta2);
202            *v_buf = new_v;
203
204            // Compute bias-corrected first moment estimate
205            // m_hat = m_t / (1 - beta1^t)
206            let bias_correction1 = 1.0 - beta1.powi(*t as i32);
207            let m_hat = m_buf.mapv(|x| x / bias_correction1);
208
209            // Compute bias-corrected second raw moment estimate
210            // v_hat = v_t / (1 - beta2^t)
211            let bias_correction2 = 1.0 - beta2.powi(*t as i32);
212            let v_hat = v_buf.mapv(|x| x / bias_correction2);
213
214            // Update parameters
215            // params = params - lr * m_hat / (sqrt(v_hat) + epsilon)
216            let update = m_hat
217                .iter()
218                .zip(v_hat.iter())
219                .map(|(&m_val, &v_val)| learning_rate * m_val / (v_val.sqrt() + epsilon));
220
221            for (param, upd) in params.iter_mut().zip(update) {
222                *param -= upd;
223            }
224        }
225
226        Ok(())
227    }
228}
229
230/// Loss function types
231#[derive(Debug, Clone, Copy)]
232pub enum LossType {
233    /// Mean Squared Error
234    Mse,
235    /// Cross Entropy
236    CrossEntropy,
237    /// Binary Cross Entropy
238    BinaryCrossEntropy,
239}
240
241/// Loss function implementations for neural network training.
242///
243/// Provides forward (loss computation) and backward (gradient computation) passes
244/// for common loss functions used in GNN training.
245///
246/// # Numerical Stability
247///
248/// All loss functions use epsilon clamping and gradient clipping to prevent
249/// numerical instability with extreme prediction values (near 0 or 1).
250pub struct Loss;
251
252impl Loss {
253    /// Small epsilon value for numerical stability in logarithms and divisions.
254    const EPS: f32 = 1e-7;
255
256    /// Maximum absolute gradient value to prevent explosion.
257    const MAX_GRAD: f32 = 1e6;
258
259    /// Compute the loss value between predictions and targets.
260    ///
261    /// # Arguments
262    /// * `loss_type` - The type of loss function to use
263    /// * `predictions` - Model predictions as a 2D array
264    /// * `targets` - Ground truth targets as a 2D array (same shape as predictions)
265    ///
266    /// # Returns
267    /// * `Ok(f32)` - The computed scalar loss value
268    /// * `Err(GnnError)` - If shapes don't match or computation fails
269    ///
270    /// # Example
271    /// ```
272    /// use ndarray::Array2;
273    /// use ruvector_gnn::training::{Loss, LossType};
274    ///
275    /// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap();
276    /// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
277    /// let loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap();
278    /// assert!(loss >= 0.0);
279    /// ```
280    pub fn compute(
281        loss_type: LossType,
282        predictions: &Array2<f32>,
283        targets: &Array2<f32>,
284    ) -> Result<f32> {
285        // Validate shapes match
286        if predictions.shape() != targets.shape() {
287            return Err(GnnError::dimension_mismatch(
288                format!("{:?}", predictions.shape()),
289                format!("{:?}", targets.shape()),
290            ));
291        }
292
293        if predictions.is_empty() {
294            return Err(GnnError::invalid_input(
295                "Cannot compute loss on empty arrays",
296            ));
297        }
298
299        match loss_type {
300            LossType::Mse => Self::mse_forward(predictions, targets),
301            LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets),
302            LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets),
303        }
304    }
305
306    /// Compute the gradient of the loss with respect to predictions.
307    ///
308    /// # Arguments
309    /// * `loss_type` - The type of loss function to use
310    /// * `predictions` - Model predictions as a 2D array
311    /// * `targets` - Ground truth targets as a 2D array (same shape as predictions)
312    ///
313    /// # Returns
314    /// * `Ok(Array2<f32>)` - Gradient array with same shape as predictions
315    /// * `Err(GnnError)` - If shapes don't match or computation fails
316    ///
317    /// # Example
318    /// ```
319    /// use ndarray::Array2;
320    /// use ruvector_gnn::training::{Loss, LossType};
321    ///
322    /// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap();
323    /// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
324    /// let grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap();
325    /// assert_eq!(grad.shape(), predictions.shape());
326    /// ```
327    pub fn gradient(
328        loss_type: LossType,
329        predictions: &Array2<f32>,
330        targets: &Array2<f32>,
331    ) -> Result<Array2<f32>> {
332        // Validate shapes match
333        if predictions.shape() != targets.shape() {
334            return Err(GnnError::dimension_mismatch(
335                format!("{:?}", predictions.shape()),
336                format!("{:?}", targets.shape()),
337            ));
338        }
339
340        if predictions.is_empty() {
341            return Err(GnnError::invalid_input(
342                "Cannot compute gradient on empty arrays",
343            ));
344        }
345
346        match loss_type {
347            LossType::Mse => Self::mse_backward(predictions, targets),
348            LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets),
349            LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets),
350        }
351    }
352
353    /// Mean Squared Error: MSE = mean((predictions - targets)^2)
354    fn mse_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
355        let diff = predictions - targets;
356        let squared = diff.mapv(|x| x * x);
357        Ok(squared.mean().unwrap_or(0.0))
358    }
359
360    /// MSE gradient: d(MSE)/d(pred) = 2 * (predictions - targets) / n
361    fn mse_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
362        let n = predictions.len() as f32;
363        let diff = predictions - targets;
364        Ok(diff.mapv(|x| 2.0 * x / n))
365    }
366
367    /// Cross Entropy: CE = -mean(sum(targets * log(predictions), axis=1))
368    ///
369    /// Used for multi-class classification where targets are one-hot encoded
370    /// and predictions are softmax probabilities.
371    fn cross_entropy_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
372        let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln());
373        let elementwise = targets * &log_pred;
374        let loss = -elementwise.sum() / predictions.nrows() as f32;
375        Ok(loss)
376    }
377
378    /// Cross Entropy gradient: d(CE)/d(pred) = -targets / predictions / n
379    ///
380    /// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion.
381    fn cross_entropy_backward(
382        predictions: &Array2<f32>,
383        targets: &Array2<f32>,
384    ) -> Result<Array2<f32>> {
385        let n = predictions.nrows() as f32;
386        // Clamp predictions to avoid division by zero
387        let safe_pred = predictions.mapv(|x| x.max(Self::EPS));
388        let grad = targets / &safe_pred;
389        // Apply gradient clipping
390        Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD)))
391    }
392
393    /// Binary Cross Entropy: BCE = -mean(targets * log(pred) + (1 - targets) * log(1 - pred))
394    ///
395    /// Used for binary classification or multi-label classification.
396    fn bce_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
397        let n = predictions.len() as f32;
398        let loss: f32 = predictions
399            .iter()
400            .zip(targets.iter())
401            .map(|(&p, &t)| {
402                // Clamp predictions to (eps, 1-eps) for numerical stability
403                let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
404                -(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln())
405            })
406            .sum();
407        Ok(loss / n)
408    }
409
410    /// BCE gradient: d(BCE)/d(pred) = (-targets/pred + (1-targets)/(1-pred)) / n
411    ///
412    /// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion.
413    fn bce_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
414        let n = predictions.len() as f32;
415        let grad_vec: Vec<f32> = predictions
416            .iter()
417            .zip(targets.iter())
418            .map(|(&p, &t)| {
419                // Clamp predictions for numerical stability
420                let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
421                let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n;
422                // Clip gradient to prevent explosion
423                grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD)
424            })
425            .collect();
426
427        Array2::from_shape_vec(predictions.dim(), grad_vec)
428            .map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e)))
429    }
430}
431
432/// TODO: Implement training configuration
433#[derive(Debug, Clone)]
434pub struct TrainingConfig {
435    /// Number of epochs
436    pub epochs: usize,
437    /// Batch size
438    pub batch_size: usize,
439    /// Learning rate
440    pub learning_rate: f32,
441    /// Loss type
442    pub loss_type: LossType,
443    /// Optimizer type
444    pub optimizer_type: OptimizerType,
445}
446
447impl Default for TrainingConfig {
448    fn default() -> Self {
449        Self {
450            epochs: 100,
451            batch_size: 32,
452            learning_rate: 0.001,
453            loss_type: LossType::Mse,
454            optimizer_type: OptimizerType::Adam {
455                learning_rate: 0.001,
456                beta1: 0.9,
457                beta2: 0.999,
458                epsilon: 1e-8,
459            },
460        }
461    }
462}
463
464/// Configuration for contrastive learning training
465#[derive(Debug, Clone)]
466pub struct TrainConfig {
467    /// Batch size for training
468    pub batch_size: usize,
469    /// Number of negative samples per positive
470    pub n_negatives: usize,
471    /// Temperature parameter for contrastive loss
472    pub temperature: f32,
473    /// Learning rate for optimization
474    pub learning_rate: f32,
475    /// Number of updates before flushing to storage
476    pub flush_threshold: usize,
477}
478
479impl Default for TrainConfig {
480    fn default() -> Self {
481        Self {
482            batch_size: 256,
483            n_negatives: 64,
484            temperature: 0.07,
485            learning_rate: 0.001,
486            flush_threshold: 1000,
487        }
488    }
489}
490
491/// Configuration for online learning
492#[derive(Debug, Clone)]
493pub struct OnlineConfig {
494    /// Number of local optimization steps
495    pub local_steps: usize,
496    /// Whether to propagate updates to neighbors
497    pub propagate_updates: bool,
498}
499
500impl Default for OnlineConfig {
501    fn default() -> Self {
502        Self {
503            local_steps: 5,
504            propagate_updates: true,
505        }
506    }
507}
508
509/// Compute InfoNCE contrastive loss
510///
511/// InfoNCE (Information Noise-Contrastive Estimation) loss is used for contrastive learning.
512/// It maximizes agreement between anchor and positive samples while minimizing agreement
513/// with negative samples.
514///
515/// # Arguments
516/// * `anchor` - The anchor embedding vector
517/// * `positives` - Positive example embeddings (similar to anchor)
518/// * `negatives` - Negative example embeddings (dissimilar to anchor)
519/// * `temperature` - Temperature scaling parameter (lower = sharper distinctions)
520///
521/// # Returns
522/// * The computed loss value (lower is better)
523///
524/// # Example
525/// ```
526/// use ruvector_gnn::training::info_nce_loss;
527///
528/// let anchor = vec![1.0, 0.0, 0.0];
529/// let positive = vec![0.9, 0.1, 0.0];
530/// let negative1 = vec![0.0, 1.0, 0.0];
531/// let negative2 = vec![0.0, 0.0, 1.0];
532///
533/// let loss = info_nce_loss(
534///     &anchor,
535///     &[&positive],
536///     &[&negative1, &negative2],
537///     0.07
538/// );
539/// assert!(loss > 0.0);
540/// ```
541pub fn info_nce_loss(
542    anchor: &[f32],
543    positives: &[&[f32]],
544    negatives: &[&[f32]],
545    temperature: f32,
546) -> f32 {
547    if positives.is_empty() {
548        return 0.0;
549    }
550
551    // Compute similarities with positives (scaled by temperature)
552    let pos_sims: Vec<f32> = positives
553        .iter()
554        .map(|pos| cosine_similarity(anchor, pos) / temperature)
555        .collect();
556
557    // Compute similarities with negatives (scaled by temperature)
558    let neg_sims: Vec<f32> = negatives
559        .iter()
560        .map(|neg| cosine_similarity(anchor, neg) / temperature)
561        .collect();
562
563    // For each positive, compute the InfoNCE loss using log-sum-exp trick for numerical stability
564    let mut total_loss = 0.0;
565    for &pos_sim in &pos_sims {
566        // Use log-sum-exp trick to avoid overflow
567        // log(exp(pos_sim) / (exp(pos_sim) + sum(exp(neg_sim))))
568        // = pos_sim - log(exp(pos_sim) + sum(exp(neg_sim)))
569        // = pos_sim - log_sum_exp([pos_sim, neg_sims...])
570
571        // Collect all logits for log-sum-exp
572        let mut all_logits = vec![pos_sim];
573        all_logits.extend(&neg_sims);
574
575        // Compute log-sum-exp with numerical stability
576        let max_logit = all_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
577        let log_sum_exp = max_logit
578            + all_logits
579                .iter()
580                .map(|&x| (x - max_logit).exp())
581                .sum::<f32>()
582                .ln();
583
584        // Loss = -log(exp(pos_sim) / sum_exp) = -(pos_sim - log_sum_exp)
585        total_loss -= pos_sim - log_sum_exp;
586    }
587
588    // Average over positives
589    total_loss / positives.len() as f32
590}
591
592/// Compute local contrastive loss for graph structures
593///
594/// This loss encourages node embeddings to be similar to their neighbors
595/// and dissimilar to non-neighbors in the graph.
596///
597/// # Arguments
598/// * `node_embedding` - The embedding of the target node
599/// * `neighbor_embeddings` - Embeddings of neighbor nodes
600/// * `non_neighbor_embeddings` - Embeddings of non-neighbor nodes
601/// * `temperature` - Temperature scaling parameter
602///
603/// # Returns
604/// * The computed loss value (lower is better)
605///
606/// # Example
607/// ```
608/// use ruvector_gnn::training::local_contrastive_loss;
609///
610/// let node = vec![1.0, 0.0, 0.0];
611/// let neighbor = vec![0.9, 0.1, 0.0];
612/// let non_neighbor1 = vec![0.0, 1.0, 0.0];
613/// let non_neighbor2 = vec![0.0, 0.0, 1.0];
614///
615/// let loss = local_contrastive_loss(
616///     &node,
617///     &[neighbor],
618///     &[non_neighbor1, non_neighbor2],
619///     0.07
620/// );
621/// assert!(loss > 0.0);
622/// ```
623pub fn local_contrastive_loss(
624    node_embedding: &[f32],
625    neighbor_embeddings: &[Vec<f32>],
626    non_neighbor_embeddings: &[Vec<f32>],
627    temperature: f32,
628) -> f32 {
629    if neighbor_embeddings.is_empty() {
630        return 0.0;
631    }
632
633    // Convert to slices for info_nce_loss
634    let positives: Vec<&[f32]> = neighbor_embeddings.iter().map(|v| v.as_slice()).collect();
635    let negatives: Vec<&[f32]> = non_neighbor_embeddings
636        .iter()
637        .map(|v| v.as_slice())
638        .collect();
639
640    info_nce_loss(node_embedding, &positives, &negatives, temperature)
641}
642
643/// Perform a single SGD (Stochastic Gradient Descent) optimization step
644///
645/// Updates the embedding in-place by subtracting the scaled gradient.
646///
647/// # Arguments
648/// * `embedding` - The embedding to update (modified in-place)
649/// * `grad` - The gradient vector
650/// * `learning_rate` - The learning rate (step size)
651///
652/// # Example
653/// ```
654/// use ruvector_gnn::training::sgd_step;
655///
656/// let mut embedding = vec![1.0, 2.0, 3.0];
657/// let gradient = vec![0.1, -0.2, 0.3];
658/// let learning_rate = 0.01;
659///
660/// sgd_step(&mut embedding, &gradient, learning_rate);
661///
662/// // Embedding is now updated: embedding[i] -= learning_rate * grad[i]
663/// assert!((embedding[0] - 0.999).abs() < 1e-6);
664/// assert!((embedding[1] - 2.002).abs() < 1e-6);
665/// assert!((embedding[2] - 2.997).abs() < 1e-6);
666/// ```
667pub fn sgd_step(embedding: &mut [f32], grad: &[f32], learning_rate: f32) {
668    assert_eq!(
669        embedding.len(),
670        grad.len(),
671        "Embedding and gradient must have the same length"
672    );
673
674    for (emb, &g) in embedding.iter_mut().zip(grad.iter()) {
675        *emb -= learning_rate * g;
676    }
677}
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    #[test]
684    fn test_train_config_default() {
685        let config = TrainConfig::default();
686        assert_eq!(config.batch_size, 256);
687        assert_eq!(config.n_negatives, 64);
688        assert_eq!(config.temperature, 0.07);
689        assert_eq!(config.learning_rate, 0.001);
690        assert_eq!(config.flush_threshold, 1000);
691    }
692
693    #[test]
694    fn test_online_config_default() {
695        let config = OnlineConfig::default();
696        assert_eq!(config.local_steps, 5);
697        assert!(config.propagate_updates);
698    }
699
700    #[test]
701    fn test_info_nce_loss_basic() {
702        // Anchor and positive are similar
703        let anchor = vec![1.0, 0.0, 0.0];
704        let positive = vec![0.9, 0.1, 0.0];
705
706        // Negatives are orthogonal
707        let negative1 = vec![0.0, 1.0, 0.0];
708        let negative2 = vec![0.0, 0.0, 1.0];
709
710        let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07);
711
712        // Loss should be positive
713        assert!(loss > 0.0);
714
715        // Loss should be reasonable (not infinite or NaN)
716        assert!(loss.is_finite());
717    }
718
719    #[test]
720    fn test_info_nce_loss_perfect_match() {
721        // Anchor and positive are identical
722        let anchor = vec![1.0, 0.0, 0.0];
723        let positive = vec![1.0, 0.0, 0.0];
724
725        // Negatives are very different
726        let negative1 = vec![0.0, 1.0, 0.0];
727        let negative2 = vec![0.0, 0.0, 1.0];
728
729        let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07);
730
731        // Loss should be lower for perfect match
732        assert!(loss < 1.0);
733        assert!(loss.is_finite());
734    }
735
736    #[test]
737    fn test_info_nce_loss_no_positives() {
738        let anchor = vec![1.0, 0.0, 0.0];
739        let negative1 = vec![0.0, 1.0, 0.0];
740
741        let loss = info_nce_loss(&anchor, &[], &[&negative1], 0.07);
742
743        // Should return 0.0 when no positives
744        assert_eq!(loss, 0.0);
745    }
746
747    #[test]
748    fn test_info_nce_loss_temperature_effect() {
749        let anchor = vec![1.0, 0.0, 0.0];
750        let positive = vec![0.9, 0.1, 0.0];
751        let negative = vec![0.0, 1.0, 0.0];
752
753        // Test with reasonable temperature values
754        // Very low temperatures can cause numerical issues, so we use 0.07 (standard) and 1.0
755        let loss_low_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 0.07);
756        let loss_high_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 1.0);
757
758        // Both should be positive and finite
759        assert!(
760            loss_low_temp > 0.0 && loss_low_temp.is_finite(),
761            "Low temp loss should be positive and finite, got: {}",
762            loss_low_temp
763        );
764        assert!(
765            loss_high_temp > 0.0 && loss_high_temp.is_finite(),
766            "High temp loss should be positive and finite, got: {}",
767            loss_high_temp
768        );
769
770        // With standard temperature, the loss should be reasonable
771        assert!(loss_low_temp < 10.0, "Loss should not be too large");
772        assert!(loss_high_temp < 10.0, "Loss should not be too large");
773    }
774
775    #[test]
776    fn test_local_contrastive_loss_basic() {
777        let node = vec![1.0, 0.0, 0.0];
778        let neighbor = vec![0.9, 0.1, 0.0];
779        let non_neighbor1 = vec![0.0, 1.0, 0.0];
780        let non_neighbor2 = vec![0.0, 0.0, 1.0];
781
782        let loss =
783            local_contrastive_loss(&node, &[neighbor], &[non_neighbor1, non_neighbor2], 0.07);
784
785        // Loss should be positive and finite
786        assert!(loss > 0.0);
787        assert!(loss.is_finite());
788    }
789
790    #[test]
791    fn test_local_contrastive_loss_multiple_neighbors() {
792        let node = vec![1.0, 0.0, 0.0];
793        let neighbor1 = vec![0.9, 0.1, 0.0];
794        let neighbor2 = vec![0.95, 0.05, 0.0];
795        let non_neighbor = vec![0.0, 1.0, 0.0];
796
797        let loss = local_contrastive_loss(&node, &[neighbor1, neighbor2], &[non_neighbor], 0.07);
798
799        assert!(loss > 0.0);
800        assert!(loss.is_finite());
801    }
802
803    #[test]
804    fn test_local_contrastive_loss_no_neighbors() {
805        let node = vec![1.0, 0.0, 0.0];
806        let non_neighbor = vec![0.0, 1.0, 0.0];
807
808        let loss = local_contrastive_loss(&node, &[], &[non_neighbor], 0.07);
809
810        // Should return 0.0 when no neighbors
811        assert_eq!(loss, 0.0);
812    }
813
814    #[test]
815    fn test_sgd_step_basic() {
816        let mut embedding = vec![1.0, 2.0, 3.0];
817        let gradient = vec![0.1, -0.2, 0.3];
818        let learning_rate = 0.01;
819
820        sgd_step(&mut embedding, &gradient, learning_rate);
821
822        // Expected: embedding[i] -= learning_rate * grad[i]
823        assert!((embedding[0] - 0.999).abs() < 1e-6); // 1.0 - 0.01 * 0.1
824        assert!((embedding[1] - 2.002).abs() < 1e-6); // 2.0 - 0.01 * (-0.2)
825        assert!((embedding[2] - 2.997).abs() < 1e-6); // 3.0 - 0.01 * 0.3
826    }
827
828    #[test]
829    fn test_sgd_step_zero_gradient() {
830        let mut embedding = vec![1.0, 2.0, 3.0];
831        let original = embedding.clone();
832        let gradient = vec![0.0, 0.0, 0.0];
833        let learning_rate = 0.01;
834
835        sgd_step(&mut embedding, &gradient, learning_rate);
836
837        // Embedding should not change with zero gradient
838        assert_eq!(embedding, original);
839    }
840
841    #[test]
842    fn test_sgd_step_zero_learning_rate() {
843        let mut embedding = vec![1.0, 2.0, 3.0];
844        let original = embedding.clone();
845        let gradient = vec![0.1, 0.2, 0.3];
846        let learning_rate = 0.0;
847
848        sgd_step(&mut embedding, &gradient, learning_rate);
849
850        // Embedding should not change with zero learning rate
851        assert_eq!(embedding, original);
852    }
853
854    #[test]
855    fn test_sgd_step_large_learning_rate() {
856        let mut embedding = vec![10.0, 20.0, 30.0];
857        let gradient = vec![1.0, 2.0, 3.0];
858        let learning_rate = 5.0;
859
860        sgd_step(&mut embedding, &gradient, learning_rate);
861
862        // Expected: embedding[i] -= learning_rate * grad[i]
863        assert!((embedding[0] - 5.0).abs() < 1e-5); // 10.0 - 5.0 * 1.0
864        assert!((embedding[1] - 10.0).abs() < 1e-5); // 20.0 - 5.0 * 2.0
865        assert!((embedding[2] - 15.0).abs() < 1e-5); // 30.0 - 5.0 * 3.0
866    }
867
868    #[test]
869    #[should_panic(expected = "Embedding and gradient must have the same length")]
870    fn test_sgd_step_mismatched_lengths() {
871        let mut embedding = vec![1.0, 2.0, 3.0];
872        let gradient = vec![0.1, 0.2]; // Wrong length
873
874        sgd_step(&mut embedding, &gradient, 0.01);
875    }
876
877    #[test]
878    fn test_info_nce_loss_multiple_positives() {
879        let anchor = vec![1.0, 0.0, 0.0];
880        let positive1 = vec![0.9, 0.1, 0.0];
881        let positive2 = vec![0.95, 0.05, 0.0];
882        let negative = vec![0.0, 1.0, 0.0];
883
884        let loss = info_nce_loss(&anchor, &[&positive1, &positive2], &[&negative], 0.07);
885
886        // Loss should be positive and finite
887        assert!(loss > 0.0);
888        assert!(loss.is_finite());
889    }
890
891    #[test]
892    fn test_contrastive_loss_gradient_property() {
893        // Test that loss decreases when positive becomes more similar
894        let anchor = vec![1.0, 0.0, 0.0];
895        let positive_far = vec![0.5, 0.5, 0.0];
896        let positive_close = vec![0.9, 0.1, 0.0];
897        let negative = vec![0.0, 1.0, 0.0];
898
899        let loss_far = info_nce_loss(&anchor, &[&positive_far], &[&negative], 0.07);
900        let loss_close = info_nce_loss(&anchor, &[&positive_close], &[&negative], 0.07);
901
902        // Loss should be lower when positive is closer to anchor
903        assert!(loss_close < loss_far);
904    }
905
906    #[test]
907    fn test_sgd_optimizer_basic() {
908        let optimizer_type = OptimizerType::Sgd {
909            learning_rate: 0.1,
910            momentum: 0.0,
911        };
912        let mut optimizer = Optimizer::new(optimizer_type);
913
914        let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
915        let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
916
917        let result = optimizer.step(&mut params, &grads);
918        assert!(result.is_ok());
919
920        // Expected: params[i] -= learning_rate * grads[i]
921        assert!((params[[0, 0]] - 0.99).abs() < 1e-6); // 1.0 - 0.1 * 0.1
922        assert!((params[[0, 1]] - 1.98).abs() < 1e-6); // 2.0 - 0.1 * 0.2
923        assert!((params[[1, 0]] - 2.97).abs() < 1e-6); // 3.0 - 0.1 * 0.3
924        assert!((params[[1, 1]] - 3.96).abs() < 1e-6); // 4.0 - 0.1 * 0.4
925    }
926
927    #[test]
928    fn test_sgd_optimizer_with_momentum() {
929        let optimizer_type = OptimizerType::Sgd {
930            learning_rate: 0.1,
931            momentum: 0.9,
932        };
933        let mut optimizer = Optimizer::new(optimizer_type);
934
935        let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
936        let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
937
938        // First step
939        let result = optimizer.step(&mut params, &grads);
940        assert!(result.is_ok());
941
942        // First step should be same as SGD without momentum (velocity starts at 0)
943        assert!((params[[0, 0]] - 0.99).abs() < 1e-6);
944
945        // Second step should use accumulated momentum
946        let result = optimizer.step(&mut params, &grads);
947        assert!(result.is_ok());
948
949        // With momentum, the update should be larger
950        assert!(params[[0, 0]] < 0.99);
951    }
952
953    #[test]
954    fn test_adam_optimizer_basic() {
955        let optimizer_type = OptimizerType::Adam {
956            learning_rate: 0.001,
957            beta1: 0.9,
958            beta2: 0.999,
959            epsilon: 1e-8,
960        };
961        let mut optimizer = Optimizer::new(optimizer_type);
962
963        let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
964        let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
965
966        let original_params = params.clone();
967        let result = optimizer.step(&mut params, &grads);
968        assert!(result.is_ok());
969
970        // Parameters should be updated (decreased in the direction of gradients)
971        assert!(params[[0, 0]] < original_params[[0, 0]]);
972        assert!(params[[0, 1]] < original_params[[0, 1]]);
973        assert!(params[[1, 0]] < original_params[[1, 0]]);
974        assert!(params[[1, 1]] < original_params[[1, 1]]);
975
976        // Check that all values are finite
977        assert!(params.iter().all(|&x| x.is_finite()));
978    }
979
980    #[test]
981    fn test_adam_optimizer_multiple_steps() {
982        let optimizer_type = OptimizerType::Adam {
983            learning_rate: 0.01,
984            beta1: 0.9,
985            beta2: 0.999,
986            epsilon: 1e-8,
987        };
988        let mut optimizer = Optimizer::new(optimizer_type);
989
990        let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
991        let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
992        let initial_params = params.clone();
993
994        // Perform multiple steps
995        for _ in 0..10 {
996            let result = optimizer.step(&mut params, &grads);
997            assert!(result.is_ok());
998            assert!(params.iter().all(|&x| x.is_finite()));
999        }
1000
1001        // After multiple steps, parameters should have decreased (gradients are positive)
1002        assert!(params[[0, 0]] < initial_params[[0, 0]]);
1003        assert!(params[[1, 1]] < initial_params[[1, 1]]);
1004        // All parameters should have moved
1005        for i in 0..2 {
1006            for j in 0..2 {
1007                assert!(params[[i, j]] < initial_params[[i, j]]);
1008            }
1009        }
1010    }
1011
1012    #[test]
1013    fn test_adam_bias_correction() {
1014        let optimizer_type = OptimizerType::Adam {
1015            learning_rate: 0.001,
1016            beta1: 0.9,
1017            beta2: 0.999,
1018            epsilon: 1e-8,
1019        };
1020        let mut optimizer = Optimizer::new(optimizer_type.clone());
1021
1022        let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
1023        let grads = Array2::from_shape_vec((1, 1), vec![0.1]).unwrap();
1024
1025        // First step should have strong bias correction
1026        let result = optimizer.step(&mut params, &grads);
1027        assert!(result.is_ok());
1028        let first_update = 1.0 - params[[0, 0]];
1029
1030        // Reset optimizer
1031        let mut optimizer = Optimizer::new(optimizer_type);
1032        let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
1033
1034        // Perform 100 steps, last step should have less bias correction effect
1035        for _ in 0..100 {
1036            let _ = optimizer.step(&mut params, &grads);
1037        }
1038
1039        // The bias correction effect should diminish over time
1040        assert!(first_update > 0.0);
1041    }
1042
1043    #[test]
1044    fn test_optimizer_shape_mismatch() {
1045        let optimizer_type = OptimizerType::Adam {
1046            learning_rate: 0.001,
1047            beta1: 0.9,
1048            beta2: 0.999,
1049            epsilon: 1e-8,
1050        };
1051        let mut optimizer = Optimizer::new(optimizer_type);
1052
1053        let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1054        let grads = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
1055
1056        let result = optimizer.step(&mut params, &grads);
1057        assert!(result.is_err());
1058        if let Err(GnnError::DimensionMismatch { expected, actual }) = result {
1059            assert!(expected.contains("2, 2"));
1060            assert!(actual.contains("3, 2"));
1061        } else {
1062            panic!("Expected DimensionMismatch error");
1063        }
1064    }
1065
1066    #[test]
1067    fn test_adam_convergence() {
1068        // Test that Adam can minimize a simple quadratic function
1069        let optimizer_type = OptimizerType::Adam {
1070            learning_rate: 0.5,
1071            beta1: 0.9,
1072            beta2: 0.999,
1073            epsilon: 1e-8,
1074        };
1075        let mut optimizer = Optimizer::new(optimizer_type);
1076
1077        // Start with params far from optimum (0, 0)
1078        let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
1079
1080        // Gradient of f(x, y) = x^2 + y^2 is (2x, 2y)
1081        for _ in 0..200 {
1082            let grads =
1083                Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]])
1084                    .unwrap();
1085            let _ = optimizer.step(&mut params, &grads);
1086        }
1087
1088        // Should converge close to (0, 0)
1089        assert!(params[[0, 0]].abs() < 0.5);
1090        assert!(params[[0, 1]].abs() < 0.5);
1091    }
1092
1093    #[test]
1094    fn test_sgd_momentum_convergence() {
1095        // Test that SGD with momentum can minimize a simple quadratic function
1096        let optimizer_type = OptimizerType::Sgd {
1097            learning_rate: 0.01,
1098            momentum: 0.9,
1099        };
1100        let mut optimizer = Optimizer::new(optimizer_type);
1101
1102        // Start with params far from optimum (0, 0)
1103        let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
1104
1105        // Gradient of f(x, y) = x^2 + y^2 is (2x, 2y)
1106        for _ in 0..200 {
1107            let grads =
1108                Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]])
1109                    .unwrap();
1110            let _ = optimizer.step(&mut params, &grads);
1111        }
1112
1113        // Should converge close to (0, 0)
1114        assert!(params[[0, 0]].abs() < 0.5);
1115        assert!(params[[0, 1]].abs() < 0.5);
1116    }
1117
1118    // ==================== Loss Function Tests ====================
1119
1120    #[test]
1121    fn test_mse_loss_zero_when_equal() {
1122        let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1123        let target = pred.clone();
1124        let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1125        assert!(
1126            (loss - 0.0).abs() < 1e-6,
1127            "MSE should be 0 when pred == target"
1128        );
1129    }
1130
1131    #[test]
1132    fn test_mse_loss_positive() {
1133        let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1134        let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap();
1135        let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1136        // Each element differs by 1, so squared diff = 1, mean = 1
1137        assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss);
1138    }
1139
1140    #[test]
1141    fn test_mse_loss_varying_diffs() {
1142        let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap();
1143        let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1144        let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1145        // Squared diffs: 1, 4, 9, 16. Mean = 30/4 = 7.5
1146        assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss);
1147    }
1148
1149    #[test]
1150    fn test_mse_gradient_shape() {
1151        let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap();
1152        let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1153        let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1154        assert_eq!(grad.shape(), pred.shape());
1155    }
1156
1157    #[test]
1158    fn test_mse_gradient_direction() {
1159        let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap();
1160        let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
1161        let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1162        // grad = 2*(pred - target)/n = 2*(-1, 1)/2 = (-1, 1)
1163        assert!(
1164            grad[[0, 0]] < 0.0,
1165            "Gradient should be negative when pred < target"
1166        );
1167        assert!(
1168            grad[[0, 1]] > 0.0,
1169            "Gradient should be positive when pred > target"
1170        );
1171    }
1172
1173    #[test]
1174    fn test_mse_gradient_zero_when_equal() {
1175        let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1176        let target = pred.clone();
1177        let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1178        assert!(
1179            grad.iter().all(|&x| x.abs() < 1e-6),
1180            "Gradient should be zero when pred == target"
1181        );
1182    }
1183
1184    #[test]
1185    fn test_bce_loss_perfect_predictions() {
1186        let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap();
1187        let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1188        let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1189        // Near-perfect predictions should have low loss
1190        assert!(
1191            loss < 0.1,
1192            "BCE should be low for good predictions, got {}",
1193            loss
1194        );
1195    }
1196
1197    #[test]
1198    fn test_bce_loss_bad_predictions() {
1199        let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap();
1200        let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1201        let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1202        // Bad predictions should have high loss
1203        assert!(
1204            loss > 1.0,
1205            "BCE should be high for bad predictions, got {}",
1206            loss
1207        );
1208    }
1209
1210    #[test]
1211    fn test_bce_loss_numerical_stability() {
1212        // Test with extreme values that could cause numerical issues
1213        let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
1214        let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
1215        let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1216        assert!(
1217            loss.is_finite(),
1218            "BCE should be finite even with extreme values"
1219        );
1220    }
1221
1222    #[test]
1223    fn test_bce_gradient_shape() {
1224        let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap();
1225        let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap();
1226        let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1227        assert_eq!(grad.shape(), pred.shape());
1228    }
1229
1230    #[test]
1231    fn test_bce_gradient_direction() {
1232        let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap();
1233        let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1234        let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1235        // When target=1 and pred<1, gradient should push pred up (negative gradient)
1236        assert!(
1237            grad[[0, 0]] < 0.0,
1238            "Gradient should be negative to increase pred towards 1"
1239        );
1240        // When target=0 and pred>0, gradient should push pred down (positive gradient)
1241        assert!(
1242            grad[[0, 1]] > 0.0,
1243            "Gradient should be positive to decrease pred towards 0"
1244        );
1245    }
1246
1247    #[test]
1248    fn test_cross_entropy_one_hot() {
1249        // Softmax-like predictions (sum to 1 per row)
1250        let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap();
1251        let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
1252        let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
1253        // Good predictions should have reasonable loss
1254        assert!(
1255            loss > 0.0 && loss < 1.0,
1256            "CE should be reasonable for good predictions, got {}",
1257            loss
1258        );
1259    }
1260
1261    #[test]
1262    fn test_cross_entropy_wrong_class() {
1263        let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap();
1264        let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap();
1265        let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
1266        // Predicting wrong class should have high loss
1267        assert!(
1268            loss > 1.0,
1269            "CE should be high for wrong predictions, got {}",
1270            loss
1271        );
1272    }
1273
1274    #[test]
1275    fn test_cross_entropy_gradient_shape() {
1276        let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap();
1277        let target =
1278            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
1279        let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap();
1280        assert_eq!(grad.shape(), pred.shape());
1281    }
1282
1283    #[test]
1284    fn test_loss_dimension_mismatch_error() {
1285        let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
1286        let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1287
1288        let result = Loss::compute(LossType::Mse, &pred, &target);
1289        assert!(result.is_err(), "Should error on dimension mismatch");
1290
1291        let result = Loss::gradient(LossType::Mse, &pred, &target);
1292        assert!(
1293            result.is_err(),
1294            "Gradient should error on dimension mismatch"
1295        );
1296    }
1297
1298    #[test]
1299    fn test_loss_empty_array_error() {
1300        let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap();
1301        let target = Array2::from_shape_vec((0, 2), vec![]).unwrap();
1302
1303        let result = Loss::compute(LossType::Mse, &pred, &target);
1304        assert!(result.is_err(), "Should error on empty arrays");
1305
1306        let result = Loss::gradient(LossType::Mse, &pred, &target);
1307        assert!(result.is_err(), "Gradient should error on empty arrays");
1308    }
1309
1310    #[test]
1311    fn test_loss_gradient_numerical_check() {
1312        // Numerical gradient check for MSE
1313        let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap();
1314        let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1315
1316        let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1317
1318        // Compute numerical gradient
1319        let eps = 1e-5;
1320        for i in 0..2 {
1321            let mut pred_plus = pred.clone();
1322            let mut pred_minus = pred.clone();
1323            pred_plus[[0, i]] += eps;
1324            pred_minus[[0, i]] -= eps;
1325
1326            let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap();
1327            let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap();
1328
1329            let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps);
1330            let error = (analytical_grad[[0, i]] - numerical_grad).abs();
1331
1332            assert!(
1333                error < 1e-3,
1334                "Numerical gradient check failed: analytical={}, numerical={}",
1335                analytical_grad[[0, i]],
1336                numerical_grad
1337            );
1338        }
1339    }
1340
1341    #[test]
1342    fn test_training_loop_integration() {
1343        // Integration test: use Loss with Optimizer
1344        let mut optimizer = Optimizer::new(OptimizerType::Sgd {
1345            learning_rate: 0.1,
1346            momentum: 0.0,
1347        });
1348
1349        let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1350        let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
1351
1352        let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1353
1354        // Perform a few optimization steps
1355        for _ in 0..10 {
1356            let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1357            optimizer.step(&mut pred, &grad).unwrap();
1358        }
1359
1360        let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1361
1362        assert!(
1363            final_loss < initial_loss,
1364            "Loss should decrease during training"
1365        );
1366    }
1367}