1mod fixedpoint;
2pub mod math;
3
4use math::{
5 convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6 rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7 saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, PartialEq)]
17pub enum SoftmaxKind {
18 Softmax(SoftmaxExp),
19 LogSoftmax,
20}
21
22impl Default for SoftmaxKind {
23 fn default() -> Self {
24 SoftmaxKind::Softmax(SoftmaxExp::default())
25 }
26}
27
28#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
29pub enum SoftmaxExp {
30 #[default]
31 Libc,
32 FastCompact,
34}
35
36#[derive(Debug, Clone, new, Hash, Default)]
37pub struct Softmax {
38 pub axes: TVec<usize>,
39 pub quant_output_dt: Option<DatumType>,
40 pub kind: SoftmaxKind,
41}
42
43impl Op for Softmax {
44 fn name(&self) -> StaticName {
45 match self.kind {
46 SoftmaxKind::Softmax(_) => "Softmax".into(),
47 SoftmaxKind::LogSoftmax => "LogSoftmax".into(),
48 }
49 }
50
51 fn info(&self) -> TractResult<Vec<String>> {
52 let mut infos = vec![format!("Axis: {:?}", self.axes)];
53 if let SoftmaxKind::Softmax(exp) = self.kind {
54 infos.push(format!("Exp impl: {exp:?}"))
55 };
56 Ok(infos)
57 }
58
59 op_as_typed_op!();
60}
61
62impl TypedOp for Softmax {
63 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64 let dt = inputs[0].datum_type;
65 if dt.is_float() {
66 ensure!(
67 self.quant_output_dt.is_none(),
68 "Float softmax should not have quant_output_dt, have {:?}",
69 self.quant_output_dt
70 );
71 } else if dt.is_quantized() {
72 ensure!(
73 self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
74 "Quantized softmax should have a quantized output type (got {:?})",
75 self.quant_output_dt
76 );
77 } else {
78 bail!(
79 "Unsupported datum type in softmax: input type {:?}, output type {:?}",
80 dt,
81 self.quant_output_dt
82 );
83 }
84
85 let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
86 Ok(tvec!(fact))
87 }
88
89 fn axes_mapping(
90 &self,
91 inputs: &[&TypedFact],
92 outputs: &[&TypedFact],
93 ) -> TractResult<AxesMapping> {
94 AxesMapping::natural(inputs, outputs)
95 }
96
97 fn change_axes(
98 &self,
99 model: &TypedModel,
100 node: &TypedNode,
101 _io: InOut,
102 change: &AxisOp,
103 ) -> TractResult<Option<AxisChangeConsequence>> {
104 let axes: Option<TVec<usize>> =
105 self.axes.iter().map(|it| change.transform_axis(*it)).collect();
106 if let Some(axes) = axes {
107 Ok(Some(AxisChangeConsequence::new(
108 model,
109 node,
110 Some(Box::new(Softmax { axes, ..self.clone() })),
111 change,
112 )))
113 } else {
114 Ok(None)
115 }
116 }
117
118 as_op!();
119}
120
121impl EvalOp for Softmax {
122 fn is_stateless(&self) -> bool {
123 true
124 }
125
126 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
127 let input = args_1!(inputs);
128 let dt = input.datum_type();
129
130 let output = match dt {
131 DatumType::F64 => self.eval_t::<f64>(input)?,
132 DatumType::F32 => self.eval_t::<f32>(input)?,
133 DatumType::F16 => self.eval_t::<f16>(input)?,
134 DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
135 dt => bail!("Unsupported type {dt:?}"),
136 };
137 Ok(output)
138 }
139}
140
141impl Softmax {
142 fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
143 where
144 T: Float + Datum + std::iter::Sum,
145 {
146 let mut iterating_shape: TVec<usize> = input.shape().into();
147
148 for i in 0..iterating_shape.len() {
149 if self.axes.contains(&i) {
150 iterating_shape[i] = 1
151 }
152 }
153
154 let mut output = input.into_tensor();
155 let mut view = output.to_array_view_mut::<T>()?;
156
157 for it_coords in tract_ndarray::indices(&*iterating_shape) {
158 let mut view = view.view_mut();
159 for ix in 0..iterating_shape.len() {
160 if !self.axes.contains(&ix) {
161 view.collapse_axis(Axis(ix), it_coords[ix]);
162 }
163 }
164 if let Some(slice) =
165 view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
166 {
167 let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
168 self.softmax_inner_slice_f32(slice, self.kind)?;
169 } else if let Some(slice) =
170 view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
171 {
172 let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
173 self.softmax_inner_slice_f16(slice, self.kind)?;
174 } else {
175 softmax_inner(view, self.kind);
176 }
177 }
178
179 Ok(tvec!(output.into_tvalue()))
180 }
181
182 fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
183 if self.kind == SoftmaxKind::LogSoftmax {
184 bail!("Quantized LogSoftmax is not supported")
185 }
186 let mut iterating_shape: TVec<usize> = input.shape().into();
187 let output_dt =
188 self.quant_output_dt.context("Quandized softmax eval with no output type")?;
189
190 for i in 0..iterating_shape.len() {
191 if self.axes.contains(&i) {
192 iterating_shape[i] = 1
193 }
194 }
195
196 let src_is_signed = input.datum_type().is_signed();
198 let out_is_signed = output_dt.is_signed();
199 let in_qp = input.datum_type().qparams().unwrap(); let out_qp = output_dt.qparams().unwrap(); let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
202
203 for it_coords in tract_ndarray::indices(&*iterating_shape) {
204 let mut view = output.view_mut();
205 for ix in 0..iterating_shape.len() {
206 if !self.axes.contains(&ix) {
207 view.collapse_axis(Axis(ix), it_coords[ix]);
208 }
209 }
210 softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
211 }
212
213 let mut output_tensor = output.into_tensor();
214 unsafe { output_tensor.set_datum_type(output_dt) };
215 Ok(tvec!(output_tensor.into_tvalue()))
216 }
217
218 fn softmax_inner_slice_f16(&self, slice: &mut [f16], kind: SoftmaxKind) -> TractResult<()> {
219 let max = (tract_linalg::ops().max_f16)().run(slice)?;
220 match kind {
221 SoftmaxKind::Softmax(exp_impl) => {
222 let sum = match exp_impl {
223 SoftmaxExp::Libc => {
224 let mut s = f16::zero();
225 slice.iter_mut().for_each(|x| {
226 *x = (*x - max).exp();
227 s += *x;
228 });
229 s
230 }
231 SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f16)()
232 .run_with_params(slice, max)?,
233 };
234 let rsum = sum.recip();
235 (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
236 }
237 SoftmaxKind::LogSoftmax => {
238 let mut exp_sum = f16::zero();
239 slice.iter_mut().for_each(|x| {
240 *x -= max;
241 exp_sum += x.exp();
242 });
243 let log_sum = exp_sum.ln();
244 slice.iter_mut().for_each(|x| *x -= log_sum);
245 }
246 }
247 Ok(())
248 }
249
250 fn softmax_inner_slice_f32(&self, slice: &mut [f32], kind: SoftmaxKind) -> TractResult<()> {
251 let max = (tract_linalg::ops().max_f32)().run(slice)?;
252 match kind {
253 SoftmaxKind::Softmax(exp_impl) => {
254 let sum = match exp_impl {
255 SoftmaxExp::Libc => {
256 let mut s = f32::zero();
257 slice.iter_mut().for_each(|x| {
258 *x = (*x - max).exp();
259 s += *x;
260 });
261 s
262 }
263 SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f32)()
264 .run_with_params(slice, max)?,
265 };
266 let rsum = sum.recip();
267 (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
268 }
269 SoftmaxKind::LogSoftmax => {
270 let mut exp_sum = f32::zero();
271 slice.iter_mut().for_each(|x| {
272 *x -= max;
273 exp_sum += x.exp();
274 });
275 let log_sum = exp_sum.ln();
276 slice.iter_mut().for_each(|x| *x -= log_sum);
277 }
278 }
279 Ok(())
280 }
281}
282
283fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(
284 mut view: ArrayViewMut<T, D>,
285 kind: SoftmaxKind,
286) {
287 let max =
288 *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
289 view.mapv_inplace(|x| x - max);
290 let exp_sum = view.iter().map(|&x| x.exp()).sum();
291 match kind {
292 SoftmaxKind::Softmax(_) => {
293 view.mapv_inplace(|x| x.exp() / exp_sum);
294 }
295 SoftmaxKind::LogSoftmax => {
296 let log_sum = exp_sum.ln();
297 view.mapv_inplace(|x| x - log_sum);
298 }
299 }
300}
301
302fn softmax_quant_inner<D: Dimension>(
303 mut view: ArrayViewMut<u8, D>,
304 src_is_signed: bool,
305 in_qp: QParams,
306 out_is_signed: bool,
307 out_qp: QParams,
308) {
309 let (_, in_scale) = in_qp.zp_scale();
310 let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
311 let (_, out_scale) = out_qp.zp_scale();
312 let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
313 let shift = 26 - scale_in_shift;
314
315 let mut buffer = vec![0_i32; view.len()];
317
318 let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
320
321 let max = view.iter().map(safe_u8).max().unwrap();
322 view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
323 let input_diff = safe_u8(x) as i32 - max as i32;
324
325 let scaled_input_diff = if scale_in_multiplier != 0 {
327 saturating_rounding_multiply_by_pot(
328 saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
329 shift as i32,
330 )
331 } else {
332 saturating_rounding_multiply_by_pot(input_diff, shift as i32)
333 };
334
335 *exp = exp_on_negative_values(scaled_input_diff);
337 });
338
339 let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
342
343 let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
346
347 let exponent = num_bits_over_unit as isize + 31 - 8;
349
350 view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
351 let unsat_output = rounding_divide_by_pot(
353 saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
354 exponent as i32,
355 );
356
357 let unsat_scaled_output = {
359 if scale_out_multiplier != 0 {
360 let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
361 rounding_divide_by_pot(
362 saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
363 (8 - scale_out_shift - 1 - num_bits as isize) as i32,
364 )
365 } else {
366 rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
367 }
368 };
369
370 #[allow(unknown_lints, unnecessary_transmutes)]
373 if out_is_signed {
374 *it = unsafe {
375 std::mem::transmute::<i8, u8>(i32::max(
376 i32::min(unsat_scaled_output, i8::MAX as i32),
377 i8::MIN as i32,
378 ) as i8)
379 };
380 } else {
381 *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
382 }
383 });
384}
385
386#[cfg(test)]
387mod test {
388 use super::*;
389 use crate::ops::nn::DataFormat::NCHW;
390 use anyhow::Result;
391 use num_traits::PrimInt;
392 use proptest::collection::vec;
393 use proptest::prelude::*;
394 use tract_data::internal::QParams::ZpScale;
395
396 fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
397 let (_, in_epsilon) = in_dt.zp_scale();
398 let (_, out_epsilon) = out_dt.zp_scale();
399 let epsilon = in_epsilon + out_epsilon;
400 let error = (found - expected).abs();
401 assert!(
402 error <= epsilon,
403 "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
404 );
405 }
406
407 fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
409 let len = shape.iter().product::<usize>();
410 let dt = q_datum::<T>((0.0001f32..0.1).boxed());
411 (vec(any::<T>(), len..=len), dt)
412 .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
413 .prop_map(move |(array, dt)| {
414 let mut tensor = array.into_tensor();
415 unsafe { tensor.set_datum_type(dt) };
416 tensor
417 })
418 .boxed()
419 }
420
421 fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
423 let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
424 prop_oneof![
425 (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
426 range
427 ]
428 .prop_map(|scale| {
429 if T::datum_type().is_signed() {
430 DatumType::QI8(ZpScale { zero_point: 0, scale })
431 } else {
432 DatumType::QU8(ZpScale { zero_point: 0, scale })
433 }
434 })
435 .boxed()
436 }
437
438 #[derive(Debug)]
439 struct SoftmaxProblem {
440 data: Tensor,
441 axes: TVec<usize>,
442 output_dt: DatumType,
443 }
444
445 impl SoftmaxProblem {
446 fn check(&self) -> Result<()> {
447 let inputs = tvec!(self.data.clone().into_tvalue());
448 let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
449 let softmax =
450 Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
451
452 let result = softmax.eval(inputs)?;
454 let result = args_1!(result);
455 let result_float = result.cast_to::<f32>()?;
456
457 let input_float = self.data.cast_to::<f32>()?;
459 let inputs_float = tvec!(input_float.into_owned().into_tvalue());
460 let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
461 let reference_float = softmax_float.eval(inputs_float)?;
462 let reference_array = args_1!(reference_float);
463 let reference = reference_array.to_array_view::<f32>()?;
464
465 result_float
466 .to_array_view::<f32>()?
467 .iter()
468 .zip(reference.iter())
469 .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
470 Ok(())
471 }
472 }
473
474 impl Arbitrary for SoftmaxProblem {
475 type Parameters = ();
476 type Strategy = BoxedStrategy<SoftmaxProblem>;
477 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
478 (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
479 .prop_flat_map(|(n, c, h, w, axis)| {
480 let shape_in: Vec<usize> =
481 NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
482 (
483 prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
484 Just(tvec![axis]),
485 prop_oneof![
486 q_datum::<u8>((0.008f32..0.1).boxed()),
487 q_datum::<i8>((0.008f32..0.1).boxed())
488 ],
489 )
490 })
491 .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
492 .boxed()
493 }
494 }
495
496 #[derive(Debug)]
497 pub struct InnerSoftmaxProblem {
498 in_qp: QParams,
499 out_qp: QParams,
500 data: Vec<i8>,
501 }
502
503 impl InnerSoftmaxProblem {
504 fn check(&self) -> Result<()> {
505 let quantized = self.quantized();
506 let reference = self.reference();
507 assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
508 let abs_diff = if *quantized > *expected {
509 quantized - *expected
510 } else {
511 expected - *quantized
512 };
513 abs_diff <= 1
514 }));
515 Ok(())
516 }
517
518 fn reference(&self) -> Vec<u8> {
519 let (in_zero_point, in_scale) = self.in_qp.zp_scale();
520 let (out_zero_point, out_scale) = self.out_qp.zp_scale();
521 let in_float =
522 self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
523 let mut in_float_array = Array1::from_vec(in_float);
524 softmax_inner(in_float_array.view_mut(), SoftmaxKind::default());
525 let rescaled_output = in_float_array
526 .iter()
527 .map(|it| {
528 ((*it / out_scale).round() as i32 + out_zero_point)
529 .max(u8::MIN as i32)
530 .min(u8::MAX as i32) as u8
531 })
532 .collect();
533 rescaled_output
534 }
535
536 fn quantized(&self) -> Vec<u8> {
537 let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
538 let mut in_array = Array1::from_vec(in_data);
539 softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
540 in_array.to_vec()
541 }
542 }
543
544 impl Arbitrary for InnerSoftmaxProblem {
545 type Parameters = ();
546 type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
547 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
548 (
549 prop_oneof![
550 q_datum::<i8>((0.0001f32..0.01).boxed()),
551 q_datum::<u8>((0.0001f32..0.01).boxed())
552 ],
553 prop_oneof![
554 q_datum::<u8>((0.008f32..0.1).boxed()),
555 q_datum::<i8>((0.008f32..0.1).boxed())
556 ],
557 vec(any::<i8>(), 1..10),
558 )
559 .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
560 in_qp: in_qp.qparams().unwrap(),
561 out_qp: out_qp.qparams().unwrap(),
562 data,
563 })
564 .boxed()
565 }
566 }
567
568 proptest::proptest! {
569 #![proptest_config(ProptestConfig::with_cases(1000))]
570 #[test]
571 fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
572 pb.check().unwrap()
573 }
574 }
575
576 proptest::proptest! {
577 #![proptest_config(ProptestConfig::with_cases(1000))]
578 #[test]
579 fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
580 pb.check().unwrap()
581 }
582 }
583
584 #[test]
585 fn test_softmax_trivial_0() -> Result<()> {
587 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
590 unsafe { data.set_datum_type(input_dt) };
591
592 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
593 prob.check()?;
594 Ok(())
595 }
596
597 #[test]
598 fn test_softmax_trivial_1() -> Result<()> {
600 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
603 unsafe { data.set_datum_type(input_dt) };
604
605 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
606 prob.check()?;
607 Ok(())
608 }
609
610 #[test]
611 fn test_softmax_trivial_2() -> Result<()> {
613 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
616 unsafe { data.set_datum_type(input_dt) };
617
618 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
619 prob.check()?;
620 Ok(())
621 }
622
623 #[test]
624 fn test_softmax_trivial_3() -> Result<()> {
626 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
629 unsafe { data.set_datum_type(input_dt) };
630
631 let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
632 prob.check()?;
633 Ok(())
634 }
635
636 #[test]
637 fn test_softmax_1() -> Result<()> {
638 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
641 unsafe { data.set_datum_type(input_dt) };
642
643 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
644 prob.check()?;
645 Ok(())
646 }
647
648 #[test]
649 fn test_softmax_2() -> Result<()> {
650 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
651 let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
652 let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
653 unsafe { data.set_datum_type(input_dt) };
654
655 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
656 prob.check()?;
657 Ok(())
658 }
659
660 #[test]
661 fn test_softmax_3() -> Result<()> {
662 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
663 let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
664 let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
665 unsafe { data.set_datum_type(input_dt) };
666
667 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
668 prob.check()?;
669 Ok(())
670 }
671
672 #[test]
673 fn test_inner_softmax_1() -> Result<()> {
674 let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
675 let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
676 let data = vec![0_i8, 1];
677
678 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
679 prob.check()?;
680 Ok(())
681 }
682
683 #[test]
684 fn test_inner_softmax_2() -> Result<()> {
685 let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
686 let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
687 let data = vec![100i8, -28];
688
689 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
690 prob.check()?;
691 Ok(())
692 }
693
694 #[test]
695 fn test_inner_softmax_not_pow_2_1() -> Result<()> {
696 let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
697 let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
698 let data = vec![100i8, -28];
699
700 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
701 prob.check()?;
702 Ok(())
703 }
704
705 #[test]
706 #[ignore]
707 fn test_inner_softmax_not_pow_2_2() -> Result<()> {
711 let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
712 let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
713 let data = vec![118i8, 108];
714
715 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
716 prob.check()?;
717 Ok(())
718 }
719
720 #[test]
721 #[ignore]
722 fn test_inner_softmax_not_pow_2_3() -> Result<()> {
726 let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
727 let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
728 let data = vec![45i8, 43];
729
730 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
731 prob.check()?;
732 Ok(())
733 }
734}