1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11use sklears_core::traits::{Fit, Predict, PredictProba};
12
13#[derive(Debug, Clone)]
19pub struct EnergyBasedModel {
20 hidden_dims: Vec<usize>,
22 n_classes: usize,
24 input_dim: usize,
26 learning_rate: f64,
28 epochs: usize,
30 regularization: f64,
32 n_negative_samples: usize,
34 temperature: f64,
36 classification_weight: f64,
38 margin: f64,
40 energy_weights: Vec<Array2<f64>>,
42 energy_biases: Vec<Array1<f64>>,
43 class_weights: Array2<f64>,
45 class_bias: Array1<f64>,
46 fitted: bool,
48}
49
50impl Default for EnergyBasedModel {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl EnergyBasedModel {
57 pub fn new() -> Self {
59 Self {
60 hidden_dims: vec![64, 32, 16],
61 n_classes: 2,
62 input_dim: 10,
63 learning_rate: 0.001,
64 epochs: 100,
65 regularization: 0.01,
66 n_negative_samples: 10,
67 temperature: 1.0,
68 classification_weight: 1.0,
69 margin: 1.0,
70 energy_weights: Vec::new(),
71 energy_biases: Vec::new(),
72 class_weights: Array2::zeros((0, 0)),
73 class_bias: Array1::zeros(0),
74 fitted: false,
75 }
76 }
77
78 pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
80 self.hidden_dims = dims;
81 self
82 }
83
84 pub fn n_classes(mut self, n_classes: usize) -> Self {
86 self.n_classes = n_classes;
87 self
88 }
89
90 pub fn input_dim(mut self, input_dim: usize) -> Self {
92 self.input_dim = input_dim;
93 self
94 }
95
96 pub fn learning_rate(mut self, lr: f64) -> Self {
98 self.learning_rate = lr;
99 self
100 }
101
102 pub fn epochs(mut self, epochs: usize) -> Self {
104 self.epochs = epochs;
105 self
106 }
107
108 pub fn regularization(mut self, reg: f64) -> Self {
110 self.regularization = reg;
111 self
112 }
113
114 pub fn n_negative_samples(mut self, n_samples: usize) -> Self {
116 self.n_negative_samples = n_samples;
117 self
118 }
119
120 pub fn temperature(mut self, temp: f64) -> Self {
122 self.temperature = temp;
123 self
124 }
125
126 pub fn classification_weight(mut self, weight: f64) -> Self {
128 self.classification_weight = weight;
129 self
130 }
131
132 pub fn margin(mut self, margin: f64) -> Self {
134 self.margin = margin;
135 self
136 }
137
138 fn initialize_parameters(&mut self) -> Result<(), SklearsError> {
140 let mut layer_dims = vec![self.input_dim];
141 layer_dims.extend_from_slice(&self.hidden_dims);
142 layer_dims.push(1); self.energy_weights.clear();
145 self.energy_biases.clear();
146
147 for i in 0..layer_dims.len() - 1 {
149 let fan_in = layer_dims[i];
150 let fan_out = layer_dims[i + 1];
151 let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
152
153 let mut rng = Random::default();
155 let mut weight = Array2::<f64>::zeros((fan_in, fan_out));
156 for i in 0..fan_in {
157 for j in 0..fan_out {
158 let u: f64 = rng.random_range(0.0..1.0);
160 weight[(i, j)] = u * (2.0 * scale) - scale;
161 }
162 }
163 let bias = Array1::zeros(fan_out);
164
165 self.energy_weights.push(weight);
166 self.energy_biases.push(bias);
167 }
168
169 let last_hidden_dim = self.hidden_dims.last().unwrap_or(&self.input_dim);
171 let class_scale = (6.0 / (last_hidden_dim + self.n_classes) as f64).sqrt();
172 let mut rng = Random::default();
174 let mut class_weights = Array2::<f64>::zeros((*last_hidden_dim, self.n_classes));
175 for i in 0..*last_hidden_dim {
176 for j in 0..self.n_classes {
177 let u: f64 = rng.random_range(0.0..1.0);
179 class_weights[(i, j)] = u * (2.0 * class_scale) - class_scale;
180 }
181 }
182 self.class_weights = class_weights;
183 self.class_bias = Array1::zeros(self.n_classes);
184
185 Ok(())
186 }
187
188 fn relu(&self, x: &Array1<f64>) -> Array1<f64> {
190 x.mapv(|v| v.max(0.0))
191 }
192
193 fn leaky_relu(&self, x: &Array1<f64>, alpha: f64) -> Array1<f64> {
195 x.mapv(|v| if v > 0.0 { v } else { alpha * v })
196 }
197
198 fn softmax(&self, x: &Array1<f64>) -> Array1<f64> {
200 let max_val = x.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
201 let exp_x = x.mapv(|v| ((v - max_val) / self.temperature).exp());
202 let sum_exp = exp_x.sum();
203 exp_x / sum_exp
204 }
205
206 fn compute_energy(&self, input: &ArrayView1<f64>) -> Result<f64, SklearsError> {
208 let mut activation = input.to_owned();
209 let mut hidden_features = Vec::new();
210
211 for (i, (weight, bias)) in self
213 .energy_weights
214 .iter()
215 .zip(self.energy_biases.iter())
216 .enumerate()
217 {
218 let linear = activation.dot(weight) + bias;
219
220 if i < self.energy_weights.len() - 1 {
221 activation = self.leaky_relu(&linear, 0.01);
223 hidden_features.push(activation.clone());
224 } else {
225 return Ok(linear[0]);
227 }
228 }
229
230 Err(SklearsError::NumericalError(
231 "Energy computation failed".to_string(),
232 ))
233 }
234
235 fn get_hidden_features(&self, input: &ArrayView1<f64>) -> Result<Array1<f64>, SklearsError> {
237 let mut activation = input.to_owned();
238
239 for i in 0..self.energy_weights.len() - 1 {
241 let weight = &self.energy_weights[i];
242 let bias = &self.energy_biases[i];
243 let linear = activation.dot(weight) + bias;
244 activation = self.leaky_relu(&linear, 0.01);
245 }
246
247 Ok(activation)
248 }
249
250 fn compute_classification_probs(
252 &self,
253 input: &ArrayView1<f64>,
254 ) -> Result<Array1<f64>, SklearsError> {
255 let features = self.get_hidden_features(input)?;
256 let logits = features.dot(&self.class_weights) + &self.class_bias;
257 Ok(self.softmax(&logits))
258 }
259
260 fn generate_negative_samples(
262 &self,
263 positive_samples: &ArrayView2<f64>,
264 ) -> Result<Array2<f64>, SklearsError> {
265 let n_samples = positive_samples.nrows();
266 let input_dim = positive_samples.ncols();
267
268 let mut rng = Random::default();
270 let mut negative_samples = Array2::<f64>::zeros((self.n_negative_samples, input_dim));
271 for i in 0..self.n_negative_samples {
272 for j in 0..input_dim {
273 let u1: f64 = rng.random_range(0.0..1.0);
275 let u2: f64 = rng.random_range(0.0..1.0);
276 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
277 negative_samples[(i, j)] = z; }
279 }
280
281 for j in 0..input_dim {
283 let column = positive_samples.column(j);
284 let mean = column.mean().unwrap_or(0.0);
285 let std = column.std(0.0);
286
287 for i in 0..self.n_negative_samples {
288 negative_samples[[i, j]] = negative_samples[[i, j]] * std + mean;
289 }
290 }
291
292 Ok(negative_samples)
293 }
294
295 fn contrastive_loss(
297 &self,
298 positive_energies: &Array1<f64>,
299 negative_energies: &Array1<f64>,
300 ) -> f64 {
301 let mut loss = 0.0;
302
303 for &energy in positive_energies.iter() {
305 loss += energy;
306 }
307
308 for &energy in negative_energies.iter() {
310 loss += (self.margin - energy).max(0.0);
311 }
312
313 loss / (positive_energies.len() + negative_energies.len()) as f64
314 }
315
316 pub fn energy_to_probability(&self, energy: f64) -> f64 {
318 (-energy / self.temperature).exp()
319 }
320
321 pub fn langevin_sample(
323 &self,
324 initial_sample: &ArrayView1<f64>,
325 n_steps: usize,
326 step_size: f64,
327 ) -> Result<Array1<f64>, SklearsError> {
328 if !self.fitted {
329 return Err(SklearsError::NotFitted {
330 operation: "sampling".to_string(),
331 });
332 }
333
334 let mut sample = initial_sample.to_owned();
335
336 for _ in 0..n_steps {
337 let mut gradient = Array1::zeros(sample.len());
339 let epsilon = 1e-6;
340
341 for i in 0..sample.len() {
342 sample[i] += epsilon;
344 let energy_plus = self.compute_energy(&sample.view())?;
345 sample[i] -= 2.0 * epsilon;
346 let energy_minus = self.compute_energy(&sample.view())?;
347 sample[i] += epsilon; gradient[i] = (energy_plus - energy_minus) / (2.0 * epsilon);
350 }
351
352 let mut rng = Random::default();
354 let noise_std = (2.0 * step_size).sqrt();
355 let mut noise = Array1::zeros(sample.len());
356 for i in 0..sample.len() {
357 noise[i] = rng.random_range(-3.0..3.0) * noise_std / 3.0; }
360 sample = &sample - step_size * &gradient + &noise;
361 }
362
363 Ok(sample)
364 }
365
366 pub fn log_partition_function(&self, n_samples: usize) -> Result<f64, SklearsError> {
368 if !self.fitted {
369 return Err(SklearsError::NotFitted {
370 operation: "computing partition function".to_string(),
371 });
372 }
373
374 let mut log_sum = f64::NEG_INFINITY;
375
376 for _ in 0..n_samples {
378 let mut rng = Random::default();
379 let mut sample = Array1::zeros(self.input_dim);
380 for i in 0..self.input_dim {
381 sample[i] = rng.random_range(-3.0..3.0) / 3.0; }
384 let energy = self.compute_energy(&sample.view())?;
385 let log_prob = -energy / self.temperature;
386
387 if log_prob > log_sum {
389 log_sum = log_prob + (1.0f64 + (log_sum - log_prob).exp()).ln();
390 } else {
391 log_sum = log_sum + (1.0f64 + (log_prob - log_sum).exp()).ln();
392 }
393 }
394
395 Ok(log_sum - (n_samples as f64).ln())
396 }
397}
398
399impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for EnergyBasedModel {
400 type Fitted = EnergyBasedModel;
401
402 fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> Result<Self::Fitted, SklearsError> {
403 if X.nrows() != y.len() {
404 return Err(SklearsError::InvalidInput(
405 "Number of samples in X and y must match".to_string(),
406 ));
407 }
408
409 let mut model = self;
410 model.input_dim = X.ncols();
411 model.initialize_parameters()?;
412
413 let n_samples = X.nrows();
414 let labeled_mask: Vec<bool> = y.iter().map(|&label| label != -1).collect();
415 let n_labeled = labeled_mask.iter().filter(|&&labeled| labeled).count();
416
417 if n_labeled == 0 {
418 return Err(SklearsError::InvalidInput(
419 "At least one labeled sample required".to_string(),
420 ));
421 }
422
423 for epoch in 0..model.epochs {
425 let mut total_energy_loss = 0.0;
426 let mut total_class_loss = 0.0;
427 let mut n_processed = 0;
428
429 let negative_samples = model.generate_negative_samples(X)?;
431
432 let mut positive_energies = Array1::zeros(n_samples);
434 for i in 0..n_samples {
435 positive_energies[i] = model.compute_energy(&X.row(i))?;
436 }
437
438 let mut negative_energies = Array1::zeros(model.n_negative_samples);
440 for i in 0..model.n_negative_samples {
441 negative_energies[i] = model.compute_energy(&negative_samples.row(i))?;
442 }
443
444 let energy_loss = model.contrastive_loss(&positive_energies, &negative_energies);
446 total_energy_loss += energy_loss;
447
448 for i in 0..n_samples {
450 if labeled_mask[i] {
451 let sample = X.row(i);
452 let label = y[i];
453
454 let class_probs = model.compute_classification_probs(&sample)?;
455 let target_class = label as usize;
456
457 if target_class >= model.n_classes {
458 return Err(SklearsError::InvalidInput(format!(
459 "Label {} exceeds number of classes {}",
460 target_class, model.n_classes
461 )));
462 }
463
464 let class_loss = -class_probs[target_class].ln();
466 total_class_loss += model.classification_weight * class_loss;
467 }
468
469 n_processed += 1;
470 }
471
472 if epoch % 10 == 0 {
475 println!(
476 "Epoch {}: Energy loss = {:.4}, Class loss = {:.4}",
477 epoch,
478 total_energy_loss,
479 total_class_loss / n_labeled as f64
480 );
481 }
482
483 for weight in &mut model.energy_weights {
485 weight.mapv_inplace(|w| w * (1.0 - model.learning_rate * model.regularization));
486 }
487 }
488
489 model.fitted = true;
490 Ok(model)
491 }
492}
493
494impl Predict<ArrayView2<'_, f64>, Array1<i32>> for EnergyBasedModel {
495 fn predict(&self, X: &ArrayView2<f64>) -> Result<Array1<i32>, SklearsError> {
496 if !self.fitted {
497 return Err(SklearsError::NotFitted {
498 operation: "making predictions".to_string(),
499 });
500 }
501
502 let mut predictions = Array1::zeros(X.nrows());
503
504 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
505 let class_probs = self.compute_classification_probs(&sample)?;
506 let predicted_class = class_probs
507 .iter()
508 .enumerate()
509 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
510 .unwrap()
511 .0;
512 predictions[i] = predicted_class as i32;
513 }
514
515 Ok(predictions)
516 }
517}
518
519impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for EnergyBasedModel {
520 fn predict_proba(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
521 if !self.fitted {
522 return Err(SklearsError::NotFitted {
523 operation: "making predictions".to_string(),
524 });
525 }
526
527 let mut probabilities = Array2::zeros((X.nrows(), self.n_classes));
528
529 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
530 let class_probs = self.compute_classification_probs(&sample)?;
531 probabilities.row_mut(i).assign(&class_probs);
532 }
533
534 Ok(probabilities)
535 }
536}
537
538#[allow(non_snake_case)]
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use scirs2_core::array;
543
544 #[test]
545 fn test_energy_based_model_creation() {
546 let model = EnergyBasedModel::new()
547 .hidden_dims(vec![32, 16, 8])
548 .n_classes(3)
549 .input_dim(5)
550 .learning_rate(0.01)
551 .epochs(50)
552 .regularization(0.1)
553 .n_negative_samples(5)
554 .temperature(0.8)
555 .classification_weight(2.0)
556 .margin(2.0);
557
558 assert_eq!(model.hidden_dims, vec![32, 16, 8]);
559 assert_eq!(model.n_classes, 3);
560 assert_eq!(model.input_dim, 5);
561 assert_eq!(model.learning_rate, 0.01);
562 assert_eq!(model.epochs, 50);
563 assert_eq!(model.regularization, 0.1);
564 assert_eq!(model.n_negative_samples, 5);
565 assert_eq!(model.temperature, 0.8);
566 assert_eq!(model.classification_weight, 2.0);
567 assert_eq!(model.margin, 2.0);
568 }
569
570 #[test]
571 #[allow(non_snake_case)]
572 fn test_energy_based_model_fit_predict() {
573 let X = array![
574 [1.0, 2.0, 3.0],
575 [2.0, 3.0, 4.0],
576 [3.0, 4.0, 5.0],
577 [4.0, 5.0, 6.0]
578 ];
579 let y = array![0, 1, -1, 0]; let model = EnergyBasedModel::new()
582 .n_classes(2)
583 .input_dim(3)
584 .epochs(10)
585 .learning_rate(0.01)
586 .n_negative_samples(3);
587
588 let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
589 let predictions = fitted_model.predict(&X.view()).unwrap();
590 let probabilities = fitted_model.predict_proba(&X.view()).unwrap();
591
592 assert_eq!(predictions.len(), 4);
593 assert_eq!(probabilities.dim(), (4, 2));
594
595 for i in 0..4 {
597 let sum: f64 = probabilities.row(i).sum();
598 assert!((sum - 1.0).abs() < 1e-6);
599 }
600 }
601
602 #[test]
603 #[allow(non_snake_case)]
604 fn test_energy_based_model_insufficient_labeled_samples() {
605 let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
606 let y = array![-1, -1]; let model = EnergyBasedModel::new().n_classes(2).input_dim(3).epochs(10);
609
610 let result = model.fit(&X.view(), &y.view());
611 assert!(result.is_err());
612 }
613
614 #[test]
615 #[allow(non_snake_case)]
616 fn test_energy_based_model_invalid_dimensions() {
617 let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
618 let y = array![0]; let model = EnergyBasedModel::new();
621 let result = model.fit(&X.view(), &y.view());
622 assert!(result.is_err());
623 }
624
625 #[test]
626 fn test_energy_computation() {
627 let model = EnergyBasedModel::new().input_dim(3).hidden_dims(vec![4, 2]);
628
629 let mut model = model.clone();
630 model.initialize_parameters().unwrap();
631
632 let input = array![1.0, 2.0, 3.0];
633 let energy = model.compute_energy(&input.view()).unwrap();
634
635 assert!(energy.is_finite());
636 }
637
638 #[test]
639 fn test_energy_to_probability() {
640 let model = EnergyBasedModel::new().temperature(1.0);
641 let energy = 2.0;
642 let prob = model.energy_to_probability(energy);
643
644 assert!(prob > 0.0);
645 assert!(prob <= 1.0);
646 assert!((prob - (-2.0f64).exp()).abs() < 1e-10);
647 }
648
649 #[test]
650 #[allow(non_snake_case)]
651 fn test_negative_sample_generation() {
652 let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
653
654 let model = EnergyBasedModel::new().input_dim(3).n_negative_samples(5);
655
656 let negative_samples = model.generate_negative_samples(&X.view()).unwrap();
657
658 assert_eq!(negative_samples.dim(), (5, 3));
659 }
660
661 #[test]
662 fn test_contrastive_loss_computation() {
663 let model = EnergyBasedModel::new().margin(1.0);
664
665 let positive_energies = array![0.5, 1.0, 0.8];
666 let negative_energies = array![2.0, 1.5, 2.5];
667
668 let loss = model.contrastive_loss(&positive_energies, &negative_energies);
669
670 assert!(loss >= 0.0);
671 assert!(loss.is_finite());
672 }
673
674 #[test]
675 fn test_softmax_computation() {
676 let model = EnergyBasedModel::new().temperature(1.0);
677 let logits = array![1.0, 2.0, 3.0];
678 let probs = model.softmax(&logits);
679
680 let sum: f64 = probs.sum();
681 assert!((sum - 1.0).abs() < 1e-10);
682
683 assert!(probs[0] < probs[1]);
685 assert!(probs[1] < probs[2]);
686 }
687
688 #[test]
689 fn test_relu_activation() {
690 let model = EnergyBasedModel::new();
691 let input = array![-1.0, 0.0, 1.0, 2.0];
692 let output = model.relu(&input);
693
694 assert_eq!(output, array![0.0, 0.0, 1.0, 2.0]);
695 }
696
697 #[test]
698 fn test_leaky_relu_activation() {
699 let model = EnergyBasedModel::new();
700 let input = array![-1.0, 0.0, 1.0, 2.0];
701 let output = model.leaky_relu(&input, 0.1);
702
703 assert_eq!(output, array![-0.1, 0.0, 1.0, 2.0]);
704 }
705
706 #[test]
707 #[allow(non_snake_case)]
708 fn test_energy_based_model_not_fitted_error() {
709 let model = EnergyBasedModel::new();
710 let X = array![[1.0, 2.0, 3.0]];
711
712 let result = model.predict(&X.view());
713 assert!(result.is_err());
714
715 let result = model.predict_proba(&X.view());
716 assert!(result.is_err());
717
718 let sample = array![1.0, 2.0, 3.0];
719 let result = model.langevin_sample(&sample.view(), 10, 0.01);
720 assert!(result.is_err());
721
722 let result = model.log_partition_function(100);
723 assert!(result.is_err());
724 }
725
726 #[test]
727 #[allow(non_snake_case)]
728 fn test_energy_based_model_with_different_parameters() {
729 let X = array![
730 [1.0, 2.0, 3.0, 4.0],
731 [2.0, 3.0, 4.0, 5.0],
732 [3.0, 4.0, 5.0, 6.0]
733 ];
734 let y = array![0, 1, 2];
735
736 let model = EnergyBasedModel::new()
737 .hidden_dims(vec![8, 4])
738 .n_classes(3)
739 .input_dim(4)
740 .learning_rate(0.1)
741 .epochs(3)
742 .regularization(0.01)
743 .n_negative_samples(2)
744 .temperature(0.5)
745 .classification_weight(0.5)
746 .margin(0.5);
747
748 let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
749 let predictions = fitted_model.predict(&X.view()).unwrap();
750 let probabilities = fitted_model.predict_proba(&X.view()).unwrap();
751
752 assert_eq!(predictions.len(), 3);
753 assert_eq!(probabilities.dim(), (3, 3));
754 }
755
756 #[test]
757 fn test_hidden_features_extraction() {
758 let model = EnergyBasedModel::new().input_dim(3).hidden_dims(vec![4, 2]);
759
760 let mut model = model.clone();
761 model.initialize_parameters().unwrap();
762
763 let input = array![1.0, 2.0, 3.0];
764 let features = model.get_hidden_features(&input.view()).unwrap();
765
766 assert_eq!(features.len(), 2); assert!(features.iter().all(|&x| x.is_finite()));
768 }
769
770 #[test]
771 fn test_classification_probabilities() {
772 let model = EnergyBasedModel::new()
773 .input_dim(3)
774 .n_classes(2)
775 .hidden_dims(vec![4]);
776
777 let mut model = model.clone();
778 model.initialize_parameters().unwrap();
779
780 let input = array![1.0, 2.0, 3.0];
781 let probs = model.compute_classification_probs(&input.view()).unwrap();
782
783 assert_eq!(probs.len(), 2);
784 let sum: f64 = probs.sum();
785 assert!((sum - 1.0).abs() < 1e-10);
786 }
787
788 #[test]
789 #[allow(non_snake_case)]
790 fn test_langevin_sampling() {
791 let X = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
792 let y = array![0, 1];
793
794 let model = EnergyBasedModel::new().n_classes(2).input_dim(3).epochs(5);
795
796 let fitted_model = model.fit(&X.view(), &y.view()).unwrap();
797 let initial_sample = array![1.0, 2.0, 3.0];
798 let sample = fitted_model
799 .langevin_sample(&initial_sample.view(), 5, 0.01)
800 .unwrap();
801
802 assert_eq!(sample.len(), 3);
803 assert!(sample.iter().all(|&x| x.is_finite()));
804 }
805}