Skip to main content

scirs2_datasets/
advanced_generators.rs

1//! Advanced synthetic data generators
2//!
3//! This module provides sophisticated synthetic data generation capabilities
4//! for complex scenarios including adversarial examples, anomaly detection,
5//! multi-task learning, and domain adaptation.
6
7use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::Uniform;
12
13/// Configuration for adversarial example generation
14#[derive(Debug, Clone)]
15pub struct AdversarialConfig {
16    /// Perturbation strength (epsilon)
17    pub epsilon: f64,
18    /// Attack method
19    pub attack_method: AttackMethod,
20    /// Target class for targeted attacks
21    pub target_class: Option<usize>,
22    /// Number of attack iterations
23    pub iterations: usize,
24    /// Step size for iterative attacks
25    pub step_size: f64,
26    /// Random seed for reproducibility
27    pub random_state: Option<u64>,
28}
29
30/// Adversarial attack methods
31#[derive(Debug, Clone, PartialEq)]
32pub enum AttackMethod {
33    /// Fast Gradient Sign Method
34    FGSM,
35    /// Projected Gradient Descent
36    PGD,
37    /// Carlini & Wagner attack
38    CW,
39    /// DeepFool attack
40    DeepFool,
41    /// Random noise baseline
42    RandomNoise,
43}
44
45impl Default for AdversarialConfig {
46    fn default() -> Self {
47        Self {
48            epsilon: 0.1,
49            attack_method: AttackMethod::FGSM,
50            target_class: None,
51            iterations: 10,
52            step_size: 0.01,
53            random_state: None,
54        }
55    }
56}
57
58/// Configuration for anomaly detection datasets
59#[derive(Debug, Clone)]
60pub struct AnomalyConfig {
61    /// Fraction of anomalous samples
62    pub anomaly_fraction: f64,
63    /// Type of anomalies to generate
64    pub anomaly_type: AnomalyType,
65    /// Severity of anomalies
66    pub severity: f64,
67    /// Whether to include multiple anomaly types
68    pub mixed_anomalies: bool,
69    /// Clustering factor for normal data
70    pub clustering_factor: f64,
71    /// Random seed
72    pub random_state: Option<u64>,
73}
74
75/// Types of anomalies
76#[derive(Debug, Clone, PartialEq)]
77pub enum AnomalyType {
78    /// Point anomalies (outliers)
79    Point,
80    /// Contextual anomalies
81    Contextual,
82    /// Collective anomalies
83    Collective,
84    /// Adversarial anomalies
85    Adversarial,
86    /// Mixed anomaly types
87    Mixed,
88}
89
90impl Default for AnomalyConfig {
91    fn default() -> Self {
92        Self {
93            anomaly_fraction: 0.1,
94            anomaly_type: AnomalyType::Point,
95            severity: 2.0,
96            mixed_anomalies: false,
97            clustering_factor: 1.0,
98            random_state: None,
99        }
100    }
101}
102
103/// Configuration for multi-task learning datasets
104#[derive(Debug, Clone)]
105pub struct MultiTaskConfig {
106    /// Number of tasks
107    pub n_tasks: usize,
108    /// Task types (classification or regression)
109    pub task_types: Vec<TaskType>,
110    /// Shared feature dimensions
111    pub shared_features: usize,
112    /// Task-specific feature dimensions
113    pub task_specific_features: usize,
114    /// Correlation between tasks
115    pub task_correlation: f64,
116    /// Noise level for each task
117    pub task_noise: Vec<f64>,
118    /// Random seed
119    pub random_state: Option<u64>,
120}
121
122/// Task types for multi-task learning
123#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125    /// Classification task with specified number of classes
126    Classification(usize),
127    /// Regression task
128    Regression,
129    /// Ordinal regression
130    Ordinal(usize),
131}
132
133impl Default for MultiTaskConfig {
134    fn default() -> Self {
135        Self {
136            n_tasks: 3,
137            task_types: vec![
138                TaskType::Classification(3),
139                TaskType::Regression,
140                TaskType::Classification(5),
141            ],
142            shared_features: 10,
143            task_specific_features: 5,
144            task_correlation: 0.5,
145            task_noise: vec![0.1, 0.1, 0.1],
146            random_state: None,
147        }
148    }
149}
150
151/// Configuration for domain adaptation datasets
152#[derive(Debug, Clone)]
153pub struct DomainAdaptationConfig {
154    /// Number of source domains
155    pub n_source_domains: usize,
156    /// Domain shift parameters
157    pub domain_shifts: Vec<DomainShift>,
158    /// Label shift (different class distributions)
159    pub label_shift: bool,
160    /// Feature shift (different feature distributions)
161    pub feature_shift: bool,
162    /// Concept drift over time
163    pub concept_drift: bool,
164    /// Random seed
165    pub random_state: Option<u64>,
166}
167
168/// Domain shift types
169#[derive(Debug, Clone)]
170pub struct DomainShift {
171    /// Shift in feature means
172    pub mean_shift: Array1<f64>,
173    /// Shift in feature covariances
174    pub covariance_shift: Option<Array2<f64>>,
175    /// Shift strength
176    pub shift_strength: f64,
177}
178
179impl Default for DomainAdaptationConfig {
180    fn default() -> Self {
181        Self {
182            n_source_domains: 2,
183            domain_shifts: vec![],
184            label_shift: true,
185            feature_shift: true,
186            concept_drift: false,
187            random_state: None,
188        }
189    }
190}
191
192/// Advanced data generator
193pub struct AdvancedGenerator {
194    random_state: Option<u64>,
195}
196
197impl AdvancedGenerator {
198    /// Create a new advanced generator
199    pub fn new(_random_state: Option<u64>) -> Self {
200        Self {
201            random_state: _random_state,
202        }
203    }
204
205    /// Generate adversarial examples
206    pub fn make_adversarial_examples(
207        &self,
208        base_dataset: &Dataset,
209        config: AdversarialConfig,
210    ) -> Result<Dataset> {
211        let n_samples = base_dataset.n_samples();
212        let _n_features = base_dataset.n_features();
213
214        println!(
215            "Generating adversarial examples using {:?}",
216            config.attack_method
217        );
218
219        // Create adversarial perturbations
220        let perturbations = self.generate_perturbations(&base_dataset.data, &config)?;
221
222        // Apply perturbations
223        let adversarial_data = &base_dataset.data + &perturbations;
224
225        // Clip to valid range if needed
226        let clipped_data = adversarial_data.mapv(|x| x.clamp(-5.0, 5.0));
227
228        // Create adversarial labels
229        let adversarial_target = if let Some(target) = &base_dataset.target {
230            match config.target_class {
231                Some(target_class) => {
232                    // Targeted attack - change labels to target class
233                    Some(Array1::from_elem(n_samples, target_class as f64))
234                }
235                None => {
236                    // Untargeted attack - keep original labels but mark as adversarial
237                    Some(target.clone())
238                }
239            }
240        } else {
241            None
242        };
243
244        let mut metadata = base_dataset.metadata.clone();
245        let _old_description = metadata.get("description").cloned().unwrap_or_default();
246        let oldname = metadata.get("name").cloned().unwrap_or_default();
247
248        metadata.insert(
249            "description".to_string(),
250            format!(
251                "Adversarial examples generated using {:?}",
252                config.attack_method
253            ),
254        );
255        metadata.insert("name".to_string(), format!("{oldname} (Adversarial)"));
256
257        Ok(Dataset {
258            data: clipped_data,
259            target: adversarial_target,
260            targetnames: base_dataset.targetnames.clone(),
261            featurenames: base_dataset.featurenames.clone(),
262            feature_descriptions: base_dataset.feature_descriptions.clone(),
263            description: base_dataset.description.clone(),
264            metadata,
265        })
266    }
267
268    /// Generate anomaly detection dataset
269    pub fn make_anomaly_dataset(
270        &self,
271        n_samples: usize,
272        n_features: usize,
273        config: AnomalyConfig,
274    ) -> Result<Dataset> {
275        let n_anomalies = (n_samples as f64 * config.anomaly_fraction) as usize;
276        let n_normal = n_samples - n_anomalies;
277
278        println!("Generating anomaly dataset: {n_normal} normal, {n_anomalies} anomalous");
279
280        // Generate normal data
281        let normal_data =
282            self.generate_normal_data(n_normal, n_features, config.clustering_factor)?;
283
284        // Generate anomalous data
285        let anomalous_data =
286            self.generate_anomalous_data(n_anomalies, n_features, &normal_data, &config)?;
287
288        // Combine data
289        let mut combined_data = Array2::zeros((n_samples, n_features));
290        combined_data
291            .slice_mut(scirs2_core::ndarray::s![..n_normal, ..])
292            .assign(&normal_data);
293        combined_data
294            .slice_mut(scirs2_core::ndarray::s![n_normal.., ..])
295            .assign(&anomalous_data);
296
297        // Create labels (0 = normal, 1 = anomaly)
298        let mut target = Array1::zeros(n_samples);
299        target
300            .slice_mut(scirs2_core::ndarray::s![n_normal..])
301            .fill(1.0);
302
303        // Shuffle the data
304        let shuffled_indices = self.generate_shuffle_indices(n_samples)?;
305        let shuffled_data = self.shuffle_by_indices(&combined_data, &shuffled_indices);
306        let shuffled_target = self.shuffle_array_by_indices(&target, &shuffled_indices);
307
308        let metadata = crate::registry::DatasetMetadata {
309            name: "Anomaly Detection Dataset".to_string(),
310            description: format!(
311                "Synthetic anomaly detection dataset with {:.1}% anomalies",
312                config.anomaly_fraction * 100.0
313            ),
314            n_samples,
315            n_features,
316            task_type: "anomaly_detection".to_string(),
317            targetnames: Some(vec!["normal".to_string(), "anomaly".to_string()]),
318            ..Default::default()
319        };
320
321        Ok(Dataset::from_metadata(
322            shuffled_data,
323            Some(shuffled_target),
324            metadata,
325        ))
326    }
327
328    /// Generate multi-task learning dataset
329    pub fn make_multitask_dataset(
330        &self,
331        n_samples: usize,
332        config: MultiTaskConfig,
333    ) -> Result<MultiTaskDataset> {
334        let total_features =
335            config.shared_features + config.task_specific_features * config.n_tasks;
336
337        println!(
338            "Generating multi-task dataset: {} tasks, {} samples, {} features",
339            config.n_tasks, n_samples, total_features
340        );
341
342        // Generate shared features
343        let shared_data = self.generate_shared_features(n_samples, config.shared_features)?;
344
345        // Generate task-specific features and targets
346        let mut task_datasets = Vec::new();
347
348        for (task_id, task_type) in config.task_types.iter().enumerate() {
349            let task_specific_data = self.generate_task_specific_features(
350                n_samples,
351                config.task_specific_features,
352                task_id,
353            )?;
354
355            // Combine shared and task-specific features
356            let task_data = self.combine_features(&shared_data, &task_specific_data);
357
358            // Generate task target based on task type
359            let task_target = self.generate_task_target(
360                &task_data,
361                task_type,
362                config.task_correlation,
363                config.task_noise.get(task_id).unwrap_or(&0.1),
364            )?;
365
366            let task_metadata = crate::registry::DatasetMetadata {
367                name: format!("Task {task_id}"),
368                description: format!("Multi-task learning task {task_id} ({task_type:?})"),
369                n_samples,
370                n_features: task_data.ncols(),
371                task_type: match task_type {
372                    TaskType::Classification(_) => "classification".to_string(),
373                    TaskType::Regression => "regression".to_string(),
374                    TaskType::Ordinal(_) => "ordinal_regression".to_string(),
375                },
376                ..Default::default()
377            };
378
379            task_datasets.push(Dataset::from_metadata(
380                task_data,
381                Some(task_target),
382                task_metadata,
383            ));
384        }
385
386        Ok(MultiTaskDataset {
387            tasks: task_datasets,
388            shared_features: config.shared_features,
389            task_correlation: config.task_correlation,
390        })
391    }
392
393    /// Generate domain adaptation dataset
394    pub fn make_domain_adaptation_dataset(
395        &self,
396        n_samples_per_domain: usize,
397        n_features: usize,
398        n_classes: usize,
399        config: DomainAdaptationConfig,
400    ) -> Result<DomainAdaptationDataset> {
401        let total_domains = config.n_source_domains + 1; // +1 for target _domain
402
403        println!(
404            "Generating _domain adaptation dataset: {total_domains} domains, {n_samples_per_domain} samples each"
405        );
406
407        let mut domain_datasets = Vec::new();
408
409        // Generate source _domain (reference)
410        let source_dataset =
411            self.generate_base_domain_dataset(n_samples_per_domain, n_features, n_classes)?;
412
413        domain_datasets.push(("source".to_string(), source_dataset.clone()));
414
415        // Generate additional source domains with shifts
416        for domain_id in 1..config.n_source_domains {
417            let shift = if domain_id - 1 < config.domain_shifts.len() {
418                &config.domain_shifts[domain_id - 1]
419            } else {
420                // Generate default shift
421                &DomainShift {
422                    mean_shift: Array1::from_elem(n_features, 0.5),
423                    covariance_shift: None,
424                    shift_strength: 1.0,
425                }
426            };
427
428            let shifted_dataset = self.apply_domain_shift(&source_dataset, shift)?;
429            domain_datasets.push((format!("source_{domain_id}"), shifted_dataset));
430        }
431
432        // Generate target _domain with different shift
433        let target_shift = DomainShift {
434            mean_shift: Array1::from_elem(n_features, 1.0),
435            covariance_shift: None,
436            shift_strength: 1.5,
437        };
438
439        let target_dataset = self.apply_domain_shift(&source_dataset, &target_shift)?;
440        domain_datasets.push(("target".to_string(), target_dataset));
441
442        Ok(DomainAdaptationDataset {
443            domains: domain_datasets,
444            n_source_domains: config.n_source_domains,
445        })
446    }
447
448    /// Generate few-shot learning dataset
449    pub fn make_few_shot_dataset(
450        &self,
451        n_way: usize,
452        k_shot: usize,
453        n_query: usize,
454        n_episodes: usize,
455        n_features: usize,
456    ) -> Result<FewShotDataset> {
457        println!(
458            "Generating few-_shot dataset: {n_way}-_way {k_shot}-_shot, {n_episodes} _episodes"
459        );
460
461        let mut episodes = Vec::new();
462
463        for episode_id in 0..n_episodes {
464            let support_set = self.generate_support_set(n_way, k_shot, n_features, episode_id)?;
465            let query_set =
466                self.generate_query_set(n_way, n_query, n_features, &support_set, episode_id)?;
467
468            episodes.push(FewShotEpisode {
469                support_set,
470                query_set,
471                n_way,
472                k_shot,
473            });
474        }
475
476        Ok(FewShotDataset {
477            episodes,
478            n_way,
479            k_shot,
480            n_query,
481        })
482    }
483
484    /// Generate continual learning dataset with concept drift
485    pub fn make_continual_learning_dataset(
486        &self,
487        n_tasks: usize,
488        n_samples_per_task: usize,
489        n_features: usize,
490        n_classes: usize,
491        concept_drift_strength: f64,
492    ) -> Result<ContinualLearningDataset> {
493        println!("Generating continual learning dataset: {n_tasks} _tasks with concept drift");
494
495        let mut task_datasets = Vec::new();
496        let mut base_centers = self.generate_class_centers(n_classes, n_features)?;
497
498        for task_id in 0..n_tasks {
499            // Apply concept drift
500            if task_id > 0 {
501                let drift = Array2::from_shape_fn((n_classes, n_features), |_| {
502                    thread_rng().random::<f64>() * concept_drift_strength
503                });
504                base_centers = base_centers + drift;
505            }
506
507            let task_dataset = self.generate_classification_from_centers(
508                n_samples_per_task,
509                &base_centers,
510                1.0, // cluster_std
511                task_id as u64,
512            )?;
513
514            let mut metadata = task_dataset.metadata.clone();
515            metadata.insert(
516                "name".to_string(),
517                format!("Continual Learning Task {task_id}"),
518            );
519            metadata.insert(
520                "description".to_string(),
521                format!("Task {task_id} with concept drift _strength {concept_drift_strength:.2}"),
522            );
523
524            task_datasets.push(Dataset {
525                data: task_dataset.data,
526                target: task_dataset.target,
527                targetnames: task_dataset.targetnames,
528                featurenames: task_dataset.featurenames,
529                feature_descriptions: task_dataset.feature_descriptions,
530                description: task_dataset.description,
531                metadata,
532            });
533        }
534
535        Ok(ContinualLearningDataset {
536            tasks: task_datasets,
537            concept_drift_strength,
538        })
539    }
540
541    // Private helper methods
542
543    fn generate_perturbations(
544        &self,
545        data: &Array2<f64>,
546        config: &AdversarialConfig,
547    ) -> Result<Array2<f64>> {
548        let (n_samples, n_features) = data.dim();
549
550        match config.attack_method {
551            AttackMethod::FGSM => {
552                // Fast Gradient Sign Method
553                let mut perturbations = Array2::zeros((n_samples, n_features));
554                for i in 0..n_samples {
555                    for j in 0..n_features {
556                        let sign = if thread_rng().random::<f64>() > 0.5 {
557                            1.0
558                        } else {
559                            -1.0
560                        };
561                        perturbations[[i, j]] = config.epsilon * sign;
562                    }
563                }
564                Ok(perturbations)
565            }
566            AttackMethod::PGD => {
567                // Projected Gradient Descent (simplified)
568                let mut perturbations: Array2<f64> = Array2::zeros((n_samples, n_features));
569                for _iter in 0..config.iterations {
570                    for i in 0..n_samples {
571                        for j in 0..n_features {
572                            let gradient = thread_rng().random::<f64>() * 2.0 - 1.0; // Simulated gradient
573                            perturbations[[i, j]] += config.step_size * gradient.signum();
574                            // Clip to epsilon ball
575                            perturbations[[i, j]] =
576                                perturbations[[i, j]].clamp(-config.epsilon, config.epsilon);
577                        }
578                    }
579                }
580                Ok(perturbations)
581            }
582            AttackMethod::RandomNoise => {
583                // Random noise baseline
584                let perturbations = Array2::from_shape_fn((n_samples, n_features), |_| {
585                    (thread_rng().random::<f64>() * 2.0 - 1.0) * config.epsilon
586                });
587                Ok(perturbations)
588            }
589            _ => {
590                // For other methods, use random noise directly
591                let mut perturbations = Array2::zeros(data.dim());
592                for i in 0..data.nrows() {
593                    for j in 0..data.ncols() {
594                        let noise = thread_rng().random::<f64>() * 2.0 - 1.0;
595                        perturbations[[i, j]] = config.epsilon * noise;
596                    }
597                }
598                Ok(perturbations)
599            }
600        }
601    }
602
603    fn generate_normal_data(
604        &self,
605        n_samples: usize,
606        n_features: usize,
607        clustering_factor: f64,
608    ) -> Result<Array2<f64>> {
609        // Generate clustered normal data
610        use crate::generators::make_blobs;
611        let n_clusters = ((n_features as f64).sqrt() as usize).max(2);
612        let dataset = make_blobs(
613            n_samples,
614            n_features,
615            n_clusters,
616            clustering_factor,
617            self.random_state,
618        )?;
619        Ok(dataset.data)
620    }
621
622    fn generate_anomalous_data(
623        &self,
624        n_anomalies: usize,
625        n_features: usize,
626        normal_data: &Array2<f64>,
627        config: &AnomalyConfig,
628    ) -> Result<Array2<f64>> {
629        use scirs2_core::random::{Rng, RngExt};
630        let mut rng = thread_rng();
631
632        match config.anomaly_type {
633            AnomalyType::Point => {
634                // Point _anomalies - outliers far from normal distribution
635                let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
636                    DatasetsError::ComputationError(
637                        "Failed to compute mean for normal data".to_string(),
638                    )
639                })?;
640                let normal_std = normal_data.std_axis(Axis(0), 0.0);
641
642                let mut anomalies = Array2::zeros((n_anomalies, n_features));
643                for i in 0..n_anomalies {
644                    for j in 0..n_features {
645                        let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
646                        anomalies[[i, j]] =
647                            normal_mean[j] + direction * config.severity * normal_std[j];
648                    }
649                }
650                Ok(anomalies)
651            }
652            AnomalyType::Contextual => {
653                // Contextual _anomalies - normal values but in wrong context
654                let mut anomalies: Array2<f64> = Array2::zeros((n_anomalies, n_features));
655                for i in 0..n_anomalies {
656                    // Pick a random normal sample and permute some _features
657                    let base_idx =
658                        rng.sample(Uniform::new(0, normal_data.nrows()).expect("Operation failed"));
659                    let mut anomaly = normal_data.row(base_idx).to_owned();
660
661                    // Permute random _features
662                    let n_permute = (n_features as f64 * 0.3) as usize;
663                    for _ in 0..n_permute {
664                        let j = rng.sample(Uniform::new(0, n_features).expect("Operation failed"));
665                        let k = rng.sample(Uniform::new(0, n_features).expect("Operation failed"));
666                        let temp = anomaly[j];
667                        anomaly[j] = anomaly[k];
668                        anomaly[k] = temp;
669                    }
670
671                    anomalies.row_mut(i).assign(&anomaly);
672                }
673                Ok(anomalies)
674            }
675            _ => {
676                // Default to point _anomalies implementation
677                let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
678                    DatasetsError::ComputationError(
679                        "Failed to compute mean for normal data".to_string(),
680                    )
681                })?;
682                let normal_std = normal_data.std_axis(Axis(0), 0.0);
683
684                let mut anomalies = Array2::zeros((n_anomalies, n_features));
685                for i in 0..n_anomalies {
686                    for j in 0..n_features {
687                        let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
688                        anomalies[[i, j]] =
689                            normal_mean[j] + direction * config.severity * normal_std[j];
690                    }
691                }
692                Ok(anomalies)
693            }
694        }
695    }
696
697    fn generate_shuffle_indices(&self, n_samples: usize) -> Result<Vec<usize>> {
698        use scirs2_core::random::{Rng, RngExt};
699        let mut rng = thread_rng();
700        let mut indices: Vec<usize> = (0..n_samples).collect();
701
702        // Simple shuffle using Fisher-Yates
703        for i in (1..n_samples).rev() {
704            let j = rng.sample(Uniform::new(0, i).expect("Operation failed"));
705            indices.swap(i, j);
706        }
707
708        Ok(indices)
709    }
710
711    fn shuffle_by_indices(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
712        let mut shuffled = Array2::zeros(data.dim());
713        for (new_idx, &old_idx) in indices.iter().enumerate() {
714            shuffled.row_mut(new_idx).assign(&data.row(old_idx));
715        }
716        shuffled
717    }
718
719    fn shuffle_array_by_indices(&self, array: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
720        let mut shuffled = Array1::zeros(array.len());
721        for (new_idx, &old_idx) in indices.iter().enumerate() {
722            shuffled[new_idx] = array[old_idx];
723        }
724        shuffled
725    }
726
727    fn generate_shared_features(&self, n_samples: usize, n_features: usize) -> Result<Array2<f64>> {
728        // Generate shared _features using multivariate normal distribution
729        let data = Array2::from_shape_fn((n_samples, n_features), |_| {
730            thread_rng().random::<f64>() * 2.0 - 1.0 // Standard normal approximation
731        });
732        Ok(data)
733    }
734
735    fn generate_task_specific_features(
736        &self,
737        n_samples: usize,
738        n_features: usize,
739        task_id: usize,
740    ) -> Result<Array2<f64>> {
741        // Generate task-specific _features with slight bias per task
742        let task_bias = task_id as f64 * 0.1;
743        let data = Array2::from_shape_fn((n_samples, n_features), |_| {
744            thread_rng().random::<f64>() * 2.0 - 1.0 + task_bias
745        });
746        Ok(data)
747    }
748
749    fn combine_features(&self, shared: &Array2<f64>, task_specific: &Array2<f64>) -> Array2<f64> {
750        let n_samples = shared.nrows();
751        let total_features = shared.ncols() + task_specific.ncols();
752        let mut combined = Array2::zeros((n_samples, total_features));
753
754        combined
755            .slice_mut(scirs2_core::ndarray::s![.., ..shared.ncols()])
756            .assign(shared);
757        combined
758            .slice_mut(scirs2_core::ndarray::s![.., shared.ncols()..])
759            .assign(task_specific);
760
761        combined
762    }
763
764    fn generate_task_target(
765        &self,
766        data: &Array2<f64>,
767        task_type: &TaskType,
768        correlation: f64,
769        noise: &f64,
770    ) -> Result<Array1<f64>> {
771        let n_samples = data.nrows();
772
773        match task_type {
774            TaskType::Classification(n_classes) => {
775                // Generate classification target based on data
776                let target = Array1::from_shape_fn(n_samples, |i| {
777                    let feature_sum = data.row(i).sum();
778                    let class = ((feature_sum * correlation).abs() as usize) % n_classes;
779                    class as f64
780                });
781                Ok(target)
782            }
783            TaskType::Regression => {
784                // Generate regression target as linear combination of features
785                let target = Array1::from_shape_fn(n_samples, |i| {
786                    let weighted_sum = data
787                        .row(i)
788                        .iter()
789                        .enumerate()
790                        .map(|(j, &x)| x * (j as f64 + 1.0) * correlation)
791                        .sum::<f64>();
792                    weighted_sum + thread_rng().random::<f64>() * noise
793                });
794                Ok(target)
795            }
796            TaskType::Ordinal(n_levels) => {
797                // Generate ordinal target
798                let target = Array1::from_shape_fn(n_samples, |i| {
799                    let feature_sum = data.row(i).sum();
800                    let level = ((feature_sum * correlation).abs() as usize) % n_levels;
801                    level as f64
802                });
803                Ok(target)
804            }
805        }
806    }
807
808    fn generate_base_domain_dataset(
809        &self,
810        n_samples: usize,
811        n_features: usize,
812        n_classes: usize,
813    ) -> Result<Dataset> {
814        use crate::generators::make_classification;
815        make_classification(
816            n_samples,
817            n_features,
818            n_classes,
819            2,
820            n_features / 2,
821            self.random_state,
822        )
823    }
824
825    fn apply_domain_shift(&self, base_dataset: &Dataset, shift: &DomainShift) -> Result<Dataset> {
826        let shifted_data = &base_dataset.data + &shift.mean_shift;
827
828        let mut metadata = base_dataset.metadata.clone();
829        let old_description = metadata.get("description").cloned().unwrap_or_default();
830        metadata.insert(
831            "description".to_string(),
832            format!("{old_description} (Domain Shifted)"),
833        );
834
835        Ok(Dataset {
836            data: shifted_data,
837            target: base_dataset.target.clone(),
838            targetnames: base_dataset.targetnames.clone(),
839            featurenames: base_dataset.featurenames.clone(),
840            feature_descriptions: base_dataset.feature_descriptions.clone(),
841            description: base_dataset.description.clone(),
842            metadata,
843        })
844    }
845
846    fn generate_support_set(
847        &self,
848        n_way: usize,
849        k_shot: usize,
850        n_features: usize,
851        episode_id: usize,
852    ) -> Result<Dataset> {
853        let n_samples = n_way * k_shot;
854        use crate::generators::make_classification;
855        make_classification(
856            n_samples,
857            n_features,
858            n_way,
859            1,
860            n_features / 2,
861            Some(episode_id as u64),
862        )
863    }
864
865    fn generate_query_set(
866        &self,
867        n_way: usize,
868        n_query: usize,
869        n_features: usize,
870        _set: &Dataset,
871        episode_id: usize,
872    ) -> Result<Dataset> {
873        let n_samples = n_way * n_query;
874        use crate::generators::make_classification;
875        make_classification(
876            n_samples,
877            n_features,
878            n_way,
879            1,
880            n_features / 2,
881            Some(episode_id as u64 + 1000),
882        )
883    }
884
885    fn generate_class_centers(&self, n_classes: usize, n_features: usize) -> Result<Array2<f64>> {
886        let centers = Array2::from_shape_fn((n_classes, n_features), |_| {
887            thread_rng().random::<f64>() * 4.0 - 2.0
888        });
889        Ok(centers)
890    }
891
892    fn generate_classification_from_centers(
893        &self,
894        n_samples: usize,
895        centers: &Array2<f64>,
896        cluster_std: f64,
897        seed: u64,
898    ) -> Result<Dataset> {
899        use crate::generators::make_blobs;
900        make_blobs(
901            n_samples,
902            centers.ncols(),
903            centers.nrows(),
904            cluster_std,
905            Some(seed),
906        )
907    }
908}
909
910/// Multi-task learning dataset container
911#[derive(Debug)]
912pub struct MultiTaskDataset {
913    /// Individual task datasets
914    pub tasks: Vec<Dataset>,
915    /// Number of shared features
916    pub shared_features: usize,
917    /// Correlation between tasks
918    pub task_correlation: f64,
919}
920
921/// Domain adaptation dataset container
922#[derive(Debug)]
923pub struct DomainAdaptationDataset {
924    /// Datasets for each domain (name, dataset)
925    pub domains: Vec<(String, Dataset)>,
926    /// Number of source domains
927    pub n_source_domains: usize,
928}
929
930/// Few-shot learning episode
931#[derive(Debug)]
932pub struct FewShotEpisode {
933    /// Support set for learning
934    pub support_set: Dataset,
935    /// Query set for evaluation
936    pub query_set: Dataset,
937    /// Number of classes (ways)
938    pub n_way: usize,
939    /// Number of examples per class (shots)
940    pub k_shot: usize,
941}
942
943/// Few-shot learning dataset
944#[derive(Debug)]
945pub struct FewShotDataset {
946    /// Training/evaluation episodes
947    pub episodes: Vec<FewShotEpisode>,
948    /// Number of classes per episode
949    pub n_way: usize,
950    /// Number of shots per class
951    pub k_shot: usize,
952    /// Number of query samples per class
953    pub n_query: usize,
954}
955
956/// Continual learning dataset
957#[derive(Debug)]
958pub struct ContinualLearningDataset {
959    /// Sequential tasks
960    pub tasks: Vec<Dataset>,
961    /// Strength of concept drift between tasks
962    pub concept_drift_strength: f64,
963}
964
965/// Convenience functions for advanced data generation
966///
967/// Generate adversarial examples from a base dataset
968#[allow(dead_code)]
969pub fn make_adversarial_examples(
970    base_dataset: &Dataset,
971    config: AdversarialConfig,
972) -> Result<Dataset> {
973    let generator = AdvancedGenerator::new(config.random_state);
974    generator.make_adversarial_examples(base_dataset, config)
975}
976
977/// Generate anomaly detection dataset
978#[allow(dead_code)]
979pub fn make_anomaly_dataset(
980    n_samples: usize,
981    n_features: usize,
982    config: AnomalyConfig,
983) -> Result<Dataset> {
984    let generator = AdvancedGenerator::new(config.random_state);
985    generator.make_anomaly_dataset(n_samples, n_features, config)
986}
987
988/// Generate multi-task learning dataset
989#[allow(dead_code)]
990pub fn make_multitask_dataset(
991    n_samples: usize,
992    config: MultiTaskConfig,
993) -> Result<MultiTaskDataset> {
994    let generator = AdvancedGenerator::new(config.random_state);
995    generator.make_multitask_dataset(n_samples, config)
996}
997
998/// Generate domain adaptation dataset
999#[allow(dead_code)]
1000pub fn make_domain_adaptation_dataset(
1001    n_samples_per_domain: usize,
1002    n_features: usize,
1003    n_classes: usize,
1004    config: DomainAdaptationConfig,
1005) -> Result<DomainAdaptationDataset> {
1006    let generator = AdvancedGenerator::new(config.random_state);
1007    generator.make_domain_adaptation_dataset(n_samples_per_domain, n_features, n_classes, config)
1008}
1009
1010/// Generate few-shot learning dataset
1011#[allow(dead_code)]
1012pub fn make_few_shot_dataset(
1013    n_way: usize,
1014    k_shot: usize,
1015    n_query: usize,
1016    n_episodes: usize,
1017    n_features: usize,
1018) -> Result<FewShotDataset> {
1019    let generator = AdvancedGenerator::new(Some(42));
1020    generator.make_few_shot_dataset(n_way, k_shot, n_query, n_episodes, n_features)
1021}
1022
1023/// Generate continual learning dataset
1024#[allow(dead_code)]
1025pub fn make_continual_learning_dataset(
1026    n_tasks: usize,
1027    n_samples_per_task: usize,
1028    n_features: usize,
1029    n_classes: usize,
1030    concept_drift_strength: f64,
1031) -> Result<ContinualLearningDataset> {
1032    let generator = AdvancedGenerator::new(Some(42));
1033    generator.make_continual_learning_dataset(
1034        n_tasks,
1035        n_samples_per_task,
1036        n_features,
1037        n_classes,
1038        concept_drift_strength,
1039    )
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use crate::generators::make_classification;
1046
1047    #[test]
1048    fn test_adversarial_config() {
1049        let config = AdversarialConfig::default();
1050        assert_eq!(config.epsilon, 0.1);
1051        assert_eq!(config.attack_method, AttackMethod::FGSM);
1052        assert_eq!(config.iterations, 10);
1053    }
1054
1055    #[test]
1056    fn test_anomaly_dataset_generation() {
1057        let config = AnomalyConfig {
1058            anomaly_fraction: 0.2,
1059            anomaly_type: AnomalyType::Point,
1060            severity: 2.0,
1061            ..Default::default()
1062        };
1063
1064        let dataset = make_anomaly_dataset(100, 10, config).expect("Operation failed");
1065
1066        assert_eq!(dataset.n_samples(), 100);
1067        assert_eq!(dataset.n_features(), 10);
1068        assert!(dataset.target.is_some());
1069
1070        // Check that we have both normal and anomalous samples
1071        let target = dataset.target.expect("Test: target required");
1072        let anomalies = target.iter().filter(|&&x| x == 1.0).count();
1073        assert!(anomalies > 0);
1074        assert!(anomalies < 100);
1075    }
1076
1077    #[test]
1078    fn test_multitask_dataset_generation() {
1079        let config = MultiTaskConfig {
1080            n_tasks: 2,
1081            task_types: vec![TaskType::Classification(3), TaskType::Regression],
1082            shared_features: 5,
1083            task_specific_features: 3,
1084            ..Default::default()
1085        };
1086
1087        let dataset = make_multitask_dataset(50, config).expect("Operation failed");
1088
1089        assert_eq!(dataset.tasks.len(), 2);
1090        assert_eq!(dataset.shared_features, 5);
1091
1092        for task in &dataset.tasks {
1093            assert_eq!(task.n_samples(), 50);
1094            assert!(task.target.is_some());
1095        }
1096    }
1097
1098    #[test]
1099    fn test_adversarial_examples_generation() {
1100        let base_dataset =
1101            make_classification(100, 10, 3, 2, 8, Some(42)).expect("Operation failed");
1102        let config = AdversarialConfig {
1103            epsilon: 0.1,
1104            attack_method: AttackMethod::FGSM,
1105            ..Default::default()
1106        };
1107
1108        let adversarial_dataset =
1109            make_adversarial_examples(&base_dataset, config).expect("Operation failed");
1110
1111        assert_eq!(adversarial_dataset.n_samples(), base_dataset.n_samples());
1112        assert_eq!(adversarial_dataset.n_features(), base_dataset.n_features());
1113
1114        // Check that the data has been perturbed (mean difference can cancel; use max absolute diff)
1115        let diff = &adversarial_dataset.data - &base_dataset.data;
1116        let mut max_abs = 0.0_f64;
1117        for v in diff.iter() {
1118            let a = v.abs();
1119            if a > max_abs {
1120                max_abs = a;
1121            }
1122        }
1123        assert!(
1124            max_abs > 1e-9,
1125            "Adversarial perturbation appears to be zero (max abs diff = {max_abs})"
1126        );
1127    }
1128
1129    #[test]
1130    fn test_few_shot_dataset() {
1131        let dataset = make_few_shot_dataset(5, 3, 10, 2, 20).expect("Operation failed");
1132
1133        assert_eq!(dataset.n_way, 5);
1134        assert_eq!(dataset.k_shot, 3);
1135        assert_eq!(dataset.n_query, 10);
1136        assert_eq!(dataset.episodes.len(), 2);
1137
1138        for episode in &dataset.episodes {
1139            assert_eq!(episode.n_way, 5);
1140            assert_eq!(episode.k_shot, 3);
1141            assert_eq!(episode.support_set.n_samples(), 5 * 3); // n_way * k_shot
1142            assert_eq!(episode.query_set.n_samples(), 5 * 10); // n_way * n_query
1143        }
1144    }
1145
1146    #[test]
1147    fn test_continual_learning_dataset() {
1148        let dataset =
1149            make_continual_learning_dataset(3, 100, 10, 5, 0.5).expect("Operation failed");
1150
1151        assert_eq!(dataset.tasks.len(), 3);
1152        assert_eq!(dataset.concept_drift_strength, 0.5);
1153
1154        for task in &dataset.tasks {
1155            assert_eq!(task.n_samples(), 100);
1156            assert_eq!(task.n_features(), 10);
1157            assert!(task.target.is_some());
1158        }
1159    }
1160}