solana_zk_sdk/encryption/
discrete_log.rs

1//! The discrete log implementation for the twisted ElGamal decryption.
2//!
3//! The implementation uses the baby-step giant-step method, which consists of a precomputation
4//! step and an online step. The precomputation step involves computing a hash table of a number
5//! of Ristretto points that is independent of a discrete log instance. The online phase computes
6//! the final discrete log solution using the discrete log instance and the pre-computed hash
7//! table. More details on the baby-step giant-step algorithm and the implementation can be found
8//! in the [spl documentation](https://spl.solana.com).
9//!
10//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
11//! straightforward timing attacks. For instance, it does not short-circuit the search when a
12//! solution is found. However, the use of hashtables, batching, and threads make the
13//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
14//! information on a discrete log solution depending on the execution time of the implementation.
15//!
16
17#[cfg(not(target_arch = "wasm32"))]
18use std::thread;
19use {
20    crate::RISTRETTO_POINT_LEN,
21    curve25519_dalek::{
22        constants::RISTRETTO_BASEPOINT_POINT as G,
23        ristretto::RistrettoPoint,
24        scalar::Scalar,
25        traits::{Identity, IsIdentity},
26    },
27    itertools::Itertools,
28    serde::{Deserialize, Serialize},
29    std::{collections::HashMap, num::NonZeroUsize},
30    thiserror::Error,
31};
32
33const TWO16: u64 = 65536; // 2^16
34const TWO17: u64 = 131072; // 2^17
35
36/// Maximum number of threads permitted for discrete log computation
37#[cfg(not(target_arch = "wasm32"))]
38const MAX_THREAD: usize = 65536;
39
40#[derive(Error, Clone, Debug, Eq, PartialEq)]
41pub enum DiscreteLogError {
42    #[error("discrete log number of threads not power-of-two")]
43    DiscreteLogThreads,
44    #[error("discrete log batch size too large")]
45    DiscreteLogBatchSize,
46}
47
48/// Type that captures a discrete log challenge.
49///
50/// The goal of discrete log is to find x such that x * generator = target.
51#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
52pub struct DiscreteLog {
53    /// Generator point for discrete log
54    pub generator: RistrettoPoint,
55    /// Target point for discrete log
56    pub target: RistrettoPoint,
57    /// Number of threads used for discrete log computation
58    num_threads: Option<NonZeroUsize>,
59    /// Range bound for discrete log search derived from the max value to search for and
60    /// `num_threads`
61    range_bound: NonZeroUsize,
62    /// Ristretto point representing each step of the discrete log search
63    step_point: RistrettoPoint,
64    /// Ristretto point compression batch size
65    compression_batch_size: NonZeroUsize,
66}
67
68#[derive(Serialize, Deserialize, Default)]
69pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
70
71/// Builds a HashMap of 2^16 elements
72#[allow(dead_code)]
73fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
74    let mut hashmap = HashMap::new();
75
76    let two17_scalar = Scalar::from(TWO17);
77    let identity = RistrettoPoint::identity(); // 0 * G
78    let generator = two17_scalar * generator; // 2^17 * G
79
80    // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
81    let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
82    for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
83        let key = point.compress().to_bytes();
84        hashmap.insert(key, x_hi as u16);
85    }
86
87    DecodePrecomputation(hashmap)
88}
89
90/// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key.
91pub static DECODE_PRECOMPUTATION_FOR_G: std::sync::LazyLock<DecodePrecomputation> =
92    std::sync::LazyLock::new(|| {
93        static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
94            include_bytes!("decode_u32_precomputation_for_G.bincode");
95        bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
96    });
97
98/// Solves the discrete log instance using a 16/16 bit offline/online split
99impl DiscreteLog {
100    /// Discrete log instance constructor.
101    ///
102    /// Default number of threads set to 1.
103    pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
104        Self {
105            generator,
106            target,
107            num_threads: None,
108            range_bound: (TWO16 as usize).try_into().unwrap(),
109            step_point: G,
110            compression_batch_size: 32.try_into().unwrap(),
111        }
112    }
113
114    /// Adjusts number of threads in a discrete log instance.
115    #[cfg(not(target_arch = "wasm32"))]
116    pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
117        // number of threads must be a positive power-of-two integer
118        if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
119            return Err(DiscreteLogError::DiscreteLogThreads);
120        }
121
122        self.num_threads = Some(num_threads);
123        self.range_bound = (TWO16 as usize)
124            .checked_div(num_threads.get())
125            .and_then(|range_bound| range_bound.try_into().ok())
126            .unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
127        self.step_point = Scalar::from(num_threads.get() as u64) * G;
128
129        Ok(())
130    }
131
132    /// Adjusts inversion batch size in a discrete log instance.
133    pub fn set_compression_batch_size(
134        &mut self,
135        compression_batch_size: NonZeroUsize,
136    ) -> Result<(), DiscreteLogError> {
137        if compression_batch_size.get() >= TWO16 as usize {
138            return Err(DiscreteLogError::DiscreteLogBatchSize);
139        }
140        self.compression_batch_size = compression_batch_size;
141
142        Ok(())
143    }
144
145    /// Solves the discrete log problem under the assumption that the solution
146    /// is a positive 32-bit number.
147    pub fn decode_u32(self) -> Option<u64> {
148        #[allow(unused_variables)]
149        if let Some(num_threads) = self.num_threads {
150            #[cfg(not(target_arch = "wasm32"))]
151            {
152                let mut starting_point = self.target;
153                let handles = (0..num_threads.get())
154                    .map(|i| {
155                        let ristretto_iterator = RistrettoIterator::new(
156                            (starting_point, i as u64),
157                            (-(&self.step_point), num_threads.get() as u64),
158                        );
159
160                        let handle = thread::spawn(move || {
161                            Self::decode_range(
162                                ristretto_iterator,
163                                self.range_bound,
164                                self.compression_batch_size,
165                            )
166                        });
167
168                        starting_point -= G;
169                        handle
170                    })
171                    .collect::<Vec<_>>();
172
173                handles
174                    .into_iter()
175                    .map_while(|h| h.join().ok())
176                    .find(|x| x.is_some())
177                    .flatten()
178            }
179            #[cfg(target_arch = "wasm32")]
180            unreachable!() // `self.num_threads` always `None` on wasm target
181        } else {
182            let ristretto_iterator =
183                RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
184
185            Self::decode_range(
186                ristretto_iterator,
187                self.range_bound,
188                self.compression_batch_size,
189            )
190        }
191    }
192
193    fn decode_range(
194        ristretto_iterator: RistrettoIterator,
195        range_bound: NonZeroUsize,
196        compression_batch_size: NonZeroUsize,
197    ) -> Option<u64> {
198        let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
199        let mut decoded = None;
200
201        for batch in &ristretto_iterator
202            .take(range_bound.get())
203            .chunks(compression_batch_size.get())
204        {
205            // batch compression currently errors if any point in the batch is the identity point
206            let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
207                .filter(|(point, index)| {
208                    if point.is_identity() {
209                        decoded = Some(*index);
210                        return false;
211                    }
212                    true
213                })
214                .unzip();
215
216            let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
217
218            for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
219                let key = point.to_bytes();
220                if hashmap.0.contains_key(&key) {
221                    let x_hi = hashmap.0[&key];
222                    decoded = Some(x_lo + TWO16 * x_hi as u64);
223                }
224            }
225        }
226
227        decoded
228    }
229}
230
231/// Hashable Ristretto iterator.
232///
233/// Given an initial point X and a stepping point P, the iterator iterates through
234/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
235struct RistrettoIterator {
236    pub current: (RistrettoPoint, u64),
237    pub step: (RistrettoPoint, u64),
238}
239
240impl RistrettoIterator {
241    fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
242        RistrettoIterator { current, step }
243    }
244}
245
246impl Iterator for RistrettoIterator {
247    type Item = (RistrettoPoint, u64);
248
249    fn next(&mut self) -> Option<Self::Item> {
250        let r = self.current;
251        self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
252        Some(r)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use {super::*, std::time::Instant};
259
260    #[test]
261    #[allow(non_snake_case)]
262    fn test_serialize_decode_u32_precomputation_for_G() {
263        let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
264        // let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
265
266        if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
267            use std::{fs::File, io::Write, path::PathBuf};
268            let mut f = File::create(PathBuf::from(
269                "src/encryption/decode_u32_precomputation_for_G.bincode",
270            ))
271            .unwrap();
272            f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
273                .unwrap();
274            panic!("Rebuild and run this test again");
275        }
276    }
277
278    #[test]
279    fn test_decode_correctness() {
280        // general case
281        let amount: u64 = 4294967295;
282
283        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
284
285        // Very informal measurements for now
286        let start_computation = Instant::now();
287        let decoded = instance.decode_u32();
288        let computation_secs = start_computation.elapsed().as_secs_f64();
289
290        assert_eq!(amount, decoded.unwrap());
291
292        println!("single thread discrete log computation secs: {computation_secs:?} sec");
293    }
294
295    #[cfg(not(target_arch = "wasm32"))]
296    #[test]
297    fn test_decode_correctness_threaded() {
298        // general case
299        let amount: u64 = 55;
300
301        let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
302        instance.num_threads(4.try_into().unwrap()).unwrap();
303
304        // Very informal measurements for now
305        let start_computation = Instant::now();
306        let decoded = instance.decode_u32();
307        let computation_secs = start_computation.elapsed().as_secs_f64();
308
309        assert_eq!(amount, decoded.unwrap());
310
311        println!("4 thread discrete log computation: {computation_secs:?} sec");
312
313        // amount 0
314        let amount: u64 = 0;
315
316        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
317
318        let decoded = instance.decode_u32();
319        assert_eq!(amount, decoded.unwrap());
320
321        // amount 1
322        let amount: u64 = 1;
323
324        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
325
326        let decoded = instance.decode_u32();
327        assert_eq!(amount, decoded.unwrap());
328
329        // amount 2
330        let amount: u64 = 2;
331
332        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
333
334        let decoded = instance.decode_u32();
335        assert_eq!(amount, decoded.unwrap());
336
337        // amount 3
338        let amount: u64 = 3;
339
340        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
341
342        let decoded = instance.decode_u32();
343        assert_eq!(amount, decoded.unwrap());
344
345        // max amount
346        let amount: u64 = (1_u64 << 32) - 1;
347
348        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
349
350        let decoded = instance.decode_u32();
351        assert_eq!(amount, decoded.unwrap());
352    }
353}