1use core::fmt::Debug;
2
3use std::sync::{Arc, RwLock, RwLockWriteGuard};
4
5use hashbrown::HashMap;
6use sp1_curves::{
7 edwards::ed25519::ed25519_sqrt, params::FieldParameters, BigUint, Integer, One, Zero,
8};
9
10pub type BoxedHook<'a> = Arc<RwLock<dyn Hook + Send + Sync + 'a>>;
12
13pub use sp1_primitives::consts::fd::*;
14
15pub trait Hook {
18 fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>>;
21}
22
23impl<F: FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>>> Hook for F {
24 fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
26 self(env, buf)
27 }
28}
29
30pub fn hookify<'a>(
34 f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
35) -> BoxedHook<'a> {
36 Arc::new(RwLock::new(f))
37}
38
39#[derive(Clone)]
41pub struct HookRegistry<'a> {
42 pub(crate) table: HashMap<u32, BoxedHook<'a>>,
45}
46
47impl<'a> HookRegistry<'a> {
48 #[must_use]
50 pub fn new() -> Self {
51 HookRegistry::default()
52 }
53
54 #[must_use]
56 pub fn empty() -> Self {
57 Self { table: HashMap::default() }
58 }
59
60 #[must_use]
65 pub fn get(&self, fd: u32) -> Option<RwLockWriteGuard<'_, dyn Hook + Send + Sync + 'a>> {
66 self.table.get(&fd).map(|x| x.write().unwrap())
68 }
69}
70
71impl Default for HookRegistry<'_> {
72 fn default() -> Self {
73 let table = HashMap::from([
75 (FD_ECRECOVER_HOOK, hookify(hook_ecrecover)),
78 (FD_EDDECOMPRESS, hookify(hook_ed_decompress)),
79 (FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)),
80 (FD_BLS12_381_SQRT, hookify(bls::hook_bls12_381_sqrt)),
81 (FD_BLS12_381_INVERSE, hookify(bls::hook_bls12_381_inverse)),
82 (FD_FP_SQRT, hookify(fp_ops::hook_fp_sqrt)),
83 (FD_FP_INV, hookify(fp_ops::hook_fp_inverse)),
84 ]);
85
86 Self { table }
87 }
88}
89
90impl Debug for HookRegistry<'_> {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 let mut keys = self.table.keys().collect::<Vec<_>>();
93 keys.sort_unstable();
94 f.debug_struct("HookRegistry")
95 .field(
96 "table",
97 &format_args!("{{{} hooks registered at {:?}}}", self.table.len(), keys),
98 )
99 .finish()
100 }
101}
102
103pub struct HookEnv {}
108
109#[must_use]
121pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
122 assert!(buf.len() == 64 + 1, "ecrecover should have length 65");
123
124 let curve_id = buf[0] & 0b0111_1111;
125 let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
126
127 let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
128 let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
129
130 match curve_id {
131 1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
132 2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
133 _ => unimplemented!("Unsupported curve id: {}", curve_id),
134 }
135}
136
137mod ecrecover {
138 use sp1_curves::{k256, p256};
139
140 const NQR: [u8; 32] = {
142 let mut nqr = [0; 32];
143 nqr[31] = 3;
144 nqr
145 };
146
147 pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
148 use k256::{
149 elliptic_curve::ff::PrimeField, FieldElement as K256FieldElement, Scalar as K256Scalar,
150 };
151
152 let r = K256FieldElement::from_bytes(&r.into()).unwrap();
153 debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
154
155 let alpha = K256FieldElement::from_bytes(&alpha.into()).unwrap();
156 assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
157
158 if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
160 let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
161 let r_inv = r.invert().expect("Non zero r scalar");
162
163 if r_y_is_odd != bool::from(y_coord.is_odd()) {
164 y_coord = y_coord.negate(1);
165 y_coord = y_coord.normalize();
166 }
167
168 vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
169 } else {
170 let nqr_field = K256FieldElement::from_bytes(&NQR.into()).unwrap();
171 let qr = alpha * nqr_field;
172 let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
173
174 vec![vec![0], root.to_bytes().to_vec()]
175 }
176 }
177
178 pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
179 use p256::{
180 elliptic_curve::ff::PrimeField, FieldElement as P256FieldElement, Scalar as P256Scalar,
181 };
182
183 let r = P256FieldElement::from_bytes(&r.into()).unwrap();
184 debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
185
186 let alpha = P256FieldElement::from_bytes(&alpha.into()).unwrap();
187 debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
188
189 if let Some(mut y_coord) = alpha.sqrt().into_option() {
190 let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
191 let r_inv = r.invert().expect("Non zero r scalar");
192
193 if r_y_is_odd != bool::from(y_coord.is_odd()) {
194 y_coord = -y_coord;
195 }
196
197 vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
198 } else {
199 let nqr_field = P256FieldElement::from_bytes(&NQR.into()).unwrap();
200 let qr = alpha * nqr_field;
201 let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
202
203 vec![vec![0], root.to_bytes().to_vec()]
204 }
205 }
206}
207
208pub mod fp_ops {
210 use super::{pad_to_be, BigUint, HookEnv, One, Zero};
211
212 #[must_use]
228 pub fn hook_fp_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
229 let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
230
231 assert!(buf.len() == 4 + 2 * len, "FpOp: Invalid buffer length");
232
233 let buf = &buf[4..];
234 let element = BigUint::from_bytes_be(&buf[..len]);
235 let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
236
237 assert!(!element.is_zero(), "FpOp: Inverse called with zero");
238
239 let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
240
241 vec![pad_to_be(&inverse, len)]
242 }
243
244 #[must_use]
272 pub fn hook_fp_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
273 let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
274
275 assert!(buf.len() == 4 + 3 * len, "FpOp: Invalid buffer length");
276
277 let buf = &buf[4..];
278 let element = BigUint::from_bytes_be(&buf[..len]);
279 let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
280 let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
281
282 assert!(
283 element < modulus,
284 "Element is not less than modulus, the hook only accepts canonical representations"
285 );
286 assert!(
287 nqr < modulus,
288 "NQR is zero or non-canonical, the hook only accepts canonical representations"
289 );
290
291 if element.is_zero() {
293 return vec![vec![1], vec![0; len]];
294 }
295
296 if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
299 vec![vec![1], pad_to_be(&root, len)]
300 } else {
301 let qr = (&nqr * &element) % &modulus;
302 let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
303
304 vec![vec![0], pad_to_be(&root, len)]
305 }
306 }
307
308 fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
312 if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
315 let maybe_root =
316 element.modpow(&((modulus + BigUint::from(1u64)) / BigUint::from(4u64)), modulus);
317
318 return Some(maybe_root).filter(|root| root * root % modulus == *element);
319 }
320
321 tonelli_shanks(element, modulus, nqr)
322 }
323
324 #[allow(clippy::many_single_char_names)]
336 fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
337 if legendre_symbol(element, modulus) != BigUint::one() {
340 return None;
341 }
342
343 let mut s = BigUint::zero();
345 let mut q = modulus - BigUint::one();
346 while &q % &BigUint::from(2u64) == BigUint::zero() {
347 s += BigUint::from(1u64);
348 q /= BigUint::from(2u64);
349 }
350
351 let z = nqr;
352 let mut c = z.modpow(&q, modulus);
353 let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
354 let mut t = element.modpow(&q, modulus);
355 let mut m = s;
356
357 while t != BigUint::one() {
358 let mut i = BigUint::zero();
359 let mut tt = t.clone();
360 while tt != BigUint::one() {
361 tt = &tt * &tt % modulus;
362 i += BigUint::from(1u64);
363
364 if i == m {
365 return None;
366 }
367 }
368
369 let b_pow =
370 BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
371 let b = c.modpow(&b_pow, modulus);
372
373 r = &r * &b % modulus;
374 c = &b * &b % modulus;
375 t = &t * &c % modulus;
376 m = i;
377 }
378
379 Some(r)
380 }
381
382 fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
388 assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
389
390 element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
391 }
392
393 #[cfg(test)]
394 mod test {
395 use super::*;
396 use std::str::FromStr;
397
398 #[test]
399 fn test_legendre_symbol() {
400 let modulus = BigUint::from_str(
402 "115792089237316195423570985008687907853269984665640564039457584007908834671663",
403 )
404 .unwrap();
405 let neg_1 = &modulus - BigUint::one();
406
407 let fixtures = [
408 (BigUint::from(4u64), BigUint::from(1u64)),
409 (BigUint::from(2u64), BigUint::from(1u64)),
410 (BigUint::from(3u64), neg_1.clone()),
411 ];
412
413 for (element, expected) in fixtures {
414 let result = legendre_symbol(&element, &modulus);
415 assert_eq!(result, expected);
416 }
417 }
418
419 #[test]
420 fn test_tonelli_shanks() {
421 let p = BigUint::from_str(
423 "115792089237316195423570985008687907853269984665640564039457584007908834671663",
424 )
425 .unwrap();
426
427 let nqr = BigUint::from_str("3").unwrap();
428
429 let large_element = &p - BigUint::from(u16::MAX);
430 let square = &large_element * &large_element % &p;
431
432 let fixtures = [
433 (BigUint::from(2u64), true),
434 (BigUint::from(3u64), false),
435 (BigUint::from(4u64), true),
436 (square, true),
437 ];
438
439 for (element, expected) in fixtures {
440 let result = tonelli_shanks(&element, &p, &nqr);
441 if expected {
442 assert!(result.is_some());
443
444 let result = result.unwrap();
445 assert!((&result * &result) % &p == element);
446 } else {
447 assert!(result.is_none());
448 }
449 }
450 }
451 }
452}
453
454#[must_use]
468pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
469 const NQR_CURVE_25519: u8 = 2;
470 let modulus = sp1_curves::edwards::ed25519::Ed25519BaseField::modulus();
471
472 let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
473 bytes[31] &= 0b0111_1111;
475
476 let y = BigUint::from_bytes_le(&bytes);
478 if y >= modulus {
479 return vec![vec![0]];
480 }
481
482 let v = BigUint::from_bytes_le(&buf[32..]);
483 assert!(v < modulus, "V is not a valid field element");
486
487 let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
489 let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
490 let u_div_v = (&u * &v_inv) % &modulus;
491
492 if ed25519_sqrt(&u_div_v).is_some() {
495 vec![vec![1]]
496 } else {
497 let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
498 let root = ed25519_sqrt(&qr).unwrap();
499
500 let v_inv_bytes = v_inv.to_bytes_le();
502 let mut v_inv_padded = [0_u8; 32];
503 v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
504
505 let root_bytes = root.to_bytes_le();
506 let mut root_padded = [0_u8; 32];
507 root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
508
509 vec![vec![0], v_inv_padded.to_vec(), root_padded.to_vec()]
510 }
511}
512
513pub mod bls {
515 use super::{pad_to_be, BigUint, HookEnv};
516 use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
517
518 pub const NQR_BLS12_381: [u8; 48] = {
520 let mut nqr = [0; 48];
521 nqr[47] = 2;
522 nqr
523 };
524
525 pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
527
528 #[must_use]
537 pub fn hook_bls12_381_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
538 let field_element = BigUint::from_bytes_be(&buf[..48]);
539
540 if field_element.is_zero() {
543 return vec![vec![1], vec![0; 48]];
544 }
545
546 let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
547
548 let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
551 let sqrt = field_element.modpow(&exp, &modulus);
552
553 let square = (&sqrt * &sqrt) % &modulus;
556 if square != field_element {
557 let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
558 let qr = (&nqr * &field_element) % &modulus;
559
560 let root = qr.modpow(&exp, &modulus);
565
566 assert!((&root * &root) % &modulus == qr, "NQR sanity check failed, this is a bug.");
567
568 return vec![vec![0], pad_to_be(&root, 48)];
569 }
570
571 vec![vec![1], pad_to_be(&sqrt, 48)]
572 }
573
574 #[must_use]
578 pub fn hook_bls12_381_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
579 let field_element = BigUint::from_bytes_be(&buf[..48]);
580
581 assert!(!field_element.is_zero(), "Field element is the additive identity");
583
584 let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
585
586 let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
588
589 vec![pad_to_be(&inverse, 48)]
590 }
591}
592
593#[must_use]
607pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
608 assert!(
609 buf.len() == 256 * 3 || buf.len() == 384 * 3 || buf.len() == 512 * 3,
610 "rsa_mul_mod input should have length key_size * 3, this is a bug."
611 );
612
613 let len = buf.len() / 3;
614 let prod = BigUint::from_bytes_le(&buf[..2 * len]);
615 let m = BigUint::from_bytes_le(&buf[2 * len..]);
616
617 let (q, rem) = prod.div_rem(&m);
618
619 let mut rem = rem.to_bytes_le();
620 rem.resize(len, 0);
621
622 let mut q = q.to_bytes_le();
623 q.resize(len, 0);
624
625 vec![rem, q]
626}
627
628fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
630 let mut bytes = val.to_bytes_le();
632 bytes.resize(len, 0);
634 bytes.reverse();
636
637 bytes
638}
639
640#[cfg(test)]
641mod tests {
642 #![allow(clippy::print_stdout)]
643
644 use super::*;
645
646 #[test]
647 pub fn registry_new_is_inhabited() {
648 assert_ne!(HookRegistry::new().table.len(), 0);
649 println!("{:?}", HookRegistry::new());
650 }
651
652 #[test]
653 pub fn registry_empty_is_empty() {
654 assert_eq!(HookRegistry::empty().table.len(), 0);
655 }
656}