Skip to main content

tract_core/ops/nn/softmax/
mod.rs

1mod fixedpoint;
2pub mod math;
3
4use math::{
5    convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6    rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7    saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
17pub enum SoftmaxExp {
18    #[default]
19    Libc,
20    // https://nic.schraudolph.org/pubs/Schraudolph99.pdf
21    FastCompact,
22}
23
24#[derive(Debug, Clone, new, Hash, Default)]
25pub struct Softmax {
26    pub axes: TVec<usize>,
27    pub quant_output_dt: Option<DatumType>,
28    pub exp: SoftmaxExp,
29}
30
31impl Op for Softmax {
32    fn name(&self) -> Cow<str> {
33        "Softmax".into()
34    }
35
36    fn info(&self) -> TractResult<Vec<String>> {
37        Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)])
38    }
39
40    op_as_typed_op!();
41}
42
43impl TypedOp for Softmax {
44    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
45        let dt = inputs[0].datum_type;
46        if dt.is_float() {
47            ensure!(
48                self.quant_output_dt.is_none(),
49                "Float softmax should not have quant_output_dt, have {:?}",
50                self.quant_output_dt
51            );
52        } else if dt.is_quantized() {
53            ensure!(
54                self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
55                "Quantized softmax should have a quantized output type (got {:?})",
56                self.quant_output_dt
57            );
58        } else {
59            bail!(
60                "Unsupported datum type in softmax: input type {:?}, output type {:?}",
61                dt,
62                self.quant_output_dt
63            );
64        }
65
66        let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
67        Ok(tvec!(fact))
68    }
69
70    fn axes_mapping(
71        &self,
72        inputs: &[&TypedFact],
73        outputs: &[&TypedFact],
74    ) -> TractResult<AxesMapping> {
75        AxesMapping::natural(inputs, outputs)
76    }
77
78    fn change_axes(
79        &self,
80        model: &TypedModel,
81        node: &TypedNode,
82        _io: InOut,
83        change: &AxisOp,
84    ) -> TractResult<Option<AxisChangeConsequence>> {
85        let axes: Option<TVec<usize>> =
86            self.axes.iter().map(|it| change.transform_axis(*it)).collect();
87        if let Some(axes) = axes {
88            Ok(Some(AxisChangeConsequence::new(
89                model,
90                node,
91                Some(Box::new(Softmax { axes, ..self.clone() })),
92                change,
93            )))
94        } else {
95            Ok(None)
96        }
97    }
98
99    as_op!();
100}
101
102impl EvalOp for Softmax {
103    fn is_stateless(&self) -> bool {
104        true
105    }
106
107    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
108        let input = args_1!(inputs);
109        let dt = input.datum_type();
110
111        let output = match dt {
112            DatumType::F64 => self.eval_t::<f64>(input)?,
113            DatumType::F32 => self.eval_t::<f32>(input)?,
114            DatumType::F16 => self.eval_t::<f16>(input)?,
115            DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
116            dt => bail!("Unsupported type {dt:?}"),
117        };
118        Ok(output)
119    }
120}
121
122impl Softmax {
123    fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
124    where
125        T: Float + Datum + std::iter::Sum,
126    {
127        let mut iterating_shape: TVec<usize> = input.shape().into();
128
129        for i in 0..iterating_shape.len() {
130            if self.axes.contains(&i) {
131                iterating_shape[i] = 1
132            }
133        }
134
135        let mut output = input.into_tensor();
136        let mut view = output.to_array_view_mut::<T>()?;
137
138        for it_coords in tract_ndarray::indices(&*iterating_shape) {
139            let mut view = view.view_mut();
140            for ix in 0..iterating_shape.len() {
141                if !self.axes.contains(&ix) {
142                    view.collapse_axis(Axis(ix), it_coords[ix]);
143                }
144            }
145            if let Some(slice) =
146                view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
147            {
148                let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
149                self.softmax_inner_slice_f32(slice)?;
150            } else if let Some(slice) =
151                view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
152            {
153                let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
154                self.softmax_inner_slice_f16(slice)?;
155            } else {
156                softmax_inner(view);
157            }
158        }
159
160        Ok(tvec!(output.into_tvalue()))
161    }
162
163    fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
164        let mut iterating_shape: TVec<usize> = input.shape().into();
165        let output_dt =
166            self.quant_output_dt.context("Quandized softmax eval with no output type")?;
167
168        for i in 0..iterating_shape.len() {
169            if self.axes.contains(&i) {
170                iterating_shape[i] = 1
171            }
172        }
173
174        // All operations will be done in u8, we will cast the result appropriately afterward.
175        let src_is_signed = input.datum_type().is_signed();
176        let out_is_signed = output_dt.is_signed();
177        let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case
178        let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case
179        let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
180
181        for it_coords in tract_ndarray::indices(&*iterating_shape) {
182            let mut view = output.view_mut();
183            for ix in 0..iterating_shape.len() {
184                if !self.axes.contains(&ix) {
185                    view.collapse_axis(Axis(ix), it_coords[ix]);
186                }
187            }
188            softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
189        }
190
191        let mut output_tensor = output.into_tensor();
192        unsafe { output_tensor.set_datum_type(output_dt) };
193        Ok(tvec!(output_tensor.into_tvalue()))
194    }
195
196    fn softmax_inner_slice_f16(&self, slice: &mut [f16]) -> TractResult<()> {
197        let max = (tract_linalg::ops().max_f16)().run(slice)?;
198        let sum = match self.exp {
199            SoftmaxExp::Libc => {
200                let mut s = f16::zero();
201                for x in slice.iter_mut() {
202                    let y = (*x - max).exp();
203                    s += y;
204                    *x = y;
205                }
206                s
207            }
208            SoftmaxExp::FastCompact => {
209                (tract_linalg::ops().softmax2_fastcompact_f16)().run_with_params(slice, max)?
210            }
211        };
212        let rsum = sum.recip();
213        (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
214        Ok(())
215    }
216
217    fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> {
218        let max = (tract_linalg::ops().max_f32)().run(slice)?;
219        let sum = match self.exp {
220            SoftmaxExp::Libc => {
221                let mut s = 0f32;
222                for x in slice.iter_mut() {
223                    let y = (*x - max).exp();
224                    s += y;
225                    *x = y;
226                }
227                s
228            }
229            SoftmaxExp::FastCompact => {
230                (tract_linalg::ops().softmax2_fastcompact_f32)().run_with_params(slice, max)?
231            }
232        };
233        let rsum = sum.recip();
234        (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
235        Ok(())
236    }
237}
238
239fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(mut view: ArrayViewMut<T, D>) {
240    let max =
241        *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
242    view.mapv_inplace(|x| (x - max).exp());
243    let exp_sum = view.iter().copied().sum();
244    view.mapv_inplace(|x| x / exp_sum);
245}
246
247fn softmax_quant_inner<D: Dimension>(
248    mut view: ArrayViewMut<u8, D>,
249    src_is_signed: bool,
250    in_qp: QParams,
251    out_is_signed: bool,
252    out_qp: QParams,
253) {
254    let (_, in_scale) = in_qp.zp_scale();
255    let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
256    let (_, out_scale) = out_qp.zp_scale();
257    let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
258    let shift = 26 - scale_in_shift;
259
260    // Compute the exponentials x - max
261    let mut buffer = vec![0_i32; view.len()];
262
263    // Handle the case were we considered an i8 as an u8 and still get the right x - max.
264    let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
265
266    let max = view.iter().map(safe_u8).max().unwrap();
267    view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
268        let input_diff = safe_u8(x) as i32 - max as i32;
269
270        // We scale the input to be in Q5_26
271        let scaled_input_diff = if scale_in_multiplier != 0 {
272            saturating_rounding_multiply_by_pot(
273                saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
274                shift as i32,
275            )
276        } else {
277            saturating_rounding_multiply_by_pot(input_diff, shift as i32)
278        };
279
280        // It expects an input from Q5_26 and returns an output in Q0_31
281        *exp = exp_on_negative_values(scaled_input_diff);
282    });
283
284    // Compute sum of exp
285    // The sum is stored as an Q12_19 that's why we need to recale from Q0_31 to Q12_19 before summing.
286    let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
287
288    // Compute 1/sum_of_exp
289    // The result of this function is in Q0_31
290    let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
291
292    // Compute the exponent value needed to be in Q24_8 before the final rescaling
293    let exponent = num_bits_over_unit as isize + 31 - 8;
294
295    view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
296        // Compute the product of exp * 1/sum_of_exp and scale the result in Q24_8
297        let unsat_output = rounding_divide_by_pot(
298            saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
299            exponent as i32,
300        );
301
302        // Scale the final result in the output scale range
303        let unsat_scaled_output = {
304            if scale_out_multiplier != 0 {
305                let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
306                rounding_divide_by_pot(
307                    saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
308                    (8 - scale_out_shift - 1 - num_bits as isize) as i32,
309                )
310            } else {
311                rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
312            }
313        };
314
315        // Return the final result by clipping the computed value within its range
316        // and casting it to u8 in any case.
317        #[allow(unknown_lints, unnecessary_transmutes)]
318        if out_is_signed {
319            *it = unsafe {
320                std::mem::transmute::<i8, u8>(i32::max(
321                    i32::min(unsat_scaled_output, i8::MAX as i32),
322                    i8::MIN as i32,
323                ) as i8)
324            };
325        } else {
326            *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
327        }
328    });
329}
330
331#[cfg(test)]
332mod test {
333    use super::*;
334    use crate::ops::nn::DataFormat::NCHW;
335    use anyhow::Result;
336    use num_traits::PrimInt;
337    use proptest::collection::vec;
338    use proptest::prelude::*;
339    use tract_data::internal::QParams::ZpScale;
340
341    fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
342        let (_, in_epsilon) = in_dt.zp_scale();
343        let (_, out_epsilon) = out_dt.zp_scale();
344        let epsilon = f32::max(in_epsilon, out_epsilon) * 1.005;
345        let error = (found - expected).abs();
346        assert!(
347            error <= epsilon,
348            "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
349        );
350    }
351
352    // Generate a random tensor with a quantized datum type
353    fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
354        let len = shape.iter().product::<usize>();
355        let dt = q_datum::<T>((0.0001f32..0.1).boxed());
356        (vec(any::<T>(), len..=len), dt)
357            .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
358            .prop_map(move |(array, dt)| {
359                let mut tensor = array.into_tensor();
360                unsafe { tensor.set_datum_type(dt) };
361                tensor
362            })
363            .boxed()
364    }
365
366    // Generate a random quantized datum type
367    fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
368        let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
369        prop_oneof![
370            (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
371            range
372        ]
373        .prop_map(|scale| {
374            if T::datum_type().is_signed() {
375                DatumType::QI8(ZpScale { zero_point: 0, scale })
376            } else {
377                DatumType::QU8(ZpScale { zero_point: 0, scale })
378            }
379        })
380        .boxed()
381    }
382
383    #[derive(Debug)]
384    struct SoftmaxProblem {
385        data: Tensor,
386        axes: TVec<usize>,
387        output_dt: DatumType,
388    }
389
390    impl SoftmaxProblem {
391        fn check(&self) -> Result<()> {
392            let inputs = tvec!(self.data.clone().into_tvalue());
393            let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
394            let softmax =
395                Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
396
397            // Compute quantized output
398            let result = softmax.eval(inputs)?;
399            let result = args_1!(result);
400            let result_float = result.cast_to::<f32>()?;
401
402            // Compute reference output
403            let input_float = self.data.cast_to::<f32>()?;
404            let inputs_float = tvec!(input_float.into_owned().into_tvalue());
405            let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
406            let reference_float = softmax_float.eval(inputs_float)?;
407            let reference_array = args_1!(reference_float);
408            let reference = reference_array.to_array_view::<f32>()?;
409
410            result_float
411                .to_array_view::<f32>()?
412                .iter()
413                .zip(reference.iter())
414                .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
415
416            Ok(())
417        }
418    }
419
420    impl Arbitrary for SoftmaxProblem {
421        type Parameters = ();
422        type Strategy = BoxedStrategy<SoftmaxProblem>;
423        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
424            (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
425                .prop_flat_map(|(n, c, h, w, axis)| {
426                    let shape_in: Vec<usize> =
427                        NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
428                    (
429                        prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
430                        Just(tvec![axis]),
431                        prop_oneof![
432                            q_datum::<u8>((0.008f32..0.1).boxed()),
433                            q_datum::<i8>((0.008f32..0.1).boxed())
434                        ],
435                    )
436                })
437                .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
438                .boxed()
439        }
440    }
441
442    #[derive(Debug)]
443    pub struct InnerSoftmaxProblem {
444        in_qp: QParams,
445        out_qp: QParams,
446        data: Vec<i8>,
447    }
448
449    impl InnerSoftmaxProblem {
450        fn check(&self) -> Result<()> {
451            let quantized = self.quantized();
452            let reference = self.reference();
453            assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
454                let abs_diff = if *quantized > *expected {
455                    quantized - *expected
456                } else {
457                    expected - *quantized
458                };
459                abs_diff <= 1
460            }));
461            Ok(())
462        }
463
464        fn reference(&self) -> Vec<u8> {
465            let (in_zero_point, in_scale) = self.in_qp.zp_scale();
466            let (out_zero_point, out_scale) = self.out_qp.zp_scale();
467            let in_float =
468                self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
469            let mut in_float_array = Array1::from_vec(in_float);
470            softmax_inner(in_float_array.view_mut());
471            let rescaled_output = in_float_array
472                .iter()
473                .map(|it| {
474                    ((*it / out_scale).round() as i32 + out_zero_point)
475                        .max(u8::MIN as i32)
476                        .min(u8::MAX as i32) as u8
477                })
478                .collect();
479            rescaled_output
480        }
481
482        fn quantized(&self) -> Vec<u8> {
483            let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
484            let mut in_array = Array1::from_vec(in_data);
485            softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
486            in_array.to_vec()
487        }
488    }
489
490    impl Arbitrary for InnerSoftmaxProblem {
491        type Parameters = ();
492        type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
493        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
494            (
495                prop_oneof![
496                    q_datum::<i8>((0.0001f32..0.01).boxed()),
497                    q_datum::<u8>((0.0001f32..0.01).boxed())
498                ],
499                prop_oneof![
500                    q_datum::<u8>((0.008f32..0.1).boxed()),
501                    q_datum::<i8>((0.008f32..0.1).boxed())
502                ],
503                vec(any::<i8>(), 1..10),
504            )
505                .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
506                    in_qp: in_qp.qparams().unwrap(),
507                    out_qp: out_qp.qparams().unwrap(),
508                    data,
509                })
510                .boxed()
511        }
512    }
513
514    proptest::proptest! {
515        #![proptest_config(ProptestConfig::with_cases(1000))]
516        #[test]
517        fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
518            pb.check().unwrap()
519        }
520    }
521
522    proptest::proptest! {
523        #![proptest_config(ProptestConfig::with_cases(1000))]
524        #[test]
525        fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
526            pb.check().unwrap()
527        }
528    }
529
530    #[test]
531    // We test QU8 -> QU8
532    fn test_softmax_trivial_0() -> Result<()> {
533        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
534        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
535        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
536        unsafe { data.set_datum_type(input_dt) };
537
538        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
539        prob.check()?;
540        Ok(())
541    }
542
543    #[test]
544    // We test QI8 -> QU8
545    fn test_softmax_trivial_1() -> Result<()> {
546        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
547        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); // Q0_8;
548        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
549        unsafe { data.set_datum_type(input_dt) };
550
551        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
552        prob.check()?;
553        Ok(())
554    }
555
556    #[test]
557    // We test QI8 -> QI8
558    fn test_softmax_trivial_2() -> Result<()> {
559        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); // Q3_4
560        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
561        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
562        unsafe { data.set_datum_type(input_dt) };
563
564        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
565        prob.check()?;
566        Ok(())
567    }
568
569    #[test]
570    // We test QU8 -> QI8
571    fn test_softmax_trivial_3() -> Result<()> {
572        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); // Q3_5
573        let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); // Q0_7;
574        let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
575        unsafe { data.set_datum_type(input_dt) };
576
577        let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
578        prob.check()?;
579        Ok(())
580    }
581
582    #[test]
583    fn test_softmax_1() -> Result<()> {
584        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); // Q6_1
585        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); // Q7_1
586        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
587        unsafe { data.set_datum_type(input_dt) };
588
589        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
590        prob.check()?;
591        Ok(())
592    }
593
594    #[test]
595    fn test_softmax_2() -> Result<()> {
596        let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
597        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
598        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
599        unsafe { data.set_datum_type(input_dt) };
600
601        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
602        prob.check()?;
603        Ok(())
604    }
605
606    #[test]
607    fn test_softmax_3() -> Result<()> {
608        let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
609        let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
610        let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
611        unsafe { data.set_datum_type(input_dt) };
612
613        let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
614        prob.check()?;
615        Ok(())
616    }
617
618    #[test]
619    fn test_inner_softmax_1() -> Result<()> {
620        let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
621        let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
622        let data = vec![0_i8, 1];
623
624        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
625        prob.check()?;
626        Ok(())
627    }
628
629    #[test]
630    fn test_inner_softmax_2() -> Result<()> {
631        let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
632        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
633        let data = vec![100i8, -28];
634
635        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
636        prob.check()?;
637        Ok(())
638    }
639
640    #[test]
641    fn test_inner_softmax_not_pow_2_1() -> Result<()> {
642        let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
643        let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
644        let data = vec![100i8, -28];
645
646        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
647        prob.check()?;
648        Ok(())
649    }
650
651    #[test]
652    #[ignore]
653    // Fails but the difference is quite low and the sum still give exactly one:
654    // quantized: 110(0.88), 15(0.12)
655    // expected: 112(0.896), 13(0.104)
656    fn test_inner_softmax_not_pow_2_2() -> Result<()> {
657        let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
658        let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
659        let data = vec![118i8, 108];
660
661        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
662        prob.check()?;
663        Ok(())
664    }
665
666    #[test]
667    #[ignore]
668    // Fails but the difference is quite low and the sum still give exactly one:
669    // quantized: 40(0.625), 24(0.375)
670    // expected: 42(0.65625), 22(0.34375)
671    fn test_inner_softmax_not_pow_2_3() -> Result<()> {
672        let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
673        let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
674        let data = vec![45i8, 43];
675
676        let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
677        prob.check()?;
678        Ok(())
679    }
680}