1#![allow(clippy::clone_on_copy)]
2#![allow(clippy::unnecessary_cast)]
3#![allow(clippy::blocks_in_conditions)]
4
5use super::array::MultiBroadcastTo;
6use super::binary::TypedBinOp;
7use crate::internal::*;
8use crate::ops::quant::scale_by;
9use num_traits::bounds::Bounded;
10use num_traits::int::PrimInt;
11use num_traits::{Float, Zero};
12use tract_data::internal::ClampCast;
13use tract_data::itertools::Itertools;
14pub use tract_data::prelude::round_ties_to_even;
15use tract_linalg::{ScaleShiftAndRound, Scaler};
16use tract_num_traits::AsPrimitive;
17
18#[cfg(feature = "complex")]
19mod complex;
20#[cfg(feature = "complex")]
21pub use complex::{ComplexToInnerDim, InnerDimToComplex};
22
23bin_to_super_type!(add, Add,
24 linalg: Add,
25 neutral_element: 0,
26 validation: Validation::Rounding,
27 q: [i8, u8, i32, i32] => add_quant;
28 q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
29 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim, String] => |c, a, b| *c = a.clone() + b);
30
31fn add_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
32where
33 T: PrimInt + Bounded + AsPrimitive<i64> + Datum,
34 i64: AsPrimitive<T>,
35{
36 *c = (a.as_() + b.as_() - zp as i64).clamp_cast()
37}
38
39bin_to_super_type!(sub, Sub,
40 linalg:Sub,
41 is_commutative: false,
42 neutral_element: 0,
43 q: [i8, u8, i32, i32] => sub_quant;
44 q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
45 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
46
47bin_to_super_type!(subf, SubF,
48 linalg:SubF,
49 is_commutative: false,
50 neutral_element: 0,
51 q: [i8, u8, i32, i32] => subf_quant;
52 q_op_on_f32: |a: f32, b: f32| -> f32 {b - a},
53 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = b.clone() - a);
54
55fn sub_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
56where
57 T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
58 i16: AsPrimitive<T>,
59{
60 *c = (a.as_() - b.as_() + zp as i16).clamp_cast()
61}
62
63fn subf_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
64where
65 T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
66 i16: AsPrimitive<T>,
67{
68 *c = (b.as_() - a.as_() + zp as i16).clamp_cast()
69}
70
71bin_to_super_type!(mul, Mul,
72 cost: |dt| tvec!((Cost::FMA(dt), 1)),
73 declutter: declutter_mul,
74 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
75 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
77 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
78 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
79 (a.datum_type(), b.datum_type(), c_dt)
80 {
81 let multiplier = a_scale * b_scale * (1.0/ c_scale);
82 let a = a.to_plain_array_view::<u8>()?;
83 let b = b.to_plain_array_view::<u8>()?;
84 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
85 let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
86 let mut c_plain = c.try_as_plain_mut()?;
87 let view = c_plain.to_array_view_mut::<u8>()?;
88 crate::ndarray::Zip::from(view)
89 .and_broadcast(a)
90 .and_broadcast(b)
91 .for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) * (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast());
92 Ok(c)
93 } else {
94 Mul.generic_eval(a, b, c_dt)
95 }
96 },
97 linalg: Mul,
98 neutral_element: 1,
99 absorbing_element: 0,
100 out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
101 if c.datum_type() == TDim::datum_type() &&
102 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
103 let a = a.to_plain_array_view::<TDim>()?;
104 let b = b.cast_to::<i32>()?;
105 let b = b.to_plain_array_view::<i32>()?;
106 let mut c_plain = c.try_as_plain_mut()?;
107 let c = c_plain.to_array_view_mut::<TDim>()?;
108 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() * *b);
109 Ok(true)
110 }
111 else {
112 match c.datum_type() {
113 DatumType::QI8(params) => {
114 let (zp, scale) = params.zp_scale();
115 let a = a.to_plain_array_view::<i8>()?;
116 let b = b.to_plain_array_view::<i8>()?;
117 let mut c_plain = c.try_as_plain_mut()?;
118 let c = c_plain.to_array_view_mut::<i8>()?;
119 crate::ndarray::Zip::from(c)
120 .and_broadcast(a)
121 .and_broadcast(b)
122 .for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast());
123 Ok(true)
124 }
125 DatumType::QU8(params) => {
126 let (zp, scale) = params.zp_scale();
127 let a = a.to_plain_array_view::<u8>()?;
128 let b = b.to_plain_array_view::<u8>()?;
129 let mut c_plain = c.try_as_plain_mut()?;
130 let c = c_plain.to_array_view_mut::<u8>()?;
131 crate::ndarray::Zip::from(c)
132 .and_broadcast(a)
133 .and_broadcast(b)
134 .for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast());
135 Ok(true)
136 }
137 _ => Ok(false)
138 }
139 }
140 },
141 q: [i8, u8, i32] => |c, a, b, zp, scale| {
142 *c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast()
143 };
144 q_op_on_f32: |a: f32, b: f32| a * b,
145 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = a.wrapping_mul(*b),
146 [f32, f16, f64] => |c, a, b| *c = a * b,
147 [TDim] => |c, a, b| *c = a.clone() * b
148);
149
150bin_to_super_type!(div, Div,
151cost: |dt| tvec!((Cost::Div(dt), 1)),
152declutter: declutter_div,
153eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
154 if
155 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
156 let a = a.to_plain_array_view::<TDim>()?;
157 let b = b.to_plain_array_view::<TDim>()?;
158 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
159 unsafe {
160 let a = a.broadcast(&*c_shape).unwrap();
161 let b = b.broadcast(&*c_shape).unwrap();
162 let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
163 let mut c_plain = c.try_as_plain_mut()?;
164 let mut view = c_plain.to_array_view_mut::<TDim>()?;
165 for coords in crate::ndarray::indices(&*c_shape) {
166 let (p, q) = a[&coords].maybe_div(&b[&coords])?;
167 view[&coords] = p/q;
168 }
169 Ok(c)
170 }
171 } else if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
172 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
173 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
174 (a.datum_type(), b.datum_type(), c_dt) {
175
176 let multiplier = a_scale / (b_scale * c_scale);
177 let a = a.to_plain_array_view::<u8>()?;
178 let b = b.to_plain_array_view::<u8>()?;
179 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
180 let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
181 let mut c_plain = c.try_as_plain_mut()?;
182 let view = c_plain.to_array_view_mut::<u8>()?;
183 crate::ndarray::Zip::from(view)
184 .and_broadcast(a)
185 .and_broadcast(b)
186 .for_each(|c,a,b| *c = (
188 scale_by(
189 (*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier
190 ) as i32 + c_zp as i32
191 ).clamp_cast());
192 Ok(c)
193 } else {
194 Div.generic_eval(a, b, c_dt)
195 }
196},
197is_commutative: false,
198neutral_element: 1,
199out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
200 if c.datum_type() == TDim::datum_type() &&
201 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
202 let a = a.to_plain_array_view::<TDim>()?;
203 let b = b.cast_to::<i32>()?;
204 let b = b.to_plain_array_view::<i32>()?;
205 let mut c_plain = c.try_as_plain_mut()?;
206 let c = c_plain.to_array_view_mut::<TDim>()?;
207 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() / *b);
208 Ok(true)
209 } else if c.datum_type().is_quantized() || b.datum_type().is_quantized() || a.datum_type().is_quantized() {
210 let a_f32 = a.cast_to::<f32>()?;
211 let a_f32 = a_f32.to_plain_array_view::<f32>()?;
212 let b_f32 = b.cast_to::<f32>()?;
213 let b_f32 = b_f32.to_plain_array_view::<f32>()?;
214 let c_f32 = &a_f32 / &b_f32;
215 *c = c_f32.into_tensor().cast_to_dt(c.datum_type())?.into_owned();
216 Ok(true)
217 } else {
218 Ok(false)
219 }
220},
221q_op_on_f32: |a: f32, b: f32| a / b,
222[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() / b
223);
224
225bin_to_super_type!(rem, Rem,
226 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
227 if
228 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
229 let a = a.to_plain_array_view::<TDim>()?;
230 let b = b.cast_to::<i32>()?;
231 let b = b.to_plain_array_view::<i32>()?;
232 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
233 unsafe {
234 let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
235 let mut c_plain = c.try_as_plain_mut()?;
236 let view = c_plain.to_array_view_mut::<TDim>()?;
237 crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
238 Ok(c)
239 }
240 } else {
241 Rem.generic_eval(a,b, c_dt)
242 }
243 },
244 out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
245 if c.datum_type() == TDim::datum_type() &&
246 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
247 let a = a.to_plain_array_view::<TDim>()?;
248 let b = b.cast_to::<i32>()?;
249 let b = b.to_plain_array_view::<i32>()?;
250 let mut c_plain = c.try_as_plain_mut()?;
251 let c = c_plain.to_array_view_mut::<TDim>()?;
252 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
253 Ok(true)
254 } else {
255 Ok(false)
256 }
257 },
258 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() % b);
259
260bin_to_super_type!(min, Min, linalg:Min,
261 q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b };
262 q_op_on_f32: |a: f32, b: f32| a.min(b),
263 [f16, f32, f64] => |c,a,b| *c = a.min(*b),
264 [TDim] => |c,a,b| *c = a.clone().mini(b.clone()),
265 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b));
266
267bin_to_super_type!(max, Max,
268 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
269 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
271 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
272 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
273 (a.datum_type(), b.datum_type(), c_dt)
274 && (a.is_uniform() || b.is_uniform()) {
275 let (d, d_zp, d_scale, e, e_zp, e_scale) = if a.is_uniform() && !b.is_uniform() {
278 (&b, &b_zp, &b_scale, &a, &a_zp, &a_scale)
279 } else {
280 (&a, &a_zp, &a_scale, &b, &b_zp, &b_scale)
281 };
282 if e.is_uniform() { let e = e.cast_to::<u8>()?.try_as_plain()?.as_slice::<u8>()?[0];
284 let e_val_as_d_aligned: i32 = scale_by(e as i32 - e_zp, e_scale / d_scale);
285 let multiplier = d_scale * (1.0/ c_scale);
286 let d = d.to_plain_array_view::<u8>()?;
287 let mut c = Tensor::zero_dt(c_dt, d.shape())?;
288 let mut c_plain = c.try_as_plain_mut()?;
289 let view = c_plain.to_array_view_mut::<u8>()?;
290 crate::ndarray::Zip::from(view)
291 .and_broadcast(d)
292 .for_each(|c,d| {
293 let d_min_zp = *d as i32 - *d_zp as i32;
294 let c_val: i32 = if d_min_zp < e_val_as_d_aligned {
295 e_val_as_d_aligned
296 } else {
297 d_min_zp
298 };
299 *c = (scale_by(c_val, multiplier) + c_zp as i32).clamp_cast();
300 });
301 return Ok(c)
302 }
303 }
304 Max.generic_eval(a, b, c_dt)
305 },
306 linalg:Max,
307 q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a };
308 q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)},
309 [f16, f32, f64] => |c,a,b| *c = a.max(*b),
310 [TDim] => |c,a,b| *c = a.clone().maxi(b.clone()),
311 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b));
312
313bin_to_super_type!(pow, Pow,
314 declutter: declutter_pow,
315 is_commutative: false,
316 neutral_element: 1,
317 q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
318 [f16, f32, f64] => |c,a,b| *c = a.powf(*b),
319 [i32, i64] => |c,a,b| *c = a.pow(*b as u32));
320
321bin_to_super_type!(shift_left, ShiftLeft,
322 is_commutative: false,
323 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
324bin_to_super_type!(shift_right, ShiftRight,
325 is_commutative: false,
326 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);
327
328fn declutter_mul(
329 _op: &Mul,
330 model: &TypedModel,
331 node: &TypedNode,
332) -> TractResult<Option<TypedModelPatch>> {
333 if node.inputs[0] == node.inputs[1] && !node.outputs[0].fact.datum_type.is_quantized() {
334 return Ok(Some(TypedModelPatch::replace_single_op(
335 model,
336 node,
337 &node.inputs[0..1],
338 square(),
339 )?));
340 }
341
342 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
343 let var_fact = model.outlet_fact(uniform.var)?;
344 if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
345 let shapes =
346 model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
347 let shape: ShapeFact =
348 crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
349 return Ok(Some(TypedModelPatch::rewire(
350 model,
351 &[],
352 &[node.id.into()],
353 &|patch, _| {
354 let scalar = patch.add_const(
355 format!("{}.zero", node.name),
356 if uniform.uni.datum_type().is_quantized() {
357 let output_dt = node.outputs[0].fact.datum_type;
358 Arc::new(uniform.uni.clone().cast_to_dt(output_dt)?.into_owned())
359 } else {
360 uniform.uni.clone()
361 },
362 )?;
363 let op = MultiBroadcastTo::new(shape.clone());
364 patch.wire_node(&node.name, op, &[scalar])
365 },
366 )?));
367 }
368 let dt = uniform.uni.datum_type();
369 if !dt.is_quantized() {
370 let integer = uniform.uni.cast_to_scalar::<i64>()?;
372 if tensor0(integer)
373 .cast_to_dt(uniform.uni.datum_type())?
374 .close_enough(&uniform.uni, false)
375 .is_ok()
376 && uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
377 && dt.is_integer()
378 {
379 let shift = integer.trailing_zeros();
380 return Ok(Some(TypedModelPatch::rewire(
381 model,
382 &[uniform.var],
383 &[node.id.into()],
384 &|patch, taps| {
385 let shift = patch.add_const(
386 format!("{}.shift", node.name),
387 tensor0(shift)
388 .cast_to_dt(dt)?
389 .into_owned()
390 .broadcast_into_rank(var_fact.rank())?,
391 )?;
392 patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
393 },
394 )?));
395 }
396 }
397 }
398 if let Some(patch) = declutter_mul_const_mul_const(model, node)? {
399 return Ok(Some(patch));
400 }
401 Ok(None)
402}
403
404fn declutter_mul_const_mul_const(
405 model: &TypedModel,
406 node: &TypedNode,
407) -> TractResult<Option<TypedModelPatch>> {
408 let input_facts = model.node_input_facts(node.id)?;
409 rule_if_some!(const_slot = input_facts.iter().position(|f| f.konst.is_some()));
410 let prec = model.node(node.inputs[1 - const_slot].node);
411 rule_if_some!(prec_mul = prec.op_as::<TypedBinOp>());
412 rule_if!(prec.outputs[0].successors.len() <= 1);
413 rule_if!(prec_mul.0.is::<Mul>());
414 let prec_input_facts = model.node_input_facts(prec.id)?;
415 rule_if_some!(prec_const_slot = prec_input_facts.iter().position(|f| f.konst.is_some()));
416
417 let const_fact = model.outlet_fact(node.inputs[const_slot])?;
418 let prec_const_fact = model.outlet_fact(prec.inputs[prec_const_slot])?;
419 rule_if!(const_fact.shape.volume().is_one() || prec_const_fact.shape.volume().is_one());
421 rule_if!(const_fact.datum_type.is_float());
422 let result = mul()
423 .eval(tvec!(
424 const_fact.konst.clone().unwrap().into_tvalue(),
425 prec_const_fact.konst.clone().unwrap().into_tvalue()
426 ))?
427 .remove(0)
428 .into_arc_tensor();
429 let mut patch = TypedModelPatch::default();
430 let konst = patch.add_const(&prec.name, result)?;
431 let input_tap = patch.tap_model(model, prec.inputs[1 - prec_const_slot])?;
432 let wire = patch.wire_node(&node.name, mul(), &[konst, input_tap])?;
433 patch.shunt_outside(model, node.id.into(), wire[0])?;
434 Ok(Some(patch))
435}
436
437fn declutter_div(
438 _op: &Div,
439 model: &TypedModel,
440 node: &TypedNode,
441) -> TractResult<Option<TypedModelPatch>> {
442 if let &[p, q] = &*model.node_input_facts(node.id)? {
443 let dt = q.datum_type;
444 if let Some(q) = &q.uniform
445 && let Ok(integer) = q.cast_to_scalar::<i64>()
446 && tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
447 && dt.is_integer()
448 && q.cast_to_scalar::<i64>()?.count_ones() == 1
449 {
450 let shift = integer.trailing_zeros();
451 return Ok(Some(TypedModelPatch::rewire(
452 model,
453 &[node.inputs[0]],
454 &[node.id.into()],
455 &|patch, taps| {
456 let shift = patch.add_const(
457 format!("{}.shift", node.name),
458 tensor0(shift)
459 .cast_to_dt(dt)?
460 .into_owned()
461 .broadcast_into_rank(p.rank())?,
462 )?;
463 patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
464 },
465 )?));
466 }
467 if dt.is_float() {
468 return Ok(Some(TypedModelPatch::rewire(
469 model,
470 &node.inputs,
471 &[node.id.into()],
472 &|patch, taps| {
473 let q =
474 patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?[0];
475 patch.wire_node(&node.name, mul(), &[taps[0], q])
476 },
477 )?));
478 }
479 }
480 Ok(None)
481}
482
483fn declutter_pow(
484 _op: &Pow,
485 model: &TypedModel,
486 node: &TypedNode,
487) -> TractResult<Option<TypedModelPatch>> {
488 let b = model.outlet_fact(node.inputs[1])?;
489 if let Some(b) = &b.uniform {
490 let b = b.cast_to_scalar::<f32>()?;
491 if b == 2.0 {
492 return Ok(Some(TypedModelPatch::replace_single_op(
493 model,
494 node,
495 &[node.inputs[0]],
496 square(),
497 )?));
498 } else if b == 0.5 {
499 return Ok(Some(TypedModelPatch::replace_single_op(
500 model,
501 node,
502 &[node.inputs[0]],
503 sqrt(),
504 )?));
505 }
506 }
507 crate::ops::nn::gelu_approximate::detect_gelu_approx(_op, model, node)
508}
509
510element_wise!(abs, Abs, [i8, i16, i32, i64, f16, f32, i32] => |_, xs| {
511 xs.iter_mut().for_each(|x| *x = x.abs());
512 Ok(())
513};
514q: [i8, u8, i32, i32] => f32::abs;
515operating_datum_type: |dt| if dt == TDim::datum_type() { i64::datum_type() } else { dt }
516);
517
518element_wise!(exp, Exp, [f16, f32, f64] => |_, xs| {
519 xs.iter_mut().for_each(|x| *x = x.exp());
520 Ok(())
521};
522q: [i8, u8, i32, i32] => f32::exp;
523validation: Validation::Rounding
524);
525
526element_wise!(ln, Ln, [f16, f32, f64] => |_, xs| {
527 xs.iter_mut().for_each(|x| *x = x.ln());
528 Ok(())
529};
530q: [i8, u8, i32, i32] => f32::ln;
531validation: Validation::Rounding
532);
533
534element_wise!(square, Square, [f16, f32, f64] => |_, xs| {
535 xs.iter_mut().for_each(|x| *x = x.powi(2));
536 Ok(())
537};
538q: [i8, u8, i32, i32] => |f : f32| f.powi(2);
539declutter: declutter_square;
540validation: Validation::Rounding
541);
542
543fn declutter_square(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
544 use super::element_wise::*;
545 if let Some(prec) = model.linear_prec(node.id)?
547 && let Some(ew) = prec.op_as::<ElementWiseOp>()
548 && ew.0.is::<Sqrt>()
549 {
550 let mut patch = TypedModelPatch::default();
551 let tap = patch.tap_model(model, prec.inputs[0])?;
552 patch.shunt_outside(model, node.id.into(), tap)?;
553 return Ok(Some(patch));
554 }
555 Ok(None)
556}
557
558element_wise!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| {
559 xs.iter_mut().for_each(|x| *x = x.sqrt());
560 Ok(())
561};
562q: [i8, u8, i32, i32] => f32::sqrt;
563validation: Validation::Rounding
564);
565
566element_wise!(recip, Recip, [f16, f32, f64] => |_, xs| {
567 xs.iter_mut().for_each(|x| *x = x.recip());
568 Ok(())
569};
570q: [i8, u8, i32, i32] => f32::recip;
571cost: |dt| {tvec!((Cost::Div(dt), 1))};
572declutter: declutter_recip;
573validation: Validation::Rounding
574);
575
576fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
577 use super::element_wise::*;
578 if let Some(prec) = model.linear_prec(node.id)?
579 && let Some(ew) = prec.op_as::<ElementWiseOp>()
580 {
581 let repl = if ew.0.is::<Sqrt>() {
582 Some(rsqrt())
583 } else if ew.0.is::<Rsqrt>() {
584 Some(sqrt())
585 } else {
586 None
587 };
588 if let Some(repl) = repl {
589 let mut patch = TypedModelPatch::default();
590 let mut wire = patch.tap_model(model, prec.inputs[0])?;
591 wire = patch.wire_node(&node.name, repl, &[wire])?[0];
592 patch.shunt_outside(model, node.id.into(), wire)?;
593 return Ok(Some(patch));
594 }
595 }
596 Ok(None)
597}
598
599element_wise!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| {
600 xs.iter_mut().for_each(|x| *x = x.sqrt().recip());
601 Ok(())
602};
603q: [i8, u8, i32] => |x : f32| x.sqrt().recip();
604validation: Validation::Rounding
605);
606
607element_wise!(ceil, Ceil, [f16, f32, f64] => |_, xs| {
608 xs.iter_mut().for_each(|x| *x = x.ceil());
609 Ok(())
610}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
611q: [i8, u8, i32] => f32::recip);
612
613element_wise!(floor, Floor, [f16, f32, f64] => |_, xs| {
614 xs.iter_mut().for_each(|x| *x = x.floor());
615 Ok(())
616}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
617q: [i8, u8, i32] => f32::floor);
618
619element_wise!(round, Round, [f16, f32, f64] => |_, xs| {
620 xs.iter_mut().for_each(|x| *x = x.round());
621 Ok(())
622}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
623q: [i8, u8, i32] => f32::round);
624
625element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| {
626 xs.iter_mut().for_each(|x| *x = x.q_scale(op.scaler));
627 Ok(())
628});
629
630element_wise!(round_half_to_even, RoundHalfToEven,
631[f32] => |_, xs| {
632 xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x));
633 Ok(())
634},
635[f16] => |_, xs| {
636 xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32())));
637 Ok(())
638};
639q: [i8, u8, i32] => round_ties_to_even);
640
641element_wise!(cos, Cos, [f16, f32, f64] => |_, xs| {
642 xs.iter_mut().for_each(|x| *x = x.cos());
643 Ok(())
644};
645q: [i8, u8, i32] => f32::cos);
646
647element_wise!(sin, Sin, [f16, f32, f64] => |_, xs| {
648 xs.iter_mut().for_each(|x| *x = x.sin());
649 Ok(())
650};
651q: [i8, u8, i32] => f32::sin);
652
653element_wise!(tan, Tan, [f16, f32, f64] => |_, xs| {
654 xs.iter_mut().for_each(|x| *x = x.tan());
655 Ok(())
656};
657q: [i8, u8, i32] => f32::tan);
658
659element_wise!(acos, Acos, [f16, f32, f64] => |_, xs| {
660 xs.iter_mut().for_each(|x| *x = x.acos());
661 Ok(())
662};
663q: [i8, u8, i32] => f32::acos);
664
665element_wise!(asin, Asin, [f16, f32, f64] => |_, xs| {
666 xs.iter_mut().for_each(|x| *x = x.asin());
667 Ok(())
668};
669q: [i8, u8, i32] => f32::asin);
670
671element_wise!(atan, Atan, [f16, f32, f64] => |_, xs| {
672 xs.iter_mut().for_each(|x| *x = x.atan());
673 Ok(())
674};
675q: [i8, u8, i32] => f32::atan);
676
677element_wise!(cosh, Cosh, [f16, f32, f64] => |_, xs| {
678 xs.iter_mut().for_each(|x| *x = x.cosh());
679 Ok(())
680};
681q: [i8, u8, i32] => f32::cosh);
682
683element_wise!(sinh, Sinh, [f16, f32, f64] => |_, xs| {
684 xs.iter_mut().for_each(|x| *x = x.sinh());
685 Ok(())
686};
687q: [i8, u8, i32] => f32::sinh);
688
689element_wise!(tanh, Tanh,
690 [f16] => |_, xs| { (tract_linalg::ops().tanh_f16)().run(xs) },
691 [f32] => |_, xs| { (tract_linalg::ops().tanh_f32)().run(xs) },
692 [f64] => |_, xs| { xs.iter_mut().for_each(|x| *x = x.tanh()); Ok(()) };
693 q: [i8, u8, i32] => f32::tanh;
694 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
695);
696
697element_wise!(erf, Erf,
698 [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) },
699 [f16] => |_, xs| {
700 let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec();
701 (tract_linalg::ops().erf_f32)().run(&mut f32s)?;
702 xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f));
703 Ok(())
704};
705 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
706);
707
708element_wise!(acosh, Acosh, [f16, f32, f64] => |_, xs| {
709 xs.iter_mut().for_each(|x| *x = x.acosh());
710 Ok(())
711};
712q: [i8, u8, i32] => f32::acosh);
713element_wise!(asinh, Asinh, [f16, f32, f64] => |_, xs| {
714 xs.iter_mut().for_each(|x| *x = x.asinh());
715 Ok(())
716};
717q: [i8, u8, i32] => f32::asinh);
718element_wise!(atanh, Atanh, [f16, f32, f64] => |_, xs| {
719 xs.iter_mut().for_each(|x| *x = x.atanh());
720 Ok(())
721};
722q: [i8, u8, i32] => f32::atanh);
723
724element_wise!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| {
725 xs.iter_mut().for_each(|x| *x = -x.clone());
726 Ok(())
727};
728q: [i8, u8, i32] => |x: f32| -x);
729
730element_wise!(sign, Sign, [f16, f32, f64] => |_, xs| {
731 xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() });
732 Ok(())
733};
734q: [i8, u8, i32] => f32::signum);
735
736element_wise_oop!(is_inf, IsInf { detect_positive: bool, detect_negative: bool },
737 [f32] => bool |op, xs, ys| {
738 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
739 *y = (op.detect_positive && *x == f32::INFINITY) || (op.detect_negative && *x == f32::NEG_INFINITY)
740 );
741 Ok(())
742 },
743 [f16] => bool |op, xs, ys| {
744 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
745 *y = (op.detect_positive && *x == f16::INFINITY) || (op.detect_negative && *x == f16::NEG_INFINITY)
746 );
747 Ok(())
748 }
749);
750
751element_wise_oop!(is_nan, IsNan,
752 [f16, f32] => bool |_, xs, ys| {
753 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = x.is_nan());
754 Ok(())
755 }
756);
757
758#[cfg(test)]
759mod tests {
760 use crate::ops::binary::TypedBinOp;
761
762 use super::*;
763 use ndarray::arr2;
764
765 #[test]
766 fn test_mul() {
767 let a = arr2(&[[1., 2.], [3., 4.]]);
768 let b = arr2(&[[1., 0.], [0., 0.]]);
769 assert_eq!(a * b, arr2(&[[1., 0.], [0., 0.]]));
770 }
771
772 #[test]
773 fn dot() {
774 let a = arr2(&[[1., 2.], [3., 4.]]);
775 let b = arr2(&[[1., 0.], [0., 0.]]);
776 assert_eq!(a.dot(&b), arr2(&[[1., 0.], [3., 0.]]));
777 }
778
779 #[test]
780 fn mul_as_shift_left() -> TractResult<()> {
781 let mut model = TypedModel::default();
782 let x = model.add_source("x", i32::fact([2usize, 2]))?;
783 let a = model.add_const("a", tensor0(4i32).broadcast_into_rank(2)?.into_arc_tensor())?;
784 let y = model.wire_node("y", mul(), &[x, a])?[0];
785 model.select_output_outlets(&[y])?;
786 let result =
787 SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
788 assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
789 let decluttered = model.into_decluttered()?;
790 let result =
791 SimplePlan::new(decluttered.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
792 assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
793 let op = decluttered
794 .node(decluttered.output_outlets()?[0].node)
795 .op()
796 .downcast_ref::<TypedBinOp>()
797 .unwrap();
798 assert!(op.0.downcast_ref::<ShiftLeft>().is_some());
799 Ok(())
800 }
801
802 #[test]
803 fn div_as_shift() -> TractResult<()> {
804 let mut model = TypedModel::default();
805 let x = model.add_source("a", i32::fact([2usize, 2]))?;
806 let s = model.add_const("shift", tensor2(&[[4]]))?;
807 let y = model.wire_node("c", div(), [x, s].as_ref())?[0];
808 model.select_output_outlets(&[y])?;
809 let result =
810 SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
811 assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
812 let decluttered = model.into_decluttered()?;
813 let result = SimplePlan::new(decluttered.clone())?
814 .run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
815 assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
816 let op = decluttered
817 .node(decluttered.output_outlets()?[0].node)
818 .op()
819 .downcast_ref::<TypedBinOp>()
820 .unwrap();
821 assert!(op.0.downcast_ref::<ShiftRight>().is_some());
822 Ok(())
823 }
824}