1#[cfg(feature = "serde")]
2use crate::error::SemaphoreError;
3use crate::{
4 MAX_TREE_DEPTH, MIN_TREE_DEPTH,
5 group::{EMPTY_ELEMENT, Element, Group, MerkleProof},
6 identity::Identity,
7 utils::{download_zkey, hash, to_big_uint, to_element},
8 witness::dispatch_witness,
9};
10use anyhow::{Ok, Result, bail};
11use circom_prover::{
12 CircomProver,
13 prover::{
14 CircomProof, ProofLib, PublicInputs,
15 circom::{self, CURVE_BN254, G1, G2, PROTOCOL_GROTH16},
16 },
17 witness::WitnessFn,
18};
19use num_bigint::BigUint;
20use num_traits::{Zero, identities::One};
21use std::{collections::HashMap, str::FromStr};
22
23pub type PackedGroth16Proof = [BigUint; 8];
24
25pub enum GroupOrMerkleProof {
26 Group(Group),
27 MerkleProof(MerkleProof),
28}
29
30impl GroupOrMerkleProof {
31 fn merkle_proof(&self, leaf: &Element) -> MerkleProof {
32 match self {
33 GroupOrMerkleProof::Group(group) => {
34 let idx = group.index_of(*leaf).expect("The identity does not exist");
35 group.generate_proof(idx).unwrap()
36 }
37 GroupOrMerkleProof::MerkleProof(proof) => proof.clone(),
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq)]
43pub struct SemaphoreProof {
44 pub merkle_tree_depth: u16,
45 pub merkle_tree_root: BigUint,
46 pub message: BigUint,
47 pub nullifier: BigUint,
48 pub scope: BigUint,
49 pub points: PackedGroth16Proof,
50}
51
52#[cfg(feature = "serde")]
53impl SemaphoreProof {
54 pub fn export(&self) -> Result<String, SemaphoreError> {
55 let mut json = serde_json::Map::new();
56 json.insert(
57 "merkle_tree_depth".to_string(),
58 self.merkle_tree_depth.into(),
59 );
60 json.insert(
61 "merkle_tree_root".to_string(),
62 self.merkle_tree_root.to_string().into(),
63 );
64 json.insert("message".to_string(), self.message.to_string().into());
65 json.insert("nullifier".to_string(), self.nullifier.to_string().into());
66 json.insert("scope".to_string(), self.scope.to_string().into());
67 json.insert(
68 "points".to_string(),
69 self.points
70 .to_vec()
71 .into_iter()
72 .map(|p| p.to_string())
73 .collect::<Vec<String>>()
74 .into(),
75 );
76 serde_json::to_string(&json).map_err(|e| SemaphoreError::SerializationError(e.to_string()))
77 }
78
79 pub fn import(json: &str) -> Result<Self, SemaphoreError> {
80 let json: serde_json::Map<String, serde_json::Value> = serde_json::from_str(json)
81 .map_err(|e| SemaphoreError::SerializationError(e.to_string()))?;
82 Ok(SemaphoreProof {
83 merkle_tree_depth: json.get("merkle_tree_depth").unwrap().as_u64().unwrap() as u16,
84 merkle_tree_root: BigUint::from_str(
85 json.get("merkle_tree_root").unwrap().as_str().unwrap(),
86 )
87 .unwrap(),
88 message: BigUint::from_str(json.get("message").unwrap().as_str().unwrap()).unwrap(),
89 nullifier: BigUint::from_str(json.get("nullifier").unwrap().as_str().unwrap()).unwrap(),
90 scope: BigUint::from_str(json.get("scope").unwrap().as_str().unwrap()).unwrap(),
91 points: json
92 .get("points")
93 .unwrap()
94 .as_array()
95 .unwrap()
96 .iter()
97 .map(|p| BigUint::from_str(p.as_str().unwrap()).unwrap())
98 .collect::<Vec<BigUint>>()
99 .try_into()
100 .unwrap(),
101 })
102 .map_err(|e| SemaphoreError::SerializationError(e.to_string()))
103 }
104}
105
106pub struct Proof {}
107
108impl Proof {
109 pub fn generate_proof(
110 identity: Identity,
111 group: GroupOrMerkleProof,
112 message: String,
113 scope: String,
114 merkle_tree_depth: u16,
115 ) -> Result<SemaphoreProof> {
116 if !(MIN_TREE_DEPTH..=MAX_TREE_DEPTH).contains(&merkle_tree_depth) {
118 bail!(format!(
119 "The tree depth must be a number between {} and {}",
120 MIN_TREE_DEPTH, MAX_TREE_DEPTH
121 ));
122 }
123
124 let merkle_proof = group.merkle_proof(&to_element(*identity.commitment()));
125 let merkle_proof_length = merkle_proof.siblings.len();
126
127 let mut merkle_proof_siblings = Vec::<Element>::new();
128 for i in 0..merkle_tree_depth {
129 if let Some(sibling) = merkle_proof.siblings.get(i as usize) {
130 merkle_proof_siblings.push(*sibling);
131 } else {
132 merkle_proof_siblings.push(EMPTY_ELEMENT);
133 }
134 }
135
136 let scope_uint = to_big_uint(&scope);
137 let message_uint = to_big_uint(&message);
138 let inputs = HashMap::from([
139 (
140 "secret".to_string(),
141 vec![identity.secret_scalar().to_string()],
142 ),
143 (
144 "merkleProofLength".to_string(),
145 vec![merkle_proof_length.to_string()],
146 ),
147 (
148 "merkleProofIndex".to_string(),
149 vec![merkle_proof.index.to_string()],
150 ),
151 (
152 "merkleProofSiblings".to_string(),
153 merkle_proof_siblings
154 .iter()
155 .map(|s| BigUint::from_bytes_le(s.to_vec().as_ref()).to_string())
156 .collect(),
157 ),
158 ("scope".to_string(), vec![hash(scope_uint.clone())]),
159 ("message".to_string(), vec![hash(message_uint.clone())]),
160 ]);
161
162 let zkey_path = download_zkey(merkle_tree_depth).expect("Failed to download zkey");
163 let witness_fn = dispatch_witness(merkle_tree_depth);
164
165 let circom_proof = CircomProver::prove(
166 ProofLib::Arkworks,
167 WitnessFn::CircomWitnessCalc(witness_fn),
168 serde_json::to_string(&inputs).unwrap(),
169 zkey_path,
170 )?;
171
172 Ok(SemaphoreProof {
173 merkle_tree_depth,
174 merkle_tree_root: BigUint::from_bytes_le(merkle_proof.root.as_ref()),
175 message: message_uint,
176 nullifier: circom_proof.pub_inputs.0.get(1).unwrap().clone(),
177 scope: scope_uint,
178 points: Self::pack_groth16_proof(circom_proof.proof),
179 })
180 }
181
182 pub fn verify_proof(proof: SemaphoreProof) -> bool {
183 if proof.merkle_tree_depth < MIN_TREE_DEPTH || proof.merkle_tree_depth > MAX_TREE_DEPTH {
185 panic!("The tree depth must be a number between and");
186 }
187
188 let scope = BigUint::from_str(hash(proof.scope).as_str()).unwrap();
189 let message = BigUint::from_str(hash(proof.message).as_str()).unwrap();
190 let pub_inputs = PublicInputs(vec![
191 proof.merkle_tree_root,
192 proof.nullifier,
193 message,
194 scope,
195 ]);
196 let p = CircomProof {
197 proof: Self::unpack_groth16_proof(proof.points),
198 pub_inputs,
199 };
200
201 let zkey_path = download_zkey(proof.merkle_tree_depth).expect("Failed to download zkey");
202 CircomProver::verify(ProofLib::Arkworks, p, zkey_path).unwrap()
203 }
204
205 pub fn pack_groth16_proof(p: circom::Proof) -> PackedGroth16Proof {
206 [
207 p.a.x,
208 p.a.y,
209 p.b.x[1].clone(),
210 p.b.x[0].clone(),
211 p.b.y[1].clone(),
212 p.b.y[0].clone(),
213 p.c.x,
214 p.c.y,
215 ]
216 }
217
218 pub fn unpack_groth16_proof(packed: PackedGroth16Proof) -> circom::Proof {
219 let a = G1 {
220 x: packed[0].clone(),
221 y: packed[1].clone(),
222 z: BigUint::one(),
223 };
224 let b = G2 {
225 x: [packed[3].clone(), packed[2].clone()],
226 y: [packed[5].clone(), packed[4].clone()],
227 z: [BigUint::one(), BigUint::zero()],
228 };
229 let c = G1 {
230 x: packed[6].clone(),
231 y: packed[7].clone(),
232 z: BigUint::one(),
233 };
234
235 circom::Proof {
236 a,
237 b,
238 c,
239 protocol: PROTOCOL_GROTH16.to_string(),
240 curve: CURVE_BN254.to_string(),
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::{
249 group::{Element, Group},
250 identity::Identity,
251 proof::SemaphoreProof,
252 };
253 use num_bigint::BigUint;
254 use std::str::FromStr;
255
256 const TREE_DEPTH: usize = 10;
257 const MESSAGE: &str = "Hello world";
258 const SCOPE: &str = "Scope";
259
260 const MEMBER1: Element = [1; 32];
261 const MEMBER2: Element = [2; 32];
262
263 #[cfg(test)]
264 mod gen_proof {
265 use super::*;
266 use std::panic::{self, AssertUnwindSafe};
267
268 #[test]
269 fn test_proof() {
270 let identity = Identity::new("secret".as_bytes());
271 let group =
272 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
273 let root = group.root().unwrap();
274
275 let proof = Proof::generate_proof(
276 identity,
277 GroupOrMerkleProof::Group(group),
278 MESSAGE.to_string(),
279 SCOPE.to_string(),
280 TREE_DEPTH as u16,
281 )
282 .unwrap();
283
284 assert_eq!(proof.merkle_tree_root, BigUint::from_bytes_le(&root));
285 assert_eq!(proof.message, to_big_uint(&MESSAGE.to_string()));
286 assert_eq!(proof.scope, to_big_uint(&SCOPE.to_string()));
287 }
288
289 #[test]
290 fn test_proof_1_member() {
291 let identity = Identity::new("secret".as_bytes());
292 let group = Group::new(&[to_element(*identity.commitment())]).unwrap();
293 let root = group.root().unwrap();
294
295 let proof = Proof::generate_proof(
296 identity,
297 GroupOrMerkleProof::Group(group),
298 MESSAGE.to_string(),
299 SCOPE.to_string(),
300 TREE_DEPTH as u16,
301 )
302 .unwrap();
303
304 assert_eq!(proof.merkle_tree_root, BigUint::from_bytes_le(&root));
305 }
306
307 #[test]
308 fn test_proof_with_semaphore_proof() {
309 let identity = Identity::new("secret".as_bytes());
310 let group =
311 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
312 let root = group.root().unwrap();
313
314 let proof = Proof::generate_proof(
315 identity,
316 GroupOrMerkleProof::MerkleProof(group.generate_proof(2).unwrap()),
317 MESSAGE.to_string(),
318 SCOPE.to_string(),
319 TREE_DEPTH as u16,
320 )
321 .unwrap();
322
323 assert_eq!(proof.merkle_tree_root, BigUint::from_bytes_le(&root));
324 }
325
326 #[test]
327 fn test_error_invalid_tree_depth() {
328 let identity = Identity::new("secret".as_bytes());
329 let group =
330 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
331
332 let result = Proof::generate_proof(
333 identity,
334 GroupOrMerkleProof::Group(group),
335 MESSAGE.to_string(),
336 SCOPE.to_string(),
337 33u16,
338 );
339
340 assert!(result.is_err());
341 if let Err(err) = result {
342 if let Some(msg) = err.downcast_ref::<String>() {
343 assert_eq!(msg, "The tree depth must be a number between 1 and 32");
344 }
345 }
346 }
347
348 #[test]
349 fn test_panic_id_not_in_group() {
350 let identity = Identity::new("secret".as_bytes());
351 let group = Group::new(&[MEMBER1, MEMBER2]).unwrap();
352
353 let err = panic::catch_unwind(AssertUnwindSafe(|| {
354 Proof::generate_proof(
355 identity,
356 GroupOrMerkleProof::Group(group),
357 MESSAGE.to_string(),
358 SCOPE.to_string(),
359 TREE_DEPTH as u16,
360 )
361 .unwrap()
362 }));
363
364 assert!(err.is_err());
365 if let Err(err) = err {
366 if let Some(msg) = err.downcast_ref::<String>() {
367 assert_eq!(msg, "The identity does not exist");
368 }
369 }
370 }
371
372 #[test]
373 fn test_panic_message_over_32bytes() {
374 let identity = Identity::new("secret".as_bytes());
375 let group =
376 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
377
378 let err = panic::catch_unwind(AssertUnwindSafe(|| {
379 Proof::generate_proof(
380 identity,
381 GroupOrMerkleProof::Group(group),
382 "This message is over 32 bytes long!!".to_string(),
383 SCOPE.to_string(),
384 TREE_DEPTH as u16,
385 )
386 .unwrap()
387 }));
388
389 assert!(err.is_err());
390 if let Err(err) = err {
391 if let Some(msg) = err.downcast_ref::<String>() {
392 assert_eq!(msg, "BigUint too large: exceeds 32 bytes");
393 }
394 }
395 }
396
397 #[test]
398 fn test_panic_scope_over_32bytes() {
399 let identity = Identity::new("secret".as_bytes());
400 let group =
401 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
402
403 let err = panic::catch_unwind(AssertUnwindSafe(|| {
404 Proof::generate_proof(
405 identity,
406 GroupOrMerkleProof::Group(group),
407 MESSAGE.to_string(),
408 "This scope is over 32 bytes long!!".to_string(),
409 TREE_DEPTH as u16,
410 )
411 .unwrap()
412 }));
413
414 assert!(err.is_err());
415 if let Err(err) = err {
416 if let Some(msg) = err.downcast_ref::<String>() {
417 assert_eq!(msg, "BigUint too large: exceeds 32 bytes");
418 }
419 }
420 }
421 }
422
423 #[cfg(test)]
424 mod verify_proof {
425 use super::*;
426 use std::panic::{self, AssertUnwindSafe};
427
428 #[test]
429 fn test_verify_proof() {
430 let identity = Identity::new("secret".as_bytes());
431 let group =
432 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
433
434 let proof = Proof::generate_proof(
435 identity,
436 GroupOrMerkleProof::Group(group),
437 MESSAGE.to_string(),
438 SCOPE.to_string(),
439 TREE_DEPTH as u16,
440 )
441 .unwrap();
442
443 assert!(Proof::verify_proof(proof))
444 }
445
446 #[test]
447 fn test_verify_proof_with_different_depth() {
448 for depth in MIN_TREE_DEPTH..=MAX_TREE_DEPTH {
449 let identity = Identity::new("secret".as_bytes());
450 let group =
451 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
452
453 let proof = Proof::generate_proof(
454 identity,
455 GroupOrMerkleProof::Group(group),
456 MESSAGE.to_string(),
457 SCOPE.to_string(),
458 depth as u16,
459 )
460 .unwrap();
461
462 assert!(Proof::verify_proof(proof));
463 }
464 }
465
466 #[test]
467 fn test_panic_verify_invalid_tree_depth() {
468 let identity = Identity::new("secret".as_bytes());
469 let group =
470 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
471
472 let mut proof = Proof::generate_proof(
473 identity,
474 GroupOrMerkleProof::Group(group),
475 MESSAGE.to_string(),
476 SCOPE.to_string(),
477 TREE_DEPTH as u16,
478 )
479 .unwrap();
480 proof.merkle_tree_depth = 40;
481
482 let err = panic::catch_unwind(AssertUnwindSafe(|| Proof::verify_proof(proof)));
483 assert!(err.is_err());
484 if let Err(err) = err {
485 if let Some(msg) = err.downcast_ref::<String>() {
486 assert_eq!(msg, "The tree depth must be a number between 1 and 32");
487 }
488 }
489 }
490
491 #[test]
492 fn test_error_verify_invalid_proof() {
493 let identity = Identity::new("secret".as_bytes());
494 let group =
495 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
496
497 let proof = Proof::generate_proof(
498 identity,
499 GroupOrMerkleProof::MerkleProof(group.generate_proof(0).unwrap()),
500 MESSAGE.to_string(),
501 SCOPE.to_string(),
502 TREE_DEPTH as u16,
503 )
504 .unwrap();
505
506 assert_eq!(Proof::verify_proof(proof), false)
507 }
508
509 #[test]
511 fn test_semaphore_js_proof() {
512 let points = [
513 "2448901300518098096993075752654536134313649038239216706400667219963346227679",
515 "11383357624181217239434984412545229801919536849542936327488167664579097021171",
516 "4740704242184999702574958393302343834384154042177684026319208048433986938524",
517 "2103898499672759617084297744151588687300569178309824227315704845907524437637",
518 "18126651739688030584140960766793516019865850111238360168731489534891060767936",
519 "13293264290162772264887787723520088518667325866686508255341288441681546077334",
520 "13860303418198054644271827809984867757526756615344099647083475463061491185143",
521 "7750331146056656453454308267328134694500438800080743301030181391570997944788",
522 ]
523 .iter()
524 .map(|&p| BigUint::from_str(p).unwrap())
525 .collect::<Vec<BigUint>>()
526 .try_into()
527 .expect("Expected exactly 8 elements");
528
529 let proof = SemaphoreProof {
530 merkle_tree_depth: 10,
531 merkle_tree_root: BigUint::from_str(
532 "4990292586352433503726012711155167179034286198473030768981544541070532815155",
533 )
534 .unwrap(),
535 nullifier: BigUint::from_str(
536 "17540473064543782218297133630279824063352907908315494138425986188962403570231",
537 )
538 .unwrap(),
539 message: BigUint::from_str(
540 "32745724963520510550185023804391900974863477733501474067656557556163468591104",
541 )
542 .unwrap(),
543 scope: BigUint::from_str(
544 "37717653415819232215590989865455204849443869931268328771929128739472152723456",
545 )
546 .unwrap(),
547 points,
548 };
549
550 assert!(Proof::verify_proof(proof));
551 }
552
553 #[cfg(feature = "serde")]
554 #[test]
555 fn test_proof_export_import() {
556 let identity = Identity::new("secret".as_bytes());
557 let group =
558 Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
559 let proof = Proof::generate_proof(
560 identity,
561 GroupOrMerkleProof::Group(group),
562 MESSAGE.to_string(),
563 SCOPE.to_string(),
564 TREE_DEPTH as u16,
565 )
566 .unwrap();
567 let proof_json = proof.export().unwrap();
568 let proof_imported = SemaphoreProof::import(&proof_json).unwrap();
569 assert_eq!(proof, proof_imported);
570 let valid = Proof::verify_proof(proof_imported);
571 assert!(valid);
572 }
573 }
574}