sp1_core_executor/
hook.rs

1use core::fmt::Debug;
2
3use std::sync::{Arc, RwLock, RwLockWriteGuard};
4
5use hashbrown::HashMap;
6use sp1_curves::{
7    edwards::ed25519::ed25519_sqrt, params::FieldParameters, BigUint, Integer, One, Zero,
8};
9
10use crate::Executor;
11
12/// A runtime hook, wrapped in a smart pointer.
13pub type BoxedHook<'a> = Arc<RwLock<dyn Hook + Send + Sync + 'a>>;
14
15pub use sp1_primitives::consts::fd::*;
16
17/// A runtime hook. May be called during execution by writing to a specified file descriptor,
18/// accepting and returning arbitrary data.
19pub trait Hook {
20    /// Invoke the runtime hook with a standard environment and arbitrary data.
21    /// Returns the computed data.
22    fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>>;
23}
24
25impl<F: FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>>> Hook for F {
26    /// Invokes the function `self` as a hook.
27    fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
28        self(env, buf)
29    }
30}
31
32/// Wrap a function in a smart pointer so it may be placed in a `HookRegistry`.
33///
34/// Note: the Send + Sync requirement may be logically extraneous. Requires further investigation.
35pub fn hookify<'a>(
36    f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
37) -> BoxedHook<'a> {
38    Arc::new(RwLock::new(f))
39}
40
41/// A registry of hooks to call, indexed by the file descriptors through which they are accessed.
42#[derive(Clone)]
43pub struct HookRegistry<'a> {
44    /// Table of registered hooks. Prefer using `Runtime::hook`, ` Runtime::hook_env`,
45    /// and `HookRegistry::get` over interacting with this field directly.
46    pub(crate) table: HashMap<u32, BoxedHook<'a>>,
47}
48
49impl<'a> HookRegistry<'a> {
50    /// Create a default [`HookRegistry`].
51    #[must_use]
52    pub fn new() -> Self {
53        HookRegistry::default()
54    }
55
56    /// Create an empty [`HookRegistry`].
57    #[must_use]
58    pub fn empty() -> Self {
59        Self { table: HashMap::default() }
60    }
61
62    /// Get a hook with exclusive write access, if it exists.
63    ///
64    /// Note: This function should not be called in async contexts, unless you know what you are
65    /// doing.
66    #[must_use]
67    pub fn get(&self, fd: u32) -> Option<RwLockWriteGuard<dyn Hook + Send + Sync + 'a>> {
68        // Calling `.unwrap()` panics on a poisoned lock. Should never happen normally.
69        self.table.get(&fd).map(|x| x.write().unwrap())
70    }
71}
72
73impl Default for HookRegistry<'_> {
74    fn default() -> Self {
75        // When `LazyCell` gets stabilized (1.81.0), we can use it to avoid unnecessary allocations.
76        let table = HashMap::from([
77            // Note: To ensure any `fd` value is synced with `zkvm/precompiles/src/io.rs`,
78            // add an assertion to the test `hook_fds_match` below.
79            (FD_ECRECOVER_HOOK, hookify(hook_ecrecover)),
80            (FD_EDDECOMPRESS, hookify(hook_ed_decompress)),
81            (FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)),
82            (FD_BLS12_381_SQRT, hookify(bls::hook_bls12_381_sqrt)),
83            (FD_BLS12_381_INVERSE, hookify(bls::hook_bls12_381_inverse)),
84            (FD_FP_SQRT, hookify(fp_ops::hook_fp_sqrt)),
85            (FD_FP_INV, hookify(fp_ops::hook_fp_inverse)),
86        ]);
87
88        Self { table }
89    }
90}
91
92impl Debug for HookRegistry<'_> {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        let mut keys = self.table.keys().collect::<Vec<_>>();
95        keys.sort_unstable();
96        f.debug_struct("HookRegistry")
97            .field(
98                "table",
99                &format_args!("{{{} hooks registered at {:?}}}", self.table.len(), keys),
100            )
101            .finish()
102    }
103}
104
105/// Environment that a hook may read from.
106pub struct HookEnv<'a, 'b: 'a> {
107    /// The runtime.
108    pub runtime: &'a Executor<'b>,
109}
110
111/// The hook for the `ecrecover` patches.
112///
113/// The input should be of the form [(`curve_id_u8` | `r_is_y_odd_u8` << 7) || `r` || `alpha`]
114/// where:
115/// * `curve_id` is 1 for secp256k1 and 2 for secp256r1
116/// * `r_is_y_odd` is 0 if r is even and 1 if r is is odd
117/// * r is the x-coordinate of the point, which should be 32 bytes,
118/// * alpha := r * r * r * (a * r) + b, which should be 32 bytes.
119///
120/// Returns vec![vec![1], `y`, `r_inv`] if the point is decompressable
121/// and vec![vec![0],`nqr_hint`] if not.
122#[must_use]
123pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
124    assert!(buf.len() == 64 + 1, "ecrecover should have length 65");
125
126    let curve_id = buf[0] & 0b0111_1111;
127    let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
128
129    let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
130    let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
131
132    match curve_id {
133        1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
134        2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
135        _ => unimplemented!("Unsupported curve id: {}", curve_id),
136    }
137}
138
139mod ecrecover {
140    use sp1_curves::{k256, p256};
141
142    /// The non-quadratic residue for the curve for secp256k1 and secp256r1.
143    const NQR: [u8; 32] = {
144        let mut nqr = [0; 32];
145        nqr[31] = 3;
146        nqr
147    };
148
149    pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
150        use k256::{
151            elliptic_curve::ff::PrimeField, FieldBytes as K256FieldBytes,
152            FieldElement as K256FieldElement, Scalar as K256Scalar,
153        };
154
155        let r = K256FieldElement::from_bytes(K256FieldBytes::from_slice(&r)).unwrap();
156        debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
157
158        let alpha = K256FieldElement::from_bytes(K256FieldBytes::from_slice(&alpha)).unwrap();
159        assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
160
161        // nomralize the y-coordinate always to be consistent.
162        if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
163            let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
164            let r_inv = r.invert().expect("Non zero r scalar");
165
166            if r_y_is_odd != bool::from(y_coord.is_odd()) {
167                y_coord = y_coord.negate(1);
168                y_coord = y_coord.normalize();
169            }
170
171            vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
172        } else {
173            let nqr_field = K256FieldElement::from_bytes(K256FieldBytes::from_slice(&NQR)).unwrap();
174            let qr = alpha * nqr_field;
175            let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
176
177            vec![vec![0], root.to_bytes().to_vec()]
178        }
179    }
180
181    pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
182        use p256::{
183            elliptic_curve::ff::PrimeField, FieldBytes as P256FieldBytes,
184            FieldElement as P256FieldElement, Scalar as P256Scalar,
185        };
186
187        let r = P256FieldElement::from_bytes(P256FieldBytes::from_slice(&r)).unwrap();
188        debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
189
190        let alpha = P256FieldElement::from_bytes(P256FieldBytes::from_slice(&alpha)).unwrap();
191        debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
192
193        if let Some(mut y_coord) = alpha.sqrt().into_option() {
194            let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
195            let r_inv = r.invert().expect("Non zero r scalar");
196
197            if r_y_is_odd != bool::from(y_coord.is_odd()) {
198                y_coord = -y_coord;
199            }
200
201            vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
202        } else {
203            let nqr_field = P256FieldElement::from_bytes(P256FieldBytes::from_slice(&NQR)).unwrap();
204            let qr = alpha * nqr_field;
205            let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
206
207            vec![vec![0], root.to_bytes().to_vec()]
208        }
209    }
210}
211
212mod fp_ops {
213    use super::{pad_to_be, BigUint, HookEnv, One, Zero};
214
215    /// Compute the inverse of a field element.
216    ///
217    /// # Arguments:
218    /// * `buf` - The buffer containing the data needed to compute the inverse.
219    ///     - [ len || Element || Modulus ]
220    ///     - len is the u32 length of the element and modulus in big endian.
221    ///     - Element is the field element to compute the inverse of, interpreted as a big endian
222    ///       integer of `len` bytes.
223    ///
224    /// # Returns:
225    /// A single 32 byte vector containing the inverse.
226    ///
227    /// # Panics:
228    /// - If the buffer length is not valid.
229    /// - If the element is zero.
230    pub fn hook_fp_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
231        let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
232
233        assert!(buf.len() == 4 + 2 * len, "FpOp: Invalid buffer length");
234
235        let buf = &buf[4..];
236        let element = BigUint::from_bytes_be(&buf[..len]);
237        let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
238
239        assert!(!element.is_zero(), "FpOp: Inverse called with zero");
240
241        let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
242
243        vec![pad_to_be(&inverse, len)]
244    }
245
246    /// Compute the square root of a field element.
247    ///
248    /// # Arguments:
249    /// * `buf` - The buffer containing the data needed to compute the square root.
250    ///     - [ len || Element || Modulus || NQR ]
251    ///     - len is the length of the element, modulus, and nqr in big endian.
252    ///     - Element is the field element to compute the square root of, interpreted as a big
253    ///       endian integer of `len` bytes.
254    ///     - Modulus is the modulus of the field, interpreted as a big endian integer of `len`
255    ///       bytes.
256    ///     - NQR is the non-quadratic residue of the field, interpreted as a big endian integer of
257    ///       `len` bytes.
258    ///
259    /// # Assumptions
260    /// - NQR is a non-quadratic residue of the field.
261    ///
262    /// # Returns:
263    /// [ `status_u8` || `root_bytes` ]
264    ///
265    /// If the status is 0, this is the root of NQR * element.
266    /// If the status is 1, this is the root of element.
267    ///
268    /// # Panics:
269    /// - If the buffer length is not valid.
270    /// - If the element is not less than the modulus.
271    /// - If the nqr is not less than the modulus.
272    /// - If the element is zero.
273    pub fn hook_fp_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
274        let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
275
276        assert!(buf.len() == 4 + 3 * len, "FpOp: Invalid buffer length");
277
278        let buf = &buf[4..];
279        let element = BigUint::from_bytes_be(&buf[..len]);
280        let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
281        let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
282
283        assert!(
284            element < modulus,
285            "Element is not less than modulus, the hook only accepts canonical representations"
286        );
287        assert!(
288            nqr < modulus,
289            "NQR is zero or non-canonical, the hook only accepts canonical representations"
290        );
291
292        // The sqrt of zero is zero.
293        if element.is_zero() {
294            return vec![vec![1], vec![0; len]];
295        }
296
297        // Compute the square root of the element using the general Tonelli-Shanks algorithm.
298        // The implementation can be used for any field as it is field-agnostic.
299        if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
300            vec![vec![1], pad_to_be(&root, len)]
301        } else {
302            let qr = (&nqr * &element) % &modulus;
303            let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
304
305            vec![vec![0], pad_to_be(&root, len)]
306        }
307    }
308
309    /// Compute the square root of a field element for some modulus.
310    ///
311    /// Requires a known non-quadratic residue of the field.
312    fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
313        // If the prime field is of the form p = 3 mod 4, and `x` is a quadratic residue modulo `p`,
314        // then one square root of `x` is given by `x^(p+1 / 4) mod p`.
315        if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
316            let maybe_root =
317                element.modpow(&((modulus + BigUint::from(1u64)) / BigUint::from(4u64)), modulus);
318
319            return Some(maybe_root).filter(|root| root * root % modulus == *element);
320        }
321
322        tonelli_shanks(element, modulus, nqr)
323    }
324
325    /// Compute the square root of a field element using the Tonelli-Shanks algorithm.
326    ///
327    /// # Arguments:
328    /// * `element` - The field element to compute the square root of.
329    /// * `modulus` - The modulus of the field.
330    /// * `nqr` - The non-quadratic residue of the field.
331    ///
332    /// # Assumptions:
333    /// - The element is a quadratic residue modulo the modulus.
334    ///
335    /// Ref: <https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm>
336    #[allow(clippy::many_single_char_names)]
337    fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
338        // First, compute the Legendre symbol of the element.
339        // If the symbol is not 1, then the element is not a quadratic residue.
340        if legendre_symbol(element, modulus) != BigUint::one() {
341            return None;
342        }
343
344        // Find the values of Q and S such that modulus - 1 = Q * 2^S.
345        let mut s = BigUint::zero();
346        let mut q = modulus - BigUint::one();
347        while &q % &BigUint::from(2u64) == BigUint::zero() {
348            s += BigUint::from(1u64);
349            q /= BigUint::from(2u64);
350        }
351
352        let z = nqr;
353        let mut c = z.modpow(&q, modulus);
354        let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
355        let mut t = element.modpow(&q, modulus);
356        let mut m = s;
357
358        while t != BigUint::one() {
359            let mut i = BigUint::zero();
360            let mut tt = t.clone();
361            while tt != BigUint::one() {
362                tt = &tt * &tt % modulus;
363                i += BigUint::from(1u64);
364
365                if i == m {
366                    return None;
367                }
368            }
369
370            let b_pow =
371                BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
372            let b = c.modpow(&b_pow, modulus);
373
374            r = &r * &b % modulus;
375            c = &b * &b % modulus;
376            t = &t * &c % modulus;
377            m = i;
378        }
379
380        Some(r)
381    }
382
383    /// Compute the Legendre symbol of a field element.
384    ///
385    /// This indicates if the element is a quadratic in the prime field.
386    ///
387    /// Ref: <https://en.wikipedia.org/wiki/Legendre_symbol>
388    fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
389        assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
390
391        element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
392    }
393
394    #[cfg(test)]
395    mod test {
396        use super::*;
397        use std::str::FromStr;
398
399        #[test]
400        fn test_legendre_symbol() {
401            // The modulus of the secp256k1 base field.
402            let modulus = BigUint::from_str(
403                "115792089237316195423570985008687907853269984665640564039457584007908834671663",
404            )
405            .unwrap();
406            let neg_1 = &modulus - BigUint::one();
407
408            let fixtures = [
409                (BigUint::from(4u64), BigUint::from(1u64)),
410                (BigUint::from(2u64), BigUint::from(1u64)),
411                (BigUint::from(3u64), neg_1.clone()),
412            ];
413
414            for (element, expected) in fixtures {
415                let result = legendre_symbol(&element, &modulus);
416                assert_eq!(result, expected);
417            }
418        }
419
420        #[test]
421        fn test_tonelli_shanks() {
422            // The modulus of the secp256k1 base field.
423            let p = BigUint::from_str(
424                "115792089237316195423570985008687907853269984665640564039457584007908834671663",
425            )
426            .unwrap();
427
428            let nqr = BigUint::from_str("3").unwrap();
429
430            let large_element = &p - BigUint::from(u16::MAX);
431            let square = &large_element * &large_element % &p;
432
433            let fixtures = [
434                (BigUint::from(2u64), true),
435                (BigUint::from(3u64), false),
436                (BigUint::from(4u64), true),
437                (square, true),
438            ];
439
440            for (element, expected) in fixtures {
441                let result = tonelli_shanks(&element, &p, &nqr);
442                if expected {
443                    assert!(result.is_some());
444
445                    let result = result.unwrap();
446                    assert!((&result * &result) % &p == element);
447                } else {
448                    assert!(result.is_none());
449                }
450            }
451        }
452    }
453}
454
455/// Checks if a compressed Edwards point can be decompressed.
456///
457/// # Arguments
458/// * `env` - The environment in which the hook is invoked.
459/// * `buf` - The buffer containing the compressed Edwards point.
460///    - The compressed Edwards point is 32 bytes.
461///    - The high bit of the last byte is the sign bit.
462///
463/// Returns vec![vec![1]] if the point is decompressable.
464/// Returns vec![vec![0], `v_inv`, `nqr_hint`] if the point is not decompressable.
465///
466/// WARNING: This function merely hints at the validity of the compressed point. These values must
467/// be constrained by the zkVM for correctness.
468#[must_use]
469pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
470    const NQR_CURVE_25519: u8 = 2;
471    let modulus = sp1_curves::edwards::ed25519::Ed25519BaseField::modulus();
472
473    let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
474    // Mask the sign bit.
475    bytes[31] &= 0b0111_1111;
476
477    // The AIR asserts canon inputs, so hint here if it cant be satisfied.
478    let y = BigUint::from_bytes_le(&bytes);
479    if y >= modulus {
480        return vec![vec![0]];
481    }
482
483    let v = BigUint::from_bytes_le(&buf[32..]);
484    // This is computed as dy^2 - 1
485    // so it should always be in the field.
486    assert!(v < modulus, "V is not a valid field element");
487
488    // For a point to be decompressable, (yy - 1) / (yy * d + 1) must be a quadratic residue.
489    let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
490    let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
491    let u_div_v = (&u * &v_inv) % &modulus;
492
493    // Note: Our sqrt impl doesnt care about canon representation,
494    // however we have already checked that were less than the modulus.
495    if ed25519_sqrt(&u_div_v).is_some() {
496        vec![vec![1]]
497    } else {
498        let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
499        let root = ed25519_sqrt(&qr).unwrap();
500
501        // Pad the results, since this may not be a full 32 bytes.
502        let v_inv_bytes = v_inv.to_bytes_le();
503        let mut v_inv_padded = [0_u8; 32];
504        v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
505
506        let root_bytes = root.to_bytes_le();
507        let mut root_padded = [0_u8; 32];
508        root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
509
510        vec![vec![0], v_inv_padded.to_vec(), root_padded.to_vec()]
511    }
512}
513
514mod bls {
515    use super::{pad_to_be, BigUint, HookEnv};
516    use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
517
518    /// A non-quadratic residue for the `12_381` base field in big endian.
519    pub const NQR_BLS12_381: [u8; 48] = {
520        let mut nqr = [0; 48];
521        nqr[47] = 2;
522        nqr
523    };
524
525    /// The base field modulus for the `12_381` curve, in little endian.
526    pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
527
528    /// Given a field element, in big endian, this function computes the square root.
529    ///
530    /// - If the field element is the additive identity, this function returns `vec![vec![1],
531    ///   vec![0; 48]]`.
532    /// - If the field element is a quadratic residue, this function returns `vec![vec![1],
533    ///   vec![sqrt(fe)]  ]`.
534    /// - If the field element (fe) is not a quadratic residue, this function returns `vec![vec![0],
535    ///   vec![sqrt(``NQR_BLS12_381`` * fe)]]`.
536    pub fn hook_bls12_381_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
537        let field_element = BigUint::from_bytes_be(&buf[..48]);
538
539        // This should be checked in the VM as its easier than dispatching a hook call.
540        // But for completeness we include this happy path also.
541        if field_element.is_zero() {
542            return vec![vec![1], vec![0; 48]];
543        }
544
545        let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
546
547        // Since `BLS12_381_MODULUS` == 3 mod 4,. we can use shanks methods.
548        // This means we only need to exponentiate by `(modulus + 1) / 4`.
549        let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
550        let sqrt = field_element.modpow(&exp, &modulus);
551
552        // Shanks methods only works if the field element is a quadratic residue.
553        // So we need to check if the square of the sqrt is equal to the field element.
554        let square = (&sqrt * &sqrt) % &modulus;
555        if square != field_element {
556            let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
557            let qr = (&nqr * &field_element) % &modulus;
558
559            // By now, the product of two non-quadratic residues is a quadratic residue.
560            // So we can use shanks methods again to get its square root.
561            //
562            // We pass this root back to the VM to constrain the "failure" case.
563            let root = qr.modpow(&exp, &modulus);
564
565            assert!((&root * &root) % &modulus == qr, "NQR sanity check failed, this is a bug.");
566
567            return vec![vec![0], pad_to_be(&root, 48)];
568        }
569
570        vec![vec![1], pad_to_be(&sqrt, 48)]
571    }
572
573    /// Given a field element, in big endian, this function computes the inverse.
574    ///
575    /// This functions will panic if the additive identity is passed in.
576    pub fn hook_bls12_381_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
577        let field_element = BigUint::from_bytes_be(&buf[..48]);
578
579        // Zero is not invertible, and we dont want to have to return a status from here.
580        assert!(!field_element.is_zero(), "Field element is the additive identity");
581
582        let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
583
584        // Compute the inverse using Fermat's little theorem, ie, a^(p-2) = a^-1 mod p.
585        let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
586
587        vec![pad_to_be(&inverse, 48)]
588    }
589}
590
591/// Given the product of some 256-byte numbers and a modulus, this function does a modular
592/// reduction and hints back the values to the vm in order to constrain it.
593///
594/// # Arguments
595///
596/// * `env` - The environment in which the hook is invoked.
597/// * `buf` - The buffer containing the le bytes of the 512 byte product and the 256 byte modulus.
598///
599/// Returns The le bytes of the product % modulus (512 bytes)
600/// and the quotient floor(product/modulus) (256 bytes).
601///
602/// WANRING: This function is used to perform a modular reduction outside of the zkVM context.
603/// These values must be constrained by the zkVM for correctness.
604#[must_use]
605pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
606    assert_eq!(
607        buf.len(),
608        256 + 256 + 256,
609        "rsa_mul_mod input should have length 256 + 256 + 256, this is a bug."
610    );
611
612    let prod: &[u8; 512] = buf[..512].try_into().unwrap();
613    let m: &[u8; 256] = buf[512..].try_into().unwrap();
614
615    let prod = BigUint::from_bytes_le(prod);
616    let m = BigUint::from_bytes_le(m);
617
618    let (q, rem) = prod.div_rem(&m);
619
620    let mut rem = rem.to_bytes_le();
621    rem.resize(256, 0);
622
623    let mut q = q.to_bytes_le();
624    q.resize(256, 0);
625
626    vec![rem, q]
627}
628
629pub(crate) mod deprecated_hooks {
630    use super::HookEnv;
631    use sp1_curves::{
632        k256::{
633            ecdsa::{RecoveryId, Signature, VerifyingKey},
634            elliptic_curve::ops::Invert,
635        },
636        p256::ecdsa::Signature as p256Signature,
637    };
638
639    /// Recovers the public key from the signature and message hash using the k256 crate.
640    ///
641    /// # Arguments
642    ///
643    /// * `env` - The environment in which the hook is invoked.
644    /// * `buf` - The buffer containing the signature and message hash.
645    ///     - The signature is 65 bytes, the first 64 bytes are the signature and the last byte is
646    ///       the recovery ID.
647    ///     - The message hash is 32 bytes.
648    ///
649    /// The result is returned as a pair of bytes, where the first 32 bytes are the X coordinate
650    /// and the second 32 bytes are the Y coordinate of the decompressed point.
651    ///
652    /// WARNING: This function is used to recover the public key outside of the zkVM context. These
653    /// values must be constrained by the zkVM for correctness.
654    #[must_use]
655    pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
656        assert_eq!(buf.len(), 65 + 32, "ecrecover input should have length 65 + 32");
657        let (sig, msg_hash) = buf.split_at(65);
658        let sig: &[u8; 65] = sig.try_into().unwrap();
659        let msg_hash: &[u8; 32] = msg_hash.try_into().unwrap();
660
661        let mut recovery_id = sig[64];
662        let mut sig = Signature::from_slice(&sig[..64]).unwrap();
663
664        if let Some(sig_normalized) = sig.normalize_s() {
665            sig = sig_normalized;
666            recovery_id ^= 1;
667        }
668        let recid = RecoveryId::from_byte(recovery_id).expect("Computed recovery ID is invalid!");
669
670        let recovered_key = VerifyingKey::recover_from_prehash(&msg_hash[..], &sig, recid).unwrap();
671        let bytes = recovered_key.to_sec1_bytes();
672
673        let (_, s) = sig.split_scalars();
674        let s_inverse = s.invert();
675
676        vec![bytes.to_vec(), s_inverse.to_bytes().to_vec()]
677    }
678
679    /// Recovers s inverse from the signature using the secp256r1 crate.
680    ///
681    /// # Arguments
682    ///
683    /// * `env` - The environment in which the hook is invoked.
684    /// * `buf` - The buffer containing the signature.
685    ///     - The signature is 64 bytes.
686    ///
687    /// The result is a single 32 byte vector containing s inverse.
688    #[must_use]
689    pub fn hook_r1_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
690        assert_eq!(buf.len(), 64, "ecrecover input should have length 64");
691        let sig: &[u8; 64] = buf.try_into().unwrap();
692        let sig = p256Signature::from_slice(sig).unwrap();
693
694        let (_, s) = sig.split_scalars();
695        let s_inverse = s.invert();
696
697        vec![s_inverse.to_bytes().to_vec()]
698    }
699
700    /// Recovers the public key from the signature and message hash using the k256 crate.
701    ///
702    /// # Arguments
703    ///
704    /// * `env` - The environment in which the hook is invoked.
705    /// * `buf` - The buffer containing the signature and message hash.
706    ///     - The signature is 65 bytes, the first 64 bytes are the signature and the last byte is
707    ///       the recovery ID.
708    ///     - The message hash is 32 bytes.
709    ///
710    /// The result is returned as a status and a pair of bytes, where the first 32 bytes are the X
711    /// coordinate and the second 32 bytes are the Y coordinate of the decompressed point.
712    ///
713    /// A status of 0 indicates that the public key could not be recovered.
714    ///
715    /// WARNING: This function is used to recover the public key outside of the zkVM context. These
716    /// values must be constrained by the zkVM for correctness.
717    #[must_use]
718    pub fn hook_ecrecover_v2(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
719        assert_eq!(
720            buf.len(),
721            65 + 32,
722            "ecrecover input should have length 65 + 32, this is a bug."
723        );
724        let (sig, msg_hash) = buf.split_at(65);
725        let sig: &[u8; 65] = sig.try_into().unwrap();
726        let msg_hash: &[u8; 32] = msg_hash.try_into().unwrap();
727
728        let mut recovery_id = sig[64];
729        let mut sig = Signature::from_slice(&sig[..64]).unwrap();
730
731        if let Some(sig_normalized) = sig.normalize_s() {
732            sig = sig_normalized;
733            recovery_id ^= 1;
734        }
735        let recid = RecoveryId::from_byte(recovery_id)
736            .expect("Computed recovery ID is invalid, this is a bug.");
737
738        // Attempting to recvover the public key has failed, write a 0 to indicate to the caller.
739        let Ok(recovered_key) = VerifyingKey::recover_from_prehash(&msg_hash[..], &sig, recid)
740        else {
741            return vec![vec![0]];
742        };
743
744        let bytes = recovered_key.to_sec1_bytes();
745
746        let (_, s) = sig.split_scalars();
747        let s_inverse = s.invert();
748
749        vec![vec![1], bytes.to_vec(), s_inverse.to_bytes().to_vec()]
750    }
751
752    /// Checks if a compressed Edwards point can be decompressed.
753    ///
754    /// # Arguments
755    /// * `env` - The environment in which the hook is invoked.
756    /// * `buf` - The buffer containing the compressed Edwards point.
757    ///    - The compressed Edwards point is 32 bytes.
758    ///    - The high bit of the last byte is the sign bit.
759    ///
760    /// The result is either `0` if the point cannot be decompressed, or `1` if it can.
761    ///
762    /// WARNING: This function merely hints at the validity of the compressed point. These values
763    /// must be constrained by the zkVM for correctness.
764    #[must_use]
765    pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
766        let Ok(point) = sp1_curves::curve25519_dalek::CompressedEdwardsY::from_slice(buf) else {
767            return vec![vec![0]];
768        };
769
770        if sp1_curves::edwards::ed25519::decompress(&point).is_some() {
771            vec![vec![1]]
772        } else {
773            vec![vec![0]]
774        }
775    }
776}
777
778/// Pads a big uint to the given length in big endian.
779fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
780    // First take the byes in little endian
781    let mut bytes = val.to_bytes_le();
782    // Resize so we get the full padding correctly.
783    bytes.resize(len, 0);
784    // Convert back to big endian.
785    bytes.reverse();
786
787    bytes
788}
789
790#[cfg(test)]
791mod tests {
792    #![allow(clippy::print_stdout)]
793
794    use super::*;
795
796    #[test]
797    pub fn registry_new_is_inhabited() {
798        assert_ne!(HookRegistry::new().table.len(), 0);
799        println!("{:?}", HookRegistry::new());
800    }
801
802    #[test]
803    pub fn registry_empty_is_empty() {
804        assert_eq!(HookRegistry::empty().table.len(), 0);
805    }
806}