1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::Uniform;
12
13#[derive(Debug, Clone)]
15pub struct AdversarialConfig {
16 pub epsilon: f64,
18 pub attack_method: AttackMethod,
20 pub target_class: Option<usize>,
22 pub iterations: usize,
24 pub step_size: f64,
26 pub random_state: Option<u64>,
28}
29
30#[derive(Debug, Clone, PartialEq)]
32pub enum AttackMethod {
33 FGSM,
35 PGD,
37 CW,
39 DeepFool,
41 RandomNoise,
43}
44
45impl Default for AdversarialConfig {
46 fn default() -> Self {
47 Self {
48 epsilon: 0.1,
49 attack_method: AttackMethod::FGSM,
50 target_class: None,
51 iterations: 10,
52 step_size: 0.01,
53 random_state: None,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct AnomalyConfig {
61 pub anomaly_fraction: f64,
63 pub anomaly_type: AnomalyType,
65 pub severity: f64,
67 pub mixed_anomalies: bool,
69 pub clustering_factor: f64,
71 pub random_state: Option<u64>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub enum AnomalyType {
78 Point,
80 Contextual,
82 Collective,
84 Adversarial,
86 Mixed,
88}
89
90impl Default for AnomalyConfig {
91 fn default() -> Self {
92 Self {
93 anomaly_fraction: 0.1,
94 anomaly_type: AnomalyType::Point,
95 severity: 2.0,
96 mixed_anomalies: false,
97 clustering_factor: 1.0,
98 random_state: None,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct MultiTaskConfig {
106 pub n_tasks: usize,
108 pub task_types: Vec<TaskType>,
110 pub shared_features: usize,
112 pub task_specific_features: usize,
114 pub task_correlation: f64,
116 pub task_noise: Vec<f64>,
118 pub random_state: Option<u64>,
120}
121
122#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125 Classification(usize),
127 Regression,
129 Ordinal(usize),
131}
132
133impl Default for MultiTaskConfig {
134 fn default() -> Self {
135 Self {
136 n_tasks: 3,
137 task_types: vec![
138 TaskType::Classification(3),
139 TaskType::Regression,
140 TaskType::Classification(5),
141 ],
142 shared_features: 10,
143 task_specific_features: 5,
144 task_correlation: 0.5,
145 task_noise: vec![0.1, 0.1, 0.1],
146 random_state: None,
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct DomainAdaptationConfig {
154 pub n_source_domains: usize,
156 pub domain_shifts: Vec<DomainShift>,
158 pub label_shift: bool,
160 pub feature_shift: bool,
162 pub concept_drift: bool,
164 pub random_state: Option<u64>,
166}
167
168#[derive(Debug, Clone)]
170pub struct DomainShift {
171 pub mean_shift: Array1<f64>,
173 pub covariance_shift: Option<Array2<f64>>,
175 pub shift_strength: f64,
177}
178
179impl Default for DomainAdaptationConfig {
180 fn default() -> Self {
181 Self {
182 n_source_domains: 2,
183 domain_shifts: vec![],
184 label_shift: true,
185 feature_shift: true,
186 concept_drift: false,
187 random_state: None,
188 }
189 }
190}
191
192pub struct AdvancedGenerator {
194 random_state: Option<u64>,
195}
196
197impl AdvancedGenerator {
198 pub fn new(_random_state: Option<u64>) -> Self {
200 Self {
201 random_state: _random_state,
202 }
203 }
204
205 pub fn make_adversarial_examples(
207 &self,
208 base_dataset: &Dataset,
209 config: AdversarialConfig,
210 ) -> Result<Dataset> {
211 let n_samples = base_dataset.n_samples();
212 let _n_features = base_dataset.n_features();
213
214 println!(
215 "Generating adversarial examples using {:?}",
216 config.attack_method
217 );
218
219 let perturbations = self.generate_perturbations(&base_dataset.data, &config)?;
221
222 let adversarial_data = &base_dataset.data + &perturbations;
224
225 let clipped_data = adversarial_data.mapv(|x| x.clamp(-5.0, 5.0));
227
228 let adversarial_target = if let Some(target) = &base_dataset.target {
230 match config.target_class {
231 Some(target_class) => {
232 Some(Array1::from_elem(n_samples, target_class as f64))
234 }
235 None => {
236 Some(target.clone())
238 }
239 }
240 } else {
241 None
242 };
243
244 let mut metadata = base_dataset.metadata.clone();
245 let _old_description = metadata.get("description").cloned().unwrap_or_default();
246 let oldname = metadata.get("name").cloned().unwrap_or_default();
247
248 metadata.insert(
249 "description".to_string(),
250 format!(
251 "Adversarial examples generated using {:?}",
252 config.attack_method
253 ),
254 );
255 metadata.insert("name".to_string(), format!("{oldname} (Adversarial)"));
256
257 Ok(Dataset {
258 data: clipped_data,
259 target: adversarial_target,
260 targetnames: base_dataset.targetnames.clone(),
261 featurenames: base_dataset.featurenames.clone(),
262 feature_descriptions: base_dataset.feature_descriptions.clone(),
263 description: base_dataset.description.clone(),
264 metadata,
265 })
266 }
267
268 pub fn make_anomaly_dataset(
270 &self,
271 n_samples: usize,
272 n_features: usize,
273 config: AnomalyConfig,
274 ) -> Result<Dataset> {
275 let n_anomalies = (n_samples as f64 * config.anomaly_fraction) as usize;
276 let n_normal = n_samples - n_anomalies;
277
278 println!("Generating anomaly dataset: {n_normal} normal, {n_anomalies} anomalous");
279
280 let normal_data =
282 self.generate_normal_data(n_normal, n_features, config.clustering_factor)?;
283
284 let anomalous_data =
286 self.generate_anomalous_data(n_anomalies, n_features, &normal_data, &config)?;
287
288 let mut combined_data = Array2::zeros((n_samples, n_features));
290 combined_data
291 .slice_mut(scirs2_core::ndarray::s![..n_normal, ..])
292 .assign(&normal_data);
293 combined_data
294 .slice_mut(scirs2_core::ndarray::s![n_normal.., ..])
295 .assign(&anomalous_data);
296
297 let mut target = Array1::zeros(n_samples);
299 target
300 .slice_mut(scirs2_core::ndarray::s![n_normal..])
301 .fill(1.0);
302
303 let shuffled_indices = self.generate_shuffle_indices(n_samples)?;
305 let shuffled_data = self.shuffle_by_indices(&combined_data, &shuffled_indices);
306 let shuffled_target = self.shuffle_array_by_indices(&target, &shuffled_indices);
307
308 let metadata = crate::registry::DatasetMetadata {
309 name: "Anomaly Detection Dataset".to_string(),
310 description: format!(
311 "Synthetic anomaly detection dataset with {:.1}% anomalies",
312 config.anomaly_fraction * 100.0
313 ),
314 n_samples,
315 n_features,
316 task_type: "anomaly_detection".to_string(),
317 targetnames: Some(vec!["normal".to_string(), "anomaly".to_string()]),
318 ..Default::default()
319 };
320
321 Ok(Dataset::from_metadata(
322 shuffled_data,
323 Some(shuffled_target),
324 metadata,
325 ))
326 }
327
328 pub fn make_multitask_dataset(
330 &self,
331 n_samples: usize,
332 config: MultiTaskConfig,
333 ) -> Result<MultiTaskDataset> {
334 let total_features =
335 config.shared_features + config.task_specific_features * config.n_tasks;
336
337 println!(
338 "Generating multi-task dataset: {} tasks, {} samples, {} features",
339 config.n_tasks, n_samples, total_features
340 );
341
342 let shared_data = self.generate_shared_features(n_samples, config.shared_features)?;
344
345 let mut task_datasets = Vec::new();
347
348 for (task_id, task_type) in config.task_types.iter().enumerate() {
349 let task_specific_data = self.generate_task_specific_features(
350 n_samples,
351 config.task_specific_features,
352 task_id,
353 )?;
354
355 let task_data = self.combine_features(&shared_data, &task_specific_data);
357
358 let task_target = self.generate_task_target(
360 &task_data,
361 task_type,
362 config.task_correlation,
363 config.task_noise.get(task_id).unwrap_or(&0.1),
364 )?;
365
366 let task_metadata = crate::registry::DatasetMetadata {
367 name: format!("Task {task_id}"),
368 description: format!("Multi-task learning task {task_id} ({task_type:?})"),
369 n_samples,
370 n_features: task_data.ncols(),
371 task_type: match task_type {
372 TaskType::Classification(_) => "classification".to_string(),
373 TaskType::Regression => "regression".to_string(),
374 TaskType::Ordinal(_) => "ordinal_regression".to_string(),
375 },
376 ..Default::default()
377 };
378
379 task_datasets.push(Dataset::from_metadata(
380 task_data,
381 Some(task_target),
382 task_metadata,
383 ));
384 }
385
386 Ok(MultiTaskDataset {
387 tasks: task_datasets,
388 shared_features: config.shared_features,
389 task_correlation: config.task_correlation,
390 })
391 }
392
393 pub fn make_domain_adaptation_dataset(
395 &self,
396 n_samples_per_domain: usize,
397 n_features: usize,
398 n_classes: usize,
399 config: DomainAdaptationConfig,
400 ) -> Result<DomainAdaptationDataset> {
401 let total_domains = config.n_source_domains + 1; println!(
404 "Generating _domain adaptation dataset: {total_domains} domains, {n_samples_per_domain} samples each"
405 );
406
407 let mut domain_datasets = Vec::new();
408
409 let source_dataset =
411 self.generate_base_domain_dataset(n_samples_per_domain, n_features, n_classes)?;
412
413 domain_datasets.push(("source".to_string(), source_dataset.clone()));
414
415 for domain_id in 1..config.n_source_domains {
417 let shift = if domain_id - 1 < config.domain_shifts.len() {
418 &config.domain_shifts[domain_id - 1]
419 } else {
420 &DomainShift {
422 mean_shift: Array1::from_elem(n_features, 0.5),
423 covariance_shift: None,
424 shift_strength: 1.0,
425 }
426 };
427
428 let shifted_dataset = self.apply_domain_shift(&source_dataset, shift)?;
429 domain_datasets.push((format!("source_{domain_id}"), shifted_dataset));
430 }
431
432 let target_shift = DomainShift {
434 mean_shift: Array1::from_elem(n_features, 1.0),
435 covariance_shift: None,
436 shift_strength: 1.5,
437 };
438
439 let target_dataset = self.apply_domain_shift(&source_dataset, &target_shift)?;
440 domain_datasets.push(("target".to_string(), target_dataset));
441
442 Ok(DomainAdaptationDataset {
443 domains: domain_datasets,
444 n_source_domains: config.n_source_domains,
445 })
446 }
447
448 pub fn make_few_shot_dataset(
450 &self,
451 n_way: usize,
452 k_shot: usize,
453 n_query: usize,
454 n_episodes: usize,
455 n_features: usize,
456 ) -> Result<FewShotDataset> {
457 println!(
458 "Generating few-_shot dataset: {n_way}-_way {k_shot}-_shot, {n_episodes} _episodes"
459 );
460
461 let mut episodes = Vec::new();
462
463 for episode_id in 0..n_episodes {
464 let support_set = self.generate_support_set(n_way, k_shot, n_features, episode_id)?;
465 let query_set =
466 self.generate_query_set(n_way, n_query, n_features, &support_set, episode_id)?;
467
468 episodes.push(FewShotEpisode {
469 support_set,
470 query_set,
471 n_way,
472 k_shot,
473 });
474 }
475
476 Ok(FewShotDataset {
477 episodes,
478 n_way,
479 k_shot,
480 n_query,
481 })
482 }
483
484 pub fn make_continual_learning_dataset(
486 &self,
487 n_tasks: usize,
488 n_samples_per_task: usize,
489 n_features: usize,
490 n_classes: usize,
491 concept_drift_strength: f64,
492 ) -> Result<ContinualLearningDataset> {
493 println!("Generating continual learning dataset: {n_tasks} _tasks with concept drift");
494
495 let mut task_datasets = Vec::new();
496 let mut base_centers = self.generate_class_centers(n_classes, n_features)?;
497
498 for task_id in 0..n_tasks {
499 if task_id > 0 {
501 let drift = Array2::from_shape_fn((n_classes, n_features), |_| {
502 thread_rng().random::<f64>() * concept_drift_strength
503 });
504 base_centers = base_centers + drift;
505 }
506
507 let task_dataset = self.generate_classification_from_centers(
508 n_samples_per_task,
509 &base_centers,
510 1.0, task_id as u64,
512 )?;
513
514 let mut metadata = task_dataset.metadata.clone();
515 metadata.insert(
516 "name".to_string(),
517 format!("Continual Learning Task {task_id}"),
518 );
519 metadata.insert(
520 "description".to_string(),
521 format!("Task {task_id} with concept drift _strength {concept_drift_strength:.2}"),
522 );
523
524 task_datasets.push(Dataset {
525 data: task_dataset.data,
526 target: task_dataset.target,
527 targetnames: task_dataset.targetnames,
528 featurenames: task_dataset.featurenames,
529 feature_descriptions: task_dataset.feature_descriptions,
530 description: task_dataset.description,
531 metadata,
532 });
533 }
534
535 Ok(ContinualLearningDataset {
536 tasks: task_datasets,
537 concept_drift_strength,
538 })
539 }
540
541 fn generate_perturbations(
544 &self,
545 data: &Array2<f64>,
546 config: &AdversarialConfig,
547 ) -> Result<Array2<f64>> {
548 let (n_samples, n_features) = data.dim();
549
550 match config.attack_method {
551 AttackMethod::FGSM => {
552 let mut perturbations = Array2::zeros((n_samples, n_features));
554 for i in 0..n_samples {
555 for j in 0..n_features {
556 let sign = if thread_rng().random::<f64>() > 0.5 {
557 1.0
558 } else {
559 -1.0
560 };
561 perturbations[[i, j]] = config.epsilon * sign;
562 }
563 }
564 Ok(perturbations)
565 }
566 AttackMethod::PGD => {
567 let mut perturbations: Array2<f64> = Array2::zeros((n_samples, n_features));
569 for _iter in 0..config.iterations {
570 for i in 0..n_samples {
571 for j in 0..n_features {
572 let gradient = thread_rng().random::<f64>() * 2.0 - 1.0; perturbations[[i, j]] += config.step_size * gradient.signum();
574 perturbations[[i, j]] =
576 perturbations[[i, j]].clamp(-config.epsilon, config.epsilon);
577 }
578 }
579 }
580 Ok(perturbations)
581 }
582 AttackMethod::RandomNoise => {
583 let perturbations = Array2::from_shape_fn((n_samples, n_features), |_| {
585 (thread_rng().random::<f64>() * 2.0 - 1.0) * config.epsilon
586 });
587 Ok(perturbations)
588 }
589 _ => {
590 let mut perturbations = Array2::zeros(data.dim());
592 for i in 0..data.nrows() {
593 for j in 0..data.ncols() {
594 let noise = thread_rng().random::<f64>() * 2.0 - 1.0;
595 perturbations[[i, j]] = config.epsilon * noise;
596 }
597 }
598 Ok(perturbations)
599 }
600 }
601 }
602
603 fn generate_normal_data(
604 &self,
605 n_samples: usize,
606 n_features: usize,
607 clustering_factor: f64,
608 ) -> Result<Array2<f64>> {
609 use crate::generators::make_blobs;
611 let n_clusters = ((n_features as f64).sqrt() as usize).max(2);
612 let dataset = make_blobs(
613 n_samples,
614 n_features,
615 n_clusters,
616 clustering_factor,
617 self.random_state,
618 )?;
619 Ok(dataset.data)
620 }
621
622 fn generate_anomalous_data(
623 &self,
624 n_anomalies: usize,
625 n_features: usize,
626 normal_data: &Array2<f64>,
627 config: &AnomalyConfig,
628 ) -> Result<Array2<f64>> {
629 use scirs2_core::random::{Rng, RngExt};
630 let mut rng = thread_rng();
631
632 match config.anomaly_type {
633 AnomalyType::Point => {
634 let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
636 DatasetsError::ComputationError(
637 "Failed to compute mean for normal data".to_string(),
638 )
639 })?;
640 let normal_std = normal_data.std_axis(Axis(0), 0.0);
641
642 let mut anomalies = Array2::zeros((n_anomalies, n_features));
643 for i in 0..n_anomalies {
644 for j in 0..n_features {
645 let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
646 anomalies[[i, j]] =
647 normal_mean[j] + direction * config.severity * normal_std[j];
648 }
649 }
650 Ok(anomalies)
651 }
652 AnomalyType::Contextual => {
653 let mut anomalies: Array2<f64> = Array2::zeros((n_anomalies, n_features));
655 for i in 0..n_anomalies {
656 let base_idx =
658 rng.sample(Uniform::new(0, normal_data.nrows()).expect("Operation failed"));
659 let mut anomaly = normal_data.row(base_idx).to_owned();
660
661 let n_permute = (n_features as f64 * 0.3) as usize;
663 for _ in 0..n_permute {
664 let j = rng.sample(Uniform::new(0, n_features).expect("Operation failed"));
665 let k = rng.sample(Uniform::new(0, n_features).expect("Operation failed"));
666 let temp = anomaly[j];
667 anomaly[j] = anomaly[k];
668 anomaly[k] = temp;
669 }
670
671 anomalies.row_mut(i).assign(&anomaly);
672 }
673 Ok(anomalies)
674 }
675 _ => {
676 let normal_mean = normal_data.mean_axis(Axis(0)).ok_or_else(|| {
678 DatasetsError::ComputationError(
679 "Failed to compute mean for normal data".to_string(),
680 )
681 })?;
682 let normal_std = normal_data.std_axis(Axis(0), 0.0);
683
684 let mut anomalies = Array2::zeros((n_anomalies, n_features));
685 for i in 0..n_anomalies {
686 for j in 0..n_features {
687 let direction = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
688 anomalies[[i, j]] =
689 normal_mean[j] + direction * config.severity * normal_std[j];
690 }
691 }
692 Ok(anomalies)
693 }
694 }
695 }
696
697 fn generate_shuffle_indices(&self, n_samples: usize) -> Result<Vec<usize>> {
698 use scirs2_core::random::{Rng, RngExt};
699 let mut rng = thread_rng();
700 let mut indices: Vec<usize> = (0..n_samples).collect();
701
702 for i in (1..n_samples).rev() {
704 let j = rng.sample(Uniform::new(0, i).expect("Operation failed"));
705 indices.swap(i, j);
706 }
707
708 Ok(indices)
709 }
710
711 fn shuffle_by_indices(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
712 let mut shuffled = Array2::zeros(data.dim());
713 for (new_idx, &old_idx) in indices.iter().enumerate() {
714 shuffled.row_mut(new_idx).assign(&data.row(old_idx));
715 }
716 shuffled
717 }
718
719 fn shuffle_array_by_indices(&self, array: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
720 let mut shuffled = Array1::zeros(array.len());
721 for (new_idx, &old_idx) in indices.iter().enumerate() {
722 shuffled[new_idx] = array[old_idx];
723 }
724 shuffled
725 }
726
727 fn generate_shared_features(&self, n_samples: usize, n_features: usize) -> Result<Array2<f64>> {
728 let data = Array2::from_shape_fn((n_samples, n_features), |_| {
730 thread_rng().random::<f64>() * 2.0 - 1.0 });
732 Ok(data)
733 }
734
735 fn generate_task_specific_features(
736 &self,
737 n_samples: usize,
738 n_features: usize,
739 task_id: usize,
740 ) -> Result<Array2<f64>> {
741 let task_bias = task_id as f64 * 0.1;
743 let data = Array2::from_shape_fn((n_samples, n_features), |_| {
744 thread_rng().random::<f64>() * 2.0 - 1.0 + task_bias
745 });
746 Ok(data)
747 }
748
749 fn combine_features(&self, shared: &Array2<f64>, task_specific: &Array2<f64>) -> Array2<f64> {
750 let n_samples = shared.nrows();
751 let total_features = shared.ncols() + task_specific.ncols();
752 let mut combined = Array2::zeros((n_samples, total_features));
753
754 combined
755 .slice_mut(scirs2_core::ndarray::s![.., ..shared.ncols()])
756 .assign(shared);
757 combined
758 .slice_mut(scirs2_core::ndarray::s![.., shared.ncols()..])
759 .assign(task_specific);
760
761 combined
762 }
763
764 fn generate_task_target(
765 &self,
766 data: &Array2<f64>,
767 task_type: &TaskType,
768 correlation: f64,
769 noise: &f64,
770 ) -> Result<Array1<f64>> {
771 let n_samples = data.nrows();
772
773 match task_type {
774 TaskType::Classification(n_classes) => {
775 let target = Array1::from_shape_fn(n_samples, |i| {
777 let feature_sum = data.row(i).sum();
778 let class = ((feature_sum * correlation).abs() as usize) % n_classes;
779 class as f64
780 });
781 Ok(target)
782 }
783 TaskType::Regression => {
784 let target = Array1::from_shape_fn(n_samples, |i| {
786 let weighted_sum = data
787 .row(i)
788 .iter()
789 .enumerate()
790 .map(|(j, &x)| x * (j as f64 + 1.0) * correlation)
791 .sum::<f64>();
792 weighted_sum + thread_rng().random::<f64>() * noise
793 });
794 Ok(target)
795 }
796 TaskType::Ordinal(n_levels) => {
797 let target = Array1::from_shape_fn(n_samples, |i| {
799 let feature_sum = data.row(i).sum();
800 let level = ((feature_sum * correlation).abs() as usize) % n_levels;
801 level as f64
802 });
803 Ok(target)
804 }
805 }
806 }
807
808 fn generate_base_domain_dataset(
809 &self,
810 n_samples: usize,
811 n_features: usize,
812 n_classes: usize,
813 ) -> Result<Dataset> {
814 use crate::generators::make_classification;
815 make_classification(
816 n_samples,
817 n_features,
818 n_classes,
819 2,
820 n_features / 2,
821 self.random_state,
822 )
823 }
824
825 fn apply_domain_shift(&self, base_dataset: &Dataset, shift: &DomainShift) -> Result<Dataset> {
826 let shifted_data = &base_dataset.data + &shift.mean_shift;
827
828 let mut metadata = base_dataset.metadata.clone();
829 let old_description = metadata.get("description").cloned().unwrap_or_default();
830 metadata.insert(
831 "description".to_string(),
832 format!("{old_description} (Domain Shifted)"),
833 );
834
835 Ok(Dataset {
836 data: shifted_data,
837 target: base_dataset.target.clone(),
838 targetnames: base_dataset.targetnames.clone(),
839 featurenames: base_dataset.featurenames.clone(),
840 feature_descriptions: base_dataset.feature_descriptions.clone(),
841 description: base_dataset.description.clone(),
842 metadata,
843 })
844 }
845
846 fn generate_support_set(
847 &self,
848 n_way: usize,
849 k_shot: usize,
850 n_features: usize,
851 episode_id: usize,
852 ) -> Result<Dataset> {
853 let n_samples = n_way * k_shot;
854 use crate::generators::make_classification;
855 make_classification(
856 n_samples,
857 n_features,
858 n_way,
859 1,
860 n_features / 2,
861 Some(episode_id as u64),
862 )
863 }
864
865 fn generate_query_set(
866 &self,
867 n_way: usize,
868 n_query: usize,
869 n_features: usize,
870 _set: &Dataset,
871 episode_id: usize,
872 ) -> Result<Dataset> {
873 let n_samples = n_way * n_query;
874 use crate::generators::make_classification;
875 make_classification(
876 n_samples,
877 n_features,
878 n_way,
879 1,
880 n_features / 2,
881 Some(episode_id as u64 + 1000),
882 )
883 }
884
885 fn generate_class_centers(&self, n_classes: usize, n_features: usize) -> Result<Array2<f64>> {
886 let centers = Array2::from_shape_fn((n_classes, n_features), |_| {
887 thread_rng().random::<f64>() * 4.0 - 2.0
888 });
889 Ok(centers)
890 }
891
892 fn generate_classification_from_centers(
893 &self,
894 n_samples: usize,
895 centers: &Array2<f64>,
896 cluster_std: f64,
897 seed: u64,
898 ) -> Result<Dataset> {
899 use crate::generators::make_blobs;
900 make_blobs(
901 n_samples,
902 centers.ncols(),
903 centers.nrows(),
904 cluster_std,
905 Some(seed),
906 )
907 }
908}
909
910#[derive(Debug)]
912pub struct MultiTaskDataset {
913 pub tasks: Vec<Dataset>,
915 pub shared_features: usize,
917 pub task_correlation: f64,
919}
920
921#[derive(Debug)]
923pub struct DomainAdaptationDataset {
924 pub domains: Vec<(String, Dataset)>,
926 pub n_source_domains: usize,
928}
929
930#[derive(Debug)]
932pub struct FewShotEpisode {
933 pub support_set: Dataset,
935 pub query_set: Dataset,
937 pub n_way: usize,
939 pub k_shot: usize,
941}
942
943#[derive(Debug)]
945pub struct FewShotDataset {
946 pub episodes: Vec<FewShotEpisode>,
948 pub n_way: usize,
950 pub k_shot: usize,
952 pub n_query: usize,
954}
955
956#[derive(Debug)]
958pub struct ContinualLearningDataset {
959 pub tasks: Vec<Dataset>,
961 pub concept_drift_strength: f64,
963}
964
965#[allow(dead_code)]
969pub fn make_adversarial_examples(
970 base_dataset: &Dataset,
971 config: AdversarialConfig,
972) -> Result<Dataset> {
973 let generator = AdvancedGenerator::new(config.random_state);
974 generator.make_adversarial_examples(base_dataset, config)
975}
976
977#[allow(dead_code)]
979pub fn make_anomaly_dataset(
980 n_samples: usize,
981 n_features: usize,
982 config: AnomalyConfig,
983) -> Result<Dataset> {
984 let generator = AdvancedGenerator::new(config.random_state);
985 generator.make_anomaly_dataset(n_samples, n_features, config)
986}
987
988#[allow(dead_code)]
990pub fn make_multitask_dataset(
991 n_samples: usize,
992 config: MultiTaskConfig,
993) -> Result<MultiTaskDataset> {
994 let generator = AdvancedGenerator::new(config.random_state);
995 generator.make_multitask_dataset(n_samples, config)
996}
997
998#[allow(dead_code)]
1000pub fn make_domain_adaptation_dataset(
1001 n_samples_per_domain: usize,
1002 n_features: usize,
1003 n_classes: usize,
1004 config: DomainAdaptationConfig,
1005) -> Result<DomainAdaptationDataset> {
1006 let generator = AdvancedGenerator::new(config.random_state);
1007 generator.make_domain_adaptation_dataset(n_samples_per_domain, n_features, n_classes, config)
1008}
1009
1010#[allow(dead_code)]
1012pub fn make_few_shot_dataset(
1013 n_way: usize,
1014 k_shot: usize,
1015 n_query: usize,
1016 n_episodes: usize,
1017 n_features: usize,
1018) -> Result<FewShotDataset> {
1019 let generator = AdvancedGenerator::new(Some(42));
1020 generator.make_few_shot_dataset(n_way, k_shot, n_query, n_episodes, n_features)
1021}
1022
1023#[allow(dead_code)]
1025pub fn make_continual_learning_dataset(
1026 n_tasks: usize,
1027 n_samples_per_task: usize,
1028 n_features: usize,
1029 n_classes: usize,
1030 concept_drift_strength: f64,
1031) -> Result<ContinualLearningDataset> {
1032 let generator = AdvancedGenerator::new(Some(42));
1033 generator.make_continual_learning_dataset(
1034 n_tasks,
1035 n_samples_per_task,
1036 n_features,
1037 n_classes,
1038 concept_drift_strength,
1039 )
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044 use super::*;
1045 use crate::generators::make_classification;
1046
1047 #[test]
1048 fn test_adversarial_config() {
1049 let config = AdversarialConfig::default();
1050 assert_eq!(config.epsilon, 0.1);
1051 assert_eq!(config.attack_method, AttackMethod::FGSM);
1052 assert_eq!(config.iterations, 10);
1053 }
1054
1055 #[test]
1056 fn test_anomaly_dataset_generation() {
1057 let config = AnomalyConfig {
1058 anomaly_fraction: 0.2,
1059 anomaly_type: AnomalyType::Point,
1060 severity: 2.0,
1061 ..Default::default()
1062 };
1063
1064 let dataset = make_anomaly_dataset(100, 10, config).expect("Operation failed");
1065
1066 assert_eq!(dataset.n_samples(), 100);
1067 assert_eq!(dataset.n_features(), 10);
1068 assert!(dataset.target.is_some());
1069
1070 let target = dataset.target.expect("Test: target required");
1072 let anomalies = target.iter().filter(|&&x| x == 1.0).count();
1073 assert!(anomalies > 0);
1074 assert!(anomalies < 100);
1075 }
1076
1077 #[test]
1078 fn test_multitask_dataset_generation() {
1079 let config = MultiTaskConfig {
1080 n_tasks: 2,
1081 task_types: vec![TaskType::Classification(3), TaskType::Regression],
1082 shared_features: 5,
1083 task_specific_features: 3,
1084 ..Default::default()
1085 };
1086
1087 let dataset = make_multitask_dataset(50, config).expect("Operation failed");
1088
1089 assert_eq!(dataset.tasks.len(), 2);
1090 assert_eq!(dataset.shared_features, 5);
1091
1092 for task in &dataset.tasks {
1093 assert_eq!(task.n_samples(), 50);
1094 assert!(task.target.is_some());
1095 }
1096 }
1097
1098 #[test]
1099 fn test_adversarial_examples_generation() {
1100 let base_dataset =
1101 make_classification(100, 10, 3, 2, 8, Some(42)).expect("Operation failed");
1102 let config = AdversarialConfig {
1103 epsilon: 0.1,
1104 attack_method: AttackMethod::FGSM,
1105 ..Default::default()
1106 };
1107
1108 let adversarial_dataset =
1109 make_adversarial_examples(&base_dataset, config).expect("Operation failed");
1110
1111 assert_eq!(adversarial_dataset.n_samples(), base_dataset.n_samples());
1112 assert_eq!(adversarial_dataset.n_features(), base_dataset.n_features());
1113
1114 let diff = &adversarial_dataset.data - &base_dataset.data;
1116 let mut max_abs = 0.0_f64;
1117 for v in diff.iter() {
1118 let a = v.abs();
1119 if a > max_abs {
1120 max_abs = a;
1121 }
1122 }
1123 assert!(
1124 max_abs > 1e-9,
1125 "Adversarial perturbation appears to be zero (max abs diff = {max_abs})"
1126 );
1127 }
1128
1129 #[test]
1130 fn test_few_shot_dataset() {
1131 let dataset = make_few_shot_dataset(5, 3, 10, 2, 20).expect("Operation failed");
1132
1133 assert_eq!(dataset.n_way, 5);
1134 assert_eq!(dataset.k_shot, 3);
1135 assert_eq!(dataset.n_query, 10);
1136 assert_eq!(dataset.episodes.len(), 2);
1137
1138 for episode in &dataset.episodes {
1139 assert_eq!(episode.n_way, 5);
1140 assert_eq!(episode.k_shot, 3);
1141 assert_eq!(episode.support_set.n_samples(), 5 * 3); assert_eq!(episode.query_set.n_samples(), 5 * 10); }
1144 }
1145
1146 #[test]
1147 fn test_continual_learning_dataset() {
1148 let dataset =
1149 make_continual_learning_dataset(3, 100, 10, 5, 0.5).expect("Operation failed");
1150
1151 assert_eq!(dataset.tasks.len(), 3);
1152 assert_eq!(dataset.concept_drift_strength, 0.5);
1153
1154 for task in &dataset.tasks {
1155 assert_eq!(task.n_samples(), 100);
1156 assert_eq!(task.n_features(), 10);
1157 assert!(task.target.is_some());
1158 }
1159 }
1160}