1use std::fmt;
2use std::fmt::Display;
3
4use nom::branch::alt;
5use nom::bytes::complete::tag;
6use nom::character::complete::multispace0;
7use nom::combinator::{map, opt};
8use nom::multi::{many0, many1};
9use nom::sequence::{delimited, separated_pair, terminated};
10use nom::IResult;
11
12use base::arithmetic::ArithmeticExpression;
13use base::column::Column;
14use base::error::ParseSQLError;
15use base::literal::LiteralExpression;
16use base::table::Table;
17use base::{CommonParser, DisplayUtil, Literal};
18
19#[derive(Default, Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
20pub enum FieldDefinitionExpression {
21 #[default]
22 All,
23 AllInTable(String),
24 Col(Column),
25 Value(FieldValueExpression),
26}
27
28impl FieldDefinitionExpression {
29 pub fn parse(i: &str) -> IResult<&str, Vec<FieldDefinitionExpression>, ParseSQLError<&str>> {
31 many0(terminated(
32 alt((
33 map(tag("*"), |_| FieldDefinitionExpression::All),
34 map(terminated(Table::table_reference, tag(".*")), |t| {
35 FieldDefinitionExpression::AllInTable(t.name.clone())
36 }),
37 map(ArithmeticExpression::parse, |expr| {
38 FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(expr))
39 }),
40 map(LiteralExpression::parse, |lit| {
41 FieldDefinitionExpression::Value(FieldValueExpression::Literal(lit))
42 }),
43 map(Column::parse, FieldDefinitionExpression::Col),
44 )),
45 opt(CommonParser::ws_sep_comma),
46 ))(i)
47 }
48
49 pub fn from_column_str(cols: &[&str]) -> Vec<FieldDefinitionExpression> {
50 cols.iter()
51 .map(|c| FieldDefinitionExpression::Col(Column::from(*c)))
52 .collect()
53 }
54}
55
56impl Display for FieldDefinitionExpression {
57 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58 match *self {
59 FieldDefinitionExpression::All => write!(f, "*"),
60 FieldDefinitionExpression::AllInTable(ref table) => {
61 write!(f, "{}.*", DisplayUtil::escape_if_keyword(table))
62 }
63 FieldDefinitionExpression::Col(ref col) => write!(f, "{}", col),
64 FieldDefinitionExpression::Value(ref val) => write!(f, "{}", val),
65 }
66 }
67}
68
69#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
70pub enum FieldValueExpression {
71 Arithmetic(ArithmeticExpression),
72 Literal(LiteralExpression),
73}
74
75impl FieldValueExpression {
76 fn parse(i: &str) -> IResult<&str, FieldValueExpression, ParseSQLError<&str>> {
77 alt((
78 map(Literal::parse, |l| {
79 FieldValueExpression::Literal(LiteralExpression {
80 value: l,
81 alias: None,
82 })
83 }),
84 map(ArithmeticExpression::parse, |ae| {
85 FieldValueExpression::Arithmetic(ae)
86 }),
87 ))(i)
88 }
89
90 fn assignment_expr(
91 i: &str,
92 ) -> IResult<&str, (Column, FieldValueExpression), ParseSQLError<&str>> {
93 separated_pair(
94 Column::without_alias,
95 delimited(multispace0, tag("="), multispace0),
96 Self::parse,
97 )(i)
98 }
99
100 pub fn assignment_expr_list(
101 i: &str,
102 ) -> IResult<&str, Vec<(Column, FieldValueExpression)>, ParseSQLError<&str>> {
103 many1(terminated(
104 Self::assignment_expr,
105 opt(CommonParser::ws_sep_comma),
106 ))(i)
107 }
108}
109
110impl Display for FieldValueExpression {
111 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 match *self {
113 FieldValueExpression::Arithmetic(ref expr) => write!(f, "{}", expr),
114 FieldValueExpression::Literal(ref lit) => write!(f, "{}", lit),
115 }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use base::algorithm_type::AlgorithmType;
122 use base::arithmetic::ArithmeticBase;
123 use base::arithmetic::ArithmeticExpression;
124 use base::arithmetic::ArithmeticOperator::{Add, Multiply};
125 use base::{FieldDefinitionExpression, FieldValueExpression, Literal};
126 use std::vec;
127
128 #[test]
129 fn parse_field_definition_expression() {
130 let str1 = "*";
131 let res1 = FieldDefinitionExpression::parse(str1);
132 assert!(res1.is_ok());
133 assert_eq!(res1.unwrap().1, vec![FieldDefinitionExpression::All]);
134
135 let str2 = "tbl_name.*";
136 let res2 = FieldDefinitionExpression::parse(str2);
137 assert!(res2.is_ok());
138 assert_eq!(
139 res2.unwrap().1,
140 vec![FieldDefinitionExpression::AllInTable(
141 "tbl_name".to_string()
142 )]
143 );
144
145 let str3 = "age, name, score";
146 let res3 = FieldDefinitionExpression::parse(str3);
147 let exp = vec![
148 FieldDefinitionExpression::Col("age".into()),
149 FieldDefinitionExpression::Col("name".into()),
150 FieldDefinitionExpression::Col("score".into()),
151 ];
152 assert!(res3.is_ok());
153 assert_eq!(res3.unwrap().1, exp);
154
155 let str4 = "1+2, price * count as total_count";
156 let res4 = FieldDefinitionExpression::parse(str4);
157 let exp = vec![
158 FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(
159 ArithmeticExpression::new(
160 Add,
161 ArithmeticBase::Scalar(Literal::Integer(1)),
162 ArithmeticBase::Scalar(Literal::Integer(2)),
163 None,
164 ),
165 )),
166 FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(
167 ArithmeticExpression::new(
168 Multiply,
169 ArithmeticBase::Column("price".into()),
170 ArithmeticBase::Column("count".into()),
171 Some(String::from("total_count")),
172 ),
173 )),
174 ];
175 assert!(res4.is_ok());
176 assert_eq!(res4.unwrap().1, exp);
177 }
178}