use starknet_crypto_codegen::poseidon_consts;
use starknet_ff::FieldElement;
poseidon_consts!();
#[derive(Debug, Default)]
pub struct PoseidonHasher {
    state: [FieldElement; 3],
    buffer: Option<FieldElement>,
}
impl PoseidonHasher {
    pub fn new() -> Self {
        Self::default()
    }
    pub fn update(&mut self, msg: FieldElement) {
        match self.buffer.take() {
            Some(previous_message) => {
                self.state[0] += previous_message;
                self.state[1] += msg;
                poseidon_permute_comp(&mut self.state);
            }
            None => {
                self.buffer = Some(msg);
            }
        }
    }
    pub fn finalize(mut self) -> FieldElement {
        match self.buffer.take() {
            Some(last_message) => {
                self.state[0] += last_message;
                self.state[1] += FieldElement::ONE;
            }
            None => {
                self.state[0] += FieldElement::ONE;
            }
        }
        poseidon_permute_comp(&mut self.state);
        self.state[0]
    }
}
pub fn poseidon_hash(x: FieldElement, y: FieldElement) -> FieldElement {
    let mut state = [x, y, FieldElement::TWO];
    poseidon_permute_comp(&mut state);
    state[0]
}
pub fn poseidon_hash_single(x: FieldElement) -> FieldElement {
    let mut state = [x, FieldElement::ZERO, FieldElement::ONE];
    poseidon_permute_comp(&mut state);
    state[0]
}
pub fn poseidon_hash_many(msgs: &[FieldElement]) -> FieldElement {
    let mut state = [FieldElement::ZERO, FieldElement::ZERO, FieldElement::ZERO];
    let mut iter = msgs.chunks_exact(2);
    for msg in iter.by_ref() {
        state[0] += msg[0];
        state[1] += msg[1];
        poseidon_permute_comp(&mut state);
    }
    let r = iter.remainder();
    if r.len() == 1 {
        state[0] += r[0];
    }
    state[r.len()] += FieldElement::ONE;
    poseidon_permute_comp(&mut state);
    state[0]
}
pub fn poseidon_permute_comp(state: &mut [FieldElement; 3]) {
    let mut idx = 0;
    for _ in 0..(FULL_ROUNDS / 2) {
        round_comp(state, idx, true);
        idx += 3;
    }
    for _ in 0..PARTIAL_ROUNDS {
        round_comp(state, idx, false);
        idx += 1;
    }
    for _ in 0..(FULL_ROUNDS / 2) {
        round_comp(state, idx, true);
        idx += 3;
    }
}
#[inline(always)]
fn mix(state: &mut [FieldElement; 3]) {
    let t = state[0] + state[1] + state[2];
    state[0] = t + FieldElement::TWO * state[0];
    state[1] = t - FieldElement::TWO * state[1];
    state[2] = t - FieldElement::THREE * state[2];
}
#[inline]
fn round_comp(state: &mut [FieldElement; 3], idx: usize, full: bool) {
    if full {
        state[0] += POSEIDON_COMP_CONSTS[idx];
        state[1] += POSEIDON_COMP_CONSTS[idx + 1];
        state[2] += POSEIDON_COMP_CONSTS[idx + 2];
        state[0] = state[0] * state[0] * state[0];
        state[1] = state[1] * state[1] * state[1];
        state[2] = state[2] * state[2] * state[2];
    } else {
        state[2] += POSEIDON_COMP_CONSTS[idx];
        state[2] = state[2] * state[2] * state[2];
    }
    mix(state);
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
    fn test_poseidon_hash() {
        let test_data = [
            (
                FieldElement::from_hex_be(
                    "0xb662f9017fa7956fd70e26129b1833e10ad000fd37b4d9f4e0ce6884b7bbe",
                )
                .unwrap(),
                FieldElement::from_hex_be(
                    "0x1fe356bf76102cdae1bfbdc173602ead228b12904c00dad9cf16e035468bea",
                )
                .unwrap(),
                FieldElement::from_hex_be(
                    "0x75540825a6ecc5dc7d7c2f5f868164182742227f1367d66c43ee51ec7937a81",
                )
                .unwrap(),
            ),
            (
                FieldElement::from_hex_be(
                    "0xf4e01b2032298f86b539e3d3ac05ced20d2ef275273f9325f8827717156529",
                )
                .unwrap(),
                FieldElement::from_hex_be(
                    "0x587bc46f5f58e0511b93c31134652a689d761a9e7f234f0f130c52e4679f3a",
                )
                .unwrap(),
                FieldElement::from_hex_be(
                    "0xbdb3180fdcfd6d6f172beb401af54dd71b6569e6061767234db2b777adf98b",
                )
                .unwrap(),
            ),
        ];
        for (x, y, hash) in test_data.into_iter() {
            assert_eq!(poseidon_hash(x, y), hash);
        }
    }
    #[test]
    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
    fn test_poseidon_hash_single() {
        let test_data = [
            (
                FieldElement::from_hex_be(
                    "0x9dad5d6f502ccbcb6d34ede04f0337df3b98936aaf782f4cc07d147e3a4fd6",
                )
                .unwrap(),
                FieldElement::from_hex_be(
                    "0x11222854783f17f1c580ff64671bc3868de034c236f956216e8ed4ab7533455",
                )
                .unwrap(),
            ),
            (
                FieldElement::from_hex_be(
                    "0x3164a8e2181ff7b83391b4a86bc8967f145c38f10f35fc74e9359a0c78f7b6",
                )
                .unwrap(),
                FieldElement::from_hex_be(
                    "0x79ad7aa7b98d47705446fa01865942119026ac748d67a5840f06948bce2306b",
                )
                .unwrap(),
            ),
        ];
        for (x, hash) in test_data.into_iter() {
            assert_eq!(poseidon_hash_single(x), hash);
        }
    }
    #[test]
    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
    fn test_poseidon_hash_many() {
        let test_data = [
            (
                vec![
                    FieldElement::from_hex_be(
                        "0x9bf52404586087391c5fbb42538692e7ca2149bac13c145ae4230a51a6fc47",
                    )
                    .unwrap(),
                    FieldElement::from_hex_be(
                        "0x40304159ee9d2d611120fbd7c7fb8020cc8f7a599bfa108e0e085222b862c0",
                    )
                    .unwrap(),
                    FieldElement::from_hex_be(
                        "0x46286e4f3c450761d960d6a151a9c0988f9e16f8a48d4c0a85817c009f806a",
                    )
                    .unwrap(),
                ],
                FieldElement::from_hex_be(
                    "0x1ec38b38dc88bac7b0ed6ff6326f975a06a59ac601b417745fd412a5d38e4f7",
                )
                .unwrap(),
            ),
            (
                vec![
                    FieldElement::from_hex_be(
                        "0xbdace8883922662601b2fd197bb660b081fcf383ede60725bd080d4b5f2fd3",
                    )
                    .unwrap(),
                    FieldElement::from_hex_be(
                        "0x1eb1daaf3fdad326b959dec70ced23649cdf8786537cee0c5758a1a4229097",
                    )
                    .unwrap(),
                    FieldElement::from_hex_be(
                        "0x869ca04071b779d6f940cdf33e62d51521e19223ab148ef571856ff3a44ff1",
                    )
                    .unwrap(),
                    FieldElement::from_hex_be(
                        "0x533e6df8d7c4b634b1f27035c8676a7439c635e1fea356484de7f0de677930",
                    )
                    .unwrap(),
                ],
                FieldElement::from_hex_be(
                    "0x2520b8f910174c3e650725baacad4efafaae7623c69a0b5513d75e500f36624",
                )
                .unwrap(),
            ),
        ];
        for (input, hash) in test_data.into_iter() {
            assert_eq!(poseidon_hash_many(&input), hash);
            let mut hasher = PoseidonHasher::new();
            input.iter().for_each(|msg| hasher.update(*msg));
            assert_eq!(hasher.finalize(), hash);
        }
    }
}