Skip to main content

spacetimedb_query_builder/
lib.rs

1pub mod expr;
2pub mod join;
3pub mod table;
4
5pub use expr::*;
6pub use join::*;
7use spacetimedb_lib::{sats::impl_st, AlgebraicType, SpacetimeType};
8pub use table::*;
9
10const QUERY_VIEW_RETURN_TAG: &str = "__query__";
11
12/// Trait implemented by all query builder types. Use `impl Query<T>` as a
13/// return type for view functions and helpers.
14pub trait Query<T> {
15    fn into_sql(self) -> String;
16}
17
18/// The concrete SQL query produced by calling `.build()` on a builder.
19pub struct RawQuery<T> {
20    pub(crate) sql: String,
21    _marker: std::marker::PhantomData<T>,
22}
23
24impl<T> RawQuery<T> {
25    pub fn new(sql: String) -> Self {
26        Self {
27            sql,
28            _marker: std::marker::PhantomData,
29        }
30    }
31
32    pub fn sql(&self) -> &str {
33        &self.sql
34    }
35}
36
37impl<T> Query<T> for RawQuery<T> {
38    fn into_sql(self) -> String {
39        self.sql
40    }
41}
42
43impl_st!(
44    [T: SpacetimeType] RawQuery<T>,
45    ts => AlgebraicType::product([(QUERY_VIEW_RETURN_TAG, T::make_type(ts))])
46);
47
48#[cfg(test)]
49mod tests {
50    use spacetimedb_lib::{sats::i256, TimeDuration};
51
52    use super::*;
53    struct User;
54    #[derive(Clone)]
55    struct UserCols {
56        pub id: Col<User, i32>,
57        pub name: Col<User, String>,
58        pub age: Col<User, i32>,
59        pub online: Col<User, bool>,
60    }
61    impl UserCols {
62        fn new(table_name: &'static str) -> Self {
63            Self {
64                id: Col::new(table_name, "id"),
65                name: Col::new(table_name, "name"),
66                age: Col::new(table_name, "age"),
67                online: Col::new(table_name, "online"),
68            }
69        }
70    }
71    impl HasCols for User {
72        type Cols = UserCols;
73        fn cols(table_name: &'static str) -> Self::Cols {
74            UserCols::new(table_name)
75        }
76    }
77    fn users() -> Table<User> {
78        Table::new("users")
79    }
80    fn other() -> Table<Other> {
81        Table::new("other")
82    }
83    struct OtherCols {
84        pub uid: Col<Other, i32>,
85    }
86
87    impl HasCols for Other {
88        type Cols = OtherCols;
89        fn cols(table: &'static str) -> Self::Cols {
90            OtherCols {
91                uid: Col::new(table, "uid"),
92            }
93        }
94    }
95    struct IxUserCols {
96        pub id: IxCol<User, i32>,
97    }
98    impl HasIxCols for User {
99        type IxCols = IxUserCols;
100        fn ix_cols(table_name: &'static str) -> Self::IxCols {
101            IxUserCols {
102                id: IxCol::new(table_name, "id"),
103            }
104        }
105    }
106    struct Other;
107    #[derive(Clone)]
108    struct IxOtherCols {
109        pub uid: IxCol<Other, i32>,
110    }
111    impl HasIxCols for Other {
112        type IxCols = IxOtherCols;
113        fn ix_cols(table_name: &'static str) -> Self::IxCols {
114            IxOtherCols {
115                uid: IxCol::new(table_name, "uid"),
116            }
117        }
118    }
119    impl CanBeLookupTable for User {}
120    impl CanBeLookupTable for Other {}
121    fn norm(s: &str) -> String {
122        s.split_whitespace().collect::<Vec<_>>().join(" ")
123    }
124    #[test]
125    fn test_simple_select() {
126        let q = users().build();
127        assert_eq!(q.sql(), r#"SELECT * FROM "users""#);
128    }
129    #[test]
130    fn test_where_literal() {
131        let q = users().r#where(|c| c.id.eq(10)).build();
132        let expected = r#"SELECT * FROM "users" WHERE ("users"."id" = 10)"#;
133        assert_eq!(norm(q.sql()), norm(expected));
134    }
135    #[test]
136    fn test_where_multiple_predicates() {
137        let q = users().r#where(|c| c.id.eq(10)).r#where(|c| c.age.gt(18)).build();
138        let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 10) AND ("users"."age" > 18))"#;
139        assert_eq!(norm(q.sql()), norm(expected));
140    }
141
142    #[test]
143    fn test_where_bool_column_directly() {
144        let q = users().r#where(|c| c.online).build();
145        let expected = r#"SELECT * FROM "users" WHERE ("users"."online" = TRUE)"#;
146        assert_eq!(norm(q.sql()), norm(expected));
147    }
148
149    #[test]
150    fn test_where_gte_lte() {
151        let q = users().r#where(|c| c.age.gte(18)).r#where(|c| c.age.lte(30)).build();
152        let expected = r#"SELECT * FROM "users" WHERE (("users"."age" >= 18) AND ("users"."age" <= 30))"#;
153        assert_eq!(norm(q.sql()), norm(expected));
154    }
155
156    #[test]
157    fn test_column_column_comparison() {
158        let q = users().r#where(|c| c.age.gt(c.id)).build();
159        let expected = r#"SELECT * FROM "users" WHERE ("users"."age" > "users"."id")"#;
160        assert_eq!(norm(q.sql()), norm(expected));
161    }
162    #[test]
163    fn test_ne_comparison() {
164        let q = users().r#where(|c| c.name.ne("Shub".to_string())).build();
165        assert!(q.sql().contains("name"), "Expected a name comparison");
166        assert!(q.sql().contains("<>"));
167    }
168
169    #[test]
170    fn test_not_comparison() {
171        let q = users().r#where(|c| c.name.eq("Alice".to_string()).not()).build();
172        let expected = r#"SELECT * FROM "users" WHERE (NOT ("users"."name" = 'Alice'))"#;
173        assert_eq!(norm(q.sql()), norm(expected));
174    }
175
176    #[test]
177    fn test_not_with_and() {
178        let q = users()
179            .r#where(|c| c.name.eq("Alice".to_string()).not().and(c.age.gt(18)))
180            .build();
181        let expected = r#"SELECT * FROM "users" WHERE ((NOT ("users"."name" = 'Alice')) AND ("users"."age" > 18))"#;
182        assert_eq!(norm(q.sql()), norm(expected));
183    }
184
185    #[test]
186    fn test_filter_alias() {
187        let q = users().filter(|c| c.id.eq(5)).filter(|c| c.age.lt(30)).build();
188        let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 5) AND ("users"."age" < 30))"#;
189        assert_eq!(norm(q.sql()), norm(expected));
190    }
191
192    #[test]
193    fn test_or_comparison() {
194        let q = users()
195            .r#where(|c| c.name.ne("Shub".to_string()).or(c.name.ne("Pop".to_string())))
196            .build();
197
198        let expected = r#"SELECT * FROM "users" WHERE (("users"."name" <> 'Shub') OR ("users"."name" <> 'Pop'))"#;
199        assert_eq!(q.sql, expected);
200    }
201
202    #[test]
203    fn test_format_expr_column_literal() {
204        let expr = BoolExpr::Eq(
205            Operand::Column(ColumnRef::<User>::new("user", "id")),
206            Operand::Literal(LiteralValue::new("42".to_string())),
207        );
208        let sql = format_expr(&expr);
209        assert!(sql.contains("id"), "Missing col");
210        assert!(sql.contains("42"), "Missing literal");
211    }
212
213    #[test]
214    fn test_format_semi_join_expr() {
215        let user = users();
216        let other = other();
217        let sql = user.left_semijoin(other, |u, o| u.id.eq(o.uid)).build().sql;
218        let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid""#;
219        assert_eq!(sql, expected);
220    }
221
222    #[test]
223    fn test_left_semijoin_with_where_expr() {
224        let user = users();
225        let o = other();
226        let sql = user
227            .left_semijoin(o, |u, o| u.id.eq(o.uid))
228            .r#where(|u| u.id.eq(1i32))
229            .r#where(|u| u.id.gt(10))
230            .build()
231            .sql;
232        let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("users"."id" = 1) AND ("users"."id" > 10))"#;
233        assert_eq!(sql, expected);
234        let user = users();
235        let other = other();
236        let sql2 = user
237            .r#where(|u| u.id.eq(1))
238            .r#where(|u| u.id.gt(10))
239            .left_semijoin(other, |u, o| u.id.eq(o.uid))
240            .build()
241            .sql;
242        assert_eq!(sql2, expected);
243    }
244    #[test]
245    fn test_right_semijoin_with_where_expr() {
246        let user = users();
247        let o = other();
248        let sql = user
249            .right_semijoin(o, |u, o| u.id.eq(o.uid))
250            .r#where(|o| o.uid.eq(1))
251            .r#where(|o| o.uid.gt(10))
252            .build()
253            .sql;
254        let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("other"."uid" = 1) AND ("other"."uid" > 10))"#;
255        assert_eq!(sql, expected);
256    }
257
258    #[test]
259    fn test_right_semijoin_with_left_and_right_where_expr() {
260        let user = users();
261        let o = other();
262        let sql = user
263            .r#where(|u| u.id.eq(1))
264            .right_semijoin(o, |u, o| u.id.eq(o.uid))
265            .r#where(|o| o.uid.gt(10))
266            .build()
267            .sql;
268        let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE ("users"."id" = 1) AND ("other"."uid" > 10)"#;
269        assert_eq!(sql, expected);
270    }
271
272    #[test]
273    fn test_literals() {
274        use spacetimedb_lib::{ConnectionId, Identity};
275
276        struct Player;
277        struct PlayerCols {
278            score: Col<Player, i32>,
279            name: Col<Player, String>,
280            active: Col<Player, bool>,
281            connection_id: Col<Player, ConnectionId>,
282            cells: Col<Player, i256>,
283            identity: Col<Player, Identity>,
284            ts: Col<Player, spacetimedb_lib::Timestamp>,
285            bytes: Col<Player, Vec<u8>>,
286        }
287
288        impl HasCols for Player {
289            type Cols = PlayerCols;
290            fn cols(table_name: &'static str) -> Self::Cols {
291                PlayerCols {
292                    score: Col::new(table_name, "score"),
293                    name: Col::new(table_name, "name"),
294                    active: Col::new(table_name, "active"),
295                    connection_id: Col::new(table_name, "connection_id"),
296                    cells: Col::new(table_name, "cells"),
297                    identity: Col::new(table_name, "identity"),
298                    ts: Col::new(table_name, "ts"),
299                    bytes: Col::new(table_name, "bytes"),
300                }
301            }
302        }
303
304        let table = Table::<Player>::new("player");
305        let q = table.r#where(|c| c.score.eq(100)).build();
306
307        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."score" = 100)"#);
308
309        let table = Table::<Player>::new("player");
310        let q = table.r#where(|c| c.name.ne("Alice".to_string())).build();
311
312        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."name" <> 'Alice')"#);
313
314        let table = Table::<Player>::new("player");
315        let q = table.r#where(|c| c.active.eq(true)).build();
316
317        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."active" = TRUE)"#);
318
319        let table = Table::<Player>::new("player");
320        let q = table.r#where(|c| c.connection_id.eq(ConnectionId::ZERO)).build();
321
322        assert_eq!(
323            q.sql,
324            r#"SELECT * FROM "player" WHERE ("player"."connection_id" = 0x00000000000000000000000000000000)"#
325        );
326
327        let big_int: i256 = (i256::ONE << 120) * i256::from(-1);
328
329        let table = Table::<Player>::new("player");
330        let q = table.r#where(|c| c.cells.gt(big_int)).build();
331
332        assert_eq!(
333            q.sql,
334            r#"SELECT * FROM "player" WHERE ("player"."cells" > -1329227995784915872903807060280344576)"#,
335        );
336
337        let table = Table::<Player>::new("player");
338        let q = table.r#where(|c| c.identity.ne(Identity::ONE)).build();
339
340        assert_eq!(
341            q.sql,
342            r#"SELECT * FROM "player" WHERE ("player"."identity" <> 0x0000000000000000000000000000000000000000000000000000000000000001)"#
343        );
344
345        let ts = spacetimedb_lib::Timestamp::UNIX_EPOCH + TimeDuration::from_micros(1000);
346
347        let table = Table::<Player>::new("player");
348        let q = table.r#where(|c| c.ts.eq(ts)).build();
349        assert_eq!(
350            q.sql,
351            r#"SELECT * FROM "player" WHERE ("player"."ts" = '1970-01-01T00:00:00.001+00:00')"#
352        );
353
354        let table = Table::<Player>::new("player");
355        let q = table.r#where(|c| c.bytes.eq(vec![1, 2, 3, 4, 255])).build();
356
357        assert_eq!(
358            q.sql,
359            r#"SELECT * FROM "player" WHERE ("player"."bytes" = 0x01020304ff)"#
360        );
361    }
362}