1use crate::{
2 Dataset, DynQuery, Expression, TableRef,
3 writer::{Context, SqlWriter},
4};
5use proc_macro2::{TokenStream, TokenTree};
6use quote::{ToTokens, TokenStreamExt, quote};
7use syn::{
8 Ident,
9 parse::{Parse, ParseStream},
10};
11
12#[derive(Debug)]
16pub struct Join<L: Dataset, R: Dataset, E: Expression> {
17 pub join: JoinType,
19 pub lhs: L,
21 pub rhs: R,
23 pub on: Option<E>,
25}
26
27#[derive(Default, Clone, Copy, Debug)]
29pub enum JoinType {
30 #[default]
32 Default,
33 Inner,
34 Outer,
35 Left,
36 Right,
37 Cross,
38 Natural,
39}
40
41impl<L: Dataset, R: Dataset, E: Expression> Dataset for Join<L, R, E> {
42 fn qualified_columns() -> bool
43 where
44 Self: Sized,
45 {
46 true
47 }
48 fn write_query(&self, writer: &dyn SqlWriter, context: &mut Context, out: &mut DynQuery) {
49 writer.write_join(
50 context,
51 out,
52 &Join {
53 join: self.join,
54 lhs: &self.lhs,
55 rhs: &self.rhs,
56 on: self.on.as_ref().map(|v| v as &dyn Expression),
57 },
58 );
59 }
60
61 fn table_ref(&self) -> TableRef {
62 let mut result = self.lhs.table_ref();
63 let other = self.rhs.table_ref();
64 result.name = Default::default();
65 if result.schema != other.schema {
66 result.schema = Default::default();
67 }
68 result.alias = Default::default();
69 result
70 }
71}
72
73impl Parse for JoinType {
74 fn parse(input: ParseStream) -> syn::Result<Self> {
75 let tokens = input.cursor().token_stream().into_iter().map(|t| match t {
76 TokenTree::Ident(ident) => ident.to_string(),
77 _ => "".to_string(),
78 });
79 let patterns: &[(&[&str], JoinType)] = &[
80 (&["JOIN"], JoinType::Default),
81 (&["INNER", "JOIN"], JoinType::Inner),
82 (&["FULL", "OUTER", "JOIN"], JoinType::Outer),
83 (&["OUTER", "JOIN"], JoinType::Outer),
84 (&["LEFT", "OUTER", "JOIN"], JoinType::Left),
85 (&["LEFT", "JOIN"], JoinType::Left),
86 (&["RIGHT", "OUTER", "JOIN"], JoinType::Right),
87 (&["RIGHT", "JOIN"], JoinType::Right),
88 (&["CROSS", "JOIN"], JoinType::Cross),
89 (&["NATURAL", "JOIN"], JoinType::Natural),
90 ];
91 for (keywords, join_type) in patterns {
92 let it = tokens.clone().take(keywords.len());
93 if it.eq(keywords.iter().copied()) {
94 for _ in 0..keywords.len() {
95 input.parse::<Ident>().expect(&format!(
96 "Unexpected error, the input should contain {:?} as next Ident tokens at this point",
97 keywords
98 ));
99 }
100 return Ok(*join_type);
101 }
102 }
103 Err(syn::Error::new(input.span(), "Not a join keyword"))
104 }
105}
106
107impl ToTokens for JoinType {
108 fn to_tokens(&self, tokens: &mut TokenStream) {
109 tokens.append_all(match self {
110 JoinType::Default => quote! { ::tank::JoinType::Default },
111 JoinType::Inner => quote! { ::tank::JoinType::Inner },
112 JoinType::Outer => quote! { ::tank::JoinType::Outer },
113 JoinType::Left => quote! { ::tank::JoinType::Left },
114 JoinType::Right => quote! { ::tank::JoinType::Right },
115 JoinType::Cross => quote! { ::tank::JoinType::Cross },
116 JoinType::Natural => quote! { ::tank::JoinType::Natural },
117 });
118 }
119}