Skip to main content

spacetimedb_query_builder/
join.rs

1use crate::TableNameStr;
2
3use super::{
4    expr::{format_expr, BoolExpr},
5    table::{CanBeLookupTable, ColumnRef, HasCols, HasIxCols, Table},
6    Query, RawQuery,
7};
8use std::marker::PhantomData;
9
10/// Indexed columns for joins
11///
12/// Joins are performed on indexed columns, Tables that implement `HasIxCols`
13/// provide access to their indexed columns.
14pub struct IxCol<T, V> {
15    pub(super) col: ColumnRef<T>,
16    _marker: PhantomData<V>,
17}
18
19impl<T, V> IxCol<T, V> {
20    pub fn new(table_name: TableNameStr, column: &'static str) -> Self {
21        Self {
22            col: ColumnRef::new(table_name, column),
23            _marker: PhantomData,
24        }
25    }
26}
27
28impl<T, V> Copy for IxCol<T, V> {}
29impl<T, V> Clone for IxCol<T, V> {
30    fn clone(&self) -> Self {
31        *self
32    }
33}
34
35pub struct IxJoinEq<L, R, V> {
36    pub(super) lhs_col: ColumnRef<L>,
37    pub(super) rhs_col: ColumnRef<R>,
38    _marker: PhantomData<V>,
39}
40
41impl<T, V> IxCol<T, V> {
42    pub fn eq<R: HasIxCols>(self, rhs: IxCol<R, V>) -> IxJoinEq<T, R, V> {
43        IxJoinEq {
44            lhs_col: self.col,
45            rhs_col: rhs.col,
46            _marker: PhantomData,
47        }
48    }
49}
50
51// Left semijoin: filters and returns left table rows
52pub struct LeftSemiJoin<L> {
53    pub(super) left_col: ColumnRef<L>,
54    pub(super) right_table: &'static str,
55    pub(super) right_col: &'static str,
56    pub(super) where_expr: Option<BoolExpr<L>>,
57}
58
59// Right semijoin: returns right table rows, but remembers left conditions
60pub struct RightSemiJoin<R, L> {
61    pub(super) left_col: ColumnRef<L>,
62    pub(super) right_col: ColumnRef<R>,
63    pub(super) left_where_expr: Option<BoolExpr<L>>,
64    pub(super) right_where_expr: Option<BoolExpr<R>>,
65    _left_marker: PhantomData<L>,
66}
67
68impl<L: HasIxCols> Table<L> {
69    pub fn left_semijoin<R: CanBeLookupTable, V>(
70        self,
71        right: Table<R>,
72        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
73    ) -> LeftSemiJoin<L> {
74        let join = on(&L::ix_cols(self.name()), &R::ix_cols(right.name()));
75        LeftSemiJoin {
76            left_col: join.lhs_col,
77            right_table: right.name(),
78            right_col: join.rhs_col.column_name(),
79            where_expr: None,
80        }
81    }
82
83    pub fn right_semijoin<R: CanBeLookupTable, V>(
84        self,
85        right: Table<R>,
86        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
87    ) -> RightSemiJoin<R, L> {
88        let join = on(&L::ix_cols(self.name()), &R::ix_cols(right.name()));
89        RightSemiJoin {
90            left_col: join.lhs_col,
91            right_col: join.rhs_col,
92            left_where_expr: None,
93            right_where_expr: None,
94            _left_marker: PhantomData,
95        }
96    }
97}
98
99impl<L: HasIxCols> super::FromWhere<L> {
100    pub fn left_semijoin<R: CanBeLookupTable, V>(
101        self,
102        right: Table<R>,
103        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
104    ) -> LeftSemiJoin<L> {
105        let join = on(&L::ix_cols(self.table_name), &R::ix_cols(right.name()));
106        LeftSemiJoin {
107            left_col: join.lhs_col,
108            right_table: right.name(),
109            right_col: join.rhs_col.column_name(),
110            where_expr: Some(self.expr),
111        }
112    }
113
114    pub fn right_semijoin<R: CanBeLookupTable, V>(
115        self,
116        right: Table<R>,
117        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
118    ) -> RightSemiJoin<R, L> {
119        let join = on(&L::ix_cols(self.table_name), &R::ix_cols(right.name()));
120        RightSemiJoin {
121            left_col: join.lhs_col,
122            right_col: join.rhs_col,
123            left_where_expr: Some(self.expr),
124            right_where_expr: None,
125            _left_marker: PhantomData,
126        }
127    }
128}
129
130impl<L: HasCols> Query<L> for LeftSemiJoin<L> {
131    fn into_sql(self) -> String {
132        self.build().into_sql()
133    }
134}
135
136impl<R: HasCols, L: HasCols> Query<R> for RightSemiJoin<R, L> {
137    fn into_sql(self) -> String {
138        self.build().into_sql()
139    }
140}
141
142// LeftSemiJoin where() operates on L
143impl<L: HasCols> LeftSemiJoin<L> {
144    pub fn r#where<F, E>(self, f: F) -> Self
145    where
146        F: Fn(&L::Cols) -> E,
147        E: Into<BoolExpr<L>>,
148    {
149        let extra = f(&L::cols(self.left_col.table_name())).into();
150        let new = match self.where_expr {
151            Some(existing) => Some(existing.and(extra)),
152            None => Some(extra),
153        };
154        Self {
155            left_col: self.left_col,
156            right_table: self.right_table,
157            right_col: self.right_col,
158            where_expr: new,
159        }
160    }
161
162    // Filter is an alias for where
163    pub fn filter<F, E>(self, f: F) -> Self
164    where
165        F: Fn(&L::Cols) -> E,
166        E: Into<BoolExpr<L>>,
167    {
168        self.r#where(f)
169    }
170
171    pub fn build(self) -> RawQuery<L> {
172        let where_clause = self
173            .where_expr
174            .map(|e| format!(" WHERE {}", format_expr(&e)))
175            .unwrap_or_default();
176
177        let sql = format!(
178            r#"SELECT "{}".* FROM "{}" JOIN "{}" ON "{}"."{}" = "{}"."{}"{}"#,
179            self.left_col.table_name(),
180            self.left_col.table_name(),
181            self.right_table,
182            self.left_col.table_name(),
183            self.left_col.column_name(),
184            self.right_table,
185            self.right_col,
186            where_clause
187        );
188        RawQuery::new(sql)
189    }
190}
191
192// RightSemiJoin where() operates on R
193impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
194    pub fn r#where<F, E>(self, f: F) -> Self
195    where
196        F: Fn(&R::Cols) -> E,
197        E: Into<BoolExpr<R>>,
198    {
199        let extra = f(&R::cols(self.right_col.table_name())).into();
200        let new = match self.right_where_expr {
201            Some(existing) => Some(existing.and(extra)),
202            None => Some(extra),
203        };
204        Self {
205            left_col: self.left_col,
206            right_col: self.right_col,
207            left_where_expr: self.left_where_expr,
208            right_where_expr: new,
209            _left_marker: PhantomData,
210        }
211    }
212
213    // Filter is an alias for where
214    pub fn filter<F, E>(self, f: F) -> Self
215    where
216        F: Fn(&R::Cols) -> E,
217        E: Into<BoolExpr<R>>,
218    {
219        self.r#where(f)
220    }
221
222    pub fn build(self) -> RawQuery<R> {
223        let mut where_parts = Vec::new();
224
225        if let Some(left_expr) = self.left_where_expr {
226            where_parts.push(format_expr(&left_expr));
227        }
228
229        if let Some(right_expr) = self.right_where_expr {
230            where_parts.push(format_expr(&right_expr));
231        }
232
233        let where_clause = if !where_parts.is_empty() {
234            format!(" WHERE {}", where_parts.join(" AND "))
235        } else {
236            String::new()
237        };
238
239        let sql = format!(
240            r#"SELECT "{}".* FROM "{}" JOIN "{}" ON "{}"."{}" = "{}"."{}"{}"#,
241            self.right_col.table_name(),
242            self.left_col.table_name(),
243            self.right_col.table_name(),
244            self.left_col.table_name(),
245            self.left_col.column_name(),
246            self.right_col.table_name(),
247            self.right_col.column_name(),
248            where_clause
249        );
250        RawQuery::new(sql)
251    }
252}