Skip to main content

sql_orm_query/
lib.rs

1//! Query AST foundations for the ORM.
2
3mod aggregate;
4mod delete;
5mod expr;
6mod insert;
7mod join;
8mod order;
9mod pagination;
10mod predicate;
11mod select;
12mod update;
13
14use sql_orm_core::{CrateIdentity, SqlValue};
15
16pub use aggregate::{
17    AggregateExpr, AggregateOrderBy, AggregatePredicate, AggregateProjection, AggregateQuery,
18};
19pub use delete::DeleteQuery;
20pub use expr::{BinaryOp, ColumnRef, Expr, TableRef, UnaryOp};
21pub use insert::InsertQuery;
22pub use join::{Join, JoinType};
23pub use order::{OrderBy, SortDirection};
24pub use pagination::Pagination;
25pub use predicate::Predicate;
26pub use select::{CountQuery, ExistsQuery, SelectProjection, SelectQuery};
27pub use update::UpdateQuery;
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct CompiledQuery {
31    pub sql: String,
32    pub params: Vec<SqlValue>,
33}
34
35impl CompiledQuery {
36    pub fn new(sql: impl Into<String>, params: Vec<SqlValue>) -> Self {
37        Self {
38            sql: sql.into(),
39            params,
40        }
41    }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum Query {
46    Select(SelectQuery),
47    Aggregate(Box<AggregateQuery>),
48    Exists(Box<ExistsQuery>),
49    Insert(InsertQuery),
50    Update(UpdateQuery),
51    Delete(DeleteQuery),
52    Count(CountQuery),
53}
54
55pub const CRATE_IDENTITY: CrateIdentity = CrateIdentity {
56    name: "sql-orm-query",
57    responsibility: "typed AST and query builder primitives without SQL generation",
58};
59
60#[cfg(test)]
61mod tests {
62    use super::{
63        AggregateExpr, AggregateOrderBy, AggregatePredicate, AggregateProjection, AggregateQuery,
64        BinaryOp, CRATE_IDENTITY, ColumnRef, CompiledQuery, CountQuery, DeleteQuery, ExistsQuery,
65        Expr, InsertQuery, Join, JoinType, OrderBy, Pagination, Predicate, Query, SelectProjection,
66        SelectQuery, SortDirection, TableRef, UpdateQuery,
67    };
68    use sql_orm_core::{
69        Changeset, ColumnMetadata, ColumnValue, Entity, EntityColumn, EntityMetadata,
70        IdentityMetadata, Insertable, PrimaryKeyMetadata, SqlServerType, SqlValue,
71    };
72
73    #[allow(dead_code)]
74    struct Customer;
75
76    #[allow(dead_code)]
77    struct Order;
78
79    static CUSTOMER_COLUMNS: [ColumnMetadata; 4] = [
80        ColumnMetadata {
81            rust_field: "id",
82            column_name: "id",
83            renamed_from: None,
84            sql_type: SqlServerType::BigInt,
85            nullable: false,
86            primary_key: true,
87            identity: Some(IdentityMetadata::new(1, 1)),
88            default_sql: None,
89            computed_sql: None,
90            rowversion: false,
91            insertable: false,
92            updatable: false,
93            max_length: None,
94            precision: None,
95            scale: None,
96        },
97        ColumnMetadata {
98            rust_field: "email",
99            column_name: "email",
100            renamed_from: None,
101            sql_type: SqlServerType::NVarChar,
102            nullable: false,
103            primary_key: false,
104            identity: None,
105            default_sql: None,
106            computed_sql: None,
107            rowversion: false,
108            insertable: true,
109            updatable: true,
110            max_length: Some(160),
111            precision: None,
112            scale: None,
113        },
114        ColumnMetadata {
115            rust_field: "active",
116            column_name: "active",
117            renamed_from: None,
118            sql_type: SqlServerType::Bit,
119            nullable: false,
120            primary_key: false,
121            identity: None,
122            default_sql: Some("1"),
123            computed_sql: None,
124            rowversion: false,
125            insertable: true,
126            updatable: true,
127            max_length: None,
128            precision: None,
129            scale: None,
130        },
131        ColumnMetadata {
132            rust_field: "created_at",
133            column_name: "created_at",
134            renamed_from: None,
135            sql_type: SqlServerType::DateTime2,
136            nullable: false,
137            primary_key: false,
138            identity: None,
139            default_sql: Some("SYSUTCDATETIME()"),
140            computed_sql: None,
141            rowversion: false,
142            insertable: true,
143            updatable: true,
144            max_length: None,
145            precision: None,
146            scale: None,
147        },
148    ];
149
150    static CUSTOMER_METADATA: EntityMetadata = EntityMetadata {
151        rust_name: "Customer",
152        schema: "sales",
153        table: "customers",
154        renamed_from: None,
155        columns: &CUSTOMER_COLUMNS,
156        primary_key: PrimaryKeyMetadata::new(Some("pk_customers"), &["id"]),
157        indexes: &[],
158        foreign_keys: &[],
159        navigations: &[],
160    };
161
162    impl Entity for Customer {
163        fn metadata() -> &'static EntityMetadata {
164            &CUSTOMER_METADATA
165        }
166    }
167
168    static ORDER_COLUMNS: [ColumnMetadata; 3] = [
169        ColumnMetadata {
170            rust_field: "id",
171            column_name: "id",
172            renamed_from: None,
173            sql_type: SqlServerType::BigInt,
174            nullable: false,
175            primary_key: true,
176            identity: Some(IdentityMetadata::new(1, 1)),
177            default_sql: None,
178            computed_sql: None,
179            rowversion: false,
180            insertable: false,
181            updatable: false,
182            max_length: None,
183            precision: None,
184            scale: None,
185        },
186        ColumnMetadata {
187            rust_field: "customer_id",
188            column_name: "customer_id",
189            renamed_from: None,
190            sql_type: SqlServerType::BigInt,
191            nullable: false,
192            primary_key: false,
193            identity: None,
194            default_sql: None,
195            computed_sql: None,
196            rowversion: false,
197            insertable: true,
198            updatable: true,
199            max_length: None,
200            precision: None,
201            scale: None,
202        },
203        ColumnMetadata {
204            rust_field: "total_cents",
205            column_name: "total_cents",
206            renamed_from: None,
207            sql_type: SqlServerType::BigInt,
208            nullable: false,
209            primary_key: false,
210            identity: None,
211            default_sql: None,
212            computed_sql: None,
213            rowversion: false,
214            insertable: true,
215            updatable: true,
216            max_length: None,
217            precision: None,
218            scale: None,
219        },
220    ];
221
222    static ORDER_METADATA: EntityMetadata = EntityMetadata {
223        rust_name: "Order",
224        schema: "sales",
225        table: "orders",
226        renamed_from: None,
227        columns: &ORDER_COLUMNS,
228        primary_key: PrimaryKeyMetadata::new(Some("pk_orders"), &["id"]),
229        indexes: &[],
230        foreign_keys: &[],
231        navigations: &[],
232    };
233
234    impl Entity for Order {
235        fn metadata() -> &'static EntityMetadata {
236            &ORDER_METADATA
237        }
238    }
239
240    #[allow(non_upper_case_globals)]
241    impl Customer {
242        const id: EntityColumn<Customer> = EntityColumn::new("id", "id");
243        const email: EntityColumn<Customer> = EntityColumn::new("email", "email");
244        const active: EntityColumn<Customer> = EntityColumn::new("active", "active");
245        const created_at: EntityColumn<Customer> = EntityColumn::new("created_at", "created_at");
246    }
247
248    #[allow(non_upper_case_globals)]
249    impl Order {
250        const customer_id: EntityColumn<Order> = EntityColumn::new("customer_id", "customer_id");
251        const total_cents: EntityColumn<Order> = EntityColumn::new("total_cents", "total_cents");
252    }
253
254    struct NewCustomer {
255        email: String,
256        active: bool,
257    }
258
259    impl Insertable<Customer> for NewCustomer {
260        fn values(&self) -> Vec<ColumnValue> {
261            vec![
262                ColumnValue::new("email", SqlValue::String(self.email.clone())),
263                ColumnValue::new("active", SqlValue::Bool(self.active)),
264            ]
265        }
266    }
267
268    struct UpdateCustomer {
269        email: Option<String>,
270    }
271
272    impl Changeset<Customer> for UpdateCustomer {
273        fn changes(&self) -> Vec<ColumnValue> {
274            self.email
275                .clone()
276                .map(|email| vec![ColumnValue::new("email", SqlValue::String(email))])
277                .unwrap_or_default()
278        }
279    }
280
281    #[test]
282    fn keeps_query_layer_sql_free() {
283        assert!(
284            CRATE_IDENTITY
285                .responsibility
286                .contains("without SQL generation")
287        );
288    }
289
290    #[test]
291    fn entity_columns_become_table_aware_column_refs() {
292        let column = ColumnRef::for_entity_column(Customer::email);
293
294        assert_eq!(column.table, TableRef::new("sales", "customers"));
295        assert_eq!(column.rust_field, "email");
296        assert_eq!(column.column_name, "email");
297    }
298
299    #[test]
300    fn expr_supports_columns_values_functions_and_operations() {
301        let expr = Expr::binary(
302            Expr::function("LOWER", vec![Expr::from(Customer::email)]),
303            BinaryOp::Add,
304            Expr::value(SqlValue::String("@example.com".to_string())),
305        );
306
307        match expr {
308            Expr::Binary { left, op, right } => {
309                assert_eq!(op, BinaryOp::Add);
310                assert!(matches!(*left, Expr::Function { .. }));
311                assert_eq!(
312                    *right,
313                    Expr::Value(SqlValue::String("@example.com".to_string()))
314                );
315            }
316            other => panic!("unexpected expr shape: {other:?}"),
317        }
318    }
319
320    #[test]
321    fn predicates_can_be_composed_without_sql_rendering() {
322        let predicate = Predicate::and(vec![
323            Predicate::eq(
324                Expr::from(Customer::active),
325                Expr::value(SqlValue::Bool(true)),
326            ),
327            Predicate::like(
328                Expr::from(Customer::email),
329                Expr::value(SqlValue::String("%@example.com".to_string())),
330            ),
331        ]);
332
333        match predicate {
334            Predicate::And(parts) => assert_eq!(parts.len(), 2),
335            other => panic!("unexpected predicate shape: {other:?}"),
336        }
337    }
338
339    #[test]
340    fn select_query_captures_projection_filters_order_and_pagination() {
341        let query = SelectQuery::from_entity::<Customer>()
342            .select(vec![Expr::from(Customer::id), Expr::from(Customer::email)])
343            .filter(Predicate::eq(
344                Expr::from(Customer::active),
345                Expr::value(SqlValue::Bool(true)),
346            ))
347            .filter(Predicate::like(
348                Expr::from(Customer::email),
349                Expr::value(SqlValue::String("%@example.com".to_string())),
350            ))
351            .order_by(OrderBy::desc(Customer::created_at))
352            .paginate(Pagination::page(2, 20));
353
354        assert_eq!(query.from, TableRef::new("sales", "customers"));
355        assert!(query.joins.is_empty());
356        assert_eq!(
357            query.projection,
358            vec![
359                SelectProjection::column(Customer::id),
360                SelectProjection::column(Customer::email)
361            ]
362        );
363        assert_eq!(
364            query.order_by,
365            vec![OrderBy::new(
366                TableRef::new("sales", "customers"),
367                "created_at",
368                SortDirection::Desc,
369            )]
370        );
371        assert_eq!(query.pagination, Some(Pagination::new(20, 20)));
372        assert!(matches!(query.predicate, Some(Predicate::And(_))));
373    }
374
375    #[test]
376    fn select_query_captures_explicit_joins_without_sql_rendering() {
377        let query = SelectQuery::from_entity::<Customer>()
378            .inner_join::<Order>(Predicate::eq(
379                Expr::from(Customer::id),
380                Expr::from(Order::customer_id),
381            ))
382            .join(Join::left(
383                TableRef::new("sales", "orders"),
384                Predicate::gt(
385                    Expr::from(Order::total_cents),
386                    Expr::value(SqlValue::I64(0)),
387                ),
388            ));
389
390        assert_eq!(query.joins.len(), 2);
391        assert_eq!(query.joins[0].join_type, JoinType::Inner);
392        assert_eq!(query.joins[0].table, TableRef::new("sales", "orders"));
393        assert!(matches!(query.joins[0].on, Predicate::Eq(_, _)));
394        assert_eq!(query.joins[1].join_type, JoinType::Left);
395        assert_eq!(query.joins[1].table, TableRef::new("sales", "orders"));
396        assert!(matches!(query.joins[1].on, Predicate::Gt(_, _)));
397    }
398
399    #[test]
400    fn table_refs_capture_optional_aliases_without_sql_rendering() {
401        let table = TableRef::for_entity_as::<Customer>("root");
402        let column = ColumnRef::for_entity_column_as(Customer::email, "root");
403        let expr = Expr::column_as(Customer::id, "root");
404
405        assert_eq!(table.schema, "sales");
406        assert_eq!(table.table, "customers");
407        assert_eq!(table.alias, Some("root"));
408        assert_eq!(table.reference_name(), "root");
409        assert_eq!(table.without_alias(), TableRef::new("sales", "customers"));
410        assert_eq!(column.table, table);
411
412        match expr {
413            Expr::Column(column) => {
414                assert_eq!(column.table.alias, Some("root"));
415                assert_eq!(column.column_name, "id");
416            }
417            other => panic!("unexpected expr shape: {other:?}"),
418        }
419    }
420
421    #[test]
422    fn select_query_captures_aliased_sources_and_repeated_joins() {
423        let query = SelectQuery::from_entity_as::<Customer>("c")
424            .inner_join_as::<Order>(
425                "created_orders",
426                Predicate::eq(
427                    Expr::column_as(Customer::id, "c"),
428                    Expr::column_as(Order::customer_id, "created_orders"),
429                ),
430            )
431            .left_join_as::<Order>(
432                "completed_orders",
433                Predicate::gt(
434                    Expr::column_as(Order::total_cents, "completed_orders"),
435                    Expr::value(SqlValue::I64(0)),
436                ),
437            );
438
439        assert_eq!(query.from, TableRef::with_alias("sales", "customers", "c"));
440        assert_eq!(query.joins.len(), 2);
441        assert_eq!(
442            query.joins[0].table,
443            TableRef::with_alias("sales", "orders", "created_orders")
444        );
445        assert_eq!(
446            query.joins[1].table,
447            TableRef::with_alias("sales", "orders", "completed_orders")
448        );
449        assert_ne!(query.joins[0].table, query.joins[1].table);
450    }
451
452    #[test]
453    fn select_projection_captures_default_and_explicit_aliases() {
454        let column_projection = SelectProjection::column(Customer::email);
455        assert_eq!(column_projection.alias, Some("email"));
456        assert_eq!(column_projection.expr, Expr::from(Customer::email));
457
458        let expression_projection = SelectProjection::expr_as(
459            Expr::function("LOWER", vec![Expr::from(Customer::email)]),
460            "email_lower",
461        );
462        assert_eq!(expression_projection.alias, Some("email_lower"));
463
464        let unaliased_expression =
465            SelectProjection::expr(Expr::function("LOWER", vec![Expr::from(Customer::email)]));
466        assert_eq!(unaliased_expression.alias, None);
467    }
468
469    #[test]
470    fn aggregate_projection_requires_alias_without_changing_select_projection() {
471        let group_key = AggregateProjection::group_key(Order::customer_id);
472        let expression_group_key = AggregateProjection::group_key_as(
473            Expr::function("YEAR", vec![Expr::from(Customer::created_at)]),
474            "created_year",
475        );
476        let aggregate = AggregateProjection::sum_as(Order::total_cents, "total_cents");
477
478        assert_eq!(group_key.alias, "customer_id");
479        assert_eq!(expression_group_key.alias, "created_year");
480        assert_eq!(aggregate.alias, "total_cents");
481        assert_eq!(
482            aggregate,
483            AggregateProjection::expr_as(
484                AggregateExpr::Sum(Expr::from(Order::total_cents)),
485                "total_cents"
486            )
487        );
488
489        let ordinary_projection =
490            SelectProjection::expr(Expr::function("LOWER", vec![Expr::from(Customer::email)]));
491        assert_eq!(ordinary_projection.alias, None);
492    }
493
494    #[test]
495    fn aggregate_query_captures_grouping_having_and_projection_without_sql_rendering() {
496        let query = AggregateQuery::from_entity::<Order>()
497            .inner_join::<Customer>(Predicate::eq(
498                Expr::from(Order::customer_id),
499                Expr::from(Customer::id),
500            ))
501            .filter(Predicate::eq(
502                Expr::from(Customer::active),
503                Expr::value(SqlValue::Bool(true)),
504            ))
505            .group_by(vec![Expr::from(Order::customer_id)])
506            .project(vec![
507                AggregateProjection::group_key(Order::customer_id),
508                AggregateProjection::count_as("order_count"),
509                AggregateProjection::sum_as(Order::total_cents, "total_cents"),
510            ])
511            .having(AggregatePredicate::gt(
512                AggregateExpr::count_all(),
513                Expr::value(SqlValue::I64(1)),
514            ))
515            .order_by(AggregateOrderBy::desc(AggregateExpr::sum(Expr::from(
516                Order::total_cents,
517            ))))
518            .paginate(Pagination::page(1, 10));
519
520        assert_eq!(query.from, TableRef::new("sales", "orders"));
521        assert_eq!(query.joins.len(), 1);
522        assert!(query.predicate.is_some());
523        assert_eq!(query.group_by, vec![Expr::from(Order::customer_id)]);
524        assert_eq!(
525            query.projection,
526            vec![
527                AggregateProjection::group_key(Order::customer_id),
528                AggregateProjection::count_as("order_count"),
529                AggregateProjection::sum_as(Order::total_cents, "total_cents")
530            ]
531        );
532        assert!(matches!(query.having, Some(AggregatePredicate::Gt(_, _))));
533        assert_eq!(
534            query.order_by,
535            vec![AggregateOrderBy::desc(AggregateExpr::sum(Expr::from(
536                Order::total_cents
537            )))]
538        );
539        assert_eq!(query.pagination, Some(Pagination::new(0, 10)));
540        assert!(matches!(
541            Query::Aggregate(Box::new(query)),
542            Query::Aggregate(_)
543        ));
544    }
545
546    #[test]
547    fn insert_update_delete_and_count_queries_capture_operation_data() {
548        let insert = InsertQuery::for_entity::<Customer, _>(&NewCustomer {
549            email: "ana@example.com".to_string(),
550            active: true,
551        });
552        let update = UpdateQuery::for_entity::<Customer, _>(&UpdateCustomer {
553            email: Some("ana.maria@example.com".to_string()),
554        })
555        .filter(Predicate::eq(
556            Expr::from(Customer::id),
557            Expr::value(SqlValue::I64(7)),
558        ));
559        let delete = DeleteQuery::from_entity::<Customer>().filter(Predicate::eq(
560            Expr::from(Customer::id),
561            Expr::value(SqlValue::I64(7)),
562        ));
563        let count = CountQuery::from_entity::<Customer>().filter(Predicate::eq(
564            Expr::from(Customer::active),
565            Expr::value(SqlValue::Bool(true)),
566        ));
567        let exists = ExistsQuery::from_entity::<Customer>()
568            .inner_join::<Order>(Predicate::eq(
569                Expr::from(Customer::id),
570                Expr::from(Order::customer_id),
571            ))
572            .filter(Predicate::eq(
573                Expr::from(Customer::active),
574                Expr::value(SqlValue::Bool(true)),
575            ));
576
577        assert_eq!(insert.into, TableRef::new("sales", "customers"));
578        assert_eq!(insert.values.len(), 2);
579        assert_eq!(update.table, TableRef::new("sales", "customers"));
580        assert_eq!(update.changes.len(), 1);
581        assert!(update.predicate.is_some());
582        assert_eq!(delete.from, TableRef::new("sales", "customers"));
583        assert!(delete.predicate.is_some());
584        assert_eq!(count.from, TableRef::new("sales", "customers"));
585        assert!(count.predicate.is_some());
586        assert_eq!(exists.from, TableRef::new("sales", "customers"));
587        assert_eq!(exists.joins.len(), 1);
588        assert!(exists.predicate.is_some());
589
590        assert!(matches!(Query::Insert(insert.clone()), Query::Insert(_)));
591        assert!(matches!(Query::Update(update.clone()), Query::Update(_)));
592        assert!(matches!(Query::Delete(delete.clone()), Query::Delete(_)));
593        assert!(matches!(Query::Count(count.clone()), Query::Count(_)));
594        assert!(matches!(Query::Exists(Box::new(exists)), Query::Exists(_)));
595    }
596
597    #[test]
598    fn compiled_query_keeps_sql_and_parameter_order() {
599        let compiled = CompiledQuery::new(
600            "SELECT [id] FROM [sales].[customers] WHERE [active] = @P1 AND [email] LIKE @P2",
601            vec![
602                SqlValue::Bool(true),
603                SqlValue::String("%@example.com".to_string()),
604            ],
605        );
606
607        assert_eq!(
608            compiled.sql,
609            "SELECT [id] FROM [sales].[customers] WHERE [active] = @P1 AND [email] LIKE @P2"
610        );
611        assert_eq!(
612            compiled.params,
613            vec![
614                SqlValue::Bool(true),
615                SqlValue::String("%@example.com".to_string()),
616            ]
617        );
618    }
619}