1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait},
4 index::Index0D,
5 matrix::{Matrix, Ref, Repr},
6 num::Num,
7};
8
9impl<T: Num, D: Device> Matrix<Ref<&mut T>, DimDyn, D> {
10 #[expect(clippy::missing_panics_doc)]
11 pub fn broadcast<R: Repr<Item = T>>(&self, source: &Matrix<R, DimDyn, D>) {
12 let source = source.to_ref();
13 if !(self.shape().is_include(source.shape())
14 || self.shape().is_include_bradcast(source.shape()))
15 {
16 panic!("!self.shape().is_include(source.shape())");
17 }
18 if self.shape() == source.shape() {
19 self.copy_from(&source);
20 return;
21 }
22 if !source.shape().is_empty() && source.shape()[0] == 1 {
23 let source = source.index_axis_dyn(Index0D::new(0));
24 self.broadcast(&source);
25 return;
26 }
27
28 let diff_len = self.shape().len() - source.shape().len();
29
30 if diff_len == 1 {
31 for i in 0..self.shape()[0] {
32 let to = self.index_axis_mut_dyn(Index0D::new(i));
33 to.copy_from(&source);
34 }
35 } else {
36 for i in 0..self.shape()[0] {
37 let to = self.index_axis_mut_dyn(Index0D::new(i));
38 to.broadcast(&source);
39 }
40 }
41 }
42}
43
44#[cfg(test)]
45mod broadcast_test {
46 #![expect(clippy::float_cmp)]
47 use crate::{
48 device::Device,
49 dim::DimDyn,
50 matrix::{Matrix, Owned},
51 };
52
53 fn broadcast_1d_0d<D: Device>() {
54 let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.], []);
55 let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([3]);
56 res.to_ref_mut().broadcast(&source.to_ref());
57 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 1., 1.], [3]);
58 let diff = ans.to_ref() - res.to_ref();
59 let diff_sum = diff.to_ref().asum();
60 assert_eq!(diff_sum, 0.);
61 }
62 #[test]
63 fn broadcast_1d_0d_cpu() {
64 broadcast_1d_0d::<crate::device::cpu::Cpu>();
65 }
66 #[test]
67 #[cfg(feature = "nvidia")]
68 fn broadcast_1d_0d_nvidia() {
69 broadcast_1d_0d::<crate::device::nvidia::Nvidia>();
70 }
71
72 fn broadcast_2d_0d<D: Device>() {
74 let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.], []);
75 let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3]);
76 res.to_ref_mut().broadcast(&source.to_ref());
77 let ans: Matrix<Owned<f32>, DimDyn, D> =
78 Matrix::from_vec(vec![1., 1., 1., 1., 1., 1.], [2, 3]);
79 let diff = ans.to_ref() - res.to_ref();
80 let diff_sum = diff.to_ref().asum();
81 assert_eq!(diff_sum, 0.);
82 }
83 #[test]
84 fn broadcast_2d_0d_cpu() {
85 broadcast_2d_0d::<crate::device::cpu::Cpu>();
86 }
87 #[test]
88 #[cfg(feature = "nvidia")]
89 fn broadcast_2d_0d_nvidia() {
90 broadcast_2d_0d::<crate::device::nvidia::Nvidia>();
91 }
92
93 fn broadcast_2d_1d<D: Device>() {
95 let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3.], [3]);
96 let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3]);
97 res.to_ref_mut().broadcast(&source.to_ref());
98 let ans: Matrix<Owned<f32>, DimDyn, D> =
99 Matrix::from_vec(vec![1., 2., 3., 1., 2., 3.], [2, 3]);
100 let diff = ans.to_ref() - res.to_ref();
101 let diff_sum = diff.to_ref().asum();
102 assert_eq!(diff_sum, 0.);
103 }
104 #[test]
105 fn broadcast_2d_1d_cpu() {
106 broadcast_2d_1d::<crate::device::cpu::Cpu>();
107 }
108 #[test]
109 #[cfg(feature = "nvidia")]
110 fn broadcast_2d_1d_nvidia() {
111 broadcast_2d_1d::<crate::device::nvidia::Nvidia>();
112 }
113
114 fn broadcast_4d_2d<D: Device>() {
115 let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2.], [1, 2]);
116 let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3, 1, 2]);
117 res.to_ref_mut().broadcast(&source.to_ref());
118 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(
119 vec![1., 2., 1., 2., 1., 2., 1., 2., 1., 2., 1., 2.],
120 [2, 3, 1, 2],
121 );
122 let diff = ans.to_ref() - res.to_ref();
123 let diff_sum = diff.to_ref().asum();
124 assert_eq!(diff_sum, 0.);
125 }
126 #[test]
127 fn broadcast_4d_2d_cpu() {
128 broadcast_4d_2d::<crate::device::cpu::Cpu>();
129 }
130 #[test]
131 #[cfg(feature = "nvidia")]
132 fn broadcast_4d_2d_nvidia() {
133 broadcast_4d_2d::<crate::device::nvidia::Nvidia>();
134 }
135
136 fn broadcast_4d_4d<D: Device>() {
138 let source: Matrix<Owned<f32>, DimDyn, D> =
139 Matrix::from_vec(vec![1., 2., 3., 4.], [1, 1, 1, 4]);
140 let mut res: Matrix<Owned<f32>, DimDyn, D> = Matrix::zeros([2, 3, 4, 4]);
141 res.to_ref_mut().broadcast(&source.to_ref());
142 let ans: Matrix<_, DimDyn, D> = Matrix::from_vec(
143 vec![
144 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1.,
145 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2.,
146 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3.,
147 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
148 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
149 ],
150 [2, 3, 4, 4],
151 );
152
153 let diff = ans - res;
154 let diff_sum = diff.asum();
155 assert_eq!(diff_sum, 0.);
156 }
157 #[test]
158 fn broadcast_4d_4d_cpu() {
159 broadcast_4d_4d::<crate::device::cpu::Cpu>();
160 }
161 #[test]
162 #[cfg(feature = "nvidia")]
163 fn broadcast_4d_4d_nvidia() {
164 broadcast_4d_4d::<crate::device::nvidia::Nvidia>();
165 }
166}