1use std::iter::once;
2
3use slop_alloc::{Backend, Buffer, HasBackend};
4use slop_tensor::{Dimensions, Tensor};
5use sp1_gpu_cudart::TaskScope;
6
7#[derive(Clone, Debug)]
8#[repr(C)]
9pub struct JaggedMle<D: DenseData<A>, A: Backend> {
10 pub col_index: Buffer<u32, A>,
12 pub start_indices: Buffer<u32, A>,
14 pub column_heights: Vec<u32>,
16 pub dense_data: D,
17}
18
19pub struct VirtualTensor<T, B: Backend> {
20 pub data: *const T,
21 pub sizes: Dimensions,
22 pub backend: B,
23}
24
25impl<T, B: Backend> VirtualTensor<T, B> {
26 pub fn new(data: *const T, sizes: Dimensions, backend: B) -> Self {
27 Self { data, sizes, backend }
28 }
29
30 pub fn sizes(&self) -> &[usize] {
31 self.sizes.sizes()
32 }
33
34 pub fn backend(&self) -> &B {
35 &self.backend
36 }
37
38 pub fn as_ptr(&self) -> *const T {
39 self.data
40 }
41
42 pub fn from_tensor(tensor: &Tensor<T, B>) -> Self {
43 Self {
44 data: tensor.as_ptr(),
45 sizes: tensor.shape().clone(),
46 backend: tensor.backend().clone(),
47 }
48 }
49}
50
51pub trait DenseData<A: Backend> {
52 type DenseDataRaw;
53 fn as_ptr(&self) -> Self::DenseDataRaw;
54}
55
56pub trait DenseDataMut<A: Backend>: DenseData<A> {
57 type DenseDataMutRaw;
58 fn as_mut_ptr(&mut self) -> Self::DenseDataMutRaw;
59}
60
61#[repr(C)]
63pub struct JaggedMleRaw<D: DenseData<A>, A: Backend> {
64 col_index: *const u32,
65 start_indices: *const u32,
66 dense_data: D::DenseDataRaw,
67}
68
69#[repr(C)]
71pub struct JaggedMleMutRaw<D: DenseDataMut<A>, A: Backend> {
72 col_index: *mut u32,
73 start_indices: *mut u32,
74 dense_data: D::DenseDataMutRaw,
75}
76
77impl<D: DenseData<A>, A: Backend> JaggedMle<D, A> {
78 pub fn as_raw(&self) -> JaggedMleRaw<D, A> {
79 JaggedMleRaw {
80 col_index: self.col_index.as_ptr(),
81 start_indices: self.start_indices.as_ptr(),
82 dense_data: self.dense_data.as_ptr(),
83 }
84 }
85
86 pub fn as_mut_raw(&mut self) -> JaggedMleMutRaw<D, A>
87 where
88 D: DenseDataMut<A>,
89 {
90 JaggedMleMutRaw {
91 col_index: self.col_index.as_mut_ptr(),
92 start_indices: self.start_indices.as_mut_ptr(),
93 dense_data: self.dense_data.as_mut_ptr(),
94 }
95 }
96
97 pub fn new(
98 dense_data: D,
99 col_index: Buffer<u32, A>,
100 start_indices: Buffer<u32, A>,
101 column_heights: Vec<u32>,
102 ) -> Self {
103 Self { dense_data, col_index, start_indices, column_heights }
104 }
105
106 pub fn dense(&self) -> &D {
107 &self.dense_data
108 }
109
110 pub fn dense_mut(&mut self) -> &mut D {
111 &mut self.dense_data
112 }
113
114 pub fn col_index(&self) -> &Buffer<u32, A> {
115 &self.col_index
116 }
117
118 pub fn col_index_mut(&mut self) -> &mut Buffer<u32, A> {
119 &mut self.col_index
120 }
121
122 pub fn start_indices(&self) -> &Buffer<u32, A> {
123 &self.start_indices
124 }
125
126 pub fn start_indices_mut(&mut self) -> &mut Buffer<u32, A> {
127 &mut self.start_indices
128 }
129
130 pub fn into_parts(self) -> (D, Buffer<u32, A>, Buffer<u32, A>) {
131 (self.dense_data, self.col_index, self.start_indices)
132 }
133}
134
135impl<D: DenseData<TaskScope>> JaggedMle<D, TaskScope> {
136 pub fn next_start_indices_and_column_heights(&self) -> (Buffer<u32>, Vec<u32>) {
140 let output_heights =
141 self.column_heights.iter().map(|height| height.div_ceil(4) * 2).collect::<Vec<u32>>();
142
143 let new_start_idx = once(0)
144 .chain(output_heights.iter().scan(0u32, |acc, x| {
145 *acc += x;
146 Some(*acc)
147 }))
148 .collect::<Vec<_>>();
149 let buffer_start_idx = Buffer::from(new_start_idx);
150 (buffer_start_idx, output_heights)
151 }
152}
153
154impl<D: DenseData<A>, A: Backend> HasBackend for JaggedMle<D, A> {
155 type Backend = A;
156 fn backend(&self) -> &A {
157 self.col_index.backend()
158 }
159}