1use nom::branch::alt;
2use nom::bytes::complete::tag;
3use nom::character::complete::{anychar, digit1, multispace0, multispace1};
4use nom::combinator::{map, opt, recognize};
5use nom::multi::many1;
6use nom::sequence::{delimited, preceded, terminated, tuple};
7use nom::IResult;
8use std::fmt::{write, Display, Formatter};
9
10use base::error::ParseSQLError;
11use base::{CommonParser, OrderType};
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
15pub struct KeyPart {
16 pub r#type: KeyPartType,
17 pub order: Option<OrderType>,
18}
19
20impl Display for KeyPart {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 write!(f, "{}", self.r#type);
23 if let Some(order) = &self.order {
24 write!(f, " {}", order);
25 }
26 Ok(())
27 }
28}
29
30impl KeyPart {
31 pub fn parse(i: &str) -> IResult<&str, Vec<KeyPart>, ParseSQLError<&str>> {
33 map(
34 tuple((
35 multispace0,
36 delimited(
37 tag("("),
38 delimited(
39 multispace0,
40 many1(map(
41 terminated(Self::parse_item, opt(CommonParser::ws_sep_comma)),
42 |e| e,
43 )),
44 multispace0,
45 ),
46 tag(")"),
47 ),
48 )),
49 |(_, val)| val,
50 )(i)
51 }
52
53 fn parse_item(i: &str) -> IResult<&str, KeyPart, ParseSQLError<&str>> {
55 map(
56 tuple((
57 KeyPartType::parse,
58 opt(map(
59 tuple((multispace1, OrderType::parse, multispace0)),
60 |(_, order, _)| order,
61 )),
62 )),
63 |(r#type, order)| KeyPart { r#type, order },
64 )(i)
65 }
66
67 pub fn format_list(key_parts: &[KeyPart]) -> String {
68 let key_parts = key_parts
69 .iter()
70 .map(|x| x.to_string())
71 .collect::<Vec<String>>()
72 .join(", ");
73 format!("({})", key_parts)
74 }
75}
76
77#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
79pub enum KeyPartType {
80 ColumnNameWithLength {
81 col_name: String,
82 length: Option<usize>,
83 },
84 Expr {
85 expr: String,
86 },
87}
88
89impl Display for KeyPartType {
90 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91 match *self {
92 KeyPartType::ColumnNameWithLength {
93 ref col_name,
94 ref length,
95 } => {
96 if let Some(length) = length {
97 write!(f, "{}({})", col_name, length)
98 } else {
99 write!(f, "{}", col_name)
100 }
101 }
102 KeyPartType::Expr { ref expr } => write!(f, "({})", expr),
103 }
104 }
105}
106
107impl KeyPartType {
108 fn parse(i: &str) -> IResult<&str, KeyPartType, ParseSQLError<&str>> {
109 let col_name_with_length = tuple((
111 CommonParser::sql_identifier,
112 multispace0,
113 opt(delimited(
114 tag("("),
115 map(digit1, |digit_str: &str| {
116 digit_str.parse::<usize>().unwrap()
117 }),
118 tag(")"),
119 )),
120 ));
121
122 let expr = preceded(
123 multispace0,
124 delimited(tag("("), recognize(many1(anychar)), tag(")")),
125 );
126
127 alt((
128 map(col_name_with_length, |(col_name, _, length)| {
129 KeyPartType::ColumnNameWithLength {
130 col_name: String::from(col_name),
131 length,
132 }
133 }),
134 map(expr, |expr| KeyPartType::Expr {
135 expr: String::from(expr),
136 }),
137 ))(i)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use base::{KeyPart, KeyPartType};
144
145 #[test]
146 fn parse_key_part_type() {
147 let str1 = "column_name(10)";
148 let res1 = KeyPartType::parse(str1);
149 let exp = KeyPartType::ColumnNameWithLength {
150 col_name: "column_name".to_string(),
151 length: Some(10),
152 };
153 assert!(res1.is_ok());
154 assert_eq!(res1.unwrap().1, exp);
155 }
156
157 #[test]
158 fn parse_key_part() {
159 let str1 = "(column_name(10))";
160 let res1 = KeyPart::parse(str1);
161
162 let key_part = KeyPartType::ColumnNameWithLength {
163 col_name: "column_name".to_string(),
164 length: Some(10),
165 };
166 let exp = vec![KeyPart {
167 r#type: key_part,
168 order: None,
169 }];
170 assert!(res1.is_ok());
171 assert_eq!(res1.unwrap().1, exp);
172 }
173}