1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::broadcast::multi_broadcast;
5use tract_core::ops as mir;
6pub use tract_core::ops::cast::wire_cast;
7pub use tract_core::ops::change_axes::wire_rank_broadcast;
8use tract_core::ops::binary::BinMiniOp;
9
10#[derive(Debug, Clone)]
11pub struct InferenceBinOp(pub Box<dyn BinMiniOp>);
12
13impl Expansion for InferenceBinOp {
14 fn name(&self) -> StaticName {
15 self.0.name().into()
16 }
17
18 fn validation(&self) -> Validation {
19 self.0.validation()
20 }
21
22 fn rules<'r, 'p: 'r, 's: 'r>(
23 &'s self,
24 s: &mut Solver<'r>,
25 inputs: &'p [TensorProxy],
26 outputs: &'p [TensorProxy],
27 ) -> InferenceResult {
28 rules(s, inputs, outputs, move |typa, typb| self.0.result_datum_type(typa, typb))
29 }
30
31 fn wire(
32 &self,
33 prefix: &str,
34 target: &mut TypedModel,
35 inputs: &[OutletId],
36 ) -> TractResult<TVec<OutletId>> {
37 let operating_datum_type = self.0.operating_datum_type(
38 target.outlet_fact(inputs[0])?.datum_type,
39 target.outlet_fact(inputs[1])?.datum_type,
40 )?;
41 let wires = wire_rank_broadcast(prefix, target, inputs)?;
42 let wires = wire_cast(prefix, target, &wires, operating_datum_type)?;
43 target.wire_node(prefix, mir::binary::TypedBinOp(self.0.clone(), None), &wires)
44 }
45}
46
47pub fn rules<'r, 'p: 'r, 's: 'r, DT: Fn(DatumType, DatumType) -> TractResult<DatumType> + 'p>(
48 s: &mut Solver<'r>,
49 inputs: &'p [TensorProxy],
50 outputs: &'p [TensorProxy],
51 dt: DT,
52) -> InferenceResult {
53 check_input_arity(inputs, 2)?;
54 check_output_arity(outputs, 1)?;
55
56 s.given_2(&inputs[0].shape, &inputs[1].shape, move |s, a, b| {
76 s.equals(&outputs[0].shape, multi_broadcast(&[a, b])?)
77 })?;
78 s.given_2(&inputs[0].datum_type, &inputs[1].datum_type, move |s, typa, typb| {
79 s.equals(&outputs[0].datum_type, dt(typa, typb)?)
80 })?;
81 Ok(())
82}
83
84pub trait BinIntoHir {
85 fn into_hir(self) -> Box<dyn InferenceOp>;
86}
87
88impl<B: BinMiniOp> BinIntoHir for B {
89 fn into_hir(self) -> Box<dyn InferenceOp> {
90 expand(InferenceBinOp(Box::new(self) as _))
91 }
92}
93
94#[derive(Debug, Clone)]
95pub struct Nary(pub Box<dyn mir::binary::BinMiniOp>, pub bool);
96
97impl Nary {
98 fn normalize_t<T>(t: &mut Tensor, n: usize) -> TractResult<()>
99 where
100 T: Datum + std::ops::DivAssign<T> + Copy,
101 usize: tract_num_traits::AsPrimitive<T>,
102 {
103 use tract_num_traits::AsPrimitive;
104 let mut t = t.to_array_view_mut::<T>()?;
105 let n: T = n.as_();
106 t /= &tract_ndarray::arr0(n);
107 Ok(())
108 }
109}
110
111impl Op for Nary {
112 fn name(&self) -> StaticName {
113 format!("{}Nary", self.0.name()).into()
114 }
115
116 fn validation(&self) -> Validation {
117 self.0.validation()
118 }
119
120 not_a_typed_op!();
121}
122
123impl EvalOp for Nary {
124 fn is_stateless(&self) -> bool {
125 true
126 }
127
128 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
129 let mut t = inputs[0].clone().into_tensor();
130 for i in inputs[1..].iter() {
131 let mut i = i.clone().into_tensor();
132 let operating_datum_type =
133 self.0.operating_datum_type(t.datum_type(), i.datum_type())?;
134 if i.datum_type() != operating_datum_type {
135 i = i.cast_to_dt(operating_datum_type)?.into_owned();
136 }
137 if t.datum_type() != operating_datum_type {
138 t = t.cast_to_dt(operating_datum_type)?.into_owned();
139 }
140 t = self.0.eval(t.into_tvalue(), i.into_tvalue(), operating_datum_type)?;
141 }
142 if self.1 {
143 dispatch_numbers!(Self::normalize_t(t.datum_type())(&mut t, inputs.len()))?;
144 }
145 Ok(tvec!(t.into_tvalue()))
146 }
147}
148
149impl InferenceRulesOp for Nary {
150 fn rules<'r, 'p: 'r, 's: 'r>(
151 &'s self,
152 s: &mut Solver<'r>,
153 inputs: &'p [TensorProxy],
154 outputs: &'p [TensorProxy],
155 ) -> InferenceResult {
156 check_output_arity(outputs, 1)?;
157 let n = inputs.len();
158 s.given_all(
159 (0..n).map(|i| (&inputs[i].datum_type).bex()),
160 move |s, types: Vec<DatumType>| {
161 let dt = DatumType::super_type_for(&types)
162 .with_context(|| format!("No super type for {types:?}"))?;
163 let dt = self.0.operating_datum_type(dt, dt)?;
164 let result = self.0.result_datum_type(dt, dt)?;
165 s.equals(&outputs[0].datum_type, result)
166 },
167 )?;
168 s.given_all(inputs.iter().map(|i| &i.shape), move |s, shapes: Vec<TVec<TDim>>| {
169 let out = tract_core::broadcast::multi_broadcast(&shapes)?;
170 s.equals(&outputs[0].shape, ShapeFactoid::from(out))
171 })
172 }
173
174 fn to_typed(
175 &self,
176 _source: &InferenceModel,
177 node: &InferenceNode,
178 target: &mut TypedModel,
179 mapping: &HashMap<OutletId, OutletId>,
180 ) -> TractResult<TVec<OutletId>> {
181 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<Vec<_>>();
182 let types = inputs
183 .iter()
184 .map(|i| Ok(target.outlet_fact(*i)?.datum_type))
185 .collect::<TractResult<Vec<_>>>()?;
186 let dt = DatumType::super_type_for(&types)
187 .with_context(|| format!("No super type for {types:?}"))?;
188 let operating = self.0.operating_datum_type(dt, dt)?;
189 let inputs = wire_cast(&node.name, target, &inputs, operating)?;
190 let mut wire = inputs[0];
191 for (ix, i) in inputs[1..].iter().enumerate() {
192 let wires = wire_rank_broadcast(format!("{}.{}", node.name, ix), target, &[wire, *i])?;
193 wire = target.wire_node(
194 format!("{}.{}", node.name, ix),
195 mir::binary::TypedBinOp(self.0.clone(), None),
196 &wires,
197 )?[0];
198 }
199 if self.1 {
200 let n = tensor0(inputs.len() as i32)
201 .cast_to_dt(node.outputs[0].fact.datum_type.concretize().unwrap())?
202 .into_owned()
203 .broadcast_into_rank(target.outlet_fact(inputs[0])?.rank())?;
204 let n = target.add_const(format!("{}.n", node.name), n.into_arc_tensor())?;
205 wire = target.wire_node(
206 format!("{}.norm", node.name),
207 crate::ops::math::div(),
208 &[wire, n],
209 )?[0];
210 }
211 Ok(tvec!(wire))
212 }
213
214 as_op!();
215}