tract_core/ops/array/
broadcast.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, new, Hash)]
4pub struct MultiBroadcastTo {
5 pub shape: ShapeFact,
6}
7
8impl Op for MultiBroadcastTo {
9 fn name(&self) -> StaticName {
10 "MultiBroadcastTo".into()
11 }
12
13 op_as_typed_op!();
14}
15
16impl EvalOp for MultiBroadcastTo {
17 fn is_stateless(&self) -> bool {
18 true
19 }
20
21 fn eval_with_session(
22 &self,
23 _node_id: usize,
24 session: &SessionState,
25 inputs: TVec<TValue>,
26 ) -> TractResult<TVec<TValue>> {
27 let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
28 Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
29 }
30}
31
32impl TypedOp for MultiBroadcastTo {
33 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
34 ensure!(inputs.len() == 1);
35 let mut fact = inputs[0].datum_type.fact(self.shape.clone());
36 fact.uniform.clone_from(&inputs[0].uniform);
37 Ok(tvec!(fact))
38 }
39
40 fn concretize_dims(
41 &self,
42 _source: &TypedModel,
43 node: &TypedNode,
44 target: &mut TypedModel,
45 mapping: &HashMap<OutletId, OutletId>,
46 values: &SymbolValues,
47 ) -> TractResult<TVec<OutletId>> {
48 let input = mapping[&node.inputs[0]];
49 let op =
50 Self { shape: self.shape.iter().map(|d| d.eval(values)).collect::<TVec<_>>().into() };
51 target.wire_node(&node.name, op, &[input])
52 }
53
54 fn declutter(
55 &self,
56 model: &TypedModel,
57 node: &TypedNode,
58 ) -> TractResult<Option<TypedModelPatch>> {
59 let input_fact = model.outlet_fact(node.inputs[0])?;
60 if input_fact.shape == self.shape {
61 TypedModelPatch::shunt_one_op(model, node)
62 } else {
63 Ok(None)
64 }
65 }
66
67 as_op!();
68}