1#![allow(dead_code)]
21
22use anyhow::Result;
23use scirs2_core::ndarray_ext::Zip;
24use scirs2_core::numeric::{Float, Num};
25use tenrso_core::DenseND;
26
27#[derive(Debug, Clone, Copy, PartialEq)]
29pub(crate) enum BroadcastPattern {
30 SameShape,
32 Scalar,
34 LastDim,
36 FirstDim,
38 General,
40}
41
42pub(crate) fn detect_broadcast_pattern(shape_a: &[usize], shape_b: &[usize]) -> BroadcastPattern {
44 if shape_a == shape_b {
46 return BroadcastPattern::SameShape;
47 }
48
49 if shape_a.len() == 1 && shape_a[0] == 1 {
51 return BroadcastPattern::Scalar;
52 }
53 if shape_b.len() == 1 && shape_b[0] == 1 {
54 return BroadcastPattern::Scalar;
55 }
56 if shape_a.is_empty() || shape_b.is_empty() {
57 return BroadcastPattern::Scalar;
58 }
59
60 if shape_a.len() == shape_b.len() {
62 let mut differs_only_in_last = true;
63 for i in 0..shape_a.len() - 1 {
64 if shape_a[i] != shape_b[i] {
65 differs_only_in_last = false;
66 break;
67 }
68 }
69 if differs_only_in_last
70 && (shape_a[shape_a.len() - 1] == 1 || shape_b[shape_b.len() - 1] == 1)
71 {
72 return BroadcastPattern::LastDim;
73 }
74 }
75
76 if shape_a.len() == shape_b.len() {
78 let mut differs_only_in_first = true;
79 for i in 1..shape_a.len() {
80 if shape_a[i] != shape_b[i] {
81 differs_only_in_first = false;
82 break;
83 }
84 }
85 if differs_only_in_first && (shape_a[0] == 1 || shape_b[0] == 1) {
86 return BroadcastPattern::FirstDim;
87 }
88 }
89
90 BroadcastPattern::General
91}
92
93pub(crate) fn vectorized_binary_op<T, F>(
102 a: &DenseND<T>,
103 b: &DenseND<T>,
104 op: F,
105) -> Result<DenseND<T>>
106where
107 T: Clone + Num + Float + Send + Sync,
108 F: Fn(T, T) -> T + Send + Sync,
109{
110 let pattern = detect_broadcast_pattern(a.shape(), b.shape());
111
112 match pattern {
113 BroadcastPattern::SameShape => vectorized_same_shape(a, b, op),
114 BroadcastPattern::Scalar => vectorized_scalar_broadcast(a, b, op),
115 BroadcastPattern::LastDim => vectorized_last_dim_broadcast(a, b, op),
116 BroadcastPattern::FirstDim => vectorized_first_dim_broadcast(a, b, op),
117 BroadcastPattern::General => vectorized_general_broadcast(a, b, op),
118 }
119}
120
121fn vectorized_same_shape<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
123where
124 T: Clone + Num + Send + Sync,
125 F: Fn(T, T) -> T + Send + Sync,
126{
127 let a_view = a.view();
128 let b_view = b.view();
129
130 let result = Zip::from(&a_view)
132 .and(&b_view)
133 .par_map_collect(|a_val, b_val| op(a_val.clone(), b_val.clone()));
134
135 Ok(DenseND::from_array(result))
136}
137
138#[allow(dead_code)]
140fn vectorized_scalar_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
141where
142 T: Clone + Num + Float,
143 F: Fn(T, T) -> T,
144{
145 let a_view = a.view();
146 let b_view = b.view();
147
148 let (scalar_val, tensor_view, op_flipped) = if a.view().len() == 1 || a.shape().is_empty() {
150 let scalar = if a.view().len() == 1 {
151 a_view.iter().next().cloned().unwrap()
152 } else {
153 T::zero()
154 };
155 (scalar, b_view, false)
156 } else {
157 let scalar = if b.view().len() == 1 {
158 b_view.iter().next().cloned().unwrap()
159 } else {
160 T::zero()
161 };
162 (scalar, a_view, true)
163 };
164
165 let result = if op_flipped {
167 tensor_view.mapv(|v| op(v, scalar_val))
168 } else {
169 tensor_view.mapv(|v| op(scalar_val, v))
170 };
171
172 Ok(DenseND::from_array(result))
173}
174
175#[allow(dead_code)]
177fn vectorized_last_dim_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
178where
179 T: Clone + Num + Float,
180 F: Fn(T, T) -> T,
181{
182 let a_view = a.view();
183 let b_view = b.view();
184
185 let result = Zip::from(&a_view)
187 .and(&b_view)
188 .map_collect(|&a_val, &b_val| op(a_val, b_val));
189
190 Ok(DenseND::from_array(result))
191}
192
193#[allow(dead_code)]
195fn vectorized_first_dim_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
196where
197 T: Clone + Num + Float,
198 F: Fn(T, T) -> T,
199{
200 let a_view = a.view();
201 let b_view = b.view();
202
203 let result = Zip::from(&a_view)
205 .and(&b_view)
206 .map_collect(|&a_val, &b_val| op(a_val, b_val));
207
208 Ok(DenseND::from_array(result))
209}
210
211#[allow(dead_code)]
213fn vectorized_general_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
214where
215 T: Clone + Num + Float,
216 F: Fn(T, T) -> T,
217{
218 let a_view = a.view();
219 let b_view = b.view();
220
221 let result = Zip::from(&a_view)
223 .and(&b_view)
224 .map_collect(|&a_val, &b_val| op(a_val, b_val));
225
226 Ok(DenseND::from_array(result))
227}
228
229#[allow(dead_code)]
231pub(crate) fn vectorized_add<T>(a: &DenseND<T>, b: &DenseND<T>) -> Result<DenseND<T>>
232where
233 T: Clone + Num + Float + Send + Sync + std::ops::Add<Output = T>,
234{
235 let pattern = detect_broadcast_pattern(a.shape(), b.shape());
236
237 match pattern {
238 BroadcastPattern::SameShape => {
239 let result = &a.view() + &b.view();
241 Ok(DenseND::from_array(result))
242 }
243 BroadcastPattern::Scalar => {
244 vectorized_scalar_broadcast(a, b, |x, y| x + y)
246 }
247 _ => {
248 let result = &a.view() + &b.view();
250 Ok(DenseND::from_array(result))
251 }
252 }
253}
254
255#[allow(dead_code)]
257pub(crate) fn vectorized_mul<T>(a: &DenseND<T>, b: &DenseND<T>) -> Result<DenseND<T>>
258where
259 T: Clone + Num + Float + Send + Sync + std::ops::Mul<Output = T>,
260{
261 let pattern = detect_broadcast_pattern(a.shape(), b.shape());
262
263 match pattern {
264 BroadcastPattern::SameShape => {
265 let result = &a.view() * &b.view();
267 Ok(DenseND::from_array(result))
268 }
269 BroadcastPattern::Scalar => {
270 vectorized_scalar_broadcast(a, b, |x, y| x * y)
272 }
273 _ => {
274 let result = &a.view() * &b.view();
276 Ok(DenseND::from_array(result))
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_detect_broadcast_pattern_same_shape() {
287 assert_eq!(
288 detect_broadcast_pattern(&[3, 4], &[3, 4]),
289 BroadcastPattern::SameShape
290 );
291 }
292
293 #[test]
294 fn test_detect_broadcast_pattern_scalar() {
295 assert_eq!(
296 detect_broadcast_pattern(&[1], &[3, 4]),
297 BroadcastPattern::Scalar
298 );
299 assert_eq!(
300 detect_broadcast_pattern(&[3, 4], &[1]),
301 BroadcastPattern::Scalar
302 );
303 }
304
305 #[test]
306 fn test_detect_broadcast_pattern_last_dim() {
307 assert_eq!(
308 detect_broadcast_pattern(&[3, 4, 1], &[3, 4, 5]),
309 BroadcastPattern::LastDim
310 );
311 }
312
313 #[test]
314 fn test_detect_broadcast_pattern_first_dim() {
315 assert_eq!(
316 detect_broadcast_pattern(&[1, 4, 5], &[3, 4, 5]),
317 BroadcastPattern::FirstDim
318 );
319 }
320
321 #[test]
322 fn test_vectorized_add_same_shape() {
323 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
324 let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[4]).unwrap();
325
326 let result = vectorized_add(&a, &b).unwrap();
327 let result_view = result.view();
328
329 assert_eq!(result_view[[0]], 6.0);
330 assert_eq!(result_view[[1]], 8.0);
331 assert_eq!(result_view[[2]], 10.0);
332 assert_eq!(result_view[[3]], 12.0);
333 }
334
335 #[test]
336 fn test_vectorized_add_scalar() {
337 let a = DenseND::from_vec(vec![5.0], &[1]).unwrap();
338 let b = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
339
340 let result = vectorized_add(&a, &b).unwrap();
341 let result_view = result.view();
342
343 assert_eq!(result_view[[0]], 6.0);
344 assert_eq!(result_view[[1]], 7.0);
345 assert_eq!(result_view[[2]], 8.0);
346 assert_eq!(result_view[[3]], 9.0);
347 }
348
349 #[test]
350 fn test_vectorized_mul_same_shape() {
351 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
352 let b = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
353
354 let result = vectorized_mul(&a, &b).unwrap();
355 let result_view = result.view();
356
357 assert_eq!(result_view[[0]], 2.0);
358 assert_eq!(result_view[[1]], 6.0);
359 assert_eq!(result_view[[2]], 12.0);
360 assert_eq!(result_view[[3]], 20.0);
361 }
362
363 #[test]
364 fn test_vectorized_mul_scalar() {
365 let a = DenseND::from_vec(vec![2.0], &[1]).unwrap();
366 let b = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
367
368 let result = vectorized_mul(&a, &b).unwrap();
369 let result_view = result.view();
370
371 assert_eq!(result_view[[0]], 2.0);
372 assert_eq!(result_view[[1]], 4.0);
373 assert_eq!(result_view[[2]], 6.0);
374 assert_eq!(result_view[[3]], 8.0);
375 }
376
377 #[test]
378 fn test_vectorized_binary_op_custom() {
379 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
380 let b = DenseND::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
381
382 let result = vectorized_binary_op(&a, &b, |x, y| x * x + y).unwrap();
383 let result_view = result.view();
384
385 assert_eq!(result_view[[0]], 5.0);
387 assert_eq!(result_view[[1]], 9.0);
388 assert_eq!(result_view[[2]], 15.0);
389 }
390}