sklears_semi_supervised/contrastive_learning/
mod.rs1mod contrastive_predictive_coding;
8mod momentum_contrast;
9mod simclr;
10mod supervised_contrastive;
11
12pub use contrastive_predictive_coding::*;
13pub use momentum_contrast::*;
14pub use simclr::*;
15pub use supervised_contrastive::*;
16
17use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use scirs2_core::random::Random;
19use sklears_core::error::{Result, SklearsError};
21use sklears_core::traits::{Estimator, Fit, Predict, PredictProba};
22use thiserror::Error;
23
24#[derive(Error, Debug)]
25pub enum ContrastiveLearningError {
26 #[error("Invalid temperature parameter: {0}")]
27 InvalidTemperature(f64),
28 #[error("Invalid augmentation strength: {0}")]
29 InvalidAugmentationStrength(f64),
30 #[error("Invalid batch size: {0}")]
31 InvalidBatchSize(usize),
32 #[error("Insufficient labeled samples for contrastive learning")]
33 InsufficientLabeledSamples,
34 #[error("Embedding dimension mismatch: expected {expected}, got {actual}")]
35 EmbeddingDimensionMismatch { expected: usize, actual: usize },
36 #[error("Matrix operation failed: {0}")]
37 MatrixOperationFailed(String),
38}
39
40impl From<ContrastiveLearningError> for SklearsError {
41 fn from(err: ContrastiveLearningError) -> Self {
42 SklearsError::FitError(err.to_string())
43 }
44}