1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::thread_rng;
9use scirs2_core::random::RandNormal;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Untrained},
13 types::Float,
14};
15
16#[derive(Debug, Clone)]
38pub struct CrossTaskTransferLearning<S = Untrained> {
39 state: S,
40 transfer_strength: Float,
41 learning_rate: Float,
42 max_iter: usize,
43 random_state: Option<u64>,
44}
45
46#[derive(Debug, Clone)]
47pub struct CrossTaskTransferLearningTrained {
48 source_weights: Array2<Float>,
49 target_weights: Array2<Float>,
50 transfer_matrix: Array2<Float>,
51 n_features: usize,
52 n_source_tasks: usize,
53 n_target_tasks: usize,
54}
55
56impl CrossTaskTransferLearning<Untrained> {
57 pub fn new() -> Self {
59 Self {
60 state: Untrained,
61 transfer_strength: 0.5,
62 learning_rate: 0.01,
63 max_iter: 1000,
64 random_state: None,
65 }
66 }
67
68 pub fn transfer_strength(mut self, strength: Float) -> Self {
70 self.transfer_strength = strength;
71 self
72 }
73
74 pub fn learning_rate(mut self, lr: Float) -> Self {
76 self.learning_rate = lr;
77 self
78 }
79
80 pub fn max_iter(mut self, max_iter: usize) -> Self {
82 self.max_iter = max_iter;
83 self
84 }
85
86 pub fn random_state(mut self, seed: Option<u64>) -> Self {
88 self.random_state = seed;
89 self
90 }
91
92 pub fn fit(
94 &self,
95 source_X: &ArrayView2<Float>,
96 source_y: &ArrayView2<Float>,
97 target_X: &ArrayView2<Float>,
98 target_y: &ArrayView2<Float>,
99 ) -> SklResult<CrossTaskTransferLearning<CrossTaskTransferLearningTrained>> {
100 let n_source_samples = source_X.nrows();
101 let n_target_samples = target_X.nrows();
102 let n_features = source_X.ncols();
103 let n_source_tasks = source_y.ncols();
104 let n_target_tasks = target_y.ncols();
105
106 if source_X.ncols() != target_X.ncols() {
107 return Err(SklearsError::InvalidInput(
108 "Source and target data must have the same number of features".to_string(),
109 ));
110 }
111
112 if n_source_samples != source_y.nrows() {
113 return Err(SklearsError::InvalidInput(
114 "Number of source samples must match source labels".to_string(),
115 ));
116 }
117
118 if n_target_samples != target_y.nrows() {
119 return Err(SklearsError::InvalidInput(
120 "Number of target samples must match target labels".to_string(),
121 ));
122 }
123
124 let mut rng = thread_rng();
125
126 let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
128
129 let mut source_weights = Array2::<Float>::zeros((n_features, n_source_tasks));
130 for i in 0..n_features {
131 for j in 0..n_source_tasks {
132 source_weights[[i, j]] = rng.sample(normal_dist);
133 }
134 }
135
136 let mut target_weights = Array2::<Float>::zeros((n_features, n_target_tasks));
137 for i in 0..n_features {
138 for j in 0..n_target_tasks {
139 target_weights[[i, j]] = rng.sample(normal_dist);
140 }
141 }
142
143 let mut transfer_matrix = Array2::<Float>::zeros((n_source_tasks, n_target_tasks));
144 for i in 0..n_source_tasks {
145 for j in 0..n_target_tasks {
146 transfer_matrix[[i, j]] = rng.sample(normal_dist);
147 }
148 }
149
150 for _ in 0..self.max_iter {
152 let source_pred = source_X.dot(&source_weights);
154 let source_error = &source_pred - source_y;
155 let source_grad = source_X.t().dot(&source_error) / n_source_samples as Float;
156 source_weights -= &(source_grad * self.learning_rate);
157
158 let target_pred = target_X.dot(&target_weights);
160 let transferred_pred = target_X.dot(&source_weights).dot(&transfer_matrix);
161 let target_error = &target_pred - target_y;
162 let transfer_error = &transferred_pred - target_y;
163
164 let target_grad = target_X.t().dot(&target_error) / n_target_samples as Float;
165 let transfer_grad = target_X.t().dot(&transfer_error) / n_target_samples as Float;
166
167 target_weights -= &(target_grad * self.learning_rate);
168 target_weights -= &(transfer_grad * self.learning_rate * self.transfer_strength);
169
170 let transfer_matrix_grad =
172 target_X.dot(&source_weights).t().dot(&transfer_error) / n_target_samples as Float;
173 transfer_matrix -=
174 &(transfer_matrix_grad * self.learning_rate * self.transfer_strength);
175 }
176
177 Ok(CrossTaskTransferLearning {
178 state: CrossTaskTransferLearningTrained {
179 source_weights,
180 target_weights,
181 transfer_matrix,
182 n_features,
183 n_source_tasks,
184 n_target_tasks,
185 },
186 transfer_strength: self.transfer_strength,
187 learning_rate: self.learning_rate,
188 max_iter: self.max_iter,
189 random_state: self.random_state,
190 })
191 }
192}
193
194impl CrossTaskTransferLearning<CrossTaskTransferLearningTrained> {
195 pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
197 if X.ncols() != self.state.n_features {
198 return Err(SklearsError::InvalidInput(
199 "Number of features must match training data".to_string(),
200 ));
201 }
202
203 let target_pred = X.dot(&self.state.target_weights);
204 Ok(target_pred)
205 }
206
207 pub fn predict_from_source(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
209 if X.ncols() != self.state.n_features {
210 return Err(SklearsError::InvalidInput(
211 "Number of features must match training data".to_string(),
212 ));
213 }
214
215 let source_pred = X.dot(&self.state.source_weights);
216 let transferred_pred = source_pred.dot(&self.state.transfer_matrix);
217 Ok(transferred_pred)
218 }
219
220 pub fn transfer_matrix(&self) -> &Array2<Float> {
222 &self.state.transfer_matrix
223 }
224
225 pub fn source_weights(&self) -> &Array2<Float> {
227 &self.state.source_weights
228 }
229
230 pub fn target_weights(&self) -> &Array2<Float> {
232 &self.state.target_weights
233 }
234}
235
236impl Default for CrossTaskTransferLearning<Untrained> {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl Estimator for CrossTaskTransferLearning<Untrained> {
243 type Config = ();
244 type Error = SklearsError;
245 type Float = Float;
246
247 fn config(&self) -> &Self::Config {
248 &()
249 }
250}
251
252impl Estimator for CrossTaskTransferLearning<CrossTaskTransferLearningTrained> {
253 type Config = ();
254 type Error = SklearsError;
255 type Float = Float;
256
257 fn config(&self) -> &Self::Config {
258 &()
259 }
260}
261
262#[derive(Debug, Clone)]
284pub struct DomainAdaptation<S = Untrained> {
285 state: S,
286 adaptation_strength: Float,
287 learning_rate: Float,
288 max_iter: usize,
289 random_state: Option<u64>,
290}
291
292#[derive(Debug, Clone)]
293pub struct DomainAdaptationTrained {
294 feature_extractor: Array2<Float>,
295 classifier: Array2<Float>,
296 domain_discriminator: Array2<Float>,
297 n_features: usize,
298 n_tasks: usize,
299}
300
301impl DomainAdaptation<Untrained> {
302 pub fn new() -> Self {
304 Self {
305 state: Untrained,
306 adaptation_strength: 0.3,
307 learning_rate: 0.01,
308 max_iter: 1000,
309 random_state: None,
310 }
311 }
312
313 pub fn adaptation_strength(mut self, strength: Float) -> Self {
315 self.adaptation_strength = strength;
316 self
317 }
318
319 pub fn learning_rate(mut self, lr: Float) -> Self {
321 self.learning_rate = lr;
322 self
323 }
324
325 pub fn max_iter(mut self, max_iter: usize) -> Self {
327 self.max_iter = max_iter;
328 self
329 }
330
331 pub fn random_state(mut self, seed: Option<u64>) -> Self {
333 self.random_state = seed;
334 self
335 }
336
337 pub fn fit(
339 &self,
340 source_X: &ArrayView2<Float>,
341 source_y: &ArrayView2<Float>,
342 target_X: &ArrayView2<Float>,
343 target_y: &ArrayView2<Float>,
344 ) -> SklResult<DomainAdaptation<DomainAdaptationTrained>> {
345 let n_source_samples = source_X.nrows();
346 let n_target_samples = target_X.nrows();
347 let n_features = source_X.ncols();
348 let n_tasks = source_y.ncols();
349
350 if source_X.ncols() != target_X.ncols() {
351 return Err(SklearsError::InvalidInput(
352 "Source and target data must have the same number of features".to_string(),
353 ));
354 }
355
356 if n_source_samples != source_y.nrows() {
357 return Err(SklearsError::InvalidInput(
358 "Number of source samples must match source labels".to_string(),
359 ));
360 }
361
362 if n_target_samples != target_y.nrows() {
363 return Err(SklearsError::InvalidInput(
364 "Number of target samples must match target labels".to_string(),
365 ));
366 }
367
368 let mut rng = thread_rng();
369
370 let hidden_dim = (n_features + n_tasks) / 2;
372 let mut feature_extractor = Array2::<Float>::zeros((n_features, hidden_dim));
373 let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
374 for i in 0..n_features {
375 for j in 0..hidden_dim {
376 feature_extractor[[i, j]] = rng.sample(normal_dist);
377 }
378 }
379 let mut classifier = Array2::<Float>::zeros((hidden_dim, n_tasks));
380 let classifier_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
381 for i in 0..hidden_dim {
382 for j in 0..n_tasks {
383 classifier[[i, j]] = rng.sample(classifier_normal_dist);
384 }
385 }
386 let mut domain_discriminator = Array2::<Float>::zeros((hidden_dim, 1));
387 let discriminator_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
388 for i in 0..hidden_dim {
389 domain_discriminator[[i, 0]] = rng.sample(discriminator_normal_dist);
390 }
391
392 let mut domain_labels = Array2::<Float>::zeros((n_source_samples + n_target_samples, 1));
394 for i in n_source_samples..(n_source_samples + n_target_samples) {
395 domain_labels[(i, 0)] = 1.0;
396 }
397
398 let mut combined_X =
400 Array2::<Float>::zeros((n_source_samples + n_target_samples, n_features));
401 combined_X
402 .slice_mut(s![..n_source_samples, ..])
403 .assign(source_X);
404 combined_X
405 .slice_mut(s![n_source_samples.., ..])
406 .assign(target_X);
407
408 for _ in 0..self.max_iter {
410 let features = combined_X.dot(&feature_extractor);
412 let source_features = features.slice(s![..n_source_samples, ..]);
413 let target_features = features.slice(s![n_source_samples.., ..]);
414
415 let source_pred = source_features.dot(&classifier);
417 let classification_error = &source_pred - source_y;
418 let classifier_grad =
419 source_features.t().dot(&classification_error) / n_source_samples as Float;
420 classifier -= &(&classifier_grad * self.learning_rate);
421
422 let domain_pred = features.dot(&domain_discriminator);
424 let domain_error = &domain_pred - &domain_labels;
425 let discriminator_grad =
426 features.t().dot(&domain_error) / (n_source_samples + n_target_samples) as Float;
427 domain_discriminator -= &(&discriminator_grad * self.learning_rate);
428
429 let feat_class_grad =
431 combined_X.t().dot(&features.dot(&classifier_grad.t())) / n_source_samples as Float;
432 let feat_domain_grad = combined_X.t().dot(&features.dot(&discriminator_grad))
433 / (n_source_samples + n_target_samples) as Float;
434
435 feature_extractor -= &(feat_class_grad * self.learning_rate);
436 feature_extractor +=
437 &(feat_domain_grad * self.learning_rate * self.adaptation_strength);
438 }
440
441 Ok(DomainAdaptation {
442 state: DomainAdaptationTrained {
443 feature_extractor,
444 classifier,
445 domain_discriminator,
446 n_features,
447 n_tasks,
448 },
449 adaptation_strength: self.adaptation_strength,
450 learning_rate: self.learning_rate,
451 max_iter: self.max_iter,
452 random_state: self.random_state,
453 })
454 }
455}
456
457impl DomainAdaptation<DomainAdaptationTrained> {
458 pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
460 if X.ncols() != self.state.n_features {
461 return Err(SklearsError::InvalidInput(
462 "Number of features must match training data".to_string(),
463 ));
464 }
465
466 let features = X.dot(&self.state.feature_extractor);
467 let predictions = features.dot(&self.state.classifier);
468 Ok(predictions)
469 }
470
471 pub fn extract_features(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
473 if X.ncols() != self.state.n_features {
474 return Err(SklearsError::InvalidInput(
475 "Number of features must match training data".to_string(),
476 ));
477 }
478
479 let features = X.dot(&self.state.feature_extractor);
480 Ok(features)
481 }
482
483 pub fn predict_domain(&self, X: &ArrayView2<Float>) -> SklResult<Array1<Float>> {
485 if X.ncols() != self.state.n_features {
486 return Err(SklearsError::InvalidInput(
487 "Number of features must match training data".to_string(),
488 ));
489 }
490
491 let features = X.dot(&self.state.feature_extractor);
492 let domain_pred = features.dot(&self.state.domain_discriminator);
493 Ok(domain_pred.column(0).to_owned())
494 }
495}
496
497impl Default for DomainAdaptation<Untrained> {
498 fn default() -> Self {
499 Self::new()
500 }
501}
502
503impl Estimator for DomainAdaptation<Untrained> {
504 type Config = ();
505 type Error = SklearsError;
506 type Float = Float;
507
508 fn config(&self) -> &Self::Config {
509 &()
510 }
511}
512
513impl Estimator for DomainAdaptation<DomainAdaptationTrained> {
514 type Config = ();
515 type Error = SklearsError;
516 type Float = Float;
517
518 fn config(&self) -> &Self::Config {
519 &()
520 }
521}
522
523#[derive(Debug, Clone)]
540pub struct ProgressiveTransferLearning<S = Untrained> {
541 state: S,
542 transfer_strength: Float,
543 learning_rate: Float,
544 max_iter: usize,
545 random_state: Option<u64>,
546}
547
548#[derive(Debug, Clone)]
549pub struct ProgressiveTransferLearningTrained {
550 task_weights: Vec<Array2<Float>>,
551 shared_weights: Array2<Float>,
552 task_order: Vec<usize>,
553 n_features: usize,
554 n_tasks: usize,
555}
556
557impl ProgressiveTransferLearning<Untrained> {
558 pub fn new() -> Self {
560 Self {
561 state: Untrained,
562 transfer_strength: 0.4,
563 learning_rate: 0.01,
564 max_iter: 500,
565 random_state: None,
566 }
567 }
568
569 pub fn transfer_strength(mut self, strength: Float) -> Self {
571 self.transfer_strength = strength;
572 self
573 }
574
575 pub fn learning_rate(mut self, lr: Float) -> Self {
577 self.learning_rate = lr;
578 self
579 }
580
581 pub fn max_iter(mut self, max_iter: usize) -> Self {
583 self.max_iter = max_iter;
584 self
585 }
586
587 pub fn random_state(mut self, seed: Option<u64>) -> Self {
589 self.random_state = seed;
590 self
591 }
592
593 pub fn fit(
595 &self,
596 X: &ArrayView2<Float>,
597 y: &ArrayView2<Float>,
598 task_order: Option<Vec<usize>>,
599 ) -> SklResult<ProgressiveTransferLearning<ProgressiveTransferLearningTrained>> {
600 let n_samples = X.nrows();
601 let n_features = X.ncols();
602 let n_tasks = y.ncols();
603
604 if n_samples != y.nrows() {
605 return Err(SklearsError::InvalidInput(
606 "Number of samples must match number of labels".to_string(),
607 ));
608 }
609
610 let mut rng = thread_rng();
611
612 let task_order = task_order.unwrap_or_else(|| (0..n_tasks).collect());
614
615 let mut shared_weights = Array2::<Float>::zeros((n_features, n_features));
617 let shared_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
618 for i in 0..n_features {
619 for j in 0..n_features {
620 shared_weights[[i, j]] = rng.sample(shared_normal_dist);
621 }
622 }
623
624 let mut task_weights = Vec::with_capacity(n_tasks);
625
626 for &task_idx in &task_order {
628 let task_y = y.column(task_idx);
629
630 let mut task_weight = Array2::<Float>::zeros((n_features, 1));
632 let task_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
633 for i in 0..n_features {
634 task_weight[[i, 0]] = rng.sample(task_normal_dist);
635 }
636
637 for _ in 0..self.max_iter {
639 let shared_features = X.dot(&shared_weights);
641
642 let task_pred = shared_features.dot(&task_weight);
644 let task_error = &task_pred.column(0) - &task_y;
645
646 let task_error_2d = task_error.insert_axis(Axis(1));
648 let task_grad = shared_features.t().dot(&task_error_2d) / n_samples as Float;
649 task_weight -= &(&task_grad * self.learning_rate);
650
651 if !task_weights.is_empty() {
653 let shared_grad =
654 X.t().dot(&task_error_2d.dot(&task_weight.t())) / n_samples as Float;
655 shared_weights -= &(shared_grad * self.learning_rate * self.transfer_strength);
656 }
657 }
658
659 task_weights.push(task_weight);
660 }
661
662 Ok(ProgressiveTransferLearning {
663 state: ProgressiveTransferLearningTrained {
664 task_weights,
665 shared_weights,
666 task_order,
667 n_features,
668 n_tasks,
669 },
670 transfer_strength: self.transfer_strength,
671 learning_rate: self.learning_rate,
672 max_iter: self.max_iter,
673 random_state: self.random_state,
674 })
675 }
676}
677
678impl ProgressiveTransferLearning<ProgressiveTransferLearningTrained> {
679 pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
681 if X.ncols() != self.state.n_features {
682 return Err(SklearsError::InvalidInput(
683 "Number of features must match training data".to_string(),
684 ));
685 }
686
687 let n_samples = X.nrows();
688 let shared_features = X.dot(&self.state.shared_weights);
689 let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_tasks));
690
691 for (i, &task_idx) in self.state.task_order.iter().enumerate() {
692 let task_pred = shared_features.dot(&self.state.task_weights[i]);
693 predictions
694 .column_mut(task_idx)
695 .assign(&task_pred.column(0));
696 }
697
698 Ok(predictions)
699 }
700
701 pub fn shared_weights(&self) -> &Array2<Float> {
703 &self.state.shared_weights
704 }
705
706 pub fn task_weights(&self) -> &Vec<Array2<Float>> {
708 &self.state.task_weights
709 }
710
711 pub fn task_order(&self) -> &Vec<usize> {
713 &self.state.task_order
714 }
715}
716
717impl Default for ProgressiveTransferLearning<Untrained> {
718 fn default() -> Self {
719 Self::new()
720 }
721}
722
723impl Estimator for ProgressiveTransferLearning<Untrained> {
724 type Config = ();
725 type Error = SklearsError;
726 type Float = Float;
727
728 fn config(&self) -> &Self::Config {
729 &()
730 }
731}
732
733impl Estimator for ProgressiveTransferLearning<ProgressiveTransferLearningTrained> {
734 type Config = ();
735 type Error = SklearsError;
736 type Float = Float;
737
738 fn config(&self) -> &Self::Config {
739 &()
740 }
741}
742
743#[derive(Debug, Clone)]
763pub struct ContinualLearning<S = Untrained> {
764 state: S,
765 importance_weight: Float,
766 learning_rate: Float,
767 max_iter: usize,
768 random_state: Option<u64>,
769}
770
771#[derive(Debug, Clone)]
772pub struct ContinualLearningTrained {
773 task_weights: Vec<Array2<Float>>,
774 fisher_information: Array2<Float>,
775 optimal_weights: Array2<Float>,
776 n_features: usize,
777 n_tasks: usize,
778}
779
780impl Default for ContinualLearning<Untrained> {
781 fn default() -> Self {
782 Self::new()
783 }
784}
785
786impl ContinualLearning<Untrained> {
787 pub fn new() -> Self {
789 Self {
790 state: Untrained,
791 importance_weight: 1000.0,
792 learning_rate: 0.01,
793 max_iter: 1000,
794 random_state: None,
795 }
796 }
797
798 pub fn importance_weight(mut self, weight: Float) -> Self {
800 self.importance_weight = weight;
801 self
802 }
803
804 pub fn learning_rate(mut self, lr: Float) -> Self {
806 self.learning_rate = lr;
807 self
808 }
809
810 pub fn max_iter(mut self, max_iter: usize) -> Self {
812 self.max_iter = max_iter;
813 self
814 }
815
816 pub fn random_state(mut self, seed: Option<u64>) -> Self {
818 self.random_state = seed;
819 self
820 }
821
822 pub fn fit(
824 &self,
825 tasks_X: &[ArrayView2<Float>],
826 tasks_y: &[ArrayView2<Float>],
827 ) -> SklResult<ContinualLearning<ContinualLearningTrained>> {
828 if tasks_X.len() != tasks_y.len() {
829 return Err(SklearsError::InvalidInput(
830 "Number of X and y task arrays must match".to_string(),
831 ));
832 }
833
834 if tasks_X.is_empty() {
835 return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
836 }
837
838 let n_features = tasks_X[0].ncols();
839 let n_tasks = tasks_y[0].ncols();
840
841 let mut rng = thread_rng();
843
844 let mut weights = Array2::<Float>::zeros((n_features, n_tasks));
845 let weights_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
846 for i in 0..n_features {
847 for j in 0..n_tasks {
848 weights[[i, j]] = rng.sample(weights_normal_dist);
849 }
850 }
851 let mut fisher_information = Array2::<Float>::zeros((n_features, n_tasks));
852 let mut task_weights = Vec::new();
853
854 for (task_idx, (X, y)) in tasks_X.iter().zip(tasks_y.iter()).enumerate() {
856 if X.nrows() != y.nrows() {
857 return Err(SklearsError::InvalidInput(
858 "Number of samples in X and y must match".to_string(),
859 ));
860 }
861
862 let old_weights = weights.clone();
864
865 for _ in 0..self.max_iter {
867 let predictions = X.dot(&weights);
868 let errors = &predictions - y;
869 let gradient = X.t().dot(&errors) / X.nrows() as Float;
870
871 if task_idx > 0 {
873 let penalty =
874 &fisher_information * (&weights - &old_weights) * self.importance_weight;
875 weights = &weights - self.learning_rate * (&gradient + penalty);
876 } else {
877 weights = &weights - self.learning_rate * &gradient;
878 }
879 }
880
881 let predictions = X.dot(&weights);
883 let errors = &predictions - y;
884 let grad_squared = X.t().dot(&errors.mapv(|x| x * x)) / X.nrows() as Float;
885 fisher_information = &fisher_information + grad_squared;
886
887 task_weights.push(weights.clone());
888 }
889
890 Ok(ContinualLearning {
891 state: ContinualLearningTrained {
892 task_weights,
893 fisher_information,
894 optimal_weights: weights,
895 n_features,
896 n_tasks,
897 },
898 importance_weight: self.importance_weight,
899 learning_rate: self.learning_rate,
900 max_iter: self.max_iter,
901 random_state: self.random_state,
902 })
903 }
904}
905
906impl ContinualLearning<ContinualLearningTrained> {
907 pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
909 if X.ncols() != self.state.n_features {
910 return Err(SklearsError::InvalidInput(
911 "Number of features must match training data".to_string(),
912 ));
913 }
914
915 Ok(X.dot(&self.state.optimal_weights))
916 }
917
918 pub fn task_weights(&self) -> &[Array2<Float>] {
920 &self.state.task_weights
921 }
922
923 pub fn fisher_information(&self) -> &Array2<Float> {
925 &self.state.fisher_information
926 }
927}
928
929impl Estimator for ContinualLearning<Untrained> {
930 type Config = ();
931 type Error = SklearsError;
932 type Float = Float;
933
934 fn config(&self) -> &Self::Config {
935 &()
936 }
937}
938
939impl Estimator for ContinualLearning<ContinualLearningTrained> {
940 type Config = ();
941 type Error = SklearsError;
942 type Float = Float;
943
944 fn config(&self) -> &Self::Config {
945 &()
946 }
947}
948
949#[derive(Debug, Clone)]
970pub struct KnowledgeDistillation<S = Untrained> {
971 state: S,
972 temperature: Float,
973 alpha: Float,
974 learning_rate: Float,
975 max_iter: usize,
976 random_state: Option<u64>,
977}
978
979#[derive(Debug, Clone)]
980pub struct KnowledgeDistillationTrained {
981 student_weights: Array2<Float>,
982 teacher_weights: Array2<Float>,
983 n_features: usize,
984 n_tasks: usize,
985}
986
987impl Default for KnowledgeDistillation<Untrained> {
988 fn default() -> Self {
989 Self::new()
990 }
991}
992
993impl KnowledgeDistillation<Untrained> {
994 pub fn new() -> Self {
996 Self {
997 state: Untrained,
998 temperature: 3.0,
999 alpha: 0.7,
1000 learning_rate: 0.01,
1001 max_iter: 1000,
1002 random_state: None,
1003 }
1004 }
1005
1006 pub fn temperature(mut self, temp: Float) -> Self {
1008 self.temperature = temp;
1009 self
1010 }
1011
1012 pub fn alpha(mut self, alpha: Float) -> Self {
1014 self.alpha = alpha;
1015 self
1016 }
1017
1018 pub fn learning_rate(mut self, lr: Float) -> Self {
1020 self.learning_rate = lr;
1021 self
1022 }
1023
1024 pub fn max_iter(mut self, max_iter: usize) -> Self {
1026 self.max_iter = max_iter;
1027 self
1028 }
1029
1030 pub fn random_state(mut self, seed: Option<u64>) -> Self {
1032 self.random_state = seed;
1033 self
1034 }
1035
1036 pub fn fit(
1038 &self,
1039 X: &ArrayView2<Float>,
1040 y: &ArrayView2<Float>,
1041 teacher_predictions: &ArrayView2<Float>,
1042 ) -> SklResult<KnowledgeDistillation<KnowledgeDistillationTrained>> {
1043 if X.nrows() != y.nrows() {
1044 return Err(SklearsError::InvalidInput(
1045 "Number of samples in X and y must match".to_string(),
1046 ));
1047 }
1048
1049 if X.nrows() != teacher_predictions.nrows() {
1050 return Err(SklearsError::InvalidInput(
1051 "Number of samples in X and teacher predictions must match".to_string(),
1052 ));
1053 }
1054
1055 let n_features = X.ncols();
1056 let n_tasks = y.ncols();
1057
1058 let mut rng = thread_rng();
1060
1061 let mut student_weights = Array2::<Float>::zeros((n_features, n_tasks));
1062 let student_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
1063 for i in 0..n_features {
1064 for j in 0..n_tasks {
1065 student_weights[[i, j]] = rng.sample(student_normal_dist);
1066 }
1067 }
1068 let mut teacher_weights = Array2::<Float>::zeros((n_features, n_tasks));
1069 let teacher_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
1070 for i in 0..n_features {
1071 for j in 0..n_tasks {
1072 teacher_weights[[i, j]] = rng.sample(teacher_normal_dist);
1073 }
1074 }
1075
1076 for _ in 0..self.max_iter {
1078 let student_predictions = X.dot(&student_weights);
1079
1080 let soft_targets = teacher_predictions / self.temperature;
1082 let student_soft = &student_predictions / self.temperature;
1083
1084 let hard_loss = &student_predictions - y;
1086 let soft_loss = &student_soft - &soft_targets;
1087
1088 let combined_loss = (1.0 - self.alpha) * hard_loss + self.alpha * soft_loss;
1089 let gradient = X.t().dot(&combined_loss) / X.nrows() as Float;
1090
1091 student_weights = &student_weights - self.learning_rate * &gradient;
1092 }
1093
1094 Ok(KnowledgeDistillation {
1095 state: KnowledgeDistillationTrained {
1096 student_weights,
1097 teacher_weights,
1098 n_features,
1099 n_tasks,
1100 },
1101 temperature: self.temperature,
1102 alpha: self.alpha,
1103 learning_rate: self.learning_rate,
1104 max_iter: self.max_iter,
1105 random_state: self.random_state,
1106 })
1107 }
1108}
1109
1110impl KnowledgeDistillation<KnowledgeDistillationTrained> {
1111 pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1113 if X.ncols() != self.state.n_features {
1114 return Err(SklearsError::InvalidInput(
1115 "Number of features must match training data".to_string(),
1116 ));
1117 }
1118
1119 Ok(X.dot(&self.state.student_weights))
1120 }
1121
1122 pub fn student_weights(&self) -> &Array2<Float> {
1124 &self.state.student_weights
1125 }
1126
1127 pub fn teacher_weights(&self) -> &Array2<Float> {
1129 &self.state.teacher_weights
1130 }
1131}
1132
1133impl Estimator for KnowledgeDistillation<Untrained> {
1134 type Config = ();
1135 type Error = SklearsError;
1136 type Float = Float;
1137
1138 fn config(&self) -> &Self::Config {
1139 &()
1140 }
1141}
1142
1143impl Estimator for KnowledgeDistillation<KnowledgeDistillationTrained> {
1144 type Config = ();
1145 type Error = SklearsError;
1146 type Float = Float;
1147
1148 fn config(&self) -> &Self::Config {
1149 &()
1150 }
1151}
1152
1153#[allow(non_snake_case)]
1154#[cfg(test)]
1155mod tests {
1156 use super::*;
1157 use scirs2_core::ndarray::array;
1159
1160 #[test]
1161 fn test_cross_task_transfer_learning_basic() {
1162 let source_X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 3.0]];
1163 let source_y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]];
1164 let target_X = array![[1.1, 2.1], [2.1, 3.1]];
1165 let target_y = array![[1.0, 0.0], [0.0, 1.0]];
1166
1167 let transfer = CrossTaskTransferLearning::new()
1168 .transfer_strength(0.5)
1169 .learning_rate(0.01)
1170 .max_iter(100)
1171 .random_state(Some(42));
1172
1173 let trained = transfer
1174 .fit(
1175 &source_X.view(),
1176 &source_y.view(),
1177 &target_X.view(),
1178 &target_y.view(),
1179 )
1180 .unwrap();
1181
1182 let predictions = trained.predict(&target_X.view()).unwrap();
1183 assert_eq!(predictions.dim(), (2, 2));
1184
1185 let source_predictions = trained.predict_from_source(&target_X.view()).unwrap();
1186 assert_eq!(source_predictions.dim(), (2, 2));
1187 }
1188
1189 #[test]
1190 fn test_cross_task_transfer_learning_validation() {
1191 let source_X = array![[1.0, 2.0], [2.0, 3.0]];
1192 let source_y = array![[1.0, 0.0], [0.0, 1.0]];
1193 let target_X = array![[1.1, 2.1, 3.1]]; let target_y = array![[1.0, 0.0]];
1195
1196 let transfer = CrossTaskTransferLearning::new();
1197
1198 assert!(transfer
1200 .fit(
1201 &source_X.view(),
1202 &source_y.view(),
1203 &target_X.view(),
1204 &target_y.view()
1205 )
1206 .is_err());
1207 }
1208
1209 #[test]
1210 fn test_domain_adaptation_basic() {
1211 let source_X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 3.0]];
1212 let source_y = array![[1.0], [0.0], [1.0], [0.0]];
1213 let target_X = array![[1.1, 2.1], [2.1, 3.1]];
1214 let target_y = array![[1.0], [0.0]];
1215
1216 let adaptation = DomainAdaptation::new()
1217 .adaptation_strength(0.3)
1218 .learning_rate(0.01)
1219 .max_iter(100)
1220 .random_state(Some(42));
1221
1222 let trained = adaptation
1223 .fit(
1224 &source_X.view(),
1225 &source_y.view(),
1226 &target_X.view(),
1227 &target_y.view(),
1228 )
1229 .unwrap();
1230
1231 let predictions = trained.predict(&target_X.view()).unwrap();
1232 assert_eq!(predictions.dim(), (2, 1));
1233
1234 let features = trained.extract_features(&target_X.view()).unwrap();
1235 assert_eq!(features.ncols(), 1); let domain_pred = trained.predict_domain(&target_X.view()).unwrap();
1238 assert_eq!(domain_pred.len(), 2);
1239 }
1240
1241 #[test]
1242 #[allow(non_snake_case)]
1243 fn test_progressive_transfer_learning_basic() {
1244 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 3.0]];
1245 let y = array![
1246 [1.0, 0.0, 1.0],
1247 [0.0, 1.0, 0.0],
1248 [1.0, 1.0, 1.0],
1249 [0.0, 0.0, 0.0]
1250 ];
1251
1252 let transfer = ProgressiveTransferLearning::new()
1253 .transfer_strength(0.4)
1254 .learning_rate(0.01)
1255 .max_iter(100)
1256 .random_state(Some(42));
1257
1258 let trained = transfer.fit(&X.view(), &y.view(), None).unwrap();
1259
1260 let predictions = trained.predict(&X.view()).unwrap();
1261 assert_eq!(predictions.dim(), (4, 3));
1262
1263 assert_eq!(trained.task_weights().len(), 3);
1265 assert_eq!(trained.task_order().len(), 3);
1266 }
1267
1268 #[test]
1269 #[allow(non_snake_case)]
1270 fn test_progressive_transfer_learning_custom_order() {
1271 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1272 let y = array![[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]];
1273
1274 let transfer = ProgressiveTransferLearning::new().random_state(Some(42));
1275
1276 let custom_order = vec![2, 0, 1]; let trained = transfer
1278 .fit(&X.view(), &y.view(), Some(custom_order.clone()))
1279 .unwrap();
1280
1281 assert_eq!(trained.task_order(), &custom_order);
1282 }
1283
1284 #[test]
1285 #[allow(non_snake_case)]
1286 fn test_transfer_learning_error_handling() {
1287 let X = array![[1.0, 2.0], [2.0, 3.0]];
1288 let y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]; let transfer = ProgressiveTransferLearning::new();
1291 assert!(transfer.fit(&X.view(), &y.view(), None).is_err());
1292 }
1293
1294 #[test]
1295 fn test_continual_learning_basic() {
1296 let X1 = array![[1.0, 2.0], [2.0, 3.0]];
1297 let y1 = array![[1.0, 0.0], [0.0, 1.0]];
1298 let X2 = array![[3.0, 1.0], [1.0, 3.0]];
1299 let y2 = array![[1.0, 1.0], [0.0, 0.0]];
1300
1301 let tasks_X = vec![X1.view(), X2.view()];
1302 let tasks_y = vec![y1.view(), y2.view()];
1303
1304 let continual = ContinualLearning::new()
1305 .importance_weight(1000.0)
1306 .learning_rate(0.01)
1307 .max_iter(100)
1308 .random_state(Some(42));
1309
1310 let trained = continual.fit(&tasks_X, &tasks_y).unwrap();
1311
1312 let predictions = trained.predict(&X1.view()).unwrap();
1313 assert_eq!(predictions.dim(), (2, 2));
1314
1315 assert_eq!(trained.task_weights().len(), 2);
1317 assert_eq!(trained.fisher_information().dim(), (2, 2));
1318 }
1319
1320 #[test]
1321 fn test_continual_learning_error_handling() {
1322 let X1 = array![[1.0, 2.0], [2.0, 3.0]];
1323 let y1 = array![[1.0, 0.0], [0.0, 1.0]];
1324 let X2 = array![[3.0, 1.0]]; let y2 = array![[1.0, 1.0], [0.0, 0.0]];
1326
1327 let tasks_X = vec![X1.view(), X2.view()];
1328 let tasks_y = vec![y1.view(), y2.view()];
1329
1330 let continual = ContinualLearning::new();
1331 assert!(continual.fit(&tasks_X, &tasks_y).is_err());
1332 }
1333
1334 #[test]
1335 #[allow(non_snake_case)]
1336 fn test_knowledge_distillation_basic() {
1337 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1338 let y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1339 let teacher_predictions = array![[0.9, 0.1], [0.1, 0.9], [0.8, 0.8]];
1340
1341 let distillation = KnowledgeDistillation::new()
1342 .temperature(3.0)
1343 .alpha(0.7)
1344 .learning_rate(0.01)
1345 .max_iter(100)
1346 .random_state(Some(42));
1347
1348 let trained = distillation
1349 .fit(&X.view(), &y.view(), &teacher_predictions.view())
1350 .unwrap();
1351
1352 let predictions = trained.predict(&X.view()).unwrap();
1353 assert_eq!(predictions.dim(), (3, 2));
1354
1355 assert_eq!(trained.student_weights().dim(), (2, 2));
1357 assert_eq!(trained.teacher_weights().dim(), (2, 2));
1358 }
1359
1360 #[test]
1361 #[allow(non_snake_case)]
1362 fn test_knowledge_distillation_error_handling() {
1363 let X = array![[1.0, 2.0], [2.0, 3.0]];
1364 let y = array![[1.0, 0.0], [0.0, 1.0]];
1365 let teacher_predictions = array![[0.9, 0.1], [0.1, 0.9], [0.8, 0.8]]; let distillation = KnowledgeDistillation::new();
1368 assert!(distillation
1369 .fit(&X.view(), &y.view(), &teacher_predictions.view())
1370 .is_err());
1371 }
1372
1373 #[test]
1374 fn test_knowledge_distillation_configuration() {
1375 let distillation = KnowledgeDistillation::new()
1376 .temperature(5.0)
1377 .alpha(0.5)
1378 .learning_rate(0.001)
1379 .max_iter(2000)
1380 .random_state(Some(123));
1381
1382 assert_eq!(distillation.temperature, 5.0);
1384 assert_eq!(distillation.alpha, 0.5);
1385 assert_eq!(distillation.learning_rate, 0.001);
1386 assert_eq!(distillation.max_iter, 2000);
1387 assert_eq!(distillation.random_state, Some(123));
1388 }
1389}