1use std::collections::HashMap;
20use std::str::FromStr;
21
22use anyhow::Result;
23use num_bigint::BigInt;
24
25pub type WtnsFn = fn(HashMap<String, Vec<BigInt>>) -> Vec<BigInt>;
28
29#[derive(Debug)]
31pub struct ProofResult {
32 pub proof: String,
33 pub public_signals: String,
34}
35
36#[link(name = "rapidsnark", kind = "static")]
37extern "C" {
38 pub fn groth16_prover_zkey_file(
39 zkey_file_path: *const std::os::raw::c_char,
40 wtns_buffer: *const std::os::raw::c_void,
41 wtns_size: std::ffi::c_ulong,
42 proof_buffer: *mut std::os::raw::c_char,
43 proof_size: *mut std::ffi::c_ulong,
44 public_buffer: *mut std::os::raw::c_char,
45 public_size: *mut std::ffi::c_ulong,
46 error_msg: *mut std::os::raw::c_char,
47 error_msg_maxsize: std::ffi::c_ulong,
48 ) -> i32;
49
50 pub fn groth16_verify(
51 proof: *const std::os::raw::c_char,
52 inputs: *const std::os::raw::c_char,
53 verification_key: *const std::os::raw::c_char,
54 error_msg: *mut std::os::raw::c_char,
55 error_msg_maxsize: std::ffi::c_ulong,
56 ) -> i32;
57}
58
59use num_traits::ops::bytes::ToBytes;
60use std::io::{self};
61
62pub fn parse_bigints_to_witness(bigints: Vec<BigInt>) -> io::Result<Vec<u8>> {
65 let mut buffer = Vec::new();
66 let version: u32 = 2;
67 let n_sections: u32 = 2;
68 let n8: u32 = 32;
69 let q = BigInt::from_str(
70 "21888242871839275222246405745257275088548364400416034343698204186575808495617",
71 )
72 .unwrap();
73 let n_witness_values: u32 = bigints.len() as u32;
74
75 let wtns_format = "wtns".as_bytes();
77 buffer.extend_from_slice(wtns_format);
78
79 buffer.extend_from_slice(&version.to_le_bytes());
81
82 buffer.extend_from_slice(&n_sections.to_le_bytes());
84
85 let section_id_1: u32 = 1;
88 let section_length_1: u64 = 8 + n8 as u64;
89 buffer.extend_from_slice(§ion_id_1.to_le_bytes());
90 buffer.extend_from_slice(§ion_length_1.to_le_bytes());
91
92 buffer.extend_from_slice(&n8.to_le_bytes());
94 buffer.extend_from_slice(&q.to_signed_bytes_le());
95 buffer.extend_from_slice(&n_witness_values.to_le_bytes());
96
97 let section_id_2: u32 = 2;
99 let section_length_2: u64 = bigints.len() as u64 * n8 as u64; buffer.extend_from_slice(§ion_id_2.to_le_bytes());
101 buffer.extend_from_slice(§ion_length_2.to_le_bytes());
102
103 for bigint in bigints {
105 let mut bytes = bigint.to_le_bytes();
106 bytes.resize(n8 as usize, 0); buffer.extend_from_slice(&bytes);
108 }
109
110 Ok(buffer)
112}
113
114pub fn groth16_prover_zkey_file_wrapper(
116 zkey_path: &str,
117 wtns_buffer: Vec<u8>,
118) -> Result<ProofResult> {
119 let formatted_zkey_path = std::ffi::CString::new(zkey_path).unwrap();
120 let wtns_size = wtns_buffer.len() as u64;
121
122 let mut proof_buffer = vec![0u8; 4 * 1024 * 1024]; let mut proof_size: u64 = 4 * 1024 * 1024;
124 let proof_ptr = proof_buffer.as_mut_ptr() as *mut std::ffi::c_char;
125
126 let mut public_buffer = vec![0u8; 4 * 1024 * 1024]; let mut public_size: u64 = 4 * 1024 * 1024;
128 let public_ptr = public_buffer.as_mut_ptr() as *mut std::ffi::c_char;
129
130 let mut error_msg = vec![0u8; 256]; let error_msg_ptr = error_msg.as_mut_ptr() as *mut std::ffi::c_char;
132
133 unsafe {
134 let result = groth16_prover_zkey_file(
135 formatted_zkey_path.as_ptr() as *const std::ffi::c_char,
136 wtns_buffer.as_ptr() as *const std::os::raw::c_void, wtns_size,
138 proof_ptr,
139 &mut proof_size,
140 public_ptr,
141 &mut public_size,
142 error_msg_ptr,
143 error_msg.len() as u64,
144 );
145 if result != 0 {
146 let error_string = std::ffi::CStr::from_ptr(error_msg_ptr)
147 .to_string_lossy()
148 .into_owned();
149 return Err(anyhow::anyhow!("Proof generation failed: {}", error_string));
150 }
151 let proof = std::ffi::CStr::from_ptr(proof_ptr)
153 .to_string_lossy()
154 .into_owned();
155 let public_signals = std::ffi::CStr::from_ptr(public_ptr)
156 .to_string_lossy()
157 .into_owned();
158 Ok(ProofResult {
159 proof,
160 public_signals,
161 })
162 }
163}
164
165pub fn groth16_verify_wrapper(proof: &str, inputs: &str, verification_key: &str) -> Result<bool> {
167 let proof_cstr = std::ffi::CString::new(proof).unwrap();
168 let inputs_cstr = std::ffi::CString::new(inputs).unwrap();
169 let verification_key_cstr = std::ffi::CString::new(verification_key).unwrap();
170
171 let mut error_msg = vec![0u8; 256]; let error_msg_ptr = error_msg.as_mut_ptr() as *mut std::ffi::c_char;
173 unsafe {
174 let result = groth16_verify(
175 proof_cstr.as_ptr() as *const std::ffi::c_char,
176 inputs_cstr.as_ptr() as *const std::ffi::c_char,
177 verification_key_cstr.as_ptr() as *const std::ffi::c_char,
178 error_msg_ptr,
179 error_msg.len() as u64,
180 );
181 if result == 2 {
182 let error_string = std::ffi::CStr::from_ptr(error_msg_ptr)
183 .to_string_lossy()
184 .into_owned();
185 return Err(anyhow::anyhow!(
186 "Proof verification failed: {}",
187 error_string
188 ));
189 }
190 Ok(result == 0)
191 }
192}