Skip to main content

signedby_sdk/
proof.rs

1//! Groth16 proof types and parsing
2
3use ark_bn254::{Bn254, Fr, G1Affine, G2Affine, Fq, Fq2};
4use ark_groth16::Proof;
5use ark_ff::PrimeField;
6use serde::{Deserialize, Serialize};
7
8use crate::error::SdkError;
9
10/// A Groth16 proof in snarkjs JSON format
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Groth16ProofJson {
13    pub pi_a: Vec<String>,
14    pub pi_b: Vec<Vec<String>>,
15    pub pi_c: Vec<String>,
16    pub protocol: String,
17    pub curve: String,
18}
19
20/// Parsed Groth16 proof ready for verification
21#[derive(Clone)]
22pub struct Groth16Proof {
23    pub(crate) inner: Proof<Bn254>,
24}
25
26impl Groth16Proof {
27    /// Parse proof from snarkjs JSON format
28    pub fn from_json(json: &str) -> Result<Self, SdkError> {
29        let proof_json: Groth16ProofJson = serde_json::from_str(json)?;
30        Self::from_snarkjs(&proof_json)
31    }
32    
33    /// Parse from snarkjs proof object
34    pub fn from_snarkjs(proof: &Groth16ProofJson) -> Result<Self, SdkError> {
35        let a = parse_g1(&proof.pi_a)?;
36        let b = parse_g2(&proof.pi_b)?;
37        let c = parse_g1(&proof.pi_c)?;
38        
39        Ok(Self {
40            inner: Proof { a, b, c }
41        })
42    }
43}
44
45/// Public inputs to the proof
46#[derive(Debug, Clone)]
47pub struct PublicInputs {
48    pub(crate) values: Vec<Fr>,
49    /// The npub (NOSTR public key) extracted from public outputs
50    pub npub_hex: String,
51    /// Merkle root hash
52    pub merkle_root: String,
53    /// Session binding (if present)
54    pub session_binding: Option<String>,
55}
56
57impl PublicInputs {
58    /// Parse public inputs from snarkjs JSON format (array of decimal strings)
59    pub fn from_json(json: &str) -> Result<Self, SdkError> {
60        let strings: Vec<String> = serde_json::from_str(json)?;
61        Self::from_strings(&strings)
62    }
63    
64    /// Parse from string array
65    pub fn from_strings(strings: &[String]) -> Result<Self, SdkError> {
66        // SignedByMe circuit public outputs:
67        // [0]: npub_x (secp256k1 x-coordinate, 256 bits)
68        // [1]: npub_y_parity (0 or 1)
69        // [2]: merkle_root
70        // [3]: session_binding (optional)
71        
72        if strings.len() < 3 {
73            return Err(SdkError::InvalidProof(
74                "Public inputs must have at least 3 elements (npub_x, npub_y_parity, merkle_root)".into()
75            ));
76        }
77        
78        let values: Vec<Fr> = strings
79            .iter()
80            .map(|s| parse_fr(s))
81            .collect::<Result<Vec<_>, _>>()?;
82        
83        // Extract npub from first two public inputs
84        // npub_x is the x-coordinate, npub_y_parity determines even/odd y
85        let npub_x = &strings[0];
86        let npub_y_parity = &strings[1];
87        
88        // Convert to compressed public key format (33 bytes: prefix + x)
89        let x_bytes = parse_bigint_to_bytes(npub_x, 32)?;
90        let parity: u8 = npub_y_parity.parse()
91            .map_err(|_| SdkError::InvalidProof("Invalid y parity".into()))?;
92        let prefix = if parity == 0 { 0x02 } else { 0x03 };
93        
94        let mut compressed = vec![prefix];
95        compressed.extend_from_slice(&x_bytes);
96        
97        // For NOSTR, we use the x-only format (32 bytes, no prefix)
98        let npub_hex = hex::encode(&x_bytes);
99        
100        let merkle_root = format_as_hex(&strings[2])?;
101        
102        let session_binding = if strings.len() > 3 {
103            Some(format_as_hex(&strings[3])?)
104        } else {
105            None
106        };
107        
108        Ok(Self {
109            values,
110            npub_hex,
111            merkle_root,
112            session_binding,
113        })
114    }
115    
116    /// Get the npub in bech32 format
117    pub fn npub_bech32(&self) -> Result<String, SdkError> {
118        crate::hex_to_npub(&self.npub_hex)
119    }
120}
121
122/// Result of proof verification
123#[derive(Debug, Clone)]
124pub struct VerificationResult {
125    /// Whether the proof is valid
126    pub valid: bool,
127    /// The npub (NOSTR public key) in hex format
128    pub npub_hex: String,
129    /// The npub in bech32 format
130    pub npub: String,
131    /// Merkle root the user proved membership in
132    pub merkle_root: String,
133    /// Session binding (prevents proof replay)
134    pub session_binding: Option<String>,
135}
136
137// Helper functions
138
139fn parse_g1(coords: &[String]) -> Result<G1Affine, SdkError> {
140    if coords.len() < 2 {
141        return Err(SdkError::InvalidProof("G1 needs at least 2 coordinates".into()));
142    }
143    
144    let x = parse_fq(&coords[0])?;
145    let y = parse_fq(&coords[1])?;
146    
147    Ok(G1Affine::new(x, y))
148}
149
150fn parse_g2(coords: &[Vec<String>]) -> Result<G2Affine, SdkError> {
151    if coords.len() < 2 {
152        return Err(SdkError::InvalidProof("G2 needs at least 2 coordinate pairs".into()));
153    }
154    
155    let x = Fq2::new(parse_fq(&coords[0][0])?, parse_fq(&coords[0][1])?);
156    let y = Fq2::new(parse_fq(&coords[1][0])?, parse_fq(&coords[1][1])?);
157    
158    Ok(G2Affine::new(x, y))
159}
160
161fn parse_fq(s: &str) -> Result<Fq, SdkError> {
162    use num_bigint::BigUint;
163    use std::str::FromStr;
164    
165    let n = BigUint::from_str(s)
166        .map_err(|e| SdkError::InvalidProof(format!("Invalid field element: {}", e)))?;
167    
168    let bytes = n.to_bytes_le();
169    let mut padded = [0u8; 32];
170    let len = bytes.len().min(32);
171    padded[..len].copy_from_slice(&bytes[..len]);
172    
173    Ok(Fq::from_le_bytes_mod_order(&padded))
174}
175
176fn parse_fr(s: &str) -> Result<Fr, SdkError> {
177    use num_bigint::BigUint;
178    use std::str::FromStr;
179    
180    let n = BigUint::from_str(s)
181        .map_err(|e| SdkError::InvalidProof(format!("Invalid field element: {}", e)))?;
182    
183    let bytes = n.to_bytes_le();
184    let mut padded = [0u8; 32];
185    let len = bytes.len().min(32);
186    padded[..len].copy_from_slice(&bytes[..len]);
187    
188    Ok(Fr::from_le_bytes_mod_order(&padded))
189}
190
191fn parse_bigint_to_bytes(s: &str, len: usize) -> Result<Vec<u8>, SdkError> {
192    use num_bigint::BigUint;
193    use std::str::FromStr;
194    
195    let n = BigUint::from_str(s)
196        .map_err(|e| SdkError::InvalidProof(format!("Invalid number: {}", e)))?;
197    
198    let mut bytes = n.to_bytes_be();
199    
200    // Pad or truncate to desired length
201    if bytes.len() < len {
202        let mut padded = vec![0u8; len - bytes.len()];
203        padded.extend(bytes);
204        bytes = padded;
205    } else if bytes.len() > len {
206        bytes = bytes[bytes.len() - len..].to_vec();
207    }
208    
209    Ok(bytes)
210}
211
212fn format_as_hex(decimal_str: &str) -> Result<String, SdkError> {
213    let bytes = parse_bigint_to_bytes(decimal_str, 32)?;
214    Ok(hex::encode(bytes))
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    
221    #[test]
222    fn test_parse_fr() {
223        let result = parse_fr("123456789");
224        assert!(result.is_ok());
225    }
226}