1use crate::{Dataset, Result};
7use scirs2_core::random::rngs::StdRng;
8use scirs2_core::random::{Rng, RngExt, SeedableRng};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use tenflowers_core::{Tensor, TensorError};
13
14static GLOBAL_SEED_MANAGER: std::sync::OnceLock<Arc<Mutex<SeedManager>>> =
16 std::sync::OnceLock::new();
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SeedManager {
21 master_seed: u64,
23 component_seeds: HashMap<String, u64>,
25 operation_counter: u64,
27}
28
29impl SeedManager {
30 pub fn new(master_seed: u64) -> Self {
32 Self {
33 master_seed,
34 component_seeds: HashMap::new(),
35 operation_counter: 0,
36 }
37 }
38
39 pub fn master_seed(&self) -> u64 {
41 self.master_seed
42 }
43
44 pub fn get_component_seed(&mut self, component: &str) -> u64 {
46 if let Some(&seed) = self.component_seeds.get(component) {
47 seed
48 } else {
49 let mut hasher = std::collections::hash_map::DefaultHasher::new();
51 use std::hash::{Hash, Hasher};
52 self.master_seed.hash(&mut hasher);
53 component.hash(&mut hasher);
54 let seed = hasher.finish();
55 self.component_seeds.insert(component.to_string(), seed);
56 seed
57 }
58 }
59
60 pub fn next_operation_seed(&mut self) -> u64 {
62 self.operation_counter += 1;
63 let mut hasher = std::collections::hash_map::DefaultHasher::new();
64 use std::hash::{Hash, Hasher};
65 self.master_seed.hash(&mut hasher);
66 self.operation_counter.hash(&mut hasher);
67 hasher.finish()
68 }
69
70 pub fn create_rng(&mut self, component: &str) -> StdRng {
72 let seed = self.get_component_seed(component);
73 StdRng::seed_from_u64(seed)
74 }
75
76 pub fn set_global(manager: SeedManager) {
78 let _ = GLOBAL_SEED_MANAGER.set(Arc::new(Mutex::new(manager)));
79 }
80
81 pub fn global() -> Arc<Mutex<SeedManager>> {
83 GLOBAL_SEED_MANAGER
84 .get_or_init(|| Arc::new(Mutex::new(SeedManager::new(42))))
85 .clone()
86 }
87
88 pub fn reset(&mut self) {
90 self.component_seeds.clear();
91 self.operation_counter = 0;
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EnvironmentInfo {
98 pub rust_version: String,
100 pub os: String,
102 pub arch: String,
104 pub num_cpus: usize,
106 pub timestamp: u64,
108 pub env_vars: HashMap<String, String>,
110 pub seed_info: SeedInfo,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct SeedInfo {
117 pub master_seed: u64,
119 pub component_seeds: HashMap<String, u64>,
121}
122
123impl EnvironmentInfo {
124 pub fn capture(seed_manager: &SeedManager) -> Self {
126 let timestamp = std::time::SystemTime::now()
127 .duration_since(std::time::UNIX_EPOCH)
128 .unwrap_or_default()
129 .as_secs();
130
131 let mut env_vars = HashMap::new();
133 for var in ["RUST_LOG", "CARGO_TARGET_DIR", "RUSTFLAGS"] {
134 if let Ok(value) = std::env::var(var) {
135 env_vars.insert(var.to_string(), value);
136 }
137 }
138
139 Self {
140 rust_version: "unknown".to_string(), os: std::env::consts::OS.to_string(),
142 arch: std::env::consts::ARCH.to_string(),
143 num_cpus: num_cpus::get(),
144 timestamp,
145 env_vars,
146 seed_info: SeedInfo {
147 master_seed: seed_manager.master_seed,
148 component_seeds: seed_manager.component_seeds.clone(),
149 },
150 }
151 }
152}
153
154#[derive(Debug)]
156pub struct DeterministicDataset<T, D> {
157 dataset: D,
158 indices: Vec<usize>,
159 _phantom: std::marker::PhantomData<T>,
160}
161
162impl<T, D> DeterministicDataset<T, D>
163where
164 D: Dataset<T>,
165 T: Clone + Default + Send + Sync + 'static,
166{
167 pub fn new(dataset: D, seed: u64) -> Self {
169 let len = dataset.len();
170 let mut indices: Vec<usize> = (0..len).collect();
171
172 let mut rng = StdRng::seed_from_u64(seed);
174 Self::fisher_yates_shuffle(&mut indices, &mut rng);
175
176 Self {
177 dataset,
178 indices,
179 _phantom: std::marker::PhantomData,
180 }
181 }
182
183 pub fn sequential(dataset: D) -> Self {
185 let len = dataset.len();
186 let indices: Vec<usize> = (0..len).collect();
187
188 Self {
189 dataset,
190 indices,
191 _phantom: std::marker::PhantomData,
192 }
193 }
194
195 pub fn reverse(dataset: D) -> Self {
197 let len = dataset.len();
198 let indices: Vec<usize> = (0..len).rev().collect();
199
200 Self {
201 dataset,
202 indices,
203 _phantom: std::marker::PhantomData,
204 }
205 }
206
207 pub fn inner(&self) -> &D {
209 &self.dataset
210 }
211
212 pub fn indices(&self) -> &[usize] {
214 &self.indices
215 }
216
217 pub fn reshuffle(&mut self, seed: u64) {
219 let mut rng = StdRng::seed_from_u64(seed);
220 Self::fisher_yates_shuffle(&mut self.indices, &mut rng);
221 }
222
223 fn fisher_yates_shuffle<R: Rng>(indices: &mut [usize], rng: &mut R) {
224 for i in (1..indices.len()).rev() {
225 let j = rng.random_range(0..i);
226 indices.swap(i, j);
227 }
228 }
229}
230
231impl<T, D> Dataset<T> for DeterministicDataset<T, D>
232where
233 D: Dataset<T>,
234 T: Clone + Default + Send + Sync + 'static,
235{
236 fn len(&self) -> usize {
237 self.dataset.len()
238 }
239
240 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
241 if index >= self.indices.len() {
242 return Err(TensorError::invalid_argument(format!(
243 "Index {} out of bounds for dataset of length {}",
244 index,
245 self.indices.len()
246 )));
247 }
248
249 let actual_index = self.indices[index];
250 self.dataset.get(actual_index)
251 }
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct ExperimentConfig {
257 pub name: String,
259 pub seed: u64,
261 pub dataset_config: DatasetConfig,
263 pub environment: EnvironmentInfo,
265 pub metadata: HashMap<String, String>,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct DatasetConfig {
272 pub ordering: OrderingStrategy,
274 pub sampling: SamplingConfig,
276 pub transforms: Vec<TransformConfig>,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub enum OrderingStrategy {
283 Sequential,
285 Reverse,
287 Shuffled { seed: u64 },
289 Custom { indices: Vec<usize> },
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct SamplingConfig {
296 pub strategy: String,
298 pub seed: u64,
300 pub parameters: HashMap<String, f64>,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct TransformConfig {
307 pub name: String,
309 pub seed: u64,
311 pub parameters: HashMap<String, serde_json::Value>,
313}
314
315pub struct DeterministicOrdering;
317
318impl DeterministicOrdering {
319 pub fn create_indices(len: usize, strategy: &OrderingStrategy) -> Vec<usize> {
321 match strategy {
322 OrderingStrategy::Sequential => (0..len).collect(),
323 OrderingStrategy::Reverse => (0..len).rev().collect(),
324 OrderingStrategy::Shuffled { seed } => {
325 let mut indices: Vec<usize> = (0..len).collect();
326 let mut rng = StdRng::seed_from_u64(*seed);
327 Self::shuffle_indices(&mut indices, &mut rng);
328 indices
329 }
330 OrderingStrategy::Custom { indices } => {
331 indices
333 .iter()
334 .map(|&i| i.min(len.saturating_sub(1)))
335 .collect()
336 }
337 }
338 }
339
340 pub fn shuffle_indices<R: Rng>(indices: &mut [usize], rng: &mut R) {
342 for i in (1..indices.len()).rev() {
343 let j = rng.random_range(0..i);
344 indices.swap(i, j);
345 }
346 }
347
348 pub fn create_stratified_indices_f32(
350 dataset: &dyn Dataset<f32>,
351 seed: u64,
352 num_classes: usize,
353 ) -> Result<Vec<usize>> {
354 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); num_classes];
356
357 for i in 0..dataset.len() {
358 let (_, labels) = dataset.get(i)?;
359
360 let class = if labels.is_scalar() {
362 labels.get(&[]).unwrap_or(0.0) as usize
363 } else if let Some(slice) = labels.as_slice() {
364 slice.first().copied().unwrap_or(0.0) as usize
365 } else {
366 0
367 };
368
369 if class < num_classes {
370 class_indices[class].push(i);
371 }
372 }
373
374 let mut rng = StdRng::seed_from_u64(seed);
376 let mut result = Vec::new();
377
378 for class_samples in &mut class_indices {
379 Self::shuffle_indices(class_samples, &mut rng);
380 result.extend_from_slice(class_samples);
381 }
382
383 Ok(result)
384 }
385}
386
387pub trait ReproducibilityExt<T>: Dataset<T> + Sized
389where
390 T: Clone + Default + Send + Sync + 'static,
391{
392 fn deterministic(self, seed: u64) -> DeterministicDataset<T, Self> {
394 DeterministicDataset::new(self, seed)
395 }
396
397 fn sequential(self) -> DeterministicDataset<T, Self> {
399 DeterministicDataset::sequential(self)
400 }
401
402 fn reverse(self) -> DeterministicDataset<T, Self> {
404 DeterministicDataset::reverse(self)
405 }
406}
407
408impl<T, D: Dataset<T>> ReproducibilityExt<T> for D where T: Clone + Default + Send + Sync + 'static {}
409
410#[derive(Debug)]
412pub struct ExperimentTracker {
413 config: ExperimentConfig,
414 start_time: std::time::Instant,
415 operations: Vec<OperationRecord>,
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
420pub struct OperationRecord {
421 pub name: String,
423 pub timestamp: u64,
425 pub duration_ms: u64,
427 pub seed: u64,
429 pub metadata: HashMap<String, String>,
431}
432
433impl ExperimentTracker {
434 pub fn new(config: ExperimentConfig) -> Self {
436 Self {
437 config,
438 start_time: std::time::Instant::now(),
439 operations: Vec::new(),
440 }
441 }
442
443 pub fn record_operation(
445 &mut self,
446 name: String,
447 duration: std::time::Duration,
448 seed: u64,
449 metadata: HashMap<String, String>,
450 ) {
451 let timestamp = std::time::SystemTime::now()
452 .duration_since(std::time::UNIX_EPOCH)
453 .unwrap_or_default()
454 .as_secs();
455
456 let record = OperationRecord {
457 name,
458 timestamp,
459 duration_ms: duration.as_millis() as u64,
460 seed,
461 metadata,
462 };
463
464 self.operations.push(record);
465 }
466
467 pub fn config(&self) -> &ExperimentConfig {
469 &self.config
470 }
471
472 pub fn operations(&self) -> &[OperationRecord] {
474 &self.operations
475 }
476
477 pub fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
479 let experiment_data = ExperimentData {
480 config: self.config.clone(),
481 operations: self.operations.clone(),
482 total_duration_ms: self.start_time.elapsed().as_millis() as u64,
483 };
484
485 let json_data = serde_json::to_string_pretty(&experiment_data).map_err(|e| {
486 TensorError::invalid_argument(format!("Failed to serialize experiment data: {e}"))
487 })?;
488
489 std::fs::write(path, json_data).map_err(|e| {
490 TensorError::invalid_argument(format!("Failed to write experiment file: {e}"))
491 })?;
492
493 Ok(())
494 }
495
496 pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
498 let json_data = std::fs::read_to_string(path).map_err(|e| {
499 TensorError::invalid_argument(format!("Failed to read experiment file: {e}"))
500 })?;
501
502 let experiment_data: ExperimentData = serde_json::from_str(&json_data).map_err(|e| {
503 TensorError::invalid_argument(format!("Failed to parse experiment JSON: {e}"))
504 })?;
505
506 Ok(Self {
507 config: experiment_data.config,
508 start_time: std::time::Instant::now(), operations: experiment_data.operations,
510 })
511 }
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516struct ExperimentData {
517 config: ExperimentConfig,
518 operations: Vec<OperationRecord>,
519 total_duration_ms: u64,
520}
521
522pub struct DeterministicOps;
524
525impl DeterministicOps {
526 pub fn set_global_seed(seed: u64) {
528 SeedManager::set_global(SeedManager::new(seed));
529 }
530
531 pub fn get_rng(component: &str) -> StdRng {
533 let manager = SeedManager::global();
534 let mut manager = manager.lock().unwrap_or_else(|e| e.into_inner());
535 manager.create_rng(component)
536 }
537
538 pub fn next_operation_seed() -> u64 {
540 let manager = SeedManager::global();
541 let mut manager = manager.lock().unwrap_or_else(|e| e.into_inner());
542 manager.next_operation_seed()
543 }
544
545 pub fn capture_environment() -> EnvironmentInfo {
547 let manager = SeedManager::global();
548 let manager = manager.lock().unwrap_or_else(|e| e.into_inner());
549 EnvironmentInfo::capture(&manager)
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556 use crate::TensorDataset;
557 use tempfile::TempDir;
558
559 #[test]
560 fn test_seed_manager() {
561 let mut manager = SeedManager::new(42);
562
563 assert_eq!(manager.master_seed(), 42);
564
565 let seed1 = manager.get_component_seed("test");
567 let seed2 = manager.get_component_seed("test");
568 assert_eq!(seed1, seed2);
569
570 let seed3 = manager.get_component_seed("other");
571 assert_ne!(seed1, seed3);
572
573 let op1 = manager.next_operation_seed();
575 let op2 = manager.next_operation_seed();
576 assert_ne!(op1, op2);
577 }
578
579 #[test]
580 fn test_deterministic_dataset() {
581 let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
583 let labels_data = vec![0.0, 1.0, 0.0];
584 let features =
585 Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
586 let labels =
587 Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
588 let dataset = TensorDataset::new(features, labels);
589
590 let det_dataset = DeterministicDataset::new(dataset, 42);
592
593 assert_eq!(det_dataset.len(), 3);
594
595 let det_dataset2 = DeterministicDataset::new(det_dataset.inner().clone(), 42);
597 assert_eq!(det_dataset.indices(), det_dataset2.indices());
598
599 let det_dataset3 = DeterministicDataset::new(det_dataset.inner().clone(), 123);
601 assert_ne!(det_dataset.indices(), det_dataset3.indices());
602 }
603
604 #[test]
605 fn test_ordering_strategies() {
606 let len = 5;
607
608 let seq_indices = DeterministicOrdering::create_indices(len, &OrderingStrategy::Sequential);
610 assert_eq!(seq_indices, vec![0, 1, 2, 3, 4]);
611
612 let rev_indices = DeterministicOrdering::create_indices(len, &OrderingStrategy::Reverse);
614 assert_eq!(rev_indices, vec![4, 3, 2, 1, 0]);
615
616 let shuffled1 =
618 DeterministicOrdering::create_indices(len, &OrderingStrategy::Shuffled { seed: 42 });
619 let shuffled2 =
620 DeterministicOrdering::create_indices(len, &OrderingStrategy::Shuffled { seed: 42 });
621 assert_eq!(shuffled1, shuffled2);
622
623 let shuffled3 =
625 DeterministicOrdering::create_indices(len, &OrderingStrategy::Shuffled { seed: 123 });
626 assert_ne!(shuffled1, shuffled3);
627
628 let custom_indices = DeterministicOrdering::create_indices(
630 len,
631 &OrderingStrategy::Custom {
632 indices: vec![2, 0, 4, 1, 3],
633 },
634 );
635 assert_eq!(custom_indices, vec![2, 0, 4, 1, 3]);
636 }
637
638 #[test]
639 fn test_environment_capture() {
640 let manager = SeedManager::new(42);
641 let env = EnvironmentInfo::capture(&manager);
642
643 assert!(!env.rust_version.is_empty());
644 assert!(!env.os.is_empty());
645 assert!(!env.arch.is_empty());
646 assert!(env.num_cpus > 0);
647 assert_eq!(env.seed_info.master_seed, 42);
648 }
649
650 #[test]
651 fn test_experiment_tracker() {
652 let config = ExperimentConfig {
653 name: "test_experiment".to_string(),
654 seed: 42,
655 dataset_config: DatasetConfig {
656 ordering: OrderingStrategy::Shuffled { seed: 42 },
657 sampling: SamplingConfig {
658 strategy: "random".to_string(),
659 seed: 42,
660 parameters: HashMap::new(),
661 },
662 transforms: Vec::new(),
663 },
664 environment: EnvironmentInfo::capture(&SeedManager::new(42)),
665 metadata: HashMap::new(),
666 };
667
668 let mut tracker = ExperimentTracker::new(config);
669
670 tracker.record_operation(
672 "data_loading".to_string(),
673 std::time::Duration::from_millis(100),
674 42,
675 HashMap::new(),
676 );
677
678 assert_eq!(tracker.operations().len(), 1);
679 assert_eq!(tracker.operations()[0].name, "data_loading");
680 assert_eq!(tracker.operations()[0].duration_ms, 100);
681
682 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
684 let file_path = temp_dir.path().join("experiment.json");
685
686 tracker
687 .save_to_file(&file_path)
688 .expect("test: save to file should succeed");
689 let loaded_tracker = ExperimentTracker::load_from_file(&file_path)
690 .expect("test: load from file should succeed");
691
692 assert_eq!(loaded_tracker.config().name, "test_experiment");
693 assert_eq!(loaded_tracker.operations().len(), 1);
694 }
695
696 #[test]
697 fn test_reproducibility_ext() {
698 let features_data = vec![1.0, 2.0, 3.0, 4.0];
700 let labels_data = vec![0.0, 1.0];
701 let features =
702 Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
703 let labels =
704 Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
705 let dataset = TensorDataset::new(features, labels);
706
707 let det_dataset = dataset.deterministic(42);
709 assert_eq!(det_dataset.len(), 2);
710
711 let seq_dataset = det_dataset.inner().clone().sequential();
712 assert_eq!(seq_dataset.indices(), &[0, 1]);
713
714 let rev_dataset = det_dataset.inner().clone().reverse();
715 assert_eq!(rev_dataset.indices(), &[1, 0]);
716 }
717
718 #[test]
719 fn test_deterministic_ops() {
720 DeterministicOps::set_global_seed(12345);
722
723 let mut rng1 = DeterministicOps::get_rng("test_component");
725 let val1: f64 = rng1.random();
726
727 let mut rng2 = DeterministicOps::get_rng("test_component");
729 let val2: f64 = rng2.random();
730 assert_eq!(val1, val2);
731
732 let mut rng3 = DeterministicOps::get_rng("other_component");
734 let val3: f64 = rng3.random();
735 assert_ne!(val1, val3);
736
737 let op1 = DeterministicOps::next_operation_seed();
739 let op2 = DeterministicOps::next_operation_seed();
740 assert_ne!(op1, op2);
741
742 let env = DeterministicOps::capture_environment();
744 assert_eq!(env.seed_info.master_seed, 12345);
745 }
746}