Skip to main content

tract_core/ops/math/
mod.rs

1#![allow(clippy::clone_on_copy)]
2#![allow(clippy::unnecessary_cast)]
3#![allow(clippy::blocks_in_conditions)]
4
5use super::array::MultiBroadcastTo;
6use super::binary::TypedBinOp;
7use crate::internal::*;
8use crate::ops::quant::scale_by;
9use num_traits::bounds::Bounded;
10use num_traits::int::PrimInt;
11use num_traits::{Float, Zero};
12use tract_data::internal::ClampCast;
13use tract_data::itertools::Itertools;
14pub use tract_data::prelude::round_ties_to_even;
15use tract_linalg::{ScaleShiftAndRound, Scaler};
16use tract_num_traits::AsPrimitive;
17
18#[cfg(feature = "complex")]
19mod complex;
20#[cfg(feature = "complex")]
21pub use complex::{ComplexToInnerDim, InnerDimToComplex};
22
23bin_to_super_type!(add, Add,
24                   linalg: Add,
25                   neutral_element: 0,
26                   validation: Validation::Rounding,
27                   q: [i8, u8, i32, i32] => add_quant;
28                   q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
29                   [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim, String] => |c, a, b| *c = a.clone() + b);
30
31fn add_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
32where
33    T: PrimInt + Bounded + AsPrimitive<i64> + Datum,
34    i64: AsPrimitive<T>,
35{
36    *c = (a.as_() + b.as_() - zp as i64).clamp_cast()
37}
38
39bin_to_super_type!(sub, Sub,
40                   linalg:Sub,
41                   is_commutative: false,
42                   neutral_element: 0,
43                   q: [i8, u8, i32, i32] => sub_quant;
44                   q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
45                   [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
46
47bin_to_super_type!(subf, SubF,
48                   linalg:SubF,
49                   is_commutative: false,
50                   neutral_element: 0,
51                   q: [i8, u8, i32, i32] => subf_quant;
52                   q_op_on_f32: |a: f32, b: f32| -> f32 {b - a},
53                   [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = b.clone() - a);
54
55fn sub_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
56where
57    T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
58    i16: AsPrimitive<T>,
59{
60    *c = (a.as_() - b.as_() + zp as i16).clamp_cast()
61}
62
63fn subf_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
64where
65    T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
66    i16: AsPrimitive<T>,
67{
68    *c = (b.as_() - a.as_() + zp as i16).clamp_cast()
69}
70
71bin_to_super_type!(mul, Mul,
72                   cost: |dt| tvec!((Cost::FMA(dt), 1)),
73                   declutter: declutter_mul,
74                   eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
75                    // we apply only if type is QU8 zp_scale datum type
76                    if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
77                            DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
78                            DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
79                        (a.datum_type(), b.datum_type(), c_dt)
80                    {
81                           let multiplier = a_scale  * b_scale * (1.0/ c_scale);
82                           let a = a.to_dense_array_view::<u8>()?;
83                           let b = b.to_dense_array_view::<u8>()?;
84                           let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
85                           let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
86                           let mut c_dense = c.try_as_dense_mut()?;
87                           let view = c_dense.to_array_view_mut::<u8>()?;
88                           crate::ndarray::Zip::from(view)
89                               .and_broadcast(a)
90                               .and_broadcast(b)
91                               .for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) * (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast());
92                           Ok(c)
93                        } else {
94                            Mul.generic_eval(a, b, c_dt)
95                        }
96                    },
97                   linalg: Mul,
98                   neutral_element: 1,
99                   out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
100                       if c.datum_type() == TDim::datum_type() &&
101                           a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
102                               let a = a.to_dense_array_view::<TDim>()?;
103                               let b = b.cast_to::<i32>()?;
104                               let b = b.to_dense_array_view::<i32>()?;
105                               let mut c_dense = c.try_as_dense_mut()?;
106                               let c = c_dense.to_array_view_mut::<TDim>()?;
107                               crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() * *b);
108                               Ok(true)
109                           }
110                       else {
111                           match c.datum_type() {
112                               DatumType::QI8(params) => {
113                                   let (zp, scale) = params.zp_scale();
114                                   let a = a.to_dense_array_view::<i8>()?;
115                                   let b = b.to_dense_array_view::<i8>()?;
116                                   let mut c_dense = c.try_as_dense_mut()?;
117                                   let c = c_dense.to_array_view_mut::<i8>()?;
118                                   crate::ndarray::Zip::from(c)
119                                       .and_broadcast(a)
120                                       .and_broadcast(b)
121                                       .for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast());
122                                   Ok(true)
123                               }
124                               DatumType::QU8(params) => {
125                                   let (zp, scale) = params.zp_scale();
126                                   let a = a.to_dense_array_view::<u8>()?;
127                                   let b = b.to_dense_array_view::<u8>()?;
128                                   let mut c_dense = c.try_as_dense_mut()?;
129                                   let c = c_dense.to_array_view_mut::<u8>()?;
130                                   crate::ndarray::Zip::from(c)
131                                       .and_broadcast(a)
132                                       .and_broadcast(b)
133                                       .for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast());
134                                   Ok(true)
135                               }
136                               _ => Ok(false)
137                           }
138                       }
139                   },
140                   q: [i8, u8, i32] => |c, a, b, zp, scale| {
141                    *c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast()
142                   };
143                   q_op_on_f32: |a: f32, b: f32| a * b,
144                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = a.wrapping_mul(*b),
145                   [f32, f16, f64] => |c, a, b| *c = a * b,
146                   [TDim] => |c, a, b| *c = a.clone() * b
147);
148
149bin_to_super_type!(div, Div,
150cost: |dt| tvec!((Cost::Div(dt), 1)),
151declutter: declutter_div,
152eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
153    if
154        a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
155            let a = a.to_dense_array_view::<TDim>()?;
156            let b = b.to_dense_array_view::<TDim>()?;
157            let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
158            unsafe {
159                let a = a.broadcast(&*c_shape).unwrap();
160                let b = b.broadcast(&*c_shape).unwrap();
161                let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
162                let mut c_dense = c.try_as_dense_mut()?;
163                let mut view = c_dense.to_array_view_mut::<TDim>()?;
164                for coords in crate::ndarray::indices(&*c_shape) {
165                    let (p, q) = a[&coords].maybe_div(&b[&coords])?;
166                    view[&coords] = p/q;
167                }
168                Ok(c)
169            }
170        } else if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
171                       DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
172                       DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
173                (a.datum_type(), b.datum_type(), c_dt) {
174
175               let multiplier = a_scale / (b_scale * c_scale);
176                let a = a.to_dense_array_view::<u8>()?;
177                let b = b.to_dense_array_view::<u8>()?;
178                let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
179                let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
180                let mut c_dense = c.try_as_dense_mut()?;
181                let view = c_dense.to_array_view_mut::<u8>()?;
182                crate::ndarray::Zip::from(view)
183                    .and_broadcast(a)
184                    .and_broadcast(b)
185                    // maintain division in f32 before rescale to maintain high accuracy
186                    .for_each(|c,a,b| *c = (
187                            scale_by(
188                                (*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier
189                            ) as i32 + c_zp as i32
190                        ).clamp_cast());
191                Ok(c)
192        } else {
193            Div.generic_eval(a, b, c_dt)
194        }
195},
196is_commutative: false,
197neutral_element: 1,
198out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
199    if c.datum_type() == TDim::datum_type() &&
200        a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
201            let a = a.to_dense_array_view::<TDim>()?;
202            let b = b.cast_to::<i32>()?;
203            let b = b.to_dense_array_view::<i32>()?;
204            let mut c_dense = c.try_as_dense_mut()?;
205            let c = c_dense.to_array_view_mut::<TDim>()?;
206            crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() / *b);
207            Ok(true)
208        } else if c.datum_type().is_quantized() || b.datum_type().is_quantized() || a.datum_type().is_quantized() {
209            let a_f32 = a.cast_to::<f32>()?;
210            let a_f32 = a_f32.to_dense_array_view::<f32>()?;
211            let b_f32 = b.cast_to::<f32>()?;
212            let b_f32 = b_f32.to_dense_array_view::<f32>()?;
213            let c_f32 = &a_f32 / &b_f32;
214            *c = c_f32.into_tensor().cast_to_dt(c.datum_type())?.into_owned();
215            Ok(true)
216        } else {
217            Ok(false)
218        }
219},
220q_op_on_f32: |a: f32, b: f32| a / b,
221[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() / b
222);
223
224bin_to_super_type!(rem, Rem,
225                                      eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
226                                          if
227                                              a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
228                                                  let a = a.to_dense_array_view::<TDim>()?;
229                                                  let b = b.cast_to::<i32>()?;
230                                                  let b = b.to_dense_array_view::<i32>()?;
231                                                  let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
232                                                  unsafe {
233                                                      let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
234                                                      let mut c_dense = c.try_as_dense_mut()?;
235                                                      let view = c_dense.to_array_view_mut::<TDim>()?;
236                                                      crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
237                                                      Ok(c)
238                                                  }
239                                              } else {
240                                                  Rem.generic_eval(a,b, c_dt)
241                                              }
242                                      },
243                                      out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
244                                          if c.datum_type() == TDim::datum_type() &&
245                                              a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
246                                                  let a = a.to_dense_array_view::<TDim>()?;
247                                                  let b = b.cast_to::<i32>()?;
248                                                  let b = b.to_dense_array_view::<i32>()?;
249                                                  let mut c_dense = c.try_as_dense_mut()?;
250                                                  let c = c_dense.to_array_view_mut::<TDim>()?;
251                                                  crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
252                                                  Ok(true)
253                                              } else {
254                                                  Ok(false)
255                                              }
256                                      },
257                                      [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() % b);
258
259bin_to_super_type!(min, Min, linalg:Min,
260                   q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b };
261                   q_op_on_f32: |a: f32, b: f32| a.min(b),
262                   [f16, f32, f64] => |c,a,b| *c = a.min(*b),
263                   [TDim] => |c,a,b| *c = a.clone().mini(b.clone()),
264                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b));
265
266bin_to_super_type!(max, Max,
267                   eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
268                   // Attempt to optimize relu case
269                    if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
270                            DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
271                            DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
272                        (a.datum_type(), b.datum_type(), c_dt)
273                    {
274                        if a.is_uniform() || b.is_uniform() {
275                            // select e between a and b as uniform if exist
276                            // and d remaining a or b
277                            let (d, d_zp, d_scale, e, e_zp, e_scale) = if a.is_uniform() && !b.is_uniform() {
278                                (&b, &b_zp, &b_scale, &a, &a_zp, &a_scale)
279                            } else {
280                                (&a, &a_zp, &a_scale, &b, &b_zp, &b_scale)
281                            };
282                            if e.is_uniform() { // may be relu or any scalar
283                                let e = e.cast_to::<u8>()?.try_as_dense()?.as_slice::<u8>()?[0];
284                                let e_val_as_d_aligned: i32 = scale_by(e as i32 - e_zp, e_scale / d_scale);
285                                let multiplier = d_scale  * (1.0/ c_scale);
286                                let d = d.to_dense_array_view::<u8>()?;
287                                let mut c = Tensor::zero_dt(c_dt, d.shape())?;
288                                let mut c_dense = c.try_as_dense_mut()?;
289                                let view = c_dense.to_array_view_mut::<u8>()?;
290                                crate::ndarray::Zip::from(view)
291                                    .and_broadcast(d)
292                                    .for_each(|c,d| {
293                                        let d_min_zp = *d as i32 - *d_zp as i32;
294                                        let c_val: i32 = if d_min_zp < e_val_as_d_aligned {
295                                            e_val_as_d_aligned
296                                        } else {
297                                            d_min_zp
298                                        };
299                                        *c = (scale_by(c_val, multiplier) + c_zp as i32).clamp_cast();
300                                    });
301                                return Ok(c)
302                            }
303                        }
304                    }
305                    Max.generic_eval(a, b, c_dt)
306                   },
307                   linalg:Max,
308                   q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a };
309                   q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)},
310                   [f16, f32, f64] => |c,a,b| *c = a.max(*b),
311                   [TDim] => |c,a,b| *c = a.clone().maxi(b.clone()),
312                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b));
313
314bin_to_super_type!(pow, Pow,
315                   declutter: declutter_pow,
316                   is_commutative: false,
317                   neutral_element: 1,
318                   q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
319                   [f16, f32, f64] => |c,a,b| *c = a.powf(*b),
320                   [i32, i64] => |c,a,b| *c = a.pow(*b as u32));
321
322bin_to_super_type!(shift_left, ShiftLeft,
323                   is_commutative: false,
324                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
325bin_to_super_type!(shift_right, ShiftRight,
326                   is_commutative: false,
327                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);
328
329fn declutter_mul(
330    _op: &Mul,
331    model: &TypedModel,
332    node: &TypedNode,
333) -> TractResult<Option<TypedModelPatch>> {
334    if node.inputs[0] == node.inputs[1] && !node.outputs[0].fact.datum_type.is_quantized() {
335        return Ok(Some(TypedModelPatch::replace_single_op(
336            model,
337            node,
338            &node.inputs[0..1],
339            square(),
340        )?));
341    }
342
343    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
344        let var_fact = model.outlet_fact(uniform.var)?;
345        if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
346            let shapes =
347                model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
348            let shape: ShapeFact =
349                crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
350            return Ok(Some(TypedModelPatch::rewire(
351                model,
352                &[],
353                &[node.id.into()],
354                &|patch, _| {
355                    let scalar = patch.add_const(
356                        format!("{}.zero", node.name),
357                        if uniform.uni.datum_type().is_quantized() {
358                            let output_dt = node.outputs[0].fact.datum_type;
359                            Arc::new(uniform.uni.clone().cast_to_dt(output_dt)?.into_owned())
360                        } else {
361                            uniform.uni.clone()
362                        },
363                    )?;
364                    let op = MultiBroadcastTo::new(shape.clone());
365                    patch.wire_node(&node.name, op, &[scalar])
366                },
367            )?));
368        }
369        let dt = uniform.uni.datum_type();
370        if !dt.is_quantized() {
371            // avoid cast potential with Q tensor
372            let integer = uniform.uni.cast_to_scalar::<i64>()?;
373            if tensor0(integer)
374                .cast_to_dt(uniform.uni.datum_type())?
375                .close_enough(&uniform.uni, false)
376                .is_ok()
377                && uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
378                && dt.is_integer()
379            {
380                let shift = integer.trailing_zeros();
381                return Ok(Some(TypedModelPatch::rewire(
382                    model,
383                    &[uniform.var],
384                    &[node.id.into()],
385                    &|patch, taps| {
386                        let shift = patch.add_const(
387                            format!("{}.shift", node.name),
388                            tensor0(shift)
389                                .cast_to_dt(dt)?
390                                .into_owned()
391                                .broadcast_into_rank(var_fact.rank())?,
392                        )?;
393                        patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
394                    },
395                )?));
396            }
397        }
398    }
399    if let Some(patch) = declutter_mul_const_mul_const(model, node)? {
400        return Ok(Some(patch));
401    }
402    Ok(None)
403}
404
405fn declutter_mul_const_mul_const(
406    model: &TypedModel,
407    node: &TypedNode,
408) -> TractResult<Option<TypedModelPatch>> {
409    let input_facts = model.node_input_facts(node.id)?;
410    rule_if_some!(const_slot = input_facts.iter().position(|f| f.konst.is_some()));
411    let prec = model.node(node.inputs[1 - const_slot].node);
412    rule_if_some!(prec_mul = prec.op_as::<TypedBinOp>());
413    rule_if!(prec.outputs[0].successors.len() <= 1);
414    rule_if!(prec_mul.0.is::<Mul>());
415    let prec_input_facts = model.node_input_facts(prec.id)?;
416    rule_if_some!(prec_const_slot = prec_input_facts.iter().position(|f| f.konst.is_some()));
417
418    let const_fact = model.outlet_fact(node.inputs[const_slot])?;
419    let prec_const_fact = model.outlet_fact(prec.inputs[prec_const_slot])?;
420    // todo: extend to anything broadcast compatible
421    rule_if!(const_fact.shape.volume().is_one() || prec_const_fact.shape.volume().is_one());
422    rule_if!(const_fact.datum_type.is_float());
423    let result = mul()
424        .eval(tvec!(
425            const_fact.konst.clone().unwrap().into_tvalue(),
426            prec_const_fact.konst.clone().unwrap().into_tvalue()
427        ))?
428        .remove(0)
429        .into_arc_tensor();
430    let mut patch = TypedModelPatch::default();
431    let konst = patch.add_const(&prec.name, result)?;
432    let input_tap = patch.tap_model(model, prec.inputs[1 - prec_const_slot])?;
433    let wire = patch.wire_node(&node.name, mul(), &[konst, input_tap])?;
434    patch.shunt_outside(model, node.id.into(), wire[0])?;
435    Ok(Some(patch))
436}
437
438fn declutter_div(
439    _op: &Div,
440    model: &TypedModel,
441    node: &TypedNode,
442) -> TractResult<Option<TypedModelPatch>> {
443    if let &[p, q] = &*model.node_input_facts(node.id)? {
444        let dt = q.datum_type;
445        if let Some(q) = &q.uniform {
446            if let Ok(integer) = q.cast_to_scalar::<i64>() {
447                if tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
448                    && dt.is_integer()
449                    && q.cast_to_scalar::<i64>()?.count_ones() == 1
450                {
451                    let shift = integer.trailing_zeros();
452                    return Ok(Some(TypedModelPatch::rewire(
453                        model,
454                        &[node.inputs[0]],
455                        &[node.id.into()],
456                        &|patch, taps| {
457                            let shift = patch.add_const(
458                                format!("{}.shift", node.name),
459                                tensor0(shift)
460                                    .cast_to_dt(dt)?
461                                    .into_owned()
462                                    .broadcast_into_rank(p.rank())?,
463                            )?;
464                            patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
465                        },
466                    )?));
467                }
468            }
469        }
470        if dt.is_float() {
471            return Ok(Some(TypedModelPatch::rewire(
472                model,
473                &node.inputs,
474                &[node.id.into()],
475                &|patch, taps| {
476                    let q =
477                        patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?[0];
478                    patch.wire_node(&node.name, mul(), &[taps[0], q])
479                },
480            )?));
481        }
482    }
483    Ok(None)
484}
485
486fn declutter_pow(
487    _op: &Pow,
488    model: &TypedModel,
489    node: &TypedNode,
490) -> TractResult<Option<TypedModelPatch>> {
491    let b = model.outlet_fact(node.inputs[1])?;
492    if let Some(b) = &b.uniform {
493        let b = b.cast_to_scalar::<f32>()?;
494        if b == 2.0 {
495            return Ok(Some(TypedModelPatch::replace_single_op(
496                model,
497                node,
498                &[node.inputs[0]],
499                square(),
500            )?));
501        } else if b == 0.5 {
502            return Ok(Some(TypedModelPatch::replace_single_op(
503                model,
504                node,
505                &[node.inputs[0]],
506                sqrt(),
507            )?));
508        }
509    }
510    crate::ops::nn::gelu_approximate::detect_gelu_approx(_op, model, node)
511}
512
513element_wise!(abs, Abs, [i8, i16, i32, i64, f16, f32, i32] => |_, xs| {
514    xs.iter_mut().for_each(|x| *x = x.abs());
515    Ok(())
516};
517q: [i8, u8, i32, i32] => f32::abs;
518operating_datum_type: |dt| if dt == TDim::datum_type() { i64::datum_type() } else { dt }
519);
520
521element_wise!(exp, Exp, [f16, f32, f64] => |_, xs| {
522    xs.iter_mut().for_each(|x| *x = x.exp());
523    Ok(())
524};
525q: [i8, u8, i32, i32] => f32::exp;
526validation: Validation::Rounding
527);
528
529element_wise!(ln, Ln, [f16, f32, f64] => |_, xs| {
530    xs.iter_mut().for_each(|x| *x = x.ln());
531    Ok(())
532};
533q: [i8, u8, i32, i32] => f32::ln;
534validation: Validation::Rounding
535);
536
537element_wise!(square, Square, [f16, f32, f64] => |_, xs| {
538    xs.iter_mut().for_each(|x| *x = x.powi(2));
539    Ok(())
540};
541q: [i8, u8, i32, i32] => |f : f32| f.powi(2);
542validation: Validation::Rounding
543);
544
545element_wise!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| {
546    xs.iter_mut().for_each(|x| *x = x.sqrt());
547    Ok(())
548};
549q: [i8, u8, i32, i32] => f32::sqrt;
550validation: Validation::Rounding
551);
552
553element_wise!(recip, Recip, [f16, f32, f64] => |_, xs| {
554    xs.iter_mut().for_each(|x| *x = x.recip());
555    Ok(())
556};
557q: [i8, u8, i32, i32] => f32::recip;
558cost: |dt| {tvec!((Cost::Div(dt), 1))};
559declutter: declutter_recip;
560validation: Validation::Rounding
561);
562
563fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
564    use super::element_wise::*;
565    if let Some(prec) = model.linear_prec(node.id)? {
566        if let Some(ew) = prec.op_as::<ElementWiseOp>() {
567            let repl = if ew.0.is::<Sqrt>() {
568                Some(rsqrt())
569            } else if ew.0.is::<Rsqrt>() {
570                Some(sqrt())
571            } else {
572                None
573            };
574            if let Some(repl) = repl {
575                let mut patch = TypedModelPatch::default();
576                let mut wire = patch.tap_model(model, prec.inputs[0])?;
577                wire = patch.wire_node(&node.name, repl, &[wire])?[0];
578                patch.shunt_outside(model, node.id.into(), wire)?;
579                return Ok(Some(patch));
580            }
581        }
582    }
583    Ok(None)
584}
585
586element_wise!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| {
587    xs.iter_mut().for_each(|x| *x = x.sqrt().recip());
588    Ok(())
589};
590q: [i8, u8, i32] => |x : f32| x.sqrt().recip();
591validation: Validation::Rounding
592);
593
594element_wise!(ceil, Ceil, [f16, f32, f64] => |_, xs| {
595    xs.iter_mut().for_each(|x| *x = x.ceil());
596    Ok(())
597}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
598q: [i8, u8, i32] => f32::recip);
599
600element_wise!(floor, Floor, [f16, f32, f64] => |_, xs| {
601    xs.iter_mut().for_each(|x| *x = x.floor());
602    Ok(())
603}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
604q: [i8, u8, i32] => f32::floor);
605
606element_wise!(round, Round, [f16, f32, f64] => |_, xs| {
607    xs.iter_mut().for_each(|x| *x = x.round());
608    Ok(())
609}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
610q: [i8, u8, i32] => f32::round);
611
612element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| {
613    xs.iter_mut().for_each(|x| *x = x.q_scale(op.scaler));
614    Ok(())
615});
616
617element_wise!(round_half_to_even, RoundHalfToEven,
618[f32] => |_, xs| {
619    xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x));
620    Ok(())
621},
622[f16] => |_, xs| {
623    xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32())));
624    Ok(())
625};
626q: [i8, u8, i32] => round_ties_to_even);
627
628element_wise!(cos, Cos, [f16, f32, f64] => |_, xs| {
629    xs.iter_mut().for_each(|x| *x = x.cos());
630    Ok(())
631};
632q: [i8, u8, i32] => f32::cos);
633
634element_wise!(sin, Sin, [f16, f32, f64] => |_, xs| {
635    xs.iter_mut().for_each(|x| *x = x.sin());
636    Ok(())
637};
638q: [i8, u8, i32] => f32::sin);
639
640element_wise!(tan, Tan, [f16, f32, f64] => |_, xs| {
641    xs.iter_mut().for_each(|x| *x = x.tan());
642    Ok(())
643};
644q: [i8, u8, i32] => f32::tan);
645
646element_wise!(acos, Acos, [f16, f32, f64] => |_, xs| {
647    xs.iter_mut().for_each(|x| *x = x.acos());
648    Ok(())
649};
650q: [i8, u8, i32] => f32::acos);
651
652element_wise!(asin, Asin, [f16, f32, f64] => |_, xs| {
653    xs.iter_mut().for_each(|x| *x = x.asin());
654    Ok(())
655};
656q: [i8, u8, i32] => f32::asin);
657
658element_wise!(atan, Atan, [f16, f32, f64] => |_, xs| {
659    xs.iter_mut().for_each(|x| *x = x.atan());
660    Ok(())
661};
662q: [i8, u8, i32] => f32::atan);
663
664element_wise!(cosh, Cosh, [f16, f32, f64] => |_, xs| {
665    xs.iter_mut().for_each(|x| *x = x.cosh());
666    Ok(())
667};
668q: [i8, u8, i32] => f32::cosh);
669
670element_wise!(sinh, Sinh, [f16, f32, f64] => |_, xs| {
671    xs.iter_mut().for_each(|x| *x = x.sinh());
672    Ok(())
673};
674q: [i8, u8, i32] => f32::sinh);
675
676element_wise!(tanh, Tanh,
677 [f16] => |_, xs| { (tract_linalg::ops().tanh_f16)().run(xs) },
678 [f32] => |_, xs| { (tract_linalg::ops().tanh_f32)().run(xs) },
679 [f64] => |_, xs| { xs.iter_mut().for_each(|x| *x = x.tanh()); Ok(()) };
680 q: [i8, u8, i32] => f32::tanh;
681 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
682);
683
684element_wise!(erf, Erf,
685 [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) },
686 [f16] => |_, xs| {
687     let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec();
688     (tract_linalg::ops().erf_f32)().run(&mut f32s)?;
689     xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f));
690     Ok(())
691};
692 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
693);
694
695element_wise!(acosh, Acosh, [f16, f32, f64] => |_, xs| {
696    xs.iter_mut().for_each(|x| *x = x.acosh());
697    Ok(())
698};
699q: [i8, u8, i32] => f32::acosh);
700element_wise!(asinh, Asinh, [f16, f32, f64] => |_, xs| {
701    xs.iter_mut().for_each(|x| *x = x.asinh());
702    Ok(())
703};
704q: [i8, u8, i32] => f32::asinh);
705element_wise!(atanh, Atanh, [f16, f32, f64] => |_, xs| {
706    xs.iter_mut().for_each(|x| *x = x.atanh());
707    Ok(())
708};
709q: [i8, u8, i32] => f32::atanh);
710
711element_wise!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| {
712    xs.iter_mut().for_each(|x| *x = -x.clone());
713    Ok(())
714};
715q: [i8, u8, i32] => |x: f32| -x);
716
717element_wise!(sign, Sign, [f16, f32, f64] => |_, xs| {
718    xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() });
719    Ok(())
720};
721q: [i8, u8, i32] => f32::signum);
722
723element_wise_oop!(is_inf, IsInf { detect_positive: bool, detect_negative: bool },
724    [f32] => bool |op, xs, ys| {
725        xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
726            *y = (op.detect_positive && *x == f32::INFINITY) || (op.detect_negative && *x == f32::NEG_INFINITY)
727        );
728        Ok(())
729    },
730    [f16] => bool |op, xs, ys| {
731        xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
732            *y = (op.detect_positive && *x == f16::INFINITY) || (op.detect_negative && *x == f16::NEG_INFINITY)
733        );
734        Ok(())
735    }
736);
737
738element_wise_oop!(is_nan, IsNan,
739    [f16, f32] => bool |_, xs, ys| {
740        xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = x.is_nan());
741        Ok(())
742    }
743);
744
745#[cfg(test)]
746mod tests {
747    use crate::ops::binary::TypedBinOp;
748
749    use super::*;
750    use ndarray::arr2;
751
752    #[test]
753    fn test_mul() {
754        let a = arr2(&[[1., 2.], [3., 4.]]);
755        let b = arr2(&[[1., 0.], [0., 0.]]);
756        assert_eq!(a * b, arr2(&[[1., 0.], [0., 0.]]));
757    }
758
759    #[test]
760    fn dot() {
761        let a = arr2(&[[1., 2.], [3., 4.]]);
762        let b = arr2(&[[1., 0.], [0., 0.]]);
763        assert_eq!(a.dot(&b), arr2(&[[1., 0.], [3., 0.]]));
764    }
765
766    #[test]
767    fn mul_as_shift_left() -> TractResult<()> {
768        let mut model = TypedModel::default();
769        let x = model.add_source("x", i32::fact([2usize, 2]))?;
770        let a = model.add_const("a", tensor0(4i32).broadcast_into_rank(2)?.into_arc_tensor())?;
771        let y = model.wire_node("y", mul(), &[x, a])?[0];
772        model.set_output_outlets(&[y])?;
773        let result =
774            SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
775        assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
776        let decluttered = model.into_decluttered()?;
777        let result =
778            SimplePlan::new(decluttered.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
779        assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
780        let op = decluttered
781            .node(decluttered.output_outlets()?[0].node)
782            .op()
783            .downcast_ref::<TypedBinOp>()
784            .unwrap();
785        assert!(op.0.downcast_ref::<ShiftLeft>().is_some());
786        Ok(())
787    }
788
789    #[test]
790    fn div_as_shift() -> TractResult<()> {
791        let mut model = TypedModel::default();
792        let x = model.add_source("a", i32::fact([2usize, 2]))?;
793        let s = model.add_const("shift", tensor2(&[[4]]))?;
794        let y = model.wire_node("c", div(), [x, s].as_ref())?[0];
795        model.set_output_outlets(&[y])?;
796        let result =
797            SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
798        assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
799        let decluttered = model.into_decluttered()?;
800        let result = SimplePlan::new(decluttered.clone())?
801            .run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
802        assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
803        let op = decluttered
804            .node(decluttered.output_outlets()?[0].node)
805            .op()
806            .downcast_ref::<TypedBinOp>()
807            .unwrap();
808        assert!(op.0.downcast_ref::<ShiftRight>().is_some());
809        Ok(())
810    }
811}