use crate::infer::*;
use crate::internal::*;
#[derive(Debug, Clone, new, Default, Hash)]
pub struct Reshape {}
impl Expansion for Reshape {
fn name(&self) -> Cow<str> {
"Reshape".into()
}
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
s.given_2(&inputs[0].shape, &inputs[1].value, move |s, ishape, shape| {
let shape = shape.cast_to::<TDim>()?;
let shape = shape.as_slice::<TDim>()?;
let oshape = compute_shape(&ishape, shape)
.with_context(|| format!("Reshaping {ishape:?} to {shape:?}"))?;
s.equals(&outputs[0].shape, ShapeFactoid::from(oshape))
})
}
fn wire(
&self,
prefix: &str,
model: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
if let Some(ref shape) = model.outlet_fact(inputs[1])?.konst {
let input_shape: TVec<TDim> = model.outlet_fact(inputs[0])?.shape.to_tvec();
let shape = shape.cast_to::<TDim>()?;
let shape = shape.as_slice::<TDim>()?;
let mut wire = tvec!(inputs[0]);
for (ix, op) in to_axis_ops(&input_shape, shape)?.into_iter().enumerate() {
wire = model.wire_node(format!("{prefix}.{ix}"), op, &wire)?;
}
return Ok(wire);
}
bail!("shape input is variable")
}
}
fn compute_shape(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
let mut shape: TVec<TDim> = shape_spec.into();
fn deal_with_zero<'a>(
mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
shape: &mut [TDim],
) -> TractResult<()> {
let mut remaining_dim_input = 1.to_dim();
for slot in shape.iter_mut() {
if *slot == (-1).into() {
break;
}
if *slot == 0.into() {
if remaining_dim_input != TDim::one() {
bail!("Invalid remaining dim");
}
*slot = (*input_dims.peek().context("Invalid")?).clone();
}
loop {
let quotient = remaining_dim_input.maybe_div(slot);
if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
remaining_dim_input *= input_dims.next().context("Invalid")?;
} else {
break;
}
}
remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
}
Ok(())
}
deal_with_zero(input.iter().peekable(), &mut shape)?;
shape.reverse();
deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
shape.reverse();
if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
let input_vol: TDim = input.iter().product();
let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
let div = input_vol.maybe_div(&shape_vol)?;
if div.1 != 1 {
bail!("invalid")
}
shape[pos] = div.0;
}
Ok(shape)
}
pub fn to_axis_ops(input_orig: &[TDim], output_spec: &[TDim]) -> TractResult<TVec<AxisOp>> {
let final_output = compute_shape(input_orig, output_spec)?;
let mut stack: TVec<AxisOp> = tvec!();
'top: loop {
let current_input =
stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
op.change_shape_array(&mut shape, false)?;
Ok(shape)
})?;
if current_input == final_output {
return Ok(stack);
}
if let Some(common) =
current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
{
if current_input[common].is_one() {
stack.push(AxisOp::Rm(common));
} else if final_output[common].is_one() {
stack.push(AxisOp::Add(common));
} else {
for i in common..current_input.len() {
let i_group = ¤t_input[common..i + 1];
let i_volume: TDim = i_group.iter().product();
for o in common..final_output.len() {
let o_group = &final_output[common..o + 1];
let o_volume: TDim = o_group.iter().product();
if i_volume == o_volume {
stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
continue 'top;
}
}
}
todo!()
}
} else if final_output.len() > current_input.len() {
stack.push(AxisOp::Add(current_input.len()));
} else {
stack.push(AxisOp::Rm(current_input.len() - 1));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use AxisOp::*;
macro_rules! s {
($($a:expr),*) => {&[ $($a.clone().into()),* ]}
}
macro_rules! r {
($at: expr ; $($from:expr),* => $($to:expr),*) => {
AxisOp::Reshape($at, tvec!($($from.into()),*), tvec!($($to.into()),*))
}
}
#[test]
fn compute_invalid() {
assert!(compute_shape(s![3, 4, 5], s!(100)).is_err());
}
#[test]
fn compute_with_leading_zero() {
assert_eq!(&*compute_shape(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
}
#[test]
fn compute_with_leading_zero_with_flatten() {
assert_eq!(&*compute_shape(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(), s![2, 3, 35])
}
#[test]
fn compute_with_trailing_zero() {
assert_eq!(&*compute_shape(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
}
#[test]
fn compute_bug_1() {
let table = SymbolTable::default();
let s = table.new_with_prefix("S");
assert_eq!(&*compute_shape(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(), s![s, 1, 256])
}
#[test]
fn compute_bug_2() {
let table = SymbolTable::default();
let b = table.new_with_prefix("B");
let s = table.new_with_prefix("S");
assert_eq!(&*compute_shape(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(), s![s, b, 256])
}
#[test]
fn axis_op_rm_begin() {
assert_eq!(&*to_axis_ops(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
}
#[test]
fn axis_op_rm_end() {
assert_eq!(&*to_axis_ops(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
}
#[test]
fn axis_op_insert_begin() {
assert_eq!(&*to_axis_ops(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
}
#[test]
fn axis_op_insert_end() {
assert_eq!(&*to_axis_ops(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
}
#[test]
fn axis_op_merge() {
assert_eq!(&*to_axis_ops(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(), &[r!(2 ; 5,7 => 35 )])
}
#[test]
fn axis_op_complex() {
assert_eq!(
&*to_axis_ops(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
&[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
)
}
}