1#![allow(clippy::clone_on_copy)]
2#![allow(clippy::unnecessary_cast)]
3#![allow(clippy::blocks_in_conditions)]
4
5use super::array::MultiBroadcastTo;
6use super::binary::TypedBinOp;
7use crate::internal::*;
8use crate::ops::quant::scale_by;
9use num_traits::bounds::Bounded;
10use num_traits::int::PrimInt;
11use num_traits::{Float, Zero};
12use tract_data::internal::ClampCast;
13use tract_data::itertools::Itertools;
14pub use tract_data::prelude::round_ties_to_even;
15use tract_linalg::{ScaleShiftAndRound, Scaler};
16use tract_num_traits::AsPrimitive;
17
18#[cfg(feature = "complex")]
19mod complex;
20#[cfg(feature = "complex")]
21pub use complex::{ComplexToInnerDim, InnerDimToComplex};
22
23bin_to_super_type!(add, Add,
24 linalg: Add,
25 neutral_element: 0,
26 validation: Validation::Rounding,
27 q: [i8, u8, i32, i32] => add_quant;
28 q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
29 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim, String] => |c, a, b| *c = a.clone() + b);
30
31fn add_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
32where
33 T: PrimInt + Bounded + AsPrimitive<i64> + Datum,
34 i64: AsPrimitive<T>,
35{
36 *c = (a.as_() + b.as_() - zp as i64).clamp_cast()
37}
38
39bin_to_super_type!(sub, Sub,
40 linalg:Sub,
41 is_commutative: false,
42 neutral_element: 0,
43 q: [i8, u8, i32, i32] => sub_quant;
44 q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
45 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
46
47bin_to_super_type!(subf, SubF,
48 linalg:SubF,
49 is_commutative: false,
50 neutral_element: 0,
51 q: [i8, u8, i32, i32] => subf_quant;
52 q_op_on_f32: |a: f32, b: f32| -> f32 {b - a},
53 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = b.clone() - a);
54
55fn sub_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
56where
57 T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
58 i16: AsPrimitive<T>,
59{
60 *c = (a.as_() - b.as_() + zp as i16).clamp_cast()
61}
62
63fn subf_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
64where
65 T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
66 i16: AsPrimitive<T>,
67{
68 *c = (b.as_() - a.as_() + zp as i16).clamp_cast()
69}
70
71bin_to_super_type!(mul, Mul,
72 cost: |dt| tvec!((Cost::FMA(dt), 1)),
73 declutter: declutter_mul,
74 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
75 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
77 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
78 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
79 (a.datum_type(), b.datum_type(), c_dt)
80 {
81 let multiplier = a_scale * b_scale * (1.0/ c_scale);
82 let a = a.to_plain_array_view::<u8>()?;
83 let b = b.to_plain_array_view::<u8>()?;
84 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
85 let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
86 let mut c_plain = c.try_as_plain_mut()?;
87 let view = c_plain.to_array_view_mut::<u8>()?;
88 crate::ndarray::Zip::from(view)
89 .and_broadcast(a)
90 .and_broadcast(b)
91 .for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) * (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast());
92 Ok(c)
93 } else {
94 Mul.generic_eval(a, b, c_dt)
95 }
96 },
97 linalg: Mul,
98 neutral_element: 1,
99 out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
100 if c.datum_type() == TDim::datum_type() &&
101 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
102 let a = a.to_plain_array_view::<TDim>()?;
103 let b = b.cast_to::<i32>()?;
104 let b = b.to_plain_array_view::<i32>()?;
105 let mut c_plain = c.try_as_plain_mut()?;
106 let c = c_plain.to_array_view_mut::<TDim>()?;
107 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() * *b);
108 Ok(true)
109 }
110 else {
111 match c.datum_type() {
112 DatumType::QI8(params) => {
113 let (zp, scale) = params.zp_scale();
114 let a = a.to_plain_array_view::<i8>()?;
115 let b = b.to_plain_array_view::<i8>()?;
116 let mut c_plain = c.try_as_plain_mut()?;
117 let c = c_plain.to_array_view_mut::<i8>()?;
118 crate::ndarray::Zip::from(c)
119 .and_broadcast(a)
120 .and_broadcast(b)
121 .for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast());
122 Ok(true)
123 }
124 DatumType::QU8(params) => {
125 let (zp, scale) = params.zp_scale();
126 let a = a.to_plain_array_view::<u8>()?;
127 let b = b.to_plain_array_view::<u8>()?;
128 let mut c_plain = c.try_as_plain_mut()?;
129 let c = c_plain.to_array_view_mut::<u8>()?;
130 crate::ndarray::Zip::from(c)
131 .and_broadcast(a)
132 .and_broadcast(b)
133 .for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast());
134 Ok(true)
135 }
136 _ => Ok(false)
137 }
138 }
139 },
140 q: [i8, u8, i32] => |c, a, b, zp, scale| {
141 *c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast()
142 };
143 q_op_on_f32: |a: f32, b: f32| a * b,
144 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = a.wrapping_mul(*b),
145 [f32, f16, f64] => |c, a, b| *c = a * b,
146 [TDim] => |c, a, b| *c = a.clone() * b
147);
148
149bin_to_super_type!(div, Div,
150cost: |dt| tvec!((Cost::Div(dt), 1)),
151declutter: declutter_div,
152eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
153 if
154 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
155 let a = a.to_plain_array_view::<TDim>()?;
156 let b = b.to_plain_array_view::<TDim>()?;
157 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
158 unsafe {
159 let a = a.broadcast(&*c_shape).unwrap();
160 let b = b.broadcast(&*c_shape).unwrap();
161 let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
162 let mut c_plain = c.try_as_plain_mut()?;
163 let mut view = c_plain.to_array_view_mut::<TDim>()?;
164 for coords in crate::ndarray::indices(&*c_shape) {
165 let (p, q) = a[&coords].maybe_div(&b[&coords])?;
166 view[&coords] = p/q;
167 }
168 Ok(c)
169 }
170 } else if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
171 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
172 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
173 (a.datum_type(), b.datum_type(), c_dt) {
174
175 let multiplier = a_scale / (b_scale * c_scale);
176 let a = a.to_plain_array_view::<u8>()?;
177 let b = b.to_plain_array_view::<u8>()?;
178 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
179 let mut c = Tensor::zero_dt(c_dt, &c_shape)?;
180 let mut c_plain = c.try_as_plain_mut()?;
181 let view = c_plain.to_array_view_mut::<u8>()?;
182 crate::ndarray::Zip::from(view)
183 .and_broadcast(a)
184 .and_broadcast(b)
185 .for_each(|c,a,b| *c = (
187 scale_by(
188 (*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier
189 ) as i32 + c_zp as i32
190 ).clamp_cast());
191 Ok(c)
192 } else {
193 Div.generic_eval(a, b, c_dt)
194 }
195},
196is_commutative: false,
197neutral_element: 1,
198out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
199 if c.datum_type() == TDim::datum_type() &&
200 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
201 let a = a.to_plain_array_view::<TDim>()?;
202 let b = b.cast_to::<i32>()?;
203 let b = b.to_plain_array_view::<i32>()?;
204 let mut c_plain = c.try_as_plain_mut()?;
205 let c = c_plain.to_array_view_mut::<TDim>()?;
206 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() / *b);
207 Ok(true)
208 } else if c.datum_type().is_quantized() || b.datum_type().is_quantized() || a.datum_type().is_quantized() {
209 let a_f32 = a.cast_to::<f32>()?;
210 let a_f32 = a_f32.to_plain_array_view::<f32>()?;
211 let b_f32 = b.cast_to::<f32>()?;
212 let b_f32 = b_f32.to_plain_array_view::<f32>()?;
213 let c_f32 = &a_f32 / &b_f32;
214 *c = c_f32.into_tensor().cast_to_dt(c.datum_type())?.into_owned();
215 Ok(true)
216 } else {
217 Ok(false)
218 }
219},
220q_op_on_f32: |a: f32, b: f32| a / b,
221[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() / b
222);
223
224bin_to_super_type!(rem, Rem,
225 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
226 if
227 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
228 let a = a.to_plain_array_view::<TDim>()?;
229 let b = b.cast_to::<i32>()?;
230 let b = b.to_plain_array_view::<i32>()?;
231 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()]).context("no broadcast solution")?;
232 unsafe {
233 let mut c = Tensor::uninitialized_dt(DatumType::TDim, &c_shape)?;
234 let mut c_plain = c.try_as_plain_mut()?;
235 let view = c_plain.to_array_view_mut::<TDim>()?;
236 crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
237 Ok(c)
238 }
239 } else {
240 Rem.generic_eval(a,b, c_dt)
241 }
242 },
243 out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
244 if c.datum_type() == TDim::datum_type() &&
245 a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
246 let a = a.to_plain_array_view::<TDim>()?;
247 let b = b.cast_to::<i32>()?;
248 let b = b.to_plain_array_view::<i32>()?;
249 let mut c_plain = c.try_as_plain_mut()?;
250 let c = c_plain.to_array_view_mut::<TDim>()?;
251 crate::ndarray::Zip::from(c).and_broadcast(a).and_broadcast(b).for_each(|c,a,b| *c = a.clone() % *b);
252 Ok(true)
253 } else {
254 Ok(false)
255 }
256 },
257 [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64] => |c, a, b| *c = a.clone() % b);
258
259bin_to_super_type!(min, Min, linalg:Min,
260 q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b };
261 q_op_on_f32: |a: f32, b: f32| a.min(b),
262 [f16, f32, f64] => |c,a,b| *c = a.min(*b),
263 [TDim] => |c,a,b| *c = a.clone().mini(b.clone()),
264 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b));
265
266bin_to_super_type!(max, Max,
267 eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
268 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
270 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
271 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
272 (a.datum_type(), b.datum_type(), c_dt)
273 && (a.is_uniform() || b.is_uniform()) {
274 let (d, d_zp, d_scale, e, e_zp, e_scale) = if a.is_uniform() && !b.is_uniform() {
277 (&b, &b_zp, &b_scale, &a, &a_zp, &a_scale)
278 } else {
279 (&a, &a_zp, &a_scale, &b, &b_zp, &b_scale)
280 };
281 if e.is_uniform() { let e = e.cast_to::<u8>()?.try_as_plain()?.as_slice::<u8>()?[0];
283 let e_val_as_d_aligned: i32 = scale_by(e as i32 - e_zp, e_scale / d_scale);
284 let multiplier = d_scale * (1.0/ c_scale);
285 let d = d.to_plain_array_view::<u8>()?;
286 let mut c = Tensor::zero_dt(c_dt, d.shape())?;
287 let mut c_plain = c.try_as_plain_mut()?;
288 let view = c_plain.to_array_view_mut::<u8>()?;
289 crate::ndarray::Zip::from(view)
290 .and_broadcast(d)
291 .for_each(|c,d| {
292 let d_min_zp = *d as i32 - *d_zp as i32;
293 let c_val: i32 = if d_min_zp < e_val_as_d_aligned {
294 e_val_as_d_aligned
295 } else {
296 d_min_zp
297 };
298 *c = (scale_by(c_val, multiplier) + c_zp as i32).clamp_cast();
299 });
300 return Ok(c)
301 }
302 }
303 Max.generic_eval(a, b, c_dt)
304 },
305 linalg:Max,
306 q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a };
307 q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)},
308 [f16, f32, f64] => |c,a,b| *c = a.max(*b),
309 [TDim] => |c,a,b| *c = a.clone().maxi(b.clone()),
310 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b));
311
312bin_to_super_type!(pow, Pow,
313 declutter: declutter_pow,
314 is_commutative: false,
315 neutral_element: 1,
316 q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
317 [f16, f32, f64] => |c,a,b| *c = a.powf(*b),
318 [i32, i64] => |c,a,b| *c = a.pow(*b as u32));
319
320bin_to_super_type!(shift_left, ShiftLeft,
321 is_commutative: false,
322 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
323bin_to_super_type!(shift_right, ShiftRight,
324 is_commutative: false,
325 [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);
326
327fn declutter_mul(
328 _op: &Mul,
329 model: &TypedModel,
330 node: &TypedNode,
331) -> TractResult<Option<TypedModelPatch>> {
332 if node.inputs[0] == node.inputs[1] && !node.outputs[0].fact.datum_type.is_quantized() {
333 return Ok(Some(TypedModelPatch::replace_single_op(
334 model,
335 node,
336 &node.inputs[0..1],
337 square(),
338 )?));
339 }
340
341 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
342 let var_fact = model.outlet_fact(uniform.var)?;
343 if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
344 let shapes =
345 model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
346 let shape: ShapeFact =
347 crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
348 return Ok(Some(TypedModelPatch::rewire(
349 model,
350 &[],
351 &[node.id.into()],
352 &|patch, _| {
353 let scalar = patch.add_const(
354 format!("{}.zero", node.name),
355 if uniform.uni.datum_type().is_quantized() {
356 let output_dt = node.outputs[0].fact.datum_type;
357 Arc::new(uniform.uni.clone().cast_to_dt(output_dt)?.into_owned())
358 } else {
359 uniform.uni.clone()
360 },
361 )?;
362 let op = MultiBroadcastTo::new(shape.clone());
363 patch.wire_node(&node.name, op, &[scalar])
364 },
365 )?));
366 }
367 let dt = uniform.uni.datum_type();
368 if !dt.is_quantized() {
369 let integer = uniform.uni.cast_to_scalar::<i64>()?;
371 if tensor0(integer)
372 .cast_to_dt(uniform.uni.datum_type())?
373 .close_enough(&uniform.uni, false)
374 .is_ok()
375 && uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
376 && dt.is_integer()
377 {
378 let shift = integer.trailing_zeros();
379 return Ok(Some(TypedModelPatch::rewire(
380 model,
381 &[uniform.var],
382 &[node.id.into()],
383 &|patch, taps| {
384 let shift = patch.add_const(
385 format!("{}.shift", node.name),
386 tensor0(shift)
387 .cast_to_dt(dt)?
388 .into_owned()
389 .broadcast_into_rank(var_fact.rank())?,
390 )?;
391 patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
392 },
393 )?));
394 }
395 }
396 }
397 if let Some(patch) = declutter_mul_const_mul_const(model, node)? {
398 return Ok(Some(patch));
399 }
400 Ok(None)
401}
402
403fn declutter_mul_const_mul_const(
404 model: &TypedModel,
405 node: &TypedNode,
406) -> TractResult<Option<TypedModelPatch>> {
407 let input_facts = model.node_input_facts(node.id)?;
408 rule_if_some!(const_slot = input_facts.iter().position(|f| f.konst.is_some()));
409 let prec = model.node(node.inputs[1 - const_slot].node);
410 rule_if_some!(prec_mul = prec.op_as::<TypedBinOp>());
411 rule_if!(prec.outputs[0].successors.len() <= 1);
412 rule_if!(prec_mul.0.is::<Mul>());
413 let prec_input_facts = model.node_input_facts(prec.id)?;
414 rule_if_some!(prec_const_slot = prec_input_facts.iter().position(|f| f.konst.is_some()));
415
416 let const_fact = model.outlet_fact(node.inputs[const_slot])?;
417 let prec_const_fact = model.outlet_fact(prec.inputs[prec_const_slot])?;
418 rule_if!(const_fact.shape.volume().is_one() || prec_const_fact.shape.volume().is_one());
420 rule_if!(const_fact.datum_type.is_float());
421 let result = mul()
422 .eval(tvec!(
423 const_fact.konst.clone().unwrap().into_tvalue(),
424 prec_const_fact.konst.clone().unwrap().into_tvalue()
425 ))?
426 .remove(0)
427 .into_arc_tensor();
428 let mut patch = TypedModelPatch::default();
429 let konst = patch.add_const(&prec.name, result)?;
430 let input_tap = patch.tap_model(model, prec.inputs[1 - prec_const_slot])?;
431 let wire = patch.wire_node(&node.name, mul(), &[konst, input_tap])?;
432 patch.shunt_outside(model, node.id.into(), wire[0])?;
433 Ok(Some(patch))
434}
435
436fn declutter_div(
437 _op: &Div,
438 model: &TypedModel,
439 node: &TypedNode,
440) -> TractResult<Option<TypedModelPatch>> {
441 if let &[p, q] = &*model.node_input_facts(node.id)? {
442 let dt = q.datum_type;
443 if let Some(q) = &q.uniform
444 && let Ok(integer) = q.cast_to_scalar::<i64>()
445 && tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
446 && dt.is_integer()
447 && q.cast_to_scalar::<i64>()?.count_ones() == 1
448 {
449 let shift = integer.trailing_zeros();
450 return Ok(Some(TypedModelPatch::rewire(
451 model,
452 &[node.inputs[0]],
453 &[node.id.into()],
454 &|patch, taps| {
455 let shift = patch.add_const(
456 format!("{}.shift", node.name),
457 tensor0(shift)
458 .cast_to_dt(dt)?
459 .into_owned()
460 .broadcast_into_rank(p.rank())?,
461 )?;
462 patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
463 },
464 )?));
465 }
466 if dt.is_float() {
467 return Ok(Some(TypedModelPatch::rewire(
468 model,
469 &node.inputs,
470 &[node.id.into()],
471 &|patch, taps| {
472 let q =
473 patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?[0];
474 patch.wire_node(&node.name, mul(), &[taps[0], q])
475 },
476 )?));
477 }
478 }
479 Ok(None)
480}
481
482fn declutter_pow(
483 _op: &Pow,
484 model: &TypedModel,
485 node: &TypedNode,
486) -> TractResult<Option<TypedModelPatch>> {
487 let b = model.outlet_fact(node.inputs[1])?;
488 if let Some(b) = &b.uniform {
489 let b = b.cast_to_scalar::<f32>()?;
490 if b == 2.0 {
491 return Ok(Some(TypedModelPatch::replace_single_op(
492 model,
493 node,
494 &[node.inputs[0]],
495 square(),
496 )?));
497 } else if b == 0.5 {
498 return Ok(Some(TypedModelPatch::replace_single_op(
499 model,
500 node,
501 &[node.inputs[0]],
502 sqrt(),
503 )?));
504 }
505 }
506 crate::ops::nn::gelu_approximate::detect_gelu_approx(_op, model, node)
507}
508
509element_wise!(abs, Abs, [i8, i16, i32, i64, f16, f32, i32] => |_, xs| {
510 xs.iter_mut().for_each(|x| *x = x.abs());
511 Ok(())
512};
513q: [i8, u8, i32, i32] => f32::abs;
514operating_datum_type: |dt| if dt == TDim::datum_type() { i64::datum_type() } else { dt }
515);
516
517element_wise!(exp, Exp, [f16, f32, f64] => |_, xs| {
518 xs.iter_mut().for_each(|x| *x = x.exp());
519 Ok(())
520};
521q: [i8, u8, i32, i32] => f32::exp;
522validation: Validation::Rounding
523);
524
525element_wise!(ln, Ln, [f16, f32, f64] => |_, xs| {
526 xs.iter_mut().for_each(|x| *x = x.ln());
527 Ok(())
528};
529q: [i8, u8, i32, i32] => f32::ln;
530validation: Validation::Rounding
531);
532
533element_wise!(square, Square, [f16, f32, f64] => |_, xs| {
534 xs.iter_mut().for_each(|x| *x = x.powi(2));
535 Ok(())
536};
537q: [i8, u8, i32, i32] => |f : f32| f.powi(2);
538declutter: declutter_square;
539validation: Validation::Rounding
540);
541
542fn declutter_square(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
543 use super::element_wise::*;
544 if let Some(prec) = model.linear_prec(node.id)?
546 && let Some(ew) = prec.op_as::<ElementWiseOp>()
547 && ew.0.is::<Sqrt>()
548 {
549 let mut patch = TypedModelPatch::default();
550 let tap = patch.tap_model(model, prec.inputs[0])?;
551 patch.shunt_outside(model, node.id.into(), tap)?;
552 return Ok(Some(patch));
553 }
554 Ok(None)
555}
556
557element_wise!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| {
558 xs.iter_mut().for_each(|x| *x = x.sqrt());
559 Ok(())
560};
561q: [i8, u8, i32, i32] => f32::sqrt;
562validation: Validation::Rounding
563);
564
565element_wise!(recip, Recip, [f16, f32, f64] => |_, xs| {
566 xs.iter_mut().for_each(|x| *x = x.recip());
567 Ok(())
568};
569q: [i8, u8, i32, i32] => f32::recip;
570cost: |dt| {tvec!((Cost::Div(dt), 1))};
571declutter: declutter_recip;
572validation: Validation::Rounding
573);
574
575fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
576 use super::element_wise::*;
577 if let Some(prec) = model.linear_prec(node.id)?
578 && let Some(ew) = prec.op_as::<ElementWiseOp>()
579 {
580 let repl = if ew.0.is::<Sqrt>() {
581 Some(rsqrt())
582 } else if ew.0.is::<Rsqrt>() {
583 Some(sqrt())
584 } else {
585 None
586 };
587 if let Some(repl) = repl {
588 let mut patch = TypedModelPatch::default();
589 let mut wire = patch.tap_model(model, prec.inputs[0])?;
590 wire = patch.wire_node(&node.name, repl, &[wire])?[0];
591 patch.shunt_outside(model, node.id.into(), wire)?;
592 return Ok(Some(patch));
593 }
594 }
595 Ok(None)
596}
597
598element_wise!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| {
599 xs.iter_mut().for_each(|x| *x = x.sqrt().recip());
600 Ok(())
601};
602q: [i8, u8, i32] => |x : f32| x.sqrt().recip();
603validation: Validation::Rounding
604);
605
606element_wise!(ceil, Ceil, [f16, f32, f64] => |_, xs| {
607 xs.iter_mut().for_each(|x| *x = x.ceil());
608 Ok(())
609}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
610q: [i8, u8, i32] => f32::recip);
611
612element_wise!(floor, Floor, [f16, f32, f64] => |_, xs| {
613 xs.iter_mut().for_each(|x| *x = x.floor());
614 Ok(())
615}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
616q: [i8, u8, i32] => f32::floor);
617
618element_wise!(round, Round, [f16, f32, f64] => |_, xs| {
619 xs.iter_mut().for_each(|x| *x = x.round());
620 Ok(())
621}, [i8, i16,i32, i64, u8, u16, u32, u64, TDim] => |_, _| Ok(());
622q: [i8, u8, i32] => f32::round);
623
624element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| {
625 xs.iter_mut().for_each(|x| *x = x.q_scale(op.scaler));
626 Ok(())
627});
628
629element_wise!(round_half_to_even, RoundHalfToEven,
630[f32] => |_, xs| {
631 xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x));
632 Ok(())
633},
634[f16] => |_, xs| {
635 xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32())));
636 Ok(())
637};
638q: [i8, u8, i32] => round_ties_to_even);
639
640element_wise!(cos, Cos, [f16, f32, f64] => |_, xs| {
641 xs.iter_mut().for_each(|x| *x = x.cos());
642 Ok(())
643};
644q: [i8, u8, i32] => f32::cos);
645
646element_wise!(sin, Sin, [f16, f32, f64] => |_, xs| {
647 xs.iter_mut().for_each(|x| *x = x.sin());
648 Ok(())
649};
650q: [i8, u8, i32] => f32::sin);
651
652element_wise!(tan, Tan, [f16, f32, f64] => |_, xs| {
653 xs.iter_mut().for_each(|x| *x = x.tan());
654 Ok(())
655};
656q: [i8, u8, i32] => f32::tan);
657
658element_wise!(acos, Acos, [f16, f32, f64] => |_, xs| {
659 xs.iter_mut().for_each(|x| *x = x.acos());
660 Ok(())
661};
662q: [i8, u8, i32] => f32::acos);
663
664element_wise!(asin, Asin, [f16, f32, f64] => |_, xs| {
665 xs.iter_mut().for_each(|x| *x = x.asin());
666 Ok(())
667};
668q: [i8, u8, i32] => f32::asin);
669
670element_wise!(atan, Atan, [f16, f32, f64] => |_, xs| {
671 xs.iter_mut().for_each(|x| *x = x.atan());
672 Ok(())
673};
674q: [i8, u8, i32] => f32::atan);
675
676element_wise!(cosh, Cosh, [f16, f32, f64] => |_, xs| {
677 xs.iter_mut().for_each(|x| *x = x.cosh());
678 Ok(())
679};
680q: [i8, u8, i32] => f32::cosh);
681
682element_wise!(sinh, Sinh, [f16, f32, f64] => |_, xs| {
683 xs.iter_mut().for_each(|x| *x = x.sinh());
684 Ok(())
685};
686q: [i8, u8, i32] => f32::sinh);
687
688element_wise!(tanh, Tanh,
689 [f16] => |_, xs| { (tract_linalg::ops().tanh_f16)().run(xs) },
690 [f32] => |_, xs| { (tract_linalg::ops().tanh_f32)().run(xs) },
691 [f64] => |_, xs| { xs.iter_mut().for_each(|x| *x = x.tanh()); Ok(()) };
692 q: [i8, u8, i32] => f32::tanh;
693 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
694);
695
696element_wise!(erf, Erf,
697 [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) },
698 [f16] => |_, xs| {
699 let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec();
700 (tract_linalg::ops().erf_f32)().run(&mut f32s)?;
701 xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f));
702 Ok(())
703};
704 cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
705);
706
707element_wise!(acosh, Acosh, [f16, f32, f64] => |_, xs| {
708 xs.iter_mut().for_each(|x| *x = x.acosh());
709 Ok(())
710};
711q: [i8, u8, i32] => f32::acosh);
712element_wise!(asinh, Asinh, [f16, f32, f64] => |_, xs| {
713 xs.iter_mut().for_each(|x| *x = x.asinh());
714 Ok(())
715};
716q: [i8, u8, i32] => f32::asinh);
717element_wise!(atanh, Atanh, [f16, f32, f64] => |_, xs| {
718 xs.iter_mut().for_each(|x| *x = x.atanh());
719 Ok(())
720};
721q: [i8, u8, i32] => f32::atanh);
722
723element_wise!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| {
724 xs.iter_mut().for_each(|x| *x = -x.clone());
725 Ok(())
726};
727q: [i8, u8, i32] => |x: f32| -x);
728
729element_wise!(sign, Sign, [f16, f32, f64] => |_, xs| {
730 xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() });
731 Ok(())
732};
733q: [i8, u8, i32] => f32::signum);
734
735element_wise_oop!(is_inf, IsInf { detect_positive: bool, detect_negative: bool },
736 [f32] => bool |op, xs, ys| {
737 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
738 *y = (op.detect_positive && *x == f32::INFINITY) || (op.detect_negative && *x == f32::NEG_INFINITY)
739 );
740 Ok(())
741 },
742 [f16] => bool |op, xs, ys| {
743 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
744 *y = (op.detect_positive && *x == f16::INFINITY) || (op.detect_negative && *x == f16::NEG_INFINITY)
745 );
746 Ok(())
747 }
748);
749
750element_wise_oop!(is_nan, IsNan,
751 [f16, f32] => bool |_, xs, ys| {
752 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = x.is_nan());
753 Ok(())
754 }
755);
756
757#[cfg(test)]
758mod tests {
759 use crate::ops::binary::TypedBinOp;
760
761 use super::*;
762 use ndarray::arr2;
763
764 #[test]
765 fn test_mul() {
766 let a = arr2(&[[1., 2.], [3., 4.]]);
767 let b = arr2(&[[1., 0.], [0., 0.]]);
768 assert_eq!(a * b, arr2(&[[1., 0.], [0., 0.]]));
769 }
770
771 #[test]
772 fn dot() {
773 let a = arr2(&[[1., 2.], [3., 4.]]);
774 let b = arr2(&[[1., 0.], [0., 0.]]);
775 assert_eq!(a.dot(&b), arr2(&[[1., 0.], [3., 0.]]));
776 }
777
778 #[test]
779 fn mul_as_shift_left() -> TractResult<()> {
780 let mut model = TypedModel::default();
781 let x = model.add_source("x", i32::fact([2usize, 2]))?;
782 let a = model.add_const("a", tensor0(4i32).broadcast_into_rank(2)?.into_arc_tensor())?;
783 let y = model.wire_node("y", mul(), &[x, a])?[0];
784 model.select_output_outlets(&[y])?;
785 let result =
786 SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
787 assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
788 let decluttered = model.into_decluttered()?;
789 let result =
790 SimplePlan::new(decluttered.clone())?.run(tvec!(tensor2(&[[1, 2], [3, 4]]).into()))?;
791 assert_eq!(*result[0], tensor2(&[[4, 8], [12, 16]]));
792 let op = decluttered
793 .node(decluttered.output_outlets()?[0].node)
794 .op()
795 .downcast_ref::<TypedBinOp>()
796 .unwrap();
797 assert!(op.0.downcast_ref::<ShiftLeft>().is_some());
798 Ok(())
799 }
800
801 #[test]
802 fn div_as_shift() -> TractResult<()> {
803 let mut model = TypedModel::default();
804 let x = model.add_source("a", i32::fact([2usize, 2]))?;
805 let s = model.add_const("shift", tensor2(&[[4]]))?;
806 let y = model.wire_node("c", div(), [x, s].as_ref())?[0];
807 model.select_output_outlets(&[y])?;
808 let result =
809 SimplePlan::new(model.clone())?.run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
810 assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
811 let decluttered = model.into_decluttered()?;
812 let result = SimplePlan::new(decluttered.clone())?
813 .run(tvec!(tensor2(&[[16, 32], [64, 68]]).into()))?;
814 assert_eq!(*result[0], tensor2(&[[4, 8], [16, 17]]));
815 let op = decluttered
816 .node(decluttered.output_outlets()?[0].node)
817 .op()
818 .downcast_ref::<TypedBinOp>()
819 .unwrap();
820 assert!(op.0.downcast_ref::<ShiftRight>().is_some());
821 Ok(())
822 }
823}