sklears_kernel_approximation/
memory_efficient.rs

1//! Memory-efficient kernel approximation methods
2//!
3//! This module provides memory-efficient implementations of kernel approximation methods
4//! that can handle large datasets that don't fit entirely in memory.
5
6use crate::nystroem::Kernel;
7use crate::{Nystroem, RBFSampler, SamplingStrategy};
8use rayon::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2};
10use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::Rng;
13use scirs2_core::random::{thread_rng, SeedableRng};
14use sklears_core::{
15    error::{Result, SklearsError},
16    traits::{Fit, Trained, Transform},
17};
18use std::sync::{Arc, Mutex};
19
20/// Configuration for memory-efficient operations
21#[derive(Debug, Clone)]
22/// MemoryConfig
23pub struct MemoryConfig {
24    /// Maximum memory usage in bytes
25    pub max_memory_bytes: usize,
26    /// Chunk size for processing
27    pub chunk_size: usize,
28    /// Number of parallel workers
29    pub n_workers: usize,
30    /// Enable disk caching for intermediate results
31    pub enable_disk_cache: bool,
32    /// Temporary directory for disk cache
33    pub temp_dir: String,
34}
35
36impl Default for MemoryConfig {
37    fn default() -> Self {
38        Self {
39            max_memory_bytes: 1024 * 1024 * 1024, // 1GB
40            chunk_size: 10000,
41            n_workers: num_cpus::get(),
42            enable_disk_cache: false,
43            temp_dir: "/tmp".to_string(),
44        }
45    }
46}
47
48/// Memory-efficient RBF sampler with chunked processing
49#[derive(Debug, Clone)]
50/// MemoryEfficientRBFSampler
51pub struct MemoryEfficientRBFSampler {
52    n_components: usize,
53    gamma: f64,
54    config: MemoryConfig,
55    random_seed: Option<u64>,
56}
57
58impl MemoryEfficientRBFSampler {
59    /// Create a new memory-efficient RBF sampler
60    pub fn new(n_components: usize) -> Self {
61        Self {
62            n_components,
63            gamma: 1.0,
64            config: MemoryConfig::default(),
65            random_seed: None,
66        }
67    }
68
69    /// Set gamma parameter
70    pub fn gamma(mut self, gamma: f64) -> Self {
71        self.gamma = gamma;
72        self
73    }
74
75    /// Set memory configuration
76    pub fn config(mut self, config: MemoryConfig) -> Self {
77        self.config = config;
78        self
79    }
80
81    /// Set random seed
82    pub fn random_seed(mut self, seed: u64) -> Self {
83        self.random_seed = Some(seed);
84        self
85    }
86
87    /// Process data in chunks
88    pub fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
89        let n_samples = x.nrows();
90        let _n_features = x.ncols();
91        let chunk_size = self.config.chunk_size.min(n_samples);
92
93        // Initialize output array
94        let mut output = Array2::zeros((n_samples, self.n_components));
95
96        // Create RBF sampler for consistent random features
97        let rbf_sampler = RBFSampler::new(self.n_components).gamma(self.gamma);
98        let fitted_sampler = rbf_sampler.fit(x, &())?;
99
100        // Process in chunks
101        for chunk_start in (0..n_samples).step_by(chunk_size) {
102            let chunk_end = (chunk_start + chunk_size).min(n_samples);
103            let chunk = x.slice(s![chunk_start..chunk_end, ..]);
104
105            // Transform chunk
106            let chunk_transformed = fitted_sampler.transform(&chunk.to_owned())?;
107
108            // Store result
109            output
110                .slice_mut(s![chunk_start..chunk_end, ..])
111                .assign(&chunk_transformed);
112        }
113
114        Ok(output)
115    }
116
117    /// Parallel chunked processing
118    pub fn transform_chunked_parallel(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
119        let n_samples = x.nrows();
120        let chunk_size = self.config.chunk_size.min(n_samples);
121
122        // Create RBF sampler for consistent random features
123        let rbf_sampler = RBFSampler::new(self.n_components).gamma(self.gamma);
124        let fitted_sampler = Arc::new(rbf_sampler.fit(x, &())?);
125
126        // Create chunks
127        let chunks: Vec<_> = (0..n_samples)
128            .step_by(chunk_size)
129            .map(|start| {
130                let end = (start + chunk_size).min(n_samples);
131                (start, end)
132            })
133            .collect();
134
135        // Process chunks in parallel
136        let results: Result<Vec<_>> = chunks
137            .par_iter()
138            .map(|&(start, end)| {
139                let chunk = x.slice(s![start..end, ..]).to_owned();
140                fitted_sampler
141                    .transform(&chunk)
142                    .map(|result| (start, result))
143            })
144            .collect();
145
146        let results = results?;
147
148        // Combine results
149        let mut output = Array2::zeros((n_samples, self.n_components));
150        for (start, chunk_result) in results {
151            let end = start + chunk_result.nrows();
152            output.slice_mut(s![start..end, ..]).assign(&chunk_result);
153        }
154
155        Ok(output)
156    }
157}
158
159/// Fitted memory-efficient RBF sampler
160pub struct FittedMemoryEfficientRBFSampler {
161    random_weights: Array2<f64>,
162    random_offset: Array1<f64>,
163    gamma: f64,
164    config: MemoryConfig,
165}
166
167impl Fit<Array2<f64>, ()> for MemoryEfficientRBFSampler {
168    type Fitted = FittedMemoryEfficientRBFSampler;
169
170    fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
171        let n_features = x.ncols();
172
173        let mut rng = if let Some(seed) = self.random_seed {
174            StdRng::seed_from_u64(seed)
175        } else {
176            StdRng::from_seed(thread_rng().gen())
177        };
178
179        // Generate random weights and offsets
180        let random_weights = Array2::from_shape_fn((self.n_components, n_features), |_| {
181            rng.sample(RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap())
182        });
183
184        let random_offset = Array1::from_shape_fn(self.n_components, |_| {
185            rng.sample(RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap())
186        });
187
188        Ok(FittedMemoryEfficientRBFSampler {
189            random_weights,
190            random_offset,
191            gamma: self.gamma,
192            config: self.config.clone(),
193        })
194    }
195}
196
197impl Transform<Array2<f64>, Array2<f64>> for FittedMemoryEfficientRBFSampler {
198    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
199        let n_samples = x.nrows();
200        let chunk_size = self.config.chunk_size.min(n_samples);
201
202        if n_samples <= chunk_size {
203            // Small dataset, process normally
204            self.transform_small(x)
205        } else {
206            // Large dataset, use chunked processing
207            self.transform_chunked(x)
208        }
209    }
210}
211
212impl FittedMemoryEfficientRBFSampler {
213    fn transform_small(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
214        let projection = x.dot(&self.random_weights.t());
215        let scaled_projection = projection + &self.random_offset;
216
217        let normalization = (2.0 / self.random_weights.nrows() as f64).sqrt();
218        Ok(scaled_projection.mapv(|v| v.cos() * normalization))
219    }
220
221    fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
222        let n_samples = x.nrows();
223        let chunk_size = self.config.chunk_size;
224        let mut output = Array2::zeros((n_samples, self.random_weights.nrows()));
225
226        for chunk_start in (0..n_samples).step_by(chunk_size) {
227            let chunk_end = (chunk_start + chunk_size).min(n_samples);
228            let chunk = x.slice(s![chunk_start..chunk_end, ..]);
229
230            let chunk_transformed = self.transform_small(&chunk.to_owned())?;
231            output
232                .slice_mut(s![chunk_start..chunk_end, ..])
233                .assign(&chunk_transformed);
234        }
235
236        Ok(output)
237    }
238}
239
240/// Memory-efficient Nyström approximation
241#[derive(Debug, Clone)]
242/// MemoryEfficientNystroem
243pub struct MemoryEfficientNystroem {
244    n_components: usize,
245    kernel: String,
246    gamma: Option<f64>,
247    degree: Option<i32>,
248    coef0: Option<f64>,
249    sampling: SamplingStrategy,
250    config: MemoryConfig,
251    random_seed: Option<u64>,
252}
253
254impl MemoryEfficientNystroem {
255    /// Create a new memory-efficient Nyström approximation
256    pub fn new(n_components: usize) -> Self {
257        Self {
258            n_components,
259            kernel: "rbf".to_string(),
260            gamma: None,
261            degree: None,
262            coef0: None,
263            sampling: SamplingStrategy::Random,
264            config: MemoryConfig::default(),
265            random_seed: None,
266        }
267    }
268
269    /// Set kernel type
270    pub fn kernel(mut self, kernel: &str) -> Self {
271        self.kernel = kernel.to_string();
272        self
273    }
274
275    /// Set gamma parameter
276    pub fn gamma(mut self, gamma: f64) -> Self {
277        self.gamma = Some(gamma);
278        self
279    }
280
281    /// Set sampling strategy
282    pub fn sampling(mut self, sampling: SamplingStrategy) -> Self {
283        self.sampling = sampling;
284        self
285    }
286
287    /// Set memory configuration
288    pub fn config(mut self, config: MemoryConfig) -> Self {
289        self.config = config;
290        self
291    }
292
293    /// Out-of-core training for large datasets
294    pub fn fit_incremental(
295        &self,
296        x_chunks: Vec<Array2<f64>>,
297    ) -> Result<FittedMemoryEfficientNystroem> {
298        // Collect representative samples from all chunks
299        let mut representative_samples = Vec::new();
300        let samples_per_chunk = self.n_components / x_chunks.len().max(1);
301
302        for chunk in &x_chunks {
303            let n_samples = chunk.nrows().min(samples_per_chunk);
304            if n_samples > 0 {
305                let indices: Vec<usize> = (0..chunk.nrows()).collect();
306                let selected_indices = &indices[..n_samples];
307
308                for &idx in selected_indices {
309                    representative_samples.push(chunk.row(idx).to_owned());
310                }
311            }
312        }
313
314        if representative_samples.is_empty() {
315            return Err(SklearsError::InvalidInput(
316                "No samples found in chunks".to_string(),
317            ));
318        }
319
320        // Create combined dataset from representative samples
321        let n_selected = representative_samples.len().min(self.n_components);
322        let n_features = representative_samples[0].len();
323        let mut combined_data = Array2::zeros((n_selected, n_features));
324
325        for (i, sample) in representative_samples.iter().take(n_selected).enumerate() {
326            combined_data.row_mut(i).assign(sample);
327        }
328
329        // Fit standard Nyström on representative samples
330        let kernel = match self.kernel.as_str() {
331            "rbf" => Kernel::Rbf {
332                gamma: self.gamma.unwrap_or(1.0),
333            },
334            "linear" => Kernel::Linear,
335            "polynomial" => Kernel::Polynomial {
336                gamma: self.gamma.unwrap_or(1.0),
337                degree: self.degree.unwrap_or(3) as u32,
338                coef0: self.coef0.unwrap_or(1.0),
339            },
340            _ => Kernel::Rbf { gamma: 1.0 }, // default
341        };
342        let nystroem = Nystroem::new(kernel, n_selected).sampling_strategy(self.sampling.clone());
343
344        let fitted_nystroem = nystroem.fit(&combined_data, &())?;
345
346        Ok(FittedMemoryEfficientNystroem {
347            fitted_nystroem,
348            config: self.config.clone(),
349        })
350    }
351}
352
353/// Fitted memory-efficient Nyström approximation
354pub struct FittedMemoryEfficientNystroem {
355    fitted_nystroem: crate::nystroem::Nystroem<Trained>,
356    config: MemoryConfig,
357}
358
359impl Fit<Array2<f64>, ()> for MemoryEfficientNystroem {
360    type Fitted = FittedMemoryEfficientNystroem;
361
362    fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
363        let kernel = match self.kernel.as_str() {
364            "rbf" => Kernel::Rbf {
365                gamma: self.gamma.unwrap_or(1.0),
366            },
367            "linear" => Kernel::Linear,
368            "polynomial" => Kernel::Polynomial {
369                gamma: self.gamma.unwrap_or(1.0),
370                degree: self.degree.unwrap_or(3) as u32,
371                coef0: self.coef0.unwrap_or(1.0),
372            },
373            _ => Kernel::Rbf { gamma: 1.0 }, // default
374        };
375        let nystroem =
376            Nystroem::new(kernel, self.n_components).sampling_strategy(self.sampling.clone());
377
378        let fitted_nystroem = nystroem.fit(x, &())?;
379
380        Ok(FittedMemoryEfficientNystroem {
381            fitted_nystroem,
382            config: self.config.clone(),
383        })
384    }
385}
386
387impl Transform<Array2<f64>, Array2<f64>> for FittedMemoryEfficientNystroem {
388    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
389        let n_samples = x.nrows();
390        let chunk_size = self.config.chunk_size;
391
392        if n_samples <= chunk_size {
393            // Small dataset, process normally
394            self.fitted_nystroem.transform(x)
395        } else {
396            // Large dataset, use chunked processing
397            self.transform_chunked(x)
398        }
399    }
400}
401
402impl FittedMemoryEfficientNystroem {
403    fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
404        let n_samples = x.nrows();
405        let chunk_size = self.config.chunk_size;
406        let n_components = self
407            .fitted_nystroem
408            .transform(&x.slice(s![0..1, ..]).to_owned())?
409            .ncols();
410
411        let mut output = Array2::zeros((n_samples, n_components));
412
413        for chunk_start in (0..n_samples).step_by(chunk_size) {
414            let chunk_end = (chunk_start + chunk_size).min(n_samples);
415            let chunk = x.slice(s![chunk_start..chunk_end, ..]);
416
417            let chunk_transformed = self.fitted_nystroem.transform(&chunk.to_owned())?;
418            output
419                .slice_mut(s![chunk_start..chunk_end, ..])
420                .assign(&chunk_transformed);
421        }
422
423        Ok(output)
424    }
425}
426
427/// Memory usage monitoring utilities
428pub struct MemoryMonitor {
429    max_memory_bytes: usize,
430    current_usage: Arc<Mutex<usize>>,
431}
432
433impl MemoryMonitor {
434    /// Create a new memory monitor
435    pub fn new(max_memory_bytes: usize) -> Self {
436        Self {
437            max_memory_bytes,
438            current_usage: Arc::new(Mutex::new(0)),
439        }
440    }
441
442    /// Check if we can allocate more memory
443    pub fn can_allocate(&self, bytes: usize) -> bool {
444        let current = *self.current_usage.lock().unwrap();
445        current + bytes <= self.max_memory_bytes
446    }
447
448    /// Allocate memory (tracking purposes)
449    pub fn allocate(&self, bytes: usize) -> Result<()> {
450        let mut current = self.current_usage.lock().unwrap();
451        if *current + bytes > self.max_memory_bytes {
452            return Err(SklearsError::InvalidInput(format!(
453                "Memory limit exceeded: {} + {} > {}",
454                *current, bytes, self.max_memory_bytes
455            )));
456        }
457        *current += bytes;
458        Ok(())
459    }
460
461    /// Deallocate memory
462    pub fn deallocate(&self, bytes: usize) {
463        let mut current = self.current_usage.lock().unwrap();
464        *current = current.saturating_sub(bytes);
465    }
466
467    /// Get current memory usage
468    pub fn current_usage(&self) -> usize {
469        *self.current_usage.lock().unwrap()
470    }
471
472    /// Get memory usage percentage
473    pub fn usage_percentage(&self) -> f64 {
474        let current = *self.current_usage.lock().unwrap();
475        (current as f64 / self.max_memory_bytes as f64) * 100.0
476    }
477}
478
479#[allow(non_snake_case)]
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use approx::assert_abs_diff_eq;
484    use scirs2_core::ndarray::Array2;
485
486    #[test]
487    fn test_memory_efficient_rbf_sampler() {
488        let x = Array2::from_shape_vec((100, 10), (0..1000).map(|i| i as f64).collect()).unwrap();
489
490        let sampler = MemoryEfficientRBFSampler::new(50)
491            .gamma(0.1)
492            .config(MemoryConfig {
493                chunk_size: 30,
494                ..Default::default()
495            });
496
497        let fitted = sampler.fit(&x, &()).unwrap();
498        let transformed = fitted.transform(&x).unwrap();
499
500        assert_eq!(transformed.shape(), &[100, 50]);
501
502        // Test chunked processing gives same results as small dataset
503        let small_x = x.slice(s![0..10, ..]).to_owned();
504        let small_transformed = fitted.transform(&small_x).unwrap();
505        let chunked_transformed = transformed.slice(s![0..10, ..]);
506
507        assert_abs_diff_eq!(small_transformed, chunked_transformed, epsilon = 1e-10);
508    }
509
510    #[test]
511    fn test_memory_efficient_rbf_chunked_parallel() {
512        let x =
513            Array2::from_shape_vec((200, 5), (0..1000).map(|i| i as f64 * 0.1).collect()).unwrap();
514
515        let sampler = MemoryEfficientRBFSampler::new(30)
516            .gamma(1.0)
517            .config(MemoryConfig {
518                chunk_size: 50,
519                n_workers: 2,
520                ..Default::default()
521            });
522
523        let result = sampler.transform_chunked_parallel(&x).unwrap();
524        assert_eq!(result.shape(), &[200, 30]);
525
526        // Verify output is reasonable (not all zeros, not all same values)
527        let mean_val = result.mean().unwrap();
528        let std_val = result.std(0.0);
529        assert!(mean_val.abs() < 0.5); // Should be roughly centered
530        assert!(std_val > 0.1); // Should have some variance
531    }
532
533    #[test]
534    fn test_memory_efficient_nystroem() {
535        let x =
536            Array2::from_shape_vec((80, 6), (0..480).map(|i| i as f64 * 0.01).collect()).unwrap();
537
538        let nystroem = MemoryEfficientNystroem::new(20)
539            .kernel("rbf")
540            .gamma(0.5)
541            .config(MemoryConfig {
542                chunk_size: 25,
543                ..Default::default()
544            });
545
546        let fitted = nystroem.fit(&x, &()).unwrap();
547        let transformed = fitted.transform(&x).unwrap();
548
549        assert_eq!(transformed.shape(), &[80, 20]);
550    }
551
552    #[test]
553    fn test_memory_efficient_nystroem_incremental() {
554        // Create multiple chunks
555        let chunk1 =
556            Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.1).collect()).unwrap();
557        let chunk2 =
558            Array2::from_shape_vec((40, 4), (120..280).map(|i| i as f64 * 0.1).collect()).unwrap();
559        let chunk3 =
560            Array2::from_shape_vec((30, 4), (280..400).map(|i| i as f64 * 0.1).collect()).unwrap();
561
562        let chunks = vec![chunk1, chunk2.clone(), chunk3];
563
564        let nystroem = MemoryEfficientNystroem::new(15)
565            .kernel("rbf")
566            .config(MemoryConfig {
567                chunk_size: 20,
568                ..Default::default()
569            });
570
571        let fitted = nystroem.fit_incremental(chunks).unwrap();
572        let transformed = fitted.transform(&chunk2).unwrap();
573
574        assert_eq!(transformed.shape(), &[40, 15]);
575    }
576
577    #[test]
578    fn test_memory_monitor() {
579        let monitor = MemoryMonitor::new(1000);
580
581        assert!(monitor.can_allocate(500));
582        assert!(monitor.allocate(500).is_ok());
583        assert_eq!(monitor.current_usage(), 500);
584        assert_eq!(monitor.usage_percentage(), 50.0);
585
586        assert!(!monitor.can_allocate(600)); // Would exceed limit
587        assert!(monitor.allocate(400).is_ok()); // Total = 900, still OK
588
589        assert!(monitor.allocate(200).is_err()); // Would exceed limit
590
591        monitor.deallocate(300);
592        assert_eq!(monitor.current_usage(), 600);
593        assert!(monitor.can_allocate(300));
594    }
595
596    #[test]
597    fn test_memory_config() {
598        let config = MemoryConfig::default();
599        assert_eq!(config.max_memory_bytes, 1024 * 1024 * 1024);
600        assert_eq!(config.chunk_size, 10000);
601        assert!(config.n_workers > 0);
602
603        let custom_config = MemoryConfig {
604            max_memory_bytes: 512 * 1024 * 1024,
605            chunk_size: 5000,
606            n_workers: 4,
607            enable_disk_cache: true,
608            temp_dir: "/custom/temp".to_string(),
609        };
610
611        let sampler = MemoryEfficientRBFSampler::new(50).config(custom_config.clone());
612        assert_eq!(sampler.config.max_memory_bytes, 512 * 1024 * 1024);
613        assert_eq!(sampler.config.chunk_size, 5000);
614        assert_eq!(sampler.config.n_workers, 4);
615        assert!(sampler.config.enable_disk_cache);
616        assert_eq!(sampler.config.temp_dir, "/custom/temp");
617    }
618
619    #[test]
620    fn test_reproducibility() {
621        let x =
622            Array2::from_shape_vec((50, 8), (0..400).map(|i| i as f64 * 0.05).collect()).unwrap();
623
624        let sampler1 = MemoryEfficientRBFSampler::new(20)
625            .gamma(0.2)
626            .random_seed(42);
627
628        let sampler2 = MemoryEfficientRBFSampler::new(20)
629            .gamma(0.2)
630            .random_seed(42);
631
632        let fitted1 = sampler1.fit(&x, &()).unwrap();
633        let fitted2 = sampler2.fit(&x, &()).unwrap();
634
635        let result1 = fitted1.transform(&x).unwrap();
636        let result2 = fitted2.transform(&x).unwrap();
637
638        assert_abs_diff_eq!(result1, result2, epsilon = 1e-10);
639    }
640}