Skip to main content

sp1_gpu_tracegen/riscv/
global.rs

1use core::mem;
2use std::sync::Arc;
3
4use futures::future::join_all;
5use slop_algebra::PrimeField32;
6use slop_alloc::mem::CopyError;
7use slop_alloc::Buffer;
8use slop_tensor::Tensor;
9use sp1_core_machine::global::{GlobalChip, GlobalCols, GLOBAL_INITIAL_DIGEST_POS};
10use sp1_gpu_cudart::sys::runtime::Dim3;
11use sp1_gpu_cudart::transpose::DeviceTransposeKernel;
12use sp1_gpu_cudart::{args, DeviceMle, ScanKernel, TaskScope};
13use sp1_hypercube::air::MachineAir;
14use sp1_hypercube::septic_curve::SepticCurve;
15use sp1_hypercube::septic_digest::SepticDigest;
16use sp1_hypercube::septic_extension::{SepticBlock, SepticExtension};
17
18use sp1_gpu_cudart::TracegenRiscvGlobalKernel;
19
20use crate::{CudaTracegenAir, F};
21
22impl CudaTracegenAir<F> for GlobalChip {
23    fn supports_device_main_tracegen(&self) -> bool {
24        true
25    }
26
27    async fn generate_trace_device(
28        &self,
29        input: &Self::Record,
30        output: &mut Self::Record,
31        scope: &TaskScope,
32    ) -> Result<DeviceMle<F>, CopyError> {
33        let events = &input.global_interaction_events;
34        let events_len = events.len();
35
36        let events_device = {
37            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
38            buf.extend_from_host_slice(events)?;
39            buf
40        };
41
42        const NUM_GLOBAL_COLS: usize = size_of::<GlobalCols<u8>>();
43
44        let height = <Self as MachineAir<F>>::num_rows(self, input)
45            .expect("num_rows(...) should be Some(_)");
46
47        let mut trace = Tensor::<F, TaskScope>::zeros_in([NUM_GLOBAL_COLS, height], scope.clone());
48
49        // "Round 1": call the decompress kernel.
50        unsafe {
51            const BLOCK_DIM: usize = 64;
52            let grid_dim = height.div_ceil(BLOCK_DIM);
53            // args:
54            // kb31_t *trace,
55            // uintptr_t trace_height,
56            // const sp1_gpu_sys::GlobalInteractionEvent *events,
57            // uintptr_t nb_events
58            let tracegen_riscv_global_args =
59                args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
60            scope
61                .launch_kernel(
62                    TaskScope::tracegen_riscv_global_decompress_kernel(),
63                    grid_dim,
64                    BLOCK_DIM,
65                    &tracegen_riscv_global_args,
66                    0,
67                )
68                .unwrap();
69        }
70
71        // "Round 2": do some munging and then call the scan kernel.
72
73        // The curve is over a degree 7 extension of the base field F.
74        const CURVE_FIELD_EXT_DEGREE: usize = 7;
75        assert_eq!(CURVE_FIELD_EXT_DEGREE * mem::size_of::<F>(), mem::size_of::<SepticBlock<F>>());
76        // A point on the curve is described by two coordinates.
77        const CURVE_POINT_WIDTH: usize = 2 * CURVE_FIELD_EXT_DEGREE;
78        assert_eq!(mem::size_of::<[SepticBlock<F>; 2]>(), CURVE_POINT_WIDTH * mem::size_of::<F>());
79        assert_eq!(mem::size_of::<SepticCurve<F>>(), CURVE_POINT_WIDTH * mem::size_of::<F>());
80        // Output of the parallel prefix sum (scan).
81        let mut cumulative_sums =
82            Buffer::<SepticCurve<F>, _>::with_capacity_in(height, scope.clone());
83        // Destination of the transpose.
84        let mut accumulation_initial_digest_row_major =
85            Buffer::<SepticCurve<F>, _>::with_capacity_in(height, scope.clone());
86
87        // Transpose the event.accumulation.initial_digest columns into row-major form.
88        {
89            let accumulation_initial_digest_col_major = &trace.as_buffer()[(height
90                * GLOBAL_INITIAL_DIGEST_POS)
91                ..(height * (GLOBAL_INITIAL_DIGEST_POS + CURVE_POINT_WIDTH))];
92            // Call the transpose kernel manually.
93            // Existing APIs don't support "tensor slices" so we have to do this.
94            let src_sizes = [CURVE_POINT_WIDTH, height];
95            let src_ptr = accumulation_initial_digest_col_major.as_ptr();
96            assert_eq!(
97                src_sizes.into_iter().product::<usize>(),
98                accumulation_initial_digest_col_major.len()
99            );
100            let dst_sizes = [height, CURVE_POINT_WIDTH];
101            let dst_mut_ptr = accumulation_initial_digest_row_major.as_mut_ptr();
102            let num_dims = src_sizes.len();
103
104            let dim_x = src_sizes[num_dims - 2];
105            let dim_y = src_sizes[num_dims - 1];
106            let dim_z: usize = src_sizes.iter().take(num_dims - 2).product();
107            assert_eq!(dim_x, dst_sizes[num_dims - 1]);
108            assert_eq!(dim_y, dst_sizes[num_dims - 2]);
109
110            let block_dim: Dim3 = (32u32, 32u32, 1u32).into();
111            let grid_dim: Dim3 = (
112                dim_x.div_ceil(block_dim.x as usize),
113                dim_y.div_ceil(block_dim.y as usize),
114                dim_z.div_ceil(block_dim.z as usize),
115            )
116                .into();
117            let args = args!(src_ptr, dst_mut_ptr, dim_x, dim_y, dim_z);
118            unsafe {
119                scope
120                    .launch_kernel(
121                        <TaskScope as DeviceTransposeKernel<F>>::transpose_kernel(),
122                        grid_dim,
123                        block_dim,
124                        &args,
125                        0,
126                    )
127                    .unwrap();
128            }
129        }
130
131        // Call the scan kernel.
132        // TODO: make a nice scan API with a trait.
133        {
134            const SCAN_KERNEL_LARGE_SECTION_SIZE: usize = 512;
135            let d_out = cumulative_sums.as_mut_ptr();
136            let d_in = accumulation_initial_digest_row_major.as_ptr();
137            let n = height;
138            if (2 * n) <= SCAN_KERNEL_LARGE_SECTION_SIZE {
139                let args = args!(d_out, d_in, n);
140                unsafe {
141                    scope
142                        .launch_kernel(
143                            <TaskScope as ScanKernel<F>>::single_block_scan_kernel_large_bb31_septic_curve(
144                            ),
145                            1,
146                            n,
147                            &args,
148                            0,
149                        )
150                        .unwrap()
151                };
152            } else {
153                let block_dim = SCAN_KERNEL_LARGE_SECTION_SIZE / 2;
154                let num_blocks = n.div_ceil(block_dim);
155                // Create `scan_values` as an array consisting of a single zero cell followed by
156                // `num_blocks` uninitialized cells.
157                let mut scan_values =
158                    Buffer::<SepticCurve<F>, _>::with_capacity_in(num_blocks + 1, scope.clone());
159                scan_values.write_bytes(0, mem::size_of::<SepticCurve<F>>()).unwrap();
160                // Create `block_counter` as a an array consisting of a single zero cell.
161                let mut block_counter = Buffer::<u32, _>::with_capacity_in(1, scope.clone());
162                block_counter.write_bytes(0, mem::size_of::<u32>()).unwrap();
163                // Create `flags` as an array consisting of a single one cell followed by
164                // `num_blocks` zero cells.
165                let mut flags = Buffer::<u32, _>::with_capacity_in(num_blocks + 1, scope.clone());
166                flags.write_bytes(1, size_of::<u32>()).unwrap();
167                flags.write_bytes(0, num_blocks * size_of::<u32>()).unwrap();
168                debug_assert_eq!(flags.len(), num_blocks + 1);
169                let args = args!(
170                    d_out,
171                    d_in,
172                    n,
173                    scan_values.as_mut_ptr(),
174                    block_counter.as_mut_ptr(),
175                    flags.as_mut_ptr()
176                );
177                unsafe {
178                    scope
179                        .launch_kernel(
180                            <TaskScope as ScanKernel<F>>::scan_kernel_large_bb31_septic_curve(),
181                            num_blocks,
182                            block_dim,
183                            &args,
184                            0,
185                        )
186                        .unwrap()
187                };
188            }
189        }
190        // This transposed version was only needed for the scan operation.
191        drop(accumulation_initial_digest_row_major);
192
193        // "Round 3": call the finalize kernel.
194        unsafe {
195            const BLOCK_DIM: usize = 64;
196            let grid_dim = height.div_ceil(BLOCK_DIM);
197            // args:
198            // kb31_t *trace,
199            // uintptr_t trace_height,
200            // const bb31_septic_curve_t *cumulative_sums,
201            // uintptr_t nb_events
202            let tracegen_riscv_global_args =
203                args!(trace.as_mut_ptr(), height, cumulative_sums.as_ptr(), events.len());
204            scope
205                .launch_kernel(
206                    TaskScope::tracegen_riscv_global_finalize_kernel(),
207                    grid_dim,
208                    BLOCK_DIM,
209                    &tracegen_riscv_global_args,
210                    0,
211                )
212                .unwrap();
213        }
214
215        // Modify the records passed as arguments.
216        output.global_interaction_event_count =
217            events.len().try_into().expect("number of Global events should fit in a u32");
218        // Wrap the trace so we can use it in concurrent tasks.
219        let trace = Arc::new(trace);
220
221        let global_sum = if height == 0 {
222            SepticDigest(SepticCurve::convert(SepticDigest::<F>::zero().0, |x| {
223                F::as_canonical_u32(&x)
224            }))
225        } else {
226            // // Copy the last digest of the last `CURVE_POINT_WIDTH` columns, which are the global digest columns.
227            const CUMULATIVE_SUM_COL_START: usize =
228                mem::offset_of!(GlobalCols<u8>, accumulation.cumulative_sum);
229            assert_eq!(CUMULATIVE_SUM_COL_START + CURVE_POINT_WIDTH, NUM_GLOBAL_COLS);
230            let copied_sum = join_all((CUMULATIVE_SUM_COL_START..NUM_GLOBAL_COLS).map(|i| {
231                let trace = Arc::clone(&trace);
232                let scope = scope.clone();
233                tokio::task::spawn_blocking(move || {
234                    // No need to synchronize, since the host memory is not pinned.
235                    trace[[i, events_len - 1]].copy_into_host(&scope)
236                })
237            }))
238            .await;
239            SepticDigest(SepticCurve {
240                x: SepticExtension(core::array::from_fn(|i| {
241                    copied_sum[i].as_ref().unwrap().as_canonical_u32()
242                })),
243                y: SepticExtension(core::array::from_fn(|i| {
244                    copied_sum[CURVE_FIELD_EXT_DEGREE + i].as_ref().unwrap().as_canonical_u32()
245                })),
246            })
247        };
248
249        *input.global_cumulative_sum.lock().unwrap() = global_sum;
250
251        let trace =
252            Arc::into_inner(trace).expect("trace Arc should have exactly one strong reference");
253
254        Ok(DeviceMle::from(trace))
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use rand::{rngs::StdRng, Rng, SeedableRng};
261    use slop_algebra::PrimeField32;
262    use slop_tensor::Tensor;
263    use sp1_core_executor::{events::GlobalInteractionEvent, ExecutionRecord};
264    use sp1_core_machine::global::GlobalChip;
265    use sp1_gpu_cudart::TaskScope;
266    use sp1_hypercube::air::MachineAir;
267    use sp1_hypercube::MachineRecord;
268
269    use crate::{CudaTracegenAir, F};
270
271    #[tokio::test]
272    async fn test_global_generate_trace() {
273        sp1_gpu_cudart::spawn(inner_test_global_generate_trace).await.unwrap();
274    }
275
276    async fn inner_test_global_generate_trace(scope: TaskScope) {
277        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
278        let events = core::iter::repeat_with(|| GlobalInteractionEvent {
279            // These seem to be the numerical bounds that make a `GlobalInteractionEvent` valid.
280            message: core::array::from_fn(|_| rng.gen::<F>().as_canonical_u32()),
281            is_receive: rng.gen(),
282            kind: rng.gen_range(0..(1 << 6)),
283        })
284        .take(1000)
285        .collect::<Vec<_>>();
286
287        let [shard, gpu_shard] = core::array::from_fn(|_| ExecutionRecord {
288            global_interaction_events: events.clone(),
289            ..Default::default()
290        });
291
292        let chip = GlobalChip;
293
294        let trace = Tensor::<F>::from(chip.generate_trace(&shard, &mut ExecutionRecord::default()));
295
296        let gpu_trace = chip
297            .generate_trace_device(&gpu_shard, &mut ExecutionRecord::default(), &scope)
298            .await
299            .expect("should copy events to device successfully")
300            .to_host()
301            .expect("should copy trace to host successfully")
302            .into_guts();
303
304        crate::tests::test_traces_eq(&trace, &gpu_trace, &events);
305
306        assert_eq!(
307            *gpu_shard.global_cumulative_sum.lock().unwrap(),
308            *shard.global_cumulative_sum.lock().unwrap()
309        );
310
311        assert_eq!(gpu_shard.public_values::<F>(), shard.public_values::<F>());
312    }
313}