semaphore_rs_proof/
compression.rs

1//! Groth16 proof compression
2//!
3//! Ported from https://github.com/worldcoin/world-id-state-bridge/blob/main/src/SemaphoreVerifier.sol
4//!
5//! Based upon work in https://xn--2-umb.com/23/bn254-compression/
6
7use ruint::aliases::U256;
8use ruint::uint;
9use serde::{Deserialize, Serialize};
10
11use super::{Proof, G1, G2};
12use lazy_static::lazy_static;
13
14/// Base field Fp order P
15pub const P: U256 =
16    uint! { 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47_U256 };
17
18// A helper for a frequently used constants
19pub const ONE: U256 = uint! { 1_U256 };
20pub const TWO: U256 = uint! { 2_U256 };
21pub const THREE: U256 = uint! { 3_U256 };
22pub const FOUR: U256 = uint! { 4_U256 };
23
24lazy_static! {
25    /// Exponent for the square root in Fp
26    pub static ref EXP_SQRT_FP: U256 = (P + ONE) / FOUR;
27
28    /// Exponent for the inverse in Fp
29    pub static ref EXP_INVERSE_FP: U256 = P - TWO;
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
33pub struct CompressedProof(pub U256, pub (U256, U256), pub U256);
34
35impl CompressedProof {
36    pub const fn from_flat(flat: [U256; 4]) -> Self {
37        let [a, b0, b1, c] = flat;
38
39        Self(a, (b0, b1), c)
40    }
41
42    pub const fn flatten(self) -> [U256; 4] {
43        let Self(a, (b0, b1), c) = self;
44        [a, b0, b1, c]
45    }
46}
47
48pub fn compress_proof(proof: Proof) -> Option<CompressedProof> {
49    let Proof(g1a, g2, g1b) = proof;
50
51    // NOTE: Order of real and imaginary parts in the proof data is flipped
52    let ([x0, x1], [y0, y1]) = g2;
53    let g2 = ([x1, x0], [y1, y0]);
54
55    let a = compress_g1(g1a)?;
56    // NOTE: G2 compressed repr is flipped
57    let (c0, c1) = compress_g2(g2)?;
58    let c = (c1, c0);
59
60    let b = compress_g1(g1b)?;
61
62    Some(CompressedProof(a, c, b))
63}
64
65pub fn decompress_proof(compressed: CompressedProof) -> Option<Proof> {
66    let CompressedProof(a, c, b) = compressed;
67
68    let g1a = decompress_g1(a)?;
69
70    // NOTE: G2 compressed repr is flipped
71    let (c1, c0) = c;
72    let c = (c0, c1);
73    let g2 = decompress_g2(c)?;
74
75    let g1b = decompress_g1(b)?;
76
77    // Unswap
78    let ([x1, x0], [y1, y0]) = g2;
79    let g2 = ([x0, x1], [y0, y1]);
80
81    Some(Proof(g1a, g2, g1b))
82}
83
84pub fn compress_g1((x, y): G1) -> Option<U256> {
85    if x >= P || y >= P {
86        return None; // Point not in field
87    }
88    if x == U256::ZERO && y == U256::ZERO {
89        return Some(U256::ZERO); // Point at infinity
90    }
91    let y_pos = sqrt_fp(x.pow_mod(THREE, P).add_mod(THREE, P))?;
92    if y == y_pos {
93        Some(x << 1)
94    } else if y == neg_fp(y_pos) {
95        Some(x << 1 | ONE)
96    } else {
97        None
98    }
99}
100
101pub fn decompress_g1(c: U256) -> Option<G1> {
102    if c == U256::ZERO {
103        return Some((U256::ZERO, U256::ZERO)); // Point at infinity
104    }
105
106    let negate = c & ONE == ONE;
107    let x: U256 = c >> 1;
108    if x >= P {
109        return None;
110    }
111
112    let y2 = x.pow_mod(THREE, P).add_mod(THREE, P);
113    let mut y = sqrt_fp(y2)?;
114
115    if negate {
116        y = neg_fp(y);
117    }
118    Some((x, y))
119}
120
121/// Compresses the
122pub fn compress_g2(([x0, x1], [y0, y1]): G2) -> Option<(U256, U256)> {
123    if x0 >= P || x1 >= P || y0 >= P || y1 >= P {
124        return None; // Point not in field
125    }
126    if (x0 | x1 | y0 | y1) == U256::ZERO {
127        return Some((U256::ZERO, U256::ZERO)); // Point at infinity
128    }
129
130    // Compute y^2
131    let n3ab = x0.mul_mod(x1, P).mul_mod(P - THREE, P);
132    let a_3 = x0.pow_mod(THREE, P);
133    let b_3 = x1.pow_mod(THREE, P);
134
135    let y0_pos = U256::from(27)
136        .mul_mod(U256::from(82).inv_mod(P).unwrap(), P)
137        .add_mod(a_3.add_mod(n3ab.mul_mod(x1, P), P), P);
138
139    let y1_pos = neg_fp(
140        THREE
141            .mul_mod(U256::from(82).inv_mod(P).unwrap(), P)
142            .add_mod(b_3.add_mod(n3ab.mul_mod(x0, P), P), P),
143    );
144
145    // Determine hint bit
146    let d = sqrt_fp(
147        y0_pos
148            .mul_mod(y0_pos, P)
149            .add_mod(y1_pos.mul_mod(y1_pos, P), P),
150    )?;
151    let hint = !is_square_fp(y0_pos.add_mod(d, P).mul_mod(TWO.inv_mod(P).unwrap(), P));
152
153    // Recover y
154    let (new_y0_pos, new_y1_pos) = sqrt_fp2(y0_pos, y1_pos, hint)?;
155
156    let hint = if hint { TWO } else { U256::ZERO };
157    if y0 == new_y0_pos && y1 == new_y1_pos {
158        Some(((x0 << 2) | hint, x1))
159    } else if y0 == neg_fp(new_y0_pos) && y1 == neg_fp(new_y1_pos) {
160        Some(((x0 << 2) | hint | ONE, x1))
161    } else {
162        None
163    }
164}
165
166pub fn decompress_g2((c0, c1): (U256, U256)) -> Option<G2> {
167    if c0 == U256::ZERO && c1 == U256::ZERO {
168        return Some(([U256::ZERO, U256::ZERO], [U256::ZERO, U256::ZERO])); // Point at infinity
169    }
170
171    let negate = c0 & ONE == ONE;
172    let hint = c0 & TWO == TWO;
173
174    let x0: U256 = c0 >> 2;
175    let x1 = c1;
176
177    if x0 >= P || x1 >= P {
178        return None;
179    }
180
181    let n3ab = x0.mul_mod(x1, P).mul_mod(P - THREE, P);
182    let a_3 = x0.pow_mod(THREE, P);
183    let b_3 = x1.pow_mod(THREE, P);
184
185    let y0 = U256::from(27)
186        .mul_mod(U256::from(82).inv_mod(P)?, P)
187        .add_mod(a_3.add_mod(n3ab.mul_mod(x1, P), P), P);
188    let y1 = neg_fp(
189        THREE
190            .mul_mod(U256::from(82).inv_mod(P)?, P)
191            .add_mod(b_3.add_mod(n3ab.mul_mod(x0, P), P), P),
192    );
193
194    let (mut y0, mut y1) = sqrt_fp2(y0, y1, hint)?;
195    if negate {
196        y0 = neg_fp(y0);
197        y1 = neg_fp(y1);
198    }
199
200    Some(([x0, x1], [y0, y1]))
201}
202
203fn sqrt_fp(a: U256) -> Option<U256> {
204    let x = a.pow_mod(*EXP_SQRT_FP, P);
205    if x.mul_mod(x, P) == a {
206        Some(x)
207    } else {
208        None
209    }
210}
211
212fn sqrt_fp2(a0: U256, a1: U256, hint: bool) -> Option<(U256, U256)> {
213    let mut d = sqrt_fp(a0.pow_mod(TWO, P).add_mod(a1.pow_mod(TWO, P), P))?;
214
215    if hint {
216        d = neg_fp(d);
217    }
218
219    let frac_1_2 = ONE.mul_mod(TWO.inv_mod(P)?, P);
220    let x0 = sqrt_fp(a0.add_mod(d, P).mul_mod(frac_1_2, P))?;
221    let x1 = a1.mul_mod(invert_fp(x0.mul_mod(TWO, P))?, P);
222
223    if a0 != x0.pow_mod(TWO, P).add_mod(neg_fp(x1.pow_mod(TWO, P)), P)
224        || a1 != TWO.mul_mod(x0.mul_mod(x1, P), P)
225    {
226        return None;
227    }
228
229    Some((x0, x1))
230}
231
232fn is_square_fp(a: U256) -> bool {
233    let x = a.pow_mod(*EXP_SQRT_FP, P);
234    x.mul_mod(x, P) == a
235}
236
237/// Inversion in Fp
238///
239/// Returns a number x such that a * x = 1 in Fp
240/// Returns None if the inverse does not exist
241fn invert_fp(a: U256) -> Option<U256> {
242    let x = a.pow_mod(*EXP_INVERSE_FP, P);
243
244    if a.mul_mod(x, P) != ONE {
245        return None;
246    }
247
248    Some(x)
249}
250
251fn neg_fp(a: U256) -> U256 {
252    P.wrapping_sub(a % P) % P
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn inversion() {
261        let v = uint! { 4598362786468342265918458423096940256393720972438048893356218087518821823203_U256 };
262        let inverted = invert_fp(v).unwrap();
263        let exp_inverted = uint! { 4182222526301715069940346543278816173622053692765626450942898397518664864041_U256 };
264
265        assert_eq!(exp_inverted, inverted);
266    }
267
268    #[test]
269    fn square_root_fp() {
270        let v = uint! { 14471043194638943579446425262583282548539507047061604313953794288955195726209_U256 };
271        let exp_sqrt = uint! { 13741342543520938546471415319044405232187715299443307089577869276344592329757_U256 };
272
273        let sqrt = sqrt_fp(v).unwrap();
274        assert_eq!(exp_sqrt, sqrt);
275    }
276
277    #[test]
278    fn square_root_fp_2() {
279        let (a, b) = sqrt_fp2(uint!{17473058728477435457299093362519578563618705081729024467362715416915525458528_U256}, uint!{17683468329848516541101685027677188007795188556813329975791177956431310972350_U256}, false).unwrap();
280
281        let exp_a = uint! {10193706077588260514783319931179623845729747565730309463634080055351233087269_U256};
282        let exp_b = uint! {2911435556167431587172450261242327574185987927358833959334220021362478804490_U256};
283
284        assert_eq!(exp_a, a);
285        assert_eq!(exp_b, b);
286    }
287
288    // The literal values below are taken from the proof in the following tx: https://etherscan.io/tx/0x53309842294be8c2b9fd694c4e86a5ab031c0d58750978fb3d6f60de16eaa897
289    // Raw proof data is:
290    // 20565048055856194013099208963146657799256893353279242520150547463020687826541
291    // 16286013012747852737396822706018267259565592188907848191354824303311847109059
292    // 4348608846293503080802796983494208797681981448804902149317789801083784587558
293    // 6172488348732750834133346196464201580503416389945891763609808290085997580078
294    // 3229429189805934086496276224876305383924675874777054942516982958483565949767
295    // 944252930093106871283598150477854448876343937304805759422971930315581301659
296    // 18318130744212307125672524358864792312717149086464333958791498157127232409959
297    // 8256141885907329266852096557308020923997215847794048916749940281741155521604
298    //
299    // Note that for the G2 compression test the order of real and imaginary is flipped
300    //
301    // The expected compressed data is generated with the SemaphoreVerifier implementation
302    // in world-id-state-bridge using chisel.
303    //
304    // Unfortunately the `compress_g1` and `compress_g2` methods are set to `internal` so
305    // the approach is a little hacky, but steps to regenerate these values are as follows:
306    // 1. Change `internal` to `public` in `SemaphoreVerifier.sol`
307    // 2. Start `chisel`
308    // 3. Execute the following in chisel repl
309    //    ```
310    //    > import {SemaphoreVerifier} from "src/SemaphoreVerifier.sol";
311    //    > SemaphoreVerifier ve = new SemaphoreVerifier();
312    //    ```
313    // 4. Now you can generate the expected data fixtures using e.g.
314    //    ```
315    //    > ve.compress_g1(0x19ded61ab5c58fdb12367526c6bc04b9186d0980c4b6fd48a44093e80f9b4206, 0x2e619a034be10e9aab294f1c77a480378e84782c8519449aef0c8f6952382bda)
316    //    ```
317    // Note that for some reason chisel doesn't handle multiple return values that well, so you
318    // might have to pattern match the return types, e.g.
319    // ```
320    // > (uint256 a, uint256 b) = ve.compress_g2(...);
321    // > a;
322    // Type: uint256
323    // ├ Hex: 0x1dd212f101a320736a9662cac57929556777fad3e7882b022d4ba3261cf14db6
324    // ├ Hex (full word): 0x1dd212f101a320736a9662cac57929556777fad3e7882b022d4ba3261cf14db6
325    // └ Decimal: 13488241221471993734368286196608381596836013455766665997449768358320614231478
326    // ```
327
328    #[test]
329    fn proof_compression() {
330        let flat_proof: [U256; 8] = uint! { [
331            20565048055856194013099208963146657799256893353279242520150547463020687826541_U256,
332            16286013012747852737396822706018267259565592188907848191354824303311847109059_U256,
333            4348608846293503080802796983494208797681981448804902149317789801083784587558_U256,
334            6172488348732750834133346196464201580503416389945891763609808290085997580078_U256,
335            3229429189805934086496276224876305383924675874777054942516982958483565949767_U256,
336            944252930093106871283598150477854448876343937304805759422971930315581301659_U256,
337            18318130744212307125672524358864792312717149086464333958791498157127232409959_U256,
338            8256141885907329266852096557308020923997215847794048916749940281741155521604_U256,
339        ]};
340        let proof = Proof::from_flat(flat_proof);
341
342        let compressed = compress_proof(proof).unwrap();
343        let exp_flat_compressed: [U256; 4] = uint! {[
344            41130096111712388026198417926293315598513786706558485040301094926041375653083_U256,
345            4348608846293503080802796983494208797681981448804902149317789801083784587558_U256,
346            24689953394931003336533384785856806322013665559783567054439233160343990320315_U256,
347            36636261488424614251345048717729584625434298172928667917582996314254464819918_U256,
348        ]};
349
350        assert_eq!(exp_flat_compressed, compressed.flatten());
351
352        let decompressed = decompress_proof(compressed).unwrap();
353
354        assert_eq!(proof, decompressed);
355    }
356
357    #[test]
358    fn g1_compression() {
359        let point: G1 = uint! {
360            (
361                0x19ded61ab5c58fdb12367526c6bc04b9186d0980c4b6fd48a44093e80f9b4206_U256,
362                0x2e619a034be10e9aab294f1c77a480378e84782c8519449aef0c8f6952382bda_U256
363            )
364        };
365        let exp_compressed =
366            uint! { 0x33bdac356b8b1fb6246cea4d8d78097230da1301896dfa91488127d01f36840c_U256 };
367
368        let compressed = compress_g1(point).unwrap();
369        assert_eq!(exp_compressed, compressed);
370
371        let decompressed = decompress_g1(compressed).unwrap();
372        assert_eq!(point, decompressed);
373    }
374
375    #[test]
376    fn g2_compression() {
377        let point: G2 = uint! {
378            (
379                [
380                    0x077484BC4068C81CDAA598B2B15E4A5559DDFEB4F9E20AC08B52E8C9873C536D_U256,
381                    0x25E744163329AABFB40086C09E0B54D09DFBD302CE975E71150133E46E75F0AA_U256,
382                ],
383                [
384                    0x20AF3E3AFED950A86937F4319100B19A1141FF59DA42B9670CFA57E5D83BE618_U256,
385                    0x089C901AA5603652F8CC748F04907233C63A75302244D67FF974B05AF09948D2_U256,
386                ]
387            )
388        };
389
390        let compressed = compress_g2(point).unwrap();
391        let exp_compressed = uint! { (0x1dd212f101a320736a9662cac57929556777fad3e7882b022d4ba3261cf14db6_U256, 0x25e744163329aabfb40086c09e0b54d09dfbd302ce975e71150133e46e75f0aa_U256) };
392
393        assert_eq!(exp_compressed, compressed);
394        let decompressed = decompress_g2(compressed).unwrap();
395
396        assert_eq!(point, decompressed);
397    }
398
399    #[test]
400    fn deser() {
401        let s = r#"["0x1",["0x2","0x3"],"0x4"]"#;
402
403        let deserialized: CompressedProof = serde_json::from_str(s).unwrap();
404        let reserialized = serde_json::to_string(&deserialized).unwrap();
405
406        assert_eq!(s, reserialized);
407    }
408}