solana_zk_sdk/encryption/
discrete_log.rs1#[cfg(not(target_arch = "wasm32"))]
18use std::thread;
19use {
20 crate::RISTRETTO_POINT_LEN,
21 curve25519_dalek::{
22 constants::RISTRETTO_BASEPOINT_POINT as G,
23 ristretto::RistrettoPoint,
24 scalar::Scalar,
25 traits::{Identity, IsIdentity},
26 },
27 itertools::Itertools,
28 serde::{Deserialize, Serialize},
29 std::{collections::HashMap, num::NonZeroUsize},
30 thiserror::Error,
31};
32
33const TWO16: u64 = 65536; const TWO17: u64 = 131072; #[cfg(not(target_arch = "wasm32"))]
38const MAX_THREAD: usize = 65536;
39
40#[derive(Error, Clone, Debug, Eq, PartialEq)]
41pub enum DiscreteLogError {
42 #[error("discrete log number of threads not power-of-two")]
43 DiscreteLogThreads,
44 #[error("discrete log batch size too large")]
45 DiscreteLogBatchSize,
46}
47
48#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
52pub struct DiscreteLog {
53 pub generator: RistrettoPoint,
55 pub target: RistrettoPoint,
57 num_threads: Option<NonZeroUsize>,
59 range_bound: NonZeroUsize,
62 step_point: RistrettoPoint,
64 compression_batch_size: NonZeroUsize,
66}
67
68#[derive(Serialize, Deserialize, Default)]
69pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
70
71#[allow(dead_code)]
73fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
74 let mut hashmap = HashMap::new();
75
76 let two17_scalar = Scalar::from(TWO17);
77 let identity = RistrettoPoint::identity(); let generator = two17_scalar * generator; let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
82 for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
83 let key = point.compress().to_bytes();
84 hashmap.insert(key, x_hi as u16);
85 }
86
87 DecodePrecomputation(hashmap)
88}
89
90pub static DECODE_PRECOMPUTATION_FOR_G: std::sync::LazyLock<DecodePrecomputation> =
92 std::sync::LazyLock::new(|| {
93 static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
94 include_bytes!("decode_u32_precomputation_for_G.bincode");
95 bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
96 });
97
98impl DiscreteLog {
100 pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
104 Self {
105 generator,
106 target,
107 num_threads: None,
108 range_bound: (TWO16 as usize).try_into().unwrap(),
109 step_point: G,
110 compression_batch_size: 32.try_into().unwrap(),
111 }
112 }
113
114 #[cfg(not(target_arch = "wasm32"))]
116 pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
117 if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
119 return Err(DiscreteLogError::DiscreteLogThreads);
120 }
121
122 self.num_threads = Some(num_threads);
123 self.range_bound = (TWO16 as usize)
124 .checked_div(num_threads.get())
125 .and_then(|range_bound| range_bound.try_into().ok())
126 .unwrap(); self.step_point = Scalar::from(num_threads.get() as u64) * G;
128
129 Ok(())
130 }
131
132 pub fn set_compression_batch_size(
134 &mut self,
135 compression_batch_size: NonZeroUsize,
136 ) -> Result<(), DiscreteLogError> {
137 if compression_batch_size.get() >= TWO16 as usize {
138 return Err(DiscreteLogError::DiscreteLogBatchSize);
139 }
140 self.compression_batch_size = compression_batch_size;
141
142 Ok(())
143 }
144
145 pub fn decode_u32(self) -> Option<u64> {
148 #[allow(unused_variables)]
149 if let Some(num_threads) = self.num_threads {
150 #[cfg(not(target_arch = "wasm32"))]
151 {
152 let mut starting_point = self.target;
153 let handles = (0..num_threads.get())
154 .map(|i| {
155 let ristretto_iterator = RistrettoIterator::new(
156 (starting_point, i as u64),
157 (-(&self.step_point), num_threads.get() as u64),
158 );
159
160 let handle = thread::spawn(move || {
161 Self::decode_range(
162 ristretto_iterator,
163 self.range_bound,
164 self.compression_batch_size,
165 )
166 });
167
168 starting_point -= G;
169 handle
170 })
171 .collect::<Vec<_>>();
172
173 handles
174 .into_iter()
175 .map_while(|h| h.join().ok())
176 .find(|x| x.is_some())
177 .flatten()
178 }
179 #[cfg(target_arch = "wasm32")]
180 unreachable!() } else {
182 let ristretto_iterator =
183 RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
184
185 Self::decode_range(
186 ristretto_iterator,
187 self.range_bound,
188 self.compression_batch_size,
189 )
190 }
191 }
192
193 fn decode_range(
194 ristretto_iterator: RistrettoIterator,
195 range_bound: NonZeroUsize,
196 compression_batch_size: NonZeroUsize,
197 ) -> Option<u64> {
198 let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
199 let mut decoded = None;
200
201 for batch in &ristretto_iterator
202 .take(range_bound.get())
203 .chunks(compression_batch_size.get())
204 {
205 let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
207 .filter(|(point, index)| {
208 if point.is_identity() {
209 decoded = Some(*index);
210 return false;
211 }
212 true
213 })
214 .unzip();
215
216 let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
217
218 for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
219 let key = point.to_bytes();
220 if hashmap.0.contains_key(&key) {
221 let x_hi = hashmap.0[&key];
222 decoded = Some(x_lo + TWO16 * x_hi as u64);
223 }
224 }
225 }
226
227 decoded
228 }
229}
230
231struct RistrettoIterator {
236 pub current: (RistrettoPoint, u64),
237 pub step: (RistrettoPoint, u64),
238}
239
240impl RistrettoIterator {
241 fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
242 RistrettoIterator { current, step }
243 }
244}
245
246impl Iterator for RistrettoIterator {
247 type Item = (RistrettoPoint, u64);
248
249 fn next(&mut self) -> Option<Self::Item> {
250 let r = self.current;
251 self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
252 Some(r)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use {super::*, std::time::Instant};
259
260 #[test]
261 #[allow(non_snake_case)]
262 fn test_serialize_decode_u32_precomputation_for_G() {
263 let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
264 if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
267 use std::{fs::File, io::Write, path::PathBuf};
268 let mut f = File::create(PathBuf::from(
269 "src/encryption/decode_u32_precomputation_for_G.bincode",
270 ))
271 .unwrap();
272 f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
273 .unwrap();
274 panic!("Rebuild and run this test again");
275 }
276 }
277
278 #[test]
279 fn test_decode_correctness() {
280 let amount: u64 = 4294967295;
282
283 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
284
285 let start_computation = Instant::now();
287 let decoded = instance.decode_u32();
288 let computation_secs = start_computation.elapsed().as_secs_f64();
289
290 assert_eq!(amount, decoded.unwrap());
291
292 println!("single thread discrete log computation secs: {computation_secs:?} sec");
293 }
294
295 #[cfg(not(target_arch = "wasm32"))]
296 #[test]
297 fn test_decode_correctness_threaded() {
298 let amount: u64 = 55;
300
301 let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
302 instance.num_threads(4.try_into().unwrap()).unwrap();
303
304 let start_computation = Instant::now();
306 let decoded = instance.decode_u32();
307 let computation_secs = start_computation.elapsed().as_secs_f64();
308
309 assert_eq!(amount, decoded.unwrap());
310
311 println!("4 thread discrete log computation: {computation_secs:?} sec");
312
313 let amount: u64 = 0;
315
316 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
317
318 let decoded = instance.decode_u32();
319 assert_eq!(amount, decoded.unwrap());
320
321 let amount: u64 = 1;
323
324 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
325
326 let decoded = instance.decode_u32();
327 assert_eq!(amount, decoded.unwrap());
328
329 let amount: u64 = 2;
331
332 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
333
334 let decoded = instance.decode_u32();
335 assert_eq!(amount, decoded.unwrap());
336
337 let amount: u64 = 3;
339
340 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
341
342 let decoded = instance.decode_u32();
343 assert_eq!(amount, decoded.unwrap());
344
345 let amount: u64 = (1_u64 << 32) - 1;
347
348 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
349
350 let decoded = instance.decode_u32();
351 assert_eq!(amount, decoded.unwrap());
352 }
353}