Skip to main content

spacetimedb_query_builder/
lib.rs

1pub mod expr;
2pub mod join;
3pub mod table;
4
5pub use expr::*;
6pub use join::*;
7use spacetimedb_lib::{sats::impl_st, AlgebraicType, SpacetimeType};
8pub use table::*;
9
10/// Trait implemented by all query builder types. Use `impl Query<T>` as a
11/// return type for view functions and helpers.
12pub trait Query<T> {
13    fn into_sql(self) -> String;
14}
15
16/// The concrete SQL query produced by calling `.build()` on a builder.
17pub struct RawQuery<T> {
18    pub(crate) sql: String,
19    _marker: std::marker::PhantomData<T>,
20}
21
22impl<T> RawQuery<T> {
23    pub fn new(sql: String) -> Self {
24        Self {
25            sql,
26            _marker: std::marker::PhantomData,
27        }
28    }
29
30    pub fn sql(&self) -> &str {
31        &self.sql
32    }
33}
34
35impl<T> Query<T> for RawQuery<T> {
36    fn into_sql(self) -> String {
37        self.sql
38    }
39}
40
41impl_st!([T: SpacetimeType] RawQuery<T>, ts => AlgebraicType::option(T::make_type(ts)));
42
43#[cfg(test)]
44mod tests {
45    use spacetimedb_lib::{sats::i256, TimeDuration};
46
47    use super::*;
48    struct User;
49    #[derive(Clone)]
50    struct UserCols {
51        pub id: Col<User, i32>,
52        pub name: Col<User, String>,
53        pub age: Col<User, i32>,
54        pub online: Col<User, bool>,
55    }
56    impl UserCols {
57        fn new(table_name: &'static str) -> Self {
58            Self {
59                id: Col::new(table_name, "id"),
60                name: Col::new(table_name, "name"),
61                age: Col::new(table_name, "age"),
62                online: Col::new(table_name, "online"),
63            }
64        }
65    }
66    impl HasCols for User {
67        type Cols = UserCols;
68        fn cols(table_name: &'static str) -> Self::Cols {
69            UserCols::new(table_name)
70        }
71    }
72    fn users() -> Table<User> {
73        Table::new("users")
74    }
75    fn other() -> Table<Other> {
76        Table::new("other")
77    }
78    struct OtherCols {
79        pub uid: Col<Other, i32>,
80    }
81
82    impl HasCols for Other {
83        type Cols = OtherCols;
84        fn cols(table: &'static str) -> Self::Cols {
85            OtherCols {
86                uid: Col::new(table, "uid"),
87            }
88        }
89    }
90    struct IxUserCols {
91        pub id: IxCol<User, i32>,
92    }
93    impl HasIxCols for User {
94        type IxCols = IxUserCols;
95        fn ix_cols(table_name: &'static str) -> Self::IxCols {
96            IxUserCols {
97                id: IxCol::new(table_name, "id"),
98            }
99        }
100    }
101    struct Other;
102    #[derive(Clone)]
103    struct IxOtherCols {
104        pub uid: IxCol<Other, i32>,
105    }
106    impl HasIxCols for Other {
107        type IxCols = IxOtherCols;
108        fn ix_cols(table_name: &'static str) -> Self::IxCols {
109            IxOtherCols {
110                uid: IxCol::new(table_name, "uid"),
111            }
112        }
113    }
114    impl CanBeLookupTable for User {}
115    impl CanBeLookupTable for Other {}
116    fn norm(s: &str) -> String {
117        s.split_whitespace().collect::<Vec<_>>().join(" ")
118    }
119    #[test]
120    fn test_simple_select() {
121        let q = users().build();
122        assert_eq!(q.sql(), r#"SELECT * FROM "users""#);
123    }
124    #[test]
125    fn test_where_literal() {
126        let q = users().r#where(|c| c.id.eq(10)).build();
127        let expected = r#"SELECT * FROM "users" WHERE ("users"."id" = 10)"#;
128        assert_eq!(norm(q.sql()), norm(expected));
129    }
130    #[test]
131    fn test_where_multiple_predicates() {
132        let q = users().r#where(|c| c.id.eq(10)).r#where(|c| c.age.gt(18)).build();
133        let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 10) AND ("users"."age" > 18))"#;
134        assert_eq!(norm(q.sql()), norm(expected));
135    }
136
137    #[test]
138    fn test_where_bool_column_directly() {
139        let q = users().r#where(|c| c.online).build();
140        let expected = r#"SELECT * FROM "users" WHERE ("users"."online" = TRUE)"#;
141        assert_eq!(norm(q.sql()), norm(expected));
142    }
143
144    #[test]
145    fn test_where_gte_lte() {
146        let q = users().r#where(|c| c.age.gte(18)).r#where(|c| c.age.lte(30)).build();
147        let expected = r#"SELECT * FROM "users" WHERE (("users"."age" >= 18) AND ("users"."age" <= 30))"#;
148        assert_eq!(norm(q.sql()), norm(expected));
149    }
150
151    #[test]
152    fn test_column_column_comparison() {
153        let q = users().r#where(|c| c.age.gt(c.id)).build();
154        let expected = r#"SELECT * FROM "users" WHERE ("users"."age" > "users"."id")"#;
155        assert_eq!(norm(q.sql()), norm(expected));
156    }
157    #[test]
158    fn test_ne_comparison() {
159        let q = users().r#where(|c| c.name.ne("Shub".to_string())).build();
160        assert!(q.sql().contains("name"), "Expected a name comparison");
161        assert!(q.sql().contains("<>"));
162    }
163
164    #[test]
165    fn test_not_comparison() {
166        let q = users().r#where(|c| c.name.eq("Alice".to_string()).not()).build();
167        let expected = r#"SELECT * FROM "users" WHERE (NOT ("users"."name" = 'Alice'))"#;
168        assert_eq!(norm(q.sql()), norm(expected));
169    }
170
171    #[test]
172    fn test_not_with_and() {
173        let q = users()
174            .r#where(|c| c.name.eq("Alice".to_string()).not().and(c.age.gt(18)))
175            .build();
176        let expected = r#"SELECT * FROM "users" WHERE ((NOT ("users"."name" = 'Alice')) AND ("users"."age" > 18))"#;
177        assert_eq!(norm(q.sql()), norm(expected));
178    }
179
180    #[test]
181    fn test_filter_alias() {
182        let q = users().filter(|c| c.id.eq(5)).filter(|c| c.age.lt(30)).build();
183        let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 5) AND ("users"."age" < 30))"#;
184        assert_eq!(norm(q.sql()), norm(expected));
185    }
186
187    #[test]
188    fn test_or_comparison() {
189        let q = users()
190            .r#where(|c| c.name.ne("Shub".to_string()).or(c.name.ne("Pop".to_string())))
191            .build();
192
193        let expected = r#"SELECT * FROM "users" WHERE (("users"."name" <> 'Shub') OR ("users"."name" <> 'Pop'))"#;
194        assert_eq!(q.sql, expected);
195    }
196
197    #[test]
198    fn test_format_expr_column_literal() {
199        let expr = BoolExpr::Eq(
200            Operand::Column(ColumnRef::<User>::new("user", "id")),
201            Operand::Literal(LiteralValue::new("42".to_string())),
202        );
203        let sql = format_expr(&expr);
204        assert!(sql.contains("id"), "Missing col");
205        assert!(sql.contains("42"), "Missing literal");
206    }
207
208    #[test]
209    fn test_format_semi_join_expr() {
210        let user = users();
211        let other = other();
212        let sql = user.left_semijoin(other, |u, o| u.id.eq(o.uid)).build().sql;
213        let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid""#;
214        assert_eq!(sql, expected);
215    }
216
217    #[test]
218    fn test_left_semijoin_with_where_expr() {
219        let user = users();
220        let o = other();
221        let sql = user
222            .left_semijoin(o, |u, o| u.id.eq(o.uid))
223            .r#where(|u| u.id.eq(1i32))
224            .r#where(|u| u.id.gt(10))
225            .build()
226            .sql;
227        let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("users"."id" = 1) AND ("users"."id" > 10))"#;
228        assert_eq!(sql, expected);
229        let user = users();
230        let other = other();
231        let sql2 = user
232            .r#where(|u| u.id.eq(1))
233            .r#where(|u| u.id.gt(10))
234            .left_semijoin(other, |u, o| u.id.eq(o.uid))
235            .build()
236            .sql;
237        assert_eq!(sql2, expected);
238    }
239    #[test]
240    fn test_right_semijoin_with_where_expr() {
241        let user = users();
242        let o = other();
243        let sql = user
244            .right_semijoin(o, |u, o| u.id.eq(o.uid))
245            .r#where(|o| o.uid.eq(1))
246            .r#where(|o| o.uid.gt(10))
247            .build()
248            .sql;
249        let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("other"."uid" = 1) AND ("other"."uid" > 10))"#;
250        assert_eq!(sql, expected);
251    }
252
253    #[test]
254    fn test_right_semijoin_with_left_and_right_where_expr() {
255        let user = users();
256        let o = other();
257        let sql = user
258            .r#where(|u| u.id.eq(1))
259            .right_semijoin(o, |u, o| u.id.eq(o.uid))
260            .r#where(|o| o.uid.gt(10))
261            .build()
262            .sql;
263        let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE ("users"."id" = 1) AND ("other"."uid" > 10)"#;
264        assert_eq!(sql, expected);
265    }
266
267    #[test]
268    fn test_literals() {
269        use spacetimedb_lib::{ConnectionId, Identity};
270
271        struct Player;
272        struct PlayerCols {
273            score: Col<Player, i32>,
274            name: Col<Player, String>,
275            active: Col<Player, bool>,
276            connection_id: Col<Player, ConnectionId>,
277            cells: Col<Player, i256>,
278            identity: Col<Player, Identity>,
279            ts: Col<Player, spacetimedb_lib::Timestamp>,
280            bytes: Col<Player, Vec<u8>>,
281        }
282
283        impl HasCols for Player {
284            type Cols = PlayerCols;
285            fn cols(table_name: &'static str) -> Self::Cols {
286                PlayerCols {
287                    score: Col::new(table_name, "score"),
288                    name: Col::new(table_name, "name"),
289                    active: Col::new(table_name, "active"),
290                    connection_id: Col::new(table_name, "connection_id"),
291                    cells: Col::new(table_name, "cells"),
292                    identity: Col::new(table_name, "identity"),
293                    ts: Col::new(table_name, "ts"),
294                    bytes: Col::new(table_name, "bytes"),
295                }
296            }
297        }
298
299        let table = Table::<Player>::new("player");
300        let q = table.r#where(|c| c.score.eq(100)).build();
301
302        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."score" = 100)"#);
303
304        let table = Table::<Player>::new("player");
305        let q = table.r#where(|c| c.name.ne("Alice".to_string())).build();
306
307        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."name" <> 'Alice')"#);
308
309        let table = Table::<Player>::new("player");
310        let q = table.r#where(|c| c.active.eq(true)).build();
311
312        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."active" = TRUE)"#);
313
314        let table = Table::<Player>::new("player");
315        let q = table.r#where(|c| c.connection_id.eq(ConnectionId::ZERO)).build();
316
317        assert_eq!(
318            q.sql,
319            r#"SELECT * FROM "player" WHERE ("player"."connection_id" = 0x00000000000000000000000000000000)"#
320        );
321
322        let big_int: i256 = (i256::ONE << 120) * i256::from(-1);
323
324        let table = Table::<Player>::new("player");
325        let q = table.r#where(|c| c.cells.gt(big_int)).build();
326
327        assert_eq!(
328            q.sql,
329            r#"SELECT * FROM "player" WHERE ("player"."cells" > -1329227995784915872903807060280344576)"#,
330        );
331
332        let table = Table::<Player>::new("player");
333        let q = table.r#where(|c| c.identity.ne(Identity::ONE)).build();
334
335        assert_eq!(
336            q.sql,
337            r#"SELECT * FROM "player" WHERE ("player"."identity" <> 0x0000000000000000000000000000000000000000000000000000000000000001)"#
338        );
339
340        let ts = spacetimedb_lib::Timestamp::UNIX_EPOCH + TimeDuration::from_micros(1000);
341
342        let table = Table::<Player>::new("player");
343        let q = table.r#where(|c| c.ts.eq(ts)).build();
344        assert_eq!(
345            q.sql,
346            r#"SELECT * FROM "player" WHERE ("player"."ts" = '1970-01-01T00:00:00.001+00:00')"#
347        );
348
349        let table = Table::<Player>::new("player");
350        let q = table.r#where(|c| c.bytes.eq(vec![1, 2, 3, 4, 255])).build();
351
352        assert_eq!(
353            q.sql,
354            r#"SELECT * FROM "player" WHERE ("player"."bytes" = 0x01020304ff)"#
355        );
356    }
357}