spacetimedb_query_builder/
lib.rs1pub mod expr;
2pub mod join;
3pub mod table;
4
5pub use expr::*;
6pub use join::*;
7use spacetimedb_lib::{sats::impl_st, AlgebraicType, SpacetimeType};
8pub use table::*;
9
10const QUERY_VIEW_RETURN_TAG: &str = "__query__";
11
12pub trait Query<T> {
15 fn into_sql(self) -> String;
16}
17
18pub struct RawQuery<T> {
20 pub(crate) sql: String,
21 _marker: std::marker::PhantomData<T>,
22}
23
24impl<T> RawQuery<T> {
25 pub fn new(sql: String) -> Self {
26 Self {
27 sql,
28 _marker: std::marker::PhantomData,
29 }
30 }
31
32 pub fn sql(&self) -> &str {
33 &self.sql
34 }
35}
36
37impl<T> Query<T> for RawQuery<T> {
38 fn into_sql(self) -> String {
39 self.sql
40 }
41}
42
43impl_st!(
44 [T: SpacetimeType] RawQuery<T>,
45 ts => AlgebraicType::product([(QUERY_VIEW_RETURN_TAG, T::make_type(ts))])
46);
47
48#[cfg(test)]
49mod tests {
50 use spacetimedb_lib::{sats::i256, TimeDuration};
51
52 use super::*;
53 struct User;
54 #[derive(Clone)]
55 struct UserCols {
56 pub id: Col<User, i32>,
57 pub name: Col<User, String>,
58 pub age: Col<User, i32>,
59 pub online: Col<User, bool>,
60 }
61 impl UserCols {
62 fn new(table_name: &'static str) -> Self {
63 Self {
64 id: Col::new(table_name, "id"),
65 name: Col::new(table_name, "name"),
66 age: Col::new(table_name, "age"),
67 online: Col::new(table_name, "online"),
68 }
69 }
70 }
71 impl HasCols for User {
72 type Cols = UserCols;
73 fn cols(table_name: &'static str) -> Self::Cols {
74 UserCols::new(table_name)
75 }
76 }
77 fn users() -> Table<User> {
78 Table::new("users")
79 }
80 fn other() -> Table<Other> {
81 Table::new("other")
82 }
83 struct OtherCols {
84 pub uid: Col<Other, i32>,
85 }
86
87 impl HasCols for Other {
88 type Cols = OtherCols;
89 fn cols(table: &'static str) -> Self::Cols {
90 OtherCols {
91 uid: Col::new(table, "uid"),
92 }
93 }
94 }
95 struct IxUserCols {
96 pub id: IxCol<User, i32>,
97 }
98 impl HasIxCols for User {
99 type IxCols = IxUserCols;
100 fn ix_cols(table_name: &'static str) -> Self::IxCols {
101 IxUserCols {
102 id: IxCol::new(table_name, "id"),
103 }
104 }
105 }
106 struct Other;
107 #[derive(Clone)]
108 struct IxOtherCols {
109 pub uid: IxCol<Other, i32>,
110 }
111 impl HasIxCols for Other {
112 type IxCols = IxOtherCols;
113 fn ix_cols(table_name: &'static str) -> Self::IxCols {
114 IxOtherCols {
115 uid: IxCol::new(table_name, "uid"),
116 }
117 }
118 }
119 impl CanBeLookupTable for User {}
120 impl CanBeLookupTable for Other {}
121 fn norm(s: &str) -> String {
122 s.split_whitespace().collect::<Vec<_>>().join(" ")
123 }
124 #[test]
125 fn test_simple_select() {
126 let q = users().build();
127 assert_eq!(q.sql(), r#"SELECT * FROM "users""#);
128 }
129 #[test]
130 fn test_where_literal() {
131 let q = users().r#where(|c| c.id.eq(10)).build();
132 let expected = r#"SELECT * FROM "users" WHERE ("users"."id" = 10)"#;
133 assert_eq!(norm(q.sql()), norm(expected));
134 }
135 #[test]
136 fn test_where_multiple_predicates() {
137 let q = users().r#where(|c| c.id.eq(10)).r#where(|c| c.age.gt(18)).build();
138 let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 10) AND ("users"."age" > 18))"#;
139 assert_eq!(norm(q.sql()), norm(expected));
140 }
141
142 #[test]
143 fn test_where_bool_column_directly() {
144 let q = users().r#where(|c| c.online).build();
145 let expected = r#"SELECT * FROM "users" WHERE ("users"."online" = TRUE)"#;
146 assert_eq!(norm(q.sql()), norm(expected));
147 }
148
149 #[test]
150 fn test_where_gte_lte() {
151 let q = users().r#where(|c| c.age.gte(18)).r#where(|c| c.age.lte(30)).build();
152 let expected = r#"SELECT * FROM "users" WHERE (("users"."age" >= 18) AND ("users"."age" <= 30))"#;
153 assert_eq!(norm(q.sql()), norm(expected));
154 }
155
156 #[test]
157 fn test_column_column_comparison() {
158 let q = users().r#where(|c| c.age.gt(c.id)).build();
159 let expected = r#"SELECT * FROM "users" WHERE ("users"."age" > "users"."id")"#;
160 assert_eq!(norm(q.sql()), norm(expected));
161 }
162 #[test]
163 fn test_ne_comparison() {
164 let q = users().r#where(|c| c.name.ne("Shub".to_string())).build();
165 assert!(q.sql().contains("name"), "Expected a name comparison");
166 assert!(q.sql().contains("<>"));
167 }
168
169 #[test]
170 fn test_not_comparison() {
171 let q = users().r#where(|c| c.name.eq("Alice".to_string()).not()).build();
172 let expected = r#"SELECT * FROM "users" WHERE (NOT ("users"."name" = 'Alice'))"#;
173 assert_eq!(norm(q.sql()), norm(expected));
174 }
175
176 #[test]
177 fn test_not_with_and() {
178 let q = users()
179 .r#where(|c| c.name.eq("Alice".to_string()).not().and(c.age.gt(18)))
180 .build();
181 let expected = r#"SELECT * FROM "users" WHERE ((NOT ("users"."name" = 'Alice')) AND ("users"."age" > 18))"#;
182 assert_eq!(norm(q.sql()), norm(expected));
183 }
184
185 #[test]
186 fn test_filter_alias() {
187 let q = users().filter(|c| c.id.eq(5)).filter(|c| c.age.lt(30)).build();
188 let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 5) AND ("users"."age" < 30))"#;
189 assert_eq!(norm(q.sql()), norm(expected));
190 }
191
192 #[test]
193 fn test_or_comparison() {
194 let q = users()
195 .r#where(|c| c.name.ne("Shub".to_string()).or(c.name.ne("Pop".to_string())))
196 .build();
197
198 let expected = r#"SELECT * FROM "users" WHERE (("users"."name" <> 'Shub') OR ("users"."name" <> 'Pop'))"#;
199 assert_eq!(q.sql, expected);
200 }
201
202 #[test]
203 fn test_format_expr_column_literal() {
204 let expr = BoolExpr::Eq(
205 Operand::Column(ColumnRef::<User>::new("user", "id")),
206 Operand::Literal(LiteralValue::new("42".to_string())),
207 );
208 let sql = format_expr(&expr);
209 assert!(sql.contains("id"), "Missing col");
210 assert!(sql.contains("42"), "Missing literal");
211 }
212
213 #[test]
214 fn test_format_semi_join_expr() {
215 let user = users();
216 let other = other();
217 let sql = user.left_semijoin(other, |u, o| u.id.eq(o.uid)).build().sql;
218 let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid""#;
219 assert_eq!(sql, expected);
220 }
221
222 #[test]
223 fn test_left_semijoin_with_where_expr() {
224 let user = users();
225 let o = other();
226 let sql = user
227 .left_semijoin(o, |u, o| u.id.eq(o.uid))
228 .r#where(|u| u.id.eq(1i32))
229 .r#where(|u| u.id.gt(10))
230 .build()
231 .sql;
232 let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("users"."id" = 1) AND ("users"."id" > 10))"#;
233 assert_eq!(sql, expected);
234 let user = users();
235 let other = other();
236 let sql2 = user
237 .r#where(|u| u.id.eq(1))
238 .r#where(|u| u.id.gt(10))
239 .left_semijoin(other, |u, o| u.id.eq(o.uid))
240 .build()
241 .sql;
242 assert_eq!(sql2, expected);
243 }
244 #[test]
245 fn test_right_semijoin_with_where_expr() {
246 let user = users();
247 let o = other();
248 let sql = user
249 .right_semijoin(o, |u, o| u.id.eq(o.uid))
250 .r#where(|o| o.uid.eq(1))
251 .r#where(|o| o.uid.gt(10))
252 .build()
253 .sql;
254 let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("other"."uid" = 1) AND ("other"."uid" > 10))"#;
255 assert_eq!(sql, expected);
256 }
257
258 #[test]
259 fn test_right_semijoin_with_left_and_right_where_expr() {
260 let user = users();
261 let o = other();
262 let sql = user
263 .r#where(|u| u.id.eq(1))
264 .right_semijoin(o, |u, o| u.id.eq(o.uid))
265 .r#where(|o| o.uid.gt(10))
266 .build()
267 .sql;
268 let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE ("users"."id" = 1) AND ("other"."uid" > 10)"#;
269 assert_eq!(sql, expected);
270 }
271
272 #[test]
273 fn test_literals() {
274 use spacetimedb_lib::{ConnectionId, Identity};
275
276 struct Player;
277 struct PlayerCols {
278 score: Col<Player, i32>,
279 name: Col<Player, String>,
280 active: Col<Player, bool>,
281 connection_id: Col<Player, ConnectionId>,
282 cells: Col<Player, i256>,
283 identity: Col<Player, Identity>,
284 ts: Col<Player, spacetimedb_lib::Timestamp>,
285 bytes: Col<Player, Vec<u8>>,
286 }
287
288 impl HasCols for Player {
289 type Cols = PlayerCols;
290 fn cols(table_name: &'static str) -> Self::Cols {
291 PlayerCols {
292 score: Col::new(table_name, "score"),
293 name: Col::new(table_name, "name"),
294 active: Col::new(table_name, "active"),
295 connection_id: Col::new(table_name, "connection_id"),
296 cells: Col::new(table_name, "cells"),
297 identity: Col::new(table_name, "identity"),
298 ts: Col::new(table_name, "ts"),
299 bytes: Col::new(table_name, "bytes"),
300 }
301 }
302 }
303
304 let table = Table::<Player>::new("player");
305 let q = table.r#where(|c| c.score.eq(100)).build();
306
307 assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."score" = 100)"#);
308
309 let table = Table::<Player>::new("player");
310 let q = table.r#where(|c| c.name.ne("Alice".to_string())).build();
311
312 assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."name" <> 'Alice')"#);
313
314 let table = Table::<Player>::new("player");
315 let q = table.r#where(|c| c.active.eq(true)).build();
316
317 assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."active" = TRUE)"#);
318
319 let table = Table::<Player>::new("player");
320 let q = table.r#where(|c| c.connection_id.eq(ConnectionId::ZERO)).build();
321
322 assert_eq!(
323 q.sql,
324 r#"SELECT * FROM "player" WHERE ("player"."connection_id" = 0x00000000000000000000000000000000)"#
325 );
326
327 let big_int: i256 = (i256::ONE << 120) * i256::from(-1);
328
329 let table = Table::<Player>::new("player");
330 let q = table.r#where(|c| c.cells.gt(big_int)).build();
331
332 assert_eq!(
333 q.sql,
334 r#"SELECT * FROM "player" WHERE ("player"."cells" > -1329227995784915872903807060280344576)"#,
335 );
336
337 let table = Table::<Player>::new("player");
338 let q = table.r#where(|c| c.identity.ne(Identity::ONE)).build();
339
340 assert_eq!(
341 q.sql,
342 r#"SELECT * FROM "player" WHERE ("player"."identity" <> 0x0000000000000000000000000000000000000000000000000000000000000001)"#
343 );
344
345 let ts = spacetimedb_lib::Timestamp::UNIX_EPOCH + TimeDuration::from_micros(1000);
346
347 let table = Table::<Player>::new("player");
348 let q = table.r#where(|c| c.ts.eq(ts)).build();
349 assert_eq!(
350 q.sql,
351 r#"SELECT * FROM "player" WHERE ("player"."ts" = '1970-01-01T00:00:00.001+00:00')"#
352 );
353
354 let table = Table::<Player>::new("player");
355 let q = table.r#where(|c| c.bytes.eq(vec![1, 2, 3, 4, 255])).build();
356
357 assert_eq!(
358 q.sql,
359 r#"SELECT * FROM "player" WHERE ("player"."bytes" = 0x01020304ff)"#
360 );
361 }
362}