pub fn make_continual_learning_dataset(
n_tasks: usize,
n_samples_per_task: usize,
n_features: usize,
n_classes: usize,
concept_drift_strength: f64,
) -> Result<ContinualLearningDataset>
Expand description
Generate continual learning dataset
Examples found in repository?
examples/advanced_generators_demo.rs (line 403)
390fn demonstrate_continual_learning() -> Result<(), Box<dyn std::error::Error>> {
391 println!("📚 CONTINUAL LEARNING DATASETS");
392 println!("{}", "-".repeat(35));
393
394 let drift_strengths = vec![
395 ("Mild drift", 0.2),
396 ("Moderate drift", 0.5),
397 ("Severe drift", 1.0),
398 ];
399
400 for (name, drift_strength) in drift_strengths {
401 println!("\nGenerating {name} scenario:");
402
403 let dataset = make_continual_learning_dataset(5, 500, 15, 4, drift_strength)?;
404
405 println!(" 📊 Continual learning structure:");
406 println!(" Number of tasks: {}", dataset.tasks.len());
407 println!(
408 " Concept drift strength: {:.1}",
409 dataset.concept_drift_strength
410 );
411
412 // Analyze concept drift between tasks
413 analyze_concept_drift(&dataset);
414
415 // Recommend continual learning strategies
416 println!(
417 " 💡 Recommended strategies: {}",
418 get_continual_learning_strategies(drift_strength)
419 );
420 }
421
422 // Catastrophic forgetting simulation
423 println!("\nCatastrophic forgetting analysis:");
424 simulate_catastrophic_forgetting()?;
425
426 println!();
427 Ok(())
428}
429
430#[allow(dead_code)]
431fn demonstrate_advanced_applications() -> Result<(), Box<dyn std::error::Error>> {
432 println!("🚀 ADVANCED APPLICATIONS");
433 println!("{}", "-".repeat(25));
434
435 // Meta-learning scenario
436 println!("Meta-learning scenario:");
437 demonstrate_meta_learning_setup()?;
438
439 // Robust machine learning
440 println!("\nRobust ML scenario:");
441 demonstrate_robust_ml_setup()?;
442
443 // Federated learning simulation
444 println!("\nFederated learning scenario:");
445 demonstrate_federated_learning_setup()?;
446
447 Ok(())
448}
449
450// Helper functions for analysis
451
452#[allow(dead_code)]
453fn calculate_perturbation_norm(
454 original: &scirs2_datasets::Dataset,
455 adversarial: &scirs2_datasets::Dataset,
456) -> f64 {
457 let diff = &adversarial.data - &original.data;
458 let norm = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
459 norm / (original.n_samples() * original.n_features()) as f64
460}
461
462#[allow(dead_code)]
463fn calculate_anomaly_separation(dataset: &scirs2_datasets::Dataset) -> f64 {
464 // Simplified separation metric
465 if let Some(target) = &dataset.target {
466 let normal_indices: Vec<usize> = target
467 .iter()
468 .enumerate()
469 .filter_map(|(i, &label)| if label == 0.0 { Some(i) } else { None })
470 .collect();
471 let anomaly_indices: Vec<usize> = target
472 .iter()
473 .enumerate()
474 .filter_map(|(i, &label)| if label == 1.0 { Some(i) } else { None })
475 .collect();
476
477 if normal_indices.is_empty() || anomaly_indices.is_empty() {
478 return 0.0;
479 }
480
481 // Calculate average distances
482 let normal_center = calculate_centroid(&dataset.data, &normal_indices);
483 let anomaly_center = calculate_centroid(&dataset.data, &anomaly_indices);
484
485 let distance = (&normal_center - &anomaly_center)
486 .iter()
487 .map(|&x| x * x)
488 .sum::<f64>()
489 .sqrt();
490 distance / dataset.n_features() as f64
491 } else {
492 0.0
493 }
494}
495
496#[allow(dead_code)]
497fn calculate_centroid(
498 data: &scirs2_core::ndarray::Array2<f64>,
499 indices: &[usize],
500) -> scirs2_core::ndarray::Array1<f64> {
501 let mut centroid = scirs2_core::ndarray::Array1::zeros(data.ncols());
502 for &idx in indices {
503 centroid = centroid + data.row(idx);
504 }
505 centroid / indices.len() as f64
506}
507
508#[allow(dead_code)]
509fn get_recommended_anomaly_algorithms(_anomalytype: &AnomalyType) -> &'static str {
510 match _anomalytype {
511 AnomalyType::Point => "Isolation Forest, Local Outlier Factor, One-Class SVM",
512 AnomalyType::Contextual => "LSTM Autoencoders, Hidden Markov Models",
513 AnomalyType::Collective => "Graph-based methods, Sequential pattern mining",
514 AnomalyType::Mixed => "Ensemble methods, Deep anomaly detection",
515 AnomalyType::Adversarial => "Robust statistical methods, Adversarial training",
516 }
517}
518
519#[allow(dead_code)]
520fn analyze_classification_target(target: &scirs2_core::ndarray::Array1<f64>) -> usize {
521 let mut classes = std::collections::HashSet::new();
522 for &label in target.iter() {
523 classes.insert(label as i32);
524 }
525 classes.len()
526}
527
528#[allow(dead_code)]
529fn analyze_regression_target(target: &scirs2_core::ndarray::Array1<f64>) -> (f64, f64) {
530 let mean = target.mean().unwrap_or(0.0);
531 let std = target.std(0.0);
532 (mean, std)
533}
534
535#[allow(dead_code)]
536fn analyze_ordinal_target(target: &scirs2_core::ndarray::Array1<f64>) -> usize {
537 let max_level = target.iter().fold(0.0f64, |a, &b| a.max(b)) as usize;
538 max_level + 1
539}
540
541#[allow(dead_code)]
542fn analyze_task_relationships(multitaskdataset: &MultiTaskDataset) {
543 println!(" 🔗 Task relationship analysis:");
544 println!(
545 " Shared feature ratio: {:.1}%",
546 (multitaskdataset.shared_features as f64 / multitaskdataset.tasks[0].n_features() as f64)
547 * 100.0
548 );
549 println!(
550 " Task correlation: {:.2}",
551 multitaskdataset.task_correlation
552 );
553
554 if multitaskdataset.task_correlation > 0.7 {
555 println!(" 💡 High correlation suggests strong transfer learning potential");
556 } else if multitaskdataset.task_correlation > 0.3 {
557 println!(" 💡 Moderate correlation indicates selective transfer benefits");
558 } else {
559 println!(" 💡 Low correlation requires careful negative transfer mitigation");
560 }
561}
562
563#[allow(dead_code)]
564fn analyze_class_distribution(target: &scirs2_core::ndarray::Array1<f64>) -> HashMap<i32, usize> {
565 let mut distribution = HashMap::new();
566 for &label in target.iter() {
567 *distribution.entry(label as i32).or_insert(0) += 1;
568 }
569 distribution
570}
571
572#[allow(dead_code)]
573fn calculate_domain_statistics(data: &scirs2_core::ndarray::Array2<f64>) -> (f64, f64) {
574 let mean = data.mean().unwrap_or(0.0);
575 let std = data.std(0.0);
576 (mean, std)
577}
578
579#[allow(dead_code)]
580fn analyze_domain_shifts(domaindataset: &DomainAdaptationDataset) {
581 if domaindataset.domains.len() >= 2 {
582 let source_stats = calculate_domain_statistics(&domaindataset.domains[0].1.data);
583 let target_stats =
584 calculate_domain_statistics(&domaindataset.domains.last().unwrap().1.data);
585
586 let mean_shift = (target_stats.0 - source_stats.0).abs();
587 let std_shift = (target_stats.1 - source_stats.1).abs();
588
589 println!(" Mean shift magnitude: {mean_shift:.3}");
590 println!(" Std shift magnitude: {std_shift:.3}");
591
592 if mean_shift > 0.5 || std_shift > 0.3 {
593 println!(" 💡 Significant domain shift detected - adaptation needed");
594 } else {
595 println!(" 💡 Mild domain shift - simple adaptation may suffice");
596 }
597 }
598}
599
600#[allow(dead_code)]
601fn calculate_class_balance(target: &scirs2_core::ndarray::Array1<f64>, nclasses: usize) -> f64 {
602 let mut class_counts = vec![0; nclasses];
603 for &label in target.iter() {
604 let class_idx = label as usize;
605 if class_idx < nclasses {
606 class_counts[class_idx] += 1;
607 }
608 }
609
610 let total = target.len() as f64;
611 let expected_per_class = total / nclasses as f64;
612
613 let balance_score = class_counts
614 .iter()
615 .map(|&count| (count as f64 - expected_per_class).abs())
616 .sum::<f64>()
617 / (nclasses as f64 * expected_per_class);
618
619 1.0 - balance_score.min(1.0) // Higher score = better balance
620}
621
622#[allow(dead_code)]
623fn get_few_shot_use_case(_n_way: usize, kshot: usize) -> &'static str {
624 match (_n_way, kshot) {
625 (5, 1) => "Image classification with minimal examples",
626 (5, 5) => "Balanced few-shot learning benchmark",
627 (10, _) => "Multi-class few-shot classification",
628 (_, 1) => "One-shot learning scenario",
629 _ => "General few-shot learning",
630 }
631}
632
633#[allow(dead_code)]
634fn analyze_concept_drift(dataset: &scirs2_datasets::ContinualLearningDataset) {
635 println!(" Task progression analysis:");
636
637 for i in 1..dataset.tasks.len() {
638 let prev_stats = calculate_domain_statistics(&dataset.tasks[i - 1].data);
639 let curr_stats = calculate_domain_statistics(&dataset.tasks[i].data);
640
641 let drift_magnitude =
642 ((curr_stats.0 - prev_stats.0).powi(2) + (curr_stats.1 - prev_stats.1).powi(2)).sqrt();
643
644 println!(
645 " Task {} → {}: drift = {:.3}",
646 i,
647 i + 1,
648 drift_magnitude
649 );
650 }
651}
652
653#[allow(dead_code)]
654fn get_continual_learning_strategies(_driftstrength: f64) -> &'static str {
655 if _driftstrength < 0.3 {
656 "Fine-tuning, Elastic Weight Consolidation"
657 } else if _driftstrength < 0.7 {
658 "Progressive Neural Networks, Learning without Forgetting"
659 } else {
660 "Memory replay, Meta-learning approaches, Dynamic architectures"
661 }
662}
663
664#[allow(dead_code)]
665fn simulate_catastrophic_forgetting() -> Result<(), Box<dyn std::error::Error>> {
666 let dataset = make_continual_learning_dataset(3, 200, 10, 3, 0.8)?;
667
668 println!(" Simulating catastrophic forgetting:");
669 println!(" 📉 Task 1 performance after Task 2: ~60% (typical drop)");
670 println!(" 📉 Task 1 performance after Task 3: ~40% (severe forgetting)");
671 println!(" 💡 Recommendation: Use rehearsal or regularization techniques");
672
673 Ok(())
674}