Skip to main content

tenflowers_core/
deterministic.rs

1/// Deterministic Mode for Reproducible Training
2///
3/// This module provides infrastructure for deterministic execution, ensuring that
4/// training runs produce identical results when using the same random seed. This is
5/// critical for debugging, comparing experiments, and scientific reproducibility.
6///
7/// ## Features
8///
9/// - **Global Seed Management**: Centralized seed control across all operations
10/// - **Operation-Local Seeds**: Each operation gets a deterministic subseed
11/// - **RNG State Tracking**: Save and restore random number generator states
12/// - **GPU Determinism**: Control non-deterministic GPU operations
13/// - **Reproducibility Validation**: Verify that operations are truly deterministic
14///
15/// ## Usage
16///
17/// ```rust,ignore
18/// use tenflowers_core::deterministic::{set_deterministic_mode, set_global_seed};
19///
20/// // Enable deterministic mode with a specific seed
21/// set_global_seed(42);
22/// set_deterministic_mode(true);
23///
24/// // All operations will now use deterministic algorithms
25/// let tensor = Tensor::<f32>::randn(&[10, 10]); // Uses seed 42
26///
27/// // Operations get unique subseeds
28/// let dropout_output = dropout(&tensor, 0.5); // Uses derived subseed
29/// ```
30///
31/// ## Important Notes
32///
33/// - Deterministic mode may be slower than non-deterministic mode
34/// - Some GPU operations may fall back to CPU for determinism
35/// - Parallel execution order must be controlled for full reproducibility
36use crate::{Result, TensorError};
37use std::sync::{Arc, Mutex, OnceLock};
38
39/// Global deterministic mode state
40#[derive(Debug, Clone)]
41pub struct DeterministicState {
42    /// Whether deterministic mode is enabled
43    pub enabled: bool,
44    /// Global random seed
45    pub global_seed: u64,
46    /// Current operation counter for subseed generation
47    pub operation_counter: u64,
48    /// Whether to enforce determinism strictly (fail on non-deterministic ops)
49    pub strict_mode: bool,
50    /// Whether to use deterministic algorithms even if slower
51    pub prefer_deterministic_algorithms: bool,
52    /// Track which operations have been executed for reproducibility
53    pub operation_log: Vec<String>,
54    /// Maximum size of operation log
55    pub max_log_size: usize,
56}
57
58impl Default for DeterministicState {
59    fn default() -> Self {
60        Self {
61            enabled: false,
62            global_seed: 0,
63            operation_counter: 0,
64            strict_mode: false,
65            prefer_deterministic_algorithms: true,
66            operation_log: Vec::new(),
67            max_log_size: 1000,
68        }
69    }
70}
71
72impl DeterministicState {
73    /// Create a new deterministic state with a seed
74    pub fn new(seed: u64) -> Self {
75        Self {
76            enabled: true,
77            global_seed: seed,
78            ..Default::default()
79        }
80    }
81
82    /// Get the next subseed for an operation
83    pub fn next_subseed(&mut self, operation_name: &str) -> u64 {
84        // Generate deterministic subseed based on global seed and counter
85        let subseed = self
86            .global_seed
87            .wrapping_mul(6364136223846793005)
88            .wrapping_add(self.operation_counter)
89            .wrapping_add(hash_string(operation_name));
90
91        self.operation_counter += 1;
92
93        // Log operation if enabled
94        if self.operation_log.len() < self.max_log_size {
95            self.operation_log
96                .push(format!("{}: seed={}", operation_name, subseed));
97        }
98
99        subseed
100    }
101
102    /// Reset the operation counter
103    pub fn reset_counter(&mut self) {
104        self.operation_counter = 0;
105    }
106
107    /// Clear the operation log
108    pub fn clear_log(&mut self) {
109        self.operation_log.clear();
110    }
111
112    /// Get a snapshot of the current state for reproducibility
113    pub fn snapshot(&self) -> DeterministicSnapshot {
114        DeterministicSnapshot {
115            global_seed: self.global_seed,
116            operation_counter: self.operation_counter,
117            enabled: self.enabled,
118        }
119    }
120
121    /// Restore from a snapshot
122    pub fn restore(&mut self, snapshot: &DeterministicSnapshot) {
123        self.global_seed = snapshot.global_seed;
124        self.operation_counter = snapshot.operation_counter;
125        self.enabled = snapshot.enabled;
126    }
127}
128
129/// Snapshot of deterministic state for checkpointing
130#[derive(Debug, Clone, Copy)]
131pub struct DeterministicSnapshot {
132    pub global_seed: u64,
133    pub operation_counter: u64,
134    pub enabled: bool,
135}
136
137/// Simple string hash function for operation names
138fn hash_string(s: &str) -> u64 {
139    let mut hash = 0xcbf29ce484222325u64; // FNV offset basis
140    for byte in s.bytes() {
141        hash ^= byte as u64;
142        hash = hash.wrapping_mul(0x100000001b3); // FNV prime
143    }
144    hash
145}
146
147// ============================================================================
148// Global State Management
149// ============================================================================
150
151static GLOBAL_STATE: OnceLock<Arc<Mutex<DeterministicState>>> = OnceLock::new();
152
153/// Get the global deterministic state
154fn get_global_state() -> &'static Arc<Mutex<DeterministicState>> {
155    GLOBAL_STATE.get_or_init(|| Arc::new(Mutex::new(DeterministicState::default())))
156}
157
158/// Enable or disable deterministic mode
159///
160/// When enabled, all operations will use deterministic algorithms and RNG seeding.
161pub fn set_deterministic_mode(enabled: bool) {
162    let state = get_global_state();
163    state.lock().expect("lock should not be poisoned").enabled = enabled;
164}
165
166/// Check if deterministic mode is enabled
167pub fn is_deterministic_mode() -> bool {
168    let state = get_global_state();
169    state.lock().expect("lock should not be poisoned").enabled
170}
171
172/// Set the global random seed
173///
174/// This seed is used to derive subseeds for all random operations.
175pub fn set_global_seed(seed: u64) {
176    let state = get_global_state();
177    let mut s = state.lock().expect("lock should not be poisoned");
178    s.global_seed = seed;
179    s.operation_counter = 0;
180    s.clear_log();
181}
182
183/// Get the current global seed
184pub fn get_global_seed() -> u64 {
185    let state = get_global_state();
186    state
187        .lock()
188        .expect("lock should not be poisoned")
189        .global_seed
190}
191
192/// Enable strict mode (fail on non-deterministic operations)
193pub fn set_strict_mode(strict: bool) {
194    let state = get_global_state();
195    state
196        .lock()
197        .expect("lock should not be poisoned")
198        .strict_mode = strict;
199}
200
201/// Check if strict mode is enabled
202pub fn is_strict_mode() -> bool {
203    let state = get_global_state();
204    state
205        .lock()
206        .expect("lock should not be poisoned")
207        .strict_mode
208}
209
210/// Get a subseed for a specific operation
211///
212/// This ensures that each operation gets a unique, deterministic seed
213/// derived from the global seed and operation sequence.
214pub fn get_operation_seed(operation_name: &str) -> u64 {
215    let state = get_global_state();
216    let mut s = state.lock().expect("lock should not be poisoned");
217
218    if !s.enabled {
219        // In non-deterministic mode, use system time
220        use std::time::{SystemTime, UNIX_EPOCH};
221        SystemTime::now()
222            .duration_since(UNIX_EPOCH)
223            .expect("system time should be after UNIX_EPOCH")
224            .as_nanos() as u64
225    } else {
226        s.next_subseed(operation_name)
227    }
228}
229
230/// Reset the operation counter
231///
232/// Useful when you want to restart from a known state while keeping
233/// the same global seed.
234pub fn reset_operation_counter() {
235    let state = get_global_state();
236    state
237        .lock()
238        .expect("lock should not be poisoned")
239        .reset_counter();
240}
241
242/// Get a snapshot of the current deterministic state
243///
244/// Useful for checkpointing and restoring state.
245pub fn get_state_snapshot() -> DeterministicSnapshot {
246    let state = get_global_state();
247    state
248        .lock()
249        .expect("lock should not be poisoned")
250        .snapshot()
251}
252
253/// Restore deterministic state from a snapshot
254pub fn restore_state_snapshot(snapshot: &DeterministicSnapshot) {
255    let state = get_global_state();
256    state
257        .lock()
258        .expect("lock should not be poisoned")
259        .restore(snapshot);
260}
261
262/// Get the operation log for debugging
263pub fn get_operation_log() -> Vec<String> {
264    let state = get_global_state();
265    state
266        .lock()
267        .expect("lock should not be poisoned")
268        .operation_log
269        .clone()
270}
271
272/// Clear the operation log
273pub fn clear_operation_log() {
274    let state = get_global_state();
275    state
276        .lock()
277        .expect("lock should not be poisoned")
278        .clear_log();
279}
280
281/// Enable operation logging with default max size
282pub fn enable_operation_logging() {
283    let state = get_global_state();
284    let mut s = state.lock().expect("lock should not be poisoned");
285    s.max_log_size = 1000;
286}
287
288/// Reset all deterministic state to defaults (for testing)
289#[doc(hidden)]
290pub fn reset_to_defaults() {
291    let state = get_global_state();
292    let mut s = state.lock().expect("lock should not be poisoned");
293    *s = DeterministicState::default();
294}
295
296/// Scoped deterministic mode
297///
298/// Temporarily enable deterministic mode with a specific seed,
299/// then restore the previous state when dropped.
300pub struct DeterministicScope {
301    previous_state: DeterministicSnapshot,
302}
303
304impl DeterministicScope {
305    /// Create a new deterministic scope with a seed
306    pub fn new(seed: u64) -> Self {
307        let previous_state = get_state_snapshot();
308
309        set_deterministic_mode(true);
310        set_global_seed(seed);
311
312        Self { previous_state }
313    }
314
315    /// Create a scope that only affects the mode, not the seed
316    pub fn with_mode(enabled: bool) -> Self {
317        let previous_state = get_state_snapshot();
318        set_deterministic_mode(enabled);
319        Self { previous_state }
320    }
321}
322
323impl Drop for DeterministicScope {
324    fn drop(&mut self) {
325        restore_state_snapshot(&self.previous_state);
326    }
327}
328
329/// Configuration for deterministic execution
330#[derive(Debug, Clone)]
331pub struct DeterministicConfig {
332    /// Global seed
333    pub seed: u64,
334    /// Enable strict mode
335    pub strict: bool,
336    /// Prefer deterministic algorithms even if slower
337    pub prefer_deterministic: bool,
338    /// Enable operation logging
339    pub log_operations: bool,
340}
341
342impl Default for DeterministicConfig {
343    fn default() -> Self {
344        Self {
345            seed: 42,
346            strict: false,
347            prefer_deterministic: true,
348            log_operations: false,
349        }
350    }
351}
352
353impl DeterministicConfig {
354    /// Apply this configuration globally
355    pub fn apply(&self) {
356        set_global_seed(self.seed);
357        set_deterministic_mode(true);
358        set_strict_mode(self.strict);
359
360        let state = get_global_state();
361        let mut s = state.lock().expect("lock should not be poisoned");
362        s.prefer_deterministic_algorithms = self.prefer_deterministic;
363
364        if !self.log_operations {
365            s.clear_log();
366            s.max_log_size = 0;
367        } else {
368            s.max_log_size = 1000;
369        }
370    }
371}
372
373/// Verify that an operation is reproducible
374///
375/// Runs the operation twice with the same seed and checks if results match.
376pub fn verify_reproducibility<F, T>(operation_name: &str, mut operation: F) -> Result<bool>
377where
378    F: FnMut() -> T,
379    T: PartialEq,
380{
381    let snapshot = get_state_snapshot();
382
383    // First run
384    set_global_seed(snapshot.global_seed);
385    reset_operation_counter();
386    let result1 = operation();
387
388    // Second run with same seed
389    set_global_seed(snapshot.global_seed);
390    reset_operation_counter();
391    let result2 = operation();
392
393    // Restore original state
394    restore_state_snapshot(&snapshot);
395
396    Ok(result1 == result2)
397}
398
399// ============================================================================
400// Utilities
401// ============================================================================
402
403/// Mark an operation as potentially non-deterministic
404///
405/// In strict mode, this will return an error. Otherwise, it logs a warning.
406pub fn mark_non_deterministic(operation_name: &str) -> Result<()> {
407    if is_deterministic_mode() && is_strict_mode() {
408        Err(TensorError::invalid_operation_simple(format!(
409            "Operation '{}' is non-deterministic but strict deterministic mode is enabled",
410            operation_name
411        )))
412    } else {
413        // In non-strict mode, just log it
414        if is_deterministic_mode() {
415            eprintln!(
416                "Warning: Operation '{}' may not be fully deterministic",
417                operation_name
418            );
419        }
420        Ok(())
421    }
422}
423
424/// Helper to check if GPU operations should use deterministic algorithms
425pub fn should_use_deterministic_gpu_ops() -> bool {
426    let state = get_global_state();
427    let s = state.lock().expect("lock should not be poisoned");
428    s.enabled && s.prefer_deterministic_algorithms
429}
430
431// ============================================================================
432// Tests
433// ============================================================================
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use std::sync::Mutex;
439
440    // Global test mutex to serialize tests that modify global state
441    lazy_static::lazy_static! {
442        static ref TEST_MUTEX: Mutex<()> = Mutex::new(());
443    }
444
445    #[test]
446    fn test_deterministic_mode_toggle() {
447        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
448        reset_to_defaults();
449        set_deterministic_mode(true);
450        assert!(is_deterministic_mode());
451
452        set_deterministic_mode(false);
453        assert!(!is_deterministic_mode());
454    }
455
456    #[test]
457    fn test_global_seed() {
458        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
459        reset_to_defaults();
460        set_global_seed(12345);
461        assert_eq!(get_global_seed(), 12345);
462
463        set_global_seed(67890);
464        assert_eq!(get_global_seed(), 67890);
465    }
466
467    #[test]
468    fn test_operation_seed_generation() {
469        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
470        reset_to_defaults();
471        set_deterministic_mode(true);
472        set_global_seed(42);
473
474        let seed1 = get_operation_seed("test_op");
475        let seed2 = get_operation_seed("test_op");
476
477        // Seeds should be different due to counter increment
478        assert_ne!(seed1, seed2);
479
480        // Reset and verify reproducibility
481        reset_operation_counter();
482        let seed3 = get_operation_seed("test_op");
483        assert_eq!(seed1, seed3);
484    }
485
486    #[test]
487    fn test_operation_seed_uniqueness() {
488        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
489        reset_to_defaults();
490        set_deterministic_mode(true);
491        set_global_seed(42);
492        reset_operation_counter();
493
494        let seed_a = get_operation_seed("operation_a");
495        let seed_b = get_operation_seed("operation_b");
496
497        // Different operations should get different seeds
498        assert_ne!(seed_a, seed_b);
499    }
500
501    #[test]
502    fn test_snapshot_and_restore() {
503        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
504        reset_to_defaults();
505        set_deterministic_mode(true);
506        set_global_seed(100);
507
508        let _ = get_operation_seed("op1");
509        let _ = get_operation_seed("op2");
510
511        let snapshot = get_state_snapshot();
512
513        let _ = get_operation_seed("op3");
514
515        restore_state_snapshot(&snapshot);
516
517        let seed_after_restore = get_operation_seed("op3");
518
519        // After restore, we should get the same seed for op3
520        restore_state_snapshot(&snapshot);
521        let seed_repeat = get_operation_seed("op3");
522
523        assert_eq!(seed_after_restore, seed_repeat);
524    }
525
526    #[test]
527    fn test_deterministic_scope() {
528        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
529        reset_to_defaults();
530        set_deterministic_mode(false);
531        set_global_seed(100);
532
533        {
534            let _scope = DeterministicScope::new(200);
535            assert!(is_deterministic_mode());
536            assert_eq!(get_global_seed(), 200);
537        }
538
539        // After scope ends, state should be restored
540        assert!(!is_deterministic_mode());
541        assert_eq!(get_global_seed(), 100);
542    }
543
544    #[test]
545    fn test_strict_mode() {
546        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
547        reset_to_defaults();
548        set_strict_mode(true);
549        assert!(is_strict_mode());
550
551        set_strict_mode(false);
552        assert!(!is_strict_mode());
553    }
554
555    #[test]
556    fn test_mark_non_deterministic() {
557        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
558        reset_to_defaults();
559        set_deterministic_mode(true);
560        set_strict_mode(false);
561
562        // Should succeed in non-strict mode
563        assert!(mark_non_deterministic("test_op").is_ok());
564
565        set_strict_mode(true);
566        // Should fail in strict mode
567        assert!(mark_non_deterministic("test_op").is_err());
568    }
569
570    #[test]
571    fn test_config_apply() {
572        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
573        reset_to_defaults();
574        let config = DeterministicConfig {
575            seed: 777,
576            strict: true,
577            prefer_deterministic: true,
578            log_operations: false,
579        };
580
581        config.apply();
582
583        assert_eq!(get_global_seed(), 777);
584        assert!(is_deterministic_mode());
585        assert!(is_strict_mode());
586    }
587
588    #[test]
589    fn test_operation_log() {
590        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
591        reset_to_defaults();
592        enable_operation_logging();
593        set_deterministic_mode(true);
594        set_global_seed(42);
595
596        let _ = get_operation_seed("op1");
597        let _ = get_operation_seed("op2");
598
599        let log = get_operation_log();
600        assert_eq!(log.len(), 2);
601        assert!(log[0].contains("op1"));
602        assert!(log[1].contains("op2"));
603    }
604
605    #[test]
606    fn test_hash_string_deterministic() {
607        // Same string should always produce same hash
608        let hash1 = hash_string("test");
609        let hash2 = hash_string("test");
610        assert_eq!(hash1, hash2);
611
612        // Different strings should produce different hashes
613        let hash3 = hash_string("different");
614        assert_ne!(hash1, hash3);
615    }
616
617    #[test]
618    fn test_reproducibility_with_counter_reset() {
619        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
620        reset_to_defaults();
621        set_deterministic_mode(true);
622        set_global_seed(42);
623
624        // First sequence
625        reset_operation_counter();
626        let seeds1: Vec<u64> = (0..5)
627            .map(|i| get_operation_seed(&format!("op{}", i)))
628            .collect();
629
630        // Second sequence with same seed
631        reset_operation_counter();
632        let seeds2: Vec<u64> = (0..5)
633            .map(|i| get_operation_seed(&format!("op{}", i)))
634            .collect();
635
636        assert_eq!(seeds1, seeds2);
637    }
638
639    #[test]
640    fn test_non_deterministic_mode_uses_system_time() {
641        let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
642        reset_to_defaults();
643        set_deterministic_mode(false);
644
645        let seed1 = get_operation_seed("test");
646        std::thread::sleep(std::time::Duration::from_nanos(100));
647        let seed2 = get_operation_seed("test");
648
649        // In non-deterministic mode, seeds should be different
650        // (though there's a tiny chance they could be the same)
651        // We just check that the function doesn't panic
652        let _ = seed1;
653        let _ = seed2;
654    }
655}