sqlparser_mysql/base/
join.rs

1use std::fmt;
2use std::str;
3
4use nom::branch::alt;
5use nom::bytes::complete::{tag, tag_no_case};
6use nom::character::complete::{multispace0, multispace1};
7use nom::combinator::{map, opt};
8use nom::sequence::{delimited, preceded, terminated, tuple};
9use nom::IResult;
10
11use base::column::Column;
12use base::condition::ConditionExpression;
13use base::error::ParseSQLError;
14use base::table::Table;
15use base::CommonParser;
16use dms::SelectStatement;
17
18/// parse `join ...` part
19#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
20pub struct JoinClause {
21    pub operator: JoinOperator,
22    pub right: JoinRightSide,
23    pub constraint: JoinConstraint,
24}
25
26impl JoinClause {
27    pub fn parse(i: &str) -> IResult<&str, JoinClause, ParseSQLError<&str>> {
28        let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple((
29            multispace0,
30            opt(terminated(tag_no_case("NATURAL"), multispace1)),
31            JoinOperator::parse,
32            multispace1,
33            JoinRightSide::parse,
34            multispace1,
35            JoinConstraint::parse,
36        ))(i)?;
37
38        Ok((
39            remaining_input,
40            JoinClause {
41                operator,
42                right,
43                constraint,
44            },
45        ))
46    }
47}
48
49impl fmt::Display for JoinClause {
50    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51        write!(f, "{}", self.operator)?;
52        write!(f, " {}", self.right)?;
53        write!(f, " {}", self.constraint)?;
54        Ok(())
55    }
56}
57
58/// right side of a [JoinOperator]
59#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
60pub enum JoinRightSide {
61    /// A single table.
62    Table(Table),
63    /// A comma-separated (and implicitly joined) sequence of tables.
64    Tables(Vec<Table>),
65    /// A nested selection, represented as (query, alias).
66    NestedSelect(Box<SelectStatement>, Option<String>),
67    /// A nested join clause.
68    NestedJoin(Box<JoinClause>),
69}
70
71impl JoinRightSide {
72    pub fn parse(i: &str) -> IResult<&str, JoinRightSide, ParseSQLError<&str>> {
73        let nested_select = map(
74            tuple((
75                delimited(tag("("), SelectStatement::nested_selection, tag(")")),
76                opt(CommonParser::as_alias),
77            )),
78            |t| JoinRightSide::NestedSelect(Box::new(t.0), t.1.map(String::from)),
79        );
80        let nested_join = map(delimited(tag("("), JoinClause::parse, tag(")")), |nj| {
81            JoinRightSide::NestedJoin(Box::new(nj))
82        });
83        let table = map(Table::table_reference, JoinRightSide::Table);
84        let tables = map(delimited(tag("("), Table::table_list, tag(")")), |tables| {
85            JoinRightSide::Tables(tables)
86        });
87        alt((nested_select, nested_join, table, tables))(i)
88    }
89}
90
91impl fmt::Display for JoinRightSide {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        match *self {
94            JoinRightSide::Table(ref t) => write!(f, "{}", t)?,
95            JoinRightSide::NestedSelect(ref q, ref a) => {
96                write!(f, "({})", q)?;
97                if a.is_some() {
98                    write!(f, " AS {}", a.as_ref().unwrap())?;
99                }
100            }
101            JoinRightSide::NestedJoin(ref jc) => write!(f, "({})", jc)?,
102            _ => unimplemented!(),
103        }
104        Ok(())
105    }
106}
107
108/// join types
109/// - join
110/// - left join
111/// - left outer join
112/// - right join
113/// - inner join
114/// - cross join
115/// - straight join
116#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
117pub enum JoinOperator {
118    Join,
119    LeftJoin,
120    LeftOuterJoin,
121    RightJoin,
122    InnerJoin,
123    CrossJoin,
124    StraightJoin,
125}
126
127impl JoinOperator {
128    // Parse binary comparison operators
129    pub fn parse(i: &str) -> IResult<&str, JoinOperator, ParseSQLError<&str>> {
130        alt((
131            map(tag_no_case("JOIN"), |_| JoinOperator::Join),
132            map(
133                tuple((tag_no_case("LEFT"), multispace1, tag_no_case("JOIN"))),
134                |_| JoinOperator::LeftJoin,
135            ),
136            map(
137                tuple((
138                    tag_no_case("LEFT"),
139                    multispace1,
140                    tag_no_case("OUTER"),
141                    multispace1,
142                    tag_no_case("JOIN"),
143                )),
144                |_| JoinOperator::LeftOuterJoin,
145            ),
146            map(
147                tuple((tag_no_case("RIGHT"), multispace1, tag_no_case("JOIN"))),
148                |_| JoinOperator::RightJoin,
149            ),
150            map(
151                tuple((tag_no_case("INNER"), multispace1, tag_no_case("JOIN"))),
152                |_| JoinOperator::InnerJoin,
153            ),
154            map(
155                tuple((tag_no_case("CROSS"), multispace1, tag_no_case("JOIN"))),
156                |_| JoinOperator::CrossJoin,
157            ),
158            map(tag_no_case("STRAIGHT_JOIN"), |_| JoinOperator::StraightJoin),
159        ))(i)
160    }
161}
162
163impl fmt::Display for JoinOperator {
164    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
165        match *self {
166            JoinOperator::Join => write!(f, "JOIN")?,
167            JoinOperator::LeftJoin => write!(f, "LEFT JOIN")?,
168            JoinOperator::LeftOuterJoin => write!(f, "LEFT OUTER JOIN")?,
169            JoinOperator::RightJoin => write!(f, "RIGHT JOIN")?,
170            JoinOperator::InnerJoin => write!(f, "INNER JOIN")?,
171            JoinOperator::CrossJoin => write!(f, "CROSS JOIN")?,
172            JoinOperator::StraightJoin => write!(f, "STRAIGHT JOIN")?,
173        }
174        Ok(())
175    }
176}
177
178/// join constraint
179/// - on xxx
180/// - using xxx
181#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
182pub enum JoinConstraint {
183    On(ConditionExpression),
184    Using(Vec<Column>),
185}
186
187impl JoinConstraint {
188    pub fn parse(i: &str) -> IResult<&str, JoinConstraint, ParseSQLError<&str>> {
189        let using_clause = map(
190            tuple((
191                tag_no_case("USING"),
192                multispace1,
193                delimited(
194                    terminated(tag("("), multispace0),
195                    Column::field_list,
196                    preceded(multispace0, tag(")")),
197                ),
198            )),
199            |t| JoinConstraint::Using(t.2),
200        );
201
202        let on_condition = alt((
203            delimited(
204                terminated(tag("("), multispace0),
205                ConditionExpression::condition_expr,
206                preceded(multispace0, tag(")")),
207            ),
208            ConditionExpression::condition_expr,
209        ));
210        let on_clause = map(tuple((tag_no_case("ON"), multispace1, on_condition)), |t| {
211            JoinConstraint::On(t.2)
212        });
213
214        alt((using_clause, on_clause))(i)
215    }
216}
217
218impl fmt::Display for JoinConstraint {
219    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
220        match *self {
221            JoinConstraint::On(ref ce) => write!(f, "ON {}", ce)?,
222            JoinConstraint::Using(ref columns) => write!(
223                f,
224                "USING ({})",
225                columns
226                    .iter()
227                    .map(|c| format!("{}", c))
228                    .collect::<Vec<_>>()
229                    .join(", ")
230            )?,
231        }
232        Ok(())
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use base::condition::ConditionBase::Field;
239    use base::condition::ConditionExpression::Base;
240    use base::condition::{ConditionExpression, ConditionTree};
241    use base::Operator;
242
243    use super::*;
244
245    #[test]
246    fn parse_join() {
247        let str = "INNER JOIN tagging ON tags.id = tagging.tag_id";
248        let res = JoinClause::parse(str);
249
250        let ct = ConditionTree {
251            left: Box::new(Base(Field(Column::from("tags.id")))),
252            right: Box::new(Base(Field(Column::from("tagging.tag_id")))),
253            operator: Operator::Equal,
254        };
255        let join_cond = ConditionExpression::ComparisonOp(ct);
256        let join = JoinClause {
257            operator: JoinOperator::InnerJoin,
258            right: JoinRightSide::Table(Table::from("tagging")),
259            constraint: JoinConstraint::On(join_cond),
260        };
261
262        let clause = res.unwrap().1;
263        assert_eq!(clause, join);
264        assert_eq!(str, format!("{}", clause));
265    }
266}