rust_kzg_bn254_prover/srs.rs
1use ark_bn254::G1Affine;
2use crossbeam_channel::{bounded, Receiver};
3use rust_kzg_bn254_primitives::errors::KzgError;
4use rust_kzg_bn254_primitives::traits::ReadPointFromBytes;
5use std::fs::File;
6use std::io::{self, BufReader, Read};
7
8/// Represents the Structured Reference String (SRS) used in KZG commitments.
9#[derive(Debug, PartialEq, Clone)]
10pub struct SRS {
11 // SRS points are stored in monomial form, ready to be used for commitments with polynomials
12 // in coefficient form. To commit against a polynomial in evaluation form, we need to transform
13 // the SRS points to lagrange form using IFFT.
14 pub g1: Vec<G1Affine>,
15 /// The order of the SRS.
16 pub order: u32,
17}
18
19impl SRS {
20 /// Initializes the SRS by loading G1 points from a file.
21 ///
22 /// # Arguments
23 ///
24 /// * `path_to_g1_points` - The file path to load G1 points from.
25 /// * `order` - The total order of the SRS.
26 /// * `points_to_load` - The number of SRS points to load.
27 ///
28 /// # Returns
29 ///
30 /// * `Result<SRS, KzgError>` - The initialized SRS or an error.
31 pub fn new(path_to_g1_points: &str, order: u32, points_to_load: u32) -> Result<Self, KzgError> {
32 if points_to_load > order {
33 return Err(KzgError::GenericError(
34 "Number of points to load exceeds SRS order.".to_string(),
35 ));
36 }
37
38 let g1_points =
39 Self::parallel_read_g1_points(path_to_g1_points.to_owned(), points_to_load, false)
40 .map_err(|e| KzgError::SerializationError(e.to_string()))?;
41
42 Ok(Self {
43 g1: g1_points,
44 order,
45 })
46 }
47
48 pub fn process_chunks<T>(receiver: Receiver<(Vec<u8>, usize, bool)>) -> Vec<(T, usize)>
49 where
50 T: ReadPointFromBytes,
51 {
52 // TODO: should we use rayon to process this in parallel?
53 receiver
54 .iter()
55 .map(|(chunk, position, is_native)| {
56 let point: T = if is_native {
57 T::read_point_from_bytes_native_compressed(&chunk)
58 .expect("Failed to read point from bytes")
59 } else {
60 T::read_point_from_bytes_be(&chunk).expect("Failed to read point from bytes")
61 };
62 (point, position)
63 })
64 .collect()
65 }
66
67 /// Reads G1 points in parallel from a file.
68 ///
69 /// # Arguments
70 ///
71 /// * `file_path` - The path to the file containing G1 points.
72 /// * `points_to_load` - The number of points to load.
73 /// * `is_native` - Whether the points are in native Arkworks format.
74 ///
75 /// # Returns
76 ///
77 /// * `Result<Vec<G1Affine>, KzgError>` - The loaded G1 points or an error.
78 fn parallel_read_g1_points(
79 file_path: String,
80 points_to_load: u32,
81 is_native: bool,
82 ) -> Result<Vec<G1Affine>, KzgError> {
83 let (sender, receiver) = bounded::<(Vec<u8>, usize, bool)>(1000);
84
85 // Spawn the reader thread
86 let reader_handle = std::thread::spawn(
87 move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
88 Self::read_file_chunks(&file_path, sender, 32, points_to_load, is_native)
89 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
90 },
91 );
92
93 let num_workers = num_cpus::get();
94
95 let workers: Vec<_> = (0..num_workers)
96 .map(|_| {
97 let receiver = receiver.clone();
98 std::thread::spawn(move || Self::process_chunks::<G1Affine>(receiver))
99 })
100 .collect();
101
102 // Wait for the reader thread to finish
103 match reader_handle.join() {
104 Ok(result) => match result {
105 Ok(_) => {},
106 Err(e) => return Err(KzgError::GenericError(e.to_string())),
107 },
108 Err(_) => {
109 return Err(KzgError::GenericError(
110 "Reader thread panicked.".to_string(),
111 ))
112 },
113 }
114
115 // Collect and sort the results
116 let mut all_points = Vec::new();
117 for worker in workers {
118 let points = worker.join().expect("Worker thread panicked.");
119 all_points.extend(points);
120 }
121
122 // Sort by original position to maintain order
123 all_points.sort_by_key(|&(_, position)| position);
124
125 // Extract the G1Affine points
126 Ok(all_points.iter().map(|(point, _)| *point).collect())
127 }
128
129 /// Reads file chunks and sends them through a channel.
130 ///
131 /// # Arguments
132 ///
133 /// * `file_path` - Path to the file.
134 /// * `sender` - Channel sender to send read chunks.
135 /// * `point_size` - Size of each point in bytes.
136 /// * `num_points` - Number of points to read.
137 /// * `is_native` - Whether the points are in native format.
138 ///
139 /// # Returns
140 ///
141 /// * `io::Result<()>` - Ok if successful, or an I/O error.
142 /// TODO: chunks seems misleading here, since we read one field element at a time.
143 fn read_file_chunks(
144 file_path: &str,
145 sender: crossbeam_channel::Sender<(Vec<u8>, usize, bool)>,
146 point_size: usize,
147 num_points: u32,
148 is_native: bool,
149 ) -> io::Result<()> {
150 let file = File::open(file_path)?;
151 let mut reader = BufReader::new(file);
152 let mut position = 0;
153 let mut buffer = vec![0u8; point_size];
154
155 let mut i = 0;
156 // We are making one syscall per field element, which is super inefficient.
157 // FIXME: Read the entire file (or large segments) into memory and then split it
158 // into field elements. Entire G1 file might be ~8GiB, so might not fit
159 // in RAM. But we can only read the subset of the file that we need.
160 // For eg. for fault proof usage, only need to read 32MiB if our blob size is
161 // that large.
162 while let Ok(bytes_read) = reader.read(&mut buffer) {
163 if bytes_read == 0 {
164 break;
165 }
166 sender
167 .send((buffer[..bytes_read].to_vec(), position, is_native))
168 .unwrap();
169 position += bytes_read;
170 buffer.resize(point_size, 0); // Ensure the buffer is always the correct size
171 i += 1;
172 if num_points == i {
173 break;
174 }
175 }
176 Ok(())
177 }
178
179 /// read G1 points in parallel, by creating one reader thread, which reads
180 /// bytes from the file, and fans them out to worker threads (one per
181 /// cpu) which parse the bytes into G1Affine points. The worker threads
182 /// then fan in the parsed points to the main thread, which sorts them by
183 /// their original position in the file to maintain order. Not used anywhere
184 /// but kept as a reference.
185 ///
186 /// # Arguments
187 /// * `file_path` - The path to the file containing the G1 points
188 /// * `points_to_load` - The number of points to load from the file
189 /// * `is_native` - Whether the points are in native arkworks format or not
190 ///
191 /// # Returns
192 /// * `Ok(Vec<G1Affine>)` - The G1 points read from the file
193 /// * `Err(KzgError)` - An error occurred while reading the file
194 pub fn parallel_read_g1_points_native(
195 file_path: String,
196 points_to_load: u32,
197 is_native: bool,
198 ) -> Result<Vec<G1Affine>, KzgError> {
199 // Channel contains (bytes, position, is_native) tuples. The position is used to
200 // reorder the points after processing them.
201 let (sender, receiver) = bounded::<(Vec<u8>, usize, bool)>(1000);
202
203 // Spawning the reader thread
204 let reader_thread = std::thread::spawn(
205 move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
206 Self::read_file_chunks(&file_path, sender, 32, points_to_load, is_native)
207 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
208 },
209 );
210
211 let num_workers = num_cpus::get();
212
213 let workers: Vec<_> = (0..num_workers)
214 .map(|_| {
215 let receiver = receiver.clone();
216 std::thread::spawn(move || Self::process_chunks::<G1Affine>(receiver))
217 })
218 .collect();
219
220 // Wait for the reader thread to finish
221 match reader_thread.join() {
222 Ok(result) => match result {
223 Ok(_) => {},
224 Err(e) => return Err(KzgError::GenericError(e.to_string())),
225 },
226 Err(_) => return Err(KzgError::GenericError("Thread panicked".to_string())),
227 }
228
229 // Collect and sort results
230 let mut all_points = Vec::new();
231 for worker in workers {
232 let points = worker.join().expect("Worker thread panicked");
233 all_points.extend(points);
234 }
235
236 // Sort by original position to maintain order
237 all_points.sort_by_key(|&(_, position)| position);
238
239 Ok(all_points.iter().map(|(point, _)| *point).collect())
240 }
241}