1use crate::activation::Activation;
10use crate::SklearsError;
11use scirs2_core::ndarray::{Array1, Array2, Axis};
12use scirs2_core::random::{Rng, SeedableRng};
13use sklears_core::{
14 error::Result,
15 traits::{Estimator, Fit, Trained, Transform, Untrained},
16 types::Float,
17};
18use std::marker::PhantomData;
19
20#[derive(Debug, Clone, Copy)]
22pub enum OptimizerType {
23 SGD,
24 Adam,
25 RMSprop,
26}
27
28#[derive(Debug, Clone)]
30pub struct AutoencoderConfig {
31 pub encoding_dim: usize,
33 pub encoder_layers: Option<Vec<usize>>,
35 pub activation: Activation,
37 pub learning_rate: Float,
39 pub n_epochs: usize,
41 pub batch_size: usize,
43 pub random_state: Option<u64>,
45 pub noise_factor: Float,
47 pub l2_reg: Float,
49 pub sparsity: Option<(Float, Float)>,
51 pub optimizer: OptimizerType,
53}
54
55impl Default for AutoencoderConfig {
56 fn default() -> Self {
57 Self {
58 encoding_dim: 32,
59 encoder_layers: None,
60 activation: Activation::Relu,
61 learning_rate: 0.01,
62 n_epochs: 100,
63 batch_size: 32,
64 random_state: None,
65 noise_factor: 0.0,
66 l2_reg: 0.0,
67 sparsity: None,
68 optimizer: OptimizerType::SGD,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct Autoencoder<State = Untrained> {
76 config: AutoencoderConfig,
77 state: PhantomData<State>,
78 encoder_weights_: Option<Array2<Float>>,
80 encoder_bias_: Option<Array1<Float>>,
81 decoder_weights_: Option<Array2<Float>>,
82 decoder_bias_: Option<Array1<Float>>,
83 _n_features: Option<usize>,
84}
85
86impl Autoencoder<Untrained> {
87 pub fn new() -> Self {
89 Self {
90 config: AutoencoderConfig::default(),
91 state: PhantomData,
92 encoder_weights_: None,
93 encoder_bias_: None,
94 decoder_weights_: None,
95 decoder_bias_: None,
96 _n_features: None,
97 }
98 }
99
100 pub fn encoding_dim(mut self, dim: usize) -> Self {
102 self.config.encoding_dim = dim;
103 self
104 }
105
106 pub fn activation(mut self, activation: Activation) -> Self {
108 self.config.activation = activation;
109 self
110 }
111
112 pub fn learning_rate(mut self, lr: Float) -> Self {
114 self.config.learning_rate = lr;
115 self
116 }
117
118 pub fn n_epochs(mut self, epochs: usize) -> Self {
120 self.config.n_epochs = epochs;
121 self
122 }
123
124 pub fn batch_size(mut self, size: usize) -> Self {
126 self.config.batch_size = size;
127 self
128 }
129
130 pub fn random_state(mut self, seed: u64) -> Self {
132 self.config.random_state = Some(seed);
133 self
134 }
135
136 pub fn encoder_layers(mut self, layers: Vec<usize>) -> Self {
138 self.config.encoder_layers = Some(layers);
139 self
140 }
141
142 pub fn noise_factor(mut self, factor: Float) -> Self {
144 self.config.noise_factor = factor;
145 self
146 }
147
148 pub fn l2_reg(mut self, reg: Float) -> Self {
150 self.config.l2_reg = reg;
151 self
152 }
153
154 pub fn sparsity(mut self, rho: Float, beta: Float) -> Self {
156 self.config.sparsity = Some((rho, beta));
157 self
158 }
159
160 pub fn optimizer(mut self, opt: OptimizerType) -> Self {
162 self.config.optimizer = opt;
163 self
164 }
165}
166
167impl Default for Autoencoder<Untrained> {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl Estimator for Autoencoder<Untrained> {
174 type Config = AutoencoderConfig;
175 type Error = SklearsError;
176 type Float = Float;
177
178 fn config(&self) -> &Self::Config {
179 &self.config
180 }
181}
182
183impl Fit<Array2<Float>, ()> for Autoencoder<Untrained> {
184 type Fitted = Autoencoder<Trained>;
185
186 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
187 let n_samples = x.nrows();
188 let n_features = x.ncols();
189 let encoding_dim = self.config.encoding_dim;
190
191 let mut rng = if let Some(seed) = self.config.random_state {
193 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
194 } else {
195 scirs2_core::random::rngs::StdRng::seed_from_u64(42) };
197
198 let limit_enc = (6.0 / (n_features + encoding_dim) as Float).sqrt();
199 let limit_dec = (6.0 / (encoding_dim + n_features) as Float).sqrt();
200
201 let mut encoder_weights = Array2::zeros((n_features, encoding_dim));
202 let mut decoder_weights = Array2::zeros((encoding_dim, n_features));
203
204 for i in 0..n_features {
205 for j in 0..encoding_dim {
206 encoder_weights[[i, j]] = rng.gen_range(-limit_enc..limit_enc);
207 }
208 }
209
210 for i in 0..encoding_dim {
211 for j in 0..n_features {
212 decoder_weights[[i, j]] = rng.gen_range(-limit_dec..limit_dec);
213 }
214 }
215
216 let mut encoder_bias = Array1::zeros(encoding_dim);
217 let mut decoder_bias = Array1::zeros(n_features);
218
219 let batch_size = self.config.batch_size.min(n_samples);
221 let n_batches = n_samples.div_ceil(batch_size);
222
223 for epoch in 0..self.config.n_epochs {
224 let mut epoch_loss = 0.0;
225
226 let mut indices: Vec<usize> = (0..n_samples).collect();
228 if self.config.random_state.is_some() {
229 use scirs2_core::random::seq::SliceRandom;
230 indices.shuffle(&mut rng);
231 }
232
233 for batch_idx in 0..n_batches {
234 let start = batch_idx * batch_size;
235 let end = ((batch_idx + 1) * batch_size).min(n_samples);
236 let batch_indices = &indices[start..end];
237 let actual_batch_size = batch_indices.len();
238
239 let mut batch = Array2::zeros((actual_batch_size, n_features));
241 for (i, &idx) in batch_indices.iter().enumerate() {
242 batch.row_mut(i).assign(&x.row(idx));
243 }
244
245 let z_enc = batch.dot(&encoder_weights) + &encoder_bias;
247 let a_enc = self.config.activation.apply(&z_enc);
248
249 let z_dec = a_enc.dot(&decoder_weights) + &decoder_bias;
250 let a_dec = Activation::Identity.apply(&z_dec); let diff = &batch - &a_dec;
254 let loss = diff.mapv(|x| x * x).sum() / (actual_batch_size * n_features) as Float;
255 epoch_loss += loss * actual_batch_size as Float;
256
257 let d_output =
259 (&a_dec - &batch) * (2.0 / (actual_batch_size * n_features) as Float);
260
261 let d_w_dec = a_enc.t().dot(&d_output);
263 let d_b_dec = d_output.sum_axis(Axis(0));
264
265 let d_enc = d_output.dot(&decoder_weights.t());
267 let d_enc = &d_enc * &self.config.activation.derivative(&a_enc);
268
269 let d_w_enc = batch.t().dot(&d_enc);
271 let d_b_enc = d_enc.sum_axis(Axis(0));
272
273 encoder_weights = &encoder_weights - &d_w_enc * self.config.learning_rate;
275 encoder_bias = &encoder_bias - &d_b_enc * self.config.learning_rate;
276 decoder_weights = &decoder_weights - &d_w_dec * self.config.learning_rate;
277 decoder_bias = &decoder_bias - &d_b_dec * self.config.learning_rate;
278 }
279
280 epoch_loss /= n_samples as Float;
281
282 if epoch % 10 == 0 {
283 println!("Epoch {epoch}: Loss = {epoch_loss:.6}");
284 }
285 }
286
287 Ok(Autoencoder {
288 config: self.config,
289 state: PhantomData,
290 encoder_weights_: Some(encoder_weights),
291 encoder_bias_: Some(encoder_bias),
292 decoder_weights_: Some(decoder_weights),
293 decoder_bias_: Some(decoder_bias),
294 _n_features: Some(n_features),
295 })
296 }
297}
298
299impl Transform<Array2<Float>, Array2<Float>> for Autoencoder<Trained> {
300 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
302 let encoder_weights =
303 self.encoder_weights_
304 .as_ref()
305 .ok_or_else(|| SklearsError::NotFitted {
306 operation: "transform".to_string(),
307 })?;
308 let encoder_bias = self
309 .encoder_bias_
310 .as_ref()
311 .ok_or_else(|| SklearsError::NotFitted {
312 operation: "transform".to_string(),
313 })?;
314
315 let z = x.dot(encoder_weights) + encoder_bias;
316 Ok(self.config.activation.apply(&z))
317 }
318}
319
320impl Autoencoder<Trained> {
321 pub fn reconstruct(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
323 let encoded = self.transform(x)?;
324
325 let decoder_weights =
326 self.decoder_weights_
327 .as_ref()
328 .ok_or_else(|| SklearsError::NotFitted {
329 operation: "reconstruct".to_string(),
330 })?;
331 let decoder_bias = self
332 .decoder_bias_
333 .as_ref()
334 .ok_or_else(|| SklearsError::NotFitted {
335 operation: "reconstruct".to_string(),
336 })?;
337
338 let z = encoded.dot(decoder_weights) + decoder_bias;
339 Ok(Activation::Identity.apply(&z))
340 }
341}
342
343#[allow(non_snake_case)]
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_autoencoder_construction() {
350 let ae = Autoencoder::new()
351 .encoding_dim(16)
352 .activation(Activation::Relu)
353 .learning_rate(0.01)
354 .n_epochs(10);
355
356 assert_eq!(ae.config.encoding_dim, 16);
357 assert_eq!(ae.config.learning_rate, 0.01);
358 assert_eq!(ae.config.n_epochs, 10);
359 }
360
361 #[test]
362 fn test_autoencoder_fit_transform() {
363 let x =
364 Array2::from_shape_vec((10, 5), (0..50).map(|i| i as Float / 10.0).collect()).unwrap();
365
366 let ae = Autoencoder::new()
367 .encoding_dim(3)
368 .n_epochs(20)
369 .learning_rate(0.1)
370 .random_state(42);
371
372 let fitted = ae.fit(&x, &()).unwrap();
373
374 let encoded = fitted.transform(&x).unwrap();
376 assert_eq!(encoded.shape(), &[10, 3]);
377
378 let reconstructed = fitted.reconstruct(&x).unwrap();
380 assert_eq!(reconstructed.shape(), x.shape());
381 }
382}