1use std::fmt;
2use std::str;
3
4use nom::branch::alt;
5use nom::bytes::complete::{tag, tag_no_case};
6use nom::character::complete::{multispace0, multispace1};
7use nom::combinator::{map, opt};
8use nom::sequence::{delimited, preceded, terminated, tuple};
9use nom::IResult;
10
11use base::column::Column;
12use base::condition::ConditionExpression;
13use base::error::ParseSQLError;
14use base::table::Table;
15use base::CommonParser;
16use dms::SelectStatement;
17
18#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
20pub struct JoinClause {
21 pub operator: JoinOperator,
22 pub right: JoinRightSide,
23 pub constraint: JoinConstraint,
24}
25
26impl JoinClause {
27 pub fn parse(i: &str) -> IResult<&str, JoinClause, ParseSQLError<&str>> {
28 let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple((
29 multispace0,
30 opt(terminated(tag_no_case("NATURAL"), multispace1)),
31 JoinOperator::parse,
32 multispace1,
33 JoinRightSide::parse,
34 multispace1,
35 JoinConstraint::parse,
36 ))(i)?;
37
38 Ok((
39 remaining_input,
40 JoinClause {
41 operator,
42 right,
43 constraint,
44 },
45 ))
46 }
47}
48
49impl fmt::Display for JoinClause {
50 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51 write!(f, "{}", self.operator)?;
52 write!(f, " {}", self.right)?;
53 write!(f, " {}", self.constraint)?;
54 Ok(())
55 }
56}
57
58#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
60pub enum JoinRightSide {
61 Table(Table),
63 Tables(Vec<Table>),
65 NestedSelect(Box<SelectStatement>, Option<String>),
67 NestedJoin(Box<JoinClause>),
69}
70
71impl JoinRightSide {
72 pub fn parse(i: &str) -> IResult<&str, JoinRightSide, ParseSQLError<&str>> {
73 let nested_select = map(
74 tuple((
75 delimited(tag("("), SelectStatement::nested_selection, tag(")")),
76 opt(CommonParser::as_alias),
77 )),
78 |t| JoinRightSide::NestedSelect(Box::new(t.0), t.1.map(String::from)),
79 );
80 let nested_join = map(delimited(tag("("), JoinClause::parse, tag(")")), |nj| {
81 JoinRightSide::NestedJoin(Box::new(nj))
82 });
83 let table = map(Table::table_reference, JoinRightSide::Table);
84 let tables = map(delimited(tag("("), Table::table_list, tag(")")), |tables| {
85 JoinRightSide::Tables(tables)
86 });
87 alt((nested_select, nested_join, table, tables))(i)
88 }
89}
90
91impl fmt::Display for JoinRightSide {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 match *self {
94 JoinRightSide::Table(ref t) => write!(f, "{}", t)?,
95 JoinRightSide::NestedSelect(ref q, ref a) => {
96 write!(f, "({})", q)?;
97 if a.is_some() {
98 write!(f, " AS {}", a.as_ref().unwrap())?;
99 }
100 }
101 JoinRightSide::NestedJoin(ref jc) => write!(f, "({})", jc)?,
102 _ => unimplemented!(),
103 }
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
117pub enum JoinOperator {
118 Join,
119 LeftJoin,
120 LeftOuterJoin,
121 RightJoin,
122 InnerJoin,
123 CrossJoin,
124 StraightJoin,
125}
126
127impl JoinOperator {
128 pub fn parse(i: &str) -> IResult<&str, JoinOperator, ParseSQLError<&str>> {
130 alt((
131 map(tag_no_case("JOIN"), |_| JoinOperator::Join),
132 map(
133 tuple((tag_no_case("LEFT"), multispace1, tag_no_case("JOIN"))),
134 |_| JoinOperator::LeftJoin,
135 ),
136 map(
137 tuple((
138 tag_no_case("LEFT"),
139 multispace1,
140 tag_no_case("OUTER"),
141 multispace1,
142 tag_no_case("JOIN"),
143 )),
144 |_| JoinOperator::LeftOuterJoin,
145 ),
146 map(
147 tuple((tag_no_case("RIGHT"), multispace1, tag_no_case("JOIN"))),
148 |_| JoinOperator::RightJoin,
149 ),
150 map(
151 tuple((tag_no_case("INNER"), multispace1, tag_no_case("JOIN"))),
152 |_| JoinOperator::InnerJoin,
153 ),
154 map(
155 tuple((tag_no_case("CROSS"), multispace1, tag_no_case("JOIN"))),
156 |_| JoinOperator::CrossJoin,
157 ),
158 map(tag_no_case("STRAIGHT_JOIN"), |_| JoinOperator::StraightJoin),
159 ))(i)
160 }
161}
162
163impl fmt::Display for JoinOperator {
164 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
165 match *self {
166 JoinOperator::Join => write!(f, "JOIN")?,
167 JoinOperator::LeftJoin => write!(f, "LEFT JOIN")?,
168 JoinOperator::LeftOuterJoin => write!(f, "LEFT OUTER JOIN")?,
169 JoinOperator::RightJoin => write!(f, "RIGHT JOIN")?,
170 JoinOperator::InnerJoin => write!(f, "INNER JOIN")?,
171 JoinOperator::CrossJoin => write!(f, "CROSS JOIN")?,
172 JoinOperator::StraightJoin => write!(f, "STRAIGHT JOIN")?,
173 }
174 Ok(())
175 }
176}
177
178#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
182pub enum JoinConstraint {
183 On(ConditionExpression),
184 Using(Vec<Column>),
185}
186
187impl JoinConstraint {
188 pub fn parse(i: &str) -> IResult<&str, JoinConstraint, ParseSQLError<&str>> {
189 let using_clause = map(
190 tuple((
191 tag_no_case("USING"),
192 multispace1,
193 delimited(
194 terminated(tag("("), multispace0),
195 Column::field_list,
196 preceded(multispace0, tag(")")),
197 ),
198 )),
199 |t| JoinConstraint::Using(t.2),
200 );
201
202 let on_condition = alt((
203 delimited(
204 terminated(tag("("), multispace0),
205 ConditionExpression::condition_expr,
206 preceded(multispace0, tag(")")),
207 ),
208 ConditionExpression::condition_expr,
209 ));
210 let on_clause = map(tuple((tag_no_case("ON"), multispace1, on_condition)), |t| {
211 JoinConstraint::On(t.2)
212 });
213
214 alt((using_clause, on_clause))(i)
215 }
216}
217
218impl fmt::Display for JoinConstraint {
219 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
220 match *self {
221 JoinConstraint::On(ref ce) => write!(f, "ON {}", ce)?,
222 JoinConstraint::Using(ref columns) => write!(
223 f,
224 "USING ({})",
225 columns
226 .iter()
227 .map(|c| format!("{}", c))
228 .collect::<Vec<_>>()
229 .join(", ")
230 )?,
231 }
232 Ok(())
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use base::condition::ConditionBase::Field;
239 use base::condition::ConditionExpression::Base;
240 use base::condition::{ConditionExpression, ConditionTree};
241 use base::Operator;
242
243 use super::*;
244
245 #[test]
246 fn parse_join() {
247 let str = "INNER JOIN tagging ON tags.id = tagging.tag_id";
248 let res = JoinClause::parse(str);
249
250 let ct = ConditionTree {
251 left: Box::new(Base(Field(Column::from("tags.id")))),
252 right: Box::new(Base(Field(Column::from("tagging.tag_id")))),
253 operator: Operator::Equal,
254 };
255 let join_cond = ConditionExpression::ComparisonOp(ct);
256 let join = JoinClause {
257 operator: JoinOperator::InnerJoin,
258 right: JoinRightSide::Table(Table::from("tagging")),
259 constraint: JoinConstraint::On(join_cond),
260 };
261
262 let clause = res.unwrap().1;
263 assert_eq!(clause, join);
264 assert_eq!(str, format!("{}", clause));
265 }
266}