Skip to main content

tract_core/ops/nn/softmax/
mod.rs

1mod fixedpoint;
2pub mod math;
3
4use math::{
5    convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6    rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7    saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
17pub enum SoftmaxKind {
18    Softmax(SoftmaxExp),
19    LogSoftmax,
20}
21
22impl Default for SoftmaxKind {
23    fn default() -> Self {
24        SoftmaxKind::Softmax(SoftmaxExp::default())
25    }
26}
27
28#[derive(Debug, Copy, Clone, Hash, Default, PartialEq, Eq)]
29pub enum SoftmaxExp {
30    #[default]
31    Libc,
32    // https://nic.schraudolph.org/pubs/Schraudolph99.pdf
33    FastCompact,
34}
35
36#[derive(Debug, Clone, new, Hash, Default, PartialEq, Eq)]
37pub struct Softmax {
38    pub axes: TVec<usize>,
39    pub quant_output_dt: Option<DatumType>,
40    pub kind: SoftmaxKind,
41}
42
43impl Op for Softmax {
44    fn name(&self) -> StaticName {
45        match self.kind {
46            SoftmaxKind::Softmax(_) => "Softmax".into(),
47            SoftmaxKind::LogSoftmax => "LogSoftmax".into(),
48        }
49    }
50
51    fn info(&self) -> TractResult<Vec<String>> {
52        let mut infos = vec![format!("Axis: {:?}", self.axes)];
53        if let SoftmaxKind::Softmax(exp) = self.kind {
54            infos.push(format!("Exp impl: {exp:?}"))
55        };
56        Ok(infos)
57    }
58
59    op_as_typed_op!();
60}
61
62impl TypedOp for Softmax {
63    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64        let dt = inputs[0].datum_type;
65        if dt.is_float() {
66            ensure!(
67                self.quant_output_dt.is_none(),
68                "Float softmax should not have quant_output_dt, have {:?}",
69                self.quant_output_dt
70            );
71        } else if dt.is_quantized() {
72            ensure!(
73                self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
74                "Quantized softmax should have a quantized output type (got {:?})",
75                self.quant_output_dt
76            );
77        } else {
78            bail!(
79                "Unsupported datum type in softmax: input type {:?}, output type {:?}",
80                dt,
81                self.quant_output_dt
82            );
83        }
84
85        let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
86        Ok(tvec!(fact))
87    }
88
89    fn input_roi(
90        &self,
91        model: &TypedModel,
92        node: &TypedNode,
93    ) -> TractResult<Option<TVec<Option<TDim>>>> {
94        crate::optim::propagate_roi::bubble_roi(model, node)
95    }
96
97    fn axes_mapping(
98        &self,
99        inputs: &[&TypedFact],
100        outputs: &[&TypedFact],
101    ) -> TractResult<AxesMapping> {
102        AxesMapping::natural(inputs, outputs)
103    }
104
105    fn change_axes(
106        &self,
107        model: &TypedModel,
108        node: &TypedNode,
109        _io: InOut,
110        change: &AxisOp,
111    ) -> TractResult<Option<AxisChangeConsequence>> {
112        let axes: Option<TVec<usize>> =
113            self.axes.iter().map(|it| change.transform_axis(*it)).collect();
114        if let Some(axes) = axes {
115            Ok(Some(AxisChangeConsequence::new(
116                model,
117                node,
118                Some(Box::new(Softmax { axes, ..self.clone() })),
119                change,
120            )))
121        } else {
122            Ok(None)
123        }
124    }
125
126    as_op!();
127}
128
129impl EvalOp for Softmax {
130    fn is_stateless(&self) -> bool {
131        true
132    }
133
134    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
135        let input = args_1!(inputs);
136        let dt = input.datum_type();
137
138        let output = match dt {
139            DatumType::F64 => self.eval_t::<f64>(input)?,
140            DatumType::F32 => self.eval_t::<f32>(input)?,
141            DatumType::F16 => self.eval_t::<f16>(input)?,
142            DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
143            dt => bail!("Unsupported type {dt:?}"),
144        };
145        Ok(output)
146    }
147}
148
149impl Softmax {
150    fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
151    where
152        T: Float + Datum + std::iter::Sum,
153    {
154        let mut iterating_shape: TVec<usize> = input.shape().into();
155
156        for i in 0..iterating_shape.len() {
157            if self.axes.contains(&i) {
158                iterating_shape[i] = 1
159            }
160        }
161
162        let mut output = input.into_tensor();
163        let mut output_plain = output.try_as_plain_mut()?;
164        let mut view = output_plain.to_array_view_mut::<T>()?;
165
166        for it_coords in tract_ndarray::indices(&*iterating_shape) {
167            let mut view = view.view_mut();
168            for ix in 0..iterating_shape.len() {
169                if !self.axes.contains(&ix) {
170                    view.collapse_axis(Axis(ix), it_coords[ix]);
171                }
172            }
173            if let Some(slice) =
174                view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
175            {
176                let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
177                self.softmax_inner_slice_f32(slice, self.kind)?;
178            } else if let Some(slice) =
179                view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
180            {
181                let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
182                self.softmax_inner_slice_f16(slice, self.kind)?;
183            } else {
184                softmax_inner(view, self.kind);
185            }
186        }
187
188        Ok(tvec!(output.into_tvalue()))
189    }
190
191    fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
192        if self.kind == SoftmaxKind::LogSoftmax {
193            bail!("Quantized LogSoftmax is not supported")
194        }
195        let mut iterating_shape: TVec<usize> = input.shape().into();
196        let output_dt =
197            self.quant_output_dt.context("Quandized softmax eval with no output type")?;
198
199        for i in 0..iterating_shape.len() {
200            if self.axes.contains(&i) {
201                iterating_shape[i] = 1
202            }
203        }
204
205        // All operations will be done in u8, we will cast the result appropriately afterward.
206        let src_is_signed = input.datum_type().is_signed();
207        let out_is_signed = output_dt.is_signed();
208        let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case
209        let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case
210        let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
211
212        for it_coords in tract_ndarray::indices(&*iterating_shape) {
213            let mut view = output.view_mut();
214            for ix in 0..iterating_shape.len() {
215                if !self.axes.contains(&ix) {
216                    view.collapse_axis(Axis(ix), it_coords[ix]);
217                }
218            }
219            softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
220        }
221
222        let mut output_tensor = output.into_tensor();
223        unsafe { output_tensor.set_datum_type(output_dt) };
224        Ok(tvec!(output_tensor.into_tvalue()))
225    }
226
227    fn softmax_inner_slice_f16(&self, slice: &mut [f16], kind: SoftmaxKind) -> TractResult<()> {
228        let max = (tract_linalg::ops().max_f16)().run(slice)?;
229        match kind {
230            SoftmaxKind::Softmax(exp_impl) => {
231                let sum = match exp_impl {
232                    SoftmaxExp::Libc => {
233                        let mut s = f16::zero();
234                        slice.iter_mut().for_each(|x| {
235                            *x = (*x - max).exp();
236                            s += *x;
237                        });
238                        s
239                    }
240                    SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f16)()
241                        .run_with_params(slice, max)?,
242                };
243                let rsum = sum.recip();
244                (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
245            }
246            SoftmaxKind::LogSoftmax => {
247                let mut exp_sum = f16::zero();
248                slice.iter_mut().for_each(|x| {
249                    *x -= max;
250                    exp_sum += x.exp();
251                });
252                let log_sum = exp_sum.ln();
253                slice.iter_mut().for_each(|x| *x -= log_sum);
254            }
255        }
256        Ok(())
257    }
258
259    fn softmax_inner_slice_f32(&self, slice: &mut [f32], kind: SoftmaxKind) -> TractResult<()> {
260        let max = (tract_linalg::ops().max_f32)().run(slice)?;
261        match kind {
262            SoftmaxKind::Softmax(exp_impl) => {
263                let sum = match exp_impl {
264                    SoftmaxExp::Libc => {
265                        let mut s = f32::zero();
266                        slice.iter_mut().for_each(|x| {
267                            *x = (*x - max).exp();
268                            s += *x;
269                        });
270                        s
271                    }
272                    SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f32)()
273                        .run_with_params(slice, max)?,
274                };
275                let rsum = sum.recip();
276                (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
277            }
278            SoftmaxKind::LogSoftmax => {
279                let mut exp_sum = f32::zero();
280                slice.iter_mut().for_each(|x| {
281                    *x -= max;
282                    exp_sum += x.exp();
283                });
284                let log_sum = exp_sum.ln();
285                slice.iter_mut().for_each(|x| *x -= log_sum);
286            }
287        }
288        Ok(())
289    }
290}
291
292fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(
293    mut view: ArrayViewMut<T, D>,
294    kind: SoftmaxKind,
295) {
296    let max =
297        *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
298    view.mapv_inplace(|x| x - max);
299    let exp_sum = view.iter().map(|&x| x.exp()).sum();
300    match kind {
301        SoftmaxKind::Softmax(_) => {
302            view.mapv_inplace(|x| x.exp() / exp_sum);
303        }
304        SoftmaxKind::LogSoftmax => {
305            let log_sum = exp_sum.ln();
306            view.mapv_inplace(|x| x - log_sum);
307        }
308    }
309}
310
311fn softmax_quant_inner<D: Dimension>(
312    mut view: ArrayViewMut<u8, D>,
313    src_is_signed: bool,
314    in_qp: QParams,
315    out_is_signed: bool,
316    out_qp: QParams,
317) {
318    let (_, in_scale) = in_qp.zp_scale();
319    let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
320    let (_, out_scale) = out_qp.zp_scale();
321    let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
322    let shift = 26 - scale_in_shift;
323
324    // Compute the exponentials x - max
325    let mut buffer = vec![0_i32; view.len()];
326
327    // Handle the case were we considered an i8 as an u8 and still get the right x - max.
328    let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
329
330    let max = view.iter().map(safe_u8).max().unwrap();
331    view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
332        let input_diff = safe_u8(x) as i32 - max as i32;
333
334        // We scale the input to be in Q5_26
335        let scaled_input_diff = if scale_in_multiplier != 0 {
336            saturating_rounding_multiply_by_pot(
337                saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
338                shift as i32,
339            )
340        } else {
341            saturating_rounding_multiply_by_pot(input_diff, shift as i32)
342        };
343
344        // It expects an input from Q5_26 and returns an output in Q0_31
345        *exp = exp_on_negative_values(scaled_input_diff);
346    });
347
348    // Compute sum of exp
349    // The sum is stored as an Q12_19 that's why we need to recale from Q0_31 to Q12_19 before summing.
350    let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
351
352    // Compute 1/sum_of_exp
353    // The result of this function is in Q0_31
354    let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
355
356    // Compute the exponent value needed to be in Q24_8 before the final rescaling
357    let exponent = num_bits_over_unit as isize + 31 - 8;
358
359    view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
360        // Compute the product of exp * 1/sum_of_exp and scale the result in Q24_8
361        let unsat_output = rounding_divide_by_pot(
362            saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
363            exponent as i32,
364        );
365
366        // Scale the final result in the output scale range
367        let unsat_scaled_output = {
368            if scale_out_multiplier != 0 {
369                let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
370                rounding_divide_by_pot(
371                    saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
372                    (8 - scale_out_shift - 1 - num_bits as isize) as i32,
373                )
374            } else {
375                rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
376            }
377        };
378
379        // Return the final result by clipping the computed value within its range
380        // and casting it to u8 in any case.
381        #[allow(unknown_lints, unnecessary_transmutes)]
382        if out_is_signed {
383            *it = unsafe {
384                std::mem::transmute::<i8, u8>(i32::max(
385                    i32::min(unsat_scaled_output, i8::MAX as i32),
386                    i8::MIN as i32,
387                ) as i8)
388            };
389        } else {
390            *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
391        }
392    });
393}
394
395#[cfg(test)]
396mod test {
397    use super::*;
398    use crate::ops::nn::DataFormat::NCHW;
399    use anyhow::Result;
400    use num_traits::PrimInt;
401    use proptest::collection::vec;
402    use proptest::prelude::*;
403    use tract_data::internal::QParams::ZpScale;
404
405    fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
406        let (_, in_epsilon) = in_dt.zp_scale();
407        let (_, out_epsilon) = out_dt.zp_scale();
408        let epsilon = in_epsilon + out_epsilon;
409        let error = (found - expected).abs();
410        assert!(
411            error <= epsilon,
412            "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
413        );
414    }
415
416    // Generate a random tensor with a quantized datum type
417    fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
418        let len = shape.iter().product::<usize>();
419        let dt = q_datum::<T>((0.0001f32..0.1).boxed());
420        (vec(any::<T>(), len..=len), dt)
421            .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
422            .prop_map(move |(array, dt)| {
423                let mut tensor = array.into_tensor();
424                unsafe { tensor.set_datum_type(dt) };
425                tensor
426            })
427            .boxed()
428    }
429
430    // Generate a random quantized datum type
431    fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
432        let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
433        prop_oneof![
434            (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
435            range
436        ]
437        .prop_map(|scale| {
438            if T::datum_type().is_signed() {
439                DatumType::QI8(ZpScale { zero_point: 0, scale })
440            } else {
441                DatumType::QU8(ZpScale { zero_point: 0, scale })
442            }
443        })
444        .boxed()
445    }
446
447    #[derive(Debug)]
448    struct SoftmaxProblem {
449        data: Tensor,
450        axes: TVec<usize>,
451        output_dt: DatumType,
452    }
453
454    impl SoftmaxProblem {
455        fn check(&self) -> Result<()> {
456            let inputs = tvec!(self.data.clone().into_tvalue());
457            let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
458            let softmax =
459                Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
460
461            // Compute quantized output
462            let result = softmax.eval(inputs)?;
463            let result = args_1!(result);
464            let result_float = result.cast_to::<f32>()?;
465
466            // Compute reference output
467            let input_float = self.data.cast_to::<f32>()?;
468            let inputs_float = tvec!(input_float.into_owned().into_tvalue());
469            let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
470            let reference_float = softmax_float.eval(inputs_float)?;
471            let reference_array = args_1!(reference_float);
472            let reference = reference_array.to_plain_array_view::<f32>()?;
473
474            result_float
475                .to_plain_array_view::<f32>()?
476                .iter()
477                .zip(reference.iter())
478                .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
479            Ok(())
480        }
481    }
482
483    impl Arbitrary for SoftmaxProblem {
484        type Parameters = ();
485        type Strategy = BoxedStrategy<SoftmaxProblem>;
486        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
487            (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
488                .prop_flat_map(|(n, c, h, w, axis)| {
489                    let shape_in: Vec<usize> =
490                        NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
491                    (
492                        prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
493                        Just(tvec![axis]),
494                        prop_oneof![
495                            q_datum::<u8>((0.008f32..0.1).boxed()),
496                            q_datum::<i8>((0.008f32..0.1).boxed())
497                        ],
498                    )
499                })
500                .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
501                .boxed()
502        }
503    }
504
505    #[derive(Debug)]
506    pub struct InnerSoftmaxProblem {
507        in_qp: QParams,
508        out_qp: QParams,
509        data: Vec<i8>,
510    }
511
512    impl InnerSoftmaxProblem {
513        fn check(&self) -> Result<()> {
514            let quantized = self.quantized();
515            let reference = self.reference();
516            assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
517                let abs_diff = if *quantized > *expected {
518                    quantized - *expected
519                } else {
520                    expected - *quantized
521                };
522                abs_diff <= 1
523            }));
524            Ok(())
525        }
526
527        fn reference(&self) -> Vec<u8> {
528            let (in_zero_point, in_scale) = self.in_qp.zp_scale();
529            let (out_zero_point, out_scale) = self.out_qp.zp_scale();
530            let in_float =
531                self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
532            let mut in_float_array = Array1::from_vec(in_float);
533            softmax_inner(in_float_array.view_mut(), SoftmaxKind::default());
534            let rescaled_output = in_float_array
535                .iter()
536                .map(|it| {
537                    ((*it / out_scale).round() as i32 + out_zero_point)
538                        .max(u8::MIN as i32)
539                        .min(u8::MAX as i32) as u8
540                })
541                .collect();
542            rescaled_output
543        }
544
545        fn quantized(&self) -> Vec<u8> {
546            let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
547            let mut in_array = Array1::from_vec(in_data);
548            softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
549            in_array.to_vec()
550        }
551    }
552
553    impl Arbitrary for InnerSoftmaxProblem {
554        type Parameters = ();
555        type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
556        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
557            (
558                prop_oneof![
559                    q_datum::<i8>((0.0001f32..0.01).boxed()),
560                    q_datum::<u8>((0.0001f32..0.01).boxed())
561                ],
562                prop_oneof![
563                    q_datum::<u8>((0.008f32..0.1).boxed()),
564                    q_datum::<i8>((0.008f32..0.1).boxed())
565                ],
566                vec(any::<i8>(), 1..10),
567            )
568                .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
569                    in_qp: in_qp.qparams().unwrap(),
570                    out_qp: out_qp.qparams().unwrap(),
571                    data,
572                })
573                .boxed()
574        }
575    }
576
577    proptest::proptest! {
578        #![proptest_config(ProptestConfig::with_cases(1000))]
579        #[test]
580        fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
581            pb.check().unwrap()
582        }
583    }
584
585    proptest::proptest! {
586        #![proptest_config(ProptestConfig::with_cases(1000))]
587        #[test]
588        fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
589            pb.check().unwrap()
590        }
591    }
592
593    #[test]
594    // We test QU8 -> QU8
595    fn test_softmax_trivial_0() -> Result<()> {
596        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
597        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
598        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
599        unsafe { data.set_datum_type(input_dt) };
600
601        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
602        prob.check()?;
603        Ok(())
604    }
605
606    #[test]
607    // We test QI8 -> QU8
608    fn test_softmax_trivial_1() -> Result<()> {
609        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
610        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
611        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
612        unsafe { data.set_datum_type(input_dt) };
613
614        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
615        prob.check()?;
616        Ok(())
617    }
618
619    #[test]
620    // We test QI8 -> QI8
621    fn test_softmax_trivial_2() -> Result<()> {
622        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
623        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
624        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
625        unsafe { data.set_datum_type(input_dt) };
626
627        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
628        prob.check()?;
629        Ok(())
630    }
631
632    #[test]
633    // We test QU8 -> QI8
634    fn test_softmax_trivial_3() -> Result<()> {
635        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
636        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
637        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
638        unsafe { data.set_datum_type(input_dt) };
639
640        let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
641        prob.check()?;
642        Ok(())
643    }
644
645    #[test]
646    fn test_softmax_1() -> Result<()> {
647        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); // Q6_1
648        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); // Q7_1
649        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
650        unsafe { data.set_datum_type(input_dt) };
651
652        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
653        prob.check()?;
654        Ok(())
655    }
656
657    #[test]
658    fn test_softmax_2() -> Result<()> {
659        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
660        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
661        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
662        unsafe { data.set_datum_type(input_dt) };
663
664        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
665        prob.check()?;
666        Ok(())
667    }
668
669    #[test]
670    fn test_softmax_3() -> Result<()> {
671        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
672        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
673        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
674        unsafe { data.set_datum_type(input_dt) };
675
676        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
677        prob.check()?;
678        Ok(())
679    }
680
681    #[test]
682    fn test_inner_softmax_1() -> Result<()> {
683        let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
684        let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
685        let data = vec![0_i8, 1];
686
687        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
688        prob.check()?;
689        Ok(())
690    }
691
692    #[test]
693    fn test_inner_softmax_2() -> Result<()> {
694        let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
695        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
696        let data = vec![100i8, -28];
697
698        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
699        prob.check()?;
700        Ok(())
701    }
702
703    #[test]
704    fn test_inner_softmax_not_pow_2_1() -> Result<()> {
705        let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
706        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
707        let data = vec![100i8, -28];
708
709        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
710        prob.check()?;
711        Ok(())
712    }
713
714    #[test]
715    #[ignore]
716    // Fails but the difference is quite low and the sum still give exactly one:
717    // quantized: 110(0.88), 15(0.12)
718    // expected: 112(0.896), 13(0.104)
719    fn test_inner_softmax_not_pow_2_2() -> Result<()> {
720        let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
721        let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
722        let data = vec![118i8, 108];
723
724        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
725        prob.check()?;
726        Ok(())
727    }
728
729    #[test]
730    #[ignore]
731    // Fails but the difference is quite low and the sum still give exactly one:
732    // quantized: 40(0.625), 24(0.375)
733    // expected: 42(0.65625), 22(0.34375)
734    fn test_inner_softmax_not_pow_2_3() -> Result<()> {
735        let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
736        let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
737        let data = vec![45i8, 43];
738
739        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
740        prob.check()?;
741        Ok(())
742    }
743}