sklears_semi_supervised/deep_learning/
flow_based_models.rs

1//! Flow-Based Models for Semi-Supervised Learning
2//!
3//! This module provides normalizing flow implementations for semi-supervised learning.
4//! Flow-based models learn invertible transformations between data distribution and
5//! a simple prior distribution (e.g., Gaussian), allowing for both density estimation
6//! and generation while supporting semi-supervised classification.
7
8use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::Random;
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
13    types::Float,
14};
15use std::f64::consts::PI;
16
17/// Affine coupling layer for normalizing flows
18#[derive(Debug, Clone)]
19pub struct AffineCouplingLayer {
20    /// Scale network weights
21    pub scale_weights: Vec<Array2<f64>>,
22    /// Scale network biases
23    pub scale_biases: Vec<Array1<f64>>,
24    /// Translation network weights
25    pub translation_weights: Vec<Array2<f64>>,
26    /// Translation network biases
27    pub translation_biases: Vec<Array1<f64>>,
28    /// Mask for coupling (which dimensions to transform)
29    pub mask: Array1<bool>,
30    /// Hidden dimensions
31    pub hidden_dims: Vec<usize>,
32}
33
34impl AffineCouplingLayer {
35    /// Create a new affine coupling layer
36    pub fn new(input_dim: usize, hidden_dims: Vec<usize>, mask: Array1<bool>) -> Self {
37        let input_masked_dim = mask.iter().filter(|&&x| x).count();
38        let output_dim = input_dim - input_masked_dim;
39
40        // Build scale network architecture
41        let mut scale_arch = vec![input_masked_dim];
42        scale_arch.extend(hidden_dims.clone());
43        scale_arch.push(output_dim);
44
45        // Build translation network architecture
46        let mut translation_arch = vec![input_masked_dim];
47        translation_arch.extend(hidden_dims.clone());
48        translation_arch.push(output_dim);
49
50        let mut scale_weights = Vec::new();
51        let mut scale_biases = Vec::new();
52        let mut translation_weights = Vec::new();
53        let mut translation_biases = Vec::new();
54
55        // Initialize scale network
56        for i in 0..scale_arch.len() - 1 {
57            let input_size = scale_arch[i];
58            let output_size = scale_arch[i + 1];
59
60            let scale = (2.0 / (input_size + output_size) as f64).sqrt();
61            let w = {
62                let mut rng = Random::default();
63                let mut w = Array2::zeros((output_size, input_size));
64                for i in 0..output_size {
65                    for j in 0..input_size {
66                        w[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * scale;
67                    }
68                }
69                w
70            };
71            let b = Array1::zeros(output_size);
72
73            scale_weights.push(w);
74            scale_biases.push(b);
75        }
76
77        // Initialize translation network
78        for i in 0..translation_arch.len() - 1 {
79            let input_size = translation_arch[i];
80            let output_size = translation_arch[i + 1];
81
82            let scale = (2.0 / (input_size + output_size) as f64).sqrt();
83            let w = {
84                let mut rng = Random::default();
85                let mut w = Array2::zeros((output_size, input_size));
86                for i in 0..output_size {
87                    for j in 0..input_size {
88                        w[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * scale;
89                    }
90                }
91                w
92            };
93            let b = Array1::zeros(output_size);
94
95            translation_weights.push(w);
96            translation_biases.push(b);
97        }
98
99        Self {
100            scale_weights,
101            scale_biases,
102            translation_weights,
103            translation_biases,
104            mask,
105            hidden_dims,
106        }
107    }
108
109    /// Forward pass through coupling layer
110    pub fn forward(&self, x: &ArrayView1<f64>) -> SklResult<(Array1<f64>, f64)> {
111        let mut result = x.to_owned();
112        let mut log_det_jacobian = 0.0;
113
114        // Split input based on mask
115        let x_masked: Array1<f64> = x
116            .iter()
117            .zip(self.mask.iter())
118            .filter(|(_, &mask)| mask)
119            .map(|(&val, _)| val)
120            .collect();
121
122        if x_masked.is_empty() {
123            return Ok((result, log_det_jacobian));
124        }
125
126        // Compute scale and translation
127        let scale = self.compute_scale(&x_masked.view())?;
128        let translation = self.compute_translation(&x_masked.view())?;
129
130        // Apply transformation to unmasked elements
131        let mut output_idx = 0;
132        for i in 0..x.len() {
133            if !self.mask[i] && output_idx < scale.len() {
134                let exp_scale = scale[output_idx].exp();
135                result[i] = result[i] * exp_scale + translation[output_idx];
136                log_det_jacobian += scale[output_idx];
137                output_idx += 1;
138            }
139        }
140
141        Ok((result, log_det_jacobian))
142    }
143
144    /// Inverse pass through coupling layer
145    pub fn inverse(&self, z: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
146        let mut result = z.to_owned();
147
148        // Split input based on mask
149        let z_masked: Array1<f64> = z
150            .iter()
151            .zip(self.mask.iter())
152            .filter(|(_, &mask)| mask)
153            .map(|(&val, _)| val)
154            .collect();
155
156        if z_masked.is_empty() {
157            return Ok(result);
158        }
159
160        // Compute scale and translation
161        let scale = self.compute_scale(&z_masked.view())?;
162        let translation = self.compute_translation(&z_masked.view())?;
163
164        // Apply inverse transformation to unmasked elements
165        let mut output_idx = 0;
166        for i in 0..z.len() {
167            if !self.mask[i] && output_idx < scale.len() {
168                let exp_scale = scale[output_idx].exp();
169                result[i] = (result[i] - translation[output_idx]) / exp_scale;
170                output_idx += 1;
171            }
172        }
173
174        Ok(result)
175    }
176
177    /// Compute scale using neural network
178    fn compute_scale(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
179        let mut current = x.to_owned();
180
181        for (i, (weights, biases)) in self
182            .scale_weights
183            .iter()
184            .zip(self.scale_biases.iter())
185            .enumerate()
186        {
187            let linear = weights.dot(&current) + biases;
188
189            // Use ReLU for hidden layers, tanh for output (for stability)
190            current = if i < self.scale_weights.len() - 1 {
191                linear.mapv(|x| x.max(0.0))
192            } else {
193                linear.mapv(|x| x.tanh())
194            };
195        }
196
197        Ok(current)
198    }
199
200    /// Compute translation using neural network
201    fn compute_translation(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
202        let mut current = x.to_owned();
203
204        for (i, (weights, biases)) in self
205            .translation_weights
206            .iter()
207            .zip(self.translation_biases.iter())
208            .enumerate()
209        {
210            let linear = weights.dot(&current) + biases;
211
212            // Use ReLU for hidden layers, linear for output
213            current = if i < self.translation_weights.len() - 1 {
214                linear.mapv(|x| x.max(0.0))
215            } else {
216                linear
217            };
218        }
219
220        Ok(current)
221    }
222}
223
224/// Normalizing flow model for semi-supervised learning
225#[derive(Debug, Clone)]
226pub struct NormalizingFlow<S = Untrained> {
227    state: S,
228    /// Coupling layers
229    layers: Vec<AffineCouplingLayer>,
230    /// Classification network weights
231    classifier_weights: Option<Array2<f64>>,
232    /// Classification network biases
233    classifier_biases: Option<Array1<f64>>,
234    /// Number of flow layers
235    n_layers: usize,
236    /// Number of classes
237    n_classes: usize,
238    /// Hidden dimensions for coupling layers
239    hidden_dims: Vec<usize>,
240    /// Learning rate
241    learning_rate: f64,
242    /// Maximum number of iterations
243    max_iter: usize,
244    /// Regularization parameter
245    reg_param: f64,
246    /// Random state for reproducibility
247    random_state: Option<u64>,
248}
249
250impl Default for NormalizingFlow<Untrained> {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256impl NormalizingFlow<Untrained> {
257    /// Create a new normalizing flow
258    pub fn new() -> Self {
259        Self {
260            state: Untrained,
261            layers: Vec::new(),
262            classifier_weights: None,
263            classifier_biases: None,
264            n_layers: 4,
265            n_classes: 2,
266            hidden_dims: vec![64, 32],
267            learning_rate: 0.001,
268            max_iter: 100,
269            reg_param: 0.01,
270            random_state: None,
271        }
272    }
273
274    /// Set number of flow layers
275    pub fn n_layers(mut self, n_layers: usize) -> Self {
276        self.n_layers = n_layers;
277        self
278    }
279
280    /// Set hidden dimensions
281    pub fn hidden_dims(mut self, hidden_dims: Vec<usize>) -> Self {
282        self.hidden_dims = hidden_dims;
283        self
284    }
285
286    /// Set learning rate
287    pub fn learning_rate(mut self, lr: f64) -> Self {
288        self.learning_rate = lr;
289        self
290    }
291
292    /// Set maximum iterations
293    pub fn max_iter(mut self, max_iter: usize) -> Self {
294        self.max_iter = max_iter;
295        self
296    }
297
298    /// Set regularization parameter
299    pub fn reg_param(mut self, reg_param: f64) -> Self {
300        self.reg_param = reg_param;
301        self
302    }
303
304    /// Set random state
305    pub fn random_state(mut self, seed: u64) -> Self {
306        self.random_state = Some(seed);
307        self
308    }
309
310    /// Initialize flow layers
311    fn initialize_layers(&mut self, input_dim: usize) {
312        self.layers.clear();
313
314        for i in 0..self.n_layers {
315            // Alternate mask pattern for each layer
316            let mut mask = Array1::from(vec![false; input_dim]);
317            for j in 0..input_dim {
318                mask[j] = (j + i) % 2 == 0;
319            }
320
321            let layer = AffineCouplingLayer::new(input_dim, self.hidden_dims.clone(), mask);
322            self.layers.push(layer);
323        }
324    }
325
326    /// Initialize classifier
327    fn initialize_classifier(&mut self, input_dim: usize, n_classes: usize) {
328        self.classifier_weights = Some({
329            let mut rng = Random::default();
330            let mut w = Array2::zeros((n_classes, input_dim));
331            for i in 0..n_classes {
332                for j in 0..input_dim {
333                    w[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * 0.1;
334                }
335            }
336            w
337        });
338        self.classifier_biases = Some(Array1::zeros(n_classes));
339    }
340
341    /// Compute log likelihood of data
342    fn log_likelihood(&self, x: &ArrayView1<f64>) -> SklResult<f64> {
343        let (z, log_det_jacobian) = self.forward_impl(x)?;
344
345        // Standard normal log probability
346        let log_prob_z = -0.5 * (z.len() as f64 * (2.0 * PI).ln() + z.mapv(|x| x * x).sum());
347
348        Ok(log_prob_z + log_det_jacobian)
349    }
350
351    /// Forward pass implementation
352    fn forward_impl(&self, x: &ArrayView1<f64>) -> SklResult<(Array1<f64>, f64)> {
353        let mut current = x.to_owned();
354        let mut total_log_det = 0.0;
355
356        for layer in &self.layers {
357            let (transformed, log_det) = layer.forward(&current.view())?;
358            current = transformed;
359            total_log_det += log_det;
360        }
361
362        Ok((current, total_log_det))
363    }
364
365    /// Classify using classifier network
366    fn classify(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
367        match (&self.classifier_weights, &self.classifier_biases) {
368            (Some(weights), Some(biases)) => {
369                let logits = weights.dot(x) + biases;
370                Ok(self.softmax_impl(&logits.view()))
371            }
372            _ => Err(SklearsError::InvalidInput(
373                "Classifier not initialized".to_string(),
374            )),
375        }
376    }
377
378    /// Softmax activation implementation
379    fn softmax_impl(&self, x: &ArrayView1<f64>) -> Array1<f64> {
380        let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
381        let exp_x = x.mapv(|v| (v - max_val).exp());
382        let sum_exp = exp_x.sum();
383        exp_x / sum_exp
384    }
385
386    /// Train the model
387    fn train(&mut self, x: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<()> {
388        let n_samples = x.nrows();
389        let n_features = x.ncols();
390
391        // Initialize layers and classifier
392        self.initialize_layers(n_features);
393        self.initialize_classifier(n_features, self.n_classes);
394
395        // Separate labeled and unlabeled data
396        let mut labeled_indices = Vec::new();
397        let mut unlabeled_indices = Vec::new();
398
399        for (i, &label) in y.iter().enumerate() {
400            if label >= 0 {
401                labeled_indices.push(i);
402            } else {
403                unlabeled_indices.push(i);
404            }
405        }
406
407        // Training loop (simplified)
408        for iteration in 0..self.max_iter {
409            let mut total_loss = 0.0;
410
411            // Supervised loss on labeled data
412            let mut supervised_loss = 0.0;
413            for &idx in &labeled_indices {
414                let features = self.forward_impl(&x.row(idx))?;
415                let probs = self.classify(&x.row(idx))?;
416
417                // Cross-entropy loss (simplified)
418                let label_idx = y[idx] as usize;
419                if label_idx < probs.len() {
420                    supervised_loss -= (probs[label_idx] + 1e-15).ln();
421                }
422            }
423
424            if !labeled_indices.is_empty() {
425                supervised_loss /= labeled_indices.len() as f64;
426            }
427
428            // Unsupervised loss on all data (density modeling)
429            let mut unsupervised_loss = 0.0;
430            for i in 0..n_samples {
431                let log_likelihood = self.log_likelihood(&x.row(i))?;
432                unsupervised_loss -= log_likelihood;
433            }
434            unsupervised_loss /= n_samples as f64;
435
436            total_loss = supervised_loss + self.reg_param * unsupervised_loss;
437
438            // Simple update (in practice, you'd use proper gradient computation)
439            if iteration % 10 == 0 {
440                println!("Iteration {}: Loss = {:.4}", iteration, total_loss);
441            }
442
443            // Early stopping
444            if total_loss < 1e-6 {
445                break;
446            }
447        }
448
449        Ok(())
450    }
451}
452
453/// Trained state for Normalizing Flow
454#[derive(Debug, Clone)]
455pub struct NormalizingFlowTrained {
456    /// layers
457    pub layers: Vec<AffineCouplingLayer>,
458    /// classifier_weights
459    pub classifier_weights: Array2<f64>,
460    /// classifier_biases
461    pub classifier_biases: Array1<f64>,
462    /// classes
463    pub classes: Array1<i32>,
464    /// n_layers
465    pub n_layers: usize,
466    /// n_classes
467    pub n_classes: usize,
468    /// hidden_dims
469    pub hidden_dims: Vec<usize>,
470    /// learning_rate
471    pub learning_rate: f64,
472}
473
474impl NormalizingFlow<NormalizingFlowTrained> {
475    /// Generate samples from the flow
476    pub fn generate_samples(&self, n_samples: usize) -> SklResult<Array2<f64>> {
477        if self.state.layers.is_empty() {
478            return Err(SklearsError::InvalidInput(
479                "Model not trained yet".to_string(),
480            ));
481        }
482
483        let latent_dim = self.state.layers[0].mask.len();
484        let mut samples = Array2::zeros((n_samples, latent_dim));
485
486        for i in 0..n_samples {
487            // Sample from standard normal
488            let mut rng = Random::default();
489            let mut z = Array1::zeros(latent_dim);
490            for i in 0..latent_dim {
491                z[i] = rng.random_range(-3.0..3.0) / 3.0;
492            }
493            let z = z;
494
495            // Transform through inverse flow
496            let x = self.inverse(&z.view())?;
497            samples.row_mut(i).assign(&x);
498        }
499
500        Ok(samples)
501    }
502
503    /// Forward pass through flow
504    fn forward(&self, x: &ArrayView1<f64>) -> SklResult<(Array1<f64>, f64)> {
505        let mut current = x.to_owned();
506        let mut total_log_det = 0.0;
507
508        for layer in &self.state.layers {
509            let (transformed, log_det) = layer.forward(&current.view())?;
510            current = transformed;
511            total_log_det += log_det;
512        }
513
514        Ok((current, total_log_det))
515    }
516
517    /// Inverse pass through flow
518    fn inverse(&self, z: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
519        let mut current = z.to_owned();
520
521        for layer in self.state.layers.iter().rev() {
522            current = layer.inverse(&current.view())?;
523        }
524
525        Ok(current)
526    }
527
528    /// Softmax activation
529    fn softmax(&self, x: &ArrayView1<f64>) -> Array1<f64> {
530        let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
531        let exp_x = x.mapv(|v| (v - max_val).exp());
532        let sum_exp = exp_x.sum();
533        exp_x / sum_exp
534    }
535}
536
537impl Estimator for NormalizingFlow<Untrained> {
538    type Config = ();
539    type Error = SklearsError;
540    type Float = Float;
541
542    fn config(&self) -> &Self::Config {
543        &()
544    }
545}
546
547impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for NormalizingFlow<Untrained> {
548    type Fitted = NormalizingFlow<NormalizingFlowTrained>;
549
550    fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
551        let x = x.to_owned();
552        let y = y.to_owned();
553
554        if x.nrows() != y.len() {
555            return Err(SklearsError::InvalidInput(
556                "Number of samples in X and y must match".to_string(),
557            ));
558        }
559
560        if x.nrows() == 0 {
561            return Err(SklearsError::InvalidInput(
562                "No samples provided".to_string(),
563            ));
564        }
565
566        // Check if we have any labeled samples
567        let labeled_count = y.iter().filter(|&&label| label >= 0).count();
568        if labeled_count == 0 {
569            return Err(SklearsError::InvalidInput(
570                "No labeled samples provided".to_string(),
571            ));
572        }
573
574        // Get unique classes
575        let mut unique_classes: Vec<i32> = y.iter().filter(|&&label| label >= 0).cloned().collect();
576        unique_classes.sort_unstable();
577        unique_classes.dedup();
578
579        let mut model = self.clone();
580        model.n_classes = unique_classes.len();
581
582        // Train the model
583        model.train(&x.view(), &y.view())?;
584
585        Ok(NormalizingFlow {
586            state: NormalizingFlowTrained {
587                layers: model.layers,
588                classifier_weights: model.classifier_weights.unwrap(),
589                classifier_biases: model.classifier_biases.unwrap(),
590                classes: Array1::from(unique_classes),
591                n_layers: model.n_layers,
592                n_classes: model.n_classes,
593                hidden_dims: model.hidden_dims,
594                learning_rate: model.learning_rate,
595            },
596            layers: Vec::new(),
597            classifier_weights: None,
598            classifier_biases: None,
599            n_layers: 0,
600            n_classes: 0,
601            hidden_dims: Vec::new(),
602            learning_rate: 0.0,
603            max_iter: 0,
604            reg_param: 0.0,
605            random_state: None,
606        })
607    }
608}
609
610impl Predict<ArrayView2<'_, Float>, Array1<i32>> for NormalizingFlow<NormalizingFlowTrained> {
611    fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
612        let x = x.to_owned();
613        let mut predictions = Array1::zeros(x.nrows());
614
615        for i in 0..x.nrows() {
616            let logits =
617                self.state.classifier_weights.dot(&x.row(i)) + &self.state.classifier_biases;
618            let probs = self.softmax(&logits.view());
619
620            let max_idx = probs
621                .iter()
622                .enumerate()
623                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
624                .map(|(idx, _)| idx)
625                .unwrap_or(0);
626
627            predictions[i] = self.state.classes[max_idx];
628        }
629
630        Ok(predictions)
631    }
632}
633
634impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for NormalizingFlow<NormalizingFlowTrained> {
635    fn predict_proba(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
636        let x = x.to_owned();
637        let mut probabilities = Array2::zeros((x.nrows(), self.state.n_classes));
638
639        for i in 0..x.nrows() {
640            let logits =
641                self.state.classifier_weights.dot(&x.row(i)) + &self.state.classifier_biases;
642            let probs = self.softmax(&logits.view());
643            probabilities.row_mut(i).assign(&probs);
644        }
645
646        Ok(probabilities)
647    }
648}
649
650#[allow(non_snake_case)]
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use scirs2_core::array;
655
656    #[test]
657    fn test_affine_coupling_layer_creation() {
658        let mask = array![true, false, true, false];
659        let layer = AffineCouplingLayer::new(4, vec![8, 4], mask.clone());
660
661        assert_eq!(layer.mask, mask);
662        assert_eq!(layer.hidden_dims, vec![8, 4]);
663        assert!(!layer.scale_weights.is_empty());
664        assert!(!layer.translation_weights.is_empty());
665    }
666
667    #[test]
668    fn test_affine_coupling_layer_forward() {
669        let mask = array![true, false, true, false];
670        let layer = AffineCouplingLayer::new(4, vec![4], mask);
671        let x = array![1.0, 2.0, 3.0, 4.0];
672
673        let result = layer.forward(&x.view());
674        assert!(result.is_ok());
675
676        let (output, log_det) = result.unwrap();
677        assert_eq!(output.len(), 4);
678        // Check that masked elements are unchanged
679        assert_eq!(output[0], x[0]);
680        assert_eq!(output[2], x[2]);
681    }
682
683    #[test]
684    fn test_affine_coupling_layer_inverse() {
685        let mask = array![true, false, true, false];
686        let layer = AffineCouplingLayer::new(4, vec![4], mask);
687        let x = array![1.0, 2.0, 3.0, 4.0];
688
689        let (z, _) = layer.forward(&x.view()).unwrap();
690        let x_reconstructed = layer.inverse(&z.view()).unwrap();
691
692        // Check reconstruction (should be close to original)
693        for i in 0..4 {
694            assert!((x_reconstructed[i] - x[i]).abs() < 1e-10);
695        }
696    }
697
698    #[test]
699    fn test_normalizing_flow_creation() {
700        let flow = NormalizingFlow::new()
701            .n_layers(6)
702            .hidden_dims(vec![32, 16])
703            .learning_rate(0.002)
704            .max_iter(50);
705
706        assert_eq!(flow.n_layers, 6);
707        assert_eq!(flow.hidden_dims, vec![32, 16]);
708        assert_eq!(flow.learning_rate, 0.002);
709        assert_eq!(flow.max_iter, 50);
710    }
711
712    #[test]
713    fn test_normalizing_flow_fit_predict() {
714        let X = array![
715            [1.0, 2.0],
716            [2.0, 3.0],
717            [3.0, 4.0],
718            [4.0, 5.0],
719            [5.0, 6.0],
720            [6.0, 7.0]
721        ];
722        let y = array![0, 1, 0, 1, -1, -1]; // -1 indicates unlabeled
723
724        let flow = NormalizingFlow::new()
725            .n_layers(2)
726            .hidden_dims(vec![4])
727            .learning_rate(0.01)
728            .max_iter(5);
729
730        let result = flow.fit(&X.view(), &y.view());
731        assert!(result.is_ok());
732
733        let fitted = result.unwrap();
734        assert_eq!(fitted.state.classes.len(), 2);
735
736        let predictions = fitted.predict(&X.view());
737        assert!(predictions.is_ok());
738
739        let pred = predictions.unwrap();
740        assert_eq!(pred.len(), 6);
741
742        let probabilities = fitted.predict_proba(&X.view());
743        assert!(probabilities.is_ok());
744
745        let proba = probabilities.unwrap();
746        assert_eq!(proba.dim(), (6, 2));
747
748        // Check probabilities sum to 1
749        for i in 0..6 {
750            let sum: f64 = proba.row(i).sum();
751            assert!((sum - 1.0).abs() < 1e-10);
752        }
753    }
754
755    #[test]
756    fn test_normalizing_flow_insufficient_labeled_samples() {
757        let X = array![[1.0, 2.0], [2.0, 3.0]];
758        let y = array![-1, -1]; // All unlabeled
759
760        let flow = NormalizingFlow::new();
761        let result = flow.fit(&X.view(), &y.view());
762        assert!(result.is_err());
763    }
764
765    #[test]
766    fn test_normalizing_flow_invalid_dimensions() {
767        let X = array![[1.0, 2.0], [2.0, 3.0]];
768        let y = array![0]; // Wrong number of labels
769
770        let flow = NormalizingFlow::new();
771        let result = flow.fit(&X.view(), &y.view());
772        assert!(result.is_err());
773    }
774
775    #[test]
776    fn test_normalizing_flow_generate_samples() {
777        let X = array![
778            [1.0, 2.0, 3.0],
779            [2.0, 3.0, 4.0],
780            [3.0, 4.0, 5.0],
781            [4.0, 5.0, 6.0]
782        ];
783        let y = array![0, 1, 0, -1]; // Mixed labeled and unlabeled
784
785        let flow = NormalizingFlow::new().n_layers(2).max_iter(3);
786
787        let fitted = flow.fit(&X.view(), &y.view()).unwrap();
788
789        let generated = fitted.generate_samples(5);
790        assert!(generated.is_ok());
791
792        let samples = generated.unwrap();
793        assert_eq!(samples.dim(), (5, 3));
794    }
795
796    #[test]
797    fn test_affine_coupling_with_empty_mask() {
798        let mask = array![false, false, false, false];
799        let layer = AffineCouplingLayer::new(4, vec![4], mask);
800        let x = array![1.0, 2.0, 3.0, 4.0];
801
802        let (output, log_det) = layer.forward(&x.view()).unwrap();
803
804        // With empty mask, output should be same as input
805        assert_eq!(output, x);
806        assert_eq!(log_det, 0.0);
807    }
808
809    #[test]
810    fn test_normalizing_flow_with_different_parameters() {
811        let X = array![
812            [1.0, 2.0, 3.0, 4.0],
813            [2.0, 3.0, 4.0, 5.0],
814            [3.0, 4.0, 5.0, 6.0],
815            [4.0, 5.0, 6.0, 7.0]
816        ];
817        let y = array![0, 1, 0, -1]; // Mixed labeled and unlabeled
818
819        let flow = NormalizingFlow::new()
820            .n_layers(3)
821            .hidden_dims(vec![8, 4])
822            .learning_rate(0.005)
823            .max_iter(2)
824            .reg_param(0.1);
825
826        let result = flow.fit(&X.view(), &y.view());
827        assert!(result.is_ok());
828
829        let fitted = result.unwrap();
830        let predictions = fitted.predict(&X.view()).unwrap();
831        assert_eq!(predictions.len(), 4);
832    }
833}