use std::convert::TryFrom;
use std::convert::TryInto;
use std::fmt::{self, Debug, Display};
use std::iter;
use tract_nnef::internal::*;
use tract_ndarray::{
    Array1, Array2, ArrayD, ArrayView1, ArrayView2, ArrayViewD, ArrayViewMut1, Axis, Ix1, Ix2,
};
use tract_num_traits::AsPrimitive;
macro_rules! ensure {
    ($cond: expr, $($rest: expr),* $(,)?) => {
        if !$cond {
            bail!($($rest),*)
        }
    }
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Cmp {
    Equal = 1,
    NotEqual = 2,
    Less = 3,
    Greater = 4,
    LessEqual = 5,
    GreaterEqual = 6,
}
impl Cmp {
    pub fn compare<T>(&self, x: T, y: T) -> bool
    where
        T: PartialOrd + Copy,
    {
        match self {
            Cmp::LessEqual => x <= y,
            Cmp::Less => x < y,
            Cmp::GreaterEqual => x >= y,
            Cmp::Greater => x > y,
            Cmp::Equal => x == y,
            Cmp::NotEqual => x != y,
        }
    }
    pub fn to_u8(&self) -> u8 {
        unsafe { std::mem::transmute(*self) }
    }
}
impl TryFrom<u8> for Cmp {
    type Error = TractError;
    fn try_from(value: u8) -> Result<Self, Self::Error> {
        if value >= 1 && value <= 5 {
            unsafe { Ok(std::mem::transmute(value)) }
        } else {
            bail!("Invalid value for Cmp: {}", value);
        }
    }
}
impl Display for Cmp {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.write_str(match self {
            Cmp::LessEqual => "<=",
            Cmp::Less => "<",
            Cmp::GreaterEqual => ">=",
            Cmp::Greater => ">",
            Cmp::Equal => "==",
            Cmp::NotEqual => "!=",
        })
    }
}
#[derive(Clone, Debug, Hash)]
pub struct TreeEnsembleData {
    
    pub trees: Arc<Tensor>,
    
    
    
    
    
    
    pub nodes: Arc<Tensor>,
    
