tract_core/ops/nn/softmax/
mod.rs

1mod fixedpoint;
2pub mod math;
3
4use math::{
5    convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6    rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7    saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
17pub enum SoftmaxExp {
18    #[default]
19    Libc,
20    // https://nic.schraudolph.org/pubs/Schraudolph99.pdf
21    FastCompact,
22}
23
24#[derive(Debug, Clone, new, Hash, Default)]
25pub struct Softmax {
26    pub axes: TVec<usize>,
27    pub quant_output_dt: Option<DatumType>,
28    pub exp: SoftmaxExp,
29}
30
31impl Op for Softmax {
32    fn name(&self) -> Cow<str> {
33        "Softmax".into()
34    }
35
36    fn info(&self) -> TractResult<Vec<String>> {
37        Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)])
38    }
39
40    op_as_typed_op!();
41}
42
43impl TypedOp for Softmax {
44    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
45        let dt = inputs[0].datum_type;
46        if dt.is_float() {
47            ensure!(
48                self.quant_output_dt.is_none(),
49                "Float softmax should not have quant_output_dt, have {:?}",
50                self.quant_output_dt
51            );
52        } else if dt.is_quantized() {
53            ensure!(
54                self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
55                "Quantized softmax should have a quantized output type (got {:?})",
56                self.quant_output_dt
57            );
58        } else {
59            bail!(
60                "Unsupported datum type in softmax: input type {:?}, output type {:?}",
61                dt,
62                self.quant_output_dt
63            );
64        }
65
66        let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
67        Ok(tvec!(fact))
68    }
69
70    fn axes_mapping(
71        &self,
72        inputs: &[&TypedFact],
73        outputs: &[&TypedFact],
74    ) -> TractResult<AxesMapping> {
75        AxesMapping::natural(inputs, outputs)
76    }
77
78    fn change_axes(
79        &self,
80        model: &TypedModel,
81        node: &TypedNode,
82        _io: InOut,
83        change: &AxisOp,
84    ) -> TractResult<Option<AxisChangeConsequence>> {
85        let axes: Option<TVec<usize>> =
86            self.axes.iter().map(|it| change.transform_axis(*it)).collect();
87        if let Some(axes) = axes {
88            Ok(Some(AxisChangeConsequence::new(
89                model,
90                node,
91                Some(Box::new(Softmax { axes, ..self.clone() })),
92                change,
93            )))
94        } else {
95            Ok(None)
96        }
97    }
98
99    as_op!();
100}
101
102impl EvalOp for Softmax {
103    fn is_stateless(&self) -> bool {
104        true
105    }
106
107    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
108        let input = args_1!(inputs);
109        let dt = input.datum_type();
110
111        let output = match dt {
112            DatumType::F64 => self.eval_t::<f64>(input)?,
113            DatumType::F32 => self.eval_t::<f32>(input)?,
114            DatumType::F16 => self.eval_t::<f16>(input)?,
115            DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
116            dt => bail!("Unsupported type {dt:?}"),
117        };
118        Ok(output)
119    }
120}
121
122impl Softmax {
123    fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
124    where
125        T: Float + Datum + std::iter::Sum,
126    {
127        let mut iterating_shape: TVec<usize> = input.shape().into();
128
129        for i in 0..iterating_shape.len() {
130            if self.axes.contains(&i) {
131                iterating_shape[i] = 1
132            }
133        }
134
135        let mut output = input.into_tensor();
136        let mut view = output.to_array_view_mut::<T>()?;
137
138        for it_coords in tract_ndarray::indices(&*iterating_shape) {
139            let mut view = view.view_mut();
140            for ix in 0..iterating_shape.len() {
141                if !self.axes.contains(&ix) {
142                    view.collapse_axis(Axis(ix), it_coords[ix]);
143                }
144            }
145            if let Some(slice) =
146                view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
147            {
148                let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
149                self.softmax_inner_slice_f32(slice)?;
150            } else if let Some(slice) =
151                view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
152            {
153                let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
154                self.softmax_inner_slice_f16(slice)?;
155            } else {
156                softmax_inner(view);
157            }
158        }
159
160        Ok(tvec!(output.into_tvalue()))
161    }
162
163    fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
164        let mut iterating_shape: TVec<usize> = input.shape().into();
165        let output_dt =
166            self.quant_output_dt.context("Quandized softmax eval with no output type")?;
167
168        for i in 0..iterating_shape.len() {
169            if self.axes.contains(&i) {
170                iterating_shape[i] = 1
171            }
172        }
173
174        // All operations will be done in u8, we will cast the result appropriately afterward.
175        let src_is_signed = input.datum_type().is_signed();
176        let out_is_signed = output_dt.is_signed();
177        let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case
178        let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case
179        let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
180
181        for it_coords in tract_ndarray::indices(&*iterating_shape) {
182            let mut view = output.view_mut();
183            for ix in 0..iterating_shape.len() {
184                if !self.axes.contains(&ix) {
185                    view.collapse_axis(Axis(ix), it_coords[ix]);
186                }
187            }
188            softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
189        }
190
191        let mut output_tensor = output.into_tensor();
192        unsafe { output_tensor.set_datum_type(output_dt) };
193        Ok(tvec!(output_tensor.into_tvalue()))
194    }
195
196    fn softmax_inner_slice_f16(&self, slice: &mut [f16]) -> TractResult<()> {
197        let max = (tract_linalg::ops().max_f16)().run(slice)?;
198        let sum = match self.exp {
199            SoftmaxExp::Libc => {
200                let mut s = f16::zero();
201                for x in slice.iter_mut() {
202                    let y = (*x - max).exp();
203                    s += y;
204                    *x = y;
205                }
206                s
207            }
208            SoftmaxExp::FastCompact => {
209                (tract_linalg::ops().softmax2_fastcompact_f16)().run_with_params(slice, max)?
210            }
211        };
212        let rsum = sum.recip();
213        (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
214        Ok(())
215    }
216
217    fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> {
218        let max = (tract_linalg::ops().max_f32)().run(slice)?;
219        let sum = match self.exp {
220            SoftmaxExp::Libc => {
221                let mut s = 0f32;
222                for x in slice.iter_mut() {
223                    let y = (*x - max).exp();
224                    s += y;
225                    *x = y;
226                }
227                s
228            }
229            SoftmaxExp::FastCompact => {
230                (tract_linalg::ops().softmax2_fastcompact_f32)().run_with_params(slice, max)?
231            }
232        };
233        let rsum = sum.recip();
234        (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
235        Ok(())
236    }
237}
238
239fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(mut view: ArrayViewMut<T, D>) {
240    let max =
241        *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
242    view.mapv_inplace(|x| (x - max).exp());
243    let exp_sum = view.iter().copied().sum();
244    view.mapv_inplace(|x| x / exp_sum);
245}
246
247fn softmax_quant_inner<D: Dimension>(
248    mut view: ArrayViewMut<u8, D>,
249    src_is_signed: bool,
250    in_qp: QParams,
251    out_is_signed: bool,
252    out_qp: QParams,
253) {
254    let (_, in_scale) = in_qp.zp_scale();
255    let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
256    let (_, out_scale) = out_qp.zp_scale();
257    let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
258    let shift = 26 - scale_in_shift;
259
260    // Compute the exponentials x - max
261    let mut buffer = vec![0_i32; view.len()];
262
263    // Handle the case were we considered an i8 as an u8 and still get the right x - max.
264    let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
265
266    let max = view.iter().map(safe_u8).max().unwrap();
267    view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
268        let input_diff = safe_u8(x) as i32 - max as i32;
269
270        // We scale the input to be in Q5_26
271        let scaled_input_diff = if scale_in_multiplier != 0 {
272            saturating_rounding_multiply_by_pot(
273                saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
274                shift as i32,
275            )
276        } else {
277            saturating_rounding_multiply_by_pot(input_diff, shift as i32)
278        };
279
280        // It expects an input from Q5_26 and returns an output in Q0_31
281        *exp = exp_on_negative_values(scaled_input_diff);
282    });
283
284    // Compute sum of exp
285    // The sum is stored as an Q12_19 that's why we need to recale from Q0_31 to Q12_19 before summing.
286    let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
287
288    // Compute 1/sum_of_exp
289    // The result of this function is in Q0_31
290    let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
291
292    // Compute the exponent value needed to be in Q24_8 before the final rescaling
293    let exponent = num_bits_over_unit as isize + 31 - 8;
294
295    view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
296        // Compute the product of exp * 1/sum_of_exp and scale the result in Q24_8
297        let unsat_output = rounding_divide_by_pot(
298            saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
299            exponent as i32,
300        );
301
302        // Scale the final result in the output scale range
303        let unsat_scaled_output = {
304            if scale_out_multiplier != 0 {
305                let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
306                rounding_divide_by_pot(
307                    saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
308                    (8 - scale_out_shift - 1 - num_bits as isize) as i32,
309                )
310            } else {
311                rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
312            }
313        };
314
315        // Return the final result by clipping the computed value within its range
316        // and casting it to u8 in any case.
317        if out_is_signed {
318            *it = unsafe {
319                std::mem::transmute::<i8, u8>(i32::max(
320                    i32::min(unsat_scaled_output, i8::MAX as i32),
321                    i8::MIN as i32,
322                ) as i8)
323            };
324        } else {
325            *it = i32::max(
326                i32::min(unsat_scaled_output, u8::MAX as i32),
327                u8::MIN as i32,
328            ) as u8;
329        }
330    });
331}
332
333#[cfg(test)]
334mod test {
335    use super::*;
336    use crate::ops::nn::DataFormat::NCHW;
337    use anyhow::Result;
338    use num_traits::PrimInt;
339    use proptest::collection::vec;
340    use proptest::prelude::*;
341    use tract_data::internal::QParams::ZpScale;
342
343    fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
344        let (_, in_epsilon) = in_dt.zp_scale();
345        let (_, out_epsilon) = out_dt.zp_scale();
346        let epsilon = f32::max(in_epsilon, out_epsilon);
347        let error = (found - expected).abs();
348        assert!(
349            error <= epsilon,
350            "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
351        );
352    }
353
354    // Generate a random tensor with a quantized datum type
355    fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
356        let len = shape.iter().product::<usize>();
357        let dt = q_datum::<T>((0.0001f32..0.1).boxed());
358        (vec(any::<T>(), len..=len), dt)
359            .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
360            .prop_map(move |(array, dt)| {
361                let mut tensor = array.into_tensor();
362                unsafe { tensor.set_datum_type(dt) };
363                tensor
364            })
365            .boxed()
366    }
367
368    // Generate a random quantized datum type
369    fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
370        let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
371        prop_oneof![
372            (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
373            range
374        ]
375        .prop_map(|scale| {
376            if T::datum_type().is_signed() {
377                DatumType::QI8(ZpScale { zero_point: 0, scale })
378            } else {
379                DatumType::QU8(ZpScale { zero_point: 0, scale })
380            }
381        })
382        .boxed()
383    }
384
385    #[derive(Debug)]
386    struct SoftmaxProblem {
387        data: Tensor,
388        axes: TVec<usize>,
389        output_dt: DatumType,
390    }
391
392    impl SoftmaxProblem {
393        fn check(&self) -> Result<()> {
394            let inputs = tvec!(self.data.clone().into_tvalue());
395            let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
396            let softmax =
397                Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
398
399            // Compute quantized output
400            let result = softmax.eval(inputs)?;
401            let result = args_1!(result);
402            let result_float = result.cast_to::<f32>()?;
403
404            // Compute reference output
405            let input_float = self.data.cast_to::<f32>()?;
406            let inputs_float = tvec!(input_float.into_owned().into_tvalue());
407            let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
408            let reference_float = softmax_float.eval(inputs_float)?;
409            let reference_array = args_1!(reference_float);
410            let reference = reference_array.to_array_view::<f32>()?;
411
412            result_float
413                .to_array_view::<f32>()?
414                .iter()
415                .zip(reference.iter())
416                .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
417
418            Ok(())
419        }
420    }
421
422    impl Arbitrary for SoftmaxProblem {
423        type Parameters = ();
424        type Strategy = BoxedStrategy<SoftmaxProblem>;
425        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
426            (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
427                .prop_flat_map(|(n, c, h, w, axis)| {
428                    let shape_in: Vec<usize> =
429                        NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
430                    (
431                        prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
432                        Just(tvec![axis]),
433                        prop_oneof![
434                            q_datum::<u8>((0.008f32..0.1).boxed()),
435                            q_datum::<i8>((0.008f32..0.1).boxed())
436                        ],
437                    )
438                })
439                .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
440                .boxed()
441        }
442    }
443
444    #[derive(Debug)]
445    pub struct InnerSoftmaxProblem {
446        in_qp: QParams,
447        out_qp: QParams,
448        data: Vec<i8>,
449    }
450
451    impl InnerSoftmaxProblem {
452        fn check(&self) -> Result<()> {
453            let quantized = self.quantized();
454            let reference = self.reference();
455            assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
456                let abs_diff = if *quantized > *expected {
457                    quantized - *expected
458                } else {
459                    expected - *quantized
460                };
461                abs_diff <= 1
462            }));
463            Ok(())
464        }
465
466        fn reference(&self) -> Vec<u8> {
467            let (in_zero_point, in_scale) = self.in_qp.zp_scale();
468            let (out_zero_point, out_scale) = self.out_qp.zp_scale();
469            let in_float =
470                self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
471            let mut in_float_array = Array1::from_vec(in_float);
472            softmax_inner(in_float_array.view_mut());
473            let rescaled_output = in_float_array
474                .iter()
475                .map(|it| {
476                    ((*it / out_scale).round() as i32 + out_zero_point)
477                        .max(u8::MIN as i32)
478                        .min(u8::MAX as i32) as u8
479                })
480                .collect();
481            rescaled_output
482        }
483
484        fn quantized(&self) -> Vec<u8> {
485            let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
486            let mut in_array = Array1::from_vec(in_data);
487            softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
488            in_array.to_vec()
489        }
490    }
491
492    impl Arbitrary for InnerSoftmaxProblem {
493        type Parameters = ();
494        type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
495        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
496            (
497                prop_oneof![
498                    q_datum::<i8>((0.0001f32..0.01).boxed()),
499                    q_datum::<u8>((0.0001f32..0.01).boxed())
500                ],
501                prop_oneof![
502                    q_datum::<u8>((0.008f32..0.1).boxed()),
503                    q_datum::<i8>((0.008f32..0.1).boxed())
504                ],
505                vec(any::<i8>(), 1..10),
506            )
507                .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
508                    in_qp: in_qp.qparams().unwrap(),
509                    out_qp: out_qp.qparams().unwrap(),
510                    data,
511                })
512                .boxed()
513        }
514    }
515
516    proptest::proptest! {
517        #![proptest_config(ProptestConfig::with_cases(1000))]
518        #[test]
519        fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
520            pb.check().unwrap()
521        }
522    }
523
524    proptest::proptest! {
525        #![proptest_config(ProptestConfig::with_cases(1000))]
526        #[test]
527        fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
528            pb.check().unwrap()
529        }
530    }
531
532    #[test]
533    // We test QU8 -> QU8
534    fn test_softmax_trivial_0() -> Result<()> {
535        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
536        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
537        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
538        unsafe { data.set_datum_type(input_dt) };
539
540        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
541        prob.check()?;
542        Ok(())
543    }
544
545    #[test]
546    // We test QI8 -> QU8
547    fn test_softmax_trivial_1() -> Result<()> {
548        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
549        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
550        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
551        unsafe { data.set_datum_type(input_dt) };
552
553        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
554        prob.check()?;
555        Ok(())
556    }
557
558    #[test]
559    // We test QI8 -> QI8
560    fn test_softmax_trivial_2() -> Result<()> {
561        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
562        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
563        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
564        unsafe { data.set_datum_type(input_dt) };
565
566        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
567        prob.check()?;
568        Ok(())
569    }
570
571    #[test]
572    // We test QU8 -> QI8
573    fn test_softmax_trivial_3() -> Result<()> {
574        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
575        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
576        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
577        unsafe { data.set_datum_type(input_dt) };
578
579        let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
580        prob.check()?;
581        Ok(())
582    }
583
584    #[test]
585    fn test_softmax_1() -> Result<()> {
586        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); // Q6_1
587        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); // Q7_1
588        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
589        unsafe { data.set_datum_type(input_dt) };
590
591        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
592        prob.check()?;
593        Ok(())
594    }
595
596    #[test]
597    fn test_softmax_2() -> Result<()> {
598        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
599        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
600        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
601        unsafe { data.set_datum_type(input_dt) };
602
603        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
604        prob.check()?;
605        Ok(())
606    }
607
608    #[test]
609    fn test_softmax_3() -> Result<()> {
610        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
611        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
612        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
613        unsafe { data.set_datum_type(input_dt) };
614
615        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
616        prob.check()?;
617        Ok(())
618    }
619
620    #[test]
621    fn test_inner_softmax_1() -> Result<()> {
622        let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
623        let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
624        let data = vec![0_i8, 1];
625
626        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
627        prob.check()?;
628        Ok(())
629    }
630
631    #[test]
632    fn test_inner_softmax_2() -> Result<()> {
633        let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
634        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
635        let data = vec![100i8, -28];
636
637        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
638        prob.check()?;
639        Ok(())
640    }
641
642    #[test]
643    fn test_inner_softmax_not_pow_2_1() -> Result<()> {
644        let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
645        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
646        let data = vec![100i8, -28];
647
648        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
649        prob.check()?;
650        Ok(())
651    }
652
653    #[test]
654    #[ignore]
655    // Fails but the difference is quite low and the sum still give exactly one:
656    // quantized: 110(0.88), 15(0.12)
657    // expected: 112(0.896), 13(0.104)
658    fn test_inner_softmax_not_pow_2_2() -> Result<()> {
659        let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
660        let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
661        let data = vec![118i8, 108];
662
663        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
664        prob.check()?;
665        Ok(())
666    }
667
668    #[test]
669    #[ignore]
670    // Fails but the difference is quite low and the sum still give exactly one:
671    // quantized: 40(0.625), 24(0.375)
672    // expected: 42(0.65625), 22(0.34375)
673    fn test_inner_softmax_not_pow_2_3() -> Result<()> {
674        let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
675        let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
676        let data = vec![45i8, 43];
677
678        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
679        prob.check()?;
680        Ok(())
681    }
682}