1mod fixedpoint;
2pub mod math;
3
4use math::{
5 convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6 rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7 saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, PartialEq)]
17pub enum SoftmaxKind {
18 Softmax(SoftmaxExp),
19 LogSoftmax,
20}
21
22impl Default for SoftmaxKind {
23 fn default() -> Self {
24 SoftmaxKind::Softmax(SoftmaxExp::default())
25 }
26}
27
28#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
29pub enum SoftmaxExp {
30 #[default]
31 Libc,
32 FastCompact,
34}
35
36#[derive(Debug, Clone, new, Hash, Default)]
37pub struct Softmax {
38 pub axes: TVec<usize>,
39 pub quant_output_dt: Option<DatumType>,
40 pub kind: SoftmaxKind,
41}
42
43impl Op for Softmax {
44 fn name(&self) -> StaticName {
45 match self.kind {
46 SoftmaxKind::Softmax(_) => "Softmax".into(),
47 SoftmaxKind::LogSoftmax => "LogSoftmax".into(),
48 }
49 }
50
51 fn info(&self) -> TractResult<Vec<String>> {
52 let mut infos = vec![format!("Axis: {:?}", self.axes)];
53 if let SoftmaxKind::Softmax(exp) = self.kind {
54 infos.push(format!("Exp impl: {exp:?}"))
55 };
56 Ok(infos)
57 }
58
59 op_as_typed_op!();
60}
61
62impl TypedOp for Softmax {
63 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64 let dt = inputs[0].datum_type;
65 if dt.is_float() {
66 ensure!(
67 self.quant_output_dt.is_none(),
68 "Float softmax should not have quant_output_dt, have {:?}",
69 self.quant_output_dt
70 );
71 } else if dt.is_quantized() {
72 ensure!(
73 self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
74 "Quantized softmax should have a quantized output type (got {:?})",
75 self.quant_output_dt
76 );
77 } else {
78 bail!(
79 "Unsupported datum type in softmax: input type {:?}, output type {:?}",
80 dt,
81 self.quant_output_dt
82 );
83 }
84
85 let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
86 Ok(tvec!(fact))
87 }
88
89 fn axes_mapping(
90 &self,
91 inputs: &[&TypedFact],
92 outputs: &[&TypedFact],
93 ) -> TractResult<AxesMapping> {
94 AxesMapping::natural(inputs, outputs)
95 }
96
97 fn change_axes(
98 &self,
99 model: &TypedModel,
100 node: &TypedNode,
101 _io: InOut,
102 change: &AxisOp,
103 ) -> TractResult<Option<AxisChangeConsequence>> {
104 let axes: Option<TVec<usize>> =
105 self.axes.iter().map(|it| change.transform_axis(*it)).collect();
106 if let Some(axes) = axes {
107 Ok(Some(AxisChangeConsequence::new(
108 model,
109 node,
110 Some(Box::new(Softmax { axes, ..self.clone() })),
111 change,
112 )))
113 } else {
114 Ok(None)
115 }
116 }
117
118 as_op!();
119}
120
121impl EvalOp for Softmax {
122 fn is_stateless(&self) -> bool {
123 true
124 }
125
126 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
127 let input = args_1!(inputs);
128 let dt = input.datum_type();
129
130 let output = match dt {
131 DatumType::F64 => self.eval_t::<f64>(input)?,
132 DatumType::F32 => self.eval_t::<f32>(input)?,
133 DatumType::F16 => self.eval_t::<f16>(input)?,
134 DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
135 dt => bail!("Unsupported type {dt:?}"),
136 };
137 Ok(output)
138 }
139}
140
141impl Softmax {
142 fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
143 where
144 T: Float + Datum + std::iter::Sum,
145 {
146 let mut iterating_shape: TVec<usize> = input.shape().into();
147
148 for i in 0..iterating_shape.len() {
149 if self.axes.contains(&i) {
150 iterating_shape[i] = 1
151 }
152 }
153
154 let mut output = input.into_tensor();
155 let mut output_dense = output.try_as_dense_mut()?;
156 let mut view = output_dense.to_array_view_mut::<T>()?;
157
158 for it_coords in tract_ndarray::indices(&*iterating_shape) {
159 let mut view = view.view_mut();
160 for ix in 0..iterating_shape.len() {
161 if !self.axes.contains(&ix) {
162 view.collapse_axis(Axis(ix), it_coords[ix]);
163 }
164 }
165 if let Some(slice) =
166 view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
167 {
168 let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
169 self.softmax_inner_slice_f32(slice, self.kind)?;
170 } else if let Some(slice) =
171 view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
172 {
173 let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
174 self.softmax_inner_slice_f16(slice, self.kind)?;
175 } else {
176 softmax_inner(view, self.kind);
177 }
178 }
179
180 Ok(tvec!(output.into_tvalue()))
181 }
182
183 fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
184 if self.kind == SoftmaxKind::LogSoftmax {
185 bail!("Quantized LogSoftmax is not supported")
186 }
187 let mut iterating_shape: TVec<usize> = input.shape().into();
188 let output_dt =
189 self.quant_output_dt.context("Quandized softmax eval with no output type")?;
190
191 for i in 0..iterating_shape.len() {
192 if self.axes.contains(&i) {
193 iterating_shape[i] = 1
194 }
195 }
196
197 let src_is_signed = input.datum_type().is_signed();
199 let out_is_signed = output_dt.is_signed();
200 let in_qp = input.datum_type().qparams().unwrap(); let out_qp = output_dt.qparams().unwrap(); let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
203
204 for it_coords in tract_ndarray::indices(&*iterating_shape) {
205 let mut view = output.view_mut();
206 for ix in 0..iterating_shape.len() {
207 if !self.axes.contains(&ix) {
208 view.collapse_axis(Axis(ix), it_coords[ix]);
209 }
210 }
211 softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
212 }
213
214 let mut output_tensor = output.into_tensor();
215 unsafe { output_tensor.set_datum_type(output_dt) };
216 Ok(tvec!(output_tensor.into_tvalue()))
217 }
218
219 fn softmax_inner_slice_f16(&self, slice: &mut [f16], kind: SoftmaxKind) -> TractResult<()> {
220 let max = (tract_linalg::ops().max_f16)().run(slice)?;
221 match kind {
222 SoftmaxKind::Softmax(exp_impl) => {
223 let sum = match exp_impl {
224 SoftmaxExp::Libc => {
225 let mut s = f16::zero();
226 slice.iter_mut().for_each(|x| {
227 *x = (*x - max).exp();
228 s += *x;
229 });
230 s
231 }
232 SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f16)()
233 .run_with_params(slice, max)?,
234 };
235 let rsum = sum.recip();
236 (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
237 }
238 SoftmaxKind::LogSoftmax => {
239 let mut exp_sum = f16::zero();
240 slice.iter_mut().for_each(|x| {
241 *x -= max;
242 exp_sum += x.exp();
243 });
244 let log_sum = exp_sum.ln();
245 slice.iter_mut().for_each(|x| *x -= log_sum);
246 }
247 }
248 Ok(())
249 }
250
251 fn softmax_inner_slice_f32(&self, slice: &mut [f32], kind: SoftmaxKind) -> TractResult<()> {
252 let max = (tract_linalg::ops().max_f32)().run(slice)?;
253 match kind {
254 SoftmaxKind::Softmax(exp_impl) => {
255 let sum = match exp_impl {
256 SoftmaxExp::Libc => {
257 let mut s = f32::zero();
258 slice.iter_mut().for_each(|x| {
259 *x = (*x - max).exp();
260 s += *x;
261 });
262 s
263 }
264 SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f32)()
265 .run_with_params(slice, max)?,
266 };
267 let rsum = sum.recip();
268 (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
269 }
270 SoftmaxKind::LogSoftmax => {
271 let mut exp_sum = f32::zero();
272 slice.iter_mut().for_each(|x| {
273 *x -= max;
274 exp_sum += x.exp();
275 });
276 let log_sum = exp_sum.ln();
277 slice.iter_mut().for_each(|x| *x -= log_sum);
278 }
279 }
280 Ok(())
281 }
282}
283
284fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(
285 mut view: ArrayViewMut<T, D>,
286 kind: SoftmaxKind,
287) {
288 let max =
289 *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
290 view.mapv_inplace(|x| x - max);
291 let exp_sum = view.iter().map(|&x| x.exp()).sum();
292 match kind {
293 SoftmaxKind::Softmax(_) => {
294 view.mapv_inplace(|x| x.exp() / exp_sum);
295 }
296 SoftmaxKind::LogSoftmax => {
297 let log_sum = exp_sum.ln();
298 view.mapv_inplace(|x| x - log_sum);
299 }
300 }
301}
302
303fn softmax_quant_inner<D: Dimension>(
304 mut view: ArrayViewMut<u8, D>,
305 src_is_signed: bool,
306 in_qp: QParams,
307 out_is_signed: bool,
308 out_qp: QParams,
309) {
310 let (_, in_scale) = in_qp.zp_scale();
311 let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
312 let (_, out_scale) = out_qp.zp_scale();
313 let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
314 let shift = 26 - scale_in_shift;
315
316 let mut buffer = vec![0_i32; view.len()];
318
319 let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
321
322 let max = view.iter().map(safe_u8).max().unwrap();
323 view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
324 let input_diff = safe_u8(x) as i32 - max as i32;
325
326 let scaled_input_diff = if scale_in_multiplier != 0 {
328 saturating_rounding_multiply_by_pot(
329 saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
330 shift as i32,
331 )
332 } else {
333 saturating_rounding_multiply_by_pot(input_diff, shift as i32)
334 };
335
336 *exp = exp_on_negative_values(scaled_input_diff);
338 });
339
340 let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
343
344 let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
347
348 let exponent = num_bits_over_unit as isize + 31 - 8;
350
351 view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
352 let unsat_output = rounding_divide_by_pot(
354 saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
355 exponent as i32,
356 );
357
358 let unsat_scaled_output = {
360 if scale_out_multiplier != 0 {
361 let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
362 rounding_divide_by_pot(
363 saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
364 (8 - scale_out_shift - 1 - num_bits as isize) as i32,
365 )
366 } else {
367 rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
368 }
369 };
370
371 #[allow(unknown_lints, unnecessary_transmutes)]
374 if out_is_signed {
375 *it = unsafe {
376 std::mem::transmute::<i8, u8>(i32::max(
377 i32::min(unsat_scaled_output, i8::MAX as i32),
378 i8::MIN as i32,
379 ) as i8)
380 };
381 } else {
382 *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
383 }
384 });
385}
386
387#[cfg(test)]
388mod test {
389 use super::*;
390 use crate::ops::nn::DataFormat::NCHW;
391 use anyhow::Result;
392 use num_traits::PrimInt;
393 use proptest::collection::vec;
394 use proptest::prelude::*;
395 use tract_data::internal::QParams::ZpScale;
396
397 fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
398 let (_, in_epsilon) = in_dt.zp_scale();
399 let (_, out_epsilon) = out_dt.zp_scale();
400 let epsilon = in_epsilon + out_epsilon;
401 let error = (found - expected).abs();
402 assert!(
403 error <= epsilon,
404 "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
405 );
406 }
407
408 fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
410 let len = shape.iter().product::<usize>();
411 let dt = q_datum::<T>((0.0001f32..0.1).boxed());
412 (vec(any::<T>(), len..=len), dt)
413 .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
414 .prop_map(move |(array, dt)| {
415 let mut tensor = array.into_tensor();
416 unsafe { tensor.set_datum_type(dt) };
417 tensor
418 })
419 .boxed()
420 }
421
422 fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
424 let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
425 prop_oneof![
426 (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
427 range
428 ]
429 .prop_map(|scale| {
430 if T::datum_type().is_signed() {
431 DatumType::QI8(ZpScale { zero_point: 0, scale })
432 } else {
433 DatumType::QU8(ZpScale { zero_point: 0, scale })
434 }
435 })
436 .boxed()
437 }
438
439 #[derive(Debug)]
440 struct SoftmaxProblem {
441 data: Tensor,
442 axes: TVec<usize>,
443 output_dt: DatumType,
444 }
445
446 impl SoftmaxProblem {
447 fn check(&self) -> Result<()> {
448 let inputs = tvec!(self.data.clone().into_tvalue());
449 let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
450 let softmax =
451 Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
452
453 let result = softmax.eval(inputs)?;
455 let result = args_1!(result);
456 let result_float = result.cast_to::<f32>()?;
457
458 let input_float = self.data.cast_to::<f32>()?;
460 let inputs_float = tvec!(input_float.into_owned().into_tvalue());
461 let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
462 let reference_float = softmax_float.eval(inputs_float)?;
463 let reference_array = args_1!(reference_float);
464 let reference = reference_array.to_dense_array_view::<f32>()?;
465
466 result_float
467 .to_dense_array_view::<f32>()?
468 .iter()
469 .zip(reference.iter())
470 .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
471 Ok(())
472 }
473 }
474
475 impl Arbitrary for SoftmaxProblem {
476 type Parameters = ();
477 type Strategy = BoxedStrategy<SoftmaxProblem>;
478 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
479 (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
480 .prop_flat_map(|(n, c, h, w, axis)| {
481 let shape_in: Vec<usize> =
482 NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
483 (
484 prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
485 Just(tvec![axis]),
486 prop_oneof![
487 q_datum::<u8>((0.008f32..0.1).boxed()),
488 q_datum::<i8>((0.008f32..0.1).boxed())
489 ],
490 )
491 })
492 .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
493 .boxed()
494 }
495 }
496
497 #[derive(Debug)]
498 pub struct InnerSoftmaxProblem {
499 in_qp: QParams,
500 out_qp: QParams,
501 data: Vec<i8>,
502 }
503
504 impl InnerSoftmaxProblem {
505 fn check(&self) -> Result<()> {
506 let quantized = self.quantized();
507 let reference = self.reference();
508 assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
509 let abs_diff = if *quantized > *expected {
510 quantized - *expected
511 } else {
512 expected - *quantized
513 };
514 abs_diff <= 1
515 }));
516 Ok(())
517 }
518
519 fn reference(&self) -> Vec<u8> {
520 let (in_zero_point, in_scale) = self.in_qp.zp_scale();
521 let (out_zero_point, out_scale) = self.out_qp.zp_scale();
522 let in_float =
523 self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
524 let mut in_float_array = Array1::from_vec(in_float);
525 softmax_inner(in_float_array.view_mut(), SoftmaxKind::default());
526 let rescaled_output = in_float_array
527 .iter()
528 .map(|it| {
529 ((*it / out_scale).round() as i32 + out_zero_point)
530 .max(u8::MIN as i32)
531 .min(u8::MAX as i32) as u8
532 })
533 .collect();
534 rescaled_output
535 }
536
537 fn quantized(&self) -> Vec<u8> {
538 let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
539 let mut in_array = Array1::from_vec(in_data);
540 softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
541 in_array.to_vec()
542 }
543 }
544
545 impl Arbitrary for InnerSoftmaxProblem {
546 type Parameters = ();
547 type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
548 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
549 (
550 prop_oneof![
551 q_datum::<i8>((0.0001f32..0.01).boxed()),
552 q_datum::<u8>((0.0001f32..0.01).boxed())
553 ],
554 prop_oneof![
555 q_datum::<u8>((0.008f32..0.1).boxed()),
556 q_datum::<i8>((0.008f32..0.1).boxed())
557 ],
558 vec(any::<i8>(), 1..10),
559 )
560 .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
561 in_qp: in_qp.qparams().unwrap(),
562 out_qp: out_qp.qparams().unwrap(),
563 data,
564 })
565 .boxed()
566 }
567 }
568
569 proptest::proptest! {
570 #![proptest_config(ProptestConfig::with_cases(1000))]
571 #[test]
572 fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
573 pb.check().unwrap()
574 }
575 }
576
577 proptest::proptest! {
578 #![proptest_config(ProptestConfig::with_cases(1000))]
579 #[test]
580 fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
581 pb.check().unwrap()
582 }
583 }
584
585 #[test]
586 fn test_softmax_trivial_0() -> Result<()> {
588 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
591 unsafe { data.set_datum_type(input_dt) };
592
593 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
594 prob.check()?;
595 Ok(())
596 }
597
598 #[test]
599 fn test_softmax_trivial_1() -> Result<()> {
601 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
604 unsafe { data.set_datum_type(input_dt) };
605
606 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
607 prob.check()?;
608 Ok(())
609 }
610
611 #[test]
612 fn test_softmax_trivial_2() -> Result<()> {
614 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
617 unsafe { data.set_datum_type(input_dt) };
618
619 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
620 prob.check()?;
621 Ok(())
622 }
623
624 #[test]
625 fn test_softmax_trivial_3() -> Result<()> {
627 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
630 unsafe { data.set_datum_type(input_dt) };
631
632 let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
633 prob.check()?;
634 Ok(())
635 }
636
637 #[test]
638 fn test_softmax_1() -> Result<()> {
639 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
642 unsafe { data.set_datum_type(input_dt) };
643
644 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
645 prob.check()?;
646 Ok(())
647 }
648
649 #[test]
650 fn test_softmax_2() -> Result<()> {
651 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
652 let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
653 let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
654 unsafe { data.set_datum_type(input_dt) };
655
656 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
657 prob.check()?;
658 Ok(())
659 }
660
661 #[test]
662 fn test_softmax_3() -> Result<()> {
663 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
664 let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
665 let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
666 unsafe { data.set_datum_type(input_dt) };
667
668 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
669 prob.check()?;
670 Ok(())
671 }
672
673 #[test]
674 fn test_inner_softmax_1() -> Result<()> {
675 let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
676 let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
677 let data = vec![0_i8, 1];
678
679 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
680 prob.check()?;
681 Ok(())
682 }
683
684 #[test]
685 fn test_inner_softmax_2() -> Result<()> {
686 let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
687 let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
688 let data = vec![100i8, -28];
689
690 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
691 prob.check()?;
692 Ok(())
693 }
694
695 #[test]
696 fn test_inner_softmax_not_pow_2_1() -> Result<()> {
697 let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
698 let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
699 let data = vec![100i8, -28];
700
701 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
702 prob.check()?;
703 Ok(())
704 }
705
706 #[test]
707 #[ignore]
708 fn test_inner_softmax_not_pow_2_2() -> Result<()> {
712 let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
713 let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
714 let data = vec![118i8, 108];
715
716 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
717 prob.check()?;
718 Ok(())
719 }
720
721 #[test]
722 #[ignore]
723 fn test_inner_softmax_not_pow_2_3() -> Result<()> {
727 let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
728 let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
729 let data = vec![45i8, 43];
730
731 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
732 prob.check()?;
733 Ok(())
734 }
735}