1use tract_data::internal::*;
3
4pub fn multi_broadcast<D>(shapes: &[impl AsRef<[D]>]) -> TractResult<TVec<D>>
6where
7 D: DimLike,
8{
9 let one = D::one();
10 let Some(len) = shapes.iter().map(|shape| shape.as_ref().len()).max() else {
11 return Ok(tvec!());
12 };
13 let mut shape: TVec<D> = tvec!();
14 for i in 0..len {
15 let mut wanted_size = D::one();
16 for shape in shapes {
17 let len = shape.as_ref().len();
18 let dim = if i < len { &shape.as_ref()[len - i - 1] } else { &one };
19 wanted_size = wanted_size.broadcast(dim.clone())?;
20 }
21 shape.push(wanted_size)
22 }
23 shape.reverse();
24 Ok(shape)
25}
26
27#[cfg(test)]
28mod tests {
29 use super::*;
30
31 #[test]
32 fn onnx_1() {
33 assert_eq!(multi_broadcast(&tvec![tvec![2, 3, 4, 5], tvec![]]).unwrap(), tvec![2, 3, 4, 5])
34 }
35
36 #[test]
37 fn onnx_2() {
38 assert_eq!(multi_broadcast(&tvec![tvec![2, 3, 4, 5], tvec![5]]).unwrap(), tvec![2, 3, 4, 5])
39 }
40
41 #[test]
42 fn onnx_3() {
43 assert_eq!(
44 multi_broadcast(&tvec![tvec![4, 5], tvec![2, 3, 4, 5]]).unwrap(),
45 tvec![2, 3, 4, 5]
46 )
47 }
48
49 #[test]
50 fn onnx_4() {
51 assert_eq!(
52 multi_broadcast(&tvec![tvec![1, 4, 5], tvec![2, 3, 4, 1]]).unwrap(),
53 tvec![2, 3, 4, 5]
54 )
55 }
56
57 #[test]
58 fn onnx_5() {
59 assert_eq!(
60 multi_broadcast(&tvec![tvec![3, 4, 5], tvec![2, 1, 1, 1]]).unwrap(),
61 tvec![2, 3, 4, 5]
62 )
63 }
64}