use std::fmt::Debug;
use crate::internal::*;
use crate::ops::array::Slice;
use crate::tract_data::itertools::Itertools;
mod eval;
use super::array::TypedConcat;
use super::math::add;
mod as_matmul;
mod codegen;
#[cfg(test)]
mod proptest;
pub use as_matmul::{rewrite_einsums_as_matmul, BasicMatMul};
#[derive(Clone, Hash)]
pub struct EinSum {
    pub axes: AxesMapping,
    pub operating_dt: DatumType,
    pub q_params: Option<DatumType>,
}
impl EinSum {
    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
        EinSum { axes, operating_dt, q_params: None }
    }
    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
        EinSum { axes, operating_dt, q_params: Some(output_type) }
    }
    #[allow(unused_variables)]
    pub(crate) fn propagate_axis(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        io: InOut,
        axis: usize,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut new_axis = self.axes.axis((io, axis))?.clone();
        let repr = new_axis.repr;
        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
        let mut taps = tvec!();
        for (ix, input) in node.inputs.iter().enumerate() {
            let mut tap = patch.tap_model(model, *input)?;
            if new_axis.inputs[ix].len() > 1 {
                return Ok(None); } else if new_axis.inputs[ix].is_empty() {
                let insert_at = self.axes.rank(InOut::In(ix));
                tap = patch.wire_node(
                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
                    AxisOp::Add(insert_at),
                    &[tap],
                )?[0];
                new_axis.inputs[ix].push(insert_at);
            }
            taps.push(tap);
        }
        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
            let insert_at = self.axes.rank(InOut::Out(0));
            new_axis.outputs[0].push(insert_at);
            Some(insert_at)
        } else {
            None
        };
        let new_expr = self
            .axes
            .iter_all_axes()
            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
            .collect_vec();
        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
        if let Some(position) = must_rm_axis {
            wire = patch.wire_node(
                format!("{}.prop_axis.{}.output", &node.name, repr),
                AxisOp::Rm(position),
                &wire,
            )?;
        }
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        Ok(Some(patch))
    }
    #[allow(clippy::comparison_chain)]
    fn declutter_after_concat(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if self.q_params.is_some() {
            return Ok(None);
        }
        'outer: for (slot, input) in node.inputs.iter().enumerate() {
            let precursor = model.node(input.node);
            if let Some(concat) = precursor.op_as::<TypedConcat>() {
                let offsets = concat.offsets(&model.node_input_facts(precursor.id)?)?;
                let axis_info = self.axes.axis((InOut::In(slot), concat.axis))?;
                if axis_info.outputs[0].len() > 0 {
                    continue;
                }
                let mut patch = TypedModelPatch::new(format!(
                    "Split Einsum for concat on axis {}",
                    axis_info.repr
                ));
                let mut inputs: TVec<TVec<OutletId>> = tvec!();
                for (slot, input) in node.inputs.iter().enumerate() {
                    let tap = patch.tap_model(model, *input)?;
                    if axis_info.inputs[slot].len() > 1 {
                        continue 'outer;
                    } else if axis_info.inputs[slot].len() == 1 {
                        let mut slices = tvec!();
                        for (start, end) in offsets.iter().cloned().tuple_windows() {
                            let wire = patch.wire_node(
                                format!(
                                    "{}.concat-einsum-slice-{}.{}.{}..{}",
                                    node.name, axis_info.repr, slot, start, end
                                ),
                                Slice { axis: axis_info.inputs[slot][0], start, end },
                                &[tap],
                            )?;
                            slices.push(wire[0]);
                        }
                        inputs.push(slices);
                    } else {
                        inputs.push(tvec!(tap)); };
                }
                let mut einsums = tvec!();
                for (ix, (start, end)) in offsets.iter().tuple_windows().enumerate() {
                    let mut einsum_inputs = tvec!();
                    for input_ix in 0..node.inputs.len() {
                        einsum_inputs
                            .push(inputs[input_ix].get(ix).cloned().unwrap_or(inputs[input_ix][0]));
                    }
                    let einsum = patch.wire_node(
                        format!(
                            "{}.concat-einsum-{}.{}..{}",
                            node.name, axis_info.repr, start, end
                        ),
                        self.clone(),
                        &einsum_inputs,
                    )?[0];
                    einsums.push(einsum);
                }
                let wire = if let Some(axis) = axis_info.outputs[0].first().cloned() {
                    patch.wire_node(
                        format!("{}.concat-einsum-{}.concat", node.name, axis_info.repr),
                        TypedConcat { axis },
                        &einsums,
                    )?[0]
                } else {
                    let mut wire = einsums[0];
                    for ix in 1..einsums.len() {
                        wire = patch.wire_node(
                            format!("{}.concat-einsum-{}.add-{}", node.name, axis_info.repr, ix),
                            add(),
                            &[wire, einsums[ix]],
                        )?[0]
                    }
                    wire
                };
                patch.shunt_outside(model, node.id.into(), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }
}
impl Debug for EinSum {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
    }
}
impl Op for EinSum {
    fn name(&self) -> Cow<str> {
        "EinSum".into()
    }
    fn info(&self) -> TractResult<Vec<String>> {
        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
        if let Some(qp) = self.q_params {
            info.push(format!("Quantized output: {qp:?}"));
        }
        Ok(info)
    }
    op_as_typed_op!();
}
impl EvalOp for EinSum {
    fn is_stateless(&self) -> bool {
        true
    }
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let output = if let Some(qp) = self.q_params {
            eval::eval_q(&self.axes, qp, inputs)
        } else {
            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
        }?;
        Ok(tvec!(output.into_tvalue()))
    }
}
impl TypedOp for EinSum {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        ensure!(inputs.len() == self.axes.input_count());
        ensure!(inputs
            .iter()
            .enumerate()
            .all(|(ix, fact)| fact.rank() == self.axes.rank(InOut::In(ix))));
        let shapes: TVec<&[TDim]> = inputs.iter().map(|t| &*t.shape).collect();
        if let Some(qp) = self.q_params {
            ensure!(inputs.len() == 9);
            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2]))))
        } else {
            Ok(tvec!(TypedFact::dt_shape(
                self.operating_dt,
                eval::output_shape(&self.axes, &shapes)
            )))
        }
    }
    fn axes_mapping(
        &self,
        _inputs: &[&TypedFact],
        _outputs: &[&TypedFact],
    ) -> TractResult<AxesMapping> {
        Ok(self.axes.clone())
    }
    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        let shapes: TVec<&[TDim]> = inputs.iter().map(|t| &*t.shape).collect();
        let oshape = eval::output_shape(&self.axes, &shapes);
        let ks = self
            .axes
            .iter_all_axes()
            .filter(|axis| axis.outputs[0].len() == 0)
            .map(|axis| {
                axis.inputs
                    .iter()
                    .enumerate()
                    .flat_map(|(ix, axes)| {
                        axes.iter()
                            .map(|axis| shapes[ix][*axis].clone())
                            .collect::<TVec<_>>()
                            .into_iter()
                    })
                    .max()
                    .unwrap()
            })
            .product::<TDim>();
        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
    }
    fn slice(
        &self,
        patch: &mut TypedModelPatch,
        prefix: &str,
        inputs: &[OutletId],
        _output_axis: usize,
        _start: usize,
        _end: usize,
    ) -> TractResult<Option<TVec<OutletId>>> {
        patch.wire_node(prefix, self.clone(), inputs).map(Some)
    }
    #[allow(unused_variables)]
    fn change_axes(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        io: InOut,
        change: &AxisOp,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        let (mut inputs, mut outputs) = self.axes.to_strs();
        let interface: &mut String = match io {
            InOut::In(i) => &mut inputs[i],
            InOut::Out(o) => &mut outputs[o],
        };
        let mut axes: Vec<char> = interface.chars().collect();
        match change {
            AxisOp::Rm(rm) => {
                axes.remove(*rm);
            }
            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
            AxisOp::Move(from, to) => {
                let c = axes.remove(*from);
                axes.insert(*to, c);
            }
            _ => return Ok(None),
        };
        *interface = axes.into_iter().collect();
        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
        Ok(Some(AxisChangeConsequence {
            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
            wire_changes: tvec!((io, change.clone())),
        }))
    }
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        self.declutter_after_concat(model, node)
    }
    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        codegen::codegen(self, model, node)
    }
    as_op!();
}