sqlparser_mysql/base/
case.rs

1use std::fmt;
2
3use nom::bytes::complete::tag_no_case;
4use nom::character::complete::{multispace0, multispace1};
5use nom::combinator::opt;
6use nom::sequence::{delimited, terminated, tuple};
7use nom::IResult;
8
9use base::column::Column;
10use base::condition::ConditionExpression;
11use base::error::ParseSQLError;
12use base::Literal;
13
14/// ```sql
15/// CASE expression
16///     WHEN {value1 | condition1} THEN result1
17///     ...
18///     ELSE resultN
19/// END
20/// ```
21#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
22pub struct CaseWhenExpression {
23    pub condition: ConditionExpression,
24    pub then_expr: ColumnOrLiteral,
25    pub else_expr: Option<ColumnOrLiteral>,
26}
27
28impl CaseWhenExpression {
29    pub fn parse(i: &str) -> IResult<&str, CaseWhenExpression, ParseSQLError<&str>> {
30        let (input, (_, _, _, _, condition, _, _, _, column, _, else_val, _)) = tuple((
31            tag_no_case("CASE"),
32            multispace1,
33            tag_no_case("WHEN"),
34            multispace0,
35            ConditionExpression::condition_expr,
36            multispace0,
37            tag_no_case("THEN"),
38            multispace0,
39            Column::without_alias,
40            multispace0,
41            opt(delimited(
42                terminated(tag_no_case("ELSE"), multispace0),
43                Literal::parse,
44                multispace0,
45            )),
46            tag_no_case("END"),
47        ))(i)?;
48
49        let then_expr = ColumnOrLiteral::Column(column);
50        let else_expr = else_val.map(ColumnOrLiteral::Literal);
51
52        Ok((
53            input,
54            CaseWhenExpression {
55                condition,
56                then_expr,
57                else_expr,
58            },
59        ))
60    }
61}
62
63impl fmt::Display for CaseWhenExpression {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        write!(f, "CASE WHEN {} THEN {}", self.condition, self.then_expr)?;
66        if let Some(ref expr) = self.else_expr {
67            write!(f, " ELSE {}", expr)?;
68        }
69        Ok(())
70    }
71}
72
73#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
74pub enum ColumnOrLiteral {
75    Column(Column),
76    Literal(Literal),
77}
78
79impl fmt::Display for ColumnOrLiteral {
80    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81        match *self {
82            ColumnOrLiteral::Column(ref c) => write!(f, "{}", c)?,
83            ColumnOrLiteral::Literal(ref l) => write!(f, "{}", l)?,
84        }
85        Ok(())
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use base::condition::ConditionBase::{Field, Literal};
92    use base::condition::ConditionExpression::{Base, ComparisonOp};
93    use base::condition::ConditionTree;
94    use base::Literal::Integer;
95    use base::Operator::Greater;
96    use base::{CaseWhenExpression, Column, ColumnOrLiteral};
97
98    #[test]
99    fn parse_case() {
100        let str = "CASE WHEN age > 10 THEN col_name ELSE 22 END;";
101        let res = CaseWhenExpression::parse(str);
102
103        let exp = CaseWhenExpression {
104            condition: ComparisonOp(ConditionTree {
105                operator: Greater,
106                left: Box::new(Base(Field(Column {
107                    name: "age".to_string(),
108                    alias: None,
109                    table: None,
110                    function: None,
111                }))),
112                right: Box::new(Base(Literal(Integer(10)))),
113            }),
114            then_expr: ColumnOrLiteral::Column(Column {
115                name: "col_name".to_string(),
116                alias: None,
117                table: None,
118                function: None,
119            }),
120            else_expr: Some(ColumnOrLiteral::Literal(Integer(22))),
121        };
122
123        assert!(res.is_ok());
124        assert_eq!(res.unwrap().1, exp);
125    }
126}