Skip to main content

sp1_gpu_utils/
traces.rs

1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut, Range};
3
4use slop_algebra::Field;
5use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
6use slop_tensor::{Dimensions, Tensor, TensorView};
7use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
8
9use crate::jagged::JaggedMle;
10use crate::{DenseData, DenseDataMut};
11
12#[derive(Clone, Debug)]
13pub struct TraceOffset {
14    /// Dense data offset.
15    pub dense_offset: Range<usize>,
16    /// The size of each polynomial in this trace.
17    pub poly_size: usize,
18    /// Number of polynomials in this trace.
19    pub num_polys: usize,
20}
21
22#[derive(Clone)]
23pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
24
25impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
26    type Backend = B;
27    fn backend(&self) -> &B {
28        self.0.backend()
29    }
30}
31
32impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
33    type Target = JaggedMle<TraceDenseData<F, B>, B>;
34
35    fn deref(&self) -> &Self::Target {
36        &self.0
37    }
38}
39
40impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
41    fn deref_mut(&mut self) -> &mut Self::Target {
42        &mut self.0
43    }
44}
45
46/// Jagged representation of the traces.
47#[derive(Clone, Debug)]
48pub struct TraceDenseData<F: Field, B: Backend> {
49    /// The dense representation of the traces.
50    pub dense: Buffer<F, B>,
51    /// The dense offset of the preprocessed traces.
52    pub preprocessed_offset: usize,
53    /// The total number of columns in the preprocessed traces.
54    pub preprocessed_cols: usize,
55    /// The amount of preprocessed padding, to the next multiple of 2^log_stacking_height.
56    pub preprocessed_padding: usize,
57    /// The amount of main padding, to the next multiple of 2^log_stacking_height.
58    pub main_padding: usize,
59    /// A mapping from chip name to the range of dense data it occupies for preprocessed traces.
60    pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
61    /// A mapping from chip name to the range of dense data it occupies for main traces.
62    pub main_table_index: BTreeMap<String, TraceOffset>,
63}
64
65impl<F: Field, B: Backend> TraceDenseData<F, B> {
66    pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
67        let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
68        let sizes = Dimensions::try_from([
69            self.main_size() / (1 << log_stacking_height),
70            1 << log_stacking_height,
71        ])
72        .unwrap();
73        // This is safe because we inherit the lifetime of self and the offset should be valid.
74        unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
75    }
76
77    /// Copies the correct data from dense to a new tensor for main traces.
78    pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
79        let mut tensor = Tensor::with_sizes_in(
80            [self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
81            self.backend().clone(),
82        );
83        let backend = self.dense.backend();
84        unsafe {
85            tensor.assume_init();
86            tensor
87                .as_mut_buffer()
88                .copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
89                .unwrap();
90        }
91        tensor
92    }
93
94    pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
95        let ptr = self.dense.as_ptr();
96        let sizes = Dimensions::try_from([
97            self.preprocessed_offset / (1 << log_stacking_height),
98            1 << log_stacking_height,
99        ])
100        .unwrap();
101        unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
102    }
103
104    /// Copies the correct data from dense to a new tensor for preprocessed traces.
105    pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
106        let mut tensor = Tensor::with_sizes_in(
107            [self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
108            self.backend().clone(),
109        );
110        let backend = self.dense.backend();
111        unsafe {
112            tensor.assume_init();
113            tensor
114                .as_mut_buffer()
115                .copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
116                .unwrap();
117        }
118        tensor
119    }
120
121    /// The size of the main polynomial.
122    #[inline]
123    pub fn main_poly_height(&self, name: &str) -> Option<usize> {
124        self.main_table_index.get(name).map(|offset| offset.poly_size)
125    }
126
127    /// The size of the preprocessed polynomial.
128    #[inline]
129    pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
130        self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
131    }
132
133    /// The number of polynomials in the main trace.
134    #[inline]
135    pub fn main_num_polys(&self, name: &str) -> Option<usize> {
136        self.main_table_index.get(name).map(|offset| offset.num_polys)
137    }
138
139    /// The size of the main trace dense data, including padding.
140    #[inline]
141    pub fn main_size(&self) -> usize {
142        self.dense.len() - self.preprocessed_offset
143    }
144
145    /// The number of polynomials in the preprocessed trace.
146    #[inline]
147    pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
148        self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
149    }
150}
151
152impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
153    type Backend = B;
154    fn backend(&self) -> &B {
155        self.dense.backend()
156    }
157}
158
159impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
160    pub fn new(
161        dense_data: TraceDenseData<F, B>,
162        col_index: Buffer<u32, B>,
163        start_indices: Buffer<u32, B>,
164        column_heights: Vec<u32>,
165    ) -> Self {
166        JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
167    }
168}
169
170impl<F: Field> JaggedTraceMle<F, TaskScope> {
171    pub fn preprocessed_virtual_tensor(
172        &'_ self,
173        log_stacking_height: u32,
174    ) -> TensorView<'_, F, TaskScope> {
175        self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
176    }
177
178    pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
179        self.dense_data.main_virtual_tensor(log_stacking_height)
180    }
181
182    pub fn main_poly_height(&self, name: &str) -> Option<usize> {
183        self.dense_data.main_poly_height(name)
184    }
185
186    pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
187        self.dense_data.preprocessed_poly_height(name)
188    }
189
190    pub fn main_num_polys(&self, name: &str) -> Option<usize> {
191        self.dense_data.main_num_polys(name)
192    }
193
194    pub fn main_size(&self) -> usize {
195        self.dense_data.main_size()
196    }
197
198    pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
199        self.dense_data.preprocessed_num_polys(name)
200    }
201}
202
203/// The raw pointer to the dense data, for use in CUDA FFI calls.
204#[repr(C)]
205pub struct TraceDenseDataRaw<F> {
206    dense: *const F,
207}
208
209/// The raw pointer to the dense data, for use in CUDA FFI calls.
210#[repr(C)]
211pub struct TraceDenseDataMutRaw<F> {
212    dense: *mut F,
213}
214
215impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
216    type DenseDataRaw = TraceDenseDataRaw<F>;
217
218    fn as_ptr(&self) -> TraceDenseDataRaw<F> {
219        TraceDenseDataRaw { dense: self.dense.as_ptr() }
220    }
221}
222
223impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
224    type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
225
226    fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
227        TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
228    }
229}
230
231impl<F: Field> JaggedTraceMle<F, CpuBackend> {
232    pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
233        let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
234        JaggedTraceMle::new(
235            dense_data.into_device_in(t),
236            DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
237            DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
238            column_heights,
239        )
240    }
241}
242
243impl<F: Field> TraceDenseData<F, CpuBackend> {
244    pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
245        TraceDenseData {
246            dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
247            preprocessed_offset: self.preprocessed_offset,
248            preprocessed_cols: self.preprocessed_cols,
249            preprocessed_table_index: self.preprocessed_table_index,
250            main_table_index: self.main_table_index,
251            preprocessed_padding: self.preprocessed_padding,
252            main_padding: self.main_padding,
253        }
254    }
255}
256
257impl<F: Field> JaggedTraceMle<F, TaskScope> {
258    pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
259        let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
260        let host_dense = dense_data.into_host();
261        // Convert device buffers to host using DeviceBuffer wrapper
262        let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
263        let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
264        JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights)
265    }
266}
267
268impl<F: Field> TraceDenseData<F, TaskScope> {
269    pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
270        let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
271        TraceDenseData {
272            dense: host_dense,
273            preprocessed_offset: self.preprocessed_offset,
274            preprocessed_cols: self.preprocessed_cols,
275            preprocessed_table_index: self.preprocessed_table_index,
276            main_table_index: self.main_table_index,
277            preprocessed_padding: self.preprocessed_padding,
278            main_padding: self.main_padding,
279        }
280    }
281}