Skip to main content

world_id_primitives/
proof.rs

1use ark_bn254::Bn254;
2use ark_groth16::Proof;
3use ruint::aliases::U256;
4use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as _};
5
6use crate::FieldElement;
7
8/// Encoded World ID Proof.
9///
10/// Internally, the first 4 elements are a compressed Groth16 proof
11/// [a (G1), b (G2), b (G2), c (G1)]. Proofs also require the root hash of the Merkle tree
12/// in the `WorldIDRegistry` as a public input. To simplify transmission, that root is encoded as the last element
13/// with the proof.
14///
15/// The `WorldIDVerifier.sol` contract handles the decoding and verification.
16#[derive(Debug, Default, Clone, PartialEq, Eq)]
17pub struct ZeroKnowledgeProof {
18    /// Array of 5 U256 values: first 4 are compressed Groth16 proof, last is Merkle root.
19    inner: [U256; 5],
20}
21
22impl ZeroKnowledgeProof {
23    /// Initialize a new proof from a Groth16 proof and Merkle root.
24    #[must_use]
25    pub fn from_groth16_proof(groth16_proof: &Proof<Bn254>, merkle_root: FieldElement) -> Self {
26        let compressed_proof = taceo_groth16_sol::prepare_compressed_proof(groth16_proof);
27        Self {
28            inner: [
29                compressed_proof[0],
30                compressed_proof[1],
31                compressed_proof[2],
32                compressed_proof[3],
33                merkle_root.into(),
34            ],
35        }
36    }
37
38    /// Outputs the proof as a Solidity-friendly representation.
39    #[must_use]
40    pub const fn as_ethereum_representation(&self) -> [U256; 5] {
41        self.inner
42    }
43
44    /// Initializes a proof from an encoded Solidity-friendly representation.
45    #[must_use]
46    pub const fn from_ethereum_representation(value: [U256; 5]) -> Self {
47        Self { inner: value }
48    }
49
50    /// Converts the proof to compressed bytes (160 bytes total: 5 × 32 bytes).
51    #[must_use]
52    pub fn to_compressed_bytes(&self) -> Vec<u8> {
53        self.inner
54            .iter()
55            .flat_map(U256::to_be_bytes::<32>)
56            .collect()
57    }
58
59    /// Constructs a proof from compressed bytes (must be exactly 160 bytes).
60    ///
61    /// # Errors
62    /// Returns an error if the input is not exactly 160 bytes or if bytes cannot be parsed.
63    pub fn from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
64        if bytes.len() != 160 {
65            return Err(format!(
66                "Invalid length: expected 160 bytes, got {}",
67                bytes.len()
68            ));
69        }
70
71        let mut inner = [U256::ZERO; 5];
72        for (i, chunk) in bytes.chunks_exact(32).enumerate() {
73            let mut arr = [0u8; 32];
74            arr.copy_from_slice(chunk);
75            inner[i] = U256::from_be_bytes(arr);
76        }
77
78        Ok(Self { inner })
79    }
80}
81
82impl Serialize for ZeroKnowledgeProof {
83    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84    where
85        S: Serializer,
86    {
87        let bytes = self.to_compressed_bytes();
88        if serializer.is_human_readable() {
89            serializer.serialize_str(&hex::encode(bytes))
90        } else {
91            serializer.serialize_bytes(&bytes)
92        }
93    }
94}
95
96impl<'de> Deserialize<'de> for ZeroKnowledgeProof {
97    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
98    where
99        D: Deserializer<'de>,
100    {
101        let bytes = if deserializer.is_human_readable() {
102            let hex_str = String::deserialize(deserializer)?;
103            hex::decode(hex_str).map_err(D::Error::custom)?
104        } else {
105            Vec::deserialize(deserializer)?
106        };
107
108        Self::from_compressed_bytes(&bytes).map_err(D::Error::custom)
109    }
110}
111
112impl From<ZeroKnowledgeProof> for [U256; 5] {
113    fn from(value: ZeroKnowledgeProof) -> Self {
114        value.inner
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use ruint::uint;
121
122    use super::*;
123
124    #[test]
125    fn test_encoding_round_trip() {
126        let proof = ZeroKnowledgeProof::default();
127        let compressed_bytes = proof.to_compressed_bytes();
128
129        assert_eq!(compressed_bytes.len(), 160);
130
131        let encoded = serde_json::to_string(&proof).unwrap();
132        assert_eq!(
133            encoded,
134            "\"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000\""
135        );
136
137        let proof_from = ZeroKnowledgeProof::from_compressed_bytes(&compressed_bytes).unwrap();
138
139        assert_eq!(proof.inner, proof_from.inner);
140    }
141
142    #[test]
143    fn test_json_deserialization() {
144        let proof = ZeroKnowledgeProof::default();
145
146        // Test roundtrip serialization
147        let json_str = serde_json::to_string(&proof).unwrap();
148        let deserialized_proof: ZeroKnowledgeProof = serde_json::from_str(&json_str).unwrap();
149
150        // Verify the roundtrip preserved all values
151        assert_eq!(proof.inner, deserialized_proof.inner);
152    }
153
154    #[test]
155    fn test_from_ethereum_representation() {
156        let values = [
157            uint!(0x0000000000000000000000000000000000000000000000000000000000000001_U256),
158            uint!(0x0000000000000000000000000000000000000000000000000000000000000002_U256),
159            uint!(0x0000000000000000000000000000000000000000000000000000000000000003_U256),
160            uint!(0x0000000000000000000000000000000000000000000000000000000000000004_U256),
161            uint!(0x11d223ce7b91ac212f42cf50f0a3439ae3fcdba4ea32acb7f194d1051ed324c2_U256),
162        ];
163
164        let proof = ZeroKnowledgeProof::from_ethereum_representation(values);
165        assert_eq!(proof.as_ethereum_representation(), values);
166
167        // Test serialization roundtrip
168        let bytes = proof.to_compressed_bytes();
169        assert_eq!(bytes.len(), 160);
170
171        let proof_from_bytes = ZeroKnowledgeProof::from_compressed_bytes(&bytes).unwrap();
172        assert_eq!(proof.inner, proof_from_bytes.inner);
173    }
174
175    #[test]
176    fn test_invalid_bytes_length() {
177        let too_short = vec![0u8; 159];
178        let result = ZeroKnowledgeProof::from_compressed_bytes(&too_short);
179        assert!(result.is_err());
180        assert!(result.unwrap_err().contains("Invalid length"));
181
182        let too_long = vec![0u8; 161];
183        let result = ZeroKnowledgeProof::from_compressed_bytes(&too_long);
184        assert!(result.is_err());
185        assert!(result.unwrap_err().contains("Invalid length"));
186    }
187}