1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut, Range};
3
4use slop_algebra::Field;
5use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
6use slop_tensor::{Dimensions, Tensor, TensorView};
7use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
8
9use crate::jagged::JaggedMle;
10use crate::{DenseData, DenseDataMut};
11
12#[derive(Clone, Debug)]
13pub struct TraceOffset {
14 pub dense_offset: Range<usize>,
16 pub poly_size: usize,
18 pub num_polys: usize,
20}
21
22#[derive(Clone)]
23pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
24
25impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
26 type Backend = B;
27 fn backend(&self) -> &B {
28 self.0.backend()
29 }
30}
31
32impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
33 type Target = JaggedMle<TraceDenseData<F, B>, B>;
34
35 fn deref(&self) -> &Self::Target {
36 &self.0
37 }
38}
39
40impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.0
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct TraceDenseData<F: Field, B: Backend> {
49 pub dense: Buffer<F, B>,
51 pub preprocessed_offset: usize,
53 pub preprocessed_cols: usize,
55 pub preprocessed_padding: usize,
57 pub main_padding: usize,
59 pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
61 pub main_table_index: BTreeMap<String, TraceOffset>,
63}
64
65impl<F: Field, B: Backend> TraceDenseData<F, B> {
66 pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
67 let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
68 let sizes = Dimensions::try_from([
69 self.main_size() / (1 << log_stacking_height),
70 1 << log_stacking_height,
71 ])
72 .unwrap();
73 unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
75 }
76
77 pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
79 let mut tensor = Tensor::with_sizes_in(
80 [self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
81 self.backend().clone(),
82 );
83 let backend = self.dense.backend();
84 unsafe {
85 tensor.assume_init();
86 tensor
87 .as_mut_buffer()
88 .copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
89 .unwrap();
90 }
91 tensor
92 }
93
94 pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
95 let ptr = self.dense.as_ptr();
96 let sizes = Dimensions::try_from([
97 self.preprocessed_offset / (1 << log_stacking_height),
98 1 << log_stacking_height,
99 ])
100 .unwrap();
101 unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
102 }
103
104 pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
106 let mut tensor = Tensor::with_sizes_in(
107 [self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
108 self.backend().clone(),
109 );
110 let backend = self.dense.backend();
111 unsafe {
112 tensor.assume_init();
113 tensor
114 .as_mut_buffer()
115 .copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
116 .unwrap();
117 }
118 tensor
119 }
120
121 #[inline]
123 pub fn main_poly_height(&self, name: &str) -> Option<usize> {
124 self.main_table_index.get(name).map(|offset| offset.poly_size)
125 }
126
127 #[inline]
129 pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
130 self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
131 }
132
133 #[inline]
135 pub fn main_num_polys(&self, name: &str) -> Option<usize> {
136 self.main_table_index.get(name).map(|offset| offset.num_polys)
137 }
138
139 #[inline]
141 pub fn main_size(&self) -> usize {
142 self.dense.len() - self.preprocessed_offset
143 }
144
145 #[inline]
147 pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
148 self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
149 }
150}
151
152impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
153 type Backend = B;
154 fn backend(&self) -> &B {
155 self.dense.backend()
156 }
157}
158
159impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
160 pub fn new(
161 dense_data: TraceDenseData<F, B>,
162 col_index: Buffer<u32, B>,
163 start_indices: Buffer<u32, B>,
164 column_heights: Vec<u32>,
165 ) -> Self {
166 JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
167 }
168}
169
170impl<F: Field> JaggedTraceMle<F, TaskScope> {
171 pub fn preprocessed_virtual_tensor(
172 &'_ self,
173 log_stacking_height: u32,
174 ) -> TensorView<'_, F, TaskScope> {
175 self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
176 }
177
178 pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
179 self.dense_data.main_virtual_tensor(log_stacking_height)
180 }
181
182 pub fn main_poly_height(&self, name: &str) -> Option<usize> {
183 self.dense_data.main_poly_height(name)
184 }
185
186 pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
187 self.dense_data.preprocessed_poly_height(name)
188 }
189
190 pub fn main_num_polys(&self, name: &str) -> Option<usize> {
191 self.dense_data.main_num_polys(name)
192 }
193
194 pub fn main_size(&self) -> usize {
195 self.dense_data.main_size()
196 }
197
198 pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
199 self.dense_data.preprocessed_num_polys(name)
200 }
201}
202
203#[repr(C)]
205pub struct TraceDenseDataRaw<F> {
206 dense: *const F,
207}
208
209#[repr(C)]
211pub struct TraceDenseDataMutRaw<F> {
212 dense: *mut F,
213}
214
215impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
216 type DenseDataRaw = TraceDenseDataRaw<F>;
217
218 fn as_ptr(&self) -> TraceDenseDataRaw<F> {
219 TraceDenseDataRaw { dense: self.dense.as_ptr() }
220 }
221}
222
223impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
224 type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
225
226 fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
227 TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
228 }
229}
230
231impl<F: Field> JaggedTraceMle<F, CpuBackend> {
232 pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
233 let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
234 JaggedTraceMle::new(
235 dense_data.into_device_in(t),
236 DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
237 DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
238 column_heights,
239 )
240 }
241}
242
243impl<F: Field> TraceDenseData<F, CpuBackend> {
244 pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
245 TraceDenseData {
246 dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
247 preprocessed_offset: self.preprocessed_offset,
248 preprocessed_cols: self.preprocessed_cols,
249 preprocessed_table_index: self.preprocessed_table_index,
250 main_table_index: self.main_table_index,
251 preprocessed_padding: self.preprocessed_padding,
252 main_padding: self.main_padding,
253 }
254 }
255}
256
257impl<F: Field> JaggedTraceMle<F, TaskScope> {
258 pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
259 let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
260 let host_dense = dense_data.into_host();
261 let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
263 let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
264 JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights)
265 }
266}
267
268impl<F: Field> TraceDenseData<F, TaskScope> {
269 pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
270 let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
271 TraceDenseData {
272 dense: host_dense,
273 preprocessed_offset: self.preprocessed_offset,
274 preprocessed_cols: self.preprocessed_cols,
275 preprocessed_table_index: self.preprocessed_table_index,
276 main_table_index: self.main_table_index,
277 preprocessed_padding: self.preprocessed_padding,
278 main_padding: self.main_padding,
279 }
280 }
281}