sklears_semi_supervised/deep_learning/
energy_based_models.rs

1//! Energy-based models for semi-supervised learning
2//!
3//! This module implements energy-based models (EBMs) that learn data distributions
4//! by associating low energy values with data samples and high energy values with
5//! unlikely samples. For semi-supervised learning, EBMs can incorporate both
6//! labeled and unlabeled data through energy minimization and contrastive learning.
7
8use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11use sklears_core::traits::{Fit, Predict, PredictProba};
12
13/// Energy-based model for semi-supervised learning
14///
15/// This implements an energy-based model that learns to assign low energy
16/// to data samples and high energy to unlikely samples. The model combines
17/// energy minimization with classification for semi-supervised learning.
18#[derive(Debug, Clone)]
19pub struct EnergyBasedModel {
20    /// Hidden layer dimensions for energy network
21    hidden_dims: Vec<usize>,
22    /// Number of classes for classification
23    n_classes: usize,
24    /// Input dimension
25    input_dim: usize,
26    /// Learning rate for gradient descent
27    learning_rate: f64,
28    /// Number of training epochs
29    epochs: usize,
30    /// Regularization parameter
31    regularization: f64,
32    /// Number of negative samples for contrastive learning
33    n_negative_samples: usize,
34    /// Temperature for Boltzmann distribution
35    temperature: f64,
36    /// Weight for classification loss vs energy loss
37    classification_weight: f64,
38    /// Contrastive learning margin
39    margin: f64,
40    /// Energy network parameters
41    energy_weights: Vec<Array2<f64>>,
42    energy_biases: Vec<Array1<f64>>,
43    /// Classification head parameters
44    class_weights: Array2<f64>,
45    class_bias: Array1<f64>,
46    /// Whether the model has been fitted
47    fitted: bool,
48}
49
50impl Default for EnergyBasedModel {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl EnergyBasedModel {
57    /// Create a new energy-based model
58    pub fn new() -> Self {
59        Self {
60            hidden_dims: vec![64, 32, 16],
61            n_classes: 2,
62            input_dim: 10,
63            learning_rate: 0.001,
64            epochs: 100,
65            regularization: 0.01,
66            n_negative_samples: 10,
67            temperature: 1.0,
68            classification_weight: 1.0,
69            margin: 1.0,
70            energy_weights: Vec::new(),
71            energy_biases: Vec::new(),
72            class_weights: Array2::zeros((0, 0)),
73            class_bias: Array1::zeros(0),
74            fitted: false,
75        }
76    }
77
78    /// Set the hidden layer dimensions
79    pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
80        self.hidden_dims = dims;
81        self
82    }
83
84    /// Set the number of classes
85    pub fn n_classes(mut self, n_classes: usize) -> Self {
86        self.n_classes = n_classes;
87        self
88    }
89
90    /// Set the input dimension
91    pub fn input_dim(mut self, input_dim: usize) -> Self {
92        self.input_dim = input_dim;
93        self
94    }
95
96    /// Set the learning rate
97    pub fn learning_rate(mut self, lr: f64) -> Self {
98        self.learning_rate = lr;
99        self
100    }
101
102    /// Set the number of epochs
103    pub fn epochs(mut self, epochs: usize) -> Self {
104        self.epochs = epochs;
105        self
106    }
107
108    /// Set the regularization parameter
109    pub fn regularization(mut self, reg: f64) -> Self {
110        self.regularization = reg;
111        self
112    }
113
114    /// Set the number of negative samples
115    pub fn n_negative_samples(mut self, n_samples: usize) -> Self {
116        self.n_negative_samples = n_samples;
117        self
118    }
119
120    /// Set the temperature parameter
121    pub fn temperature(mut self, temp: f64) -> Self {
122        self.temperature = temp;
123        self
124    }
125
126    /// Set the classification weight
127    pub fn classification_weight(mut self, weight: f64) -> Self {
128        self.classification_weight = weight;
129        self
130    }
131
132    /// Set the contrastive margin
133    pub fn margin(mut self, margin: f64) -> Self {
134        self.margin = margin;
135        self
136    }
137
138    /// Initialize the model parameters
139    fn initialize_parameters(&mut self) -> Result<(), SklearsError> {
140        let mut layer_dims = vec![self.input_dim];
141        layer_dims.extend_from_slice(&self.hidden_dims);
142        layer_dims.push(1); // Output single energy value
143
144        self.energy_weights.clear();
145        self.energy_biases.clear();
146
147        // Initialize energy network weights using Xavier initialization
148        for i in 0..layer_dims.len() - 1 {
149            let fan_in = layer_dims[i];
150            let fan_out = layer_dims[i + 1];
151            let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
152
153            // Xavier initialization - create weights manually
154            let mut rng = Random::default();
155            let mut weight = Array2::<f64>::zeros((fan_in, fan_out));
156            for i in 0..fan_in {
157                for j in 0..fan_out {
158                    // Generate uniform distributed random number in [-scale, scale]
159                    let u: f64 = rng.random_range(0.0..1.0);
160                    weight[(i, j)] = u * (2.0 * scale) - scale;
161                }
162            }
163            let bias = Array1::zeros(fan_out);
164
165            self.energy_weights.push(weight);
166            self.energy_biases.push(bias);
167        }
168
169        // Initialize classification head (from last hidden layer)
170        let last_hidden_dim = self.hidden_dims.last().unwrap_or(&self.input_dim);
171        let class_scale = (6.0 / (last_hidden_dim + self.n_classes) as f64).sqrt();
172        // Initialize class weights manually
173        let mut rng = Random::default();
174        let mut class_weights = Array2::<f64>::zeros((*last_hidden_dim, self.n_classes));
175        for i in 0..*last_hidden_dim {
176            for j in 0..self.n_classes {
177                // Generate uniform distributed random number in [-class_scale, class_scale]
178                let u: f64 = rng.random_range(0.0..1.0);
179                class_weights[(i, j)] = u * (2.0 * class_scale) - class_scale;
180            }
181        }
182        self.class_weights = class_weights;
183        self.class_bias = Array1::zeros(self.n_classes);
184
185        Ok(())
186    }
187
188    /// Apply ReLU activation function
189    fn relu(&self, x: &Array1<f64>) -> Array1<f64> {
190        x.mapv(|v| v.max(0.0))
191    }
192
193    /// Apply leaky ReLU activation function
194    fn leaky_relu(&self, x: &Array1<f64>, alpha: f64) -> Array1<f64> {
195        x.mapv(|v| if v > 0.0 { v } else { alpha * v })
196    }
197
198    /// Apply softmax activation function
199    fn softmax(&self, x: &Array1<f64>) -> Array1<f64> {
200        let max_val = x.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
201        let exp_x = x.mapv(|v| ((v - max_val) / self.temperature).exp());
202        let sum_exp = exp_x.sum();
203        exp_x / sum_exp
204    }
205
206    /// Compute energy for a given input
207    fn compute_energy(&self, input: &ArrayView1<f64>) -> Result<f64, SklearsError> {
208        let mut activation = input.to_owned();
209        let mut hidden_features = Vec::new();
210
211        // Forward pass through energy network
212        for (i, (weight, bias)) in self
213            .energy_weights
214            .iter()
215            .zip(self.energy_biases.iter())
216            .enumerate()
217        {
218            let linear = activation.dot(weight) + bias;
219
220            if i < self.energy_weights.len() - 1 {
221                // Apply Leaky ReLU for hidden layers
222                activation = self.leaky_relu(&linear, 0.01);
223                hidden_features.push(activation.clone());
224            } else {
225                // Linear output for energy
226                return Ok(linear[0]);
227            }
228        }
229
230        Err(SklearsError::NumericalError(
231            "Energy computation failed".to_string(),
232        ))
233    }
234
235    /// Get hidden features from the energy network
236    fn get_hidden_features(&self, input: &ArrayView1<f64>) -> Result<Array1<f64>, SklearsError> {
237        let mut activation = input.to_owned();
238
239        // Forward pass through energy network (excluding final layer)
240        for i in 0..self.energy_weights.len() - 1 {
241            let weight = &self.energy_weights[i];
242            let bias = &self.energy_biases[i];
243            let linear = activation.dot(weight) + bias;
244            activation = self.leaky_relu(&linear, 0.01);
245        }
246
247        Ok(activation)
248    }
249
250    /// Compute classification probabilities using hidden features
251    fn compute_classification_probs(
252        &self,
253        input: &ArrayView1<f64>,
254    ) -> Result<Array1<f64>, SklearsError> {
255        let features = self.get_hidden_features(input)?;
256        let logits = features.dot(&self.class_weights) + &self.class_bias;
257        Ok(self.softmax(&logits))
258    }
259
260    /// Generate negative samples using noise
261    fn generate_negative_samples(
262        &self,
263        positive_samples: &ArrayView2<f64>,
264    ) -> Result<Array2<f64>, SklearsError> {
265        let n_samples = positive_samples.nrows();
266        let input_dim = positive_samples.ncols();
267
268        // Generate random noise samples manually
269        let mut rng = Random::default();
270        let mut negative_samples = Array2::<f64>::zeros((self.n_negative_samples, input_dim));
271        for i in 0..self.n_negative_samples {
272            for j in 0..input_dim {
273                // Generate normal distributed random number (mean=0.0, std=1.0)
274                let u1: f64 = rng.random_range(0.0..1.0);
275                let u2: f64 = rng.random_range(0.0..1.0);
276                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
277                negative_samples[(i, j)] = z; // mean=0.0, std=1.0
278            }
279        }
280
281        // Scale based on positive sample statistics
282        for j in 0..input_dim {
283            let column = positive_samples.column(j);
284            let mean = column.mean().unwrap_or(0.0);
285            let std = column.std(0.0);
286
287            for i in 0..self.n_negative_samples {
288                negative_samples[[i, j]] = negative_samples[[i, j]] * std + mean;
289            }
290        }
291
292        Ok(negative_samples)
293    }
294
295    /// Compute contrastive loss between positive and negative samples
296    fn contrastive_loss(
297        &self,
298        positive_energies: &Array1<f64>,
299        negative_energies: &Array1<f64>,
300    ) -> f64 {
301        let mut loss = 0.0;
302
303        // Positive samples should have low energy
304        for &energy in positive_energies.iter() {
305            loss += energy;
306        }
307
308        // Negative samples should have high energy (contrastive)
309        for &energy in negative_energies.iter() {
310            loss += (self.margin - energy).max(0.0);
311        }
312
313        loss / (positive_energies.len() + negative_energies.len()) as f64
314    }
315
316    /// Compute Boltzmann probability from energy
317    pub fn energy_to_probability(&self, energy: f64) -> f64 {
318        (-energy / self.temperature).exp()
319    }
320
321    /// Sample from the model using Langevin dynamics
322    pub fn langevin_sample(
323        &self,
324        initial_sample: &ArrayView1<f64>,
325        n_steps: usize,
326        step_size: f64,
327    ) -> Result<Array1<f64>, SklearsError> {
328        if !self.fitted {
329            return Err(SklearsError::NotFitted {
330                operation: "sampling".to_string(),
331            });
332        }
333
334        let mut sample = initial_sample.to_owned();
335
336        for _ in 0..n_steps {
337            // Compute energy gradient (simplified numerical gradient)
338            let mut gradient = Array1::zeros(sample.len());
339            let epsilon = 1e-6;
340
341            for i in 0..sample.len() {
342                // Forward difference
343                sample[i] += epsilon;
344                let energy_plus = self.compute_energy(&sample.view())?;
345                sample[i] -= 2.0 * epsilon;
346                let energy_minus = self.compute_energy(&sample.view())?;
347                sample[i] += epsilon; // Reset
348
349                gradient[i] = (energy_plus - energy_minus) / (2.0 * epsilon);
350            }
351
352            // Langevin update - using simple Gaussian noise
353            let mut rng = Random::default();
354            let noise_std = (2.0 * step_size).sqrt();
355            let mut noise = Array1::zeros(sample.len());
356            for i in 0..sample.len() {
357                // Generate standard normal and scale
358                noise[i] = rng.random_range(-3.0..3.0) * noise_std / 3.0; // Approximate normal
359            }
360            sample = &sample - step_size * &gradient + &noise;
361        }
362
363        Ok(sample)
364    }
365
366    /// Compute the partition function approximation
367    pub fn log_partition_function(&self, n_samples: usize) -> Result<f64, SklearsError> {
368        if !self.fitted {
369            return Err(SklearsError::NotFitted {
370                operation: "computing partition function".to_string(),
371            });
372        }
373
374        let mut log_sum = f64::NEG_INFINITY;
375
376        // Monte Carlo approximation
377        for _ in 0..n_samples {
378            let mut rng = Random::default();
379            let mut sample = Array1::zeros(self.input_dim);
380            for i in 0..self.input_dim {
381                // Generate standard normal (approximate)
382                sample[i] = rng.random_range(-3.0..3.0) / 3.0; // Approximate standard normal
383            }
384            let energy = self.compute_energy(&sample.view())?;
385            let log_prob = -energy / self.temperature;
386
387            // LogSumExp trick
388            if log_prob > log_sum {
389                log_sum = log_prob + (1.0f64 + (log_sum - log_prob).exp()).ln();
390            } else {
391                log_sum = log_sum + (1.0f64 + (log_prob - log_sum).exp()).ln();
392            }
393        }
394
395        Ok(log_sum - (n_samples as f64).ln())
396    }
397}
398
399impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for EnergyBasedModel {
400    type Fitted = EnergyBasedModel;
401
402    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> Result<Self::Fitted, SklearsError> {
403        if X.nrows() != y.len() {
404            return Err(SklearsError::InvalidInput(
405                "Number of samples in X and y must match".to_string(),
406            ));
407        }
408
409        let mut model = self;
410        model.input_dim = X.ncols();
411        model.initialize_parameters()?;
412
413        let n_samples = X.nrows();
414        let labeled_mask: Vec<bool> = y.iter().map(|&label| label != -1).collect();
415        let n_labeled = labeled_mask.iter().filter(|&&labeled| labeled).count();
416
417        if n_labeled == 0 {
418            return Err(SklearsError::InvalidInput(
419                "At least one labeled sample required".to_string(),
420            ));
421        }
422
423        // Training loop
424        for epoch in 0..model.epochs {
425            let mut total_energy_loss = 0.0;
426            let mut total_class_loss = 0.0;
427            let mut n_processed = 0;
428
429            // Generate negative samples for contrastive learning
430            let negative_samples = model.generate_negative_samples(X)?;
431
432            // Compute positive energies
433            let mut positive_energies = Array1::zeros(n_samples);
434            for i in 0..n_samples {
435                positive_energies[i] = model.compute_energy(&X.row(i))?;
436            }
437
438            // Compute negative energies
439            let mut negative_energies = Array1::zeros(model.n_negative_samples);
440            for i in 0..model.n_negative_samples {
441                negative_energies[i] = model.compute_energy(&negative_samples.row(i))?;
442            }
443
444            // Contrastive energy loss
445            let energy_loss = model.contrastive_loss(&positive_energies, &negative_energies);
446            total_energy_loss += energy_loss;
447
448            // Classification loss for labeled samples
449            for i in 0..n_samples {
450                if labeled_mask[i] {
451                    let sample = X.row(i);
452                    let label = y[i];
453
454                    let class_probs = model.compute_classification_probs(&sample)?;
455                    let target_class = label as usize;
456
457                    if target_class >= model.n_classes {
458                        return Err(SklearsError::InvalidInput(format!(
459                            "Label {} exceeds number of classes {}",
460                            target_class, model.n_classes
461                        )));
462                    }
463
464                    // Cross-entropy loss
465                    let class_loss = -class_probs[target_class].ln();
466                    total_class_loss += model.classification_weight * class_loss;
467                }
468
469                n_processed += 1;
470            }
471
472            // Simple gradient descent update (simplified)
473            // In practice, this would use proper backpropagation
474            if epoch % 10 == 0 {
475                println!(
476                    "Epoch {}: Energy loss = {:.4}, Class loss = {:.4}",
477                    epoch,
478                    total_energy_loss,
479                    total_class_loss / n_labeled as f64
480                );
481            }
482
483            // Apply regularization
484            for weight in &mut model.energy_weights {
485                weight.mapv_inplace(|w| w * (1.0 - model.learning_rate * model.regularization));
486            }
487        }
488
489        model.fitted = true;
490        Ok(model)
491    }
492}
493
494impl Predict<ArrayView2<'_, f64>, Array1<i32>> for EnergyBasedModel {
495    fn predict(&self, X: &ArrayView2<f64>) -> Result<Array1<i32>, SklearsError> {
496        if !self.fitted {
497            return Err(SklearsError::NotFitted {
498                operation: "making predictions".to_string(),
499            });
500        }
501
502        let mut predictions = Array1::zeros(X.nrows());
503
504        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
505            let class_probs = self.compute_classification_probs(&sample)?;
506            let predicted_class = class_probs
507                .iter()
508                .enumerate()
509                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
510                .unwrap()
511                .0;
512            predictions[i] = predicted_class as i32;
513        }
514
515        Ok(predictions)
516    }
517}
518
519impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for EnergyBasedModel {
520    fn predict_proba(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
521        if !self.fitted {
522            return Err(SklearsError::NotFitted {
523                operation: "making predictions".to_string(),
524            });
525        }
526
527        let mut probabilities = Array2::zeros((X.nrows(), self.n_classes));
528
529        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
530            let class_probs = self.compute_classification_probs(&sample)?;
531            probabilities.row_mut(i).assign(&class_probs);
532        }
533
534        Ok(probabilities)
535    }
536}
537
538#[allow(non_snake_case)]
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use scirs2_core::array;
543
544    #[test]
545    fn test_energy_based_model_creation() {
546        let model = EnergyBasedModel::new()
547            .hidden_dims(vec![32, 16, 8])
548            .n_classes(3)
549            .input_dim(5)
550            .learning_rate(0.01)
551            .epochs(50)
552            .regularization(0.1)
553            .n_negative_samples(5)
554            .temperature(0.8)
555            .classification_weight(2.0)
556            .margin(2.0);
557
558        assert_eq!(model.hidden_dims, vec![32, 16, 8]);
559        assert_eq!(model.n_classes, 3);
560        assert_eq!(model.input_dim, 5);
561        assert_eq!(model.learning_rate, 0.01);
562        assert_eq!(model.epochs, 50);
563        assert_eq!(model.regularization, 0.1);
564        assert_eq!(model.n_negative_samples, 5);
565        assert_eq!(model.temperature, 0.8);
566        assert_eq!(model.classification_weight, 2.0);
567        assert_eq!(model.margin, 2.0);
568    }
569
570    #[test]
571    #[allow(non_snake_case)]
572    fn test_energy_based_model_fit_predict() {
573        let X = array![
574            [1.0, 2.0, 3.0],
575            [2.0, 3.0, 4.0],
576            [3.0, 4.0, 5.0],
577            [4.0, 5.0, 6.0]
578        ];
579        let y = array![0, 1, -1, 0]; // -1 indicates unlabeled
580
581        let model = EnergyBasedModel::new()
582            .n_classes(2)
583            .input_dim(3)
584            .epochs(10)
585            .learning_rate(0.01)
586            .n_negative_samples(3);
587
588        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
589        let predictions = fitted_model.predict(&X.view()).unwrap();
590        let probabilities = fitted_model.predict_proba(&X.view()).unwrap();
591
592        assert_eq!(predictions.len(), 4);
593        assert_eq!(probabilities.dim(), (4, 2));
594
595        // Check that probabilities sum to 1
596        for i in 0..4 {
597            let sum: f64 = probabilities.row(i).sum();
598            assert!((sum - 1.0).abs() < 1e-6);
599        }
600    }
601
602    #[test]
603    #[allow(non_snake_case)]
604    fn test_energy_based_model_insufficient_labeled_samples() {
605        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
606        let y = array![-1, -1]; // All unlabeled
607
608        let model = EnergyBasedModel::new().n_classes(2).input_dim(3).epochs(10);
609
610        let result = model.fit(&X.view(), &y.view());
611        assert!(result.is_err());
612    }
613
614    #[test]
615    #[allow(non_snake_case)]
616    fn test_energy_based_model_invalid_dimensions() {
617        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
618        let y = array![0]; // Mismatched dimensions
619
620        let model = EnergyBasedModel::new();
621        let result = model.fit(&X.view(), &y.view());
622        assert!(result.is_err());
623    }
624
625    #[test]
626    fn test_energy_computation() {
627        let model = EnergyBasedModel::new().input_dim(3).hidden_dims(vec![4, 2]);
628
629        let mut model = model.clone();
630        model.initialize_parameters().unwrap();
631
632        let input = array![1.0, 2.0, 3.0];
633        let energy = model.compute_energy(&input.view()).unwrap();
634
635        assert!(energy.is_finite());
636    }
637
638    #[test]
639    fn test_energy_to_probability() {
640        let model = EnergyBasedModel::new().temperature(1.0);
641        let energy = 2.0;
642        let prob = model.energy_to_probability(energy);
643
644        assert!(prob > 0.0);
645        assert!(prob <= 1.0);
646        assert!((prob - (-2.0f64).exp()).abs() < 1e-10);
647    }
648
649    #[test]
650    #[allow(non_snake_case)]
651    fn test_negative_sample_generation() {
652        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
653
654        let model = EnergyBasedModel::new().input_dim(3).n_negative_samples(5);
655
656        let negative_samples = model.generate_negative_samples(&X.view()).unwrap();
657
658        assert_eq!(negative_samples.dim(), (5, 3));
659    }
660
661    #[test]
662    fn test_contrastive_loss_computation() {
663        let model = EnergyBasedModel::new().margin(1.0);
664
665        let positive_energies = array![0.5, 1.0, 0.8];
666        let negative_energies = array![2.0, 1.5, 2.5];
667
668        let loss = model.contrastive_loss(&positive_energies, &negative_energies);
669
670        assert!(loss >= 0.0);
671        assert!(loss.is_finite());
672    }
673
674    #[test]
675    fn test_softmax_computation() {
676        let model = EnergyBasedModel::new().temperature(1.0);
677        let logits = array![1.0, 2.0, 3.0];
678        let probs = model.softmax(&logits);
679
680        let sum: f64 = probs.sum();
681        assert!((sum - 1.0).abs() < 1e-10);
682
683        // Check that probabilities are in ascending order
684        assert!(probs[0] < probs[1]);
685        assert!(probs[1] < probs[2]);
686    }
687
688    #[test]
689    fn test_relu_activation() {
690        let model = EnergyBasedModel::new();
691        let input = array![-1.0, 0.0, 1.0, 2.0];
692        let output = model.relu(&input);
693
694        assert_eq!(output, array![0.0, 0.0, 1.0, 2.0]);
695    }
696
697    #[test]
698    fn test_leaky_relu_activation() {
699        let model = EnergyBasedModel::new();
700        let input = array![-1.0, 0.0, 1.0, 2.0];
701        let output = model.leaky_relu(&input, 0.1);
702
703        assert_eq!(output, array![-0.1, 0.0, 1.0, 2.0]);
704    }
705
706    #[test]
707    #[allow(non_snake_case)]
708    fn test_energy_based_model_not_fitted_error() {
709        let model = EnergyBasedModel::new();
710        let X = array![[1.0, 2.0, 3.0]];
711
712        let result = model.predict(&X.view());
713        assert!(result.is_err());
714
715        let result = model.predict_proba(&X.view());
716        assert!(result.is_err());
717
718        let sample = array![1.0, 2.0, 3.0];
719        let result = model.langevin_sample(&sample.view(), 10, 0.01);
720        assert!(result.is_err());
721
722        let result = model.log_partition_function(100);
723        assert!(result.is_err());
724    }
725
726    #[test]
727    #[allow(non_snake_case)]
728    fn test_energy_based_model_with_different_parameters() {
729        let X = array![
730            [1.0, 2.0, 3.0, 4.0],
731            [2.0, 3.0, 4.0, 5.0],
732            [3.0, 4.0, 5.0, 6.0]
733        ];
734        let y = array![0, 1, 2];
735
736        let model = EnergyBasedModel::new()
737            .hidden_dims(vec![8, 4])
738            .n_classes(3)
739            .input_dim(4)
740            .learning_rate(0.1)
741            .epochs(3)
742            .regularization(0.01)
743            .n_negative_samples(2)
744            .temperature(0.5)
745            .classification_weight(0.5)
746            .margin(0.5);
747
748        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
749        let predictions = fitted_model.predict(&X.view()).unwrap();
750        let probabilities = fitted_model.predict_proba(&X.view()).unwrap();
751
752        assert_eq!(predictions.len(), 3);
753        assert_eq!(probabilities.dim(), (3, 3));
754    }
755
756    #[test]
757    fn test_hidden_features_extraction() {
758        let model = EnergyBasedModel::new().input_dim(3).hidden_dims(vec![4, 2]);
759
760        let mut model = model.clone();
761        model.initialize_parameters().unwrap();
762
763        let input = array![1.0, 2.0, 3.0];
764        let features = model.get_hidden_features(&input.view()).unwrap();
765
766        assert_eq!(features.len(), 2); // Last hidden layer dimension
767        assert!(features.iter().all(|&x| x.is_finite()));
768    }
769
770    #[test]
771    fn test_classification_probabilities() {
772        let model = EnergyBasedModel::new()
773            .input_dim(3)
774            .n_classes(2)
775            .hidden_dims(vec![4]);
776
777        let mut model = model.clone();
778        model.initialize_parameters().unwrap();
779
780        let input = array![1.0, 2.0, 3.0];
781        let probs = model.compute_classification_probs(&input.view()).unwrap();
782
783        assert_eq!(probs.len(), 2);
784        let sum: f64 = probs.sum();
785        assert!((sum - 1.0).abs() < 1e-10);
786    }
787
788    #[test]
789    #[allow(non_snake_case)]
790    fn test_langevin_sampling() {
791        let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
792        let y = array![0, 1];
793
794        let model = EnergyBasedModel::new().n_classes(2).input_dim(3).epochs(5);
795
796        let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
797        let initial_sample = array![1.0, 2.0, 3.0];
798        let sample = fitted_model
799            .langevin_sample(&initial_sample.view(), 5, 0.01)
800            .unwrap();
801
802        assert_eq!(sample.len(), 3);
803        assert!(sample.iter().all(|&x| x.is_finite()));
804    }
805}