use fallible_iterator::FallibleIterator;
use memchr::memchr;
pub use crate::dialect::TokenType;
use crate::dialect::TokenType::*;
use crate::dialect::{
    is_identifier_continue, is_identifier_start, keyword_token, sentinel, MAX_KEYWORD_LEN,
};
use crate::parser::ast::Cmd;
use crate::parser::parse::{yyParser, YYCODETYPE};
use crate::parser::Context;
mod error;
#[cfg(test)]
mod test;
use crate::lexer::scan::ScanError;
use crate::lexer::scan::Splitter;
use crate::lexer::Scanner;
pub use crate::parser::ParserError;
pub use error::Error;
pub struct Parser<'input> {
    input: &'input [u8],
    scanner: Scanner<Tokenizer>,
    parser: yyParser<'input>,
}
impl<'input> Parser<'input> {
    pub fn new(input: &'input [u8]) -> Parser<'input> {
        let lexer = Tokenizer::new();
        let scanner = Scanner::new(lexer);
        let ctx = Context::new(input);
        let parser = yyParser::new(ctx);
        Parser {
            input,
            scanner,
            parser,
        }
    }
    pub fn reset(&mut self, input: &'input [u8]) {
        self.input = input;
        self.scanner.reset();
    }
    pub fn line(&self) -> u64 {
        self.scanner.line()
    }
    pub fn column(&self) -> usize {
        self.scanner.column()
    }
}
fn get_token(scanner: &mut Scanner<Tokenizer>, input: &[u8]) -> Result<TokenType, Error> {
    let mut t = {
        let (_, token_type) = match scanner.scan(input)? {
            (_, None, _) => {
                return Ok(TK_EOF);
            }
            (_, Some(tuple), _) => tuple,
        };
        token_type
    };
    if t == TK_ID
        || t == TK_STRING
        || t == TK_JOIN_KW
        || t == TK_WINDOW
        || t == TK_OVER
        || yyParser::parse_fallback(t as YYCODETYPE) == TK_ID as YYCODETYPE
    {
        t = TK_ID;
    }
    Ok(t)
}
fn analyze_window_keyword(
    scanner: &mut Scanner<Tokenizer>,
    input: &[u8],
) -> Result<TokenType, Error> {
    let t = get_token(scanner, input)?;
    if t != TK_ID {
        return Ok(TK_ID);
    };
    let t = get_token(scanner, input)?;
    if t != TK_AS {
        return Ok(TK_ID);
    };
    Ok(TK_WINDOW)
}
fn analyze_over_keyword(
    scanner: &mut Scanner<Tokenizer>,
    input: &[u8],
    last_token: TokenType,
) -> Result<TokenType, Error> {
    if last_token == TK_RP {
        let t = get_token(scanner, input)?;
        if t == TK_LP || t == TK_ID {
            return Ok(TK_OVER);
        }
    }
    Ok(TK_ID)
}
fn analyze_filter_keyword(
    scanner: &mut Scanner<Tokenizer>,
    input: &[u8],
    last_token: TokenType,
) -> Result<TokenType, Error> {
    if last_token == TK_RP && get_token(scanner, input)? == TK_LP {
        return Ok(TK_FILTER);
    }
    Ok(TK_ID)
}
macro_rules! try_with_position {
    ($scanner:expr, $expr:expr) => {
        match $expr {
            Ok(val) => val,
            Err(err) => {
                let mut err = Error::from(err);
                err.position($scanner.line(), $scanner.column());
                return Err(err);
            }
        }
    };
}
impl<'input> FallibleIterator for Parser<'input> {
    type Item = Cmd;
    type Error = Error;
    fn next(&mut self) -> Result<Option<Cmd>, Error> {
        self.parser.ctx.reset();
        let mut last_token_parsed = TK_EOF;
        let mut eof = false;
        loop {
            let (start, (value, mut token_type), end) = match self.scanner.scan(self.input)? {
                (_, None, _) => {
                    eof = true;
                    break;
                }
                (start, Some(tuple), end) => (start, tuple, end),
            };
            let token = if token_type >= TK_WINDOW {
                debug_assert!(
                    token_type == TK_OVER || token_type == TK_FILTER || token_type == TK_WINDOW
                );
                self.scanner.mark();
                if token_type == TK_WINDOW {
                    token_type = analyze_window_keyword(&mut self.scanner, self.input)?;
                } else if token_type == TK_OVER {
                    token_type =
                        analyze_over_keyword(&mut self.scanner, self.input, last_token_parsed)?;
                } else if token_type == TK_FILTER {
                    token_type =
                        analyze_filter_keyword(&mut self.scanner, self.input, last_token_parsed)?;
                }
                self.scanner.reset_to_mark();
                token_type.to_token(start, value, end)
            } else {
                token_type.to_token(start, value, end)
            };
            try_with_position!(self.scanner, self.parser.sqlite3Parser(token_type, token));
            last_token_parsed = token_type;
            if self.parser.ctx.done() {
                break;
            }
        }
        if last_token_parsed == TK_EOF {
            return Ok(None); }
        if eof && self.parser.ctx.is_ok() {
            if last_token_parsed != TK_SEMI {
                try_with_position!(
                    self.scanner,
                    self.parser
                        .sqlite3Parser(TK_SEMI, sentinel(self.input.len()))
                );
            }
            try_with_position!(
                self.scanner,
                self.parser
                    .sqlite3Parser(TK_EOF, sentinel(self.input.len()))
            );
        }
        self.parser.sqlite3ParserFinalize();
        if let Some(e) = self.parser.ctx.error() {
            let err = Error::ParserError(e, Some((self.scanner.line(), self.scanner.column())));
            return Err(err);
        }
        let cmd = self.parser.ctx.cmd();
        Ok(cmd)
    }
}
pub type Token<'input> = (&'input [u8], TokenType);
#[derive(Default)]
pub struct Tokenizer {}
impl Tokenizer {
    pub fn new() -> Tokenizer {
        Tokenizer {}
    }
}
impl Splitter for Tokenizer {
    type Error = Error;
    type TokenType = TokenType;
    fn split<'input>(
        &mut self,
        data: &'input [u8],
    ) -> Result<(Option<Token<'input>>, usize), Error> {
        if data[0].is_ascii_whitespace() {
            return Ok((
                None,
                match data.iter().skip(1).position(|&b| !b.is_ascii_whitespace()) {
                    Some(i) => i + 1,
                    _ => data.len(),
                },
            ));
        }
        return match data[0] {
            b'-' => {
                if let Some(b) = data.get(1) {
                    if *b == b'-' {
                        if let Some(i) = memchr(b'\n', data) {
                            Ok((None, i + 1))
                        } else {
                            Ok((None, data.len()))
                        }
                    } else if *b == b'>' {
                        if let Some(b) = data.get(2) {
                            if *b == b'>' {
                                return Ok((Some((&data[..3], TK_PTR)), 3));
                            }
                        }
                        Ok((Some((&data[..2], TK_PTR)), 2))
                    } else {
                        Ok((Some((&data[..1], TK_MINUS)), 1))
                    }
                } else {
                    Ok((Some((&data[..1], TK_MINUS)), 1))
                }
            }
            b'(' => Ok((Some((&data[..1], TK_LP)), 1)),
            b')' => Ok((Some((&data[..1], TK_RP)), 1)),
            b';' => Ok((Some((&data[..1], TK_SEMI)), 1)),
            b'+' => Ok((Some((&data[..1], TK_PLUS)), 1)),
            b'*' => Ok((Some((&data[..1], TK_STAR)), 1)),
            b'/' => {
                if let Some(b) = data.get(1) {
                    if *b == b'*' {
                        let mut pb = 0;
                        let mut end = None;
                        for (i, b) in data.iter().enumerate().skip(2) {
                            if *b == b'/' && pb == b'*' {
                                end = Some(i);
                                break;
                            }
                            pb = *b;
                        }
                        if let Some(i) = end {
                            Ok((None, i + 1))
                        } else {
                            Err(Error::UnterminatedBlockComment(None))
                        }
                    } else {
                        Ok((Some((&data[..1], TK_SLASH)), 1))
                    }
                } else {
                    Ok((Some((&data[..1], TK_SLASH)), 1))
                }
            }
            b'%' => Ok((Some((&data[..1], TK_REM)), 1)),
            b'=' => {
                if let Some(b) = data.get(1) {
                    Ok(if *b == b'=' {
                        (Some((&data[..2], TK_EQ)), 2)
                    } else {
                        (Some((&data[..1], TK_EQ)), 1)
                    })
                } else {
                    Ok((Some((&data[..1], TK_EQ)), 1))
                }
            }
            b'<' => {
                if let Some(b) = data.get(1) {
                    Ok(match *b {
                        b'=' => (Some((&data[..2], TK_LE)), 2),
                        b'>' => (Some((&data[..2], TK_NE)), 2),
                        b'<' => (Some((&data[..2], TK_LSHIFT)), 2),
                        _ => (Some((&data[..1], TK_LT)), 1),
                    })
                } else {
                    Ok((Some((&data[..1], TK_LT)), 1))
                }
            }
            b'>' => {
                if let Some(b) = data.get(1) {
                    Ok(match *b {
                        b'=' => (Some((&data[..2], TK_GE)), 2),
                        b'>' => (Some((&data[..2], TK_RSHIFT)), 2),
                        _ => (Some((&data[..1], TK_GT)), 1),
                    })
                } else {
                    Ok((Some((&data[..1], TK_GT)), 1))
                }
            }
            b'!' => {
                if let Some(b) = data.get(1) {
                    if *b == b'=' {
                        Ok((Some((&data[..2], TK_NE)), 2))
                    } else {
                        Err(Error::ExpectedEqualsSign(None))
                    }
                } else {
                    Err(Error::ExpectedEqualsSign(None))
                }
            }
            b'|' => {
                if let Some(b) = data.get(1) {
                    Ok(if *b == b'|' {
                        (Some((&data[..2], TK_CONCAT)), 2)
                    } else {
                        (Some((&data[..1], TK_BITOR)), 1)
                    })
                } else {
                    Ok((Some((&data[..1], TK_BITOR)), 1))
                }
            }
            b',' => Ok((Some((&data[..1], TK_COMMA)), 1)),
            b'&' => Ok((Some((&data[..1], TK_BITAND)), 1)),
            b'~' => Ok((Some((&data[..1], TK_BITNOT)), 1)),
            quote @ b'`' | quote @ b'\'' | quote @ b'"' => literal(data, quote),
            b'.' => {
                if let Some(b) = data.get(1) {
                    if b.is_ascii_digit() {
                        fractional_part(data, 0)
                    } else {
                        Ok((Some((&data[..1], TK_DOT)), 1))
                    }
                } else {
                    Ok((Some((&data[..1], TK_DOT)), 1))
                }
            }
            b'0'..=b'9' => number(data),
            b'[' => {
                if let Some(i) = memchr(b']', data) {
                    Ok((Some((&data[0..i + 1], TK_ID)), i + 1))
                } else {
                    Err(Error::UnterminatedBracket(None))
                }
            }
            b'?' => {
                match data.iter().skip(1).position(|&b| !b.is_ascii_digit()) {
                    Some(i) => {
                        Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
                    }
                    None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
                }
            }
            b'$' | b'@' | b'#' | b':' => {
                match data
                    .iter()
                    .skip(1)
                    .position(|&b| !is_identifier_continue(b))
                {
                    Some(0) => Err(Error::BadVariableName(None)),
                    Some(i) => {
                        Ok((Some((&data[..=i], TK_VARIABLE)), i + 1))
                    }
                    None => {
                        if data.len() == 1 {
                            return Err(Error::BadVariableName(None));
                        }
                        Ok((Some((data, TK_VARIABLE)), data.len()))
                    }
                }
            }
            b if is_identifier_start(b) => {
                if b == b'x' || b == b'X' {
                    if let Some(&b'\'') = data.get(1) {
                        blob_literal(data)
                    } else {
                        Ok(self.identifierish(data))
                    }
                } else {
                    Ok(self.identifierish(data))
                }
            }
            _ => Err(Error::UnrecognizedToken(None)),
        };
    }
}
fn literal(data: &[u8], quote: u8) -> Result<(Option<Token<'_>>, usize), Error> {
    debug_assert_eq!(data[0], quote);
    let tt = if quote == b'\'' { TK_STRING } else { TK_ID };
    let mut pb = 0;
    let mut end = None;
    for (i, b) in data.iter().enumerate().skip(1) {
        if *b == quote {
            if pb == quote {
                pb = 0;
                continue;
            }
        } else if pb == quote {
            end = Some(i);
            break;
        }
        pb = *b;
    }
    if end.is_some() || pb == quote {
        let i = match end {
            Some(i) => i,
            _ => data.len(),
        };
        Ok((Some((&data[0..i], tt)), i))
    } else {
        Err(Error::UnterminatedLiteral(None))
    }
}
fn blob_literal(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
    debug_assert!(data[0] == b'x' || data[0] == b'X');
    debug_assert_eq!(data[1], b'\'');
    return if let Some((i, b)) = data
        .iter()
        .enumerate()
        .skip(2)
        .find(|&(_, &b)| !b.is_ascii_hexdigit())
    {
        if *b != b'\'' || i % 2 != 0 {
            return Err(Error::MalformedBlobLiteral(None));
        }
        Ok((Some((&data[2..i], TK_BLOB)), i + 1))
    } else {
        Err(Error::MalformedBlobLiteral(None))
    };
}
fn number(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
    debug_assert!(data[0].is_ascii_digit());
    if data[0] == b'0' {
        if let Some(b) = data.get(1) {
            if *b == b'x' || *b == b'X' {
                return hex_integer(data);
            }
        } else {
            return Ok((Some((data, TK_INTEGER)), data.len()));
        }
    }
    return if let Some((i, b)) = data
        .iter()
        .enumerate()
        .skip(1)
        .find(|&(_, &b)| !b.is_ascii_digit())
    {
        if *b == b'.' {
            return fractional_part(data, i);
        } else if *b == b'e' || *b == b'E' {
            return exponential_part(data, i);
        } else if is_identifier_start(*b) {
            return Err(Error::BadNumber(None));
        }
        Ok((Some((&data[..i], TK_INTEGER)), i))
    } else {
        Ok((Some((data, TK_INTEGER)), data.len()))
    };
}
fn hex_integer(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
    debug_assert_eq!(data[0], b'0');
    debug_assert!(data[1] == b'x' || data[1] == b'X');
    return if let Some((i, b)) = data
        .iter()
        .enumerate()
        .skip(2)
        .find(|&(_, &b)| !b.is_ascii_hexdigit())
    {
        if i == 2 || is_identifier_start(*b) {
            return Err(Error::MalformedHexInteger(None));
        }
        Ok((Some((&data[..i], TK_INTEGER)), i))
    } else {
        if data.len() == 2 {
            return Err(Error::MalformedHexInteger(None));
        }
        Ok((Some((data, TK_INTEGER)), data.len()))
    };
}
fn fractional_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
    debug_assert_eq!(data[i], b'.');
    return if let Some((i, b)) = data
        .iter()
        .enumerate()
        .skip(i + 1)
        .find(|&(_, &b)| !b.is_ascii_digit())
    {
        if *b == b'e' || *b == b'E' {
            return exponential_part(data, i);
        } else if is_identifier_start(*b) {
            return Err(Error::BadNumber(None));
        }
        Ok((Some((&data[..i], TK_FLOAT)), i))
    } else {
        Ok((Some((data, TK_FLOAT)), data.len()))
    };
}
fn exponential_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
    debug_assert!(data[i] == b'e' || data[i] == b'E');
    return if let Some(b) = data.get(i + 1) {
        let i = if *b == b'+' || *b == b'-' { i + 1 } else { i };
        if let Some((i, b)) = data
            .iter()
            .enumerate()
            .skip(i + 1)
            .find(|&(_, &b)| !b.is_ascii_digit())
        {
            if is_identifier_start(*b) {
                return Err(Error::BadNumber(None));
            }
            Ok((Some((&data[..i], TK_FLOAT)), i))
        } else {
            if data.len() == i + 1 {
                return Err(Error::BadNumber(None));
            }
            Ok((Some((data, TK_FLOAT)), data.len()))
        }
    } else {
        Err(Error::BadNumber(None))
    };
}
impl Tokenizer {
    fn identifierish<'input>(&mut self, data: &'input [u8]) -> (Option<Token<'input>>, usize) {
        debug_assert!(is_identifier_start(data[0]));
        let end = data
            .iter()
            .skip(1)
            .position(|&b| !is_identifier_continue(b));
        let i = match end {
            Some(i) => i + 1,
            _ => data.len(),
        };
        let word = &data[..i];
        let tt = if word.len() >= 2 && word.len() <= MAX_KEYWORD_LEN && word.is_ascii() {
            keyword_token(word).unwrap_or(TK_ID)
        } else {
            TK_ID
        };
        (Some((word, tt)), i)
    }
}
#[cfg(test)]
mod tests {
    use super::Tokenizer;
    use crate::dialect::TokenType;
    use crate::lexer::Scanner;
    #[test]
    fn fallible_iterator() {
        let tokenizer = Tokenizer::new();
        let input = "PRAGMA parser_trace=ON;".as_bytes();
        let mut s = Scanner::new(tokenizer);
        let (token1, token_type1) = s.scan(input).unwrap().1.unwrap();
        assert!(b"PRAGMA".eq_ignore_ascii_case(token1));
        assert_eq!(TokenType::TK_PRAGMA, token_type1);
        let (token2, token_type2) = s.scan(input).unwrap().1.unwrap();
        assert_eq!("parser_trace".as_bytes(), token2);
        assert_eq!(TokenType::TK_ID, token_type2);
    }
}