Skip to main content

tract_core/ops/math/
mod.rs

1#![allow(clippy::clone_on_copy)]
2#![allow(clippy::unnecessary_cast)]
3#![allow(clippy::blocks_in_conditions)]
4
5use super::array::MultiBroadcastTo;
6use super::binary::TypedBinOp;
7use crate::internal::*;
8use crate::ops::quant::scale_by;
9use num_traits::bounds::Bounded;
10use num_traits::int::PrimInt;
11use num_traits::{Float, Zero};
12use tract_data::internal::ClampCast;
13use tract_data::itertools::Itertools;
14pub use tract_data::prelude::round_ties_to_even;
15use tract_linalg::{ScaleShiftAndRound, Scaler};
16use tract_num_traits::AsPrimitive;
17
18#[cfg(feature = "complex")]
19mod complex;
20#[cfg(feature = "complex")]
21pub use complex::{ComplexToInnerDim, InnerDimToComplex};
22
23bin_to_super_type!(add, Add,
24                   linalg: Add,
25                   neutral_element: 0,
26                   validation: Validation::Rounding,
27                   q: [i8, u8, i32, i32] => add_quant;
28                   q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
29                   [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim, String] => |c, a, b| *c = a.clone() + b);
30
31fn add_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
32where
33    T: PrimInt + Bounded + AsPrimitive<i64> + Datum,
34    i64: AsPrimitive<T>,
35{
36    *c = (a.as_() + b.as_() - zp as i64).clamp_cast()
37}
38
39bin_to_super_type!(sub, Sub,
40                   linalg:Sub,
41                   is_commutative: false,
42                   neutral_element: 0,
43                   q: [i8, u8, i32, i32] => sub_quant;
44                   q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
45                   [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
46
47bin_to_super_type!(subf, SubF,
48                   linalg:SubF,
49                   is_commutative: false,
50                   neutral_element: 0,
51                   q: [i8, u8, i32, i32] => subf_quant;
52                   q_op_on_f32: |a: f32, b: f32| -> f32 {b - a},
53                   [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = b.clone() - a);
54
55fn sub_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
56where
57    T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
58    i16: AsPrimitive<T>,
59{
60    *c = (a.as_() - b.as_() + zp as i16).clamp_cast()
61}
62
63fn subf_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
64where
65    T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
66    i16: AsPrimitive<T>,
67{
68    *c = (b.as_() - a.as_() + zp as i16).clamp_cast()
69}
70
71bin_to_super_type!(mul, Mul,
72                   cost: |dt| tvec!((Cost::FMA(dt), 1)),
73                   declutter: declutter_mul,
74                   eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
75                    // we apply only if type is QU8 zp_scale datum type
76                    if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
77                            DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
78                            DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
79                        (a.datum_type(), b.datum_type(), c_dt)
80                    {
81                           let multiplier = a_scale  * b_scale * (1.0/ c_scale);
82                           let a = a.to_array_view::<u8>()?;
83                           let b = b.to_array_view::<u8>()?;
84                           let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
85                           let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
86                           let view = c.to_array_view_mut::<u8>()?;
87                           crate::ndarray::Zip::from(view)
88                               .and_broadcast(a)
89                               .and_broadcast(b)
90                               .for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) * (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast());
91                           Ok(c)
92                        } else {
93                            Mul.generic_eval(a, b, c_dt)
94                        }
95                    },
96                   linalg: Mul,
97                   neutral_element: 1,
98                   out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
99                       if c.datum_type() == TDim::datum_type() &&
100                           a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
101                               let a = a.to_array_view::<TDim>()?;
102                               let b = b.cast_to::<i32>()?;
103                               let b = b.to_array_view::<i32>()?;
104                               let c = c.to_array_view_mut::<TDim>()?;
105                               crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() * *b);
106                               Ok(true)
107                           }
108                       else {
109                           match c.datum_type() {
110                               DatumType::QI8(params) => {
111                                   let (zp, scale) = params.zp_scale();
112                                   let a = a.to_array_view::<i8>()?;
113                                   let b = b.to_array_view::<i8>()?;
114                                   let c = c.to_array_view_mut::<i8>()?;
115                                   crate::ndarray::Zip::from(c)
116                                       .and_broadcast(a)
117                                       .and_broadcast(b)
118                                       .for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast());
119                                   Ok(true)
120                               }
121                               DatumType::QU8(params) => {
122                                   let (zp, scale) = params.zp_scale();
123                                   let a = a.to_array_view::<u8>()?;
124                                   let b = b.to_array_view::<u8>()?;
125                                   let c = c.to_array_view_mut::<u8>()?;
126                                   crate::ndarray::Zip::from(c)
127                                       .and_broadcast(a)
128                                       .and_broadcast(b)
129                                       .for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast());
130                                   Ok(true)
131                               }
132                               _ => Ok(false)
133                           }
134                       }
135                   },
136                   q: [i8, u8, i32] => |c, a, b, zp, scale| {
137                    *c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast()
138                   };
139                   q_op_on_f32: |a: f32, b: f32| a * b,
140                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = a.wrapping_mul(*b),
141                   [f32, f16, f64] => |c, a, b| *c = a * b,
142                   [TDim] => |c, a, b| *c = a.clone() * b
143);
144
145bin_to_super_type!(div, Div,
146cost: |dt| tvec!((Cost::Div(dt), 1)),
147declutter: declutter_div,
148eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
149    if
150        a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
151            let a = a.to_array_view::<TDim>()?;
152            let b = b.to_array_view::<TDim>()?;
153            let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
154            unsafe {
155                let a = a.broadcast(&*c_shape).unwrap();
156                let b = b.broadcast(&*c_shape).unwrap();
157                let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
158                let mut view = c.to_array_view_mut::<TDim>()?;
159                for coords in crate::ndarray::indices(&*c_shape) {
160                    let (p, q) = a[&coords].maybe_div(&b[&coords])?;
161                    view[&coords] = p/q;
162                }
163                Ok(c)
164            }
165        } else if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
166                       DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
167                       DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
168                (a.datum_type(), b.datum_type(), c_dt) {
169
170               let multiplier = a_scale / (b_scale * c_scale);
171                let a = a.to_array_view::<u8>()?;
172                let b = b.to_array_view::<u8>()?;
173                let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
174                let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
175                let view = c.to_array_view_mut::<u8>()?;
176                crate::ndarray::Zip::from(view)
177                    .and_broadcast(a)
178                    .and_broadcast(b)
179                    // maintain division in f32 before rescale to maintain high accuracy
180                    .for_each(|c,a,b| *c = (
181                            scale_by(
182                                (*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier
183                            ) as i32 + c_zp as i32
184                        ).clamp_cast());
185                Ok(c)
186        } else {
187            Div.generic_eval(a, b, c_dt)
188        }
189},
190is_commutative: false,
191neutral_element: 1,
192out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
193    if c.datum_type() == TDim::datum_type() &&
194        a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
195            let a = a.to_array_view::<TDim>()?;
196            let b = b.cast_to::<i32>()?;
197            let b = b.to_array_view::<i32>()?;
198            let c = c.to_array_view_mut::<TDim>()?;
199            crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() / *b);
200            Ok(true)
201        } else if c.datum_type().is_quantized() || b.datum_type().is_quantized() || a.datum_type().is_quantized() {
202            let a_f32 = a.cast_to::<f32>()?;
203            let a_f32 = a_f32.to_array_view::<f32>()?;
204            let b_f32 = b.cast_to::<f32>()?;
205            let b_f32 = b_f32.to_array_view::<f32>()?;
206            let c_f32 = &a_f32 / &b_f32;
207            *c = c_f32.into_tensor().cast_to_dt(c.datum_type())?.into_owned();
208            Ok(true)
209        } else {
210            Ok(false)
211        }
212},
213q_op_on_f32: |a: f32, b: f32| a / b,
214[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() / b
215);
216
217bin_to_super_type!(rem, Rem,
218                                      eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
219                                          if
220                                              a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
221                                                  let a = a.to_array_view::<TDim>()?;
222                                                  let b = b.cast_to::<i32>()?;
223                                                  let b = b.to_array_view::<i32>()?;
224                                                  let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
225                                                  unsafe {
226                                                      let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
227                                                      let view = c.to_array_view_mut::<TDim>()?;
228                                                      crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
229                                                      Ok(c)
230                                                  }
231                                              } else {
232                                                  Rem.generic_eval(a,b, c_dt)
233                                              }
234                                      },
235                                      out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
236                                          if c.datum_type() == TDim::datum_type() &&
237                                              a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
238                                                  let a = a.to_array_view::<TDim>()?;
239                                                  let b = b.cast_to::<i32>()?;
240                                                  let b = b.to_array_view::<i32>()?;
241                                                  let c = c.to_array_view_mut::<TDim>()?;
242                                                  crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
243                                                  Ok(true)
244                                              } else {
245                                                  Ok(false)
246                                              }
247                                      },
248                                      [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() % b);
249
250bin_to_super_type!(min, Min, linalg:Min,
251                   q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b };
252                   q_op_on_f32: |a: f32, b: f32| a.min(b),
253                   [f16, f32, f64] => |c,a,b| *c = a.min(*b),
254                   [TDim] => |c,a,b| *c = a.clone().mini(b.clone()),
255                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b));
256
257bin_to_super_type!(max, Max,
258                   eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
259                   // Attempt to optimize relu case
260                    if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
261                            DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
262                            DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
263                        (a.datum_type(), b.datum_type(), c_dt)
264                    {
265                        if a.is_uniform() || b.is_uniform() {
266                            // select e between a and b as uniform if exist
267                            // and d remaining a or b
268                            let (d, d_zp, d_scale, e, e_zp, e_scale) = if a.is_uniform() && !b.is_uniform() {
269                                (&b, &b_zp, &b_scale, &a, &a_zp, &a_scale)
270                            } else {
271                                (&a, &a_zp, &a_scale, &b, &b_zp, &b_scale)
272                            };
273                            if e.is_uniform() { // may be relu or any scalar
274                                let e = e.cast_to::<u8>()?.as_slice::<u8>()?[0];
275                                let e_val_as_d_aligned: i32 = scale_by(e as i32 - e_zp, e_scale / d_scale);
276                                let multiplier = d_scale  * (1.0/ c_scale);
277                                let d = d.to_array_view::<u8>()?;
278                                let mut c = Tensor::zero_dt(c_dt, d.shape())?;
279                                let view = c.to_array_view_mut::<u8>()?;
280                                crate::ndarray::Zip::from(view)
281                                    .and_broadcast(d)
282                                    .for_each(|c,d| {
283                                        let d_min_zp = *d as i32 - *d_zp as i32;
284                                        let c_val: i32 = if d_min_zp < e_val_as_d_aligned {
285                                            e_val_as_d_aligned
286                                        } else {
287                                            d_min_zp
288                                        };
289                                        *c = (scale_by(c_val, multiplier) + c_zp as i32).clamp_cast();
290                                    });
291                                return Ok(c)
292                            }
293                        }
294                    }
295                    Max.generic_eval(a, b, c_dt)
296                   },
297                   linalg:Max,
298                   q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a };
299                   q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)},
300                   [f16, f32, f64] => |c,a,b| *c = a.max(*b),
301                   [TDim] => |c,a,b| *c = a.clone().maxi(b.clone()),
302                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b));
303
304bin_to_super_type!(pow, Pow,
305                   declutter: declutter_pow,
306                   is_commutative: false,
307                   neutral_element: 1,
308                   q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
309                   [f16, f32, f64] => |c,a,b| *c = a.powf(*b),
310                   [i32, i64] => |c,a,b| *c = a.pow(*b as u32));
311
312bin_to_super_type!(shift_left, ShiftLeft,
313                   is_commutative: false,
314                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
315bin_to_super_type!(shift_right, ShiftRight,
316                   is_commutative: false,
317                   [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);
318
319fn declutter_mul(
320    _op: &Mul,
321    model: &TypedModel,
322    node: &TypedNode,
323) -> TractResult<Option<TypedModelPatch>> {
324    if node.inputs[0] == node.inputs[1] && !node.outputs[0].fact.datum_type.is_quantized() {
325        return Ok(Some(TypedModelPatch::replace_single_op(
326            model,
327            node,
328            &node.inputs[0..1],
329            square(),
330        )?));
331    }
332
333    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
334        let var_fact = model.outlet_fact(uniform.var)?;
335        if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
336            let shapes =
337                model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
338            let shape: ShapeFact =
339                crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
340            return Ok(Some(TypedModelPatch::rewire(
341                model,
342                &[],
343                &[node.id.into()],
344                &|patch, _| {
345                    let scalar = patch.add_const(
346                        format!("{}.zero", node.name),
347                        if uniform.uni.datum_type().is_quantized() {
348                            let output_dt = node.outputs[0].fact.datum_type;
349                            Arc::new(uniform.uni.clone().cast_to_dt(output_dt)?.into_owned())
350                        } else {
351                            uniform.uni.clone()
352                        },
353                    )?;
354                    let op = MultiBroadcastTo::new(shape.clone());
355                    patch.wire_node(&node.name, op, &[scalar])
356                },
357            )?));
358        }
359        let dt = uniform.uni.datum_type();
360        if !dt.is_quantized() {
361            // avoid cast potential with Q tensor
362            let integer = uniform.uni.cast_to_scalar::<i64>()?;
363            if tensor0(integer)
364                .cast_to_dt(uniform.uni.datum_type())?
365                .close_enough(&uniform.uni, false)
366                .is_ok()
367                && uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
368                && dt.is_integer()
369            {
370                let shift = integer.trailing_zeros();
371                return Ok(Some(TypedModelPatch::rewire(
372                    model,
373                    &[uniform.var],
374                    &[node.id.into()],
375                    &|patch, taps| {
376                        let shift = patch.add_const(
377                            format!("{}.shift", node.name),
378                            tensor0(shift)
379                                .cast_to_dt(dt)?
380                                .into_owned()
381                                .broadcast_into_rank(var_fact.rank())?,
382                        )?;
383                        patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
384                    },
385                )?));
386            }
387        }
388    }
389    if let Some(patch) = declutter_mul_const_mul_const(model, node)? {
390        return Ok(Some(patch));
391    }
392    Ok(None)
393}
394
395fn declutter_mul_const_mul_const(
396    model: &TypedModel,
397    node: &TypedNode,
398) -> TractResult<Option<TypedModelPatch>> {
399    let input_facts = model.node_input_facts(node.id)?;
400    rule_if_some!(const_slot = input_facts.iter().position(|f| f.konst.is_some()));
401    let prec = model.node(node.inputs[1 - const_slot].node);
402    rule_if_some!(prec_mul = prec.op_as::<TypedBinOp>());
403    rule_if!(prec.outputs[0].successors.len() <= 1);
404    rule_if!(prec_mul.0.is::<Mul>());
405    let prec_input_facts = model.node_input_facts(prec.id)?;
406    rule_if_some!(prec_const_slot = prec_input_facts.iter().position(|f| f.konst.is_some()));
407
408    let const_fact = model.outlet_fact(node.inputs[const_slot])?;
409    let prec_const_fact = model.outlet_fact(prec.inputs[prec_const_slot])?;
410    // todo: extend to anything broadcast compatible
411    rule_if!(const_fact.shape.volume().is_one() || prec_const_fact.shape.volume().is_one());
412    rule_if!(const_fact.datum_type.is_float());
413    let result = mul()
414        .eval(tvec!(
415            const_fact.konst.clone().unwrap().into_tvalue(),
416            prec_const_fact.konst.clone().unwrap().into_tvalue()
417        ))?
418        .remove(0)
419        .into_arc_tensor();
420    let mut patch = TypedModelPatch::default();
421    let konst = patch.add_const(&prec.name, result)?;
422    let input_tap = patch.tap_model(model, prec.inputs[1 - prec_const_slot])?;
423    let wire = patch.wire_node(&node.name, mul(), &[konst, input_tap])?;
424    patch.shunt_outside(model, node.id.into(), wire[0])?;
425    Ok(Some(patch))
426}
427
428fn declutter_div(
429    _op: &Div,
430    model: &TypedModel,
431    node: &TypedNode,
432) -> TractResult<Option<TypedModelPatch>> {
433    if let &[p, q] = &*model.node_input_facts(node.id)? {
434        let dt = q.datum_type;
435        if let Some(q) = &q.uniform {
436            if let Ok(integer) = q.cast_to_scalar::<i64>() {
437                if tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
438                    && dt.is_integer()
439                    && q.cast_to_scalar::<i64>()?.count_ones() == 1
440                {
441                    let shift = integer.trailing_zeros();
442                    return Ok(Some(TypedModelPatch::rewire(
443                        model,
444                        &[node.inputs[0]],
445                        &[node.id.into()],
446                        &|patch, taps| {
447                            let shift = patch.add_const(
448                                format!("{}.shift", node.name),
449                                tensor0(shift)
450                                    .cast_to_dt(dt)?
451                                    .into_owned()
452                                    .broadcast_into_rank(p.rank())?,
453                            )?;
454                            patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
455                        },
456                    )?));
457                }
458            }
459        }
460        if dt.is_float() {
461            return Ok(Some(TypedModelPatch::rewire(
462                model,
463                &node.inputs,
464                &[node.id.into()],
465                &|patch, taps| {
466                    let q =
467                        patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?[0];
468                    patch.wire_node(&node.name, mul(), &[taps[0], q])
469                },
470            )?));
471        }
472    }
473    Ok(None)
474}
475
476fn declutter_pow(
477    _op: &Pow,
478    model: &TypedModel,
479    node: &TypedNode,
480) -> TractResult<Option<TypedModelPatch>> {
481    let b = model.outlet_fact(node.inputs[1])?;
482    if let Some(b) = &b.uniform {
483        let b = b.cast_to_scalar::<f32>()?;
484        if b == 2.0 {
485            return Ok(Some(TypedModelPatch::replace_single_op(
486                model,
487                node,
488                &[node.inputs[0]],
489                square(),
490            )?));
491        } else if b == 0.5 {
492            return Ok(Some(TypedModelPatch::replace_single_op(
493                model,
494                node,
495                &[node.inputs[0]],
496                sqrt(),
497            )?));
498        }
499    }
500    Ok(None)
501}
502
503element_wise!(abs, Abs, [i8, i16, i32, i64, f16, f32, i32] => |_, xs| {
504    xs.iter_mut().for_each(|x| *x = x.abs());
505    Ok(())
506};
507q: [i8, u8, i32, i32] => f32::abs;
508operating_datum_type: |dt| if dt == TDim::datum_type() { i64::datum_type() } else { dt }
509);
510
511element_wise!(exp, Exp, [f16, f32, f64] => |_, xs| {
512    xs.iter_mut().for_each(|x| *x = x.exp());
513    Ok(())
514};
515q: [i8, u8, i32, i32] => f32::exp;
516validation: Validation::Rounding
517);
518
519element_wise!(ln, Ln, [f16, f32, f64] => |_, xs| {
520    xs.iter_mut().for_each(|x| *x = x.ln());
521    Ok(())
522};
523q: [i8, u8, i32, i32] => f32::ln;
524validation: Validation::Rounding
525);
526
527element_wise!(square, Square, [f16, f32, f64] => |_, xs| {
528    xs.iter_mut().for_each(|x| *x = x.powi(2));
529    Ok(())
530};
531q: [i8, u8, i32, i32] => |f : f32| f.powi(2);
532validation: Validation::Rounding
533);
534
535element_wise!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| {
536    xs.iter_mut().for_each(|x| *x = x.sqrt());
537    Ok(())
538};
539q: [i8, u8, i32, i32] => f32::sqrt;
540validation: Validation::Rounding
541);
542
543element_wise!(recip, Recip, [f16, f32, f64] => |_, xs| {
544    xs.iter_mut().for_each(|x| *x = x.recip());
545    Ok(())
546};
547q: [i8, u8, i32, i32] => f32::recip;
548cost: |dt| {tvec!((Cost::Div(dt), 1))};
549declutter: declutter_recip;
550validation: Validation::Rounding
551);
552
553fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
554    use super::element_wise::*;
555    if let Some(prec) = model.linear_prec(node.id)? {
556        if let Some(ew) = prec.op_as::<ElementWiseOp>() {
557            let repl = if ew.0.is::<Sqrt>() {
558                Some(rsqrt())
559            } else if ew.0.is::<Rsqrt>() {
560                Some(sqrt())
561            } else {
562                None
563            };
564            if let Some(repl) = repl {
565                let mut patch = TypedModelPatch::default();
566                let mut wire = patch.tap_model(model, prec.inputs[0])?;
567                wire = patch.wire_node(&node.name, repl, &[wire])?[0];
568                patch.shunt_outside(model, node.id.into(), wire)?;
569                return Ok(Some(patch));
570            }
571        }
572    }
573    Ok(None)
574}
575
576element_wise!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| {
577    xs.iter_mut().for_each(|x| *x = x.sqrt().recip());
578    Ok(())
579};
580q: [i8, u8, i32] => |x : f32| x.sqrt().recip();
581validation: Validation::Rounding
582);
583
584element_wise!(ceil, Ceil, [f16, f32, f64] => |_, xs| {
585    xs.iter_mut().for_each(|x| *x = x.ceil());
586    Ok(())
587}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
588q: [i8, u8, i32] => f32::recip);
589
590element_wise!(floor, Floor, [f16, f32, f64] => |_, xs| {
591    xs.iter_mut().for_each(|x| *x = x.floor());
592    Ok(())
593}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
594q: [i8, u8, i32] => f32::floor);
595
596element_wise!(round, Round, [f16, f32, f64] => |_, xs| {
597    xs.iter_mut().for_each(|x| *x = x.round());
598    Ok(())
599}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
600q: [i8, u8, i32] => f32::round);
601
602element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| {
603    xs.iter_mut().for_each(|x| *x = x.q_scale(op.scaler));
604    Ok(())
605});
606
607element_wise!(round_half_to_even, RoundHalfToEven,
608[f32] => |_, xs| {
609    xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x));
610    Ok(())
611},
612[f16] => |_, xs| {
613    xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32())));
614    Ok(())
615};
616q: [i8, u8, i32] => round_ties_to_even);
617
618element_wise!(cos, Cos, [f16, f32, f64] => |_, xs| {
619    xs.iter_mut().for_each(|x| *x = x.cos());
620    Ok(())
621};
622q: [i8, u8, i32] => f32::cos);
623
624element_wise!(sin, Sin, [f16, f32, f64] => |_, xs| {
625    xs.iter_mut().for_each(|x| *x = x.sin());
626    Ok(())
627};
628q: [i8, u8, i32] => f32::sin);
629
630element_wise!(tan, Tan, [f16, f32, f64] => |_, xs| {
631    xs.iter_mut().for_each(|x| *x = x.tan());
632    Ok(())
633};
634q: [i8, u8, i32] => f32::tan);
635
636element_wise!(acos, Acos, [f16, f32, f64] => |_, xs| {
637    xs.iter_mut().for_each(|x| *x = x.acos());
638    Ok(())
639};
640q: [i8, u8, i32] => f32::acos);
641
642element_wise!(asin, Asin, [f16, f32, f64] => |_, xs| {
643    xs.iter_mut().for_each(|x| *x = x.asin());
644    Ok(())
645};
646q: [i8, u8, i32] => f32::asin);
647
648element_wise!(atan, Atan, [f16, f32, f64] => |_, xs| {
649    xs.iter_mut().for_each(|x| *x = x.atan());
650    Ok(())
651};
652q: [i8, u8, i32] => f32::atan);
653
654element_wise!(cosh, Cosh, [f16, f32, f64] => |_, xs| {
655    xs.iter_mut().for_each(|x| *x = x.cosh());
656    Ok(())
657};
658q: [i8, u8, i32] => f32::cosh);
659
660element_wise!(sinh, Sinh, [f16, f32, f64] => |_, xs| {
661    xs.iter_mut().for_each(|x| *x = x.sinh());
662    Ok(())
663};
664q: [i8, u8, i32] => f32::sinh);
665
666element_wise!(tanh, Tanh,
667 [f16] => |_, xs| { (tract_linalg::ops().tanh_f16)().run(xs) },
668 [f32] => |_, xs| { (tract_linalg::ops().tanh_f32)().run(xs) },
669 [f64] => |_, xs| { xs.iter_mut().for_each(|x| *x = x.tanh()); Ok(()) };
670 q: [i8, u8, i32] => f32::tanh;
671 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
672);
673
674element_wise!(erf, Erf,
675 [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) },
676 [f16] => |_, xs| {
677     let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec();
678     (tract_linalg::ops().erf_f32)().run(&mut f32s)?;
679     xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f));
680     Ok(())
681};
682 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
683);
684
685element_wise!(acosh, Acosh, [f16, f32, f64] => |_, xs| {
686    xs.iter_mut().for_each(|x| *x = x.acosh());
687    Ok(())
688};
689q: [i8, u8, i32] => f32::acosh);
690element_wise!(asinh, Asinh, [f16, f32, f64] => |_, xs| {
691    xs.iter_mut().for_each(|x| *x = x.asinh());
692    Ok(())
693};
694q: [i8, u8, i32] => f32::asinh);
695element_wise!(atanh, Atanh, [f16, f32, f64] => |_, xs| {
696    xs.iter_mut().for_each(|x| *x = x.atanh());
697    Ok(())
698};
699q: [i8, u8, i32] => f32::atanh);
700
701element_wise!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| {
702    xs.iter_mut().for_each(|x| *x = -x.clone());
703    Ok(())
704};
705q: [i8, u8, i32] => |x: f32| -x);
706
707element_wise!(sign, Sign, [f16, f32, f64] => |_, xs| {
708    xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() });
709    Ok(())
710};
711q: [i8, u8, i32] => f32::signum);
712
713element_wise_oop!(is_inf, IsInf { detect_positive: bool, detect_negative: bool },
714    [f32] => bool |op, xs, ys| {
715        xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
716            *y = (op.detect_positive && *x == f32::INFINITY) || (op.detect_negative && *x == f32::NEG_INFINITY)
717        );
718        Ok(())
719    },
720    [f16] => bool |op, xs, ys| {
721        xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
722            *y = (op.detect_positive && *x == f16::INFINITY) || (op.detect_negative && *x == f16::NEG_INFINITY)
723        );
724        Ok(())
725    }
726);
727
728element_wise_oop!(is_nan, IsNan,
729    [f16, f32] => bool |_, xs, ys| {
730        xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = x.is_nan());
731        Ok(())
732    }
733);
734
735#[cfg(test)]
736mod tests {
737    use crate::ops::binary::TypedBinOp;
738
739    use super::*;
740    use ndarray::arr2;
741
742    #[test]
743    fn test_mul() {
744        let a = arr2(&[[1., 2.], [3., 4.]]);
745        let b = arr2(&[[1., 0.], [0., 0.]]);
746        assert_eq!(a * b, arr2(&[[1., 0.], [0., 0.]]));
747    }
748
749    #[test]
750    fn dot() {
751        let a = arr2(&[[1., 2.], [3., 4.]]);
752        let b = arr2(&[[1., 0.], [0., 0.]]);
753        assert_eq!(a.dot(&b), arr2(&[[1., 0.], [3., 0.]]));
754    }
755
756    #[test]
757    fn mul_as_shift_left() -> TractResult<()> {
758        let mut model = TypedModel::default();
759        let x = model.add_source("x", i32::fact([2usize, 2]))?;
760        let a = model.add_const("a", tensor0(4i32).broadcast_into_rank(2)?.into_arc_tensor())?;
761        let y = model.wire_node("y", mul(), &[x, a])?[0];
762        model.set_output_outlets(&[y])?;
763        let result =
764            SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
765        assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
766        let decluttered = model.into_decluttered()?;
767        let result =
768            SimplePlan::new(decluttered.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
769        assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
770        let op = decluttered
771            .node(decluttered.output_outlets()?[0].node)
772            .op()
773            .downcast_ref::<TypedBinOp>()
774            .unwrap();
775        assert!(op.0.downcast_ref::<ShiftLeft>().is_some());
776        Ok(())
777    }
778
779    #[test]
780    fn div_as_shift() -> TractResult<()> {
781        let mut model = TypedModel::default();
782        let x = model.add_source("a", i32::fact([2usize, 2]))?;
783        let s = model.add_const("shift", tensor2(&[[4]]))?;
784        let y = model.wire_node("c", div(), [x, s].as_ref())?[0];
785        model.set_output_outlets(&[y])?;
786        let result =
787            SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
788        assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
789        let decluttered = model.into_decluttered()?;
790        let result = SimplePlan::new(decluttered.clone())?
791            .run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
792        assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
793        let op = decluttered
794            .node(decluttered.output_outlets()?[0].node)
795            .op()
796            .downcast_ref::<TypedBinOp>()
797            .unwrap();
798        assert!(op.0.downcast_ref::<ShiftRight>().is_some());
799        Ok(())
800    }
801}