sp1_gpu_tracegen/recursion/
prefix_sum_checks.rs1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::TracegenRecursionPrefixSumChecksKernel;
6use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_machine::chips::prefix_sum_checks::PrefixSumChecksChip;
9
10use crate::{CudaTracegenAir, F};
11
12impl CudaTracegenAir<F> for PrefixSumChecksChip {
13 fn supports_device_main_tracegen(&self) -> bool {
14 true
15 }
16
17 async fn generate_trace_device(
18 &self,
19 input: &Self::Record,
20 _: &mut Self::Record,
21 scope: &TaskScope,
22 ) -> Result<DeviceMle<F>, CopyError> {
23 let events = &input.prefix_sum_checks_events;
24
25 let events_device = {
26 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
27 buf.extend_from_host_slice(events)?;
28 buf
29 };
30
31 let width = <Self as BaseAir<F>>::width(self);
32
33 let height = <Self as MachineAir<F>>::num_rows(self, input)
34 .expect("num_rows(...) should be Some(_)");
35
36 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
37
38 unsafe {
39 const BLOCK_DIM: usize = 64;
40 let grid_dim = height.div_ceil(BLOCK_DIM);
41 let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
47 scope
48 .launch_kernel(
49 TaskScope::tracegen_recursion_prefix_sum_checks_kernel(),
50 grid_dim,
51 BLOCK_DIM,
52 &args,
53 0,
54 )
55 .unwrap();
56 }
57
58 Ok(DeviceMle::from(trace))
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use rand::Rng;
65
66 use sp1_recursion_executor::{Block, ExecutionRecord, PrefixSumChecksEvent};
67 use sp1_recursion_machine::chips::prefix_sum_checks::PrefixSumChecksChip;
68
69 #[tokio::test]
70 async fn test_prefix_sum_checks_generate_main_trace() {
71 sp1_gpu_cudart::spawn(move |scope| {
72 crate::tests::test_main_tracegen(
73 PrefixSumChecksChip,
74 |rng| PrefixSumChecksEvent {
75 x1: rng.gen(),
76 x2: Block(rng.gen()),
77 zero: rng.gen(),
78 one: Block(rng.gen()),
79 acc: Block(rng.gen()),
80 new_acc: Block(rng.gen()),
81 field_acc: rng.gen(),
82 new_field_acc: rng.gen(),
83 },
84 |prefix_sum_checks_events| ExecutionRecord {
85 prefix_sum_checks_events,
86 ..Default::default()
87 },
88 scope,
89 )
90 })
91 .await
92 .unwrap();
93 }
94}