1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::error::{Result, SklearsError};
9use sklears_core::prelude::Predict;
10use sklears_core::traits::{Estimator, Fit, Trained, Untrained};
11use sklears_core::types::Float;
12use std::collections::VecDeque;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18#[derive(Debug, Clone)]
20pub struct MemoryEfficientConfig {
21 pub max_estimators_in_memory: usize,
23 pub batch_size: usize,
25 pub window_size: Option<usize>,
27 pub lazy_evaluation: bool,
29 pub memory_threshold_mb: usize,
31 pub compress_models: bool,
33 pub use_disk_cache: bool,
35 pub cache_dir: Option<String>,
37 pub learning_rate_decay: Float,
39 pub forgetting_factor: Float,
41 pub adaptive_batch_size: bool,
43}
44
45impl Default for MemoryEfficientConfig {
46 fn default() -> Self {
47 Self {
48 max_estimators_in_memory: 50,
49 batch_size: 1000,
50 window_size: Some(10000),
51 lazy_evaluation: true,
52 memory_threshold_mb: 512,
53 compress_models: false,
54 use_disk_cache: false,
55 cache_dir: None,
56 learning_rate_decay: 0.999,
57 forgetting_factor: 0.95,
58 adaptive_batch_size: true,
59 }
60 }
61}
62
63pub struct MemoryEfficientEnsemble<State = Untrained> {
65 config: MemoryEfficientConfig,
66 state: PhantomData<State>,
67 active_models_: Option<Vec<Box<dyn IncrementalModel>>>,
69 model_weights_: Option<Array1<Float>>,
71 memory_usage_: usize,
73 total_models_created_: usize,
74 data_buffer_: Option<VecDeque<(Array1<Float>, Float)>>,
76 prediction_cache_: Option<std::collections::HashMap<u64, Float>>,
78 performance_history_: Vec<Float>,
80 current_learning_rate_: Float,
81}
82
83pub trait IncrementalModel: Send + Sync {
85 fn partial_fit(&mut self, x: &Array1<Float>, y: Float) -> Result<()>;
87
88 fn predict_single(&self, x: &Array1<Float>) -> Result<Float>;
90
91 fn complexity(&self) -> usize;
93
94 fn serialize(&self) -> Result<Vec<u8>>;
96
97 fn clone_model(&self) -> Box<dyn IncrementalModel>;
99}
100
101pub fn deserialize_incremental_model(data: &[u8]) -> Result<Box<dyn IncrementalModel>> {
103 IncrementalLinearRegression::deserialize(data)
104}
105
106#[derive(Debug, Clone)]
108pub struct IncrementalLinearRegression {
109 weights: Array1<Float>,
110 bias: Float,
111 n_features: usize,
112 learning_rate: Float,
113 l2_reg: Float,
114}
115
116impl IncrementalLinearRegression {
117 pub fn new(n_features: usize, learning_rate: Float, l2_reg: Float) -> Self {
118 Self {
119 weights: Array1::zeros(n_features),
120 bias: 0.0,
121 n_features,
122 learning_rate,
123 l2_reg,
124 }
125 }
126}
127
128impl IncrementalModel for IncrementalLinearRegression {
129 fn partial_fit(&mut self, x: &Array1<Float>, y: Float) -> Result<()> {
130 if x.len() != self.n_features {
131 return Err(SklearsError::ShapeMismatch {
132 expected: format!("{} features", self.n_features),
133 actual: format!("{} features", x.len()),
134 });
135 }
136
137 let y_pred = self.weights.dot(x) + self.bias;
139
140 let error = y - y_pred;
142
143 for i in 0..self.n_features {
145 self.weights[i] += self.learning_rate * (error * x[i] - self.l2_reg * self.weights[i]);
146 }
147
148 self.bias += self.learning_rate * error;
150
151 Ok(())
152 }
153
154 fn predict_single(&self, x: &Array1<Float>) -> Result<Float> {
155 if x.len() != self.n_features {
156 return Err(SklearsError::ShapeMismatch {
157 expected: format!("{} features", self.n_features),
158 actual: format!("{} features", x.len()),
159 });
160 }
161
162 Ok(self.weights.dot(x) + self.bias)
163 }
164
165 fn complexity(&self) -> usize {
166 self.n_features * 8 + 8 + 32
168 }
169
170 fn serialize(&self) -> Result<Vec<u8>> {
171 let mut data = Vec::new();
174
175 data.extend_from_slice(&self.n_features.to_le_bytes());
177
178 data.extend_from_slice(&self.learning_rate.to_le_bytes());
180
181 data.extend_from_slice(&self.l2_reg.to_le_bytes());
183
184 data.extend_from_slice(&self.bias.to_le_bytes());
186
187 for &weight in self.weights.iter() {
189 data.extend_from_slice(&weight.to_le_bytes());
190 }
191
192 Ok(data)
193 }
194
195 fn clone_model(&self) -> Box<dyn IncrementalModel> {
196 Box::new(self.clone())
197 }
198}
199
200impl IncrementalLinearRegression {
201 pub fn deserialize(data: &[u8]) -> Result<Box<dyn IncrementalModel>> {
202 if data.len() < 32 {
203 return Err(SklearsError::InvalidInput(
204 "Insufficient data for deserialization".to_string(),
205 ));
206 }
207
208 let mut offset = 0;
209
210 let n_features = usize::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
212 offset += 8;
213
214 let learning_rate = Float::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
216 offset += 8;
217
218 let l2_reg = Float::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
220 offset += 8;
221
222 let bias = Float::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
224 offset += 8;
225
226 let mut weights = Array1::zeros(n_features);
228 for i in 0..n_features {
229 weights[i] = Float::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
230 offset += 8;
231 }
232
233 Ok(Box::new(IncrementalLinearRegression {
234 weights,
235 bias,
236 n_features,
237 learning_rate,
238 l2_reg,
239 }))
240 }
241}
242
243impl<State> MemoryEfficientEnsemble<State> {
244 pub fn memory_usage(&self) -> usize {
246 self.memory_usage_
247 }
248
249 pub fn active_model_count(&self) -> usize {
251 self.active_models_
252 .as_ref()
253 .map_or(0, |models| models.len())
254 }
255
256 pub fn total_models_created(&self) -> usize {
258 self.total_models_created_
259 }
260
261 pub fn needs_cleanup(&self) -> bool {
263 self.memory_usage_ > self.config.memory_threshold_mb * 1024 * 1024
264 }
265}
266
267impl MemoryEfficientEnsemble<Untrained> {
268 pub fn new() -> Self {
270 Self {
271 config: MemoryEfficientConfig::default(),
272 state: PhantomData,
273 active_models_: None,
274 model_weights_: None,
275 memory_usage_: 0,
276 total_models_created_: 0,
277 data_buffer_: None,
278 prediction_cache_: None,
279 performance_history_: Vec::new(),
280 current_learning_rate_: 0.01,
281 }
282 }
283
284 pub fn max_estimators_in_memory(mut self, max_estimators: usize) -> Self {
286 self.config.max_estimators_in_memory = max_estimators;
287 self
288 }
289
290 pub fn batch_size(mut self, batch_size: usize) -> Self {
292 self.config.batch_size = batch_size;
293 self
294 }
295
296 pub fn window_size(mut self, window_size: Option<usize>) -> Self {
298 self.config.window_size = window_size;
299 self
300 }
301
302 pub fn lazy_evaluation(mut self, enabled: bool) -> Self {
304 self.config.lazy_evaluation = enabled;
305 self
306 }
307
308 pub fn memory_threshold_mb(mut self, threshold: usize) -> Self {
310 self.config.memory_threshold_mb = threshold;
311 self
312 }
313
314 pub fn compress_models(mut self, enabled: bool) -> Self {
316 self.config.compress_models = enabled;
317 self
318 }
319
320 pub fn use_disk_cache(mut self, enabled: bool, cache_dir: Option<String>) -> Self {
322 self.config.use_disk_cache = enabled;
323 self.config.cache_dir = cache_dir;
324 self
325 }
326
327 pub fn learning_rate_decay(mut self, decay: Float) -> Self {
329 self.config.learning_rate_decay = decay;
330 self
331 }
332
333 pub fn forgetting_factor(mut self, factor: Float) -> Self {
335 self.config.forgetting_factor = factor;
336 self
337 }
338
339 pub fn adaptive_batch_size(mut self, enabled: bool) -> Self {
341 self.config.adaptive_batch_size = enabled;
342 self
343 }
344
345 pub fn for_large_datasets() -> Self {
347 Self::new()
348 .max_estimators_in_memory(20)
349 .batch_size(5000)
350 .window_size(Some(50000))
351 .lazy_evaluation(true)
352 .memory_threshold_mb(1024)
353 .compress_models(true)
354 .use_disk_cache(true, Some("/tmp/sklears_cache".to_string()))
355 .learning_rate_decay(0.995)
356 .adaptive_batch_size(true)
357 }
358
359 pub fn for_streaming() -> Self {
361 Self::new()
362 .max_estimators_in_memory(10)
363 .batch_size(100)
364 .window_size(Some(5000))
365 .lazy_evaluation(true)
366 .memory_threshold_mb(256)
367 .forgetting_factor(0.9)
368 .adaptive_batch_size(true)
369 }
370}
371
372impl MemoryEfficientEnsemble<Trained> {
373 pub fn predict_lazy(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
375 if !self.config.lazy_evaluation {
376 return self.predict(x);
377 }
378
379 let mut predictions = Array1::zeros(x.nrows());
380
381 if let Some(cache) = &self.prediction_cache_ {
383 for (i, row) in x.axis_iter(Axis(0)).enumerate() {
384 let hash = self.hash_input(&row.to_owned())?;
385 if let Some(&cached_pred) = cache.get(&hash) {
386 predictions[i] = cached_pred;
387 } else {
388 predictions[i] = self.predict_single_internal(&row.to_owned())?;
389 }
390 }
391 } else {
392 for (i, row) in x.axis_iter(Axis(0)).enumerate() {
393 predictions[i] = self.predict_single_internal(&row.to_owned())?;
394 }
395 }
396
397 Ok(predictions)
398 }
399
400 fn predict_single_internal(&self, x: &Array1<Float>) -> Result<Float> {
402 if let Some(models) = &self.active_models_ {
403 if models.is_empty() {
404 return Err(SklearsError::NotFitted {
405 operation: "prediction".to_string(),
406 });
407 }
408
409 let mut weighted_sum = 0.0;
410 let mut total_weight = 0.0;
411
412 for (i, model) in models.iter().enumerate() {
413 let prediction = model.predict_single(x)?;
414 let weight = self.model_weights_.as_ref().map(|w| w[i]).unwrap_or(1.0);
415
416 weighted_sum += prediction * weight;
417 total_weight += weight;
418 }
419
420 if total_weight > 0.0 {
421 Ok(weighted_sum / total_weight)
422 } else {
423 Ok(0.0)
424 }
425 } else {
426 Err(SklearsError::NotFitted {
427 operation: "prediction".to_string(),
428 })
429 }
430 }
431
432 pub fn partial_fit(&mut self, x: &Array1<Float>, y: Float) -> Result<()> {
434 if let Some(buffer) = &mut self.data_buffer_ {
436 buffer.push_back((x.clone(), y));
437
438 if let Some(window_size) = self.config.window_size {
440 while buffer.len() > window_size {
441 buffer.pop_front();
442 }
443 }
444 } else {
445 self.data_buffer_ = Some(VecDeque::new());
446 self.data_buffer_
447 .as_mut()
448 .unwrap()
449 .push_back((x.clone(), y));
450 }
451
452 if let Some(models) = &mut self.active_models_ {
454 for model in models.iter_mut() {
455 model.partial_fit(x, y)?;
456 }
457 }
458
459 if self.should_add_model() {
461 self.add_new_model(x.len())?;
462 }
463
464 self.current_learning_rate_ *= self.config.learning_rate_decay;
466
467 if self.needs_cleanup() {
469 self.cleanup_memory()?;
470 }
471
472 Ok(())
473 }
474
475 fn should_add_model(&self) -> bool {
477 if self.performance_history_.len() < 10 {
479 return true;
480 }
481
482 let recent_performance = self
483 .performance_history_
484 .iter()
485 .rev()
486 .take(5)
487 .sum::<Float>()
488 / 5.0;
489 let older_performance = self
490 .performance_history_
491 .iter()
492 .rev()
493 .skip(5)
494 .take(5)
495 .sum::<Float>()
496 / 5.0;
497
498 recent_performance < older_performance * 0.95 }
500
501 fn add_new_model(&mut self, n_features: usize) -> Result<()> {
503 if let Some(models) = &mut self.active_models_ {
504 let new_model = Box::new(IncrementalLinearRegression::new(
506 n_features,
507 self.current_learning_rate_,
508 0.001, ));
510
511 models.push(new_model as Box<dyn IncrementalModel>);
512 self.total_models_created_ += 1;
513
514 self.memory_usage_ += n_features * 8 + 64; if models.len() > self.config.max_estimators_in_memory && !models.is_empty() {
519 let removed_model = models.remove(0);
520 self.memory_usage_ = self
521 .memory_usage_
522 .saturating_sub(removed_model.complexity());
523 }
524
525 self.update_model_weights()?;
527 } else {
528 let mut models = Vec::new();
530 let new_model = Box::new(IncrementalLinearRegression::new(
531 n_features,
532 self.current_learning_rate_,
533 0.001,
534 ));
535 models.push(new_model as Box<dyn IncrementalModel>);
536 self.active_models_ = Some(models);
537 self.total_models_created_ = 1;
538 self.memory_usage_ = n_features * 8 + 64;
539 self.update_model_weights()?;
540 }
541
542 Ok(())
543 }
544
545 fn update_model_weights(&mut self) -> Result<()> {
547 if let Some(models) = &self.active_models_ {
548 let n_models = models.len();
549 if n_models == 0 {
550 return Ok(());
551 }
552
553 let weights = Array1::from_elem(n_models, 1.0 / n_models as Float);
556 self.model_weights_ = Some(weights);
557 }
558
559 Ok(())
560 }
561
562 fn cleanup_memory(&mut self) -> Result<()> {
564 if let Some(models) = &mut self.active_models_ {
565 let target_size = self.config.max_estimators_in_memory / 2;
566
567 while models.len() > target_size {
568 if !models.is_empty() {
569 let removed_model = models.remove(0);
570 self.memory_usage_ = self
571 .memory_usage_
572 .saturating_sub(removed_model.complexity());
573 }
574 }
575
576 self.update_model_weights()?;
578 }
579
580 if let Some(cache) = &mut self.prediction_cache_ {
582 cache.clear();
583 }
584
585 Ok(())
586 }
587
588 fn hash_input(&self, x: &Array1<Float>) -> Result<u64> {
590 use std::collections::hash_map::DefaultHasher;
591 use std::hash::{Hash, Hasher};
592
593 let mut hasher = DefaultHasher::new();
594 for &val in x.iter() {
595 val.to_bits().hash(&mut hasher);
596 }
597 Ok(hasher.finish())
598 }
599}
600
601impl Estimator for MemoryEfficientEnsemble<Untrained> {
603 type Config = MemoryEfficientConfig;
604 type Error = SklearsError;
605 type Float = Float;
606
607 fn config(&self) -> &Self::Config {
608 &self.config
609 }
610}
611
612impl Fit<Array2<Float>, Array1<Float>> for MemoryEfficientEnsemble<Untrained> {
613 type Fitted = MemoryEfficientEnsemble<Trained>;
614
615 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
616 let n_samples = x.nrows();
617 let n_features = x.ncols();
618
619 if n_samples != y.len() {
620 return Err(SklearsError::ShapeMismatch {
621 expected: format!("{} samples", n_samples),
622 actual: format!("{} samples", y.len()),
623 });
624 }
625
626 let config = self.config.clone();
628 let mut ensemble = MemoryEfficientEnsemble::<Trained> {
629 config: config.clone(),
630 state: PhantomData,
631 active_models_: Some(Vec::new()),
632 model_weights_: None,
633 memory_usage_: 0,
634 total_models_created_: 0,
635 data_buffer_: Some(VecDeque::new()),
636 prediction_cache_: if self.config.lazy_evaluation {
637 Some(std::collections::HashMap::new())
638 } else {
639 None
640 },
641 performance_history_: Vec::new(),
642 current_learning_rate_: 0.01,
643 };
644
645 let batch_size = config.batch_size;
647 let mut current_batch_size = batch_size;
648
649 for start_idx in (0..n_samples).step_by(current_batch_size) {
650 let end_idx = (start_idx + current_batch_size).min(n_samples);
651
652 for i in start_idx..end_idx {
654 let x_sample = x.row(i).to_owned();
655 let y_sample = y[i];
656
657 if ensemble.active_models_.as_ref().unwrap().is_empty() {
659 ensemble.add_new_model(n_features)?;
660 }
661
662 ensemble.partial_fit(&x_sample, y_sample)?;
664 }
665
666 if config.adaptive_batch_size {
668 if ensemble.memory_usage_ > config.memory_threshold_mb * 1024 * 1024 / 2 {
669 current_batch_size /= 2;
670 } else if ensemble.memory_usage_ < config.memory_threshold_mb * 1024 * 1024 / 4 {
671 current_batch_size = (current_batch_size * 3 / 2).min(batch_size * 2);
672 }
673 current_batch_size = current_batch_size.max(100); }
675 }
676
677 Ok(ensemble)
678 }
679}
680
681impl Predict<Array2<Float>, Array1<Float>> for MemoryEfficientEnsemble<Trained> {
682 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
683 let mut predictions = Array1::zeros(x.nrows());
684
685 for (i, row) in x.axis_iter(Axis(0)).enumerate() {
686 predictions[i] = self.predict_single_internal(&row.to_owned())?;
687 }
688
689 Ok(predictions)
690 }
691}
692
693impl Default for MemoryEfficientEnsemble<Untrained> {
694 fn default() -> Self {
695 Self::new()
696 }
697}
698
699#[allow(non_snake_case)]
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use scirs2_core::ndarray::array;
704
705 #[test]
706 fn test_memory_efficient_ensemble_basic() {
707 let x = Array2::from_shape_vec(
708 (10, 2),
709 vec![
710 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
711 9.0, 10.0, 10.0, 11.0,
712 ],
713 )
714 .unwrap();
715
716 let y = array![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0];
717
718 let ensemble = MemoryEfficientEnsemble::new()
719 .max_estimators_in_memory(5)
720 .batch_size(3);
721
722 let trained = ensemble.fit(&x, &y).unwrap();
723
724 assert!(trained.active_model_count() > 0);
725 assert!(trained.memory_usage() > 0);
726
727 let predictions = trained.predict(&x).unwrap();
728 assert_eq!(predictions.len(), x.nrows());
729 }
730
731 #[test]
732 fn test_incremental_learning() {
733 let mut ensemble = MemoryEfficientEnsemble::new()
734 .max_estimators_in_memory(3)
735 .batch_size(2);
736
737 let x =
739 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0]).unwrap();
740 let y = array![3.0, 5.0, 7.0, 9.0];
741
742 let mut trained = ensemble.fit(&x, &y).unwrap();
743
744 let x_new = array![5.0, 6.0];
746 trained.partial_fit(&x_new, 11.0).unwrap();
747
748 assert!(trained.active_model_count() > 0);
749
750 let test_x = Array2::from_shape_vec((1, 2), vec![3.0, 4.0]).unwrap();
752 let predictions = trained.predict(&test_x).unwrap();
753 assert_eq!(predictions.len(), 1);
754 }
755
756 #[test]
757 fn test_memory_management() {
758 let ensemble = MemoryEfficientEnsemble::new()
759 .max_estimators_in_memory(2)
760 .memory_threshold_mb(1); let x = Array2::from_shape_vec((20, 5), (0..100).map(|i| i as Float).collect()).unwrap();
763 let y = Array1::from_shape_vec(20, (0..20).map(|i| i as Float).collect()).unwrap();
764
765 let trained = ensemble.fit(&x, &y).unwrap();
766
767 assert!(trained.active_model_count() <= 2);
769 assert!(trained.total_models_created() >= trained.active_model_count());
770 }
771
772 #[test]
773 fn test_lazy_evaluation() {
774 let ensemble = MemoryEfficientEnsemble::new()
775 .lazy_evaluation(true)
776 .max_estimators_in_memory(3);
777
778 let x = Array2::from_shape_vec(
779 (6, 2),
780 vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
781 )
782 .unwrap();
783 let y = array![3.0, 5.0, 7.0, 9.0, 11.0, 13.0];
784
785 let trained = ensemble.fit(&x, &y).unwrap();
786
787 let predictions = trained.predict_lazy(&x).unwrap();
789 assert_eq!(predictions.len(), x.nrows());
790
791 let regular_predictions = trained.predict(&x).unwrap();
793 assert_eq!(regular_predictions.len(), x.nrows());
794 }
795}