zenu_matrix/operation/
broadcast.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait},
4    index::Index0D,
5    matrix::{Matrix, Ref, Repr},
6    num::Num,
7};
8
9impl<T: Num, D: Device> Matrix<Ref<&mut T>, DimDyn, D> {
10    #[expect(clippy::missing_panics_doc)]
11    pub fn broadcast<R: Repr<Item = T>>(&self, source: &Matrix<R, DimDyn, D>) {
12        let source = source.to_ref();
13        if !(self.shape().is_include(source.shape())
14            || self.shape().is_include_bradcast(source.shape()))
15        {
16            panic!("!self.shape().is_include(source.shape())");
17        }
18        if self.shape() == source.shape() {
19            self.copy_from(&source);
20            return;
21        }
22        if !source.shape().is_empty() && source.shape()[0] == 1 {
23            let source = source.index_axis_dyn(Index0D::new(0));
24            self.broadcast(&source);
25            return;
26        }
27
28        let diff_len = self.shape().len() - source.shape().len();
29
30        if diff_len == 1 {
31            for i in 0..self.shape()[0] {
32                let to = self.index_axis_mut_dyn(Index0D::new(i));
33                to.copy_from(&source);
34            }
35        } else {
36            for i in 0..self.shape()[0] {
37                let to = self.index_axis_mut_dyn(Index0D::new(i));
38                to.broadcast(&source);
39            }
40        }
41    }
42}
43
44#[cfg(test)]
45mod broadcast_test {
46    #![expect(clippy::float_cmp)]
47    use crate::{
48        device::Device,
49        dim::DimDyn,
50        matrix::{Matrix, Owned},
51    };
52
53    fn broadcast_1d_0d<D: Device>() {
54        let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.], []);
55        let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([3]);
56        res.to_ref_mut().broadcast(&source.to_ref());
57        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 1., 1.], [3]);
58        let diff = ans.to_ref() - res.to_ref();
59        let diff_sum = diff.to_ref().asum();
60        assert_eq!(diff_sum, 0.);
61    }
62    #[test]
63    fn broadcast_1d_0d_cpu() {
64        broadcast_1d_0d::<crate::device::cpu::Cpu>();
65    }
66    #[test]
67    #[cfg(feature = "nvidia")]
68    fn broadcast_1d_0d_nvidia() {
69        broadcast_1d_0d::<crate::device::nvidia::Nvidia>();
70    }
71
72    // #[test]
73    fn broadcast_2d_0d<D: Device>() {
74        let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.], []);
75        let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3]);
76        res.to_ref_mut().broadcast(&source.to_ref());
77        let ans: Matrix<Owned<f32>, DimDyn, D> =
78            Matrix::from_vec(vec![1., 1., 1., 1., 1., 1.], [2, 3]);
79        let diff = ans.to_ref() - res.to_ref();
80        let diff_sum = diff.to_ref().asum();
81        assert_eq!(diff_sum, 0.);
82    }
83    #[test]
84    fn broadcast_2d_0d_cpu() {
85        broadcast_2d_0d::<crate::device::cpu::Cpu>();
86    }
87    #[test]
88    #[cfg(feature = "nvidia")]
89    fn broadcast_2d_0d_nvidia() {
90        broadcast_2d_0d::<crate::device::nvidia::Nvidia>();
91    }
92
93    // #[test]
94    fn broadcast_2d_1d<D: Device>() {
95        let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3.], [3]);
96        let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3]);
97        res.to_ref_mut().broadcast(&source.to_ref());
98        let ans: Matrix<Owned<f32>, DimDyn, D> =
99            Matrix::from_vec(vec![1., 2., 3., 1., 2., 3.], [2, 3]);
100        let diff = ans.to_ref() - res.to_ref();
101        let diff_sum = diff.to_ref().asum();
102        assert_eq!(diff_sum, 0.);
103    }
104    #[test]
105    fn broadcast_2d_1d_cpu() {
106        broadcast_2d_1d::<crate::device::cpu::Cpu>();
107    }
108    #[test]
109    #[cfg(feature = "nvidia")]
110    fn broadcast_2d_1d_nvidia() {
111        broadcast_2d_1d::<crate::device::nvidia::Nvidia>();
112    }
113
114    fn broadcast_4d_2d<D: Device>() {
115        let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2.], [1, 2]);
116        let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3, 1, 2]);
117        res.to_ref_mut().broadcast(&source.to_ref());
118        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(
119            vec![1., 2., 1., 2., 1., 2., 1., 2., 1., 2., 1., 2.],
120            [2, 3, 1, 2],
121        );
122        let diff = ans.to_ref() - res.to_ref();
123        let diff_sum = diff.to_ref().asum();
124        assert_eq!(diff_sum, 0.);
125    }
126    #[test]
127    fn broadcast_4d_2d_cpu() {
128        broadcast_4d_2d::<crate::device::cpu::Cpu>();
129    }
130    #[test]
131    #[cfg(feature = "nvidia")]
132    fn broadcast_4d_2d_nvidia() {
133        broadcast_4d_2d::<crate::device::nvidia::Nvidia>();
134    }
135
136    // #[test]
137    fn broadcast_4d_4d<D: Device>() {
138        let source: Matrix<Owned<f32>, DimDyn, D> =
139            Matrix::from_vec(vec![1., 2., 3., 4.], [1, 1, 1, 4]);
140        let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3, 4, 4]);
141        res.to_ref_mut().broadcast(&source.to_ref());
142        let ans: Matrix<_, DimDyn, D> = Matrix::from_vec(
143            vec![
144                1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1.,
145                2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2.,
146                3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3.,
147                4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
148                1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
149            ],
150            [2, 3, 4, 4],
151        );
152
153        let diff = ans - res;
154        let diff_sum = diff.asum();
155        assert_eq!(diff_sum, 0.);
156    }
157    #[test]
158    fn broadcast_4d_4d_cpu() {
159        broadcast_4d_4d::<crate::device::cpu::Cpu>();
160    }
161    #[test]
162    #[cfg(feature = "nvidia")]
163    fn broadcast_4d_4d_nvidia() {
164        broadcast_4d_4d::<crate::device::nvidia::Nvidia>();
165    }
166}