sklears_semi_supervised/deep_learning/
deep_belief_networks.rs

1//! Deep Belief Networks for semi-supervised learning
2//!
3//! This module implements Deep Belief Networks (DBNs) which are generative models
4//! consisting of multiple layers of Restricted Boltzmann Machines (RBMs).
5//! DBNs can be used for semi-supervised learning by pre-training on unlabeled data
6//! and fine-tuning with labeled data.
7
8use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::Random;
10use sklears_core::error::{Result, SklearsError};
11use sklears_core::traits::{Estimator, Fit, Predict, PredictProba};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum DeepBeliefNetworkError {
16    #[error("Invalid layer size: {0}")]
17    InvalidLayerSize(usize),
18    #[error("Invalid learning rate: {0}")]
19    InvalidLearningRate(f64),
20    #[error("Invalid number of epochs: {0}")]
21    InvalidEpochs(usize),
22    #[error("Invalid batch size: {0}")]
23    InvalidBatchSize(usize),
24    #[error("Invalid number of gibbs steps: {0}")]
25    InvalidGibbsSteps(usize),
26    #[error("Empty hidden layers")]
27    EmptyHiddenLayers,
28    #[error("Insufficient labeled samples")]
29    InsufficientLabeledSamples,
30    #[error("Matrix operation failed: {0}")]
31    MatrixOperationFailed(String),
32    #[error("RBM training failed: {0}")]
33    RBMTrainingFailed(String),
34}
35
36impl From<DeepBeliefNetworkError> for SklearsError {
37    fn from(err: DeepBeliefNetworkError) -> Self {
38        SklearsError::FitError(err.to_string())
39    }
40}
41
42/// Restricted Boltzmann Machine (RBM) component
43///
44/// An RBM is a two-layer neural network with visible and hidden units
45/// that can learn probability distributions over its inputs.
46#[derive(Debug, Clone)]
47pub struct RestrictedBoltzmannMachine {
48    /// n_visible
49    pub n_visible: usize,
50    /// n_hidden
51    pub n_hidden: usize,
52    /// learning_rate
53    pub learning_rate: f64,
54    /// n_epochs
55    pub n_epochs: usize,
56    /// batch_size
57    pub batch_size: usize,
58    /// n_gibbs_steps
59    pub n_gibbs_steps: usize,
60    /// random_state
61    pub random_state: Option<u64>,
62    weights: Array2<f64>,
63    visible_bias: Array1<f64>,
64    hidden_bias: Array1<f64>,
65}
66
67impl RestrictedBoltzmannMachine {
68    pub fn new(n_visible: usize, n_hidden: usize) -> Result<Self> {
69        if n_visible == 0 {
70            return Err(DeepBeliefNetworkError::InvalidLayerSize(n_visible).into());
71        }
72        if n_hidden == 0 {
73            return Err(DeepBeliefNetworkError::InvalidLayerSize(n_hidden).into());
74        }
75
76        Ok(Self {
77            n_visible,
78            n_hidden,
79            learning_rate: 0.01,
80            n_epochs: 10,
81            batch_size: 32,
82            n_gibbs_steps: 1,
83            random_state: None,
84            weights: Array2::zeros((n_visible, n_hidden)),
85            visible_bias: Array1::zeros(n_visible),
86            hidden_bias: Array1::zeros(n_hidden),
87        })
88    }
89
90    pub fn learning_rate(mut self, learning_rate: f64) -> Result<Self> {
91        if learning_rate <= 0.0 {
92            return Err(DeepBeliefNetworkError::InvalidLearningRate(learning_rate).into());
93        }
94        self.learning_rate = learning_rate;
95        Ok(self)
96    }
97
98    pub fn n_epochs(mut self, n_epochs: usize) -> Result<Self> {
99        if n_epochs == 0 {
100            return Err(DeepBeliefNetworkError::InvalidEpochs(n_epochs).into());
101        }
102        self.n_epochs = n_epochs;
103        Ok(self)
104    }
105
106    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
107        if batch_size == 0 {
108            return Err(DeepBeliefNetworkError::InvalidBatchSize(batch_size).into());
109        }
110        self.batch_size = batch_size;
111        Ok(self)
112    }
113
114    pub fn n_gibbs_steps(mut self, n_gibbs_steps: usize) -> Result<Self> {
115        if n_gibbs_steps == 0 {
116            return Err(DeepBeliefNetworkError::InvalidGibbsSteps(n_gibbs_steps).into());
117        }
118        self.n_gibbs_steps = n_gibbs_steps;
119        Ok(self)
120    }
121
122    pub fn random_state(mut self, random_state: u64) -> Self {
123        self.random_state = Some(random_state);
124        self
125    }
126
127    fn initialize_weights(&mut self) -> Result<()> {
128        let mut rng = match self.random_state {
129            Some(seed) => Random::seed(seed),
130            None => Random::seed(42),
131        };
132
133        // Initialize weights with small random values manually
134        let mut weights = Array2::<f64>::zeros((self.n_visible, self.n_hidden));
135        for i in 0..self.n_visible {
136            for j in 0..self.n_hidden {
137                // Generate normal distributed random number (mean=0.0, std=0.01)
138                let u1: f64 = rng.random_range(0.0..1.0);
139                let u2: f64 = rng.random_range(0.0..1.0);
140                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
141                weights[(i, j)] = z * 0.01;
142            }
143        }
144        self.weights = weights;
145        self.visible_bias = Array1::zeros(self.n_visible);
146        self.hidden_bias = Array1::zeros(self.n_hidden);
147
148        Ok(())
149    }
150
151    fn sigmoid(&self, x: f64) -> f64 {
152        1.0 / (1.0 + (-x).exp())
153    }
154
155    fn sample_hidden<R>(
156        &self,
157        visible: &ArrayView1<f64>,
158        rng: &mut Random<R>,
159    ) -> Result<Array1<f64>>
160    where
161        R: scirs2_core::random::Rng,
162    {
163        let mut hidden_probs = Array1::zeros(self.n_hidden);
164
165        for j in 0..self.n_hidden {
166            let mut activation = self.hidden_bias[j];
167            for i in 0..self.n_visible {
168                activation += visible[i] * self.weights[[i, j]];
169            }
170            hidden_probs[j] = self.sigmoid(activation);
171        }
172
173        // Sample from Bernoulli distribution
174        let mut hidden_sample = Array1::zeros(self.n_hidden);
175        for j in 0..self.n_hidden {
176            let random_val = rng.random_range(0.0..1.0);
177            hidden_sample[j] = if random_val < hidden_probs[j] {
178                1.0
179            } else {
180                0.0
181            };
182        }
183
184        Ok(hidden_sample)
185    }
186
187    fn sample_visible<R>(
188        &self,
189        hidden: &ArrayView1<f64>,
190        rng: &mut Random<R>,
191    ) -> Result<Array1<f64>>
192    where
193        R: scirs2_core::random::Rng,
194    {
195        let mut visible_probs = Array1::zeros(self.n_visible);
196
197        for i in 0..self.n_visible {
198            let mut activation = self.visible_bias[i];
199            for j in 0..self.n_hidden {
200                activation += hidden[j] * self.weights[[i, j]];
201            }
202            visible_probs[i] = self.sigmoid(activation);
203        }
204
205        // Sample from Bernoulli distribution
206        let mut visible_sample = Array1::zeros(self.n_visible);
207        for i in 0..self.n_visible {
208            let random_val = rng.random_range(0.0..1.0);
209            visible_sample[i] = if random_val < visible_probs[i] {
210                1.0
211            } else {
212                0.0
213            };
214        }
215
216        Ok(visible_sample)
217    }
218
219    fn contrastive_divergence(&mut self, data: &ArrayView2<f64>) -> Result<f64> {
220        let n_samples = data.dim().0;
221        let mut rng = match self.random_state {
222            Some(seed) => Random::seed(seed),
223            None => Random::seed(42),
224        };
225
226        let mut total_error = 0.0;
227
228        // Process data in batches
229        for batch_start in (0..n_samples).step_by(self.batch_size) {
230            let batch_end = std::cmp::min(batch_start + self.batch_size, n_samples);
231            let batch_size = batch_end - batch_start;
232
233            if batch_size == 0 {
234                continue;
235            }
236
237            let mut pos_weights_grad: Array2<f64> = Array2::zeros((self.n_visible, self.n_hidden));
238            let mut neg_weights_grad: Array2<f64> = Array2::zeros((self.n_visible, self.n_hidden));
239            let mut pos_visible_grad: Array1<f64> = Array1::zeros(self.n_visible);
240            let mut neg_visible_grad: Array1<f64> = Array1::zeros(self.n_visible);
241            let mut pos_hidden_grad: Array1<f64> = Array1::zeros(self.n_hidden);
242            let mut neg_hidden_grad: Array1<f64> = Array1::zeros(self.n_hidden);
243
244            for sample_idx in batch_start..batch_end {
245                let visible_data = data.row(sample_idx);
246
247                // Positive phase: clamp visible units to data
248                let hidden_probs_pos = self.compute_hidden_probs(&visible_data)?;
249
250                // Negative phase: Gibbs sampling
251                let mut visible_sample = visible_data.to_owned();
252                let mut hidden_sample = self.sample_hidden(&visible_sample.view(), &mut rng)?;
253
254                for _ in 0..self.n_gibbs_steps {
255                    visible_sample = self.sample_visible(&hidden_sample.view(), &mut rng)?;
256                    hidden_sample = self.sample_hidden(&visible_sample.view(), &mut rng)?;
257                }
258
259                let hidden_probs_neg = self.compute_hidden_probs(&visible_sample.view())?;
260
261                // Accumulate gradients
262                for i in 0..self.n_visible {
263                    for j in 0..self.n_hidden {
264                        pos_weights_grad[[i, j]] += visible_data[i] * hidden_probs_pos[j];
265                        neg_weights_grad[[i, j]] += visible_sample[i] * hidden_probs_neg[j];
266                    }
267                    pos_visible_grad[i] += visible_data[i];
268                    neg_visible_grad[i] += visible_sample[i];
269                }
270
271                for j in 0..self.n_hidden {
272                    pos_hidden_grad[j] += hidden_probs_pos[j];
273                    neg_hidden_grad[j] += hidden_probs_neg[j];
274                }
275
276                // Compute reconstruction error
277                let error: f64 = visible_data
278                    .iter()
279                    .zip(visible_sample.iter())
280                    .map(|(a, b)| (a - b).powi(2))
281                    .sum();
282                total_error += error;
283            }
284
285            // Update parameters
286            let lr = self.learning_rate / batch_size as f64;
287
288            self.weights = &self.weights + &((pos_weights_grad - neg_weights_grad) * lr);
289            self.visible_bias = &self.visible_bias + &((pos_visible_grad - neg_visible_grad) * lr);
290            self.hidden_bias = &self.hidden_bias + &((pos_hidden_grad - neg_hidden_grad) * lr);
291        }
292
293        Ok(total_error / n_samples as f64)
294    }
295
296    fn compute_hidden_probs(&self, visible: &ArrayView1<f64>) -> Result<Array1<f64>> {
297        let mut hidden_probs = Array1::zeros(self.n_hidden);
298
299        for j in 0..self.n_hidden {
300            let mut activation = self.hidden_bias[j];
301            for i in 0..self.n_visible {
302                activation += visible[i] * self.weights[[i, j]];
303            }
304            hidden_probs[j] = self.sigmoid(activation);
305        }
306
307        Ok(hidden_probs)
308    }
309
310    pub fn fit(&mut self, data: &ArrayView2<f64>) -> Result<()> {
311        self.initialize_weights()?;
312
313        for epoch in 0..self.n_epochs {
314            let error = self.contrastive_divergence(data)?;
315
316            if epoch % 10 == 0 {
317                println!("RBM Epoch {}: Reconstruction Error = {:.6}", epoch, error);
318            }
319        }
320
321        Ok(())
322    }
323
324    pub fn transform(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
325        let n_samples = data.dim().0;
326        let mut hidden_features = Array2::zeros((n_samples, self.n_hidden));
327
328        for i in 0..n_samples {
329            let hidden_probs = self.compute_hidden_probs(&data.row(i))?;
330            hidden_features.row_mut(i).assign(&hidden_probs);
331        }
332
333        Ok(hidden_features)
334    }
335
336    pub fn reconstruct(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
337        let n_samples = data.dim().0;
338        let mut reconstructed = Array2::zeros((n_samples, self.n_visible));
339        let mut rng = match self.random_state {
340            Some(seed) => Random::seed(seed),
341            None => Random::seed(42),
342        };
343
344        for i in 0..n_samples {
345            let hidden_sample = self.sample_hidden(&data.row(i), &mut rng)?;
346            let visible_sample = self.sample_visible(&hidden_sample.view(), &mut rng)?;
347            reconstructed.row_mut(i).assign(&visible_sample);
348        }
349
350        Ok(reconstructed)
351    }
352}
353
354/// Deep Belief Network for semi-supervised learning
355///
356/// A DBN consists of multiple RBM layers stacked on top of each other.
357/// It uses unsupervised pre-training followed by supervised fine-tuning.
358#[derive(Debug, Clone)]
359pub struct DeepBeliefNetwork {
360    /// hidden_layers
361    pub hidden_layers: Vec<usize>,
362    /// learning_rate
363    pub learning_rate: f64,
364    /// pretraining_epochs
365    pub pretraining_epochs: usize,
366    /// finetuning_epochs
367    pub finetuning_epochs: usize,
368    /// batch_size
369    pub batch_size: usize,
370    /// n_gibbs_steps
371    pub n_gibbs_steps: usize,
372    /// random_state
373    pub random_state: Option<u64>,
374}
375
376impl Default for DeepBeliefNetwork {
377    fn default() -> Self {
378        Self {
379            hidden_layers: vec![100, 50],
380            learning_rate: 0.01,
381            pretraining_epochs: 50,
382            finetuning_epochs: 100,
383            batch_size: 32,
384            n_gibbs_steps: 1,
385            random_state: None,
386        }
387    }
388}
389
390impl DeepBeliefNetwork {
391    pub fn new() -> Self {
392        Self::default()
393    }
394
395    pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Result<Self> {
396        if hidden_layers.is_empty() {
397            return Err(DeepBeliefNetworkError::EmptyHiddenLayers.into());
398        }
399        for &size in hidden_layers.iter() {
400            if size == 0 {
401                return Err(DeepBeliefNetworkError::InvalidLayerSize(size).into());
402            }
403        }
404        self.hidden_layers = hidden_layers;
405        Ok(self)
406    }
407
408    pub fn learning_rate(mut self, learning_rate: f64) -> Result<Self> {
409        if learning_rate <= 0.0 {
410            return Err(DeepBeliefNetworkError::InvalidLearningRate(learning_rate).into());
411        }
412        self.learning_rate = learning_rate;
413        Ok(self)
414    }
415
416    pub fn pretraining_epochs(mut self, pretraining_epochs: usize) -> Result<Self> {
417        if pretraining_epochs == 0 {
418            return Err(DeepBeliefNetworkError::InvalidEpochs(pretraining_epochs).into());
419        }
420        self.pretraining_epochs = pretraining_epochs;
421        Ok(self)
422    }
423
424    pub fn finetuning_epochs(mut self, finetuning_epochs: usize) -> Result<Self> {
425        if finetuning_epochs == 0 {
426            return Err(DeepBeliefNetworkError::InvalidEpochs(finetuning_epochs).into());
427        }
428        self.finetuning_epochs = finetuning_epochs;
429        Ok(self)
430    }
431
432    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
433        if batch_size == 0 {
434            return Err(DeepBeliefNetworkError::InvalidBatchSize(batch_size).into());
435        }
436        self.batch_size = batch_size;
437        Ok(self)
438    }
439
440    pub fn n_gibbs_steps(mut self, n_gibbs_steps: usize) -> Result<Self> {
441        if n_gibbs_steps == 0 {
442            return Err(DeepBeliefNetworkError::InvalidGibbsSteps(n_gibbs_steps).into());
443        }
444        self.n_gibbs_steps = n_gibbs_steps;
445        Ok(self)
446    }
447
448    pub fn random_state(mut self, random_state: u64) -> Self {
449        self.random_state = Some(random_state);
450        self
451    }
452}
453
454/// Fitted Deep Belief Network model
455#[derive(Debug, Clone)]
456pub struct FittedDeepBeliefNetwork {
457    /// base_model
458    pub base_model: DeepBeliefNetwork,
459    /// rbm_layers
460    pub rbm_layers: Vec<RestrictedBoltzmannMachine>,
461    /// classifier_weights
462    pub classifier_weights: Array2<f64>,
463    /// classifier_bias
464    pub classifier_bias: Array1<f64>,
465    /// classes
466    pub classes: Array1<i32>,
467    /// n_classes
468    pub n_classes: usize,
469}
470
471impl Estimator for DeepBeliefNetwork {
472    type Config = DeepBeliefNetwork;
473    type Error = DeepBeliefNetworkError;
474    type Float = f64;
475
476    fn config(&self) -> &Self::Config {
477        self
478    }
479}
480
481impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for DeepBeliefNetwork {
482    type Fitted = FittedDeepBeliefNetwork;
483
484    fn fit(self, X: &ArrayView2<'_, f64>, y: &ArrayView1<'_, i32>) -> Result<Self::Fitted> {
485        let (n_samples, n_features) = X.dim();
486
487        // Check for sufficient labeled samples
488        let labeled_count = y.iter().filter(|&&label| label != -1).count();
489        if labeled_count < 2 {
490            return Err(DeepBeliefNetworkError::InsufficientLabeledSamples.into());
491        }
492
493        // Get unique classes
494        let unique_classes: Vec<i32> = y
495            .iter()
496            .cloned()
497            .filter(|&label| label != -1)
498            .collect::<std::collections::HashSet<_>>()
499            .into_iter()
500            .collect();
501        let n_classes = unique_classes.len();
502
503        println!(
504            "Starting DBN pre-training with {} layers",
505            self.hidden_layers.len()
506        );
507
508        // Phase 1: Unsupervised pre-training of RBM layers
509        let mut rbm_layers = Vec::new();
510        let mut current_input = X.to_owned();
511
512        for (layer_idx, &layer_size) in self.hidden_layers.iter().enumerate() {
513            let input_size = current_input.dim().1;
514
515            println!(
516                "Pre-training RBM layer {} ({} -> {})",
517                layer_idx + 1,
518                input_size,
519                layer_size
520            );
521
522            let mut rbm = RestrictedBoltzmannMachine::new(input_size, layer_size)?
523                .learning_rate(self.learning_rate)?
524                .n_epochs(self.pretraining_epochs)?
525                .batch_size(self.batch_size)?
526                .n_gibbs_steps(self.n_gibbs_steps)?;
527
528            if let Some(seed) = self.random_state {
529                rbm = rbm.random_state(seed + layer_idx as u64);
530            }
531
532            rbm.fit(&current_input.view())?;
533
534            // Transform current input for next layer
535            current_input = rbm.transform(&current_input.view())?;
536
537            rbm_layers.push(rbm);
538        }
539
540        println!("Pre-training completed. Starting fine-tuning...");
541
542        // Phase 2: Supervised fine-tuning with labeled data
543        let labeled_indices: Vec<usize> = y
544            .iter()
545            .enumerate()
546            .filter(|(_, &label)| label != -1)
547            .map(|(i, _)| i)
548            .collect();
549
550        if labeled_indices.is_empty() {
551            return Err(DeepBeliefNetworkError::InsufficientLabeledSamples.into());
552        }
553
554        // Extract labeled data
555        let labeled_X = Array2::from_shape_vec(
556            (labeled_indices.len(), n_features),
557            labeled_indices
558                .iter()
559                .flat_map(|&i| X.row(i).to_vec())
560                .collect(),
561        )
562        .map_err(|e| {
563            DeepBeliefNetworkError::MatrixOperationFailed(format!("Array creation failed: {}", e))
564        })?;
565
566        let labeled_y: Vec<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
567
568        // Forward pass through all RBM layers to get final features
569        let mut features = labeled_X.clone();
570        for rbm in rbm_layers.iter() {
571            features = rbm.transform(&features.view())?;
572        }
573
574        // Initialize classifier weights
575        let feature_dim = features.dim().1;
576        let mut rng = match self.random_state {
577            Some(seed) => Random::seed(seed),
578            None => Random::seed(42),
579        };
580
581        // Initialize classifier weights manually
582        let mut classifier_weights = Array2::<f64>::zeros((feature_dim, n_classes));
583        for i in 0..feature_dim {
584            for j in 0..n_classes {
585                // Generate normal distributed random number (mean=0.0, std=0.1)
586                let u1: f64 = rng.random_range(0.0..1.0);
587                let u2: f64 = rng.random_range(0.0..1.0);
588                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
589                classifier_weights[(i, j)] = z * 0.1;
590            }
591        }
592        let mut classifier_bias = Array1::zeros(n_classes);
593
594        // Simple gradient descent for classifier fine-tuning
595        let lr = self.learning_rate;
596        for epoch in 0..self.finetuning_epochs {
597            let mut total_loss = 0.0;
598            let mut correct_predictions = 0;
599
600            for (sample_idx, &label) in labeled_y.iter().enumerate() {
601                let class_idx = unique_classes.iter().position(|&c| c == label).unwrap();
602                let feature_vec = features.row(sample_idx);
603
604                // Forward pass
605                let mut logits = Array1::zeros(n_classes);
606                for j in 0..n_classes {
607                    logits[j] = classifier_bias[j] + feature_vec.dot(&classifier_weights.column(j));
608                }
609
610                // Softmax
611                let max_logit = logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
612                let exp_logits: Array1<f64> =
613                    logits.iter().map(|&x| (x - max_logit).exp()).collect();
614                let sum_exp: f64 = exp_logits.sum();
615                let probabilities: Array1<f64> = exp_logits.iter().map(|&x| x / sum_exp).collect();
616
617                // Cross-entropy loss
618                let loss = -probabilities[class_idx].ln();
619                total_loss += loss;
620
621                // Check prediction
622                let predicted_class = probabilities
623                    .iter()
624                    .enumerate()
625                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
626                    .map(|(i, _)| i)
627                    .unwrap();
628
629                if predicted_class == class_idx {
630                    correct_predictions += 1;
631                }
632
633                // Backward pass
634                let mut target = Array1::zeros(n_classes);
635                target[class_idx] = 1.0;
636                let error = &probabilities - &target;
637
638                // Update weights and bias
639                for j in 0..n_classes {
640                    classifier_bias[j] -= lr * error[j];
641                    for k in 0..feature_dim {
642                        classifier_weights[[k, j]] -= lr * error[j] * feature_vec[k];
643                    }
644                }
645            }
646
647            if epoch % 10 == 0 {
648                let accuracy = correct_predictions as f64 / labeled_y.len() as f64;
649                println!(
650                    "Fine-tuning Epoch {}: Loss = {:.6}, Accuracy = {:.3}",
651                    epoch,
652                    total_loss / labeled_y.len() as f64,
653                    accuracy
654                );
655            }
656        }
657
658        println!("DBN training completed");
659
660        Ok(FittedDeepBeliefNetwork {
661            base_model: self.clone(),
662            rbm_layers,
663            classifier_weights,
664            classifier_bias,
665            classes: Array1::from_vec(unique_classes),
666            n_classes,
667        })
668    }
669}
670
671impl Predict<ArrayView2<'_, f64>, Array1<i32>> for FittedDeepBeliefNetwork {
672    fn predict(&self, X: &ArrayView2<'_, f64>) -> Result<Array1<i32>> {
673        let probabilities = self.predict_proba(X)?;
674        let n_samples = X.dim().0;
675        let mut predictions = Array1::zeros(n_samples);
676
677        for i in 0..n_samples {
678            let predicted_class_idx = probabilities
679                .row(i)
680                .iter()
681                .enumerate()
682                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
683                .map(|(i, _)| i)
684                .unwrap();
685            predictions[i] = self.classes[predicted_class_idx];
686        }
687
688        Ok(predictions)
689    }
690}
691
692impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for FittedDeepBeliefNetwork {
693    fn predict_proba(&self, X: &ArrayView2<'_, f64>) -> Result<Array2<f64>> {
694        let n_samples = X.dim().0;
695
696        // Forward pass through all RBM layers
697        let mut features = X.to_owned();
698        for rbm in self.rbm_layers.iter() {
699            features = rbm.transform(&features.view())?;
700        }
701
702        let mut probabilities = Array2::zeros((n_samples, self.n_classes));
703
704        for i in 0..n_samples {
705            let feature_vec = features.row(i);
706
707            // Compute logits
708            let mut logits = Array1::zeros(self.n_classes);
709            for j in 0..self.n_classes {
710                logits[j] =
711                    self.classifier_bias[j] + feature_vec.dot(&self.classifier_weights.column(j));
712            }
713
714            // Softmax
715            let max_logit = logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
716            let exp_logits: Array1<f64> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
717            let sum_exp: f64 = exp_logits.sum();
718
719            for j in 0..self.n_classes {
720                probabilities[[i, j]] = exp_logits[j] / sum_exp;
721            }
722        }
723
724        Ok(probabilities)
725    }
726}
727
728#[allow(non_snake_case)]
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use approx::assert_abs_diff_eq;
733    use scirs2_core::array;
734
735    #[test]
736    fn test_rbm_creation() {
737        let rbm = RestrictedBoltzmannMachine::new(10, 5).unwrap();
738        assert_eq!(rbm.n_visible, 10);
739        assert_eq!(rbm.n_hidden, 5);
740        assert_eq!(rbm.learning_rate, 0.01);
741        assert_eq!(rbm.n_epochs, 10);
742    }
743
744    #[test]
745    fn test_rbm_invalid_parameters() {
746        assert!(RestrictedBoltzmannMachine::new(0, 5).is_err());
747        assert!(RestrictedBoltzmannMachine::new(5, 0).is_err());
748
749        let rbm = RestrictedBoltzmannMachine::new(5, 3).unwrap();
750        assert!(rbm.clone().learning_rate(0.0).is_err());
751        assert!(rbm.clone().learning_rate(-0.1).is_err());
752        assert!(rbm.clone().n_epochs(0).is_err());
753        assert!(rbm.clone().batch_size(0).is_err());
754        assert!(rbm.clone().n_gibbs_steps(0).is_err());
755    }
756
757    #[test]
758    fn test_rbm_sigmoid() {
759        let rbm = RestrictedBoltzmannMachine::new(3, 2).unwrap();
760        assert_abs_diff_eq!(rbm.sigmoid(0.0), 0.5, epsilon = 1e-10);
761        assert!(rbm.sigmoid(10.0) > 0.9);
762        assert!(rbm.sigmoid(-10.0) < 0.1);
763    }
764
765    #[test]
766    #[allow(non_snake_case)]
767    fn test_rbm_fit_and_transform() {
768        let X = array![
769            [1.0, 0.0, 1.0],
770            [0.0, 1.0, 0.0],
771            [1.0, 1.0, 0.0],
772            [0.0, 0.0, 1.0]
773        ];
774
775        let mut rbm = RestrictedBoltzmannMachine::new(3, 2)
776            .unwrap()
777            .learning_rate(0.1)
778            .unwrap()
779            .n_epochs(5)
780            .unwrap()
781            .batch_size(2)
782            .unwrap()
783            .random_state(42);
784
785        rbm.fit(&X.view()).unwrap();
786
787        let transformed = rbm.transform(&X.view()).unwrap();
788        assert_eq!(transformed.dim(), (4, 2));
789
790        // Check that transformed values are probabilities (between 0 and 1)
791        for value in transformed.iter() {
792            assert!(*value >= 0.0 && *value <= 1.0);
793        }
794    }
795
796    #[test]
797    #[allow(non_snake_case)]
798    fn test_rbm_reconstruct() {
799        let X = array![[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]];
800
801        let mut rbm = RestrictedBoltzmannMachine::new(3, 2)
802            .unwrap()
803            .learning_rate(0.1)
804            .unwrap()
805            .n_epochs(3)
806            .unwrap()
807            .random_state(42);
808
809        rbm.fit(&X.view()).unwrap();
810
811        let reconstructed = rbm.reconstruct(&X.view()).unwrap();
812        assert_eq!(reconstructed.dim(), (2, 3));
813
814        // Check that reconstructed values are binary (0 or 1)
815        for value in reconstructed.iter() {
816            assert!(*value == 0.0 || *value == 1.0);
817        }
818    }
819
820    #[test]
821    fn test_dbn_creation() {
822        let dbn = DeepBeliefNetwork::new()
823            .hidden_layers(vec![10, 5])
824            .unwrap()
825            .learning_rate(0.01)
826            .unwrap()
827            .pretraining_epochs(5)
828            .unwrap()
829            .finetuning_epochs(5)
830            .unwrap()
831            .batch_size(16)
832            .unwrap()
833            .random_state(42);
834
835        assert_eq!(dbn.hidden_layers, vec![10, 5]);
836        assert_eq!(dbn.learning_rate, 0.01);
837        assert_eq!(dbn.pretraining_epochs, 5);
838        assert_eq!(dbn.finetuning_epochs, 5);
839        assert_eq!(dbn.batch_size, 16);
840        assert_eq!(dbn.random_state, Some(42));
841    }
842
843    #[test]
844    fn test_dbn_invalid_parameters() {
845        assert!(DeepBeliefNetwork::new().hidden_layers(vec![]).is_err());
846        assert!(DeepBeliefNetwork::new().hidden_layers(vec![0, 5]).is_err());
847        assert!(DeepBeliefNetwork::new().learning_rate(0.0).is_err());
848        assert!(DeepBeliefNetwork::new().pretraining_epochs(0).is_err());
849        assert!(DeepBeliefNetwork::new().finetuning_epochs(0).is_err());
850        assert!(DeepBeliefNetwork::new().batch_size(0).is_err());
851        assert!(DeepBeliefNetwork::new().n_gibbs_steps(0).is_err());
852    }
853
854    #[test]
855    #[allow(non_snake_case)]
856    fn test_dbn_fit_predict() {
857        let X = array![
858            [1.0, 0.0, 1.0, 0.0],
859            [0.0, 1.0, 0.0, 1.0],
860            [1.0, 1.0, 0.0, 0.0],
861            [0.0, 0.0, 1.0, 1.0],
862            [1.0, 0.0, 0.0, 1.0],
863            [0.0, 1.0, 1.0, 0.0]
864        ];
865        let y = array![0, 1, 0, 1, -1, -1]; // Last two are unlabeled
866
867        let dbn = DeepBeliefNetwork::new()
868            .hidden_layers(vec![3, 2])
869            .unwrap()
870            .learning_rate(0.1)
871            .unwrap()
872            .pretraining_epochs(3)
873            .unwrap()
874            .finetuning_epochs(3)
875            .unwrap()
876            .batch_size(2)
877            .unwrap()
878            .random_state(42);
879
880        let fitted = dbn.fit(&X.view(), &y.view()).unwrap();
881
882        let predictions = fitted.predict(&X.view()).unwrap();
883        assert_eq!(predictions.len(), 6);
884
885        // Check that predictions are valid class labels
886        for &pred in predictions.iter() {
887            assert!(pred == 0 || pred == 1);
888        }
889
890        let probabilities = fitted.predict_proba(&X.view()).unwrap();
891        assert_eq!(probabilities.dim(), (6, 2));
892
893        // Check that probabilities sum to 1
894        for i in 0..6 {
895            let sum: f64 = probabilities.row(i).sum();
896            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
897        }
898
899        // Check that probabilities are between 0 and 1
900        for value in probabilities.iter() {
901            assert!(*value >= 0.0 && *value <= 1.0);
902        }
903    }
904
905    #[test]
906    #[allow(non_snake_case)]
907    fn test_dbn_insufficient_labeled_samples() {
908        let X = array![[1.0, 2.0], [2.0, 3.0]];
909        let y = array![-1, -1]; // All unlabeled
910
911        let dbn = DeepBeliefNetwork::new().hidden_layers(vec![2]).unwrap();
912
913        let result = dbn.fit(&X.view(), &y.view());
914        assert!(result.is_err());
915    }
916
917    #[test]
918    fn test_rbm_hidden_probs_computation() {
919        let mut rbm = RestrictedBoltzmannMachine::new(3, 2)
920            .unwrap()
921            .random_state(42);
922        rbm.initialize_weights().unwrap();
923
924        let visible = array![1.0, 0.0, 1.0];
925        let hidden_probs = rbm.compute_hidden_probs(&visible.view()).unwrap();
926
927        assert_eq!(hidden_probs.len(), 2);
928        for prob in hidden_probs.iter() {
929            assert!(*prob >= 0.0 && *prob <= 1.0);
930        }
931    }
932}