1use rayon::prelude::*;
2use serde::{Deserialize, Serialize};
3use slop_air::BaseAir;
4use std::{
5 collections::{BTreeMap, BTreeSet},
6 future::Future,
7 ops::{Deref, DerefMut},
8 sync::Arc,
9};
10use tracing::Instrument;
11
12use slop_algebra::Field;
13use slop_alloc::{Backend, CanCopyFrom, CpuBackend, GLOBAL_CPU_BACKEND};
14use slop_multilinear::{Mle, PaddedMle};
15use slop_tensor::Tensor;
16use tokio::sync::oneshot;
17
18use crate::{air::MachineAir, Machine, MachineRecord};
19
20use super::{MainTraceData, PreprocessedTraceData, ProverSemaphore, TraceData};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(bound(serialize = "Tensor<F, B>: Serialize, F: Serialize, B: Serialize, "))]
25#[serde(bound(
26 deserialize = "Tensor<F, B>: Deserialize<'de>, F: Deserialize<'de>, B: Deserialize<'de>, "
27))]
28pub struct Traces<F, B: Backend> {
29 pub named_traces: BTreeMap<String, PaddedMle<F, B>>,
31}
32
33impl<F, B: Backend> IntoIterator for Traces<F, B> {
34 type Item = (String, PaddedMle<F, B>);
35 type IntoIter = <BTreeMap<String, PaddedMle<F, B>> as IntoIterator>::IntoIter;
36
37 fn into_iter(self) -> Self::IntoIter {
38 self.named_traces.into_iter()
39 }
40}
41
42impl<F, B: Backend> Deref for Traces<F, B> {
43 type Target = BTreeMap<String, PaddedMle<F, B>>;
44
45 fn deref(&self) -> &Self::Target {
46 &self.named_traces
47 }
48}
49
50impl<F, B: Backend> DerefMut for Traces<F, B> {
51 fn deref_mut(&mut self) -> &mut Self::Target {
52 &mut self.named_traces
53 }
54}
55
56pub trait TraceGenerator<F: Field, A: MachineAir<F>, B: Backend>: 'static + Send + Sync {
61 fn machine(&self) -> &Machine<F, A>;
63
64 fn allocator(&self) -> &B;
66
67 fn generate_preprocessed_traces(
69 &self,
70 program: Arc<A::Program>,
71 max_log_row_count: usize,
72 setup_permits: ProverSemaphore,
73 ) -> impl Future<Output = PreprocessedTraceData<F, B>> + Send;
74
75 fn generate_main_traces(
77 &self,
78 record: A::Record,
79 max_log_row_count: usize,
80 prover_permits: ProverSemaphore,
81 ) -> impl Future<Output = MainTraceData<F, A, B>> + Send;
82
83 fn generate_traces(
85 &self,
86 program: Arc<A::Program>,
87 record: A::Record,
88 max_log_row_count: usize,
89 prover_permits: ProverSemaphore,
90 ) -> impl Future<Output = TraceData<F, A, B>> + Send;
91}
92
93pub struct DefaultTraceGenerator<F: Field, A, B = CpuBackend> {
95 machine: Machine<F, A>,
96 trace_allocator: B,
97}
98
99impl<F: Field, A: MachineAir<F>, B: Backend> DefaultTraceGenerator<F, A, B> {
100 #[must_use]
102 pub fn new_in(machine: Machine<F, A>, trace_allocator: B) -> Self {
103 Self { machine, trace_allocator }
104 }
105}
106
107impl<F: Field, A: MachineAir<F>> DefaultTraceGenerator<F, A, CpuBackend> {
108 #[must_use]
110 pub fn new(machine: Machine<F, A>) -> Self {
111 Self { machine, trace_allocator: GLOBAL_CPU_BACKEND }
112 }
113}
114
115impl<F: Field, A: MachineAir<F>> TraceGenerator<F, A, CpuBackend>
116 for DefaultTraceGenerator<F, A, CpuBackend>
117{
118 fn machine(&self) -> &Machine<F, A> {
119 &self.machine
120 }
121
122 fn allocator(&self) -> &CpuBackend {
123 &self.trace_allocator
124 }
125
126 async fn generate_main_traces(
127 &self,
128 record: A::Record,
129 max_log_row_count: usize,
130 prover_permits: ProverSemaphore,
131 ) -> MainTraceData<F, A, CpuBackend> {
132 let airs = self.machine.chips().to_vec();
133 let (tx, rx) = oneshot::channel();
134 slop_futures::rayon::spawn(move || {
136 let chips_and_traces = airs
137 .into_par_iter()
138 .filter(|air| air.included(&record))
139 .map(|air| {
140 let trace = air.generate_trace(&record, &mut A::Record::default());
141 let trace = Mle::from(trace);
142 (air, trace)
143 })
144 .collect::<BTreeMap<_, _>>();
145
146 let public_values = record.public_values::<F>();
148
149 tx.send((chips_and_traces, public_values)).ok().unwrap();
150 drop(record);
152 });
153 let (chips_and_traces, public_values) = rx.await.unwrap();
155
156 let chip_set = chips_and_traces.keys().cloned().collect::<BTreeSet<_>>();
157 let shard_chips = self.machine.smallest_cluster(&chip_set).unwrap().clone();
158
159 let permit = prover_permits
161 .acquire()
162 .instrument(tracing::debug_span!("acquire prover"))
163 .await
164 .unwrap();
165 let padded_traces = shard_chips
169 .iter()
170 .filter(|chip| !chips_and_traces.contains_key(chip))
171 .map(|chip| {
172 let num_polynomials = chip.width();
173 (
174 chip.name().to_string(),
175 PaddedMle::zeros(num_polynomials, max_log_row_count as u32),
176 )
177 })
178 .collect::<BTreeMap<_, _>>();
179
180 let real_traces = chips_and_traces
182 .into_iter()
183 .map(|(chip, trace)| {
184 let trace = self.trace_allocator.copy_into(trace).unwrap();
185 let mle = Arc::new(trace);
186 (
187 chip.name().to_string(),
188 PaddedMle::padded_with_zeros(mle, max_log_row_count as u32),
189 )
190 })
191 .collect::<Vec<_>>();
192
193 let mut traces = padded_traces;
194
195 for (name, trace) in real_traces {
196 traces.insert(name, trace);
197 }
198
199 let traces = Traces { named_traces: traces };
200
201 MainTraceData { traces, public_values, shard_chips, permit }
202 }
203
204 async fn generate_preprocessed_traces(
205 &self,
206 program: Arc<A::Program>,
207 max_log_row_count: usize,
208 setup_permits: ProverSemaphore,
209 ) -> PreprocessedTraceData<F, CpuBackend> {
210 let airs = self.machine.chips().iter().map(|chip| chip.air.clone()).collect::<Vec<_>>();
212 let (tx, rx) = oneshot::channel();
213 slop_futures::rayon::spawn(move || {
215 let named_preprocessed_traces = airs
216 .par_iter()
217 .filter_map(|air| {
218 let name = air.name().to_string();
219 let trace = air.generate_preprocessed_trace(&program);
220 trace.map(Mle::from).map(|tr| (name, tr))
221 })
222 .collect::<BTreeMap<_, _>>();
223 tx.send(named_preprocessed_traces).ok().unwrap();
224 });
225
226 let named_preprocessed_traces = rx.await.unwrap();
229
230 let permit = setup_permits
232 .acquire()
233 .instrument(tracing::debug_span!("acquire setup"))
234 .await
235 .unwrap();
236
237 let named_traces = named_preprocessed_traces
239 .into_iter()
240 .map(|(name, trace)| {
241 let trace = self.trace_allocator.copy_into(trace).unwrap();
242 let padded_mle =
243 PaddedMle::padded_with_zeros(Arc::new(trace), max_log_row_count as u32);
244 (name, padded_mle)
245 })
246 .collect::<BTreeMap<_, _>>();
247
248 let traces = Traces { named_traces };
249
250 PreprocessedTraceData { preprocessed_traces: traces, permit }
251 }
252
253 async fn generate_traces(
254 &self,
255 program: Arc<A::Program>,
256 record: A::Record,
257 max_log_row_count: usize,
258 prover_permits: ProverSemaphore,
259 ) -> TraceData<F, A, CpuBackend> {
260 let airs = self.machine.chips().to_vec();
261 let (tx, rx) = oneshot::channel();
262 slop_futures::rayon::spawn(move || {
264 let named_preprocessed_traces = airs
265 .par_iter()
266 .filter_map(|air| {
267 let name = air.name().to_string();
268 let trace = air.generate_preprocessed_trace(&program);
269 trace.map(Mle::from).map(|tr| (name, tr))
270 })
271 .collect::<BTreeMap<_, _>>();
272
273 let chips_and_traces = airs
274 .into_par_iter()
275 .filter(|air| air.included(&record))
276 .map(|air| {
277 let trace = air.generate_trace(&record, &mut A::Record::default());
278 let trace = Mle::from(trace);
279 (air, trace)
280 })
281 .collect::<BTreeMap<_, _>>();
282
283 let public_values = record.public_values::<F>();
285 tx.send((named_preprocessed_traces, chips_and_traces, public_values)).ok().unwrap();
286 drop(record);
288 });
289 let (named_preprocessed_traces, chips_and_traces, public_values) = rx.await.unwrap();
291
292 let chip_set = chips_and_traces.keys().cloned().collect::<BTreeSet<_>>();
293 let shard_chips = self.machine.smallest_cluster(&chip_set).unwrap().clone();
294
295 let padded_traces = shard_chips
297 .iter()
298 .filter(|chip| !chips_and_traces.contains_key(chip))
299 .map(|chip| {
300 let num_polynomials = chip.width();
301 (
302 chip.name().to_string(),
303 PaddedMle::zeros(num_polynomials, max_log_row_count as u32),
304 )
305 })
306 .collect::<BTreeMap<_, _>>();
307
308 let permit = prover_permits
310 .acquire()
311 .instrument(tracing::debug_span!("acquire prover"))
312 .await
313 .unwrap();
314
315 let preprocessed_traces = named_preprocessed_traces
317 .into_iter()
318 .map(|(name, trace)| {
319 let trace = self.trace_allocator.copy_into(trace).unwrap();
320 let padded_mle =
321 PaddedMle::padded_with_zeros(Arc::new(trace), max_log_row_count as u32);
322 (name, padded_mle)
323 })
324 .collect::<BTreeMap<_, _>>();
325
326 let preprocessed_traces = Traces { named_traces: preprocessed_traces };
327
328 let real_traces = chips_and_traces
330 .into_iter()
331 .map(|(chip, trace)| {
332 let trace = self.trace_allocator.copy_into(trace).unwrap();
333 let mle = Arc::new(trace);
334 (
335 chip.name().to_string(),
336 PaddedMle::padded_with_zeros(mle, max_log_row_count as u32),
337 )
338 })
339 .collect::<Vec<_>>();
340
341 let mut traces = padded_traces;
342
343 for (name, trace) in real_traces {
344 traces.insert(name, trace);
345 }
346
347 let traces = Traces { named_traces: traces };
348
349 let main_trace_data = MainTraceData { traces, public_values, shard_chips, permit };
350
351 TraceData { preprocessed_traces, main_trace_data }
352 }
353}