tract_core/
broadcast.rs

1//! N-way tensor broadcast
2use tract_data::internal::*;
3
4/// Computes a shape, if any, to which all shapes can be broadcasted.
5pub fn multi_broadcast<D>(shapes: &[impl AsRef<[D]>]) -> TractResult<TVec<D>>
6where
7    D: DimLike,
8{
9    let one = D::one();
10    let Some(len) = shapes.iter().map(|shape| shape.as_ref().len()).max() else {
11        return Ok(tvec!());
12    };
13    let mut shape: TVec<D> = tvec!();
14    for i in 0..len {
15        let mut wanted_size = D::one();
16        for shape in shapes {
17            let len = shape.as_ref().len();
18            let dim = if i < len { &shape.as_ref()[len - i - 1] } else { &one };
19            wanted_size = wanted_size.broadcast(dim.clone())?;
20        }
21        shape.push(wanted_size)
22    }
23    shape.reverse();
24    Ok(shape)
25}
26
27#[cfg(test)]
28mod tests {
29    use super::*;
30
31    #[test]
32    fn onnx_1() {
33        assert_eq!(multi_broadcast(&tvec![tvec![2, 3, 4, 5], tvec![]]).unwrap(), tvec![2, 3, 4, 5])
34    }
35
36    #[test]
37    fn onnx_2() {
38        assert_eq!(multi_broadcast(&tvec![tvec![2, 3, 4, 5], tvec![5]]).unwrap(), tvec![2, 3, 4, 5])
39    }
40
41    #[test]
42    fn onnx_3() {
43        assert_eq!(
44            multi_broadcast(&tvec![tvec![4, 5], tvec![2, 3, 4, 5]]).unwrap(),
45            tvec![2, 3, 4, 5]
46        )
47    }
48
49    #[test]
50    fn onnx_4() {
51        assert_eq!(
52            multi_broadcast(&tvec![tvec![1, 4, 5], tvec![2, 3, 4, 1]]).unwrap(),
53            tvec![2, 3, 4, 5]
54        )
55    }
56
57    #[test]
58    fn onnx_5() {
59        assert_eq!(
60            multi_broadcast(&tvec![tvec![3, 4, 5], tvec![2, 1, 1, 1]]).unwrap(),
61            tvec![2, 3, 4, 5]
62        )
63    }
64}