1use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::random::essentials::Normal;
14use scirs2_core::random::rngs::StdRng;
15use scirs2_core::random::{Distribution, Rng, SeedableRng};
16use serde::{Deserialize, Serialize};
17use sklears_core::types::Float;
18use std::collections::{HashMap, VecDeque};
19
20#[derive(Debug, Clone)]
26pub enum TransferStrategy {
27 DirectTransfer {
29 source_weight: Float,
31 adapt_domain: bool,
33 },
34 FeatureTransfer {
36 feature_dim: usize,
38 adaptation_rate: Float,
40 },
41 ModelTransfer {
43 adaptation_steps: usize,
45 regularization: Float,
47 },
48 InstanceTransfer {
50 k_neighbors: usize,
52 weighting_method: ImportanceWeightingMethod,
54 },
55 MultiTaskTransfer {
57 n_tasks: usize,
59 similarity_threshold: Float,
61 },
62}
63
64#[derive(Debug, Clone)]
66pub enum ImportanceWeightingMethod {
67 KernelMeanMatching { kernel_bandwidth: Float },
69 KLIEP { regularization: Float },
71 ULSIF { sigma_list: Vec<Float> },
73 RatioOfGaussians { bandwidth: Float },
75}
76
77#[derive(Debug, Clone)]
79pub struct TransferLearningConfig {
80 pub strategy: TransferStrategy,
81 pub source_task_data: Vec<OptimizationExperience>,
82 pub n_init_samples: usize,
83 pub confidence_threshold: Float,
84 pub max_transfer_iterations: usize,
85 pub random_state: Option<u64>,
86}
87
88impl Default for TransferLearningConfig {
89 fn default() -> Self {
90 Self {
91 strategy: TransferStrategy::DirectTransfer {
92 source_weight: 0.5,
93 adapt_domain: true,
94 },
95 source_task_data: Vec::new(),
96 n_init_samples: 5,
97 confidence_threshold: 0.7,
98 max_transfer_iterations: 50,
99 random_state: None,
100 }
101 }
102}
103
104pub struct TransferLearningOptimizer {
106 config: TransferLearningConfig,
107 transferred_knowledge: HashMap<String, ParameterDistribution>,
108 domain_adaptation_params: HashMap<String, Float>,
109 transfer_performance: Vec<Float>,
110}
111
112impl TransferLearningOptimizer {
113 pub fn new(config: TransferLearningConfig) -> Self {
114 Self {
115 config,
116 transferred_knowledge: HashMap::new(),
117 domain_adaptation_params: HashMap::new(),
118 transfer_performance: Vec::new(),
119 }
120 }
121
122 pub fn transfer_knowledge(
124 &mut self,
125 target_task_characteristics: &TaskCharacteristics,
126 ) -> Result<TransferResult, Box<dyn std::error::Error>> {
127 let strategy = self.config.strategy.clone();
129 match strategy {
130 TransferStrategy::DirectTransfer {
131 source_weight,
132 adapt_domain,
133 } => self.direct_transfer(source_weight, adapt_domain, target_task_characteristics),
134 TransferStrategy::FeatureTransfer {
135 feature_dim,
136 adaptation_rate,
137 } => self.feature_transfer(feature_dim, adaptation_rate, target_task_characteristics),
138 TransferStrategy::ModelTransfer {
139 adaptation_steps,
140 regularization,
141 } => self.model_transfer(
142 adaptation_steps,
143 regularization,
144 target_task_characteristics,
145 ),
146 TransferStrategy::InstanceTransfer {
147 k_neighbors,
148 weighting_method,
149 } => {
150 self.instance_transfer(k_neighbors, &weighting_method, target_task_characteristics)
151 }
152 TransferStrategy::MultiTaskTransfer {
153 n_tasks,
154 similarity_threshold,
155 } => {
156 self.multi_task_transfer(n_tasks, similarity_threshold, target_task_characteristics)
157 }
158 }
159 }
160
161 fn direct_transfer(
163 &mut self,
164 source_weight: Float,
165 adapt_domain: bool,
166 target_task: &TaskCharacteristics,
167 ) -> Result<TransferResult, Box<dyn std::error::Error>> {
168 let mut transferred_params = HashMap::new();
169 let mut transfer_confidence = HashMap::new();
170
171 let similar_tasks = self.find_similar_tasks(target_task, 5)?;
173
174 for param_name in target_task.parameter_space.keys() {
175 let mut weighted_sum = 0.0;
176 let mut total_weight = 0.0;
177
178 for (task_idx, similarity) in &similar_tasks {
179 if let Some(source_task) = self.config.source_task_data.get(*task_idx) {
180 if let Some(value) = source_task.best_parameters.get(param_name) {
181 let weight = similarity * source_weight;
182 weighted_sum += value * weight;
183 total_weight += weight;
184 }
185 }
186 }
187
188 if total_weight > 0.0 {
189 let transferred_value = weighted_sum / total_weight;
190
191 let final_value = if adapt_domain {
193 self.apply_domain_adaptation(param_name, transferred_value, target_task)?
194 } else {
195 transferred_value
196 };
197
198 transferred_params.insert(param_name.clone(), final_value);
199 transfer_confidence.insert(
200 param_name.clone(),
201 total_weight / similar_tasks.len() as Float,
202 );
203 }
204 }
205
206 Ok(TransferResult {
207 transferred_parameters: transferred_params,
208 confidence_scores: transfer_confidence,
209 source_tasks_used: similar_tasks.iter().map(|(idx, _)| *idx).collect(),
210 adaptation_applied: adapt_domain,
211 expected_improvement: self.estimate_transfer_improvement(&similar_tasks),
212 })
213 }
214
215 fn feature_transfer(
217 &mut self,
218 feature_dim: usize,
219 adaptation_rate: Float,
220 target_task: &TaskCharacteristics,
221 ) -> Result<TransferResult, Box<dyn std::error::Error>> {
222 let feature_matrix = self.learn_feature_representation(feature_dim)?;
224
225 let target_features = self.map_to_feature_space(target_task, &feature_matrix)?;
227
228 let adapted_features = self.adapt_features(&target_features, adaptation_rate)?;
230
231 let transferred_params = self.decode_features(&adapted_features)?;
233
234 Ok(TransferResult {
235 transferred_parameters: transferred_params.clone(),
236 confidence_scores: transferred_params
237 .keys()
238 .map(|k| (k.clone(), 0.8))
239 .collect(),
240 source_tasks_used: (0..self.config.source_task_data.len()).collect(),
241 adaptation_applied: true,
242 expected_improvement: 0.15,
243 })
244 }
245
246 fn model_transfer(
248 &mut self,
249 adaptation_steps: usize,
250 regularization: Float,
251 target_task: &TaskCharacteristics,
252 ) -> Result<TransferResult, Box<dyn std::error::Error>> {
253 let surrogate = self.build_transfer_surrogate()?;
255
256 let adapted_surrogate =
258 self.fine_tune_surrogate(&surrogate, target_task, adaptation_steps, regularization)?;
259
260 let transferred_params = self.generate_recommendations(&adapted_surrogate, target_task)?;
262
263 Ok(TransferResult {
264 transferred_parameters: transferred_params.clone(),
265 confidence_scores: transferred_params
266 .keys()
267 .map(|k| (k.clone(), 0.75))
268 .collect(),
269 source_tasks_used: (0..self.config.source_task_data.len()).collect(),
270 adaptation_applied: true,
271 expected_improvement: 0.20,
272 })
273 }
274
275 fn instance_transfer(
277 &mut self,
278 k_neighbors: usize,
279 weighting_method: &ImportanceWeightingMethod,
280 target_task: &TaskCharacteristics,
281 ) -> Result<TransferResult, Box<dyn std::error::Error>> {
282 let similar_instances = self.find_similar_tasks(target_task, k_neighbors)?;
284
285 let weights = self.compute_importance_weights(&similar_instances, weighting_method)?;
287
288 let mut transferred_params = HashMap::new();
290 for param_name in target_task.parameter_space.keys() {
291 let mut weighted_sum = 0.0;
292 let mut total_weight = 0.0;
293
294 for ((task_idx, _), weight) in similar_instances.iter().zip(weights.iter()) {
295 if let Some(source_task) = self.config.source_task_data.get(*task_idx) {
296 if let Some(value) = source_task.best_parameters.get(param_name) {
297 weighted_sum += value * weight;
298 total_weight += weight;
299 }
300 }
301 }
302
303 if total_weight > 0.0 {
304 transferred_params.insert(param_name.clone(), weighted_sum / total_weight);
305 }
306 }
307
308 Ok(TransferResult {
309 transferred_parameters: transferred_params.clone(),
310 confidence_scores: transferred_params
311 .keys()
312 .map(|k| (k.clone(), 0.7))
313 .collect(),
314 source_tasks_used: similar_instances.iter().map(|(idx, _)| *idx).collect(),
315 adaptation_applied: true,
316 expected_improvement: 0.12,
317 })
318 }
319
320 fn multi_task_transfer(
322 &mut self,
323 n_tasks: usize,
324 similarity_threshold: Float,
325 target_task: &TaskCharacteristics,
326 ) -> Result<TransferResult, Box<dyn std::error::Error>> {
327 let related_tasks: Vec<_> = self
329 .find_similar_tasks(target_task, n_tasks)?
330 .into_iter()
331 .filter(|(_, sim)| *sim >= similarity_threshold)
332 .collect();
333
334 if related_tasks.is_empty() {
335 return Err("No sufficiently similar tasks found for transfer".into());
336 }
337
338 let shared_params = self.learn_shared_representation(&related_tasks)?;
340
341 let adapted_params = self.task_specific_adaptation(&shared_params, target_task)?;
343
344 Ok(TransferResult {
345 transferred_parameters: adapted_params.clone(),
346 confidence_scores: adapted_params.keys().map(|k| (k.clone(), 0.85)).collect(),
347 source_tasks_used: related_tasks.iter().map(|(idx, _)| *idx).collect(),
348 adaptation_applied: true,
349 expected_improvement: 0.25,
350 })
351 }
352
353 fn find_similar_tasks(
356 &self,
357 target_task: &TaskCharacteristics,
358 k: usize,
359 ) -> Result<Vec<(usize, Float)>, Box<dyn std::error::Error>> {
360 let mut similarities = Vec::new();
361
362 for (idx, source_task) in self.config.source_task_data.iter().enumerate() {
363 let similarity =
364 self.compute_task_similarity(target_task, &source_task.task_characteristics)?;
365 similarities.push((idx, similarity));
366 }
367
368 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
370 similarities.truncate(k);
371
372 Ok(similarities)
373 }
374
375 fn compute_task_similarity(
376 &self,
377 task_a: &TaskCharacteristics,
378 task_b: &TaskCharacteristics,
379 ) -> Result<Float, Box<dyn std::error::Error>> {
380 let size_sim = self.compute_size_similarity(task_a, task_b);
382 let complexity_sim = self.compute_complexity_similarity(task_a, task_b);
383 let domain_sim = self.compute_domain_similarity(task_a, task_b);
384
385 Ok(0.4 * size_sim + 0.3 * complexity_sim + 0.3 * domain_sim)
387 }
388
389 fn compute_size_similarity(
390 &self,
391 task_a: &TaskCharacteristics,
392 task_b: &TaskCharacteristics,
393 ) -> Float {
394 let ratio = task_a.n_samples as Float / task_b.n_samples as Float;
395
396 if ratio > 1.0 {
397 1.0 / ratio
398 } else {
399 ratio
400 } }
402
403 fn compute_complexity_similarity(
404 &self,
405 task_a: &TaskCharacteristics,
406 task_b: &TaskCharacteristics,
407 ) -> Float {
408 let feat_ratio = task_a.n_features as Float / task_b.n_features as Float;
410
411 if feat_ratio > 1.0 {
412 1.0 / feat_ratio
413 } else {
414 feat_ratio
415 }
416 }
417
418 fn compute_domain_similarity(
419 &self,
420 task_a: &TaskCharacteristics,
421 task_b: &TaskCharacteristics,
422 ) -> Float {
423 if task_a.task_type == task_b.task_type {
425 0.8
426 } else {
427 0.3
428 }
429 }
430
431 fn apply_domain_adaptation(
432 &mut self,
433 param_name: &str,
434 value: Float,
435 _target_task: &TaskCharacteristics,
436 ) -> Result<Float, Box<dyn std::error::Error>> {
437 let adaptation_factor = self
439 .domain_adaptation_params
440 .get(param_name)
441 .cloned()
442 .unwrap_or(1.0);
443
444 Ok(value * adaptation_factor)
445 }
446
447 fn estimate_transfer_improvement(&self, similar_tasks: &[(usize, Float)]) -> Float {
448 if similar_tasks.is_empty() {
449 return 0.0;
450 }
451
452 let avg_similarity: Float =
454 similar_tasks.iter().map(|(_, sim)| sim).sum::<Float>() / similar_tasks.len() as Float;
455
456 avg_similarity * 0.3 }
459
460 fn learn_feature_representation(
461 &self,
462 feature_dim: usize,
463 ) -> Result<Array2<Float>, Box<dyn std::error::Error>> {
464 let n_source_tasks = self.config.source_task_data.len();
466 let mut rng = StdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
467 let normal = Normal::new(0.0, 0.1)
468 .map_err(|e| format!("Failed to create normal distribution: {}", e))?;
469
470 let feature_matrix =
471 Array2::from_shape_fn((n_source_tasks, feature_dim), |_| normal.sample(&mut rng));
472
473 Ok(feature_matrix)
474 }
475
476 fn map_to_feature_space(
477 &self,
478 _target_task: &TaskCharacteristics,
479 _feature_matrix: &Array2<Float>,
480 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
481 Ok(Array1::zeros(10))
483 }
484
485 fn adapt_features(
486 &self,
487 features: &Array1<Float>,
488 adaptation_rate: Float,
489 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
490 Ok(features * adaptation_rate)
492 }
493
494 fn decode_features(
495 &self,
496 _features: &Array1<Float>,
497 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
498 Ok(HashMap::new())
500 }
501
502 fn build_transfer_surrogate(&self) -> Result<TransferSurrogate, Box<dyn std::error::Error>> {
503 Ok(TransferSurrogate::default())
504 }
505
506 fn fine_tune_surrogate(
507 &self,
508 _surrogate: &TransferSurrogate,
509 _target_task: &TaskCharacteristics,
510 _steps: usize,
511 _regularization: Float,
512 ) -> Result<TransferSurrogate, Box<dyn std::error::Error>> {
513 Ok(TransferSurrogate::default())
514 }
515
516 fn generate_recommendations(
517 &self,
518 _surrogate: &TransferSurrogate,
519 _target_task: &TaskCharacteristics,
520 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
521 Ok(HashMap::new())
522 }
523
524 fn compute_importance_weights(
525 &self,
526 _similar_instances: &[(usize, Float)],
527 weighting_method: &ImportanceWeightingMethod,
528 ) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
529 let weights = match weighting_method {
531 ImportanceWeightingMethod::KernelMeanMatching { .. } => {
532 vec![1.0; _similar_instances.len()]
533 }
534 ImportanceWeightingMethod::KLIEP { .. } => {
535 vec![1.0; _similar_instances.len()]
536 }
537 ImportanceWeightingMethod::ULSIF { .. } => {
538 vec![1.0; _similar_instances.len()]
539 }
540 ImportanceWeightingMethod::RatioOfGaussians { .. } => {
541 vec![1.0; _similar_instances.len()]
542 }
543 };
544
545 Ok(weights)
546 }
547
548 fn learn_shared_representation(
549 &self,
550 _related_tasks: &[(usize, Float)],
551 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
552 Ok(HashMap::new())
553 }
554
555 fn task_specific_adaptation(
556 &self,
557 shared_params: &HashMap<String, Float>,
558 _target_task: &TaskCharacteristics,
559 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
560 Ok(shared_params.clone())
561 }
562}
563
564#[derive(Debug, Clone, Default)]
565struct TransferSurrogate {
566 }
568
569#[derive(Debug, Clone)]
571pub struct TransferResult {
572 pub transferred_parameters: HashMap<String, Float>,
573 pub confidence_scores: HashMap<String, Float>,
574 pub source_tasks_used: Vec<usize>,
575 pub adaptation_applied: bool,
576 pub expected_improvement: Float,
577}
578
579#[derive(Debug, Clone)]
585pub struct FewShotConfig {
586 pub n_support: usize,
588 pub n_query: usize,
590 pub algorithm: FewShotAlgorithm,
592 pub n_meta_episodes: usize,
594 pub inner_lr: Float,
596 pub outer_lr: Float,
598 pub random_state: Option<u64>,
599}
600
601impl Default for FewShotConfig {
602 fn default() -> Self {
603 Self {
604 n_support: 5,
605 n_query: 10,
606 algorithm: FewShotAlgorithm::MAML {
607 adaptation_steps: 5,
608 },
609 n_meta_episodes: 100,
610 inner_lr: 0.01,
611 outer_lr: 0.001,
612 random_state: None,
613 }
614 }
615}
616
617#[derive(Debug, Clone)]
619pub enum FewShotAlgorithm {
620 MAML { adaptation_steps: usize },
622 ProtoNet { embedding_dim: usize },
624 MatchingNet { attention_mechanism: bool },
626 RelationNet { relation_module_layers: Vec<usize> },
628}
629
630pub struct FewShotOptimizer {
632 config: FewShotConfig,
633 meta_parameters: HashMap<String, Float>,
634 episode_history: Vec<FewShotEpisode>,
635}
636
637impl FewShotOptimizer {
638 pub fn new(config: FewShotConfig) -> Self {
639 Self {
640 config,
641 meta_parameters: HashMap::new(),
642 episode_history: Vec::new(),
643 }
644 }
645
646 pub fn meta_train(
648 &mut self,
649 tasks: &[OptimizationTask],
650 ) -> Result<FewShotResult, Box<dyn std::error::Error>> {
651 let algorithm = self.config.algorithm.clone();
653 match algorithm {
654 FewShotAlgorithm::MAML { adaptation_steps } => self.train_maml(tasks, adaptation_steps),
655 FewShotAlgorithm::ProtoNet { embedding_dim } => {
656 self.train_protonet(tasks, embedding_dim)
657 }
658 FewShotAlgorithm::MatchingNet {
659 attention_mechanism,
660 } => self.train_matchingnet(tasks, attention_mechanism),
661 FewShotAlgorithm::RelationNet {
662 relation_module_layers,
663 } => self.train_relationnet(tasks, &relation_module_layers),
664 }
665 }
666
667 pub fn adapt(
669 &self,
670 support_set: &[(HashMap<String, Float>, Float)],
671 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
672 if support_set.len() < self.config.n_support {
673 return Err(format!(
674 "Insufficient support examples: {} < {}",
675 support_set.len(),
676 self.config.n_support
677 )
678 .into());
679 }
680
681 match &self.config.algorithm {
682 FewShotAlgorithm::MAML { adaptation_steps } => {
683 self.adapt_maml(support_set, *adaptation_steps)
684 }
685 FewShotAlgorithm::ProtoNet { .. } => self.adapt_protonet(support_set),
686 FewShotAlgorithm::MatchingNet { .. } => self.adapt_matchingnet(support_set),
687 FewShotAlgorithm::RelationNet { .. } => self.adapt_relationnet(support_set),
688 }
689 }
690
691 fn train_maml(
693 &mut self,
694 tasks: &[OptimizationTask],
695 adaptation_steps: usize,
696 ) -> Result<FewShotResult, Box<dyn std::error::Error>> {
697 for param_name in &["learning_rate", "momentum", "batch_size"] {
699 self.meta_parameters.insert(param_name.to_string(), 0.5);
700 }
701
702 let mut meta_loss_history = Vec::new();
703
704 for _episode in 0..self.config.n_meta_episodes {
705 let mut episode_loss = 0.0;
706
707 let n_tasks = tasks.len().min(5);
709
710 for task in tasks.iter().take(n_tasks) {
711 let mut adapted_params = self.meta_parameters.clone();
713
714 for _ in 0..adaptation_steps {
715 let gradient =
717 self.compute_inner_gradient(&adapted_params, &task.support_examples)?;
718 for (param_name, grad) in gradient {
719 if let Some(param) = adapted_params.get_mut(¶m_name) {
720 *param -= self.config.inner_lr * grad;
721 }
722 }
723 }
724
725 let query_loss = self.compute_query_loss(&adapted_params, &task.query_examples)?;
727 episode_loss += query_loss;
728 }
729
730 let meta_gradient = episode_loss / n_tasks as Float;
732 for (_param_name, param_value) in self.meta_parameters.iter_mut() {
733 *param_value -= self.config.outer_lr * meta_gradient;
734 }
735
736 meta_loss_history.push(episode_loss / n_tasks as Float);
737 }
738
739 let final_perf = meta_loss_history.last().cloned().unwrap_or(0.0);
740
741 Ok(FewShotResult {
742 meta_parameters: self.meta_parameters.clone(),
743 meta_loss_history,
744 n_episodes: self.config.n_meta_episodes,
745 final_performance: final_perf,
746 })
747 }
748
749 fn adapt_maml(
750 &self,
751 support_set: &[(HashMap<String, Float>, Float)],
752 adaptation_steps: usize,
753 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
754 let mut adapted_params = self.meta_parameters.clone();
755
756 for _ in 0..adaptation_steps {
757 let gradient = self.compute_support_gradient(&adapted_params, support_set)?;
758 for (param_name, grad) in gradient {
759 if let Some(param) = adapted_params.get_mut(¶m_name) {
760 *param -= self.config.inner_lr * grad;
761 }
762 }
763 }
764
765 Ok(adapted_params)
766 }
767
768 fn train_protonet(
770 &mut self,
771 tasks: &[OptimizationTask],
772 embedding_dim: usize,
773 ) -> Result<FewShotResult, Box<dyn std::error::Error>> {
774 let mut prototypes = HashMap::new();
776
777 for task in tasks {
778 let prototype = self.compute_prototype(&task.support_examples, embedding_dim)?;
780 prototypes.insert(task.task_id.clone(), prototype);
781 }
782
783 Ok(FewShotResult {
784 meta_parameters: HashMap::new(),
785 meta_loss_history: vec![0.0],
786 n_episodes: self.config.n_meta_episodes,
787 final_performance: 0.0,
788 })
789 }
790
791 fn adapt_protonet(
792 &self,
793 support_set: &[(HashMap<String, Float>, Float)],
794 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
795 if support_set.is_empty() {
797 return Ok(HashMap::new());
798 }
799
800 let best = support_set
802 .iter()
803 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
804 .unwrap();
805
806 Ok(best.0.clone())
807 }
808
809 fn train_matchingnet(
811 &mut self,
812 _tasks: &[OptimizationTask],
813 _attention: bool,
814 ) -> Result<FewShotResult, Box<dyn std::error::Error>> {
815 Ok(FewShotResult {
816 meta_parameters: HashMap::new(),
817 meta_loss_history: vec![0.0],
818 n_episodes: self.config.n_meta_episodes,
819 final_performance: 0.0,
820 })
821 }
822
823 fn adapt_matchingnet(
824 &self,
825 support_set: &[(HashMap<String, Float>, Float)],
826 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
827 self.adapt_protonet(support_set)
828 }
829
830 fn train_relationnet(
832 &mut self,
833 _tasks: &[OptimizationTask],
834 _layers: &[usize],
835 ) -> Result<FewShotResult, Box<dyn std::error::Error>> {
836 Ok(FewShotResult {
837 meta_parameters: HashMap::new(),
838 meta_loss_history: vec![0.0],
839 n_episodes: self.config.n_meta_episodes,
840 final_performance: 0.0,
841 })
842 }
843
844 fn adapt_relationnet(
845 &self,
846 support_set: &[(HashMap<String, Float>, Float)],
847 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
848 self.adapt_protonet(support_set)
849 }
850
851 fn compute_inner_gradient(
853 &self,
854 _params: &HashMap<String, Float>,
855 _support: &[(HashMap<String, Float>, Float)],
856 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
857 Ok(HashMap::new())
859 }
860
861 fn compute_query_loss(
862 &self,
863 _params: &HashMap<String, Float>,
864 _query: &[(HashMap<String, Float>, Float)],
865 ) -> Result<Float, Box<dyn std::error::Error>> {
866 Ok(0.1) }
868
869 fn compute_support_gradient(
870 &self,
871 _params: &HashMap<String, Float>,
872 _support: &[(HashMap<String, Float>, Float)],
873 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
874 Ok(HashMap::new())
875 }
876
877 fn compute_prototype(
878 &self,
879 _examples: &[(HashMap<String, Float>, Float)],
880 _embedding_dim: usize,
881 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
882 Ok(Array1::zeros(10))
884 }
885}
886
887#[derive(Debug, Clone)]
889pub struct FewShotResult {
890 pub meta_parameters: HashMap<String, Float>,
891 pub meta_loss_history: Vec<Float>,
892 pub n_episodes: usize,
893 pub final_performance: Float,
894}
895
896#[derive(Debug, Clone)]
898struct FewShotEpisode {
899 task_id: String,
900 support_loss: Float,
901 query_loss: Float,
902}
903
904#[derive(Debug, Clone)]
910pub struct Learn2OptimizeConfig {
911 pub optimizer_architecture: OptimizerArchitecture,
912 pub n_training_tasks: usize,
913 pub max_optimization_steps: usize,
914 pub meta_learning_rate: Float,
915 pub use_recurrent: bool,
916 pub random_state: Option<u64>,
917}
918
919impl Default for Learn2OptimizeConfig {
920 fn default() -> Self {
921 Self {
922 optimizer_architecture: OptimizerArchitecture::RNN { hidden_size: 20 },
923 n_training_tasks: 100,
924 max_optimization_steps: 100,
925 meta_learning_rate: 0.001,
926 use_recurrent: true,
927 random_state: None,
928 }
929 }
930}
931
932#[derive(Debug, Clone)]
934pub enum OptimizerArchitecture {
935 RNN { hidden_size: usize },
937 LSTM { hidden_size: usize, n_layers: usize },
939 Transformer { n_heads: usize, d_model: usize },
941 GNN { n_message_passing_steps: usize },
943}
944
945pub struct LearnedOptimizer {
947 config: Learn2OptimizeConfig,
948 optimizer_state: OptimizerState,
949 training_history: Vec<TrainingEpisode>,
950}
951
952impl LearnedOptimizer {
953 pub fn new(config: Learn2OptimizeConfig) -> Self {
954 Self {
955 config,
956 optimizer_state: OptimizerState::default(),
957 training_history: Vec::new(),
958 }
959 }
960
961 pub fn train(
963 &mut self,
964 training_tasks: &[OptimizationTask],
965 ) -> Result<Learn2OptimizeResult, Box<dyn std::error::Error>> {
966 let mut total_reward = 0.0;
967
968 for task in training_tasks.iter().take(self.config.n_training_tasks) {
969 let episode_reward = self.train_on_task(task)?;
970 total_reward += episode_reward;
971
972 self.training_history.push(TrainingEpisode {
973 task_id: task.task_id.clone(),
974 reward: episode_reward,
975 n_steps: self.config.max_optimization_steps,
976 });
977 }
978
979 Ok(Learn2OptimizeResult {
980 final_performance: total_reward / self.config.n_training_tasks as Float,
981 training_curve: self.training_history.iter().map(|e| e.reward).collect(),
982 n_tasks_trained: training_tasks.len().min(self.config.n_training_tasks),
983 })
984 }
985
986 pub fn optimize(
988 &self,
989 objective_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
990 initial_params: &HashMap<String, Float>,
991 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
992 let mut current_params = initial_params.clone();
993 let mut state = self.optimizer_state.clone();
994
995 for _ in 0..self.config.max_optimization_steps {
996 let update = self.compute_update(¤t_params, &state, objective_fn)?;
998
999 for (param_name, delta) in update {
1001 if let Some(param) = current_params.get_mut(¶m_name) {
1002 *param += delta;
1003 }
1004 }
1005
1006 state = self.update_state(&state, ¤t_params)?;
1008 }
1009
1010 Ok(current_params)
1011 }
1012
1013 fn train_on_task(
1014 &mut self,
1015 task: &OptimizationTask,
1016 ) -> Result<Float, Box<dyn std::error::Error>> {
1017 let mut reward = 0.0;
1019
1020 for example in &task.support_examples {
1021 reward += example.1; }
1023
1024 Ok(reward / task.support_examples.len() as Float)
1025 }
1026
1027 fn compute_update(
1028 &self,
1029 current_params: &HashMap<String, Float>,
1030 _state: &OptimizerState,
1031 _objective_fn: &dyn Fn(&HashMap<String, Float>) -> Float,
1032 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
1033 let mut updates = HashMap::new();
1035 for param_name in current_params.keys() {
1036 updates.insert(param_name.clone(), 0.01); }
1038 Ok(updates)
1039 }
1040
1041 fn update_state(
1042 &self,
1043 state: &OptimizerState,
1044 _params: &HashMap<String, Float>,
1045 ) -> Result<OptimizerState, Box<dyn std::error::Error>> {
1046 Ok(state.clone())
1047 }
1048}
1049
1050#[derive(Debug, Clone, Default)]
1051struct OptimizerState {
1052 hidden: Vec<Float>,
1053 cell: Vec<Float>,
1054}
1055
1056#[derive(Debug, Clone)]
1057struct TrainingEpisode {
1058 task_id: String,
1059 reward: Float,
1060 n_steps: usize,
1061}
1062
1063#[derive(Debug, Clone)]
1065pub struct Learn2OptimizeResult {
1066 pub final_performance: Float,
1067 pub training_curve: Vec<Float>,
1068 pub n_tasks_trained: usize,
1069}
1070
1071#[derive(Debug, Clone)]
1077pub struct ExperienceReplayConfig {
1078 pub buffer_size: usize,
1079 pub batch_size: usize,
1080 pub prioritization: PrioritizationStrategy,
1081 pub sampling_strategy: SamplingStrategy,
1082 pub n_replay_updates: usize,
1083 pub random_state: Option<u64>,
1084}
1085
1086impl Default for ExperienceReplayConfig {
1087 fn default() -> Self {
1088 Self {
1089 buffer_size: 10000,
1090 batch_size: 32,
1091 prioritization: PrioritizationStrategy::Uniform,
1092 sampling_strategy: SamplingStrategy::Random,
1093 n_replay_updates: 10,
1094 random_state: None,
1095 }
1096 }
1097}
1098
1099#[derive(Debug, Clone)]
1101pub enum PrioritizationStrategy {
1102 Uniform,
1104 TDError { alpha: Float },
1106 Improvement { temperature: Float },
1108 Recency { decay_rate: Float },
1110 Diversity { distance_threshold: Float },
1112}
1113
1114#[derive(Debug, Clone)]
1116pub enum SamplingStrategy {
1117 Random,
1119 Reservoir,
1121 Stratified { n_strata: usize },
1123 KDPP { k: usize },
1125}
1126
1127pub struct ExperienceReplayBuffer {
1129 config: ExperienceReplayConfig,
1130 buffer: VecDeque<Experience>,
1131 priorities: Vec<Float>,
1132 rng: StdRng,
1133}
1134
1135impl ExperienceReplayBuffer {
1136 pub fn new(config: ExperienceReplayConfig) -> Self {
1137 let rng = StdRng::seed_from_u64(config.random_state.unwrap_or(42));
1138 Self {
1139 config,
1140 buffer: VecDeque::new(),
1141 priorities: Vec::new(),
1142 rng,
1143 }
1144 }
1145
1146 pub fn add(&mut self, experience: Experience) {
1148 if self.buffer.len() >= self.config.buffer_size {
1149 self.buffer.pop_front();
1150 if !self.priorities.is_empty() {
1151 self.priorities.remove(0);
1152 }
1153 }
1154
1155 let priority = self.compute_priority(&experience);
1156 self.buffer.push_back(experience);
1157 self.priorities.push(priority);
1158 }
1159
1160 pub fn sample(
1162 &mut self,
1163 batch_size: usize,
1164 ) -> Result<Vec<Experience>, Box<dyn std::error::Error>> {
1165 if self.buffer.is_empty() {
1166 return Err("Buffer is empty".into());
1167 }
1168
1169 let sample_size = batch_size.min(self.buffer.len());
1170
1171 match &self.config.sampling_strategy {
1172 SamplingStrategy::Random => self.sample_random(sample_size),
1173 SamplingStrategy::Reservoir => self.sample_reservoir(sample_size),
1174 SamplingStrategy::Stratified { n_strata } => {
1175 self.sample_stratified(sample_size, *n_strata)
1176 }
1177 SamplingStrategy::KDPP { k } => self.sample_kdpp(sample_size, *k),
1178 }
1179 }
1180
1181 pub fn replay_update(
1183 &mut self,
1184 optimizer: &mut dyn OptimizationLearner,
1185 ) -> Result<ReplayResult, Box<dyn std::error::Error>> {
1186 let mut total_loss = 0.0;
1187 let mut n_updates = 0;
1188
1189 for _ in 0..self.config.n_replay_updates {
1190 let batch = self.sample(self.config.batch_size)?;
1191 let loss = optimizer.update_from_batch(&batch)?;
1192 total_loss += loss;
1193 n_updates += 1;
1194 }
1195
1196 Ok(ReplayResult {
1197 average_loss: total_loss / n_updates as Float,
1198 n_updates,
1199 buffer_size: self.buffer.len(),
1200 })
1201 }
1202
1203 fn sample_random(&mut self, n: usize) -> Result<Vec<Experience>, Box<dyn std::error::Error>> {
1206 let mut sampled = Vec::new();
1207 let buffer_vec: Vec<_> = self.buffer.iter().collect();
1208
1209 for _ in 0..n {
1210 let idx = self.rng.gen_range(0..self.buffer.len());
1211 sampled.push(buffer_vec[idx].clone());
1212 }
1213
1214 Ok(sampled)
1215 }
1216
1217 fn sample_reservoir(
1218 &mut self,
1219 n: usize,
1220 ) -> Result<Vec<Experience>, Box<dyn std::error::Error>> {
1221 self.sample_random(n) }
1223
1224 fn sample_stratified(
1225 &mut self,
1226 n: usize,
1227 n_strata: usize,
1228 ) -> Result<Vec<Experience>, Box<dyn std::error::Error>> {
1229 let mut strata: Vec<Vec<Experience>> = vec![Vec::new(); n_strata];
1231
1232 for exp in &self.buffer {
1233 let stratum_idx = ((exp.reward * n_strata as Float) as usize).min(n_strata - 1);
1234 strata[stratum_idx].push(exp.clone());
1235 }
1236
1237 let mut sampled = Vec::new();
1239 let per_stratum = n / n_strata;
1240
1241 for stratum in &strata {
1242 if stratum.is_empty() {
1243 continue;
1244 }
1245
1246 for _ in 0..per_stratum.min(stratum.len()) {
1247 let idx = self.rng.gen_range(0..stratum.len());
1248 sampled.push(stratum[idx].clone());
1249 }
1250 }
1251
1252 Ok(sampled)
1253 }
1254
1255 fn sample_kdpp(
1256 &mut self,
1257 n: usize,
1258 _k: usize,
1259 ) -> Result<Vec<Experience>, Box<dyn std::error::Error>> {
1260 self.sample_random(n)
1262 }
1263
1264 fn compute_priority(&self, experience: &Experience) -> Float {
1265 match &self.config.prioritization {
1266 PrioritizationStrategy::Uniform => 1.0,
1267 PrioritizationStrategy::TDError { alpha } => {
1268 (experience.improvement.abs() + 1e-6).powf(*alpha)
1270 }
1271 PrioritizationStrategy::Improvement { temperature } => {
1272 (experience.improvement / temperature).exp()
1273 }
1274 PrioritizationStrategy::Recency { decay_rate } => {
1275 (-decay_rate * self.buffer.len() as Float).exp()
1276 }
1277 PrioritizationStrategy::Diversity { .. } => 1.0, }
1279 }
1280}
1281
1282#[derive(Debug, Clone)]
1284pub struct Experience {
1285 pub state: HashMap<String, Float>,
1286 pub action: HashMap<String, Float>,
1287 pub reward: Float,
1288 pub next_state: HashMap<String, Float>,
1289 pub improvement: Float,
1290 pub timestamp: usize,
1291}
1292
1293#[derive(Debug, Clone)]
1295pub struct ReplayResult {
1296 pub average_loss: Float,
1297 pub n_updates: usize,
1298 pub buffer_size: usize,
1299}
1300
1301pub trait OptimizationLearner {
1303 fn update_from_batch(
1304 &mut self,
1305 batch: &[Experience],
1306 ) -> Result<Float, Box<dyn std::error::Error>>;
1307}
1308
1309#[derive(Debug, Clone)]
1315pub struct TaskCharacteristics {
1316 pub task_type: String,
1317 pub n_samples: usize,
1318 pub n_features: usize,
1319 pub parameter_space: HashMap<String, ParameterRange>,
1320 pub complexity: Float,
1321}
1322
1323#[derive(Debug, Clone)]
1324pub struct ParameterRange {
1325 pub min: Float,
1326 pub max: Float,
1327 pub scale: ParameterScale,
1328}
1329
1330#[derive(Debug, Clone)]
1331pub enum ParameterScale {
1332 Linear,
1333 Log,
1334 Categorical,
1335}
1336
1337#[derive(Debug, Clone)]
1338pub struct ParameterDistribution {
1339 pub mean: Float,
1340 pub std: Float,
1341}
1342
1343#[derive(Debug, Clone, Serialize, Deserialize)]
1345pub struct OptimizationExperience {
1346 pub task_characteristics: TaskCharacteristics,
1347 pub best_parameters: HashMap<String, Float>,
1348 pub performance: Float,
1349 pub n_iterations: usize,
1350 pub convergence_curve: Vec<Float>,
1351}
1352
1353impl Serialize for TaskCharacteristics {
1355 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1356 where
1357 S: serde::Serializer,
1358 {
1359 use serde::ser::SerializeStruct;
1360 let mut state = serializer.serialize_struct("TaskCharacteristics", 5)?;
1361 state.serialize_field("task_type", &self.task_type)?;
1362 state.serialize_field("n_samples", &self.n_samples)?;
1363 state.serialize_field("n_features", &self.n_features)?;
1364 state.serialize_field("complexity", &self.complexity)?;
1365 state.end()
1366 }
1367}
1368
1369impl<'de> Deserialize<'de> for TaskCharacteristics {
1370 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1371 where
1372 D: serde::Deserializer<'de>,
1373 {
1374 #[derive(Deserialize)]
1375 struct Helper {
1376 task_type: String,
1377 n_samples: usize,
1378 n_features: usize,
1379 complexity: Float,
1380 }
1381
1382 let helper = Helper::deserialize(deserializer)?;
1383 Ok(TaskCharacteristics {
1384 task_type: helper.task_type,
1385 n_samples: helper.n_samples,
1386 n_features: helper.n_features,
1387 parameter_space: HashMap::new(),
1388 complexity: helper.complexity,
1389 })
1390 }
1391}
1392
1393#[derive(Debug, Clone)]
1395pub struct OptimizationTask {
1396 pub task_id: String,
1397 pub support_examples: Vec<(HashMap<String, Float>, Float)>,
1398 pub query_examples: Vec<(HashMap<String, Float>, Float)>,
1399}
1400
1401#[cfg(test)]
1406mod tests {
1407 use super::*;
1408
1409 #[test]
1410 fn test_transfer_learning_config() {
1411 let config = TransferLearningConfig::default();
1412 assert_eq!(config.n_init_samples, 5);
1413 assert_eq!(config.max_transfer_iterations, 50);
1414 }
1415
1416 #[test]
1417 fn test_few_shot_config() {
1418 let config = FewShotConfig::default();
1419 assert_eq!(config.n_support, 5);
1420 assert_eq!(config.n_query, 10);
1421 }
1422
1423 #[test]
1424 fn test_learn2optimize_config() {
1425 let config = Learn2OptimizeConfig::default();
1426 assert_eq!(config.n_training_tasks, 100);
1427 assert!(config.use_recurrent);
1428 }
1429
1430 #[test]
1431 fn test_experience_replay_buffer() {
1432 let config = ExperienceReplayConfig::default();
1433 let mut buffer = ExperienceReplayBuffer::new(config);
1434
1435 let experience = Experience {
1436 state: HashMap::new(),
1437 action: HashMap::new(),
1438 reward: 0.8,
1439 next_state: HashMap::new(),
1440 improvement: 0.1,
1441 timestamp: 0,
1442 };
1443
1444 buffer.add(experience.clone());
1445 assert_eq!(buffer.buffer.len(), 1);
1446 }
1447
1448 #[test]
1449 fn test_transfer_learning_optimizer() {
1450 let config = TransferLearningConfig::default();
1451 let optimizer = TransferLearningOptimizer::new(config);
1452
1453 assert!(optimizer.transferred_knowledge.is_empty());
1454 assert!(optimizer.transfer_performance.is_empty());
1455 }
1456
1457 #[test]
1458 fn test_few_shot_optimizer() {
1459 let config = FewShotConfig::default();
1460 let optimizer = FewShotOptimizer::new(config);
1461
1462 assert!(optimizer.meta_parameters.is_empty());
1463 assert!(optimizer.episode_history.is_empty());
1464 }
1465
1466 #[test]
1467 fn test_learned_optimizer() {
1468 let config = Learn2OptimizeConfig::default();
1469 let optimizer = LearnedOptimizer::new(config);
1470
1471 assert!(optimizer.training_history.is_empty());
1472 }
1473}