Skip to main content

slop_tensor/
dot.rs

1use rayon::prelude::*;
2
3use slop_algebra::{AbstractExtensionField, AbstractField};
4use slop_alloc::{buffer, Buffer, CpuBackend};
5
6use crate::{Dimensions, Tensor};
7
8/// Compute the dot product of a tensor with a scalar tensor along a given dimension.
9///
10/// This scalar tensor is assumed to be a `1D` tensor, which is any tensor of a shape
11/// `[len, 1, 1, 1,..]`.
12pub fn dot_along_dim<T, U>(
13    src: &Tensor<T, CpuBackend>,
14    scalars: &Tensor<U, CpuBackend>,
15    dim: usize,
16) -> Tensor<U, CpuBackend>
17where
18    T: AbstractField + 'static + Sync,
19    U: AbstractExtensionField<T> + 'static + Send + Sync,
20{
21    let mut sizes = src.sizes().to_vec();
22    sizes.remove(dim);
23    let dimensions = Dimensions::try_from(sizes).unwrap();
24    let mut dst = Tensor { storage: buffer![], dimensions };
25    let max_scalar_dim = *scalars.sizes().iter().max().unwrap();
26    assert_eq!(max_scalar_dim, scalars.total_len(), "The scalar tensor must be a 1D tensor");
27    match dim {
28        0 => {
29            assert!(
30                src.sizes().len() <= 2,
31                "Only 1D and 2D dimensional tensors are supported for dim 0"
32            );
33            let total_len = dst.total_len();
34            let dot_products = src
35                .as_buffer()
36                .par_chunks_exact(src.strides()[0])
37                .zip(scalars.as_buffer().par_iter())
38                .map(|(chunk, scalar)| chunk.iter().map(|a| scalar.clone() * a.clone()).collect())
39                .reduce(
40                    || vec![U::zero(); total_len],
41                    |mut a, b| {
42                        a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += b.clone());
43                        a
44                    },
45                );
46
47            let dot_products = Buffer::from(dot_products);
48            dst.storage = dot_products;
49        }
50        dim if dim == src.sizes().len() - 1 => {
51            let mut dst_storage = Vec::<U>::with_capacity(dst.total_len());
52            src.as_buffer()
53                .par_chunks_exact(src.strides()[dim - 1])
54                .map(|chunk| {
55                    scalars
56                        .as_buffer()
57                        .iter()
58                        .zip(chunk.iter())
59                        .map(|(a, b)| a.clone() * b.clone())
60                        .sum::<U>()
61                })
62                .collect_into_vec(&mut dst_storage);
63            dst.storage = Buffer::from(dst_storage);
64        }
65        _ => {
66            panic!("Unsupported dot product dimension {} for tensor sizes: {:?}", dim, src.sizes())
67        }
68    }
69    dst
70}
71
72#[cfg(test)]
73mod tests {
74    use slop_algebra::AbstractField;
75    use slop_baby_bear::BabyBear;
76
77    use super::*;
78
79    #[test]
80    fn test_dot_along_dim_0() {
81        let mut rng = rand::thread_rng();
82        let tensor = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [1500, 10]);
83        let scalars = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [1500]);
84        let dot = dot_along_dim(&tensor, &scalars, 0);
85        for j in 0..10 {
86            let mut dot_product = BabyBear::zero();
87            for i in 0..1500 {
88                dot_product += *scalars[[i]] * *tensor[[i, j]];
89            }
90            assert_eq!(*dot[[j]], dot_product);
91        }
92    }
93
94    #[test]
95    fn test_dot_along_dim_last() {
96        let mut rng = rand::thread_rng();
97        let tensor = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [10, 1500, 10]);
98        let scalars = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [10]);
99        let dot = dot_along_dim(&tensor, &scalars, 2);
100        for k in 0..10 {
101            for i in 0..1500 {
102                let mut dot_product = BabyBear::zero();
103                for j in 0..10 {
104                    dot_product += *scalars[[j]] * *tensor[[k, i, j]];
105                }
106                assert_eq!(*dot[[k, i]], dot_product);
107            }
108        }
109    }
110}