tract_hir/ops/
binary.rs

1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::broadcast::multi_broadcast;
5use tract_core::ops as mir;
6pub use tract_core::ops::cast::wire_cast;
7pub use tract_core::ops::change_axes::wire_rank_broadcast;
8use tract_core::ops::binary::BinMiniOp;
9
10#[derive(Debug, Clone)]
11pub struct InferenceBinOp(pub Box<dyn BinMiniOp>);
12
13impl Expansion for InferenceBinOp {
14    fn name(&self) -> StaticName {
15        self.0.name().into()
16    }
17
18    fn validation(&self) -> Validation {
19        self.0.validation()
20    }
21
22    fn rules<'r, 'p: 'r, 's: 'r>(
23        &'s self,
24        s: &mut Solver<'r>,
25        inputs: &'p [TensorProxy],
26        outputs: &'p [TensorProxy],
27    ) -> InferenceResult {
28        rules(s, inputs, outputs, move |typa, typb| self.0.result_datum_type(typa, typb))
29    }
30
31    fn wire(
32        &self,
33        prefix: &str,
34        target: &mut TypedModel,
35        inputs: &[OutletId],
36    ) -> TractResult<TVec<OutletId>> {
37        let operating_datum_type = self.0.operating_datum_type(
38            target.outlet_fact(inputs[0])?.datum_type,
39            target.outlet_fact(inputs[1])?.datum_type,
40        )?;
41        let wires = wire_rank_broadcast(prefix, target, inputs)?;
42        let wires = wire_cast(prefix, target, &wires, operating_datum_type)?;
43        target.wire_node(prefix, mir::binary::TypedBinOp(self.0.clone(), None), &wires)
44    }
45}
46
47pub fn rules<'r, 'p: 'r, 's: 'r, DT: Fn(DatumType, DatumType) -> TractResult<DatumType> + 'p>(
48    s: &mut Solver<'r>,
49    inputs: &'p [TensorProxy],
50    outputs: &'p [TensorProxy],
51    dt: DT,
52) -> InferenceResult {
53    check_input_arity(inputs, 2)?;
54    check_output_arity(outputs, 1)?;
55
56    /*
57    s.with(&inputs[0].shape, move |s, a_shape| {
58        s.with(&inputs[1].shape, move |s, b_shape| {
59            /*
60            if let Some(c_shape) =
61                crate::infer::helpers::infer_shape_broadcasting(&[&a_shape, &b_shape])
62                    .with_context(|| {
63                        format!(
64                            "Matching {a_shape:?} and {b_shape:?} with numpy/onnx broadcast rules"
65                        )
66                    })?
67            {
68                s.equals(&outputs[0].shape, c_shape)?;
69            }
70            Ok(())
71        })
72        */
73    })?;
74    */
75    s.given_2(&inputs[0].shape, &inputs[1].shape, move |s, a, b| {
76        s.equals(&outputs[0].shape, multi_broadcast(&[a, b])?)
77    })?;
78    s.given_2(&inputs[0].datum_type, &inputs[1].datum_type, move |s, typa, typb| {
79        s.equals(&outputs[0].datum_type, dt(typa, typb)?)
80    })?;
81    Ok(())
82}
83
84pub trait BinIntoHir {
85    fn into_hir(self) -> Box<dyn InferenceOp>;
86}
87
88impl<B: BinMiniOp> BinIntoHir for B {
89    fn into_hir(self) -> Box<dyn InferenceOp> {
90        expand(InferenceBinOp(Box::new(self) as _))
91    }
92}
93
94#[derive(Debug, Clone)]
95pub struct Nary(pub Box<dyn mir::binary::BinMiniOp>, pub bool);
96
97impl Nary {
98    fn normalize_t<T>(t: &mut Tensor, n: usize) -> TractResult<()>
99    where
100        T: Datum + std::ops::DivAssign<T> + Copy,
101        usize: tract_num_traits::AsPrimitive<T>,
102    {
103        use tract_num_traits::AsPrimitive;
104        let mut t = t.to_array_view_mut::<T>()?;
105        let n: T = n.as_();
106        t /= &tract_ndarray::arr0(n);
107        Ok(())
108    }
109}
110
111impl Op for Nary {
112    fn name(&self) -> StaticName {
113        format!("{}Nary", self.0.name()).into()
114    }
115
116    fn validation(&self) -> Validation {
117        self.0.validation()
118    }
119
120    not_a_typed_op!();
121}
122
123impl EvalOp for Nary {
124    fn is_stateless(&self) -> bool {
125        true
126    }
127
128    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
129        let mut t = inputs[0].clone().into_tensor();
130        for i in inputs[1..].iter() {
131            let mut i = i.clone().into_tensor();
132            let operating_datum_type =
133                self.0.operating_datum_type(t.datum_type(), i.datum_type())?;
134            if i.datum_type() != operating_datum_type {
135                i = i.cast_to_dt(operating_datum_type)?.into_owned();
136            }
137            if t.datum_type() != operating_datum_type {
138                t = t.cast_to_dt(operating_datum_type)?.into_owned();
139            }
140            t = self.0.eval(t.into_tvalue(), i.into_tvalue(), operating_datum_type)?;
141        }
142        if self.1 {
143            dispatch_numbers!(Self::normalize_t(t.datum_type())(&mut t, inputs.len()))?;
144        }
145        Ok(tvec!(t.into_tvalue()))
146    }
147}
148
149impl InferenceRulesOp for Nary {
150    fn rules<'r, 'p: 'r, 's: 'r>(
151        &'s self,
152        s: &mut Solver<'r>,
153        inputs: &'p [TensorProxy],
154        outputs: &'p [TensorProxy],
155    ) -> InferenceResult {
156        check_output_arity(outputs, 1)?;
157        let n = inputs.len();
158        s.given_all(
159            (0..n).map(|i| (&inputs[i].datum_type).bex()),
160            move |s, types: Vec<DatumType>| {
161                let dt = DatumType::super_type_for(&types)
162                    .with_context(|| format!("No super type for {types:?}"))?;
163                let dt = self.0.operating_datum_type(dt, dt)?;
164                let result = self.0.result_datum_type(dt, dt)?;
165                s.equals(&outputs[0].datum_type, result)
166            },
167        )?;
168        s.given_all(inputs.iter().map(|i| &i.shape), move |s, shapes: Vec<TVec<TDim>>| {
169            let out = tract_core::broadcast::multi_broadcast(&shapes)?;
170            s.equals(&outputs[0].shape, ShapeFactoid::from(out))
171        })
172    }
173
174    fn to_typed(
175        &self,
176        _source: &InferenceModel,
177        node: &InferenceNode,
178        target: &mut TypedModel,
179        mapping: &HashMap<OutletId, OutletId>,
180    ) -> TractResult<TVec<OutletId>> {
181        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<Vec<_>>();
182        let types = inputs
183            .iter()
184            .map(|i| Ok(target.outlet_fact(*i)?.datum_type))
185            .collect::<TractResult<Vec<_>>>()?;
186        let dt = DatumType::super_type_for(&types)
187            .with_context(|| format!("No super type for {types:?}"))?;
188        let operating = self.0.operating_datum_type(dt, dt)?;
189        let inputs = wire_cast(&node.name, target, &inputs, operating)?;
190        let mut wire = inputs[0];
191        for (ix, i) in inputs[1..].iter().enumerate() {
192            let wires = wire_rank_broadcast(format!("{}.{}", node.name, ix), target, &[wire, *i])?;
193            wire = target.wire_node(
194                format!("{}.{}", node.name, ix),
195                mir::binary::TypedBinOp(self.0.clone(), None),
196                &wires,
197            )?[0];
198        }
199        if self.1 {
200            let n = tensor0(inputs.len() as i32)
201                .cast_to_dt(node.outputs[0].fact.datum_type.concretize().unwrap())?
202                .into_owned()
203                .broadcast_into_rank(target.outlet_fact(inputs[0])?.rank())?;
204            let n = target.add_const(format!("{}.n", node.name), n.into_arc_tensor())?;
205            wire = target.wire_node(
206                format!("{}.norm", node.name),
207                crate::ops::math::div(),
208                &[wire, n],
209            )?[0];
210        }
211        Ok(tvec!(wire))
212    }
213
214    as_op!();
215}