rust_rapidsnark/
lib.rs

1//! Rust bindings for rapidsnark proving.
2//!
3//! Prebuilt binaries are provided for the following platforms:
4//! - aarch64-apple-ios
5//! - aarch64-apple-ios-sim
6//! - x86_64-apple-ios
7//! - aarch64-apple-darwin
8//! - x86_64-apple-darwin
9//! - aarch64-linux-android
10//! - x86_64-linux-android
11//! - x86_64 linux
12//! - arm64 linux
13//!
14//! If a specific target is not included the sysytem will fallback to
15//! the generic architecture, which may cause problems. e.g. if you compile
16//! for aarch64-linux-generic, the system will fallback to aarch64.
17//!
18
19use std::collections::HashMap;
20use std::str::FromStr;
21
22use anyhow::Result;
23use num_bigint::BigInt;
24
25/// A function that converts named inputs to a full witness. This should be generated using e.g.
26/// [rust-witness](https://crates.io/crates/rust-witness).
27pub type WtnsFn = fn(HashMap<String, Vec<BigInt>>) -> Vec<BigInt>;
28
29/// A structure representing a proof and public signals.
30#[derive(Debug)]
31pub struct ProofResult {
32    pub proof: String,
33    pub public_signals: String,
34}
35
36#[link(name = "rapidsnark", kind = "static")]
37extern "C" {
38    pub fn groth16_prover_zkey_file(
39        zkey_file_path: *const std::os::raw::c_char,
40        wtns_buffer: *const std::os::raw::c_void,
41        wtns_size: std::ffi::c_ulong,
42        proof_buffer: *mut std::os::raw::c_char,
43        proof_size: *mut std::ffi::c_ulong,
44        public_buffer: *mut std::os::raw::c_char,
45        public_size: *mut std::ffi::c_ulong,
46        error_msg: *mut std::os::raw::c_char,
47        error_msg_maxsize: std::ffi::c_ulong,
48    ) -> i32;
49
50    pub fn groth16_verify(
51        proof: *const std::os::raw::c_char,
52        inputs: *const std::os::raw::c_char,
53        verification_key: *const std::os::raw::c_char,
54        error_msg: *mut std::os::raw::c_char,
55        error_msg_maxsize: std::ffi::c_ulong,
56    ) -> i32;
57}
58
59use num_traits::ops::bytes::ToBytes;
60use std::io::{self};
61
62/// Parse bigints to `wtns` format.<br/>
63/// Reference: [witnesscalc/src/witnesscalc.cpp](https://github.com/0xPolygonID/witnesscalc/blob/4a789880727aa0df50f1c4ef78ec295f5a30a15e/src/witnesscalc.cpp)
64pub fn parse_bigints_to_witness(bigints: Vec<BigInt>) -> io::Result<Vec<u8>> {
65    let mut buffer = Vec::new();
66    let version: u32 = 2;
67    let n_sections: u32 = 2;
68    let n8: u32 = 32;
69    let q = BigInt::from_str(
70        "21888242871839275222246405745257275088548364400416034343698204186575808495617",
71    )
72    .unwrap();
73    let n_witness_values: u32 = bigints.len() as u32;
74
75    // Write the format bytes (4 bytes)
76    let wtns_format = "wtns".as_bytes();
77    buffer.extend_from_slice(wtns_format);
78
79    // Write version (4 bytes)
80    buffer.extend_from_slice(&version.to_le_bytes());
81
82    // Write number of sections (4 bytes)
83    buffer.extend_from_slice(&n_sections.to_le_bytes());
84
85    // Iterate through sections to write the data
86    // Section 1 (Field parameters)
87    let section_id_1: u32 = 1;
88    let section_length_1: u64 = 8 + n8 as u64;
89    buffer.extend_from_slice(&section_id_1.to_le_bytes());
90    buffer.extend_from_slice(&section_length_1.to_le_bytes());
91
92    // Write n8 (4 bytes), q (32 bytes), and n_witness_values (4 bytes)
93    buffer.extend_from_slice(&n8.to_le_bytes());
94    buffer.extend_from_slice(&q.to_signed_bytes_le());
95    buffer.extend_from_slice(&n_witness_values.to_le_bytes());
96
97    // Section 2 (Witness data)
98    let section_id_2: u32 = 2;
99    let section_length_2: u64 = bigints.len() as u64 * n8 as u64; // Witness data size
100    buffer.extend_from_slice(&section_id_2.to_le_bytes());
101    buffer.extend_from_slice(&section_length_2.to_le_bytes());
102
103    // Write the witness data (each BigInt to n8 bytes)
104    for bigint in bigints {
105        let mut bytes = bigint.to_le_bytes();
106        bytes.resize(n8 as usize, 0); // Ensure each BigInt is padded to n8 bytes
107        buffer.extend_from_slice(&bytes);
108    }
109
110    // Return the buffer containing the complete witness data
111    Ok(buffer)
112}
113
114/// Wrapper for `groth16_prover_zkey_file`
115pub fn groth16_prover_zkey_file_wrapper(
116    zkey_path: &str,
117    wtns_buffer: Vec<u8>,
118) -> Result<ProofResult> {
119    let formatted_zkey_path = std::ffi::CString::new(zkey_path).unwrap();
120    let wtns_size = wtns_buffer.len() as u64;
121
122    let mut proof_buffer = vec![0u8; 4 * 1024 * 1024]; // Adjust size as needed
123    let mut proof_size: u64 = 4 * 1024 * 1024;
124    let proof_ptr = proof_buffer.as_mut_ptr() as *mut std::ffi::c_char;
125
126    let mut public_buffer = vec![0u8; 4 * 1024 * 1024]; // Adjust size as needed
127    let mut public_size: u64 = 4 * 1024 * 1024;
128    let public_ptr = public_buffer.as_mut_ptr() as *mut std::ffi::c_char;
129
130    let mut error_msg = vec![0u8; 256]; // Error message buffer
131    let error_msg_ptr = error_msg.as_mut_ptr() as *mut std::ffi::c_char;
132
133    unsafe {
134        let result = groth16_prover_zkey_file(
135            formatted_zkey_path.as_ptr() as *const std::ffi::c_char,
136            wtns_buffer.as_ptr() as *const std::os::raw::c_void, // Witness buffer
137            wtns_size,
138            proof_ptr,
139            &mut proof_size,
140            public_ptr,
141            &mut public_size,
142            error_msg_ptr,
143            error_msg.len() as u64,
144        );
145        if result != 0 {
146            let error_string = std::ffi::CStr::from_ptr(error_msg_ptr)
147                .to_string_lossy()
148                .into_owned();
149            return Err(anyhow::anyhow!("Proof generation failed: {}", error_string));
150        }
151        // Convert both strings
152        let proof = std::ffi::CStr::from_ptr(proof_ptr)
153            .to_string_lossy()
154            .into_owned();
155        let public_signals = std::ffi::CStr::from_ptr(public_ptr)
156            .to_string_lossy()
157            .into_owned();
158        Ok(ProofResult {
159            proof,
160            public_signals,
161        })
162    }
163}
164
165/// Wrapper for `groth16_verify`
166pub fn groth16_verify_wrapper(proof: &str, inputs: &str, verification_key: &str) -> Result<bool> {
167    let proof_cstr = std::ffi::CString::new(proof).unwrap();
168    let inputs_cstr = std::ffi::CString::new(inputs).unwrap();
169    let verification_key_cstr = std::ffi::CString::new(verification_key).unwrap();
170
171    let mut error_msg = vec![0u8; 256]; // Error message buffer
172    let error_msg_ptr = error_msg.as_mut_ptr() as *mut std::ffi::c_char;
173    unsafe {
174        let result = groth16_verify(
175            proof_cstr.as_ptr() as *const std::ffi::c_char,
176            inputs_cstr.as_ptr() as *const std::ffi::c_char,
177            verification_key_cstr.as_ptr() as *const std::ffi::c_char,
178            error_msg_ptr,
179            error_msg.len() as u64,
180        );
181        if result == 2 {
182            let error_string = std::ffi::CStr::from_ptr(error_msg_ptr)
183                .to_string_lossy()
184                .into_owned();
185            return Err(anyhow::anyhow!(
186                "Proof verification failed: {}",
187                error_string
188            ));
189        }
190        Ok(result == 0)
191    }
192}