1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::Random;
10use sklears_core::error::{Result, SklearsError};
11use sklears_core::traits::{Estimator, Fit, Predict, PredictProba};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum DeepBeliefNetworkError {
16 #[error("Invalid layer size: {0}")]
17 InvalidLayerSize(usize),
18 #[error("Invalid learning rate: {0}")]
19 InvalidLearningRate(f64),
20 #[error("Invalid number of epochs: {0}")]
21 InvalidEpochs(usize),
22 #[error("Invalid batch size: {0}")]
23 InvalidBatchSize(usize),
24 #[error("Invalid number of gibbs steps: {0}")]
25 InvalidGibbsSteps(usize),
26 #[error("Empty hidden layers")]
27 EmptyHiddenLayers,
28 #[error("Insufficient labeled samples")]
29 InsufficientLabeledSamples,
30 #[error("Matrix operation failed: {0}")]
31 MatrixOperationFailed(String),
32 #[error("RBM training failed: {0}")]
33 RBMTrainingFailed(String),
34}
35
36impl From<DeepBeliefNetworkError> for SklearsError {
37 fn from(err: DeepBeliefNetworkError) -> Self {
38 SklearsError::FitError(err.to_string())
39 }
40}
41
42#[derive(Debug, Clone)]
47pub struct RestrictedBoltzmannMachine {
48 pub n_visible: usize,
50 pub n_hidden: usize,
52 pub learning_rate: f64,
54 pub n_epochs: usize,
56 pub batch_size: usize,
58 pub n_gibbs_steps: usize,
60 pub random_state: Option<u64>,
62 weights: Array2<f64>,
63 visible_bias: Array1<f64>,
64 hidden_bias: Array1<f64>,
65}
66
67impl RestrictedBoltzmannMachine {
68 pub fn new(n_visible: usize, n_hidden: usize) -> Result<Self> {
69 if n_visible == 0 {
70 return Err(DeepBeliefNetworkError::InvalidLayerSize(n_visible).into());
71 }
72 if n_hidden == 0 {
73 return Err(DeepBeliefNetworkError::InvalidLayerSize(n_hidden).into());
74 }
75
76 Ok(Self {
77 n_visible,
78 n_hidden,
79 learning_rate: 0.01,
80 n_epochs: 10,
81 batch_size: 32,
82 n_gibbs_steps: 1,
83 random_state: None,
84 weights: Array2::zeros((n_visible, n_hidden)),
85 visible_bias: Array1::zeros(n_visible),
86 hidden_bias: Array1::zeros(n_hidden),
87 })
88 }
89
90 pub fn learning_rate(mut self, learning_rate: f64) -> Result<Self> {
91 if learning_rate <= 0.0 {
92 return Err(DeepBeliefNetworkError::InvalidLearningRate(learning_rate).into());
93 }
94 self.learning_rate = learning_rate;
95 Ok(self)
96 }
97
98 pub fn n_epochs(mut self, n_epochs: usize) -> Result<Self> {
99 if n_epochs == 0 {
100 return Err(DeepBeliefNetworkError::InvalidEpochs(n_epochs).into());
101 }
102 self.n_epochs = n_epochs;
103 Ok(self)
104 }
105
106 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
107 if batch_size == 0 {
108 return Err(DeepBeliefNetworkError::InvalidBatchSize(batch_size).into());
109 }
110 self.batch_size = batch_size;
111 Ok(self)
112 }
113
114 pub fn n_gibbs_steps(mut self, n_gibbs_steps: usize) -> Result<Self> {
115 if n_gibbs_steps == 0 {
116 return Err(DeepBeliefNetworkError::InvalidGibbsSteps(n_gibbs_steps).into());
117 }
118 self.n_gibbs_steps = n_gibbs_steps;
119 Ok(self)
120 }
121
122 pub fn random_state(mut self, random_state: u64) -> Self {
123 self.random_state = Some(random_state);
124 self
125 }
126
127 fn initialize_weights(&mut self) -> Result<()> {
128 let mut rng = match self.random_state {
129 Some(seed) => Random::seed(seed),
130 None => Random::seed(42),
131 };
132
133 let mut weights = Array2::<f64>::zeros((self.n_visible, self.n_hidden));
135 for i in 0..self.n_visible {
136 for j in 0..self.n_hidden {
137 let u1: f64 = rng.random_range(0.0..1.0);
139 let u2: f64 = rng.random_range(0.0..1.0);
140 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
141 weights[(i, j)] = z * 0.01;
142 }
143 }
144 self.weights = weights;
145 self.visible_bias = Array1::zeros(self.n_visible);
146 self.hidden_bias = Array1::zeros(self.n_hidden);
147
148 Ok(())
149 }
150
151 fn sigmoid(&self, x: f64) -> f64 {
152 1.0 / (1.0 + (-x).exp())
153 }
154
155 fn sample_hidden<R>(
156 &self,
157 visible: &ArrayView1<f64>,
158 rng: &mut Random<R>,
159 ) -> Result<Array1<f64>>
160 where
161 R: scirs2_core::random::Rng,
162 {
163 let mut hidden_probs = Array1::zeros(self.n_hidden);
164
165 for j in 0..self.n_hidden {
166 let mut activation = self.hidden_bias[j];
167 for i in 0..self.n_visible {
168 activation += visible[i] * self.weights[[i, j]];
169 }
170 hidden_probs[j] = self.sigmoid(activation);
171 }
172
173 let mut hidden_sample = Array1::zeros(self.n_hidden);
175 for j in 0..self.n_hidden {
176 let random_val = rng.random_range(0.0..1.0);
177 hidden_sample[j] = if random_val < hidden_probs[j] {
178 1.0
179 } else {
180 0.0
181 };
182 }
183
184 Ok(hidden_sample)
185 }
186
187 fn sample_visible<R>(
188 &self,
189 hidden: &ArrayView1<f64>,
190 rng: &mut Random<R>,
191 ) -> Result<Array1<f64>>
192 where
193 R: scirs2_core::random::Rng,
194 {
195 let mut visible_probs = Array1::zeros(self.n_visible);
196
197 for i in 0..self.n_visible {
198 let mut activation = self.visible_bias[i];
199 for j in 0..self.n_hidden {
200 activation += hidden[j] * self.weights[[i, j]];
201 }
202 visible_probs[i] = self.sigmoid(activation);
203 }
204
205 let mut visible_sample = Array1::zeros(self.n_visible);
207 for i in 0..self.n_visible {
208 let random_val = rng.random_range(0.0..1.0);
209 visible_sample[i] = if random_val < visible_probs[i] {
210 1.0
211 } else {
212 0.0
213 };
214 }
215
216 Ok(visible_sample)
217 }
218
219 fn contrastive_divergence(&mut self, data: &ArrayView2<f64>) -> Result<f64> {
220 let n_samples = data.dim().0;
221 let mut rng = match self.random_state {
222 Some(seed) => Random::seed(seed),
223 None => Random::seed(42),
224 };
225
226 let mut total_error = 0.0;
227
228 for batch_start in (0..n_samples).step_by(self.batch_size) {
230 let batch_end = std::cmp::min(batch_start + self.batch_size, n_samples);
231 let batch_size = batch_end - batch_start;
232
233 if batch_size == 0 {
234 continue;
235 }
236
237 let mut pos_weights_grad: Array2<f64> = Array2::zeros((self.n_visible, self.n_hidden));
238 let mut neg_weights_grad: Array2<f64> = Array2::zeros((self.n_visible, self.n_hidden));
239 let mut pos_visible_grad: Array1<f64> = Array1::zeros(self.n_visible);
240 let mut neg_visible_grad: Array1<f64> = Array1::zeros(self.n_visible);
241 let mut pos_hidden_grad: Array1<f64> = Array1::zeros(self.n_hidden);
242 let mut neg_hidden_grad: Array1<f64> = Array1::zeros(self.n_hidden);
243
244 for sample_idx in batch_start..batch_end {
245 let visible_data = data.row(sample_idx);
246
247 let hidden_probs_pos = self.compute_hidden_probs(&visible_data)?;
249
250 let mut visible_sample = visible_data.to_owned();
252 let mut hidden_sample = self.sample_hidden(&visible_sample.view(), &mut rng)?;
253
254 for _ in 0..self.n_gibbs_steps {
255 visible_sample = self.sample_visible(&hidden_sample.view(), &mut rng)?;
256 hidden_sample = self.sample_hidden(&visible_sample.view(), &mut rng)?;
257 }
258
259 let hidden_probs_neg = self.compute_hidden_probs(&visible_sample.view())?;
260
261 for i in 0..self.n_visible {
263 for j in 0..self.n_hidden {
264 pos_weights_grad[[i, j]] += visible_data[i] * hidden_probs_pos[j];
265 neg_weights_grad[[i, j]] += visible_sample[i] * hidden_probs_neg[j];
266 }
267 pos_visible_grad[i] += visible_data[i];
268 neg_visible_grad[i] += visible_sample[i];
269 }
270
271 for j in 0..self.n_hidden {
272 pos_hidden_grad[j] += hidden_probs_pos[j];
273 neg_hidden_grad[j] += hidden_probs_neg[j];
274 }
275
276 let error: f64 = visible_data
278 .iter()
279 .zip(visible_sample.iter())
280 .map(|(a, b)| (a - b).powi(2))
281 .sum();
282 total_error += error;
283 }
284
285 let lr = self.learning_rate / batch_size as f64;
287
288 self.weights = &self.weights + &((pos_weights_grad - neg_weights_grad) * lr);
289 self.visible_bias = &self.visible_bias + &((pos_visible_grad - neg_visible_grad) * lr);
290 self.hidden_bias = &self.hidden_bias + &((pos_hidden_grad - neg_hidden_grad) * lr);
291 }
292
293 Ok(total_error / n_samples as f64)
294 }
295
296 fn compute_hidden_probs(&self, visible: &ArrayView1<f64>) -> Result<Array1<f64>> {
297 let mut hidden_probs = Array1::zeros(self.n_hidden);
298
299 for j in 0..self.n_hidden {
300 let mut activation = self.hidden_bias[j];
301 for i in 0..self.n_visible {
302 activation += visible[i] * self.weights[[i, j]];
303 }
304 hidden_probs[j] = self.sigmoid(activation);
305 }
306
307 Ok(hidden_probs)
308 }
309
310 pub fn fit(&mut self, data: &ArrayView2<f64>) -> Result<()> {
311 self.initialize_weights()?;
312
313 for epoch in 0..self.n_epochs {
314 let error = self.contrastive_divergence(data)?;
315
316 if epoch % 10 == 0 {
317 println!("RBM Epoch {}: Reconstruction Error = {:.6}", epoch, error);
318 }
319 }
320
321 Ok(())
322 }
323
324 pub fn transform(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
325 let n_samples = data.dim().0;
326 let mut hidden_features = Array2::zeros((n_samples, self.n_hidden));
327
328 for i in 0..n_samples {
329 let hidden_probs = self.compute_hidden_probs(&data.row(i))?;
330 hidden_features.row_mut(i).assign(&hidden_probs);
331 }
332
333 Ok(hidden_features)
334 }
335
336 pub fn reconstruct(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
337 let n_samples = data.dim().0;
338 let mut reconstructed = Array2::zeros((n_samples, self.n_visible));
339 let mut rng = match self.random_state {
340 Some(seed) => Random::seed(seed),
341 None => Random::seed(42),
342 };
343
344 for i in 0..n_samples {
345 let hidden_sample = self.sample_hidden(&data.row(i), &mut rng)?;
346 let visible_sample = self.sample_visible(&hidden_sample.view(), &mut rng)?;
347 reconstructed.row_mut(i).assign(&visible_sample);
348 }
349
350 Ok(reconstructed)
351 }
352}
353
354#[derive(Debug, Clone)]
359pub struct DeepBeliefNetwork {
360 pub hidden_layers: Vec<usize>,
362 pub learning_rate: f64,
364 pub pretraining_epochs: usize,
366 pub finetuning_epochs: usize,
368 pub batch_size: usize,
370 pub n_gibbs_steps: usize,
372 pub random_state: Option<u64>,
374}
375
376impl Default for DeepBeliefNetwork {
377 fn default() -> Self {
378 Self {
379 hidden_layers: vec![100, 50],
380 learning_rate: 0.01,
381 pretraining_epochs: 50,
382 finetuning_epochs: 100,
383 batch_size: 32,
384 n_gibbs_steps: 1,
385 random_state: None,
386 }
387 }
388}
389
390impl DeepBeliefNetwork {
391 pub fn new() -> Self {
392 Self::default()
393 }
394
395 pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Result<Self> {
396 if hidden_layers.is_empty() {
397 return Err(DeepBeliefNetworkError::EmptyHiddenLayers.into());
398 }
399 for &size in hidden_layers.iter() {
400 if size == 0 {
401 return Err(DeepBeliefNetworkError::InvalidLayerSize(size).into());
402 }
403 }
404 self.hidden_layers = hidden_layers;
405 Ok(self)
406 }
407
408 pub fn learning_rate(mut self, learning_rate: f64) -> Result<Self> {
409 if learning_rate <= 0.0 {
410 return Err(DeepBeliefNetworkError::InvalidLearningRate(learning_rate).into());
411 }
412 self.learning_rate = learning_rate;
413 Ok(self)
414 }
415
416 pub fn pretraining_epochs(mut self, pretraining_epochs: usize) -> Result<Self> {
417 if pretraining_epochs == 0 {
418 return Err(DeepBeliefNetworkError::InvalidEpochs(pretraining_epochs).into());
419 }
420 self.pretraining_epochs = pretraining_epochs;
421 Ok(self)
422 }
423
424 pub fn finetuning_epochs(mut self, finetuning_epochs: usize) -> Result<Self> {
425 if finetuning_epochs == 0 {
426 return Err(DeepBeliefNetworkError::InvalidEpochs(finetuning_epochs).into());
427 }
428 self.finetuning_epochs = finetuning_epochs;
429 Ok(self)
430 }
431
432 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
433 if batch_size == 0 {
434 return Err(DeepBeliefNetworkError::InvalidBatchSize(batch_size).into());
435 }
436 self.batch_size = batch_size;
437 Ok(self)
438 }
439
440 pub fn n_gibbs_steps(mut self, n_gibbs_steps: usize) -> Result<Self> {
441 if n_gibbs_steps == 0 {
442 return Err(DeepBeliefNetworkError::InvalidGibbsSteps(n_gibbs_steps).into());
443 }
444 self.n_gibbs_steps = n_gibbs_steps;
445 Ok(self)
446 }
447
448 pub fn random_state(mut self, random_state: u64) -> Self {
449 self.random_state = Some(random_state);
450 self
451 }
452}
453
454#[derive(Debug, Clone)]
456pub struct FittedDeepBeliefNetwork {
457 pub base_model: DeepBeliefNetwork,
459 pub rbm_layers: Vec<RestrictedBoltzmannMachine>,
461 pub classifier_weights: Array2<f64>,
463 pub classifier_bias: Array1<f64>,
465 pub classes: Array1<i32>,
467 pub n_classes: usize,
469}
470
471impl Estimator for DeepBeliefNetwork {
472 type Config = DeepBeliefNetwork;
473 type Error = DeepBeliefNetworkError;
474 type Float = f64;
475
476 fn config(&self) -> &Self::Config {
477 self
478 }
479}
480
481impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for DeepBeliefNetwork {
482 type Fitted = FittedDeepBeliefNetwork;
483
484 fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
485 let (n_samples, n_features) = X.dim();
486
487 let labeled_count = y.iter().filter(|&&label| label != -1).count();
489 if labeled_count < 2 {
490 return Err(DeepBeliefNetworkError::InsufficientLabeledSamples.into());
491 }
492
493 let unique_classes: Vec<i32> = y
495 .iter()
496 .cloned()
497 .filter(|&label| label != -1)
498 .collect::<std::collections::HashSet<_>>()
499 .into_iter()
500 .collect();
501 let n_classes = unique_classes.len();
502
503 println!(
504 "Starting DBN pre-training with {} layers",
505 self.hidden_layers.len()
506 );
507
508 let mut rbm_layers = Vec::new();
510 let mut current_input = X.to_owned();
511
512 for (layer_idx, &layer_size) in self.hidden_layers.iter().enumerate() {
513 let input_size = current_input.dim().1;
514
515 println!(
516 "Pre-training RBM layer {} ({} -> {})",
517 layer_idx + 1,
518 input_size,
519 layer_size
520 );
521
522 let mut rbm = RestrictedBoltzmannMachine::new(input_size, layer_size)?
523 .learning_rate(self.learning_rate)?
524 .n_epochs(self.pretraining_epochs)?
525 .batch_size(self.batch_size)?
526 .n_gibbs_steps(self.n_gibbs_steps)?;
527
528 if let Some(seed) = self.random_state {
529 rbm = rbm.random_state(seed + layer_idx as u64);
530 }
531
532 rbm.fit(¤t_input.view())?;
533
534 current_input = rbm.transform(¤t_input.view())?;
536
537 rbm_layers.push(rbm);
538 }
539
540 println!("Pre-training completed. Starting fine-tuning...");
541
542 let labeled_indices: Vec<usize> = y
544 .iter()
545 .enumerate()
546 .filter(|(_, &label)| label != -1)
547 .map(|(i, _)| i)
548 .collect();
549
550 if labeled_indices.is_empty() {
551 return Err(DeepBeliefNetworkError::InsufficientLabeledSamples.into());
552 }
553
554 let labeled_X = Array2::from_shape_vec(
556 (labeled_indices.len(), n_features),
557 labeled_indices
558 .iter()
559 .flat_map(|&i| X.row(i).to_vec())
560 .collect(),
561 )
562 .map_err(|e| {
563 DeepBeliefNetworkError::MatrixOperationFailed(format!("Array creation failed: {}", e))
564 })?;
565
566 let labeled_y: Vec<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
567
568 let mut features = labeled_X.clone();
570 for rbm in rbm_layers.iter() {
571 features = rbm.transform(&features.view())?;
572 }
573
574 let feature_dim = features.dim().1;
576 let mut rng = match self.random_state {
577 Some(seed) => Random::seed(seed),
578 None => Random::seed(42),
579 };
580
581 let mut classifier_weights = Array2::<f64>::zeros((feature_dim, n_classes));
583 for i in 0..feature_dim {
584 for j in 0..n_classes {
585 let u1: f64 = rng.random_range(0.0..1.0);
587 let u2: f64 = rng.random_range(0.0..1.0);
588 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
589 classifier_weights[(i, j)] = z * 0.1;
590 }
591 }
592 let mut classifier_bias = Array1::zeros(n_classes);
593
594 let lr = self.learning_rate;
596 for epoch in 0..self.finetuning_epochs {
597 let mut total_loss = 0.0;
598 let mut correct_predictions = 0;
599
600 for (sample_idx, &label) in labeled_y.iter().enumerate() {
601 let class_idx = unique_classes.iter().position(|&c| c == label).unwrap();
602 let feature_vec = features.row(sample_idx);
603
604 let mut logits = Array1::zeros(n_classes);
606 for j in 0..n_classes {
607 logits[j] = classifier_bias[j] + feature_vec.dot(&classifier_weights.column(j));
608 }
609
610 let max_logit = logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
612 let exp_logits: Array1<f64> =
613 logits.iter().map(|&x| (x - max_logit).exp()).collect();
614 let sum_exp: f64 = exp_logits.sum();
615 let probabilities: Array1<f64> = exp_logits.iter().map(|&x| x / sum_exp).collect();
616
617 let loss = -probabilities[class_idx].ln();
619 total_loss += loss;
620
621 let predicted_class = probabilities
623 .iter()
624 .enumerate()
625 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
626 .map(|(i, _)| i)
627 .unwrap();
628
629 if predicted_class == class_idx {
630 correct_predictions += 1;
631 }
632
633 let mut target = Array1::zeros(n_classes);
635 target[class_idx] = 1.0;
636 let error = &probabilities - ⌖
637
638 for j in 0..n_classes {
640 classifier_bias[j] -= lr * error[j];
641 for k in 0..feature_dim {
642 classifier_weights[[k, j]] -= lr * error[j] * feature_vec[k];
643 }
644 }
645 }
646
647 if epoch % 10 == 0 {
648 let accuracy = correct_predictions as f64 / labeled_y.len() as f64;
649 println!(
650 "Fine-tuning Epoch {}: Loss = {:.6}, Accuracy = {:.3}",
651 epoch,
652 total_loss / labeled_y.len() as f64,
653 accuracy
654 );
655 }
656 }
657
658 println!("DBN training completed");
659
660 Ok(FittedDeepBeliefNetwork {
661 base_model: self.clone(),
662 rbm_layers,
663 classifier_weights,
664 classifier_bias,
665 classes: Array1::from_vec(unique_classes),
666 n_classes,
667 })
668 }
669}
670
671impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedDeepBeliefNetwork {
672 fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
673 let probabilities = self.predict_proba(X)?;
674 let n_samples = X.dim().0;
675 let mut predictions = Array1::zeros(n_samples);
676
677 for i in 0..n_samples {
678 let predicted_class_idx = probabilities
679 .row(i)
680 .iter()
681 .enumerate()
682 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
683 .map(|(i, _)| i)
684 .unwrap();
685 predictions[i] = self.classes[predicted_class_idx];
686 }
687
688 Ok(predictions)
689 }
690}
691
692impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedDeepBeliefNetwork {
693 fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
694 let n_samples = X.dim().0;
695
696 let mut features = X.to_owned();
698 for rbm in self.rbm_layers.iter() {
699 features = rbm.transform(&features.view())?;
700 }
701
702 let mut probabilities = Array2::zeros((n_samples, self.n_classes));
703
704 for i in 0..n_samples {
705 let feature_vec = features.row(i);
706
707 let mut logits = Array1::zeros(self.n_classes);
709 for j in 0..self.n_classes {
710 logits[j] =
711 self.classifier_bias[j] + feature_vec.dot(&self.classifier_weights.column(j));
712 }
713
714 let max_logit = logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
716 let exp_logits: Array1<f64> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
717 let sum_exp: f64 = exp_logits.sum();
718
719 for j in 0..self.n_classes {
720 probabilities[[i, j]] = exp_logits[j] / sum_exp;
721 }
722 }
723
724 Ok(probabilities)
725 }
726}
727
728#[allow(non_snake_case)]
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use approx::assert_abs_diff_eq;
733 use scirs2_core::array;
734
735 #[test]
736 fn test_rbm_creation() {
737 let rbm = RestrictedBoltzmannMachine::new(10, 5).unwrap();
738 assert_eq!(rbm.n_visible, 10);
739 assert_eq!(rbm.n_hidden, 5);
740 assert_eq!(rbm.learning_rate, 0.01);
741 assert_eq!(rbm.n_epochs, 10);
742 }
743
744 #[test]
745 fn test_rbm_invalid_parameters() {
746 assert!(RestrictedBoltzmannMachine::new(0, 5).is_err());
747 assert!(RestrictedBoltzmannMachine::new(5, 0).is_err());
748
749 let rbm = RestrictedBoltzmannMachine::new(5, 3).unwrap();
750 assert!(rbm.clone().learning_rate(0.0).is_err());
751 assert!(rbm.clone().learning_rate(-0.1).is_err());
752 assert!(rbm.clone().n_epochs(0).is_err());
753 assert!(rbm.clone().batch_size(0).is_err());
754 assert!(rbm.clone().n_gibbs_steps(0).is_err());
755 }
756
757 #[test]
758 fn test_rbm_sigmoid() {
759 let rbm = RestrictedBoltzmannMachine::new(3, 2).unwrap();
760 assert_abs_diff_eq!(rbm.sigmoid(0.0), 0.5, epsilon = 1e-10);
761 assert!(rbm.sigmoid(10.0) > 0.9);
762 assert!(rbm.sigmoid(-10.0) < 0.1);
763 }
764
765 #[test]
766 #[allow(non_snake_case)]
767 fn test_rbm_fit_and_transform() {
768 let X = array![
769 [1.0, 0.0, 1.0],
770 [0.0, 1.0, 0.0],
771 [1.0, 1.0, 0.0],
772 [0.0, 0.0, 1.0]
773 ];
774
775 let mut rbm = RestrictedBoltzmannMachine::new(3, 2)
776 .unwrap()
777 .learning_rate(0.1)
778 .unwrap()
779 .n_epochs(5)
780 .unwrap()
781 .batch_size(2)
782 .unwrap()
783 .random_state(42);
784
785 rbm.fit(&X.view()).unwrap();
786
787 let transformed = rbm.transform(&X.view()).unwrap();
788 assert_eq!(transformed.dim(), (4, 2));
789
790 for value in transformed.iter() {
792 assert!(*value >= 0.0 && *value <= 1.0);
793 }
794 }
795
796 #[test]
797 #[allow(non_snake_case)]
798 fn test_rbm_reconstruct() {
799 let X = array![[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]];
800
801 let mut rbm = RestrictedBoltzmannMachine::new(3, 2)
802 .unwrap()
803 .learning_rate(0.1)
804 .unwrap()
805 .n_epochs(3)
806 .unwrap()
807 .random_state(42);
808
809 rbm.fit(&X.view()).unwrap();
810
811 let reconstructed = rbm.reconstruct(&X.view()).unwrap();
812 assert_eq!(reconstructed.dim(), (2, 3));
813
814 for value in reconstructed.iter() {
816 assert!(*value == 0.0 || *value == 1.0);
817 }
818 }
819
820 #[test]
821 fn test_dbn_creation() {
822 let dbn = DeepBeliefNetwork::new()
823 .hidden_layers(vec![10, 5])
824 .unwrap()
825 .learning_rate(0.01)
826 .unwrap()
827 .pretraining_epochs(5)
828 .unwrap()
829 .finetuning_epochs(5)
830 .unwrap()
831 .batch_size(16)
832 .unwrap()
833 .random_state(42);
834
835 assert_eq!(dbn.hidden_layers, vec![10, 5]);
836 assert_eq!(dbn.learning_rate, 0.01);
837 assert_eq!(dbn.pretraining_epochs, 5);
838 assert_eq!(dbn.finetuning_epochs, 5);
839 assert_eq!(dbn.batch_size, 16);
840 assert_eq!(dbn.random_state, Some(42));
841 }
842
843 #[test]
844 fn test_dbn_invalid_parameters() {
845 assert!(DeepBeliefNetwork::new().hidden_layers(vec![]).is_err());
846 assert!(DeepBeliefNetwork::new().hidden_layers(vec![0, 5]).is_err());
847 assert!(DeepBeliefNetwork::new().learning_rate(0.0).is_err());
848 assert!(DeepBeliefNetwork::new().pretraining_epochs(0).is_err());
849 assert!(DeepBeliefNetwork::new().finetuning_epochs(0).is_err());
850 assert!(DeepBeliefNetwork::new().batch_size(0).is_err());
851 assert!(DeepBeliefNetwork::new().n_gibbs_steps(0).is_err());
852 }
853
854 #[test]
855 #[allow(non_snake_case)]
856 fn test_dbn_fit_predict() {
857 let X = array![
858 [1.0, 0.0, 1.0, 0.0],
859 [0.0, 1.0, 0.0, 1.0],
860 [1.0, 1.0, 0.0, 0.0],
861 [0.0, 0.0, 1.0, 1.0],
862 [1.0, 0.0, 0.0, 1.0],
863 [0.0, 1.0, 1.0, 0.0]
864 ];
865 let y = array![0, 1, 0, 1, -1, -1]; let dbn = DeepBeliefNetwork::new()
868 .hidden_layers(vec![3, 2])
869 .unwrap()
870 .learning_rate(0.1)
871 .unwrap()
872 .pretraining_epochs(3)
873 .unwrap()
874 .finetuning_epochs(3)
875 .unwrap()
876 .batch_size(2)
877 .unwrap()
878 .random_state(42);
879
880 let fitted = dbn.fit(&X.view(), &y.view()).unwrap();
881
882 let predictions = fitted.predict(&X.view()).unwrap();
883 assert_eq!(predictions.len(), 6);
884
885 for &pred in predictions.iter() {
887 assert!(pred == 0 || pred == 1);
888 }
889
890 let probabilities = fitted.predict_proba(&X.view()).unwrap();
891 assert_eq!(probabilities.dim(), (6, 2));
892
893 for i in 0..6 {
895 let sum: f64 = probabilities.row(i).sum();
896 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
897 }
898
899 for value in probabilities.iter() {
901 assert!(*value >= 0.0 && *value <= 1.0);
902 }
903 }
904
905 #[test]
906 #[allow(non_snake_case)]
907 fn test_dbn_insufficient_labeled_samples() {
908 let X = array![[1.0, 2.0], [2.0, 3.0]];
909 let y = array![-1, -1]; let dbn = DeepBeliefNetwork::new().hidden_layers(vec![2]).unwrap();
912
913 let result = dbn.fit(&X.view(), &y.view());
914 assert!(result.is_err());
915 }
916
917 #[test]
918 fn test_rbm_hidden_probs_computation() {
919 let mut rbm = RestrictedBoltzmannMachine::new(3, 2)
920 .unwrap()
921 .random_state(42);
922 rbm.initialize_weights().unwrap();
923
924 let visible = array![1.0, 0.0, 1.0];
925 let hidden_probs = rbm.compute_hidden_probs(&visible.view()).unwrap();
926
927 assert_eq!(hidden_probs.len(), 2);
928 for prob in hidden_probs.iter() {
929 assert!(*prob >= 0.0 && *prob <= 1.0);
930 }
931 }
932}