1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut, Range};
3
4use slop_algebra::Field;
5use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
6use slop_tensor::{Dimensions, Tensor, TensorView};
7use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
8
9use crate::jagged::JaggedMle;
10use crate::{DenseData, DenseDataMut};
11
12#[derive(Clone, Debug)]
13pub struct TraceOffset {
14 pub dense_offset: Range<usize>,
16 pub poly_size: usize,
18 pub num_polys: usize,
20}
21
22#[derive(Clone)]
23pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
24
25impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
26 type Backend = B;
27 fn backend(&self) -> &B {
28 self.0.backend()
29 }
30}
31
32impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
33 type Target = JaggedMle<TraceDenseData<F, B>, B>;
34
35 fn deref(&self) -> &Self::Target {
36 &self.0
37 }
38}
39
40impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.0
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct TraceDenseData<F: Field, B: Backend> {
49 pub dense: Buffer<F, B>,
51 pub preprocessed_offset: usize,
53 pub preprocessed_cols: usize,
55 pub preprocessed_padding: usize,
57 pub main_padding: usize,
59 pub prep_padding_col_count: usize,
68 pub main_padding_col_count: usize,
71 pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
73 pub main_table_index: BTreeMap<String, TraceOffset>,
75}
76
77impl<F: Field, B: Backend> TraceDenseData<F, B> {
78 pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
79 let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
80 let sizes = Dimensions::try_from([
81 self.main_size() / (1 << log_stacking_height),
82 1 << log_stacking_height,
83 ])
84 .unwrap();
85 unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
87 }
88
89 pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
91 let mut tensor = Tensor::with_sizes_in(
92 [self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
93 self.backend().clone(),
94 );
95 let backend = self.dense.backend();
96 unsafe {
97 tensor.assume_init();
98 tensor
99 .as_mut_buffer()
100 .copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
101 .unwrap();
102 }
103 tensor
104 }
105
106 pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
107 let ptr = self.dense.as_ptr();
108 let sizes = Dimensions::try_from([
109 self.preprocessed_offset / (1 << log_stacking_height),
110 1 << log_stacking_height,
111 ])
112 .unwrap();
113 unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
114 }
115
116 pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
118 let mut tensor = Tensor::with_sizes_in(
119 [self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
120 self.backend().clone(),
121 );
122 let backend = self.dense.backend();
123 unsafe {
124 tensor.assume_init();
125 tensor
126 .as_mut_buffer()
127 .copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
128 .unwrap();
129 }
130 tensor
131 }
132
133 #[inline]
135 pub fn main_poly_height(&self, name: &str) -> Option<usize> {
136 self.main_table_index.get(name).map(|offset| offset.poly_size)
137 }
138
139 #[inline]
141 pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
142 self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
143 }
144
145 #[inline]
147 pub fn main_num_polys(&self, name: &str) -> Option<usize> {
148 self.main_table_index.get(name).map(|offset| offset.num_polys)
149 }
150
151 #[inline]
153 pub fn main_size(&self) -> usize {
154 self.dense.len() - self.preprocessed_offset
155 }
156
157 #[inline]
159 pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
160 self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
161 }
162}
163
164pub struct AbstractChipLayout(Vec<(String, usize, usize)>);
167
168impl AbstractChipLayout {
169 pub fn new(entries: Vec<(String, usize, usize)>) -> Self {
170 Self(entries)
171 }
172
173 pub fn entries(&self) -> &[(String, usize, usize)] {
174 &self.0
175 }
176}
177
178pub struct AbstractChipLayoutWithHeights(Vec<(String, usize, usize, usize)>);
181
182impl AbstractChipLayoutWithHeights {
183 pub fn new(entries: Vec<(String, usize, usize, usize)>) -> Self {
184 Self(entries)
185 }
186
187 pub fn entries(&self) -> &[(String, usize, usize, usize)] {
188 &self.0
189 }
190
191 pub fn chip_names(&self) -> impl Iterator<Item = &str> {
193 self.0.iter().map(|(name, _, _, _)| name.as_str())
194 }
195}
196
197impl<F: Field> TraceDenseData<F, CpuBackend> {
198 pub fn from_chip_layout(
205 dense: Buffer<F, CpuBackend>,
206 layout: &AbstractChipLayoutWithHeights,
207 log_stacking_height: u32,
208 ) -> Self {
209 let stacking = 1usize << log_stacking_height;
210
211 let total_preprocessed: usize = layout.0.iter().map(|(_, p, _, h)| p * h).sum();
212 let total_main: usize = layout.0.iter().map(|(_, _, m, h)| m * h).sum();
213
214 let padded_preprocessed = total_preprocessed.next_multiple_of(stacking).max(stacking);
216 let padded_main = total_main.next_multiple_of(stacking).max(stacking);
217
218 assert_eq!(
219 dense.len(),
220 padded_preprocessed + padded_main,
221 "dense buffer length must equal padded_preprocessed + padded_main",
222 );
223
224 let preprocessed_cols: usize = layout.0.iter().map(|(_, p, _, _)| p).sum();
225
226 let mut preprocessed_table_index = BTreeMap::new();
227 let mut main_table_index = BTreeMap::new();
228 let mut preprocessed_ptr = 0usize;
229 let mut main_ptr = padded_preprocessed;
230 for (name, prep_w, main_w, h) in layout.0.iter() {
231 let prep_lo = preprocessed_ptr;
232 let prep_hi = prep_lo + h * prep_w;
233 preprocessed_table_index.insert(
234 name.clone(),
235 TraceOffset { dense_offset: prep_lo..prep_hi, poly_size: *h, num_polys: *prep_w },
236 );
237 preprocessed_ptr = prep_hi;
238
239 let main_lo = main_ptr;
240 let main_hi = main_lo + h * main_w;
241 main_table_index.insert(
242 name.clone(),
243 TraceOffset { dense_offset: main_lo..main_hi, poly_size: *h, num_polys: *main_w },
244 );
245 main_ptr = main_hi;
246 }
247
248 let preprocessed_padding = padded_preprocessed - total_preprocessed;
249 let main_padding = padded_main - total_main;
250 TraceDenseData {
251 dense,
252 preprocessed_offset: padded_preprocessed,
253 preprocessed_cols,
254 preprocessed_padding,
255 main_padding,
256 prep_padding_col_count: (preprocessed_padding > 0) as usize,
260 main_padding_col_count: (main_padding > 0) as usize,
261 preprocessed_table_index,
262 main_table_index,
263 }
264 }
265}
266
267impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
268 type Backend = B;
269 fn backend(&self) -> &B {
270 self.dense.backend()
271 }
272}
273
274impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
275 pub fn new(
276 dense_data: TraceDenseData<F, B>,
277 col_index: Buffer<u32, B>,
278 start_indices: Buffer<u32, B>,
279 column_heights: Buffer<u32, B>,
280 ) -> Self {
281 JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
282 }
283}
284
285impl<F: Field> JaggedTraceMle<F, CpuBackend> {
286 pub fn from_chip_layout(
295 dense: Buffer<F, CpuBackend>,
296 layout: &AbstractChipLayoutWithHeights,
297 log_stacking_height: u32,
298 ) -> Self {
299 assert!(layout.0.iter().all(|(_, _, _, h)| h % 2 == 0), "heights must be even");
300
301 let dense_data = TraceDenseData::from_chip_layout(dense, layout, log_stacking_height);
302
303 let total_dense = dense_data.dense.len();
304 let preprocessed_padding = dense_data.preprocessed_padding;
305 let main_padding = dense_data.main_padding;
306
307 let num_data_cols: usize = layout.0.iter().map(|(_, p, m, _)| p + m).sum();
308 let num_cols =
309 num_data_cols + (preprocessed_padding > 0) as usize + (main_padding > 0) as usize;
310
311 let mut col_index = vec![0u32; total_dense / 2];
312 let mut start_idx = vec![0u32; num_cols + 1];
313 let mut column_heights: Vec<u32> = Vec::with_capacity(num_cols);
314
315 let mut col: u32 = 0;
316 let mut cnt: usize = 0;
317
318 let mut emit = |w: usize, h: usize, col: &mut u32, cnt: &mut usize| {
319 let half = h / 2;
320 for _ in 0..w {
321 col_index[*cnt..*cnt + half].fill(*col);
322 *cnt += half;
323 start_idx[*col as usize + 1] = start_idx[*col as usize] + half as u32;
324 column_heights.push(half as u32);
325 *col += 1;
326 }
327 };
328
329 for (_, prep_w, _, h) in layout.0.iter() {
330 emit(*prep_w, *h, &mut col, &mut cnt);
331 }
332 if preprocessed_padding > 0 {
333 emit(1, preprocessed_padding, &mut col, &mut cnt);
334 }
335 for (_, _, main_w, h) in layout.0.iter() {
336 emit(*main_w, *h, &mut col, &mut cnt);
337 }
338 if main_padding > 0 {
339 emit(1, main_padding, &mut col, &mut cnt);
340 }
341
342 debug_assert_eq!(cnt, total_dense / 2);
343 debug_assert_eq!(col as usize, num_cols);
344
345 Self::new(
346 dense_data,
347 Buffer::from(col_index),
348 Buffer::from(start_idx),
349 Buffer::from(column_heights),
350 )
351 }
352}
353
354impl<F: Field> JaggedTraceMle<F, TaskScope> {
355 pub fn preprocessed_virtual_tensor(
356 &'_ self,
357 log_stacking_height: u32,
358 ) -> TensorView<'_, F, TaskScope> {
359 self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
360 }
361
362 pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
363 self.dense_data.main_virtual_tensor(log_stacking_height)
364 }
365
366 pub fn main_poly_height(&self, name: &str) -> Option<usize> {
367 self.dense_data.main_poly_height(name)
368 }
369
370 pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
371 self.dense_data.preprocessed_poly_height(name)
372 }
373
374 pub fn main_num_polys(&self, name: &str) -> Option<usize> {
375 self.dense_data.main_num_polys(name)
376 }
377
378 pub fn main_size(&self) -> usize {
379 self.dense_data.main_size()
380 }
381
382 pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
383 self.dense_data.preprocessed_num_polys(name)
384 }
385}
386
387#[repr(C)]
389pub struct TraceDenseDataRaw<F> {
390 dense: *const F,
391}
392
393#[repr(C)]
395pub struct TraceDenseDataMutRaw<F> {
396 dense: *mut F,
397}
398
399impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
400 type DenseDataRaw = TraceDenseDataRaw<F>;
401
402 fn as_ptr(&self) -> TraceDenseDataRaw<F> {
403 TraceDenseDataRaw { dense: self.dense.as_ptr() }
404 }
405}
406
407impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
408 type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
409
410 fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
411 TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
412 }
413}
414
415impl<F: Field> JaggedTraceMle<F, CpuBackend> {
416 pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
417 let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
418 JaggedTraceMle::new(
419 dense_data.into_device_in(t),
420 DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
421 DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
422 DeviceBuffer::from_host(&column_heights, t).unwrap().into_inner(),
423 )
424 }
425}
426
427impl<F: Field> TraceDenseData<F, CpuBackend> {
428 pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
429 TraceDenseData {
430 dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
431 preprocessed_offset: self.preprocessed_offset,
432 preprocessed_cols: self.preprocessed_cols,
433 preprocessed_table_index: self.preprocessed_table_index,
434 main_table_index: self.main_table_index,
435 preprocessed_padding: self.preprocessed_padding,
436 main_padding: self.main_padding,
437 prep_padding_col_count: self.prep_padding_col_count,
438 main_padding_col_count: self.main_padding_col_count,
439 }
440 }
441}
442
443impl<F: Field> JaggedTraceMle<F, TaskScope> {
444 pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
445 let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
446 let host_dense = dense_data.into_host();
447 let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
449 let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
450 let column_heights_host = DeviceBuffer::from_raw(column_heights).to_host().unwrap().into();
451 JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights_host)
452 }
453}
454
455impl<F: Field> TraceDenseData<F, TaskScope> {
456 pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
457 let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
458 TraceDenseData {
459 dense: host_dense,
460 preprocessed_offset: self.preprocessed_offset,
461 preprocessed_cols: self.preprocessed_cols,
462 preprocessed_table_index: self.preprocessed_table_index,
463 main_table_index: self.main_table_index,
464 preprocessed_padding: self.preprocessed_padding,
465 main_padding: self.main_padding,
466 prep_padding_col_count: self.prep_padding_col_count,
467 main_padding_col_count: self.main_padding_col_count,
468 }
469 }
470}