1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use std::fmt;

use nom::bytes::complete::tag_no_case;
use nom::character::complete::{multispace0, multispace1};
use nom::combinator::opt;
use nom::sequence::{delimited, terminated, tuple};
use nom::IResult;

use base::column::Column;
use base::condition::ConditionExpression;
use base::error::ParseSQLError;
use base::Literal;

/// ```sql
/// CASE expression
///     WHEN {value1 | condition1} THEN result1
///     ...
///     ELSE resultN
/// END
/// ```
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct CaseWhenExpression {
    pub condition: ConditionExpression,
    pub then_expr: ColumnOrLiteral,
    pub else_expr: Option<ColumnOrLiteral>,
}

impl CaseWhenExpression {
    pub fn parse(i: &str) -> IResult<&str, CaseWhenExpression, ParseSQLError<&str>> {
        let (input, (_, _, _, _, condition, _, _, _, column, _, else_val, _)) = tuple((
            tag_no_case("CASE"),
            multispace1,
            tag_no_case("WHEN"),
            multispace0,
            ConditionExpression::condition_expr,
            multispace0,
            tag_no_case("THEN"),
            multispace0,
            Column::without_alias,
            multispace0,
            opt(delimited(
                terminated(tag_no_case("ELSE"), multispace0),
                Literal::parse,
                multispace0,
            )),
            tag_no_case("END"),
        ))(i)?;

        let then_expr = ColumnOrLiteral::Column(column);
        let else_expr = else_val.map(ColumnOrLiteral::Literal);

        Ok((
            input,
            CaseWhenExpression {
                condition,
                then_expr,
                else_expr,
            },
        ))
    }
}

impl fmt::Display for CaseWhenExpression {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "CASE WHEN {} THEN {}", self.condition, self.then_expr)?;
        if let Some(ref expr) = self.else_expr {
            write!(f, " ELSE {}", expr)?;
        }
        Ok(())
    }
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum ColumnOrLiteral {
    Column(Column),
    Literal(Literal),
}

impl fmt::Display for ColumnOrLiteral {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match *self {
            ColumnOrLiteral::Column(ref c) => write!(f, "{}", c)?,
            ColumnOrLiteral::Literal(ref l) => write!(f, "{}", l)?,
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use base::condition::ConditionBase::{Field, Literal};
    use base::condition::ConditionExpression::{Base, ComparisonOp};
    use base::condition::ConditionTree;
    use base::Literal::Integer;
    use base::Operator::Greater;
    use base::{CaseWhenExpression, Column, ColumnOrLiteral};

    #[test]
    fn parse_case() {
        let str = "CASE WHEN age > 10 THEN col_name ELSE 22 END;";
        let res = CaseWhenExpression::parse(str);

        let exp = CaseWhenExpression {
            condition: ComparisonOp(ConditionTree {
                operator: Greater,
                left: Box::new(Base(Field(Column {
                    name: "age".to_string(),
                    alias: None,
                    table: None,
                    function: None,
                }))),
                right: Box::new(Base(Literal(Integer(10)))),
            }),
            then_expr: ColumnOrLiteral::Column(Column {
                name: "col_name".to_string(),
                alias: None,
                table: None,
                function: None,
            }),
            else_expr: Some(ColumnOrLiteral::Literal(Integer(22))),
        };

        assert!(res.is_ok());
        assert_eq!(res.unwrap().1, exp);
    }
}