Skip to main content

sp1_gpu_tracegen/
lib.rs

1mod recursion;
2mod riscv;
3
4use core::future::{ready, Future};
5use core::pin::pin;
6use std::collections::BTreeSet;
7use std::{collections::BTreeMap, sync::Arc};
8
9use futures::stream::FuturesUnordered;
10use futures::{join, StreamExt};
11use rayon::prelude::*;
12use slop_air::BaseAir;
13use slop_algebra::Field;
14use slop_alloc::mem::CopyError;
15use slop_multilinear::{Mle, PaddedMle};
16use sp1_gpu_cudart::{DeviceMle, DeviceTransposeKernel, TaskScope};
17use sp1_hypercube::prover::{MainTraceData, PreprocessedTraceData, ProverSemaphore, TraceData};
18use sp1_hypercube::{
19    air::MachineAir,
20    prover::{TraceGenerator, Traces},
21    Machine,
22};
23
24use sp1_hypercube::{Chip, MachineRecord};
25use sp1_primitives::SP1Field;
26use tracing::{debug_span, instrument, Instrument};
27
28/// We currently only link to KoalaBear-specialized trace generation FFI.
29pub(crate) type F = SP1Field;
30
31/// A trace generator that is GPU accelerated.
32pub struct CudaTraceGenerator<F: Field, A> {
33    machine: Machine<F, A>,
34    trace_allocator: TaskScope,
35}
36
37impl<A: MachineAir<F>> CudaTraceGenerator<F, A> {
38    /// Create a new trace generator.
39    #[must_use]
40    pub fn new_in(machine: Machine<F, A>, trace_allocator: TaskScope) -> Self {
41        Self { machine, trace_allocator }
42    }
43}
44
45/// TODO(tqn) documentation
46struct HostPhaseTracegen<F, A> {
47    pub device_airs: Vec<Arc<A>>,
48    pub host_traces: futures::channel::mpsc::UnboundedReceiver<(String, Mle<F>)>,
49}
50
51/// TODO(tqn) documentation
52struct HostPhaseShapePadding<F: Field, A> {
53    pub shard_chips: BTreeSet<Chip<F, A>>,
54    pub padded_traces: BTreeMap<String, PaddedMle<F, TaskScope>>,
55}
56
57impl<F, A> CudaTraceGenerator<F, A>
58where
59    F: Field,
60    A: CudaTracegenAir<F>,
61    TaskScope: DeviceTransposeKernel<F>,
62{
63    /// TODO(tqn) documentation
64    #[instrument(skip_all, level = "debug")]
65    fn host_preprocessed_tracegen(
66        &self,
67        program: Arc<<A as MachineAir<F>>::Program>,
68    ) -> HostPhaseTracegen<F, A> {
69        // Split chips based on where we will generate their traces.
70        let (device_airs, host_airs): (Vec<_>, Vec<_>) = self
71            .machine
72            .chips()
73            .iter()
74            .map(|chip| chip.air.clone())
75            .partition(|air| air.supports_device_preprocessed_tracegen());
76
77        // Spawn a rayon task to generate the traces on the CPU.
78        // `traces` is a futures Stream that will immediately begin buffering traces.
79        let (host_traces_tx, host_traces) = futures::channel::mpsc::unbounded();
80        slop_futures::rayon::spawn(move || {
81            host_airs.into_par_iter().for_each_with(host_traces_tx, |tx, air| {
82                if let Some(trace) = air.generate_preprocessed_trace(&program) {
83                    tx.unbounded_send((air.name().to_string(), Mle::from(trace))).unwrap();
84                }
85            });
86            // Make this explicit.
87            // If we are the last users of the program, this will expensively drop it.
88            drop(program);
89        });
90        HostPhaseTracegen { device_airs, host_traces }
91    }
92
93    #[instrument(skip_all, level = "debug")]
94    async fn device_preprocessed_tracegen(
95        &self,
96        program: Arc<<A as MachineAir<F>>::Program>,
97        max_log_row_count: usize,
98        host_phase_tracegen: HostPhaseTracegen<F, A>,
99    ) -> Traces<F, TaskScope> {
100        let HostPhaseTracegen { device_airs, host_traces } = host_phase_tracegen;
101
102        // Stream that, when polled, copies the host traces to the device.
103        let copied_host_traces = pin!(host_traces.then(|(name, trace)| async move {
104            (name, DeviceMle::from_host(&trace, &self.trace_allocator).unwrap().into())
105        }));
106        // Stream that, when polled, copies events to the device and generates traces.
107        let device_traces = device_airs
108            .into_iter()
109            .map(|air| {
110                // We want to borrow the program and move the air.
111                let program = program.as_ref();
112                async move {
113                    let maybe_trace = air
114                        .generate_preprocessed_trace_device(program, &self.trace_allocator)
115                        .await
116                        .unwrap();
117                    (air, maybe_trace)
118                }
119            })
120            .collect::<FuturesUnordered<_>>()
121            .filter_map(|(air, maybe_trace)| {
122                ready(maybe_trace.map(|trace| (air.name().to_string(), trace.into())))
123            });
124
125        let named_traces = futures::stream_select!(copied_host_traces, device_traces)
126            .map(|(name, trace)| {
127                (name, PaddedMle::padded_with_zeros(Arc::new(trace), max_log_row_count as u32))
128            })
129            .collect::<BTreeMap<_, _>>()
130            .await;
131
132        // If we're the last users of the program, expensively drop it in a separate task.
133        // TODO: in general, figure out the best way to drop expensive-to-drop things.
134        rayon::spawn(move || drop(program));
135
136        Traces { named_traces }
137    }
138
139    /// TODO(tqn) documentation
140    #[instrument(skip_all, level = "debug")]
141    fn host_main_tracegen(
142        &self,
143        record: Arc<<A as MachineAir<F>>::Record>,
144        max_log_row_count: usize,
145    ) -> (HostPhaseTracegen<F, A>, HostPhaseShapePadding<F, A>)
146    where
147        F: Field,
148        A: CudaTracegenAir<F>,
149    {
150        // Set of chips we need to generate traces for.
151        let chip_set = self
152            .machine
153            .chips()
154            .iter()
155            .filter(|chip| chip.included(&record))
156            .cloned()
157            .collect::<BTreeSet<_>>();
158
159        // Split chips based on where we will generate their traces.
160        let (device_airs, host_airs): (Vec<_>, Vec<_>) = chip_set
161            .iter()
162            .map(|chip| chip.air.clone())
163            .partition(|c| c.supports_device_main_tracegen());
164
165        // Spawn a rayon task to generate the traces on the CPU.
166        // `host_traces` is a futures Stream that will immediately begin buffering traces.
167        let (host_traces_tx, host_traces) = futures::channel::mpsc::unbounded();
168        slop_futures::rayon::spawn(move || {
169            host_airs.into_par_iter().for_each_with(host_traces_tx, |tx, air| {
170                let trace = Mle::from(air.generate_trace(&record, &mut A::Record::default()));
171                // Since it's unbounded, it will only error if the receiver is disconnected.
172                tx.unbounded_send((air.name().to_string(), trace)).unwrap();
173            });
174            // Make this explicit.
175            // If we are the last users of the record, this will expensively drop it.
176            drop(record);
177        });
178
179        // Get the smallest cluster containing our tracegen chip set.
180        let shard_chips = self.machine.smallest_cluster(&chip_set).unwrap().clone();
181        // For every AIR in the cluster, make a (virtual) padded trace.
182        let padded_traces = shard_chips
183            .iter()
184            .filter(|chip| !chip_set.contains(chip))
185            .map(|chip| {
186                let num_polynomials = chip.width();
187                (
188                    chip.name().to_string(),
189                    PaddedMle::zeros_in(
190                        num_polynomials,
191                        max_log_row_count as u32,
192                        self.trace_allocator.clone(),
193                    ),
194                )
195            })
196            .collect::<BTreeMap<_, _>>();
197
198        (
199            HostPhaseTracegen { device_airs, host_traces },
200            HostPhaseShapePadding { shard_chips, padded_traces },
201        )
202    }
203
204    #[instrument(skip_all, level = "debug")]
205    async fn device_main_tracegen(
206        &self,
207        max_log_row_count: usize,
208        record: Arc<<A as MachineAir<F>>::Record>,
209        host_phase_tracegen: HostPhaseTracegen<F, A>,
210        padded_traces: BTreeMap<String, PaddedMle<F, TaskScope>>,
211    ) -> (Traces<F, TaskScope>, Vec<F>)
212    where
213        F: Field,
214        A: CudaTracegenAir<F>,
215    {
216        let HostPhaseTracegen { device_airs, host_traces } = host_phase_tracegen;
217
218        // Stream that, when polled, copies the host traces to the device.
219        let copied_host_traces = pin!(host_traces.then(|(name, trace)| async move {
220            (name, DeviceMle::from_host(&trace, &self.trace_allocator).unwrap().into())
221        }));
222        // Stream that, when polled, copies events to the device and generates traces.
223        let device_traces = device_airs
224            .into_iter()
225            .map(|air| {
226                // We want to borrow the record and move the chip.
227                let record = record.as_ref();
228                async move {
229                    let trace = air
230                        .generate_trace_device(
231                            record,
232                            &mut A::Record::default(),
233                            &self.trace_allocator,
234                        )
235                        .await
236                        .unwrap();
237                    (air.name().to_string(), trace.into())
238                }
239            })
240            .collect::<FuturesUnordered<_>>();
241
242        let mut all_traces = padded_traces;
243
244        // Combine the host and device trace streams and insert them into `all_traces`.
245        futures::stream_select!(copied_host_traces, device_traces)
246            .for_each(|(name, trace)| {
247                all_traces.insert(
248                    name,
249                    PaddedMle::padded_with_zeros(Arc::new(trace), max_log_row_count as u32),
250                );
251                ready(())
252            })
253            .await;
254
255        // All traces are now generated, so the public values are ready.
256        // That is, this value will have the correct global cumulative sum.
257        let public_values = record.public_values::<F>();
258
259        // If we're the last users of the record, expensively drop it in a separate task.
260        // TODO: in general, figure out the best way to drop expensive-to-drop things.
261        rayon::spawn(move || drop(record));
262
263        let traces = Traces { named_traces: all_traces };
264        (traces, public_values)
265    }
266}
267
268impl<F, A> TraceGenerator<F, A, TaskScope> for CudaTraceGenerator<F, A>
269where
270    F: Field,
271    A: CudaTracegenAir<F>,
272    TaskScope: DeviceTransposeKernel<F>,
273{
274    fn machine(&self) -> &Machine<F, A> {
275        &self.machine
276    }
277
278    fn allocator(&self) -> &TaskScope {
279        &self.trace_allocator
280    }
281
282    async fn generate_preprocessed_traces(
283        &self,
284        program: Arc<<A as MachineAir<F>>::Program>,
285        max_log_row_count: usize,
286        prover_permits: ProverSemaphore,
287    ) -> PreprocessedTraceData<F, TaskScope> {
288        let host_phase_tracegen = self.host_preprocessed_tracegen(Arc::clone(&program));
289
290        // Wait for a prover to be available.
291        let permit = prover_permits.acquire().instrument(debug_span!("acquire")).await.unwrap();
292
293        // Now that the permit is acquired, we can begin the following two tasks:
294        // - Copying host traces to the device.
295        // - Generating traces on the device.
296
297        let preprocessed_traces = self
298            .device_preprocessed_tracegen(program, max_log_row_count, host_phase_tracegen)
299            .await;
300        PreprocessedTraceData { preprocessed_traces, permit }
301    }
302
303    async fn generate_main_traces(
304        &self,
305        record: <A as MachineAir<F>>::Record,
306        max_log_row_count: usize,
307        prover_permits: ProverSemaphore,
308    ) -> MainTraceData<F, A, TaskScope> {
309        let record = Arc::new(record);
310
311        let (host_phase_tracegen, HostPhaseShapePadding { shard_chips, padded_traces }) =
312            self.host_main_tracegen(Arc::clone(&record), max_log_row_count);
313
314        // Wait for a prover to be available.
315        let permit = prover_permits.acquire().instrument(debug_span!("acquire")).await.unwrap();
316
317        // Now that the permit is acquired, we can begin the following two tasks:
318        // - Copying host traces to the device.
319        // - Generating traces on the device.
320
321        let (traces, public_values) = self
322            .device_main_tracegen(max_log_row_count, record, host_phase_tracegen, padded_traces)
323            .await;
324
325        MainTraceData { traces, public_values, permit, shard_chips }
326    }
327
328    async fn generate_traces(
329        &self,
330        program: Arc<<A as MachineAir<F>>::Program>,
331        record: <A as MachineAir<F>>::Record,
332        max_log_row_count: usize,
333        prover_permits: sp1_hypercube::prover::ProverSemaphore,
334    ) -> TraceData<F, A, TaskScope> {
335        let record = Arc::new(record);
336
337        let prep_host_phase_tracegen = self.host_preprocessed_tracegen(Arc::clone(&program));
338
339        let (main_host_phase_tracegen, HostPhaseShapePadding { shard_chips, padded_traces }) =
340            self.host_main_tracegen(Arc::clone(&record), max_log_row_count);
341
342        // Wait for a prover to be available.
343        let permit = prover_permits.acquire().instrument(debug_span!("acquire")).await.unwrap();
344
345        // Now that the permit is acquired, we can begin the following two tasks:
346        // - Copying host traces to the device.
347        // - Generating traces on the device.
348
349        let (preprocessed_traces, (traces, public_values)) = join!(
350            self.device_preprocessed_tracegen(program, max_log_row_count, prep_host_phase_tracegen),
351            self.device_main_tracegen(
352                max_log_row_count,
353                record,
354                main_host_phase_tracegen,
355                padded_traces,
356            )
357        );
358
359        TraceData {
360            preprocessed_traces,
361            main_trace_data: MainTraceData { traces, public_values, permit, shard_chips },
362        }
363    }
364}
365
366/// An AIR that potentially supports device trace generation over the given field.
367pub trait CudaTracegenAir<F: Field>: MachineAir<F> {
368    /// Whether this AIR supports preprocessed trace generation on the device.
369    fn supports_device_preprocessed_tracegen(&self) -> bool {
370        false
371    }
372
373    /// Generate the preprocessed trace on the device.
374    ///
375    /// # Panics
376    /// Panics if unsupported. See [`CudaTracegenAir::supports_device_preprocessed_tracegen`].
377    #[allow(unused_variables)]
378    fn generate_preprocessed_trace_device(
379        &self,
380        program: &Self::Program,
381        scope: &TaskScope,
382    ) -> impl Future<Output = Result<Option<DeviceMle<F>>, CopyError>> + Send {
383        #[allow(unreachable_code)]
384        ready(unimplemented!())
385    }
386
387    /// Whether this AIR supports main trace generation on the device.
388    fn supports_device_main_tracegen(&self) -> bool {
389        false
390    }
391
392    /// Generate the main trace on the device.
393    ///
394    /// # Panics
395    /// Panics if unsupported. See [`CudaTracegenAir::supports_device_main_tracegen`].
396    #[allow(unused_variables)]
397    fn generate_trace_device(
398        &self,
399        input: &Self::Record,
400        output: &mut Self::Record,
401        scope: &TaskScope,
402    ) -> impl Future<Output = Result<DeviceMle<F>, CopyError>> + Send {
403        #[allow(unreachable_code)]
404        ready(unimplemented!())
405    }
406}
407
408#[cfg(test)]
409pub(crate) mod tests {
410    use super::{CudaTracegenAir, F};
411    use rand::{rngs::StdRng, SeedableRng};
412    use slop_tensor::Tensor;
413    use sp1_gpu_cudart::TaskScope;
414    use sp1_hypercube::air::MachineAir;
415    use std::collections::BTreeSet;
416
417    pub(crate) fn test_traces_eq(
418        trace: &Tensor<F>,
419        gpu_trace: &Tensor<F>,
420        events: &[impl core::fmt::Debug],
421    ) {
422        assert_eq!(gpu_trace.dimensions, trace.dimensions);
423
424        tracing::info!("{:?}", trace.dimensions);
425
426        let mut eventful_mismatched_columns = BTreeSet::new();
427        let mut padding_mismatched_columns = BTreeSet::new();
428        for row_idx in 0..trace.sizes()[0] {
429            let mut col_mismatches = BTreeSet::new();
430            for col_idx in 0..trace.sizes()[1] {
431                let actual = gpu_trace[[row_idx, col_idx]];
432                let expected = trace[[row_idx, col_idx]];
433                if actual != expected {
434                    tracing::error!(
435                        "mismatch on row {} col {}. actual: {:?} expected: {:?}",
436                        row_idx,
437                        col_idx,
438                        *actual,
439                        *expected
440                    );
441                    col_mismatches.insert(col_idx);
442                }
443            }
444            let event = events.get(row_idx);
445            if col_mismatches.is_empty() {
446                tracing::info!(
447                    "row {row_idx} matches   . event (assuming events/row = 1): {event:?}"
448                );
449            } else {
450                tracing::error!(
451                    "row {row_idx} MISMATCHES. event (assuming events/row = 1): {event:?}"
452                );
453                tracing::error!("mismatched columns: {col_mismatches:?}");
454            }
455            if event.is_some() {
456                eventful_mismatched_columns.extend(col_mismatches);
457            } else {
458                padding_mismatched_columns.extend(col_mismatches);
459            }
460        }
461        tracing::info!("eventful mismatched columns: {eventful_mismatched_columns:?}");
462        tracing::info!("padding mismatched columns: {padding_mismatched_columns:?}");
463
464        assert_eq!(gpu_trace, trace);
465    }
466
467    pub async fn test_main_tracegen<A, Event, Record>(
468        chip: A,
469        mut make_event: impl FnMut(&mut StdRng) -> Event,
470        mut insert_events: impl FnMut(Vec<Event>) -> Record,
471        scope: TaskScope,
472    ) where
473        A: CudaTracegenAir<F> + MachineAir<F, Record = Record>,
474        Record: Default,
475        Event: Clone + core::fmt::Debug,
476    {
477        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
478
479        let events =
480            core::iter::repeat_with(|| make_event(&mut rng)).take(1000).collect::<Vec<_>>();
481
482        let [shard, gpu_shard] = core::array::from_fn(|_| insert_events(events.clone()));
483
484        let trace = Tensor::<F>::from(chip.generate_trace(&shard, &mut Record::default()));
485
486        let gpu_trace = chip
487            .generate_trace_device(&gpu_shard, &mut Record::default(), &scope)
488            .await
489            .expect("should copy events to device successfully")
490            .to_host()
491            .expect("should copy trace to host successfully")
492            .into_guts();
493
494        crate::tests::test_traces_eq(&trace, &gpu_trace, &events);
495    }
496}