scirs2_datasets/
advanced_generators.rs

1//! Advanced synthetic data generators
2//!
3//! This module provides sophisticated synthetic data generation capabilities
4//! for complex scenarios including adversarial examples, anomaly detection,
5//! multi-task learning, and domain adaptation.
6
7use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::Uniform;
12
13/// Configuration for adversarial example generation
14#[derive(Debug, Clone)]
15pub struct AdversarialConfig {
16    /// Perturbation strength (epsilon)
17    pub epsilon: f64,
18    /// Attack method
19    pub attack_method: AttackMethod,
20    /// Target class for targeted attacks
21    pub target_class: Option<usize>,
22    /// Number of attack iterations
23    pub iterations: usize,
24    /// Step size for iterative attacks
25    pub step_size: f64,
26    /// Random seed for reproducibility
27    pub random_state: Option<u64>,
28}
29
30/// Adversarial attack methods
31#[derive(Debug, Clone, PartialEq)]
32pub enum AttackMethod {
33    /// Fast Gradient Sign Method
34    FGSM,
35    /// Projected Gradient Descent
36    PGD,
37    /// Carlini & Wagner attack
38    CW,
39    /// DeepFool attack
40    DeepFool,
41    /// Random noise baseline
42    RandomNoise,
43}
44
45impl Default for AdversarialConfig {
46    fn default() -> Self {
47        Self {
48            epsilon: 0.1,
49            attack_method: AttackMethod::FGSM,
50            target_class: None,
51            iterations: 10,
52            step_size: 0.01,
53            random_state: None,
54        }
55    }
56}
57
58/// Configuration for anomaly detection datasets
59#[derive(Debug, Clone)]
60pub struct AnomalyConfig {
61    /// Fraction of anomalous samples
62    pub anomaly_fraction: f64,
63    /// Type of anomalies to generate
64    pub anomaly_type: AnomalyType,
65    /// Severity of anomalies
66    pub severity: f64,
67    /// Whether to include multiple anomaly types
68    pub mixed_anomalies: bool,
69    /// Clustering factor for normal data
70    pub clustering_factor: f64,
71    /// Random seed
72    pub random_state: Option<u64>,
73}
74
75/// Types of anomalies
76#[derive(Debug, Clone, PartialEq)]
77pub enum AnomalyType {
78    /// Point anomalies (outliers)
79    Point,
80    /// Contextual anomalies
81    Contextual,
82    /// Collective anomalies
83    Collective,
84    /// Adversarial anomalies
85    Adversarial,
86    /// Mixed anomaly types
87    Mixed,
88}
89
90impl Default for AnomalyConfig {
91    fn default() -> Self {
92        Self {
93            anomaly_fraction: 0.1,
94            anomaly_type: AnomalyType::Point,
95            severity: 2.0,
96            mixed_anomalies: false,
97            clustering_factor: 1.0,
98            random_state: None,
99        }
100    }
101}
102
103/// Configuration for multi-task learning datasets
104#[derive(Debug, Clone)]
105pub struct MultiTaskConfig {
106    /// Number of tasks
107    pub n_tasks: usize,
108    /// Task types (classification or regression)
109    pub task_types: Vec<TaskType>,
110    /// Shared feature dimensions
111    pub shared_features: usize,
112    /// Task-specific feature dimensions
113    pub task_specific_features: usize,
114    /// Correlation between tasks
115    pub task_correlation: f64,
116    /// Noise level for each task
117    pub task_noise: Vec<f64>,
118    /// Random seed
119    pub random_state: Option<u64>,
120}
121
122/// Task types for multi-task learning
123#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125    /// Classification task with specified number of classes
126    Classification(usize),
127    /// Regression task
128    Regression,
129    /// Ordinal regression
130    Ordinal(usize),
131}
132
133impl Default for MultiTaskConfig {
134    fn default() -> Self {
135        Self {
136            n_tasks: 3,
137            task_types: vec![
138                TaskType::Classification(3),
139                TaskType::Regression,
140                TaskType::Classification(5),
141            ],
142            shared_features: 10,
143            task_specific_features: 5,
144            task_correlation: 0.5,
145            task_noise: vec![0.1, 0.1, 0.1],
146            random_state: None,
147        }
148    }
149}
150
151/// Configuration for domain adaptation datasets
152#[derive(Debug, Clone)]
153pub struct DomainAdaptationConfig {
154    /// Number of source domains
155    pub n_source_domains: usize,
156    /// Domain shift parameters
157    pub domain_shifts: Vec<DomainShift>,
158    /// Label shift (different class distributions)
159    pub label_shift: bool,
160    /// Feature shift (different feature distributions)
161    pub feature_shift: bool,
162    /// Concept drift over time
163    pub concept_drift: bool,
164    /// Random seed
165    pub random_state: Option<u64>,
166}
167
168/// Domain shift types
169#[derive(Debug, Clone)]
170pub struct DomainShift {
171    /// Shift in feature means
172    pub mean_shift: Array1<f64>,
173    /// Shift in feature covariances
174    pub covariance_shift: Option<Array2<f64>>,
175    /// Shift strength
176    pub shift_strength: f64,
177}
178
179impl Default for DomainAdaptationConfig {
180    fn default() -> Self {
181        Self {
182            n_source_domains: 2,
183            domain_shifts: vec![],
184            label_shift: true,
185            feature_shift: true,
186            concept_drift: false,
187            random_state: None,
188        }
189    }
190}
191
192/// Advanced data generator
193pub struct AdvancedGenerator {
194    random_state: Option<u64>,
195}
196
197impl AdvancedGenerator {
198    /// Create a new advanced generator
199    pub fn new(_random_state: Option<u64>) -> Self {
200        Self {
201            random_state: _random_state,
202        }
203    }
204
205    /// Generate adversarial examples
206    pub fn make_adversarial_examples(
207        &self,
208        base_dataset: &Dataset,
209        config: AdversarialConfig,
210    ) -> Result<Dataset> {
211        let n_samples = base_dataset.n_samples();
212        let _n_features = base_dataset.n_features();
213
214        println!(
215            "Generating adversarial examples using {:?}",
216            config.attack_method
217        );
218
219        // Create adversarial perturbations
220        let perturbations = self.generate_perturbations(&base_dataset.data, &config)?;
221
222        // Apply perturbations
223        let adversarial_data = &base_dataset.data + &perturbations;
224
225        // Clip to valid range if needed
226        let clipped_data = adversarial_data.mapv(|x| x.clamp(-5.0, 5.0));
227
228        // Create adversarial labels
229        let adversarial_target = if let Some(target) = &base_dataset.target {
230            match config.target_class {
231                Some(target_class) => {
232                    // Targeted attack - change labels to target class
233                    Some(Array1::from_elem(n_samples, target_class as f64))
234                }
235                None => {
236                    // Untargeted attack - keep original labels but mark as adversarial
237                    Some(target.clone())
238                }
239            }
240        } else {
241            None
242        };
243
244        let mut metadata = base_dataset.metadata.clone();
245        let _old_description = metadata.get("description").cloned().unwrap_or_default();
246        let oldname = metadata.get("name").cloned().unwrap_or_default();
247
248        metadata.insert(
249            "description".to_string(),
250            format!(
251                "Adversarial examples generated using {:?}",
252                config.attack_method
253            ),
254        );
255        metadata.insert("name".to_string(), format!("{oldname} (Adversarial)"));
256
257        Ok(Dataset {
258            data: clipped_data,
259            target: adversarial_target,
260            targetnames: base_dataset.targetnames.clone(),
261            featurenames: base_dataset.featurenames.clone(),
262            feature_descriptions: base_dataset.feature_descriptions.clone(),
263            description: base_dataset.description.clone(),
264            metadata,
265        })
266    }
267
268    /// Generate anomaly detection dataset
269    pub fn make_anomaly_dataset(
270        &self,
271        n_samples: usize,
272        n_features: usize,
273        config: AnomalyConfig,
274    ) -> Result<Dataset> {
275        let n_anomalies = (n_samples as f64 * config.anomaly_fraction) as usize;
276        let n_normal = n_samples - n_anomalies;
277
278        println!("Generating anomaly dataset: {n_normal} normal, {n_anomalies} anomalous");
279
280        // Generate normal data
281        let normal_data =
282            self.generate_normal_data(n_normal, n_features, config.clustering_factor)?;
283
284        // Generate anomalous data
285        let anomalous_data =
286            self.generate_anomalous_data(n_anomalies, n_features, &normal_data, &config)?;
287
288        // Combine data
289        let mut combined_data = Array2::zeros((n_samples, n_features));
290        combined_data
291            .slice_mut(scirs2_core::ndarray::s![..n_normal, ..])
292            .assign(&normal_data);
293        combined_data
294            .slice_mut(scirs2_core::ndarray::s![n_normal.., ..])
295            .assign(&anomalous_data);
296
297        // Create labels (0 = normal, 1 = anomaly)
298        let mut target = Array1::zeros(n_samples);
299        target
300            .slice_mut(scirs2_core::ndarray::s![n_normal..])
301            .fill(1.0);
302
303        // Shuffle the data
304        let shuffled_indices = self.generate_shuffle_indices(n_samples)?;
305        let shuffled_data = self.shuffle_by_indices(&combined_data, &shuffled_indices);
306        let shuffled_target = self.shuffle_array_by_indices(&target, &shuffled_indices);
307
308        let metadata = crate::registry::DatasetMetadata {
309            name: "Anomaly Detection Dataset".to_string(),
310            description: format!(
311                "Synthetic anomaly detection dataset with {:.1}% anomalies",
312                config.anomaly_fraction * 100.0
313            ),
314            n_samples,
315            n_features,
316            task_type: "anomaly_detection".to_string(),
317            targetnames: Some(vec!["normal".to_string(), "anomaly".to_string()]),
318            ..Default::default()
319        };
320
321        Ok(Dataset::from_metadata(
322            shuffled_data,
323            Some(shuffled_target),
324            metadata,
325        ))
326    }
327
328    /// Generate multi-task learning dataset
329    pub fn make_multitask_dataset(
330        &self,
331        n_samples: usize,
332        config: MultiTaskConfig,
333    ) -> Result<MultiTaskDataset> {
334        let total_features =
335            config.shared_features + config.task_specific_features * config.n_tasks;
336
337        println!(
338            "Generating multi-task dataset: {} tasks, {} samples, {} features",
339            config.n_tasks, n_samples, total_features
340        );
341
342        // Generate shared features
343        let shared_data = self.generate_shared_features(n_samples, config.shared_features)?;
344
345        // Generate task-specific features and targets
346        let mut task_datasets = Vec::new();
347
348        for (task_id, task_type) in config.task_types.iter().enumerate() {
349            let task_specific_data = self.generate_task_specific_features(
350                n_samples,
351                config.task_specific_features,
352                task_id,
353            )?;
354
355            // Combine shared and task-specific features
356            let task_data = self.combine_features(&shared_data, &task_specific_data);
357
358            // Generate task target based on task type
359            let task_target = self.generate_task_target(
360                &task_data,
361                task_type,
362                config.task_correlation,
363                config.task_noise.get(task_id).unwrap_or(&0.1),
364            )?;
365
366            let task_metadata = crate::registry::DatasetMetadata {
367                name: format!("Task {task_id}"),
368                description: format!("Multi-task learning task {task_id} ({task_type:?})"),
369                n_samples,
370                n_features: task_data.ncols(),
371                task_type: match task_type {
372                    TaskType::Classification(_) => "classification".to_string(),
373                    TaskType::Regression => "regression".to_string(),
374                    TaskType::Ordinal(_) => "ordinal_regression".to_string(),
375                },
376                ..Default::default()
377            };
378
379            task_datasets.push(Dataset::from_metadata(
380                task_data,
381                Some(task_target),
382                task_metadata,
383            ));
384        }
385
386        Ok(MultiTaskDataset {
387            tasks: task_datasets,
388            shared_features: config.shared_features,
389            task_correlation: config.task_correlation,
390        })
391    }
392
393    /// Generate domain adaptation dataset
394    pub fn make_domain_adaptation_dataset(
395        &self,
396        n_samples_per_domain: usize,
397        n_features: usize,
398        n_classes: usize,
399        config: DomainAdaptationConfig,
400    ) -> Result<DomainAdaptationDataset> {
401        let total_domains = config.n_source_domains + 1; // +1 for target _domain
402
403        println!(
404            "Generating _domain adaptation dataset: {total_domains} domains, {n_samples_per_domain} samples each"
405        );
406
407        let mut domain_datasets = Vec::new();
408
409        // Generate source _domain (reference)
410        let source_dataset =
411            self.generate_base_domain_dataset(n_samples_per_domain, n_features, n_classes)?;
412
413        domain_datasets.push(("source".to_string(), source_dataset.clone()));
414
415        // Generate additional source domains with shifts
416        for domain_id in 1..config.n_source_domains {
417            let shift = if domain_id - 1 < config.domain_shifts.len() {
418                &config.domain_shifts[domain_id - 1]
419            } else {
420                // Generate default shift
421                &DomainShift {
422                    mean_shift: Array1::from_elem(n_features, 0.5),
423                    covariance_shift: None,
424                    shift_strength: 1.0,
425                }
426            };
427
428            let shifted_dataset = self.apply_domain_shift(&source_dataset, shift)?;
429            domain_datasets.push((format!("source_{domain_id}"), shifted_dataset));
430        }
431
432        // Generate target _domain with different shift
433        let target_shift = DomainShift {
434            mean_shift: Array1::from_elem(n_features, 1.0),
435            covariance_shift: None,
436            shift_strength: 1.5,
437        };
438
439        let target_dataset = self.apply_domain_shift(&source_dataset, &target_shift)?;
440        domain_datasets.push(("target".to_string(), target_dataset));
441
442        Ok(DomainAdaptationDataset {
443            domains: domain_datasets,
444            n_source_domains: config.n_source_domains,
445        })
446    }
447
448    /// Generate few-shot learning dataset
449    pub fn make_few_shot_dataset(
450        &self,
451        n_way: usize,
452        k_shot: usize,
453        n_query: usize,
454        n_episodes: usize,
455        n_features: usize,
456    ) -> Result<FewShotDataset> {
457        println!(
458            "Generating few-_shot dataset: {n_way}-_way {k_shot}-_shot, {n_episodes} _episodes"
459        );
460
461        let mut episodes = Vec::new();
462
463        for episode_id in 0..n_episodes {
464            let support_set = self.generate_support_set(n_way, k_shot, n_features, episode_id)?;
465            let query_set =
466                self.generate_query_set(n_way, n_query, n_features, &support_set, episode_id)?;
467
468            episodes.push(FewShotEpisode {
469                support_set,
470                query_set,
471                n_way,
472                k_shot,
473            });
474        }
475
476        Ok(FewShotDataset {
477            episodes,
478            n_way,
479            k_shot,
480            n_query,
481        })
482    }
483
484    /// Generate continual learning dataset with concept drift
485    pub fn make_continual_learning_dataset(
486        &self,
487        n_tasks: usize,
488        n_samples_per_task: usize,
489        n_features: usize,
490        n_classes: usize,
491        concept_drift_strength: f64,
492    ) -> Result<ContinualLearningDataset> {
493        println!("Generating continual learning dataset: {n_tasks} _tasks with concept drift");
494
495        let mut task_datasets = Vec::new();
496        let mut base_centers = self.generate_class_centers(n_classes, n_features)?;
497
498        for task_id in 0..n_tasks {
499            // Apply concept drift
500            if task_id > 0 {
501                let drift = Array2::from_shape_fn((n_classes, n_features), |_| {
502                    thread_rng().random::<f64>() * concept_drift_strength
503                });
504                base_centers = base_centers + drift;
505            }
506
507            let task_dataset = self.generate_classification_from_centers(
508                n_samples_per_task,
509                &base_centers,
510                1.0, // cluster_std
511                task_id as u64,
512            )?;
513
514            let mut metadata = task_dataset.metadata.clone();
515            metadata.insert(
516                "name".to_string(),
517                format!("Continual Learning Task {task_id}"),
518            );
519            metadata.insert(
520                "description".to_string(),
521                format!("Task {task_id} with concept drift _strength {concept_drift_strength:.2}"),
522            );
523
524            task_datasets.push(Dataset {
525                data: task_dataset.data,
526                target: task_dataset.target,
527                targetnames: task_dataset.targetnames,
528                featurenames: task_dataset.featurenames,
529                feature_descriptions: task_dataset.feature_descriptions,
530                description: task_dataset.description,
531                metadata,
532            });
533        }
534
535        Ok(ContinualLearningDataset {
536            tasks: task_datasets,
537            concept_drift_strength,
538        })
539    }
540
541    // Private helper methods
542
543    fn generate_perturbations(
544        &self,
545        data: &Array2<f64>,
546        config: &AdversarialConfig,
547    ) -> Result<Array2<f64>> {
548        let (n_samples, n_features) = data.dim();
549
550        match config.attack_method {
551            AttackMethod::FGSM => {
552                // Fast Gradient Sign Method
553                let mut perturbations = Array2::zeros((n_samples, n_features));
554                for i in 0..n_samples {
555                    for j in 0..n_features {
556                        let sign = if thread_rng().random::<f64>() > 0.5 {
557                            1.0
558                        } else {
559                            -1.0
560                        };
561                        perturbations[[i, j]] = config.epsilon * sign;
562                    }
563                }
564                Ok(perturbations)
565            }
566            AttackMethod::PGD => {
567                // Projected Gradient Descent (simplified)
568                let mut perturbations: Array2<f64> = Array2::zeros((n_samples, n_features));
569                for _iter in 0..config.iterations {
570                    for i in 0..n_samples {
571                        for j in 0..n_features {
572                            let gradient = thread_rng().random::<f64>() * 2.0 - 1.0; // Simulated gradient
573                            perturbations[[i, j]] += config.step_size * gradient.signum();
574                            // Clip to epsilon ball
575                            perturbations[[i, j]] =
576                                perturbations[[i, j]].clamp(-config.epsilon, config.epsilon);
577                        }
578                    }
579                }
580                Ok(perturbations)
581            }
582            AttackMethod::RandomNoise => {
583                // Random noise baseline
584                let perturbations = Array2::from_shape_fn((n_samples, n_features), |_| {
585                    (thread_rng().random::<f64>() * 2.0 - 1.0) * config.epsilon
586                });
587                Ok(perturbations)
588            }
589            _ => {
590                // For other methods, use random noise directly
591                let mut perturbations = Array2::zeros(data.dim());
592                for i in 0..data.nrows() {
593                    for j in 0..data.ncols() {
594                        let noise = thread_rng().random::<f64>() * 2.0 - 1.0;
595                        perturbations[[i, j]] = config.epsilon * noise;
596                    }
597                }
598                Ok(perturbations)
599            }
600        }
601    }
602
603    fn generate_normal_data(
604        &self,
605        n_samples: usize,
606        n_features: usize,
607        clustering_factor: f64,
608    ) -> Result<Array2<f64>> {
609        // Generate clustered normal data
610        use crate::generators::make_blobs;
611        let n_clusters = ((n_features as f64).sqrt() as usize).max(2);
612        let dataset = make_blobs(
613            n_samples,
614            n_features,
615            n_clusters,
616            clustering_factor,
617            self.random_state,
618        )?;
619        Ok(dataset.data)
620    }
621
622    fn generate_anomalous_data(
623        &self,
624        n_anomalies: usize,
625        n_features: usize,
626        normal_data: &Array2<f64>,
627        config: &AnomalyConfig,
628    ) -> Result<Array2<f64>> {
629        use scirs2_core::random::Rng;
630        let mut rng = thread_rng();
631
632        match config.anomaly_type {
633            AnomalyType::Point => {
634                // Point _anomalies - outliers far from normal distribution
635                let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
636                    DatasetsError::ComputationError(
637                        "Failed to compute mean for normal data".to_string(),
638                    )
639                })?;
640                let normal_std = normal_data.std_axis(Axis(0), 0.0);
641
642                let mut anomalies = Array2::zeros((n_anomalies, n_features));
643                for i in 0..n_anomalies {
644                    for j in 0..n_features {
645                        let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
646                        anomalies[[i, j]] =
647                            normal_mean[j] + direction * config.severity * normal_std[j];
648                    }
649                }
650                Ok(anomalies)
651            }
652            AnomalyType::Contextual => {
653                // Contextual _anomalies - normal values but in wrong context
654                let mut anomalies: Array2<f64> = Array2::zeros((n_anomalies, n_features));
655                for i in 0..n_anomalies {
656                    // Pick a random normal sample and permute some _features
657                    let base_idx = rng.sample(Uniform::new(0, normal_data.nrows()).unwrap());
658                    let mut anomaly = normal_data.row(base_idx).to_owned();
659
660                    // Permute random _features
661                    let n_permute = (n_features as f64 * 0.3) as usize;
662                    for _ in 0..n_permute {
663                        let j = rng.sample(Uniform::new(0, n_features).unwrap());
664                        let k = rng.sample(Uniform::new(0, n_features).unwrap());
665                        let temp = anomaly[j];
666                        anomaly[j] = anomaly[k];
667                        anomaly[k] = temp;
668                    }
669
670                    anomalies.row_mut(i).assign(&anomaly);
671                }
672                Ok(anomalies)
673            }
674            _ => {
675                // Default to point _anomalies implementation
676                let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
677                    DatasetsError::ComputationError(
678                        "Failed to compute mean for normal data".to_string(),
679                    )
680                })?;
681                let normal_std = normal_data.std_axis(Axis(0), 0.0);
682
683                let mut anomalies = Array2::zeros((n_anomalies, n_features));
684                for i in 0..n_anomalies {
685                    for j in 0..n_features {
686                        let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
687                        anomalies[[i, j]] =
688                            normal_mean[j] + direction * config.severity * normal_std[j];
689                    }
690                }
691                Ok(anomalies)
692            }
693        }
694    }
695
696    fn generate_shuffle_indices(&self, n_samples: usize) -> Result<Vec<usize>> {
697        use scirs2_core::random::Rng;
698        let mut rng = thread_rng();
699        let mut indices: Vec<usize> = (0..n_samples).collect();
700
701        // Simple shuffle using Fisher-Yates
702        for i in (1..n_samples).rev() {
703            let j = rng.sample(Uniform::new(0, i).unwrap());
704            indices.swap(i, j);
705        }
706
707        Ok(indices)
708    }
709
710    fn shuffle_by_indices(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
711        let mut shuffled = Array2::zeros(data.dim());
712        for (new_idx, &old_idx) in indices.iter().enumerate() {
713            shuffled.row_mut(new_idx).assign(&data.row(old_idx));
714        }
715        shuffled
716    }
717
718    fn shuffle_array_by_indices(&self, array: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
719        let mut shuffled = Array1::zeros(array.len());
720        for (new_idx, &old_idx) in indices.iter().enumerate() {
721            shuffled[new_idx] = array[old_idx];
722        }
723        shuffled
724    }
725
726    fn generate_shared_features(&self, n_samples: usize, n_features: usize) -> Result<Array2<f64>> {
727        // Generate shared _features using multivariate normal distribution
728        let data = Array2::from_shape_fn((n_samples, n_features), |_| {
729            thread_rng().random::<f64>() * 2.0 - 1.0 // Standard normal approximation
730        });
731        Ok(data)
732    }
733
734    fn generate_task_specific_features(
735        &self,
736        n_samples: usize,
737        n_features: usize,
738        task_id: usize,
739    ) -> Result<Array2<f64>> {
740        // Generate task-specific _features with slight bias per task
741        let task_bias = task_id as f64 * 0.1;
742        let data = Array2::from_shape_fn((n_samples, n_features), |_| {
743            thread_rng().random::<f64>() * 2.0 - 1.0 + task_bias
744        });
745        Ok(data)
746    }
747
748    fn combine_features(&self, shared: &Array2<f64>, task_specific: &Array2<f64>) -> Array2<f64> {
749        let n_samples = shared.nrows();
750        let total_features = shared.ncols() + task_specific.ncols();
751        let mut combined = Array2::zeros((n_samples, total_features));
752
753        combined
754            .slice_mut(scirs2_core::ndarray::s![.., ..shared.ncols()])
755            .assign(shared);
756        combined
757            .slice_mut(scirs2_core::ndarray::s![.., shared.ncols()..])
758            .assign(task_specific);
759
760        combined
761    }
762
763    fn generate_task_target(
764        &self,
765        data: &Array2<f64>,
766        task_type: &TaskType,
767        correlation: f64,
768        noise: &f64,
769    ) -> Result<Array1<f64>> {
770        let n_samples = data.nrows();
771
772        match task_type {
773            TaskType::Classification(n_classes) => {
774                // Generate classification target based on data
775                let target = Array1::from_shape_fn(n_samples, |i| {
776                    let feature_sum = data.row(i).sum();
777                    let class = ((feature_sum * correlation).abs() as usize) % n_classes;
778                    class as f64
779                });
780                Ok(target)
781            }
782            TaskType::Regression => {
783                // Generate regression target as linear combination of features
784                let target = Array1::from_shape_fn(n_samples, |i| {
785                    let weighted_sum = data
786                        .row(i)
787                        .iter()
788                        .enumerate()
789                        .map(|(j, &x)| x * (j as f64 + 1.0) * correlation)
790                        .sum::<f64>();
791                    weighted_sum + thread_rng().random::<f64>() * noise
792                });
793                Ok(target)
794            }
795            TaskType::Ordinal(n_levels) => {
796                // Generate ordinal target
797                let target = Array1::from_shape_fn(n_samples, |i| {
798                    let feature_sum = data.row(i).sum();
799                    let level = ((feature_sum * correlation).abs() as usize) % n_levels;
800                    level as f64
801                });
802                Ok(target)
803            }
804        }
805    }
806
807    fn generate_base_domain_dataset(
808        &self,
809        n_samples: usize,
810        n_features: usize,
811        n_classes: usize,
812    ) -> Result<Dataset> {
813        use crate::generators::make_classification;
814        make_classification(
815            n_samples,
816            n_features,
817            n_classes,
818            2,
819            n_features / 2,
820            self.random_state,
821        )
822    }
823
824    fn apply_domain_shift(&self, base_dataset: &Dataset, shift: &DomainShift) -> Result<Dataset> {
825        let shifted_data = &base_dataset.data + &shift.mean_shift;
826
827        let mut metadata = base_dataset.metadata.clone();
828        let old_description = metadata.get("description").cloned().unwrap_or_default();
829        metadata.insert(
830            "description".to_string(),
831            format!("{old_description} (Domain Shifted)"),
832        );
833
834        Ok(Dataset {
835            data: shifted_data,
836            target: base_dataset.target.clone(),
837            targetnames: base_dataset.targetnames.clone(),
838            featurenames: base_dataset.featurenames.clone(),
839            feature_descriptions: base_dataset.feature_descriptions.clone(),
840            description: base_dataset.description.clone(),
841            metadata,
842        })
843    }
844
845    fn generate_support_set(
846        &self,
847        n_way: usize,
848        k_shot: usize,
849        n_features: usize,
850        episode_id: usize,
851    ) -> Result<Dataset> {
852        let n_samples = n_way * k_shot;
853        use crate::generators::make_classification;
854        make_classification(
855            n_samples,
856            n_features,
857            n_way,
858            1,
859            n_features / 2,
860            Some(episode_id as u64),
861        )
862    }
863
864    fn generate_query_set(
865        &self,
866        n_way: usize,
867        n_query: usize,
868        n_features: usize,
869        _set: &Dataset,
870        episode_id: usize,
871    ) -> Result<Dataset> {
872        let n_samples = n_way * n_query;
873        use crate::generators::make_classification;
874        make_classification(
875            n_samples,
876            n_features,
877            n_way,
878            1,
879            n_features / 2,
880            Some(episode_id as u64 + 1000),
881        )
882    }
883
884    fn generate_class_centers(&self, n_classes: usize, n_features: usize) -> Result<Array2<f64>> {
885        let centers = Array2::from_shape_fn((n_classes, n_features), |_| {
886            thread_rng().random::<f64>() * 4.0 - 2.0
887        });
888        Ok(centers)
889    }
890
891    fn generate_classification_from_centers(
892        &self,
893        n_samples: usize,
894        centers: &Array2<f64>,
895        cluster_std: f64,
896        seed: u64,
897    ) -> Result<Dataset> {
898        use crate::generators::make_blobs;
899        make_blobs(
900            n_samples,
901            centers.ncols(),
902            centers.nrows(),
903            cluster_std,
904            Some(seed),
905        )
906    }
907}
908
909/// Multi-task learning dataset container
910#[derive(Debug)]
911pub struct MultiTaskDataset {
912    /// Individual task datasets
913    pub tasks: Vec<Dataset>,
914    /// Number of shared features
915    pub shared_features: usize,
916    /// Correlation between tasks
917    pub task_correlation: f64,
918}
919
920/// Domain adaptation dataset container
921#[derive(Debug)]
922pub struct DomainAdaptationDataset {
923    /// Datasets for each domain (name, dataset)
924    pub domains: Vec<(String, Dataset)>,
925    /// Number of source domains
926    pub n_source_domains: usize,
927}
928
929/// Few-shot learning episode
930#[derive(Debug)]
931pub struct FewShotEpisode {
932    /// Support set for learning
933    pub support_set: Dataset,
934    /// Query set for evaluation
935    pub query_set: Dataset,
936    /// Number of classes (ways)
937    pub n_way: usize,
938    /// Number of examples per class (shots)
939    pub k_shot: usize,
940}
941
942/// Few-shot learning dataset
943#[derive(Debug)]
944pub struct FewShotDataset {
945    /// Training/evaluation episodes
946    pub episodes: Vec<FewShotEpisode>,
947    /// Number of classes per episode
948    pub n_way: usize,
949    /// Number of shots per class
950    pub k_shot: usize,
951    /// Number of query samples per class
952    pub n_query: usize,
953}
954
955/// Continual learning dataset
956#[derive(Debug)]
957pub struct ContinualLearningDataset {
958    /// Sequential tasks
959    pub tasks: Vec<Dataset>,
960    /// Strength of concept drift between tasks
961    pub concept_drift_strength: f64,
962}
963
964/// Convenience functions for advanced data generation
965///
966/// Generate adversarial examples from a base dataset
967#[allow(dead_code)]
968pub fn make_adversarial_examples(
969    base_dataset: &Dataset,
970    config: AdversarialConfig,
971) -> Result<Dataset> {
972    let generator = AdvancedGenerator::new(config.random_state);
973    generator.make_adversarial_examples(base_dataset, config)
974}
975
976/// Generate anomaly detection dataset
977#[allow(dead_code)]
978pub fn make_anomaly_dataset(
979    n_samples: usize,
980    n_features: usize,
981    config: AnomalyConfig,
982) -> Result<Dataset> {
983    let generator = AdvancedGenerator::new(config.random_state);
984    generator.make_anomaly_dataset(n_samples, n_features, config)
985}
986
987/// Generate multi-task learning dataset
988#[allow(dead_code)]
989pub fn make_multitask_dataset(
990    n_samples: usize,
991    config: MultiTaskConfig,
992) -> Result<MultiTaskDataset> {
993    let generator = AdvancedGenerator::new(config.random_state);
994    generator.make_multitask_dataset(n_samples, config)
995}
996
997/// Generate domain adaptation dataset
998#[allow(dead_code)]
999pub fn make_domain_adaptation_dataset(
1000    n_samples_per_domain: usize,
1001    n_features: usize,
1002    n_classes: usize,
1003    config: DomainAdaptationConfig,
1004) -> Result<DomainAdaptationDataset> {
1005    let generator = AdvancedGenerator::new(config.random_state);
1006    generator.make_domain_adaptation_dataset(n_samples_per_domain, n_features, n_classes, config)
1007}
1008
1009/// Generate few-shot learning dataset
1010#[allow(dead_code)]
1011pub fn make_few_shot_dataset(
1012    n_way: usize,
1013    k_shot: usize,
1014    n_query: usize,
1015    n_episodes: usize,
1016    n_features: usize,
1017) -> Result<FewShotDataset> {
1018    let generator = AdvancedGenerator::new(Some(42));
1019    generator.make_few_shot_dataset(n_way, k_shot, n_query, n_episodes, n_features)
1020}
1021
1022/// Generate continual learning dataset
1023#[allow(dead_code)]
1024pub fn make_continual_learning_dataset(
1025    n_tasks: usize,
1026    n_samples_per_task: usize,
1027    n_features: usize,
1028    n_classes: usize,
1029    concept_drift_strength: f64,
1030) -> Result<ContinualLearningDataset> {
1031    let generator = AdvancedGenerator::new(Some(42));
1032    generator.make_continual_learning_dataset(
1033        n_tasks,
1034        n_samples_per_task,
1035        n_features,
1036        n_classes,
1037        concept_drift_strength,
1038    )
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043    use super::*;
1044    use crate::generators::make_classification;
1045
1046    #[test]
1047    fn test_adversarial_config() {
1048        let config = AdversarialConfig::default();
1049        assert_eq!(config.epsilon, 0.1);
1050        assert_eq!(config.attack_method, AttackMethod::FGSM);
1051        assert_eq!(config.iterations, 10);
1052    }
1053
1054    #[test]
1055    fn test_anomaly_dataset_generation() {
1056        let config = AnomalyConfig {
1057            anomaly_fraction: 0.2,
1058            anomaly_type: AnomalyType::Point,
1059            severity: 2.0,
1060            ..Default::default()
1061        };
1062
1063        let dataset = make_anomaly_dataset(100, 10, config).unwrap();
1064
1065        assert_eq!(dataset.n_samples(), 100);
1066        assert_eq!(dataset.n_features(), 10);
1067        assert!(dataset.target.is_some());
1068
1069        // Check that we have both normal and anomalous samples
1070        let target = dataset.target.unwrap();
1071        let anomalies = target.iter().filter(|&&x| x == 1.0).count();
1072        assert!(anomalies > 0);
1073        assert!(anomalies < 100);
1074    }
1075
1076    #[test]
1077    fn test_multitask_dataset_generation() {
1078        let config = MultiTaskConfig {
1079            n_tasks: 2,
1080            task_types: vec![TaskType::Classification(3), TaskType::Regression],
1081            shared_features: 5,
1082            task_specific_features: 3,
1083            ..Default::default()
1084        };
1085
1086        let dataset = make_multitask_dataset(50, config).unwrap();
1087
1088        assert_eq!(dataset.tasks.len(), 2);
1089        assert_eq!(dataset.shared_features, 5);
1090
1091        for task in &dataset.tasks {
1092            assert_eq!(task.n_samples(), 50);
1093            assert!(task.target.is_some());
1094        }
1095    }
1096
1097    #[test]
1098    fn test_adversarial_examples_generation() {
1099        let base_dataset = make_classification(100, 10, 3, 2, 8, Some(42)).unwrap();
1100        let config = AdversarialConfig {
1101            epsilon: 0.1,
1102            attack_method: AttackMethod::FGSM,
1103            ..Default::default()
1104        };
1105
1106        let adversarial_dataset = make_adversarial_examples(&base_dataset, config).unwrap();
1107
1108        assert_eq!(adversarial_dataset.n_samples(), base_dataset.n_samples());
1109        assert_eq!(adversarial_dataset.n_features(), base_dataset.n_features());
1110
1111        // Check that the data has been perturbed (mean difference can cancel; use max absolute diff)
1112        let diff = &adversarial_dataset.data - &base_dataset.data;
1113        let mut max_abs = 0.0_f64;
1114        for v in diff.iter() {
1115            let a = v.abs();
1116            if a > max_abs {
1117                max_abs = a;
1118            }
1119        }
1120        assert!(
1121            max_abs > 1e-9,
1122            "Adversarial perturbation appears to be zero (max abs diff = {max_abs})"
1123        );
1124    }
1125
1126    #[test]
1127    fn test_few_shot_dataset() {
1128        let dataset = make_few_shot_dataset(5, 3, 10, 2, 20).unwrap();
1129
1130        assert_eq!(dataset.n_way, 5);
1131        assert_eq!(dataset.k_shot, 3);
1132        assert_eq!(dataset.n_query, 10);
1133        assert_eq!(dataset.episodes.len(), 2);
1134
1135        for episode in &dataset.episodes {
1136            assert_eq!(episode.n_way, 5);
1137            assert_eq!(episode.k_shot, 3);
1138            assert_eq!(episode.support_set.n_samples(), 5 * 3); // n_way * k_shot
1139            assert_eq!(episode.query_set.n_samples(), 5 * 10); // n_way * n_query
1140        }
1141    }
1142
1143    #[test]
1144    fn test_continual_learning_dataset() {
1145        let dataset = make_continual_learning_dataset(3, 100, 10, 5, 0.5).unwrap();
1146
1147        assert_eq!(dataset.tasks.len(), 3);
1148        assert_eq!(dataset.concept_drift_strength, 0.5);
1149
1150        for task in &dataset.tasks {
1151            assert_eq!(task.n_samples(), 100);
1152            assert_eq!(task.n_features(), 10);
1153            assert!(task.target.is_some());
1154        }
1155    }
1156}