1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
14use scirs2_core::random::{prelude::*, thread_rng, Distribution, Rng};
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17use sklears_core::error::SklearsError;
18use sklears_core::traits::{Fit, Predict};
19use std::collections::HashMap;
20
21#[derive(Debug, Clone)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub enum FewShotStrategy {
25 NearestPrototype,
27 KNearestNeighbors { k: usize },
29 SupportBased,
31 Centroid,
33 Probabilistic,
35}
36
37#[derive(Debug, Clone)]
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40pub enum TransferStrategy {
41 SourcePrior { source_weight: f64 },
43 FeatureBased { adaptation_rate: f64 },
45 InstanceBased { similarity_threshold: f64 },
47 ModelBased { confidence_threshold: f64 },
49 EnsembleTransfer { domain_weights: Vec<f64> },
51}
52
53#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56pub enum DomainAdaptationStrategy {
57 FeatureAlignment,
59 InstanceReweighting { adaptation_strength: f64 },
61 GradientReversal { lambda: f64 },
63 SubspaceAlignment { subspace_dim: usize },
65 MMDMinimization { bandwidth: f64 },
67}
68
69#[derive(Debug, Clone)]
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72pub enum ContinualStrategy {
73 ElasticWeightConsolidation { importance_weight: f64 },
75 Rehearsal { memory_size: usize },
77 Progressive { column_capacity: usize },
79 LearningWithoutForgetting { distillation_weight: f64 },
81 AGEM { memory_strength: f64 },
83}
84
85#[derive(Debug, Clone)]
87pub struct FewShotBaselineClassifier {
88 strategy: FewShotStrategy,
89 random_state: Option<u64>,
90}
91
92#[derive(Debug, Clone)]
94pub struct FittedFewShotClassifier {
95 strategy: FewShotStrategy,
96 prototypes: HashMap<i32, Array1<f64>>,
97 support_samples: Option<(Array2<f64>, Array1<i32>)>,
98 class_counts: HashMap<i32, usize>,
99 random_state: Option<u64>,
100}
101
102#[derive(Debug, Clone)]
104pub struct FewShotBaselineRegressor {
105 strategy: FewShotStrategy,
106 random_state: Option<u64>,
107}
108
109#[derive(Debug, Clone)]
111pub struct FittedFewShotRegressor {
112 strategy: FewShotStrategy,
113 prototypes: Vec<(Array1<f64>, f64)>,
114 support_samples: Option<(Array2<f64>, Array1<f64>)>,
115 random_state: Option<u64>,
116}
117
118#[derive(Debug, Clone)]
120pub struct TransferLearningBaseline {
121 strategy: TransferStrategy,
122 source_statistics: Option<SourceDomainStats>,
123 random_state: Option<u64>,
124}
125
126#[derive(Debug, Clone)]
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129pub struct SourceDomainStats {
130 feature_means: Array1<f64>,
131 feature_stds: Array1<f64>,
132 class_priors: HashMap<i32, f64>,
133 target_mean: f64,
134 target_std: f64,
135}
136
137#[derive(Debug, Clone)]
139pub struct FittedTransferBaseline {
140 strategy: TransferStrategy,
141 source_stats: SourceDomainStats,
142 target_stats: SourceDomainStats,
143 adaptation_weights: Array1<f64>,
144 random_state: Option<u64>,
145}
146
147#[derive(Debug, Clone)]
149pub struct DomainAdaptationBaseline {
150 strategy: DomainAdaptationStrategy,
151 source_domain_data: Option<(Array2<f64>, Array1<f64>)>,
152 random_state: Option<u64>,
153}
154
155#[derive(Debug, Clone)]
157pub struct FittedDomainAdaptationBaseline {
158 strategy: DomainAdaptationStrategy,
159 source_stats: SourceDomainStats,
160 target_stats: SourceDomainStats,
161 adaptation_matrix: Array2<f64>,
162 instance_weights: Array1<f64>,
163 random_state: Option<u64>,
164}
165
166impl FewShotBaselineClassifier {
167 pub fn new(strategy: FewShotStrategy) -> Self {
169 Self {
170 strategy,
171 random_state: None,
172 }
173 }
174
175 pub fn with_random_state(mut self, seed: u64) -> Self {
177 self.random_state = Some(seed);
178 self
179 }
180}
181
182impl Fit<Array2<f64>, Array1<i32>, FittedFewShotClassifier> for FewShotBaselineClassifier {
183 type Fitted = FittedFewShotClassifier;
184 fn fit(
185 self,
186 x: &Array2<f64>,
187 y: &Array1<i32>,
188 ) -> Result<FittedFewShotClassifier, SklearsError> {
189 if x.nrows() != y.len() {
190 return Err(SklearsError::ShapeMismatch {
191 expected: format!("{} samples", x.nrows()),
192 actual: format!("{} labels", y.len()),
193 });
194 }
195
196 let rng = self.random_state.map_or_else(
197 || Box::new(thread_rng()) as Box<dyn RngCore>,
198 |seed| Box::new(StdRng::seed_from_u64(seed)),
199 );
200
201 let mut class_counts = HashMap::new();
203 for &class in y.iter() {
204 *class_counts.entry(class).or_insert(0) += 1;
205 }
206
207 let mut prototypes = HashMap::new();
209 for &class in class_counts.keys() {
210 let class_indices: Vec<usize> = y
211 .iter()
212 .enumerate()
213 .filter(|(_, &label)| label == class)
214 .map(|(i, _)| i)
215 .collect();
216
217 if !class_indices.is_empty() {
218 let class_data: Vec<f64> = class_indices
219 .iter()
220 .flat_map(|&i| x.row(i).to_vec())
221 .collect();
222 let class_samples: Array2<f64> =
223 Array2::from_shape_vec((class_indices.len(), x.ncols()), class_data)?;
224
225 let prototype = class_samples.mean_axis(Axis(0)).unwrap();
226 prototypes.insert(class, prototype);
227 }
228 }
229
230 let support_samples = match self.strategy {
232 FewShotStrategy::KNearestNeighbors { .. } | FewShotStrategy::SupportBased => {
233 Some((x.clone(), y.clone()))
234 }
235 _ => None,
236 };
237
238 Ok(FittedFewShotClassifier {
239 strategy: self.strategy.clone(),
240 prototypes,
241 support_samples,
242 class_counts,
243 random_state: self.random_state,
244 })
245 }
246}
247
248impl Predict<Array2<f64>, Array1<i32>> for FittedFewShotClassifier {
249 fn predict(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError> {
250 let mut predictions = Vec::with_capacity(x.nrows());
251 let mut rng = self.random_state.map_or_else(
252 || Box::new(thread_rng()) as Box<dyn RngCore>,
253 |seed| Box::new(StdRng::seed_from_u64(seed)),
254 );
255
256 for sample in x.rows() {
257 let prediction = match &self.strategy {
258 FewShotStrategy::NearestPrototype => self.predict_nearest_prototype(&sample)?,
259 FewShotStrategy::KNearestNeighbors { k } => self.predict_knn(&sample, *k)?,
260 FewShotStrategy::SupportBased => self.predict_support_based(&sample)?,
261 FewShotStrategy::Centroid => self.predict_centroid(&sample)?,
262 FewShotStrategy::Probabilistic => self.predict_probabilistic(&sample, &mut *rng)?,
263 };
264 predictions.push(prediction);
265 }
266
267 Ok(Array1::from_vec(predictions))
268 }
269}
270
271impl FittedFewShotClassifier {
272 fn predict_nearest_prototype(&self, sample: &ArrayView1<f64>) -> Result<i32, SklearsError> {
273 let mut min_distance = f64::INFINITY;
274 let mut best_class = 0;
275
276 for (&class, prototype) in &self.prototypes {
277 let distance: f64 = sample
278 .iter()
279 .zip(prototype.iter())
280 .map(|(a, b)| (a - b).powi(2))
281 .sum::<f64>()
282 .sqrt();
283
284 if distance < min_distance {
285 min_distance = distance;
286 best_class = class;
287 }
288 }
289
290 Ok(best_class)
291 }
292
293 fn predict_knn(&self, sample: &ArrayView1<f64>, k: usize) -> Result<i32, SklearsError> {
294 if let Some((ref support_x, ref support_y)) = self.support_samples {
295 let mut distances: Vec<(f64, i32)> = Vec::new();
296
297 for (i, support_sample) in support_x.rows().into_iter().enumerate() {
298 let distance: f64 = sample
299 .iter()
300 .zip(support_sample.iter())
301 .map(|(a, b)| (a - b).powi(2))
302 .sum::<f64>()
303 .sqrt();
304 distances.push((distance, support_y[i]));
305 }
306
307 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
308
309 let k_neighbors = distances.into_iter().take(k).collect::<Vec<_>>();
310 let mut class_votes: HashMap<i32, usize> = HashMap::new();
311
312 for (_, class) in k_neighbors {
313 *class_votes.entry(class).or_insert(0) += 1;
314 }
315
316 let best_class = class_votes
317 .into_iter()
318 .max_by_key(|(_, count)| *count)
319 .map(|(class, _)| class)
320 .unwrap_or(0);
321
322 Ok(best_class)
323 } else {
324 self.predict_nearest_prototype(sample)
325 }
326 }
327
328 fn predict_support_based(&self, sample: &ArrayView1<f64>) -> Result<i32, SklearsError> {
329 if let Some((ref support_x, ref support_y)) = self.support_samples {
331 let mut weighted_votes: HashMap<i32, f64> = HashMap::new();
333
334 for (i, support_sample) in support_x.rows().into_iter().enumerate() {
335 let distance: f64 = sample
336 .iter()
337 .zip(support_sample.iter())
338 .map(|(a, b)| (a - b).powi(2))
339 .sum::<f64>()
340 .sqrt();
341
342 let weight = 1.0 / (1.0 + distance);
343 let class = support_y[i];
344 *weighted_votes.entry(class).or_insert(0.0) += weight;
345 }
346
347 let best_class = weighted_votes
348 .into_iter()
349 .max_by(|(_, weight_a), (_, weight_b)| weight_a.partial_cmp(weight_b).unwrap())
350 .map(|(class, _)| class)
351 .unwrap_or(0);
352
353 Ok(best_class)
354 } else {
355 self.predict_nearest_prototype(sample)
356 }
357 }
358
359 fn predict_centroid(&self, sample: &ArrayView1<f64>) -> Result<i32, SklearsError> {
360 self.predict_nearest_prototype(sample)
362 }
363
364 fn predict_probabilistic(
365 &self,
366 sample: &ArrayView1<f64>,
367 rng: &mut dyn RngCore,
368 ) -> Result<i32, SklearsError> {
369 let mut class_probs: HashMap<i32, f64> = HashMap::new();
371 let mut total_weight = 0.0;
372
373 for (&class, prototype) in &self.prototypes {
374 let distance: f64 = sample
375 .iter()
376 .zip(prototype.iter())
377 .map(|(a, b)| (a - b).powi(2))
378 .sum::<f64>()
379 .sqrt();
380
381 let weight = (-distance).exp();
382 class_probs.insert(class, weight);
383 total_weight += weight;
384 }
385
386 for (_, prob) in class_probs.iter_mut() {
388 *prob /= total_weight;
389 }
390
391 let rand_val: f64 = rng.gen();
393 let mut cumulative_prob = 0.0;
394
395 for (&class, &prob) in &class_probs {
396 cumulative_prob += prob;
397 if rand_val <= cumulative_prob {
398 return Ok(class);
399 }
400 }
401
402 let best_class = class_probs
404 .into_iter()
405 .max_by(|(_, prob_a), (_, prob_b)| prob_a.partial_cmp(prob_b).unwrap())
406 .map(|(class, _)| class)
407 .unwrap_or(0);
408
409 Ok(best_class)
410 }
411}
412
413impl FewShotBaselineRegressor {
414 pub fn new(strategy: FewShotStrategy) -> Self {
416 Self {
417 strategy,
418 random_state: None,
419 }
420 }
421
422 pub fn with_random_state(mut self, seed: u64) -> Self {
424 self.random_state = Some(seed);
425 self
426 }
427}
428
429impl Fit<Array2<f64>, Array1<f64>, FittedFewShotRegressor> for FewShotBaselineRegressor {
430 type Fitted = FittedFewShotRegressor;
431 fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedFewShotRegressor, SklearsError> {
432 if x.nrows() != y.len() {
433 return Err(SklearsError::ShapeMismatch {
434 expected: format!("{} samples", x.nrows()),
435 actual: format!("{} labels", y.len()),
436 });
437 }
438
439 let mut prototypes = Vec::new();
441 for (i, sample) in x.rows().into_iter().enumerate() {
442 prototypes.push((sample.to_owned(), y[i]));
443 }
444
445 let support_samples = match self.strategy {
447 FewShotStrategy::KNearestNeighbors { .. } | FewShotStrategy::SupportBased => {
448 Some((x.clone(), y.clone()))
449 }
450 _ => None,
451 };
452
453 Ok(FittedFewShotRegressor {
454 strategy: self.strategy.clone(),
455 prototypes,
456 support_samples,
457 random_state: self.random_state,
458 })
459 }
460}
461
462impl Predict<Array2<f64>, Array1<f64>> for FittedFewShotRegressor {
463 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
464 let mut predictions = Vec::with_capacity(x.nrows());
465 let mut rng = self.random_state.map_or_else(
466 || Box::new(thread_rng()) as Box<dyn RngCore>,
467 |seed| Box::new(StdRng::seed_from_u64(seed)),
468 );
469
470 for sample in x.rows() {
471 let prediction = match &self.strategy {
472 FewShotStrategy::NearestPrototype => self.predict_nearest_prototype(&sample)?,
473 FewShotStrategy::KNearestNeighbors { k } => self.predict_knn(&sample, *k)?,
474 FewShotStrategy::SupportBased => self.predict_support_based(&sample)?,
475 FewShotStrategy::Centroid => self.predict_centroid(&sample)?,
476 FewShotStrategy::Probabilistic => self.predict_probabilistic(&sample, &mut *rng)?,
477 };
478 predictions.push(prediction);
479 }
480
481 Ok(Array1::from_vec(predictions))
482 }
483}
484
485impl FittedFewShotRegressor {
486 fn predict_nearest_prototype(&self, sample: &ArrayView1<f64>) -> Result<f64, SklearsError> {
487 let mut min_distance = f64::INFINITY;
488 let mut best_value = 0.0;
489
490 for (prototype, value) in &self.prototypes {
491 let distance: f64 = sample
492 .iter()
493 .zip(prototype.iter())
494 .map(|(a, b)| (a - b).powi(2))
495 .sum::<f64>()
496 .sqrt();
497
498 if distance < min_distance {
499 min_distance = distance;
500 best_value = *value;
501 }
502 }
503
504 Ok(best_value)
505 }
506
507 fn predict_knn(&self, sample: &ArrayView1<f64>, k: usize) -> Result<f64, SklearsError> {
508 if let Some((ref support_x, ref support_y)) = self.support_samples {
509 let mut distances: Vec<(f64, f64)> = Vec::new();
510
511 for (i, support_sample) in support_x.rows().into_iter().enumerate() {
512 let distance: f64 = sample
513 .iter()
514 .zip(support_sample.iter())
515 .map(|(a, b)| (a - b).powi(2))
516 .sum::<f64>()
517 .sqrt();
518 distances.push((distance, support_y[i]));
519 }
520
521 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
522
523 let k_neighbors = distances.into_iter().take(k).collect::<Vec<_>>();
524 let mean_value =
525 k_neighbors.iter().map(|(_, value)| value).sum::<f64>() / k_neighbors.len() as f64;
526
527 Ok(mean_value)
528 } else {
529 self.predict_nearest_prototype(sample)
530 }
531 }
532
533 fn predict_support_based(&self, sample: &ArrayView1<f64>) -> Result<f64, SklearsError> {
534 if let Some((ref support_x, ref support_y)) = self.support_samples {
535 let mut weighted_sum = 0.0;
536 let mut total_weight = 0.0;
537
538 for (i, support_sample) in support_x.rows().into_iter().enumerate() {
539 let distance: f64 = sample
540 .iter()
541 .zip(support_sample.iter())
542 .map(|(a, b)| (a - b).powi(2))
543 .sum::<f64>()
544 .sqrt();
545
546 let weight = 1.0 / (1.0 + distance);
547 weighted_sum += weight * support_y[i];
548 total_weight += weight;
549 }
550
551 Ok(weighted_sum / total_weight)
552 } else {
553 self.predict_nearest_prototype(sample)
554 }
555 }
556
557 fn predict_centroid(&self, _sample: &ArrayView1<f64>) -> Result<f64, SklearsError> {
558 let mean_target = self.prototypes.iter().map(|(_, value)| value).sum::<f64>()
560 / self.prototypes.len() as f64;
561 Ok(mean_target)
562 }
563
564 fn predict_probabilistic(
565 &self,
566 sample: &ArrayView1<f64>,
567 rng: &mut dyn RngCore,
568 ) -> Result<f64, SklearsError> {
569 let mut weighted_sum = 0.0;
571 let mut total_weight = 0.0;
572 let mut variance_sum = 0.0;
573
574 for (prototype, value) in &self.prototypes {
575 let distance: f64 = sample
576 .iter()
577 .zip(prototype.iter())
578 .map(|(a, b)| (a - b).powi(2))
579 .sum::<f64>()
580 .sqrt();
581
582 let weight = (-distance).exp();
583 weighted_sum += weight * value;
584 total_weight += weight;
585 variance_sum += weight * value * value;
586 }
587
588 let mean = weighted_sum / total_weight;
589 let variance = (variance_sum / total_weight) - mean * mean;
590 let std_dev = variance.sqrt().max(0.1);
591
592 use scirs2_core::random::essentials::Normal;
594 let normal = Normal::new(mean, std_dev).map_err(|_| SklearsError::InvalidParameter {
595 name: "normal_distribution".to_string(),
596 reason: "Invalid parameters for normal distribution".to_string(),
597 })?;
598 let sample_value = normal.sample(rng);
599
600 Ok(sample_value)
601 }
602}
603
604impl TransferLearningBaseline {
605 pub fn new(strategy: TransferStrategy) -> Self {
607 Self {
608 strategy,
609 source_statistics: None,
610 random_state: None,
611 }
612 }
613
614 pub fn with_source_statistics(mut self, stats: SourceDomainStats) -> Self {
616 self.source_statistics = Some(stats);
617 self
618 }
619
620 pub fn with_random_state(mut self, seed: u64) -> Self {
622 self.random_state = Some(seed);
623 self
624 }
625}
626
627impl Fit<Array2<f64>, Array1<f64>, FittedTransferBaseline> for TransferLearningBaseline {
628 type Fitted = FittedTransferBaseline;
629 fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedTransferBaseline, SklearsError> {
630 if x.nrows() != y.len() {
631 return Err(SklearsError::ShapeMismatch {
632 expected: format!("{} samples", x.nrows()),
633 actual: format!("{} labels", y.len()),
634 });
635 }
636
637 let feature_means = x.mean_axis(Axis(0)).unwrap();
639 let feature_stds = x.std_axis(Axis(0), 0.0);
640 let target_mean = y.mean().unwrap();
641 let target_std = y.std(0.0);
642
643 let target_stats = SourceDomainStats {
644 feature_means,
645 feature_stds,
646 class_priors: HashMap::new(), target_mean,
648 target_std,
649 };
650
651 let source_stats = self
652 .source_statistics
653 .clone()
654 .unwrap_or_else(|| target_stats.clone());
655
656 let adaptation_weights = match &self.strategy {
658 TransferStrategy::SourcePrior { source_weight } => {
659 Array1::from_elem(x.ncols(), *source_weight)
660 }
661 TransferStrategy::FeatureBased { adaptation_rate } => {
662 let mut weights = Array1::zeros(x.ncols());
664 for i in 0..x.ncols() {
665 let source_var = source_stats.feature_stds[i].powi(2);
666 let target_var = target_stats.feature_stds[i].powi(2);
667 let ratio = (target_var / (source_var + 1e-10)).min(1.0);
668 weights[i] = adaptation_rate * ratio + (1.0 - adaptation_rate);
669 }
670 weights
671 }
672 TransferStrategy::InstanceBased {
673 similarity_threshold: _,
674 } => Array1::from_elem(x.ncols(), 0.5),
675 TransferStrategy::ModelBased {
676 confidence_threshold: _,
677 } => Array1::from_elem(x.ncols(), 0.7),
678 TransferStrategy::EnsembleTransfer { domain_weights } => {
679 if domain_weights.is_empty() {
680 Array1::from_elem(x.ncols(), 1.0)
681 } else {
682 Array1::from_elem(x.ncols(), domain_weights[0])
683 }
684 }
685 };
686
687 Ok(FittedTransferBaseline {
688 strategy: self.strategy.clone(),
689 source_stats,
690 target_stats,
691 adaptation_weights,
692 random_state: self.random_state,
693 })
694 }
695}
696
697impl Predict<Array2<f64>, Array1<f64>> for FittedTransferBaseline {
698 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
699 let mut predictions = Vec::with_capacity(x.nrows());
700
701 for sample in x.rows() {
702 let mut prediction = 0.0;
704
705 match &self.strategy {
706 TransferStrategy::SourcePrior { source_weight } => {
707 prediction = source_weight * self.source_stats.target_mean
708 + (1.0 - source_weight) * self.target_stats.target_mean;
709 }
710 TransferStrategy::FeatureBased { adaptation_rate: _ } => {
711 let source_contrib = self.source_stats.target_mean;
713 let target_contrib = self.target_stats.target_mean;
714 let avg_weight = self.adaptation_weights.mean().unwrap();
715 prediction = avg_weight * target_contrib + (1.0 - avg_weight) * source_contrib;
716 }
717 TransferStrategy::InstanceBased {
718 similarity_threshold: _,
719 } => {
720 prediction =
722 0.5 * self.source_stats.target_mean + 0.5 * self.target_stats.target_mean;
723 }
724 TransferStrategy::ModelBased {
725 confidence_threshold: _,
726 } => {
727 prediction =
729 0.7 * self.target_stats.target_mean + 0.3 * self.source_stats.target_mean;
730 }
731 TransferStrategy::EnsembleTransfer { domain_weights } => {
732 if domain_weights.is_empty() {
733 prediction = self.target_stats.target_mean;
734 } else {
735 prediction = domain_weights[0] * self.source_stats.target_mean
736 + (1.0 - domain_weights[0]) * self.target_stats.target_mean;
737 }
738 }
739 }
740
741 predictions.push(prediction);
742 }
743
744 Ok(Array1::from_vec(predictions))
745 }
746}
747
748impl DomainAdaptationBaseline {
749 pub fn new(strategy: DomainAdaptationStrategy) -> Self {
751 Self {
752 strategy,
753 source_domain_data: None,
754 random_state: None,
755 }
756 }
757
758 pub fn with_source_domain_data(mut self, source_x: Array2<f64>, source_y: Array1<f64>) -> Self {
760 self.source_domain_data = Some((source_x, source_y));
761 self
762 }
763
764 pub fn with_random_state(mut self, seed: u64) -> Self {
766 self.random_state = Some(seed);
767 self
768 }
769}
770
771impl Fit<Array2<f64>, Array1<f64>, FittedDomainAdaptationBaseline> for DomainAdaptationBaseline {
772 type Fitted = FittedDomainAdaptationBaseline;
773
774 fn fit(
775 self,
776 x: &Array2<f64>,
777 y: &Array1<f64>,
778 ) -> Result<FittedDomainAdaptationBaseline, SklearsError> {
779 if x.nrows() != y.len() {
780 return Err(SklearsError::ShapeMismatch {
781 expected: format!("{} samples", x.nrows()),
782 actual: format!("{} labels", y.len()),
783 });
784 }
785
786 let target_feature_means = x.mean_axis(Axis(0)).unwrap();
788 let target_feature_stds = x.std_axis(Axis(0), 0.0);
789 let target_mean = y.mean().unwrap();
790 let target_std = y.std(0.0);
791
792 let target_stats = SourceDomainStats {
793 feature_means: target_feature_means,
794 feature_stds: target_feature_stds,
795 class_priors: HashMap::new(),
796 target_mean,
797 target_std,
798 };
799
800 let source_stats = if let Some((ref source_x, ref source_y)) = self.source_domain_data {
802 let source_feature_means = source_x.mean_axis(Axis(0)).unwrap();
803 let source_feature_stds = source_x.std_axis(Axis(0), 0.0);
804 let source_mean = source_y.mean().unwrap();
805 let source_std = source_y.std(0.0);
806
807 SourceDomainStats {
808 feature_means: source_feature_means,
809 feature_stds: source_feature_stds,
810 class_priors: HashMap::new(),
811 target_mean: source_mean,
812 target_std: source_std,
813 }
814 } else {
815 target_stats.clone()
816 };
817
818 let adaptation_matrix = self.compute_adaptation_matrix(&source_stats, &target_stats);
820 let instance_weights = self.compute_instance_weights(x, &source_stats, &target_stats);
821
822 Ok(FittedDomainAdaptationBaseline {
823 strategy: self.strategy,
824 source_stats,
825 target_stats,
826 adaptation_matrix,
827 instance_weights,
828 random_state: self.random_state,
829 })
830 }
831}
832
833impl DomainAdaptationBaseline {
834 fn compute_adaptation_matrix(
835 &self,
836 source_stats: &SourceDomainStats,
837 target_stats: &SourceDomainStats,
838 ) -> Array2<f64> {
839 let n_features = source_stats.feature_means.len();
840 let mut adaptation_matrix = Array2::eye(n_features);
841
842 match &self.strategy {
843 DomainAdaptationStrategy::FeatureAlignment => {
844 for i in 0..n_features {
846 let source_std = source_stats.feature_stds[i].max(1e-8);
847 let target_std = target_stats.feature_stds[i].max(1e-8);
848 adaptation_matrix[[i, i]] = target_std / source_std;
849 }
850 }
851 DomainAdaptationStrategy::InstanceReweighting {
852 adaptation_strength,
853 } => {
854 adaptation_matrix *= *adaptation_strength;
856 }
857 DomainAdaptationStrategy::GradientReversal { lambda } => {
858 for i in 0..n_features {
860 adaptation_matrix[[i, i]] = 1.0 - *lambda;
861 }
862 }
863 DomainAdaptationStrategy::SubspaceAlignment { subspace_dim } => {
864 let dim = (*subspace_dim).min(n_features);
866 for i in 0..dim {
867 for j in 0..dim {
868 if i != j {
869 adaptation_matrix[[i, j]] = 0.1; }
871 }
872 }
873 }
874 DomainAdaptationStrategy::MMDMinimization { bandwidth } => {
875 for i in 0..n_features {
877 let distance =
878 (source_stats.feature_means[i] - target_stats.feature_means[i]).abs();
879 let weight = (-distance / bandwidth).exp();
880 adaptation_matrix[[i, i]] = weight;
881 }
882 }
883 }
884
885 adaptation_matrix
886 }
887
888 fn compute_instance_weights(
889 &self,
890 x: &Array2<f64>,
891 source_stats: &SourceDomainStats,
892 target_stats: &SourceDomainStats,
893 ) -> Array1<f64> {
894 let mut weights = Array1::ones(x.nrows());
895
896 match &self.strategy {
897 DomainAdaptationStrategy::InstanceReweighting {
898 adaptation_strength,
899 } => {
900 for (i, sample) in x.rows().into_iter().enumerate() {
902 let source_distance: f64 = sample
903 .iter()
904 .zip(source_stats.feature_means.iter())
905 .map(|(x_val, mean_val)| (x_val - mean_val).powi(2))
906 .sum::<f64>()
907 .sqrt();
908
909 let target_distance: f64 = sample
910 .iter()
911 .zip(target_stats.feature_means.iter())
912 .map(|(x_val, mean_val)| (x_val - mean_val).powi(2))
913 .sum::<f64>()
914 .sqrt();
915
916 let weight = if source_distance > 0.0 {
917 adaptation_strength * target_distance / (source_distance + target_distance)
918 } else {
919 *adaptation_strength
920 };
921
922 weights[i] = weight.max(0.1).min(10.0); }
924 }
925 _ => {
926 }
928 }
929
930 weights
931 }
932}
933
934impl Predict<Array2<f64>, Array1<f64>> for FittedDomainAdaptationBaseline {
935 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
936 let mut predictions = Vec::with_capacity(x.nrows());
937
938 for (i, sample) in x.rows().into_iter().enumerate() {
939 let adapted_sample = if i < self.instance_weights.len() {
941 let weight = self.instance_weights[i];
942 sample.mapv(|val| val * weight)
943 } else {
944 sample.to_owned()
945 };
946
947 let prediction = match &self.strategy {
949 DomainAdaptationStrategy::FeatureAlignment => {
950 self.target_stats.target_mean
952 }
953 DomainAdaptationStrategy::InstanceReweighting { .. } => {
954 let source_contrib = self.source_stats.target_mean;
956 let target_contrib = self.target_stats.target_mean;
957 let weight = if i < self.instance_weights.len() {
958 self.instance_weights[i]
959 } else {
960 0.5
961 };
962 weight * target_contrib + (1.0 - weight) * source_contrib
963 }
964 DomainAdaptationStrategy::GradientReversal { lambda } => {
965 let adversarial_weight = 1.0 - lambda;
967 adversarial_weight * self.target_stats.target_mean
968 + lambda * self.source_stats.target_mean
969 }
970 DomainAdaptationStrategy::SubspaceAlignment { .. } => {
971 (self.source_stats.target_mean + self.target_stats.target_mean) / 2.0
973 }
974 DomainAdaptationStrategy::MMDMinimization { .. } => {
975 let feature_weighted_prediction = adapted_sample.mean().unwrap_or(0.0);
977 (feature_weighted_prediction + self.target_stats.target_mean) / 2.0
978 }
979 };
980
981 predictions.push(prediction);
982 }
983
984 Ok(Array1::from_vec(predictions))
985 }
986}
987
988#[derive(Debug, Clone)]
990pub struct ContinualLearningBaseline {
991 strategy: ContinualStrategy,
992 memory_buffer: Vec<(Array1<f64>, f64)>,
993 task_statistics: Vec<SourceDomainStats>,
994 random_state: Option<u64>,
995}
996
997#[derive(Debug, Clone)]
999pub struct FittedContinualLearningBaseline {
1000 strategy: ContinualStrategy,
1001 memory_buffer: Vec<(Array1<f64>, f64)>,
1002 task_statistics: Vec<SourceDomainStats>,
1003 consolidation_weights: Array1<f64>,
1004 current_task_id: usize,
1005 random_state: Option<u64>,
1006}
1007
1008impl ContinualLearningBaseline {
1009 pub fn new(strategy: ContinualStrategy) -> Self {
1011 Self {
1012 strategy,
1013 memory_buffer: Vec::new(),
1014 task_statistics: Vec::new(),
1015 random_state: None,
1016 }
1017 }
1018
1019 pub fn with_random_state(mut self, seed: u64) -> Self {
1021 self.random_state = Some(seed);
1022 self
1023 }
1024
1025 pub fn with_task_memory(mut self, memory: Vec<(Array1<f64>, f64)>) -> Self {
1027 self.memory_buffer = memory;
1028 self
1029 }
1030}
1031
1032impl Fit<Array2<f64>, Array1<f64>, FittedContinualLearningBaseline> for ContinualLearningBaseline {
1033 type Fitted = FittedContinualLearningBaseline;
1034
1035 fn fit(
1036 self,
1037 x: &Array2<f64>,
1038 y: &Array1<f64>,
1039 ) -> Result<FittedContinualLearningBaseline, SklearsError> {
1040 if x.nrows() != y.len() {
1041 return Err(SklearsError::ShapeMismatch {
1042 expected: format!("{} samples", x.nrows()),
1043 actual: format!("{} labels", y.len()),
1044 });
1045 }
1046
1047 let feature_means = x.mean_axis(Axis(0)).unwrap();
1049 let feature_stds = x.std_axis(Axis(0), 0.0);
1050 let target_mean = y.mean().unwrap();
1051 let target_std = y.std(0.0);
1052
1053 let current_task_stats = SourceDomainStats {
1054 feature_means,
1055 feature_stds,
1056 class_priors: HashMap::new(),
1057 target_mean,
1058 target_std,
1059 };
1060
1061 let mut task_statistics = self.task_statistics.clone();
1062 task_statistics.push(current_task_stats.clone());
1063
1064 let mut memory_buffer = self.memory_buffer.clone();
1066 match &self.strategy {
1067 ContinualStrategy::ElasticWeightConsolidation { .. } => {
1068 for i in 0..x.nrows().min(100) {
1070 memory_buffer.push((x.row(i).to_owned(), y[i]));
1072 }
1073 }
1074 ContinualStrategy::Rehearsal { memory_size } => {
1075 for i in 0..x.nrows() {
1077 memory_buffer.push((x.row(i).to_owned(), y[i]));
1078 if memory_buffer.len() > *memory_size {
1079 memory_buffer.remove(0); }
1081 }
1082 }
1083 ContinualStrategy::Progressive { .. } => {
1084 for i in 0..x.nrows() {
1086 memory_buffer.push((x.row(i).to_owned(), y[i]));
1087 }
1088 }
1089 ContinualStrategy::LearningWithoutForgetting { .. } => {
1090 let samples_to_store = x.nrows().min(50);
1092 for i in 0..samples_to_store {
1093 memory_buffer.push((x.row(i).to_owned(), y[i]));
1094 }
1095 }
1096 ContinualStrategy::AGEM { .. } => {
1097 let memory_samples = x.nrows().min(20);
1099 for i in 0..memory_samples {
1100 memory_buffer.push((x.row(i).to_owned(), y[i]));
1101 }
1102 }
1103 }
1104
1105 let consolidation_weights =
1107 self.compute_consolidation_weights(&task_statistics, ¤t_task_stats);
1108
1109 let current_task_id = task_statistics.len() - 1;
1110
1111 Ok(FittedContinualLearningBaseline {
1112 strategy: self.strategy,
1113 memory_buffer,
1114 task_statistics,
1115 consolidation_weights,
1116 current_task_id,
1117 random_state: self.random_state,
1118 })
1119 }
1120}
1121
1122impl ContinualLearningBaseline {
1123 fn compute_consolidation_weights(
1124 &self,
1125 task_stats: &[SourceDomainStats],
1126 current_stats: &SourceDomainStats,
1127 ) -> Array1<f64> {
1128 let n_features = current_stats.feature_means.len();
1129 let mut weights = Array1::ones(n_features);
1130
1131 match &self.strategy {
1132 ContinualStrategy::ElasticWeightConsolidation { importance_weight } => {
1133 for i in 0..n_features {
1135 let feature_variance: f64 = task_stats
1136 .iter()
1137 .map(|stats| stats.feature_means[i])
1138 .map(|mean| (mean - current_stats.feature_means[i]).powi(2))
1139 .sum::<f64>()
1140 / task_stats.len().max(1) as f64;
1141
1142 weights[i] = importance_weight * feature_variance.sqrt();
1143 }
1144 }
1145 ContinualStrategy::Rehearsal { memory_size: _ } => {
1146 weights.fill(1.0);
1148 }
1149 ContinualStrategy::Progressive { column_capacity } => {
1150 let capacity_weight = 1.0 / (*column_capacity as f64);
1152 weights.fill(capacity_weight);
1153 }
1154 ContinualStrategy::LearningWithoutForgetting {
1155 distillation_weight,
1156 } => {
1157 weights.fill(*distillation_weight);
1159 }
1160 ContinualStrategy::AGEM { memory_strength } => {
1161 weights.fill(*memory_strength);
1163 }
1164 }
1165
1166 weights
1167 }
1168}
1169
1170impl Predict<Array2<f64>, Array1<f64>> for FittedContinualLearningBaseline {
1171 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
1172 let mut predictions = Vec::with_capacity(x.nrows());
1173
1174 for sample in x.rows() {
1175 let prediction = match &self.strategy {
1176 ContinualStrategy::ElasticWeightConsolidation { .. } => {
1177 let mut weighted_sum = 0.0;
1179 let mut total_weight = 0.0;
1180
1181 for (task_id, task_stats) in self.task_statistics.iter().enumerate() {
1182 let task_weight = if task_id < self.consolidation_weights.len() {
1183 self.consolidation_weights[task_id]
1184 } else {
1185 1.0
1186 };
1187
1188 weighted_sum += task_weight * task_stats.target_mean;
1189 total_weight += task_weight;
1190 }
1191
1192 if total_weight > 0.0 {
1193 weighted_sum / total_weight
1194 } else {
1195 self.task_statistics
1196 .last()
1197 .map_or(0.0, |stats| stats.target_mean)
1198 }
1199 }
1200 ContinualStrategy::Rehearsal { .. } => {
1201 if !self.memory_buffer.is_empty() {
1203 let memory_prediction: f64 =
1204 self.memory_buffer.iter().map(|(_, y)| *y).sum::<f64>()
1205 / self.memory_buffer.len() as f64;
1206
1207 let current_prediction = self
1208 .task_statistics
1209 .last()
1210 .map_or(0.0, |stats| stats.target_mean);
1211
1212 (memory_prediction + current_prediction) / 2.0
1213 } else {
1214 self.task_statistics
1215 .last()
1216 .map_or(0.0, |stats| stats.target_mean)
1217 }
1218 }
1219 ContinualStrategy::Progressive { .. } => {
1220 let task_predictions: f64 = self
1222 .task_statistics
1223 .iter()
1224 .map(|stats| stats.target_mean)
1225 .sum();
1226
1227 if !self.task_statistics.is_empty() {
1228 task_predictions / self.task_statistics.len() as f64
1229 } else {
1230 0.0
1231 }
1232 }
1233 ContinualStrategy::LearningWithoutForgetting {
1234 distillation_weight,
1235 } => {
1236 let current_prediction = self
1238 .task_statistics
1239 .last()
1240 .map_or(0.0, |stats| stats.target_mean);
1241
1242 if self.task_statistics.len() > 1 {
1243 let previous_predictions: f64 = self
1244 .task_statistics
1245 .iter()
1246 .take(self.task_statistics.len() - 1)
1247 .map(|stats| stats.target_mean)
1248 .sum::<f64>()
1249 / (self.task_statistics.len() - 1) as f64;
1250
1251 distillation_weight * previous_predictions
1252 + (1.0 - distillation_weight) * current_prediction
1253 } else {
1254 current_prediction
1255 }
1256 }
1257 ContinualStrategy::AGEM { .. } => {
1258 if !self.memory_buffer.is_empty() {
1260 let mut min_distance = f64::INFINITY;
1262 let mut nearest_value = 0.0;
1263
1264 for (memory_sample, memory_value) in &self.memory_buffer {
1265 let distance: f64 = sample
1266 .iter()
1267 .zip(memory_sample.iter())
1268 .map(|(a, b)| (a - b).powi(2))
1269 .sum::<f64>()
1270 .sqrt();
1271
1272 if distance < min_distance {
1273 min_distance = distance;
1274 nearest_value = *memory_value;
1275 }
1276 }
1277
1278 nearest_value
1279 } else {
1280 self.task_statistics
1281 .last()
1282 .map_or(0.0, |stats| stats.target_mean)
1283 }
1284 }
1285 };
1286
1287 predictions.push(prediction);
1288 }
1289
1290 Ok(Array1::from_vec(predictions))
1291 }
1292}
1293
1294#[allow(non_snake_case)]
1295#[cfg(test)]
1296mod tests {
1297 use super::*;
1298 use scirs2_core::ndarray::array;
1299
1300 #[test]
1301 fn test_few_shot_classifier() {
1302 let x = Array2::from_shape_vec(
1303 (6, 2),
1304 vec![1.0, 1.0, 1.1, 1.1, 5.0, 5.0, 5.1, 5.1, 3.0, 3.0, 3.1, 3.1],
1305 )
1306 .unwrap();
1307 let y = array![0, 0, 1, 1, 2, 2];
1308
1309 let classifier = FewShotBaselineClassifier::new(FewShotStrategy::NearestPrototype);
1310 let fitted = classifier.fit(&x, &y).unwrap();
1311 let predictions = fitted.predict(&x).unwrap();
1312
1313 assert_eq!(predictions.len(), 6);
1314 }
1315
1316 #[test]
1317 fn test_few_shot_regressor() {
1318 let x =
1319 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1320 let y = array![1.0, 2.0, 3.0, 4.0];
1321
1322 let regressor = FewShotBaselineRegressor::new(FewShotStrategy::KNearestNeighbors { k: 2 });
1323 let fitted = regressor.fit(&x, &y).unwrap();
1324 let predictions = fitted.predict(&x).unwrap();
1325
1326 assert_eq!(predictions.len(), 4);
1327 }
1328
1329 #[test]
1330 fn test_transfer_learning() {
1331 let x =
1332 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1333 let y = array![1.0, 2.0, 3.0, 4.0];
1334
1335 let source_stats = SourceDomainStats {
1336 feature_means: array![2.0, 2.0],
1337 feature_stds: array![1.0, 1.0],
1338 class_priors: HashMap::new(),
1339 target_mean: 2.0,
1340 target_std: 1.0,
1341 };
1342
1343 let baseline =
1344 TransferLearningBaseline::new(TransferStrategy::SourcePrior { source_weight: 0.3 })
1345 .with_source_statistics(source_stats);
1346 let fitted = baseline.fit(&x, &y).unwrap();
1347 let predictions = fitted.predict(&x).unwrap();
1348
1349 assert_eq!(predictions.len(), 4);
1350 }
1351
1352 #[test]
1353 fn test_few_shot_strategies() {
1354 let x =
1355 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 5.0, 5.0, 6.0, 6.0]).unwrap();
1356 let y = array![0, 0, 1, 1];
1357
1358 let strategies = vec![
1359 FewShotStrategy::NearestPrototype,
1360 FewShotStrategy::KNearestNeighbors { k: 2 },
1361 FewShotStrategy::SupportBased,
1362 FewShotStrategy::Centroid,
1363 FewShotStrategy::Probabilistic,
1364 ];
1365
1366 for strategy in strategies {
1367 let classifier = FewShotBaselineClassifier::new(strategy).with_random_state(42);
1368 let fitted = classifier.fit(&x, &y).unwrap();
1369 let predictions = fitted.predict(&x).unwrap();
1370 assert_eq!(predictions.len(), 4);
1371 }
1372 }
1373
1374 #[test]
1375 fn test_transfer_strategies() {
1376 let x =
1377 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1378 let y = array![1.0, 2.0, 3.0, 4.0];
1379
1380 let strategies = vec![
1381 TransferStrategy::SourcePrior { source_weight: 0.5 },
1382 TransferStrategy::FeatureBased {
1383 adaptation_rate: 0.3,
1384 },
1385 TransferStrategy::InstanceBased {
1386 similarity_threshold: 0.8,
1387 },
1388 TransferStrategy::ModelBased {
1389 confidence_threshold: 0.7,
1390 },
1391 TransferStrategy::EnsembleTransfer {
1392 domain_weights: vec![0.6],
1393 },
1394 ];
1395
1396 for strategy in strategies {
1397 let baseline = TransferLearningBaseline::new(strategy);
1398 let fitted = baseline.fit(&x, &y).unwrap();
1399 let predictions = fitted.predict(&x).unwrap();
1400 assert_eq!(predictions.len(), 4);
1401 }
1402 }
1403
1404 #[test]
1405 fn test_domain_adaptation() {
1406 let x =
1407 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1408 let y = array![1.0, 2.0, 3.0, 4.0];
1409
1410 let source_x = Array2::from_shape_vec((3, 2), vec![0.5, 0.5, 1.5, 1.5, 2.5, 2.5]).unwrap();
1411 let source_y = array![0.5, 1.5, 2.5];
1412
1413 let baseline = DomainAdaptationBaseline::new(DomainAdaptationStrategy::FeatureAlignment)
1414 .with_source_domain_data(source_x, source_y);
1415 let fitted = baseline.fit(&x, &y).unwrap();
1416 let predictions = fitted.predict(&x).unwrap();
1417
1418 assert_eq!(predictions.len(), 4);
1419 assert!(fitted.adaptation_matrix.shape() == &[2, 2]);
1420 }
1421
1422 #[test]
1423 fn test_domain_adaptation_strategies() {
1424 let x =
1425 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1426 let y = array![1.0, 2.0, 3.0, 4.0];
1427
1428 let strategies = vec![
1429 DomainAdaptationStrategy::FeatureAlignment,
1430 DomainAdaptationStrategy::InstanceReweighting {
1431 adaptation_strength: 0.8,
1432 },
1433 DomainAdaptationStrategy::GradientReversal { lambda: 0.1 },
1434 DomainAdaptationStrategy::SubspaceAlignment { subspace_dim: 2 },
1435 DomainAdaptationStrategy::MMDMinimization { bandwidth: 1.0 },
1436 ];
1437
1438 for strategy in strategies {
1439 let baseline = DomainAdaptationBaseline::new(strategy);
1440 let fitted = baseline.fit(&x, &y).unwrap();
1441 let predictions = fitted.predict(&x).unwrap();
1442 assert_eq!(predictions.len(), 4);
1443 }
1444 }
1445
1446 #[test]
1447 fn test_continual_learning() {
1448 let x =
1449 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1450 let y = array![1.0, 2.0, 3.0, 4.0];
1451
1452 let baseline =
1453 ContinualLearningBaseline::new(ContinualStrategy::ElasticWeightConsolidation {
1454 importance_weight: 1.0,
1455 });
1456 let fitted = baseline.fit(&x, &y).unwrap();
1457 let predictions = fitted.predict(&x).unwrap();
1458
1459 assert_eq!(predictions.len(), 4);
1460 assert_eq!(fitted.task_statistics.len(), 1);
1461 }
1462
1463 #[test]
1464 fn test_continual_strategies() {
1465 let x =
1466 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1467 let y = array![1.0, 2.0, 3.0, 4.0];
1468
1469 let strategies = vec![
1470 ContinualStrategy::ElasticWeightConsolidation {
1471 importance_weight: 1.0,
1472 },
1473 ContinualStrategy::Rehearsal { memory_size: 100 },
1474 ContinualStrategy::Progressive {
1475 column_capacity: 10,
1476 },
1477 ContinualStrategy::LearningWithoutForgetting {
1478 distillation_weight: 0.8,
1479 },
1480 ContinualStrategy::AGEM {
1481 memory_strength: 0.5,
1482 },
1483 ];
1484
1485 for strategy in strategies {
1486 let baseline = ContinualLearningBaseline::new(strategy).with_random_state(42);
1487 let fitted = baseline.fit(&x, &y).unwrap();
1488 let predictions = fitted.predict(&x).unwrap();
1489 assert_eq!(predictions.len(), 4);
1490 }
1491 }
1492}