tract_tensorflow/ops/random/
philox.rs1use tract_hir::internal::*;
4
5#[derive(Copy, Clone)]
6pub struct Philox4x32x10 {
7 key: u64,
8 counter: u128,
9}
10
11fn mul_hilo(a: u32, b: u32) -> (u32, u32) {
12 ((((a as u64) * (b as u64)) >> 32) as u32, ((a as u64) * (b as u64)) as u32)
13}
14
15#[allow(non_upper_case_globals)]
16impl Philox4x32x10 {
17 pub fn weird_tf_constructor(seed_lo: u64, seed_hi: u64) -> Philox4x32x10 {
18 let mut ph = Self::for_seed(seed_lo);
19 ph.skip_fast((seed_hi as u128) << 64);
20 ph
21 }
22
23 #[allow(unused)]
24 pub fn for_seeds(seed1: u32, seed2: u32) -> Philox4x32x10 {
25 Self::for_seed(((seed2 as u64) << 32) | seed1 as u64)
26 }
27
28 pub fn for_seed(seed: u64) -> Philox4x32x10 {
29 Philox4x32x10 { key: seed, counter: 0 }
30 }
31
32 pub fn skip_fast(&mut self, n: u128) {
33 self.counter = self.counter.wrapping_add(n);
34 }
35
36 #[allow(unused)]
37 pub fn next_as_u32s(&mut self) -> [u32; 4] {
38 let v = self.next();
39 [v as u32, (v >> 32) as u32, (v >> 64) as u32, (v >> 96) as u32]
40 }
41
42 pub fn next(&mut self) -> u128 {
43 let mut key = self.key;
44 let mut counter = self.counter;
45
46 Self::compute_one(&mut counter, key);
48 Self::raise_key(&mut key);
49 Self::compute_one(&mut counter, key);
51 Self::raise_key(&mut key);
52 Self::compute_one(&mut counter, key);
54 Self::raise_key(&mut key);
55 Self::compute_one(&mut counter, key);
57 Self::raise_key(&mut key);
58 Self::compute_one(&mut counter, key);
60 Self::raise_key(&mut key);
61 Self::compute_one(&mut counter, key);
63 Self::raise_key(&mut key);
64 Self::compute_one(&mut counter, key);
66 Self::raise_key(&mut key);
67 Self::compute_one(&mut counter, key);
69 Self::raise_key(&mut key);
70 Self::compute_one(&mut counter, key);
72 Self::raise_key(&mut key);
73 Self::compute_one(&mut counter, key);
75
76 self.counter = self.counter.wrapping_add(1);
77 counter
78 }
79
80 fn raise_key(key: &mut u64) {
81 const kPhiloxW32A: u32 = 0x9E3779B9;
82 const kPhiloxW32B: u32 = 0xBB67AE85;
83
84 let k0 = *key as u32;
85 let k1 = (*key >> 32) as u32;
86 let k0 = k0.wrapping_add(kPhiloxW32A) as u64;
87 let k1 = k1.wrapping_add(kPhiloxW32B) as u64;
88
89 *key = (k1 << 32) | k0;
90 }
91
92 fn compute_one(counter: &mut u128, key: u64) {
93 const kPhiloxM4x32A: u32 = 0xD2511F53;
94 const kPhiloxM4x32B: u32 = 0xCD9E8D57;
95
96 let c0 = *counter as u32;
97 let c1 = (*counter >> 32) as u32;
98 let c2 = (*counter >> 64) as u32;
99 let c3 = (*counter >> 96) as u32;
100
101 let (hi0, lo0) = mul_hilo(kPhiloxM4x32A, c0);
102 let (hi1, lo1) = mul_hilo(kPhiloxM4x32B, c2);
103
104 let r0 = (hi1 ^ c1 ^ (key as u32)) as u128;
105 let r1 = lo1 as u128;
106 let r2 = (hi0 ^ c3 ^ ((key >> 32) as u32)) as u128;
107 let r3 = lo0 as u128;
108
109 *counter = (r3 << 96) | (r2 << 64) | (r1 << 32) | r0
110 }
111
112 pub fn u32_iter(self) -> impl Iterator<Item = u32> {
113 self.flat_map(|big| {
114 tvec![big as u32, (big >> 32) as u32, (big >> 64) as u32, (big >> 96) as u32]
115 .into_iter()
116 })
117 }
118}
119
120impl Iterator for Philox4x32x10 {
121 type Item = u128;
122 fn next(&mut self) -> Option<u128> {
123 Some(Philox4x32x10::next(self))
124 }
125}
126
127#[cfg(test)]
128mod test {
129 use super::*;
130
131 #[test]
134 fn seed() {
135 let mut ph = Philox4x32x10::for_seeds(1, 2);
136 assert_eq!(ph.next_as_u32s(), [0x598de3a, 0x98d2802e, 0x270f8f9e, 0xeab709d3]);
137 }
138
139 #[test]
140 fn zeros() {
141 let mut ph = Philox4x32x10::for_seeds(0, 0);
142 assert_eq!(ph.next_as_u32s(), [0x6627e8d5, 0xe169c58d, 0xbc57ac4c, 0x9b00dbd8]);
143 }
144
145 #[test]
146 fn ffff() {
147 let mut ph = Philox4x32x10::for_seeds(0xffffffff, 0xffffffff);
148 ph.skip_fast(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff);
149 assert_eq!(ph.next_as_u32s(), [0x408f276d, 0x41c83b0e, 0xa20bc7c6, 0x6d5451fd]);
150 }
151
152 #[test]
153 fn x243f6a88() {
154 let mut ph = Philox4x32x10::for_seeds(0xa4093822, 0x299f31d0);
155 ph.skip_fast(0x0370_7344_1319_8a2e_85a3_08d3_243f_6a88);
156 assert_eq!(ph.next_as_u32s(), [0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1]);
157 }
158}