1use ark_bn254::Bn254;
5use ark_ff::AdditiveGroup as _;
6use ark_ff::Field as _;
7use ark_ff::LegendreSymbol;
8use ark_ff::UniformRand as _;
9use ark_serialize::CanonicalDeserialize;
10use ark_serialize::CanonicalSerialize;
11use circom_witness_rs::Graph;
12use groth16::CircomReduction;
13use groth16::Groth16;
14use rand::{CryptoRng, Rng};
15use ruint::aliases::U256;
16use sha2::Digest as _;
17use std::io::Write as _;
18use std::ops::Shr;
19use std::sync::Arc;
20use std::{collections::HashMap, path::Path};
21
22use crate::Groth16Error;
23
24pub use ark_groth16::Proof;
25pub use ark_serialize::Compress;
26pub use ark_serialize::Validate;
27pub use circom_types::groth16::ArkZkey;
28pub use circom_witness_rs::BlackBoxFunction;
29
30pub trait ProofInput {
35 fn prepare_input(&self) -> HashMap<String, Vec<U256>>;
60}
61
62impl ProofInput for HashMap<String, Vec<U256>> {
63 fn prepare_input(&self) -> HashMap<String, Vec<U256>> {
64 self.to_owned()
65 }
66}
67
68#[derive(Debug, thiserror::Error)]
70pub enum ZkeyError {
71 #[error("invalid zkey - wrong sha256 fingerprint: {0}")]
73 ZkeyFingerprintMismatch(String),
74 #[error("invalid graph - wrong sha256 fingerprint: {0}")]
76 GraphFingerprintMismatch(String),
77 #[error("Could not parse zkey - see wrapped error")]
79 ZkeyInvalid(#[source] eyre::Report),
80 #[error(transparent)]
82 GraphInvalid(#[from] eyre::Report),
83 #[error(transparent)]
85 IoError(#[from] std::io::Error),
86}
87
88#[derive(Debug, thiserror::Error)]
90pub enum MaterialSerializationError {
91 #[error("could not serialize zkey - see wrapped error")]
93 ZkeySerialization(#[source] ark_serialize::SerializationError),
94 #[error("could not serialize graph - see wrapped error")]
96 GraphSerialization(#[source] postcard::Error),
97 #[error(transparent)]
99 IoError(#[from] std::io::Error),
100}
101
102#[cfg(any(feature = "reqwest", feature = "reqwest-blocking"))]
103impl From<reqwest::Error> for ZkeyError {
104 fn from(value: reqwest::Error) -> Self {
105 Self::IoError(std::io::Error::new(std::io::ErrorKind::InvalidData, value))
106 }
107}
108
109impl From<circom_types::ZkeyParserError> for ZkeyError {
110 fn from(value: circom_types::ZkeyParserError) -> Self {
111 Self::ZkeyInvalid(eyre::eyre!(value))
112 }
113}
114
115impl From<ark_serialize::SerializationError> for ZkeyError {
116 fn from(value: ark_serialize::SerializationError) -> Self {
117 Self::ZkeyInvalid(eyre::eyre!(value))
118 }
119}
120
121#[derive(Clone)]
128pub struct CircomGroth16Material {
129 zkey: ArkZkey<Bn254>,
130 graph: Graph,
132 bbfs: HashMap<String, BlackBoxFunction>,
134}
135
136pub struct CircomGroth16MaterialBuilder {
152 compress: Compress,
153 validate: Validate,
154 fingerprint_zkey: Option<String>,
155 fingerprint_graph: Option<String>,
156 bbfs: HashMap<String, BlackBoxFunction>,
157}
158
159pub struct CircomGroth16MaterialSerializer<'a> {
161 material: &'a CircomGroth16Material,
162 compress: Compress,
163}
164
165impl Default for CircomGroth16MaterialBuilder {
166 fn default() -> Self {
167 Self {
168 compress: Compress::No,
169 validate: Validate::Yes,
170 fingerprint_zkey: None,
171 fingerprint_graph: None,
172 bbfs: HashMap::default(),
173 }
174 }
175}
176
177impl CircomGroth16MaterialBuilder {
178 pub fn new() -> Self {
187 Self::default()
188 }
189
190 pub fn compress(mut self, compress: Compress) -> Self {
192 self.compress = compress;
193 self
194 }
195
196 pub fn validate(mut self, validate: Validate) -> Self {
198 self.validate = validate;
199 self
200 }
201
202 pub fn fingerprint_zkey(mut self, fingerprint_zkey: String) -> Self {
204 self.fingerprint_zkey = Some(fingerprint_zkey);
205 self
206 }
207
208 pub fn fingerprint_graph(mut self, fingerprint_graph: String) -> Self {
210 self.fingerprint_graph = Some(fingerprint_graph);
211 self
212 }
213
214 pub fn add_bbfs(mut self, bbfs: HashMap<String, BlackBoxFunction>) -> Self {
216 self.bbfs.extend(bbfs);
217 self
218 }
219
220 pub fn bbf_inv(mut self) -> Self {
229 self.bbfs.insert(
230 "bbf_inv".to_string(),
231 Arc::new(move |args: &[ark_bn254::Fr]| -> ark_bn254::Fr {
232 args[0].inverse().unwrap_or(ark_bn254::Fr::ZERO)
236 }),
237 );
238
239 self
240 }
241
242 pub fn bbf_legendre(mut self) -> Self {
251 self.bbfs.insert(
252 "bbf_legendre".to_string(),
253 Arc::new(move |args: &[ark_bn254::Fr]| -> ark_bn254::Fr {
254 match args[0].legendre() {
255 LegendreSymbol::Zero => ark_bn254::Fr::from(0u64),
256 LegendreSymbol::QuadraticResidue => ark_bn254::Fr::from(1u64),
257 LegendreSymbol::QuadraticNonResidue => -ark_bn254::Fr::from(1u64),
258 }
259 }),
260 );
261
262 self
263 }
264
265 pub fn bbf_sqrt_unchecked(mut self) -> Self {
267 self.bbfs.insert(
268 "bbf_sqrt_unchecked".to_string(),
269 Arc::new(move |args: &[ark_bn254::Fr]| -> ark_bn254::Fr {
270 args[0].sqrt().unwrap_or(ark_bn254::Fr::ZERO)
271 }),
272 );
273 self
274 }
275
276 pub fn bbf_sqrt_input(mut self) -> Self {
289 self.bbfs.insert(
290 "bbf_sqrt_input".to_string(),
291 Arc::new(move |args: &[ark_bn254::Fr]| -> ark_bn254::Fr {
292 if args[0] != -ark_bn254::Fr::ONE {
300 args[1]
301 } else {
302 args[2]
303 }
304 }),
305 );
306 self
307 }
308
309 pub fn bbf_num_2_bits_helper(mut self) -> Self {
318 self.bbfs.insert(
319 "bbf_num_2_bits_helper".to_string(),
320 Arc::new(move |args: &[ark_bn254::Fr]| -> ark_bn254::Fr {
321 let a: U256 = args[0].into();
325 let b: U256 = args[1].into();
326 let ls_limb = b.as_limbs()[0];
327 ark_bn254::Fr::new((a.shr(ls_limb as usize) & U256::from(1)).into())
328 }),
329 );
330 self
331 }
332
333 pub fn build_from_paths(
335 self,
336 zkey_path: impl AsRef<Path>,
337 graph_path: impl AsRef<Path>,
338 ) -> Result<CircomGroth16Material, ZkeyError> {
339 let zkey_bytes = std::fs::read(zkey_path)?;
340 let graph_bytes = std::fs::read(graph_path)?;
341 self.build_from_bytes(&zkey_bytes, &graph_bytes)
342 }
343
344 pub fn build_from_reader(
346 self,
347 mut zkey_reader: impl std::io::Read,
348 mut graph_reader: impl std::io::Read,
349 ) -> Result<CircomGroth16Material, ZkeyError> {
350 let mut zkey_bytes = Vec::new();
351 zkey_reader.read_to_end(&mut zkey_bytes)?;
352 let mut graph_bytes = Vec::new();
353 graph_reader.read_to_end(&mut graph_bytes)?;
354 self.build_from_bytes(&zkey_bytes, &graph_bytes)
355 }
356
357 pub fn build_from_bytes(
359 self,
360 zkey_bytes: &[u8],
361 graph_bytes: &[u8],
362 ) -> Result<CircomGroth16Material, ZkeyError> {
363 let validate = if let Some(should_fingerprint) = self.fingerprint_zkey {
364 let is_fingerprint = hex::encode(sha2::Sha256::digest(zkey_bytes));
365 if is_fingerprint != should_fingerprint {
366 return Err(ZkeyError::ZkeyFingerprintMismatch(is_fingerprint));
367 }
368 Validate::No
369 } else {
370 self.validate
371 };
372
373 let zkey = circom_types::groth16::ArkZkey::deserialize_with_mode(
374 zkey_bytes,
375 self.compress,
376 validate,
377 )?;
378 if let Some(should_fingerprint) = self.fingerprint_graph {
379 let is_fingerprint = hex::encode(sha2::Sha256::digest(graph_bytes));
380 if is_fingerprint != should_fingerprint {
381 return Err(ZkeyError::GraphFingerprintMismatch(is_fingerprint));
382 }
383 }
384 let graph = circom_witness_rs::init_graph(graph_bytes).map_err(ZkeyError::GraphInvalid)?;
385 Ok(CircomGroth16Material {
386 zkey,
387 graph,
388 bbfs: self.bbfs,
389 })
390 }
391
392 #[cfg(feature = "reqwest")]
394 pub async fn build_from_urls(
395 self,
396 zkey_url: impl reqwest::IntoUrl,
397 graph_url: impl reqwest::IntoUrl,
398 ) -> Result<CircomGroth16Material, ZkeyError> {
399 let zkey_bytes = reqwest::get(zkey_url).await?.bytes().await?;
400 let graph_bytes = reqwest::get(graph_url).await?.bytes().await?;
401 self.build_from_bytes(&zkey_bytes, &graph_bytes)
402 }
403
404 #[cfg(feature = "reqwest-blocking")]
406 pub fn build_from_urls_blocking(
407 self,
408 zkey_url: impl reqwest::IntoUrl,
409 graph_url: impl reqwest::IntoUrl,
410 ) -> Result<CircomGroth16Material, ZkeyError> {
411 let zkey_bytes = reqwest::blocking::get(zkey_url)?.bytes()?;
412 let graph_bytes = reqwest::blocking::get(graph_url)?.bytes()?;
413 self.build_from_bytes(&zkey_bytes, &graph_bytes)
414 }
415}
416
417impl CircomGroth16Material {
418 pub fn serializer(&self) -> CircomGroth16MaterialSerializer<'_> {
434 CircomGroth16MaterialSerializer {
435 material: self,
436 compress: Compress::No,
437 }
438 }
439
440 pub fn zkey(&self) -> &ArkZkey<Bn254> {
442 &self.zkey
443 }
444
445 pub fn generate_witness(
447 &self,
448 inputs: &impl ProofInput,
449 ) -> Result<Vec<ark_bn254::Fr>, Groth16Error> {
450 let witness = circom_witness_rs::calculate_witness(
451 inputs.prepare_input(),
452 &self.graph,
453 Some(&self.bbfs),
454 )
455 .map_err(Groth16Error::WitnessGeneration)?
456 .into_iter()
457 .map(|v| ark_bn254::Fr::from(ark_ff::BigInt(v.into_limbs())))
458 .collect::<Vec<_>>();
459 Ok(witness)
460 }
461
462 pub fn generate_proof_from_witness<R: Rng + CryptoRng>(
466 &self,
467 witness: &[ark_bn254::Fr],
468 rng: &mut R,
469 ) -> Result<(Proof<Bn254>, Vec<ark_bn254::Fr>), Groth16Error> {
470 let r = ark_bn254::Fr::rand(rng);
471 let s = ark_bn254::Fr::rand(rng);
472
473 let (matrices, pk) = self.zkey.as_inner();
474 let proof = Groth16::prove::<CircomReduction>(pk, r, s, matrices, witness)
475 .map_err(Groth16Error::ProofGeneration)?;
476
477 let inputs = witness[1..matrices.num_instance_variables].to_vec();
478 Ok((proof, inputs))
479 }
480
481 pub fn generate_proof<R: Rng + CryptoRng>(
485 &self,
486 inputs: &impl ProofInput,
487 rng: &mut R,
488 ) -> Result<(Proof<Bn254>, Vec<ark_bn254::Fr>), Groth16Error> {
489 let witness = self.generate_witness(inputs)?;
490 self.generate_proof_from_witness(&witness, rng)
491 }
492
493 pub fn verify_proof(
495 &self,
496 proof: &Proof<Bn254>,
497 public_inputs: &[ark_bn254::Fr],
498 ) -> Result<(), Groth16Error> {
499 Groth16::verify(&self.zkey.pk.vk, proof, public_inputs)
500 .map_err(|_| Groth16Error::InvalidProof)
501 }
502}
503
504impl<'a> CircomGroth16MaterialSerializer<'a> {
505 pub fn compress(mut self, compress: Compress) -> Self {
507 self.compress = compress;
508 self
509 }
510
511 pub fn to_bytes(self) -> Result<(Vec<u8>, Vec<u8>), MaterialSerializationError> {
526 let mut zkey_bytes = Vec::new();
527 let mut graph_bytes = Vec::new();
528 self.to_writer(&mut zkey_bytes, &mut graph_bytes)?;
529 Ok((zkey_bytes, graph_bytes))
530 }
531
532 pub fn to_writer(
534 self,
535 mut zkey_writer: impl std::io::Write,
536 mut graph_writer: impl std::io::Write,
537 ) -> Result<(), MaterialSerializationError> {
538 self.material
539 .zkey
540 .serialize_with_mode(&mut zkey_writer, self.compress)
541 .map_err(MaterialSerializationError::ZkeySerialization)?;
542 postcard::to_io(
543 &(
544 &self.material.graph.nodes,
545 &self.material.graph.signals,
546 &self.material.graph.input_mapping,
547 ),
548 &mut graph_writer,
549 )
550 .map_err(MaterialSerializationError::GraphSerialization)?;
551 Ok(())
552 }
553
554 pub fn to_paths(
556 self,
557 zkey_path: impl AsRef<Path>,
558 graph_path: impl AsRef<Path>,
559 ) -> Result<(), MaterialSerializationError> {
560 let zkey_file = std::fs::File::create(zkey_path)?;
561 let graph_file = std::fs::File::create(graph_path)?;
562 let mut zkey_writer = std::io::BufWriter::new(zkey_file);
563 let mut graph_writer = std::io::BufWriter::new(graph_file);
564 self.to_writer(&mut zkey_writer, &mut graph_writer)?;
565 zkey_writer.flush()?;
566 graph_writer.flush()?;
567 Ok(())
568 }
569}