1#![allow(clippy::unnecessary_cast)]
2
3use crate::internal::*;
4use crate::ops::element_wise::ElementWiseOp;
5use crate::ops::math::QScale;
6use num_traits::AsPrimitive;
7use tract_linalg::Scaler;
8use tract_linalg::lut::Lut;
9use tract_linalg::mmm::RoundingPolicy;
10
11use super::binary::TypedBinOp;
12use super::math::round_ties_to_even;
13
14pub fn quantize_linear_f32_u8(x: f32, scale: f32, zero_point: i32) -> u8 {
15 (((x * scale).round() as i32) + zero_point).clamp(u8::MIN as i32, u8::MAX as i32) as u8
16}
17
18pub fn quantize_linear_f32_i8(x: f32, scale: f32, zero_point: i32) -> i8 {
19 (((x * scale).round() as i32) + zero_point).clamp(i8::MIN as i32, i8::MAX as i32) as i8
20}
21
22element_wise_oop!(quantize_linear_u8,
23 QuantizeLinearU8 {
24 scale: f32,
25 zero_point: u8
26 },
27 [f16] => u8 |op, xs, ys| {
28 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
29 *y = quantize_linear_f32_u8(x.to_f32(), op.scale, op.zero_point as i32)
30 );
31 Ok(())
32 },
33 [f32,i32] => u8 |op, xs, ys| {
34 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
35 *y = quantize_linear_f32_u8(*x as f32, op.scale, op.zero_point as i32)
36 );
37 Ok(())
38 };
39 info: info_quantize_linear_u8
40);
41
42fn info_quantize_linear_u8(q: &QuantizeLinearU8) -> TractResult<Vec<String>> {
43 Ok(vec![format!(
44 "scale: {} zero_point: {} 1/scale: {}",
45 q.scale,
46 q.zero_point,
47 q.scale.recip()
48 )])
49}
50
51element_wise_oop!(quantize_linear_i8,
52 QuantizeLinearI8 {
53 scale: f32,
54 zero_point: i8
55 },
56 [f32,i32] => i8 |op, xs, ys| {
57 xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
58 *y = quantize_linear_f32_i8(*x as f32, op.scale, op.zero_point as i32)
59 );
60 Ok(())
61 };
62 info: info_quantize_linear_i8
63);
64
65fn info_quantize_linear_i8(q: &QuantizeLinearI8) -> TractResult<Vec<String>> {
66 Ok(vec![format!(
67 "scale: {} zero_point: {} 1/scale: {}",
68 q.scale,
69 q.zero_point,
70 q.scale.recip()
71 )])
72}
73
74#[derive(Clone, Debug, new)]
75pub struct DequantizeLinearF32 {
76 pub scale: f32,
77 pub zero_point: i32,
78}
79
80impl DequantizeLinearF32 {
81 fn eval_t<T: Datum + AsPrimitive<i32>>(&self, input: &Tensor) -> TractResult<Tensor> {
82 let mut output = unsafe { Tensor::uninitialized::<f32>(input.shape())? };
83 input
84 .try_as_dense()?
85 .as_slice::<T>()?
86 .iter()
87 .zip(output.try_as_dense_mut()?.as_slice_mut::<f32>()?.iter_mut())
88 .for_each(|(x, y)| *y = (x.as_() - self.zero_point) as f32 * self.scale);
89 Ok(output)
90 }
91}
92
93impl Op for DequantizeLinearF32 {
94 fn name(&self) -> StaticName {
95 "DequantizeLinearF32".into()
96 }
97
98 fn info(&self) -> TractResult<Vec<String>> {
99 Ok(vec![format!("scale: {} zero_point: {}", self.scale, self.zero_point)])
100 }
101
102 fn validation(&self) -> Validation {
103 Validation::Accurate
104 }
105
106 op_as_typed_op!();
107}
108
109impl EvalOp for DequantizeLinearF32 {
110 fn is_stateless(&self) -> bool {
111 true
112 }
113 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
114 let output = match inputs[0].datum_type() {
115 DatumType::I8 => self.eval_t::<i8>(&inputs[0])?,
116 DatumType::I32 => self.eval_t::<i32>(&inputs[0])?,
117 DatumType::U8 => self.eval_t::<u8>(&inputs[0])?,
118 dt => bail!("Unsupported type {:?}", dt),
119 };
120 Ok(tvec!(output.into_tvalue()))
121 }
122}
123
124impl TypedOp for DequantizeLinearF32 {
125 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
126 let mut fact = inputs[0].clone();
127 fact.datum_type = f32::datum_type();
128 Ok(tvec!(fact))
129 }
130
131 fn axes_mapping(
132 &self,
133 inputs: &[&TypedFact],
134 outputs: &[&TypedFact],
135 ) -> TractResult<AxesMapping> {
136 AxesMapping::natural(inputs, outputs)
137 }
138
139 fn change_axes(
140 &self,
141 model: &TypedModel,
142 node: &TypedNode,
143 _io: InOut,
144 change: &AxisOp,
145 ) -> TractResult<Option<AxisChangeConsequence>> {
146 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
147 }
148
149 fn declutter(
150 &self,
151 model: &TypedModel,
152 dequant: &TypedNode,
153 ) -> TractResult<Option<TypedModelPatch>> {
154 let mut current = dequant;
155 let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
156 while let Some(quant) = model.single_succ(current.id)? {
157 let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
158 if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
159 Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
160 } else {
161 op.0.downcast_ref::<QuantizeLinearI8>()
162 .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
163 }
164 } else {
165 None
166 };
167 if let Some((scale, zero_point, dt)) = q_params {
168 let mut patch = TypedModelPatch::default();
170 let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
171 let mut next = model.single_succ(dequant.id)?.unwrap();
172 loop {
173 if let Some(op) = next
174 .op
175 .quantize(model, dequant, dt, scale, zero_point)
176 .with_context(|| format!("Quantizing {next}"))?
177 {
178 wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
179 } else {
180 break;
181 }
182 if next.id == current.id {
183 patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
184 return Ok(Some(patch));
185 } else {
186 next = model.single_succ(next.id)?.unwrap();
187 }
188 }
189 if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
191 let mut adhoc_model = TypedModel::default();
192 let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
193 let mut next = model.single_succ(dequant.id)?.unwrap();
194 let mut name = None;
195 wire = adhoc_model.wire_node(
197 &*dequant.name,
198 dequant.op.clone(),
199 [wire].as_ref(),
200 )?[0];
201 while next.id != quant.id {
202 name.get_or_insert(&*next.name);
203 wire =
204 adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
205 [0];
206 next = model.single_succ(next.id)?.unwrap();
207 }
208 wire =
210 adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
211 adhoc_model.set_output_outlets(&[wire])?;
212 let input = (0u8..=255).collect::<Vec<u8>>();
213 let input = match dt {
214 DatumType::I8 => unsafe {
215 tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
216 },
217 DatumType::U8 => tensor1(&input),
218 _ => unreachable!(),
219 };
220 let output =
221 SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
222 let table: &[u8] = match dt {
223 DatumType::I8 => unsafe {
224 std::mem::transmute::<&[i8], &[u8]>(
225 output.try_as_dense()?.as_slice::<i8>()?,
226 )
227 },
228 DatumType::U8 => output.try_as_dense()?.as_slice::<u8>()?,
229 _ => unreachable!(),
230 };
231 let op = lookup_table((tract_linalg::ops().lut_u8)(table));
232 let mut patch = TypedModelPatch::default();
233 let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
234
235 wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
236 patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
237 return Ok(Some(patch));
238 }
239 }
240 let (input_facts, output_facts) = model.node_facts(quant.id)?;
241 let invariants = quant
242 .op
243 .axes_mapping(&input_facts, &output_facts)
244 .with_context(|| format!("Querying invariants for {quant}"))?;
245 if invariants.is_element_wise_unary() {
246 current = quant;
247 } else {
248 break;
249 }
250 }
251 Ok(None)
252 }
253
254 as_op!();
255}
256
257element_wise_oop!(lookup_table,
258 LookupTable {
259 table: Box<dyn Lut>
260 },
261 [i8] => i8 |op, xs, ys| {
262 ys.copy_from_slice(xs);
263 unsafe {
264 let casted = std::slice::from_raw_parts_mut(ys.as_mut_ptr() as *mut u8, ys.len());
265 op.table.run(casted);
266 }
267 Ok(())
268 },
269 [u8] => u8 |op, xs, ys| {
270 ys.copy_from_slice(xs);
271 op.table.run(ys);
272 Ok(())
273 }
274);
275
276#[derive(Debug, Clone, Hash)]
277pub struct Scale;
278
279impl crate::ops::binary::BinMiniOp for Scale {
280 fn name(&self) -> &'static str {
281 "Scale"
282 }
283 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
284 if !a.is_float() {
285 bail!("Scale left operand must be float, got {:?}", a);
286 }
287 Ok(b)
288 }
289
290 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
291 if !a.is_float() {
292 bail!("Scale left operand must be float, got {:?}", a);
293 }
294 Ok(b)
295 }
296
297 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
298 let a = a.cast_to::<f32>()?;
299 let a = a.to_dense_array_view::<f32>()?;
300 unsafe fn eval_out_of_place_t<T: Datum + AsPrimitive<f32>>(
301 c: &mut Tensor,
302 a: &ndarray::ArrayViewD<f32>,
303 b: &Tensor,
304 ) where
305 f32: AsPrimitive<T>,
306 {
307 let b = unsafe { b.to_array_view_unchecked::<T>() };
308 let mut c = unsafe { c.to_array_view_mut_unchecked::<T>() };
309 ndarray::Zip::from(&mut c)
310 .and_broadcast(a)
311 .and_broadcast(b)
312 .for_each(|c, a, b| *c = scale_by(*b, *a))
313 }
314 unsafe { dispatch_numbers!(eval_out_of_place_t(b.datum_type())(c, &a, b)) }
315 Ok(())
316 }
317
318 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
319 let mut a_dense = a.try_as_dense_mut()?;
320 let a = a_dense.to_array_view_mut::<f32>()?;
321 let b = b.to_dense_array_view::<f32>()?;
322 ndarray::Zip::from(a).and_broadcast(b).for_each(|a, b| *a = scale_by(*b, *a));
323 Ok(())
324 }
325
326 fn is_commutative(&self) -> bool {
327 false
328 }
329
330 fn declutter(
331 &self,
332 model: &TypedModel,
333 node: &TypedNode,
334 ) -> TractResult<Option<TypedModelPatch>> {
335 let a = model.outlet_fact(node.inputs[0])?;
336 if let Some(a) = &a.uniform {
337 if a.cast_to_scalar::<f32>()? == 1. {
338 return Ok(Some(TypedModelPatch::rewire(
339 model,
340 &node.inputs[1..2],
341 &[node.id.into()],
342 &|_p, x| Ok(x.into()),
343 )?));
344 } else if node.outputs[0].fact.datum_type == DatumType::I32 {
345 let factor = a.cast_to_scalar::<f32>()?;
346 let scaler = Scaler::new(factor, RoundingPolicy::Even);
347
348 let op = ElementWiseOp(Box::new(QScale { scaler }), None);
349 let patch =
350 TypedModelPatch::replace_single_op(model, node, &node.inputs[1..2], op)?;
351
352 return Ok(Some(patch));
353 }
354 }
355 Ok(None)
356 }
357}
358
359#[inline]
360pub(crate) fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
361where
362 f32: AsPrimitive<T>,
363{
364 let b = b.as_();
365 (round_ties_to_even(b.abs() * a) * b.signum()).as_()
366}
367
368pub fn scale() -> TypedBinOp {
369 TypedBinOp(Box::new(Scale), None)
370}
371
372pub(crate) fn offset_i8_as_u8_elementwise(x: i8) -> u8 {
374 (x as u8).wrapping_add(128)
375}
376
377#[derive(Debug, Clone)]
378pub struct OffsetI8asU8 {}
379impl ElementWiseMiniOp for OffsetI8asU8 {
380 fn name(&self) -> String {
381 format!("{}{}", self.prefix(), stringify!(OffsetI8asU8))
382 }
383 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
384 Some(if let DatumType::QI8(qp) = input_type {
385 let (zp, scale) = qp.zp_scale();
386 DatumType::QU8(QParams::ZpScale { zero_point: zp + 128, scale })
387 } else if input_type == DatumType::I8 {
388 DatumType::U8
389 } else {
390 input_type
391 })
392 }
393 fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
394 let output_type = out_dt.unwrap_or(self.output_type(t.datum_type()).unwrap());
395 let mut dst = unsafe { Tensor::uninitialized_dt(output_type, t.shape())? };
396 if t.datum_type().unquantized() == i8::datum_type() {
397 t.try_as_dense()?
398 .as_slice::<i8>()?
399 .iter()
400 .zip(dst.try_as_dense_mut()?.as_slice_mut::<u8>()?.iter_mut())
401 .for_each(|(x, y)| *y = offset_i8_as_u8_elementwise(*x));
402 return Ok(dst);
403 }
404
405 bail!("{} does not support {:?}", self.name(), t.datum_type());
406 }
407}
408
409pub fn offset_i8_as_u8() -> ElementWiseOp {
410 ElementWiseOp(Box::new(OffsetI8asU8 {}), None)
411}
412
413pub(crate) fn offset_u8_as_i8_elementwise(x: u8) -> i8 {
415 x.wrapping_sub(128) as i8
416}
417
418#[derive(Debug, Clone)]
419pub struct OffsetU8asI8 {}
420impl ElementWiseMiniOp for OffsetU8asI8 {
421 fn name(&self) -> String {
422 format!("{}{}", self.prefix(), stringify!(OffsetU8asI8))
423 }
424 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
425 Some(if let DatumType::QU8(qp) = input_type {
426 let (zp, scale) = qp.zp_scale();
427 DatumType::QI8(QParams::ZpScale { zero_point: zp - 128, scale })
428 } else if input_type == DatumType::U8 {
429 DatumType::I8
430 } else {
431 input_type
432 })
433 }
434 fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
435 let output_type = out_dt.unwrap_or(self.output_type(t.datum_type()).unwrap());
436 let mut dst = unsafe { Tensor::uninitialized_dt(output_type, t.shape())? };
437 if t.datum_type().unquantized() == u8::datum_type() {
438 t.try_as_dense()?
439 .as_slice::<u8>()?
440 .iter()
441 .zip(dst.try_as_dense_mut()?.as_slice_mut::<i8>()?.iter_mut())
442 .for_each(|(x, y)| *y = offset_u8_as_i8_elementwise(*x));
443 return Ok(dst);
444 }
445
446 bail!("{} does not support {:?}", self.name(), t.datum_type());
447 }
448}
449pub fn offset_u8_as_i8() -> ElementWiseOp {
450 ElementWiseOp(Box::new(OffsetU8asI8 {}), None)
451}
452
453#[cfg(test)]
454pub mod scale {
455 use crate::internal::*;
456 use crate::ops::einsum::EinSum;
457 use crate::ops::math::round_ties_to_even;
458 use proptest::prelude::*;
459
460 fn test_scale(a: i8, b: i8, scale: f32) {
461 let expected = (((a as i32) * (b as i32)) as f32) / scale;
462 let expected = round_ties_to_even(expected.abs()) * expected.signum();
463 let expected = (expected as i32).clamp(-128, 127);
464 let expected = tensor2(&[[expected as i8]]);
465
466 let input = tvec!(tensor2(&[[b]]).into_tvalue());
467 let mut model = TypedModel::default();
468 let a = model.add_const("a", tensor2(&[[a]])).unwrap();
469 let b = model.add_source("b", i8::fact([1, 1])).unwrap();
470 let bias = model.add_const("bias", tensor0(0i32)).unwrap();
471 let a0 = model.add_const("a0", tensor0(0i8)).unwrap();
472 let a_scale = model.add_const("a_scale", tensor0(1f32)).unwrap();
473 let b0 = model.add_const("b0", tensor0(0i8)).unwrap();
474 let b_scale = model.add_const("b_scale", tensor0(1f32)).unwrap();
475 let c0 = model.add_const("c0", tensor0(0i8)).unwrap();
476 let c_scale = model.add_const("c_scale", tensor0(scale)).unwrap();
477 let op = EinSum {
478 axes: "mk,kn,,,,,,,->mn".parse().unwrap(),
479 operating_dt: i32::datum_type(),
480 q_params: Some(i8::datum_type()),
481 };
482 let output = model
483 .wire_node("mmm", op, &[a, b, bias, a0, a_scale, b0, b_scale, c0, c_scale])
484 .unwrap();
485 model.set_output_outlets(&output).unwrap();
486
487 let plain = model.clone().into_runnable().unwrap().run(input.clone()).unwrap();
488 assert_eq!(*plain[0], expected);
489
490 let optim = model.into_optimized().unwrap().into_runnable().unwrap().run(input).unwrap();
491 assert_eq!(*optim[0], expected);
492 }
493
494 proptest! {
495 #[test]
496 fn prop(a in any::<i8>(), b in any::<i8>(), scale in 0.00001f32..1000.) {
497 test_scale(a, b, scale);
498 }
499 }
500
501 #[test]
502 fn t1() {
503 test_scale(-117, 15, 37.753822);
504 }
505
506 #[test]
507 fn t2() {
508 test_scale(-4, -60, 475.21674);
509 }
510}