1use ark_bn254::{Bn254, Fr, G1Affine, G2Affine, Fq, Fq2};
4use ark_groth16::Proof;
5use ark_ff::PrimeField;
6use serde::{Deserialize, Serialize};
7
8use crate::error::SdkError;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Groth16ProofJson {
13 pub pi_a: Vec<String>,
14 pub pi_b: Vec<Vec<String>>,
15 pub pi_c: Vec<String>,
16 pub protocol: String,
17 pub curve: String,
18}
19
20#[derive(Clone)]
22pub struct Groth16Proof {
23 pub(crate) inner: Proof<Bn254>,
24}
25
26impl Groth16Proof {
27 pub fn from_json(json: &str) -> Result<Self, SdkError> {
29 let proof_json: Groth16ProofJson = serde_json::from_str(json)?;
30 Self::from_snarkjs(&proof_json)
31 }
32
33 pub fn from_snarkjs(proof: &Groth16ProofJson) -> Result<Self, SdkError> {
35 let a = parse_g1(&proof.pi_a)?;
36 let b = parse_g2(&proof.pi_b)?;
37 let c = parse_g1(&proof.pi_c)?;
38
39 Ok(Self {
40 inner: Proof { a, b, c }
41 })
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct PublicInputs {
48 pub(crate) values: Vec<Fr>,
49 pub npub_hex: String,
51 pub merkle_root: String,
53 pub session_binding: Option<String>,
55}
56
57impl PublicInputs {
58 pub fn from_json(json: &str) -> Result<Self, SdkError> {
60 let strings: Vec<String> = serde_json::from_str(json)?;
61 Self::from_strings(&strings)
62 }
63
64 pub fn from_strings(strings: &[String]) -> Result<Self, SdkError> {
66 if strings.len() < 3 {
73 return Err(SdkError::InvalidProof(
74 "Public inputs must have at least 3 elements (npub_x, npub_y_parity, merkle_root)".into()
75 ));
76 }
77
78 let values: Vec<Fr> = strings
79 .iter()
80 .map(|s| parse_fr(s))
81 .collect::<Result<Vec<_>, _>>()?;
82
83 let npub_x = &strings[0];
86 let npub_y_parity = &strings[1];
87
88 let x_bytes = parse_bigint_to_bytes(npub_x, 32)?;
90 let parity: u8 = npub_y_parity.parse()
91 .map_err(|_| SdkError::InvalidProof("Invalid y parity".into()))?;
92 let prefix = if parity == 0 { 0x02 } else { 0x03 };
93
94 let mut compressed = vec![prefix];
95 compressed.extend_from_slice(&x_bytes);
96
97 let npub_hex = hex::encode(&x_bytes);
99
100 let merkle_root = format_as_hex(&strings[2])?;
101
102 let session_binding = if strings.len() > 3 {
103 Some(format_as_hex(&strings[3])?)
104 } else {
105 None
106 };
107
108 Ok(Self {
109 values,
110 npub_hex,
111 merkle_root,
112 session_binding,
113 })
114 }
115
116 pub fn npub_bech32(&self) -> Result<String, SdkError> {
118 crate::hex_to_npub(&self.npub_hex)
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct VerificationResult {
125 pub valid: bool,
127 pub npub_hex: String,
129 pub npub: String,
131 pub merkle_root: String,
133 pub session_binding: Option<String>,
135}
136
137fn parse_g1(coords: &[String]) -> Result<G1Affine, SdkError> {
140 if coords.len() < 2 {
141 return Err(SdkError::InvalidProof("G1 needs at least 2 coordinates".into()));
142 }
143
144 let x = parse_fq(&coords[0])?;
145 let y = parse_fq(&coords[1])?;
146
147 Ok(G1Affine::new(x, y))
148}
149
150fn parse_g2(coords: &[Vec<String>]) -> Result<G2Affine, SdkError> {
151 if coords.len() < 2 {
152 return Err(SdkError::InvalidProof("G2 needs at least 2 coordinate pairs".into()));
153 }
154
155 let x = Fq2::new(parse_fq(&coords[0][0])?, parse_fq(&coords[0][1])?);
156 let y = Fq2::new(parse_fq(&coords[1][0])?, parse_fq(&coords[1][1])?);
157
158 Ok(G2Affine::new(x, y))
159}
160
161fn parse_fq(s: &str) -> Result<Fq, SdkError> {
162 use num_bigint::BigUint;
163 use std::str::FromStr;
164
165 let n = BigUint::from_str(s)
166 .map_err(|e| SdkError::InvalidProof(format!("Invalid field element: {}", e)))?;
167
168 let bytes = n.to_bytes_le();
169 let mut padded = [0u8; 32];
170 let len = bytes.len().min(32);
171 padded[..len].copy_from_slice(&bytes[..len]);
172
173 Ok(Fq::from_le_bytes_mod_order(&padded))
174}
175
176fn parse_fr(s: &str) -> Result<Fr, SdkError> {
177 use num_bigint::BigUint;
178 use std::str::FromStr;
179
180 let n = BigUint::from_str(s)
181 .map_err(|e| SdkError::InvalidProof(format!("Invalid field element: {}", e)))?;
182
183 let bytes = n.to_bytes_le();
184 let mut padded = [0u8; 32];
185 let len = bytes.len().min(32);
186 padded[..len].copy_from_slice(&bytes[..len]);
187
188 Ok(Fr::from_le_bytes_mod_order(&padded))
189}
190
191fn parse_bigint_to_bytes(s: &str, len: usize) -> Result<Vec<u8>, SdkError> {
192 use num_bigint::BigUint;
193 use std::str::FromStr;
194
195 let n = BigUint::from_str(s)
196 .map_err(|e| SdkError::InvalidProof(format!("Invalid number: {}", e)))?;
197
198 let mut bytes = n.to_bytes_be();
199
200 if bytes.len() < len {
202 let mut padded = vec![0u8; len - bytes.len()];
203 padded.extend(bytes);
204 bytes = padded;
205 } else if bytes.len() > len {
206 bytes = bytes[bytes.len() - len..].to_vec();
207 }
208
209 Ok(bytes)
210}
211
212fn format_as_hex(decimal_str: &str) -> Result<String, SdkError> {
213 let bytes = parse_bigint_to_bytes(decimal_str, 32)?;
214 Ok(hex::encode(bytes))
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_parse_fr() {
223 let result = parse_fr("123456789");
224 assert!(result.is_ok());
225 }
226}