rust_randomx/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5use std::os::raw::{c_ulong, c_void};
6use std::sync::Arc;
7use std::thread;
8
9#[allow(dead_code)]
10mod bindings;
11use bindings::*;
12
13#[derive(Debug, Clone, Copy)]
14pub struct Difficulty(u32);
15
16fn div_128(a: u128, b: u128) -> f64 {
17    let a_bytes = a.to_be_bytes();
18    let b_bytes = b.to_be_bytes();
19    let mut t_64 = 0u64;
20    let mut r_64 = 0u64;
21    let mut bytes = 0;
22    let mut started = false;
23    for (t, r) in a_bytes.into_iter().zip(b_bytes.into_iter()) {
24        if t > 0 || r > 0 {
25            started = true;
26        }
27        if started {
28            t_64 <<= 8;
29            r_64 <<= 8;
30            t_64 += t as u64;
31            r_64 += r as u64;
32            bytes += 1;
33            if bytes == 8 {
34                break;
35            }
36        }
37    }
38    t_64 as f64 / r_64 as f64
39}
40
41impl Difficulty {
42    pub fn to_u32(&self) -> u32 {
43        self.0
44    }
45    pub fn new(d: u32) -> Self {
46        Difficulty(d)
47    }
48    pub fn zeros(&self) -> usize {
49        (self.0 >> 24) as usize
50    }
51    pub fn postfix(&self) -> u32 {
52        self.0 & 0x00ffffff
53    }
54    pub fn powerf(&self) -> f64 {
55        2f64.powf(self.zeros() as f64 * 8f64) * (0xffffff as f64 / self.postfix() as f64)
56    }
57    pub fn power(&self) -> u128 {
58        self.powerf() as u128
59    }
60    pub fn from_power(target: u128) -> Self {
61        let mut result = Self::new(0x00ffffff);
62        loop {
63            let mul = target / result.power();
64            if mul > 2 {
65                result = result.scale(2.0);
66            } else {
67                result = result.scale(div_128(target, result.power()) as f32);
68                break;
69            }
70        }
71        result
72    }
73    pub fn scale(&self, s: f32) -> Self {
74        let mut zeros_add = s.log2() as i32 / 8;
75        let rem = s / 256f32.powf(zeros_add as f32);
76        let mut new_postfix = self.postfix() as f32 / rem;
77
78        let postfix_power = 0xffffff as f32 / new_postfix;
79        let postfix_power_zeros = postfix_power.log2() as i32 / 8;
80        zeros_add += postfix_power_zeros;
81        new_postfix *= 256f32.powf(postfix_power_zeros as f32);
82
83        while new_postfix as u32 > 0xffffff {
84            new_postfix /= 256f32;
85            zeros_add -= 1;
86        }
87
88        if self.zeros() as i32 + zeros_add < 0 {
89            return Self::new(0x00ffffff);
90        }
91
92        let new_postfix = (new_postfix as u32).to_le_bytes();
93
94        Difficulty(u32::from_le_bytes([
95            new_postfix[0],
96            new_postfix[1],
97            new_postfix[2],
98            (self.zeros() as i32 + zeros_add) as u8,
99        ]))
100    }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq)]
104pub struct Output([u8; RANDOMX_HASH_SIZE as usize]);
105
106impl From<Difficulty> for Output {
107    fn from(d: Difficulty) -> Self {
108        let mut output = [0u8; 32];
109        let zeros = d.zeros();
110        let postfix = d.postfix();
111        output[zeros..zeros + 3].copy_from_slice(&postfix.to_be_bytes()[1..4]);
112        Self(output)
113    }
114}
115
116impl AsRef<[u8]> for Output {
117    fn as_ref(&self) -> &[u8] {
118        &self.0
119    }
120}
121
122impl Output {
123    pub fn meets_difficulty(&self, d: Difficulty) -> bool {
124        for (a, b) in self.0.iter().zip(Output::from(d).0.iter()) {
125            if a > b {
126                return false;
127            }
128            if a < b {
129                return true;
130            }
131        }
132        true
133    }
134
135    pub fn leading_zeros(&self) -> u32 {
136        let mut zeros = 0;
137        for limb in self.0.iter() {
138            let limb_zeros = limb.leading_zeros();
139            zeros += limb_zeros;
140            if limb_zeros != 8 {
141                break;
142            }
143        }
144        zeros
145    }
146}
147
148#[derive(Clone)]
149struct Sendable<T>(*mut T);
150unsafe impl<T> Send for Sendable<T> {}
151
152pub struct Context {
153    key: Vec<u8>,
154    flags: randomx_flags,
155    fast: bool,
156    cache: *mut randomx_cache,
157    dataset: *mut randomx_dataset,
158}
159
160unsafe impl Send for Context {}
161unsafe impl Sync for Context {}
162
163impl Context {
164    pub fn key(&self) -> &[u8] {
165        &self.key
166    }
167    pub fn new(key: &[u8], fast: bool) -> Self {
168        unsafe {
169            let mut flags = randomx_get_flags();
170            let mut cache = randomx_alloc_cache(flags);
171            randomx_init_cache(cache, key.as_ptr() as *const c_void, key.len());
172            let mut dataset = std::ptr::null_mut();
173            if fast {
174                flags |= randomx_flags_RANDOMX_FLAG_FULL_MEM;
175                dataset = randomx_alloc_dataset(flags);
176                let num_threads = thread::available_parallelism().expect("Failed to determine available parallelism").get();
177                let length = randomx_dataset_item_count() as usize / num_threads;
178                let mut threads = Vec::new();
179                for i in 0..num_threads {
180                    let sendable_cache = Sendable(cache);
181                    let sendable_dataset = Sendable(dataset);
182                    threads.push(thread::spawn(move || {
183                        let cache = sendable_cache.clone();
184                        let dataset = sendable_dataset.clone();
185                        randomx_init_dataset(
186                            dataset.0,
187                            cache.0,
188                            (i * length) as c_ulong,
189                            length as c_ulong,
190                        );
191                    }));
192                }
193                for t in threads {
194                    t.join()
195                        .expect("Error while initializing the RandomX dataset!");
196                }
197
198                randomx_release_cache(cache);
199                cache = std::ptr::null_mut();
200            }
201
202            Self {
203                key: key.to_vec(),
204                flags,
205                fast,
206                cache,
207                dataset,
208            }
209        }
210    }
211}
212
213impl Drop for Context {
214    fn drop(&mut self) {
215        unsafe {
216            if self.fast {
217                randomx_release_dataset(self.dataset);
218            } else {
219                randomx_release_cache(self.cache);
220            }
221        }
222    }
223}
224
225pub struct Hasher {
226    context: Arc<Context>,
227    vm: *mut randomx_vm,
228}
229
230unsafe impl Send for Hasher {}
231unsafe impl Sync for Hasher {}
232
233impl Hasher {
234    pub fn new(context: Arc<Context>) -> Self {
235        unsafe {
236            Hasher {
237                context: Arc::clone(&context),
238                vm: randomx_create_vm(context.flags, context.cache, context.dataset),
239            }
240        }
241    }
242    pub fn update(&mut self, context: Arc<Context>) {
243        unsafe {
244            if context.fast {
245                randomx_vm_set_dataset(self.vm, context.dataset);
246            } else {
247                randomx_vm_set_cache(self.vm, context.cache);
248            }
249        }
250        self.context = context;
251    }
252    pub fn context(&self) -> &Context {
253        &self.context
254    }
255
256    pub fn hash(&self, inp: &[u8]) -> Output {
257        let mut hash = [0u8; RANDOMX_HASH_SIZE as usize];
258        unsafe {
259            randomx_calculate_hash(
260                self.vm,
261                inp.as_ptr() as *const c_void,
262                inp.len(),
263                hash.as_mut_ptr() as *mut c_void,
264            );
265        }
266        Output(hash)
267    }
268
269    pub fn hash_first(&mut self, inp: &[u8]) {
270        unsafe {
271            randomx_calculate_hash_first(self.vm, inp.as_ptr() as *const c_void, inp.len());
272        }
273    }
274    pub fn hash_next(&mut self, next_inp: &[u8]) -> Output {
275        let mut hash = [0u8; RANDOMX_HASH_SIZE as usize];
276        unsafe {
277            randomx_calculate_hash_next(
278                self.vm,
279                next_inp.as_ptr() as *const c_void,
280                next_inp.len(),
281                hash.as_mut_ptr() as *mut c_void,
282            );
283        }
284        Output(hash)
285    }
286    pub fn hash_last(&mut self) -> Output {
287        let mut hash = [0u8; RANDOMX_HASH_SIZE as usize];
288        unsafe {
289            randomx_calculate_hash_last(self.vm, hash.as_mut_ptr() as *mut c_void);
290        }
291        Output(hash)
292    }
293}
294
295impl Drop for Hasher {
296    fn drop(&mut self) {
297        unsafe {
298            randomx_destroy_vm(self.vm);
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    const KEY: &[u8] = b"RandomX example key\x00";
308    const INPUT: &[u8] = b"RandomX example input\x00";
309    const EXPECTED: Output = Output([
310        138, 72, 229, 249, 219, 69, 171, 121, 217, 8, 5, 116, 196, 216, 25, 84, 254, 106, 198, 56,
311        66, 33, 74, 255, 115, 194, 68, 178, 99, 48, 183, 201,
312    ]);
313
314    #[test]
315    fn test_slow_hasher() {
316        let slow = Hasher::new(Arc::new(Context::new(KEY, false)));
317        assert_eq!(slow.hash(INPUT), EXPECTED);
318    }
319
320    #[test]
321    fn test_fast_hasher() {
322        let fast = Hasher::new(Arc::new(Context::new(KEY, true)));
323        assert_eq!(fast.hash(INPUT), EXPECTED);
324    }
325
326    #[test]
327    fn test_difficulty_scaling() {
328        let d1 = Difficulty::new(0x011fffff);
329        let d2 = d1.scale(3f32).scale(3f32).scale(3f32);
330        let d3 = d2.scale(1f32 / 3f32).scale(1f32 / 3f32).scale(1f32 / 3f32);
331        assert_eq!(d1.power(), 2048);
332        assert_eq!(d2.power(), 2048 * 27);
333        assert_eq!(d3.power(), 2048);
334    }
335}