1use core::fmt::Debug;
2
3use std::sync::{Arc, RwLock, RwLockWriteGuard};
4
5use hashbrown::HashMap;
6use sp1_curves::{
7 edwards::ed25519::ed25519_sqrt, params::FieldParameters, BigUint, Integer, One, Zero,
8};
9
10use crate::Executor;
11
12pub type BoxedHook<'a> = Arc<RwLock<dyn Hook + Send + Sync + 'a>>;
14
15pub use sp1_primitives::consts::fd::*;
16
17pub trait Hook {
20 fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>>;
23}
24
25impl<F: FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>>> Hook for F {
26 fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
28 self(env, buf)
29 }
30}
31
32pub fn hookify<'a>(
36 f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
37) -> BoxedHook<'a> {
38 Arc::new(RwLock::new(f))
39}
40
41#[derive(Clone)]
43pub struct HookRegistry<'a> {
44 pub(crate) table: HashMap<u32, BoxedHook<'a>>,
47}
48
49impl<'a> HookRegistry<'a> {
50 #[must_use]
52 pub fn new() -> Self {
53 HookRegistry::default()
54 }
55
56 #[must_use]
58 pub fn empty() -> Self {
59 Self { table: HashMap::default() }
60 }
61
62 #[must_use]
67 pub fn get(&self, fd: u32) -> Option<RwLockWriteGuard<dyn Hook + Send + Sync + 'a>> {
68 self.table.get(&fd).map(|x| x.write().unwrap())
70 }
71}
72
73impl Default for HookRegistry<'_> {
74 fn default() -> Self {
75 let table = HashMap::from([
77 (FD_ECRECOVER_HOOK, hookify(hook_ecrecover)),
80 (FD_EDDECOMPRESS, hookify(hook_ed_decompress)),
81 (FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)),
82 (FD_BLS12_381_SQRT, hookify(bls::hook_bls12_381_sqrt)),
83 (FD_BLS12_381_INVERSE, hookify(bls::hook_bls12_381_inverse)),
84 (FD_FP_SQRT, hookify(fp_ops::hook_fp_sqrt)),
85 (FD_FP_INV, hookify(fp_ops::hook_fp_inverse)),
86 ]);
87
88 Self { table }
89 }
90}
91
92impl Debug for HookRegistry<'_> {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 let mut keys = self.table.keys().collect::<Vec<_>>();
95 keys.sort_unstable();
96 f.debug_struct("HookRegistry")
97 .field(
98 "table",
99 &format_args!("{{{} hooks registered at {:?}}}", self.table.len(), keys),
100 )
101 .finish()
102 }
103}
104
105pub struct HookEnv<'a, 'b: 'a> {
107 pub runtime: &'a Executor<'b>,
109}
110
111#[must_use]
123pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
124 assert!(buf.len() == 64 + 1, "ecrecover should have length 65");
125
126 let curve_id = buf[0] & 0b0111_1111;
127 let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
128
129 let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
130 let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
131
132 match curve_id {
133 1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
134 2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
135 _ => unimplemented!("Unsupported curve id: {}", curve_id),
136 }
137}
138
139mod ecrecover {
140 use sp1_curves::{k256, p256};
141
142 const NQR: [u8; 32] = {
144 let mut nqr = [0; 32];
145 nqr[31] = 3;
146 nqr
147 };
148
149 pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
150 use k256::{
151 elliptic_curve::ff::PrimeField, FieldBytes as K256FieldBytes,
152 FieldElement as K256FieldElement, Scalar as K256Scalar,
153 };
154
155 let r = K256FieldElement::from_bytes(K256FieldBytes::from_slice(&r)).unwrap();
156 debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
157
158 let alpha = K256FieldElement::from_bytes(K256FieldBytes::from_slice(&alpha)).unwrap();
159 assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
160
161 if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
163 let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
164 let r_inv = r.invert().expect("Non zero r scalar");
165
166 if r_y_is_odd != bool::from(y_coord.is_odd()) {
167 y_coord = y_coord.negate(1);
168 y_coord = y_coord.normalize();
169 }
170
171 vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
172 } else {
173 let nqr_field = K256FieldElement::from_bytes(K256FieldBytes::from_slice(&NQR)).unwrap();
174 let qr = alpha * nqr_field;
175 let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
176
177 vec![vec![0], root.to_bytes().to_vec()]
178 }
179 }
180
181 pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
182 use p256::{
183 elliptic_curve::ff::PrimeField, FieldBytes as P256FieldBytes,
184 FieldElement as P256FieldElement, Scalar as P256Scalar,
185 };
186
187 let r = P256FieldElement::from_bytes(P256FieldBytes::from_slice(&r)).unwrap();
188 debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
189
190 let alpha = P256FieldElement::from_bytes(P256FieldBytes::from_slice(&alpha)).unwrap();
191 debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
192
193 if let Some(mut y_coord) = alpha.sqrt().into_option() {
194 let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
195 let r_inv = r.invert().expect("Non zero r scalar");
196
197 if r_y_is_odd != bool::from(y_coord.is_odd()) {
198 y_coord = -y_coord;
199 }
200
201 vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
202 } else {
203 let nqr_field = P256FieldElement::from_bytes(P256FieldBytes::from_slice(&NQR)).unwrap();
204 let qr = alpha * nqr_field;
205 let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
206
207 vec![vec![0], root.to_bytes().to_vec()]
208 }
209 }
210}
211
212mod fp_ops {
213 use super::{pad_to_be, BigUint, HookEnv, One, Zero};
214
215 pub fn hook_fp_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
231 let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
232
233 assert!(buf.len() == 4 + 2 * len, "FpOp: Invalid buffer length");
234
235 let buf = &buf[4..];
236 let element = BigUint::from_bytes_be(&buf[..len]);
237 let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
238
239 assert!(!element.is_zero(), "FpOp: Inverse called with zero");
240
241 let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
242
243 vec![pad_to_be(&inverse, len)]
244 }
245
246 pub fn hook_fp_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
274 let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
275
276 assert!(buf.len() == 4 + 3 * len, "FpOp: Invalid buffer length");
277
278 let buf = &buf[4..];
279 let element = BigUint::from_bytes_be(&buf[..len]);
280 let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
281 let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
282
283 assert!(
284 element < modulus,
285 "Element is not less than modulus, the hook only accepts canonical representations"
286 );
287 assert!(
288 nqr < modulus,
289 "NQR is zero or non-canonical, the hook only accepts canonical representations"
290 );
291
292 if element.is_zero() {
294 return vec![vec![1], vec![0; len]];
295 }
296
297 if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
300 vec![vec![1], pad_to_be(&root, len)]
301 } else {
302 let qr = (&nqr * &element) % &modulus;
303 let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
304
305 vec![vec![0], pad_to_be(&root, len)]
306 }
307 }
308
309 fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
313 if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
316 let maybe_root =
317 element.modpow(&((modulus + BigUint::from(1u64)) / BigUint::from(4u64)), modulus);
318
319 return Some(maybe_root).filter(|root| root * root % modulus == *element);
320 }
321
322 tonelli_shanks(element, modulus, nqr)
323 }
324
325 #[allow(clippy::many_single_char_names)]
337 fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
338 if legendre_symbol(element, modulus) != BigUint::one() {
341 return None;
342 }
343
344 let mut s = BigUint::zero();
346 let mut q = modulus - BigUint::one();
347 while &q % &BigUint::from(2u64) == BigUint::zero() {
348 s += BigUint::from(1u64);
349 q /= BigUint::from(2u64);
350 }
351
352 let z = nqr;
353 let mut c = z.modpow(&q, modulus);
354 let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
355 let mut t = element.modpow(&q, modulus);
356 let mut m = s;
357
358 while t != BigUint::one() {
359 let mut i = BigUint::zero();
360 let mut tt = t.clone();
361 while tt != BigUint::one() {
362 tt = &tt * &tt % modulus;
363 i += BigUint::from(1u64);
364
365 if i == m {
366 return None;
367 }
368 }
369
370 let b_pow =
371 BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
372 let b = c.modpow(&b_pow, modulus);
373
374 r = &r * &b % modulus;
375 c = &b * &b % modulus;
376 t = &t * &c % modulus;
377 m = i;
378 }
379
380 Some(r)
381 }
382
383 fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
389 assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
390
391 element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
392 }
393
394 #[cfg(test)]
395 mod test {
396 use super::*;
397 use std::str::FromStr;
398
399 #[test]
400 fn test_legendre_symbol() {
401 let modulus = BigUint::from_str(
403 "115792089237316195423570985008687907853269984665640564039457584007908834671663",
404 )
405 .unwrap();
406 let neg_1 = &modulus - BigUint::one();
407
408 let fixtures = [
409 (BigUint::from(4u64), BigUint::from(1u64)),
410 (BigUint::from(2u64), BigUint::from(1u64)),
411 (BigUint::from(3u64), neg_1.clone()),
412 ];
413
414 for (element, expected) in fixtures {
415 let result = legendre_symbol(&element, &modulus);
416 assert_eq!(result, expected);
417 }
418 }
419
420 #[test]
421 fn test_tonelli_shanks() {
422 let p = BigUint::from_str(
424 "115792089237316195423570985008687907853269984665640564039457584007908834671663",
425 )
426 .unwrap();
427
428 let nqr = BigUint::from_str("3").unwrap();
429
430 let large_element = &p - BigUint::from(u16::MAX);
431 let square = &large_element * &large_element % &p;
432
433 let fixtures = [
434 (BigUint::from(2u64), true),
435 (BigUint::from(3u64), false),
436 (BigUint::from(4u64), true),
437 (square, true),
438 ];
439
440 for (element, expected) in fixtures {
441 let result = tonelli_shanks(&element, &p, &nqr);
442 if expected {
443 assert!(result.is_some());
444
445 let result = result.unwrap();
446 assert!((&result * &result) % &p == element);
447 } else {
448 assert!(result.is_none());
449 }
450 }
451 }
452 }
453}
454
455#[must_use]
469pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
470 const NQR_CURVE_25519: u8 = 2;
471 let modulus = sp1_curves::edwards::ed25519::Ed25519BaseField::modulus();
472
473 let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
474 bytes[31] &= 0b0111_1111;
476
477 let y = BigUint::from_bytes_le(&bytes);
479 if y >= modulus {
480 return vec![vec![0]];
481 }
482
483 let v = BigUint::from_bytes_le(&buf[32..]);
484 assert!(v < modulus, "V is not a valid field element");
487
488 let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
490 let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
491 let u_div_v = (&u * &v_inv) % &modulus;
492
493 if ed25519_sqrt(&u_div_v).is_some() {
496 vec![vec![1]]
497 } else {
498 let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
499 let root = ed25519_sqrt(&qr).unwrap();
500
501 let v_inv_bytes = v_inv.to_bytes_le();
503 let mut v_inv_padded = [0_u8; 32];
504 v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
505
506 let root_bytes = root.to_bytes_le();
507 let mut root_padded = [0_u8; 32];
508 root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
509
510 vec![vec![0], v_inv_padded.to_vec(), root_padded.to_vec()]
511 }
512}
513
514mod bls {
515 use super::{pad_to_be, BigUint, HookEnv};
516 use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
517
518 pub const NQR_BLS12_381: [u8; 48] = {
520 let mut nqr = [0; 48];
521 nqr[47] = 2;
522 nqr
523 };
524
525 pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
527
528 pub fn hook_bls12_381_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
537 let field_element = BigUint::from_bytes_be(&buf[..48]);
538
539 if field_element.is_zero() {
542 return vec![vec![1], vec![0; 48]];
543 }
544
545 let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
546
547 let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
550 let sqrt = field_element.modpow(&exp, &modulus);
551
552 let square = (&sqrt * &sqrt) % &modulus;
555 if square != field_element {
556 let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
557 let qr = (&nqr * &field_element) % &modulus;
558
559 let root = qr.modpow(&exp, &modulus);
564
565 assert!((&root * &root) % &modulus == qr, "NQR sanity check failed, this is a bug.");
566
567 return vec![vec![0], pad_to_be(&root, 48)];
568 }
569
570 vec![vec![1], pad_to_be(&sqrt, 48)]
571 }
572
573 pub fn hook_bls12_381_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
577 let field_element = BigUint::from_bytes_be(&buf[..48]);
578
579 assert!(!field_element.is_zero(), "Field element is the additive identity");
581
582 let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
583
584 let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
586
587 vec![pad_to_be(&inverse, 48)]
588 }
589}
590
591#[must_use]
605pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
606 assert_eq!(
607 buf.len(),
608 256 + 256 + 256,
609 "rsa_mul_mod input should have length 256 + 256 + 256, this is a bug."
610 );
611
612 let prod: &[u8; 512] = buf[..512].try_into().unwrap();
613 let m: &[u8; 256] = buf[512..].try_into().unwrap();
614
615 let prod = BigUint::from_bytes_le(prod);
616 let m = BigUint::from_bytes_le(m);
617
618 let (q, rem) = prod.div_rem(&m);
619
620 let mut rem = rem.to_bytes_le();
621 rem.resize(256, 0);
622
623 let mut q = q.to_bytes_le();
624 q.resize(256, 0);
625
626 vec![rem, q]
627}
628
629pub(crate) mod deprecated_hooks {
630 use super::HookEnv;
631 use sp1_curves::{
632 k256::{
633 ecdsa::{RecoveryId, Signature, VerifyingKey},
634 elliptic_curve::ops::Invert,
635 },
636 p256::ecdsa::Signature as p256Signature,
637 };
638
639 #[must_use]
655 pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
656 assert_eq!(buf.len(), 65 + 32, "ecrecover input should have length 65 + 32");
657 let (sig, msg_hash) = buf.split_at(65);
658 let sig: &[u8; 65] = sig.try_into().unwrap();
659 let msg_hash: &[u8; 32] = msg_hash.try_into().unwrap();
660
661 let mut recovery_id = sig[64];
662 let mut sig = Signature::from_slice(&sig[..64]).unwrap();
663
664 if let Some(sig_normalized) = sig.normalize_s() {
665 sig = sig_normalized;
666 recovery_id ^= 1;
667 }
668 let recid = RecoveryId::from_byte(recovery_id).expect("Computed recovery ID is invalid!");
669
670 let recovered_key = VerifyingKey::recover_from_prehash(&msg_hash[..], &sig, recid).unwrap();
671 let bytes = recovered_key.to_sec1_bytes();
672
673 let (_, s) = sig.split_scalars();
674 let s_inverse = s.invert();
675
676 vec![bytes.to_vec(), s_inverse.to_bytes().to_vec()]
677 }
678
679 #[must_use]
689 pub fn hook_r1_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
690 assert_eq!(buf.len(), 64, "ecrecover input should have length 64");
691 let sig: &[u8; 64] = buf.try_into().unwrap();
692 let sig = p256Signature::from_slice(sig).unwrap();
693
694 let (_, s) = sig.split_scalars();
695 let s_inverse = s.invert();
696
697 vec![s_inverse.to_bytes().to_vec()]
698 }
699
700 #[must_use]
718 pub fn hook_ecrecover_v2(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
719 assert_eq!(
720 buf.len(),
721 65 + 32,
722 "ecrecover input should have length 65 + 32, this is a bug."
723 );
724 let (sig, msg_hash) = buf.split_at(65);
725 let sig: &[u8; 65] = sig.try_into().unwrap();
726 let msg_hash: &[u8; 32] = msg_hash.try_into().unwrap();
727
728 let mut recovery_id = sig[64];
729 let mut sig = Signature::from_slice(&sig[..64]).unwrap();
730
731 if let Some(sig_normalized) = sig.normalize_s() {
732 sig = sig_normalized;
733 recovery_id ^= 1;
734 }
735 let recid = RecoveryId::from_byte(recovery_id)
736 .expect("Computed recovery ID is invalid, this is a bug.");
737
738 let Ok(recovered_key) = VerifyingKey::recover_from_prehash(&msg_hash[..], &sig, recid)
740 else {
741 return vec![vec![0]];
742 };
743
744 let bytes = recovered_key.to_sec1_bytes();
745
746 let (_, s) = sig.split_scalars();
747 let s_inverse = s.invert();
748
749 vec![vec![1], bytes.to_vec(), s_inverse.to_bytes().to_vec()]
750 }
751
752 #[must_use]
765 pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
766 let Ok(point) = sp1_curves::curve25519_dalek::CompressedEdwardsY::from_slice(buf) else {
767 return vec![vec![0]];
768 };
769
770 if sp1_curves::edwards::ed25519::decompress(&point).is_some() {
771 vec![vec![1]]
772 } else {
773 vec![vec![0]]
774 }
775 }
776}
777
778fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
780 let mut bytes = val.to_bytes_le();
782 bytes.resize(len, 0);
784 bytes.reverse();
786
787 bytes
788}
789
790#[cfg(test)]
791mod tests {
792 #![allow(clippy::print_stdout)]
793
794 use super::*;
795
796 #[test]
797 pub fn registry_new_is_inhabited() {
798 assert_ne!(HookRegistry::new().table.len(), 0);
799 println!("{:?}", HookRegistry::new());
800 }
801
802 #[test]
803 pub fn registry_empty_is_empty() {
804 assert_eq!(HookRegistry::empty().table.len(), 0);
805 }
806}