Skip to main content

tract_tensorflow/ops/random/
philox.rs

1// from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/random/philox_random.h
2
3use tract_hir::internal::*;
4
5#[derive(Copy, Clone)]
6pub struct Philox4x32x10 {
7    key: u64,
8    counter: u128,
9}
10
11fn mul_hilo(a: u32, b: u32) -> (u32, u32) {
12    ((((a as u64) * (b as u64)) >> 32) as u32, ((a as u64) * (b as u64)) as u32)
13}
14
15#[allow(non_upper_case_globals)]
16impl Philox4x32x10 {
17    pub fn weird_tf_constructor(seed_lo: u64, seed_hi: u64) -> Philox4x32x10 {
18        let mut ph = Self::for_seed(seed_lo);
19        ph.skip_fast((seed_hi as u128) << 64);
20        ph
21    }
22
23    #[allow(unused)]
24    pub fn for_seeds(seed1: u32, seed2: u32) -> Philox4x32x10 {
25        Self::for_seed(((seed2 as u64) << 32) | seed1 as u64)
26    }
27
28    pub fn for_seed(seed: u64) -> Philox4x32x10 {
29        Philox4x32x10 { key: seed, counter: 0 }
30    }
31
32    pub fn skip_fast(&mut self, n: u128) {
33        self.counter = self.counter.wrapping_add(n);
34    }
35
36    #[allow(unused)]
37    pub fn next_as_u32s(&mut self) -> [u32; 4] {
38        let v = self.next();
39        [v as u32, (v >> 32) as u32, (v >> 64) as u32, (v >> 96) as u32]
40    }
41
42    pub fn next(&mut self) -> u128 {
43        let mut key = self.key;
44        let mut counter = self.counter;
45
46        // 0
47        Self::compute_one(&mut counter, key);
48        Self::raise_key(&mut key);
49        // 1
50        Self::compute_one(&mut counter, key);
51        Self::raise_key(&mut key);
52        // 2
53        Self::compute_one(&mut counter, key);
54        Self::raise_key(&mut key);
55        // 3
56        Self::compute_one(&mut counter, key);
57        Self::raise_key(&mut key);
58        // 4
59        Self::compute_one(&mut counter, key);
60        Self::raise_key(&mut key);
61        // 5
62        Self::compute_one(&mut counter, key);
63        Self::raise_key(&mut key);
64        // 6
65        Self::compute_one(&mut counter, key);
66        Self::raise_key(&mut key);
67        // 7
68        Self::compute_one(&mut counter, key);
69        Self::raise_key(&mut key);
70        // 8
71        Self::compute_one(&mut counter, key);
72        Self::raise_key(&mut key);
73        // 9
74        Self::compute_one(&mut counter, key);
75
76        self.counter = self.counter.wrapping_add(1);
77        counter
78    }
79
80    fn raise_key(key: &mut u64) {
81        const kPhiloxW32A: u32 = 0x9E3779B9;
82        const kPhiloxW32B: u32 = 0xBB67AE85;
83
84        let k0 = *key as u32;
85        let k1 = (*key >> 32) as u32;
86        let k0 = k0.wrapping_add(kPhiloxW32A) as u64;
87        let k1 = k1.wrapping_add(kPhiloxW32B) as u64;
88
89        *key = (k1 << 32) | k0;
90    }
91
92    fn compute_one(counter: &mut u128, key: u64) {
93        const kPhiloxM4x32A: u32 = 0xD2511F53;
94        const kPhiloxM4x32B: u32 = 0xCD9E8D57;
95
96        let c0 = *counter as u32;
97        let c1 = (*counter >> 32) as u32;
98        let c2 = (*counter >> 64) as u32;
99        let c3 = (*counter >> 96) as u32;
100
101        let (hi0, lo0) = mul_hilo(kPhiloxM4x32A, c0);
102        let (hi1, lo1) = mul_hilo(kPhiloxM4x32B, c2);
103
104        let r0 = (hi1 ^ c1 ^ (key as u32)) as u128;
105        let r1 = lo1 as u128;
106        let r2 = (hi0 ^ c3 ^ ((key >> 32) as u32)) as u128;
107        let r3 = lo0 as u128;
108
109        *counter = (r3 << 96) | (r2 << 64) | (r1 << 32) | r0
110    }
111
112    pub fn u32_iter(self) -> impl Iterator<Item = u32> {
113        self.flat_map(|big| {
114            tvec![big as u32, (big >> 32) as u32, (big >> 64) as u32, (big >> 96) as u32]
115                .into_iter()
116        })
117    }
118}
119
120impl Iterator for Philox4x32x10 {
121    type Item = u128;
122    fn next(&mut self) -> Option<u128> {
123        Some(Philox4x32x10::next(self))
124    }
125}
126
127#[cfg(test)]
128mod test {
129    use super::*;
130
131    // checked against https://github.com/dominikwerder/philox
132    // https://github.com/dominikwerder/philox/blob/master/src/test.rs#L62
133    #[test]
134    fn seed() {
135        let mut ph = Philox4x32x10::for_seeds(1, 2);
136        assert_eq!(ph.next_as_u32s(), [0x598de3a, 0x98d2802e, 0x270f8f9e, 0xeab709d3]);
137    }
138
139    #[test]
140    fn zeros() {
141        let mut ph = Philox4x32x10::for_seeds(0, 0);
142        assert_eq!(ph.next_as_u32s(), [0x6627e8d5, 0xe169c58d, 0xbc57ac4c, 0x9b00dbd8]);
143    }
144
145    #[test]
146    fn ffff() {
147        let mut ph = Philox4x32x10::for_seeds(0xffffffff, 0xffffffff);
148        ph.skip_fast(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff);
149        assert_eq!(ph.next_as_u32s(), [0x408f276d, 0x41c83b0e, 0xa20bc7c6, 0x6d5451fd]);
150    }
151
152    #[test]
153    fn x243f6a88() {
154        let mut ph = Philox4x32x10::for_seeds(0xa4093822, 0x299f31d0);
155        ph.skip_fast(0x0370_7344_1319_8a2e_85a3_08d3_243f_6a88);
156        assert_eq!(ph.next_as_u32s(), [0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1]);
157    }
158}