tenflowers_core/
deterministic.rs1use crate::{Result, TensorError};
37use std::sync::{Arc, Mutex, OnceLock};
38
39#[derive(Debug, Clone)]
41pub struct DeterministicState {
42 pub enabled: bool,
44 pub global_seed: u64,
46 pub operation_counter: u64,
48 pub strict_mode: bool,
50 pub prefer_deterministic_algorithms: bool,
52 pub operation_log: Vec<String>,
54 pub max_log_size: usize,
56}
57
58impl Default for DeterministicState {
59 fn default() -> Self {
60 Self {
61 enabled: false,
62 global_seed: 0,
63 operation_counter: 0,
64 strict_mode: false,
65 prefer_deterministic_algorithms: true,
66 operation_log: Vec::new(),
67 max_log_size: 1000,
68 }
69 }
70}
71
72impl DeterministicState {
73 pub fn new(seed: u64) -> Self {
75 Self {
76 enabled: true,
77 global_seed: seed,
78 ..Default::default()
79 }
80 }
81
82 pub fn next_subseed(&mut self, operation_name: &str) -> u64 {
84 let subseed = self
86 .global_seed
87 .wrapping_mul(6364136223846793005)
88 .wrapping_add(self.operation_counter)
89 .wrapping_add(hash_string(operation_name));
90
91 self.operation_counter += 1;
92
93 if self.operation_log.len() < self.max_log_size {
95 self.operation_log
96 .push(format!("{}: seed={}", operation_name, subseed));
97 }
98
99 subseed
100 }
101
102 pub fn reset_counter(&mut self) {
104 self.operation_counter = 0;
105 }
106
107 pub fn clear_log(&mut self) {
109 self.operation_log.clear();
110 }
111
112 pub fn snapshot(&self) -> DeterministicSnapshot {
114 DeterministicSnapshot {
115 global_seed: self.global_seed,
116 operation_counter: self.operation_counter,
117 enabled: self.enabled,
118 }
119 }
120
121 pub fn restore(&mut self, snapshot: &DeterministicSnapshot) {
123 self.global_seed = snapshot.global_seed;
124 self.operation_counter = snapshot.operation_counter;
125 self.enabled = snapshot.enabled;
126 }
127}
128
129#[derive(Debug, Clone, Copy)]
131pub struct DeterministicSnapshot {
132 pub global_seed: u64,
133 pub operation_counter: u64,
134 pub enabled: bool,
135}
136
137fn hash_string(s: &str) -> u64 {
139 let mut hash = 0xcbf29ce484222325u64; for byte in s.bytes() {
141 hash ^= byte as u64;
142 hash = hash.wrapping_mul(0x100000001b3); }
144 hash
145}
146
147static GLOBAL_STATE: OnceLock<Arc<Mutex<DeterministicState>>> = OnceLock::new();
152
153fn get_global_state() -> &'static Arc<Mutex<DeterministicState>> {
155 GLOBAL_STATE.get_or_init(|| Arc::new(Mutex::new(DeterministicState::default())))
156}
157
158pub fn set_deterministic_mode(enabled: bool) {
162 let state = get_global_state();
163 state.lock().expect("lock should not be poisoned").enabled = enabled;
164}
165
166pub fn is_deterministic_mode() -> bool {
168 let state = get_global_state();
169 state.lock().expect("lock should not be poisoned").enabled
170}
171
172pub fn set_global_seed(seed: u64) {
176 let state = get_global_state();
177 let mut s = state.lock().expect("lock should not be poisoned");
178 s.global_seed = seed;
179 s.operation_counter = 0;
180 s.clear_log();
181}
182
183pub fn get_global_seed() -> u64 {
185 let state = get_global_state();
186 state
187 .lock()
188 .expect("lock should not be poisoned")
189 .global_seed
190}
191
192pub fn set_strict_mode(strict: bool) {
194 let state = get_global_state();
195 state
196 .lock()
197 .expect("lock should not be poisoned")
198 .strict_mode = strict;
199}
200
201pub fn is_strict_mode() -> bool {
203 let state = get_global_state();
204 state
205 .lock()
206 .expect("lock should not be poisoned")
207 .strict_mode
208}
209
210pub fn get_operation_seed(operation_name: &str) -> u64 {
215 let state = get_global_state();
216 let mut s = state.lock().expect("lock should not be poisoned");
217
218 if !s.enabled {
219 use std::time::{SystemTime, UNIX_EPOCH};
221 SystemTime::now()
222 .duration_since(UNIX_EPOCH)
223 .expect("system time should be after UNIX_EPOCH")
224 .as_nanos() as u64
225 } else {
226 s.next_subseed(operation_name)
227 }
228}
229
230pub fn reset_operation_counter() {
235 let state = get_global_state();
236 state
237 .lock()
238 .expect("lock should not be poisoned")
239 .reset_counter();
240}
241
242pub fn get_state_snapshot() -> DeterministicSnapshot {
246 let state = get_global_state();
247 state
248 .lock()
249 .expect("lock should not be poisoned")
250 .snapshot()
251}
252
253pub fn restore_state_snapshot(snapshot: &DeterministicSnapshot) {
255 let state = get_global_state();
256 state
257 .lock()
258 .expect("lock should not be poisoned")
259 .restore(snapshot);
260}
261
262pub fn get_operation_log() -> Vec<String> {
264 let state = get_global_state();
265 state
266 .lock()
267 .expect("lock should not be poisoned")
268 .operation_log
269 .clone()
270}
271
272pub fn clear_operation_log() {
274 let state = get_global_state();
275 state
276 .lock()
277 .expect("lock should not be poisoned")
278 .clear_log();
279}
280
281pub fn enable_operation_logging() {
283 let state = get_global_state();
284 let mut s = state.lock().expect("lock should not be poisoned");
285 s.max_log_size = 1000;
286}
287
288#[doc(hidden)]
290pub fn reset_to_defaults() {
291 let state = get_global_state();
292 let mut s = state.lock().expect("lock should not be poisoned");
293 *s = DeterministicState::default();
294}
295
296pub struct DeterministicScope {
301 previous_state: DeterministicSnapshot,
302}
303
304impl DeterministicScope {
305 pub fn new(seed: u64) -> Self {
307 let previous_state = get_state_snapshot();
308
309 set_deterministic_mode(true);
310 set_global_seed(seed);
311
312 Self { previous_state }
313 }
314
315 pub fn with_mode(enabled: bool) -> Self {
317 let previous_state = get_state_snapshot();
318 set_deterministic_mode(enabled);
319 Self { previous_state }
320 }
321}
322
323impl Drop for DeterministicScope {
324 fn drop(&mut self) {
325 restore_state_snapshot(&self.previous_state);
326 }
327}
328
329#[derive(Debug, Clone)]
331pub struct DeterministicConfig {
332 pub seed: u64,
334 pub strict: bool,
336 pub prefer_deterministic: bool,
338 pub log_operations: bool,
340}
341
342impl Default for DeterministicConfig {
343 fn default() -> Self {
344 Self {
345 seed: 42,
346 strict: false,
347 prefer_deterministic: true,
348 log_operations: false,
349 }
350 }
351}
352
353impl DeterministicConfig {
354 pub fn apply(&self) {
356 set_global_seed(self.seed);
357 set_deterministic_mode(true);
358 set_strict_mode(self.strict);
359
360 let state = get_global_state();
361 let mut s = state.lock().expect("lock should not be poisoned");
362 s.prefer_deterministic_algorithms = self.prefer_deterministic;
363
364 if !self.log_operations {
365 s.clear_log();
366 s.max_log_size = 0;
367 } else {
368 s.max_log_size = 1000;
369 }
370 }
371}
372
373pub fn verify_reproducibility<F, T>(operation_name: &str, mut operation: F) -> Result<bool>
377where
378 F: FnMut() -> T,
379 T: PartialEq,
380{
381 let snapshot = get_state_snapshot();
382
383 set_global_seed(snapshot.global_seed);
385 reset_operation_counter();
386 let result1 = operation();
387
388 set_global_seed(snapshot.global_seed);
390 reset_operation_counter();
391 let result2 = operation();
392
393 restore_state_snapshot(&snapshot);
395
396 Ok(result1 == result2)
397}
398
399pub fn mark_non_deterministic(operation_name: &str) -> Result<()> {
407 if is_deterministic_mode() && is_strict_mode() {
408 Err(TensorError::invalid_operation_simple(format!(
409 "Operation '{}' is non-deterministic but strict deterministic mode is enabled",
410 operation_name
411 )))
412 } else {
413 if is_deterministic_mode() {
415 eprintln!(
416 "Warning: Operation '{}' may not be fully deterministic",
417 operation_name
418 );
419 }
420 Ok(())
421 }
422}
423
424pub fn should_use_deterministic_gpu_ops() -> bool {
426 let state = get_global_state();
427 let s = state.lock().expect("lock should not be poisoned");
428 s.enabled && s.prefer_deterministic_algorithms
429}
430
431#[cfg(test)]
436mod tests {
437 use super::*;
438 use std::sync::Mutex;
439
440 lazy_static::lazy_static! {
442 static ref TEST_MUTEX: Mutex<()> = Mutex::new(());
443 }
444
445 #[test]
446 fn test_deterministic_mode_toggle() {
447 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
448 reset_to_defaults();
449 set_deterministic_mode(true);
450 assert!(is_deterministic_mode());
451
452 set_deterministic_mode(false);
453 assert!(!is_deterministic_mode());
454 }
455
456 #[test]
457 fn test_global_seed() {
458 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
459 reset_to_defaults();
460 set_global_seed(12345);
461 assert_eq!(get_global_seed(), 12345);
462
463 set_global_seed(67890);
464 assert_eq!(get_global_seed(), 67890);
465 }
466
467 #[test]
468 fn test_operation_seed_generation() {
469 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
470 reset_to_defaults();
471 set_deterministic_mode(true);
472 set_global_seed(42);
473
474 let seed1 = get_operation_seed("test_op");
475 let seed2 = get_operation_seed("test_op");
476
477 assert_ne!(seed1, seed2);
479
480 reset_operation_counter();
482 let seed3 = get_operation_seed("test_op");
483 assert_eq!(seed1, seed3);
484 }
485
486 #[test]
487 fn test_operation_seed_uniqueness() {
488 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
489 reset_to_defaults();
490 set_deterministic_mode(true);
491 set_global_seed(42);
492 reset_operation_counter();
493
494 let seed_a = get_operation_seed("operation_a");
495 let seed_b = get_operation_seed("operation_b");
496
497 assert_ne!(seed_a, seed_b);
499 }
500
501 #[test]
502 fn test_snapshot_and_restore() {
503 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
504 reset_to_defaults();
505 set_deterministic_mode(true);
506 set_global_seed(100);
507
508 let _ = get_operation_seed("op1");
509 let _ = get_operation_seed("op2");
510
511 let snapshot = get_state_snapshot();
512
513 let _ = get_operation_seed("op3");
514
515 restore_state_snapshot(&snapshot);
516
517 let seed_after_restore = get_operation_seed("op3");
518
519 restore_state_snapshot(&snapshot);
521 let seed_repeat = get_operation_seed("op3");
522
523 assert_eq!(seed_after_restore, seed_repeat);
524 }
525
526 #[test]
527 fn test_deterministic_scope() {
528 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
529 reset_to_defaults();
530 set_deterministic_mode(false);
531 set_global_seed(100);
532
533 {
534 let _scope = DeterministicScope::new(200);
535 assert!(is_deterministic_mode());
536 assert_eq!(get_global_seed(), 200);
537 }
538
539 assert!(!is_deterministic_mode());
541 assert_eq!(get_global_seed(), 100);
542 }
543
544 #[test]
545 fn test_strict_mode() {
546 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
547 reset_to_defaults();
548 set_strict_mode(true);
549 assert!(is_strict_mode());
550
551 set_strict_mode(false);
552 assert!(!is_strict_mode());
553 }
554
555 #[test]
556 fn test_mark_non_deterministic() {
557 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
558 reset_to_defaults();
559 set_deterministic_mode(true);
560 set_strict_mode(false);
561
562 assert!(mark_non_deterministic("test_op").is_ok());
564
565 set_strict_mode(true);
566 assert!(mark_non_deterministic("test_op").is_err());
568 }
569
570 #[test]
571 fn test_config_apply() {
572 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
573 reset_to_defaults();
574 let config = DeterministicConfig {
575 seed: 777,
576 strict: true,
577 prefer_deterministic: true,
578 log_operations: false,
579 };
580
581 config.apply();
582
583 assert_eq!(get_global_seed(), 777);
584 assert!(is_deterministic_mode());
585 assert!(is_strict_mode());
586 }
587
588 #[test]
589 fn test_operation_log() {
590 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
591 reset_to_defaults();
592 enable_operation_logging();
593 set_deterministic_mode(true);
594 set_global_seed(42);
595
596 let _ = get_operation_seed("op1");
597 let _ = get_operation_seed("op2");
598
599 let log = get_operation_log();
600 assert_eq!(log.len(), 2);
601 assert!(log[0].contains("op1"));
602 assert!(log[1].contains("op2"));
603 }
604
605 #[test]
606 fn test_hash_string_deterministic() {
607 let hash1 = hash_string("test");
609 let hash2 = hash_string("test");
610 assert_eq!(hash1, hash2);
611
612 let hash3 = hash_string("different");
614 assert_ne!(hash1, hash3);
615 }
616
617 #[test]
618 fn test_reproducibility_with_counter_reset() {
619 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
620 reset_to_defaults();
621 set_deterministic_mode(true);
622 set_global_seed(42);
623
624 reset_operation_counter();
626 let seeds1: Vec<u64> = (0..5)
627 .map(|i| get_operation_seed(&format!("op{}", i)))
628 .collect();
629
630 reset_operation_counter();
632 let seeds2: Vec<u64> = (0..5)
633 .map(|i| get_operation_seed(&format!("op{}", i)))
634 .collect();
635
636 assert_eq!(seeds1, seeds2);
637 }
638
639 #[test]
640 fn test_non_deterministic_mode_uses_system_time() {
641 let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
642 reset_to_defaults();
643 set_deterministic_mode(false);
644
645 let seed1 = get_operation_seed("test");
646 std::thread::sleep(std::time::Duration::from_nanos(100));
647 let seed2 = get_operation_seed("test");
648
649 let _ = seed1;
653 let _ = seed2;
654 }
655}