Skip to main content

tract_core/ops/nn/softmax/
mod.rs

1mod fixedpoint;
2pub mod math;
3
4use math::{
5    convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6    rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7    saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, PartialEq)]
17pub enum SoftmaxKind {
18    Softmax(SoftmaxExp),
19    LogSoftmax,
20}
21
22impl Default for SoftmaxKind {
23    fn default() -> Self {
24        SoftmaxKind::Softmax(SoftmaxExp::default())
25    }
26}
27
28#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
29pub enum SoftmaxExp {
30    #[default]
31    Libc,
32    // https://nic.schraudolph.org/pubs/Schraudolph99.pdf
33    FastCompact,
34}
35
36#[derive(Debug, Clone, new, Hash, Default)]
37pub struct Softmax {
38    pub axes: TVec<usize>,
39    pub quant_output_dt: Option<DatumType>,
40    pub kind: SoftmaxKind,
41}
42
43impl Op for Softmax {
44    fn name(&self) -> StaticName {
45        match self.kind {
46            SoftmaxKind::Softmax(_) => "Softmax".into(),
47            SoftmaxKind::LogSoftmax => "LogSoftmax".into(),
48        }
49    }
50
51    fn info(&self) -> TractResult<Vec<String>> {
52        let mut infos = vec![format!("Axis: {:?}", self.axes)];
53        if let SoftmaxKind::Softmax(exp) = self.kind {
54            infos.push(format!("Exp impl: {exp:?}"))
55        };
56        Ok(infos)
57    }
58
59    op_as_typed_op!();
60}
61
62impl TypedOp for Softmax {
63    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64        let dt = inputs[0].datum_type;
65        if dt.is_float() {
66            ensure!(
67                self.quant_output_dt.is_none(),
68                "Float softmax should not have quant_output_dt, have {:?}",
69                self.quant_output_dt
70            );
71        } else if dt.is_quantized() {
72            ensure!(
73                self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
74                "Quantized softmax should have a quantized output type (got {:?})",
75                self.quant_output_dt
76            );
77        } else {
78            bail!(
79                "Unsupported datum type in softmax: input type {:?}, output type {:?}",
80                dt,
81                self.quant_output_dt
82            );
83        }
84
85        let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
86        Ok(tvec!(fact))
87    }
88
89    fn axes_mapping(
90        &self,
91        inputs: &[&TypedFact],
92        outputs: &[&TypedFact],
93    ) -> TractResult<AxesMapping> {
94        AxesMapping::natural(inputs, outputs)
95    }
96
97    fn change_axes(
98        &self,
99        model: &TypedModel,
100        node: &TypedNode,
101        _io: InOut,
102        change: &AxisOp,
103    ) -> TractResult<Option<AxisChangeConsequence>> {
104        let axes: Option<TVec<usize>> =
105            self.axes.iter().map(|it| change.transform_axis(*it)).collect();
106        if let Some(axes) = axes {
107            Ok(Some(AxisChangeConsequence::new(
108                model,
109                node,
110                Some(Box::new(Softmax { axes, ..self.clone() })),
111                change,
112            )))
113        } else {
114            Ok(None)
115        }
116    }
117
118    as_op!();
119}
120
121impl EvalOp for Softmax {
122    fn is_stateless(&self) -> bool {
123        true
124    }
125
126    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
127        let input = args_1!(inputs);
128        let dt = input.datum_type();
129
130        let output = match dt {
131            DatumType::F64 => self.eval_t::<f64>(input)?,
132            DatumType::F32 => self.eval_t::<f32>(input)?,
133            DatumType::F16 => self.eval_t::<f16>(input)?,
134            DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
135            dt => bail!("Unsupported type {dt:?}"),
136        };
137        Ok(output)
138    }
139}
140
141impl Softmax {
142    fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
143    where
144        T: Float + Datum + std::iter::Sum,
145    {
146        let mut iterating_shape: TVec<usize> = input.shape().into();
147
148        for i in 0..iterating_shape.len() {
149            if self.axes.contains(&i) {
150                iterating_shape[i] = 1
151            }
152        }
153
154        let mut output = input.into_tensor();
155        let mut output_dense = output.try_as_dense_mut()?;
156        let mut view = output_dense.to_array_view_mut::<T>()?;
157
158        for it_coords in tract_ndarray::indices(&*iterating_shape) {
159            let mut view = view.view_mut();
160            for ix in 0..iterating_shape.len() {
161                if !self.axes.contains(&ix) {
162                    view.collapse_axis(Axis(ix), it_coords[ix]);
163                }
164            }
165            if let Some(slice) =
166                view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
167            {
168                let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
169                self.softmax_inner_slice_f32(slice, self.kind)?;
170            } else if let Some(slice) =
171                view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
172            {
173                let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
174                self.softmax_inner_slice_f16(slice, self.kind)?;
175            } else {
176                softmax_inner(view, self.kind);
177            }
178        }
179
180        Ok(tvec!(output.into_tvalue()))
181    }
182
183    fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
184        if self.kind == SoftmaxKind::LogSoftmax {
185            bail!("Quantized LogSoftmax is not supported")
186        }
187        let mut iterating_shape: TVec<usize> = input.shape().into();
188        let output_dt =
189            self.quant_output_dt.context("Quandized softmax eval with no output type")?;
190
191        for i in 0..iterating_shape.len() {
192            if self.axes.contains(&i) {
193                iterating_shape[i] = 1
194            }
195        }
196
197        // All operations will be done in u8, we will cast the result appropriately afterward.
198        let src_is_signed = input.datum_type().is_signed();
199        let out_is_signed = output_dt.is_signed();
200        let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case
201        let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case
202        let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
203
204        for it_coords in tract_ndarray::indices(&*iterating_shape) {
205            let mut view = output.view_mut();
206            for ix in 0..iterating_shape.len() {
207                if !self.axes.contains(&ix) {
208                    view.collapse_axis(Axis(ix), it_coords[ix]);
209                }
210            }
211            softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
212        }
213
214        let mut output_tensor = output.into_tensor();
215        unsafe { output_tensor.set_datum_type(output_dt) };
216        Ok(tvec!(output_tensor.into_tvalue()))
217    }
218
219    fn softmax_inner_slice_f16(&self, slice: &mut [f16], kind: SoftmaxKind) -> TractResult<()> {
220        let max = (tract_linalg::ops().max_f16)().run(slice)?;
221        match kind {
222            SoftmaxKind::Softmax(exp_impl) => {
223                let sum = match exp_impl {
224                    SoftmaxExp::Libc => {
225                        let mut s = f16::zero();
226                        slice.iter_mut().for_each(|x| {
227                            *x = (*x - max).exp();
228                            s += *x;
229                        });
230                        s
231                    }
232                    SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f16)()
233                        .run_with_params(slice, max)?,
234                };
235                let rsum = sum.recip();
236                (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
237            }
238            SoftmaxKind::LogSoftmax => {
239                let mut exp_sum = f16::zero();
240                slice.iter_mut().for_each(|x| {
241                    *x -= max;
242                    exp_sum += x.exp();
243                });
244                let log_sum = exp_sum.ln();
245                slice.iter_mut().for_each(|x| *x -= log_sum);
246            }
247        }
248        Ok(())
249    }
250
251    fn softmax_inner_slice_f32(&self, slice: &mut [f32], kind: SoftmaxKind) -> TractResult<()> {
252        let max = (tract_linalg::ops().max_f32)().run(slice)?;
253        match kind {
254            SoftmaxKind::Softmax(exp_impl) => {
255                let sum = match exp_impl {
256                    SoftmaxExp::Libc => {
257                        let mut s = f32::zero();
258                        slice.iter_mut().for_each(|x| {
259                            *x = (*x - max).exp();
260                            s += *x;
261                        });
262                        s
263                    }
264                    SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f32)()
265                        .run_with_params(slice, max)?,
266                };
267                let rsum = sum.recip();
268                (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
269            }
270            SoftmaxKind::LogSoftmax => {
271                let mut exp_sum = f32::zero();
272                slice.iter_mut().for_each(|x| {
273                    *x -= max;
274                    exp_sum += x.exp();
275                });
276                let log_sum = exp_sum.ln();
277                slice.iter_mut().for_each(|x| *x -= log_sum);
278            }
279        }
280        Ok(())
281    }
282}
283
284fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(
285    mut view: ArrayViewMut<T, D>,
286    kind: SoftmaxKind,
287) {
288    let max =
289        *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
290    view.mapv_inplace(|x| x - max);
291    let exp_sum = view.iter().map(|&x| x.exp()).sum();
292    match kind {
293        SoftmaxKind::Softmax(_) => {
294            view.mapv_inplace(|x| x.exp() / exp_sum);
295        }
296        SoftmaxKind::LogSoftmax => {
297            let log_sum = exp_sum.ln();
298            view.mapv_inplace(|x| x - log_sum);
299        }
300    }
301}
302
303fn softmax_quant_inner<D: Dimension>(
304    mut view: ArrayViewMut<u8, D>,
305    src_is_signed: bool,
306    in_qp: QParams,
307    out_is_signed: bool,
308    out_qp: QParams,
309) {
310    let (_, in_scale) = in_qp.zp_scale();
311    let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
312    let (_, out_scale) = out_qp.zp_scale();
313    let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
314    let shift = 26 - scale_in_shift;
315
316    // Compute the exponentials x - max
317    let mut buffer = vec![0_i32; view.len()];
318
319    // Handle the case were we considered an i8 as an u8 and still get the right x - max.
320    let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
321
322    let max = view.iter().map(safe_u8).max().unwrap();
323    view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
324        let input_diff = safe_u8(x) as i32 - max as i32;
325
326        // We scale the input to be in Q5_26
327        let scaled_input_diff = if scale_in_multiplier != 0 {
328            saturating_rounding_multiply_by_pot(
329                saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
330                shift as i32,
331            )
332        } else {
333            saturating_rounding_multiply_by_pot(input_diff, shift as i32)
334        };
335
336        // It expects an input from Q5_26 and returns an output in Q0_31
337        *exp = exp_on_negative_values(scaled_input_diff);
338    });
339
340    // Compute sum of exp
341    // The sum is stored as an Q12_19 that's why we need to recale from Q0_31 to Q12_19 before summing.
342    let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
343
344    // Compute 1/sum_of_exp
345    // The result of this function is in Q0_31
346    let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
347
348    // Compute the exponent value needed to be in Q24_8 before the final rescaling
349    let exponent = num_bits_over_unit as isize + 31 - 8;
350
351    view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
352        // Compute the product of exp * 1/sum_of_exp and scale the result in Q24_8
353        let unsat_output = rounding_divide_by_pot(
354            saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
355            exponent as i32,
356        );
357
358        // Scale the final result in the output scale range
359        let unsat_scaled_output = {
360            if scale_out_multiplier != 0 {
361                let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
362                rounding_divide_by_pot(
363                    saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
364                    (8 - scale_out_shift - 1 - num_bits as isize) as i32,
365                )
366            } else {
367                rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
368            }
369        };
370
371        // Return the final result by clipping the computed value within its range
372        // and casting it to u8 in any case.
373        #[allow(unknown_lints, unnecessary_transmutes)]
374        if out_is_signed {
375            *it = unsafe {
376                std::mem::transmute::<i8, u8>(i32::max(
377                    i32::min(unsat_scaled_output, i8::MAX as i32),
378                    i8::MIN as i32,
379                ) as i8)
380            };
381        } else {
382            *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
383        }
384    });
385}
386
387#[cfg(test)]
388mod test {
389    use super::*;
390    use crate::ops::nn::DataFormat::NCHW;
391    use anyhow::Result;
392    use num_traits::PrimInt;
393    use proptest::collection::vec;
394    use proptest::prelude::*;
395    use tract_data::internal::QParams::ZpScale;
396
397    fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
398        let (_, in_epsilon) = in_dt.zp_scale();
399        let (_, out_epsilon) = out_dt.zp_scale();
400        let epsilon = in_epsilon + out_epsilon;
401        let error = (found - expected).abs();
402        assert!(
403            error <= epsilon,
404            "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
405        );
406    }
407
408    // Generate a random tensor with a quantized datum type
409    fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
410        let len = shape.iter().product::<usize>();
411        let dt = q_datum::<T>((0.0001f32..0.1).boxed());
412        (vec(any::<T>(), len..=len), dt)
413            .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
414            .prop_map(move |(array, dt)| {
415                let mut tensor = array.into_tensor();
416                unsafe { tensor.set_datum_type(dt) };
417                tensor
418            })
419            .boxed()
420    }
421
422    // Generate a random quantized datum type
423    fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
424        let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
425        prop_oneof![
426            (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
427            range
428        ]
429        .prop_map(|scale| {
430            if T::datum_type().is_signed() {
431                DatumType::QI8(ZpScale { zero_point: 0, scale })
432            } else {
433                DatumType::QU8(ZpScale { zero_point: 0, scale })
434            }
435        })
436        .boxed()
437    }
438
439    #[derive(Debug)]
440    struct SoftmaxProblem {
441        data: Tensor,
442        axes: TVec<usize>,
443        output_dt: DatumType,
444    }
445
446    impl SoftmaxProblem {
447        fn check(&self) -> Result<()> {
448            let inputs = tvec!(self.data.clone().into_tvalue());
449            let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
450            let softmax =
451                Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
452
453            // Compute quantized output
454            let result = softmax.eval(inputs)?;
455            let result = args_1!(result);
456            let result_float = result.cast_to::<f32>()?;
457
458            // Compute reference output
459            let input_float = self.data.cast_to::<f32>()?;
460            let inputs_float = tvec!(input_float.into_owned().into_tvalue());
461            let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
462            let reference_float = softmax_float.eval(inputs_float)?;
463            let reference_array = args_1!(reference_float);
464            let reference = reference_array.to_dense_array_view::<f32>()?;
465
466            result_float
467                .to_dense_array_view::<f32>()?
468                .iter()
469                .zip(reference.iter())
470                .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
471            Ok(())
472        }
473    }
474
475    impl Arbitrary for SoftmaxProblem {
476        type Parameters = ();
477        type Strategy = BoxedStrategy<SoftmaxProblem>;
478        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
479            (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
480                .prop_flat_map(|(n, c, h, w, axis)| {
481                    let shape_in: Vec<usize> =
482                        NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
483                    (
484                        prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
485                        Just(tvec![axis]),
486                        prop_oneof![
487                            q_datum::<u8>((0.008f32..0.1).boxed()),
488                            q_datum::<i8>((0.008f32..0.1).boxed())
489                        ],
490                    )
491                })
492                .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
493                .boxed()
494        }
495    }
496
497    #[derive(Debug)]
498    pub struct InnerSoftmaxProblem {
499        in_qp: QParams,
500        out_qp: QParams,
501        data: Vec<i8>,
502    }
503
504    impl InnerSoftmaxProblem {
505        fn check(&self) -> Result<()> {
506            let quantized = self.quantized();
507            let reference = self.reference();
508            assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
509                let abs_diff = if *quantized > *expected {
510                    quantized - *expected
511                } else {
512                    expected - *quantized
513                };
514                abs_diff <= 1
515            }));
516            Ok(())
517        }
518
519        fn reference(&self) -> Vec<u8> {
520            let (in_zero_point, in_scale) = self.in_qp.zp_scale();
521            let (out_zero_point, out_scale) = self.out_qp.zp_scale();
522            let in_float =
523                self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
524            let mut in_float_array = Array1::from_vec(in_float);
525            softmax_inner(in_float_array.view_mut(), SoftmaxKind::default());
526            let rescaled_output = in_float_array
527                .iter()
528                .map(|it| {
529                    ((*it / out_scale).round() as i32 + out_zero_point)
530                        .max(u8::MIN as i32)
531                        .min(u8::MAX as i32) as u8
532                })
533                .collect();
534            rescaled_output
535        }
536
537        fn quantized(&self) -> Vec<u8> {
538            let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
539            let mut in_array = Array1::from_vec(in_data);
540            softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
541            in_array.to_vec()
542        }
543    }
544
545    impl Arbitrary for InnerSoftmaxProblem {
546        type Parameters = ();
547        type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
548        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
549            (
550                prop_oneof![
551                    q_datum::<i8>((0.0001f32..0.01).boxed()),
552                    q_datum::<u8>((0.0001f32..0.01).boxed())
553                ],
554                prop_oneof![
555                    q_datum::<u8>((0.008f32..0.1).boxed()),
556                    q_datum::<i8>((0.008f32..0.1).boxed())
557                ],
558                vec(any::<i8>(), 1..10),
559            )
560                .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
561                    in_qp: in_qp.qparams().unwrap(),
562                    out_qp: out_qp.qparams().unwrap(),
563                    data,
564                })
565                .boxed()
566        }
567    }
568
569    proptest::proptest! {
570        #![proptest_config(ProptestConfig::with_cases(1000))]
571        #[test]
572        fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
573            pb.check().unwrap()
574        }
575    }
576
577    proptest::proptest! {
578        #![proptest_config(ProptestConfig::with_cases(1000))]
579        #[test]
580        fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
581            pb.check().unwrap()
582        }
583    }
584
585    #[test]
586    // We test QU8 -> QU8
587    fn test_softmax_trivial_0() -> Result<()> {
588        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
589        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
590        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
591        unsafe { data.set_datum_type(input_dt) };
592
593        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
594        prob.check()?;
595        Ok(())
596    }
597
598    #[test]
599    // We test QI8 -> QU8
600    fn test_softmax_trivial_1() -> Result<()> {
601        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
602        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
603        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
604        unsafe { data.set_datum_type(input_dt) };
605
606        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
607        prob.check()?;
608        Ok(())
609    }
610
611    #[test]
612    // We test QI8 -> QI8
613    fn test_softmax_trivial_2() -> Result<()> {
614        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
615        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
616        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
617        unsafe { data.set_datum_type(input_dt) };
618
619        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
620        prob.check()?;
621        Ok(())
622    }
623
624    #[test]
625    // We test QU8 -> QI8
626    fn test_softmax_trivial_3() -> Result<()> {
627        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
628        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
629        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
630        unsafe { data.set_datum_type(input_dt) };
631
632        let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
633        prob.check()?;
634        Ok(())
635    }
636
637    #[test]
638    fn test_softmax_1() -> Result<()> {
639        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); // Q6_1
640        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); // Q7_1
641        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
642        unsafe { data.set_datum_type(input_dt) };
643
644        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
645        prob.check()?;
646        Ok(())
647    }
648
649    #[test]
650    fn test_softmax_2() -> Result<()> {
651        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
652        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
653        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
654        unsafe { data.set_datum_type(input_dt) };
655
656        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
657        prob.check()?;
658        Ok(())
659    }
660
661    #[test]
662    fn test_softmax_3() -> Result<()> {
663        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
664        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
665        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
666        unsafe { data.set_datum_type(input_dt) };
667
668        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
669        prob.check()?;
670        Ok(())
671    }
672
673    #[test]
674    fn test_inner_softmax_1() -> Result<()> {
675        let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
676        let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
677        let data = vec![0_i8, 1];
678
679        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
680        prob.check()?;
681        Ok(())
682    }
683
684    #[test]
685    fn test_inner_softmax_2() -> Result<()> {
686        let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
687        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
688        let data = vec![100i8, -28];
689
690        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
691        prob.check()?;
692        Ok(())
693    }
694
695    #[test]
696    fn test_inner_softmax_not_pow_2_1() -> Result<()> {
697        let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
698        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
699        let data = vec![100i8, -28];
700
701        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
702        prob.check()?;
703        Ok(())
704    }
705
706    #[test]
707    #[ignore]
708    // Fails but the difference is quite low and the sum still give exactly one:
709    // quantized: 110(0.88), 15(0.12)
710    // expected: 112(0.896), 13(0.104)
711    fn test_inner_softmax_not_pow_2_2() -> Result<()> {
712        let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
713        let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
714        let data = vec![118i8, 108];
715
716        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
717        prob.check()?;
718        Ok(())
719    }
720
721    #[test]
722    #[ignore]
723    // Fails but the difference is quite low and the sum still give exactly one:
724    // quantized: 40(0.625), 24(0.375)
725    // expected: 42(0.65625), 22(0.34375)
726    fn test_inner_softmax_not_pow_2_3() -> Result<()> {
727        let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
728        let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
729        let data = vec![45i8, 43];
730
731        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
732        prob.check()?;
733        Ok(())
734    }
735}