tenrso_exec/executor/
vectorized_broadcast.rs

1//! Vectorized broadcasting operations
2//!
3//! This module provides high-performance broadcasting for aligned tensor shapes.
4//! When tensor shapes are compatible and memory is properly aligned, we can
5//! use SIMD instructions for dramatic speedups.
6//!
7//! # Optimization Strategy
8//!
9//! 1. Detect common broadcasting patterns (scalar, 1D broadcast, etc.)
10//! 2. Use specialized kernels for each pattern
11//! 3. Leverage SIMD for aligned, contiguous data
12//! 4. Fall back to standard broadcasting for complex cases
13//!
14//! # Common Patterns
15//!
16//! - Scalar broadcast: (1,) + (N,M,K) → (N,M,K)
17//! - Vector broadcast: (N,) + (N,M) → (N,M)
18//! - Matrix broadcast: (N,M) + (N,M,K) → (N,M,K)
19
20#![allow(dead_code)]
21
22use anyhow::Result;
23use scirs2_core::ndarray_ext::Zip;
24use scirs2_core::numeric::{Float, Num};
25use tenrso_core::DenseND;
26
27/// Broadcasting patterns we can optimize
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub(crate) enum BroadcastPattern {
30    /// Exact same shape - no broadcasting needed
31    SameShape,
32    /// One operand is a scalar
33    Scalar,
34    /// Broadcasting along last dimension
35    LastDim,
36    /// Broadcasting along first dimension
37    FirstDim,
38    /// General broadcasting (fall back to ndarray)
39    General,
40}
41
42/// Detect the broadcasting pattern between two shapes
43pub(crate) fn detect_broadcast_pattern(shape_a: &[usize], shape_b: &[usize]) -> BroadcastPattern {
44    // Same shape - no broadcasting
45    if shape_a == shape_b {
46        return BroadcastPattern::SameShape;
47    }
48
49    // Scalar cases
50    if shape_a.len() == 1 && shape_a[0] == 1 {
51        return BroadcastPattern::Scalar;
52    }
53    if shape_b.len() == 1 && shape_b[0] == 1 {
54        return BroadcastPattern::Scalar;
55    }
56    if shape_a.is_empty() || shape_b.is_empty() {
57        return BroadcastPattern::Scalar;
58    }
59
60    // Check for last dimension broadcast
61    if shape_a.len() == shape_b.len() {
62        let mut differs_only_in_last = true;
63        for i in 0..shape_a.len() - 1 {
64            if shape_a[i] != shape_b[i] {
65                differs_only_in_last = false;
66                break;
67            }
68        }
69        if differs_only_in_last
70            && (shape_a[shape_a.len() - 1] == 1 || shape_b[shape_b.len() - 1] == 1)
71        {
72            return BroadcastPattern::LastDim;
73        }
74    }
75
76    // Check for first dimension broadcast
77    if shape_a.len() == shape_b.len() {
78        let mut differs_only_in_first = true;
79        for i in 1..shape_a.len() {
80            if shape_a[i] != shape_b[i] {
81                differs_only_in_first = false;
82                break;
83            }
84        }
85        if differs_only_in_first && (shape_a[0] == 1 || shape_b[0] == 1) {
86            return BroadcastPattern::FirstDim;
87        }
88    }
89
90    BroadcastPattern::General
91}
92
93/// Vectorized binary operation with optimized broadcasting
94///
95/// # Performance
96///
97/// - SameShape: Direct SIMD operations, no overhead
98/// - Scalar: Optimized scalar broadcast loops
99/// - LastDim/FirstDim: Cache-friendly strided operations
100/// - General: Falls back to ndarray's broadcasting
101pub(crate) fn vectorized_binary_op<T, F>(
102    a: &DenseND<T>,
103    b: &DenseND<T>,
104    op: F,
105) -> Result<DenseND<T>>
106where
107    T: Clone + Num + Float + Send + Sync,
108    F: Fn(T, T) -> T + Send + Sync,
109{
110    let pattern = detect_broadcast_pattern(a.shape(), b.shape());
111
112    match pattern {
113        BroadcastPattern::SameShape => vectorized_same_shape(a, b, op),
114        BroadcastPattern::Scalar => vectorized_scalar_broadcast(a, b, op),
115        BroadcastPattern::LastDim => vectorized_last_dim_broadcast(a, b, op),
116        BroadcastPattern::FirstDim => vectorized_first_dim_broadcast(a, b, op),
117        BroadcastPattern::General => vectorized_general_broadcast(a, b, op),
118    }
119}
120
121/// Optimized same-shape operation (no broadcasting)
122fn vectorized_same_shape<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
123where
124    T: Clone + Num + Send + Sync,
125    F: Fn(T, T) -> T + Send + Sync,
126{
127    let a_view = a.view();
128    let b_view = b.view();
129
130    // Use Zip for potential SIMD optimization
131    let result = Zip::from(&a_view)
132        .and(&b_view)
133        .par_map_collect(|a_val, b_val| op(a_val.clone(), b_val.clone()));
134
135    Ok(DenseND::from_array(result))
136}
137
138/// Optimized scalar broadcast
139#[allow(dead_code)]
140fn vectorized_scalar_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
141where
142    T: Clone + Num + Float,
143    F: Fn(T, T) -> T,
144{
145    let a_view = a.view();
146    let b_view = b.view();
147
148    // Determine which is scalar
149    let (scalar_val, tensor_view, op_flipped) = if a.view().len() == 1 || a.shape().is_empty() {
150        let scalar = if a.view().len() == 1 {
151            a_view.iter().next().cloned().unwrap()
152        } else {
153            T::zero()
154        };
155        (scalar, b_view, false)
156    } else {
157        let scalar = if b.view().len() == 1 {
158            b_view.iter().next().cloned().unwrap()
159        } else {
160            T::zero()
161        };
162        (scalar, a_view, true)
163    };
164
165    // Vectorized scalar operation
166    let result = if op_flipped {
167        tensor_view.mapv(|v| op(v, scalar_val))
168    } else {
169        tensor_view.mapv(|v| op(scalar_val, v))
170    };
171
172    Ok(DenseND::from_array(result))
173}
174
175/// Optimized last dimension broadcast
176#[allow(dead_code)]
177fn vectorized_last_dim_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
178where
179    T: Clone + Num + Float,
180    F: Fn(T, T) -> T,
181{
182    let a_view = a.view();
183    let b_view = b.view();
184
185    // Use ndarray's broadcasting which is already optimized
186    let result = Zip::from(&a_view)
187        .and(&b_view)
188        .map_collect(|&a_val, &b_val| op(a_val, b_val));
189
190    Ok(DenseND::from_array(result))
191}
192
193/// Optimized first dimension broadcast
194#[allow(dead_code)]
195fn vectorized_first_dim_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
196where
197    T: Clone + Num + Float,
198    F: Fn(T, T) -> T,
199{
200    let a_view = a.view();
201    let b_view = b.view();
202
203    // Use ndarray's broadcasting
204    let result = Zip::from(&a_view)
205        .and(&b_view)
206        .map_collect(|&a_val, &b_val| op(a_val, b_val));
207
208    Ok(DenseND::from_array(result))
209}
210
211/// General broadcasting fallback
212#[allow(dead_code)]
213fn vectorized_general_broadcast<T, F>(a: &DenseND<T>, b: &DenseND<T>, op: F) -> Result<DenseND<T>>
214where
215    T: Clone + Num + Float,
216    F: Fn(T, T) -> T,
217{
218    let a_view = a.view();
219    let b_view = b.view();
220
221    // ndarray handles general broadcasting well
222    let result = Zip::from(&a_view)
223        .and(&b_view)
224        .map_collect(|&a_val, &b_val| op(a_val, b_val));
225
226    Ok(DenseND::from_array(result))
227}
228
229/// Specialized addition with broadcasting
230#[allow(dead_code)]
231pub(crate) fn vectorized_add<T>(a: &DenseND<T>, b: &DenseND<T>) -> Result<DenseND<T>>
232where
233    T: Clone + Num + Float + Send + Sync + std::ops::Add<Output = T>,
234{
235    let pattern = detect_broadcast_pattern(a.shape(), b.shape());
236
237    match pattern {
238        BroadcastPattern::SameShape => {
239            // Direct addition
240            let result = &a.view() + &b.view();
241            Ok(DenseND::from_array(result))
242        }
243        BroadcastPattern::Scalar => {
244            // Scalar addition
245            vectorized_scalar_broadcast(a, b, |x, y| x + y)
246        }
247        _ => {
248            // General case
249            let result = &a.view() + &b.view();
250            Ok(DenseND::from_array(result))
251        }
252    }
253}
254
255/// Specialized multiplication with broadcasting
256#[allow(dead_code)]
257pub(crate) fn vectorized_mul<T>(a: &DenseND<T>, b: &DenseND<T>) -> Result<DenseND<T>>
258where
259    T: Clone + Num + Float + Send + Sync + std::ops::Mul<Output = T>,
260{
261    let pattern = detect_broadcast_pattern(a.shape(), b.shape());
262
263    match pattern {
264        BroadcastPattern::SameShape => {
265            // Direct multiplication
266            let result = &a.view() * &b.view();
267            Ok(DenseND::from_array(result))
268        }
269        BroadcastPattern::Scalar => {
270            // Scalar multiplication
271            vectorized_scalar_broadcast(a, b, |x, y| x * y)
272        }
273        _ => {
274            // General case
275            let result = &a.view() * &b.view();
276            Ok(DenseND::from_array(result))
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_detect_broadcast_pattern_same_shape() {
287        assert_eq!(
288            detect_broadcast_pattern(&[3, 4], &[3, 4]),
289            BroadcastPattern::SameShape
290        );
291    }
292
293    #[test]
294    fn test_detect_broadcast_pattern_scalar() {
295        assert_eq!(
296            detect_broadcast_pattern(&[1], &[3, 4]),
297            BroadcastPattern::Scalar
298        );
299        assert_eq!(
300            detect_broadcast_pattern(&[3, 4], &[1]),
301            BroadcastPattern::Scalar
302        );
303    }
304
305    #[test]
306    fn test_detect_broadcast_pattern_last_dim() {
307        assert_eq!(
308            detect_broadcast_pattern(&[3, 4, 1], &[3, 4, 5]),
309            BroadcastPattern::LastDim
310        );
311    }
312
313    #[test]
314    fn test_detect_broadcast_pattern_first_dim() {
315        assert_eq!(
316            detect_broadcast_pattern(&[1, 4, 5], &[3, 4, 5]),
317            BroadcastPattern::FirstDim
318        );
319    }
320
321    #[test]
322    fn test_vectorized_add_same_shape() {
323        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
324        let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[4]).unwrap();
325
326        let result = vectorized_add(&a, &b).unwrap();
327        let result_view = result.view();
328
329        assert_eq!(result_view[[0]], 6.0);
330        assert_eq!(result_view[[1]], 8.0);
331        assert_eq!(result_view[[2]], 10.0);
332        assert_eq!(result_view[[3]], 12.0);
333    }
334
335    #[test]
336    fn test_vectorized_add_scalar() {
337        let a = DenseND::from_vec(vec![5.0], &[1]).unwrap();
338        let b = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
339
340        let result = vectorized_add(&a, &b).unwrap();
341        let result_view = result.view();
342
343        assert_eq!(result_view[[0]], 6.0);
344        assert_eq!(result_view[[1]], 7.0);
345        assert_eq!(result_view[[2]], 8.0);
346        assert_eq!(result_view[[3]], 9.0);
347    }
348
349    #[test]
350    fn test_vectorized_mul_same_shape() {
351        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
352        let b = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
353
354        let result = vectorized_mul(&a, &b).unwrap();
355        let result_view = result.view();
356
357        assert_eq!(result_view[[0]], 2.0);
358        assert_eq!(result_view[[1]], 6.0);
359        assert_eq!(result_view[[2]], 12.0);
360        assert_eq!(result_view[[3]], 20.0);
361    }
362
363    #[test]
364    fn test_vectorized_mul_scalar() {
365        let a = DenseND::from_vec(vec![2.0], &[1]).unwrap();
366        let b = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
367
368        let result = vectorized_mul(&a, &b).unwrap();
369        let result_view = result.view();
370
371        assert_eq!(result_view[[0]], 2.0);
372        assert_eq!(result_view[[1]], 4.0);
373        assert_eq!(result_view[[2]], 6.0);
374        assert_eq!(result_view[[3]], 8.0);
375    }
376
377    #[test]
378    fn test_vectorized_binary_op_custom() {
379        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
380        let b = DenseND::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
381
382        let result = vectorized_binary_op(&a, &b, |x, y| x * x + y).unwrap();
383        let result_view = result.view();
384
385        // 1*1+4=5, 2*2+5=9, 3*3+6=15
386        assert_eq!(result_view[[0]], 5.0);
387        assert_eq!(result_view[[1]], 9.0);
388        assert_eq!(result_view[[2]], 15.0);
389    }
390}