use crate::internal::*;
use crate::ops::array::MultiBroadcastTo;
use crate::ops::cnn::wire_reshape_bias_for_bin;
use crate::ops::cnn::KernelFormat;
use crate::ops::cnn::PoolSpec;
use crate::ops::einsum::EinSum;
#[derive(Clone, Debug, new, Hash)]
pub struct Deconv {
    pub pool_spec: PoolSpec,
    pub kernel_format: KernelFormat,
    pub adjustments: TVec<usize>,
    pub group: usize,
}
impl Deconv {
    fn wire_with_deconv_sum(
        &self,
        name: &str,
        target: &mut TypedModel,
        inputs: &[OutletId],
    ) -> TractResult<TVec<OutletId>> {
        let input_shape = target.outlet_fact(inputs[0])?.shape.clone();
        let shape = self.pool_spec.data_format.shape(input_shape.to_tvec())?;
        let geo_dim = shape.hw_dims().iter().product();
        let mut input = target.wire_node(
            format!("{name}.reshaped_input"),
            AxisOp::Reshape(shape.h_axis(), shape.hw_dims().into(), tvec!(geo_dim)),
            &[inputs[0]],
        )?;
        if self.group != 1 {
            let i_axis = self.pool_spec.data_format.has_n() as usize
                + self.pool_spec.data_format.c_is_last() as usize;
            let i_dim = target.outlet_fact(input[0])?.shape[i_axis].clone();
            input = target.wire_node(
                format!("{name}.reshaped_input_for_group"),
                AxisOp::Reshape(
                    i_axis,
                    tvec![i_dim.clone()],
                    tvec!(self.group.to_dim(), i_dim / self.group),
                ),
                &input,
            )?;
            if self.pool_spec.data_format.c_is_last() {
                input = target.wire_node(
                    format!("{name}.group_axis_left"),
                    AxisOp::Move(
                        self.pool_spec.data_format.has_n() as usize + 1,
                        self.pool_spec.data_format.has_n() as usize,
                    ),
                    &input,
                )?;
            }
        }
        let mut kernel = tvec!(inputs[1]);
        let kernel_fact = target.outlet_fact(kernel[0])?.clone();
        for (ix, op) in self
            .kernel_format
            .kernel_as_group_o_i_hw_ops(&kernel_fact.shape, self.group)
            .into_iter()
            .enumerate()
        {
            kernel = target.wire_node(format!("{name}.kernel.{ix}"), op, &kernel)?;
        }
        kernel = target.wire_node(format!("{name}.kernel.mv_i"), AxisOp::Move(2, 3), &kernel)?;
        kernel =
            AxisOp::wire_collapse_axis(target, format!("{name}.kernel.col_ohw"), kernel[0], 1)?;
        if self.group == 1 {
            kernel = target.wire_node(format!("{name}.kernel.rm_g"), AxisOp::Rm(0), &kernel)?;
        }
        let mut expr = if self.pool_spec.data_format.c_is_last() {
            "gmk,Ngnk->Ngmn".to_string()
        } else {
            "gmk,Ngkn->Ngmn".to_string()
        };
        if !self.pool_spec.data_format.has_n() {
            expr = expr.replace('N', "");
        }
        if self.group == 1 {
            expr = expr.replace('g', "");
        }
        let einsum = target.wire_node(
            format!("{name}.einsum"),
            EinSum { axes: expr.parse()?, operating_dt: kernel_fact.datum_type, q_params: None },
            &[kernel[0], input[0]],
        )?;
        let mut bias = wire_reshape_bias_for_bin(
            target,
            format!("{name}.reshape_bias"),
            inputs[2],
            shape.rank(),
            shape.c_axis(),
            self.pool_spec.output_channels,
        )?[0];
        let output_shape = super::output_shape(&self.pool_spec, &shape.shape, &self.adjustments)?;
        bias = target.wire_node(
            &format!("{name}.broadcast_bias"),
            MultiBroadcastTo { shape: output_shape.into() },
            &[bias],
        )?[0];
        let deconv_sum = target.wire_node(
            format!("{name}.deconv_sum"),
            super::deconv_sum::DeconvSum::new(
                self.pool_spec.clone(),
                self.kernel_format,
                input_shape,
                self.adjustments.clone(),
                self.group,
            ),
            &[einsum[0], bias],
        )?;
        Ok(deconv_sum)
    }
}
impl Op for Deconv {
    fn name(&self) -> Cow<str> {
        "Deconv".into()
    }
    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![format!("{:?}", self.pool_spec)])
    }
    op_as_typed_op!();
}
impl EvalOp for Deconv {
    fn is_stateless(&self) -> bool {
        true
    }
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(inputs.len() == 3);
        let mut model = TypedModel::default();
        let inputs = inputs
            .into_iter()
            .enumerate()
            .map(|(ix, input)| model.add_const(format!("s{ix}"), input.into_tensor()))
            .collect::<TractResult<TVec<OutletId>>>()?;
        let output = self.wire_with_deconv_sum("adhoc", &mut model, &inputs)?;
        model.set_output_outlets(&output)?;
        model.into_runnable()?.run(tvec![]).context("In adhoc deconvolution eval")
    }
}
impl TypedOp for Deconv {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        ensure!(inputs.len() == 3);
        let x_fact = inputs[0];
        let k_fact = inputs[1];
        ensure!(
            &self.pool_spec.input_channels.to_dim()
                == self.pool_spec.data_format.shape(&inputs[0].shape)?.c()
        );
        ensure!(
            self.pool_spec.input_channels.to_dim()
                == *self.kernel_format.input_channels(&k_fact.shape, self.group)
        );
        let output_shape = super::output_shape(&self.pool_spec, &x_fact.shape, &self.adjustments)?;
        Ok(tvec!(x_fact.datum_type.fact(&output_shape)))
    }
    fn axes_mapping(
        &self,
        inputs: &[&TypedFact],
        outputs: &[&TypedFact],
    ) -> TractResult<AxesMapping> {
        let fact = &inputs[0];
        let k_fact = &inputs[1];
        let shape = self.pool_spec.data_format.shape(&fact.shape)?;
        let mut axes = AxesMapping::disconnected(inputs, outputs)?
            .renaming((InOut::In(0), shape.c_axis()), 'I')?
            .renaming((InOut::Out(0), shape.c_axis()), 'O')?;
        if let Some(n_axis) = shape.n_axis() {
            axes = axes
                .renaming((InOut::In(0), n_axis), 'N')?
                .linking('N', (InOut::Out(0), n_axis))?;
        }
        let h_axis = shape.h_axis();
        let geo = "HWXYZ".chars().chain('a'..);
        let kernel_spatial_shape = self.kernel_format.spatial_shape(&k_fact.shape);
        for ((ix, dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) {
            if dim.is_one()
                && self.pool_spec.stride(ix) == 1
                && self.pool_spec.padding.valid_dim(ix, true)
                && self.adjustments[ix] == 0
            {
                axes = axes
                    .renaming((InOut::In(0), ix + h_axis), repr)?
                    .linking((InOut::In(0), ix + h_axis), (InOut::Out(0), ix + h_axis))?;
            }
        }
        Ok(axes)
    }
    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        let inputs = patch.taps(model, &node.inputs)?;
        let output = self
            .wire_with_deconv_sum(&node.name, &mut patch, &inputs)
            .context("In wire_with_deconv_sum")?;
        patch.shunt_outside(model, node.id.into(), output[0])?;
        Ok(Some(patch))
    }
    as_op!();
}