use crate::{
    collections::{BTreeMap, VecDeque},
    error::{Error, Result},
    merge::{hash_leaf, merge},
    traits::Hasher,
    vec::Vec,
    H256, TREE_HEIGHT,
};
type Range = core::ops::Range<usize>;
#[derive(Debug, Clone)]
pub struct MerkleProof {
    leaves_path: Vec<Vec<u8>>,
    proof: Vec<(H256, u8)>,
}
impl MerkleProof {
    
    
    
    pub fn new(leaves_path: Vec<Vec<u8>>, proof: Vec<(H256, u8)>) -> Self {
        MerkleProof { leaves_path, proof }
    }
    
    pub fn take(self) -> (Vec<Vec<u8>>, Vec<(H256, u8)>) {
        let MerkleProof { leaves_path, proof } = self;
        (leaves_path, proof)
    }
    
    pub fn leaves_count(&self) -> usize {
        self.leaves_path.len()
    }
    
    pub fn leaves_path(&self) -> &Vec<Vec<u8>> {
        &self.leaves_path
    }
    
    pub fn proof(&self) -> &Vec<(H256, u8)> {
        &self.proof
    }
    pub fn compile(self, mut leaves: Vec<(H256, H256)>) -> Result<CompiledMerkleProof> {
        if leaves.is_empty() {
            return Err(Error::EmptyKeys);
        } else if leaves.len() != self.leaves_count() {
            return Err(Error::IncorrectNumberOfLeaves {
                expected: self.leaves_count(),
                actual: leaves.len(),
            });
        }
        let (leaves_path, proof) = self.take();
        let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
        let mut proof: VecDeque<_> = proof.into();
        
        leaves.sort_unstable_by_key(|(k, _v)| *k);
        
        let mut tree_buf: BTreeMap<_, _> = leaves
            .into_iter()
            .enumerate()
            .map(|(i, (k, _v))| ((0, k), (i, leaf_program(i))))
            .collect();
        
        while !tree_buf.is_empty() {
            
            let &(mut height, key) = tree_buf.keys().next().unwrap();
            let (leaf_index, program) = tree_buf.remove(&(height, key)).unwrap();
            if proof.is_empty() && tree_buf.is_empty() {
                return Ok(CompiledMerkleProof(program.0));
            } else if height == TREE_HEIGHT {
                if !proof.is_empty() {
                    return Err(Error::CorruptedProof);
                }
                return Ok(CompiledMerkleProof(program.0));
            }
            let mut sibling_key = key.parent_path(height as u8);
            if !key.get_bit(height as u8) {
                sibling_key.set_bit(height as u8)
            }
            let (parent_key, parent_program, height) =
                if Some(&(height, sibling_key)) == tree_buf.keys().next() {
                    let (_leaf_index, sibling_program) = tree_buf
                        .remove(&(height, sibling_key))
                        .expect("pop sibling");
                    let parent_key = key.parent_path(height as u8);
                    let parent_program = merge_program(&program, &sibling_program, height as u8)?;
                    (parent_key, parent_program, height)
                } else {
                    let merge_height = leaves_path[leaf_index]
                        .front()
                        .map(|h| *h as usize)
                        .unwrap_or(height);
                    if height != merge_height {
                        debug_assert!(height < merge_height);
                        let parent_key = key.copy_bits(merge_height as u8..);
                        
                        tree_buf.insert((merge_height, parent_key), (leaf_index, program));
                        continue;
                    }
                    let (proof, proof_height) = proof.pop_front().expect("pop proof");
                    debug_assert_eq!(proof_height, leaves_path[leaf_index][0]);
                    let proof_height = proof_height as usize;
                    debug_assert!(height <= proof_height);
                    if height < proof_height {
                        height = proof_height;
                    }
                    let parent_key = key.parent_path(height as u8);
                    let parent_program = proof_program(&program, proof, height as u8);
                    (parent_key, parent_program, height)
                };
            leaves_path[leaf_index].pop_front();
            tree_buf.insert((height + 1, parent_key), (leaf_index, parent_program));
        }
        Err(Error::CorruptedProof)
    }
    
    
    
    
    
