1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut, Range};
3
4use slop_algebra::Field;
5use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
6use slop_tensor::{Dimensions, Tensor, TensorView};
7use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
8
9use crate::jagged::JaggedMle;
10use crate::{DenseData, DenseDataMut};
11
12#[derive(Clone, Debug)]
13pub struct TraceOffset {
14 pub dense_offset: Range<usize>,
16 pub poly_size: usize,
18 pub num_polys: usize,
20}
21
22#[derive(Clone)]
23pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
24
25impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
26 type Backend = B;
27 fn backend(&self) -> &B {
28 self.0.backend()
29 }
30}
31
32impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
33 type Target = JaggedMle<TraceDenseData<F, B>, B>;
34
35 fn deref(&self) -> &Self::Target {
36 &self.0
37 }
38}
39
40impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.0
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct TraceDenseData<F: Field, B: Backend> {
49 pub dense: Buffer<F, B>,
51 pub preprocessed_offset: usize,
53 pub preprocessed_cols: usize,
55 pub preprocessed_padding: usize,
57 pub main_padding: usize,
59 pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
61 pub main_table_index: BTreeMap<String, TraceOffset>,
63}
64
65impl<F: Field, B: Backend> TraceDenseData<F, B> {
66 pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
67 let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
68 let sizes = Dimensions::try_from([
69 self.main_size() / (1 << log_stacking_height),
70 1 << log_stacking_height,
71 ])
72 .unwrap();
73 unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
75 }
76
77 pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
79 let mut tensor = Tensor::with_sizes_in(
80 [self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
81 self.backend().clone(),
82 );
83 let backend = self.dense.backend();
84 unsafe {
85 tensor.assume_init();
86 tensor
87 .as_mut_buffer()
88 .copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
89 .unwrap();
90 }
91 tensor
92 }
93
94 pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
95 let ptr = self.dense.as_ptr();
96 let sizes = Dimensions::try_from([
97 self.preprocessed_offset / (1 << log_stacking_height),
98 1 << log_stacking_height,
99 ])
100 .unwrap();
101 unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
102 }
103
104 pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
106 let mut tensor = Tensor::with_sizes_in(
107 [self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
108 self.backend().clone(),
109 );
110 let backend = self.dense.backend();
111 unsafe {
112 tensor.assume_init();
113 tensor
114 .as_mut_buffer()
115 .copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
116 .unwrap();
117 }
118 tensor
119 }
120
121 #[inline]
123 pub fn main_poly_height(&self, name: &str) -> Option<usize> {
124 self.main_table_index.get(name).map(|offset| offset.poly_size)
125 }
126
127 #[inline]
129 pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
130 self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
131 }
132
133 #[inline]
135 pub fn main_num_polys(&self, name: &str) -> Option<usize> {
136 self.main_table_index.get(name).map(|offset| offset.num_polys)
137 }
138
139 #[inline]
141 pub fn main_size(&self) -> usize {
142 self.dense.len() - self.preprocessed_offset
143 }
144
145 #[inline]
147 pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
148 self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
149 }
150}
151
152pub struct AbstractChipLayout(Vec<(String, usize, usize)>);
155
156impl AbstractChipLayout {
157 pub fn new(entries: Vec<(String, usize, usize)>) -> Self {
158 Self(entries)
159 }
160
161 pub fn entries(&self) -> &[(String, usize, usize)] {
162 &self.0
163 }
164}
165
166pub struct AbstractChipLayoutWithHeights(Vec<(String, usize, usize, usize)>);
169
170impl AbstractChipLayoutWithHeights {
171 pub fn new(entries: Vec<(String, usize, usize, usize)>) -> Self {
172 Self(entries)
173 }
174
175 pub fn entries(&self) -> &[(String, usize, usize, usize)] {
176 &self.0
177 }
178
179 pub fn chip_names(&self) -> impl Iterator<Item = &str> {
181 self.0.iter().map(|(name, _, _, _)| name.as_str())
182 }
183}
184
185impl<F: Field> TraceDenseData<F, CpuBackend> {
186 pub fn from_chip_layout(
193 dense: Buffer<F, CpuBackend>,
194 layout: &AbstractChipLayoutWithHeights,
195 log_stacking_height: u32,
196 ) -> Self {
197 let stacking = 1usize << log_stacking_height;
198
199 let total_preprocessed: usize = layout.0.iter().map(|(_, p, _, h)| p * h).sum();
200 let total_main: usize = layout.0.iter().map(|(_, _, m, h)| m * h).sum();
201
202 let padded_preprocessed = total_preprocessed.next_multiple_of(stacking).max(stacking);
204 let padded_main = total_main.next_multiple_of(stacking).max(stacking);
205
206 assert_eq!(
207 dense.len(),
208 padded_preprocessed + padded_main,
209 "dense buffer length must equal padded_preprocessed + padded_main",
210 );
211
212 let preprocessed_cols: usize = layout.0.iter().map(|(_, p, _, _)| p).sum();
213
214 let mut preprocessed_table_index = BTreeMap::new();
215 let mut main_table_index = BTreeMap::new();
216 let mut preprocessed_ptr = 0usize;
217 let mut main_ptr = padded_preprocessed;
218 for (name, prep_w, main_w, h) in layout.0.iter() {
219 let prep_lo = preprocessed_ptr;
220 let prep_hi = prep_lo + h * prep_w;
221 preprocessed_table_index.insert(
222 name.clone(),
223 TraceOffset { dense_offset: prep_lo..prep_hi, poly_size: *h, num_polys: *prep_w },
224 );
225 preprocessed_ptr = prep_hi;
226
227 let main_lo = main_ptr;
228 let main_hi = main_lo + h * main_w;
229 main_table_index.insert(
230 name.clone(),
231 TraceOffset { dense_offset: main_lo..main_hi, poly_size: *h, num_polys: *main_w },
232 );
233 main_ptr = main_hi;
234 }
235
236 TraceDenseData {
237 dense,
238 preprocessed_offset: padded_preprocessed,
239 preprocessed_cols,
240 preprocessed_padding: padded_preprocessed - total_preprocessed,
241 main_padding: padded_main - total_main,
242 preprocessed_table_index,
243 main_table_index,
244 }
245 }
246}
247
248impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
249 type Backend = B;
250 fn backend(&self) -> &B {
251 self.dense.backend()
252 }
253}
254
255impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
256 pub fn new(
257 dense_data: TraceDenseData<F, B>,
258 col_index: Buffer<u32, B>,
259 start_indices: Buffer<u32, B>,
260 column_heights: Vec<u32>,
261 ) -> Self {
262 JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
263 }
264}
265
266impl<F: Field> JaggedTraceMle<F, CpuBackend> {
267 pub fn from_chip_layout(
276 dense: Buffer<F, CpuBackend>,
277 layout: &AbstractChipLayoutWithHeights,
278 log_stacking_height: u32,
279 ) -> Self {
280 assert!(layout.0.iter().all(|(_, _, _, h)| h % 2 == 0), "heights must be even");
281
282 let dense_data = TraceDenseData::from_chip_layout(dense, layout, log_stacking_height);
283
284 let total_dense = dense_data.dense.len();
285 let preprocessed_padding = dense_data.preprocessed_padding;
286 let main_padding = dense_data.main_padding;
287
288 let num_data_cols: usize = layout.0.iter().map(|(_, p, m, _)| p + m).sum();
289 let num_cols =
290 num_data_cols + (preprocessed_padding > 0) as usize + (main_padding > 0) as usize;
291
292 let mut col_index = vec![0u32; total_dense / 2];
293 let mut start_idx = vec![0u32; num_cols + 1];
294 let mut column_heights: Vec<u32> = Vec::with_capacity(num_cols);
295
296 let mut col: u32 = 0;
297 let mut cnt: usize = 0;
298
299 let mut emit = |w: usize, h: usize, col: &mut u32, cnt: &mut usize| {
300 let half = h / 2;
301 for _ in 0..w {
302 col_index[*cnt..*cnt + half].fill(*col);
303 *cnt += half;
304 start_idx[*col as usize + 1] = start_idx[*col as usize] + half as u32;
305 column_heights.push(half as u32);
306 *col += 1;
307 }
308 };
309
310 for (_, prep_w, _, h) in layout.0.iter() {
311 emit(*prep_w, *h, &mut col, &mut cnt);
312 }
313 if preprocessed_padding > 0 {
314 emit(1, preprocessed_padding, &mut col, &mut cnt);
315 }
316 for (_, _, main_w, h) in layout.0.iter() {
317 emit(*main_w, *h, &mut col, &mut cnt);
318 }
319 if main_padding > 0 {
320 emit(1, main_padding, &mut col, &mut cnt);
321 }
322
323 debug_assert_eq!(cnt, total_dense / 2);
324 debug_assert_eq!(col as usize, num_cols);
325
326 Self::new(dense_data, Buffer::from(col_index), Buffer::from(start_idx), column_heights)
327 }
328}
329
330impl<F: Field> JaggedTraceMle<F, TaskScope> {
331 pub fn preprocessed_virtual_tensor(
332 &'_ self,
333 log_stacking_height: u32,
334 ) -> TensorView<'_, F, TaskScope> {
335 self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
336 }
337
338 pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
339 self.dense_data.main_virtual_tensor(log_stacking_height)
340 }
341
342 pub fn main_poly_height(&self, name: &str) -> Option<usize> {
343 self.dense_data.main_poly_height(name)
344 }
345
346 pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
347 self.dense_data.preprocessed_poly_height(name)
348 }
349
350 pub fn main_num_polys(&self, name: &str) -> Option<usize> {
351 self.dense_data.main_num_polys(name)
352 }
353
354 pub fn main_size(&self) -> usize {
355 self.dense_data.main_size()
356 }
357
358 pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
359 self.dense_data.preprocessed_num_polys(name)
360 }
361}
362
363#[repr(C)]
365pub struct TraceDenseDataRaw<F> {
366 dense: *const F,
367}
368
369#[repr(C)]
371pub struct TraceDenseDataMutRaw<F> {
372 dense: *mut F,
373}
374
375impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
376 type DenseDataRaw = TraceDenseDataRaw<F>;
377
378 fn as_ptr(&self) -> TraceDenseDataRaw<F> {
379 TraceDenseDataRaw { dense: self.dense.as_ptr() }
380 }
381}
382
383impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
384 type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
385
386 fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
387 TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
388 }
389}
390
391impl<F: Field> JaggedTraceMle<F, CpuBackend> {
392 pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
393 let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
394 JaggedTraceMle::new(
395 dense_data.into_device_in(t),
396 DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
397 DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
398 column_heights,
399 )
400 }
401}
402
403impl<F: Field> TraceDenseData<F, CpuBackend> {
404 pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
405 TraceDenseData {
406 dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
407 preprocessed_offset: self.preprocessed_offset,
408 preprocessed_cols: self.preprocessed_cols,
409 preprocessed_table_index: self.preprocessed_table_index,
410 main_table_index: self.main_table_index,
411 preprocessed_padding: self.preprocessed_padding,
412 main_padding: self.main_padding,
413 }
414 }
415}
416
417impl<F: Field> JaggedTraceMle<F, TaskScope> {
418 pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
419 let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
420 let host_dense = dense_data.into_host();
421 let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
423 let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
424 JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights)
425 }
426}
427
428impl<F: Field> TraceDenseData<F, TaskScope> {
429 pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
430 let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
431 TraceDenseData {
432 dense: host_dense,
433 preprocessed_offset: self.preprocessed_offset,
434 preprocessed_cols: self.preprocessed_cols,
435 preprocessed_table_index: self.preprocessed_table_index,
436 main_table_index: self.main_table_index,
437 preprocessed_padding: self.preprocessed_padding,
438 main_padding: self.main_padding,
439 }
440 }
441}