1use std::{fmt, str};
2
3use nom::{
4 branch::alt,
5 bytes::complete::{tag, tag_no_case},
6 character::complete::{multispace0, multispace1},
7 combinator::{map, opt},
8 lib::std::fmt::Formatter,
9 multi::many0,
10 sequence::{delimited, pair, preceded, separated_pair, terminated, tuple},
11 Err::Error,
12 IResult,
13};
14
15use base::Column;
16use base::ParseSQLErrorKind;
17use base::{CommonParser, DataType, Literal, ParseSQLError};
18
19#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
20pub enum ArithmeticOperator {
21 Add,
22 Subtract,
23 Multiply,
24 Divide,
25}
26
27impl ArithmeticOperator {
28 fn add_sub_operator(i: &str) -> IResult<&str, ArithmeticOperator, ParseSQLError<&str>> {
29 alt((
30 map(tag("+"), |_| ArithmeticOperator::Add),
31 map(tag("-"), |_| ArithmeticOperator::Subtract),
32 ))(i)
33 }
34
35 fn mul_div_operator(i: &str) -> IResult<&str, ArithmeticOperator, ParseSQLError<&str>> {
36 alt((
37 map(tag("*"), |_| ArithmeticOperator::Multiply),
38 map(tag("/"), |_| ArithmeticOperator::Divide),
39 ))(i)
40 }
41}
42
43impl fmt::Display for ArithmeticOperator {
44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 match *self {
46 ArithmeticOperator::Add => write!(f, "+"),
47 ArithmeticOperator::Subtract => write!(f, "-"),
48 ArithmeticOperator::Multiply => write!(f, "*"),
49 ArithmeticOperator::Divide => write!(f, "/"),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
55pub enum ArithmeticBase {
56 Column(Column),
57 Scalar(Literal),
58 Bracketed(Box<Arithmetic>),
59}
60
61impl ArithmeticBase {
62 fn parse(i: &str) -> IResult<&str, ArithmeticBase, ParseSQLError<&str>> {
64 alt((
65 map(Literal::integer_literal, ArithmeticBase::Scalar),
66 map(Column::without_alias, ArithmeticBase::Column),
67 map(
68 delimited(
69 terminated(tag("("), multispace0),
70 Arithmetic::parse,
71 preceded(multispace0, tag(")")),
72 ),
73 |ari| ArithmeticBase::Bracketed(Box::new(ari)),
74 ),
75 ))(i)
76 }
77}
78
79impl fmt::Display for ArithmeticBase {
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 match *self {
82 ArithmeticBase::Column(ref col) => write!(f, "{}", col),
83 ArithmeticBase::Scalar(ref lit) => write!(f, "{}", lit),
84 ArithmeticBase::Bracketed(ref ari) => write!(f, "({})", ari),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
90pub enum ArithmeticItem {
91 Base(ArithmeticBase),
92 Expr(Box<Arithmetic>),
93}
94
95impl ArithmeticItem {
96 fn term(i: &str) -> IResult<&str, ArithmeticItem, ParseSQLError<&str>> {
97 map(
98 pair(Self::arithmetic_cast, many0(Self::term_rest)),
99 |(b, rs)| {
100 rs.into_iter()
101 .fold(ArithmeticItem::Base(b.0), |acc, (o, r)| {
102 ArithmeticItem::Expr(Box::new(Arithmetic {
103 op: o,
104 left: acc,
105 right: r,
106 }))
107 })
108 },
109 )(i)
110 }
111
112 fn term_rest(
113 i: &str,
114 ) -> IResult<&str, (ArithmeticOperator, ArithmeticItem), ParseSQLError<&str>> {
115 separated_pair(
116 preceded(multispace0, ArithmeticOperator::mul_div_operator),
117 multispace0,
118 map(Self::arithmetic_cast, |b| ArithmeticItem::Base(b.0)),
119 )(i)
120 }
121
122 fn expr(i: &str) -> IResult<&str, ArithmeticItem, ParseSQLError<&str>> {
123 map(
124 pair(ArithmeticItem::term, many0(Self::expr_rest)),
125 |(item, rs)| {
126 rs.into_iter().fold(item, |acc, (o, r)| {
127 ArithmeticItem::Expr(Box::new(Arithmetic {
128 op: o,
129 left: acc,
130 right: r,
131 }))
132 })
133 },
134 )(i)
135 }
136
137 fn expr_rest(
138 i: &str,
139 ) -> IResult<&str, (ArithmeticOperator, ArithmeticItem), ParseSQLError<&str>> {
140 separated_pair(
141 preceded(multispace0, ArithmeticOperator::add_sub_operator),
142 multispace0,
143 ArithmeticItem::term,
144 )(i)
145 }
146
147 fn arithmetic_cast(
148 i: &str,
149 ) -> IResult<&str, (ArithmeticBase, Option<DataType>), ParseSQLError<&str>> {
150 alt((
151 Self::arithmetic_cast_helper,
152 map(ArithmeticBase::parse, |v| (v, None)),
153 ))(i)
154 }
155
156 fn arithmetic_cast_helper(
157 i: &str,
158 ) -> IResult<&str, (ArithmeticBase, Option<DataType>), ParseSQLError<&str>> {
159 let (remaining_input, (_, _, _, _, a_base, _, _, _, _sign, sql_type, _, _)) = tuple((
160 tag_no_case("CAST"),
161 multispace0,
162 tag("("),
163 multispace0,
164 ArithmeticBase::parse,
166 multispace1,
167 tag_no_case("AS"),
168 multispace1,
169 opt(terminated(tag_no_case("SIGNED"), multispace1)),
170 DataType::type_identifier,
171 multispace0,
172 tag(")"),
173 ))(i)?;
174
175 Ok((remaining_input, (a_base, Some(sql_type))))
176 }
177}
178
179impl fmt::Display for ArithmeticItem {
180 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
181 match *self {
182 ArithmeticItem::Base(ref b) => write!(f, "{}", b),
183 ArithmeticItem::Expr(ref expr) => write!(f, "{}", expr),
184 }
185 }
186}
187
188#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
189pub struct Arithmetic {
190 pub op: ArithmeticOperator,
191 pub left: ArithmeticItem,
192 pub right: ArithmeticItem,
193}
194
195impl Arithmetic {
196 fn parse(i: &str) -> IResult<&str, Arithmetic, ParseSQLError<&str>> {
197 let res = ArithmeticItem::expr(i)?;
198 match res.1 {
199 ArithmeticItem::Base(ArithmeticBase::Column(_))
200 | ArithmeticItem::Base(ArithmeticBase::Scalar(_)) => {
201 let mut error: ParseSQLError<&str> = ParseSQLError { errors: vec![] };
202 error.errors.push((i, ParseSQLErrorKind::Context("Tag")));
203 Err(Error(error))
204 } ArithmeticItem::Base(ArithmeticBase::Bracketed(expr)) => Ok((res.0, *expr)),
206 ArithmeticItem::Expr(expr) => Ok((res.0, *expr)),
207 }
208 }
209 pub fn new(op: ArithmeticOperator, left: ArithmeticBase, right: ArithmeticBase) -> Self {
210 Self {
211 op,
212 left: ArithmeticItem::Base(left),
213 right: ArithmeticItem::Base(right),
214 }
215 }
216}
217
218impl fmt::Display for Arithmetic {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 write!(f, "{} {} {}", self.left, self.op, self.right)
221 }
222}
223
224#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
225pub struct ArithmeticExpression {
226 pub ari: Arithmetic,
227 pub alias: Option<String>,
228}
229
230impl ArithmeticExpression {
231 pub fn parse(i: &str) -> IResult<&str, ArithmeticExpression, ParseSQLError<&str>> {
232 map(
233 pair(Arithmetic::parse, opt(CommonParser::as_alias)),
234 |(ari, opt_alias)| ArithmeticExpression {
235 ari,
236 alias: opt_alias.map(String::from),
237 },
238 )(i)
239 }
240}
241
242impl ArithmeticExpression {
243 pub fn new(
244 op: ArithmeticOperator,
245 left: ArithmeticBase,
246 right: ArithmeticBase,
247 alias: Option<String>,
248 ) -> Self {
249 Self {
250 ari: Arithmetic {
251 op,
252 left: ArithmeticItem::Base(left),
253 right: ArithmeticItem::Base(right),
254 },
255 alias,
256 }
257 }
258}
259
260impl fmt::Display for ArithmeticExpression {
261 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
262 match self.alias {
263 Some(ref alias) => write!(f, "{} AS {}", self.ari, alias),
264 None => write!(f, "{}", self.ari),
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use base::arithmetic::ArithmeticBase::Scalar;
272 use base::arithmetic::ArithmeticOperator::{Add, Divide, Multiply, Subtract};
273 use base::column::{Column, FunctionArgument, FunctionExpression};
274
275 use super::*;
276
277 #[test]
278 fn parses_arithmetic_expressions() {
279 use super::{
280 ArithmeticBase::{Column as ArithmeticBaseColumn, Scalar},
281 ArithmeticOperator::*,
282 };
283
284 let lit_ae = [
285 "5 + 42",
286 "5+42",
287 "5 * 42",
288 "5 - 42",
289 "5 / 42",
290 "2 * 10 AS twenty ",
291 ];
292
293 let col_lit_ae = [
297 "foo+5",
298 "foo + 5",
299 "5 + foo ",
300 "foo * bar AS foobar",
301 "MAX(foo)-3333",
302 ];
303
304 let expected_lit_ae = [
305 ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
306 ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
307 ArithmeticExpression::new(Multiply, Scalar(5.into()), Scalar(42.into()), None),
308 ArithmeticExpression::new(Subtract, Scalar(5.into()), Scalar(42.into()), None),
309 ArithmeticExpression::new(Divide, Scalar(5.into()), Scalar(42.into()), None),
310 ArithmeticExpression::new(
311 Multiply,
312 Scalar(2.into()),
313 Scalar(10.into()),
314 Some(String::from("twenty")),
315 ),
316 ];
317 let expected_col_lit_ae = [
318 ArithmeticExpression::new(
319 Add,
320 ArithmeticBaseColumn("foo".into()),
321 Scalar(5.into()),
322 None,
323 ),
324 ArithmeticExpression::new(
325 Add,
326 ArithmeticBaseColumn("foo".into()),
327 Scalar(5.into()),
328 None,
329 ),
330 ArithmeticExpression::new(
331 Add,
332 Scalar(5.into()),
333 ArithmeticBaseColumn("foo".into()),
334 None,
335 ),
336 ArithmeticExpression::new(
337 Multiply,
338 ArithmeticBaseColumn("foo".into()),
339 ArithmeticBaseColumn("bar".into()),
340 Some(String::from("foobar")),
341 ),
342 ArithmeticExpression::new(
343 Subtract,
344 ArithmeticBaseColumn(Column {
345 name: String::from("max(foo)"),
346 alias: None,
347 table: None,
348 function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column(
349 "foo".into(),
350 )))),
351 }),
352 Scalar(3333.into()),
353 None,
354 ),
355 ];
356
357 for (i, e) in lit_ae.iter().enumerate() {
358 let res = ArithmeticExpression::parse(e);
359 assert!(res.is_ok());
360 assert_eq!(res.unwrap().1, expected_lit_ae[i]);
361 }
362
363 for (i, e) in col_lit_ae.iter().enumerate() {
364 let res = ArithmeticExpression::parse(e);
365 assert!(res.is_ok());
366 assert_eq!(res.unwrap().1, expected_col_lit_ae[i]);
367 }
368 }
369
370 #[test]
371 fn displays_arithmetic_expressions() {
372 use super::{
373 ArithmeticBase::{Column as ArithmeticBaseColumn, Scalar},
374 ArithmeticOperator::*,
375 };
376
377 let expressions = [
378 ArithmeticExpression::new(
379 Add,
380 ArithmeticBaseColumn("foo".into()),
381 Scalar(5.into()),
382 None,
383 ),
384 ArithmeticExpression::new(
385 Subtract,
386 Scalar(5.into()),
387 ArithmeticBaseColumn("foo".into()),
388 None,
389 ),
390 ArithmeticExpression::new(
391 Multiply,
392 ArithmeticBaseColumn("foo".into()),
393 ArithmeticBaseColumn("bar".into()),
394 None,
395 ),
396 ArithmeticExpression::new(Divide, Scalar(10.into()), Scalar(2.into()), None),
397 ArithmeticExpression::new(
398 Add,
399 Scalar(10.into()),
400 Scalar(2.into()),
401 Some(String::from("bob")),
402 ),
403 ];
404
405 let expected_strings = ["foo + 5", "5 - foo", "foo * bar", "10 / 2", "10 + 2 AS bob"];
406 for (i, e) in expressions.iter().enumerate() {
407 assert_eq!(expected_strings[i], format!("{}", e));
408 }
409 }
410
411 #[test]
412 fn parses_arithmetic_casts() {
413 use super::{
414 ArithmeticBase::{Column as ArithmeticBaseColumn, Scalar},
415 ArithmeticOperator::*,
416 };
417
418 let exprs = [
419 "CAST(`t`.`foo` AS signed int) + CAST(`t`.`bar` AS signed int) ",
420 "CAST(5 AS bigint) - foo ",
421 "CAST(5 AS bigint) - foo AS `5_minus_foo`",
422 ];
423
424 let expected = [
426 ArithmeticExpression::new(
427 Add,
428 ArithmeticBaseColumn(Column::from("t.foo")),
429 ArithmeticBaseColumn(Column::from("t.bar")),
430 None,
431 ),
432 ArithmeticExpression::new(
433 Subtract,
434 Scalar(5.into()),
435 ArithmeticBaseColumn("foo".into()),
436 None,
437 ),
438 ArithmeticExpression::new(
439 Subtract,
440 Scalar(5.into()),
441 ArithmeticBaseColumn("foo".into()),
442 Some("5_minus_foo".into()),
443 ),
444 ];
445
446 for (i, e) in exprs.iter().enumerate() {
447 let res = ArithmeticExpression::parse(e);
448 assert!(res.is_ok(), "{} failed to parse", e);
449 assert_eq!(res.unwrap().1, expected[i]);
450 }
451 }
452
453 #[test]
454 fn parse_nested_arithmetic() {
455 let qs = [
456 "1 + 1",
457 "1 + 2 - 3",
458 "1 + 2 * 3",
459 "2 * 3 - 1 / 3",
460 "3 * (1 + 2)",
461 ];
462
463 let expects =
464 [
465 Arithmetic::new(Add, Scalar(1.into()), Scalar(1.into())),
466 Arithmetic {
467 op: Subtract,
468 left: ArithmeticItem::Expr(Box::new(Arithmetic::new(
469 Add,
470 Scalar(1.into()),
471 Scalar(2.into()),
472 ))),
473 right: ArithmeticItem::Base(Scalar(3.into())),
474 },
475 Arithmetic {
476 op: Add,
477 left: ArithmeticItem::Base(Scalar(1.into())),
478 right: ArithmeticItem::Expr(Box::new(Arithmetic::new(
479 Multiply,
480 Scalar(2.into()),
481 Scalar(3.into()),
482 ))),
483 },
484 Arithmetic {
485 op: Subtract,
486 left: ArithmeticItem::Expr(Box::new(Arithmetic::new(
487 Multiply,
488 Scalar(2.into()),
489 Scalar(3.into()),
490 ))),
491 right: ArithmeticItem::Expr(Box::new(Arithmetic::new(
492 Divide,
493 Scalar(1.into()),
494 Scalar(3.into()),
495 ))),
496 },
497 Arithmetic {
498 op: Multiply,
499 left: ArithmeticItem::Base(Scalar(3.into())),
500 right: ArithmeticItem::Base(ArithmeticBase::Bracketed(Box::new(
501 Arithmetic::new(Add, Scalar(1.into()), Scalar(2.into())),
502 ))),
503 },
504 ];
505
506 for (i, e) in qs.iter().enumerate() {
507 let res = Arithmetic::parse(e);
508 let ari = res.unwrap().1;
509 assert_eq!(ari, expects[i]);
510 assert_eq!(format!("{}", ari), qs[i]);
511 }
512 }
513
514 #[test]
515 fn parse_arithmetic_scalar() {
516 let qs = "56";
517 let res = Arithmetic::parse(qs);
518 assert!(res.is_err());
519 }
520}