1use crate::nystroem::Kernel;
7use crate::{Nystroem, RBFSampler, SamplingStrategy};
8use rayon::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2};
10use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::RngExt;
13use scirs2_core::random::{thread_rng, SeedableRng};
14use sklears_core::{
15 error::{Result, SklearsError},
16 traits::{Fit, Trained, Transform},
17};
18use std::sync::{Arc, Mutex};
19
20#[derive(Debug, Clone)]
22pub struct MemoryConfig {
24 pub max_memory_bytes: usize,
26 pub chunk_size: usize,
28 pub n_workers: usize,
30 pub enable_disk_cache: bool,
32 pub temp_dir: String,
34}
35
36impl Default for MemoryConfig {
37 fn default() -> Self {
38 Self {
39 max_memory_bytes: 1024 * 1024 * 1024, chunk_size: 10000,
41 n_workers: num_cpus::get(),
42 enable_disk_cache: false,
43 temp_dir: "/tmp".to_string(),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct MemoryEfficientRBFSampler {
52 n_components: usize,
53 gamma: f64,
54 config: MemoryConfig,
55 random_seed: Option<u64>,
56}
57
58impl MemoryEfficientRBFSampler {
59 pub fn new(n_components: usize) -> Self {
61 Self {
62 n_components,
63 gamma: 1.0,
64 config: MemoryConfig::default(),
65 random_seed: None,
66 }
67 }
68
69 pub fn gamma(mut self, gamma: f64) -> Self {
71 self.gamma = gamma;
72 self
73 }
74
75 pub fn config(mut self, config: MemoryConfig) -> Self {
77 self.config = config;
78 self
79 }
80
81 pub fn random_seed(mut self, seed: u64) -> Self {
83 self.random_seed = Some(seed);
84 self
85 }
86
87 pub fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
89 let n_samples = x.nrows();
90 let _n_features = x.ncols();
91 let chunk_size = self.config.chunk_size.min(n_samples);
92
93 let mut output = Array2::zeros((n_samples, self.n_components));
95
96 let rbf_sampler = RBFSampler::new(self.n_components).gamma(self.gamma);
98 let fitted_sampler = rbf_sampler.fit(x, &())?;
99
100 for chunk_start in (0..n_samples).step_by(chunk_size) {
102 let chunk_end = (chunk_start + chunk_size).min(n_samples);
103 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
104
105 let chunk_transformed = fitted_sampler.transform(&chunk.to_owned())?;
107
108 output
110 .slice_mut(s![chunk_start..chunk_end, ..])
111 .assign(&chunk_transformed);
112 }
113
114 Ok(output)
115 }
116
117 pub fn transform_chunked_parallel(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
119 let n_samples = x.nrows();
120 let chunk_size = self.config.chunk_size.min(n_samples);
121
122 let rbf_sampler = RBFSampler::new(self.n_components).gamma(self.gamma);
124 let fitted_sampler = Arc::new(rbf_sampler.fit(x, &())?);
125
126 let chunks: Vec<_> = (0..n_samples)
128 .step_by(chunk_size)
129 .map(|start| {
130 let end = (start + chunk_size).min(n_samples);
131 (start, end)
132 })
133 .collect();
134
135 let results: Result<Vec<_>> = chunks
137 .par_iter()
138 .map(|&(start, end)| {
139 let chunk = x.slice(s![start..end, ..]).to_owned();
140 fitted_sampler
141 .transform(&chunk)
142 .map(|result| (start, result))
143 })
144 .collect();
145
146 let results = results?;
147
148 let mut output = Array2::zeros((n_samples, self.n_components));
150 for (start, chunk_result) in results {
151 let end = start + chunk_result.nrows();
152 output.slice_mut(s![start..end, ..]).assign(&chunk_result);
153 }
154
155 Ok(output)
156 }
157}
158
159pub struct FittedMemoryEfficientRBFSampler {
161 random_weights: Array2<f64>,
162 random_offset: Array1<f64>,
163 gamma: f64,
164 config: MemoryConfig,
165}
166
167impl Fit<Array2<f64>, ()> for MemoryEfficientRBFSampler {
168 type Fitted = FittedMemoryEfficientRBFSampler;
169
170 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
171 let n_features = x.ncols();
172
173 let mut rng = if let Some(seed) = self.random_seed {
174 StdRng::seed_from_u64(seed)
175 } else {
176 StdRng::from_seed(thread_rng().random())
177 };
178
179 let random_weights = Array2::from_shape_fn((self.n_components, n_features), |_| {
181 rng.sample(
182 RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).expect("operation should succeed"),
183 )
184 });
185
186 let random_offset = Array1::from_shape_fn(self.n_components, |_| {
187 rng.sample(
188 RandUniform::new(0.0, 2.0 * std::f64::consts::PI)
189 .expect("operation should succeed"),
190 )
191 });
192
193 Ok(FittedMemoryEfficientRBFSampler {
194 random_weights,
195 random_offset,
196 gamma: self.gamma,
197 config: self.config.clone(),
198 })
199 }
200}
201
202impl Transform<Array2<f64>, Array2<f64>> for FittedMemoryEfficientRBFSampler {
203 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
204 let n_samples = x.nrows();
205 let chunk_size = self.config.chunk_size.min(n_samples);
206
207 if n_samples <= chunk_size {
208 self.transform_small(x)
210 } else {
211 self.transform_chunked(x)
213 }
214 }
215}
216
217impl FittedMemoryEfficientRBFSampler {
218 fn transform_small(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
219 let projection = x.dot(&self.random_weights.t());
220 let scaled_projection = projection + &self.random_offset;
221
222 let normalization = (2.0 / self.random_weights.nrows() as f64).sqrt();
223 Ok(scaled_projection.mapv(|v| v.cos() * normalization))
224 }
225
226 fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
227 let n_samples = x.nrows();
228 let chunk_size = self.config.chunk_size;
229 let mut output = Array2::zeros((n_samples, self.random_weights.nrows()));
230
231 for chunk_start in (0..n_samples).step_by(chunk_size) {
232 let chunk_end = (chunk_start + chunk_size).min(n_samples);
233 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
234
235 let chunk_transformed = self.transform_small(&chunk.to_owned())?;
236 output
237 .slice_mut(s![chunk_start..chunk_end, ..])
238 .assign(&chunk_transformed);
239 }
240
241 Ok(output)
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct MemoryEfficientNystroem {
249 n_components: usize,
250 kernel: String,
251 gamma: Option<f64>,
252 degree: Option<i32>,
253 coef0: Option<f64>,
254 sampling: SamplingStrategy,
255 config: MemoryConfig,
256 random_seed: Option<u64>,
257}
258
259impl MemoryEfficientNystroem {
260 pub fn new(n_components: usize) -> Self {
262 Self {
263 n_components,
264 kernel: "rbf".to_string(),
265 gamma: None,
266 degree: None,
267 coef0: None,
268 sampling: SamplingStrategy::Random,
269 config: MemoryConfig::default(),
270 random_seed: None,
271 }
272 }
273
274 pub fn kernel(mut self, kernel: &str) -> Self {
276 self.kernel = kernel.to_string();
277 self
278 }
279
280 pub fn gamma(mut self, gamma: f64) -> Self {
282 self.gamma = Some(gamma);
283 self
284 }
285
286 pub fn sampling(mut self, sampling: SamplingStrategy) -> Self {
288 self.sampling = sampling;
289 self
290 }
291
292 pub fn config(mut self, config: MemoryConfig) -> Self {
294 self.config = config;
295 self
296 }
297
298 pub fn fit_incremental(
300 &self,
301 x_chunks: Vec<Array2<f64>>,
302 ) -> Result<FittedMemoryEfficientNystroem> {
303 let mut representative_samples = Vec::new();
305 let samples_per_chunk = self.n_components / x_chunks.len().max(1);
306
307 for chunk in &x_chunks {
308 let n_samples = chunk.nrows().min(samples_per_chunk);
309 if n_samples > 0 {
310 let indices: Vec<usize> = (0..chunk.nrows()).collect();
311 let selected_indices = &indices[..n_samples];
312
313 for &idx in selected_indices {
314 representative_samples.push(chunk.row(idx).to_owned());
315 }
316 }
317 }
318
319 if representative_samples.is_empty() {
320 return Err(SklearsError::InvalidInput(
321 "No samples found in chunks".to_string(),
322 ));
323 }
324
325 let n_selected = representative_samples.len().min(self.n_components);
327 let n_features = representative_samples[0].len();
328 let mut combined_data = Array2::zeros((n_selected, n_features));
329
330 for (i, sample) in representative_samples.iter().take(n_selected).enumerate() {
331 combined_data.row_mut(i).assign(sample);
332 }
333
334 let kernel = match self.kernel.as_str() {
336 "rbf" => Kernel::Rbf {
337 gamma: self.gamma.unwrap_or(1.0),
338 },
339 "linear" => Kernel::Linear,
340 "polynomial" => Kernel::Polynomial {
341 gamma: self.gamma.unwrap_or(1.0),
342 degree: self.degree.unwrap_or(3) as u32,
343 coef0: self.coef0.unwrap_or(1.0),
344 },
345 _ => Kernel::Rbf { gamma: 1.0 }, };
347 let nystroem = Nystroem::new(kernel, n_selected).sampling_strategy(self.sampling.clone());
348
349 let fitted_nystroem = nystroem.fit(&combined_data, &())?;
350
351 Ok(FittedMemoryEfficientNystroem {
352 fitted_nystroem,
353 config: self.config.clone(),
354 })
355 }
356}
357
358pub struct FittedMemoryEfficientNystroem {
360 fitted_nystroem: crate::nystroem::Nystroem<Trained>,
361 config: MemoryConfig,
362}
363
364impl Fit<Array2<f64>, ()> for MemoryEfficientNystroem {
365 type Fitted = FittedMemoryEfficientNystroem;
366
367 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
368 let kernel = match self.kernel.as_str() {
369 "rbf" => Kernel::Rbf {
370 gamma: self.gamma.unwrap_or(1.0),
371 },
372 "linear" => Kernel::Linear,
373 "polynomial" => Kernel::Polynomial {
374 gamma: self.gamma.unwrap_or(1.0),
375 degree: self.degree.unwrap_or(3) as u32,
376 coef0: self.coef0.unwrap_or(1.0),
377 },
378 _ => Kernel::Rbf { gamma: 1.0 }, };
380 let nystroem =
381 Nystroem::new(kernel, self.n_components).sampling_strategy(self.sampling.clone());
382
383 let fitted_nystroem = nystroem.fit(x, &())?;
384
385 Ok(FittedMemoryEfficientNystroem {
386 fitted_nystroem,
387 config: self.config.clone(),
388 })
389 }
390}
391
392impl Transform<Array2<f64>, Array2<f64>> for FittedMemoryEfficientNystroem {
393 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
394 let n_samples = x.nrows();
395 let chunk_size = self.config.chunk_size;
396
397 if n_samples <= chunk_size {
398 self.fitted_nystroem.transform(x)
400 } else {
401 self.transform_chunked(x)
403 }
404 }
405}
406
407impl FittedMemoryEfficientNystroem {
408 fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
409 let n_samples = x.nrows();
410 let chunk_size = self.config.chunk_size;
411 let n_components = self
412 .fitted_nystroem
413 .transform(&x.slice(s![0..1, ..]).to_owned())?
414 .ncols();
415
416 let mut output = Array2::zeros((n_samples, n_components));
417
418 for chunk_start in (0..n_samples).step_by(chunk_size) {
419 let chunk_end = (chunk_start + chunk_size).min(n_samples);
420 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
421
422 let chunk_transformed = self.fitted_nystroem.transform(&chunk.to_owned())?;
423 output
424 .slice_mut(s![chunk_start..chunk_end, ..])
425 .assign(&chunk_transformed);
426 }
427
428 Ok(output)
429 }
430}
431
432pub struct MemoryMonitor {
434 max_memory_bytes: usize,
435 current_usage: Arc<Mutex<usize>>,
436}
437
438impl MemoryMonitor {
439 pub fn new(max_memory_bytes: usize) -> Self {
441 Self {
442 max_memory_bytes,
443 current_usage: Arc::new(Mutex::new(0)),
444 }
445 }
446
447 pub fn can_allocate(&self, bytes: usize) -> bool {
449 let current = *self.current_usage.lock().expect("operation should succeed");
450 current + bytes <= self.max_memory_bytes
451 }
452
453 pub fn allocate(&self, bytes: usize) -> Result<()> {
455 let mut current = self.current_usage.lock().expect("operation should succeed");
456 if *current + bytes > self.max_memory_bytes {
457 return Err(SklearsError::InvalidInput(format!(
458 "Memory limit exceeded: {} + {} > {}",
459 *current, bytes, self.max_memory_bytes
460 )));
461 }
462 *current += bytes;
463 Ok(())
464 }
465
466 pub fn deallocate(&self, bytes: usize) {
468 let mut current = self.current_usage.lock().expect("operation should succeed");
469 *current = current.saturating_sub(bytes);
470 }
471
472 pub fn current_usage(&self) -> usize {
474 *self.current_usage.lock().expect("operation should succeed")
475 }
476
477 pub fn usage_percentage(&self) -> f64 {
479 let current = *self.current_usage.lock().expect("operation should succeed");
480 (current as f64 / self.max_memory_bytes as f64) * 100.0
481 }
482}
483
484#[allow(non_snake_case)]
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use approx::assert_abs_diff_eq;
489 use scirs2_core::ndarray::Array2;
490
491 #[test]
492 fn test_memory_efficient_rbf_sampler() {
493 let x = Array2::from_shape_vec((100, 10), (0..1000).map(|i| i as f64).collect())
494 .expect("operation should succeed");
495
496 let sampler = MemoryEfficientRBFSampler::new(50)
497 .gamma(0.1)
498 .config(MemoryConfig {
499 chunk_size: 30,
500 ..Default::default()
501 });
502
503 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
504 let transformed = fitted.transform(&x).expect("operation should succeed");
505
506 assert_eq!(transformed.shape(), &[100, 50]);
507
508 let small_x = x.slice(s![0..10, ..]).to_owned();
510 let small_transformed = fitted
511 .transform(&small_x)
512 .expect("operation should succeed");
513 let chunked_transformed = transformed.slice(s![0..10, ..]);
514
515 assert_abs_diff_eq!(small_transformed, chunked_transformed, epsilon = 1e-10);
516 }
517
518 #[test]
519 fn test_memory_efficient_rbf_chunked_parallel() {
520 let x = Array2::from_shape_vec((200, 5), (0..1000).map(|i| i as f64 * 0.1).collect())
521 .expect("operation should succeed");
522
523 let sampler = MemoryEfficientRBFSampler::new(30)
524 .gamma(1.0)
525 .config(MemoryConfig {
526 chunk_size: 50,
527 n_workers: 2,
528 ..Default::default()
529 });
530
531 let result = sampler
532 .transform_chunked_parallel(&x)
533 .expect("operation should succeed");
534 assert_eq!(result.shape(), &[200, 30]);
535
536 let mean_val = result.mean().expect("operation should succeed");
538 let std_val = result.std(0.0);
539 assert!(mean_val.abs() < 0.5); assert!(std_val > 0.1); }
542
543 #[test]
544 fn test_memory_efficient_nystroem() {
545 let x = Array2::from_shape_vec((80, 6), (0..480).map(|i| i as f64 * 0.01).collect())
546 .expect("operation should succeed");
547
548 let nystroem = MemoryEfficientNystroem::new(20)
549 .kernel("rbf")
550 .gamma(0.5)
551 .config(MemoryConfig {
552 chunk_size: 25,
553 ..Default::default()
554 });
555
556 let fitted = nystroem.fit(&x, &()).expect("operation should succeed");
557 let transformed = fitted.transform(&x).expect("operation should succeed");
558
559 assert_eq!(transformed.shape(), &[80, 20]);
560 }
561
562 #[test]
563 fn test_memory_efficient_nystroem_incremental() {
564 let chunk1 = Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.1).collect())
566 .expect("operation should succeed");
567 let chunk2 = Array2::from_shape_vec((40, 4), (120..280).map(|i| i as f64 * 0.1).collect())
568 .expect("operation should succeed");
569 let chunk3 = Array2::from_shape_vec((30, 4), (280..400).map(|i| i as f64 * 0.1).collect())
570 .expect("operation should succeed");
571
572 let chunks = vec![chunk1, chunk2.clone(), chunk3];
573
574 let nystroem = MemoryEfficientNystroem::new(15)
575 .kernel("rbf")
576 .config(MemoryConfig {
577 chunk_size: 20,
578 ..Default::default()
579 });
580
581 let fitted = nystroem
582 .fit_incremental(chunks)
583 .expect("operation should succeed");
584 let transformed = fitted.transform(&chunk2).expect("operation should succeed");
585
586 assert_eq!(transformed.shape(), &[40, 15]);
587 }
588
589 #[test]
590 fn test_memory_monitor() {
591 let monitor = MemoryMonitor::new(1000);
592
593 assert!(monitor.can_allocate(500));
594 assert!(monitor.allocate(500).is_ok());
595 assert_eq!(monitor.current_usage(), 500);
596 assert_eq!(monitor.usage_percentage(), 50.0);
597
598 assert!(!monitor.can_allocate(600)); assert!(monitor.allocate(400).is_ok()); assert!(monitor.allocate(200).is_err()); monitor.deallocate(300);
604 assert_eq!(monitor.current_usage(), 600);
605 assert!(monitor.can_allocate(300));
606 }
607
608 #[test]
609 fn test_memory_config() {
610 let config = MemoryConfig::default();
611 assert_eq!(config.max_memory_bytes, 1024 * 1024 * 1024);
612 assert_eq!(config.chunk_size, 10000);
613 assert!(config.n_workers > 0);
614
615 let custom_config = MemoryConfig {
616 max_memory_bytes: 512 * 1024 * 1024,
617 chunk_size: 5000,
618 n_workers: 4,
619 enable_disk_cache: true,
620 temp_dir: "/custom/temp".to_string(),
621 };
622
623 let sampler = MemoryEfficientRBFSampler::new(50).config(custom_config.clone());
624 assert_eq!(sampler.config.max_memory_bytes, 512 * 1024 * 1024);
625 assert_eq!(sampler.config.chunk_size, 5000);
626 assert_eq!(sampler.config.n_workers, 4);
627 assert!(sampler.config.enable_disk_cache);
628 assert_eq!(sampler.config.temp_dir, "/custom/temp");
629 }
630
631 #[test]
632 fn test_reproducibility() {
633 let x = Array2::from_shape_vec((50, 8), (0..400).map(|i| i as f64 * 0.05).collect())
634 .expect("operation should succeed");
635
636 let sampler1 = MemoryEfficientRBFSampler::new(20)
637 .gamma(0.2)
638 .random_seed(42);
639
640 let sampler2 = MemoryEfficientRBFSampler::new(20)
641 .gamma(0.2)
642 .random_seed(42);
643
644 let fitted1 = sampler1.fit(&x, &()).expect("operation should succeed");
645 let fitted2 = sampler2.fit(&x, &()).expect("operation should succeed");
646
647 let result1 = fitted1.transform(&x).expect("operation should succeed");
648 let result2 = fitted2.transform(&x).expect("operation should succeed");
649
650 assert_abs_diff_eq!(result1, result2, epsilon = 1e-10);
651 }
652}