mod fixedpoint;
pub mod math;
use math::{
convert_scale_to_mult_shift, exp_on_negative_values, get_reciprocal, rescale,
rounding_divide_by_pot, saturating_rounding_doubling_high_mul,
saturating_rounding_multiply_by_pot,
};
use num_traits::Float;
use std::fmt::Debug;
use tract_num_traits::Zero;
use crate::internal::*;
use ndarray::prelude::*;
#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
pub enum SoftmaxExp {
#[default]
Libc,
FastCompact,
}
#[derive(Debug, Clone, new, Hash, Default)]
pub struct Softmax {
pub axes: TVec<usize>,
pub quant_output_dt: Option<DatumType>,
pub exp: SoftmaxExp,
}
impl Op for Softmax {
fn name(&self) -> Cow<str> {
"Softmax".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)])
}
op_as_typed_op!();
}
impl TypedOp for Softmax {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let dt = inputs[0].datum_type;
if dt.is_float() {
ensure!(
self.quant_output_dt.is_none(),
"Float softmax should not have quant_output_dt, have {:?}",
self.quant_output_dt
);
} else if dt.is_quantized() {
ensure!(
self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
"Quantized softmax should have a quantized output type (got {:?})",
self.quant_output_dt
);
} else {
bail!(
"Unsupported datum type in softmax: input type {:?}, output type {:?}",
dt,
self.quant_output_dt
);
}
let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
Ok(tvec!(fact))
}
fn axes_mapping(
&self,
inputs: &[&TypedFact],
outputs: &[&TypedFact],
) -> TractResult<AxesMapping> {
AxesMapping::natural(inputs, outputs)
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
_io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
let axes: Option<TVec<usize>> =
self.axes.iter().map(|it| change.transform_axis(*it)).collect();
if let Some(axes) = axes {
Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(Softmax { axes, ..self.clone() })),
change,
)))
} else {
Ok(None)
}
}
as_op!();
}
impl EvalOp for Softmax {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let dt = input.datum_type();
let output = match dt {
DatumType::F64 => self.eval_t::<f64>(input)?,
DatumType::F32 => self.eval_t::<f32>(input)?,
DatumType::F16 => self.eval_t::<f16>(input)?,
DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
dt => bail!("Unsupported type {dt:?}"),
};
Ok(output)
}
}
impl Softmax {
fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
where
T: Float + Datum + std::iter::Sum,
{
let mut iterating_shape: TVec<usize> = input.shape().into();
for i in 0..iterating_shape.len() {
if self.axes.contains(&i) {
iterating_shape[i] = 1
}
}
let mut output = input.into_tensor();
let mut view = output.to_array_view_mut::<T>()?;
for it_coords in tract_ndarray::indices(&*iterating_shape) {
let mut view = view.view_mut();
for ix in 0..iterating_shape.len() {
if !self.axes.contains(&ix) {
view.collapse_axis(Axis(ix), it_coords[ix]);
}
}
if let Some(slice) =
view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
{
let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
self.softmax_inner_slice_f32(slice)?;
} else if let Some(slice) =
view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
{
let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
self.softmax_inner_slice_f16(slice)?;
} else {
softmax_inner(view);
}
}
Ok(tvec!(output.into_tvalue()))
}
fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
let mut iterating_shape: TVec<usize> = input.shape().into();
let output_dt =
self.quant_output_dt.context("Quandized softmax eval with no output type")?;
for i in 0..iterating_shape.len() {
if self.axes.contains(&i) {
iterating_shape[i] = 1
}
}
let src_is_signed = input.datum_type().is_signed();
let out_is_signed = output_dt.is_signed();
let in_qp = input.datum_type().qparams().unwrap(); let out_qp = output_dt.qparams().unwrap(); let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };
for it_coords in tract_ndarray::indices(&*iterating_shape) {
let mut view = output.view_mut();
for ix in 0..iterating_shape.len() {
if !self.axes.contains(&ix) {
view.collapse_axis(Axis(ix), it_coords[ix]);
}
}
softmax_quant_inner(view, src_is_signed, in_qp, out_is_signed, out_qp);
}
let mut output_tensor = output.into_tensor();
unsafe { output_tensor.set_datum_type(output_dt) };
Ok(tvec!(output_tensor.into_tvalue()))
}
fn softmax_inner_slice_f16(&self, slice: &mut [f16]) -> TractResult<()> {
let max = (tract_linalg::ops().max_f16)().run(slice)?;
let sum = match self.exp {
SoftmaxExp::Libc => {
let mut s = f16::zero();
for x in slice.iter_mut() {
let y = (*x - max).exp();
s += y;
*x = y;
}
s
}
SoftmaxExp::FastCompact => {
(tract_linalg::ops().softmax2_fastcompact_f16)().run_with_params(slice, max)?
}
};
let rsum = sum.recip();
(tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
Ok(())
}
fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> {
let max = (tract_linalg::ops().max_f32)().run(slice)?;
let sum = match self.exp {
SoftmaxExp::Libc => {
let mut s = 0f32;
for x in slice.iter_mut() {
let y = (*x - max).exp();
s += y;
*x = y;
}
s
}
SoftmaxExp::FastCompact => {
(tract_linalg::ops().softmax2_fastcompact_f32)().run_with_params(slice, max)?
}
};
let rsum = sum.recip();
(tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
Ok(())
}
}
fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(mut view: ArrayViewMut<T, D>) {
let max =
*view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap();
view.mapv_inplace(|x| (x - max).exp());
let exp_sum = view.iter().copied().sum();
view.mapv_inplace(|x| x / exp_sum);
}
fn softmax_quant_inner<D: Dimension>(
mut view: ArrayViewMut<u8, D>,
src_is_signed: bool,
in_qp: QParams,
out_is_signed: bool,
out_qp: QParams,
) {
let (_, in_scale) = in_qp.zp_scale();
let (scale_in_multiplier, scale_in_shift) = convert_scale_to_mult_shift(in_scale).unwrap();
let (_, out_scale) = out_qp.zp_scale();
let (scale_out_multiplier, scale_out_shift) = convert_scale_to_mult_shift(out_scale).unwrap();
let shift = 26 - scale_in_shift;
let mut buffer = vec![0_i32; view.len()];
let safe_u8 = if src_is_signed { |x: &u8| x.wrapping_add(128) } else { |x: &u8| *x };
let max = view.iter().map(safe_u8).max().unwrap();
view.iter().zip(buffer.iter_mut()).for_each(|(x, exp)| {
let input_diff = safe_u8(x) as i32 - max as i32;
let scaled_input_diff = if scale_in_multiplier != 0 {
saturating_rounding_multiply_by_pot(
saturating_rounding_doubling_high_mul(input_diff, scale_in_multiplier),
shift as i32,
)
} else {
saturating_rounding_multiply_by_pot(input_diff, shift as i32)
};
*exp = exp_on_negative_values(scaled_input_diff);
});
let sum_of_exp = buffer.iter().map(|it| rescale(*it, 0, 12)).sum();
let (inv_sum_of_exp, num_bits_over_unit) = get_reciprocal(sum_of_exp, 12);
let exponent = num_bits_over_unit as isize + 31 - 8;
view.iter_mut().zip(buffer.iter()).for_each(|(it, exp)| {
let unsat_output = rounding_divide_by_pot(
saturating_rounding_doubling_high_mul(inv_sum_of_exp, *exp),
exponent as i32,
);
let unsat_scaled_output = {
if scale_out_multiplier != 0 {
let (inv_multiplier, num_bits) = get_reciprocal(scale_out_multiplier, 1);
rounding_divide_by_pot(
saturating_rounding_doubling_high_mul(unsat_output, inv_multiplier),
(8 - scale_out_shift - 1 - num_bits as isize) as i32,
)
} else {
rounding_divide_by_pot(unsat_output, (8 - scale_out_shift) as i32)
}
};
if out_is_signed {
*it = unsafe {
std::mem::transmute(i32::max(
i32::min(unsat_scaled_output, i8::max_value() as i32),
i8::min_value() as i32,
) as i8)
};
} else {
*it = i32::max(
i32::min(unsat_scaled_output, u8::max_value() as i32),
u8::min_value() as i32,
) as u8;
}
});
}
#[cfg(test)]
mod test {
use super::*;
use crate::ops::nn::DataFormat::NCHW;
use anyhow::Result;
use num_traits::PrimInt;
use proptest::collection::vec;
use proptest::prelude::*;
use tract_data::internal::QParams::ZpScale;
fn assert_is_close(found: f32, expected: f32, in_dt: DatumType, out_dt: DatumType) {
let (_, in_epsilon) = in_dt.zp_scale();
let (_, out_epsilon) = out_dt.zp_scale();
let epsilon = f32::max(in_epsilon, out_epsilon);
let error = (found - expected).abs();
assert!(
error <= epsilon,
"epsilon eq failed: |{found:?}-{expected:?}|={error} should be <= {epsilon}"
);
}
fn qtensor<T: PrimInt + Datum + Arbitrary>(shape: Vec<usize>) -> BoxedStrategy<Tensor> {
let len = shape.iter().product::<usize>();
let dt = q_datum::<T>((0.0001f32..0.1).boxed());
(vec(any::<T>(), len..=len), dt)
.prop_map(move |(vec, dt)| (ArrayD::from_shape_vec(shape.clone(), vec).unwrap(), dt))
.prop_map(move |(array, dt)| {
let mut tensor = array.into_tensor();
unsafe { tensor.set_datum_type(dt) };
tensor
})
.boxed()
}
fn q_datum<T: PrimInt + Datum>(range: BoxedStrategy<f32>) -> BoxedStrategy<DatumType> {
let max_integer_bits = std::mem::size_of::<T>() * 8 - T::datum_type().is_signed() as usize;
prop_oneof![
(1usize..max_integer_bits).prop_map(|fixed_point| { 2f32.powi(-(fixed_point as i32)) }),
range
]
.prop_map(|scale| {
if T::datum_type().is_signed() {
DatumType::QI8(ZpScale { zero_point: 0, scale })
} else {
DatumType::QU8(ZpScale { zero_point: 0, scale })
}
})
.boxed()
}
#[derive(Debug)]
struct SoftmaxProblem {
data: Tensor,
axes: TVec<usize>,
output_dt: DatumType,
}
impl SoftmaxProblem {
fn check(&self) -> Result<()> {
let inputs = tvec!(self.data.clone().into_tvalue());
let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
let softmax =
Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };
let result = softmax.eval(inputs)?;
let result = args_1!(result);
let result_float = result.cast_to::<f32>()?;
let input_float = self.data.cast_to::<f32>()?;
let inputs_float = tvec!(input_float.into_owned().into_tvalue());
let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
let reference_float = softmax_float.eval(inputs_float)?;
let reference_array = args_1!(reference_float);
let reference = reference_array.to_array_view::<f32>()?;
result_float
.to_array_view::<f32>()?
.iter()
.zip(reference.iter())
.for_each(|(a, b)| assert_is_close(*a, *b, self.data.datum_type(), self.output_dt));
Ok(())
}
}
impl Arbitrary for SoftmaxProblem {
type Parameters = ();
type Strategy = BoxedStrategy<SoftmaxProblem>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(1usize..2, 1usize..2, 1usize..5, 1usize..5, 0usize..4)
.prop_flat_map(|(n, c, h, w, axis)| {
let shape_in: Vec<usize> =
NCHW.from_n_c_hw(n, c, [h, w]).unwrap().shape.to_vec();
(
prop_oneof![qtensor::<i8>(shape_in.clone()), qtensor::<u8>(shape_in)],
Just(tvec![axis]),
prop_oneof![
q_datum::<u8>((0.008f32..0.1).boxed()),
q_datum::<i8>((0.008f32..0.1).boxed())
],
)
})
.prop_map(|(data, axes, output_dt)| SoftmaxProblem { data, axes, output_dt })
.boxed()
}
}
#[derive(Debug)]
pub struct InnerSoftmaxProblem {
in_qp: QParams,
out_qp: QParams,
data: Vec<i8>,
}
impl InnerSoftmaxProblem {
fn check(&self) -> Result<()> {
let quantized = self.quantized();
let reference = self.reference();
assert!(quantized.iter().zip(reference.iter()).all(|(quantized, expected)| {
let abs_diff = if *quantized > *expected {
quantized - *expected
} else {
expected - *quantized
};
abs_diff <= 1
}));
Ok(())
}
fn reference(&self) -> Vec<u8> {
let (in_zero_point, in_scale) = self.in_qp.zp_scale();
let (out_zero_point, out_scale) = self.out_qp.zp_scale();
let in_float =
self.data.iter().map(|it| (*it as f32 - in_zero_point as f32) * in_scale).collect();
let mut in_float_array = Array1::from_vec(in_float);
softmax_inner(in_float_array.view_mut());
let rescaled_output = in_float_array
.iter()
.map(|it| {
((*it / out_scale).round() as i32 + out_zero_point)
.max(u8::MIN as i32)
.min(u8::MAX as i32) as u8
})
.collect();
rescaled_output
}
fn quantized(&self) -> Vec<u8> {
let in_data: Vec<u8> = unsafe { std::mem::transmute(self.data.clone()) };
let mut in_array = Array1::from_vec(in_data);
softmax_quant_inner(in_array.view_mut(), true, self.in_qp, false, self.out_qp);
in_array.to_vec()
}
}
impl Arbitrary for InnerSoftmaxProblem {
type Parameters = ();
type Strategy = BoxedStrategy<InnerSoftmaxProblem>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(
prop_oneof![
q_datum::<i8>((0.0001f32..0.01).boxed()),
q_datum::<u8>((0.0001f32..0.01).boxed())
],
prop_oneof![
q_datum::<u8>((0.008f32..0.1).boxed()),
q_datum::<i8>((0.008f32..0.1).boxed())
],
vec(any::<i8>(), 1..10),
)
.prop_map(|(in_qp, out_qp, data)| InnerSoftmaxProblem {
in_qp: in_qp.qparams().unwrap(),
out_qp: out_qp.qparams().unwrap(),
data,
})
.boxed()
}
}
proptest::proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn test_softmax_inner_prop(pb in any::<InnerSoftmaxProblem>()) {
pb.check().unwrap()
}
}
proptest::proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn test_softmax_prop(pb in any::<SoftmaxProblem>()) {
pb.check().unwrap()
}
}
#[test]
fn test_softmax_trivial_0() -> Result<()> {
let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_softmax_trivial_1() -> Result<()> {
let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.00390625 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, 4])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_softmax_trivial_2() -> Result<()> {
let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0625 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_i8, 0, 0, -4])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_softmax_trivial_3() -> Result<()> {
let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.03125 }); let output_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0078125 }); let mut data = Tensor::from_shape(&[1, 1, 2, 2], &[0_u8, 0, 0, 4])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![2], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_softmax_1() -> Result<()> {
let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.5 }); let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5 }); let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_softmax_2() -> Result<()> {
let input_dt = DatumType::QI8(ZpScale { zero_point: 0, scale: 0.0001 });
let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.008 });
let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[115_i8, 115])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_softmax_3() -> Result<()> {
let input_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.6220956 });
let output_dt = DatumType::QU8(ZpScale { zero_point: 0, scale: 0.5187921 });
let mut data = Tensor::from_shape(&[1, 1, 1, 2], &[13_u8, 218])?;
unsafe { data.set_datum_type(input_dt) };
let prob = SoftmaxProblem { data, axes: tvec![3], output_dt };
prob.check()?;
Ok(())
}
#[test]
fn test_inner_softmax_1() -> Result<()> {
let in_qp = ZpScale { zero_point: 0, scale: 0.03125 };
let out_qp = ZpScale { zero_point: 0, scale: 0.5 };
let data = vec![0_i8, 1];
let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
prob.check()?;
Ok(())
}
#[test]
fn test_inner_softmax_2() -> Result<()> {
let in_qp = ZpScale { zero_point: 0, scale: 0.5 };
let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
let data = vec![100i8, -28];
let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
prob.check()?;
Ok(())
}
#[test]
fn test_inner_softmax_not_pow_2_1() -> Result<()> {
let in_qp = ZpScale { zero_point: 0, scale: 0.7298456 };
let out_qp = ZpScale { zero_point: 0, scale: 0.03125 };
let data = vec![100i8, -28];
let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
prob.check()?;
Ok(())
}
#[test]
#[ignore]
fn test_inner_softmax_not_pow_2_2() -> Result<()> {
let in_qp = ZpScale { zero_point: 0, scale: 0.2123116 };
let out_qp = ZpScale { zero_point: 0, scale: 0.008 };
let data = vec![118i8, 108];
let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
prob.check()?;
Ok(())
}
#[test]
#[ignore]
fn test_inner_softmax_not_pow_2_3() -> Result<()> {
let in_qp = ZpScale { zero_point: 0, scale: 0.33034274 };
let out_qp = ZpScale { zero_point: 0, scale: 0.015625 };
let data = vec![45i8, 43];
let prob = InnerSoftmaxProblem { in_qp, out_qp, data };
prob.check()?;
Ok(())
}
}