advanced_generators_demo/
advanced_generators_demo.rs

1//! Advanced synthetic data generators demonstration
2//!
3//! This example demonstrates sophisticated synthetic data generation capabilities
4//! for modern machine learning scenarios including adversarial examples, anomaly detection,
5//! multi-task learning, domain adaptation, few-shot learning, and continual learning.
6//!
7//! Usage:
8//!   cargo run --example advanced_generators_demo --release
9
10use scirs2_datasets::{
11    make_adversarial_examples, make_anomaly_dataset, make_classification,
12    make_continual_learning_dataset, make_domain_adaptation_dataset, make_few_shot_dataset,
13    make_multitask_dataset, AdversarialConfig, AnomalyConfig, AnomalyType, AttackMethod,
14    DomainAdaptationConfig, DomainAdaptationDataset, MultiTaskConfig, MultiTaskDataset, TaskType,
15};
16use statrs::statistics::Statistics;
17use std::collections::HashMap;
18
19#[allow(dead_code)]
20fn main() -> Result<(), Box<dyn std::error::Error>> {
21    println!("🧬 Advanced Synthetic Data Generators Demonstration");
22    println!("===================================================\n");
23
24    // Adversarial examples generation
25    demonstrate_adversarial_examples()?;
26
27    // Anomaly detection datasets
28    demonstrate_anomaly_detection()?;
29
30    // Multi-task learning datasets
31    demonstrate_multitask_learning()?;
32
33    // Domain adaptation scenarios
34    demonstrate_domain_adaptation()?;
35
36    // Few-shot learning datasets
37    demonstrate_few_shot_learning()?;
38
39    // Continual learning with concept drift
40    demonstrate_continual_learning()?;
41
42    // Advanced analysis and applications
43    demonstrate_advanced_applications()?;
44
45    println!("\nšŸŽ‰ Advanced generators demonstration completed!");
46    Ok(())
47}
48
49#[allow(dead_code)]
50fn demonstrate_adversarial_examples() -> Result<(), Box<dyn std::error::Error>> {
51    println!("šŸ›”ļø ADVERSARIAL EXAMPLES GENERATION");
52    println!("{}", "-".repeat(45));
53
54    // Create a base classification dataset
55    let basedataset = make_classification(1000, 20, 5, 2, 15, Some(42))?;
56    println!(
57        "Base dataset: {} samples, {} features, {} classes",
58        basedataset.n_samples(),
59        basedataset.n_features(),
60        5
61    );
62
63    // Test different attack methods
64    let attack_methods = vec![
65        ("FGSM", AttackMethod::FGSM, 0.1),
66        ("PGD", AttackMethod::PGD, 0.05),
67        ("Random Noise", AttackMethod::RandomNoise, 0.2),
68    ];
69
70    for (name, method, epsilon) in attack_methods {
71        println!("\nGenerating {name} adversarial examples:");
72
73        let config = AdversarialConfig {
74            epsilon,
75            attack_method: method,
76            target_class: None, // Untargeted attack
77            iterations: 10,
78            step_size: 0.01,
79            random_state: Some(42),
80        };
81
82        let adversarialdataset = make_adversarial_examples(&basedataset, config)?;
83
84        // Analyze perturbation strength
85        let perturbation_norm = calculate_perturbation_norm(&basedataset, &adversarialdataset);
86
87        println!(
88            "  āœ… Generated {} adversarial examples",
89            adversarialdataset.n_samples()
90        );
91        println!("  šŸ“Š Perturbation strength: {perturbation_norm:.4}");
92        println!("  šŸŽÆ Attack budget (ε): {epsilon:.2}");
93        println!(
94            "  šŸ“ˆ Expected robustness impact: {:.1}%",
95            (1.0 - perturbation_norm) * 100.0
96        );
97    }
98
99    // Targeted attack example
100    println!("\nTargeted adversarial attack:");
101    let targeted_config = AdversarialConfig {
102        epsilon: 0.1,
103        attack_method: AttackMethod::FGSM,
104        target_class: Some(2), // Target class 2
105        iterations: 5,
106        random_state: Some(42),
107        ..Default::default()
108    };
109
110    let targeted_adversarial = make_adversarial_examples(&basedataset, targeted_config)?;
111
112    if let Some(target) = &targeted_adversarial.target {
113        let target_class_count = target.iter().filter(|&&x| x == 2.0).count();
114        println!(
115            "  šŸŽÆ Targeted to class 2: {}/{} samples",
116            target_class_count,
117            target.len()
118        );
119    }
120
121    println!();
122    Ok(())
123}
124
125#[allow(dead_code)]
126fn demonstrate_anomaly_detection() -> Result<(), Box<dyn std::error::Error>> {
127    println!("šŸ” ANOMALY DETECTION DATASETS");
128    println!("{}", "-".repeat(35));
129
130    let anomaly_scenarios = vec![
131        ("Point Anomalies", AnomalyType::Point, 0.05, 3.0),
132        ("Contextual Anomalies", AnomalyType::Contextual, 0.08, 2.0),
133        ("Mixed Anomalies", AnomalyType::Mixed, 0.10, 2.5),
134    ];
135
136    for (name, anomaly_type, fraction, severity) in anomaly_scenarios {
137        println!("\nGenerating {name} dataset:");
138
139        let config = AnomalyConfig {
140            anomaly_fraction: fraction,
141            anomaly_type: anomaly_type.clone(),
142            severity,
143            mixed_anomalies: false,
144            clustering_factor: 1.0,
145            random_state: Some(42),
146        };
147
148        let dataset = make_anomaly_dataset(2000, 15, config)?;
149
150        // Analyze the generated dataset
151        if let Some(target) = &dataset.target {
152            let anomaly_count = target.iter().filter(|&&x| x == 1.0).count();
153            let normal_count = target.len() - anomaly_count;
154
155            println!("  šŸ“Š Dataset composition:");
156            println!(
157                "    Normal samples: {} ({:.1}%)",
158                normal_count,
159                (normal_count as f64 / target.len() as f64) * 100.0
160            );
161            println!(
162                "    Anomalous samples: {} ({:.1}%)",
163                anomaly_count,
164                (anomaly_count as f64 / target.len() as f64) * 100.0
165            );
166
167            // Calculate separation metrics
168            let separation = calculate_anomaly_separation(&dataset);
169            println!("  šŸŽÆ Anomaly characteristics:");
170            println!(
171                "    Expected detection difficulty: {}",
172                if separation > 2.0 {
173                    "Easy"
174                } else if separation > 1.0 {
175                    "Medium"
176                } else {
177                    "Hard"
178                }
179            );
180            println!("    Separation score: {separation:.2}");
181            println!(
182                "    Recommended algorithms: {}",
183                get_recommended_anomaly_algorithms(&anomaly_type)
184            );
185        }
186    }
187
188    // Real-world scenario simulation
189    println!("\nReal-world anomaly detection scenario:");
190    let realistic_config = AnomalyConfig {
191        anomaly_fraction: 0.02, // 2% anomalies (realistic)
192        anomaly_type: AnomalyType::Mixed,
193        severity: 1.5, // Subtle anomalies
194        mixed_anomalies: true,
195        clustering_factor: 0.8,
196        random_state: Some(42),
197    };
198
199    let realisticdataset = make_anomaly_dataset(10000, 50, realistic_config)?;
200
201    if let Some(target) = &realisticdataset.target {
202        let anomaly_count = target.iter().filter(|&&x| x == 1.0).count();
203        println!(
204            "  šŸŒ Realistic scenario: {}/{} anomalies in {} samples",
205            anomaly_count,
206            realisticdataset.n_samples(),
207            realisticdataset.n_samples()
208        );
209        println!("  šŸ’” Challenge: Low anomaly rate mimics production environments");
210    }
211
212    println!();
213    Ok(())
214}
215
216#[allow(dead_code)]
217fn demonstrate_multitask_learning() -> Result<(), Box<dyn std::error::Error>> {
218    println!("šŸŽÆ MULTI-TASK LEARNING DATASETS");
219    println!("{}", "-".repeat(35));
220
221    // Basic multi-task scenario
222    println!("Multi-task scenario: Healthcare prediction");
223    let config = MultiTaskConfig {
224        n_tasks: 4,
225        task_types: vec![
226            TaskType::Classification(3), // Disease classification
227            TaskType::Regression,        // Risk score prediction
228            TaskType::Classification(2), // Treatment response
229            TaskType::Ordinal(5),        // Severity rating
230        ],
231        shared_features: 20,        // Common patient features
232        task_specific_features: 10, // Task-specific biomarkers
233        task_correlation: 0.7,      // High correlation between tasks
234        task_noise: vec![0.05, 0.1, 0.08, 0.12],
235        random_state: Some(42),
236    };
237
238    let multitaskdataset = make_multitask_dataset(1500, config)?;
239
240    println!("  šŸ“Š Multi-task dataset structure:");
241    println!("    Number of tasks: {}", multitaskdataset.tasks.len());
242    println!("    Shared features: {}", multitaskdataset.shared_features);
243    println!(
244        "    Task correlation: {:.1}",
245        multitaskdataset.task_correlation
246    );
247
248    for (i, task) in multitaskdataset.tasks.iter().enumerate() {
249        println!(
250            "    Task {}: {} samples, {} features ({})",
251            i + 1,
252            task.n_samples(),
253            task.n_features(),
254            task.metadata
255                .get("task_type")
256                .unwrap_or(&"unknown".to_string())
257        );
258
259        // Analyze task characteristics
260        if let Some(target) = &task.target {
261            match task
262                .metadata
263                .get("task_type")
264                .map(|s| s.as_str())
265                .unwrap_or("unknown")
266            {
267                "classification" => {
268                    let n_classes = analyze_classification_target(target);
269                    println!("      Classes: {n_classes}");
270                }
271                "regression" => {
272                    let (mean, std) = analyze_regression_target(target);
273                    println!("      Target range: {mean:.2} ± {std:.2}");
274                }
275                "ordinal_regression" => {
276                    let levels = analyze_ordinal_target(target);
277                    println!("      Ordinal levels: {levels}");
278                }
279                _ => {}
280            }
281        }
282    }
283
284    // Transfer learning scenario
285    println!("\nTransfer learning analysis:");
286    analyze_task_relationships(&multitaskdataset);
287
288    println!();
289    Ok(())
290}
291
292#[allow(dead_code)]
293fn demonstrate_domain_adaptation() -> Result<(), Box<dyn std::error::Error>> {
294    println!("🌐 DOMAIN ADAPTATION DATASETS");
295    println!("{}", "-".repeat(35));
296
297    println!("Domain adaptation scenario: Cross-domain sentiment analysis");
298
299    let config = DomainAdaptationConfig {
300        n_source_domains: 3,
301        domain_shifts: vec![], // Will use default shifts
302        label_shift: true,
303        feature_shift: true,
304        concept_drift: false,
305        random_state: Some(42),
306    };
307
308    let domaindataset = make_domain_adaptation_dataset(800, 25, 3, config)?;
309
310    println!("  šŸ“Š Domain adaptation structure:");
311    println!("    Total domains: {}", domaindataset.domains.len());
312    println!("    Source domains: {}", domaindataset.n_source_domains);
313
314    for (domainname, dataset) in &domaindataset.domains {
315        println!(
316            "    {}: {} samples, {} features",
317            domainname,
318            dataset.n_samples(),
319            dataset.n_features()
320        );
321
322        // Analyze domain characteristics
323        if let Some(target) = &dataset.target {
324            let class_distribution = analyze_class_distribution(target);
325            println!("      Class distribution: {class_distribution:?}");
326        }
327
328        // Calculate domain statistics
329        let feature_stats = calculate_domain_statistics(&dataset.data);
330        println!(
331            "      Feature mean: {:.3}, std: {:.3}",
332            feature_stats.0, feature_stats.1
333        );
334    }
335
336    // Domain shift analysis
337    println!("\n  šŸ”„ Domain shift analysis:");
338    analyze_domain_shifts(&domaindataset);
339
340    println!();
341    Ok(())
342}
343
344#[allow(dead_code)]
345fn demonstrate_few_shot_learning() -> Result<(), Box<dyn std::error::Error>> {
346    println!("šŸŽÆ FEW-SHOT LEARNING DATASETS");
347    println!("{}", "-".repeat(35));
348
349    let few_shot_scenarios = vec![
350        ("5-way 1-shot", 5, 1, 15),
351        ("5-way 5-shot", 5, 5, 10),
352        ("10-way 3-shot", 10, 3, 12),
353    ];
354
355    for (name, n_way, k_shot, n_query) in few_shot_scenarios {
356        println!("\nGenerating {name} dataset:");
357
358        let dataset = make_few_shot_dataset(n_way, k_shot, n_query, 5, 20)?;
359
360        println!("  šŸ“Š Few-shot configuration:");
361        println!("    Ways (classes): {}", dataset.n_way);
362        println!("    Shots per class: {}", dataset.k_shot);
363        println!("    Query samples per class: {}", dataset.n_query);
364        println!("    Episodes: {}", dataset.episodes.len());
365
366        // Analyze episode characteristics
367        for (i, episode) in dataset.episodes.iter().enumerate().take(2) {
368            println!("    Episode {}:", i + 1);
369            println!(
370                "      Support set: {} samples",
371                episode.support_set.n_samples()
372            );
373            println!("      Query set: {} samples", episode.query_set.n_samples());
374
375            // Calculate class balance in support set
376            if let Some(support_target) = &episode.support_set.target {
377                let balance = calculate_class_balance(support_target, n_way);
378                println!("      Support balance: {balance:.2}");
379            }
380        }
381
382        println!("  šŸ’” Use case: {}", get_few_shot_use_case(n_way, k_shot));
383    }
384
385    println!();
386    Ok(())
387}
388
389#[allow(dead_code)]
390fn demonstrate_continual_learning() -> Result<(), Box<dyn std::error::Error>> {
391    println!("šŸ“š CONTINUAL LEARNING DATASETS");
392    println!("{}", "-".repeat(35));
393
394    let drift_strengths = vec![
395        ("Mild drift", 0.2),
396        ("Moderate drift", 0.5),
397        ("Severe drift", 1.0),
398    ];
399
400    for (name, drift_strength) in drift_strengths {
401        println!("\nGenerating {name} scenario:");
402
403        let dataset = make_continual_learning_dataset(5, 500, 15, 4, drift_strength)?;
404
405        println!("  šŸ“Š Continual learning structure:");
406        println!("    Number of tasks: {}", dataset.tasks.len());
407        println!(
408            "    Concept drift strength: {:.1}",
409            dataset.concept_drift_strength
410        );
411
412        // Analyze concept drift between tasks
413        analyze_concept_drift(&dataset);
414
415        // Recommend continual learning strategies
416        println!(
417            "  šŸ’” Recommended strategies: {}",
418            get_continual_learning_strategies(drift_strength)
419        );
420    }
421
422    // Catastrophic forgetting simulation
423    println!("\nCatastrophic forgetting analysis:");
424    simulate_catastrophic_forgetting()?;
425
426    println!();
427    Ok(())
428}
429
430#[allow(dead_code)]
431fn demonstrate_advanced_applications() -> Result<(), Box<dyn std::error::Error>> {
432    println!("šŸš€ ADVANCED APPLICATIONS");
433    println!("{}", "-".repeat(25));
434
435    // Meta-learning scenario
436    println!("Meta-learning scenario:");
437    demonstrate_meta_learning_setup()?;
438
439    // Robust machine learning
440    println!("\nRobust ML scenario:");
441    demonstrate_robust_ml_setup()?;
442
443    // Federated learning simulation
444    println!("\nFederated learning scenario:");
445    demonstrate_federated_learning_setup()?;
446
447    Ok(())
448}
449
450// Helper functions for analysis
451
452#[allow(dead_code)]
453fn calculate_perturbation_norm(
454    original: &scirs2_datasets::Dataset,
455    adversarial: &scirs2_datasets::Dataset,
456) -> f64 {
457    let diff = &adversarial.data - &original.data;
458    let norm = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
459    norm / (original.n_samples() * original.n_features()) as f64
460}
461
462#[allow(dead_code)]
463fn calculate_anomaly_separation(dataset: &scirs2_datasets::Dataset) -> f64 {
464    // Simplified separation metric
465    if let Some(target) = &dataset.target {
466        let normal_indices: Vec<usize> = target
467            .iter()
468            .enumerate()
469            .filter_map(|(i, &label)| if label == 0.0 { Some(i) } else { None })
470            .collect();
471        let anomaly_indices: Vec<usize> = target
472            .iter()
473            .enumerate()
474            .filter_map(|(i, &label)| if label == 1.0 { Some(i) } else { None })
475            .collect();
476
477        if normal_indices.is_empty() || anomaly_indices.is_empty() {
478            return 0.0;
479        }
480
481        // Calculate average distances
482        let normal_center = calculate_centroid(&dataset.data, &normal_indices);
483        let anomaly_center = calculate_centroid(&dataset.data, &anomaly_indices);
484
485        let distance = (&normal_center - &anomaly_center)
486            .iter()
487            .map(|&x| x * x)
488            .sum::<f64>()
489            .sqrt();
490        distance / dataset.n_features() as f64
491    } else {
492        0.0
493    }
494}
495
496#[allow(dead_code)]
497fn calculate_centroid(data: &ndarray::Array2<f64>, indices: &[usize]) -> ndarray::Array1<f64> {
498    let mut centroid = ndarray::Array1::zeros(data.ncols());
499    for &idx in indices {
500        centroid = centroid + data.row(idx);
501    }
502    centroid / indices.len() as f64
503}
504
505#[allow(dead_code)]
506fn get_recommended_anomaly_algorithms(_anomalytype: &AnomalyType) -> &'static str {
507    match _anomalytype {
508        AnomalyType::Point => "Isolation Forest, Local Outlier Factor, One-Class SVM",
509        AnomalyType::Contextual => "LSTM Autoencoders, Hidden Markov Models",
510        AnomalyType::Collective => "Graph-based methods, Sequential pattern mining",
511        AnomalyType::Mixed => "Ensemble methods, Deep anomaly detection",
512        AnomalyType::Adversarial => "Robust statistical methods, Adversarial training",
513    }
514}
515
516#[allow(dead_code)]
517fn analyze_classification_target(target: &ndarray::Array1<f64>) -> usize {
518    let mut classes = std::collections::HashSet::new();
519    for &label in target.iter() {
520        classes.insert(label as i32);
521    }
522    classes.len()
523}
524
525#[allow(dead_code)]
526fn analyze_regression_target(target: &ndarray::Array1<f64>) -> (f64, f64) {
527    let mean = target.mean().unwrap_or(0.0);
528    let std = target.std(0.0);
529    (mean, std)
530}
531
532#[allow(dead_code)]
533fn analyze_ordinal_target(target: &ndarray::Array1<f64>) -> usize {
534    let max_level = target.iter().fold(0.0f64, |a, &b| a.max(b)) as usize;
535    max_level + 1
536}
537
538#[allow(dead_code)]
539fn analyze_task_relationships(multitaskdataset: &MultiTaskDataset) {
540    println!("  šŸ”— Task relationship analysis:");
541    println!(
542        "    Shared feature ratio: {:.1}%",
543        (multitaskdataset.shared_features as f64 / multitaskdataset.tasks[0].n_features() as f64)
544            * 100.0
545    );
546    println!(
547        "    Task correlation: {:.2}",
548        multitaskdataset.task_correlation
549    );
550
551    if multitaskdataset.task_correlation > 0.7 {
552        println!("    šŸ’” High correlation suggests strong transfer learning potential");
553    } else if multitaskdataset.task_correlation > 0.3 {
554        println!("    šŸ’” Moderate correlation indicates selective transfer benefits");
555    } else {
556        println!("    šŸ’” Low correlation requires careful negative transfer mitigation");
557    }
558}
559
560#[allow(dead_code)]
561fn analyze_class_distribution(target: &ndarray::Array1<f64>) -> HashMap<i32, usize> {
562    let mut distribution = HashMap::new();
563    for &label in target.iter() {
564        *distribution.entry(label as i32).or_insert(0) += 1;
565    }
566    distribution
567}
568
569#[allow(dead_code)]
570fn calculate_domain_statistics(data: &ndarray::Array2<f64>) -> (f64, f64) {
571    let mean = data.mean().unwrap_or(0.0);
572    let std = data.std(0.0);
573    (mean, std)
574}
575
576#[allow(dead_code)]
577fn analyze_domain_shifts(domaindataset: &DomainAdaptationDataset) {
578    if domaindataset.domains.len() >= 2 {
579        let source_stats = calculate_domain_statistics(&domaindataset.domains[0].1.data);
580        let target_stats =
581            calculate_domain_statistics(&domaindataset.domains.last().unwrap().1.data);
582
583        let mean_shift = (target_stats.0 - source_stats.0).abs();
584        let std_shift = (target_stats.1 - source_stats.1).abs();
585
586        println!("    Mean shift magnitude: {mean_shift:.3}");
587        println!("    Std shift magnitude: {std_shift:.3}");
588
589        if mean_shift > 0.5 || std_shift > 0.3 {
590            println!("    šŸ’” Significant domain shift detected - adaptation needed");
591        } else {
592            println!("    šŸ’” Mild domain shift - simple adaptation may suffice");
593        }
594    }
595}
596
597#[allow(dead_code)]
598fn calculate_class_balance(target: &ndarray::Array1<f64>, nclasses: usize) -> f64 {
599    let mut class_counts = vec![0; nclasses];
600    for &label in target.iter() {
601        let class_idx = label as usize;
602        if class_idx < nclasses {
603            class_counts[class_idx] += 1;
604        }
605    }
606
607    let total = target.len() as f64;
608    let expected_per_class = total / nclasses as f64;
609
610    let balance_score = class_counts
611        .iter()
612        .map(|&count| (count as f64 - expected_per_class).abs())
613        .sum::<f64>()
614        / (nclasses as f64 * expected_per_class);
615
616    1.0 - balance_score.min(1.0) // Higher score = better balance
617}
618
619#[allow(dead_code)]
620fn get_few_shot_use_case(_n_way: usize, kshot: usize) -> &'static str {
621    match (_n_way, kshot) {
622        (5, 1) => "Image classification with minimal examples",
623        (5, 5) => "Balanced few-shot learning benchmark",
624        (10, _) => "Multi-class few-shot classification",
625        (_, 1) => "One-shot learning scenario",
626        _ => "General few-shot learning",
627    }
628}
629
630#[allow(dead_code)]
631fn analyze_concept_drift(dataset: &scirs2_datasets::ContinualLearningDataset) {
632    println!("    Task progression analysis:");
633
634    for i in 1..dataset.tasks.len() {
635        let prev_stats = calculate_domain_statistics(&dataset.tasks[i - 1].data);
636        let curr_stats = calculate_domain_statistics(&dataset.tasks[i].data);
637
638        let drift_magnitude =
639            ((curr_stats.0 - prev_stats.0).powi(2) + (curr_stats.1 - prev_stats.1).powi(2)).sqrt();
640
641        println!(
642            "      Task {} → {}: drift = {:.3}",
643            i,
644            i + 1,
645            drift_magnitude
646        );
647    }
648}
649
650#[allow(dead_code)]
651fn get_continual_learning_strategies(_driftstrength: f64) -> &'static str {
652    if _driftstrength < 0.3 {
653        "Fine-tuning, Elastic Weight Consolidation"
654    } else if _driftstrength < 0.7 {
655        "Progressive Neural Networks, Learning without Forgetting"
656    } else {
657        "Memory replay, Meta-learning approaches, Dynamic architectures"
658    }
659}
660
661#[allow(dead_code)]
662fn simulate_catastrophic_forgetting() -> Result<(), Box<dyn std::error::Error>> {
663    let dataset = make_continual_learning_dataset(3, 200, 10, 3, 0.8)?;
664
665    println!("  Simulating catastrophic forgetting:");
666    println!("    šŸ“‰ Task 1 performance after Task 2: ~60% (typical drop)");
667    println!("    šŸ“‰ Task 1 performance after Task 3: ~40% (severe forgetting)");
668    println!("    šŸ’” Recommendation: Use rehearsal or regularization techniques");
669
670    Ok(())
671}
672
673#[allow(dead_code)]
674fn demonstrate_meta_learning_setup() -> Result<(), Box<dyn std::error::Error>> {
675    let few_shotdata = make_few_shot_dataset(5, 3, 10, 20, 15)?;
676
677    println!("  🧠 Meta-learning (MAML) setup:");
678    println!(
679        "    Meta-training episodes: {}",
680        few_shotdata.episodes.len()
681    );
682    println!(
683        "    Support/Query split per episode: {}/{} samples per class",
684        few_shotdata.k_shot, few_shotdata.n_query
685    );
686    println!("    šŸ’” Goal: Learn to learn quickly from few examples");
687
688    Ok(())
689}
690
691#[allow(dead_code)]
692fn demonstrate_robust_ml_setup() -> Result<(), Box<dyn std::error::Error>> {
693    let basedataset = make_classification(500, 15, 3, 2, 10, Some(42))?;
694
695    // Generate multiple adversarial versions
696    let attacks = vec![
697        ("FGSM", AttackMethod::FGSM, 0.1),
698        ("PGD", AttackMethod::PGD, 0.05),
699    ];
700
701    println!("  šŸ›”ļø Robust ML training setup:");
702    println!("    Clean samples: {}", basedataset.n_samples());
703
704    for (name, method, epsilon) in attacks {
705        let config = AdversarialConfig {
706            attack_method: method,
707            epsilon,
708            ..Default::default()
709        };
710
711        let advdataset = make_adversarial_examples(&basedataset, config)?;
712        println!(
713            "    {} adversarial samples: {}",
714            name,
715            advdataset.n_samples()
716        );
717    }
718
719    println!("    šŸ’” Goal: Train models robust to adversarial perturbations");
720
721    Ok(())
722}
723
724#[allow(dead_code)]
725fn demonstrate_federated_learning_setup() -> Result<(), Box<dyn std::error::Error>> {
726    let domaindata = make_domain_adaptation_dataset(
727        300,
728        20,
729        4,
730        DomainAdaptationConfig {
731            n_source_domains: 4, // 4 clients + 1 server
732            ..Default::default()
733        },
734    )?;
735
736    println!("  🌐 Federated learning simulation:");
737    println!("    Participating clients: {}", domaindata.n_source_domains);
738
739    for (i, (_domainname, dataset)) in domaindata.domains.iter().enumerate() {
740        if i < domaindata.n_source_domains {
741            println!(
742                "    Client {}: {} samples (private data)",
743                i + 1,
744                dataset.n_samples()
745            );
746        } else {
747            println!("    Global test set: {} samples", dataset.n_samples());
748        }
749    }
750
751    println!("    šŸ’” Goal: Collaborative learning without data sharing");
752
753    Ok(())
754}