spacetimedb_query_builder/
join.rs1use crate::TableNameStr;
2
3use super::{
4 expr::{format_expr, BoolExpr},
5 table::{CanBeLookupTable, ColumnRef, HasCols, HasIxCols, Table},
6 Query, RawQuery,
7};
8use std::marker::PhantomData;
9
10pub struct IxCol<T, V> {
15 pub(super) col: ColumnRef<T>,
16 _marker: PhantomData<V>,
17}
18
19impl<T, V> IxCol<T, V> {
20 pub fn new(table_name: TableNameStr, column: &'static str) -> Self {
21 Self {
22 col: ColumnRef::new(table_name, column),
23 _marker: PhantomData,
24 }
25 }
26}
27
28impl<T, V> Copy for IxCol<T, V> {}
29impl<T, V> Clone for IxCol<T, V> {
30 fn clone(&self) -> Self {
31 *self
32 }
33}
34
35pub struct IxJoinEq<L, R, V> {
36 pub(super) lhs_col: ColumnRef<L>,
37 pub(super) rhs_col: ColumnRef<R>,
38 _marker: PhantomData<V>,
39}
40
41impl<T, V> IxCol<T, V> {
42 pub fn eq<R: HasIxCols>(self, rhs: IxCol<R, V>) -> IxJoinEq<T, R, V> {
43 IxJoinEq {
44 lhs_col: self.col,
45 rhs_col: rhs.col,
46 _marker: PhantomData,
47 }
48 }
49}
50
51pub struct LeftSemiJoin<L> {
53 pub(super) left_col: ColumnRef<L>,
54 pub(super) right_table: &'static str,
55 pub(super) right_col: &'static str,
56 pub(super) where_expr: Option<BoolExpr<L>>,
57}
58
59pub struct RightSemiJoin<R, L> {
61 pub(super) left_col: ColumnRef<L>,
62 pub(super) right_col: ColumnRef<R>,
63 pub(super) left_where_expr: Option<BoolExpr<L>>,
64 pub(super) right_where_expr: Option<BoolExpr<R>>,
65 _left_marker: PhantomData<L>,
66}
67
68impl<L: HasIxCols> Table<L> {
69 pub fn left_semijoin<R: CanBeLookupTable, V>(
70 self,
71 right: Table<R>,
72 on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
73 ) -> LeftSemiJoin<L> {
74 let join = on(&L::ix_cols(self.name()), &R::ix_cols(right.name()));
75 LeftSemiJoin {
76 left_col: join.lhs_col,
77 right_table: right.name(),
78 right_col: join.rhs_col.column_name(),
79 where_expr: None,
80 }
81 }
82
83 pub fn right_semijoin<R: CanBeLookupTable, V>(
84 self,
85 right: Table<R>,
86 on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
87 ) -> RightSemiJoin<R, L> {
88 let join = on(&L::ix_cols(self.name()), &R::ix_cols(right.name()));
89 RightSemiJoin {
90 left_col: join.lhs_col,
91 right_col: join.rhs_col,
92 left_where_expr: None,
93 right_where_expr: None,
94 _left_marker: PhantomData,
95 }
96 }
97}
98
99impl<L: HasIxCols> super::FromWhere<L> {
100 pub fn left_semijoin<R: CanBeLookupTable, V>(
101 self,
102 right: Table<R>,
103 on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
104 ) -> LeftSemiJoin<L> {
105 let join = on(&L::ix_cols(self.table_name), &R::ix_cols(right.name()));
106 LeftSemiJoin {
107 left_col: join.lhs_col,
108 right_table: right.name(),
109 right_col: join.rhs_col.column_name(),
110 where_expr: Some(self.expr),
111 }
112 }
113
114 pub fn right_semijoin<R: CanBeLookupTable, V>(
115 self,
116 right: Table<R>,
117 on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
118 ) -> RightSemiJoin<R, L> {
119 let join = on(&L::ix_cols(self.table_name), &R::ix_cols(right.name()));
120 RightSemiJoin {
121 left_col: join.lhs_col,
122 right_col: join.rhs_col,
123 left_where_expr: Some(self.expr),
124 right_where_expr: None,
125 _left_marker: PhantomData,
126 }
127 }
128}
129
130impl<L: HasCols> Query<L> for LeftSemiJoin<L> {
131 fn into_sql(self) -> String {
132 self.build().into_sql()
133 }
134}
135
136impl<R: HasCols, L: HasCols> Query<R> for RightSemiJoin<R, L> {
137 fn into_sql(self) -> String {
138 self.build().into_sql()
139 }
140}
141
142impl<L: HasCols> LeftSemiJoin<L> {
144 pub fn r#where<F>(self, f: F) -> Self
145 where
146 F: Fn(&L::Cols) -> BoolExpr<L>,
147 {
148 let extra = f(&L::cols(self.left_col.table_name()));
149 let new = match self.where_expr {
150 Some(existing) => Some(existing.and(extra)),
151 None => Some(extra),
152 };
153 Self {
154 left_col: self.left_col,
155 right_table: self.right_table,
156 right_col: self.right_col,
157 where_expr: new,
158 }
159 }
160
161 pub fn filter<F>(self, f: F) -> Self
163 where
164 F: Fn(&L::Cols) -> BoolExpr<L>,
165 {
166 self.r#where(f)
167 }
168
169 pub fn build(self) -> RawQuery<L> {
170 let where_clause = self
171 .where_expr
172 .map(|e| format!(" WHERE {}", format_expr(&e)))
173 .unwrap_or_default();
174
175 let sql = format!(
176 r#"SELECT "{}".* FROM "{}" JOIN "{}" ON "{}"."{}" = "{}"."{}"{}"#,
177 self.left_col.table_name(),
178 self.left_col.table_name(),
179 self.right_table,
180 self.left_col.table_name(),
181 self.left_col.column_name(),
182 self.right_table,
183 self.right_col,
184 where_clause
185 );
186 RawQuery::new(sql)
187 }
188}
189
190impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
192 pub fn r#where<F>(self, f: F) -> Self
193 where
194 F: Fn(&R::Cols) -> BoolExpr<R>,
195 {
196 let extra = f(&R::cols(self.right_col.table_name()));
197 let new = match self.right_where_expr {
198 Some(existing) => Some(existing.and(extra)),
199 None => Some(extra),
200 };
201 Self {
202 left_col: self.left_col,
203 right_col: self.right_col,
204 left_where_expr: self.left_where_expr,
205 right_where_expr: new,
206 _left_marker: PhantomData,
207 }
208 }
209
210 pub fn filter<F>(self, f: F) -> Self
212 where
213 F: Fn(&R::Cols) -> BoolExpr<R>,
214 {
215 self.r#where(f)
216 }
217
218 pub fn build(self) -> RawQuery<R> {
219 let mut where_parts = Vec::new();
220
221 if let Some(left_expr) = self.left_where_expr {
222 where_parts.push(format_expr(&left_expr));
223 }
224
225 if let Some(right_expr) = self.right_where_expr {
226 where_parts.push(format_expr(&right_expr));
227 }
228
229 let where_clause = if !where_parts.is_empty() {
230 format!(" WHERE {}", where_parts.join(" AND "))
231 } else {
232 String::new()
233 };
234
235 let sql = format!(
236 r#"SELECT "{}".* FROM "{}" JOIN "{}" ON "{}"."{}" = "{}"."{}"{}"#,
237 self.right_col.table_name(),
238 self.left_col.table_name(),
239 self.right_col.table_name(),
240 self.left_col.table_name(),
241 self.left_col.column_name(),
242 self.right_col.table_name(),
243 self.right_col.column_name(),
244 where_clause
245 );
246 RawQuery::new(sql)
247 }
248}