1use std::fmt;
2
3use nom::bytes::complete::tag_no_case;
4use nom::character::complete::{multispace0, multispace1};
5use nom::combinator::opt;
6use nom::sequence::{delimited, terminated, tuple};
7use nom::IResult;
8
9use base::column::Column;
10use base::condition::ConditionExpression;
11use base::error::ParseSQLError;
12use base::Literal;
13
14#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
22pub struct CaseWhenExpression {
23 pub condition: ConditionExpression,
24 pub then_expr: ColumnOrLiteral,
25 pub else_expr: Option<ColumnOrLiteral>,
26}
27
28impl CaseWhenExpression {
29 pub fn parse(i: &str) -> IResult<&str, CaseWhenExpression, ParseSQLError<&str>> {
30 let (input, (_, _, _, _, condition, _, _, _, column, _, else_val, _)) = tuple((
31 tag_no_case("CASE"),
32 multispace1,
33 tag_no_case("WHEN"),
34 multispace0,
35 ConditionExpression::condition_expr,
36 multispace0,
37 tag_no_case("THEN"),
38 multispace0,
39 Column::without_alias,
40 multispace0,
41 opt(delimited(
42 terminated(tag_no_case("ELSE"), multispace0),
43 Literal::parse,
44 multispace0,
45 )),
46 tag_no_case("END"),
47 ))(i)?;
48
49 let then_expr = ColumnOrLiteral::Column(column);
50 let else_expr = else_val.map(ColumnOrLiteral::Literal);
51
52 Ok((
53 input,
54 CaseWhenExpression {
55 condition,
56 then_expr,
57 else_expr,
58 },
59 ))
60 }
61}
62
63impl fmt::Display for CaseWhenExpression {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 write!(f, "CASE WHEN {} THEN {}", self.condition, self.then_expr)?;
66 if let Some(ref expr) = self.else_expr {
67 write!(f, " ELSE {}", expr)?;
68 }
69 Ok(())
70 }
71}
72
73#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
74pub enum ColumnOrLiteral {
75 Column(Column),
76 Literal(Literal),
77}
78
79impl fmt::Display for ColumnOrLiteral {
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 match *self {
82 ColumnOrLiteral::Column(ref c) => write!(f, "{}", c)?,
83 ColumnOrLiteral::Literal(ref l) => write!(f, "{}", l)?,
84 }
85 Ok(())
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use base::condition::ConditionBase::{Field, Literal};
92 use base::condition::ConditionExpression::{Base, ComparisonOp};
93 use base::condition::ConditionTree;
94 use base::Literal::Integer;
95 use base::Operator::Greater;
96 use base::{CaseWhenExpression, Column, ColumnOrLiteral};
97
98 #[test]
99 fn parse_case() {
100 let str = "CASE WHEN age > 10 THEN col_name ELSE 22 END;";
101 let res = CaseWhenExpression::parse(str);
102
103 let exp = CaseWhenExpression {
104 condition: ComparisonOp(ConditionTree {
105 operator: Greater,
106 left: Box::new(Base(Field(Column {
107 name: "age".to_string(),
108 alias: None,
109 table: None,
110 function: None,
111 }))),
112 right: Box::new(Base(Literal(Integer(10)))),
113 }),
114 then_expr: ColumnOrLiteral::Column(Column {
115 name: "col_name".to_string(),
116 alias: None,
117 table: None,
118 function: None,
119 }),
120 else_expr: Some(ColumnOrLiteral::Literal(Integer(22))),
121 };
122
123 assert!(res.is_ok());
124 assert_eq!(res.unwrap().1, exp);
125 }
126}