scirs2_datasets/
advanced_generators.rs

1//! Advanced synthetic data generators
2//!
3//! This module provides sophisticated synthetic data generation capabilities
4//! for complex scenarios including adversarial examples, anomaly detection,
5//! multi-task learning, and domain adaptation.
6
7use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use ndarray::{Array1, Array2, Axis};
10use rand::{rng, Rng};
11use rand_distr::Uniform;
12
13/// Configuration for adversarial example generation
14#[derive(Debug, Clone)]
15pub struct AdversarialConfig {
16    /// Perturbation strength (epsilon)
17    pub epsilon: f64,
18    /// Attack method
19    pub attack_method: AttackMethod,
20    /// Target class for targeted attacks
21    pub target_class: Option<usize>,
22    /// Number of attack iterations
23    pub iterations: usize,
24    /// Step size for iterative attacks
25    pub step_size: f64,
26    /// Random seed for reproducibility
27    pub random_state: Option<u64>,
28}
29
30/// Adversarial attack methods
31#[derive(Debug, Clone, PartialEq)]
32pub enum AttackMethod {
33    /// Fast Gradient Sign Method
34    FGSM,
35    /// Projected Gradient Descent
36    PGD,
37    /// Carlini & Wagner attack
38    CW,
39    /// DeepFool attack
40    DeepFool,
41    /// Random noise baseline
42    RandomNoise,
43}
44
45impl Default for AdversarialConfig {
46    fn default() -> Self {
47        Self {
48            epsilon: 0.1,
49            attack_method: AttackMethod::FGSM,
50            target_class: None,
51            iterations: 10,
52            step_size: 0.01,
53            random_state: None,
54        }
55    }
56}
57
58/// Configuration for anomaly detection datasets
59#[derive(Debug, Clone)]
60pub struct AnomalyConfig {
61    /// Fraction of anomalous samples
62    pub anomaly_fraction: f64,
63    /// Type of anomalies to generate
64    pub anomaly_type: AnomalyType,
65    /// Severity of anomalies
66    pub severity: f64,
67    /// Whether to include multiple anomaly types
68    pub mixed_anomalies: bool,
69    /// Clustering factor for normal data
70    pub clustering_factor: f64,
71    /// Random seed
72    pub random_state: Option<u64>,
73}
74
75/// Types of anomalies
76#[derive(Debug, Clone, PartialEq)]
77pub enum AnomalyType {
78    /// Point anomalies (outliers)
79    Point,
80    /// Contextual anomalies
81    Contextual,
82    /// Collective anomalies
83    Collective,
84    /// Adversarial anomalies
85    Adversarial,
86    /// Mixed anomaly types
87    Mixed,
88}
89
90impl Default for AnomalyConfig {
91    fn default() -> Self {
92        Self {
93            anomaly_fraction: 0.1,
94            anomaly_type: AnomalyType::Point,
95            severity: 2.0,
96            mixed_anomalies: false,
97            clustering_factor: 1.0,
98            random_state: None,
99        }
100    }
101}
102
103/// Configuration for multi-task learning datasets
104#[derive(Debug, Clone)]
105pub struct MultiTaskConfig {
106    /// Number of tasks
107    pub n_tasks: usize,
108    /// Task types (classification or regression)
109    pub task_types: Vec<TaskType>,
110    /// Shared feature dimensions
111    pub shared_features: usize,
112    /// Task-specific feature dimensions
113    pub task_specific_features: usize,
114    /// Correlation between tasks
115    pub task_correlation: f64,
116    /// Noise level for each task
117    pub task_noise: Vec<f64>,
118    /// Random seed
119    pub random_state: Option<u64>,
120}
121
122/// Task types for multi-task learning
123#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125    /// Classification task with specified number of classes
126    Classification(usize),
127    /// Regression task
128    Regression,
129    /// Ordinal regression
130    Ordinal(usize),
131}
132
133impl Default for MultiTaskConfig {
134    fn default() -> Self {
135        Self {
136            n_tasks: 3,
137            task_types: vec![
138                TaskType::Classification(3),
139                TaskType::Regression,
140                TaskType::Classification(5),
141            ],
142            shared_features: 10,
143            task_specific_features: 5,
144            task_correlation: 0.5,
145            task_noise: vec![0.1, 0.1, 0.1],
146            random_state: None,
147        }
148    }
149}
150
151/// Configuration for domain adaptation datasets
152#[derive(Debug, Clone)]
153pub struct DomainAdaptationConfig {
154    /// Number of source domains
155    pub n_source_domains: usize,
156    /// Domain shift parameters
157    pub domain_shifts: Vec<DomainShift>,
158    /// Label shift (different class distributions)
159    pub label_shift: bool,
160    /// Feature shift (different feature distributions)
161    pub feature_shift: bool,
162    /// Concept drift over time
163    pub concept_drift: bool,
164    /// Random seed
165    pub random_state: Option<u64>,
166}
167
168/// Domain shift types
169#[derive(Debug, Clone)]
170pub struct DomainShift {
171    /// Shift in feature means
172    pub mean_shift: Array1<f64>,
173    /// Shift in feature covariances
174    pub covariance_shift: Option<Array2<f64>>,
175    /// Shift strength
176    pub shift_strength: f64,
177}
178
179impl Default for DomainAdaptationConfig {
180    fn default() -> Self {
181        Self {
182            n_source_domains: 2,
183            domain_shifts: vec![],
184            label_shift: true,
185            feature_shift: true,
186            concept_drift: false,
187            random_state: None,
188        }
189    }
190}
191
192/// Advanced data generator
193pub struct AdvancedGenerator {
194    random_state: Option<u64>,
195}
196
197impl AdvancedGenerator {
198    /// Create a new advanced generator
199    pub fn new(_random_state: Option<u64>) -> Self {
200        Self {
201            random_state: _random_state,
202        }
203    }
204
205    /// Generate adversarial examples
206    pub fn make_adversarial_examples(
207        &self,
208        base_dataset: &Dataset,
209        config: AdversarialConfig,
210    ) -> Result<Dataset> {
211        let n_samples = base_dataset.n_samples();
212        let _n_features = base_dataset.n_features();
213
214        println!(
215            "Generating adversarial examples using {:?}",
216            config.attack_method
217        );
218
219        // Create adversarial perturbations
220        let perturbations = self.generate_perturbations(&base_dataset.data, &config)?;
221
222        // Apply perturbations
223        let adversarial_data = &base_dataset.data + &perturbations;
224
225        // Clip to valid range if needed
226        let clipped_data = adversarial_data.mapv(|x| x.clamp(-5.0, 5.0));
227
228        // Create adversarial labels
229        let adversarial_target = if let Some(target) = &base_dataset.target {
230            match config.target_class {
231                Some(target_class) => {
232                    // Targeted attack - change labels to target class
233                    Some(Array1::from_elem(n_samples, target_class as f64))
234                }
235                None => {
236                    // Untargeted attack - keep original labels but mark as adversarial
237                    Some(target.clone())
238                }
239            }
240        } else {
241            None
242        };
243
244        let mut metadata = base_dataset.metadata.clone();
245        let _old_description = metadata.get("description").cloned().unwrap_or_default();
246        let oldname = metadata.get("name").cloned().unwrap_or_default();
247
248        metadata.insert(
249            "description".to_string(),
250            format!(
251                "Adversarial examples generated using {:?}",
252                config.attack_method
253            ),
254        );
255        metadata.insert("name".to_string(), format!("{oldname} (Adversarial)"));
256
257        Ok(Dataset {
258            data: clipped_data,
259            target: adversarial_target,
260            targetnames: base_dataset.targetnames.clone(),
261            featurenames: base_dataset.featurenames.clone(),
262            feature_descriptions: base_dataset.feature_descriptions.clone(),
263            description: base_dataset.description.clone(),
264            metadata,
265        })
266    }
267
268    /// Generate anomaly detection dataset
269    pub fn make_anomaly_dataset(
270        &self,
271        n_samples: usize,
272        n_features: usize,
273        config: AnomalyConfig,
274    ) -> Result<Dataset> {
275        let n_anomalies = (n_samples as f64 * config.anomaly_fraction) as usize;
276        let n_normal = n_samples - n_anomalies;
277
278        println!("Generating anomaly dataset: {n_normal} normal, {n_anomalies} anomalous");
279
280        // Generate normal data
281        let normal_data =
282            self.generate_normal_data(n_normal, n_features, config.clustering_factor)?;
283
284        // Generate anomalous data
285        let anomalous_data =
286            self.generate_anomalous_data(n_anomalies, n_features, &normal_data, &config)?;
287
288        // Combine data
289        let mut combined_data = Array2::zeros((n_samples, n_features));
290        combined_data
291            .slice_mut(ndarray::s![..n_normal, ..])
292            .assign(&normal_data);
293        combined_data
294            .slice_mut(ndarray::s![n_normal.., ..])
295            .assign(&anomalous_data);
296
297        // Create labels (0 = normal, 1 = anomaly)
298        let mut target = Array1::zeros(n_samples);
299        target.slice_mut(ndarray::s![n_normal..]).fill(1.0);
300
301        // Shuffle the data
302        let shuffled_indices = self.generate_shuffle_indices(n_samples)?;
303        let shuffled_data = self.shuffle_by_indices(&combined_data, &shuffled_indices);
304        let shuffled_target = self.shuffle_array_by_indices(&target, &shuffled_indices);
305
306        let metadata = crate::registry::DatasetMetadata {
307            name: "Anomaly Detection Dataset".to_string(),
308            description: format!(
309                "Synthetic anomaly detection dataset with {:.1}% anomalies",
310                config.anomaly_fraction * 100.0
311            ),
312            n_samples,
313            n_features,
314            task_type: "anomaly_detection".to_string(),
315            targetnames: Some(vec!["normal".to_string(), "anomaly".to_string()]),
316            ..Default::default()
317        };
318
319        Ok(Dataset::from_metadata(
320            shuffled_data,
321            Some(shuffled_target),
322            metadata,
323        ))
324    }
325
326    /// Generate multi-task learning dataset
327    pub fn make_multitask_dataset(
328        &self,
329        n_samples: usize,
330        config: MultiTaskConfig,
331    ) -> Result<MultiTaskDataset> {
332        let total_features =
333            config.shared_features + config.task_specific_features * config.n_tasks;
334
335        println!(
336            "Generating multi-task dataset: {} tasks, {} samples, {} features",
337            config.n_tasks, n_samples, total_features
338        );
339
340        // Generate shared features
341        let shared_data = self.generate_shared_features(n_samples, config.shared_features)?;
342
343        // Generate task-specific features and targets
344        let mut task_datasets = Vec::new();
345
346        for (task_id, task_type) in config.task_types.iter().enumerate() {
347            let task_specific_data = self.generate_task_specific_features(
348                n_samples,
349                config.task_specific_features,
350                task_id,
351            )?;
352
353            // Combine shared and task-specific features
354            let task_data = self.combine_features(&shared_data, &task_specific_data);
355
356            // Generate task target based on task type
357            let task_target = self.generate_task_target(
358                &task_data,
359                task_type,
360                config.task_correlation,
361                config.task_noise.get(task_id).unwrap_or(&0.1),
362            )?;
363
364            let task_metadata = crate::registry::DatasetMetadata {
365                name: format!("Task {task_id}"),
366                description: format!("Multi-task learning task {task_id} ({task_type:?})"),
367                n_samples,
368                n_features: task_data.ncols(),
369                task_type: match task_type {
370                    TaskType::Classification(_) => "classification".to_string(),
371                    TaskType::Regression => "regression".to_string(),
372                    TaskType::Ordinal(_) => "ordinal_regression".to_string(),
373                },
374                ..Default::default()
375            };
376
377            task_datasets.push(Dataset::from_metadata(
378                task_data,
379                Some(task_target),
380                task_metadata,
381            ));
382        }
383
384        Ok(MultiTaskDataset {
385            tasks: task_datasets,
386            shared_features: config.shared_features,
387            task_correlation: config.task_correlation,
388        })
389    }
390
391    /// Generate domain adaptation dataset
392    pub fn make_domain_adaptation_dataset(
393        &self,
394        n_samples_per_domain: usize,
395        n_features: usize,
396        n_classes: usize,
397        config: DomainAdaptationConfig,
398    ) -> Result<DomainAdaptationDataset> {
399        let total_domains = config.n_source_domains + 1; // +1 for target _domain
400
401        println!(
402            "Generating _domain adaptation dataset: {total_domains} domains, {n_samples_per_domain} samples each"
403        );
404
405        let mut domain_datasets = Vec::new();
406
407        // Generate source _domain (reference)
408        let source_dataset =
409            self.generate_base_domain_dataset(n_samples_per_domain, n_features, n_classes)?;
410
411        domain_datasets.push(("source".to_string(), source_dataset.clone()));
412
413        // Generate additional source domains with shifts
414        for domain_id in 1..config.n_source_domains {
415            let shift = if domain_id - 1 < config.domain_shifts.len() {
416                &config.domain_shifts[domain_id - 1]
417            } else {
418                // Generate default shift
419                &DomainShift {
420                    mean_shift: Array1::from_elem(n_features, 0.5),
421                    covariance_shift: None,
422                    shift_strength: 1.0,
423                }
424            };
425
426            let shifted_dataset = self.apply_domain_shift(&source_dataset, shift)?;
427            domain_datasets.push((format!("source_{domain_id}"), shifted_dataset));
428        }
429
430        // Generate target _domain with different shift
431        let target_shift = DomainShift {
432            mean_shift: Array1::from_elem(n_features, 1.0),
433            covariance_shift: None,
434            shift_strength: 1.5,
435        };
436
437        let target_dataset = self.apply_domain_shift(&source_dataset, &target_shift)?;
438        domain_datasets.push(("target".to_string(), target_dataset));
439
440        Ok(DomainAdaptationDataset {
441            domains: domain_datasets,
442            n_source_domains: config.n_source_domains,
443        })
444    }
445
446    /// Generate few-shot learning dataset
447    pub fn make_few_shot_dataset(
448        &self,
449        n_way: usize,
450        k_shot: usize,
451        n_query: usize,
452        n_episodes: usize,
453        n_features: usize,
454    ) -> Result<FewShotDataset> {
455        println!(
456            "Generating few-_shot dataset: {n_way}-_way {k_shot}-_shot, {n_episodes} _episodes"
457        );
458
459        let mut episodes = Vec::new();
460
461        for episode_id in 0..n_episodes {
462            let support_set = self.generate_support_set(n_way, k_shot, n_features, episode_id)?;
463            let query_set =
464                self.generate_query_set(n_way, n_query, n_features, &support_set, episode_id)?;
465
466            episodes.push(FewShotEpisode {
467                support_set,
468                query_set,
469                n_way,
470                k_shot,
471            });
472        }
473
474        Ok(FewShotDataset {
475            episodes,
476            n_way,
477            k_shot,
478            n_query,
479        })
480    }
481
482    /// Generate continual learning dataset with concept drift
483    pub fn make_continual_learning_dataset(
484        &self,
485        n_tasks: usize,
486        n_samples_per_task: usize,
487        n_features: usize,
488        n_classes: usize,
489        concept_drift_strength: f64,
490    ) -> Result<ContinualLearningDataset> {
491        println!("Generating continual learning dataset: {n_tasks} _tasks with concept drift");
492
493        let mut task_datasets = Vec::new();
494        let mut base_centers = self.generate_class_centers(n_classes, n_features)?;
495
496        for task_id in 0..n_tasks {
497            // Apply concept drift
498            if task_id > 0 {
499                let drift = Array2::from_shape_fn((n_classes, n_features), |_| {
500                    rng().random::<f64>() * concept_drift_strength
501                });
502                base_centers = base_centers + drift;
503            }
504
505            let task_dataset = self.generate_classification_from_centers(
506                n_samples_per_task,
507                &base_centers,
508                1.0, // cluster_std
509                task_id as u64,
510            )?;
511
512            let mut metadata = task_dataset.metadata.clone();
513            metadata.insert(
514                "name".to_string(),
515                format!("Continual Learning Task {task_id}"),
516            );
517            metadata.insert(
518                "description".to_string(),
519                format!("Task {task_id} with concept drift _strength {concept_drift_strength:.2}"),
520            );
521
522            task_datasets.push(Dataset {
523                data: task_dataset.data,
524                target: task_dataset.target,
525                targetnames: task_dataset.targetnames,
526                featurenames: task_dataset.featurenames,
527                feature_descriptions: task_dataset.feature_descriptions,
528                description: task_dataset.description,
529                metadata,
530            });
531        }
532
533        Ok(ContinualLearningDataset {
534            tasks: task_datasets,
535            concept_drift_strength,
536        })
537    }
538
539    // Private helper methods
540
541    fn generate_perturbations(
542        &self,
543        data: &Array2<f64>,
544        config: &AdversarialConfig,
545    ) -> Result<Array2<f64>> {
546        let (n_samples, n_features) = data.dim();
547
548        match config.attack_method {
549            AttackMethod::FGSM => {
550                // Fast Gradient Sign Method
551                let mut perturbations = Array2::zeros((n_samples, n_features));
552                for i in 0..n_samples {
553                    for j in 0..n_features {
554                        let sign = if rng().random::<f64>() > 0.5 {
555                            1.0
556                        } else {
557                            -1.0
558                        };
559                        perturbations[[i, j]] = config.epsilon * sign;
560                    }
561                }
562                Ok(perturbations)
563            }
564            AttackMethod::PGD => {
565                // Projected Gradient Descent (simplified)
566                let mut perturbations: Array2<f64> = Array2::zeros((n_samples, n_features));
567                for _iter in 0..config.iterations {
568                    for i in 0..n_samples {
569                        for j in 0..n_features {
570                            let gradient = rng().random::<f64>() * 2.0 - 1.0; // Simulated gradient
571                            perturbations[[i, j]] += config.step_size * gradient.signum();
572                            // Clip to epsilon ball
573                            perturbations[[i, j]] =
574                                perturbations[[i, j]].clamp(-config.epsilon, config.epsilon);
575                        }
576                    }
577                }
578                Ok(perturbations)
579            }
580            AttackMethod::RandomNoise => {
581                // Random noise baseline
582                let perturbations = Array2::from_shape_fn((n_samples, n_features), |_| {
583                    (rng().random::<f64>() * 2.0 - 1.0) * config.epsilon
584                });
585                Ok(perturbations)
586            }
587            _ => {
588                // For other methods, use random noise directly
589                let mut perturbations = Array2::zeros(data.dim());
590                for i in 0..data.nrows() {
591                    for j in 0..data.ncols() {
592                        let noise = rng().random::<f64>() * 2.0 - 1.0;
593                        perturbations[[i, j]] = config.epsilon * noise;
594                    }
595                }
596                Ok(perturbations)
597            }
598        }
599    }
600
601    fn generate_normal_data(
602        &self,
603        n_samples: usize,
604        n_features: usize,
605        clustering_factor: f64,
606    ) -> Result<Array2<f64>> {
607        // Generate clustered normal data
608        use crate::generators::make_blobs;
609        let n_clusters = ((n_features as f64).sqrt() as usize).max(2);
610        let dataset = make_blobs(
611            n_samples,
612            n_features,
613            n_clusters,
614            clustering_factor,
615            self.random_state,
616        )?;
617        Ok(dataset.data)
618    }
619
620    fn generate_anomalous_data(
621        &self,
622        n_anomalies: usize,
623        n_features: usize,
624        normal_data: &Array2<f64>,
625        config: &AnomalyConfig,
626    ) -> Result<Array2<f64>> {
627        use rand::Rng;
628        let mut rng = rng();
629
630        match config.anomaly_type {
631            AnomalyType::Point => {
632                // Point _anomalies - outliers far from normal distribution
633                let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
634                    DatasetsError::ComputationError(
635                        "Failed to compute mean for normal data".to_string(),
636                    )
637                })?;
638                let normal_std = normal_data.std_axis(Axis(0), 0.0);
639
640                let mut anomalies = Array2::zeros((n_anomalies, n_features));
641                for i in 0..n_anomalies {
642                    for j in 0..n_features {
643                        let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
644                        anomalies[[i, j]] =
645                            normal_mean[j] + direction * config.severity * normal_std[j];
646                    }
647                }
648                Ok(anomalies)
649            }
650            AnomalyType::Contextual => {
651                // Contextual _anomalies - normal values but in wrong context
652                let mut anomalies: Array2<f64> = Array2::zeros((n_anomalies, n_features));
653                for i in 0..n_anomalies {
654                    // Pick a random normal sample and permute some _features
655                    let base_idx = rng.sample(Uniform::new(0, normal_data.nrows()).unwrap());
656                    let mut anomaly = normal_data.row(base_idx).to_owned();
657
658                    // Permute random _features
659                    let n_permute = (n_features as f64 * 0.3) as usize;
660                    for _ in 0..n_permute {
661                        let j = rng.sample(Uniform::new(0, n_features).unwrap());
662                        let k = rng.sample(Uniform::new(0, n_features).unwrap());
663                        let temp = anomaly[j];
664                        anomaly[j] = anomaly[k];
665                        anomaly[k] = temp;
666                    }
667
668                    anomalies.row_mut(i).assign(&anomaly);
669                }
670                Ok(anomalies)
671            }
672            _ => {
673                // Default to point _anomalies implementation
674                let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
675                    DatasetsError::ComputationError(
676                        "Failed to compute mean for normal data".to_string(),
677                    )
678                })?;
679                let normal_std = normal_data.std_axis(Axis(0), 0.0);
680
681                let mut anomalies = Array2::zeros((n_anomalies, n_features));
682                for i in 0..n_anomalies {
683                    for j in 0..n_features {
684                        let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
685                        anomalies[[i, j]] =
686                            normal_mean[j] + direction * config.severity * normal_std[j];
687                    }
688                }
689                Ok(anomalies)
690            }
691        }
692    }
693
694    fn generate_shuffle_indices(&self, n_samples: usize) -> Result<Vec<usize>> {
695        use rand::Rng;
696        let mut rng = rng();
697        let mut indices: Vec<usize> = (0..n_samples).collect();
698
699        // Simple shuffle using Fisher-Yates
700        for i in (1..n_samples).rev() {
701            let j = rng.sample(Uniform::new(0, i).unwrap());
702            indices.swap(i, j);
703        }
704
705        Ok(indices)
706    }
707
708    fn shuffle_by_indices(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
709        let mut shuffled = Array2::zeros(data.dim());
710        for (new_idx, &old_idx) in indices.iter().enumerate() {
711            shuffled.row_mut(new_idx).assign(&data.row(old_idx));
712        }
713        shuffled
714    }
715
716    fn shuffle_array_by_indices(&self, array: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
717        let mut shuffled = Array1::zeros(array.len());
718        for (new_idx, &old_idx) in indices.iter().enumerate() {
719            shuffled[new_idx] = array[old_idx];
720        }
721        shuffled
722    }
723
724    fn generate_shared_features(&self, n_samples: usize, n_features: usize) -> Result<Array2<f64>> {
725        // Generate shared _features using multivariate normal distribution
726        let data = Array2::from_shape_fn((n_samples, n_features), |_| {
727            rng().random::<f64>() * 2.0 - 1.0 // Standard normal approximation
728        });
729        Ok(data)
730    }
731
732    fn generate_task_specific_features(
733        &self,
734        n_samples: usize,
735        n_features: usize,
736        task_id: usize,
737    ) -> Result<Array2<f64>> {
738        // Generate task-specific _features with slight bias per task
739        let task_bias = task_id as f64 * 0.1;
740        let data = Array2::from_shape_fn((n_samples, n_features), |_| {
741            rng().random::<f64>() * 2.0 - 1.0 + task_bias
742        });
743        Ok(data)
744    }
745
746    fn combine_features(&self, shared: &Array2<f64>, task_specific: &Array2<f64>) -> Array2<f64> {
747        let n_samples = shared.nrows();
748        let total_features = shared.ncols() + task_specific.ncols();
749        let mut combined = Array2::zeros((n_samples, total_features));
750
751        combined
752            .slice_mut(ndarray::s![.., ..shared.ncols()])
753            .assign(shared);
754        combined
755            .slice_mut(ndarray::s![.., shared.ncols()..])
756            .assign(task_specific);
757
758        combined
759    }
760
761    fn generate_task_target(
762        &self,
763        data: &Array2<f64>,
764        task_type: &TaskType,
765        correlation: f64,
766        noise: &f64,
767    ) -> Result<Array1<f64>> {
768        let n_samples = data.nrows();
769
770        match task_type {
771            TaskType::Classification(n_classes) => {
772                // Generate classification target based on data
773                let target = Array1::from_shape_fn(n_samples, |i| {
774                    let feature_sum = data.row(i).sum();
775                    let class = ((feature_sum * correlation).abs() as usize) % n_classes;
776                    class as f64
777                });
778                Ok(target)
779            }
780            TaskType::Regression => {
781                // Generate regression target as linear combination of features
782                let target = Array1::from_shape_fn(n_samples, |i| {
783                    let weighted_sum = data
784                        .row(i)
785                        .iter()
786                        .enumerate()
787                        .map(|(j, &x)| x * (j as f64 + 1.0) * correlation)
788                        .sum::<f64>();
789                    weighted_sum + rng().random::<f64>() * noise
790                });
791                Ok(target)
792            }
793            TaskType::Ordinal(n_levels) => {
794                // Generate ordinal target
795                let target = Array1::from_shape_fn(n_samples, |i| {
796                    let feature_sum = data.row(i).sum();
797                    let level = ((feature_sum * correlation).abs() as usize) % n_levels;
798                    level as f64
799                });
800                Ok(target)
801            }
802        }
803    }
804
805    fn generate_base_domain_dataset(
806        &self,
807        n_samples: usize,
808        n_features: usize,
809        n_classes: usize,
810    ) -> Result<Dataset> {
811        use crate::generators::make_classification;
812        make_classification(
813            n_samples,
814            n_features,
815            n_classes,
816            2,
817            n_features / 2,
818            self.random_state,
819        )
820    }
821
822    fn apply_domain_shift(&self, base_dataset: &Dataset, shift: &DomainShift) -> Result<Dataset> {
823        let shifted_data = &base_dataset.data + &shift.mean_shift;
824
825        let mut metadata = base_dataset.metadata.clone();
826        let old_description = metadata.get("description").cloned().unwrap_or_default();
827        metadata.insert(
828            "description".to_string(),
829            format!("{old_description} (Domain Shifted)"),
830        );
831
832        Ok(Dataset {
833            data: shifted_data,
834            target: base_dataset.target.clone(),
835            targetnames: base_dataset.targetnames.clone(),
836            featurenames: base_dataset.featurenames.clone(),
837            feature_descriptions: base_dataset.feature_descriptions.clone(),
838            description: base_dataset.description.clone(),
839            metadata,
840        })
841    }
842
843    fn generate_support_set(
844        &self,
845        n_way: usize,
846        k_shot: usize,
847        n_features: usize,
848        episode_id: usize,
849    ) -> Result<Dataset> {
850        let n_samples = n_way * k_shot;
851        use crate::generators::make_classification;
852        make_classification(
853            n_samples,
854            n_features,
855            n_way,
856            1,
857            n_features / 2,
858            Some(episode_id as u64),
859        )
860    }
861
862    fn generate_query_set(
863        &self,
864        n_way: usize,
865        n_query: usize,
866        n_features: usize,
867        _set: &Dataset,
868        episode_id: usize,
869    ) -> Result<Dataset> {
870        let n_samples = n_way * n_query;
871        use crate::generators::make_classification;
872        make_classification(
873            n_samples,
874            n_features,
875            n_way,
876            1,
877            n_features / 2,
878            Some(episode_id as u64 + 1000),
879        )
880    }
881
882    fn generate_class_centers(&self, n_classes: usize, n_features: usize) -> Result<Array2<f64>> {
883        let centers = Array2::from_shape_fn((n_classes, n_features), |_| {
884            rng().random::<f64>() * 4.0 - 2.0
885        });
886        Ok(centers)
887    }
888
889    fn generate_classification_from_centers(
890        &self,
891        n_samples: usize,
892        centers: &Array2<f64>,
893        cluster_std: f64,
894        seed: u64,
895    ) -> Result<Dataset> {
896        use crate::generators::make_blobs;
897        make_blobs(
898            n_samples,
899            centers.ncols(),
900            centers.nrows(),
901            cluster_std,
902            Some(seed),
903        )
904    }
905}
906
907/// Multi-task learning dataset container
908#[derive(Debug)]
909pub struct MultiTaskDataset {
910    /// Individual task datasets
911    pub tasks: Vec<Dataset>,
912    /// Number of shared features
913    pub shared_features: usize,
914    /// Correlation between tasks
915    pub task_correlation: f64,
916}
917
918/// Domain adaptation dataset container
919#[derive(Debug)]
920pub struct DomainAdaptationDataset {
921    /// Datasets for each domain (name, dataset)
922    pub domains: Vec<(String, Dataset)>,
923    /// Number of source domains
924    pub n_source_domains: usize,
925}
926
927/// Few-shot learning episode
928#[derive(Debug)]
929pub struct FewShotEpisode {
930    /// Support set for learning
931    pub support_set: Dataset,
932    /// Query set for evaluation
933    pub query_set: Dataset,
934    /// Number of classes (ways)
935    pub n_way: usize,
936    /// Number of examples per class (shots)
937    pub k_shot: usize,
938}
939
940/// Few-shot learning dataset
941#[derive(Debug)]
942pub struct FewShotDataset {
943    /// Training/evaluation episodes
944    pub episodes: Vec<FewShotEpisode>,
945    /// Number of classes per episode
946    pub n_way: usize,
947    /// Number of shots per class
948    pub k_shot: usize,
949    /// Number of query samples per class
950    pub n_query: usize,
951}
952
953/// Continual learning dataset
954#[derive(Debug)]
955pub struct ContinualLearningDataset {
956    /// Sequential tasks
957    pub tasks: Vec<Dataset>,
958    /// Strength of concept drift between tasks
959    pub concept_drift_strength: f64,
960}
961
962/// Convenience functions for advanced data generation
963///
964/// Generate adversarial examples from a base dataset
965#[allow(dead_code)]
966pub fn make_adversarial_examples(
967    base_dataset: &Dataset,
968    config: AdversarialConfig,
969) -> Result<Dataset> {
970    let generator = AdvancedGenerator::new(config.random_state);
971    generator.make_adversarial_examples(base_dataset, config)
972}
973
974/// Generate anomaly detection dataset
975#[allow(dead_code)]
976pub fn make_anomaly_dataset(
977    n_samples: usize,
978    n_features: usize,
979    config: AnomalyConfig,
980) -> Result<Dataset> {
981    let generator = AdvancedGenerator::new(config.random_state);
982    generator.make_anomaly_dataset(n_samples, n_features, config)
983}
984
985/// Generate multi-task learning dataset
986#[allow(dead_code)]
987pub fn make_multitask_dataset(
988    n_samples: usize,
989    config: MultiTaskConfig,
990) -> Result<MultiTaskDataset> {
991    let generator = AdvancedGenerator::new(config.random_state);
992    generator.make_multitask_dataset(n_samples, config)
993}
994
995/// Generate domain adaptation dataset
996#[allow(dead_code)]
997pub fn make_domain_adaptation_dataset(
998    n_samples_per_domain: usize,
999    n_features: usize,
1000    n_classes: usize,
1001    config: DomainAdaptationConfig,
1002) -> Result<DomainAdaptationDataset> {
1003    let generator = AdvancedGenerator::new(config.random_state);
1004    generator.make_domain_adaptation_dataset(n_samples_per_domain, n_features, n_classes, config)
1005}
1006
1007/// Generate few-shot learning dataset
1008#[allow(dead_code)]
1009pub fn make_few_shot_dataset(
1010    n_way: usize,
1011    k_shot: usize,
1012    n_query: usize,
1013    n_episodes: usize,
1014    n_features: usize,
1015) -> Result<FewShotDataset> {
1016    let generator = AdvancedGenerator::new(Some(42));
1017    generator.make_few_shot_dataset(n_way, k_shot, n_query, n_episodes, n_features)
1018}
1019
1020/// Generate continual learning dataset
1021#[allow(dead_code)]
1022pub fn make_continual_learning_dataset(
1023    n_tasks: usize,
1024    n_samples_per_task: usize,
1025    n_features: usize,
1026    n_classes: usize,
1027    concept_drift_strength: f64,
1028) -> Result<ContinualLearningDataset> {
1029    let generator = AdvancedGenerator::new(Some(42));
1030    generator.make_continual_learning_dataset(
1031        n_tasks,
1032        n_samples_per_task,
1033        n_features,
1034        n_classes,
1035        concept_drift_strength,
1036    )
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042    use crate::generators::make_classification;
1043
1044    #[test]
1045    fn test_adversarial_config() {
1046        let config = AdversarialConfig::default();
1047        assert_eq!(config.epsilon, 0.1);
1048        assert_eq!(config.attack_method, AttackMethod::FGSM);
1049        assert_eq!(config.iterations, 10);
1050    }
1051
1052    #[test]
1053    fn test_anomaly_dataset_generation() {
1054        let config = AnomalyConfig {
1055            anomaly_fraction: 0.2,
1056            anomaly_type: AnomalyType::Point,
1057            severity: 2.0,
1058            ..Default::default()
1059        };
1060
1061        let dataset = make_anomaly_dataset(100, 10, config).unwrap();
1062
1063        assert_eq!(dataset.n_samples(), 100);
1064        assert_eq!(dataset.n_features(), 10);
1065        assert!(dataset.target.is_some());
1066
1067        // Check that we have both normal and anomalous samples
1068        let target = dataset.target.unwrap();
1069        let anomalies = target.iter().filter(|&&x| x == 1.0).count();
1070        assert!(anomalies > 0);
1071        assert!(anomalies < 100);
1072    }
1073
1074    #[test]
1075    fn test_multitask_dataset_generation() {
1076        let config = MultiTaskConfig {
1077            n_tasks: 2,
1078            task_types: vec![TaskType::Classification(3), TaskType::Regression],
1079            shared_features: 5,
1080            task_specific_features: 3,
1081            ..Default::default()
1082        };
1083
1084        let dataset = make_multitask_dataset(50, config).unwrap();
1085
1086        assert_eq!(dataset.tasks.len(), 2);
1087        assert_eq!(dataset.shared_features, 5);
1088
1089        for task in &dataset.tasks {
1090            assert_eq!(task.n_samples(), 50);
1091            assert!(task.target.is_some());
1092        }
1093    }
1094
1095    #[test]
1096    fn test_adversarial_examples_generation() {
1097        let base_dataset = make_classification(100, 10, 3, 2, 8, Some(42)).unwrap();
1098        let config = AdversarialConfig {
1099            epsilon: 0.1,
1100            attack_method: AttackMethod::FGSM,
1101            ..Default::default()
1102        };
1103
1104        let adversarial_dataset = make_adversarial_examples(&base_dataset, config).unwrap();
1105
1106        assert_eq!(adversarial_dataset.n_samples(), base_dataset.n_samples());
1107        assert_eq!(adversarial_dataset.n_features(), base_dataset.n_features());
1108
1109        // Check that the data has been perturbed
1110        let original_mean = base_dataset.data.mean().unwrap_or(0.0);
1111        let adversarial_mean = adversarial_dataset.data.mean().unwrap_or(0.0);
1112        assert!((original_mean - adversarial_mean).abs() > 1e-6);
1113    }
1114
1115    #[test]
1116    fn test_few_shot_dataset() {
1117        let dataset = make_few_shot_dataset(5, 3, 10, 2, 20).unwrap();
1118
1119        assert_eq!(dataset.n_way, 5);
1120        assert_eq!(dataset.k_shot, 3);
1121        assert_eq!(dataset.n_query, 10);
1122        assert_eq!(dataset.episodes.len(), 2);
1123
1124        for episode in &dataset.episodes {
1125            assert_eq!(episode.n_way, 5);
1126            assert_eq!(episode.k_shot, 3);
1127            assert_eq!(episode.support_set.n_samples(), 5 * 3); // n_way * k_shot
1128            assert_eq!(episode.query_set.n_samples(), 5 * 10); // n_way * n_query
1129        }
1130    }
1131
1132    #[test]
1133    fn test_continual_learning_dataset() {
1134        let dataset = make_continual_learning_dataset(3, 100, 10, 5, 0.5).unwrap();
1135
1136        assert_eq!(dataset.tasks.len(), 3);
1137        assert_eq!(dataset.concept_drift_strength, 0.5);
1138
1139        for task in &dataset.tasks {
1140            assert_eq!(task.n_samples(), 100);
1141            assert_eq!(task.n_features(), 10);
1142            assert!(task.target.is_some());
1143        }
1144    }
1145}