Skip to main content

sp1_gpu_utils/
traces.rs

1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut, Range};
3
4use slop_algebra::Field;
5use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
6use slop_tensor::{Dimensions, Tensor, TensorView};
7use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
8
9use crate::jagged::JaggedMle;
10use crate::{DenseData, DenseDataMut};
11
12#[derive(Clone, Debug)]
13pub struct TraceOffset {
14    /// Dense data offset.
15    pub dense_offset: Range<usize>,
16    /// The size of each polynomial in this trace.
17    pub poly_size: usize,
18    /// Number of polynomials in this trace.
19    pub num_polys: usize,
20}
21
22#[derive(Clone)]
23pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
24
25impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
26    type Backend = B;
27    fn backend(&self) -> &B {
28        self.0.backend()
29    }
30}
31
32impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
33    type Target = JaggedMle<TraceDenseData<F, B>, B>;
34
35    fn deref(&self) -> &Self::Target {
36        &self.0
37    }
38}
39
40impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
41    fn deref_mut(&mut self) -> &mut Self::Target {
42        &mut self.0
43    }
44}
45
46/// Jagged representation of the traces.
47#[derive(Clone, Debug)]
48pub struct TraceDenseData<F: Field, B: Backend> {
49    /// The dense representation of the traces.
50    pub dense: Buffer<F, B>,
51    /// The dense offset of the preprocessed traces.
52    pub preprocessed_offset: usize,
53    /// The total number of columns in the preprocessed traces.
54    pub preprocessed_cols: usize,
55    /// The amount of preprocessed padding, to the next multiple of 2^log_stacking_height.
56    pub preprocessed_padding: usize,
57    /// The amount of main padding, to the next multiple of 2^log_stacking_height.
58    pub main_padding: usize,
59    /// A mapping from chip name to the range of dense data it occupies for preprocessed traces.
60    pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
61    /// A mapping from chip name to the range of dense data it occupies for main traces.
62    pub main_table_index: BTreeMap<String, TraceOffset>,
63}
64
65impl<F: Field, B: Backend> TraceDenseData<F, B> {
66    pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
67        let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
68        let sizes = Dimensions::try_from([
69            self.main_size() / (1 << log_stacking_height),
70            1 << log_stacking_height,
71        ])
72        .unwrap();
73        // This is safe because we inherit the lifetime of self and the offset should be valid.
74        unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
75    }
76
77    /// Copies the correct data from dense to a new tensor for main traces.
78    pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
79        let mut tensor = Tensor::with_sizes_in(
80            [self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
81            self.backend().clone(),
82        );
83        let backend = self.dense.backend();
84        unsafe {
85            tensor.assume_init();
86            tensor
87                .as_mut_buffer()
88                .copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
89                .unwrap();
90        }
91        tensor
92    }
93
94    pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
95        let ptr = self.dense.as_ptr();
96        let sizes = Dimensions::try_from([
97            self.preprocessed_offset / (1 << log_stacking_height),
98            1 << log_stacking_height,
99        ])
100        .unwrap();
101        unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
102    }
103
104    /// Copies the correct data from dense to a new tensor for preprocessed traces.
105    pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
106        let mut tensor = Tensor::with_sizes_in(
107            [self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
108            self.backend().clone(),
109        );
110        let backend = self.dense.backend();
111        unsafe {
112            tensor.assume_init();
113            tensor
114                .as_mut_buffer()
115                .copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
116                .unwrap();
117        }
118        tensor
119    }
120
121    /// The size of the main polynomial.
122    #[inline]
123    pub fn main_poly_height(&self, name: &str) -> Option<usize> {
124        self.main_table_index.get(name).map(|offset| offset.poly_size)
125    }
126
127    /// The size of the preprocessed polynomial.
128    #[inline]
129    pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
130        self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
131    }
132
133    /// The number of polynomials in the main trace.
134    #[inline]
135    pub fn main_num_polys(&self, name: &str) -> Option<usize> {
136        self.main_table_index.get(name).map(|offset| offset.num_polys)
137    }
138
139    /// The size of the main trace dense data, including padding.
140    #[inline]
141    pub fn main_size(&self) -> usize {
142        self.dense.len() - self.preprocessed_offset
143    }
144
145    /// The number of polynomials in the preprocessed trace.
146    #[inline]
147    pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
148        self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
149    }
150}
151
152/// Abstract description of a chip layout used to build [`TraceDenseData`] / [`JaggedTraceMle`].
153/// Each tuple is`(chip_name, preprocessed_width, main_width)` for one chip;
154pub struct AbstractChipLayout(Vec<(String, usize, usize)>);
155
156impl AbstractChipLayout {
157    pub fn new(entries: Vec<(String, usize, usize)>) -> Self {
158        Self(entries)
159    }
160
161    pub fn entries(&self) -> &[(String, usize, usize)] {
162        &self.0
163    }
164}
165
166/// Like [`AbstractChipLayout`], but with a per-chip row count attached to each entry.
167/// Each tuple is `(chip_name, preprocessed_width, main_width, height)` for one chip.
168pub struct AbstractChipLayoutWithHeights(Vec<(String, usize, usize, usize)>);
169
170impl AbstractChipLayoutWithHeights {
171    pub fn new(entries: Vec<(String, usize, usize, usize)>) -> Self {
172        Self(entries)
173    }
174
175    pub fn entries(&self) -> &[(String, usize, usize, usize)] {
176        &self.0
177    }
178
179    /// Chip names in layout order.
180    pub fn chip_names(&self) -> impl Iterator<Item = &str> {
181        self.0.iter().map(|(name, _, _, _)| name.as_str())
182    }
183}
184
185impl<F: Field> TraceDenseData<F, CpuBackend> {
186    /// Build a `TraceDenseData` over a pre-allocated `dense` buffer using an
187    /// [`AbstractChipLayoutWithHeights`].
188    ///
189    /// The `dense` buffer must be sized as `padded_preprocessed + padded_main`, where
190    /// each section is the unpadded total rounded up to the next multiple of
191    /// `2^log_stacking_height`.
192    pub fn from_chip_layout(
193        dense: Buffer<F, CpuBackend>,
194        layout: &AbstractChipLayoutWithHeights,
195        log_stacking_height: u32,
196    ) -> Self {
197        let stacking = 1usize << log_stacking_height;
198
199        let total_preprocessed: usize = layout.0.iter().map(|(_, p, _, h)| p * h).sum();
200        let total_main: usize = layout.0.iter().map(|(_, _, m, h)| m * h).sum();
201
202        // note that this makes sure there is always at least one main and one preprocessed column
203        let padded_preprocessed = total_preprocessed.next_multiple_of(stacking).max(stacking);
204        let padded_main = total_main.next_multiple_of(stacking).max(stacking);
205
206        assert_eq!(
207            dense.len(),
208            padded_preprocessed + padded_main,
209            "dense buffer length must equal padded_preprocessed + padded_main",
210        );
211
212        let preprocessed_cols: usize = layout.0.iter().map(|(_, p, _, _)| p).sum();
213
214        let mut preprocessed_table_index = BTreeMap::new();
215        let mut main_table_index = BTreeMap::new();
216        let mut preprocessed_ptr = 0usize;
217        let mut main_ptr = padded_preprocessed;
218        for (name, prep_w, main_w, h) in layout.0.iter() {
219            let prep_lo = preprocessed_ptr;
220            let prep_hi = prep_lo + h * prep_w;
221            preprocessed_table_index.insert(
222                name.clone(),
223                TraceOffset { dense_offset: prep_lo..prep_hi, poly_size: *h, num_polys: *prep_w },
224            );
225            preprocessed_ptr = prep_hi;
226
227            let main_lo = main_ptr;
228            let main_hi = main_lo + h * main_w;
229            main_table_index.insert(
230                name.clone(),
231                TraceOffset { dense_offset: main_lo..main_hi, poly_size: *h, num_polys: *main_w },
232            );
233            main_ptr = main_hi;
234        }
235
236        TraceDenseData {
237            dense,
238            preprocessed_offset: padded_preprocessed,
239            preprocessed_cols,
240            preprocessed_padding: padded_preprocessed - total_preprocessed,
241            main_padding: padded_main - total_main,
242            preprocessed_table_index,
243            main_table_index,
244        }
245    }
246}
247
248impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
249    type Backend = B;
250    fn backend(&self) -> &B {
251        self.dense.backend()
252    }
253}
254
255impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
256    pub fn new(
257        dense_data: TraceDenseData<F, B>,
258        col_index: Buffer<u32, B>,
259        start_indices: Buffer<u32, B>,
260        column_heights: Vec<u32>,
261    ) -> Self {
262        JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
263    }
264}
265
266impl<F: Field> JaggedTraceMle<F, CpuBackend> {
267    /// Build a `JaggedTraceMle` over a pre-allocated `dense` buffer using a chip-layout
268    /// description as parallel slices. Constructs the inner [`TraceDenseData`] with the
269    /// same layout as [`TraceDenseData::from_chip_layout`], plus the jagged column
270    /// metadata: one logical column per chip column for both preprocessed and main,
271    /// plus one padding column per section that has nonzero padding.
272    ///
273    /// All heights must be even, since column heights and column-index entries
274    /// are stored at half-element granularity.
275    pub fn from_chip_layout(
276        dense: Buffer<F, CpuBackend>,
277        layout: &AbstractChipLayoutWithHeights,
278        log_stacking_height: u32,
279    ) -> Self {
280        assert!(layout.0.iter().all(|(_, _, _, h)| h % 2 == 0), "heights must be even");
281
282        let dense_data = TraceDenseData::from_chip_layout(dense, layout, log_stacking_height);
283
284        let total_dense = dense_data.dense.len();
285        let preprocessed_padding = dense_data.preprocessed_padding;
286        let main_padding = dense_data.main_padding;
287
288        let num_data_cols: usize = layout.0.iter().map(|(_, p, m, _)| p + m).sum();
289        let num_cols =
290            num_data_cols + (preprocessed_padding > 0) as usize + (main_padding > 0) as usize;
291
292        let mut col_index = vec![0u32; total_dense / 2];
293        let mut start_idx = vec![0u32; num_cols + 1];
294        let mut column_heights: Vec<u32> = Vec::with_capacity(num_cols);
295
296        let mut col: u32 = 0;
297        let mut cnt: usize = 0;
298
299        let mut emit = |w: usize, h: usize, col: &mut u32, cnt: &mut usize| {
300            let half = h / 2;
301            for _ in 0..w {
302                col_index[*cnt..*cnt + half].fill(*col);
303                *cnt += half;
304                start_idx[*col as usize + 1] = start_idx[*col as usize] + half as u32;
305                column_heights.push(half as u32);
306                *col += 1;
307            }
308        };
309
310        for (_, prep_w, _, h) in layout.0.iter() {
311            emit(*prep_w, *h, &mut col, &mut cnt);
312        }
313        if preprocessed_padding > 0 {
314            emit(1, preprocessed_padding, &mut col, &mut cnt);
315        }
316        for (_, _, main_w, h) in layout.0.iter() {
317            emit(*main_w, *h, &mut col, &mut cnt);
318        }
319        if main_padding > 0 {
320            emit(1, main_padding, &mut col, &mut cnt);
321        }
322
323        debug_assert_eq!(cnt, total_dense / 2);
324        debug_assert_eq!(col as usize, num_cols);
325
326        Self::new(dense_data, Buffer::from(col_index), Buffer::from(start_idx), column_heights)
327    }
328}
329
330impl<F: Field> JaggedTraceMle<F, TaskScope> {
331    pub fn preprocessed_virtual_tensor(
332        &'_ self,
333        log_stacking_height: u32,
334    ) -> TensorView<'_, F, TaskScope> {
335        self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
336    }
337
338    pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
339        self.dense_data.main_virtual_tensor(log_stacking_height)
340    }
341
342    pub fn main_poly_height(&self, name: &str) -> Option<usize> {
343        self.dense_data.main_poly_height(name)
344    }
345
346    pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
347        self.dense_data.preprocessed_poly_height(name)
348    }
349
350    pub fn main_num_polys(&self, name: &str) -> Option<usize> {
351        self.dense_data.main_num_polys(name)
352    }
353
354    pub fn main_size(&self) -> usize {
355        self.dense_data.main_size()
356    }
357
358    pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
359        self.dense_data.preprocessed_num_polys(name)
360    }
361}
362
363/// The raw pointer to the dense data, for use in CUDA FFI calls.
364#[repr(C)]
365pub struct TraceDenseDataRaw<F> {
366    dense: *const F,
367}
368
369/// The raw pointer to the dense data, for use in CUDA FFI calls.
370#[repr(C)]
371pub struct TraceDenseDataMutRaw<F> {
372    dense: *mut F,
373}
374
375impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
376    type DenseDataRaw = TraceDenseDataRaw<F>;
377
378    fn as_ptr(&self) -> TraceDenseDataRaw<F> {
379        TraceDenseDataRaw { dense: self.dense.as_ptr() }
380    }
381}
382
383impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
384    type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
385
386    fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
387        TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
388    }
389}
390
391impl<F: Field> JaggedTraceMle<F, CpuBackend> {
392    pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
393        let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
394        JaggedTraceMle::new(
395            dense_data.into_device_in(t),
396            DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
397            DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
398            column_heights,
399        )
400    }
401}
402
403impl<F: Field> TraceDenseData<F, CpuBackend> {
404    pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
405        TraceDenseData {
406            dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
407            preprocessed_offset: self.preprocessed_offset,
408            preprocessed_cols: self.preprocessed_cols,
409            preprocessed_table_index: self.preprocessed_table_index,
410            main_table_index: self.main_table_index,
411            preprocessed_padding: self.preprocessed_padding,
412            main_padding: self.main_padding,
413        }
414    }
415}
416
417impl<F: Field> JaggedTraceMle<F, TaskScope> {
418    pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
419        let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
420        let host_dense = dense_data.into_host();
421        // Convert device buffers to host using DeviceBuffer wrapper
422        let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
423        let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
424        JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights)
425    }
426}
427
428impl<F: Field> TraceDenseData<F, TaskScope> {
429    pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
430        let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
431        TraceDenseData {
432            dense: host_dense,
433            preprocessed_offset: self.preprocessed_offset,
434            preprocessed_cols: self.preprocessed_cols,
435            preprocessed_table_index: self.preprocessed_table_index,
436            main_table_index: self.main_table_index,
437            preprocessed_padding: self.preprocessed_padding,
438            main_padding: self.main_padding,
439        }
440    }
441}