rust_kzg_bn254_prover/
srs.rs

1use ark_bn254::G1Affine;
2use crossbeam_channel::{bounded, Receiver};
3use rust_kzg_bn254_primitives::errors::KzgError;
4use rust_kzg_bn254_primitives::traits::ReadPointFromBytes;
5use std::fs::File;
6use std::io::{self, BufReader, Read};
7
8/// Represents the Structured Reference String (SRS) used in KZG commitments.
9#[derive(Debug, PartialEq, Clone)]
10pub struct SRS {
11    // SRS points are stored in monomial form, ready to be used for commitments with polynomials
12    // in coefficient form. To commit against a polynomial in evaluation form, we need to transform
13    // the SRS points to lagrange form using IFFT.
14    pub g1: Vec<G1Affine>,
15    /// The order of the SRS.
16    pub order: u32,
17}
18
19impl SRS {
20    /// Initializes the SRS by loading G1 points from a file.
21    ///
22    /// # Arguments
23    ///
24    /// * `path_to_g1_points` - The file path to load G1 points from.
25    /// * `order` - The total order of the SRS.
26    /// * `points_to_load` - The number of SRS points to load.
27    ///
28    /// # Returns
29    ///
30    /// * `Result<SRS, KzgError>` - The initialized SRS or an error.
31    pub fn new(path_to_g1_points: &str, order: u32, points_to_load: u32) -> Result<Self, KzgError> {
32        if points_to_load > order {
33            return Err(KzgError::GenericError(
34                "Number of points to load exceeds SRS order.".to_string(),
35            ));
36        }
37
38        let g1_points =
39            Self::parallel_read_g1_points(path_to_g1_points.to_owned(), points_to_load, false)
40                .map_err(|e| KzgError::SerializationError(e.to_string()))?;
41
42        Ok(Self {
43            g1: g1_points,
44            order,
45        })
46    }
47
48    pub fn process_chunks<T>(receiver: Receiver<(Vec<u8>, usize, bool)>) -> Vec<(T, usize)>
49    where
50        T: ReadPointFromBytes,
51    {
52        // TODO: should we use rayon to process this in parallel?
53        receiver
54            .iter()
55            .map(|(chunk, position, is_native)| {
56                let point: T = if is_native {
57                    T::read_point_from_bytes_native_compressed(&chunk)
58                        .expect("Failed to read point from bytes")
59                } else {
60                    T::read_point_from_bytes_be(&chunk).expect("Failed to read point from bytes")
61                };
62                (point, position)
63            })
64            .collect()
65    }
66
67    /// Reads G1 points in parallel from a file.
68    ///
69    /// # Arguments
70    ///
71    /// * `file_path` - The path to the file containing G1 points.
72    /// * `points_to_load` - The number of points to load.
73    /// * `is_native` - Whether the points are in native Arkworks format.
74    ///
75    /// # Returns
76    ///
77    /// * `Result<Vec<G1Affine>, KzgError>` - The loaded G1 points or an error.
78    fn parallel_read_g1_points(
79        file_path: String,
80        points_to_load: u32,
81        is_native: bool,
82    ) -> Result<Vec<G1Affine>, KzgError> {
83        let (sender, receiver) = bounded::<(Vec<u8>, usize, bool)>(1000);
84
85        // Spawn the reader thread
86        let reader_handle = std::thread::spawn(
87            move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
88                Self::read_file_chunks(&file_path, sender, 32, points_to_load, is_native)
89                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
90            },
91        );
92
93        let num_workers = num_cpus::get();
94
95        let workers: Vec<_> = (0..num_workers)
96            .map(|_| {
97                let receiver = receiver.clone();
98                std::thread::spawn(move || Self::process_chunks::<G1Affine>(receiver))
99            })
100            .collect();
101
102        // Wait for the reader thread to finish
103        match reader_handle.join() {
104            Ok(result) => match result {
105                Ok(_) => {},
106                Err(e) => return Err(KzgError::GenericError(e.to_string())),
107            },
108            Err(_) => {
109                return Err(KzgError::GenericError(
110                    "Reader thread panicked.".to_string(),
111                ))
112            },
113        }
114
115        // Collect and sort the results
116        let mut all_points = Vec::new();
117        for worker in workers {
118            let points = worker.join().expect("Worker thread panicked.");
119            all_points.extend(points);
120        }
121
122        // Sort by original position to maintain order
123        all_points.sort_by_key(|&(_, position)| position);
124
125        // Extract the G1Affine points
126        Ok(all_points.iter().map(|(point, _)| *point).collect())
127    }
128
129    /// Reads file chunks and sends them through a channel.
130    ///
131    /// # Arguments
132    ///
133    /// * `file_path` - Path to the file.
134    /// * `sender` - Channel sender to send read chunks.
135    /// * `point_size` - Size of each point in bytes.
136    /// * `num_points` - Number of points to read.
137    /// * `is_native` - Whether the points are in native format.
138    ///
139    /// # Returns
140    ///
141    /// * `io::Result<()>` - Ok if successful, or an I/O error.
142    ///    TODO: chunks seems misleading here, since we read one field element at a time.
143    fn read_file_chunks(
144        file_path: &str,
145        sender: crossbeam_channel::Sender<(Vec<u8>, usize, bool)>,
146        point_size: usize,
147        num_points: u32,
148        is_native: bool,
149    ) -> io::Result<()> {
150        let file = File::open(file_path)?;
151        let mut reader = BufReader::new(file);
152        let mut position = 0;
153        let mut buffer = vec![0u8; point_size];
154
155        let mut i = 0;
156        // We are making one syscall per field element, which is super inefficient.
157        // FIXME: Read the entire file (or large segments) into memory and then split it
158        // into field elements. Entire G1 file might be ~8GiB, so might not fit
159        // in RAM. But we can only read the subset of the file that we need.
160        // For eg. for fault proof usage, only need to read 32MiB if our blob size is
161        // that large.
162        while let Ok(bytes_read) = reader.read(&mut buffer) {
163            if bytes_read == 0 {
164                break;
165            }
166            sender
167                .send((buffer[..bytes_read].to_vec(), position, is_native))
168                .unwrap();
169            position += bytes_read;
170            buffer.resize(point_size, 0); // Ensure the buffer is always the correct size
171            i += 1;
172            if num_points == i {
173                break;
174            }
175        }
176        Ok(())
177    }
178
179    /// read G1 points in parallel, by creating one reader thread, which reads
180    /// bytes from the file, and fans them out to worker threads (one per
181    /// cpu) which parse the bytes into G1Affine points. The worker threads
182    /// then fan in the parsed points to the main thread, which sorts them by
183    /// their original position in the file to maintain order. Not used anywhere
184    /// but kept as a reference.
185    ///
186    /// # Arguments
187    /// * `file_path` - The path to the file containing the G1 points
188    /// * `points_to_load` - The number of points to load from the file
189    /// * `is_native` - Whether the points are in native arkworks format or not
190    ///
191    /// # Returns
192    /// * `Ok(Vec<G1Affine>)` - The G1 points read from the file
193    /// * `Err(KzgError)` - An error occurred while reading the file
194    pub fn parallel_read_g1_points_native(
195        file_path: String,
196        points_to_load: u32,
197        is_native: bool,
198    ) -> Result<Vec<G1Affine>, KzgError> {
199        // Channel contains (bytes, position, is_native) tuples. The position is used to
200        // reorder the points after processing them.
201        let (sender, receiver) = bounded::<(Vec<u8>, usize, bool)>(1000);
202
203        // Spawning the reader thread
204        let reader_thread = std::thread::spawn(
205            move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
206                Self::read_file_chunks(&file_path, sender, 32, points_to_load, is_native)
207                    .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
208            },
209        );
210
211        let num_workers = num_cpus::get();
212
213        let workers: Vec<_> = (0..num_workers)
214            .map(|_| {
215                let receiver = receiver.clone();
216                std::thread::spawn(move || Self::process_chunks::<G1Affine>(receiver))
217            })
218            .collect();
219
220        // Wait for the reader thread to finish
221        match reader_thread.join() {
222            Ok(result) => match result {
223                Ok(_) => {},
224                Err(e) => return Err(KzgError::GenericError(e.to_string())),
225            },
226            Err(_) => return Err(KzgError::GenericError("Thread panicked".to_string())),
227        }
228
229        // Collect and sort results
230        let mut all_points = Vec::new();
231        for worker in workers {
232            let points = worker.join().expect("Worker thread panicked");
233            all_points.extend(points);
234        }
235
236        // Sort by original position to maintain order
237        all_points.sort_by_key(|&(_, position)| position);
238
239        Ok(all_points.iter().map(|(point, _)| *point).collect())
240    }
241}