tract_core/ops/array/
broadcast.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, new, Hash)]
4pub struct MultiBroadcastTo {
5 pub shape: ShapeFact,
6}
7
8impl Op for MultiBroadcastTo {
9 fn name(&self) -> Cow<str> {
10 "MultiBroadcastTo".into()
11 }
12
13 op_as_typed_op!();
14}
15
16impl EvalOp for MultiBroadcastTo {
17 fn is_stateless(&self) -> bool {
18 true
19 }
20
21 fn eval_with_session(
22 &self,
23 session: &SessionState,
24 inputs: TVec<TValue>,
25 ) -> TractResult<TVec<TValue>> {
26 let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
27 Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
28 }
29}
30
31impl TypedOp for MultiBroadcastTo {
32 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
33 ensure!(inputs.len() == 1);
34 let mut fact = inputs[0].datum_type.fact(self.shape.clone());
35 fact.uniform.clone_from(&inputs[0].uniform);
36 Ok(tvec!(fact))
37 }
38
39 fn concretize_dims(
40 &self,
41 _source: &TypedModel,
42 node: &TypedNode,
43 target: &mut TypedModel,
44 mapping: &HashMap<OutletId, OutletId>,
45 values: &SymbolValues,
46 ) -> TractResult<TVec<OutletId>> {
47 let input = mapping[&node.inputs[0]];
48 let op =
49 Self { shape: self.shape.iter().map(|d| d.eval(values)).collect::<TVec<_>>().into() };
50 target.wire_node(&node.name, op, &[input])
51 }
52
53 fn declutter(
54 &self,
55 model: &TypedModel,
56 node: &TypedNode,
57 ) -> TractResult<Option<TypedModelPatch>> {
58 let input_fact = model.outlet_fact(node.inputs[0])?;
59 if input_fact.shape == self.shape {
60 TypedModelPatch::shunt_one_op(model, node)
61 } else {
62 Ok(None)
63 }
64 }
65
66 as_op!();
67}