semaphore/
proof.rs

1#[cfg(feature = "serde")]
2use crate::error::SemaphoreError;
3use crate::{
4    MAX_TREE_DEPTH, MIN_TREE_DEPTH,
5    group::{EMPTY_ELEMENT, Element, Group, MerkleProof},
6    identity::Identity,
7    utils::{download_zkey, hash, to_big_uint, to_element},
8    witness::dispatch_witness,
9};
10use anyhow::{Ok, Result, bail};
11use circom_prover::{
12    CircomProver,
13    prover::{
14        CircomProof, ProofLib, PublicInputs,
15        circom::{self, CURVE_BN254, G1, G2, PROTOCOL_GROTH16},
16    },
17    witness::WitnessFn,
18};
19use num_bigint::BigUint;
20use num_traits::{Zero, identities::One};
21use std::{collections::HashMap, str::FromStr};
22
23pub type PackedGroth16Proof = [BigUint; 8];
24
25pub enum GroupOrMerkleProof {
26    Group(Group),
27    MerkleProof(MerkleProof),
28}
29
30impl GroupOrMerkleProof {
31    fn merkle_proof(&self, leaf: &Element) -> MerkleProof {
32        match self {
33            GroupOrMerkleProof::Group(group) => {
34                let idx = group.index_of(*leaf).expect("The identity does not exist");
35                group.generate_proof(idx).unwrap()
36            }
37            GroupOrMerkleProof::MerkleProof(proof) => proof.clone(),
38        }
39    }
40}
41
42#[derive(Debug, Clone, PartialEq)]
43pub struct SemaphoreProof {
44    pub merkle_tree_depth: u16,
45    pub merkle_tree_root: BigUint,
46    pub message: BigUint,
47    pub nullifier: BigUint,
48    pub scope: BigUint,
49    pub points: PackedGroth16Proof,
50}
51
52#[cfg(feature = "serde")]
53impl SemaphoreProof {
54    pub fn export(&self) -> Result<String, SemaphoreError> {
55        let mut json = serde_json::Map::new();
56        json.insert(
57            "merkle_tree_depth".to_string(),
58            self.merkle_tree_depth.into(),
59        );
60        json.insert(
61            "merkle_tree_root".to_string(),
62            self.merkle_tree_root.to_string().into(),
63        );
64        json.insert("message".to_string(), self.message.to_string().into());
65        json.insert("nullifier".to_string(), self.nullifier.to_string().into());
66        json.insert("scope".to_string(), self.scope.to_string().into());
67        json.insert(
68            "points".to_string(),
69            self.points
70                .to_vec()
71                .into_iter()
72                .map(|p| p.to_string())
73                .collect::<Vec<String>>()
74                .into(),
75        );
76        serde_json::to_string(&json).map_err(|e| SemaphoreError::SerializationError(e.to_string()))
77    }
78
79    pub fn import(json: &str) -> Result<Self, SemaphoreError> {
80        let json: serde_json::Map<String, serde_json::Value> = serde_json::from_str(json)
81            .map_err(|e| SemaphoreError::SerializationError(e.to_string()))?;
82        Ok(SemaphoreProof {
83            merkle_tree_depth: json.get("merkle_tree_depth").unwrap().as_u64().unwrap() as u16,
84            merkle_tree_root: BigUint::from_str(
85                json.get("merkle_tree_root").unwrap().as_str().unwrap(),
86            )
87            .unwrap(),
88            message: BigUint::from_str(json.get("message").unwrap().as_str().unwrap()).unwrap(),
89            nullifier: BigUint::from_str(json.get("nullifier").unwrap().as_str().unwrap()).unwrap(),
90            scope: BigUint::from_str(json.get("scope").unwrap().as_str().unwrap()).unwrap(),
91            points: json
92                .get("points")
93                .unwrap()
94                .as_array()
95                .unwrap()
96                .iter()
97                .map(|p| BigUint::from_str(p.as_str().unwrap()).unwrap())
98                .collect::<Vec<BigUint>>()
99                .try_into()
100                .unwrap(),
101        })
102        .map_err(|e| SemaphoreError::SerializationError(e.to_string()))
103    }
104}
105
106pub struct Proof {}
107
108impl Proof {
109    pub fn generate_proof(
110        identity: Identity,
111        group: GroupOrMerkleProof,
112        message: String,
113        scope: String,
114        merkle_tree_depth: u16,
115    ) -> Result<SemaphoreProof> {
116        // check tree depth
117        if !(MIN_TREE_DEPTH..=MAX_TREE_DEPTH).contains(&merkle_tree_depth) {
118            bail!(format!(
119                "The tree depth must be a number between {} and {}",
120                MIN_TREE_DEPTH, MAX_TREE_DEPTH
121            ));
122        }
123
124        let merkle_proof = group.merkle_proof(&to_element(*identity.commitment()));
125        let merkle_proof_length = merkle_proof.siblings.len();
126
127        let mut merkle_proof_siblings = Vec::<Element>::new();
128        for i in 0..merkle_tree_depth {
129            if let Some(sibling) = merkle_proof.siblings.get(i as usize) {
130                merkle_proof_siblings.push(*sibling);
131            } else {
132                merkle_proof_siblings.push(EMPTY_ELEMENT);
133            }
134        }
135
136        let scope_uint = to_big_uint(&scope);
137        let message_uint = to_big_uint(&message);
138        let inputs = HashMap::from([
139            (
140                "secret".to_string(),
141                vec![identity.secret_scalar().to_string()],
142            ),
143            (
144                "merkleProofLength".to_string(),
145                vec![merkle_proof_length.to_string()],
146            ),
147            (
148                "merkleProofIndex".to_string(),
149                vec![merkle_proof.index.to_string()],
150            ),
151            (
152                "merkleProofSiblings".to_string(),
153                merkle_proof_siblings
154                    .iter()
155                    .map(|s| BigUint::from_bytes_le(s.to_vec().as_ref()).to_string())
156                    .collect(),
157            ),
158            ("scope".to_string(), vec![hash(scope_uint.clone())]),
159            ("message".to_string(), vec![hash(message_uint.clone())]),
160        ]);
161
162        let zkey_path = download_zkey(merkle_tree_depth).expect("Failed to download zkey");
163        let witness_fn = dispatch_witness(merkle_tree_depth);
164
165        let circom_proof = CircomProver::prove(
166            ProofLib::Arkworks,
167            WitnessFn::CircomWitnessCalc(witness_fn),
168            serde_json::to_string(&inputs).unwrap(),
169            zkey_path,
170        )?;
171
172        Ok(SemaphoreProof {
173            merkle_tree_depth,
174            merkle_tree_root: BigUint::from_bytes_le(merkle_proof.root.as_ref()),
175            message: message_uint,
176            nullifier: circom_proof.pub_inputs.0.get(1).unwrap().clone(),
177            scope: scope_uint,
178            points: Self::pack_groth16_proof(circom_proof.proof),
179        })
180    }
181
182    pub fn verify_proof(proof: SemaphoreProof) -> bool {
183        // check tree depth
184        if proof.merkle_tree_depth < MIN_TREE_DEPTH || proof.merkle_tree_depth > MAX_TREE_DEPTH {
185            panic!("The tree depth must be a number between and");
186        }
187
188        let scope = BigUint::from_str(hash(proof.scope).as_str()).unwrap();
189        let message = BigUint::from_str(hash(proof.message).as_str()).unwrap();
190        let pub_inputs = PublicInputs(vec![
191            proof.merkle_tree_root,
192            proof.nullifier,
193            message,
194            scope,
195        ]);
196        let p = CircomProof {
197            proof: Self::unpack_groth16_proof(proof.points),
198            pub_inputs,
199        };
200
201        let zkey_path = download_zkey(proof.merkle_tree_depth).expect("Failed to download zkey");
202        CircomProver::verify(ProofLib::Arkworks, p, zkey_path).unwrap()
203    }
204
205    pub fn pack_groth16_proof(p: circom::Proof) -> PackedGroth16Proof {
206        [
207            p.a.x,
208            p.a.y,
209            p.b.x[1].clone(),
210            p.b.x[0].clone(),
211            p.b.y[1].clone(),
212            p.b.y[0].clone(),
213            p.c.x,
214            p.c.y,
215        ]
216    }
217
218    pub fn unpack_groth16_proof(packed: PackedGroth16Proof) -> circom::Proof {
219        let a = G1 {
220            x: packed[0].clone(),
221            y: packed[1].clone(),
222            z: BigUint::one(),
223        };
224        let b = G2 {
225            x: [packed[3].clone(), packed[2].clone()],
226            y: [packed[5].clone(), packed[4].clone()],
227            z: [BigUint::one(), BigUint::zero()],
228        };
229        let c = G1 {
230            x: packed[6].clone(),
231            y: packed[7].clone(),
232            z: BigUint::one(),
233        };
234
235        circom::Proof {
236            a,
237            b,
238            c,
239            protocol: PROTOCOL_GROTH16.to_string(),
240            curve: CURVE_BN254.to_string(),
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::{
249        group::{Element, Group},
250        identity::Identity,
251        proof::SemaphoreProof,
252    };
253    use num_bigint::BigUint;
254    use std::str::FromStr;
255
256    const TREE_DEPTH: usize = 10;
257    const MESSAGE: &str = "Hello world";
258    const SCOPE: &str = "Scope";
259
260    const MEMBER1: Element = [1; 32];
261    const MEMBER2: Element = [2; 32];
262
263    #[cfg(test)]
264    mod gen_proof {
265        use super::*;
266        use std::panic::{self, AssertUnwindSafe};
267
268        #[test]
269        fn test_proof() {
270            let identity = Identity::new("secret".as_bytes());
271            let group =
272                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
273            let root = group.root().unwrap();
274
275            let proof = Proof::generate_proof(
276                identity,
277                GroupOrMerkleProof::Group(group),
278                MESSAGE.to_string(),
279                SCOPE.to_string(),
280                TREE_DEPTH as u16,
281            )
282            .unwrap();
283
284            assert_eq!(proof.merkle_tree_root, BigUint::from_bytes_le(&root));
285            assert_eq!(proof.message, to_big_uint(&MESSAGE.to_string()));
286            assert_eq!(proof.scope, to_big_uint(&SCOPE.to_string()));
287        }
288
289        #[test]
290        fn test_proof_1_member() {
291            let identity = Identity::new("secret".as_bytes());
292            let group = Group::new(&[to_element(*identity.commitment())]).unwrap();
293            let root = group.root().unwrap();
294
295            let proof = Proof::generate_proof(
296                identity,
297                GroupOrMerkleProof::Group(group),
298                MESSAGE.to_string(),
299                SCOPE.to_string(),
300                TREE_DEPTH as u16,
301            )
302            .unwrap();
303
304            assert_eq!(proof.merkle_tree_root, BigUint::from_bytes_le(&root));
305        }
306
307        #[test]
308        fn test_proof_with_semaphore_proof() {
309            let identity = Identity::new("secret".as_bytes());
310            let group =
311                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
312            let root = group.root().unwrap();
313
314            let proof = Proof::generate_proof(
315                identity,
316                GroupOrMerkleProof::MerkleProof(group.generate_proof(2).unwrap()),
317                MESSAGE.to_string(),
318                SCOPE.to_string(),
319                TREE_DEPTH as u16,
320            )
321            .unwrap();
322
323            assert_eq!(proof.merkle_tree_root, BigUint::from_bytes_le(&root));
324        }
325
326        #[test]
327        fn test_error_invalid_tree_depth() {
328            let identity = Identity::new("secret".as_bytes());
329            let group =
330                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
331
332            let result = Proof::generate_proof(
333                identity,
334                GroupOrMerkleProof::Group(group),
335                MESSAGE.to_string(),
336                SCOPE.to_string(),
337                33u16,
338            );
339
340            assert!(result.is_err());
341            if let Err(err) = result {
342                if let Some(msg) = err.downcast_ref::<String>() {
343                    assert_eq!(msg, "The tree depth must be a number between 1 and 32");
344                }
345            }
346        }
347
348        #[test]
349        fn test_panic_id_not_in_group() {
350            let identity = Identity::new("secret".as_bytes());
351            let group = Group::new(&[MEMBER1, MEMBER2]).unwrap();
352
353            let err = panic::catch_unwind(AssertUnwindSafe(|| {
354                Proof::generate_proof(
355                    identity,
356                    GroupOrMerkleProof::Group(group),
357                    MESSAGE.to_string(),
358                    SCOPE.to_string(),
359                    TREE_DEPTH as u16,
360                )
361                .unwrap()
362            }));
363
364            assert!(err.is_err());
365            if let Err(err) = err {
366                if let Some(msg) = err.downcast_ref::<String>() {
367                    assert_eq!(msg, "The identity does not exist");
368                }
369            }
370        }
371
372        #[test]
373        fn test_panic_message_over_32bytes() {
374            let identity = Identity::new("secret".as_bytes());
375            let group =
376                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
377
378            let err = panic::catch_unwind(AssertUnwindSafe(|| {
379                Proof::generate_proof(
380                    identity,
381                    GroupOrMerkleProof::Group(group),
382                    "This message is over 32 bytes long!!".to_string(),
383                    SCOPE.to_string(),
384                    TREE_DEPTH as u16,
385                )
386                .unwrap()
387            }));
388
389            assert!(err.is_err());
390            if let Err(err) = err {
391                if let Some(msg) = err.downcast_ref::<String>() {
392                    assert_eq!(msg, "BigUint too large: exceeds 32 bytes");
393                }
394            }
395        }
396
397        #[test]
398        fn test_panic_scope_over_32bytes() {
399            let identity = Identity::new("secret".as_bytes());
400            let group =
401                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
402
403            let err = panic::catch_unwind(AssertUnwindSafe(|| {
404                Proof::generate_proof(
405                    identity,
406                    GroupOrMerkleProof::Group(group),
407                    MESSAGE.to_string(),
408                    "This scope is over 32 bytes long!!".to_string(),
409                    TREE_DEPTH as u16,
410                )
411                .unwrap()
412            }));
413
414            assert!(err.is_err());
415            if let Err(err) = err {
416                if let Some(msg) = err.downcast_ref::<String>() {
417                    assert_eq!(msg, "BigUint too large: exceeds 32 bytes");
418                }
419            }
420        }
421    }
422
423    #[cfg(test)]
424    mod verify_proof {
425        use super::*;
426        use std::panic::{self, AssertUnwindSafe};
427
428        #[test]
429        fn test_verify_proof() {
430            let identity = Identity::new("secret".as_bytes());
431            let group =
432                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
433
434            let proof = Proof::generate_proof(
435                identity,
436                GroupOrMerkleProof::Group(group),
437                MESSAGE.to_string(),
438                SCOPE.to_string(),
439                TREE_DEPTH as u16,
440            )
441            .unwrap();
442
443            assert!(Proof::verify_proof(proof))
444        }
445
446        #[test]
447        fn test_verify_proof_with_different_depth() {
448            for depth in MIN_TREE_DEPTH..=MAX_TREE_DEPTH {
449                let identity = Identity::new("secret".as_bytes());
450                let group =
451                    Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
452
453                let proof = Proof::generate_proof(
454                    identity,
455                    GroupOrMerkleProof::Group(group),
456                    MESSAGE.to_string(),
457                    SCOPE.to_string(),
458                    depth as u16,
459                )
460                .unwrap();
461
462                assert!(Proof::verify_proof(proof));
463            }
464        }
465
466        #[test]
467        fn test_panic_verify_invalid_tree_depth() {
468            let identity = Identity::new("secret".as_bytes());
469            let group =
470                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
471
472            let mut proof = Proof::generate_proof(
473                identity,
474                GroupOrMerkleProof::Group(group),
475                MESSAGE.to_string(),
476                SCOPE.to_string(),
477                TREE_DEPTH as u16,
478            )
479            .unwrap();
480            proof.merkle_tree_depth = 40;
481
482            let err = panic::catch_unwind(AssertUnwindSafe(|| Proof::verify_proof(proof)));
483            assert!(err.is_err());
484            if let Err(err) = err {
485                if let Some(msg) = err.downcast_ref::<String>() {
486                    assert_eq!(msg, "The tree depth must be a number between 1 and 32");
487                }
488            }
489        }
490
491        #[test]
492        fn test_error_verify_invalid_proof() {
493            let identity = Identity::new("secret".as_bytes());
494            let group =
495                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
496
497            let proof = Proof::generate_proof(
498                identity,
499                GroupOrMerkleProof::MerkleProof(group.generate_proof(0).unwrap()),
500                MESSAGE.to_string(),
501                SCOPE.to_string(),
502                TREE_DEPTH as u16,
503            )
504            .unwrap();
505
506            assert_eq!(Proof::verify_proof(proof), false)
507        }
508
509        // This test case is to test a semaphore-js proof can be verified by semaphore-rs verifier.
510        #[test]
511        fn test_semaphore_js_proof() {
512            let points = [
513                // Proof generated from `Semaphore-js`
514                "2448901300518098096993075752654536134313649038239216706400667219963346227679",
515                "11383357624181217239434984412545229801919536849542936327488167664579097021171",
516                "4740704242184999702574958393302343834384154042177684026319208048433986938524",
517                "2103898499672759617084297744151588687300569178309824227315704845907524437637",
518                "18126651739688030584140960766793516019865850111238360168731489534891060767936",
519                "13293264290162772264887787723520088518667325866686508255341288441681546077334",
520                "13860303418198054644271827809984867757526756615344099647083475463061491185143",
521                "7750331146056656453454308267328134694500438800080743301030181391570997944788",
522            ]
523            .iter()
524            .map(|&p| BigUint::from_str(p).unwrap())
525            .collect::<Vec<BigUint>>()
526            .try_into()
527            .expect("Expected exactly 8 elements");
528
529            let proof = SemaphoreProof {
530                merkle_tree_depth: 10,
531                merkle_tree_root: BigUint::from_str(
532                    "4990292586352433503726012711155167179034286198473030768981544541070532815155",
533                )
534                .unwrap(),
535                nullifier: BigUint::from_str(
536                    "17540473064543782218297133630279824063352907908315494138425986188962403570231",
537                )
538                .unwrap(),
539                message: BigUint::from_str(
540                    "32745724963520510550185023804391900974863477733501474067656557556163468591104",
541                )
542                .unwrap(),
543                scope: BigUint::from_str(
544                    "37717653415819232215590989865455204849443869931268328771929128739472152723456",
545                )
546                .unwrap(),
547                points,
548            };
549
550            assert!(Proof::verify_proof(proof));
551        }
552
553        #[cfg(feature = "serde")]
554        #[test]
555        fn test_proof_export_import() {
556            let identity = Identity::new("secret".as_bytes());
557            let group =
558                Group::new(&[MEMBER1, MEMBER2, to_element(*identity.commitment())]).unwrap();
559            let proof = Proof::generate_proof(
560                identity,
561                GroupOrMerkleProof::Group(group),
562                MESSAGE.to_string(),
563                SCOPE.to_string(),
564                TREE_DEPTH as u16,
565            )
566            .unwrap();
567            let proof_json = proof.export().unwrap();
568            let proof_imported = SemaphoreProof::import(&proof_json).unwrap();
569            assert_eq!(proof, proof_imported);
570            let valid = Proof::verify_proof(proof_imported);
571            assert!(valid);
572        }
573    }
574}