1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use ndarray::{Array1, Array2, Axis};
10use rand::{rng, Rng};
11use rand_distr::Uniform;
12
13#[derive(Debug, Clone)]
15pub struct AdversarialConfig {
16 pub epsilon: f64,
18 pub attack_method: AttackMethod,
20 pub target_class: Option<usize>,
22 pub iterations: usize,
24 pub step_size: f64,
26 pub random_state: Option<u64>,
28}
29
30#[derive(Debug, Clone, PartialEq)]
32pub enum AttackMethod {
33 FGSM,
35 PGD,
37 CW,
39 DeepFool,
41 RandomNoise,
43}
44
45impl Default for AdversarialConfig {
46 fn default() -> Self {
47 Self {
48 epsilon: 0.1,
49 attack_method: AttackMethod::FGSM,
50 target_class: None,
51 iterations: 10,
52 step_size: 0.01,
53 random_state: None,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct AnomalyConfig {
61 pub anomaly_fraction: f64,
63 pub anomaly_type: AnomalyType,
65 pub severity: f64,
67 pub mixed_anomalies: bool,
69 pub clustering_factor: f64,
71 pub random_state: Option<u64>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub enum AnomalyType {
78 Point,
80 Contextual,
82 Collective,
84 Adversarial,
86 Mixed,
88}
89
90impl Default for AnomalyConfig {
91 fn default() -> Self {
92 Self {
93 anomaly_fraction: 0.1,
94 anomaly_type: AnomalyType::Point,
95 severity: 2.0,
96 mixed_anomalies: false,
97 clustering_factor: 1.0,
98 random_state: None,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct MultiTaskConfig {
106 pub n_tasks: usize,
108 pub task_types: Vec<TaskType>,
110 pub shared_features: usize,
112 pub task_specific_features: usize,
114 pub task_correlation: f64,
116 pub task_noise: Vec<f64>,
118 pub random_state: Option<u64>,
120}
121
122#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125 Classification(usize),
127 Regression,
129 Ordinal(usize),
131}
132
133impl Default for MultiTaskConfig {
134 fn default() -> Self {
135 Self {
136 n_tasks: 3,
137 task_types: vec![
138 TaskType::Classification(3),
139 TaskType::Regression,
140 TaskType::Classification(5),
141 ],
142 shared_features: 10,
143 task_specific_features: 5,
144 task_correlation: 0.5,
145 task_noise: vec![0.1, 0.1, 0.1],
146 random_state: None,
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct DomainAdaptationConfig {
154 pub n_source_domains: usize,
156 pub domain_shifts: Vec<DomainShift>,
158 pub label_shift: bool,
160 pub feature_shift: bool,
162 pub concept_drift: bool,
164 pub random_state: Option<u64>,
166}
167
168#[derive(Debug, Clone)]
170pub struct DomainShift {
171 pub mean_shift: Array1<f64>,
173 pub covariance_shift: Option<Array2<f64>>,
175 pub shift_strength: f64,
177}
178
179impl Default for DomainAdaptationConfig {
180 fn default() -> Self {
181 Self {
182 n_source_domains: 2,
183 domain_shifts: vec![],
184 label_shift: true,
185 feature_shift: true,
186 concept_drift: false,
187 random_state: None,
188 }
189 }
190}
191
192pub struct AdvancedGenerator {
194 random_state: Option<u64>,
195}
196
197impl AdvancedGenerator {
198 pub fn new(_random_state: Option<u64>) -> Self {
200 Self {
201 random_state: _random_state,
202 }
203 }
204
205 pub fn make_adversarial_examples(
207 &self,
208 base_dataset: &Dataset,
209 config: AdversarialConfig,
210 ) -> Result<Dataset> {
211 let n_samples = base_dataset.n_samples();
212 let _n_features = base_dataset.n_features();
213
214 println!(
215 "Generating adversarial examples using {:?}",
216 config.attack_method
217 );
218
219 let perturbations = self.generate_perturbations(&base_dataset.data, &config)?;
221
222 let adversarial_data = &base_dataset.data + &perturbations;
224
225 let clipped_data = adversarial_data.mapv(|x| x.clamp(-5.0, 5.0));
227
228 let adversarial_target = if let Some(target) = &base_dataset.target {
230 match config.target_class {
231 Some(target_class) => {
232 Some(Array1::from_elem(n_samples, target_class as f64))
234 }
235 None => {
236 Some(target.clone())
238 }
239 }
240 } else {
241 None
242 };
243
244 let mut metadata = base_dataset.metadata.clone();
245 let _old_description = metadata.get("description").cloned().unwrap_or_default();
246 let oldname = metadata.get("name").cloned().unwrap_or_default();
247
248 metadata.insert(
249 "description".to_string(),
250 format!(
251 "Adversarial examples generated using {:?}",
252 config.attack_method
253 ),
254 );
255 metadata.insert("name".to_string(), format!("{oldname} (Adversarial)"));
256
257 Ok(Dataset {
258 data: clipped_data,
259 target: adversarial_target,
260 targetnames: base_dataset.targetnames.clone(),
261 featurenames: base_dataset.featurenames.clone(),
262 feature_descriptions: base_dataset.feature_descriptions.clone(),
263 description: base_dataset.description.clone(),
264 metadata,
265 })
266 }
267
268 pub fn make_anomaly_dataset(
270 &self,
271 n_samples: usize,
272 n_features: usize,
273 config: AnomalyConfig,
274 ) -> Result<Dataset> {
275 let n_anomalies = (n_samples as f64 * config.anomaly_fraction) as usize;
276 let n_normal = n_samples - n_anomalies;
277
278 println!("Generating anomaly dataset: {n_normal} normal, {n_anomalies} anomalous");
279
280 let normal_data =
282 self.generate_normal_data(n_normal, n_features, config.clustering_factor)?;
283
284 let anomalous_data =
286 self.generate_anomalous_data(n_anomalies, n_features, &normal_data, &config)?;
287
288 let mut combined_data = Array2::zeros((n_samples, n_features));
290 combined_data
291 .slice_mut(ndarray::s![..n_normal, ..])
292 .assign(&normal_data);
293 combined_data
294 .slice_mut(ndarray::s![n_normal.., ..])
295 .assign(&anomalous_data);
296
297 let mut target = Array1::zeros(n_samples);
299 target.slice_mut(ndarray::s![n_normal..]).fill(1.0);
300
301 let shuffled_indices = self.generate_shuffle_indices(n_samples)?;
303 let shuffled_data = self.shuffle_by_indices(&combined_data, &shuffled_indices);
304 let shuffled_target = self.shuffle_array_by_indices(&target, &shuffled_indices);
305
306 let metadata = crate::registry::DatasetMetadata {
307 name: "Anomaly Detection Dataset".to_string(),
308 description: format!(
309 "Synthetic anomaly detection dataset with {:.1}% anomalies",
310 config.anomaly_fraction * 100.0
311 ),
312 n_samples,
313 n_features,
314 task_type: "anomaly_detection".to_string(),
315 targetnames: Some(vec!["normal".to_string(), "anomaly".to_string()]),
316 ..Default::default()
317 };
318
319 Ok(Dataset::from_metadata(
320 shuffled_data,
321 Some(shuffled_target),
322 metadata,
323 ))
324 }
325
326 pub fn make_multitask_dataset(
328 &self,
329 n_samples: usize,
330 config: MultiTaskConfig,
331 ) -> Result<MultiTaskDataset> {
332 let total_features =
333 config.shared_features + config.task_specific_features * config.n_tasks;
334
335 println!(
336 "Generating multi-task dataset: {} tasks, {} samples, {} features",
337 config.n_tasks, n_samples, total_features
338 );
339
340 let shared_data = self.generate_shared_features(n_samples, config.shared_features)?;
342
343 let mut task_datasets = Vec::new();
345
346 for (task_id, task_type) in config.task_types.iter().enumerate() {
347 let task_specific_data = self.generate_task_specific_features(
348 n_samples,
349 config.task_specific_features,
350 task_id,
351 )?;
352
353 let task_data = self.combine_features(&shared_data, &task_specific_data);
355
356 let task_target = self.generate_task_target(
358 &task_data,
359 task_type,
360 config.task_correlation,
361 config.task_noise.get(task_id).unwrap_or(&0.1),
362 )?;
363
364 let task_metadata = crate::registry::DatasetMetadata {
365 name: format!("Task {task_id}"),
366 description: format!("Multi-task learning task {task_id} ({task_type:?})"),
367 n_samples,
368 n_features: task_data.ncols(),
369 task_type: match task_type {
370 TaskType::Classification(_) => "classification".to_string(),
371 TaskType::Regression => "regression".to_string(),
372 TaskType::Ordinal(_) => "ordinal_regression".to_string(),
373 },
374 ..Default::default()
375 };
376
377 task_datasets.push(Dataset::from_metadata(
378 task_data,
379 Some(task_target),
380 task_metadata,
381 ));
382 }
383
384 Ok(MultiTaskDataset {
385 tasks: task_datasets,
386 shared_features: config.shared_features,
387 task_correlation: config.task_correlation,
388 })
389 }
390
391 pub fn make_domain_adaptation_dataset(
393 &self,
394 n_samples_per_domain: usize,
395 n_features: usize,
396 n_classes: usize,
397 config: DomainAdaptationConfig,
398 ) -> Result<DomainAdaptationDataset> {
399 let total_domains = config.n_source_domains + 1; println!(
402 "Generating _domain adaptation dataset: {total_domains} domains, {n_samples_per_domain} samples each"
403 );
404
405 let mut domain_datasets = Vec::new();
406
407 let source_dataset =
409 self.generate_base_domain_dataset(n_samples_per_domain, n_features, n_classes)?;
410
411 domain_datasets.push(("source".to_string(), source_dataset.clone()));
412
413 for domain_id in 1..config.n_source_domains {
415 let shift = if domain_id - 1 < config.domain_shifts.len() {
416 &config.domain_shifts[domain_id - 1]
417 } else {
418 &DomainShift {
420 mean_shift: Array1::from_elem(n_features, 0.5),
421 covariance_shift: None,
422 shift_strength: 1.0,
423 }
424 };
425
426 let shifted_dataset = self.apply_domain_shift(&source_dataset, shift)?;
427 domain_datasets.push((format!("source_{domain_id}"), shifted_dataset));
428 }
429
430 let target_shift = DomainShift {
432 mean_shift: Array1::from_elem(n_features, 1.0),
433 covariance_shift: None,
434 shift_strength: 1.5,
435 };
436
437 let target_dataset = self.apply_domain_shift(&source_dataset, &target_shift)?;
438 domain_datasets.push(("target".to_string(), target_dataset));
439
440 Ok(DomainAdaptationDataset {
441 domains: domain_datasets,
442 n_source_domains: config.n_source_domains,
443 })
444 }
445
446 pub fn make_few_shot_dataset(
448 &self,
449 n_way: usize,
450 k_shot: usize,
451 n_query: usize,
452 n_episodes: usize,
453 n_features: usize,
454 ) -> Result<FewShotDataset> {
455 println!(
456 "Generating few-_shot dataset: {n_way}-_way {k_shot}-_shot, {n_episodes} _episodes"
457 );
458
459 let mut episodes = Vec::new();
460
461 for episode_id in 0..n_episodes {
462 let support_set = self.generate_support_set(n_way, k_shot, n_features, episode_id)?;
463 let query_set =
464 self.generate_query_set(n_way, n_query, n_features, &support_set, episode_id)?;
465
466 episodes.push(FewShotEpisode {
467 support_set,
468 query_set,
469 n_way,
470 k_shot,
471 });
472 }
473
474 Ok(FewShotDataset {
475 episodes,
476 n_way,
477 k_shot,
478 n_query,
479 })
480 }
481
482 pub fn make_continual_learning_dataset(
484 &self,
485 n_tasks: usize,
486 n_samples_per_task: usize,
487 n_features: usize,
488 n_classes: usize,
489 concept_drift_strength: f64,
490 ) -> Result<ContinualLearningDataset> {
491 println!("Generating continual learning dataset: {n_tasks} _tasks with concept drift");
492
493 let mut task_datasets = Vec::new();
494 let mut base_centers = self.generate_class_centers(n_classes, n_features)?;
495
496 for task_id in 0..n_tasks {
497 if task_id > 0 {
499 let drift = Array2::from_shape_fn((n_classes, n_features), |_| {
500 rng().random::<f64>() * concept_drift_strength
501 });
502 base_centers = base_centers + drift;
503 }
504
505 let task_dataset = self.generate_classification_from_centers(
506 n_samples_per_task,
507 &base_centers,
508 1.0, task_id as u64,
510 )?;
511
512 let mut metadata = task_dataset.metadata.clone();
513 metadata.insert(
514 "name".to_string(),
515 format!("Continual Learning Task {task_id}"),
516 );
517 metadata.insert(
518 "description".to_string(),
519 format!("Task {task_id} with concept drift _strength {concept_drift_strength:.2}"),
520 );
521
522 task_datasets.push(Dataset {
523 data: task_dataset.data,
524 target: task_dataset.target,
525 targetnames: task_dataset.targetnames,
526 featurenames: task_dataset.featurenames,
527 feature_descriptions: task_dataset.feature_descriptions,
528 description: task_dataset.description,
529 metadata,
530 });
531 }
532
533 Ok(ContinualLearningDataset {
534 tasks: task_datasets,
535 concept_drift_strength,
536 })
537 }
538
539 fn generate_perturbations(
542 &self,
543 data: &Array2<f64>,
544 config: &AdversarialConfig,
545 ) -> Result<Array2<f64>> {
546 let (n_samples, n_features) = data.dim();
547
548 match config.attack_method {
549 AttackMethod::FGSM => {
550 let mut perturbations = Array2::zeros((n_samples, n_features));
552 for i in 0..n_samples {
553 for j in 0..n_features {
554 let sign = if rng().random::<f64>() > 0.5 {
555 1.0
556 } else {
557 -1.0
558 };
559 perturbations[[i, j]] = config.epsilon * sign;
560 }
561 }
562 Ok(perturbations)
563 }
564 AttackMethod::PGD => {
565 let mut perturbations: Array2<f64> = Array2::zeros((n_samples, n_features));
567 for _iter in 0..config.iterations {
568 for i in 0..n_samples {
569 for j in 0..n_features {
570 let gradient = rng().random::<f64>() * 2.0 - 1.0; perturbations[[i, j]] += config.step_size * gradient.signum();
572 perturbations[[i, j]] =
574 perturbations[[i, j]].clamp(-config.epsilon, config.epsilon);
575 }
576 }
577 }
578 Ok(perturbations)
579 }
580 AttackMethod::RandomNoise => {
581 let perturbations = Array2::from_shape_fn((n_samples, n_features), |_| {
583 (rng().random::<f64>() * 2.0 - 1.0) * config.epsilon
584 });
585 Ok(perturbations)
586 }
587 _ => {
588 let mut perturbations = Array2::zeros(data.dim());
590 for i in 0..data.nrows() {
591 for j in 0..data.ncols() {
592 let noise = rng().random::<f64>() * 2.0 - 1.0;
593 perturbations[[i, j]] = config.epsilon * noise;
594 }
595 }
596 Ok(perturbations)
597 }
598 }
599 }
600
601 fn generate_normal_data(
602 &self,
603 n_samples: usize,
604 n_features: usize,
605 clustering_factor: f64,
606 ) -> Result<Array2<f64>> {
607 use crate::generators::make_blobs;
609 let n_clusters = ((n_features as f64).sqrt() as usize).max(2);
610 let dataset = make_blobs(
611 n_samples,
612 n_features,
613 n_clusters,
614 clustering_factor,
615 self.random_state,
616 )?;
617 Ok(dataset.data)
618 }
619
620 fn generate_anomalous_data(
621 &self,
622 n_anomalies: usize,
623 n_features: usize,
624 normal_data: &Array2<f64>,
625 config: &AnomalyConfig,
626 ) -> Result<Array2<f64>> {
627 use rand::Rng;
628 let mut rng = rng();
629
630 match config.anomaly_type {
631 AnomalyType::Point => {
632 let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
634 DatasetsError::ComputationError(
635 "Failed to compute mean for normal data".to_string(),
636 )
637 })?;
638 let normal_std = normal_data.std_axis(Axis(0), 0.0);
639
640 let mut anomalies = Array2::zeros((n_anomalies, n_features));
641 for i in 0..n_anomalies {
642 for j in 0..n_features {
643 let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
644 anomalies[[i, j]] =
645 normal_mean[j] + direction * config.severity * normal_std[j];
646 }
647 }
648 Ok(anomalies)
649 }
650 AnomalyType::Contextual => {
651 let mut anomalies: Array2<f64> = Array2::zeros((n_anomalies, n_features));
653 for i in 0..n_anomalies {
654 let base_idx = rng.sample(Uniform::new(0, normal_data.nrows()).unwrap());
656 let mut anomaly = normal_data.row(base_idx).to_owned();
657
658 let n_permute = (n_features as f64 * 0.3) as usize;
660 for _ in 0..n_permute {
661 let j = rng.sample(Uniform::new(0, n_features).unwrap());
662 let k = rng.sample(Uniform::new(0, n_features).unwrap());
663 let temp = anomaly[j];
664 anomaly[j] = anomaly[k];
665 anomaly[k] = temp;
666 }
667
668 anomalies.row_mut(i).assign(&anomaly);
669 }
670 Ok(anomalies)
671 }
672 _ => {
673 let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
675 DatasetsError::ComputationError(
676 "Failed to compute mean for normal data".to_string(),
677 )
678 })?;
679 let normal_std = normal_data.std_axis(Axis(0), 0.0);
680
681 let mut anomalies = Array2::zeros((n_anomalies, n_features));
682 for i in 0..n_anomalies {
683 for j in 0..n_features {
684 let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
685 anomalies[[i, j]] =
686 normal_mean[j] + direction * config.severity * normal_std[j];
687 }
688 }
689 Ok(anomalies)
690 }
691 }
692 }
693
694 fn generate_shuffle_indices(&self, n_samples: usize) -> Result<Vec<usize>> {
695 use rand::Rng;
696 let mut rng = rng();
697 let mut indices: Vec<usize> = (0..n_samples).collect();
698
699 for i in (1..n_samples).rev() {
701 let j = rng.sample(Uniform::new(0, i).unwrap());
702 indices.swap(i, j);
703 }
704
705 Ok(indices)
706 }
707
708 fn shuffle_by_indices(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
709 let mut shuffled = Array2::zeros(data.dim());
710 for (new_idx, &old_idx) in indices.iter().enumerate() {
711 shuffled.row_mut(new_idx).assign(&data.row(old_idx));
712 }
713 shuffled
714 }
715
716 fn shuffle_array_by_indices(&self, array: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
717 let mut shuffled = Array1::zeros(array.len());
718 for (new_idx, &old_idx) in indices.iter().enumerate() {
719 shuffled[new_idx] = array[old_idx];
720 }
721 shuffled
722 }
723
724 fn generate_shared_features(&self, n_samples: usize, n_features: usize) -> Result<Array2<f64>> {
725 let data = Array2::from_shape_fn((n_samples, n_features), |_| {
727 rng().random::<f64>() * 2.0 - 1.0 });
729 Ok(data)
730 }
731
732 fn generate_task_specific_features(
733 &self,
734 n_samples: usize,
735 n_features: usize,
736 task_id: usize,
737 ) -> Result<Array2<f64>> {
738 let task_bias = task_id as f64 * 0.1;
740 let data = Array2::from_shape_fn((n_samples, n_features), |_| {
741 rng().random::<f64>() * 2.0 - 1.0 + task_bias
742 });
743 Ok(data)
744 }
745
746 fn combine_features(&self, shared: &Array2<f64>, task_specific: &Array2<f64>) -> Array2<f64> {
747 let n_samples = shared.nrows();
748 let total_features = shared.ncols() + task_specific.ncols();
749 let mut combined = Array2::zeros((n_samples, total_features));
750
751 combined
752 .slice_mut(ndarray::s![.., ..shared.ncols()])
753 .assign(shared);
754 combined
755 .slice_mut(ndarray::s![.., shared.ncols()..])
756 .assign(task_specific);
757
758 combined
759 }
760
761 fn generate_task_target(
762 &self,
763 data: &Array2<f64>,
764 task_type: &TaskType,
765 correlation: f64,
766 noise: &f64,
767 ) -> Result<Array1<f64>> {
768 let n_samples = data.nrows();
769
770 match task_type {
771 TaskType::Classification(n_classes) => {
772 let target = Array1::from_shape_fn(n_samples, |i| {
774 let feature_sum = data.row(i).sum();
775 let class = ((feature_sum * correlation).abs() as usize) % n_classes;
776 class as f64
777 });
778 Ok(target)
779 }
780 TaskType::Regression => {
781 let target = Array1::from_shape_fn(n_samples, |i| {
783 let weighted_sum = data
784 .row(i)
785 .iter()
786 .enumerate()
787 .map(|(j, &x)| x * (j as f64 + 1.0) * correlation)
788 .sum::<f64>();
789 weighted_sum + rng().random::<f64>() * noise
790 });
791 Ok(target)
792 }
793 TaskType::Ordinal(n_levels) => {
794 let target = Array1::from_shape_fn(n_samples, |i| {
796 let feature_sum = data.row(i).sum();
797 let level = ((feature_sum * correlation).abs() as usize) % n_levels;
798 level as f64
799 });
800 Ok(target)
801 }
802 }
803 }
804
805 fn generate_base_domain_dataset(
806 &self,
807 n_samples: usize,
808 n_features: usize,
809 n_classes: usize,
810 ) -> Result<Dataset> {
811 use crate::generators::make_classification;
812 make_classification(
813 n_samples,
814 n_features,
815 n_classes,
816 2,
817 n_features / 2,
818 self.random_state,
819 )
820 }
821
822 fn apply_domain_shift(&self, base_dataset: &Dataset, shift: &DomainShift) -> Result<Dataset> {
823 let shifted_data = &base_dataset.data + &shift.mean_shift;
824
825 let mut metadata = base_dataset.metadata.clone();
826 let old_description = metadata.get("description").cloned().unwrap_or_default();
827 metadata.insert(
828 "description".to_string(),
829 format!("{old_description} (Domain Shifted)"),
830 );
831
832 Ok(Dataset {
833 data: shifted_data,
834 target: base_dataset.target.clone(),
835 targetnames: base_dataset.targetnames.clone(),
836 featurenames: base_dataset.featurenames.clone(),
837 feature_descriptions: base_dataset.feature_descriptions.clone(),
838 description: base_dataset.description.clone(),
839 metadata,
840 })
841 }
842
843 fn generate_support_set(
844 &self,
845 n_way: usize,
846 k_shot: usize,
847 n_features: usize,
848 episode_id: usize,
849 ) -> Result<Dataset> {
850 let n_samples = n_way * k_shot;
851 use crate::generators::make_classification;
852 make_classification(
853 n_samples,
854 n_features,
855 n_way,
856 1,
857 n_features / 2,
858 Some(episode_id as u64),
859 )
860 }
861
862 fn generate_query_set(
863 &self,
864 n_way: usize,
865 n_query: usize,
866 n_features: usize,
867 _set: &Dataset,
868 episode_id: usize,
869 ) -> Result<Dataset> {
870 let n_samples = n_way * n_query;
871 use crate::generators::make_classification;
872 make_classification(
873 n_samples,
874 n_features,
875 n_way,
876 1,
877 n_features / 2,
878 Some(episode_id as u64 + 1000),
879 )
880 }
881
882 fn generate_class_centers(&self, n_classes: usize, n_features: usize) -> Result<Array2<f64>> {
883 let centers = Array2::from_shape_fn((n_classes, n_features), |_| {
884 rng().random::<f64>() * 4.0 - 2.0
885 });
886 Ok(centers)
887 }
888
889 fn generate_classification_from_centers(
890 &self,
891 n_samples: usize,
892 centers: &Array2<f64>,
893 cluster_std: f64,
894 seed: u64,
895 ) -> Result<Dataset> {
896 use crate::generators::make_blobs;
897 make_blobs(
898 n_samples,
899 centers.ncols(),
900 centers.nrows(),
901 cluster_std,
902 Some(seed),
903 )
904 }
905}
906
907#[derive(Debug)]
909pub struct MultiTaskDataset {
910 pub tasks: Vec<Dataset>,
912 pub shared_features: usize,
914 pub task_correlation: f64,
916}
917
918#[derive(Debug)]
920pub struct DomainAdaptationDataset {
921 pub domains: Vec<(String, Dataset)>,
923 pub n_source_domains: usize,
925}
926
927#[derive(Debug)]
929pub struct FewShotEpisode {
930 pub support_set: Dataset,
932 pub query_set: Dataset,
934 pub n_way: usize,
936 pub k_shot: usize,
938}
939
940#[derive(Debug)]
942pub struct FewShotDataset {
943 pub episodes: Vec<FewShotEpisode>,
945 pub n_way: usize,
947 pub k_shot: usize,
949 pub n_query: usize,
951}
952
953#[derive(Debug)]
955pub struct ContinualLearningDataset {
956 pub tasks: Vec<Dataset>,
958 pub concept_drift_strength: f64,
960}
961
962#[allow(dead_code)]
966pub fn make_adversarial_examples(
967 base_dataset: &Dataset,
968 config: AdversarialConfig,
969) -> Result<Dataset> {
970 let generator = AdvancedGenerator::new(config.random_state);
971 generator.make_adversarial_examples(base_dataset, config)
972}
973
974#[allow(dead_code)]
976pub fn make_anomaly_dataset(
977 n_samples: usize,
978 n_features: usize,
979 config: AnomalyConfig,
980) -> Result<Dataset> {
981 let generator = AdvancedGenerator::new(config.random_state);
982 generator.make_anomaly_dataset(n_samples, n_features, config)
983}
984
985#[allow(dead_code)]
987pub fn make_multitask_dataset(
988 n_samples: usize,
989 config: MultiTaskConfig,
990) -> Result<MultiTaskDataset> {
991 let generator = AdvancedGenerator::new(config.random_state);
992 generator.make_multitask_dataset(n_samples, config)
993}
994
995#[allow(dead_code)]
997pub fn make_domain_adaptation_dataset(
998 n_samples_per_domain: usize,
999 n_features: usize,
1000 n_classes: usize,
1001 config: DomainAdaptationConfig,
1002) -> Result<DomainAdaptationDataset> {
1003 let generator = AdvancedGenerator::new(config.random_state);
1004 generator.make_domain_adaptation_dataset(n_samples_per_domain, n_features, n_classes, config)
1005}
1006
1007#[allow(dead_code)]
1009pub fn make_few_shot_dataset(
1010 n_way: usize,
1011 k_shot: usize,
1012 n_query: usize,
1013 n_episodes: usize,
1014 n_features: usize,
1015) -> Result<FewShotDataset> {
1016 let generator = AdvancedGenerator::new(Some(42));
1017 generator.make_few_shot_dataset(n_way, k_shot, n_query, n_episodes, n_features)
1018}
1019
1020#[allow(dead_code)]
1022pub fn make_continual_learning_dataset(
1023 n_tasks: usize,
1024 n_samples_per_task: usize,
1025 n_features: usize,
1026 n_classes: usize,
1027 concept_drift_strength: f64,
1028) -> Result<ContinualLearningDataset> {
1029 let generator = AdvancedGenerator::new(Some(42));
1030 generator.make_continual_learning_dataset(
1031 n_tasks,
1032 n_samples_per_task,
1033 n_features,
1034 n_classes,
1035 concept_drift_strength,
1036 )
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042 use crate::generators::make_classification;
1043
1044 #[test]
1045 fn test_adversarial_config() {
1046 let config = AdversarialConfig::default();
1047 assert_eq!(config.epsilon, 0.1);
1048 assert_eq!(config.attack_method, AttackMethod::FGSM);
1049 assert_eq!(config.iterations, 10);
1050 }
1051
1052 #[test]
1053 fn test_anomaly_dataset_generation() {
1054 let config = AnomalyConfig {
1055 anomaly_fraction: 0.2,
1056 anomaly_type: AnomalyType::Point,
1057 severity: 2.0,
1058 ..Default::default()
1059 };
1060
1061 let dataset = make_anomaly_dataset(100, 10, config).unwrap();
1062
1063 assert_eq!(dataset.n_samples(), 100);
1064 assert_eq!(dataset.n_features(), 10);
1065 assert!(dataset.target.is_some());
1066
1067 let target = dataset.target.unwrap();
1069 let anomalies = target.iter().filter(|&&x| x == 1.0).count();
1070 assert!(anomalies > 0);
1071 assert!(anomalies < 100);
1072 }
1073
1074 #[test]
1075 fn test_multitask_dataset_generation() {
1076 let config = MultiTaskConfig {
1077 n_tasks: 2,
1078 task_types: vec![TaskType::Classification(3), TaskType::Regression],
1079 shared_features: 5,
1080 task_specific_features: 3,
1081 ..Default::default()
1082 };
1083
1084 let dataset = make_multitask_dataset(50, config).unwrap();
1085
1086 assert_eq!(dataset.tasks.len(), 2);
1087 assert_eq!(dataset.shared_features, 5);
1088
1089 for task in &dataset.tasks {
1090 assert_eq!(task.n_samples(), 50);
1091 assert!(task.target.is_some());
1092 }
1093 }
1094
1095 #[test]
1096 fn test_adversarial_examples_generation() {
1097 let base_dataset = make_classification(100, 10, 3, 2, 8, Some(42)).unwrap();
1098 let config = AdversarialConfig {
1099 epsilon: 0.1,
1100 attack_method: AttackMethod::FGSM,
1101 ..Default::default()
1102 };
1103
1104 let adversarial_dataset = make_adversarial_examples(&base_dataset, config).unwrap();
1105
1106 assert_eq!(adversarial_dataset.n_samples(), base_dataset.n_samples());
1107 assert_eq!(adversarial_dataset.n_features(), base_dataset.n_features());
1108
1109 let original_mean = base_dataset.data.mean().unwrap_or(0.0);
1111 let adversarial_mean = adversarial_dataset.data.mean().unwrap_or(0.0);
1112 assert!((original_mean - adversarial_mean).abs() > 1e-6);
1113 }
1114
1115 #[test]
1116 fn test_few_shot_dataset() {
1117 let dataset = make_few_shot_dataset(5, 3, 10, 2, 20).unwrap();
1118
1119 assert_eq!(dataset.n_way, 5);
1120 assert_eq!(dataset.k_shot, 3);
1121 assert_eq!(dataset.n_query, 10);
1122 assert_eq!(dataset.episodes.len(), 2);
1123
1124 for episode in &dataset.episodes {
1125 assert_eq!(episode.n_way, 5);
1126 assert_eq!(episode.k_shot, 3);
1127 assert_eq!(episode.support_set.n_samples(), 5 * 3); assert_eq!(episode.query_set.n_samples(), 5 * 10); }
1130 }
1131
1132 #[test]
1133 fn test_continual_learning_dataset() {
1134 let dataset = make_continual_learning_dataset(3, 100, 10, 5, 0.5).unwrap();
1135
1136 assert_eq!(dataset.tasks.len(), 3);
1137 assert_eq!(dataset.concept_drift_strength, 0.5);
1138
1139 for task in &dataset.tasks {
1140 assert_eq!(task.n_samples(), 100);
1141 assert_eq!(task.n_features(), 10);
1142 assert!(task.target.is_some());
1143 }
1144 }
1145}