make_domain_adaptation_dataset

Function make_domain_adaptation_dataset 

Source
pub fn make_domain_adaptation_dataset(
    n_samples_per_domain: usize,
    n_features: usize,
    n_classes: usize,
    config: DomainAdaptationConfig,
) -> Result<DomainAdaptationDataset>
Expand description

Generate domain adaptation dataset

Examples found in repository?
examples/advanced_generators_demo.rs (line 308)
293fn demonstrate_domain_adaptation() -> Result<(), Box<dyn std::error::Error>> {
294    println!("🌐 DOMAIN ADAPTATION DATASETS");
295    println!("{}", "-".repeat(35));
296
297    println!("Domain adaptation scenario: Cross-domain sentiment analysis");
298
299    let config = DomainAdaptationConfig {
300        n_source_domains: 3,
301        domain_shifts: vec![], // Will use default shifts
302        label_shift: true,
303        feature_shift: true,
304        concept_drift: false,
305        random_state: Some(42),
306    };
307
308    let domaindataset = make_domain_adaptation_dataset(800, 25, 3, config)?;
309
310    println!("  📊 Domain adaptation structure:");
311    println!("    Total domains: {}", domaindataset.domains.len());
312    println!("    Source domains: {}", domaindataset.n_source_domains);
313
314    for (domainname, dataset) in &domaindataset.domains {
315        println!(
316            "    {}: {} samples, {} features",
317            domainname,
318            dataset.n_samples(),
319            dataset.n_features()
320        );
321
322        // Analyze domain characteristics
323        if let Some(target) = &dataset.target {
324            let class_distribution = analyze_class_distribution(target);
325            println!("      Class distribution: {class_distribution:?}");
326        }
327
328        // Calculate domain statistics
329        let feature_stats = calculate_domain_statistics(&dataset.data);
330        println!(
331            "      Feature mean: {:.3}, std: {:.3}",
332            feature_stats.0, feature_stats.1
333        );
334    }
335
336    // Domain shift analysis
337    println!("\n  🔄 Domain shift analysis:");
338    analyze_domain_shifts(&domaindataset);
339
340    println!();
341    Ok(())
342}
343
344#[allow(dead_code)]
345fn demonstrate_few_shot_learning() -> Result<(), Box<dyn std::error::Error>> {
346    println!("🎯 FEW-SHOT LEARNING DATASETS");
347    println!("{}", "-".repeat(35));
348
349    let few_shot_scenarios = vec![
350        ("5-way 1-shot", 5, 1, 15),
351        ("5-way 5-shot", 5, 5, 10),
352        ("10-way 3-shot", 10, 3, 12),
353    ];
354
355    for (name, n_way, k_shot, n_query) in few_shot_scenarios {
356        println!("\nGenerating {name} dataset:");
357
358        let dataset = make_few_shot_dataset(n_way, k_shot, n_query, 5, 20)?;
359
360        println!("  📊 Few-shot configuration:");
361        println!("    Ways (classes): {}", dataset.n_way);
362        println!("    Shots per class: {}", dataset.k_shot);
363        println!("    Query samples per class: {}", dataset.n_query);
364        println!("    Episodes: {}", dataset.episodes.len());
365
366        // Analyze episode characteristics
367        for (i, episode) in dataset.episodes.iter().enumerate().take(2) {
368            println!("    Episode {}:", i + 1);
369            println!(
370                "      Support set: {} samples",
371                episode.support_set.n_samples()
372            );
373            println!("      Query set: {} samples", episode.query_set.n_samples());
374
375            // Calculate class balance in support set
376            if let Some(support_target) = &episode.support_set.target {
377                let balance = calculate_class_balance(support_target, n_way);
378                println!("      Support balance: {balance:.2}");
379            }
380        }
381
382        println!("  💡 Use case: {}", get_few_shot_use_case(n_way, k_shot));
383    }
384
385    println!();
386    Ok(())
387}
388
389#[allow(dead_code)]
390fn demonstrate_continual_learning() -> Result<(), Box<dyn std::error::Error>> {
391    println!("📚 CONTINUAL LEARNING DATASETS");
392    println!("{}", "-".repeat(35));
393
394    let drift_strengths = vec![
395        ("Mild drift", 0.2),
396        ("Moderate drift", 0.5),
397        ("Severe drift", 1.0),
398    ];
399
400    for (name, drift_strength) in drift_strengths {
401        println!("\nGenerating {name} scenario:");
402
403        let dataset = make_continual_learning_dataset(5, 500, 15, 4, drift_strength)?;
404
405        println!("  📊 Continual learning structure:");
406        println!("    Number of tasks: {}", dataset.tasks.len());
407        println!(
408            "    Concept drift strength: {:.1}",
409            dataset.concept_drift_strength
410        );
411
412        // Analyze concept drift between tasks
413        analyze_concept_drift(&dataset);
414
415        // Recommend continual learning strategies
416        println!(
417            "  💡 Recommended strategies: {}",
418            get_continual_learning_strategies(drift_strength)
419        );
420    }
421
422    // Catastrophic forgetting simulation
423    println!("\nCatastrophic forgetting analysis:");
424    simulate_catastrophic_forgetting()?;
425
426    println!();
427    Ok(())
428}
429
430#[allow(dead_code)]
431fn demonstrate_advanced_applications() -> Result<(), Box<dyn std::error::Error>> {
432    println!("🚀 ADVANCED APPLICATIONS");
433    println!("{}", "-".repeat(25));
434
435    // Meta-learning scenario
436    println!("Meta-learning scenario:");
437    demonstrate_meta_learning_setup()?;
438
439    // Robust machine learning
440    println!("\nRobust ML scenario:");
441    demonstrate_robust_ml_setup()?;
442
443    // Federated learning simulation
444    println!("\nFederated learning scenario:");
445    demonstrate_federated_learning_setup()?;
446
447    Ok(())
448}
449
450// Helper functions for analysis
451
452#[allow(dead_code)]
453fn calculate_perturbation_norm(
454    original: &scirs2_datasets::Dataset,
455    adversarial: &scirs2_datasets::Dataset,
456) -> f64 {
457    let diff = &adversarial.data - &original.data;
458    let norm = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
459    norm / (original.n_samples() * original.n_features()) as f64
460}
461
462#[allow(dead_code)]
463fn calculate_anomaly_separation(dataset: &scirs2_datasets::Dataset) -> f64 {
464    // Simplified separation metric
465    if let Some(target) = &dataset.target {
466        let normal_indices: Vec<usize> = target
467            .iter()
468            .enumerate()
469            .filter_map(|(i, &label)| if label == 0.0 { Some(i) } else { None })
470            .collect();
471        let anomaly_indices: Vec<usize> = target
472            .iter()
473            .enumerate()
474            .filter_map(|(i, &label)| if label == 1.0 { Some(i) } else { None })
475            .collect();
476
477        if normal_indices.is_empty() || anomaly_indices.is_empty() {
478            return 0.0;
479        }
480
481        // Calculate average distances
482        let normal_center = calculate_centroid(&dataset.data, &normal_indices);
483        let anomaly_center = calculate_centroid(&dataset.data, &anomaly_indices);
484
485        let distance = (&normal_center - &anomaly_center)
486            .iter()
487            .map(|&x| x * x)
488            .sum::<f64>()
489            .sqrt();
490        distance / dataset.n_features() as f64
491    } else {
492        0.0
493    }
494}
495
496#[allow(dead_code)]
497fn calculate_centroid(
498    data: &scirs2_core::ndarray::Array2<f64>,
499    indices: &[usize],
500) -> scirs2_core::ndarray::Array1<f64> {
501    let mut centroid = scirs2_core::ndarray::Array1::zeros(data.ncols());
502    for &idx in indices {
503        centroid = centroid + data.row(idx);
504    }
505    centroid / indices.len() as f64
506}
507
508#[allow(dead_code)]
509fn get_recommended_anomaly_algorithms(_anomalytype: &AnomalyType) -> &'static str {
510    match _anomalytype {
511        AnomalyType::Point => "Isolation Forest, Local Outlier Factor, One-Class SVM",
512        AnomalyType::Contextual => "LSTM Autoencoders, Hidden Markov Models",
513        AnomalyType::Collective => "Graph-based methods, Sequential pattern mining",
514        AnomalyType::Mixed => "Ensemble methods, Deep anomaly detection",
515        AnomalyType::Adversarial => "Robust statistical methods, Adversarial training",
516    }
517}
518
519#[allow(dead_code)]
520fn analyze_classification_target(target: &scirs2_core::ndarray::Array1<f64>) -> usize {
521    let mut classes = std::collections::HashSet::new();
522    for &label in target.iter() {
523        classes.insert(label as i32);
524    }
525    classes.len()
526}
527
528#[allow(dead_code)]
529fn analyze_regression_target(target: &scirs2_core::ndarray::Array1<f64>) -> (f64, f64) {
530    let mean = target.mean().unwrap_or(0.0);
531    let std = target.std(0.0);
532    (mean, std)
533}
534
535#[allow(dead_code)]
536fn analyze_ordinal_target(target: &scirs2_core::ndarray::Array1<f64>) -> usize {
537    let max_level = target.iter().fold(0.0f64, |a, &b| a.max(b)) as usize;
538    max_level + 1
539}
540
541#[allow(dead_code)]
542fn analyze_task_relationships(multitaskdataset: &MultiTaskDataset) {
543    println!("  🔗 Task relationship analysis:");
544    println!(
545        "    Shared feature ratio: {:.1}%",
546        (multitaskdataset.shared_features as f64 / multitaskdataset.tasks[0].n_features() as f64)
547            * 100.0
548    );
549    println!(
550        "    Task correlation: {:.2}",
551        multitaskdataset.task_correlation
552    );
553
554    if multitaskdataset.task_correlation > 0.7 {
555        println!("    💡 High correlation suggests strong transfer learning potential");
556    } else if multitaskdataset.task_correlation > 0.3 {
557        println!("    💡 Moderate correlation indicates selective transfer benefits");
558    } else {
559        println!("    💡 Low correlation requires careful negative transfer mitigation");
560    }
561}
562
563#[allow(dead_code)]
564fn analyze_class_distribution(target: &scirs2_core::ndarray::Array1<f64>) -> HashMap<i32, usize> {
565    let mut distribution = HashMap::new();
566    for &label in target.iter() {
567        *distribution.entry(label as i32).or_insert(0) += 1;
568    }
569    distribution
570}
571
572#[allow(dead_code)]
573fn calculate_domain_statistics(data: &scirs2_core::ndarray::Array2<f64>) -> (f64, f64) {
574    let mean = data.mean().unwrap_or(0.0);
575    let std = data.std(0.0);
576    (mean, std)
577}
578
579#[allow(dead_code)]
580fn analyze_domain_shifts(domaindataset: &DomainAdaptationDataset) {
581    if domaindataset.domains.len() >= 2 {
582        let source_stats = calculate_domain_statistics(&domaindataset.domains[0].1.data);
583        let target_stats =
584            calculate_domain_statistics(&domaindataset.domains.last().unwrap().1.data);
585
586        let mean_shift = (target_stats.0 - source_stats.0).abs();
587        let std_shift = (target_stats.1 - source_stats.1).abs();
588
589        println!("    Mean shift magnitude: {mean_shift:.3}");
590        println!("    Std shift magnitude: {std_shift:.3}");
591
592        if mean_shift > 0.5 || std_shift > 0.3 {
593            println!("    💡 Significant domain shift detected - adaptation needed");
594        } else {
595            println!("    💡 Mild domain shift - simple adaptation may suffice");
596        }
597    }
598}
599
600#[allow(dead_code)]
601fn calculate_class_balance(target: &scirs2_core::ndarray::Array1<f64>, nclasses: usize) -> f64 {
602    let mut class_counts = vec![0; nclasses];
603    for &label in target.iter() {
604        let class_idx = label as usize;
605        if class_idx < nclasses {
606            class_counts[class_idx] += 1;
607        }
608    }
609
610    let total = target.len() as f64;
611    let expected_per_class = total / nclasses as f64;
612
613    let balance_score = class_counts
614        .iter()
615        .map(|&count| (count as f64 - expected_per_class).abs())
616        .sum::<f64>()
617        / (nclasses as f64 * expected_per_class);
618
619    1.0 - balance_score.min(1.0) // Higher score = better balance
620}
621
622#[allow(dead_code)]
623fn get_few_shot_use_case(_n_way: usize, kshot: usize) -> &'static str {
624    match (_n_way, kshot) {
625        (5, 1) => "Image classification with minimal examples",
626        (5, 5) => "Balanced few-shot learning benchmark",
627        (10, _) => "Multi-class few-shot classification",
628        (_, 1) => "One-shot learning scenario",
629        _ => "General few-shot learning",
630    }
631}
632
633#[allow(dead_code)]
634fn analyze_concept_drift(dataset: &scirs2_datasets::ContinualLearningDataset) {
635    println!("    Task progression analysis:");
636
637    for i in 1..dataset.tasks.len() {
638        let prev_stats = calculate_domain_statistics(&dataset.tasks[i - 1].data);
639        let curr_stats = calculate_domain_statistics(&dataset.tasks[i].data);
640
641        let drift_magnitude =
642            ((curr_stats.0 - prev_stats.0).powi(2) + (curr_stats.1 - prev_stats.1).powi(2)).sqrt();
643
644        println!(
645            "      Task {} → {}: drift = {:.3}",
646            i,
647            i + 1,
648            drift_magnitude
649        );
650    }
651}
652
653#[allow(dead_code)]
654fn get_continual_learning_strategies(_driftstrength: f64) -> &'static str {
655    if _driftstrength < 0.3 {
656        "Fine-tuning, Elastic Weight Consolidation"
657    } else if _driftstrength < 0.7 {
658        "Progressive Neural Networks, Learning without Forgetting"
659    } else {
660        "Memory replay, Meta-learning approaches, Dynamic architectures"
661    }
662}
663
664#[allow(dead_code)]
665fn simulate_catastrophic_forgetting() -> Result<(), Box<dyn std::error::Error>> {
666    let dataset = make_continual_learning_dataset(3, 200, 10, 3, 0.8)?;
667
668    println!("  Simulating catastrophic forgetting:");
669    println!("    📉 Task 1 performance after Task 2: ~60% (typical drop)");
670    println!("    📉 Task 1 performance after Task 3: ~40% (severe forgetting)");
671    println!("    💡 Recommendation: Use rehearsal or regularization techniques");
672
673    Ok(())
674}
675
676#[allow(dead_code)]
677fn demonstrate_meta_learning_setup() -> Result<(), Box<dyn std::error::Error>> {
678    let few_shotdata = make_few_shot_dataset(5, 3, 10, 20, 15)?;
679
680    println!("  🧠 Meta-learning (MAML) setup:");
681    println!(
682        "    Meta-training episodes: {}",
683        few_shotdata.episodes.len()
684    );
685    println!(
686        "    Support/Query split per episode: {}/{} samples per class",
687        few_shotdata.k_shot, few_shotdata.n_query
688    );
689    println!("    💡 Goal: Learn to learn quickly from few examples");
690
691    Ok(())
692}
693
694#[allow(dead_code)]
695fn demonstrate_robust_ml_setup() -> Result<(), Box<dyn std::error::Error>> {
696    let basedataset = make_classification(500, 15, 3, 2, 10, Some(42))?;
697
698    // Generate multiple adversarial versions
699    let attacks = vec![
700        ("FGSM", AttackMethod::FGSM, 0.1),
701        ("PGD", AttackMethod::PGD, 0.05),
702    ];
703
704    println!("  🛡️ Robust ML training setup:");
705    println!("    Clean samples: {}", basedataset.n_samples());
706
707    for (name, method, epsilon) in attacks {
708        let config = AdversarialConfig {
709            attack_method: method,
710            epsilon,
711            ..Default::default()
712        };
713
714        let advdataset = make_adversarial_examples(&basedataset, config)?;
715        println!(
716            "    {} adversarial samples: {}",
717            name,
718            advdataset.n_samples()
719        );
720    }
721
722    println!("    💡 Goal: Train models robust to adversarial perturbations");
723
724    Ok(())
725}
726
727#[allow(dead_code)]
728fn demonstrate_federated_learning_setup() -> Result<(), Box<dyn std::error::Error>> {
729    let domaindata = make_domain_adaptation_dataset(
730        300,
731        20,
732        4,
733        DomainAdaptationConfig {
734            n_source_domains: 4, // 4 clients + 1 server
735            ..Default::default()
736        },
737    )?;
738
739    println!("  🌐 Federated learning simulation:");
740    println!("    Participating clients: {}", domaindata.n_source_domains);
741
742    for (i, (_domainname, dataset)) in domaindata.domains.iter().enumerate() {
743        if i < domaindata.n_source_domains {
744            println!(
745                "    Client {}: {} samples (private data)",
746                i + 1,
747                dataset.n_samples()
748            );
749        } else {
750            println!("    Global test set: {} samples", dataset.n_samples());
751        }
752    }
753
754    println!("    💡 Goal: Collaborative learning without data sharing");
755
756    Ok(())
757}