Skip to main content

sp1_core_executor/
hook.rs

1use core::fmt::Debug;
2
3use std::sync::{Arc, RwLock, RwLockWriteGuard};
4
5use hashbrown::HashMap;
6use sp1_curves::{
7    edwards::ed25519::ed25519_sqrt, params::FieldParameters, BigUint, Integer, One, Zero,
8};
9
10/// A runtime hook, wrapped in a smart pointer.
11pub type BoxedHook<'a> = Arc<RwLock<dyn Hook + Send + Sync + 'a>>;
12
13pub use sp1_primitives::consts::fd::*;
14
15/// A runtime hook. May be called during execution by writing to a specified file descriptor,
16/// accepting and returning arbitrary data.
17pub trait Hook {
18    /// Invoke the runtime hook with a standard environment and arbitrary data.
19    /// Returns the computed data.
20    fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>>;
21}
22
23impl<F: FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>>> Hook for F {
24    /// Invokes the function `self` as a hook.
25    fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
26        self(env, buf)
27    }
28}
29
30/// Wrap a function in a smart pointer so it may be placed in a `HookRegistry`.
31///
32/// Note: the Send + Sync requirement may be logically extraneous. Requires further investigation.
33pub fn hookify<'a>(
34    f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
35) -> BoxedHook<'a> {
36    Arc::new(RwLock::new(f))
37}
38
39/// A registry of hooks to call, indexed by the file descriptors through which they are accessed.
40#[derive(Clone)]
41pub struct HookRegistry<'a> {
42    /// Table of registered hooks. Prefer using `Runtime::hook`, ` Runtime::hook_env`,
43    /// and `HookRegistry::get` over interacting with this field directly.
44    pub(crate) table: HashMap<u32, BoxedHook<'a>>,
45}
46
47impl<'a> HookRegistry<'a> {
48    /// Create a default [`HookRegistry`].
49    #[must_use]
50    pub fn new() -> Self {
51        HookRegistry::default()
52    }
53
54    /// Create an empty [`HookRegistry`].
55    #[must_use]
56    pub fn empty() -> Self {
57        Self { table: HashMap::default() }
58    }
59
60    /// Get a hook with exclusive write access, if it exists.
61    ///
62    /// Note: This function should not be called in async contexts, unless you know what you are
63    /// doing.
64    #[must_use]
65    pub fn get(&self, fd: u32) -> Option<RwLockWriteGuard<'_, dyn Hook + Send + Sync + 'a>> {
66        // Calling `.unwrap()` panics on a poisoned lock. Should never happen normally.
67        self.table.get(&fd).map(|x| x.write().unwrap())
68    }
69}
70
71impl Default for HookRegistry<'_> {
72    fn default() -> Self {
73        // When `LazyCell` gets stabilized (1.81.0), we can use it to avoid unnecessary allocations.
74        let table = HashMap::from([
75            // Note: To ensure any `fd` value is synced with `zkvm/precompiles/src/io.rs`,
76            // add an assertion to the test `hook_fds_match` below.
77            (FD_ECRECOVER_HOOK, hookify(hook_ecrecover)),
78            (FD_EDDECOMPRESS, hookify(hook_ed_decompress)),
79            (FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)),
80            (FD_BLS12_381_SQRT, hookify(bls::hook_bls12_381_sqrt)),
81            (FD_BLS12_381_INVERSE, hookify(bls::hook_bls12_381_inverse)),
82            (FD_FP_SQRT, hookify(fp_ops::hook_fp_sqrt)),
83            (FD_FP_INV, hookify(fp_ops::hook_fp_inverse)),
84        ]);
85
86        Self { table }
87    }
88}
89
90impl Debug for HookRegistry<'_> {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        let mut keys = self.table.keys().collect::<Vec<_>>();
93        keys.sort_unstable();
94        f.debug_struct("HookRegistry")
95            .field(
96                "table",
97                &format_args!("{{{} hooks registered at {:?}}}", self.table.len(), keys),
98            )
99            .finish()
100    }
101}
102
103/// Environment that a hook may read from.
104///
105/// Note: This struct is currently empty but exists for backwards compatibility
106/// and potential future extensions.
107pub struct HookEnv {}
108
109/// The hook for the `ecrecover` patches.
110///
111/// The input should be of the form [(`curve_id_u8` | `r_is_y_odd_u8` << 7) || `r` || `alpha`]
112/// where:
113/// * `curve_id` is 1 for secp256k1 and 2 for secp256r1
114/// * `r_is_y_odd` is 0 if r is even and 1 if r is is odd
115/// * r is the x-coordinate of the point, which should be 32 bytes,
116/// * alpha := r * r * r * (a * r) + b, which should be 32 bytes.
117///
118/// Returns vec![vec![1], `y`, `r_inv`] if the point is decompressable
119/// and vec![vec![0],`nqr_hint`] if not.
120#[must_use]
121pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
122    assert!(buf.len() == 64 + 1, "ecrecover should have length 65");
123
124    let curve_id = buf[0] & 0b0111_1111;
125    let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
126
127    let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
128    let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
129
130    match curve_id {
131        1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
132        2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
133        _ => unimplemented!("Unsupported curve id: {}", curve_id),
134    }
135}
136
137mod ecrecover {
138    use sp1_curves::{k256, p256};
139
140    /// The non-quadratic residue for the curve for secp256k1 and secp256r1.
141    const NQR: [u8; 32] = {
142        let mut nqr = [0; 32];
143        nqr[31] = 3;
144        nqr
145    };
146
147    pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
148        use k256::{
149            elliptic_curve::ff::PrimeField, FieldElement as K256FieldElement, Scalar as K256Scalar,
150        };
151
152        let r = K256FieldElement::from_bytes(&r.into()).unwrap();
153        debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
154
155        let alpha = K256FieldElement::from_bytes(&alpha.into()).unwrap();
156        assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
157
158        // Normalize the y-coordinate always to be consistent.
159        if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
160            let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
161            let r_inv = r.invert().expect("Non zero r scalar");
162
163            if r_y_is_odd != bool::from(y_coord.is_odd()) {
164                y_coord = y_coord.negate(1);
165                y_coord = y_coord.normalize();
166            }
167
168            vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
169        } else {
170            let nqr_field = K256FieldElement::from_bytes(&NQR.into()).unwrap();
171            let qr = alpha * nqr_field;
172            let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
173
174            vec![vec![0], root.to_bytes().to_vec()]
175        }
176    }
177
178    pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
179        use p256::{
180            elliptic_curve::ff::PrimeField, FieldElement as P256FieldElement, Scalar as P256Scalar,
181        };
182
183        let r = P256FieldElement::from_bytes(&r.into()).unwrap();
184        debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
185
186        let alpha = P256FieldElement::from_bytes(&alpha.into()).unwrap();
187        debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
188
189        if let Some(mut y_coord) = alpha.sqrt().into_option() {
190            let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
191            let r_inv = r.invert().expect("Non zero r scalar");
192
193            if r_y_is_odd != bool::from(y_coord.is_odd()) {
194                y_coord = -y_coord;
195            }
196
197            vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
198        } else {
199            let nqr_field = P256FieldElement::from_bytes(&NQR.into()).unwrap();
200            let qr = alpha * nqr_field;
201            let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
202
203            vec![vec![0], root.to_bytes().to_vec()]
204        }
205    }
206}
207
208/// Field operation hooks for computing inverses and square roots.
209pub mod fp_ops {
210    use super::{pad_to_be, BigUint, HookEnv, One, Zero};
211
212    /// Compute the inverse of a field element.
213    ///
214    /// # Arguments:
215    /// * `buf` - The buffer containing the data needed to compute the inverse.
216    ///     - [ len || Element || Modulus ]
217    ///     - len is the u32 length of the element and modulus in big endian.
218    ///     - Element is the field element to compute the inverse of, interpreted as a big endian
219    ///       integer of `len` bytes.
220    ///
221    /// # Returns:
222    /// A single 32 byte vector containing the inverse.
223    ///
224    /// # Panics:
225    /// - If the buffer length is not valid.
226    /// - If the element is zero.
227    #[must_use]
228    pub fn hook_fp_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
229        let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
230
231        assert!(buf.len() == 4 + 2 * len, "FpOp: Invalid buffer length");
232
233        let buf = &buf[4..];
234        let element = BigUint::from_bytes_be(&buf[..len]);
235        let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
236
237        assert!(!element.is_zero(), "FpOp: Inverse called with zero");
238
239        let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
240
241        vec![pad_to_be(&inverse, len)]
242    }
243
244    /// Compute the square root of a field element.
245    ///
246    /// # Arguments:
247    /// * `buf` - The buffer containing the data needed to compute the square root.
248    ///     - [ len || Element || Modulus || NQR ]
249    ///     - len is the length of the element, modulus, and nqr in big endian.
250    ///     - Element is the field element to compute the square root of, interpreted as a big
251    ///       endian integer of `len` bytes.
252    ///     - Modulus is the modulus of the field, interpreted as a big endian integer of `len`
253    ///       bytes.
254    ///     - NQR is the non-quadratic residue of the field, interpreted as a big endian integer of
255    ///       `len` bytes.
256    ///
257    /// # Assumptions
258    /// - NQR is a non-quadratic residue of the field.
259    ///
260    /// # Returns:
261    /// [ `status_u8` || `root_bytes` ]
262    ///
263    /// If the status is 0, this is the root of NQR * element.
264    /// If the status is 1, this is the root of element.
265    ///
266    /// # Panics:
267    /// - If the buffer length is not valid.
268    /// - If the element is not less than the modulus.
269    /// - If the nqr is not less than the modulus.
270    /// - If the element is zero.
271    #[must_use]
272    pub fn hook_fp_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
273        let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
274
275        assert!(buf.len() == 4 + 3 * len, "FpOp: Invalid buffer length");
276
277        let buf = &buf[4..];
278        let element = BigUint::from_bytes_be(&buf[..len]);
279        let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
280        let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
281
282        assert!(
283            element < modulus,
284            "Element is not less than modulus, the hook only accepts canonical representations"
285        );
286        assert!(
287            nqr < modulus,
288            "NQR is zero or non-canonical, the hook only accepts canonical representations"
289        );
290
291        // The sqrt of zero is zero.
292        if element.is_zero() {
293            return vec![vec![1], vec![0; len]];
294        }
295
296        // Compute the square root of the element using the general Tonelli-Shanks algorithm.
297        // The implementation can be used for any field as it is field-agnostic.
298        if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
299            vec![vec![1], pad_to_be(&root, len)]
300        } else {
301            let qr = (&nqr * &element) % &modulus;
302            let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
303
304            vec![vec![0], pad_to_be(&root, len)]
305        }
306    }
307
308    /// Compute the square root of a field element for some modulus.
309    ///
310    /// Requires a known non-quadratic residue of the field.
311    fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
312        // If the prime field is of the form p = 3 mod 4, and `x` is a quadratic residue modulo `p`,
313        // then one square root of `x` is given by `x^(p+1 / 4) mod p`.
314        if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
315            let maybe_root =
316                element.modpow(&((modulus + BigUint::from(1u64)) / BigUint::from(4u64)), modulus);
317
318            return Some(maybe_root).filter(|root| root * root % modulus == *element);
319        }
320
321        tonelli_shanks(element, modulus, nqr)
322    }
323
324    /// Compute the square root of a field element using the Tonelli-Shanks algorithm.
325    ///
326    /// # Arguments:
327    /// * `element` - The field element to compute the square root of.
328    /// * `modulus` - The modulus of the field.
329    /// * `nqr` - The non-quadratic residue of the field.
330    ///
331    /// # Assumptions:
332    /// - The element is a quadratic residue modulo the modulus.
333    ///
334    /// Ref: <https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm>
335    #[allow(clippy::many_single_char_names)]
336    fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
337        // First, compute the Legendre symbol of the element.
338        // If the symbol is not 1, then the element is not a quadratic residue.
339        if legendre_symbol(element, modulus) != BigUint::one() {
340            return None;
341        }
342
343        // Find the values of Q and S such that modulus - 1 = Q * 2^S.
344        let mut s = BigUint::zero();
345        let mut q = modulus - BigUint::one();
346        while &q % &BigUint::from(2u64) == BigUint::zero() {
347            s += BigUint::from(1u64);
348            q /= BigUint::from(2u64);
349        }
350
351        let z = nqr;
352        let mut c = z.modpow(&q, modulus);
353        let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
354        let mut t = element.modpow(&q, modulus);
355        let mut m = s;
356
357        while t != BigUint::one() {
358            let mut i = BigUint::zero();
359            let mut tt = t.clone();
360            while tt != BigUint::one() {
361                tt = &tt * &tt % modulus;
362                i += BigUint::from(1u64);
363
364                if i == m {
365                    return None;
366                }
367            }
368
369            let b_pow =
370                BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
371            let b = c.modpow(&b_pow, modulus);
372
373            r = &r * &b % modulus;
374            c = &b * &b % modulus;
375            t = &t * &c % modulus;
376            m = i;
377        }
378
379        Some(r)
380    }
381
382    /// Compute the Legendre symbol of a field element.
383    ///
384    /// This indicates if the element is a quadratic in the prime field.
385    ///
386    /// Ref: <https://en.wikipedia.org/wiki/Legendre_symbol>
387    fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
388        assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
389
390        element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
391    }
392
393    #[cfg(test)]
394    mod test {
395        use super::*;
396        use std::str::FromStr;
397
398        #[test]
399        fn test_legendre_symbol() {
400            // The modulus of the secp256k1 base field.
401            let modulus = BigUint::from_str(
402                "115792089237316195423570985008687907853269984665640564039457584007908834671663",
403            )
404            .unwrap();
405            let neg_1 = &modulus - BigUint::one();
406
407            let fixtures = [
408                (BigUint::from(4u64), BigUint::from(1u64)),
409                (BigUint::from(2u64), BigUint::from(1u64)),
410                (BigUint::from(3u64), neg_1.clone()),
411            ];
412
413            for (element, expected) in fixtures {
414                let result = legendre_symbol(&element, &modulus);
415                assert_eq!(result, expected);
416            }
417        }
418
419        #[test]
420        fn test_tonelli_shanks() {
421            // The modulus of the secp256k1 base field.
422            let p = BigUint::from_str(
423                "115792089237316195423570985008687907853269984665640564039457584007908834671663",
424            )
425            .unwrap();
426
427            let nqr = BigUint::from_str("3").unwrap();
428
429            let large_element = &p - BigUint::from(u16::MAX);
430            let square = &large_element * &large_element % &p;
431
432            let fixtures = [
433                (BigUint::from(2u64), true),
434                (BigUint::from(3u64), false),
435                (BigUint::from(4u64), true),
436                (square, true),
437            ];
438
439            for (element, expected) in fixtures {
440                let result = tonelli_shanks(&element, &p, &nqr);
441                if expected {
442                    assert!(result.is_some());
443
444                    let result = result.unwrap();
445                    assert!((&result * &result) % &p == element);
446                } else {
447                    assert!(result.is_none());
448                }
449            }
450        }
451    }
452}
453
454/// Checks if a compressed Edwards point can be decompressed.
455///
456/// # Arguments
457/// * `env` - The environment in which the hook is invoked.
458/// * `buf` - The buffer containing the compressed Edwards point.
459///    - The compressed Edwards point is 32 bytes.
460///    - The high bit of the last byte is the sign bit.
461///
462/// Returns vec![vec![1]] if the point is decompressable.
463/// Returns vec![vec![0], `v_inv`, `nqr_hint`] if the point is not decompressable.
464///
465/// WARNING: This function merely hints at the validity of the compressed point. These values must
466/// be constrained by the zkVM for correctness.
467#[must_use]
468pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
469    const NQR_CURVE_25519: u8 = 2;
470    let modulus = sp1_curves::edwards::ed25519::Ed25519BaseField::modulus();
471
472    let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
473    // Mask the sign bit.
474    bytes[31] &= 0b0111_1111;
475
476    // The AIR asserts canon inputs, so hint here if it cant be satisfied.
477    let y = BigUint::from_bytes_le(&bytes);
478    if y >= modulus {
479        return vec![vec![0]];
480    }
481
482    let v = BigUint::from_bytes_le(&buf[32..]);
483    // This is computed as dy^2 - 1
484    // so it should always be in the field.
485    assert!(v < modulus, "V is not a valid field element");
486
487    // For a point to be decompressable, (yy - 1) / (yy * d + 1) must be a quadratic residue.
488    let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
489    let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
490    let u_div_v = (&u * &v_inv) % &modulus;
491
492    // Note: Our sqrt impl doesnt care about canon representation,
493    // however we have already checked that were less than the modulus.
494    if ed25519_sqrt(&u_div_v).is_some() {
495        vec![vec![1]]
496    } else {
497        let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
498        let root = ed25519_sqrt(&qr).unwrap();
499
500        // Pad the results, since this may not be a full 32 bytes.
501        let v_inv_bytes = v_inv.to_bytes_le();
502        let mut v_inv_padded = [0_u8; 32];
503        v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
504
505        let root_bytes = root.to_bytes_le();
506        let mut root_padded = [0_u8; 32];
507        root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
508
509        vec![vec![0], v_inv_padded.to_vec(), root_padded.to_vec()]
510    }
511}
512
513/// BLS12-381 field operation hooks.
514pub mod bls {
515    use super::{pad_to_be, BigUint, HookEnv};
516    use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
517
518    /// A non-quadratic residue for the `12_381` base field in big endian.
519    pub const NQR_BLS12_381: [u8; 48] = {
520        let mut nqr = [0; 48];
521        nqr[47] = 2;
522        nqr
523    };
524
525    /// The base field modulus for the `12_381` curve, in little endian.
526    pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
527
528    /// Given a field element, in big endian, this function computes the square root.
529    ///
530    /// - If the field element is the additive identity, this function returns `vec![vec![1],
531    ///   vec![0; 48]]`.
532    /// - If the field element is a quadratic residue, this function returns `vec![vec![1],
533    ///   vec![sqrt(fe)]  ]`.
534    /// - If the field element (fe) is not a quadratic residue, this function returns `vec![vec![0],
535    ///   vec![sqrt(``NQR_BLS12_381`` * fe)]]`.
536    #[must_use]
537    pub fn hook_bls12_381_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
538        let field_element = BigUint::from_bytes_be(&buf[..48]);
539
540        // This should be checked in the VM as its easier than dispatching a hook call.
541        // But for completeness we include this happy path also.
542        if field_element.is_zero() {
543            return vec![vec![1], vec![0; 48]];
544        }
545
546        let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
547
548        // Since `BLS12_381_MODULUS` == 3 mod 4,. we can use shanks methods.
549        // This means we only need to exponentiate by `(modulus + 1) / 4`.
550        let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
551        let sqrt = field_element.modpow(&exp, &modulus);
552
553        // Shanks methods only works if the field element is a quadratic residue.
554        // So we need to check if the square of the sqrt is equal to the field element.
555        let square = (&sqrt * &sqrt) % &modulus;
556        if square != field_element {
557            let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
558            let qr = (&nqr * &field_element) % &modulus;
559
560            // By now, the product of two non-quadratic residues is a quadratic residue.
561            // So we can use shanks methods again to get its square root.
562            //
563            // We pass this root back to the VM to constrain the "failure" case.
564            let root = qr.modpow(&exp, &modulus);
565
566            assert!((&root * &root) % &modulus == qr, "NQR sanity check failed, this is a bug.");
567
568            return vec![vec![0], pad_to_be(&root, 48)];
569        }
570
571        vec![vec![1], pad_to_be(&sqrt, 48)]
572    }
573
574    /// Given a field element, in big endian, this function computes the inverse.
575    ///
576    /// This functions will panic if the additive identity is passed in.
577    #[must_use]
578    pub fn hook_bls12_381_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
579        let field_element = BigUint::from_bytes_be(&buf[..48]);
580
581        // Zero is not invertible, and we dont want to have to return a status from here.
582        assert!(!field_element.is_zero(), "Field element is the additive identity");
583
584        let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
585
586        // Compute the inverse using Fermat's little theorem, ie, a^(p-2) = a^-1 mod p.
587        let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
588
589        vec![pad_to_be(&inverse, 48)]
590    }
591}
592
593/// Given the product of some 256-byte numbers and a modulus, this function does a modular
594/// reduction and hints back the values to the vm in order to constrain it.
595///
596/// # Arguments
597///
598/// * `env` - The environment in which the hook is invoked.
599/// * `buf` - The buffer containing the le bytes of the 512 byte product and the 256 byte modulus.
600///
601/// Returns The le bytes of the product % modulus (512 bytes)
602/// and the quotient floor(product/modulus) (256 bytes).
603///
604/// WANRING: This function is used to perform a modular reduction outside of the zkVM context.
605/// These values must be constrained by the zkVM for correctness.
606#[must_use]
607pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
608    assert!(
609        buf.len() == 256 * 3 || buf.len() == 384 * 3 || buf.len() == 512 * 3,
610        "rsa_mul_mod input should have length key_size * 3, this is a bug."
611    );
612
613    let len = buf.len() / 3;
614    let prod = BigUint::from_bytes_le(&buf[..2 * len]);
615    let m = BigUint::from_bytes_le(&buf[2 * len..]);
616
617    let (q, rem) = prod.div_rem(&m);
618
619    let mut rem = rem.to_bytes_le();
620    rem.resize(len, 0);
621
622    let mut q = q.to_bytes_le();
623    q.resize(len, 0);
624
625    vec![rem, q]
626}
627
628/// Pads a big uint to the given length in big endian.
629fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
630    // First take the byes in little endian
631    let mut bytes = val.to_bytes_le();
632    // Resize so we get the full padding correctly.
633    bytes.resize(len, 0);
634    // Convert back to big endian.
635    bytes.reverse();
636
637    bytes
638}
639
640#[cfg(test)]
641mod tests {
642    #![allow(clippy::print_stdout)]
643
644    use super::*;
645
646    #[test]
647    pub fn registry_new_is_inhabited() {
648        assert_ne!(HookRegistry::new().table.len(), 0);
649        println!("{:?}", HookRegistry::new());
650    }
651
652    #[test]
653    pub fn registry_empty_is_empty() {
654        assert_eq!(HookRegistry::empty().table.len(), 0);
655    }
656}