Skip to main content

sp1_gpu_cudart/tensor/
dot.rs

1use slop_tensor::{Tensor, TensorView};
2use sp1_gpu_sys::{
3    reduce::{
4        dot_along_short_dimension_kernel_koala_bear_base_base,
5        dot_along_short_dimension_kernel_koala_bear_base_extension,
6        dot_along_short_dimension_kernel_koala_bear_extension_extension,
7        partial_dot_koala_bear_base_extension_kernel, partial_dot_koala_bear_extension_kernel,
8        partial_dot_koala_bear_kernel,
9    },
10    runtime::KernelPtr,
11};
12use sp1_primitives::{SP1ExtensionField, SP1Field};
13
14use crate::{args, reduce::partial_sum_reduction_into, DeviceCopy, DeviceTensor, TaskScope};
15
16use super::reduce::DeviceSumKernel;
17
18/// # Safety
19///
20pub unsafe trait DotKernel<T: DeviceCopy, U: DeviceCopy>: DeviceSumKernel<U> {
21    fn partial_dot_kernel_last_dim() -> KernelPtr;
22
23    fn dot_along_short_dimension_kernel() -> KernelPtr;
24}
25
26pub fn dot_along_dim_view<'a, T: DeviceCopy, U: DeviceCopy>(
27    src: TensorView<'a, T, TaskScope>,
28    scalars: TensorView<'a, U, TaskScope>,
29    dim: usize,
30) -> Tensor<U, TaskScope>
31where
32    TaskScope: DotKernel<T, U>,
33{
34    let mut sizes = src.sizes().to_vec();
35    sizes.remove(dim);
36    let mut dst = Tensor::with_sizes_in(sizes, src.backend().clone());
37    assert_eq!(src.sizes().len(), 2, "Dot product only supported for 2D tensors",);
38    let max_scalar_dim = *scalars.sizes().iter().max().unwrap();
39    assert_eq!(max_scalar_dim, scalars.total_len(), "The scalar tensor must be a 1D tensor");
40    match dim {
41        dim if dim == src.sizes().len() - 1 => {
42            let height = src.sizes()[dim];
43            let width = src.total_len() / height;
44
45            let null_ptr = std::ptr::null::<std::ffi::c_void>();
46            let partial_args = args!(null_ptr, src.as_ptr(), scalars.as_ptr(), width, height);
47            const BLOCK_SIZE: usize = 256;
48            const INTIAL_STRIDE: usize = 4;
49            dst.storage.write_bytes(0, dst.total_len() * std::mem::size_of::<U>()).unwrap();
50            unsafe {
51                partial_sum_reduction_into::<U, BLOCK_SIZE, INTIAL_STRIDE, 5>(
52                    dst.as_view_mut(),
53                    TaskScope::partial_dot_kernel_last_dim(),
54                    partial_args,
55                    0,
56                    src.shape(),
57                    dim,
58                    src.backend(),
59                );
60            }
61        }
62        0 => {
63            let height = src.sizes()[1];
64            let width = src.total_len() / height;
65
66            const BLOCK_SIZE: usize = 256;
67            let args = args!(dst.as_mut_ptr(), src.as_ptr(), scalars.as_ptr(), width, height);
68            let grid_dim = height.div_ceil(BLOCK_SIZE);
69            unsafe {
70                dst.assume_init();
71                src.backend()
72                    .launch_kernel(
73                        TaskScope::dot_along_short_dimension_kernel(),
74                        grid_dim,
75                        BLOCK_SIZE,
76                        &args,
77                        0,
78                    )
79                    .unwrap();
80            }
81        }
82        _ => panic!(
83            "Dot product is not supported along dimension {} for tensor of sizes {:?}",
84            dim,
85            src.sizes()
86        ),
87    }
88    dst
89}
90
91impl<T: DeviceCopy> DeviceTensor<T> {
92    pub fn dot_along_dim<U: DeviceCopy>(
93        &self,
94        scalars: &DeviceTensor<U>,
95        dim: usize,
96    ) -> DeviceTensor<U>
97    where
98        TaskScope: DotKernel<T, U>,
99    {
100        let raw = dot_along_dim_view(self.raw.as_view(), scalars.raw.as_view(), dim);
101        DeviceTensor { raw }
102    }
103}
104
105unsafe impl DotKernel<SP1Field, SP1Field> for TaskScope {
106    fn partial_dot_kernel_last_dim() -> KernelPtr {
107        unsafe { partial_dot_koala_bear_kernel() }
108    }
109
110    fn dot_along_short_dimension_kernel() -> KernelPtr {
111        unsafe { dot_along_short_dimension_kernel_koala_bear_base_base() }
112    }
113}
114
115unsafe impl DotKernel<SP1ExtensionField, SP1ExtensionField> for TaskScope {
116    fn partial_dot_kernel_last_dim() -> KernelPtr {
117        unsafe { partial_dot_koala_bear_extension_kernel() }
118    }
119
120    fn dot_along_short_dimension_kernel() -> KernelPtr {
121        unsafe { dot_along_short_dimension_kernel_koala_bear_extension_extension() }
122    }
123}
124
125unsafe impl DotKernel<SP1Field, SP1ExtensionField> for TaskScope {
126    fn partial_dot_kernel_last_dim() -> KernelPtr {
127        unsafe { partial_dot_koala_bear_base_extension_kernel() }
128    }
129
130    fn dot_along_short_dimension_kernel() -> KernelPtr {
131        unsafe { dot_along_short_dimension_kernel_koala_bear_base_extension() }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use itertools::Itertools;
138    use slop_algebra::AbstractField;
139    use slop_tensor::Tensor;
140    use sp1_primitives::{SP1ExtensionField, SP1Field};
141
142    use super::DeviceTensor;
143
144    type SP1FieldExt = SP1ExtensionField;
145
146    #[test]
147    fn test_koala_bear_dot() {
148        let num_summands = 100;
149        let mut rng = rand::thread_rng();
150
151        for size in [10, 100, 1 << 16] {
152            let tensor = Tensor::<SP1Field>::rand(&mut rng, [num_summands, size]);
153            let scalars = Tensor::<SP1Field>::rand(&mut rng, [size]);
154
155            let inner_product = crate::run_sync_in_place(|t| {
156                let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
157                let device_scalars = DeviceTensor::from_host(&scalars, &t).unwrap();
158                let inner_product = device_tensor.dot_along_dim(&device_scalars, 1);
159                inner_product.to_host().unwrap()
160            })
161            .unwrap();
162
163            assert_eq!(inner_product.sizes(), [num_summands]);
164            for i in 0..num_summands {
165                let expected_inner_product: SP1Field = tensor
166                    .get(i)
167                    .unwrap()
168                    .as_slice()
169                    .iter()
170                    .copied()
171                    .zip_eq(scalars.as_buffer().iter().copied())
172                    .map(|(a, b)| a * b)
173                    .sum();
174                assert_eq!(expected_inner_product, *inner_product[[i]]);
175            }
176        }
177    }
178
179    #[test]
180    fn test_koala_bear_extension_dot() {
181        let num_summands = 100;
182        let mut rng = rand::thread_rng();
183
184        type EF = SP1ExtensionField;
185
186        for size in [10, 100, 1 << 16] {
187            let tensor = Tensor::<EF>::rand(&mut rng, [num_summands, size]);
188            let scalars = Tensor::<EF>::rand(&mut rng, [size]);
189
190            let inner_product = crate::run_sync_in_place(|t| {
191                let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
192                let device_scalars = DeviceTensor::from_host(&scalars, &t).unwrap();
193                let inner_product = device_tensor.dot_along_dim(&device_scalars, 1);
194                inner_product.to_host().unwrap()
195            })
196            .unwrap();
197
198            assert_eq!(inner_product.sizes(), [num_summands]);
199            for i in 0..num_summands {
200                let expected_inner_product: EF = tensor
201                    .get(i)
202                    .unwrap()
203                    .as_slice()
204                    .iter()
205                    .copied()
206                    .zip_eq(scalars.as_buffer().iter().copied())
207                    .map(|(a, b)| a * b)
208                    .sum();
209                assert_eq!(expected_inner_product, *inner_product[[i]]);
210            }
211        }
212    }
213
214    #[test]
215    fn test_koala_bear_base_extension_dot() {
216        let mut rng = rand::thread_rng();
217
218        type F = SP1Field;
219        type EF = SP1ExtensionField;
220
221        for size in [10, 100, 1 << 10, 1 << 12, 1 << 16] {
222            for num_summands in [64, 128] {
223                let tensor = Tensor::<F>::rand(&mut rng, [num_summands, size]);
224                let scalars = Tensor::<EF>::rand(&mut rng, [size]);
225
226                let inner_product = crate::run_sync_in_place(|t| {
227                    let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
228                    let device_scalars = DeviceTensor::from_host(&scalars, &t).unwrap();
229                    t.synchronize_blocking().unwrap();
230                    let time = std::time::Instant::now();
231                    let inner_product = device_tensor.dot_along_dim(&device_scalars, 1);
232                    t.synchronize_blocking().unwrap();
233                    tracing::info!(
234                        "Dot time for size {}, num_summands: {}, time: {:?}",
235                        size,
236                        num_summands,
237                        time.elapsed()
238                    );
239                    inner_product.to_host().unwrap()
240                })
241                .unwrap();
242
243                assert_eq!(inner_product.sizes(), [num_summands]);
244                for i in 0..num_summands {
245                    let expected_inner_product: EF = tensor
246                        .get(i)
247                        .unwrap()
248                        .as_slice()
249                        .iter()
250                        .copied()
251                        .zip_eq(scalars.as_buffer().iter().copied())
252                        .map(|(a, b)| b * a)
253                        .sum();
254                    assert_eq!(expected_inner_product, *inner_product[[i]]);
255                }
256            }
257        }
258    }
259
260    #[test]
261    fn test_dot_along_dim_0_base_base() {
262        let mut rng = rand::thread_rng();
263
264        let width = 10;
265        let height = 1500;
266
267        let host_tensor = Tensor::<SP1Field>::rand(&mut rng, [width, height]);
268        let host_scalars = Tensor::<SP1Field>::rand(&mut rng, [width]);
269
270        let dot = crate::run_sync_in_place(|t| {
271            let tensor = DeviceTensor::from_host(&host_tensor, &t).unwrap();
272            let scalars = DeviceTensor::from_host(&host_scalars, &t).unwrap();
273            let dot = tensor.dot_along_dim(&scalars, 0);
274            dot.to_host().unwrap()
275        })
276        .unwrap();
277
278        assert_eq!(dot.sizes(), [height]);
279        for i in 0..height {
280            let mut dot_product = SP1Field::zero();
281            for j in 0..width {
282                dot_product += *host_scalars[[j]] * *host_tensor[[j, i]];
283            }
284            assert_eq!(*dot[[i]], dot_product, "Dot product at index {i} is incorrect");
285        }
286    }
287
288    #[test]
289    fn test_dot_along_dim_0_base_ext() {
290        let mut rng = rand::thread_rng();
291
292        let width = 10;
293        let height = 1500;
294
295        let host_tensor = Tensor::<SP1Field>::rand(&mut rng, [width, height]);
296        let host_scalars = Tensor::<SP1FieldExt>::rand(&mut rng, [width]);
297
298        let dot = crate::run_sync_in_place(|t| {
299            let tensor = DeviceTensor::from_host(&host_tensor, &t).unwrap();
300            let scalars = DeviceTensor::from_host(&host_scalars, &t).unwrap();
301            let dot = tensor.dot_along_dim(&scalars, 0);
302            dot.to_host().unwrap()
303        })
304        .unwrap();
305
306        assert_eq!(dot.sizes(), [height]);
307        for i in 0..height {
308            let mut dot_product = SP1FieldExt::zero();
309            for j in 0..width {
310                dot_product += *host_scalars[[j]] * *host_tensor[[j, i]];
311            }
312            assert_eq!(*dot[[i]], dot_product, "Dot product at index {i} is incorrect");
313        }
314    }
315
316    #[test]
317    fn test_dot_along_dim_0_ext_ext() {
318        let mut rng = rand::thread_rng();
319
320        let width = 10;
321        let height = 1500;
322
323        let host_tensor = Tensor::<SP1FieldExt>::rand(&mut rng, [width, height]);
324        let host_scalars = Tensor::<SP1FieldExt>::rand(&mut rng, [width]);
325
326        let dot = crate::run_sync_in_place(|t| {
327            let tensor = DeviceTensor::from_host(&host_tensor, &t).unwrap();
328            let scalars = DeviceTensor::from_host(&host_scalars, &t).unwrap();
329            let dot = tensor.dot_along_dim(&scalars, 0);
330            dot.to_host().unwrap()
331        })
332        .unwrap();
333
334        assert_eq!(dot.sizes(), [height]);
335        for i in 0..height {
336            let mut dot_product = SP1FieldExt::zero();
337            for j in 0..width {
338                dot_product += *host_scalars[[j]] * *host_tensor[[j, i]];
339            }
340            assert_eq!(*dot[[i]], dot_product, "Dot product at index {i} is incorrect");
341        }
342    }
343}