sqlparser_mysql/base/
key_part.rs

1use nom::branch::alt;
2use nom::bytes::complete::tag;
3use nom::character::complete::{anychar, digit1, multispace0, multispace1};
4use nom::combinator::{map, opt, recognize};
5use nom::multi::many1;
6use nom::sequence::{delimited, preceded, terminated, tuple};
7use nom::IResult;
8use std::fmt::{write, Display, Formatter};
9
10use base::error::ParseSQLError;
11use base::{CommonParser, OrderType};
12
13/// parse `key_part: {col_name [(length)] | (expr)} [ASC | DESC]`
14#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
15pub struct KeyPart {
16    pub r#type: KeyPartType,
17    pub order: Option<OrderType>,
18}
19
20impl Display for KeyPart {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        write!(f, "{}", self.r#type);
23        if let Some(order) = &self.order {
24            write!(f, " {}", order);
25        }
26        Ok(())
27    }
28}
29
30impl KeyPart {
31    ///parse list of key_part `(key_part,...)`
32    pub fn parse(i: &str) -> IResult<&str, Vec<KeyPart>, ParseSQLError<&str>> {
33        map(
34            tuple((
35                multispace0,
36                delimited(
37                    tag("("),
38                    delimited(
39                        multispace0,
40                        many1(map(
41                            terminated(Self::parse_item, opt(CommonParser::ws_sep_comma)),
42                            |e| e,
43                        )),
44                        multispace0,
45                    ),
46                    tag(")"),
47                ),
48            )),
49            |(_, val)| val,
50        )(i)
51    }
52
53    /// parse `key_part: {col_name [(length)] | (expr)} [ASC | DESC]`
54    fn parse_item(i: &str) -> IResult<&str, KeyPart, ParseSQLError<&str>> {
55        map(
56            tuple((
57                KeyPartType::parse,
58                opt(map(
59                    tuple((multispace1, OrderType::parse, multispace0)),
60                    |(_, order, _)| order,
61                )),
62            )),
63            |(r#type, order)| KeyPart { r#type, order },
64        )(i)
65    }
66
67    pub fn format_list(key_parts: &[KeyPart]) -> String {
68        let key_parts = key_parts
69            .iter()
70            .map(|x| x.to_string())
71            .collect::<Vec<String>>()
72            .join(", ");
73        format!("({})", key_parts)
74    }
75}
76
77/// parse `{col_name [(length)] | (expr)}`
78#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
79pub enum KeyPartType {
80    ColumnNameWithLength {
81        col_name: String,
82        length: Option<usize>,
83    },
84    Expr {
85        expr: String,
86    },
87}
88
89impl Display for KeyPartType {
90    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91        match *self {
92            KeyPartType::ColumnNameWithLength {
93                ref col_name,
94                ref length,
95            } => {
96                if let Some(length) = length {
97                    write!(f, "{}({})", col_name, length)
98                } else {
99                    write!(f, "{}", col_name)
100                }
101            }
102            KeyPartType::Expr { ref expr } => write!(f, "({})", expr),
103        }
104    }
105}
106
107impl KeyPartType {
108    fn parse(i: &str) -> IResult<&str, KeyPartType, ParseSQLError<&str>> {
109        // {col_name [(length)]
110        let col_name_with_length = tuple((
111            CommonParser::sql_identifier,
112            multispace0,
113            opt(delimited(
114                tag("("),
115                map(digit1, |digit_str: &str| {
116                    digit_str.parse::<usize>().unwrap()
117                }),
118                tag(")"),
119            )),
120        ));
121
122        let expr = preceded(
123            multispace0,
124            delimited(tag("("), recognize(many1(anychar)), tag(")")),
125        );
126
127        alt((
128            map(col_name_with_length, |(col_name, _, length)| {
129                KeyPartType::ColumnNameWithLength {
130                    col_name: String::from(col_name),
131                    length,
132                }
133            }),
134            map(expr, |expr| KeyPartType::Expr {
135                expr: String::from(expr),
136            }),
137        ))(i)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use base::{KeyPart, KeyPartType};
144
145    #[test]
146    fn parse_key_part_type() {
147        let str1 = "column_name(10)";
148        let res1 = KeyPartType::parse(str1);
149        let exp = KeyPartType::ColumnNameWithLength {
150            col_name: "column_name".to_string(),
151            length: Some(10),
152        };
153        assert!(res1.is_ok());
154        assert_eq!(res1.unwrap().1, exp);
155    }
156
157    #[test]
158    fn parse_key_part() {
159        let str1 = "(column_name(10))";
160        let res1 = KeyPart::parse(str1);
161
162        let key_part = KeyPartType::ColumnNameWithLength {
163            col_name: "column_name".to_string(),
164            length: Some(10),
165        };
166        let exp = vec![KeyPart {
167            r#type: key_part,
168            order: None,
169        }];
170        assert!(res1.is_ok());
171        assert_eq!(res1.unwrap().1, exp);
172    }
173}