    pub leaves: Arc<Tensor>,
}
impl TreeEnsembleData {
    unsafe fn get_unchecked(&self, node: usize) -> TreeNode {
        let row = &self.nodes.as_slice_unchecked::<u32>()[node * 5..][..5];
        if let Some(cmp) = ((row[4] & 0xFF) as u8).try_into().ok() {
            let feature_id = row[0];
            let true_id = row[1];
            let false_id = row[2];
            let value = f32::from_bits(row[3]);
            let nan_is_true = (row[4] & 0x0100) != 0;
            TreeNode::Branch(BranchNode { cmp, feature_id, value, true_id, false_id, nan_is_true })
        } else {
            TreeNode::Leaf(LeafNode { start_id: row[0] as usize, end_id: row[1] as usize })
        }
    }
    unsafe fn get_leaves_unchecked<T>(&self, tree: usize, input: &ArrayView1<T>) -> LeafNode
    where
        T: AsPrimitive<f32>,
    {
        let mut node_id = self.trees.as_slice_unchecked::<u32>()[tree] as usize;
        loop {
            let node = self.get_unchecked(node_id);
            match node {
                TreeNode::Branch(ref b) => {
                    let feature = *input.uget(b.feature_id as usize);
                    node_id = b.get_child_id(feature.as_());
                }
                TreeNode::Leaf(l) => return l,
            }
        }
    }
    unsafe fn eval_unchecked<A, T>(
        &self,
        tree: usize,
        input: &ArrayView1<T>,
        output: &mut ArrayViewMut1<f32>,
        aggs: &mut [A],
    ) where
        A: AggregateFn,
        T: AsPrimitive<f32>,
    {
        let leaf = self.get_leaves_unchecked(tree, input);
        for leaf in self
            .leaves
            .to_array_view_unchecked::<u32>()
            .outer_iter()
            .skip(leaf.start_id)
            .take(leaf.end_id - leaf.start_id)
        {
            let class_id = leaf[0] as usize;
            let weight = f32::from_bits(leaf[1]);
            let agg_fn = aggs.get_unchecked_mut(class_id);
            agg_fn.aggregate(weight, output.uget_mut(class_id));
        }
    }
}
#[derive(Copy, Clone)]
struct BranchNode {
    pub cmp: Cmp, 
    pub feature_id: u32,
    pub value: f32,
    pub true_id: u32,
    pub false_id: u32,
    pub nan_is_true: bool,
}
impl std::fmt::Debug for BranchNode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "if feat({}) {} {} then {} else {}",
            self.feature_id, self.cmp, self.value, self.true_id, self.false_id
        )
    }
}
impl BranchNode {
    pub fn get_child_id(&self, feature: f32) -> usize {
        let condition =
            if feature.is_nan() { self.nan_is_true } else { self.cmp.compare(feature, self.value) };
        if condition {
            self.true_id as usize
        } else {
            self.false_id as usize
        }
    }
}
#[derive(Copy, Clone, Debug, Hash)]
struct LeafNode {
    pub start_id: usize,
    pub end_id: usize,
}
#[derive(Copy, Clone, Debug)]
enum TreeNode {
    Branch(BranchNode),
    Leaf(LeafNode),
}
#[derive(Copy, Clone, Debug)]
struct Leaf {
    class_id: u32,
    weight: f32,
}
pub trait AggregateFn: Default {
    fn aggregate(&mut self, score: f32, total: &mut f32);
    fn post_aggregate(&mut self, _total: &mut f32) {}
}
#[derive(Clone, Copy, Default, Debug)]
pub struct SumFn;
impl AggregateFn for SumFn {
    fn aggregate(&mut self, score: f32, total: &mut f32) {
        *total += score;
    }
}
#[derive(Clone, Copy, Default, Debug)]
pub struct AvgFn {
    count: usize,
}
impl AggregateFn for AvgFn {
    fn aggregate(&mut self, score: f32, total: &mut f32) {
        *total += score;
        self.count += 1;
    }
    fn post_aggregate(&mut self, total: &mut f32) {
        if self.count > 1 {
            *total /= self.count as f32;
        }
        self.count = 0;
    }
}
#[derive(Clone, Copy, Default, Debug)]
pub struct MaxFn;
impl AggregateFn for MaxFn {
    fn aggregate(&mut self, score: f32, total: &mut f32) {
        *total = total.max(score);
    }
}
#[derive(Clone, Copy, Default, Debug)]
pub struct MinFn;
impl AggregateFn for MinFn {
    fn aggregate(&mut self, score: f32, total: &mut f32) {
        *total = total.min(score);
    }
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Aggregate {
    Sum,
    Avg,
    Max,
    Min,
}
impl Default for Aggregate {
    fn default() -> Self {
        Aggregate::Sum
    }
}
#[derive(Clone, Debug, Hash)]
pub struct TreeEnsemble {
    pub data: TreeEnsembleData,
    pub n_features: usize,
    pub n_classes: usize,
    pub aggregate_fn: Aggregate, 
}
impl TreeEnsemble {
    pub fn build(
        data: TreeEnsembleData,
        n_features: usize,
        n_classes: usize,
        aggregate_fn: Aggregate,
    ) -> TractResult<Self> {
        Ok(Self { data, n_features, n_classes, aggregate_fn })
    }
    pub fn n_classes(&self) -> usize {
        self.n_classes
    }
    unsafe fn eval_one_unchecked<A, T>(
        &self,
        input: &ArrayView1<T>,
        output: &mut ArrayViewMut1<f32>,
        aggs: &mut [A],
    ) where
        A: AggregateFn,
        T: AsPrimitive<f32>,
    {
        for t in 0..self.data.trees.len() {
            self.data.eval_unchecked(t, input, output, aggs)
        }
        for i in 0..self.n_classes {
            aggs.get_unchecked_mut(i).post_aggregate(output.uget_mut(i));
        }
    }
    pub fn check_n_features(&self, n_features: usize) -> TractResult<()> {
        Ok(ensure!(
            n_features == self.n_features,
            "Invalid input shape: got {} features, expected {}",
            n_features,
            self.n_features
        ))
    }
    fn eval_2d<'i, A, T>(&self, input: &ArrayView2<T>) -> TractResult<Array2<f32>>
    where
        A: AggregateFn,
        T: AsPrimitive<f32>,
    {
        self.check_n_features(input.shape()[1])?;
        let n = input.shape()[0];
        let mut output = Array2::zeros((n, self.n_classes));
        let mut aggs: tract_smallvec::SmallVec<[A; 16]> =
            iter::repeat_with(Default::default).take(self.n_classes).collect();
        for i in 0..n {
            unsafe {
                self.eval_one_unchecked::<A, T>(
                    &input.index_axis(Axis(0), i),
                    &mut output.index_axis_mut(Axis(0), i),
                    &mut aggs,
                );
            }
        }
        Ok(output)
    }
    fn eval_1d<'i, A, T>(&self, input: &ArrayView1<T>) -> TractResult<Array1<f32>>
    where
        A: AggregateFn,
        T: AsPrimitive<f32>,
    {
        self.check_n_features(input.len())?;
        let mut output = Array1::zeros(self.n_classes);
        let mut aggs: tract_smallvec::SmallVec<[A; 16]> =
            iter::repeat_with(Default::default).take(self.n_classes).collect();
        unsafe {
            self.eval_one_unchecked::<A, T>(input, &mut output.view_mut(), &mut aggs);
        }
        Ok(output)
    }
    pub fn eval<'i, I, T>(&self, input: I) -> TractResult<ArrayD<f32>>
    where
        I: Into<ArrayViewD<'i, T>>, 
        T: AsPrimitive<f32>,
    {
        let input = input.into();
        if let Ok(input) = input.view().into_dimensionality::<Ix1>() {
            Ok(match self.aggregate_fn {
                Aggregate::Sum => self.eval_1d::<SumFn, T>(&input),
                Aggregate::Avg => self.eval_1d::<AvgFn, T>(&input),
                Aggregate::Min => self.eval_1d::<MinFn, T>(&input),
                Aggregate::Max => self.eval_1d::<MaxFn, T>(&input),
            }?
            .into_dyn())
        } else if let Ok(input) = input.view().into_dimensionality::<Ix2>() {
            Ok(match self.aggregate_fn {
                Aggregate::Sum => self.eval_2d::<SumFn, T>(&input),
                Aggregate::Avg => self.eval_2d::<AvgFn, T>(&input),
                Aggregate::Min => self.eval_2d::<MinFn, T>(&input),
                Aggregate::Max => self.eval_2d::<MaxFn, T>(&input),
            }?
            .into_dyn())
        } else {
            bail!("Invalid input dimensionality for tree ensemble: {:?}", input.shape());
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use tract_ndarray::prelude::*;
    fn b(
        node_offset: usize,
        cmp: Cmp,
        feat: usize,
        v: f32,
        left: usize,
        right: usize,
        nan_is_true: bool,
    ) -> [u32; 5] {
        [
            feat as u32,
            (node_offset + left) as u32,
            (node_offset + right) as u32,
            v.to_bits(),
            cmp as u32 | if nan_is_true { 0x100 } else { 0 },
        ]
    }
    fn l(leaf_offset: usize, start_id: usize, end_id: usize) -> [u32; 5] {
        [(leaf_offset + start_id) as u32, (leaf_offset + end_id) as u32, 0, 0, 0]
    }
    fn w(categ: usize, weight: f32) -> [u32; 2] {
        [categ as u32, weight.to_bits()]
    }
    fn generate_gbm_trees() -> TreeEnsembleData {
        let trees = rctensor1(&[0u32, 5u32, 14, 21, 30, 41]);
        let nodes = rctensor2(&[
            b(0, Cmp::LessEqual, 2, 3.15, 1, 2, true),
            b(0, Cmp::LessEqual, 1, 3.35, 3, 4, true),
            l(0, 0, 1),
            l(0, 1, 2),
            l(0, 2, 3),
            
            b(5, Cmp::LessEqual, 2, 1.8, 1, 2, true),
            l(3, 0, 1),
            b(5, Cmp::LessEqual, 3, 1.65, 3, 4, true),
            b(5, Cmp::LessEqual, 2, 4.45, 5, 6, true),
            b(5, Cmp::LessEqual, 2, 5.35, 7, 8, true),
            l(3, 1, 2),
            l(3, 2, 3),
            l(3, 3, 4),
            l(3, 4, 5),
            
            b(14, Cmp::LessEqual, 3, 1.65, 1, 2, true),
            b(14, Cmp::LessEqual, 2, 4.45, 3, 4, true),
            b(14, Cmp::LessEqual, 2, 5.35, 5, 6, true),
            l(8, 0, 1),
            l(8, 1, 2),
            l(8, 2, 3),
            l(8, 3, 4),
            
            b(21, Cmp::LessEqual, 2, 3.15, 1, 2, true),
            b(21, Cmp::LessEqual, 1, 3.35, 3, 4, true),
            b(21, Cmp::LessEqual, 2, 4.45, 5, 6, true),
            l(12, 0, 1),
            l(12, 1, 2),
            l(12, 2, 3),
            b(21, Cmp::LessEqual, 2, 5.35, 7, 8, true),
            l(12, 3, 4),
            l(12, 4, 5),
            
            b(30, Cmp::LessEqual, 3, 0.45, 1, 2, true),
            b(30, Cmp::LessEqual, 2, 1.45, 3, 4, true),
            b(30, Cmp::LessEqual, 3, 1.65, 5, 6, true),
            l(17, 0, 1),
            l(17, 1, 2),
            b(30, Cmp::LessEqual, 2, 4.45, 7, 8, true),
            b(30, Cmp::LessEqual, 2, 5.35, 9, 10, true),
            l(17, 2, 3),
            l(17, 3, 4),
            l(17, 4, 5),
            l(17, 5, 6),
            
            b(41, Cmp::LessEqual, 2, 4.75, 1, 2, true),
            b(41, Cmp::LessEqual, 1, 2.75, 3, 4, true),
            b(41, Cmp::LessEqual, 2, 5.15, 7, 8, true),
            l(23, 0, 1),
            b(41, Cmp::LessEqual, 2, 4.15, 5, 6, true),
            l(23, 1, 2),
            l(23, 2, 3),
            l(23, 3, 4),
            l(23, 4, 5),
        ]);
        assert_eq!(nodes.shape(), &[50, 5]);
        let leaves = rctensor2(&[
            w(0, -0.075),
            w(0, 0.13928571),
            w(0, 0.15),
            
            w(1, -0.075),
            w(1, 0.13548388),
            w(1, 0.110869564),
            w(1, -0.052500002),
            w(1, -0.075),
            
            w(2, -0.075),
            w(2, -0.035869565),
            w(2, 0.1275),
            w(2, 0.15),
            
            w(0, 0.12105576),
            w(0, 0.1304589),
            w(0, -0.07237862),
            w(0, -0.07226522),
            w(0, -0.07220469),
            
            w(1, -0.07226842),
            w(1, -0.07268012),
            w(1, 0.119391434),
            w(1, 0.097440675),
            w(1, -0.049815115),
            w(1, -0.07219931),
            
            w(2, -0.061642267),
            w(2, -0.0721846),
            w(2, -0.07319043),
            w(2, 0.076814815),
            w(2, 0.1315959),
        ]);
        assert_eq!(leaves.shape(), &[28, 2]);
        TreeEnsembleData { nodes, trees, leaves }
    }
    fn generate_gbm_ensemble() -> TreeEnsemble {
        
        let trees = generate_gbm_trees();
        TreeEnsemble::build(trees, 4, 3, Aggregate::Sum).unwrap()
    }
    fn generate_gbm_input() -> Array2<f32> {
        arr2(&[
            [5.1, 3.5, 1.4, 0.2],
            [5.4, 3.7, 1.5, 0.2],
            [5.4, 3.4, 1.7, 0.2],
            [4.8, 3.1, 1.6, 0.2],
            [5.0, 3.5, 1.3, 0.3],
            [7.0, 3.2, 4.7, 1.4],
            [5.0, 2.0, 3.5, 1.0],
            [5.9, 3.2, 4.8, 1.8],
            [5.5, 2.4, 3.8, 1.1],
            [5.5, 2.6, 4.4, 1.2],
            [6.3, 3.3, 6.0, 2.5],
            [6.5, 3.2, 5.1, 2.0],
            [6.9, 3.2, 5.7, 2.3],
            [7.4, 2.8, 6.1, 1.9],
            [6.7, 3.1, 5.6, 2.4],
        ])
    }
    fn generate_gbm_raw_output() -> Array2<f32> {
        arr2(&[
            [0.28045893, -0.14726841, -0.14718461],
            [0.28045893, -0.14768013, -0.14718461],
            [0.28045893, -0.14768013, -0.14718461],
            [0.26034147, -0.14768013, -0.14718461],
            [0.28045893, -0.14726841, -0.14718461],
            [-0.14726523, 0.20831025, -0.10905999],
            [-0.14737862, 0.25487530, -0.13664228],
            [-0.14726523, -0.10231511, 0.20431481],
            [-0.14737862, 0.25487530, -0.13664228],
            [-0.14737862, 0.25487530, -0.13664228],
            [-0.14720470, -0.14719930, 0.28159590],
            [-0.14726523, -0.10231511, 0.20431481],
            [-0.14720470, -0.14719930, 0.28159590],
            [-0.14720470, -0.14719930, 0.28159590],
            [-0.14720470, -0.14719930, 0.28159590],
        ])
    }
    #[test]
    fn test_tree_ensemble() {
        let ensemble = generate_gbm_ensemble();
        let input = generate_gbm_input();
        let output = ensemble.eval(&input.view().into_dyn()).unwrap();
        assert_eq!(output, generate_gbm_raw_output().into_dyn());
    }
}