senax_mysql_parser/
common.rs

1use nom::IResult;
2use nom::branch::alt;
3use nom::character::complete::{digit1, line_ending, multispace0, multispace1};
4use nom::combinator::{map, not, peek};
5use nom::{AsChar, Parser};
6use serde::Deserialize;
7use serde::Serialize;
8use std::fmt;
9use std::str;
10use std::str::FromStr;
11
12use super::column::Column;
13use super::keywords::{escape, sql_keyword};
14use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_while1};
15use nom::combinator::opt;
16use nom::error::{ErrorKind, ParseError};
17use nom::multi::{fold_many0, many0};
18use nom::sequence::{delimited, pair, preceded, terminated};
19
20#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
21pub enum SqlType {
22    Bool,
23    Char(u32),
24    Varchar(u32),
25    Int,
26    UnsignedInt,
27    Smallint,
28    UnsignedSmallint,
29    Bigint,
30    UnsignedBigint,
31    Tinyint,
32    UnsignedTinyint,
33    Blob,
34    Longblob,
35    Mediumblob,
36    Tinyblob,
37    Double,
38    Float,
39    Real,
40    Tinytext,
41    Mediumtext,
42    Longtext,
43    Text,
44    Date,
45    Time,
46    DateTime(u16),
47    Timestamp(u16),
48    Binary(u16),
49    Varbinary(u16),
50    Enum(Vec<Literal>),
51    Set(Vec<Literal>),
52    Decimal(u16, u16),
53    Json,
54    Point,
55    Geometry,
56}
57
58impl fmt::Display for SqlType {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        match *self {
61            SqlType::Bool => write!(f, "BOOL"),
62            SqlType::Char(len) => write!(f, "CHAR({})", len),
63            SqlType::Varchar(len) => write!(f, "VARCHAR({})", len),
64            SqlType::Int => write!(f, "INT"),
65            SqlType::UnsignedInt => write!(f, "INT UNSIGNED"),
66            SqlType::Smallint => write!(f, "SMALLINT"),
67            SqlType::UnsignedSmallint => write!(f, "SMALLINT UNSIGNED"),
68            SqlType::Bigint => write!(f, "BIGINT"),
69            SqlType::UnsignedBigint => write!(f, "BIGINT UNSIGNED"),
70            SqlType::Tinyint => write!(f, "TINYINT"),
71            SqlType::UnsignedTinyint => write!(f, "TINYINT UNSIGNED"),
72            SqlType::Blob => write!(f, "BLOB"),
73            SqlType::Longblob => write!(f, "LONGBLOB"),
74            SqlType::Mediumblob => write!(f, "MEDIUMBLOB"),
75            SqlType::Tinyblob => write!(f, "TINYBLOB"),
76            SqlType::Double => write!(f, "DOUBLE"),
77            SqlType::Float => write!(f, "FLOAT"),
78            SqlType::Real => write!(f, "REAL"),
79            SqlType::Tinytext => write!(f, "TINYTEXT"),
80            SqlType::Mediumtext => write!(f, "MEDIUMTEXT"),
81            SqlType::Longtext => write!(f, "LONGTEXT"),
82            SqlType::Text => write!(f, "TEXT"),
83            SqlType::Date => write!(f, "DATE"),
84            SqlType::Time => write!(f, "TIME"),
85            SqlType::DateTime(len) => {
86                if len > 0 {
87                    write!(f, "DATETIME({})", len)
88                } else {
89                    write!(f, "DATETIME")
90                }
91            }
92            SqlType::Timestamp(len) => {
93                if len > 0 {
94                    write!(f, "TIMESTAMP({})", len)
95                } else {
96                    write!(f, "TIMESTAMP")
97                }
98            }
99            SqlType::Binary(len) => write!(f, "BINARY({})", len),
100            SqlType::Varbinary(len) => write!(f, "VARBINARY({})", len),
101            SqlType::Enum(ref v) => write!(
102                f,
103                "ENUM({})",
104                v.iter()
105                    .map(|v| v.to_string())
106                    .collect::<Vec<String>>()
107                    .join(",")
108            ),
109            SqlType::Set(ref v) => write!(
110                f,
111                "SET({})",
112                v.iter()
113                    .map(|v| v.to_string())
114                    .collect::<Vec<String>>()
115                    .join(",")
116            ),
117            SqlType::Decimal(m, d) => write!(f, "DECIMAL({}, {})", m, d),
118            SqlType::Json => write!(f, "JSON"),
119            SqlType::Point => write!(f, "POINT"),
120            SqlType::Geometry => write!(f, "GEOMETRY"),
121        }
122    }
123}
124
125#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
126pub struct Real {
127    pub integral: i32,
128    pub fractional: i32,
129}
130
131#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
132pub enum ItemPlaceholder {
133    QuestionMark,
134    DollarNumber(i32),
135    ColonNumber(i32),
136}
137
138#[allow(clippy::to_string_trait_impl)]
139impl ToString for ItemPlaceholder {
140    fn to_string(&self) -> String {
141        match *self {
142            ItemPlaceholder::QuestionMark => "?".to_string(),
143            ItemPlaceholder::DollarNumber(ref i) => format!("${}", i),
144            ItemPlaceholder::ColonNumber(ref i) => format!(":{}", i),
145        }
146    }
147}
148
149#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
150pub enum Literal {
151    Null,
152    Integer(i64),
153    UnsignedInteger(u64),
154    FixedPoint(Real),
155    String(String),
156    Blob(Vec<u8>),
157    CurrentTime,
158    CurrentDate,
159    CurrentTimestamp,
160    Placeholder(ItemPlaceholder),
161}
162
163impl From<i64> for Literal {
164    fn from(i: i64) -> Self {
165        Literal::Integer(i)
166    }
167}
168
169impl From<u64> for Literal {
170    fn from(i: u64) -> Self {
171        Literal::UnsignedInteger(i)
172    }
173}
174
175impl From<i32> for Literal {
176    fn from(i: i32) -> Self {
177        Literal::Integer(i.into())
178    }
179}
180
181impl From<u32> for Literal {
182    fn from(i: u32) -> Self {
183        Literal::UnsignedInteger(i.into())
184    }
185}
186
187impl From<String> for Literal {
188    fn from(s: String) -> Self {
189        Literal::String(s)
190    }
191}
192
193impl<'a> From<&'a str> for Literal {
194    fn from(s: &'a str) -> Self {
195        Literal::String(String::from(s))
196    }
197}
198
199#[allow(clippy::to_string_trait_impl)]
200impl ToString for Literal {
201    fn to_string(&self) -> String {
202        match *self {
203            Literal::Null => "NULL".to_string(),
204            Literal::Integer(ref i) => format!("{}", i),
205            Literal::UnsignedInteger(ref i) => format!("{}", i),
206            Literal::FixedPoint(ref f) => format!("{}.{}", f.integral, f.fractional),
207            Literal::String(ref s) => format!("'{}'", s.replace('\'', "''")),
208            Literal::Blob(ref bv) => bv
209                .iter()
210                .map(|v| format!("{:x}", v))
211                .collect::<Vec<String>>()
212                .join(" "),
213            Literal::CurrentTime => "CURRENT_TIME".to_string(),
214            Literal::CurrentDate => "CURRENT_DATE".to_string(),
215            Literal::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
216            Literal::Placeholder(ref item) => item.to_string(),
217        }
218    }
219}
220
221impl Literal {
222    pub fn to_raw_string(&self) -> String {
223        match *self {
224            Literal::Integer(ref i) => format!("{}", i),
225            Literal::UnsignedInteger(ref i) => format!("{}", i),
226            Literal::FixedPoint(ref f) => format!("{}.{}", f.integral, f.fractional),
227            Literal::String(ref s) => s.clone(),
228            _ => "".to_string(),
229        }
230    }
231}
232
233#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
234pub enum TableKey {
235    PrimaryKey(Vec<Column>),
236    UniqueKey(String, Vec<Column>),
237    FulltextKey(String, Vec<Column>, Option<String>),
238    Key(String, Vec<Column>),
239    SpatialKey(String, Vec<Column>),
240    Constraint(
241        String,
242        Vec<Column>,
243        String,
244        Vec<Column>,
245        Option<ReferenceOption>,
246        Option<ReferenceOption>,
247    ),
248}
249
250#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, derive_more::Display)]
251pub enum ReferenceOption {
252    #[display("RESTRICT")]
253    Restrict,
254    #[display("CASCADE")]
255    Cascade,
256    #[display("SET NULL")]
257    SetNull,
258    #[display("NO ACTION")]
259    NoAction,
260    #[display("SET DEFAULT")]
261    SetDefault,
262}
263
264pub fn reference_option(i: &[u8]) -> IResult<&[u8], ReferenceOption> {
265    alt((
266        map(tag_no_case("RESTRICT"), |_| ReferenceOption::Restrict),
267        map(tag_no_case("CASCADE"), |_| ReferenceOption::Cascade),
268        map(tag_no_case("SET NULL"), |_| ReferenceOption::SetNull),
269        map(tag_no_case("NO ACTION"), |_| ReferenceOption::NoAction),
270        map(tag_no_case("SET DEFAULT"), |_| ReferenceOption::SetDefault),
271    ))
272    .parse(i)
273}
274
275impl fmt::Display for TableKey {
276    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
277        match *self {
278            TableKey::PrimaryKey(ref columns) => {
279                write!(f, "PRIMARY KEY ")?;
280                write!(
281                    f,
282                    "({})",
283                    columns
284                        .iter()
285                        .map(|c| c.to_string())
286                        .collect::<Vec<_>>()
287                        .join(", ")
288                )
289            }
290            TableKey::UniqueKey(ref name, ref columns) => {
291                write!(f, "UNIQUE KEY {} ", escape(name))?;
292                write!(
293                    f,
294                    "({})",
295                    columns
296                        .iter()
297                        .map(|c| c.to_string())
298                        .collect::<Vec<_>>()
299                        .join(", ")
300                )
301            }
302            TableKey::FulltextKey(ref name, ref columns, ref parser) => {
303                write!(f, "FULLTEXT KEY {} ", escape(name))?;
304                write!(
305                    f,
306                    "({})",
307                    columns
308                        .iter()
309                        .map(|c| c.to_string())
310                        .collect::<Vec<_>>()
311                        .join(", ")
312                )?;
313                if let Some(parser) = parser {
314                    write!(f, "/*!50100 WITH PARSER `{}` */", parser)?;
315                }
316                Ok(())
317            }
318            TableKey::Key(ref name, ref columns) => {
319                write!(f, "KEY {} ", escape(name))?;
320                write!(
321                    f,
322                    "({})",
323                    columns
324                        .iter()
325                        .map(|c| c.to_string())
326                        .collect::<Vec<_>>()
327                        .join(", ")
328                )
329            }
330            TableKey::SpatialKey(ref name, ref columns) => {
331                write!(f, "SPATIAL KEY {} ", escape(name))?;
332                write!(
333                    f,
334                    "({})",
335                    columns
336                        .iter()
337                        .map(|c| c.to_string())
338                        .collect::<Vec<_>>()
339                        .join(", ")
340                )
341            }
342            TableKey::Constraint(
343                ref name,
344                ref columns,
345                ref table,
346                ref foreign,
347                ref on_delete,
348                ref on_update,
349            ) => {
350                write!(f, "CONSTRAINT {} FOREIGN KEY ", escape(name))?;
351                write!(
352                    f,
353                    "({})",
354                    columns
355                        .iter()
356                        .map(|c| c.to_string())
357                        .collect::<Vec<_>>()
358                        .join(", ")
359                )?;
360                write!(f, " REFERENCES {} ", escape(table))?;
361                write!(
362                    f,
363                    "({})",
364                    foreign
365                        .iter()
366                        .map(|c| c.to_string())
367                        .collect::<Vec<_>>()
368                        .join(", ")
369                )?;
370                if let Some(on_delete) = on_delete {
371                    write!(f, " ON DELETE {}", &on_delete.to_string())?;
372                }
373                if let Some(on_update) = on_update {
374                    write!(f, " ON UPDATE {}", &on_update.to_string())?;
375                }
376                Ok(())
377            }
378        }
379    }
380}
381
382#[inline]
383pub fn is_sql_identifier(chr: u8) -> bool {
384    AsChar::is_alphanum(chr) || chr == b'_' || chr == b'@'
385}
386
387#[inline]
388pub fn is_quoted_sql_identifier(chr: u8) -> bool {
389    chr > b' '
390        && chr != b'`'
391        && chr != b'['
392        && chr != b']'
393        && chr != b','
394        && chr != b'('
395        && chr != b')'
396        && chr != 0x7f
397}
398
399#[inline]
400fn len_as_u16(len: &[u8]) -> u16 {
401    match str::from_utf8(len) {
402        Ok(s) => match u16::from_str(s) {
403            Ok(v) => v,
404            Err(e) => std::panic::panic_any(e),
405        },
406        Err(e) => std::panic::panic_any(e),
407    }
408}
409
410pub fn len_as_u32(len: &[u8]) -> u32 {
411    match str::from_utf8(len) {
412        Ok(s) => match u32::from_str(s) {
413            Ok(v) => v,
414            Err(e) => std::panic::panic_any(e),
415        },
416        Err(e) => std::panic::panic_any(e),
417    }
418}
419
420fn precision_helper(i: &[u8]) -> IResult<&[u8], (u16, Option<u16>)> {
421    let (remaining_input, (m, d)) = (
422        digit1,
423        opt(preceded(tag(","), preceded(multispace0, digit1))),
424    )
425        .parse(i)?;
426
427    Ok((remaining_input, (len_as_u16(m), d.map(len_as_u16))))
428}
429
430pub fn precision(i: &[u8]) -> IResult<&[u8], (u16, Option<u16>)> {
431    delimited(tag("("), precision_helper, tag(")")).parse(i)
432}
433
434fn opt_signed(i: &[u8]) -> IResult<&[u8], Option<&[u8]>> {
435    opt(alt((tag_no_case("unsigned"), tag_no_case("signed")))).parse(i)
436}
437
438fn opt_unsigned(i: &[u8]) -> IResult<&[u8], Option<&[u8]>> {
439    opt(tag_no_case("unsigned")).parse(i)
440}
441
442fn delim_digit(i: &[u8]) -> IResult<&[u8], &[u8]> {
443    delimited(tag("("), digit1, tag(")")).parse(i)
444}
445
446// TODO: rather than copy paste these functions, should create a function that returns a parser
447// based on the sql int type, just like nom does
448fn tiny_int(i: &[u8]) -> IResult<&[u8], SqlType> {
449    let (remaining_input, (_, _len, _, signed)) = (
450        tag_no_case("tinyint"),
451        opt(delim_digit),
452        multispace0,
453        opt_signed,
454    )
455        .parse(i)?;
456
457    match signed {
458        Some(sign) => {
459            if str::from_utf8(sign)
460                .unwrap()
461                .eq_ignore_ascii_case("unsigned")
462            {
463                Ok((remaining_input, SqlType::UnsignedTinyint))
464            } else {
465                Ok((remaining_input, SqlType::Tinyint))
466            }
467        }
468        None => Ok((remaining_input, SqlType::Tinyint)),
469    }
470}
471
472// TODO: rather than copy paste these functions, should create a function that returns a parser
473// based on the sql int type, just like nom does
474fn big_int(i: &[u8]) -> IResult<&[u8], SqlType> {
475    let (remaining_input, (_, _len, _, signed)) = (
476        tag_no_case("bigint"),
477        opt(delim_digit),
478        multispace0,
479        opt_signed,
480    )
481        .parse(i)?;
482
483    match signed {
484        Some(sign) => {
485            if str::from_utf8(sign)
486                .unwrap()
487                .eq_ignore_ascii_case("unsigned")
488            {
489                Ok((remaining_input, SqlType::UnsignedBigint))
490            } else {
491                Ok((remaining_input, SqlType::Bigint))
492            }
493        }
494        None => Ok((remaining_input, SqlType::Bigint)),
495    }
496}
497
498// TODO: rather than copy paste these functions, should create a function that returns a parser
499// based on the sql int type, just like nom does
500fn sql_int_type(i: &[u8]) -> IResult<&[u8], SqlType> {
501    let (remaining_input, (_, _len, _, signed)) = (
502        alt((tag_no_case("integer"), tag_no_case("int"))),
503        opt(delim_digit),
504        multispace0,
505        opt_signed,
506    )
507        .parse(i)?;
508
509    match signed {
510        Some(sign) => {
511            if str::from_utf8(sign)
512                .unwrap()
513                .eq_ignore_ascii_case("unsigned")
514            {
515                Ok((remaining_input, SqlType::UnsignedInt))
516            } else {
517                Ok((remaining_input, SqlType::Int))
518            }
519        }
520        None => Ok((remaining_input, SqlType::Int)),
521    }
522}
523fn small_int_type(i: &[u8]) -> IResult<&[u8], SqlType> {
524    let (remaining_input, (_, _len, _, signed)) = (
525        tag_no_case("smallint"),
526        opt(delim_digit),
527        multispace0,
528        opt_signed,
529    )
530        .parse(i)?;
531
532    match signed {
533        Some(sign) => {
534            if str::from_utf8(sign)
535                .unwrap()
536                .eq_ignore_ascii_case("unsigned")
537            {
538                Ok((remaining_input, SqlType::UnsignedSmallint))
539            } else {
540                Ok((remaining_input, SqlType::Smallint))
541            }
542        }
543        None => Ok((remaining_input, SqlType::Smallint)),
544    }
545}
546
547// TODO(malte): not strictly ok to treat DECIMAL and NUMERIC as identical; the
548// former has "at least" M precision, the latter "exactly".
549// See https://dev.mysql.com/doc/refman/5.7/en/precision-math-decimal-characteristics.html
550fn decimal_or_numeric(i: &[u8]) -> IResult<&[u8], SqlType> {
551    let (remaining_input, (_, precision, _, _unsigned)) = (
552        alt((tag_no_case("decimal"), tag_no_case("numeric"))),
553        opt(precision),
554        multispace0,
555        opt_unsigned,
556    )
557        .parse(i)?;
558
559    match precision {
560        None => Ok((remaining_input, SqlType::Decimal(32, 0))),
561        Some((m, None)) => Ok((remaining_input, SqlType::Decimal(m, 0))),
562        Some((m, Some(d))) => Ok((remaining_input, SqlType::Decimal(m, d))),
563    }
564}
565
566fn type_identifier_first_half(i: &[u8]) -> IResult<&[u8], SqlType> {
567    alt((
568        tiny_int,
569        big_int,
570        sql_int_type,
571        small_int_type,
572        map(tag_no_case("bool"), |_| SqlType::Bool),
573        map(
574            (
575                tag_no_case("char"),
576                delim_digit,
577                multispace0,
578                opt(tag_no_case("binary")),
579            ),
580            |t| SqlType::Char(len_as_u32(t.1)),
581        ),
582        map(preceded(tag_no_case("datetime"), opt(delim_digit)), |fsp| {
583            SqlType::DateTime(match fsp {
584                Some(fsp) => len_as_u16(fsp),
585                None => 0_u16,
586            })
587        }),
588        map(tag_no_case("date"), |_| SqlType::Date),
589        map(
590            preceded(tag_no_case("timestamp"), opt(delim_digit)),
591            |fsp| {
592                SqlType::Timestamp(match fsp {
593                    Some(fsp) => len_as_u16(fsp),
594                    None => 0_u16,
595                })
596            },
597        ),
598        map(tag_no_case("time"), |_| SqlType::Time),
599        map((tag_no_case("double"), multispace0, opt_unsigned), |_| {
600            SqlType::Double
601        }),
602        map(
603            terminated(
604                preceded(
605                    tag_no_case("enum"),
606                    delimited(tag("("), value_list, tag(")")),
607                ),
608                multispace0,
609            ),
610            SqlType::Enum,
611        ),
612        map(
613            terminated(
614                preceded(
615                    tag_no_case("set"),
616                    delimited(tag("("), value_list, tag(")")),
617                ),
618                multispace0,
619            ),
620            SqlType::Set,
621        ),
622        map(
623            (
624                tag_no_case("float"),
625                multispace0,
626                opt(precision),
627                multispace0,
628                opt_unsigned,
629            ),
630            |_| SqlType::Float,
631        ),
632        map((tag_no_case("real"), multispace0, opt_unsigned), |_| {
633            SqlType::Real
634        }),
635        map(tag_no_case("text"), |_| SqlType::Text),
636        map(
637            (
638                tag_no_case("varchar"),
639                delim_digit,
640                multispace0,
641                opt(tag_no_case("binary")),
642            ),
643            |t| SqlType::Varchar(len_as_u32(t.1)),
644        ),
645        map(tag_no_case("json"), |_| SqlType::Json),
646        map(tag_no_case("point"), |_| SqlType::Point),
647        map(tag_no_case("geometry"), |_| SqlType::Geometry),
648        decimal_or_numeric,
649    ))
650    .parse(i)
651}
652
653fn type_identifier_second_half(i: &[u8]) -> IResult<&[u8], SqlType> {
654    alt((
655        map((tag_no_case("binary"), delim_digit, multispace0), |t| {
656            SqlType::Binary(len_as_u16(t.1))
657        }),
658        map(tag_no_case("blob"), |_| SqlType::Blob),
659        map(tag_no_case("longblob"), |_| SqlType::Longblob),
660        map(tag_no_case("mediumblob"), |_| SqlType::Mediumblob),
661        map(tag_no_case("mediumtext"), |_| SqlType::Mediumtext),
662        map(tag_no_case("longtext"), |_| SqlType::Longtext),
663        map(tag_no_case("tinyblob"), |_| SqlType::Tinyblob),
664        map(tag_no_case("tinytext"), |_| SqlType::Tinytext),
665        map((tag_no_case("varbinary"), delim_digit, multispace0), |t| {
666            SqlType::Varbinary(len_as_u16(t.1))
667        }),
668    ))
669    .parse(i)
670}
671
672// A SQL type specifier.
673pub fn type_identifier(i: &[u8]) -> IResult<&[u8], SqlType> {
674    alt((type_identifier_first_half, type_identifier_second_half)).parse(i)
675}
676
677// Parses a SQL column identifier in the table.column format
678pub fn column_identifier_no_alias(i: &[u8]) -> IResult<&[u8], Column> {
679    let (remaining_input, (column, len)) =
680        (sql_identifier, opt(delimited(tag("("), digit1, tag(")")))).parse(i)?;
681    Ok((
682        remaining_input,
683        Column {
684            name: str::from_utf8(column).unwrap().to_string(),
685            query: None,
686            len: len.map(|l| u32::from_str(str::from_utf8(l).unwrap()).unwrap()),
687            desc: false,
688        },
689    ))
690}
691pub fn column_identifier_query(i: &[u8]) -> IResult<&[u8], Column> {
692    let (remaining_input, query) =
693        delimited(tag("("), take_until_unbalanced('(', ')'), tag(")")).parse(i)?;
694    Ok((
695        remaining_input,
696        Column {
697            name: "".to_string(),
698            query: Some(str::from_utf8(query).unwrap().to_string()),
699            len: None,
700            desc: false,
701        },
702    ))
703}
704
705// Parses a SQL identifier (alphanumeric1 and "_").
706pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
707    alt((
708        preceded(not(peek(sql_keyword)), take_while1(is_sql_identifier)),
709        delimited(tag("`"), take_while1(is_quoted_sql_identifier), tag("`")),
710        delimited(tag("["), take_while1(is_quoted_sql_identifier), tag("]")),
711    ))
712    .parse(i)
713}
714pub fn take_until_unbalanced(
715    opening_bracket: char,
716    closing_bracket: char,
717) -> impl Fn(&[u8]) -> IResult<&[u8], &[u8]> {
718    move |i: &[u8]| {
719        let mut index = 0;
720        let mut bracket_counter = 0;
721        while index < i.len() {
722            match i[index] {
723                b'\\' => {
724                    index += 1;
725                }
726                c if c == opening_bracket as u8 => {
727                    bracket_counter += 1;
728                }
729                c if c == closing_bracket as u8 => {
730                    bracket_counter -= 1;
731                }
732                _ => {}
733            };
734            if bracket_counter == -1 {
735                return Ok((&i[index..], &i[0..index]));
736            };
737            index += 1;
738        }
739
740        if bracket_counter == 0 {
741            Ok(("".as_bytes(), i))
742        } else {
743            Err(nom::Err::Error(nom::error::Error::from_error_kind(
744                i,
745                ErrorKind::TakeUntil,
746            )))
747        }
748    }
749}
750
751pub(crate) fn eof<I: Copy + nom::Input, E: ParseError<I>>(input: I) -> IResult<I, I, E> {
752    if input.input_len() == 0 {
753        Ok((input, input))
754    } else {
755        Err(nom::Err::Error(E::from_error_kind(input, ErrorKind::Eof)))
756    }
757}
758
759// Parse a terminator that ends a SQL statement.
760pub fn statement_terminator(i: &[u8]) -> IResult<&[u8], ()> {
761    let (remaining_input, _) =
762        delimited(multispace0, alt((tag(";"), line_ending, eof)), multispace0).parse(i)?;
763
764    Ok((remaining_input, ()))
765}
766
767pub(crate) fn ws_sep_comma(i: &[u8]) -> IResult<&[u8], &[u8]> {
768    delimited(multispace0, tag(","), multispace0).parse(i)
769}
770
771pub(crate) fn ws_sep_equals<'a, I>(i: I) -> IResult<I, I>
772where
773    I: nom::Input + nom::Compare<&'a str>,
774    // Compare required by tag
775    <I as nom::Input>::Item: nom::AsChar + Clone,
776    // AsChar and Clone required by multispace0
777{
778    delimited(multispace0, tag("="), multispace0).parse(i)
779}
780
781// Integer literal value
782pub fn integer_literal(i: &[u8]) -> IResult<&[u8], Literal> {
783    map(pair(opt(tag("-")), digit1), |tup| {
784        let mut intval = i64::from_str(str::from_utf8(tup.1).unwrap()).unwrap();
785        if (tup.0).is_some() {
786            intval *= -1;
787        }
788        Literal::Integer(intval)
789    })
790    .parse(i)
791}
792
793fn unpack(v: &[u8]) -> i32 {
794    i32::from_str(str::from_utf8(v).unwrap()).unwrap()
795}
796
797// Floating point literal value
798pub fn float_literal(i: &[u8]) -> IResult<&[u8], Literal> {
799    map((opt(tag("-")), digit1, tag("."), digit1), |tup| {
800        Literal::FixedPoint(Real {
801            integral: if (tup.0).is_some() {
802                -unpack(tup.1)
803            } else {
804                unpack(tup.1)
805            },
806            fractional: unpack(tup.3),
807        })
808    })
809    .parse(i)
810}
811
812/// String literal value
813fn raw_string_quoted(input: &[u8], is_single_quote: bool) -> IResult<&[u8], Vec<u8>> {
814    // TODO: clean up these assignments. lifetimes and temporary values made it difficult
815    let quote_slice: &[u8] = if is_single_quote { b"\'" } else { b"\"" };
816    let double_quote_slice: &[u8] = if is_single_quote { b"\'\'" } else { b"\"\"" };
817    let backslash_quote: &[u8] = if is_single_quote { b"\\\'" } else { b"\\\"" };
818    delimited(
819        tag(quote_slice),
820        fold_many0(
821            alt((
822                is_not(backslash_quote),
823                map(tag(double_quote_slice), |_| -> &[u8] {
824                    if is_single_quote { b"\'" } else { b"\"" }
825                }),
826                map(tag("\\\\"), |_| &b"\\"[..]),
827                map(tag("\\b"), |_| &b"\x7f"[..]),
828                map(tag("\\r"), |_| &b"\r"[..]),
829                map(tag("\\n"), |_| &b"\n"[..]),
830                map(tag("\\t"), |_| &b"\t"[..]),
831                map(tag("\\0"), |_| &b"\0"[..]),
832                map(tag("\\Z"), |_| &b"\x1A"[..]),
833                preceded(tag("\\"), take(1usize)),
834            )),
835            Vec::new,
836            |mut acc: Vec<u8>, bytes: &[u8]| {
837                acc.extend(bytes);
838                acc
839            },
840        ),
841        tag(quote_slice),
842    )
843    .parse(input)
844}
845
846fn raw_string_single_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
847    raw_string_quoted(i, true)
848}
849
850fn raw_string_double_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
851    raw_string_quoted(i, false)
852}
853
854pub fn string_literal(i: &[u8]) -> IResult<&[u8], Literal> {
855    map(
856        alt((raw_string_single_quoted, raw_string_double_quoted)),
857        |bytes| match String::from_utf8(bytes) {
858            Ok(s) => Literal::String(s),
859            Err(err) => Literal::Blob(err.into_bytes()),
860        },
861    )
862    .parse(i)
863}
864
865// Any literal value.
866pub fn literal(i: &[u8]) -> IResult<&[u8], Literal> {
867    alt((
868        float_literal,
869        integer_literal,
870        string_literal,
871        map(tag_no_case("null"), |_| Literal::Null),
872        map(tag_no_case("current_timestamp"), |_| {
873            Literal::CurrentTimestamp
874        }),
875        map(tag_no_case("current_date"), |_| Literal::CurrentDate),
876        map(tag_no_case("current_time"), |_| Literal::CurrentTime),
877        map(tag("?"), |_| {
878            Literal::Placeholder(ItemPlaceholder::QuestionMark)
879        }),
880        map(preceded(tag(":"), digit1), |num| {
881            let value = i32::from_str(str::from_utf8(num).unwrap()).unwrap();
882            Literal::Placeholder(ItemPlaceholder::ColonNumber(value))
883        }),
884        map(preceded(tag("$"), digit1), |num| {
885            let value = i32::from_str(str::from_utf8(num).unwrap()).unwrap();
886            Literal::Placeholder(ItemPlaceholder::DollarNumber(value))
887        }),
888    ))
889    .parse(i)
890}
891
892// Parse a list of values (e.g., for INSERT syntax).
893pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec<Literal>> {
894    many0(delimited(multispace0, literal, opt(ws_sep_comma))).parse(i)
895}
896
897// Parse a reference to a named schema.table, with an optional alias
898pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], String> {
899    map(sql_identifier, |tup| {
900        String::from(str::from_utf8(tup).unwrap())
901    })
902    .parse(i)
903}
904
905// Parse rule for a comment part.
906pub fn parse_comment(i: &[u8]) -> IResult<&[u8], Literal> {
907    preceded(
908        delimited(multispace0, tag_no_case("comment"), multispace1),
909        string_literal,
910    )
911    .parse(i)
912}