use std::cell::UnsafeCell;
use std::rc::Rc;
use bumpalo::Bump;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use thiserror::Error;
use crate::ast::Node;
use crate::compiler::CompilerError::{
    ArgumentNotFound, UnknownBinaryOperator, UnknownBuiltIn, UnknownUnaryOperator,
};
use crate::opcodes::{Opcode, TypeConversionKind, Variable};
type Bytecode<'a> = Rc<UnsafeCell<Vec<&'a Opcode<'a>>>>;
#[derive(Debug, Error)]
pub enum CompilerError {
    #[error("Unknown unary operator: {operator}")]
    UnknownUnaryOperator { operator: String },
    #[error("Unknown binary operator: {operator}")]
    UnknownBinaryOperator { operator: String },
    #[error("Unknown builtin: {builtin}")]
    UnknownBuiltIn { builtin: String },
    #[error("Argument not found for builtin {builtin} at index {index}")]
    ArgumentNotFound { builtin: String, index: usize },
}
pub struct Compiler<'a> {
    root: &'a Node<'a>,
    bytecode: Bytecode<'a>,
    bump: &'a Bump,
}
impl<'a> Compiler<'a> {
    pub fn new(root: &'a Node<'a>, bytecode: Bytecode<'a>, bump: &'a Bump) -> Self {
        Self {
            root,
            bytecode,
            bump,
        }
    }
    pub fn compile(&self) -> Result<(), CompilerError> {
        self.compile_node(self.root)?;
        Ok(())
    }
    fn emit(&self, op: Opcode<'a>) -> usize {
        let bc = unsafe { &mut *self.bytecode.get() };
        bc.push(self.bump.alloc(op));
        bc.len()
    }
    fn emit_loop<F>(&self, mut body: F) -> Result<(), CompilerError>
    where
        F: FnMut() -> Result<(), CompilerError>,
    {
        let begin = unsafe { (*self.bytecode.get()).len() };
        let end = self.emit(Opcode::JumpIfEnd(0));
        body()?;
        self.emit(Opcode::IncrementIt);
        let e = self.emit(Opcode::JumpBackward(self.calc_backward_jump(begin)));
        self.replace(end, Opcode::JumpIfEnd(e - end));
        Ok(())
    }
    fn emit_cond<F>(&self, mut body: F)
    where
        F: FnMut(),
    {
        let noop = self.emit(Opcode::JumpIfFalse(0));
        self.emit(Opcode::Pop);
        body();
        let jmp = self.emit(Opcode::Jump(0));
        self.replace(noop, Opcode::JumpIfFalse(jmp - noop));
        let e = self.emit(Opcode::Pop);
        self.replace(jmp, Opcode::Jump(e - jmp));
    }
    fn replace(&self, at: usize, op: Opcode<'a>) {
        let bytecode = unsafe { &mut *self.bytecode.get() };
        let _ = std::mem::replace(&mut bytecode[at - 1], self.bump.alloc(op));
    }
    fn calc_backward_jump(&self, to: usize) -> usize {
        unsafe { (*self.bytecode.get()).len() + 1 - to }
    }
    fn compile_argument(
        &self,
        name: &str,
        arguments: &&[&'a Node<'a>],
        index: usize,
    ) -> Result<usize, CompilerError> {
        let arg = arguments.get(index).ok_or_else(|| ArgumentNotFound {
            index,
            builtin: name.to_string(),
        })?;
        self.compile_node(arg)
    }
    fn compile_node(&self, node: &'a Node<'a>) -> Result<usize, CompilerError> {
        match node {
            Node::Null => Ok(self.emit(Opcode::Push(Variable::Null))),
            Node::Bool(v) => Ok(self.emit(Opcode::Push(Variable::Bool(*v)))),
            Node::Number(v) => Ok(self.emit(Opcode::Push(Variable::Number(*v)))),
            Node::String(v) => Ok(self.emit(Opcode::Push(Variable::String(v)))),
            Node::Pointer => Ok(self.emit(Opcode::Pointer)),
            Node::Array(v) => {
                v.iter()
                    .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
                self.emit(Opcode::Push(Variable::Number(Decimal::from(v.len()))));
                Ok(self.emit(Opcode::Array))
            }
            Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(v))),
            Node::Closure(v) => self.compile_node(v),
            Node::Member { node, property } => {
                self.compile_node(node)?;
                self.compile_node(property)?;
                Ok(self.emit(Opcode::Fetch))
            }
            Node::Slice { node, to, from } => {
                self.compile_node(node)?;
                if let Some(t) = to {
                    self.compile_node(t)?;
                } else {
                    self.emit(Opcode::Len);
                    self.emit(Opcode::Push(Variable::Number(dec!(1))));
                    self.emit(Opcode::Subtract);
                }
                if let Some(f) = from {
                    self.compile_node(f)?;
                } else {
                    self.emit(Opcode::Push(Variable::Number(dec!(0))));
                }
                Ok(self.emit(Opcode::Slice))
            }
            Node::Interval {
                left,
                right,
                left_bracket,
                right_bracket,
            } => {
                self.compile_node(left)?;
                self.compile_node(right)?;
                Ok(self.emit(Opcode::Interval {
                    left_bracket,
                    right_bracket,
                }))
            }
            Node::Conditional {
                condition,
                on_true,
                on_false,
            } => {
                self.compile_node(condition)?;
                let otherwise = self.emit(Opcode::JumpIfFalse(0));
                self.emit(Opcode::Pop);
                self.compile_node(on_true)?;
                let end = self.emit(Opcode::Jump(0));
                self.replace(otherwise, Opcode::JumpIfFalse(end - otherwise));
                self.emit(Opcode::Pop);
                let b = self.compile_node(on_false)?;
                self.replace(end, Opcode::Jump(b - end));
                Ok(b)
            }
            Node::Unary { node, operator } => {
                let curr = self.compile_node(node)?;
                match *operator {
                    "+" => Ok(curr),
                    "!" | "not" => Ok(self.emit(Opcode::Not)),
                    "-" => Ok(self.emit(Opcode::Negate)),
                    _ => Err(UnknownUnaryOperator {
                        operator: operator.to_string(),
                    }),
                }
            }
            Node::Binary {
                left,
                right,
                operator,
            } => match *operator {
                "==" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Equal))
                }
                "!=" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    self.emit(Opcode::Equal);
                    Ok(self.emit(Opcode::Not))
                }
                "or" => {
                    self.compile_node(left)?;
                    let end = self.emit(Opcode::JumpIfTrue(0));
                    self.emit(Opcode::Pop);
                    let r = self.compile_node(right)?;
                    self.replace(end, Opcode::JumpIfTrue(r - end));
                    Ok(r)
                }
                "and" => {
                    self.compile_node(left)?;
                    let end = self.emit(Opcode::JumpIfFalse(0));
                    self.emit(Opcode::Pop);
                    let r = self.compile_node(right)?;
                    self.replace(end, Opcode::JumpIfFalse(r - end));
                    Ok(r)
                }
                "in" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::In))
                }
                "not in" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    self.emit(Opcode::In);
                    Ok(self.emit(Opcode::Not))
                }
                "<" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Less))
                }
                "<=" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::LessOrEqual))
                }
                ">" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::More))
                }
                ">=" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::MoreOrEqual))
                }
                "+" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Add))
                }
                "-" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Subtract))
                }
                "*" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Multiply))
                }
                "/" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Divide))
                }
                "%" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Modulo))
                }
                "^" => {
                    self.compile_node(left)?;
                    self.compile_node(right)?;
                    Ok(self.emit(Opcode::Exponent))
                }
                _ => Err(UnknownBinaryOperator {
                    operator: operator.to_string(),
                }),
            },
            Node::BuiltIn { name, arguments } => match *name {
                "len" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Len);
                    self.emit(Opcode::Rot);
                    Ok(self.emit(Opcode::Pop))
                }
                "date" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::ParseDateTime))
                }
                "time" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::ParseTime))
                }
                "duration" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::ParseDuration))
                }
                "startsWith" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.compile_argument(name, arguments, 1)?;
                    Ok(self.emit(Opcode::StartsWith))
                }
                "endsWith" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.compile_argument(name, arguments, 1)?;
                    Ok(self.emit(Opcode::EndsWith))
                }
                "contains" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.compile_argument(name, arguments, 1)?;
                    Ok(self.emit(Opcode::Contains))
                }
                "matches" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.compile_argument(name, arguments, 1)?;
                    Ok(self.emit(Opcode::Matches))
                }
                "extract" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.compile_argument(name, arguments, 1)?;
                    Ok(self.emit(Opcode::Extract))
                }
                "flatten" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Flatten))
                }
                "upper" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Uppercase))
                }
                "lower" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Lowercase))
                }
                "abs" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Abs))
                }
                "avg" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Average))
                }
                "median" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Median))
                }
                "mode" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Mode))
                }
                "max" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Max))
                }
                "min" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Min))
                }
                "sum" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Sum))
                }
                "floor" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Floor))
                }
                "ceil" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Ceil))
                }
                "round" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Round))
                }
                "rand" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::Random))
                }
                "string" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::TypeConversion(TypeConversionKind::String)))
                }
                "number" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::TypeConversion(TypeConversionKind::Number)))
                }
                "startOf" | "endOf" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.compile_argument(name, arguments, 1)?;
                    Ok(self.emit(Opcode::DateFunction(name)))
                }
                "dayOfWeek" | "dayOfMonth" | "dayOfYear" | "weekOfYear" | "monthOfYear"
                | "monthString" | "weekdayString" | "year" | "dateString" => {
                    self.compile_argument(name, arguments, 0)?;
                    Ok(self.emit(Opcode::DateManipulation(name)))
                }
                "all" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    let mut loop_break: usize = 0;
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        loop_break = self.emit(Opcode::JumpIfFalse(0));
                        self.emit(Opcode::Pop);
                        Ok(())
                    })?;
                    let e = self.emit(Opcode::Push(Variable::Bool(true)));
                    self.replace(loop_break, Opcode::JumpIfFalse(e - loop_break));
                    Ok(self.emit(Opcode::End))
                }
                "none" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    let mut loop_break: usize = 0;
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        self.emit(Opcode::Not);
                        loop_break = self.emit(Opcode::JumpIfFalse(0));
                        self.emit(Opcode::Pop);
                        Ok(())
                    })?;
                    let e = self.emit(Opcode::Push(Variable::Bool(true)));
                    self.replace(loop_break, Opcode::JumpIfFalse(e - loop_break));
                    Ok(self.emit(Opcode::End))
                }
                "some" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    let mut loop_break: usize = 0;
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        loop_break = self.emit(Opcode::JumpIfTrue(0));
                        self.emit(Opcode::Pop);
                        Ok(())
                    })?;
                    let e = self.emit(Opcode::Push(Variable::Bool(false)));
                    self.replace(loop_break, Opcode::JumpIfTrue(e - loop_break));
                    Ok(self.emit(Opcode::End))
                }
                "one" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        self.emit_cond(|| {
                            self.emit(Opcode::IncrementCount);
                        });
                        Ok(())
                    })?;
                    self.emit(Opcode::GetCount);
                    self.emit(Opcode::Push(Variable::Number(dec!(1))));
                    self.emit(Opcode::Equal);
                    Ok(self.emit(Opcode::End))
                }
                "filter" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        self.emit_cond(|| {
                            self.emit(Opcode::IncrementCount);
                            self.emit(Opcode::Pointer);
                        });
                        Ok(())
                    })?;
                    self.emit(Opcode::GetCount);
                    self.emit(Opcode::End);
                    Ok(self.emit(Opcode::Array))
                }
                "map" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        Ok(())
                    })?;
                    self.emit(Opcode::GetLen);
                    self.emit(Opcode::End);
                    Ok(self.emit(Opcode::Array))
                }
                "flatMap" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        Ok(())
                    })?;
                    self.emit(Opcode::GetLen);
                    self.emit(Opcode::End);
                    self.emit(Opcode::Array);
                    Ok(self.emit(Opcode::Flatten))
                }
                "count" => {
                    self.compile_argument(name, arguments, 0)?;
                    self.emit(Opcode::Begin);
                    self.emit_loop(|| {
                        self.compile_argument(name, arguments, 1)?;
                        self.emit_cond(|| {
                            self.emit(Opcode::IncrementCount);
                        });
                        Ok(())
                    })?;
                    self.emit(Opcode::GetCount);
                    Ok(self.emit(Opcode::End))
                }
                _ => Err(UnknownBuiltIn {
                    builtin: name.to_string(),
                }),
            },
        }
    }
}