1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::Uniform;
12
13#[derive(Debug, Clone)]
15pub struct AdversarialConfig {
16 pub epsilon: f64,
18 pub attack_method: AttackMethod,
20 pub target_class: Option<usize>,
22 pub iterations: usize,
24 pub step_size: f64,
26 pub random_state: Option<u64>,
28}
29
30#[derive(Debug, Clone, PartialEq)]
32pub enum AttackMethod {
33 FGSM,
35 PGD,
37 CW,
39 DeepFool,
41 RandomNoise,
43}
44
45impl Default for AdversarialConfig {
46 fn default() -> Self {
47 Self {
48 epsilon: 0.1,
49 attack_method: AttackMethod::FGSM,
50 target_class: None,
51 iterations: 10,
52 step_size: 0.01,
53 random_state: None,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct AnomalyConfig {
61 pub anomaly_fraction: f64,
63 pub anomaly_type: AnomalyType,
65 pub severity: f64,
67 pub mixed_anomalies: bool,
69 pub clustering_factor: f64,
71 pub random_state: Option<u64>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub enum AnomalyType {
78 Point,
80 Contextual,
82 Collective,
84 Adversarial,
86 Mixed,
88}
89
90impl Default for AnomalyConfig {
91 fn default() -> Self {
92 Self {
93 anomaly_fraction: 0.1,
94 anomaly_type: AnomalyType::Point,
95 severity: 2.0,
96 mixed_anomalies: false,
97 clustering_factor: 1.0,
98 random_state: None,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct MultiTaskConfig {
106 pub n_tasks: usize,
108 pub task_types: Vec<TaskType>,
110 pub shared_features: usize,
112 pub task_specific_features: usize,
114 pub task_correlation: f64,
116 pub task_noise: Vec<f64>,
118 pub random_state: Option<u64>,
120}
121
122#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125 Classification(usize),
127 Regression,
129 Ordinal(usize),
131}
132
133impl Default for MultiTaskConfig {
134 fn default() -> Self {
135 Self {
136 n_tasks: 3,
137 task_types: vec![
138 TaskType::Classification(3),
139 TaskType::Regression,
140 TaskType::Classification(5),
141 ],
142 shared_features: 10,
143 task_specific_features: 5,
144 task_correlation: 0.5,
145 task_noise: vec![0.1, 0.1, 0.1],
146 random_state: None,
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct DomainAdaptationConfig {
154 pub n_source_domains: usize,
156 pub domain_shifts: Vec<DomainShift>,
158 pub label_shift: bool,
160 pub feature_shift: bool,
162 pub concept_drift: bool,
164 pub random_state: Option<u64>,
166}
167
168#[derive(Debug, Clone)]
170pub struct DomainShift {
171 pub mean_shift: Array1<f64>,
173 pub covariance_shift: Option<Array2<f64>>,
175 pub shift_strength: f64,
177}
178
179impl Default for DomainAdaptationConfig {
180 fn default() -> Self {
181 Self {
182 n_source_domains: 2,
183 domain_shifts: vec![],
184 label_shift: true,
185 feature_shift: true,
186 concept_drift: false,
187 random_state: None,
188 }
189 }
190}
191
192pub struct AdvancedGenerator {
194 random_state: Option<u64>,
195}
196
197impl AdvancedGenerator {
198 pub fn new(_random_state: Option<u64>) -> Self {
200 Self {
201 random_state: _random_state,
202 }
203 }
204
205 pub fn make_adversarial_examples(
207 &self,
208 base_dataset: &Dataset,
209 config: AdversarialConfig,
210 ) -> Result<Dataset> {
211 let n_samples = base_dataset.n_samples();
212 let _n_features = base_dataset.n_features();
213
214 println!(
215 "Generating adversarial examples using {:?}",
216 config.attack_method
217 );
218
219 let perturbations = self.generate_perturbations(&base_dataset.data, &config)?;
221
222 let adversarial_data = &base_dataset.data + &perturbations;
224
225 let clipped_data = adversarial_data.mapv(|x| x.clamp(-5.0, 5.0));
227
228 let adversarial_target = if let Some(target) = &base_dataset.target {
230 match config.target_class {
231 Some(target_class) => {
232 Some(Array1::from_elem(n_samples, target_class as f64))
234 }
235 None => {
236 Some(target.clone())
238 }
239 }
240 } else {
241 None
242 };
243
244 let mut metadata = base_dataset.metadata.clone();
245 let _old_description = metadata.get("description").cloned().unwrap_or_default();
246 let oldname = metadata.get("name").cloned().unwrap_or_default();
247
248 metadata.insert(
249 "description".to_string(),
250 format!(
251 "Adversarial examples generated using {:?}",
252 config.attack_method
253 ),
254 );
255 metadata.insert("name".to_string(), format!("{oldname} (Adversarial)"));
256
257 Ok(Dataset {
258 data: clipped_data,
259 target: adversarial_target,
260 targetnames: base_dataset.targetnames.clone(),
261 featurenames: base_dataset.featurenames.clone(),
262 feature_descriptions: base_dataset.feature_descriptions.clone(),
263 description: base_dataset.description.clone(),
264 metadata,
265 })
266 }
267
268 pub fn make_anomaly_dataset(
270 &self,
271 n_samples: usize,
272 n_features: usize,
273 config: AnomalyConfig,
274 ) -> Result<Dataset> {
275 let n_anomalies = (n_samples as f64 * config.anomaly_fraction) as usize;
276 let n_normal = n_samples - n_anomalies;
277
278 println!("Generating anomaly dataset: {n_normal} normal, {n_anomalies} anomalous");
279
280 let normal_data =
282 self.generate_normal_data(n_normal, n_features, config.clustering_factor)?;
283
284 let anomalous_data =
286 self.generate_anomalous_data(n_anomalies, n_features, &normal_data, &config)?;
287
288 let mut combined_data = Array2::zeros((n_samples, n_features));
290 combined_data
291 .slice_mut(scirs2_core::ndarray::s![..n_normal, ..])
292 .assign(&normal_data);
293 combined_data
294 .slice_mut(scirs2_core::ndarray::s![n_normal.., ..])
295 .assign(&anomalous_data);
296
297 let mut target = Array1::zeros(n_samples);
299 target
300 .slice_mut(scirs2_core::ndarray::s![n_normal..])
301 .fill(1.0);
302
303 let shuffled_indices = self.generate_shuffle_indices(n_samples)?;
305 let shuffled_data = self.shuffle_by_indices(&combined_data, &shuffled_indices);
306 let shuffled_target = self.shuffle_array_by_indices(&target, &shuffled_indices);
307
308 let metadata = crate::registry::DatasetMetadata {
309 name: "Anomaly Detection Dataset".to_string(),
310 description: format!(
311 "Synthetic anomaly detection dataset with {:.1}% anomalies",
312 config.anomaly_fraction * 100.0
313 ),
314 n_samples,
315 n_features,
316 task_type: "anomaly_detection".to_string(),
317 targetnames: Some(vec!["normal".to_string(), "anomaly".to_string()]),
318 ..Default::default()
319 };
320
321 Ok(Dataset::from_metadata(
322 shuffled_data,
323 Some(shuffled_target),
324 metadata,
325 ))
326 }
327
328 pub fn make_multitask_dataset(
330 &self,
331 n_samples: usize,
332 config: MultiTaskConfig,
333 ) -> Result<MultiTaskDataset> {
334 let total_features =
335 config.shared_features + config.task_specific_features * config.n_tasks;
336
337 println!(
338 "Generating multi-task dataset: {} tasks, {} samples, {} features",
339 config.n_tasks, n_samples, total_features
340 );
341
342 let shared_data = self.generate_shared_features(n_samples, config.shared_features)?;
344
345 let mut task_datasets = Vec::new();
347
348 for (task_id, task_type) in config.task_types.iter().enumerate() {
349 let task_specific_data = self.generate_task_specific_features(
350 n_samples,
351 config.task_specific_features,
352 task_id,
353 )?;
354
355 let task_data = self.combine_features(&shared_data, &task_specific_data);
357
358 let task_target = self.generate_task_target(
360 &task_data,
361 task_type,
362 config.task_correlation,
363 config.task_noise.get(task_id).unwrap_or(&0.1),
364 )?;
365
366 let task_metadata = crate::registry::DatasetMetadata {
367 name: format!("Task {task_id}"),
368 description: format!("Multi-task learning task {task_id} ({task_type:?})"),
369 n_samples,
370 n_features: task_data.ncols(),
371 task_type: match task_type {
372 TaskType::Classification(_) => "classification".to_string(),
373 TaskType::Regression => "regression".to_string(),
374 TaskType::Ordinal(_) => "ordinal_regression".to_string(),
375 },
376 ..Default::default()
377 };
378
379 task_datasets.push(Dataset::from_metadata(
380 task_data,
381 Some(task_target),
382 task_metadata,
383 ));
384 }
385
386 Ok(MultiTaskDataset {
387 tasks: task_datasets,
388 shared_features: config.shared_features,
389 task_correlation: config.task_correlation,
390 })
391 }
392
393 pub fn make_domain_adaptation_dataset(
395 &self,
396 n_samples_per_domain: usize,
397 n_features: usize,
398 n_classes: usize,
399 config: DomainAdaptationConfig,
400 ) -> Result<DomainAdaptationDataset> {
401 let total_domains = config.n_source_domains + 1; println!(
404 "Generating _domain adaptation dataset: {total_domains} domains, {n_samples_per_domain} samples each"
405 );
406
407 let mut domain_datasets = Vec::new();
408
409 let source_dataset =
411 self.generate_base_domain_dataset(n_samples_per_domain, n_features, n_classes)?;
412
413 domain_datasets.push(("source".to_string(), source_dataset.clone()));
414
415 for domain_id in 1..config.n_source_domains {
417 let shift = if domain_id - 1 < config.domain_shifts.len() {
418 &config.domain_shifts[domain_id - 1]
419 } else {
420 &DomainShift {
422 mean_shift: Array1::from_elem(n_features, 0.5),
423 covariance_shift: None,
424 shift_strength: 1.0,
425 }
426 };
427
428 let shifted_dataset = self.apply_domain_shift(&source_dataset, shift)?;
429 domain_datasets.push((format!("source_{domain_id}"), shifted_dataset));
430 }
431
432 let target_shift = DomainShift {
434 mean_shift: Array1::from_elem(n_features, 1.0),
435 covariance_shift: None,
436 shift_strength: 1.5,
437 };
438
439 let target_dataset = self.apply_domain_shift(&source_dataset, &target_shift)?;
440 domain_datasets.push(("target".to_string(), target_dataset));
441
442 Ok(DomainAdaptationDataset {
443 domains: domain_datasets,
444 n_source_domains: config.n_source_domains,
445 })
446 }
447
448 pub fn make_few_shot_dataset(
450 &self,
451 n_way: usize,
452 k_shot: usize,
453 n_query: usize,
454 n_episodes: usize,
455 n_features: usize,
456 ) -> Result<FewShotDataset> {
457 println!(
458 "Generating few-_shot dataset: {n_way}-_way {k_shot}-_shot, {n_episodes} _episodes"
459 );
460
461 let mut episodes = Vec::new();
462
463 for episode_id in 0..n_episodes {
464 let support_set = self.generate_support_set(n_way, k_shot, n_features, episode_id)?;
465 let query_set =
466 self.generate_query_set(n_way, n_query, n_features, &support_set, episode_id)?;
467
468 episodes.push(FewShotEpisode {
469 support_set,
470 query_set,
471 n_way,
472 k_shot,
473 });
474 }
475
476 Ok(FewShotDataset {
477 episodes,
478 n_way,
479 k_shot,
480 n_query,
481 })
482 }
483
484 pub fn make_continual_learning_dataset(
486 &self,
487 n_tasks: usize,
488 n_samples_per_task: usize,
489 n_features: usize,
490 n_classes: usize,
491 concept_drift_strength: f64,
492 ) -> Result<ContinualLearningDataset> {
493 println!("Generating continual learning dataset: {n_tasks} _tasks with concept drift");
494
495 let mut task_datasets = Vec::new();
496 let mut base_centers = self.generate_class_centers(n_classes, n_features)?;
497
498 for task_id in 0..n_tasks {
499 if task_id > 0 {
501 let drift = Array2::from_shape_fn((n_classes, n_features), |_| {
502 thread_rng().random::<f64>() * concept_drift_strength
503 });
504 base_centers = base_centers + drift;
505 }
506
507 let task_dataset = self.generate_classification_from_centers(
508 n_samples_per_task,
509 &base_centers,
510 1.0, task_id as u64,
512 )?;
513
514 let mut metadata = task_dataset.metadata.clone();
515 metadata.insert(
516 "name".to_string(),
517 format!("Continual Learning Task {task_id}"),
518 );
519 metadata.insert(
520 "description".to_string(),
521 format!("Task {task_id} with concept drift _strength {concept_drift_strength:.2}"),
522 );
523
524 task_datasets.push(Dataset {
525 data: task_dataset.data,
526 target: task_dataset.target,
527 targetnames: task_dataset.targetnames,
528 featurenames: task_dataset.featurenames,
529 feature_descriptions: task_dataset.feature_descriptions,
530 description: task_dataset.description,
531 metadata,
532 });
533 }
534
535 Ok(ContinualLearningDataset {
536 tasks: task_datasets,
537 concept_drift_strength,
538 })
539 }
540
541 fn generate_perturbations(
544 &self,
545 data: &Array2<f64>,
546 config: &AdversarialConfig,
547 ) -> Result<Array2<f64>> {
548 let (n_samples, n_features) = data.dim();
549
550 match config.attack_method {
551 AttackMethod::FGSM => {
552 let mut perturbations = Array2::zeros((n_samples, n_features));
554 for i in 0..n_samples {
555 for j in 0..n_features {
556 let sign = if thread_rng().random::<f64>() > 0.5 {
557 1.0
558 } else {
559 -1.0
560 };
561 perturbations[[i, j]] = config.epsilon * sign;
562 }
563 }
564 Ok(perturbations)
565 }
566 AttackMethod::PGD => {
567 let mut perturbations: Array2<f64> = Array2::zeros((n_samples, n_features));
569 for _iter in 0..config.iterations {
570 for i in 0..n_samples {
571 for j in 0..n_features {
572 let gradient = thread_rng().random::<f64>() * 2.0 - 1.0; perturbations[[i, j]] += config.step_size * gradient.signum();
574 perturbations[[i, j]] =
576 perturbations[[i, j]].clamp(-config.epsilon, config.epsilon);
577 }
578 }
579 }
580 Ok(perturbations)
581 }
582 AttackMethod::RandomNoise => {
583 let perturbations = Array2::from_shape_fn((n_samples, n_features), |_| {
585 (thread_rng().random::<f64>() * 2.0 - 1.0) * config.epsilon
586 });
587 Ok(perturbations)
588 }
589 _ => {
590 let mut perturbations = Array2::zeros(data.dim());
592 for i in 0..data.nrows() {
593 for j in 0..data.ncols() {
594 let noise = thread_rng().random::<f64>() * 2.0 - 1.0;
595 perturbations[[i, j]] = config.epsilon * noise;
596 }
597 }
598 Ok(perturbations)
599 }
600 }
601 }
602
603 fn generate_normal_data(
604 &self,
605 n_samples: usize,
606 n_features: usize,
607 clustering_factor: f64,
608 ) -> Result<Array2<f64>> {
609 use crate::generators::make_blobs;
611 let n_clusters = ((n_features as f64).sqrt() as usize).max(2);
612 let dataset = make_blobs(
613 n_samples,
614 n_features,
615 n_clusters,
616 clustering_factor,
617 self.random_state,
618 )?;
619 Ok(dataset.data)
620 }
621
622 fn generate_anomalous_data(
623 &self,
624 n_anomalies: usize,
625 n_features: usize,
626 normal_data: &Array2<f64>,
627 config: &AnomalyConfig,
628 ) -> Result<Array2<f64>> {
629 use scirs2_core::random::Rng;
630 let mut rng = thread_rng();
631
632 match config.anomaly_type {
633 AnomalyType::Point => {
634 let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
636 DatasetsError::ComputationError(
637 "Failed to compute mean for normal data".to_string(),
638 )
639 })?;
640 let normal_std = normal_data.std_axis(Axis(0), 0.0);
641
642 let mut anomalies = Array2::zeros((n_anomalies, n_features));
643 for i in 0..n_anomalies {
644 for j in 0..n_features {
645 let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
646 anomalies[[i, j]] =
647 normal_mean[j] + direction * config.severity * normal_std[j];
648 }
649 }
650 Ok(anomalies)
651 }
652 AnomalyType::Contextual => {
653 let mut anomalies: Array2<f64> = Array2::zeros((n_anomalies, n_features));
655 for i in 0..n_anomalies {
656 let base_idx = rng.sample(Uniform::new(0, normal_data.nrows()).unwrap());
658 let mut anomaly = normal_data.row(base_idx).to_owned();
659
660 let n_permute = (n_features as f64 * 0.3) as usize;
662 for _ in 0..n_permute {
663 let j = rng.sample(Uniform::new(0, n_features).unwrap());
664 let k = rng.sample(Uniform::new(0, n_features).unwrap());
665 let temp = anomaly[j];
666 anomaly[j] = anomaly[k];
667 anomaly[k] = temp;
668 }
669
670 anomalies.row_mut(i).assign(&anomaly);
671 }
672 Ok(anomalies)
673 }
674 _ => {
675 let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
677 DatasetsError::ComputationError(
678 "Failed to compute mean for normal data".to_string(),
679 )
680 })?;
681 let normal_std = normal_data.std_axis(Axis(0), 0.0);
682
683 let mut anomalies = Array2::zeros((n_anomalies, n_features));
684 for i in 0..n_anomalies {
685 for j in 0..n_features {
686 let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
687 anomalies[[i, j]] =
688 normal_mean[j] + direction * config.severity * normal_std[j];
689 }
690 }
691 Ok(anomalies)
692 }
693 }
694 }
695
696 fn generate_shuffle_indices(&self, n_samples: usize) -> Result<Vec<usize>> {
697 use scirs2_core::random::Rng;
698 let mut rng = thread_rng();
699 let mut indices: Vec<usize> = (0..n_samples).collect();
700
701 for i in (1..n_samples).rev() {
703 let j = rng.sample(Uniform::new(0, i).unwrap());
704 indices.swap(i, j);
705 }
706
707 Ok(indices)
708 }
709
710 fn shuffle_by_indices(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
711 let mut shuffled = Array2::zeros(data.dim());
712 for (new_idx, &old_idx) in indices.iter().enumerate() {
713 shuffled.row_mut(new_idx).assign(&data.row(old_idx));
714 }
715 shuffled
716 }
717
718 fn shuffle_array_by_indices(&self, array: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
719 let mut shuffled = Array1::zeros(array.len());
720 for (new_idx, &old_idx) in indices.iter().enumerate() {
721 shuffled[new_idx] = array[old_idx];
722 }
723 shuffled
724 }
725
726 fn generate_shared_features(&self, n_samples: usize, n_features: usize) -> Result<Array2<f64>> {
727 let data = Array2::from_shape_fn((n_samples, n_features), |_| {
729 thread_rng().random::<f64>() * 2.0 - 1.0 });
731 Ok(data)
732 }
733
734 fn generate_task_specific_features(
735 &self,
736 n_samples: usize,
737 n_features: usize,
738 task_id: usize,
739 ) -> Result<Array2<f64>> {
740 let task_bias = task_id as f64 * 0.1;
742 let data = Array2::from_shape_fn((n_samples, n_features), |_| {
743 thread_rng().random::<f64>() * 2.0 - 1.0 + task_bias
744 });
745 Ok(data)
746 }
747
748 fn combine_features(&self, shared: &Array2<f64>, task_specific: &Array2<f64>) -> Array2<f64> {
749 let n_samples = shared.nrows();
750 let total_features = shared.ncols() + task_specific.ncols();
751 let mut combined = Array2::zeros((n_samples, total_features));
752
753 combined
754 .slice_mut(scirs2_core::ndarray::s![.., ..shared.ncols()])
755 .assign(shared);
756 combined
757 .slice_mut(scirs2_core::ndarray::s![.., shared.ncols()..])
758 .assign(task_specific);
759
760 combined
761 }
762
763 fn generate_task_target(
764 &self,
765 data: &Array2<f64>,
766 task_type: &TaskType,
767 correlation: f64,
768 noise: &f64,
769 ) -> Result<Array1<f64>> {
770 let n_samples = data.nrows();
771
772 match task_type {
773 TaskType::Classification(n_classes) => {
774 let target = Array1::from_shape_fn(n_samples, |i| {
776 let feature_sum = data.row(i).sum();
777 let class = ((feature_sum * correlation).abs() as usize) % n_classes;
778 class as f64
779 });
780 Ok(target)
781 }
782 TaskType::Regression => {
783 let target = Array1::from_shape_fn(n_samples, |i| {
785 let weighted_sum = data
786 .row(i)
787 .iter()
788 .enumerate()
789 .map(|(j, &x)| x * (j as f64 + 1.0) * correlation)
790 .sum::<f64>();
791 weighted_sum + thread_rng().random::<f64>() * noise
792 });
793 Ok(target)
794 }
795 TaskType::Ordinal(n_levels) => {
796 let target = Array1::from_shape_fn(n_samples, |i| {
798 let feature_sum = data.row(i).sum();
799 let level = ((feature_sum * correlation).abs() as usize) % n_levels;
800 level as f64
801 });
802 Ok(target)
803 }
804 }
805 }
806
807 fn generate_base_domain_dataset(
808 &self,
809 n_samples: usize,
810 n_features: usize,
811 n_classes: usize,
812 ) -> Result<Dataset> {
813 use crate::generators::make_classification;
814 make_classification(
815 n_samples,
816 n_features,
817 n_classes,
818 2,
819 n_features / 2,
820 self.random_state,
821 )
822 }
823
824 fn apply_domain_shift(&self, base_dataset: &Dataset, shift: &DomainShift) -> Result<Dataset> {
825 let shifted_data = &base_dataset.data + &shift.mean_shift;
826
827 let mut metadata = base_dataset.metadata.clone();
828 let old_description = metadata.get("description").cloned().unwrap_or_default();
829 metadata.insert(
830 "description".to_string(),
831 format!("{old_description} (Domain Shifted)"),
832 );
833
834 Ok(Dataset {
835 data: shifted_data,
836 target: base_dataset.target.clone(),
837 targetnames: base_dataset.targetnames.clone(),
838 featurenames: base_dataset.featurenames.clone(),
839 feature_descriptions: base_dataset.feature_descriptions.clone(),
840 description: base_dataset.description.clone(),
841 metadata,
842 })
843 }
844
845 fn generate_support_set(
846 &self,
847 n_way: usize,
848 k_shot: usize,
849 n_features: usize,
850 episode_id: usize,
851 ) -> Result<Dataset> {
852 let n_samples = n_way * k_shot;
853 use crate::generators::make_classification;
854 make_classification(
855 n_samples,
856 n_features,
857 n_way,
858 1,
859 n_features / 2,
860 Some(episode_id as u64),
861 )
862 }
863
864 fn generate_query_set(
865 &self,
866 n_way: usize,
867 n_query: usize,
868 n_features: usize,
869 _set: &Dataset,
870 episode_id: usize,
871 ) -> Result<Dataset> {
872 let n_samples = n_way * n_query;
873 use crate::generators::make_classification;
874 make_classification(
875 n_samples,
876 n_features,
877 n_way,
878 1,
879 n_features / 2,
880 Some(episode_id as u64 + 1000),
881 )
882 }
883
884 fn generate_class_centers(&self, n_classes: usize, n_features: usize) -> Result<Array2<f64>> {
885 let centers = Array2::from_shape_fn((n_classes, n_features), |_| {
886 thread_rng().random::<f64>() * 4.0 - 2.0
887 });
888 Ok(centers)
889 }
890
891 fn generate_classification_from_centers(
892 &self,
893 n_samples: usize,
894 centers: &Array2<f64>,
895 cluster_std: f64,
896 seed: u64,
897 ) -> Result<Dataset> {
898 use crate::generators::make_blobs;
899 make_blobs(
900 n_samples,
901 centers.ncols(),
902 centers.nrows(),
903 cluster_std,
904 Some(seed),
905 )
906 }
907}
908
909#[derive(Debug)]
911pub struct MultiTaskDataset {
912 pub tasks: Vec<Dataset>,
914 pub shared_features: usize,
916 pub task_correlation: f64,
918}
919
920#[derive(Debug)]
922pub struct DomainAdaptationDataset {
923 pub domains: Vec<(String, Dataset)>,
925 pub n_source_domains: usize,
927}
928
929#[derive(Debug)]
931pub struct FewShotEpisode {
932 pub support_set: Dataset,
934 pub query_set: Dataset,
936 pub n_way: usize,
938 pub k_shot: usize,
940}
941
942#[derive(Debug)]
944pub struct FewShotDataset {
945 pub episodes: Vec<FewShotEpisode>,
947 pub n_way: usize,
949 pub k_shot: usize,
951 pub n_query: usize,
953}
954
955#[derive(Debug)]
957pub struct ContinualLearningDataset {
958 pub tasks: Vec<Dataset>,
960 pub concept_drift_strength: f64,
962}
963
964#[allow(dead_code)]
968pub fn make_adversarial_examples(
969 base_dataset: &Dataset,
970 config: AdversarialConfig,
971) -> Result<Dataset> {
972 let generator = AdvancedGenerator::new(config.random_state);
973 generator.make_adversarial_examples(base_dataset, config)
974}
975
976#[allow(dead_code)]
978pub fn make_anomaly_dataset(
979 n_samples: usize,
980 n_features: usize,
981 config: AnomalyConfig,
982) -> Result<Dataset> {
983 let generator = AdvancedGenerator::new(config.random_state);
984 generator.make_anomaly_dataset(n_samples, n_features, config)
985}
986
987#[allow(dead_code)]
989pub fn make_multitask_dataset(
990 n_samples: usize,
991 config: MultiTaskConfig,
992) -> Result<MultiTaskDataset> {
993 let generator = AdvancedGenerator::new(config.random_state);
994 generator.make_multitask_dataset(n_samples, config)
995}
996
997#[allow(dead_code)]
999pub fn make_domain_adaptation_dataset(
1000 n_samples_per_domain: usize,
1001 n_features: usize,
1002 n_classes: usize,
1003 config: DomainAdaptationConfig,
1004) -> Result<DomainAdaptationDataset> {
1005 let generator = AdvancedGenerator::new(config.random_state);
1006 generator.make_domain_adaptation_dataset(n_samples_per_domain, n_features, n_classes, config)
1007}
1008
1009#[allow(dead_code)]
1011pub fn make_few_shot_dataset(
1012 n_way: usize,
1013 k_shot: usize,
1014 n_query: usize,
1015 n_episodes: usize,
1016 n_features: usize,
1017) -> Result<FewShotDataset> {
1018 let generator = AdvancedGenerator::new(Some(42));
1019 generator.make_few_shot_dataset(n_way, k_shot, n_query, n_episodes, n_features)
1020}
1021
1022#[allow(dead_code)]
1024pub fn make_continual_learning_dataset(
1025 n_tasks: usize,
1026 n_samples_per_task: usize,
1027 n_features: usize,
1028 n_classes: usize,
1029 concept_drift_strength: f64,
1030) -> Result<ContinualLearningDataset> {
1031 let generator = AdvancedGenerator::new(Some(42));
1032 generator.make_continual_learning_dataset(
1033 n_tasks,
1034 n_samples_per_task,
1035 n_features,
1036 n_classes,
1037 concept_drift_strength,
1038 )
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use super::*;
1044 use crate::generators::make_classification;
1045
1046 #[test]
1047 fn test_adversarial_config() {
1048 let config = AdversarialConfig::default();
1049 assert_eq!(config.epsilon, 0.1);
1050 assert_eq!(config.attack_method, AttackMethod::FGSM);
1051 assert_eq!(config.iterations, 10);
1052 }
1053
1054 #[test]
1055 fn test_anomaly_dataset_generation() {
1056 let config = AnomalyConfig {
1057 anomaly_fraction: 0.2,
1058 anomaly_type: AnomalyType::Point,
1059 severity: 2.0,
1060 ..Default::default()
1061 };
1062
1063 let dataset = make_anomaly_dataset(100, 10, config).unwrap();
1064
1065 assert_eq!(dataset.n_samples(), 100);
1066 assert_eq!(dataset.n_features(), 10);
1067 assert!(dataset.target.is_some());
1068
1069 let target = dataset.target.unwrap();
1071 let anomalies = target.iter().filter(|&&x| x == 1.0).count();
1072 assert!(anomalies > 0);
1073 assert!(anomalies < 100);
1074 }
1075
1076 #[test]
1077 fn test_multitask_dataset_generation() {
1078 let config = MultiTaskConfig {
1079 n_tasks: 2,
1080 task_types: vec![TaskType::Classification(3), TaskType::Regression],
1081 shared_features: 5,
1082 task_specific_features: 3,
1083 ..Default::default()
1084 };
1085
1086 let dataset = make_multitask_dataset(50, config).unwrap();
1087
1088 assert_eq!(dataset.tasks.len(), 2);
1089 assert_eq!(dataset.shared_features, 5);
1090
1091 for task in &dataset.tasks {
1092 assert_eq!(task.n_samples(), 50);
1093 assert!(task.target.is_some());
1094 }
1095 }
1096
1097 #[test]
1098 fn test_adversarial_examples_generation() {
1099 let base_dataset = make_classification(100, 10, 3, 2, 8, Some(42)).unwrap();
1100 let config = AdversarialConfig {
1101 epsilon: 0.1,
1102 attack_method: AttackMethod::FGSM,
1103 ..Default::default()
1104 };
1105
1106 let adversarial_dataset = make_adversarial_examples(&base_dataset, config).unwrap();
1107
1108 assert_eq!(adversarial_dataset.n_samples(), base_dataset.n_samples());
1109 assert_eq!(adversarial_dataset.n_features(), base_dataset.n_features());
1110
1111 let diff = &adversarial_dataset.data - &base_dataset.data;
1113 let mut max_abs = 0.0_f64;
1114 for v in diff.iter() {
1115 let a = v.abs();
1116 if a > max_abs {
1117 max_abs = a;
1118 }
1119 }
1120 assert!(
1121 max_abs > 1e-9,
1122 "Adversarial perturbation appears to be zero (max abs diff = {max_abs})"
1123 );
1124 }
1125
1126 #[test]
1127 fn test_few_shot_dataset() {
1128 let dataset = make_few_shot_dataset(5, 3, 10, 2, 20).unwrap();
1129
1130 assert_eq!(dataset.n_way, 5);
1131 assert_eq!(dataset.k_shot, 3);
1132 assert_eq!(dataset.n_query, 10);
1133 assert_eq!(dataset.episodes.len(), 2);
1134
1135 for episode in &dataset.episodes {
1136 assert_eq!(episode.n_way, 5);
1137 assert_eq!(episode.k_shot, 3);
1138 assert_eq!(episode.support_set.n_samples(), 5 * 3); assert_eq!(episode.query_set.n_samples(), 5 * 10); }
1141 }
1142
1143 #[test]
1144 fn test_continual_learning_dataset() {
1145 let dataset = make_continual_learning_dataset(3, 100, 10, 5, 0.5).unwrap();
1146
1147 assert_eq!(dataset.tasks.len(), 3);
1148 assert_eq!(dataset.concept_drift_strength, 0.5);
1149
1150 for task in &dataset.tasks {
1151 assert_eq!(task.n_samples(), 100);
1152 assert_eq!(task.n_features(), 10);
1153 assert!(task.target.is_some());
1154 }
1155 }
1156}