Skip to main content

sp1_gpu_cudart/tensor/
reduce.rs

1use std::ffi::c_void;
2
3use slop_tensor::{Dimensions, Tensor, TensorViewMut};
4use sp1_gpu_sys::{
5    reduce::{
6        koala_bear_extension_sum_block_reduce_kernel,
7        koala_bear_extension_sum_partial_block_reduce_kernel, koala_bear_sum_block_reduce_kernel,
8        koala_bear_sum_partial_block_reduce_kernel,
9    },
10    runtime::{Dim3, KernelPtr},
11};
12use sp1_primitives::{SP1ExtensionField, SP1Field};
13
14use crate::{args, DeviceCopy, DeviceTensor};
15
16use super::TaskScope;
17
18const MAX_NUM_FINAL_BLOCKS: usize = 2;
19
20/// Kernels for performing a sum over a block or a partial sum over a grid to block sums.
21///
22/// # Safety
23///
24/// The implementor must ensure that the arguments of the kernels are laid out as expected by the
25/// functions [block_sum] and [partial_sum_reduction] below.
26pub unsafe trait DeviceSumKernel<T> {
27    fn partial_sum_kernel() -> KernelPtr;
28
29    fn block_sum_kernel() -> KernelPtr;
30}
31
32fn block_sum<T: DeviceCopy, const BLOCK_SIZE: usize, const INTIAL_STRIDE: usize>(
33    src: &Tensor<T, TaskScope>,
34    mut dst: TensorViewMut<T, TaskScope>,
35    dim: usize,
36) where
37    TaskScope: DeviceSumKernel<T>,
38{
39    let height = src.sizes()[dim];
40    let width = src.total_len() / height;
41
42    let block_dim: Dim3 = BLOCK_SIZE.into();
43    let num_reduce_blocks = height.div_ceil(block_dim.x as usize).div_ceil(INTIAL_STRIDE);
44    let grid_dim: Dim3 = (num_reduce_blocks, width, 1).into();
45
46    // If the height is small enough, we can use one kernel for the sum.
47    let args = args!(src.as_ptr(), dst.as_mut_ptr(), width, height);
48    let shared_mem = 0;
49    unsafe {
50        src.backend()
51            .launch_kernel(TaskScope::block_sum_kernel(), grid_dim, block_dim, &args, shared_mem)
52            .unwrap();
53    }
54}
55
56/// A general sum based reduction that allows a generic first step.
57///
58/// # Safety
59#[inline]
60pub unsafe fn partial_sum_reduction_into<
61    T: DeviceCopy,
62    const BLOCK_SIZE: usize,
63    const INTIAL_STRIDE: usize,
64    const NUM_ARGS: usize,
65>(
66    dst: TensorViewMut<T, TaskScope>,
67    partial_reduction_kernel: KernelPtr,
68    mut partial_args: [*mut c_void; NUM_ARGS],
69    partial_shared_mem: usize,
70    reduction_shape: &Dimensions,
71    dim: usize,
72    scope: &TaskScope,
73) where
74    TaskScope: DeviceSumKernel<T>,
75{
76    let height = reduction_shape.sizes()[dim];
77    let width = reduction_shape.total_len() / height;
78
79    let block_dim: Dim3 = BLOCK_SIZE.into();
80    let num_reduce_blocks = height.div_ceil(block_dim.x as usize).div_ceil(INTIAL_STRIDE);
81    let grid_dim: Dim3 = (num_reduce_blocks, width, 1).into();
82
83    let mut sizes = reduction_shape.sizes().to_vec();
84    sizes[dim] = grid_dim.x as usize;
85    let mut partial_sums = Tensor::<T, _>::with_sizes_in(sizes.clone(), scope.clone());
86    let num_tiles = block_dim.x.checked_div(32).unwrap_or(1);
87    let shared_mem = num_tiles * block_dim.y * (std::mem::size_of::<T>() as u32);
88    let partial_args_ptr = &partial_sums.as_mut_ptr() as *const _ as *mut c_void;
89    partial_args[0] = partial_args_ptr;
90    let args = partial_args;
91    unsafe {
92        partial_sums.assume_init();
93        scope
94            .launch_kernel(
95                partial_reduction_kernel,
96                grid_dim,
97                block_dim,
98                &args,
99                shared_mem as usize + partial_shared_mem,
100            )
101            .unwrap();
102    }
103
104    // Now we need to sum the partial sums. We will do it in an iterative manner until the length
105    // is small enough to do the final summation in one kernel.
106    let mut partial_sums = partial_sums;
107    while sizes[dim] > MAX_NUM_FINAL_BLOCKS * BLOCK_SIZE {
108        let height = sizes[dim];
109        let block_dim: Dim3 = BLOCK_SIZE.into();
110        let num_reduce_blocks = height.div_ceil(block_dim.x as usize).div_ceil(INTIAL_STRIDE);
111        let grid_dim: Dim3 = (num_reduce_blocks, width, 1).into();
112
113        sizes[dim] = grid_dim.x as usize;
114        let mut current = Tensor::<T, _>::with_sizes_in(sizes.clone(), scope.clone());
115        let args = args!(current.as_mut_ptr(), partial_sums.as_ptr(), width, height);
116        let num_tiles = block_dim.x.checked_div(32).unwrap_or(1);
117        let shared_mem = num_tiles * block_dim.y * (std::mem::size_of::<T>() as u32);
118        unsafe {
119            current.assume_init();
120            scope
121                .launch_kernel(
122                    TaskScope::partial_sum_kernel(),
123                    grid_dim,
124                    block_dim,
125                    &args,
126                    shared_mem as usize,
127                )
128                .unwrap();
129        }
130        // sizes[dim] = num_reduce_blocks;
131        partial_sums = current;
132    }
133
134    // Now we need to sum the partial sums so we will use the block sum function.
135    block_sum::<T, BLOCK_SIZE, INTIAL_STRIDE>(&partial_sums, dst, dim);
136}
137
138/// # Safety
139pub unsafe fn partial_sum_reduction<
140    T: DeviceCopy,
141    const BLOCK_SIZE: usize,
142    const INTIAL_STRIDE: usize,
143    const NUM_ARGS: usize,
144>(
145    partial_reduction_kernel: KernelPtr,
146    partial_args: [*mut c_void; NUM_ARGS],
147    partial_shared_mem: usize,
148    reduction_shape: &Dimensions,
149    scope: &TaskScope,
150    dim: usize,
151) -> Tensor<T, TaskScope>
152where
153    TaskScope: DeviceSumKernel<T>,
154{
155    let mut sizes = reduction_shape.sizes().to_vec();
156    sizes.remove(dim);
157    let mut dst = Tensor::zeros_in(sizes, scope.clone());
158    partial_sum_reduction_into::<T, BLOCK_SIZE, INTIAL_STRIDE, NUM_ARGS>(
159        dst.as_view_mut(),
160        partial_reduction_kernel,
161        partial_args,
162        partial_shared_mem,
163        reduction_shape,
164        dim,
165        scope,
166    );
167    dst
168}
169
170impl<T: DeviceCopy> DeviceTensor<T>
171where
172    TaskScope: DeviceSumKernel<T>,
173{
174    pub fn sum_dim(&self, dim: usize) -> DeviceTensor<T> {
175        let src = &self.raw;
176        let mut sizes = src.sizes().to_vec();
177        sizes.remove(dim);
178        let mut dst = Tensor::zeros_in(sizes, src.backend().clone());
179        const BLOCK_SIZE: usize = 512;
180        const INTIAL_STRIDE: usize = 8;
181        assert!(dim == src.sizes().len() - 1, "only summing over the last dimension is supported");
182
183        let height = src.sizes()[dim];
184        let width = src.total_len() / height;
185
186        if height <= BLOCK_SIZE {
187            block_sum::<T, BLOCK_SIZE, INTIAL_STRIDE>(src, dst.as_view_mut(), dim);
188            return DeviceTensor { raw: dst };
189        }
190
191        // If the number of elements to sum is bigger than the block size, we need to use a two
192        // step reduction.
193        // 1. Partial sum: sum the elements in blocks of size BLOCK_SIZE
194        // 2. Block sum: sum the partial sums in blocks of size BLOCK_SIZE
195
196        let null_ptr = std::ptr::null::<c_void>();
197        let partial_args = args!(null_ptr, src.as_ptr(), width, height);
198        unsafe {
199            partial_sum_reduction_into::<T, BLOCK_SIZE, INTIAL_STRIDE, 4>(
200                dst.as_view_mut(),
201                TaskScope::partial_sum_kernel(),
202                partial_args,
203                0,
204                &src.dimensions,
205                dim,
206                src.backend(),
207            );
208        }
209        DeviceTensor { raw: dst }
210    }
211}
212
213unsafe impl DeviceSumKernel<SP1Field> for TaskScope {
214    fn partial_sum_kernel() -> KernelPtr {
215        unsafe { koala_bear_sum_partial_block_reduce_kernel() }
216    }
217
218    fn block_sum_kernel() -> KernelPtr {
219        unsafe { koala_bear_sum_block_reduce_kernel() }
220    }
221}
222
223unsafe impl DeviceSumKernel<SP1ExtensionField> for TaskScope {
224    fn partial_sum_kernel() -> KernelPtr {
225        unsafe { koala_bear_extension_sum_partial_block_reduce_kernel() }
226    }
227
228    fn block_sum_kernel() -> KernelPtr {
229        unsafe { koala_bear_extension_sum_block_reduce_kernel() }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use slop_tensor::Tensor;
236    use sp1_primitives::{SP1ExtensionField, SP1Field};
237
238    use super::DeviceTensor;
239
240    #[test]
241    fn test_koala_bear_sum() {
242        let num_summands = 100;
243        let mut rng = rand::thread_rng();
244
245        for size in [10, 100, 1 << 16] {
246            let tensor = Tensor::<SP1Field>::rand(&mut rng, [num_summands, size]);
247
248            let sum_tensor = crate::run_sync_in_place(|t| {
249                let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
250                let sums = device_tensor.sum_dim(1);
251                sums.to_host().unwrap()
252            })
253            .unwrap();
254
255            assert_eq!(sum_tensor.sizes(), [num_summands]);
256            for i in 0..num_summands {
257                let expected_sum: SP1Field =
258                    tensor.get(i).unwrap().as_slice().iter().copied().sum();
259                assert_eq!(expected_sum, *sum_tensor[[i]]);
260            }
261        }
262    }
263
264    #[test]
265    fn test_koala_bear_ext_sum() {
266        let num_summands = 128;
267        let size = 1 << 16;
268        let mut rng = rand::thread_rng();
269
270        type EF = SP1ExtensionField;
271
272        let tensor = Tensor::<EF>::rand(&mut rng, [num_summands, size]);
273
274        let sum_tensor = crate::run_sync_in_place(|t| {
275            let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
276            let sums = device_tensor.sum_dim(1);
277            sums.to_host().unwrap()
278        })
279        .unwrap();
280
281        assert_eq!(sum_tensor.sizes(), [num_summands]);
282        for i in 0..num_summands {
283            let expected_sum: EF = tensor.get(i).unwrap().as_slice().iter().copied().sum();
284            assert_eq!(expected_sum, *sum_tensor[[i]]);
285        }
286    }
287}