tract_core/ops/nn/softmax/
mod.rs

1mod fixedpoint;
2pub mod math;
3
4use math::{
5    convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6    rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7    saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, PartialEq)]
17pub enum SoftmaxKind {
18    Softmax(SoftmaxExp),
19    LogSoftmax,
20}
21
22impl Default for SoftmaxKind {
23    fn default() -> Self {
24        SoftmaxKind::Softmax(SoftmaxExp::default())
25    }
26}
27
28#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
29pub enum SoftmaxExp {
30    #[default]
31    Libc,
32    // https://nic.schraudolph.org/pubs/Schraudolph99.pdf
33    FastCompact,
34}
35
36#[derive(Debug, Clone, new, Hash, Default)]
37pub struct Softmax {
38    pub axes: TVec<usize>,
39    pub quant_output_dt: Option<DatumType>,
40    pub kind: SoftmaxKind,
41}
42
43impl Op for Softmax {
44    fn name(&self) -> StaticName {
45        match self.kind {
46            SoftmaxKind::Softmax(_) => "Softmax".into(),
47            SoftmaxKind::LogSoftmax => "LogSoftmax".into(),
48        }
49    }
50
51    fn info(&self) -> TractResult<Vec<String>> {
52        let mut infos = vec![format!("Axis: {:?}", self.axes)];
53        if let SoftmaxKind::Softmax(exp) = self.kind {
54            infos.push(format!("Exp impl: {exp:?}"))
55        };
56        Ok(infos)
57    }
58
59    op_as_typed_op!();
60}
61
62impl TypedOp for Softmax {
63    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64        let dt = inputs[0].datum_type;
65        if dt.is_float() {
66            ensure!(
67                self.quant_output_dt.is_none(),
68                "Float softmax should not have quant_output_dt, have {:?}",
69                self.quant_output_dt
70            );
71        } else if dt.is_quantized() {
72            ensure!(
73                self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
74                "Quantized softmax should have a quantized output type (got {:?})",
75                self.quant_output_dt
76            );
77        } else {
78            bail!(
79                "Unsupported datum type in softmax: input type {:?}, output type {:?}",
80                dt,
81                self.quant_output_dt
82            );
83        }
84
85        let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
86        Ok(tvec!(fact))
87    }
88
89    fn axes_mapping(
90        &self,
91        inputs: &[&TypedFact],
92        outputs: &[&TypedFact],
93    ) -> TractResult<AxesMapping> {
94        AxesMapping::natural(inputs, outputs)
95    }
96
97    fn change_axes(
98        &self,
99        model: &TypedModel,
100        node: &TypedNode,
101        _io: InOut,
102        change: &AxisOp,
103    ) -> TractResult<Option<AxisChangeConsequence>> {
104        let axes: Option<TVec<usize>> =
105            self.axes.iter().map(|it| change.transform_axis(*it)).collect();
106        if let Some(axes) = axes {
107            Ok(Some(AxisChangeConsequence::new(
108                model,
109                node,
110                Some(Box::new(Softmax { axes, ..self.clone() })),
111                change,
112            )))
113        } else {
114            Ok(None)
115        }
116    }
117
118    as_op!();
119}
120
121impl EvalOp for Softmax {
122    fn is_stateless(&self) -> bool {
123        true
124    }
125
126    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
127        let input = args_1!(inputs);
128        let dt = input.datum_type();
129
130        let output = match dt {
131            DatumType::F64 => self.eval_t::<f64>(input)?,
132            DatumType::F32 => self.eval_t::<f32>(input)?,
133            DatumType::F16 => self.eval_t::<f16>(input)?,
134            DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
135            dt => bail!("Unsupported type {dt:?}"),
136        };
137        Ok(output)
138    }
139}
140
141impl Softmax {
142    fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
143    where
144        T: Float + Datum + std::iter::Sum,
145    {
146        let mut iterating_shape: TVec<usize> = input.shape().into();
147
148        for i in 0..iterating_shape.len() {
149            if self.axes.contains(&i) {
150                iterating_shape[i] = 1
151            }
152        }
153
154        let mut output = input.into_tensor();
155        let mut view = output.to_array_view_mut::<T>()?;
156
157        for it_coords in tract_ndarray::indices(&*iterating_shape) {
158            let mut view = view.view_mut();
159            for ix in 0..iterating_shape.len() {
160                if !self.axes.contains(&ix) {
161                    view.collapse_axis(Axis(ix), it_coords[ix]);
162                }
163            }
164            if let Some(slice) =
165                view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
166            {
167                let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
168                self.softmax_inner_slice_f32(slice, self.kind)?;
169            } else if let Some(slice) =
170                view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
171            {
172                let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
173                self.softmax_inner_slice_f16(slice, self.kind)?;
174            } else {
175                softmax_inner(view, self.kind);
176            }
177        }
178
179        Ok(tvec!(output.into_tvalue()))
180    }
181
182    fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
183        if self.kind == SoftmaxKind::LogSoftmax {
184            bail!("Quantized LogSoftmax is not supported")
185        }
186        let mut iterating_shape: TVec<usize> = input.shape().into();
187        let output_dt =
188            self.quant_output_dt.context("Quandized softmax eval with no output type")?;
189
190        for i in 0..iterating_shape.len() {
191            if self.axes.contains(&i) {
192                iterating_shape[i] = 1
193            }
194        }
195
196        // All operations will be done in u8, we will cast the result appropriately afterward.
197        let src_is_signed = input.datum_type().is_signed();
198        let out_is_signed = output_dt.is_signed();
199        let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case
200        let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case
201        let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
202
203        for it_coords in tract_ndarray::indices(&*iterating_shape) {
204            let mut view = output.view_mut();
205            for ix in 0..iterating_shape.len() {
206                if !self.axes.contains(&ix) {
207                    view.collapse_axis(Axis(ix), it_coords[ix]);
208                }
209            }
210            softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
211        }
212
213        let mut output_tensor = output.into_tensor();
214        unsafe { output_tensor.set_datum_type(output_dt) };
215        Ok(tvec!(output_tensor.into_tvalue()))
216    }
217
218    fn softmax_inner_slice_f16(&self, slice: &mut [f16], kind: SoftmaxKind) -> TractResult<()> {
219        let max = (tract_linalg::ops().max_f16)().run(slice)?;
220        match kind {
221            SoftmaxKind::Softmax(exp_impl) => {
222                let sum = match exp_impl {
223                    SoftmaxExp::Libc => {
224                        let mut s = f16::zero();
225                        slice.iter_mut().for_each(|x| {
226                            *x = (*x - max).exp();
227                            s += *x;
228                        });
229                        s
230                    }
231                    SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f16)()
232                        .run_with_params(slice, max)?,
233                };
234                let rsum = sum.recip();
235                (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
236            }
237            SoftmaxKind::LogSoftmax => {
238                let mut exp_sum = f16::zero();
239                slice.iter_mut().for_each(|x| {
240                    *x -= max;
241                    exp_sum += x.exp();
242                });
243                let log_sum = exp_sum.ln();
244                slice.iter_mut().for_each(|x| *x -= log_sum);
245            }
246        }
247        Ok(())
248    }
249
250    fn softmax_inner_slice_f32(&self, slice: &mut [f32], kind: SoftmaxKind) -> TractResult<()> {
251        let max = (tract_linalg::ops().max_f32)().run(slice)?;
252        match kind {
253            SoftmaxKind::Softmax(exp_impl) => {
254                let sum = match exp_impl {
255                    SoftmaxExp::Libc => {
256                        let mut s = f32::zero();
257                        slice.iter_mut().for_each(|x| {
258                            *x = (*x - max).exp();
259                            s += *x;
260                        });
261                        s
262                    }
263                    SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f32)()
264                        .run_with_params(slice, max)?,
265                };
266                let rsum = sum.recip();
267                (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
268            }
269            SoftmaxKind::LogSoftmax => {
270                let mut exp_sum = f32::zero();
271                slice.iter_mut().for_each(|x| {
272                    *x -= max;
273                    exp_sum += x.exp();
274                });
275                let log_sum = exp_sum.ln();
276                slice.iter_mut().for_each(|x| *x -= log_sum);
277            }
278        }
279        Ok(())
280    }
281}
282
283fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(
284    mut view: ArrayViewMut<T, D>,
285    kind: SoftmaxKind,
286) {
287    let max =
288        *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
289    view.mapv_inplace(|x| x - max);
290    let exp_sum = view.iter().map(|&x| x.exp()).sum();
291    match kind {
292        SoftmaxKind::Softmax(_) => {
293            view.mapv_inplace(|x| x.exp() / exp_sum);
294        }
295        SoftmaxKind::LogSoftmax => {
296            let log_sum = exp_sum.ln();
297            view.mapv_inplace(|x| x - log_sum);
298        }
299    }
300}
301
302fn softmax_quant_inner<D: Dimension>(
303    mut view: ArrayViewMut<u8, D>,
304    src_is_signed: bool,
305    in_qp: QParams,
306    out_is_signed: bool,
307    out_qp: QParams,
308) {
309    let (_, in_scale) = in_qp.zp_scale();
310    let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
311    let (_, out_scale) = out_qp.zp_scale();
312    let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
313    let shift = 26 - scale_in_shift;
314
315    // Compute the exponentials x - max
316    let mut buffer = vec![0_i32; view.len()];
317
318    // Handle the case were we considered an i8 as an u8 and still get the right x - max.
319    let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
320
321    let max = view.iter().map(safe_u8).max().unwrap();
322    view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
323        let input_diff = safe_u8(x) as i32 - max as i32;
324
325        // We scale the input to be in Q5_26
326        let scaled_input_diff = if scale_in_multiplier != 0 {
327            saturating_rounding_multiply_by_pot(
328                saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
329                shift as i32,
330            )
331        } else {
332            saturating_rounding_multiply_by_pot(input_diff, shift as i32)
333        };
334
335        // It expects an input from Q5_26 and returns an output in Q0_31
336        *exp = exp_on_negative_values(scaled_input_diff);
337    });
338
339    // Compute sum of exp
340    // The sum is stored as an Q12_19 that's why we need to recale from Q0_31 to Q12_19 before summing.
341    let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
342
343    // Compute 1/sum_of_exp
344    // The result of this function is in Q0_31
345    let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
346
347    // Compute the exponent value needed to be in Q24_8 before the final rescaling
348    let exponent = num_bits_over_unit as isize + 31 - 8;
349
350    view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
351        // Compute the product of exp * 1/sum_of_exp and scale the result in Q24_8
352        let unsat_output = rounding_divide_by_pot(
353            saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
354            exponent as i32,
355        );
356
357        // Scale the final result in the output scale range
358        let unsat_scaled_output = {
359            if scale_out_multiplier != 0 {
360                let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
361                rounding_divide_by_pot(
362                    saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
363                    (8 - scale_out_shift - 1 - num_bits as isize) as i32,
364                )
365            } else {
366                rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
367            }
368        };
369
370        // Return the final result by clipping the computed value within its range
371        // and casting it to u8 in any case.
372        #[allow(unknown_lints, unnecessary_transmutes)]
373        if out_is_signed {
374            *it = unsafe {
375                std::mem::transmute::<i8, u8>(i32::max(
376                    i32::min(unsat_scaled_output, i8::MAX as i32),
377                    i8::MIN as i32,
378                ) as i8)
379            };
380        } else {
381            *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
382        }
383    });
384}
385
386#[cfg(test)]
387mod test {
388    use super::*;
389    use crate::ops::nn::DataFormat::NCHW;
390    use anyhow::Result;
391    use num_traits::PrimInt;
392    use proptest::collection::vec;
393    use proptest::prelude::*;
394    use tract_data::internal::QParams::ZpScale;
395
396    fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
397        let (_, in_epsilon) = in_dt.zp_scale();
398        let (_, out_epsilon) = out_dt.zp_scale();
399        let epsilon = in_epsilon + out_epsilon;
400        let error = (found - expected).abs();
401        assert!(
402            error <= epsilon,
403            "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
404        );
405    }
406
407    // Generate a random tensor with a quantized datum type
408    fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
409        let len = shape.iter().product::<usize>();
410        let dt = q_datum::<T>((0.0001f32..0.1).boxed());
411        (vec(any::<T>(), len..=len), dt)
412            .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
413            .prop_map(move |(array, dt)| {
414                let mut tensor = array.into_tensor();
415                unsafe { tensor.set_datum_type(dt) };
416                tensor
417            })
418            .boxed()
419    }
420
421    // Generate a random quantized datum type
422    fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
423        let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
424        prop_oneof![
425            (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
426            range
427        ]
428        .prop_map(|scale| {
429            if T::datum_type().is_signed() {
430                DatumType::QI8(ZpScale { zero_point: 0, scale })
431            } else {
432                DatumType::QU8(ZpScale { zero_point: 0, scale })
433            }
434        })
435        .boxed()
436    }
437
438    #[derive(Debug)]
439    struct SoftmaxProblem {
440        data: Tensor,
441        axes: TVec<usize>,
442        output_dt: DatumType,
443    }
444
445    impl SoftmaxProblem {
446        fn check(&self) -> Result<()> {
447            let inputs = tvec!(self.data.clone().into_tvalue());
448            let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
449            let softmax =
450                Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
451
452            // Compute quantized output
453            let result = softmax.eval(inputs)?;
454            let result = args_1!(result);
455            let result_float = result.cast_to::<f32>()?;
456
457            // Compute reference output
458            let input_float = self.data.cast_to::<f32>()?;
459            let inputs_float = tvec!(input_float.into_owned().into_tvalue());
460            let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
461            let reference_float = softmax_float.eval(inputs_float)?;
462            let reference_array = args_1!(reference_float);
463            let reference = reference_array.to_array_view::<f32>()?;
464
465            result_float
466                .to_array_view::<f32>()?
467                .iter()
468                .zip(reference.iter())
469                .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
470            Ok(())
471        }
472    }
473
474    impl Arbitrary for SoftmaxProblem {
475        type Parameters = ();
476        type Strategy = BoxedStrategy<SoftmaxProblem>;
477        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
478            (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
479                .prop_flat_map(|(n, c, h, w, axis)| {
480                    let shape_in: Vec<usize> =
481                        NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
482                    (
483                        prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
484                        Just(tvec![axis]),
485                        prop_oneof![
486                            q_datum::<u8>((0.008f32..0.1).boxed()),
487                            q_datum::<i8>((0.008f32..0.1).boxed())
488                        ],
489                    )
490                })
491                .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
492                .boxed()
493        }
494    }
495
496    #[derive(Debug)]
497    pub struct InnerSoftmaxProblem {
498        in_qp: QParams,
499        out_qp: QParams,
500        data: Vec<i8>,
501    }
502
503    impl InnerSoftmaxProblem {
504        fn check(&self) -> Result<()> {
505            let quantized = self.quantized();
506            let reference = self.reference();
507            assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
508                let abs_diff = if *quantized > *expected {
509                    quantized - *expected
510                } else {
511                    expected - *quantized
512                };
513                abs_diff <= 1
514            }));
515            Ok(())
516        }
517
518        fn reference(&self) -> Vec<u8> {
519            let (in_zero_point, in_scale) = self.in_qp.zp_scale();
520            let (out_zero_point, out_scale) = self.out_qp.zp_scale();
521            let in_float =
522                self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
523            let mut in_float_array = Array1::from_vec(in_float);
524            softmax_inner(in_float_array.view_mut(), SoftmaxKind::default());
525            let rescaled_output = in_float_array
526                .iter()
527                .map(|it| {
528                    ((*it / out_scale).round() as i32 + out_zero_point)
529                        .max(u8::MIN as i32)
530                        .min(u8::MAX as i32) as u8
531                })
532                .collect();
533            rescaled_output
534        }
535
536        fn quantized(&self) -> Vec<u8> {
537            let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
538            let mut in_array = Array1::from_vec(in_data);
539            softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
540            in_array.to_vec()
541        }
542    }
543
544    impl Arbitrary for InnerSoftmaxProblem {
545        type Parameters = ();
546        type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
547        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
548            (
549                prop_oneof![
550                    q_datum::<i8>((0.0001f32..0.01).boxed()),
551                    q_datum::<u8>((0.0001f32..0.01).boxed())
552                ],
553                prop_oneof![
554                    q_datum::<u8>((0.008f32..0.1).boxed()),
555                    q_datum::<i8>((0.008f32..0.1).boxed())
556                ],
557                vec(any::<i8>(), 1..10),
558            )
559                .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
560                    in_qp: in_qp.qparams().unwrap(),
561                    out_qp: out_qp.qparams().unwrap(),
562                    data,
563                })
564                .boxed()
565        }
566    }
567
568    proptest::proptest! {
569        #![proptest_config(ProptestConfig::with_cases(1000))]
570        #[test]
571        fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
572            pb.check().unwrap()
573        }
574    }
575
576    proptest::proptest! {
577        #![proptest_config(ProptestConfig::with_cases(1000))]
578        #[test]
579        fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
580            pb.check().unwrap()
581        }
582    }
583
584    #[test]
585    // We test QU8 -> QU8
586    fn test_softmax_trivial_0() -> Result<()> {
587        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
588        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
589        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
590        unsafe { data.set_datum_type(input_dt) };
591
592        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
593        prob.check()?;
594        Ok(())
595    }
596
597    #[test]
598    // We test QI8 -> QU8
599    fn test_softmax_trivial_1() -> Result<()> {
600        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
601        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
602        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
603        unsafe { data.set_datum_type(input_dt) };
604
605        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
606        prob.check()?;
607        Ok(())
608    }
609
610    #[test]
611    // We test QI8 -> QI8
612    fn test_softmax_trivial_2() -> Result<()> {
613        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
614        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
615        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
616        unsafe { data.set_datum_type(input_dt) };
617
618        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
619        prob.check()?;
620        Ok(())
621    }
622
623    #[test]
624    // We test QU8 -> QI8
625    fn test_softmax_trivial_3() -> Result<()> {
626        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
627        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
628        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
629        unsafe { data.set_datum_type(input_dt) };
630
631        let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
632        prob.check()?;
633        Ok(())
634    }
635
636    #[test]
637    fn test_softmax_1() -> Result<()> {
638        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); // Q6_1
639        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); // Q7_1
640        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
641        unsafe { data.set_datum_type(input_dt) };
642
643        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
644        prob.check()?;
645        Ok(())
646    }
647
648    #[test]
649    fn test_softmax_2() -> Result<()> {
650        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
651        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
652        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
653        unsafe { data.set_datum_type(input_dt) };
654
655        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
656        prob.check()?;
657        Ok(())
658    }
659
660    #[test]
661    fn test_softmax_3() -> Result<()> {
662        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
663        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
664        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
665        unsafe { data.set_datum_type(input_dt) };
666
667        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
668        prob.check()?;
669        Ok(())
670    }
671    
672    #[test]
673    fn test_inner_softmax_1() -> Result<()> {
674        let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
675        let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
676        let data = vec![0_i8, 1];
677
678        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
679        prob.check()?;
680        Ok(())
681    }
682
683    #[test]
684    fn test_inner_softmax_2() -> Result<()> {
685        let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
686        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
687        let data = vec![100i8, -28];
688
689        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
690        prob.check()?;
691        Ok(())
692    }
693
694    #[test]
695    fn test_inner_softmax_not_pow_2_1() -> Result<()> {
696        let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
697        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
698        let data = vec![100i8, -28];
699
700        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
701        prob.check()?;
702        Ok(())
703    }
704
705    #[test]
706    #[ignore]
707    // Fails but the difference is quite low and the sum still give exactly one:
708    // quantized: 110(0.88), 15(0.12)
709    // expected: 112(0.896), 13(0.104)
710    fn test_inner_softmax_not_pow_2_2() -> Result<()> {
711        let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
712        let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
713        let data = vec![118i8, 108];
714
715        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
716        prob.check()?;
717        Ok(())
718    }
719
720    #[test]
721    #[ignore]
722    // Fails but the difference is quite low and the sum still give exactly one:
723    // quantized: 40(0.625), 24(0.375)
724    // expected: 42(0.65625), 22(0.34375)
725    fn test_inner_softmax_not_pow_2_3() -> Result<()> {
726        let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
727        let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
728        let data = vec![45i8, 43];
729
730        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
731        prob.check()?;
732        Ok(())
733    }
734}