sklears_neural/
autoencoder.rs

1//! Autoencoder implementation for unsupervised feature learning
2//!
3//! This module provides various autoencoder architectures including:
4//! - Standard autoencoder for dimensionality reduction
5//! - Denoising autoencoder for data cleaning
6//! - Sparse autoencoder for feature learning
7//! - Deep autoencoder for complex representations
8
9use crate::activation::Activation;
10use crate::SklearsError;
11use scirs2_core::ndarray::{Array1, Array2, Axis};
12use scirs2_core::random::{Rng, SeedableRng};
13use sklears_core::{
14    error::Result,
15    traits::{Estimator, Fit, Trained, Transform, Untrained},
16    types::Float,
17};
18use std::marker::PhantomData;
19
20/// Optimizer types for training
21#[derive(Debug, Clone, Copy)]
22pub enum OptimizerType {
23    SGD,
24    Adam,
25    RMSprop,
26}
27
28/// Autoencoder configuration
29#[derive(Debug, Clone)]
30pub struct AutoencoderConfig {
31    /// Size of hidden layer (encoding dimension) - for simple autoencoder
32    pub encoding_dim: usize,
33    /// Encoder layer sizes - for deep autoencoder
34    pub encoder_layers: Option<Vec<usize>>,
35    /// Activation function
36    pub activation: Activation,
37    /// Learning rate
38    pub learning_rate: Float,
39    /// Number of epochs
40    pub n_epochs: usize,
41    /// Batch size
42    pub batch_size: usize,
43    /// Random state
44    pub random_state: Option<u64>,
45    /// Noise factor for denoising autoencoder
46    pub noise_factor: Float,
47    /// L2 regularization parameter
48    pub l2_reg: Float,
49    /// Sparsity parameters (rho, beta) for sparse autoencoder
50    pub sparsity: Option<(Float, Float)>,
51    /// Optimizer type
52    pub optimizer: OptimizerType,
53}
54
55impl Default for AutoencoderConfig {
56    fn default() -> Self {
57        Self {
58            encoding_dim: 32,
59            encoder_layers: None,
60            activation: Activation::Relu,
61            learning_rate: 0.01,
62            n_epochs: 100,
63            batch_size: 32,
64            random_state: None,
65            noise_factor: 0.0,
66            l2_reg: 0.0,
67            sparsity: None,
68            optimizer: OptimizerType::SGD,
69        }
70    }
71}
72
73/// Simple Autoencoder model
74#[derive(Debug, Clone)]
75pub struct Autoencoder<State = Untrained> {
76    config: AutoencoderConfig,
77    state: PhantomData<State>,
78    // Trained parameters
79    encoder_weights_: Option<Array2<Float>>,
80    encoder_bias_: Option<Array1<Float>>,
81    decoder_weights_: Option<Array2<Float>>,
82    decoder_bias_: Option<Array1<Float>>,
83    _n_features: Option<usize>,
84}
85
86impl Autoencoder<Untrained> {
87    /// Create a new autoencoder
88    pub fn new() -> Self {
89        Self {
90            config: AutoencoderConfig::default(),
91            state: PhantomData,
92            encoder_weights_: None,
93            encoder_bias_: None,
94            decoder_weights_: None,
95            decoder_bias_: None,
96            _n_features: None,
97        }
98    }
99
100    /// Set encoding dimension
101    pub fn encoding_dim(mut self, dim: usize) -> Self {
102        self.config.encoding_dim = dim;
103        self
104    }
105
106    /// Set activation function
107    pub fn activation(mut self, activation: Activation) -> Self {
108        self.config.activation = activation;
109        self
110    }
111
112    /// Set learning rate
113    pub fn learning_rate(mut self, lr: Float) -> Self {
114        self.config.learning_rate = lr;
115        self
116    }
117
118    /// Set number of epochs
119    pub fn n_epochs(mut self, epochs: usize) -> Self {
120        self.config.n_epochs = epochs;
121        self
122    }
123
124    /// Set batch size
125    pub fn batch_size(mut self, size: usize) -> Self {
126        self.config.batch_size = size;
127        self
128    }
129
130    /// Set random state
131    pub fn random_state(mut self, seed: u64) -> Self {
132        self.config.random_state = Some(seed);
133        self
134    }
135
136    /// Set encoder layers for deep autoencoder
137    pub fn encoder_layers(mut self, layers: Vec<usize>) -> Self {
138        self.config.encoder_layers = Some(layers);
139        self
140    }
141
142    /// Set noise factor for denoising autoencoder
143    pub fn noise_factor(mut self, factor: Float) -> Self {
144        self.config.noise_factor = factor;
145        self
146    }
147
148    /// Set L2 regularization parameter
149    pub fn l2_reg(mut self, reg: Float) -> Self {
150        self.config.l2_reg = reg;
151        self
152    }
153
154    /// Set sparsity parameters (rho, beta)
155    pub fn sparsity(mut self, rho: Float, beta: Float) -> Self {
156        self.config.sparsity = Some((rho, beta));
157        self
158    }
159
160    /// Set optimizer type
161    pub fn optimizer(mut self, opt: OptimizerType) -> Self {
162        self.config.optimizer = opt;
163        self
164    }
165}
166
167impl Default for Autoencoder<Untrained> {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl Estimator for Autoencoder<Untrained> {
174    type Config = AutoencoderConfig;
175    type Error = SklearsError;
176    type Float = Float;
177
178    fn config(&self) -> &Self::Config {
179        &self.config
180    }
181}
182
183impl Fit<Array2<Float>, ()> for Autoencoder<Untrained> {
184    type Fitted = Autoencoder<Trained>;
185
186    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
187        let n_samples = x.nrows();
188        let n_features = x.ncols();
189        let encoding_dim = self.config.encoding_dim;
190
191        // Initialize weights with Xavier initialization
192        let mut rng = if let Some(seed) = self.config.random_state {
193            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
194        } else {
195            scirs2_core::random::rngs::StdRng::seed_from_u64(42) // Use default seed if none provided
196        };
197
198        let limit_enc = (6.0 / (n_features + encoding_dim) as Float).sqrt();
199        let limit_dec = (6.0 / (encoding_dim + n_features) as Float).sqrt();
200
201        let mut encoder_weights = Array2::zeros((n_features, encoding_dim));
202        let mut decoder_weights = Array2::zeros((encoding_dim, n_features));
203
204        for i in 0..n_features {
205            for j in 0..encoding_dim {
206                encoder_weights[[i, j]] = rng.gen_range(-limit_enc..limit_enc);
207            }
208        }
209
210        for i in 0..encoding_dim {
211            for j in 0..n_features {
212                decoder_weights[[i, j]] = rng.gen_range(-limit_dec..limit_dec);
213            }
214        }
215
216        let mut encoder_bias = Array1::zeros(encoding_dim);
217        let mut decoder_bias = Array1::zeros(n_features);
218
219        // Training loop
220        let batch_size = self.config.batch_size.min(n_samples);
221        let n_batches = n_samples.div_ceil(batch_size);
222
223        for epoch in 0..self.config.n_epochs {
224            let mut epoch_loss = 0.0;
225
226            // Shuffle indices
227            let mut indices: Vec<usize> = (0..n_samples).collect();
228            if self.config.random_state.is_some() {
229                use scirs2_core::random::seq::SliceRandom;
230                indices.shuffle(&mut rng);
231            }
232
233            for batch_idx in 0..n_batches {
234                let start = batch_idx * batch_size;
235                let end = ((batch_idx + 1) * batch_size).min(n_samples);
236                let batch_indices = &indices[start..end];
237                let actual_batch_size = batch_indices.len();
238
239                // Get batch
240                let mut batch = Array2::zeros((actual_batch_size, n_features));
241                for (i, &idx) in batch_indices.iter().enumerate() {
242                    batch.row_mut(i).assign(&x.row(idx));
243                }
244
245                // Forward pass
246                let z_enc = batch.dot(&encoder_weights) + &encoder_bias;
247                let a_enc = self.config.activation.apply(&z_enc);
248
249                let z_dec = a_enc.dot(&decoder_weights) + &decoder_bias;
250                let a_dec = Activation::Identity.apply(&z_dec); // Linear output
251
252                // Compute loss (MSE)
253                let diff = &batch - &a_dec;
254                let loss = diff.mapv(|x| x * x).sum() / (actual_batch_size * n_features) as Float;
255                epoch_loss += loss * actual_batch_size as Float;
256
257                // Backward pass
258                let d_output =
259                    (&a_dec - &batch) * (2.0 / (actual_batch_size * n_features) as Float);
260
261                // Decoder gradients
262                let d_w_dec = a_enc.t().dot(&d_output);
263                let d_b_dec = d_output.sum_axis(Axis(0));
264
265                // Propagate through decoder
266                let d_enc = d_output.dot(&decoder_weights.t());
267                let d_enc = &d_enc * &self.config.activation.derivative(&a_enc);
268
269                // Encoder gradients
270                let d_w_enc = batch.t().dot(&d_enc);
271                let d_b_enc = d_enc.sum_axis(Axis(0));
272
273                // Update weights (simple gradient descent)
274                encoder_weights = &encoder_weights - &d_w_enc * self.config.learning_rate;
275                encoder_bias = &encoder_bias - &d_b_enc * self.config.learning_rate;
276                decoder_weights = &decoder_weights - &d_w_dec * self.config.learning_rate;
277                decoder_bias = &decoder_bias - &d_b_dec * self.config.learning_rate;
278            }
279
280            epoch_loss /= n_samples as Float;
281
282            if epoch % 10 == 0 {
283                println!("Epoch {epoch}: Loss = {epoch_loss:.6}");
284            }
285        }
286
287        Ok(Autoencoder {
288            config: self.config,
289            state: PhantomData,
290            encoder_weights_: Some(encoder_weights),
291            encoder_bias_: Some(encoder_bias),
292            decoder_weights_: Some(decoder_weights),
293            decoder_bias_: Some(decoder_bias),
294            _n_features: Some(n_features),
295        })
296    }
297}
298
299impl Transform<Array2<Float>, Array2<Float>> for Autoencoder<Trained> {
300    /// Transform data to encoded representation
301    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
302        let encoder_weights =
303            self.encoder_weights_
304                .as_ref()
305                .ok_or_else(|| SklearsError::NotFitted {
306                    operation: "transform".to_string(),
307                })?;
308        let encoder_bias = self
309            .encoder_bias_
310            .as_ref()
311            .ok_or_else(|| SklearsError::NotFitted {
312                operation: "transform".to_string(),
313            })?;
314
315        let z = x.dot(encoder_weights) + encoder_bias;
316        Ok(self.config.activation.apply(&z))
317    }
318}
319
320impl Autoencoder<Trained> {
321    /// Reconstruct data (encode then decode)
322    pub fn reconstruct(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
323        let encoded = self.transform(x)?;
324
325        let decoder_weights =
326            self.decoder_weights_
327                .as_ref()
328                .ok_or_else(|| SklearsError::NotFitted {
329                    operation: "reconstruct".to_string(),
330                })?;
331        let decoder_bias = self
332            .decoder_bias_
333            .as_ref()
334            .ok_or_else(|| SklearsError::NotFitted {
335                operation: "reconstruct".to_string(),
336            })?;
337
338        let z = encoded.dot(decoder_weights) + decoder_bias;
339        Ok(Activation::Identity.apply(&z))
340    }
341}
342
343#[allow(non_snake_case)]
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_autoencoder_construction() {
350        let ae = Autoencoder::new()
351            .encoding_dim(16)
352            .activation(Activation::Relu)
353            .learning_rate(0.01)
354            .n_epochs(10);
355
356        assert_eq!(ae.config.encoding_dim, 16);
357        assert_eq!(ae.config.learning_rate, 0.01);
358        assert_eq!(ae.config.n_epochs, 10);
359    }
360
361    #[test]
362    fn test_autoencoder_fit_transform() {
363        let x =
364            Array2::from_shape_vec((10, 5), (0..50).map(|i| i as Float / 10.0).collect()).unwrap();
365
366        let ae = Autoencoder::new()
367            .encoding_dim(3)
368            .n_epochs(20)
369            .learning_rate(0.1)
370            .random_state(42);
371
372        let fitted = ae.fit(&x, &()).unwrap();
373
374        // Check encoding
375        let encoded = fitted.transform(&x).unwrap();
376        assert_eq!(encoded.shape(), &[10, 3]);
377
378        // Check reconstruction
379        let reconstructed = fitted.reconstruct(&x).unwrap();
380        assert_eq!(reconstructed.shape(), x.shape());
381    }
382}