sqlparser_mysql/base/
field.rs

1use std::fmt;
2use std::fmt::Display;
3
4use nom::branch::alt;
5use nom::bytes::complete::tag;
6use nom::character::complete::multispace0;
7use nom::combinator::{map, opt};
8use nom::multi::{many0, many1};
9use nom::sequence::{delimited, separated_pair, terminated};
10use nom::IResult;
11
12use base::arithmetic::ArithmeticExpression;
13use base::column::Column;
14use base::error::ParseSQLError;
15use base::literal::LiteralExpression;
16use base::table::Table;
17use base::{CommonParser, DisplayUtil, Literal};
18
19#[derive(Default, Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
20pub enum FieldDefinitionExpression {
21    #[default]
22    All,
23    AllInTable(String),
24    Col(Column),
25    Value(FieldValueExpression),
26}
27
28impl FieldDefinitionExpression {
29    /// Parse list of column/field definitions.
30    pub fn parse(i: &str) -> IResult<&str, Vec<FieldDefinitionExpression>, ParseSQLError<&str>> {
31        many0(terminated(
32            alt((
33                map(tag("*"), |_| FieldDefinitionExpression::All),
34                map(terminated(Table::table_reference, tag(".*")), |t| {
35                    FieldDefinitionExpression::AllInTable(t.name.clone())
36                }),
37                map(ArithmeticExpression::parse, |expr| {
38                    FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(expr))
39                }),
40                map(LiteralExpression::parse, |lit| {
41                    FieldDefinitionExpression::Value(FieldValueExpression::Literal(lit))
42                }),
43                map(Column::parse, FieldDefinitionExpression::Col),
44            )),
45            opt(CommonParser::ws_sep_comma),
46        ))(i)
47    }
48
49    pub fn from_column_str(cols: &[&str]) -> Vec<FieldDefinitionExpression> {
50        cols.iter()
51            .map(|c| FieldDefinitionExpression::Col(Column::from(*c)))
52            .collect()
53    }
54}
55
56impl Display for FieldDefinitionExpression {
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58        match *self {
59            FieldDefinitionExpression::All => write!(f, "*"),
60            FieldDefinitionExpression::AllInTable(ref table) => {
61                write!(f, "{}.*", DisplayUtil::escape_if_keyword(table))
62            }
63            FieldDefinitionExpression::Col(ref col) => write!(f, "{}", col),
64            FieldDefinitionExpression::Value(ref val) => write!(f, "{}", val),
65        }
66    }
67}
68
69#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
70pub enum FieldValueExpression {
71    Arithmetic(ArithmeticExpression),
72    Literal(LiteralExpression),
73}
74
75impl FieldValueExpression {
76    fn parse(i: &str) -> IResult<&str, FieldValueExpression, ParseSQLError<&str>> {
77        alt((
78            map(Literal::parse, |l| {
79                FieldValueExpression::Literal(LiteralExpression {
80                    value: l,
81                    alias: None,
82                })
83            }),
84            map(ArithmeticExpression::parse, |ae| {
85                FieldValueExpression::Arithmetic(ae)
86            }),
87        ))(i)
88    }
89
90    fn assignment_expr(
91        i: &str,
92    ) -> IResult<&str, (Column, FieldValueExpression), ParseSQLError<&str>> {
93        separated_pair(
94            Column::without_alias,
95            delimited(multispace0, tag("="), multispace0),
96            Self::parse,
97        )(i)
98    }
99
100    pub fn assignment_expr_list(
101        i: &str,
102    ) -> IResult<&str, Vec<(Column, FieldValueExpression)>, ParseSQLError<&str>> {
103        many1(terminated(
104            Self::assignment_expr,
105            opt(CommonParser::ws_sep_comma),
106        ))(i)
107    }
108}
109
110impl Display for FieldValueExpression {
111    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        match *self {
113            FieldValueExpression::Arithmetic(ref expr) => write!(f, "{}", expr),
114            FieldValueExpression::Literal(ref lit) => write!(f, "{}", lit),
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use base::algorithm_type::AlgorithmType;
122    use base::arithmetic::ArithmeticBase;
123    use base::arithmetic::ArithmeticExpression;
124    use base::arithmetic::ArithmeticOperator::{Add, Multiply};
125    use base::{FieldDefinitionExpression, FieldValueExpression, Literal};
126    use std::vec;
127
128    #[test]
129    fn parse_field_definition_expression() {
130        let str1 = "*";
131        let res1 = FieldDefinitionExpression::parse(str1);
132        assert!(res1.is_ok());
133        assert_eq!(res1.unwrap().1, vec![FieldDefinitionExpression::All]);
134
135        let str2 = "tbl_name.*";
136        let res2 = FieldDefinitionExpression::parse(str2);
137        assert!(res2.is_ok());
138        assert_eq!(
139            res2.unwrap().1,
140            vec![FieldDefinitionExpression::AllInTable(
141                "tbl_name".to_string()
142            )]
143        );
144
145        let str3 = "age, name, score";
146        let res3 = FieldDefinitionExpression::parse(str3);
147        let exp = vec![
148            FieldDefinitionExpression::Col("age".into()),
149            FieldDefinitionExpression::Col("name".into()),
150            FieldDefinitionExpression::Col("score".into()),
151        ];
152        assert!(res3.is_ok());
153        assert_eq!(res3.unwrap().1, exp);
154
155        let str4 = "1+2, price * count as total_count";
156        let res4 = FieldDefinitionExpression::parse(str4);
157        let exp = vec![
158            FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(
159                ArithmeticExpression::new(
160                    Add,
161                    ArithmeticBase::Scalar(Literal::Integer(1)),
162                    ArithmeticBase::Scalar(Literal::Integer(2)),
163                    None,
164                ),
165            )),
166            FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(
167                ArithmeticExpression::new(
168                    Multiply,
169                    ArithmeticBase::Column("price".into()),
170                    ArithmeticBase::Column("count".into()),
171                    Some(String::from("total_count")),
172                ),
173            )),
174        ];
175        assert!(res4.is_ok());
176        assert_eq!(res4.unwrap().1, exp);
177    }
178}