sqlparser_mysql/base/
arithmetic.rs

1use std::{fmt, str};
2
3use nom::{
4    branch::alt,
5    bytes::complete::{tag, tag_no_case},
6    character::complete::{multispace0, multispace1},
7    combinator::{map, opt},
8    lib::std::fmt::Formatter,
9    multi::many0,
10    sequence::{delimited, pair, preceded, separated_pair, terminated, tuple},
11    Err::Error,
12    IResult,
13};
14
15use base::Column;
16use base::ParseSQLErrorKind;
17use base::{CommonParser, DataType, Literal, ParseSQLError};
18
19#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
20pub enum ArithmeticOperator {
21    Add,
22    Subtract,
23    Multiply,
24    Divide,
25}
26
27impl ArithmeticOperator {
28    fn add_sub_operator(i: &str) -> IResult<&str, ArithmeticOperator, ParseSQLError<&str>> {
29        alt((
30            map(tag("+"), |_| ArithmeticOperator::Add),
31            map(tag("-"), |_| ArithmeticOperator::Subtract),
32        ))(i)
33    }
34
35    fn mul_div_operator(i: &str) -> IResult<&str, ArithmeticOperator, ParseSQLError<&str>> {
36        alt((
37            map(tag("*"), |_| ArithmeticOperator::Multiply),
38            map(tag("/"), |_| ArithmeticOperator::Divide),
39        ))(i)
40    }
41}
42
43impl fmt::Display for ArithmeticOperator {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        match *self {
46            ArithmeticOperator::Add => write!(f, "+"),
47            ArithmeticOperator::Subtract => write!(f, "-"),
48            ArithmeticOperator::Multiply => write!(f, "*"),
49            ArithmeticOperator::Divide => write!(f, "/"),
50        }
51    }
52}
53
54#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
55pub enum ArithmeticBase {
56    Column(Column),
57    Scalar(Literal),
58    Bracketed(Box<Arithmetic>),
59}
60
61impl ArithmeticBase {
62    // Base case for nested arithmetic expressions: column name or literal.
63    fn parse(i: &str) -> IResult<&str, ArithmeticBase, ParseSQLError<&str>> {
64        alt((
65            map(Literal::integer_literal, ArithmeticBase::Scalar),
66            map(Column::without_alias, ArithmeticBase::Column),
67            map(
68                delimited(
69                    terminated(tag("("), multispace0),
70                    Arithmetic::parse,
71                    preceded(multispace0, tag(")")),
72                ),
73                |ari| ArithmeticBase::Bracketed(Box::new(ari)),
74            ),
75        ))(i)
76    }
77}
78
79impl fmt::Display for ArithmeticBase {
80    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81        match *self {
82            ArithmeticBase::Column(ref col) => write!(f, "{}", col),
83            ArithmeticBase::Scalar(ref lit) => write!(f, "{}", lit),
84            ArithmeticBase::Bracketed(ref ari) => write!(f, "({})", ari),
85        }
86    }
87}
88
89#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
90pub enum ArithmeticItem {
91    Base(ArithmeticBase),
92    Expr(Box<Arithmetic>),
93}
94
95impl ArithmeticItem {
96    fn term(i: &str) -> IResult<&str, ArithmeticItem, ParseSQLError<&str>> {
97        map(
98            pair(Self::arithmetic_cast, many0(Self::term_rest)),
99            |(b, rs)| {
100                rs.into_iter()
101                    .fold(ArithmeticItem::Base(b.0), |acc, (o, r)| {
102                        ArithmeticItem::Expr(Box::new(Arithmetic {
103                            op: o,
104                            left: acc,
105                            right: r,
106                        }))
107                    })
108            },
109        )(i)
110    }
111
112    fn term_rest(
113        i: &str,
114    ) -> IResult<&str, (ArithmeticOperator, ArithmeticItem), ParseSQLError<&str>> {
115        separated_pair(
116            preceded(multispace0, ArithmeticOperator::mul_div_operator),
117            multispace0,
118            map(Self::arithmetic_cast, |b| ArithmeticItem::Base(b.0)),
119        )(i)
120    }
121
122    fn expr(i: &str) -> IResult<&str, ArithmeticItem, ParseSQLError<&str>> {
123        map(
124            pair(ArithmeticItem::term, many0(Self::expr_rest)),
125            |(item, rs)| {
126                rs.into_iter().fold(item, |acc, (o, r)| {
127                    ArithmeticItem::Expr(Box::new(Arithmetic {
128                        op: o,
129                        left: acc,
130                        right: r,
131                    }))
132                })
133            },
134        )(i)
135    }
136
137    fn expr_rest(
138        i: &str,
139    ) -> IResult<&str, (ArithmeticOperator, ArithmeticItem), ParseSQLError<&str>> {
140        separated_pair(
141            preceded(multispace0, ArithmeticOperator::add_sub_operator),
142            multispace0,
143            ArithmeticItem::term,
144        )(i)
145    }
146
147    fn arithmetic_cast(
148        i: &str,
149    ) -> IResult<&str, (ArithmeticBase, Option<DataType>), ParseSQLError<&str>> {
150        alt((
151            Self::arithmetic_cast_helper,
152            map(ArithmeticBase::parse, |v| (v, None)),
153        ))(i)
154    }
155
156    fn arithmetic_cast_helper(
157        i: &str,
158    ) -> IResult<&str, (ArithmeticBase, Option<DataType>), ParseSQLError<&str>> {
159        let (remaining_input, (_, _, _, _, a_base, _, _, _, _sign, sql_type, _, _)) = tuple((
160            tag_no_case("CAST"),
161            multispace0,
162            tag("("),
163            multispace0,
164            // TODO(malte): should be arbitrary expr
165            ArithmeticBase::parse,
166            multispace1,
167            tag_no_case("AS"),
168            multispace1,
169            opt(terminated(tag_no_case("SIGNED"), multispace1)),
170            DataType::type_identifier,
171            multispace0,
172            tag(")"),
173        ))(i)?;
174
175        Ok((remaining_input, (a_base, Some(sql_type))))
176    }
177}
178
179impl fmt::Display for ArithmeticItem {
180    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
181        match *self {
182            ArithmeticItem::Base(ref b) => write!(f, "{}", b),
183            ArithmeticItem::Expr(ref expr) => write!(f, "{}", expr),
184        }
185    }
186}
187
188#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
189pub struct Arithmetic {
190    pub op: ArithmeticOperator,
191    pub left: ArithmeticItem,
192    pub right: ArithmeticItem,
193}
194
195impl Arithmetic {
196    fn parse(i: &str) -> IResult<&str, Arithmetic, ParseSQLError<&str>> {
197        let res = ArithmeticItem::expr(i)?;
198        match res.1 {
199            ArithmeticItem::Base(ArithmeticBase::Column(_))
200            | ArithmeticItem::Base(ArithmeticBase::Scalar(_)) => {
201                let mut error: ParseSQLError<&str> = ParseSQLError { errors: vec![] };
202                error.errors.push((i, ParseSQLErrorKind::Context("Tag")));
203                Err(Error(error))
204            } // no operator
205            ArithmeticItem::Base(ArithmeticBase::Bracketed(expr)) => Ok((res.0, *expr)),
206            ArithmeticItem::Expr(expr) => Ok((res.0, *expr)),
207        }
208    }
209    pub fn new(op: ArithmeticOperator, left: ArithmeticBase, right: ArithmeticBase) -> Self {
210        Self {
211            op,
212            left: ArithmeticItem::Base(left),
213            right: ArithmeticItem::Base(right),
214        }
215    }
216}
217
218impl fmt::Display for Arithmetic {
219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220        write!(f, "{} {} {}", self.left, self.op, self.right)
221    }
222}
223
224#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
225pub struct ArithmeticExpression {
226    pub ari: Arithmetic,
227    pub alias: Option<String>,
228}
229
230impl ArithmeticExpression {
231    pub fn parse(i: &str) -> IResult<&str, ArithmeticExpression, ParseSQLError<&str>> {
232        map(
233            pair(Arithmetic::parse, opt(CommonParser::as_alias)),
234            |(ari, opt_alias)| ArithmeticExpression {
235                ari,
236                alias: opt_alias.map(String::from),
237            },
238        )(i)
239    }
240}
241
242impl ArithmeticExpression {
243    pub fn new(
244        op: ArithmeticOperator,
245        left: ArithmeticBase,
246        right: ArithmeticBase,
247        alias: Option<String>,
248    ) -> Self {
249        Self {
250            ari: Arithmetic {
251                op,
252                left: ArithmeticItem::Base(left),
253                right: ArithmeticItem::Base(right),
254            },
255            alias,
256        }
257    }
258}
259
260impl fmt::Display for ArithmeticExpression {
261    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
262        match self.alias {
263            Some(ref alias) => write!(f, "{} AS {}", self.ari, alias),
264            None => write!(f, "{}", self.ari),
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use base::arithmetic::ArithmeticBase::Scalar;
272    use base::arithmetic::ArithmeticOperator::{Add, Divide, Multiply, Subtract};
273    use base::column::{Column, FunctionArgument, FunctionExpression};
274
275    use super::*;
276
277    #[test]
278    fn parses_arithmetic_expressions() {
279        use super::{
280            ArithmeticBase::{Column as ArithmeticBaseColumn, Scalar},
281            ArithmeticOperator::*,
282        };
283
284        let lit_ae = [
285            "5 + 42",
286            "5+42",
287            "5 * 42",
288            "5 - 42",
289            "5 / 42",
290            "2 * 10 AS twenty ",
291        ];
292
293        // N.B. trailing space in "5 + foo " is required because `sql_identifier`'s keyword
294        // detection requires a follow-up character (in practice, there always is one because we
295        // use semicolon-terminated queries).
296        let col_lit_ae = [
297            "foo+5",
298            "foo + 5",
299            "5 + foo ",
300            "foo * bar AS foobar",
301            "MAX(foo)-3333",
302        ];
303
304        let expected_lit_ae = [
305            ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
306            ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
307            ArithmeticExpression::new(Multiply, Scalar(5.into()), Scalar(42.into()), None),
308            ArithmeticExpression::new(Subtract, Scalar(5.into()), Scalar(42.into()), None),
309            ArithmeticExpression::new(Divide, Scalar(5.into()), Scalar(42.into()), None),
310            ArithmeticExpression::new(
311                Multiply,
312                Scalar(2.into()),
313                Scalar(10.into()),
314                Some(String::from("twenty")),
315            ),
316        ];
317        let expected_col_lit_ae = [
318            ArithmeticExpression::new(
319                Add,
320                ArithmeticBaseColumn("foo".into()),
321                Scalar(5.into()),
322                None,
323            ),
324            ArithmeticExpression::new(
325                Add,
326                ArithmeticBaseColumn("foo".into()),
327                Scalar(5.into()),
328                None,
329            ),
330            ArithmeticExpression::new(
331                Add,
332                Scalar(5.into()),
333                ArithmeticBaseColumn("foo".into()),
334                None,
335            ),
336            ArithmeticExpression::new(
337                Multiply,
338                ArithmeticBaseColumn("foo".into()),
339                ArithmeticBaseColumn("bar".into()),
340                Some(String::from("foobar")),
341            ),
342            ArithmeticExpression::new(
343                Subtract,
344                ArithmeticBaseColumn(Column {
345                    name: String::from("max(foo)"),
346                    alias: None,
347                    table: None,
348                    function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column(
349                        "foo".into(),
350                    )))),
351                }),
352                Scalar(3333.into()),
353                None,
354            ),
355        ];
356
357        for (i, e) in lit_ae.iter().enumerate() {
358            let res = ArithmeticExpression::parse(e);
359            assert!(res.is_ok());
360            assert_eq!(res.unwrap().1, expected_lit_ae[i]);
361        }
362
363        for (i, e) in col_lit_ae.iter().enumerate() {
364            let res = ArithmeticExpression::parse(e);
365            assert!(res.is_ok());
366            assert_eq!(res.unwrap().1, expected_col_lit_ae[i]);
367        }
368    }
369
370    #[test]
371    fn displays_arithmetic_expressions() {
372        use super::{
373            ArithmeticBase::{Column as ArithmeticBaseColumn, Scalar},
374            ArithmeticOperator::*,
375        };
376
377        let expressions = [
378            ArithmeticExpression::new(
379                Add,
380                ArithmeticBaseColumn("foo".into()),
381                Scalar(5.into()),
382                None,
383            ),
384            ArithmeticExpression::new(
385                Subtract,
386                Scalar(5.into()),
387                ArithmeticBaseColumn("foo".into()),
388                None,
389            ),
390            ArithmeticExpression::new(
391                Multiply,
392                ArithmeticBaseColumn("foo".into()),
393                ArithmeticBaseColumn("bar".into()),
394                None,
395            ),
396            ArithmeticExpression::new(Divide, Scalar(10.into()), Scalar(2.into()), None),
397            ArithmeticExpression::new(
398                Add,
399                Scalar(10.into()),
400                Scalar(2.into()),
401                Some(String::from("bob")),
402            ),
403        ];
404
405        let expected_strings = ["foo + 5", "5 - foo", "foo * bar", "10 / 2", "10 + 2 AS bob"];
406        for (i, e) in expressions.iter().enumerate() {
407            assert_eq!(expected_strings[i], format!("{}", e));
408        }
409    }
410
411    #[test]
412    fn parses_arithmetic_casts() {
413        use super::{
414            ArithmeticBase::{Column as ArithmeticBaseColumn, Scalar},
415            ArithmeticOperator::*,
416        };
417
418        let exprs = [
419            "CAST(`t`.`foo` AS signed int) + CAST(`t`.`bar` AS signed int) ",
420            "CAST(5 AS bigint) - foo ",
421            "CAST(5 AS bigint) - foo AS `5_minus_foo`",
422        ];
423
424        // XXX(malte): currently discards the cast and type information!
425        let expected = [
426            ArithmeticExpression::new(
427                Add,
428                ArithmeticBaseColumn(Column::from("t.foo")),
429                ArithmeticBaseColumn(Column::from("t.bar")),
430                None,
431            ),
432            ArithmeticExpression::new(
433                Subtract,
434                Scalar(5.into()),
435                ArithmeticBaseColumn("foo".into()),
436                None,
437            ),
438            ArithmeticExpression::new(
439                Subtract,
440                Scalar(5.into()),
441                ArithmeticBaseColumn("foo".into()),
442                Some("5_minus_foo".into()),
443            ),
444        ];
445
446        for (i, e) in exprs.iter().enumerate() {
447            let res = ArithmeticExpression::parse(e);
448            assert!(res.is_ok(), "{} failed to parse", e);
449            assert_eq!(res.unwrap().1, expected[i]);
450        }
451    }
452
453    #[test]
454    fn parse_nested_arithmetic() {
455        let qs = [
456            "1 + 1",
457            "1 + 2 - 3",
458            "1 + 2 * 3",
459            "2 * 3 - 1 / 3",
460            "3 * (1 + 2)",
461        ];
462
463        let expects =
464            [
465                Arithmetic::new(Add, Scalar(1.into()), Scalar(1.into())),
466                Arithmetic {
467                    op: Subtract,
468                    left: ArithmeticItem::Expr(Box::new(Arithmetic::new(
469                        Add,
470                        Scalar(1.into()),
471                        Scalar(2.into()),
472                    ))),
473                    right: ArithmeticItem::Base(Scalar(3.into())),
474                },
475                Arithmetic {
476                    op: Add,
477                    left: ArithmeticItem::Base(Scalar(1.into())),
478                    right: ArithmeticItem::Expr(Box::new(Arithmetic::new(
479                        Multiply,
480                        Scalar(2.into()),
481                        Scalar(3.into()),
482                    ))),
483                },
484                Arithmetic {
485                    op: Subtract,
486                    left: ArithmeticItem::Expr(Box::new(Arithmetic::new(
487                        Multiply,
488                        Scalar(2.into()),
489                        Scalar(3.into()),
490                    ))),
491                    right: ArithmeticItem::Expr(Box::new(Arithmetic::new(
492                        Divide,
493                        Scalar(1.into()),
494                        Scalar(3.into()),
495                    ))),
496                },
497                Arithmetic {
498                    op: Multiply,
499                    left: ArithmeticItem::Base(Scalar(3.into())),
500                    right: ArithmeticItem::Base(ArithmeticBase::Bracketed(Box::new(
501                        Arithmetic::new(Add, Scalar(1.into()), Scalar(2.into())),
502                    ))),
503                },
504            ];
505
506        for (i, e) in qs.iter().enumerate() {
507            let res = Arithmetic::parse(e);
508            let ari = res.unwrap().1;
509            assert_eq!(ari, expects[i]);
510            assert_eq!(format!("{}", ari), qs[i]);
511        }
512    }
513
514    #[test]
515    fn parse_arithmetic_scalar() {
516        let qs = "56";
517        let res = Arithmetic::parse(qs);
518        assert!(res.is_err());
519    }
520}