Skip to main content

sp1_gpu_utils/
jagged.rs

1use std::iter::once;
2
3use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
4use slop_tensor::{Dimensions, Tensor};
5use sp1_gpu_cudart::{args, TaskScope};
6
7#[derive(Clone, Debug)]
8#[repr(C)]
9pub struct JaggedMle<D: DenseData<A>, A: Backend> {
10    /// col_index[i / 2] is the column that the i'th element of the dense data belongs to.
11    pub col_index: Buffer<u32, A>,
12    /// start_indices[i] is the half the dense index of the first element of the i'th column.
13    pub start_indices: Buffer<u32, A>,
14    /// column_heights[i] is half of the height of the i'th column. Device-
15    /// resident — every fold runs on the GPU, and zerocheck consumes it
16    /// directly on device to derive per-chip layouts without a host round-trip.
17    pub column_heights: Buffer<u32, A>,
18    pub dense_data: D,
19}
20
21pub struct VirtualTensor<T, B: Backend> {
22    pub data: *const T,
23    pub sizes: Dimensions,
24    pub backend: B,
25}
26
27impl<T, B: Backend> VirtualTensor<T, B> {
28    pub fn new(data: *const T, sizes: Dimensions, backend: B) -> Self {
29        Self { data, sizes, backend }
30    }
31
32    pub fn sizes(&self) -> &[usize] {
33        self.sizes.sizes()
34    }
35
36    pub fn backend(&self) -> &B {
37        &self.backend
38    }
39
40    pub fn as_ptr(&self) -> *const T {
41        self.data
42    }
43
44    pub fn from_tensor(tensor: &Tensor<T, B>) -> Self {
45        Self {
46            data: tensor.as_ptr(),
47            sizes: tensor.shape().clone(),
48            backend: tensor.backend().clone(),
49        }
50    }
51}
52
53pub trait DenseData<A: Backend> {
54    type DenseDataRaw;
55    fn as_ptr(&self) -> Self::DenseDataRaw;
56}
57
58pub trait DenseDataMut<A: Backend>: DenseData<A> {
59    type DenseDataMutRaw;
60    fn as_mut_ptr(&mut self) -> Self::DenseDataMutRaw;
61}
62
63/// The raw pointer equivalent of [`JaggedMle`] for use in cuda kernels.
64#[repr(C)]
65pub struct JaggedMleRaw<D: DenseData<A>, A: Backend> {
66    col_index: *const u32,
67    start_indices: *const u32,
68    dense_data: D::DenseDataRaw,
69}
70
71/// The mutable raw pointer equivalent of [`JaggedMle`] for use in cuda kernels.
72#[repr(C)]
73pub struct JaggedMleMutRaw<D: DenseDataMut<A>, A: Backend> {
74    col_index: *mut u32,
75    start_indices: *mut u32,
76    dense_data: D::DenseDataMutRaw,
77}
78
79impl<D: DenseData<A>, A: Backend> JaggedMle<D, A> {
80    pub fn as_raw(&self) -> JaggedMleRaw<D, A> {
81        JaggedMleRaw {
82            col_index: self.col_index.as_ptr(),
83            start_indices: self.start_indices.as_ptr(),
84            dense_data: self.dense_data.as_ptr(),
85        }
86    }
87
88    pub fn as_mut_raw(&mut self) -> JaggedMleMutRaw<D, A>
89    where
90        D: DenseDataMut<A>,
91    {
92        JaggedMleMutRaw {
93            col_index: self.col_index.as_mut_ptr(),
94            start_indices: self.start_indices.as_mut_ptr(),
95            dense_data: self.dense_data.as_mut_ptr(),
96        }
97    }
98
99    pub fn new(
100        dense_data: D,
101        col_index: Buffer<u32, A>,
102        start_indices: Buffer<u32, A>,
103        column_heights: Buffer<u32, A>,
104    ) -> Self {
105        Self { dense_data, col_index, start_indices, column_heights }
106    }
107
108    pub fn column_heights(&self) -> &Buffer<u32, A> {
109        &self.column_heights
110    }
111
112    pub fn dense(&self) -> &D {
113        &self.dense_data
114    }
115
116    pub fn dense_mut(&mut self) -> &mut D {
117        &mut self.dense_data
118    }
119
120    pub fn col_index(&self) -> &Buffer<u32, A> {
121        &self.col_index
122    }
123
124    pub fn col_index_mut(&mut self) -> &mut Buffer<u32, A> {
125        &mut self.col_index
126    }
127
128    pub fn start_indices(&self) -> &Buffer<u32, A> {
129        &self.start_indices
130    }
131
132    pub fn start_indices_mut(&mut self) -> &mut Buffer<u32, A> {
133        &mut self.start_indices
134    }
135
136    pub fn into_parts(self) -> (D, Buffer<u32, A>, Buffer<u32, A>) {
137        (self.dense_data, self.col_index, self.start_indices)
138    }
139}
140
141impl<D: DenseData<TaskScope>> JaggedMle<D, TaskScope> {
142    /// Computes the next start indices, column heights and *input* total
143    /// length for use in jagged fix last variable.
144    ///
145    /// Returns host buffers; the caller uploads device copies as needed. We
146    /// download `column_heights` once and compute on host because the per-
147    /// round fold's hot work happens on device — this metadata derivation is
148    /// O(n_columns) and the round trip is cheap. The input length is returned
149    /// alongside so callers don't re-download `column_heights` to sum it.
150    ///
151    /// TODO: ignore all of the padding stuff.
152    pub fn next_start_indices_and_column_heights(
153        &self,
154    ) -> (Buffer<u32, CpuBackend>, Vec<u32>, u32) {
155        // SAFETY: `column_heights` was populated via `extend_from_host_slice`
156        // (or the equivalent during fold), so the device range is fully
157        // initialised up to `len()`.
158        let host_column_heights: Vec<u32> = unsafe { self.column_heights.copy_into_host_vec() };
159        let input_length = host_column_heights.iter().sum::<u32>();
160        let output_heights =
161            host_column_heights.iter().map(|height| height.div_ceil(4) * 2).collect::<Vec<u32>>();
162
163        let new_start_idx = once(0)
164            .chain(output_heights.iter().scan(0u32, |acc, x| {
165                *acc += x;
166                Some(*acc)
167            }))
168            .collect::<Vec<_>>();
169        let buffer_start_idx = Buffer::from(new_start_idx);
170        (buffer_start_idx, output_heights, input_length)
171    }
172
173    /// Device-resident counterpart of [`Self::next_start_indices_and_column_heights`].
174    ///
175    /// Runs the `jagged_fold_metadata` kernel to compute the new `column_heights`
176    /// and `start_indices` on device (no host download of the input
177    /// `column_heights`, no host upload of the derived metadata). Reads back
178    /// only the final `output_height` scalar (last element of new
179    /// `start_indices`, ~4 bytes) since downstream callers need it to size
180    /// host-allocated output tensors.
181    ///
182    /// Replaces the bulk `D2H column_heights + 2× H2D start_idx/heights`
183    /// pattern with a single kernel launch + tiny D2H — on `v6/rsp` this
184    /// drops ~50 k of `cudaMemcpyAsync` calls per prove across the 4
185    /// logup_gkr callers (`execution::layer_transition`,
186    /// `execution::first_layer_transition`, `sumcheck::fix_and_sum_first_layer`,
187    /// `sumcheck::fix_and_sum_layer_transition`).
188    ///
189    /// Allocates the scan bookkeeping (`block_counter`, `flags`,
190    /// `scan_values`) inline. The bookkeeping is small (≤ `n_blocks + 1`
191    /// `u32` each, where `n_blocks = ceil(n_columns / SECTION_SIZE)` —
192    /// typically 1 for shards with a few hundred columns), so the extra
193    /// allocs are cheap compared to the bulk transfers eliminated.
194    ///
195    /// Returns `(new_start_indices_dev, new_column_heights_dev, output_height)`.
196    pub fn next_start_indices_and_column_heights_dev(
197        &self,
198    ) -> (Buffer<u32, TaskScope>, Buffer<u32, TaskScope>, u32) {
199        let backend = self.column_heights.backend();
200        let n_columns = self.column_heights.len();
201        let section_size =
202            unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_section_size() } as usize;
203        let block_dim = unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_block_dim() };
204        let n_blocks: usize = n_columns.div_ceil(section_size).max(1);
205
206        let mut new_column_heights =
207            Buffer::<u32, TaskScope>::with_capacity_in(n_columns, backend.clone());
208        let mut new_start_indices =
209            Buffer::<u32, TaskScope>::with_capacity_in(n_columns + 1, backend.clone());
210        // SAFETY: the fold-metadata kernel writes all `n_columns` +
211        // `n_columns + 1` slots before any downstream read.
212        unsafe {
213            new_column_heights.assume_init();
214            new_start_indices.assume_init();
215        }
216
217        // Decoupled-lookback scan bookkeeping. Per the contract in
218        // `fold_metadata.cuh`: `block_counter[0] = 0`, `flags[0] = 1` so
219        // the first block doesn't wait, `flags[1..]` and `scan_values[..]`
220        // start at zero.
221        let u32_bytes = std::mem::size_of::<u32>();
222        let mut block_counter = Buffer::<u32, TaskScope>::with_capacity_in(1, backend.clone());
223        let mut flags = Buffer::<u32, TaskScope>::with_capacity_in(n_blocks + 1, backend.clone());
224        let mut scan_values =
225            Buffer::<u32, TaskScope>::with_capacity_in(n_blocks + 1, backend.clone());
226        block_counter.write_bytes(0, u32_bytes).unwrap();
227        flags.write_bytes(1, u32_bytes).unwrap();
228        flags.write_bytes(0, n_blocks * u32_bytes).unwrap();
229        scan_values.write_bytes(0, (n_blocks + 1) * u32_bytes).unwrap();
230
231        // SAFETY: `args!` tuple matches `jagged_fold_metadata`'s C signature
232        // in `sys/include/jagged_assist/fold_metadata.cuh`; every pointer
233        // borrows from a Buffer owned for the launch's lifetime.
234        unsafe {
235            let a = args!(
236                self.column_heights.as_ptr(),
237                n_columns as u32,
238                new_column_heights.as_mut_ptr(),
239                new_start_indices.as_mut_ptr(),
240                block_counter.as_mut_ptr(),
241                flags.as_mut_ptr(),
242                scan_values.as_mut_ptr()
243            );
244            backend
245                .launch_kernel(
246                    sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_kernel(),
247                    (n_blocks as u32, 1u32, 1u32),
248                    (block_dim, 1u32, 1u32),
249                    &a,
250                    0,
251                )
252                .unwrap();
253        }
254
255        // Read back `output_height = new_start_indices[n_columns]`. The
256        // downstream caller needs this scalar to size the next layer's
257        // host-allocated output tensors. We download the whole
258        // `new_start_indices` buffer (n_columns + 1 u32 ≈ a few KB,
259        // *much* smaller than the bulk transfers this path replaces) and
260        // grab the last element. A future optimization could maintain
261        // `output_height` as a host-tracked scalar via the same
262        // recurrence the kernel runs, eliminating this final D2H.
263        // SAFETY: kernel above fully wrote `new_start_indices`; the
264        // download synchronizes on `backend`'s stream.
265        let host_start_idx: Vec<u32> = unsafe { new_start_indices.copy_into_host_vec() };
266        let output_height = *host_start_idx.last().unwrap();
267
268        (new_start_indices, new_column_heights, output_height)
269    }
270}
271
272impl<D: DenseData<A>, A: Backend> HasBackend for JaggedMle<D, A> {
273    type Backend = A;
274    fn backend(&self) -> &A {
275        self.col_index.backend()
276    }
277}