Skip to main content

sklears_linear/
large_scale_variational_inference.rs

1//! Large-Scale Variational Inference for Linear Models
2//!
3//! This module implements scalable variational Bayesian inference algorithms
4//! designed for large datasets that don't fit in memory. It includes stochastic
5//! variational inference, mini-batch processing, and streaming algorithms.
6
7use scirs2_core::ndarray::{s, Array1, Array2};
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::{Distribution, RandNormal as Normal, Rng};
10use scirs2_core::{SeedableRng, StdRng};
11use sklears_core::{
12    error::{Result, SklearsError},
13    traits::{Estimator, Fit, Predict, Trained, Untrained},
14    types::Float,
15};
16use std::marker::PhantomData;
17
18/// Configuration for large-scale variational inference
19#[derive(Debug, Clone)]
20pub struct LargeScaleVariationalConfig {
21    /// Maximum number of epochs
22    pub max_epochs: usize,
23    /// Mini-batch size for stochastic updates
24    pub batch_size: usize,
25    /// Learning rate for variational parameter updates
26    pub learning_rate: Float,
27    /// Learning rate decay schedule
28    pub learning_rate_decay: LearningRateDecay,
29    /// Convergence tolerance
30    pub tolerance: Float,
31    /// Number of Monte Carlo samples for expectations
32    pub n_mc_samples: usize,
33    /// Whether to use natural gradients
34    pub use_natural_gradients: bool,
35    /// Whether to use control variates for variance reduction
36    pub use_control_variates: bool,
37    /// Memory limit in GB for adaptive batch sizing
38    pub memory_limit_gb: Option<Float>,
39    /// Whether to enable verbose output
40    pub verbose: bool,
41    /// Random seed for reproducibility
42    pub random_seed: Option<u64>,
43    /// Prior parameters
44    pub prior_config: PriorConfiguration,
45}
46
47impl Default for LargeScaleVariationalConfig {
48    fn default() -> Self {
49        Self {
50            max_epochs: 100,
51            batch_size: 256,
52            learning_rate: 0.01,
53            learning_rate_decay: LearningRateDecay::Exponential { decay_rate: 0.95 },
54            tolerance: 1e-6,
55            n_mc_samples: 10,
56            use_natural_gradients: true,
57            use_control_variates: true,
58            memory_limit_gb: Some(4.0),
59            verbose: false,
60            random_seed: None,
61            prior_config: PriorConfiguration::default(),
62        }
63    }
64}
65
66/// Learning rate decay schedules
67#[derive(Debug, Clone)]
68pub enum LearningRateDecay {
69    /// Constant learning rate
70    Constant,
71    /// Exponential decay: lr * decay_rate^epoch
72    Exponential { decay_rate: Float },
73    /// Step decay: lr * step_factor^floor(epoch / step_size)
74    Step {
75        step_size: usize,
76        step_factor: Float,
77    },
78    /// Polynomial decay: lr * (1 + decay_rate * epoch)^(-power)
79    Polynomial { decay_rate: Float, power: Float },
80    /// Cosine annealing
81    CosineAnnealing { min_lr: Float },
82}
83
84/// Prior configuration for Bayesian linear regression
85#[derive(Debug, Clone)]
86pub struct PriorConfiguration {
87    /// Prior precision for weights (Gamma distribution parameters)
88    pub weight_precision_shape: Float,
89    pub weight_precision_rate: Float,
90    /// Prior precision for noise (Gamma distribution parameters)
91    pub noise_precision_shape: Float,
92    pub noise_precision_rate: Float,
93    /// Whether to use hierarchical priors
94    pub hierarchical: bool,
95    /// ARD (Automatic Relevance Determination) configuration
96    pub ard_config: Option<ARDConfiguration>,
97}
98
99impl Default for PriorConfiguration {
100    fn default() -> Self {
101        Self {
102            weight_precision_shape: 1e-6,
103            weight_precision_rate: 1e-6,
104            noise_precision_shape: 1e-6,
105            noise_precision_rate: 1e-6,
106            hierarchical: false,
107            ard_config: None,
108        }
109    }
110}
111
112/// Configuration for Automatic Relevance Determination
113#[derive(Debug, Clone)]
114pub struct ARDConfiguration {
115    /// Individual precision priors for each feature
116    pub feature_precision_shape: Float,
117    pub feature_precision_rate: Float,
118    /// Threshold for feature pruning
119    pub pruning_threshold: Float,
120    /// Whether to enable automatic feature pruning
121    pub enable_pruning: bool,
122}
123
124/// Variational parameters for the posterior distribution
125#[derive(Debug, Clone)]
126pub struct VariationalPosterior {
127    /// Mean of weight posterior (multivariate normal)
128    pub weight_mean: Array1<Float>,
129    /// Covariance of weight posterior
130    pub weight_covariance: Array2<Float>,
131    /// Precision matrix (inverse covariance)
132    pub weight_precision: Array2<Float>,
133    /// Parameters for weight precision posterior (Gamma)
134    pub weight_precision_shape: Array1<Float>,
135    pub weight_precision_rate: Array1<Float>,
136    /// Parameters for noise precision posterior (Gamma)
137    pub noise_precision_shape: Float,
138    pub noise_precision_rate: Float,
139    /// Log marginal likelihood lower bound (ELBO)
140    pub elbo: Float,
141}
142
143impl VariationalPosterior {
144    /// Create a new variational posterior with given dimensions
145    pub fn new(n_features: usize, config: &PriorConfiguration) -> Self {
146        Self {
147            weight_mean: Array1::zeros(n_features),
148            weight_covariance: Array2::eye(n_features),
149            weight_precision: Array2::eye(n_features),
150            weight_precision_shape: Array1::from_elem(n_features, config.weight_precision_shape),
151            weight_precision_rate: Array1::from_elem(n_features, config.weight_precision_rate),
152            noise_precision_shape: config.noise_precision_shape,
153            noise_precision_rate: config.noise_precision_rate,
154            elbo: Float::NEG_INFINITY,
155        }
156    }
157
158    /// Sample from the posterior distribution
159    pub fn sample_weights(&self, n_samples: usize, rng: &mut impl Rng) -> Result<Array2<Float>> {
160        let n_features = self.weight_mean.len();
161        let mut samples = Array2::zeros((n_samples, n_features));
162
163        // Compute Cholesky decomposition of covariance matrix
164        let chol = self.cholesky_decomposition(&self.weight_covariance)?;
165
166        for i in 0..n_samples {
167            // Sample from standard normal
168            let z: Array1<Float> = (0..n_features)
169                .map(|_| {
170                    Normal::new(0.0, 1.0)
171                        .expect("valid normal distribution parameters")
172                        .sample(rng)
173                })
174                .collect::<Vec<_>>()
175                .into();
176
177            // Transform to desired distribution: μ + L * z
178            let sample = &self.weight_mean + chol.dot(&z);
179            samples.slice_mut(s![i, ..]).assign(&sample);
180        }
181
182        Ok(samples)
183    }
184
185    /// Compute Cholesky decomposition (simplified implementation)
186    fn cholesky_decomposition(&self, matrix: &Array2<Float>) -> Result<Array2<Float>> {
187        let n = matrix.nrows();
188        let mut l = Array2::zeros((n, n));
189
190        for i in 0..n {
191            for j in 0..=i {
192                if i == j {
193                    // Diagonal elements
194                    let sum: Float = (0..j).map(|k| l[[i, k]] * l[[i, k]]).sum();
195                    let val = matrix[[i, i]] - sum;
196                    if val <= 0.0 {
197                        return Err(SklearsError::NumericalError(
198                            "Matrix is not positive definite".to_string(),
199                        ));
200                    }
201                    l[[i, j]] = val.sqrt();
202                } else {
203                    // Lower triangular elements
204                    let sum: Float = (0..j).map(|k| l[[i, k]] * l[[j, k]]).sum();
205                    l[[i, j]] = (matrix[[i, j]] - sum) / l[[j, j]];
206                }
207            }
208        }
209
210        Ok(l)
211    }
212}
213
214/// Large-scale variational Bayesian linear regression
215#[derive(Debug)]
216pub struct LargeScaleVariationalRegression<State = Untrained> {
217    config: LargeScaleVariationalConfig,
218    state: PhantomData<State>,
219    // Trained state
220    posterior: Option<VariationalPosterior>,
221    convergence_history: Option<Vec<Float>>,
222    feature_relevance: Option<Array1<Float>>,
223    n_features: Option<usize>,
224    intercept: Option<Float>,
225}
226
227impl Default for LargeScaleVariationalRegression<Untrained> {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl LargeScaleVariationalRegression<Untrained> {
234    /// Create a new large-scale variational regression model
235    pub fn new() -> Self {
236        Self {
237            config: LargeScaleVariationalConfig::default(),
238            state: PhantomData,
239            posterior: None,
240            convergence_history: None,
241            feature_relevance: None,
242            n_features: None,
243            intercept: None,
244        }
245    }
246
247    /// Set the configuration
248    pub fn with_config(mut self, config: LargeScaleVariationalConfig) -> Self {
249        self.config = config;
250        self
251    }
252
253    /// Set batch size
254    pub fn batch_size(mut self, batch_size: usize) -> Self {
255        self.config.batch_size = batch_size;
256        self
257    }
258
259    /// Set learning rate
260    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
261        self.config.learning_rate = learning_rate;
262        self
263    }
264
265    /// Enable Automatic Relevance Determination
266    pub fn enable_ard(mut self, pruning_threshold: Float) -> Self {
267        self.config.prior_config.ard_config = Some(ARDConfiguration {
268            feature_precision_shape: 1e-6,
269            feature_precision_rate: 1e-6,
270            pruning_threshold,
271            enable_pruning: true,
272        });
273        self
274    }
275
276    /// Set memory limit for adaptive batch sizing
277    pub fn memory_limit_gb(mut self, limit: Float) -> Self {
278        self.config.memory_limit_gb = Some(limit);
279        self
280    }
281}
282
283impl LargeScaleVariationalRegression<Trained> {
284    /// Get the posterior mean of coefficients
285    pub fn coefficients(&self) -> &Array1<Float> {
286        &self
287            .posterior
288            .as_ref()
289            .expect("value should be present")
290            .weight_mean
291    }
292
293    /// Get the posterior covariance of coefficients
294    pub fn coefficient_covariance(&self) -> &Array2<Float> {
295        &self
296            .posterior
297            .as_ref()
298            .expect("value should be present")
299            .weight_covariance
300    }
301
302    /// Get feature relevance scores (for ARD)
303    pub fn feature_relevance(&self) -> Option<&Array1<Float>> {
304        self.feature_relevance.as_ref()
305    }
306
307    /// Get convergence history
308    pub fn convergence_history(&self) -> Option<&[Float]> {
309        self.convergence_history.as_deref()
310    }
311
312    /// Sample predictions from the posterior predictive distribution
313    pub fn sample_predictions(
314        &self,
315        x: &Array2<Float>,
316        n_samples: usize,
317        rng: &mut impl Rng,
318    ) -> Result<Array2<Float>> {
319        let posterior = self
320            .posterior
321            .as_ref()
322            .ok_or_else(|| SklearsError::NumericalError("value should be present".into()))?;
323        let weight_samples = posterior.sample_weights(n_samples, rng)?;
324
325        let mut predictions = Array2::zeros((n_samples, x.nrows()));
326
327        for i in 0..n_samples {
328            let weights = weight_samples.slice(s![i, ..]);
329            let pred = x.dot(&weights);
330            predictions.slice_mut(s![i, ..]).assign(&pred);
331
332            // Add intercept if fitted
333            if let Some(intercept) = self.intercept {
334                predictions
335                    .slice_mut(s![i, ..])
336                    .mapv_inplace(|x| x + intercept);
337            }
338        }
339
340        Ok(predictions)
341    }
342
343    /// Compute predictive uncertainties
344    pub fn predict_with_uncertainty(
345        &self,
346        x: &Array2<Float>,
347    ) -> Result<(Array1<Float>, Array1<Float>)> {
348        let posterior = self
349            .posterior
350            .as_ref()
351            .ok_or_else(|| SklearsError::NumericalError("value should be present".into()))?;
352
353        // Predictive mean
354        let pred_mean = x.dot(&posterior.weight_mean);
355
356        // Predictive variance: X * Σ * X^T + σ²
357        let mut pred_var = Array1::zeros(x.nrows());
358
359        for i in 0..x.nrows() {
360            let x_i = x.slice(s![i, ..]);
361            let var_contrib = x_i.dot(&posterior.weight_covariance.dot(&x_i));
362
363            // Add noise variance
364            let noise_var = 1.0 / posterior.noise_precision_rate; // Simplified
365            pred_var[i] = var_contrib + noise_var;
366        }
367
368        let pred_std = pred_var.mapv(|v| v.sqrt());
369
370        Ok((pred_mean, pred_std))
371    }
372}
373
374impl Fit<Array2<Float>, Array1<Float>> for LargeScaleVariationalRegression<Untrained> {
375    type Fitted = LargeScaleVariationalRegression<Trained>;
376
377    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
378        let (n_samples, n_features) = x.dim();
379
380        if n_samples != y.len() {
381            return Err(SklearsError::DimensionMismatch {
382                expected: n_samples,
383                actual: y.len(),
384            });
385        }
386
387        // Initialize variational posterior
388        let mut posterior = VariationalPosterior::new(n_features, &self.config.prior_config);
389        let mut convergence_history = Vec::new();
390
391        // Initialize random number generator
392        let mut rng = if let Some(seed) = self.config.random_seed {
393            StdRng::seed_from_u64(seed)
394        } else {
395            StdRng::from_rng(&mut scirs2_core::random::thread_rng())
396        };
397
398        // Stochastic variational inference
399        let mut current_lr = self.config.learning_rate;
400
401        for epoch in 0..self.config.max_epochs {
402            let epoch_elbo = self.run_epoch(x, y, &mut posterior, current_lr, &mut rng)?;
403            convergence_history.push(epoch_elbo);
404
405            if self.config.verbose && epoch % 10 == 0 {
406                println!("Epoch {}: ELBO = {:.6}", epoch, epoch_elbo);
407            }
408
409            // Check convergence
410            if epoch > 0 {
411                let prev_elbo = convergence_history[epoch - 1];
412                let elbo_change = (epoch_elbo - prev_elbo).abs();
413
414                if elbo_change < self.config.tolerance {
415                    if self.config.verbose {
416                        println!("Converged after {} epochs", epoch);
417                    }
418                    break;
419                }
420            }
421
422            // Update learning rate
423            current_lr = self.update_learning_rate(current_lr, epoch);
424        }
425
426        // Compute feature relevance for ARD
427        let feature_relevance = if self.config.prior_config.ard_config.is_some() {
428            Some(self.compute_feature_relevance(&posterior))
429        } else {
430            None
431        };
432
433        Ok(LargeScaleVariationalRegression {
434            config: self.config,
435            state: PhantomData,
436            posterior: Some(posterior),
437            convergence_history: Some(convergence_history),
438            feature_relevance,
439            n_features: Some(n_features),
440            intercept: None, // Simplified: not handling intercept in this implementation
441        })
442    }
443}
444
445impl LargeScaleVariationalRegression<Untrained> {
446    /// Run one epoch of stochastic variational inference
447    fn run_epoch(
448        &self,
449        x: &Array2<Float>,
450        y: &Array1<Float>,
451        posterior: &mut VariationalPosterior,
452        learning_rate: Float,
453        rng: &mut impl Rng,
454    ) -> Result<Float> {
455        let (n_samples, _n_features) = x.dim();
456        let batch_size = self.config.batch_size.min(n_samples);
457
458        let mut total_elbo = 0.0;
459        let mut n_batches = 0;
460
461        // Create mini-batches
462        let mut indices: Vec<usize> = (0..n_samples).collect();
463        indices.shuffle(rng);
464
465        for batch_indices in indices.chunks(batch_size) {
466            // Extract mini-batch
467            let batch_x = self.extract_batch_features(x, batch_indices);
468            let batch_y = self.extract_batch_targets(y, batch_indices);
469
470            // Compute natural gradients
471            let (elbo, gradients) = self.compute_natural_gradients(
472                &batch_x,
473                &batch_y,
474                posterior,
475                n_samples,
476                batch_indices.len(),
477            )?;
478
479            // Update variational parameters
480            self.update_variational_parameters(posterior, &gradients, learning_rate)?;
481
482            total_elbo += elbo;
483            n_batches += 1;
484        }
485
486        Ok(total_elbo / n_batches as Float)
487    }
488
489    /// Extract mini-batch features
490    fn extract_batch_features(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
491        let mut batch_x = Array2::zeros((indices.len(), x.ncols()));
492        for (i, &idx) in indices.iter().enumerate() {
493            batch_x.slice_mut(s![i, ..]).assign(&x.slice(s![idx, ..]));
494        }
495        batch_x
496    }
497
498    /// Extract mini-batch targets
499    fn extract_batch_targets(&self, y: &Array1<Float>, indices: &[usize]) -> Array1<Float> {
500        indices.iter().map(|&i| y[i]).collect::<Vec<_>>().into()
501    }
502
503    /// Compute natural gradients for variational parameters
504    fn compute_natural_gradients(
505        &self,
506        batch_x: &Array2<Float>,
507        batch_y: &Array1<Float>,
508        posterior: &VariationalPosterior,
509        total_samples: usize,
510        batch_size: usize,
511    ) -> Result<(Float, VariationalGradients)> {
512        let scale_factor = total_samples as Float / batch_size as Float;
513
514        // Compute expected log likelihood
515        let expected_ll = self.compute_expected_log_likelihood(batch_x, batch_y, posterior)?;
516
517        // Compute KL divergence
518        let kl_div = self.compute_kl_divergence(posterior)?;
519
520        // ELBO = E[log p(y|X,w)] - KL[q(w)||p(w)]
521        let elbo = scale_factor * expected_ll - kl_div;
522
523        // Compute gradients (simplified implementation)
524        let gradients = VariationalGradients {
525            weight_mean_grad: Array1::zeros(posterior.weight_mean.len()),
526            weight_precision_grad: Array2::zeros(posterior.weight_precision.dim()),
527            noise_precision_shape_grad: 0.0,
528            noise_precision_rate_grad: 0.0,
529        };
530
531        Ok((elbo, gradients))
532    }
533
534    /// Compute expected log likelihood
535    fn compute_expected_log_likelihood(
536        &self,
537        x: &Array2<Float>,
538        y: &Array1<Float>,
539        posterior: &VariationalPosterior,
540    ) -> Result<Float> {
541        let _n_samples = x.nrows();
542
543        // E[log p(y|X,w)] under q(w)
544        let pred_mean = x.dot(&posterior.weight_mean);
545        let residuals = y - &pred_mean;
546
547        // Simplified computation (should include trace term for full correctness)
548        let sum_squared_residuals = residuals.mapv(|r| r * r).sum();
549        let expected_noise_precision =
550            posterior.noise_precision_shape / posterior.noise_precision_rate;
551
552        let log_likelihood = -0.5 * expected_noise_precision * sum_squared_residuals;
553
554        Ok(log_likelihood)
555    }
556
557    /// Compute KL divergence KL[q(w,α,β)||p(w,α,β)]
558    fn compute_kl_divergence(&self, _posterior: &VariationalPosterior) -> Result<Float> {
559        // Simplified KL computation
560        // Full implementation would compute KL for multivariate normal and Gamma distributions
561        Ok(0.0)
562    }
563
564    /// Update variational parameters using natural gradients
565    fn update_variational_parameters(
566        &self,
567        posterior: &mut VariationalPosterior,
568        gradients: &VariationalGradients,
569        learning_rate: Float,
570    ) -> Result<()> {
571        // Natural gradient updates for exponential family
572        posterior.weight_mean =
573            &posterior.weight_mean + learning_rate * &gradients.weight_mean_grad;
574
575        // Ensure precision matrix remains positive definite
576        // (Simplified update - full implementation would use proper natural gradients)
577
578        Ok(())
579    }
580
581    /// Update learning rate according to decay schedule
582    fn update_learning_rate(&self, current_lr: Float, epoch: usize) -> Float {
583        match &self.config.learning_rate_decay {
584            LearningRateDecay::Constant => current_lr,
585            LearningRateDecay::Exponential { decay_rate } => {
586                current_lr * decay_rate.powf(epoch as Float)
587            }
588            LearningRateDecay::Step {
589                step_size,
590                step_factor,
591            } => current_lr * step_factor.powf((epoch / step_size) as Float),
592            LearningRateDecay::Polynomial { decay_rate, power } => {
593                current_lr * (1.0 + decay_rate * epoch as Float).powf(-power)
594            }
595            LearningRateDecay::CosineAnnealing { min_lr } => {
596                min_lr
597                    + 0.5
598                        * (current_lr - min_lr)
599                        * (1.0
600                            + (std::f64::consts::PI * epoch as Float
601                                / self.config.max_epochs as Float)
602                                .cos())
603            }
604        }
605    }
606
607    /// Compute feature relevance scores for ARD
608    fn compute_feature_relevance(&self, posterior: &VariationalPosterior) -> Array1<Float> {
609        // Feature relevance based on posterior precision
610        posterior.weight_precision_shape.clone() / &posterior.weight_precision_rate
611    }
612}
613
614/// Gradients for variational parameters
615#[derive(Debug, Clone)]
616struct VariationalGradients {
617    weight_mean_grad: Array1<Float>,
618    weight_precision_grad: Array2<Float>,
619    noise_precision_shape_grad: Float,
620    noise_precision_rate_grad: Float,
621}
622
623impl Predict<Array2<Float>, Array1<Float>> for LargeScaleVariationalRegression<Trained> {
624    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
625        let (pred_mean, _) = self.predict_with_uncertainty(x)?;
626        Ok(pred_mean)
627    }
628}
629
630impl Estimator for LargeScaleVariationalRegression<Untrained> {
631    type Config = LargeScaleVariationalConfig;
632    type Error = SklearsError;
633    type Float = Float;
634
635    fn config(&self) -> &LargeScaleVariationalConfig {
636        &self.config
637    }
638}
639
640impl Estimator for LargeScaleVariationalRegression<Trained> {
641    type Config = LargeScaleVariationalConfig;
642    type Error = SklearsError;
643    type Float = Float;
644    fn config(&self) -> &LargeScaleVariationalConfig {
645        &self.config
646    }
647}
648
649#[allow(non_snake_case)]
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use scirs2_core::ndarray::Array;
654
655    #[test]
656    fn test_large_scale_variational_config() {
657        let config = LargeScaleVariationalConfig::default();
658        assert_eq!(config.max_epochs, 100);
659        assert_eq!(config.batch_size, 256);
660        assert_eq!(config.learning_rate, 0.01);
661        assert_eq!(config.n_mc_samples, 10);
662        assert!(config.use_natural_gradients);
663    }
664
665    #[test]
666    fn test_variational_posterior_creation() {
667        let prior_config = PriorConfiguration::default();
668        let posterior = VariationalPosterior::new(5, &prior_config);
669
670        assert_eq!(posterior.weight_mean.len(), 5);
671        assert_eq!(posterior.weight_covariance.dim(), (5, 5));
672        assert_eq!(posterior.weight_precision_shape.len(), 5);
673    }
674
675    #[test]
676    fn test_learning_rate_decay() {
677        let config = LargeScaleVariationalConfig {
678            learning_rate: 0.1,
679            learning_rate_decay: LearningRateDecay::Exponential { decay_rate: 0.9 },
680            ..Default::default()
681        };
682
683        let model = LargeScaleVariationalRegression::new().with_config(config);
684
685        let lr_epoch_0 = model.update_learning_rate(0.1, 0);
686        let lr_epoch_1 = model.update_learning_rate(0.1, 1);
687
688        assert_eq!(lr_epoch_0, 0.1);
689        assert!((lr_epoch_1 - 0.09).abs() < 1e-10);
690    }
691
692    #[test]
693    fn test_ard_configuration() {
694        let model = LargeScaleVariationalRegression::new()
695            .enable_ard(1e-6)
696            .batch_size(128)
697            .learning_rate(0.005);
698
699        assert!(model.config.prior_config.ard_config.is_some());
700        assert_eq!(model.config.batch_size, 128);
701        assert_eq!(model.config.learning_rate, 0.005);
702    }
703
704    #[test]
705    fn test_batch_extraction() {
706        let X = Array::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
707            .expect("valid array shape");
708        let y = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
709
710        let model = LargeScaleVariationalRegression::new();
711        let indices = [0, 2];
712
713        let batch_x = model.extract_batch_features(&X, &indices);
714        let batch_y = model.extract_batch_targets(&y, &indices);
715
716        assert_eq!(batch_x.dim(), (2, 2));
717        assert_eq!(batch_y.len(), 2);
718        assert_eq!(batch_x[[0, 0]], 1.0);
719        assert_eq!(batch_x[[1, 0]], 5.0);
720        assert_eq!(batch_y[0], 1.0);
721        assert_eq!(batch_y[1], 3.0);
722    }
723
724    #[test]
725    fn test_model_creation() {
726        let model = LargeScaleVariationalRegression::new()
727            .batch_size(64)
728            .learning_rate(0.001)
729            .memory_limit_gb(2.0);
730
731        assert_eq!(model.config.batch_size, 64);
732        assert_eq!(model.config.learning_rate, 0.001);
733        assert_eq!(model.config.memory_limit_gb, Some(2.0));
734    }
735}