    pub fn compute_root<H: Hasher + Default>(self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
        if leaves.is_empty() {
            return Err(Error::EmptyKeys);
        } else if leaves.len() != self.leaves_count() {
            return Err(Error::IncorrectNumberOfLeaves {
                expected: self.leaves_count(),
                actual: leaves.len(),
            });
        }
        let (leaves_path, proof) = self.take();
        let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
        let mut proof: VecDeque<_> = proof.into();
        
        leaves.sort_unstable_by_key(|(k, _v)| *k);
        
        let mut tree_buf: BTreeMap<_, _> = leaves
            .into_iter()
            .enumerate()
            .map(|(i, (k, v))| ((0, k), (i, hash_leaf::<H>(&k, &v))))
            .collect();
        
        while !tree_buf.is_empty() {
            
            let (&(mut height, key), &(leaf_index, node)) = tree_buf.iter().next().unwrap();
            tree_buf.remove(&(height, key));
            if proof.is_empty() && tree_buf.is_empty() {
                return Ok(node);
            } else if height == TREE_HEIGHT {
                if !proof.is_empty() {
                    return Err(Error::CorruptedProof);
                }
                return Ok(node);
            }
            let mut sibling_key = key.parent_path(height as u8);
            if !key.get_bit(height as u8) {
                sibling_key.set_bit(height as u8)
            }
            let (sibling, sibling_height) =
                if Some(&(height, sibling_key)) == tree_buf.keys().next() {
                    let (_leaf_index, sibling) = tree_buf
                        .remove(&(height, sibling_key))
                        .expect("pop sibling");
                    (sibling, height)
                } else {
                    let merge_height = leaves_path[leaf_index]
                        .front()
                        .map(|h| *h as usize)
                        .unwrap_or(height);
                    if height != merge_height {
                        debug_assert!(height < merge_height);
                        let parent_key = key.copy_bits(merge_height as u8..);
                        
                        tree_buf.insert((merge_height, parent_key), (leaf_index, node));
                        continue;
                    }
                    let (node, height) = proof.pop_front().expect("pop proof");
                    debug_assert_eq!(height, leaves_path[leaf_index][0]);
                    (node, height as usize)
                };
            debug_assert!(height <= sibling_height);
            if height < sibling_height {
                height = sibling_height;
            }
            
            let parent_key = key.parent_path(height as u8);
            let parent = if key.get_bit(height as u8) {
                merge::<H>(&sibling, &node)
            } else {
                merge::<H>(&node, &sibling)
            };
            leaves_path[leaf_index].pop_front();
            tree_buf.insert((height + 1, parent_key), (leaf_index, parent));
        }
        Err(Error::CorruptedProof)
    }
    
    
    pub fn verify<H: Hasher + Default>(
        self,
        root: &H256,
        leaves: Vec<(H256, H256)>,
    ) -> Result<bool> {
        let calculated_root = self.compute_root::<H>(leaves)?;
        Ok(&calculated_root == root)
    }
}
fn leaf_program(leaf_index: usize) -> (Vec<u8>, Option<Range>) {
    let mut program = Vec::with_capacity(1);
    program.push(0x4C);
    (
        program,
        Some(Range {
            start: leaf_index,
            end: leaf_index + 1,
        }),
    )
}
fn proof_program(
    child: &(Vec<u8>, Option<Range>),
    proof: H256,
    height: u8,
) -> (Vec<u8>, Option<Range>) {
    let (child_program, child_range) = child;
    let mut program = Vec::new();
    program.resize(34 + child_program.len(), 0x50);
    program[..child_program.len()].copy_from_slice(child_program);
    program[child_program.len() + 1] = height;
    program[child_program.len() + 2..].copy_from_slice(proof.as_slice());
    (program, child_range.clone())
}
fn merge_program(
    a: &(Vec<u8>, Option<Range>),
    b: &(Vec<u8>, Option<Range>),
    height: u8,
) -> Result<(Vec<u8>, Option<Range>)> {
    let (a_program, a_range) = a;
    let (b_program, b_range) = b;
    let (a_comes_first, range) = if a_range.is_none() || b_range.is_none() {
        let range = if a_range.is_none() { b_range } else { a_range }
            .clone()
            .unwrap();
        (true, range)
    } else {
        let a_range = a_range.clone().unwrap();
        let b_range = b_range.clone().unwrap();
        if a_range.end == b_range.start {
            (
                true,
                Range {
                    start: a_range.start,
                    end: b_range.end,
                },
            )
        } else {
            return Err(Error::NonMergableRange);
        }
    };
    let mut program = Vec::new();
    program.resize(2 + a_program.len() + b_program.len(), 0x48);
    if a_comes_first {
        program[..a_program.len()].copy_from_slice(a_program);
        program[a_program.len()..a_program.len() + b_program.len()].copy_from_slice(b_program);
    } else {
        program[..b_program.len()].copy_from_slice(b_program);
        program[b_program.len()..a_program.len() + b_program.len()].copy_from_slice(a_program);
    }
    program[a_program.len() + b_program.len() + 1] = height;
    Ok((program, Some(range)))
}
#[derive(Debug, Clone)]
pub struct CompiledMerkleProof(pub Vec<u8>);
impl CompiledMerkleProof {
    pub fn compute_root<H: Hasher + Default>(&self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
        leaves.sort_unstable_by_key(|(k, _v)| *k);
        let mut program_index = 0;
        let mut leave_index = 0;
        let mut stack = Vec::new();
        while program_index < self.0.len() {
            let code = self.0[program_index];
            program_index += 1;
            match code {
                
                0x4C => {
                    if leave_index >= leaves.len() {
                        return Err(Error::CorruptedStack);
                    }
                    let (k, v) = leaves[leave_index];
                    stack.push((k, hash_leaf::<H>(&k, &v)));
                    leave_index += 1;
                }
                
                0x50 => {
                    if stack.is_empty() {
                        return Err(Error::CorruptedStack);
                    }
                    if program_index + 33 > self.0.len() {
                        return Err(Error::CorruptedProof);
                    }
                    let height = self.0[program_index];
                    program_index += 1;
                    let mut data = [0u8; 32];
                    data.copy_from_slice(&self.0[program_index..program_index + 32]);
                    program_index += 32;
                    let proof = H256::from(data);
                    let (key, value) = stack.pop().unwrap();
                    let parent_key = key.parent_path(height);
                    let parent = if key.get_bit(height) {
                        merge::<H>(&proof, &value)
                    } else {
                        merge::<H>(&value, &proof)
                    };
                    stack.push((parent_key, parent));
                }
                
                0x48 => {
                    if stack.len() < 2 {
                        return Err(Error::CorruptedStack);
                    }
                    if program_index >= self.0.len() {
                        return Err(Error::CorruptedProof);
                    }
                    let height = self.0[program_index];
                    program_index += 1;
                    let (key_b, value_b) = stack.pop().unwrap();
                    let (key_a, value_a) = stack.pop().unwrap();
                    let parent_key_a = key_a.copy_bits(height..);
                    let parent_key_b = key_b.copy_bits(height..);
                    let a_set = key_a.get_bit(height);
                    let b_set = key_b.get_bit(height);
                    let mut sibling_key_a = parent_key_a;
                    if !a_set {
                        sibling_key_a.set_bit(height);
                    }
                    
                    if !(sibling_key_a == parent_key_b && (a_set ^ b_set)) {
                        return Err(Error::NonSiblings);
                    }
                    let parent = if key_a.get_bit(height) {
                        merge::<H>(&value_b, &value_a)
                    } else {
                        merge::<H>(&value_a, &value_b)
                    };
                    stack.push((parent_key_a, parent));
                }
                _ => return Err(Error::InvalidCode(code)),
            }
        }
        if stack.len() != 1 {
            return Err(Error::CorruptedStack);
        }
        Ok(stack[0].1)
    }
    pub fn verify<H: Hasher + Default>(
        &self,
        root: &H256,
        leaves: Vec<(H256, H256)>,
    ) -> Result<bool> {
        let calculated_root = self.compute_root::<H>(leaves)?;
        Ok(&calculated_root == root)
    }
}