1use crate::nystroem::Kernel;
7use crate::{Nystroem, RBFSampler, SamplingStrategy};
8use rayon::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2};
10use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::Rng;
13use scirs2_core::random::{thread_rng, SeedableRng};
14use sklears_core::{
15 error::{Result, SklearsError},
16 traits::{Fit, Trained, Transform},
17};
18use std::sync::{Arc, Mutex};
19
20#[derive(Debug, Clone)]
22pub struct MemoryConfig {
24 pub max_memory_bytes: usize,
26 pub chunk_size: usize,
28 pub n_workers: usize,
30 pub enable_disk_cache: bool,
32 pub temp_dir: String,
34}
35
36impl Default for MemoryConfig {
37 fn default() -> Self {
38 Self {
39 max_memory_bytes: 1024 * 1024 * 1024, chunk_size: 10000,
41 n_workers: num_cpus::get(),
42 enable_disk_cache: false,
43 temp_dir: "/tmp".to_string(),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct MemoryEfficientRBFSampler {
52 n_components: usize,
53 gamma: f64,
54 config: MemoryConfig,
55 random_seed: Option<u64>,
56}
57
58impl MemoryEfficientRBFSampler {
59 pub fn new(n_components: usize) -> Self {
61 Self {
62 n_components,
63 gamma: 1.0,
64 config: MemoryConfig::default(),
65 random_seed: None,
66 }
67 }
68
69 pub fn gamma(mut self, gamma: f64) -> Self {
71 self.gamma = gamma;
72 self
73 }
74
75 pub fn config(mut self, config: MemoryConfig) -> Self {
77 self.config = config;
78 self
79 }
80
81 pub fn random_seed(mut self, seed: u64) -> Self {
83 self.random_seed = Some(seed);
84 self
85 }
86
87 pub fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
89 let n_samples = x.nrows();
90 let _n_features = x.ncols();
91 let chunk_size = self.config.chunk_size.min(n_samples);
92
93 let mut output = Array2::zeros((n_samples, self.n_components));
95
96 let rbf_sampler = RBFSampler::new(self.n_components).gamma(self.gamma);
98 let fitted_sampler = rbf_sampler.fit(x, &())?;
99
100 for chunk_start in (0..n_samples).step_by(chunk_size) {
102 let chunk_end = (chunk_start + chunk_size).min(n_samples);
103 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
104
105 let chunk_transformed = fitted_sampler.transform(&chunk.to_owned())?;
107
108 output
110 .slice_mut(s![chunk_start..chunk_end, ..])
111 .assign(&chunk_transformed);
112 }
113
114 Ok(output)
115 }
116
117 pub fn transform_chunked_parallel(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
119 let n_samples = x.nrows();
120 let chunk_size = self.config.chunk_size.min(n_samples);
121
122 let rbf_sampler = RBFSampler::new(self.n_components).gamma(self.gamma);
124 let fitted_sampler = Arc::new(rbf_sampler.fit(x, &())?);
125
126 let chunks: Vec<_> = (0..n_samples)
128 .step_by(chunk_size)
129 .map(|start| {
130 let end = (start + chunk_size).min(n_samples);
131 (start, end)
132 })
133 .collect();
134
135 let results: Result<Vec<_>> = chunks
137 .par_iter()
138 .map(|&(start, end)| {
139 let chunk = x.slice(s![start..end, ..]).to_owned();
140 fitted_sampler
141 .transform(&chunk)
142 .map(|result| (start, result))
143 })
144 .collect();
145
146 let results = results?;
147
148 let mut output = Array2::zeros((n_samples, self.n_components));
150 for (start, chunk_result) in results {
151 let end = start + chunk_result.nrows();
152 output.slice_mut(s![start..end, ..]).assign(&chunk_result);
153 }
154
155 Ok(output)
156 }
157}
158
159pub struct FittedMemoryEfficientRBFSampler {
161 random_weights: Array2<f64>,
162 random_offset: Array1<f64>,
163 gamma: f64,
164 config: MemoryConfig,
165}
166
167impl Fit<Array2<f64>, ()> for MemoryEfficientRBFSampler {
168 type Fitted = FittedMemoryEfficientRBFSampler;
169
170 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
171 let n_features = x.ncols();
172
173 let mut rng = if let Some(seed) = self.random_seed {
174 StdRng::seed_from_u64(seed)
175 } else {
176 StdRng::from_seed(thread_rng().gen())
177 };
178
179 let random_weights = Array2::from_shape_fn((self.n_components, n_features), |_| {
181 rng.sample(RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap())
182 });
183
184 let random_offset = Array1::from_shape_fn(self.n_components, |_| {
185 rng.sample(RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap())
186 });
187
188 Ok(FittedMemoryEfficientRBFSampler {
189 random_weights,
190 random_offset,
191 gamma: self.gamma,
192 config: self.config.clone(),
193 })
194 }
195}
196
197impl Transform<Array2<f64>, Array2<f64>> for FittedMemoryEfficientRBFSampler {
198 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
199 let n_samples = x.nrows();
200 let chunk_size = self.config.chunk_size.min(n_samples);
201
202 if n_samples <= chunk_size {
203 self.transform_small(x)
205 } else {
206 self.transform_chunked(x)
208 }
209 }
210}
211
212impl FittedMemoryEfficientRBFSampler {
213 fn transform_small(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
214 let projection = x.dot(&self.random_weights.t());
215 let scaled_projection = projection + &self.random_offset;
216
217 let normalization = (2.0 / self.random_weights.nrows() as f64).sqrt();
218 Ok(scaled_projection.mapv(|v| v.cos() * normalization))
219 }
220
221 fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
222 let n_samples = x.nrows();
223 let chunk_size = self.config.chunk_size;
224 let mut output = Array2::zeros((n_samples, self.random_weights.nrows()));
225
226 for chunk_start in (0..n_samples).step_by(chunk_size) {
227 let chunk_end = (chunk_start + chunk_size).min(n_samples);
228 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
229
230 let chunk_transformed = self.transform_small(&chunk.to_owned())?;
231 output
232 .slice_mut(s![chunk_start..chunk_end, ..])
233 .assign(&chunk_transformed);
234 }
235
236 Ok(output)
237 }
238}
239
240#[derive(Debug, Clone)]
242pub struct MemoryEfficientNystroem {
244 n_components: usize,
245 kernel: String,
246 gamma: Option<f64>,
247 degree: Option<i32>,
248 coef0: Option<f64>,
249 sampling: SamplingStrategy,
250 config: MemoryConfig,
251 random_seed: Option<u64>,
252}
253
254impl MemoryEfficientNystroem {
255 pub fn new(n_components: usize) -> Self {
257 Self {
258 n_components,
259 kernel: "rbf".to_string(),
260 gamma: None,
261 degree: None,
262 coef0: None,
263 sampling: SamplingStrategy::Random,
264 config: MemoryConfig::default(),
265 random_seed: None,
266 }
267 }
268
269 pub fn kernel(mut self, kernel: &str) -> Self {
271 self.kernel = kernel.to_string();
272 self
273 }
274
275 pub fn gamma(mut self, gamma: f64) -> Self {
277 self.gamma = Some(gamma);
278 self
279 }
280
281 pub fn sampling(mut self, sampling: SamplingStrategy) -> Self {
283 self.sampling = sampling;
284 self
285 }
286
287 pub fn config(mut self, config: MemoryConfig) -> Self {
289 self.config = config;
290 self
291 }
292
293 pub fn fit_incremental(
295 &self,
296 x_chunks: Vec<Array2<f64>>,
297 ) -> Result<FittedMemoryEfficientNystroem> {
298 let mut representative_samples = Vec::new();
300 let samples_per_chunk = self.n_components / x_chunks.len().max(1);
301
302 for chunk in &x_chunks {
303 let n_samples = chunk.nrows().min(samples_per_chunk);
304 if n_samples > 0 {
305 let indices: Vec<usize> = (0..chunk.nrows()).collect();
306 let selected_indices = &indices[..n_samples];
307
308 for &idx in selected_indices {
309 representative_samples.push(chunk.row(idx).to_owned());
310 }
311 }
312 }
313
314 if representative_samples.is_empty() {
315 return Err(SklearsError::InvalidInput(
316 "No samples found in chunks".to_string(),
317 ));
318 }
319
320 let n_selected = representative_samples.len().min(self.n_components);
322 let n_features = representative_samples[0].len();
323 let mut combined_data = Array2::zeros((n_selected, n_features));
324
325 for (i, sample) in representative_samples.iter().take(n_selected).enumerate() {
326 combined_data.row_mut(i).assign(sample);
327 }
328
329 let kernel = match self.kernel.as_str() {
331 "rbf" => Kernel::Rbf {
332 gamma: self.gamma.unwrap_or(1.0),
333 },
334 "linear" => Kernel::Linear,
335 "polynomial" => Kernel::Polynomial {
336 gamma: self.gamma.unwrap_or(1.0),
337 degree: self.degree.unwrap_or(3) as u32,
338 coef0: self.coef0.unwrap_or(1.0),
339 },
340 _ => Kernel::Rbf { gamma: 1.0 }, };
342 let nystroem = Nystroem::new(kernel, n_selected).sampling_strategy(self.sampling.clone());
343
344 let fitted_nystroem = nystroem.fit(&combined_data, &())?;
345
346 Ok(FittedMemoryEfficientNystroem {
347 fitted_nystroem,
348 config: self.config.clone(),
349 })
350 }
351}
352
353pub struct FittedMemoryEfficientNystroem {
355 fitted_nystroem: crate::nystroem::Nystroem<Trained>,
356 config: MemoryConfig,
357}
358
359impl Fit<Array2<f64>, ()> for MemoryEfficientNystroem {
360 type Fitted = FittedMemoryEfficientNystroem;
361
362 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
363 let kernel = match self.kernel.as_str() {
364 "rbf" => Kernel::Rbf {
365 gamma: self.gamma.unwrap_or(1.0),
366 },
367 "linear" => Kernel::Linear,
368 "polynomial" => Kernel::Polynomial {
369 gamma: self.gamma.unwrap_or(1.0),
370 degree: self.degree.unwrap_or(3) as u32,
371 coef0: self.coef0.unwrap_or(1.0),
372 },
373 _ => Kernel::Rbf { gamma: 1.0 }, };
375 let nystroem =
376 Nystroem::new(kernel, self.n_components).sampling_strategy(self.sampling.clone());
377
378 let fitted_nystroem = nystroem.fit(x, &())?;
379
380 Ok(FittedMemoryEfficientNystroem {
381 fitted_nystroem,
382 config: self.config.clone(),
383 })
384 }
385}
386
387impl Transform<Array2<f64>, Array2<f64>> for FittedMemoryEfficientNystroem {
388 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
389 let n_samples = x.nrows();
390 let chunk_size = self.config.chunk_size;
391
392 if n_samples <= chunk_size {
393 self.fitted_nystroem.transform(x)
395 } else {
396 self.transform_chunked(x)
398 }
399 }
400}
401
402impl FittedMemoryEfficientNystroem {
403 fn transform_chunked(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
404 let n_samples = x.nrows();
405 let chunk_size = self.config.chunk_size;
406 let n_components = self
407 .fitted_nystroem
408 .transform(&x.slice(s![0..1, ..]).to_owned())?
409 .ncols();
410
411 let mut output = Array2::zeros((n_samples, n_components));
412
413 for chunk_start in (0..n_samples).step_by(chunk_size) {
414 let chunk_end = (chunk_start + chunk_size).min(n_samples);
415 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
416
417 let chunk_transformed = self.fitted_nystroem.transform(&chunk.to_owned())?;
418 output
419 .slice_mut(s![chunk_start..chunk_end, ..])
420 .assign(&chunk_transformed);
421 }
422
423 Ok(output)
424 }
425}
426
427pub struct MemoryMonitor {
429 max_memory_bytes: usize,
430 current_usage: Arc<Mutex<usize>>,
431}
432
433impl MemoryMonitor {
434 pub fn new(max_memory_bytes: usize) -> Self {
436 Self {
437 max_memory_bytes,
438 current_usage: Arc::new(Mutex::new(0)),
439 }
440 }
441
442 pub fn can_allocate(&self, bytes: usize) -> bool {
444 let current = *self.current_usage.lock().unwrap();
445 current + bytes <= self.max_memory_bytes
446 }
447
448 pub fn allocate(&self, bytes: usize) -> Result<()> {
450 let mut current = self.current_usage.lock().unwrap();
451 if *current + bytes > self.max_memory_bytes {
452 return Err(SklearsError::InvalidInput(format!(
453 "Memory limit exceeded: {} + {} > {}",
454 *current, bytes, self.max_memory_bytes
455 )));
456 }
457 *current += bytes;
458 Ok(())
459 }
460
461 pub fn deallocate(&self, bytes: usize) {
463 let mut current = self.current_usage.lock().unwrap();
464 *current = current.saturating_sub(bytes);
465 }
466
467 pub fn current_usage(&self) -> usize {
469 *self.current_usage.lock().unwrap()
470 }
471
472 pub fn usage_percentage(&self) -> f64 {
474 let current = *self.current_usage.lock().unwrap();
475 (current as f64 / self.max_memory_bytes as f64) * 100.0
476 }
477}
478
479#[allow(non_snake_case)]
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use approx::assert_abs_diff_eq;
484 use scirs2_core::ndarray::Array2;
485
486 #[test]
487 fn test_memory_efficient_rbf_sampler() {
488 let x = Array2::from_shape_vec((100, 10), (0..1000).map(|i| i as f64).collect()).unwrap();
489
490 let sampler = MemoryEfficientRBFSampler::new(50)
491 .gamma(0.1)
492 .config(MemoryConfig {
493 chunk_size: 30,
494 ..Default::default()
495 });
496
497 let fitted = sampler.fit(&x, &()).unwrap();
498 let transformed = fitted.transform(&x).unwrap();
499
500 assert_eq!(transformed.shape(), &[100, 50]);
501
502 let small_x = x.slice(s![0..10, ..]).to_owned();
504 let small_transformed = fitted.transform(&small_x).unwrap();
505 let chunked_transformed = transformed.slice(s![0..10, ..]);
506
507 assert_abs_diff_eq!(small_transformed, chunked_transformed, epsilon = 1e-10);
508 }
509
510 #[test]
511 fn test_memory_efficient_rbf_chunked_parallel() {
512 let x =
513 Array2::from_shape_vec((200, 5), (0..1000).map(|i| i as f64 * 0.1).collect()).unwrap();
514
515 let sampler = MemoryEfficientRBFSampler::new(30)
516 .gamma(1.0)
517 .config(MemoryConfig {
518 chunk_size: 50,
519 n_workers: 2,
520 ..Default::default()
521 });
522
523 let result = sampler.transform_chunked_parallel(&x).unwrap();
524 assert_eq!(result.shape(), &[200, 30]);
525
526 let mean_val = result.mean().unwrap();
528 let std_val = result.std(0.0);
529 assert!(mean_val.abs() < 0.5); assert!(std_val > 0.1); }
532
533 #[test]
534 fn test_memory_efficient_nystroem() {
535 let x =
536 Array2::from_shape_vec((80, 6), (0..480).map(|i| i as f64 * 0.01).collect()).unwrap();
537
538 let nystroem = MemoryEfficientNystroem::new(20)
539 .kernel("rbf")
540 .gamma(0.5)
541 .config(MemoryConfig {
542 chunk_size: 25,
543 ..Default::default()
544 });
545
546 let fitted = nystroem.fit(&x, &()).unwrap();
547 let transformed = fitted.transform(&x).unwrap();
548
549 assert_eq!(transformed.shape(), &[80, 20]);
550 }
551
552 #[test]
553 fn test_memory_efficient_nystroem_incremental() {
554 let chunk1 =
556 Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.1).collect()).unwrap();
557 let chunk2 =
558 Array2::from_shape_vec((40, 4), (120..280).map(|i| i as f64 * 0.1).collect()).unwrap();
559 let chunk3 =
560 Array2::from_shape_vec((30, 4), (280..400).map(|i| i as f64 * 0.1).collect()).unwrap();
561
562 let chunks = vec![chunk1, chunk2.clone(), chunk3];
563
564 let nystroem = MemoryEfficientNystroem::new(15)
565 .kernel("rbf")
566 .config(MemoryConfig {
567 chunk_size: 20,
568 ..Default::default()
569 });
570
571 let fitted = nystroem.fit_incremental(chunks).unwrap();
572 let transformed = fitted.transform(&chunk2).unwrap();
573
574 assert_eq!(transformed.shape(), &[40, 15]);
575 }
576
577 #[test]
578 fn test_memory_monitor() {
579 let monitor = MemoryMonitor::new(1000);
580
581 assert!(monitor.can_allocate(500));
582 assert!(monitor.allocate(500).is_ok());
583 assert_eq!(monitor.current_usage(), 500);
584 assert_eq!(monitor.usage_percentage(), 50.0);
585
586 assert!(!monitor.can_allocate(600)); assert!(monitor.allocate(400).is_ok()); assert!(monitor.allocate(200).is_err()); monitor.deallocate(300);
592 assert_eq!(monitor.current_usage(), 600);
593 assert!(monitor.can_allocate(300));
594 }
595
596 #[test]
597 fn test_memory_config() {
598 let config = MemoryConfig::default();
599 assert_eq!(config.max_memory_bytes, 1024 * 1024 * 1024);
600 assert_eq!(config.chunk_size, 10000);
601 assert!(config.n_workers > 0);
602
603 let custom_config = MemoryConfig {
604 max_memory_bytes: 512 * 1024 * 1024,
605 chunk_size: 5000,
606 n_workers: 4,
607 enable_disk_cache: true,
608 temp_dir: "/custom/temp".to_string(),
609 };
610
611 let sampler = MemoryEfficientRBFSampler::new(50).config(custom_config.clone());
612 assert_eq!(sampler.config.max_memory_bytes, 512 * 1024 * 1024);
613 assert_eq!(sampler.config.chunk_size, 5000);
614 assert_eq!(sampler.config.n_workers, 4);
615 assert!(sampler.config.enable_disk_cache);
616 assert_eq!(sampler.config.temp_dir, "/custom/temp");
617 }
618
619 #[test]
620 fn test_reproducibility() {
621 let x =
622 Array2::from_shape_vec((50, 8), (0..400).map(|i| i as f64 * 0.05).collect()).unwrap();
623
624 let sampler1 = MemoryEfficientRBFSampler::new(20)
625 .gamma(0.2)
626 .random_seed(42);
627
628 let sampler2 = MemoryEfficientRBFSampler::new(20)
629 .gamma(0.2)
630 .random_seed(42);
631
632 let fitted1 = sampler1.fit(&x, &()).unwrap();
633 let fitted2 = sampler2.fit(&x, &()).unwrap();
634
635 let result1 = fitted1.transform(&x).unwrap();
636 let result2 = fitted2.transform(&x).unwrap();
637
638 assert_abs_diff_eq!(result1, result2, epsilon = 1e-10);
639 }
640}