1use slop_tensor::{Tensor, TensorView};
2use sp1_gpu_sys::{
3 reduce::{
4 dot_along_short_dimension_kernel_koala_bear_base_base,
5 dot_along_short_dimension_kernel_koala_bear_base_extension,
6 dot_along_short_dimension_kernel_koala_bear_extension_extension,
7 partial_dot_koala_bear_base_extension_kernel, partial_dot_koala_bear_extension_kernel,
8 partial_dot_koala_bear_kernel,
9 },
10 runtime::KernelPtr,
11};
12use sp1_primitives::{SP1ExtensionField, SP1Field};
13
14use crate::{args, reduce::partial_sum_reduction_into, DeviceCopy, DeviceTensor, TaskScope};
15
16use super::reduce::DeviceSumKernel;
17
18pub unsafe trait DotKernel<T: DeviceCopy, U: DeviceCopy>: DeviceSumKernel<U> {
21 fn partial_dot_kernel_last_dim() -> KernelPtr;
22
23 fn dot_along_short_dimension_kernel() -> KernelPtr;
24}
25
26pub fn dot_along_dim_view<'a, T: DeviceCopy, U: DeviceCopy>(
27 src: TensorView<'a, T, TaskScope>,
28 scalars: TensorView<'a, U, TaskScope>,
29 dim: usize,
30) -> Tensor<U, TaskScope>
31where
32 TaskScope: DotKernel<T, U>,
33{
34 let mut sizes = src.sizes().to_vec();
35 sizes.remove(dim);
36 let mut dst = Tensor::with_sizes_in(sizes, src.backend().clone());
37 assert_eq!(src.sizes().len(), 2, "Dot product only supported for 2D tensors",);
38 let max_scalar_dim = *scalars.sizes().iter().max().unwrap();
39 assert_eq!(max_scalar_dim, scalars.total_len(), "The scalar tensor must be a 1D tensor");
40 match dim {
41 dim if dim == src.sizes().len() - 1 => {
42 let height = src.sizes()[dim];
43 let width = src.total_len() / height;
44
45 let null_ptr = std::ptr::null::<std::ffi::c_void>();
46 let partial_args = args!(null_ptr, src.as_ptr(), scalars.as_ptr(), width, height);
47 const BLOCK_SIZE: usize = 256;
48 const INTIAL_STRIDE: usize = 4;
49 dst.storage.write_bytes(0, dst.total_len() * std::mem::size_of::<U>()).unwrap();
50 unsafe {
51 partial_sum_reduction_into::<U, BLOCK_SIZE, INTIAL_STRIDE, 5>(
52 dst.as_view_mut(),
53 TaskScope::partial_dot_kernel_last_dim(),
54 partial_args,
55 0,
56 src.shape(),
57 dim,
58 src.backend(),
59 );
60 }
61 }
62 0 => {
63 let height = src.sizes()[1];
64 let width = src.total_len() / height;
65
66 const BLOCK_SIZE: usize = 256;
67 let args = args!(dst.as_mut_ptr(), src.as_ptr(), scalars.as_ptr(), width, height);
68 let grid_dim = height.div_ceil(BLOCK_SIZE);
69 unsafe {
70 dst.assume_init();
71 src.backend()
72 .launch_kernel(
73 TaskScope::dot_along_short_dimension_kernel(),
74 grid_dim,
75 BLOCK_SIZE,
76 &args,
77 0,
78 )
79 .unwrap();
80 }
81 }
82 _ => panic!(
83 "Dot product is not supported along dimension {} for tensor of sizes {:?}",
84 dim,
85 src.sizes()
86 ),
87 }
88 dst
89}
90
91impl<T: DeviceCopy> DeviceTensor<T> {
92 pub fn dot_along_dim<U: DeviceCopy>(
93 &self,
94 scalars: &DeviceTensor<U>,
95 dim: usize,
96 ) -> DeviceTensor<U>
97 where
98 TaskScope: DotKernel<T, U>,
99 {
100 let raw = dot_along_dim_view(self.raw.as_view(), scalars.raw.as_view(), dim);
101 DeviceTensor { raw }
102 }
103}
104
105unsafe impl DotKernel<SP1Field, SP1Field> for TaskScope {
106 fn partial_dot_kernel_last_dim() -> KernelPtr {
107 unsafe { partial_dot_koala_bear_kernel() }
108 }
109
110 fn dot_along_short_dimension_kernel() -> KernelPtr {
111 unsafe { dot_along_short_dimension_kernel_koala_bear_base_base() }
112 }
113}
114
115unsafe impl DotKernel<SP1ExtensionField, SP1ExtensionField> for TaskScope {
116 fn partial_dot_kernel_last_dim() -> KernelPtr {
117 unsafe { partial_dot_koala_bear_extension_kernel() }
118 }
119
120 fn dot_along_short_dimension_kernel() -> KernelPtr {
121 unsafe { dot_along_short_dimension_kernel_koala_bear_extension_extension() }
122 }
123}
124
125unsafe impl DotKernel<SP1Field, SP1ExtensionField> for TaskScope {
126 fn partial_dot_kernel_last_dim() -> KernelPtr {
127 unsafe { partial_dot_koala_bear_base_extension_kernel() }
128 }
129
130 fn dot_along_short_dimension_kernel() -> KernelPtr {
131 unsafe { dot_along_short_dimension_kernel_koala_bear_base_extension() }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use itertools::Itertools;
138 use slop_algebra::AbstractField;
139 use slop_tensor::Tensor;
140 use sp1_primitives::{SP1ExtensionField, SP1Field};
141
142 use super::DeviceTensor;
143
144 type SP1FieldExt = SP1ExtensionField;
145
146 #[test]
147 fn test_koala_bear_dot() {
148 let num_summands = 100;
149 let mut rng = rand::thread_rng();
150
151 for size in [10, 100, 1 << 16] {
152 let tensor = Tensor::<SP1Field>::rand(&mut rng, [num_summands, size]);
153 let scalars = Tensor::<SP1Field>::rand(&mut rng, [size]);
154
155 let inner_product = crate::run_sync_in_place(|t| {
156 let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
157 let device_scalars = DeviceTensor::from_host(&scalars, &t).unwrap();
158 let inner_product = device_tensor.dot_along_dim(&device_scalars, 1);
159 inner_product.to_host().unwrap()
160 })
161 .unwrap();
162
163 assert_eq!(inner_product.sizes(), [num_summands]);
164 for i in 0..num_summands {
165 let expected_inner_product: SP1Field = tensor
166 .get(i)
167 .unwrap()
168 .as_slice()
169 .iter()
170 .copied()
171 .zip_eq(scalars.as_buffer().iter().copied())
172 .map(|(a, b)| a * b)
173 .sum();
174 assert_eq!(expected_inner_product, *inner_product[[i]]);
175 }
176 }
177 }
178
179 #[test]
180 fn test_koala_bear_extension_dot() {
181 let num_summands = 100;
182 let mut rng = rand::thread_rng();
183
184 type EF = SP1ExtensionField;
185
186 for size in [10, 100, 1 << 16] {
187 let tensor = Tensor::<EF>::rand(&mut rng, [num_summands, size]);
188 let scalars = Tensor::<EF>::rand(&mut rng, [size]);
189
190 let inner_product = crate::run_sync_in_place(|t| {
191 let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
192 let device_scalars = DeviceTensor::from_host(&scalars, &t).unwrap();
193 let inner_product = device_tensor.dot_along_dim(&device_scalars, 1);
194 inner_product.to_host().unwrap()
195 })
196 .unwrap();
197
198 assert_eq!(inner_product.sizes(), [num_summands]);
199 for i in 0..num_summands {
200 let expected_inner_product: EF = tensor
201 .get(i)
202 .unwrap()
203 .as_slice()
204 .iter()
205 .copied()
206 .zip_eq(scalars.as_buffer().iter().copied())
207 .map(|(a, b)| a * b)
208 .sum();
209 assert_eq!(expected_inner_product, *inner_product[[i]]);
210 }
211 }
212 }
213
214 #[test]
215 fn test_koala_bear_base_extension_dot() {
216 let mut rng = rand::thread_rng();
217
218 type F = SP1Field;
219 type EF = SP1ExtensionField;
220
221 for size in [10, 100, 1 << 10, 1 << 12, 1 << 16] {
222 for num_summands in [64, 128] {
223 let tensor = Tensor::<F>::rand(&mut rng, [num_summands, size]);
224 let scalars = Tensor::<EF>::rand(&mut rng, [size]);
225
226 let inner_product = crate::run_sync_in_place(|t| {
227 let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
228 let device_scalars = DeviceTensor::from_host(&scalars, &t).unwrap();
229 t.synchronize_blocking().unwrap();
230 let time = std::time::Instant::now();
231 let inner_product = device_tensor.dot_along_dim(&device_scalars, 1);
232 t.synchronize_blocking().unwrap();
233 tracing::info!(
234 "Dot time for size {}, num_summands: {}, time: {:?}",
235 size,
236 num_summands,
237 time.elapsed()
238 );
239 inner_product.to_host().unwrap()
240 })
241 .unwrap();
242
243 assert_eq!(inner_product.sizes(), [num_summands]);
244 for i in 0..num_summands {
245 let expected_inner_product: EF = tensor
246 .get(i)
247 .unwrap()
248 .as_slice()
249 .iter()
250 .copied()
251 .zip_eq(scalars.as_buffer().iter().copied())
252 .map(|(a, b)| b * a)
253 .sum();
254 assert_eq!(expected_inner_product, *inner_product[[i]]);
255 }
256 }
257 }
258 }
259
260 #[test]
261 fn test_dot_along_dim_0_base_base() {
262 let mut rng = rand::thread_rng();
263
264 let width = 10;
265 let height = 1500;
266
267 let host_tensor = Tensor::<SP1Field>::rand(&mut rng, [width, height]);
268 let host_scalars = Tensor::<SP1Field>::rand(&mut rng, [width]);
269
270 let dot = crate::run_sync_in_place(|t| {
271 let tensor = DeviceTensor::from_host(&host_tensor, &t).unwrap();
272 let scalars = DeviceTensor::from_host(&host_scalars, &t).unwrap();
273 let dot = tensor.dot_along_dim(&scalars, 0);
274 dot.to_host().unwrap()
275 })
276 .unwrap();
277
278 assert_eq!(dot.sizes(), [height]);
279 for i in 0..height {
280 let mut dot_product = SP1Field::zero();
281 for j in 0..width {
282 dot_product += *host_scalars[[j]] * *host_tensor[[j, i]];
283 }
284 assert_eq!(*dot[[i]], dot_product, "Dot product at index {i} is incorrect");
285 }
286 }
287
288 #[test]
289 fn test_dot_along_dim_0_base_ext() {
290 let mut rng = rand::thread_rng();
291
292 let width = 10;
293 let height = 1500;
294
295 let host_tensor = Tensor::<SP1Field>::rand(&mut rng, [width, height]);
296 let host_scalars = Tensor::<SP1FieldExt>::rand(&mut rng, [width]);
297
298 let dot = crate::run_sync_in_place(|t| {
299 let tensor = DeviceTensor::from_host(&host_tensor, &t).unwrap();
300 let scalars = DeviceTensor::from_host(&host_scalars, &t).unwrap();
301 let dot = tensor.dot_along_dim(&scalars, 0);
302 dot.to_host().unwrap()
303 })
304 .unwrap();
305
306 assert_eq!(dot.sizes(), [height]);
307 for i in 0..height {
308 let mut dot_product = SP1FieldExt::zero();
309 for j in 0..width {
310 dot_product += *host_scalars[[j]] * *host_tensor[[j, i]];
311 }
312 assert_eq!(*dot[[i]], dot_product, "Dot product at index {i} is incorrect");
313 }
314 }
315
316 #[test]
317 fn test_dot_along_dim_0_ext_ext() {
318 let mut rng = rand::thread_rng();
319
320 let width = 10;
321 let height = 1500;
322
323 let host_tensor = Tensor::<SP1FieldExt>::rand(&mut rng, [width, height]);
324 let host_scalars = Tensor::<SP1FieldExt>::rand(&mut rng, [width]);
325
326 let dot = crate::run_sync_in_place(|t| {
327 let tensor = DeviceTensor::from_host(&host_tensor, &t).unwrap();
328 let scalars = DeviceTensor::from_host(&host_scalars, &t).unwrap();
329 let dot = tensor.dot_along_dim(&scalars, 0);
330 dot.to_host().unwrap()
331 })
332 .unwrap();
333
334 assert_eq!(dot.sizes(), [height]);
335 for i in 0..height {
336 let mut dot_product = SP1FieldExt::zero();
337 for j in 0..width {
338 dot_product += *host_scalars[[j]] * *host_tensor[[j, i]];
339 }
340 assert_eq!(*dot[[i]], dot_product, "Dot product at index {i} is incorrect");
341 }
342 }
343}