use crate::internal::*;
use tract_core::num_traits::Zero;
use tract_core::ops::cnn::{MaxPool, PaddingSpec, PoolSpec, SumPool};
register_all!(MaxPool: pulsify_max_pool, SumPool: pulsify_sum_pool);
fn pulsify_max_pool(
op: &MaxPool,
source: &TypedModel,
node: &TypedNode,
target: &mut PulsedModel,
mapping: &HashMap<OutletId, OutletId>,
_symbol: &Symbol,
_pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
fn min_value<D: Datum + tract_core::num_traits::Bounded>() -> Tensor {
tensor0(D::min_value())
}
let fact = target.outlet_fact(mapping[&node.inputs[0]])?;
let min = dispatch_numbers!(min_value(fact.datum_type)());
if let Some((wire, pool_spec)) =
pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, Some(min))?
{
Ok(Some(target.wire_node(&node.name, MaxPool { pool_spec, ..op.clone() }, &[wire])?))
} else {
Ok(None)
}
}
fn pulsify_sum_pool(
op: &SumPool,
source: &TypedModel,
node: &TypedNode,
target: &mut PulsedModel,
mapping: &HashMap<OutletId, OutletId>,
_symbol: &Symbol,
_pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
if let Some((wire, pool_spec)) =
pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, None)?
{
Ok(Some(target.wire_node(&node.name, SumPool { pool_spec, ..op.clone() }, &[wire])?))
} else {
Ok(None)
}
}
impl PulsedOp for SumPool {
fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
pulsed_output_facts(&self.pool_spec, inputs, inputs[0].datum_type)
}
as_op!();
pulsed_op_to_typed_op!();
}
impl PulsedOp for MaxPool {
fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
let mut facts = pulsed_output_facts(&self.pool_spec, inputs, inputs[0].datum_type)?;
if let Some(idt) = self.with_index_outputs {
facts.push(facts[0].clone());
facts[1].datum_type = idt;
}
Ok(facts)
}
as_op!();
pulsed_op_to_typed_op!();
}
pub fn pulsed_output_facts(
spec: &PoolSpec,
inputs: &[&PulsedFact],
output_dt: DatumType,
) -> TractResult<TVec<PulsedFact>> {
let ishape = spec.data_format.shape(&inputs[0].shape)?;
let computed = spec.padding.compute(
ishape.hw_dims(),
&spec.kernel_shape,
&spec.dilations(),
&spec.strides(),
);
let spatial_dims = computed.into_iter().map(|d| d.convoluted).collect::<TVec<TDim>>();
let oshape = spec.data_format.from_n_c_hw(
ishape.n().cloned().unwrap_or_else(|| 1.to_dim()),
spec.output_channel_override.map(|d| d.to_dim()).unwrap_or_else(|| ishape.c().clone()),
spatial_dims,
)?;
let mut fact = inputs[0].clone();
let stream = fact.stream.as_mut().unwrap();
let input_shape = spec.data_format.shape(&*fact.shape)?;
let geo_axis = stream.axis - input_shape.h_axis();
let dilation = spec.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1);
let kernel_len = (spec.kernel_shape[geo_axis] - 1) * dilation;
let stride = spec.strides.as_ref().and_then(|v| v.get(geo_axis).cloned()).unwrap_or(1);
stream.delay /= stride;
stream.dim = (stream.dim.clone() - kernel_len.to_dim()).div_ceil(stride as _);
fact.shape = oshape.shape.into();
fact.datum_type = output_dt;
Ok(tvec!(fact))
}
pub fn pulsify_pooled_input(
spec: &PoolSpec,
_source: &TypedModel,
node: &TypedNode,
target: &mut PulsedModel,
mapping: &HashMap<OutletId, OutletId>,
padding_value: Option<Tensor>,
) -> TractResult<Option<(OutletId, PoolSpec)>> {
let mut wire = mapping[&node.inputs[0]];
let input_fact: PulsedFact = target.outlet_fact(wire)?.clone();
let input_stream = input_fact.stream.as_ref().unwrap();
let input_shape = spec.data_format.shape(input_fact.shape.clone())?;
if Some(input_stream.axis) == input_shape.n_axis() {
return Ok(None);
}
if input_stream.axis == input_shape.c_axis() {
bail!("Can not pulsify cnn pooling ops along the input channel axis");
}
let geo_axis = input_stream.axis - input_shape.h_axis();
let stride = spec.strides.as_ref().and_then(|v| v.get(geo_axis).cloned()).unwrap_or(1);
let pulse = input_fact.pulse().unwrap();
if !(pulse.to_owned() % (stride as i64)).is_zero() {
bail!("Pulsification requires pulse ({}) to be a stride ({}) multiple", pulse, stride)
}
let dilation = spec.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1);
let kernel_len = (spec.kernel_shape[geo_axis] - 1) * dilation;
let overlap = (kernel_len + 1).saturating_sub(stride);
let computed_padding = spec.padding.compute_one(
geo_axis,
&input_stream.dim,
spec.kernel_shape[geo_axis],
spec.dilation(geo_axis),
spec.stride(geo_axis),
);
let before = computed_padding.pad_before.to_usize()?;
let early = input_stream.delay as isize + overlap as isize - before as isize;
let mut extra_delay = if early < 0 { (-early) as usize } else { 0 };
let delayed_input = input_stream.delay + overlap + extra_delay - before;
let misalignment = delayed_input % stride;
if misalignment > 0 {
extra_delay += stride - misalignment;
}
if overlap > 0 || extra_delay > 0 {
wire = target.wire_node(
format!("{}.delay", node.name),
tract_pulse_opl::ops::Delay::new_typed(
&(&input_fact).into(),
input_stream.axis,
extra_delay,
overlap,
),
&[wire],
)?[0];
}
let has_padding =
!computed_padding.pad_before.is_zero() || !computed_padding.pad_after.is_zero();
if has_padding {
use tract_core::ops::array::PadMode;
let value = if let Some(tensor) = padding_value {
tensor.into_arc_tensor()
} else {
bail!("No padding value for streaming pool operation");
};
let op = tract_pulse_opl::ops::PulsePad {
axis: input_stream.axis,
before,
after: computed_padding.pad_after,
begin_input: input_stream.delay + extra_delay + overlap,
end_input: input_stream.dim.clone()
+ input_stream.delay
+ extra_delay
+ overlap.to_dim(),
mode: PadMode::Constant(value),
overlap,
};
wire = target.wire_node(format!("{}.pulse-pad", node.name), op, &[wire])?[0];
}
if has_padding {
let mut bef = tvec!();
let mut aft = tvec!();
for ix in 0..input_shape.hw_rank() {
if ix == geo_axis {
bef.push(0);
aft.push(0);
} else {
let c = spec.padding.compute_one(
ix,
&input_shape.hw_dims()[ix],
spec.kernel_shape[ix],
spec.dilations()[ix],
spec.strides()[ix],
);
bef.push(c.pad_before.to_usize()?);
aft.push(c.pad_after.to_usize()?);
};
}
Ok(Some((
wire,
PoolSpec { padding: PaddingSpec::ExplicitOnnxPool(bef, aft, false), ..spec.clone() },
)))
} else {
Ok(Some((wire, spec.clone())))
}
}
#[cfg(test)]
mod test {
use tract_pulse_opl::tract_core::ops::cnn::{ConvUnary, PoolSpec};
use tract_pulse_opl::tract_nnef::internal::*;
use crate::model::{PulsedModel, PulsedModelExt};
#[test]
fn left_padded_conv_wo_delay() -> TractResult<()> {
let mut model = TypedModel::default();
let stream_sym = model.symbol_table.sym("S");
let stream_dim = stream_sym.to_dim();
let source = model.add_source("source", f32::fact(dims!(1, stream_dim)))?;
let conv = model.wire_node(
"conv",
ConvUnary {
pool_spec: PoolSpec {
data_format: tract_core::ops::nn::DataFormat::CHW,
dilations: None,
strides: None,
kernel_shape: tvec![2],
padding: tract_core::ops::cnn::PaddingSpec::ExplicitOnnxPool(tvec![1], tvec![0], false),
output_channel_override: Some(1),
},
kernel_fmt: tract_core::ops::cnn::KernelFormat::OIHW,
kernel: rctensor3(&[[[1f32, 2f32]]]),
group: 1,
bias: None,
q_params: None,
},
&[source],
)?;
model.set_output_outlets(&conv)?;
let pulsed = PulsedModel::new(&model, stream_sym, &1.to_dim())?;
let output_fact = pulsed.output_fact(0)?;
assert_eq!(output_fact.stream.as_ref().unwrap().delay, 0);
Ok(())
}
}