1use rayon::prelude::*;
2
3use slop_algebra::{AbstractExtensionField, AbstractField};
4use slop_alloc::{buffer, Buffer, CpuBackend};
5
6use crate::{Dimensions, Tensor};
7
8pub fn dot_along_dim<T, U>(
13 src: &Tensor<T, CpuBackend>,
14 scalars: &Tensor<U, CpuBackend>,
15 dim: usize,
16) -> Tensor<U, CpuBackend>
17where
18 T: AbstractField + 'static + Sync,
19 U: AbstractExtensionField<T> + 'static + Send + Sync,
20{
21 let mut sizes = src.sizes().to_vec();
22 sizes.remove(dim);
23 let dimensions = Dimensions::try_from(sizes).unwrap();
24 let mut dst = Tensor { storage: buffer![], dimensions };
25 let max_scalar_dim = *scalars.sizes().iter().max().unwrap();
26 assert_eq!(max_scalar_dim, scalars.total_len(), "The scalar tensor must be a 1D tensor");
27 match dim {
28 0 => {
29 assert!(
30 src.sizes().len() <= 2,
31 "Only 1D and 2D dimensional tensors are supported for dim 0"
32 );
33 let total_len = dst.total_len();
34 let dot_products = src
35 .as_buffer()
36 .par_chunks_exact(src.strides()[0])
37 .zip(scalars.as_buffer().par_iter())
38 .map(|(chunk, scalar)| chunk.iter().map(|a| scalar.clone() * a.clone()).collect())
39 .reduce(
40 || vec![U::zero(); total_len],
41 |mut a, b| {
42 a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += b.clone());
43 a
44 },
45 );
46
47 let dot_products = Buffer::from(dot_products);
48 dst.storage = dot_products;
49 }
50 dim if dim == src.sizes().len() - 1 => {
51 let mut dst_storage = Vec::<U>::with_capacity(dst.total_len());
52 src.as_buffer()
53 .par_chunks_exact(src.strides()[dim - 1])
54 .map(|chunk| {
55 scalars
56 .as_buffer()
57 .iter()
58 .zip(chunk.iter())
59 .map(|(a, b)| a.clone() * b.clone())
60 .sum::<U>()
61 })
62 .collect_into_vec(&mut dst_storage);
63 dst.storage = Buffer::from(dst_storage);
64 }
65 _ => {
66 panic!("Unsupported dot product dimension {} for tensor sizes: {:?}", dim, src.sizes())
67 }
68 }
69 dst
70}
71
72#[cfg(test)]
73mod tests {
74 use slop_algebra::AbstractField;
75 use slop_baby_bear::BabyBear;
76
77 use super::*;
78
79 #[test]
80 fn test_dot_along_dim_0() {
81 let mut rng = rand::thread_rng();
82 let tensor = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [1500, 10]);
83 let scalars = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [1500]);
84 let dot = dot_along_dim(&tensor, &scalars, 0);
85 for j in 0..10 {
86 let mut dot_product = BabyBear::zero();
87 for i in 0..1500 {
88 dot_product += *scalars[[i]] * *tensor[[i, j]];
89 }
90 assert_eq!(*dot[[j]], dot_product);
91 }
92 }
93
94 #[test]
95 fn test_dot_along_dim_last() {
96 let mut rng = rand::thread_rng();
97 let tensor = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [10, 1500, 10]);
98 let scalars = Tensor::<BabyBear, CpuBackend>::rand(&mut rng, [10]);
99 let dot = dot_along_dim(&tensor, &scalars, 2);
100 for k in 0..10 {
101 for i in 0..1500 {
102 let mut dot_product = BabyBear::zero();
103 for j in 0..10 {
104 dot_product += *scalars[[j]] * *tensor[[k, i, j]];
105 }
106 assert_eq!(*dot[[k, i]], dot_product);
107 }
108 }
109 }
110}