sklears_semi_supervised/contrastive_learning/
mod.rs

1//! Contrastive learning methods for semi-supervised learning
2//!
3//! This module implements various contrastive learning approaches that can be used
4//! for semi-supervised learning by learning representations that cluster similar
5//! samples together while pushing dissimilar samples apart.
6
7mod contrastive_predictive_coding;
8mod momentum_contrast;
9mod simclr;
10mod supervised_contrastive;
11
12pub use contrastive_predictive_coding::*;
13pub use momentum_contrast::*;
14pub use simclr::*;
15pub use supervised_contrastive::*;
16
17use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use scirs2_core::random::Random;
19// use scirs2_core::random::rand::seq::SliceRandom;
20use sklears_core::error::{Result, SklearsError};
21use sklears_core::traits::{Estimator, Fit, Predict, PredictProba};
22use thiserror::Error;
23
24#[derive(Error, Debug)]
25pub enum ContrastiveLearningError {
26    #[error("Invalid temperature parameter: {0}")]
27    InvalidTemperature(f64),
28    #[error("Invalid augmentation strength: {0}")]
29    InvalidAugmentationStrength(f64),
30    #[error("Invalid batch size: {0}")]
31    InvalidBatchSize(usize),
32    #[error("Insufficient labeled samples for contrastive learning")]
33    InsufficientLabeledSamples,
34    #[error("Embedding dimension mismatch: expected {expected}, got {actual}")]
35    EmbeddingDimensionMismatch { expected: usize, actual: usize },
36    #[error("Matrix operation failed: {0}")]
37    MatrixOperationFailed(String),
38}
39
40impl From<ContrastiveLearningError> for SklearsError {
41    fn from(err: ContrastiveLearningError) -> Self {
42        SklearsError::FitError(err.to_string())
43    }
44}