1#![allow(clippy::clone_on_copy)]
2#![allow(clippy::unnecessary_cast)]
3#![allow(clippy::blocks_in_conditions)]
4
5use super::array::MultiBroadcastTo;
6use super::binary::TypedBinOp;
7use crate::internal::*;
8use crate::ops::quant::scale_by;
9use num_traits::bounds::Bounded;
10use num_traits::int::PrimInt;
11use num_traits::{Float, Zero};
12use tract_data::internal::ClampCast;
13use tract_data::itertools::Itertools;
14pub use tract_data::prelude::round_ties_to_even;
15use tract_linalg::{ScaleShiftAndRound, Scaler};
16use tract_num_traits::AsPrimitive;
17
18#[cfg(feature = "complex")]
19mod complex;
20#[cfg(feature = "complex")]
21pub use complex::{ComplexToInnerDim, InnerDimToComplex};
22
23bin_to_super_type!(add, Add,
24 linalg: Add,
25 neutral_element: 0,
26 validation: Validation::Rounding,
27 q: [i8, u8, i32, i32] => add_quant;
28 q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
29 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim, String] => |c, a, b| *c = a.clone() + b);
30
31fn add_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
32where
33 T: PrimInt + Bounded + AsPrimitive<i64> + Datum,
34 i64: AsPrimitive<T>,
35{
36 *c = (a.as_() + b.as_() - zp as i64).clamp_cast()
37}
38
39bin_to_super_type!(sub, Sub,
40 linalg:Sub,
41 is_commutative: false,
42 neutral_element: 0,
43 q: [i8, u8, i32, i32] => sub_quant;
44 q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
45 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
46
47bin_to_super_type!(subf, SubF,
48 linalg:SubF,
49 is_commutative: false,
50 neutral_element: 0,
51 q: [i8, u8, i32, i32] => subf_quant;
52 q_op_on_f32: |a: f32, b: f32| -> f32 {b - a},
53 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = b.clone() - a);
54
55fn sub_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
56where
57 T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
58 i16: AsPrimitive<T>,
59{
60 *c = (a.as_() - b.as_() + zp as i16).clamp_cast()
61}
62
63fn subf_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
64where
65 T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
66 i16: AsPrimitive<T>,
67{
68 *c = (b.as_() - a.as_() + zp as i16).clamp_cast()
69}
70
71bin_to_super_type!(mul, Mul,
72 cost: |dt| tvec!((Cost::FMA(dt), 1)),
73 declutter: declutter_mul,
74 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
75 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
77 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
78 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
79 (a.datum_type(), b.datum_type(), c_dt)
80 {
81 let multiplier = a_scale * b_scale * (1.0/ c_scale);
82 let a = a.to_array_view::<u8>()?;
83 let b = b.to_array_view::<u8>()?;
84 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
85 let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
86 let view = c.to_array_view_mut::<u8>()?;
87 crate::ndarray::Zip::from(view)
88 .and_broadcast(a)
89 .and_broadcast(b)
90 .for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) * (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast());
91 Ok(c)
92 } else {
93 Mul.generic_eval(a, b, c_dt)
94 }
95 },
96 linalg: Mul,
97 neutral_element: 1,
98 out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
99 if c.datum_type() == TDim::datum_type() &&
100 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
101 let a = a.to_array_view::<TDim>()?;
102 let b = b.cast_to::<i32>()?;
103 let b = b.to_array_view::<i32>()?;
104 let c = c.to_array_view_mut::<TDim>()?;
105 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() * *b);
106 Ok(true)
107 }
108 else {
109 match c.datum_type() {
110 DatumType::QI8(params) => {
111 let (zp, scale) = params.zp_scale();
112 let a = a.to_array_view::<i8>()?;
113 let b = b.to_array_view::<i8>()?;
114 let c = c.to_array_view_mut::<i8>()?;
115 crate::ndarray::Zip::from(c)
116 .and_broadcast(a)
117 .and_broadcast(b)
118 .for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast());
119 Ok(true)
120 }
121 DatumType::QU8(params) => {
122 let (zp, scale) = params.zp_scale();
123 let a = a.to_array_view::<u8>()?;
124 let b = b.to_array_view::<u8>()?;
125 let c = c.to_array_view_mut::<u8>()?;
126 crate::ndarray::Zip::from(c)
127 .and_broadcast(a)
128 .and_broadcast(b)
129 .for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast());
130 Ok(true)
131 }
132 _ => Ok(false)
133 }
134 }
135 },
136 q: [i8, u8, i32] => |c, a, b, zp, scale| {
137 *c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast()
138 };
139 q_op_on_f32: |a: f32, b: f32| a * b,
140[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b
141);
142
143bin_to_super_type!(div, Div,
144cost: |dt| tvec!((Cost::Div(dt), 1)),
145declutter: declutter_div,
146eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
147 if
148 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
149 let a = a.to_array_view::<TDim>()?;
150 let b = b.to_array_view::<TDim>()?;
151 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
152 unsafe {
153 let a = a.broadcast(&*c_shape).unwrap();
154 let b = b.broadcast(&*c_shape).unwrap();
155 let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
156 let mut view = c.to_array_view_mut::<TDim>()?;
157 for coords in crate::ndarray::indices(&*c_shape) {
158 let (p, q) = a[&coords].maybe_div(&b[&coords])?;
159 view[&coords] = p/q;
160 }
161 Ok(c)
162 }
163 } else if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
164 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
165 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
166 (a.datum_type(), b.datum_type(), c_dt) {
167
168 let multiplier = a_scale / (b_scale * c_scale);
169 let a = a.to_array_view::<u8>()?;
170 let b = b.to_array_view::<u8>()?;
171 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
172 let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
173 let view = c.to_array_view_mut::<u8>()?;
174 crate::ndarray::Zip::from(view)
175 .and_broadcast(a)
176 .and_broadcast(b)
177 .for_each(|c,a,b| *c = (
179 scale_by(
180 (*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier
181 ) as i32 + c_zp as i32
182 ).clamp_cast());
183 Ok(c)
184 } else {
185 Div.generic_eval(a, b, c_dt)
186 }
187},
188is_commutative: false,
189neutral_element: 1,
190out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
191 if c.datum_type() == TDim::datum_type() &&
192 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
193 let a = a.to_array_view::<TDim>()?;
194 let b = b.cast_to::<i32>()?;
195 let b = b.to_array_view::<i32>()?;
196 let c = c.to_array_view_mut::<TDim>()?;
197 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() / *b);
198 Ok(true)
199 } else if c.datum_type().is_quantized() || b.datum_type().is_quantized() || a.datum_type().is_quantized() {
200 let a_f32 = a.cast_to::<f32>()?;
201 let a_f32 = a_f32.to_array_view::<f32>()?;
202 let b_f32 = b.cast_to::<f32>()?;
203 let b_f32 = b_f32.to_array_view::<f32>()?;
204 let c_f32 = &a_f32 / &b_f32;
205 *c = c_f32.into_tensor().cast_to_dt(c.datum_type())?.into_owned();
206 Ok(true)
207 } else {
208 Ok(false)
209 }
210},
211q_op_on_f32: |a: f32, b: f32| a / b,
212[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() / b
213);
214
215bin_to_super_type!(rem, Rem,
216 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
217 if
218 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
219 let a = a.to_array_view::<TDim>()?;
220 let b = b.cast_to::<i32>()?;
221 let b = b.to_array_view::<i32>()?;
222 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
223 unsafe {
224 let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
225 let view = c.to_array_view_mut::<TDim>()?;
226 crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
227 Ok(c)
228 }
229 } else {
230 Rem.generic_eval(a,b, c_dt)
231 }
232 },
233 out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
234 if c.datum_type() == TDim::datum_type() &&
235 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
236 let a = a.to_array_view::<TDim>()?;
237 let b = b.cast_to::<i32>()?;
238 let b = b.to_array_view::<i32>()?;
239 let c = c.to_array_view_mut::<TDim>()?;
240 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
241 Ok(true)
242 } else {
243 Ok(false)
244 }
245 },
246 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() % b);
247
248bin_to_super_type!(min, Min, linalg:Min,
249 q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b };
250 q_op_on_f32: |a: f32, b: f32| a.min(b),
251 [f16, f32, f64] => |c,a,b| *c = a.min(*b),
252 [TDim] => |c,a,b| *c = a.clone().mini(b.clone()),
253 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b));
254
255bin_to_super_type!(max, Max,
256 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
257 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
259 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
260 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
261 (a.datum_type(), b.datum_type(), c_dt)
262 {
263 if a.is_uniform() || b.is_uniform() {
264 let (d, d_zp, d_scale, e, e_zp, e_scale) = if a.is_uniform() && !b.is_uniform() {
267 (&b, &b_zp, &b_scale, &a, &a_zp, &a_scale)
268 } else {
269 (&a, &a_zp, &a_scale, &b, &b_zp, &b_scale)
270 };
271 if e.is_uniform() { let e = e.cast_to::<u8>()?.as_slice::<u8>()?[0];
273 let e_val_as_d_aligned: i32 = scale_by(e as i32 - e_zp, e_scale / d_scale);
274 let multiplier = d_scale * (1.0/ c_scale);
275 let d = d.to_array_view::<u8>()?;
276 let mut c = Tensor::zero_dt(c_dt, d.shape())?;
277 let view = c.to_array_view_mut::<u8>()?;
278 crate::ndarray::Zip::from(view)
279 .and_broadcast(d)
280 .for_each(|c,d| {
281 let d_min_zp = *d as i32 - *d_zp as i32;
282 let c_val: i32 = if d_min_zp < e_val_as_d_aligned {
283 e_val_as_d_aligned
284 } else {
285 d_min_zp
286 };
287 *c = (scale_by(c_val, multiplier) + c_zp as i32).clamp_cast();
288 });
289 return Ok(c)
290 }
291 }
292 }
293 Max.generic_eval(a, b, c_dt)
294 },
295 linalg:Max,
296 q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a };
297 q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)},
298 [f16, f32, f64] => |c,a,b| *c = a.max(*b),
299 [TDim] => |c,a,b| *c = a.clone().maxi(b.clone()),
300 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b));
301
302bin_to_super_type!(pow, Pow,
303 declutter: declutter_pow,
304 is_commutative: false,
305 neutral_element: 1,
306 q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
307 [f16, f32, f64] => |c,a,b| *c = a.powf(*b),
308 [i32, i64] => |c,a,b| *c = a.pow(*b as u32));
309
310bin_to_super_type!(shift_left, ShiftLeft,
311 is_commutative: false,
312 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
313bin_to_super_type!(shift_right, ShiftRight,
314 is_commutative: false,
315 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);
316
317fn declutter_mul(
318 _op: &Mul,
319 model: &TypedModel,
320 node: &TypedNode,
321) -> TractResult<Option<TypedModelPatch>> {
322 if node.inputs[0] == node.inputs[1] && !node.outputs[0].fact.datum_type.is_quantized() {
323 return Ok(Some(TypedModelPatch::replace_single_op(
324 model,
325 node,
326 &node.inputs[0..1],
327 square(),
328 )?));
329 }
330
331 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
332 let var_fact = model.outlet_fact(uniform.var)?;
333 if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
334 let shapes =
335 model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
336 let shape: ShapeFact =
337 crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
338 return Ok(Some(TypedModelPatch::rewire(
339 model,
340 &[],
341 &[node.id.into()],
342 &|patch, _| {
343 let scalar = patch.add_const(
344 format!("{}.zero", node.name),
345 if uniform.uni.datum_type().is_quantized() {
346 let output_dt = node.outputs[0].fact.datum_type;
347 Arc::new(uniform.uni.clone().cast_to_dt(output_dt)?.into_owned())
348 } else {
349 uniform.uni.clone()
350 },
351 )?;
352 let op = MultiBroadcastTo::new(shape.clone());
353 patch.wire_node(&node.name, op, &[scalar])
354 },
355 )?));
356 }
357 let dt = uniform.uni.datum_type();
358 if !dt.is_quantized() {
359 let integer = uniform.uni.cast_to_scalar::<i64>()?;
361 if tensor0(integer)
362 .cast_to_dt(uniform.uni.datum_type())?
363 .close_enough(&uniform.uni, false)
364 .is_ok()
365 && uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
366 && dt.is_integer()
367 {
368 let shift = integer.trailing_zeros();
369 return Ok(Some(TypedModelPatch::rewire(
370 model,
371 &[uniform.var],
372 &[node.id.into()],
373 &|patch, taps| {
374 let shift = patch.add_const(
375 format!("{}.shift", node.name),
376 tensor0(shift)
377 .cast_to_dt(dt)?
378 .into_owned()
379 .broadcast_into_rank(var_fact.rank())?,
380 )?;
381 patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
382 },
383 )?));
384 }
385 }
386 }
387 if let Some(patch) = declutter_mul_const_mul_const(model, node)? {
388 return Ok(Some(patch));
389 }
390 Ok(None)
391}
392
393fn declutter_mul_const_mul_const(
394 model: &TypedModel,
395 node: &TypedNode,
396) -> TractResult<Option<TypedModelPatch>> {
397 let input_facts = model.node_input_facts(node.id)?;
398 let Some(const_slot) = input_facts.iter().position(|f| f.konst.is_some()) else {
399 return Ok(None);
400 };
401 let prec = model.node(node.inputs[1 - const_slot].node);
402 let Some(prec_mul) = prec.op_as::<TypedBinOp>() else {
403 return Ok(None);
404 };
405 if prec.outputs[0].successors.len() > 1 {
406 return Ok(None);
407 };
408 if !prec_mul.0.is::<Mul>() {
409 return Ok(None);
410 }
411 let prec_input_facts = model.node_input_facts(prec.id)?;
412 let Some(prec_const_slot) = prec_input_facts.iter().position(|f| f.konst.is_some()) else {
413 return Ok(None);
414 };
415
416 let const_fact = model.outlet_fact(node.inputs[const_slot])?;
417 let prec_const_fact = model.outlet_fact(prec.inputs[prec_const_slot])?;
418 if !const_fact.shape.volume().is_one() && !prec_const_fact.shape.volume().is_one() {
420 return Ok(None);
421 }
422 if !const_fact.datum_type.is_float() {
423 return Ok(None);
424 }
425 let result = mul()
426 .eval(tvec!(
427 const_fact.konst.clone().unwrap().into_tvalue(),
428 prec_const_fact.konst.clone().unwrap().into_tvalue()
429 ))?
430 .remove(0)
431 .into_arc_tensor();
432 let mut patch = TypedModelPatch::default();
433 let konst = patch.add_const(&prec.name, result)?;
434 let input_tap = patch.tap_model(model, prec.inputs[1 - prec_const_slot])?;
435 let wire = patch.wire_node(&node.name, mul(), &[konst, input_tap])?;
436 patch.shunt_outside(model, node.id.into(), wire[0])?;
437 Ok(Some(patch))
438}
439
440fn declutter_div(
441 _op: &Div,
442 model: &TypedModel,
443 node: &TypedNode,
444) -> TractResult<Option<TypedModelPatch>> {
445 if let &[p, q] = &*model.node_input_facts(node.id)? {
446 let dt = q.datum_type;
447 if let Some(q) = &q.uniform {
448 if let Ok(integer) = q.cast_to_scalar::<i64>() {
449 if tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
450 && dt.is_integer()
451 && q.cast_to_scalar::<i64>()?.count_ones() == 1
452 {
453 let shift = integer.trailing_zeros();
454 return Ok(Some(TypedModelPatch::rewire(
455 model,
456 &[node.inputs[0]],
457 &[node.id.into()],
458 &|patch, taps| {
459 let shift = patch.add_const(
460 format!("{}.shift", node.name),
461 tensor0(shift)
462 .cast_to_dt(dt)?
463 .into_owned()
464 .broadcast_into_rank(p.rank())?,
465 )?;
466 patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
467 },
468 )?));
469 }
470 }
471 }
472 if dt.is_float() {
473 return Ok(Some(TypedModelPatch::rewire(
474 model,
475 &node.inputs,
476 &[node.id.into()],
477 &|patch, taps| {
478 let q =
479 patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?[0];
480 patch.wire_node(&node.name, mul(), &[taps[0], q])
481 },
482 )?));
483 }
484 }
485 Ok(None)
486}
487
488fn declutter_pow(
489 _op: &Pow,
490 model: &TypedModel,
491 node: &TypedNode,
492) -> TractResult<Option<TypedModelPatch>> {
493 let b = model.outlet_fact(node.inputs[1])?;
494 if let Some(b) = &b.uniform {
495 let b = b.cast_to_scalar::<f32>()?;
496 if b == 2.0 {
497 return Ok(Some(TypedModelPatch::replace_single_op(
498 model,
499 node,
500 &[node.inputs[0]],
501 square(),
502 )?));
503 } else if b == 0.5 {
504 return Ok(Some(TypedModelPatch::replace_single_op(
505 model,
506 node,
507 &[node.inputs[0]],
508 sqrt(),
509 )?));
510 }
511 }
512 Ok(None)
513}
514
515element_wise!(abs, Abs, [i8, i16, i32, i64, f16, f32, i32] => |_, xs| {
516 xs.iter_mut().for_each(|x| *x = x.abs());
517 Ok(())
518};
519q: [i8, u8, i32, i32] => f32::abs;
520operating_datum_type: |dt| if dt == TDim::datum_type() { i64::datum_type() } else { dt }
521);
522
523element_wise!(exp, Exp, [f16, f32, f64] => |_, xs| {
524 xs.iter_mut().for_each(|x| *x = x.exp());
525 Ok(())
526};
527q: [i8, u8, i32, i32] => f32::exp;
528validation: Validation::Rounding
529);
530
531element_wise!(ln, Ln, [f16, f32, f64] => |_, xs| {
532 xs.iter_mut().for_each(|x| *x = x.ln());
533 Ok(())
534};
535q: [i8, u8, i32, i32] => f32::ln;
536validation: Validation::Rounding
537);
538
539element_wise!(square, Square, [f16, f32, f64] => |_, xs| {
540 xs.iter_mut().for_each(|x| *x = x.powi(2));
541 Ok(())
542};
543q: [i8, u8, i32, i32] => |f : f32| f.powi(2);
544validation: Validation::Rounding
545);
546
547element_wise!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| {
548 xs.iter_mut().for_each(|x| *x = x.sqrt());
549 Ok(())
550};
551q: [i8, u8, i32, i32] => f32::sqrt;
552validation: Validation::Rounding
553);
554
555element_wise!(recip, Recip, [f16, f32, f64] => |_, xs| {
556 xs.iter_mut().for_each(|x| *x = x.recip());
557 Ok(())
558};
559q: [i8, u8, i32, i32] => f32::recip;
560cost: |dt| {tvec!((Cost::Div(dt), 1))};
561declutter: declutter_recip;
562validation: Validation::Rounding
563);
564
565fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
566 use super::element_wise::*;
567 if let Some(prec) = model.linear_prec(node.id)? {
568 if let Some(ew) = prec.op_as::<ElementWiseOp>() {
569 let repl = if ew.0.is::<Sqrt>() {
570 Some(rsqrt())
571 } else if ew.0.is::<Rsqrt>() {
572 Some(sqrt())
573 } else {
574 None
575 };
576 if let Some(repl) = repl {
577 let mut patch = TypedModelPatch::default();
578 let mut wire = patch.tap_model(model, prec.inputs[0])?;
579 wire = patch.wire_node(&node.name, repl, &[wire])?[0];
580 patch.shunt_outside(model, node.id.into(), wire)?;
581 return Ok(Some(patch));
582 }
583 }
584 }
585 Ok(None)
586}
587
588element_wise!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| {
589 xs.iter_mut().for_each(|x| *x = x.sqrt().recip());
590 Ok(())
591};
592q: [i8, u8, i32] => |x : f32| x.sqrt().recip();
593validation: Validation::Rounding
594);
595
596element_wise!(ceil, Ceil, [f16, f32, f64] => |_, xs| {
597 xs.iter_mut().for_each(|x| *x = x.ceil());
598 Ok(())
599}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
600q: [i8, u8, i32] => f32::recip);
601
602element_wise!(floor, Floor, [f16, f32, f64] => |_, xs| {
603 xs.iter_mut().for_each(|x| *x = x.floor());
604 Ok(())
605}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
606q: [i8, u8, i32] => f32::floor);
607
608element_wise!(round, Round, [f16, f32, f64] => |_, xs| {
609 xs.iter_mut().for_each(|x| *x = x.round());
610 Ok(())
611}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
612q: [i8, u8, i32] => f32::round);
613
614element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| {
615 xs.iter_mut().for_each(|x| *x = x.q_scale(op.scaler));
616 Ok(())
617});
618
619element_wise!(round_half_to_even, RoundHalfToEven,
620[f32] => |_, xs| {
621 xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x));
622 Ok(())
623},
624[f16] => |_, xs| {
625 xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32())));
626 Ok(())
627};
628q: [i8, u8, i32] => round_ties_to_even);
629
630element_wise!(cos, Cos, [f16, f32, f64] => |_, xs| {
631 xs.iter_mut().for_each(|x| *x = x.cos());
632 Ok(())
633};
634q: [i8, u8, i32] => f32::cos);
635
636element_wise!(sin, Sin, [f16, f32, f64] => |_, xs| {
637 xs.iter_mut().for_each(|x| *x = x.sin());
638 Ok(())
639};
640q: [i8, u8, i32] => f32::sin);
641
642element_wise!(tan, Tan, [f16, f32, f64] => |_, xs| {
643 xs.iter_mut().for_each(|x| *x = x.tan());
644 Ok(())
645};
646q: [i8, u8, i32] => f32::tan);
647
648element_wise!(acos, Acos, [f16, f32, f64] => |_, xs| {
649 xs.iter_mut().for_each(|x| *x = x.acos());
650 Ok(())
651};
652q: [i8, u8, i32] => f32::acos);
653
654element_wise!(asin, Asin, [f16, f32, f64] => |_, xs| {
655 xs.iter_mut().for_each(|x| *x = x.asin());
656 Ok(())
657};
658q: [i8, u8, i32] => f32::asin);
659
660element_wise!(atan, Atan, [f16, f32, f64] => |_, xs| {
661 xs.iter_mut().for_each(|x| *x = x.atan());
662 Ok(())
663};
664q: [i8, u8, i32] => f32::atan);
665
666element_wise!(cosh, Cosh, [f16, f32, f64] => |_, xs| {
667 xs.iter_mut().for_each(|x| *x = x.cosh());
668 Ok(())
669};
670q: [i8, u8, i32] => f32::cosh);
671
672element_wise!(sinh, Sinh, [f16, f32, f64] => |_, xs| {
673 xs.iter_mut().for_each(|x| *x = x.sinh());
674 Ok(())
675};
676q: [i8, u8, i32] => f32::sinh);
677
678element_wise!(tanh, Tanh,
679 [f16] => |_, xs| { (tract_linalg::ops().tanh_f16)().run(xs) },
680 [f32] => |_, xs| { (tract_linalg::ops().tanh_f32)().run(xs) },
681 [f64] => |_, xs| { xs.iter_mut().for_each(|x| *x = x.tanh()); Ok(()) };
682 q: [i8, u8, i32] => f32::tanh;
683 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
684);
685
686element_wise!(erf, Erf,
687 [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) },
688 [f16] => |_, xs| {
689 let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec();
690 (tract_linalg::ops().erf_f32)().run(&mut f32s)?;
691 xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f));
692 Ok(())
693};
694 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
695);
696
697element_wise!(acosh, Acosh, [f16, f32, f64] => |_, xs| {
698 xs.iter_mut().for_each(|x| *x = x.acosh());
699 Ok(())
700};
701q: [i8, u8, i32] => f32::acosh);
702element_wise!(asinh, Asinh, [f16, f32, f64] => |_, xs| {
703 xs.iter_mut().for_each(|x| *x = x.asinh());
704 Ok(())
705};
706q: [i8, u8, i32] => f32::asinh);
707element_wise!(atanh, Atanh, [f16, f32, f64] => |_, xs| {
708 xs.iter_mut().for_each(|x| *x = x.atanh());
709 Ok(())
710};
711q: [i8, u8, i32] => f32::atanh);
712
713element_wise!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| {
714 xs.iter_mut().for_each(|x| *x = -x.clone());
715 Ok(())
716};
717q: [i8, u8, i32] => |x: f32| -x);
718
719element_wise!(sign, Sign, [f16, f32, f64] => |_, xs| {
720 xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() });
721 Ok(())
722};
723q: [i8, u8, i32] => f32::signum);
724
725#[cfg(test)]
726mod tests {
727 use crate::ops::binary::TypedBinOp;
728
729 use super::*;
730 use ndarray::arr2;
731
732 #[test]
733 fn test_mul() {
734 let a = arr2(&[[1., 2.], [3., 4.]]);
735 let b = arr2(&[[1., 0.], [0., 0.]]);
736 assert_eq!(a * b, arr2(&[[1., 0.], [0., 0.]]));
737 }
738
739 #[test]
740 fn dot() {
741 let a = arr2(&[[1., 2.], [3., 4.]]);
742 let b = arr2(&[[1., 0.], [0., 0.]]);
743 assert_eq!(a.dot(&b), arr2(&[[1., 0.], [3., 0.]]));
744 }
745
746 #[test]
747 fn mul_as_shift_left() -> TractResult<()> {
748 let mut model = TypedModel::default();
749 let x = model.add_source("x", i32::fact([2usize, 2]))?;
750 let a = model.add_const("a", tensor0(4i32).broadcast_into_rank(2)?.into_arc_tensor())?;
751 let y = model.wire_node("y", mul(), &[x, a])?[0];
752 model.set_output_outlets(&[y])?;
753 let result = SimplePlan::new(&model)?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
754 assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
755 let decluttered = model.into_decluttered()?;
756 let result =
757 SimplePlan::new(&decluttered)?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
758 assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
759 let op = decluttered
760 .node(decluttered.output_outlets()?[0].node)
761 .op()
762 .downcast_ref::<TypedBinOp>()
763 .unwrap();
764 assert!(op.0.downcast_ref::<ShiftLeft>().is_some());
765 Ok(())
766 }
767
768 #[test]
769 fn div_as_shift() -> TractResult<()> {
770 let mut model = TypedModel::default();
771 let x = model.add_source("a", i32::fact([2usize, 2]))?;
772 let s = model.add_const("shift", tensor2(&[[4]]))?;
773 let y = model.wire_node("c", div(), [x, s].as_ref())?[0];
774 model.set_output_outlets(&[y])?;
775 let result = SimplePlan::new(&model)?.run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
776 assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
777 let decluttered = model.into_decluttered()?;
778 let result =
779 SimplePlan::new(&decluttered)?.run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
780 assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
781 let op = decluttered
782 .node(decluttered.output_outlets()?[0].node)
783 .op()
784 .downcast_ref::<TypedBinOp>()
785 .unwrap();
786 assert!(op.0.downcast_ref::<ShiftRight>().is_some());
787 Ok(())
788 }
789}