1mod fixedpoint;
2pub mod math;
3
4use math::{
5 convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
6 rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
7 saturating_rounding_multiply_by_pot,
8};
9use num_traits::Float;
10use std::fmt::Debug;
11use tract_num_traits::Zero;
12
13use crate::internal::*;
14use ndarray::prelude::*;
15
16#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
17pub enum SoftmaxKind {
18 Softmax(SoftmaxExp),
19 LogSoftmax,
20}
21
22impl Default for SoftmaxKind {
23 fn default() -> Self {
24 SoftmaxKind::Softmax(SoftmaxExp::default())
25 }
26}
27
28#[derive(Debug, Copy, Clone, Hash, Default, PartialEq, Eq)]
29pub enum SoftmaxExp {
30 #[default]
31 Libc,
32 FastCompact,
34}
35
36#[derive(Debug, Clone, new, Hash, Default, PartialEq, Eq)]
37pub struct Softmax {
38 pub axes: TVec<usize>,
39 pub quant_output_dt: Option<DatumType>,
40 pub kind: SoftmaxKind,
41}
42
43impl Op for Softmax {
44 fn name(&self) -> StaticName {
45 match self.kind {
46 SoftmaxKind::Softmax(_) => "Softmax".into(),
47 SoftmaxKind::LogSoftmax => "LogSoftmax".into(),
48 }
49 }
50
51 fn info(&self) -> TractResult<Vec<String>> {
52 let mut infos = vec![format!("Axis: {:?}", self.axes)];
53 if let SoftmaxKind::Softmax(exp) = self.kind {
54 infos.push(format!("Exp impl: {exp:?}"))
55 };
56 Ok(infos)
57 }
58
59 op_as_typed_op!();
60}
61
62impl TypedOp for Softmax {
63 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64 let dt = inputs[0].datum_type;
65 if dt.is_float() {
66 ensure!(
67 self.quant_output_dt.is_none(),
68 "Float softmax should not have quant_output_dt, have {:?}",
69 self.quant_output_dt
70 );
71 } else if dt.is_quantized() {
72 ensure!(
73 self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
74 "Quantized softmax should have a quantized output type (got {:?})",
75 self.quant_output_dt
76 );
77 } else {
78 bail!(
79 "Unsupported datum type in softmax: input type {:?}, output type {:?}",
80 dt,
81 self.quant_output_dt
82 );
83 }
84
85 let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
86 Ok(tvec!(fact))
87 }
88
89 fn input_roi(
90 &self,
91 model: &TypedModel,
92 node: &TypedNode,
93 ) -> TractResult<Option<TVec<Option<TDim>>>> {
94 crate::optim::propagate_roi::bubble_roi(model, node)
95 }
96
97 fn axes_mapping(
98 &self,
99 inputs: &[&TypedFact],
100 outputs: &[&TypedFact],
101 ) -> TractResult<AxesMapping> {
102 AxesMapping::natural(inputs, outputs)
103 }
104
105 fn change_axes(
106 &self,
107 model: &TypedModel,
108 node: &TypedNode,
109 _io: InOut,
110 change: &AxisOp,
111 ) -> TractResult<Option<AxisChangeConsequence>> {
112 let axes: Option<TVec<usize>> =
113 self.axes.iter().map(|it| change.transform_axis(*it)).collect();
114 if let Some(axes) = axes {
115 Ok(Some(AxisChangeConsequence::new(
116 model,
117 node,
118 Some(Box::new(Softmax { axes, ..self.clone() })),
119 change,
120 )))
121 } else {
122 Ok(None)
123 }
124 }
125
126 as_op!();
127}
128
129impl EvalOp for Softmax {
130 fn is_stateless(&self) -> bool {
131 true
132 }
133
134 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
135 let input = args_1!(inputs);
136 let dt = input.datum_type();
137
138 let output = match dt {
139 DatumType::F64 => self.eval_t::<f64>(input)?,
140 DatumType::F32 => self.eval_t::<f32>(input)?,
141 DatumType::F16 => self.eval_t::<f16>(input)?,
142 DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
143 dt => bail!("Unsupported type {dt:?}"),
144 };
145 Ok(output)
146 }
147}
148
149impl Softmax {
150 fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
151 where
152 T: Float + Datum + std::iter::Sum,
153 {
154 let mut iterating_shape: TVec<usize> = input.shape().into();
155
156 for i in 0..iterating_shape.len() {
157 if self.axes.contains(&i) {
158 iterating_shape[i] = 1
159 }
160 }
161
162 let mut output = input.into_tensor();
163 let mut output_plain = output.try_as_plain_mut()?;
164 let mut view = output_plain.to_array_view_mut::<T>()?;
165
166 for it_coords in tract_ndarray::indices(&*iterating_shape) {
167 let mut view = view.view_mut();
168 for ix in 0..iterating_shape.len() {
169 if !self.axes.contains(&ix) {
170 view.collapse_axis(Axis(ix), it_coords[ix]);
171 }
172 }
173 if let Some(slice) =
174 view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
175 {
176 let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
177 self.softmax_inner_slice_f32(slice, self.kind)?;
178 } else if let Some(slice) =
179 view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
180 {
181 let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
182 self.softmax_inner_slice_f16(slice, self.kind)?;
183 } else {
184 softmax_inner(view, self.kind);
185 }
186 }
187
188 Ok(tvec!(output.into_tvalue()))
189 }
190
191 fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
192 if self.kind == SoftmaxKind::LogSoftmax {
193 bail!("Quantized LogSoftmax is not supported")
194 }
195 let mut iterating_shape: TVec<usize> = input.shape().into();
196 let output_dt =
197 self.quant_output_dt.context("Quandized softmax eval with no output type")?;
198
199 for i in 0..iterating_shape.len() {
200 if self.axes.contains(&i) {
201 iterating_shape[i] = 1
202 }
203 }
204
205 let src_is_signed = input.datum_type().is_signed();
207 let out_is_signed = output_dt.is_signed();
208 let in_qp = input.datum_type().qparams().unwrap(); let out_qp = output_dt.qparams().unwrap(); let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
211
212 for it_coords in tract_ndarray::indices(&*iterating_shape) {
213 let mut view = output.view_mut();
214 for ix in 0..iterating_shape.len() {
215 if !self.axes.contains(&ix) {
216 view.collapse_axis(Axis(ix), it_coords[ix]);
217 }
218 }
219 softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
220 }
221
222 let mut output_tensor = output.into_tensor();
223 unsafe { output_tensor.set_datum_type(output_dt) };
224 Ok(tvec!(output_tensor.into_tvalue()))
225 }
226
227 fn softmax_inner_slice_f16(&self, slice: &mut [f16], kind: SoftmaxKind) -> TractResult<()> {
228 let max = (tract_linalg::ops().max_f16)().run(slice)?;
229 match kind {
230 SoftmaxKind::Softmax(exp_impl) => {
231 let sum = match exp_impl {
232 SoftmaxExp::Libc => {
233 let mut s = f16::zero();
234 slice.iter_mut().for_each(|x| {
235 *x = (*x - max).exp();
236 s += *x;
237 });
238 s
239 }
240 SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f16)()
241 .run_with_params(slice, max)?,
242 };
243 let rsum = sum.recip();
244 (tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
245 }
246 SoftmaxKind::LogSoftmax => {
247 let mut exp_sum = f16::zero();
248 slice.iter_mut().for_each(|x| {
249 *x -= max;
250 exp_sum += x.exp();
251 });
252 let log_sum = exp_sum.ln();
253 slice.iter_mut().for_each(|x| *x -= log_sum);
254 }
255 }
256 Ok(())
257 }
258
259 fn softmax_inner_slice_f32(&self, slice: &mut [f32], kind: SoftmaxKind) -> TractResult<()> {
260 let max = (tract_linalg::ops().max_f32)().run(slice)?;
261 match kind {
262 SoftmaxKind::Softmax(exp_impl) => {
263 let sum = match exp_impl {
264 SoftmaxExp::Libc => {
265 let mut s = f32::zero();
266 slice.iter_mut().for_each(|x| {
267 *x = (*x - max).exp();
268 s += *x;
269 });
270 s
271 }
272 SoftmaxExp::FastCompact => (tract_linalg::ops().softmax2_fastcompact_f32)()
273 .run_with_params(slice, max)?,
274 };
275 let rsum = sum.recip();
276 (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
277 }
278 SoftmaxKind::LogSoftmax => {
279 let mut exp_sum = f32::zero();
280 slice.iter_mut().for_each(|x| {
281 *x -= max;
282 exp_sum += x.exp();
283 });
284 let log_sum = exp_sum.ln();
285 slice.iter_mut().for_each(|x| *x -= log_sum);
286 }
287 }
288 Ok(())
289 }
290}
291
292fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(
293 mut view: ArrayViewMut<T, D>,
294 kind: SoftmaxKind,
295) {
296 let max =
297 *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
298 view.mapv_inplace(|x| x - max);
299 let exp_sum = view.iter().map(|&x| x.exp()).sum();
300 match kind {
301 SoftmaxKind::Softmax(_) => {
302 view.mapv_inplace(|x| x.exp() / exp_sum);
303 }
304 SoftmaxKind::LogSoftmax => {
305 let log_sum = exp_sum.ln();
306 view.mapv_inplace(|x| x - log_sum);
307 }
308 }
309}
310
311fn softmax_quant_inner<D: Dimension>(
312 mut view: ArrayViewMut<u8, D>,
313 src_is_signed: bool,
314 in_qp: QParams,
315 out_is_signed: bool,
316 out_qp: QParams,
317) {
318 let (_, in_scale) = in_qp.zp_scale();
319 let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
320 let (_, out_scale) = out_qp.zp_scale();
321 let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
322 let shift = 26 - scale_in_shift;
323
324 let mut buffer = vec![0_i32; view.len()];
326
327 let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
329
330 let max = view.iter().map(safe_u8).max().unwrap();
331 view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
332 let input_diff = safe_u8(x) as i32 - max as i32;
333
334 let scaled_input_diff = if scale_in_multiplier != 0 {
336 saturating_rounding_multiply_by_pot(
337 saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
338 shift as i32,
339 )
340 } else {
341 saturating_rounding_multiply_by_pot(input_diff, shift as i32)
342 };
343
344 *exp = exp_on_negative_values(scaled_input_diff);
346 });
347
348 let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
351
352 let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
355
356 let exponent = num_bits_over_unit as isize + 31 - 8;
358
359 view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
360 let unsat_output = rounding_divide_by_pot(
362 saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
363 exponent as i32,
364 );
365
366 let unsat_scaled_output = {
368 if scale_out_multiplier != 0 {
369 let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
370 rounding_divide_by_pot(
371 saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
372 (8 - scale_out_shift - 1 - num_bits as isize) as i32,
373 )
374 } else {
375 rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
376 }
377 };
378
379 #[allow(unknown_lints, unnecessary_transmutes)]
382 if out_is_signed {
383 *it = unsafe {
384 std::mem::transmute::<i8, u8>(i32::max(
385 i32::min(unsat_scaled_output, i8::MAX as i32),
386 i8::MIN as i32,
387 ) as i8)
388 };
389 } else {
390 *it = i32::max(i32::min(unsat_scaled_output, u8::MAX as i32), u8::MIN as i32) as u8;
391 }
392 });
393}
394
395#[cfg(test)]
396mod test {
397 use super::*;
398 use crate::ops::nn::DataFormat::NCHW;
399 use anyhow::Result;
400 use num_traits::PrimInt;
401 use proptest::collection::vec;
402 use proptest::prelude::*;
403 use tract_data::internal::QParams::ZpScale;
404
405 fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
406 let (_, in_epsilon) = in_dt.zp_scale();
407 let (_, out_epsilon) = out_dt.zp_scale();
408 let epsilon = in_epsilon + out_epsilon;
409 let error = (found - expected).abs();
410 assert!(
411 error <= epsilon,
412 "epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
413 );
414 }
415
416 fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
418 let len = shape.iter().product::<usize>();
419 let dt = q_datum::<T>((0.0001f32..0.1).boxed());
420 (vec(any::<T>(), len..=len), dt)
421 .prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
422 .prop_map(move |(array, dt)| {
423 let mut tensor = array.into_tensor();
424 unsafe { tensor.set_datum_type(dt) };
425 tensor
426 })
427 .boxed()
428 }
429
430 fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
432 let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
433 prop_oneof![
434 (1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
435 range
436 ]
437 .prop_map(|scale| {
438 if T::datum_type().is_signed() {
439 DatumType::QI8(ZpScale { zero_point: 0, scale })
440 } else {
441 DatumType::QU8(ZpScale { zero_point: 0, scale })
442 }
443 })
444 .boxed()
445 }
446
447 #[derive(Debug)]
448 struct SoftmaxProblem {
449 data: Tensor,
450 axes: TVec<usize>,
451 output_dt: DatumType,
452 }
453
454 impl SoftmaxProblem {
455 fn check(&self) -> Result<()> {
456 let inputs = tvec!(self.data.clone().into_tvalue());
457 let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
458 let softmax =
459 Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
460
461 let result = softmax.eval(inputs)?;
463 let result = args_1!(result);
464 let result_float = result.cast_to::<f32>()?;
465
466 let input_float = self.data.cast_to::<f32>()?;
468 let inputs_float = tvec!(input_float.into_owned().into_tvalue());
469 let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
470 let reference_float = softmax_float.eval(inputs_float)?;
471 let reference_array = args_1!(reference_float);
472 let reference = reference_array.to_plain_array_view::<f32>()?;
473
474 result_float
475 .to_plain_array_view::<f32>()?
476 .iter()
477 .zip(reference.iter())
478 .for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
479 Ok(())
480 }
481 }
482
483 impl Arbitrary for SoftmaxProblem {
484 type Parameters = ();
485 type Strategy = BoxedStrategy<SoftmaxProblem>;
486 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
487 (1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
488 .prop_flat_map(|(n, c, h, w, axis)| {
489 let shape_in: Vec<usize> =
490 NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
491 (
492 prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
493 Just(tvec![axis]),
494 prop_oneof![
495 q_datum::<u8>((0.008f32..0.1).boxed()),
496 q_datum::<i8>((0.008f32..0.1).boxed())
497 ],
498 )
499 })
500 .prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
501 .boxed()
502 }
503 }
504
505 #[derive(Debug)]
506 pub struct InnerSoftmaxProblem {
507 in_qp: QParams,
508 out_qp: QParams,
509 data: Vec<i8>,
510 }
511
512 impl InnerSoftmaxProblem {
513 fn check(&self) -> Result<()> {
514 let quantized = self.quantized();
515 let reference = self.reference();
516 assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
517 let abs_diff = if *quantized > *expected {
518 quantized - *expected
519 } else {
520 expected - *quantized
521 };
522 abs_diff <= 1
523 }));
524 Ok(())
525 }
526
527 fn reference(&self) -> Vec<u8> {
528 let (in_zero_point, in_scale) = self.in_qp.zp_scale();
529 let (out_zero_point, out_scale) = self.out_qp.zp_scale();
530 let in_float =
531 self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
532 let mut in_float_array = Array1::from_vec(in_float);
533 softmax_inner(in_float_array.view_mut(), SoftmaxKind::default());
534 let rescaled_output = in_float_array
535 .iter()
536 .map(|it| {
537 ((*it / out_scale).round() as i32 + out_zero_point)
538 .max(u8::MIN as i32)
539 .min(u8::MAX as i32) as u8
540 })
541 .collect();
542 rescaled_output
543 }
544
545 fn quantized(&self) -> Vec<u8> {
546 let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
547 let mut in_array = Array1::from_vec(in_data);
548 softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
549 in_array.to_vec()
550 }
551 }
552
553 impl Arbitrary for InnerSoftmaxProblem {
554 type Parameters = ();
555 type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
556 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
557 (
558 prop_oneof![
559 q_datum::<i8>((0.0001f32..0.01).boxed()),
560 q_datum::<u8>((0.0001f32..0.01).boxed())
561 ],
562 prop_oneof![
563 q_datum::<u8>((0.008f32..0.1).boxed()),
564 q_datum::<i8>((0.008f32..0.1).boxed())
565 ],
566 vec(any::<i8>(), 1..10),
567 )
568 .prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
569 in_qp: in_qp.qparams().unwrap(),
570 out_qp: out_qp.qparams().unwrap(),
571 data,
572 })
573 .boxed()
574 }
575 }
576
577 proptest::proptest! {
578 #![proptest_config(ProptestConfig::with_cases(1000))]
579 #[test]
580 fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
581 pb.check().unwrap()
582 }
583 }
584
585 proptest::proptest! {
586 #![proptest_config(ProptestConfig::with_cases(1000))]
587 #[test]
588 fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
589 pb.check().unwrap()
590 }
591 }
592
593 #[test]
594 fn test_softmax_trivial_0() -> Result<()> {
596 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
599 unsafe { data.set_datum_type(input_dt) };
600
601 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
602 prob.check()?;
603 Ok(())
604 }
605
606 #[test]
607 fn test_softmax_trivial_1() -> Result<()> {
609 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
612 unsafe { data.set_datum_type(input_dt) };
613
614 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
615 prob.check()?;
616 Ok(())
617 }
618
619 #[test]
620 fn test_softmax_trivial_2() -> Result<()> {
622 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
625 unsafe { data.set_datum_type(input_dt) };
626
627 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
628 prob.check()?;
629 Ok(())
630 }
631
632 #[test]
633 fn test_softmax_trivial_3() -> Result<()> {
635 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
638 unsafe { data.set_datum_type(input_dt) };
639
640 let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
641 prob.check()?;
642 Ok(())
643 }
644
645 #[test]
646 fn test_softmax_1() -> Result<()> {
647 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
650 unsafe { data.set_datum_type(input_dt) };
651
652 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
653 prob.check()?;
654 Ok(())
655 }
656
657 #[test]
658 fn test_softmax_2() -> Result<()> {
659 let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
660 let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
661 let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
662 unsafe { data.set_datum_type(input_dt) };
663
664 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
665 prob.check()?;
666 Ok(())
667 }
668
669 #[test]
670 fn test_softmax_3() -> Result<()> {
671 let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
672 let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
673 let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
674 unsafe { data.set_datum_type(input_dt) };
675
676 let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
677 prob.check()?;
678 Ok(())
679 }
680
681 #[test]
682 fn test_inner_softmax_1() -> Result<()> {
683 let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
684 let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
685 let data = vec![0_i8, 1];
686
687 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
688 prob.check()?;
689 Ok(())
690 }
691
692 #[test]
693 fn test_inner_softmax_2() -> Result<()> {
694 let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
695 let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
696 let data = vec![100i8, -28];
697
698 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
699 prob.check()?;
700 Ok(())
701 }
702
703 #[test]
704 fn test_inner_softmax_not_pow_2_1() -> Result<()> {
705 let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
706 let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
707 let data = vec![100i8, -28];
708
709 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
710 prob.check()?;
711 Ok(())
712 }
713
714 #[test]
715 #[ignore]
716 fn test_inner_softmax_not_pow_2_2() -> Result<()> {
720 let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
721 let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
722 let data = vec![118i8, 108];
723
724 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
725 prob.check()?;
726 Ok(())
727 }
728
729 #[test]
730 #[ignore]
731 fn test_inner_softmax_not_pow_2_3() -> Result<()> {
735 let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
736 let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
737 let data = vec![45i8, 43];
738
739 let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
740 prob.check()?;
741 Ok(())
742 }
743}