Skip to main content

sp1_gpu_tracegen/recursion/
prefix_sum_checks.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::TracegenRecursionPrefixSumChecksKernel;
6use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_machine::chips::prefix_sum_checks::PrefixSumChecksChip;
9
10use crate::{CudaTracegenAir, F};
11
12impl CudaTracegenAir<F> for PrefixSumChecksChip {
13    fn supports_device_main_tracegen(&self) -> bool {
14        true
15    }
16
17    async fn generate_trace_device(
18        &self,
19        input: &Self::Record,
20        _: &mut Self::Record,
21        scope: &TaskScope,
22    ) -> Result<DeviceMle<F>, CopyError> {
23        let events = &input.prefix_sum_checks_events;
24
25        let events_device = {
26            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
27            buf.extend_from_host_slice(events)?;
28            buf
29        };
30
31        let width = <Self as BaseAir<F>>::width(self);
32
33        let height = <Self as MachineAir<F>>::num_rows(self, input)
34            .expect("num_rows(...) should be Some(_)");
35
36        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
37
38        unsafe {
39            const BLOCK_DIM: usize = 64;
40            let grid_dim = height.div_ceil(BLOCK_DIM);
41            // args:
42            // T *trace,
43            // uintptr_t trace_height,
44            // const sp1_gpu_sys::PrefixSumChecksEvent<T> *events,
45            // uintptr_t nb_events
46            let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
47            scope
48                .launch_kernel(
49                    TaskScope::tracegen_recursion_prefix_sum_checks_kernel(),
50                    grid_dim,
51                    BLOCK_DIM,
52                    &args,
53                    0,
54                )
55                .unwrap();
56        }
57
58        Ok(DeviceMle::from(trace))
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use rand::Rng;
65
66    use sp1_recursion_executor::{Block, ExecutionRecord, PrefixSumChecksEvent};
67    use sp1_recursion_machine::chips::prefix_sum_checks::PrefixSumChecksChip;
68
69    #[tokio::test]
70    async fn test_prefix_sum_checks_generate_main_trace() {
71        sp1_gpu_cudart::spawn(move |scope| {
72            crate::tests::test_main_tracegen(
73                PrefixSumChecksChip,
74                |rng| PrefixSumChecksEvent {
75                    x1: rng.gen(),
76                    x2: Block(rng.gen()),
77                    zero: rng.gen(),
78                    one: Block(rng.gen()),
79                    acc: Block(rng.gen()),
80                    new_acc: Block(rng.gen()),
81                    field_acc: rng.gen(),
82                    new_field_acc: rng.gen(),
83                },
84                |prefix_sum_checks_events| ExecutionRecord {
85                    prefix_sum_checks_events,
86                    ..Default::default()
87                },
88                scope,
89            )
90        })
91        .await
92        .unwrap();
93    }
94}