1use std::ffi::c_void;
2
3use slop_tensor::{Dimensions, Tensor, TensorViewMut};
4use sp1_gpu_sys::{
5 reduce::{
6 koala_bear_extension_sum_block_reduce_kernel,
7 koala_bear_extension_sum_partial_block_reduce_kernel, koala_bear_sum_block_reduce_kernel,
8 koala_bear_sum_partial_block_reduce_kernel,
9 },
10 runtime::{Dim3, KernelPtr},
11};
12use sp1_primitives::{SP1ExtensionField, SP1Field};
13
14use crate::{args, DeviceCopy, DeviceTensor};
15
16use super::TaskScope;
17
18const MAX_NUM_FINAL_BLOCKS: usize = 2;
19
20pub unsafe trait DeviceSumKernel<T> {
27 fn partial_sum_kernel() -> KernelPtr;
28
29 fn block_sum_kernel() -> KernelPtr;
30}
31
32fn block_sum<T: DeviceCopy, const BLOCK_SIZE: usize, const INTIAL_STRIDE: usize>(
33 src: &Tensor<T, TaskScope>,
34 mut dst: TensorViewMut<T, TaskScope>,
35 dim: usize,
36) where
37 TaskScope: DeviceSumKernel<T>,
38{
39 let height = src.sizes()[dim];
40 let width = src.total_len() / height;
41
42 let block_dim: Dim3 = BLOCK_SIZE.into();
43 let num_reduce_blocks = height.div_ceil(block_dim.x as usize).div_ceil(INTIAL_STRIDE);
44 let grid_dim: Dim3 = (num_reduce_blocks, width, 1).into();
45
46 let args = args!(src.as_ptr(), dst.as_mut_ptr(), width, height);
48 let shared_mem = 0;
49 unsafe {
50 src.backend()
51 .launch_kernel(TaskScope::block_sum_kernel(), grid_dim, block_dim, &args, shared_mem)
52 .unwrap();
53 }
54}
55
56#[inline]
60pub unsafe fn partial_sum_reduction_into<
61 T: DeviceCopy,
62 const BLOCK_SIZE: usize,
63 const INTIAL_STRIDE: usize,
64 const NUM_ARGS: usize,
65>(
66 dst: TensorViewMut<T, TaskScope>,
67 partial_reduction_kernel: KernelPtr,
68 mut partial_args: [*mut c_void; NUM_ARGS],
69 partial_shared_mem: usize,
70 reduction_shape: &Dimensions,
71 dim: usize,
72 scope: &TaskScope,
73) where
74 TaskScope: DeviceSumKernel<T>,
75{
76 let height = reduction_shape.sizes()[dim];
77 let width = reduction_shape.total_len() / height;
78
79 let block_dim: Dim3 = BLOCK_SIZE.into();
80 let num_reduce_blocks = height.div_ceil(block_dim.x as usize).div_ceil(INTIAL_STRIDE);
81 let grid_dim: Dim3 = (num_reduce_blocks, width, 1).into();
82
83 let mut sizes = reduction_shape.sizes().to_vec();
84 sizes[dim] = grid_dim.x as usize;
85 let mut partial_sums = Tensor::<T, _>::with_sizes_in(sizes.clone(), scope.clone());
86 let num_tiles = block_dim.x.checked_div(32).unwrap_or(1);
87 let shared_mem = num_tiles * block_dim.y * (std::mem::size_of::<T>() as u32);
88 let partial_args_ptr = &partial_sums.as_mut_ptr() as *const _ as *mut c_void;
89 partial_args[0] = partial_args_ptr;
90 let args = partial_args;
91 unsafe {
92 partial_sums.assume_init();
93 scope
94 .launch_kernel(
95 partial_reduction_kernel,
96 grid_dim,
97 block_dim,
98 &args,
99 shared_mem as usize + partial_shared_mem,
100 )
101 .unwrap();
102 }
103
104 let mut partial_sums = partial_sums;
107 while sizes[dim] > MAX_NUM_FINAL_BLOCKS * BLOCK_SIZE {
108 let height = sizes[dim];
109 let block_dim: Dim3 = BLOCK_SIZE.into();
110 let num_reduce_blocks = height.div_ceil(block_dim.x as usize).div_ceil(INTIAL_STRIDE);
111 let grid_dim: Dim3 = (num_reduce_blocks, width, 1).into();
112
113 sizes[dim] = grid_dim.x as usize;
114 let mut current = Tensor::<T, _>::with_sizes_in(sizes.clone(), scope.clone());
115 let args = args!(current.as_mut_ptr(), partial_sums.as_ptr(), width, height);
116 let num_tiles = block_dim.x.checked_div(32).unwrap_or(1);
117 let shared_mem = num_tiles * block_dim.y * (std::mem::size_of::<T>() as u32);
118 unsafe {
119 current.assume_init();
120 scope
121 .launch_kernel(
122 TaskScope::partial_sum_kernel(),
123 grid_dim,
124 block_dim,
125 &args,
126 shared_mem as usize,
127 )
128 .unwrap();
129 }
130 partial_sums = current;
132 }
133
134 block_sum::<T, BLOCK_SIZE, INTIAL_STRIDE>(&partial_sums, dst, dim);
136}
137
138pub unsafe fn partial_sum_reduction<
140 T: DeviceCopy,
141 const BLOCK_SIZE: usize,
142 const INTIAL_STRIDE: usize,
143 const NUM_ARGS: usize,
144>(
145 partial_reduction_kernel: KernelPtr,
146 partial_args: [*mut c_void; NUM_ARGS],
147 partial_shared_mem: usize,
148 reduction_shape: &Dimensions,
149 scope: &TaskScope,
150 dim: usize,
151) -> Tensor<T, TaskScope>
152where
153 TaskScope: DeviceSumKernel<T>,
154{
155 let mut sizes = reduction_shape.sizes().to_vec();
156 sizes.remove(dim);
157 let mut dst = Tensor::zeros_in(sizes, scope.clone());
158 partial_sum_reduction_into::<T, BLOCK_SIZE, INTIAL_STRIDE, NUM_ARGS>(
159 dst.as_view_mut(),
160 partial_reduction_kernel,
161 partial_args,
162 partial_shared_mem,
163 reduction_shape,
164 dim,
165 scope,
166 );
167 dst
168}
169
170impl<T: DeviceCopy> DeviceTensor<T>
171where
172 TaskScope: DeviceSumKernel<T>,
173{
174 pub fn sum_dim(&self, dim: usize) -> DeviceTensor<T> {
175 let src = &self.raw;
176 let mut sizes = src.sizes().to_vec();
177 sizes.remove(dim);
178 let mut dst = Tensor::zeros_in(sizes, src.backend().clone());
179 const BLOCK_SIZE: usize = 512;
180 const INTIAL_STRIDE: usize = 8;
181 assert!(dim == src.sizes().len() - 1, "only summing over the last dimension is supported");
182
183 let height = src.sizes()[dim];
184 let width = src.total_len() / height;
185
186 if height <= BLOCK_SIZE {
187 block_sum::<T, BLOCK_SIZE, INTIAL_STRIDE>(src, dst.as_view_mut(), dim);
188 return DeviceTensor { raw: dst };
189 }
190
191 let null_ptr = std::ptr::null::<c_void>();
197 let partial_args = args!(null_ptr, src.as_ptr(), width, height);
198 unsafe {
199 partial_sum_reduction_into::<T, BLOCK_SIZE, INTIAL_STRIDE, 4>(
200 dst.as_view_mut(),
201 TaskScope::partial_sum_kernel(),
202 partial_args,
203 0,
204 &src.dimensions,
205 dim,
206 src.backend(),
207 );
208 }
209 DeviceTensor { raw: dst }
210 }
211}
212
213unsafe impl DeviceSumKernel<SP1Field> for TaskScope {
214 fn partial_sum_kernel() -> KernelPtr {
215 unsafe { koala_bear_sum_partial_block_reduce_kernel() }
216 }
217
218 fn block_sum_kernel() -> KernelPtr {
219 unsafe { koala_bear_sum_block_reduce_kernel() }
220 }
221}
222
223unsafe impl DeviceSumKernel<SP1ExtensionField> for TaskScope {
224 fn partial_sum_kernel() -> KernelPtr {
225 unsafe { koala_bear_extension_sum_partial_block_reduce_kernel() }
226 }
227
228 fn block_sum_kernel() -> KernelPtr {
229 unsafe { koala_bear_extension_sum_block_reduce_kernel() }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use slop_tensor::Tensor;
236 use sp1_primitives::{SP1ExtensionField, SP1Field};
237
238 use super::DeviceTensor;
239
240 #[test]
241 fn test_koala_bear_sum() {
242 let num_summands = 100;
243 let mut rng = rand::thread_rng();
244
245 for size in [10, 100, 1 << 16] {
246 let tensor = Tensor::<SP1Field>::rand(&mut rng, [num_summands, size]);
247
248 let sum_tensor = crate::run_sync_in_place(|t| {
249 let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
250 let sums = device_tensor.sum_dim(1);
251 sums.to_host().unwrap()
252 })
253 .unwrap();
254
255 assert_eq!(sum_tensor.sizes(), [num_summands]);
256 for i in 0..num_summands {
257 let expected_sum: SP1Field =
258 tensor.get(i).unwrap().as_slice().iter().copied().sum();
259 assert_eq!(expected_sum, *sum_tensor[[i]]);
260 }
261 }
262 }
263
264 #[test]
265 fn test_koala_bear_ext_sum() {
266 let num_summands = 128;
267 let size = 1 << 16;
268 let mut rng = rand::thread_rng();
269
270 type EF = SP1ExtensionField;
271
272 let tensor = Tensor::<EF>::rand(&mut rng, [num_summands, size]);
273
274 let sum_tensor = crate::run_sync_in_place(|t| {
275 let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
276 let sums = device_tensor.sum_dim(1);
277 sums.to_host().unwrap()
278 })
279 .unwrap();
280
281 assert_eq!(sum_tensor.sizes(), [num_summands]);
282 for i in 0..num_summands {
283 let expected_sum: EF = tensor.get(i).unwrap().as_slice().iter().copied().sum();
284 assert_eq!(expected_sum, *sum_tensor[[i]]);
285 }
286 }
287}