Skip to main content

sp1_gpu_utils/
jagged.rs

1use std::iter::once;
2
3use slop_alloc::{Backend, Buffer, HasBackend};
4use slop_tensor::{Dimensions, Tensor};
5use sp1_gpu_cudart::TaskScope;
6
7#[derive(Clone, Debug)]
8#[repr(C)]
9pub struct JaggedMle<D: DenseData<A>, A: Backend> {
10    /// col_index[i / 2] is the column that the i'th element of the dense data belongs to.
11    pub col_index: Buffer<u32, A>,
12    /// start_indices[i] is the half the dense index of the first element of the i'th column.
13    pub start_indices: Buffer<u32, A>,
14    /// column_heights[i] is half of the height of the i'th column.
15    pub column_heights: Vec<u32>,
16    pub dense_data: D,
17}
18
19pub struct VirtualTensor<T, B: Backend> {
20    pub data: *const T,
21    pub sizes: Dimensions,
22    pub backend: B,
23}
24
25impl<T, B: Backend> VirtualTensor<T, B> {
26    pub fn new(data: *const T, sizes: Dimensions, backend: B) -> Self {
27        Self { data, sizes, backend }
28    }
29
30    pub fn sizes(&self) -> &[usize] {
31        self.sizes.sizes()
32    }
33
34    pub fn backend(&self) -> &B {
35        &self.backend
36    }
37
38    pub fn as_ptr(&self) -> *const T {
39        self.data
40    }
41
42    pub fn from_tensor(tensor: &Tensor<T, B>) -> Self {
43        Self {
44            data: tensor.as_ptr(),
45            sizes: tensor.shape().clone(),
46            backend: tensor.backend().clone(),
47        }
48    }
49}
50
51pub trait DenseData<A: Backend> {
52    type DenseDataRaw;
53    fn as_ptr(&self) -> Self::DenseDataRaw;
54}
55
56pub trait DenseDataMut<A: Backend>: DenseData<A> {
57    type DenseDataMutRaw;
58    fn as_mut_ptr(&mut self) -> Self::DenseDataMutRaw;
59}
60
61/// The raw pointer equivalent of [`JaggedMle`] for use in cuda kernels.
62#[repr(C)]
63pub struct JaggedMleRaw<D: DenseData<A>, A: Backend> {
64    col_index: *const u32,
65    start_indices: *const u32,
66    dense_data: D::DenseDataRaw,
67}
68
69/// The mutable raw pointer equivalent of [`JaggedMle`] for use in cuda kernels.
70#[repr(C)]
71pub struct JaggedMleMutRaw<D: DenseDataMut<A>, A: Backend> {
72    col_index: *mut u32,
73    start_indices: *mut u32,
74    dense_data: D::DenseDataMutRaw,
75}
76
77impl<D: DenseData<A>, A: Backend> JaggedMle<D, A> {
78    pub fn as_raw(&self) -> JaggedMleRaw<D, A> {
79        JaggedMleRaw {
80            col_index: self.col_index.as_ptr(),
81            start_indices: self.start_indices.as_ptr(),
82            dense_data: self.dense_data.as_ptr(),
83        }
84    }
85
86    pub fn as_mut_raw(&mut self) -> JaggedMleMutRaw<D, A>
87    where
88        D: DenseDataMut<A>,
89    {
90        JaggedMleMutRaw {
91            col_index: self.col_index.as_mut_ptr(),
92            start_indices: self.start_indices.as_mut_ptr(),
93            dense_data: self.dense_data.as_mut_ptr(),
94        }
95    }
96
97    pub fn new(
98        dense_data: D,
99        col_index: Buffer<u32, A>,
100        start_indices: Buffer<u32, A>,
101        column_heights: Vec<u32>,
102    ) -> Self {
103        Self { dense_data, col_index, start_indices, column_heights }
104    }
105
106    pub fn dense(&self) -> &D {
107        &self.dense_data
108    }
109
110    pub fn dense_mut(&mut self) -> &mut D {
111        &mut self.dense_data
112    }
113
114    pub fn col_index(&self) -> &Buffer<u32, A> {
115        &self.col_index
116    }
117
118    pub fn col_index_mut(&mut self) -> &mut Buffer<u32, A> {
119        &mut self.col_index
120    }
121
122    pub fn start_indices(&self) -> &Buffer<u32, A> {
123        &self.start_indices
124    }
125
126    pub fn start_indices_mut(&mut self) -> &mut Buffer<u32, A> {
127        &mut self.start_indices
128    }
129
130    pub fn into_parts(self) -> (D, Buffer<u32, A>, Buffer<u32, A>) {
131        (self.dense_data, self.col_index, self.start_indices)
132    }
133}
134
135impl<D: DenseData<TaskScope>> JaggedMle<D, TaskScope> {
136    /// Computes the next start indices and column heights for use in jagged fix last variable.
137    ///
138    /// TODO: ignore all of the padding stuff.
139    pub fn next_start_indices_and_column_heights(&self) -> (Buffer<u32>, Vec<u32>) {
140        let output_heights =
141            self.column_heights.iter().map(|height| height.div_ceil(4) * 2).collect::<Vec<u32>>();
142
143        let new_start_idx = once(0)
144            .chain(output_heights.iter().scan(0u32, |acc, x| {
145                *acc += x;
146                Some(*acc)
147            }))
148            .collect::<Vec<_>>();
149        let buffer_start_idx = Buffer::from(new_start_idx);
150        (buffer_start_idx, output_heights)
151    }
152}
153
154impl<D: DenseData<A>, A: Backend> HasBackend for JaggedMle<D, A> {
155    type Backend = A;
156    fn backend(&self) -> &A {
157        self.col_index.backend()
158    }
159}