Skip to main content

sql_orm_query/
lib.rs

1//! Query AST foundations for the ORM.
2
3mod aggregate;
4mod delete;
5mod expr;
6mod insert;
7mod join;
8mod order;
9mod pagination;
10mod predicate;
11mod select;
12mod update;
13
14use sql_orm_core::{CrateIdentity, SqlValue};
15
16pub use aggregate::{
17    AggregateExpr, AggregateOrderBy, AggregatePredicate, AggregateProjection, AggregateQuery,
18};
19pub use delete::DeleteQuery;
20pub use expr::{BinaryOp, ColumnRef, Expr, SqlFunction, TableRef, UnaryOp};
21pub use insert::InsertQuery;
22pub use join::{Join, JoinType};
23pub use order::{OrderBy, SortDirection};
24pub use pagination::Pagination;
25pub use predicate::Predicate;
26pub use select::{CountQuery, ExistsQuery, SelectProjection, SelectQuery};
27pub use update::UpdateQuery;
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct CompiledQuery {
31    pub sql: String,
32    pub params: Vec<SqlValue>,
33    pub execution: QueryExecution,
34}
35
36impl CompiledQuery {
37    pub fn new(sql: impl Into<String>, params: Vec<SqlValue>) -> Self {
38        Self::with_execution(sql, params, QueryExecution::RawNoRetry)
39    }
40
41    pub fn read_only(sql: impl Into<String>, params: Vec<SqlValue>) -> Self {
42        Self::with_execution(sql, params, QueryExecution::ReadOnly)
43    }
44
45    pub fn write(sql: impl Into<String>, params: Vec<SqlValue>) -> Self {
46        Self::with_execution(sql, params, QueryExecution::Write)
47    }
48
49    pub fn migration(sql: impl Into<String>, params: Vec<SqlValue>) -> Self {
50        Self::with_execution(sql, params, QueryExecution::Migration)
51    }
52
53    pub fn raw_no_retry(sql: impl Into<String>, params: Vec<SqlValue>) -> Self {
54        Self::with_execution(sql, params, QueryExecution::RawNoRetry)
55    }
56
57    pub fn with_execution(
58        sql: impl Into<String>,
59        params: Vec<SqlValue>,
60        execution: QueryExecution,
61    ) -> Self {
62        Self {
63            sql: sql.into(),
64            params,
65            execution,
66        }
67    }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
71pub enum QueryExecution {
72    ReadOnly,
73    Write,
74    Migration,
75    RawNoRetry,
76}
77
78#[derive(Debug, Clone, PartialEq)]
79pub enum Query {
80    Select(SelectQuery),
81    Aggregate(Box<AggregateQuery>),
82    Exists(Box<ExistsQuery>),
83    Insert(InsertQuery),
84    Update(UpdateQuery),
85    Delete(DeleteQuery),
86    Count(CountQuery),
87}
88
89pub const CRATE_IDENTITY: CrateIdentity = CrateIdentity {
90    name: "sql-orm-query",
91    responsibility: "typed AST and query builder primitives without SQL generation",
92};
93
94#[cfg(test)]
95mod tests {
96    use super::{
97        AggregateExpr, AggregateOrderBy, AggregatePredicate, AggregateProjection, AggregateQuery,
98        BinaryOp, CRATE_IDENTITY, ColumnRef, CompiledQuery, CountQuery, DeleteQuery, ExistsQuery,
99        Expr, InsertQuery, Join, JoinType, OrderBy, Pagination, Predicate, Query, SelectProjection,
100        SelectQuery, SortDirection, SqlFunction, TableRef, UpdateQuery,
101    };
102    use sql_orm_core::{
103        Changeset, ColumnMetadata, ColumnValue, Entity, EntityColumn, EntityMetadata,
104        IdentityMetadata, Insertable, PrimaryKeyMetadata, SqlServerType, SqlValue,
105    };
106
107    #[allow(dead_code)]
108    struct Customer;
109
110    #[allow(dead_code)]
111    struct Order;
112
113    static CUSTOMER_COLUMNS: [ColumnMetadata; 4] = [
114        ColumnMetadata {
115            rust_field: "id",
116            column_name: "id",
117            renamed_from: None,
118            sql_type: SqlServerType::BigInt,
119            nullable: false,
120            primary_key: true,
121            identity: Some(IdentityMetadata::new(1, 1)),
122            default_sql: None,
123            computed_sql: None,
124            rowversion: false,
125            insertable: false,
126            updatable: false,
127            max_length: None,
128            precision: None,
129            scale: None,
130        },
131        ColumnMetadata {
132            rust_field: "email",
133            column_name: "email",
134            renamed_from: None,
135            sql_type: SqlServerType::NVarChar,
136            nullable: false,
137            primary_key: false,
138            identity: None,
139            default_sql: None,
140            computed_sql: None,
141            rowversion: false,
142            insertable: true,
143            updatable: true,
144            max_length: Some(160),
145            precision: None,
146            scale: None,
147        },
148        ColumnMetadata {
149            rust_field: "active",
150            column_name: "active",
151            renamed_from: None,
152            sql_type: SqlServerType::Bit,
153            nullable: false,
154            primary_key: false,
155            identity: None,
156            default_sql: Some("1"),
157            computed_sql: None,
158            rowversion: false,
159            insertable: true,
160            updatable: true,
161            max_length: None,
162            precision: None,
163            scale: None,
164        },
165        ColumnMetadata {
166            rust_field: "created_at",
167            column_name: "created_at",
168            renamed_from: None,
169            sql_type: SqlServerType::DateTime2,
170            nullable: false,
171            primary_key: false,
172            identity: None,
173            default_sql: Some("SYSUTCDATETIME()"),
174            computed_sql: None,
175            rowversion: false,
176            insertable: true,
177            updatable: true,
178            max_length: None,
179            precision: None,
180            scale: None,
181        },
182    ];
183
184    static CUSTOMER_METADATA: EntityMetadata = EntityMetadata {
185        rust_name: "Customer",
186        schema: "sales",
187        table: "customers",
188        renamed_from: None,
189        columns: &CUSTOMER_COLUMNS,
190        primary_key: PrimaryKeyMetadata::new(Some("pk_customers"), &["id"]),
191        indexes: &[],
192        foreign_keys: &[],
193        navigations: &[],
194    };
195
196    impl Entity for Customer {
197        fn metadata() -> &'static EntityMetadata {
198            &CUSTOMER_METADATA
199        }
200    }
201
202    static ORDER_COLUMNS: [ColumnMetadata; 3] = [
203        ColumnMetadata {
204            rust_field: "id",
205            column_name: "id",
206            renamed_from: None,
207            sql_type: SqlServerType::BigInt,
208            nullable: false,
209            primary_key: true,
210            identity: Some(IdentityMetadata::new(1, 1)),
211            default_sql: None,
212            computed_sql: None,
213            rowversion: false,
214            insertable: false,
215            updatable: false,
216            max_length: None,
217            precision: None,
218            scale: None,
219        },
220        ColumnMetadata {
221            rust_field: "customer_id",
222            column_name: "customer_id",
223            renamed_from: None,
224            sql_type: SqlServerType::BigInt,
225            nullable: false,
226            primary_key: false,
227            identity: None,
228            default_sql: None,
229            computed_sql: None,
230            rowversion: false,
231            insertable: true,
232            updatable: true,
233            max_length: None,
234            precision: None,
235            scale: None,
236        },
237        ColumnMetadata {
238            rust_field: "total_cents",
239            column_name: "total_cents",
240            renamed_from: None,
241            sql_type: SqlServerType::BigInt,
242            nullable: false,
243            primary_key: false,
244            identity: None,
245            default_sql: None,
246            computed_sql: None,
247            rowversion: false,
248            insertable: true,
249            updatable: true,
250            max_length: None,
251            precision: None,
252            scale: None,
253        },
254    ];
255
256    static ORDER_METADATA: EntityMetadata = EntityMetadata {
257        rust_name: "Order",
258        schema: "sales",
259        table: "orders",
260        renamed_from: None,
261        columns: &ORDER_COLUMNS,
262        primary_key: PrimaryKeyMetadata::new(Some("pk_orders"), &["id"]),
263        indexes: &[],
264        foreign_keys: &[],
265        navigations: &[],
266    };
267
268    impl Entity for Order {
269        fn metadata() -> &'static EntityMetadata {
270            &ORDER_METADATA
271        }
272    }
273
274    #[allow(non_upper_case_globals)]
275    impl Customer {
276        const id: EntityColumn<Customer> = EntityColumn::new("id", "id");
277        const email: EntityColumn<Customer> = EntityColumn::new("email", "email");
278        const active: EntityColumn<Customer> = EntityColumn::new("active", "active");
279        const created_at: EntityColumn<Customer> = EntityColumn::new("created_at", "created_at");
280    }
281
282    #[allow(non_upper_case_globals)]
283    impl Order {
284        const customer_id: EntityColumn<Order> = EntityColumn::new("customer_id", "customer_id");
285        const total_cents: EntityColumn<Order> = EntityColumn::new("total_cents", "total_cents");
286    }
287
288    struct NewCustomer {
289        email: String,
290        active: bool,
291    }
292
293    impl Insertable<Customer> for NewCustomer {
294        fn values(&self) -> Vec<ColumnValue> {
295            vec![
296                ColumnValue::new("email", SqlValue::String(self.email.clone())),
297                ColumnValue::new("active", SqlValue::Bool(self.active)),
298            ]
299        }
300    }
301
302    struct UpdateCustomer {
303        email: Option<String>,
304    }
305
306    impl Changeset<Customer> for UpdateCustomer {
307        fn changes(&self) -> Vec<ColumnValue> {
308            self.email
309                .clone()
310                .map(|email| vec![ColumnValue::new("email", SqlValue::String(email))])
311                .unwrap_or_default()
312        }
313    }
314
315    #[test]
316    fn keeps_query_layer_sql_free() {
317        assert!(
318            CRATE_IDENTITY
319                .responsibility
320                .contains("without SQL generation")
321        );
322    }
323
324    #[test]
325    fn entity_columns_become_table_aware_column_refs() {
326        let column = ColumnRef::for_entity_column(Customer::email);
327
328        assert_eq!(column.table, TableRef::new("sales", "customers"));
329        assert_eq!(column.rust_field, "email");
330        assert_eq!(column.column_name, "email");
331    }
332
333    #[test]
334    fn expr_supports_columns_values_functions_and_operations() {
335        let expr = Expr::binary(
336            Expr::function(SqlFunction::Lower, vec![Expr::from(Customer::email)]),
337            BinaryOp::Add,
338            Expr::value(SqlValue::String("@example.com".to_string())),
339        );
340
341        match expr {
342            Expr::Binary { left, op, right } => {
343                assert_eq!(op, BinaryOp::Add);
344                assert!(matches!(*left, Expr::Function { .. }));
345                assert_eq!(
346                    *right,
347                    Expr::Value(SqlValue::String("@example.com".to_string()))
348                );
349            }
350            other => panic!("unexpected expr shape: {other:?}"),
351        }
352    }
353
354    #[test]
355    fn predicates_can_be_composed_without_sql_rendering() {
356        let predicate = Predicate::and(vec![
357            Predicate::eq(
358                Expr::from(Customer::active),
359                Expr::value(SqlValue::Bool(true)),
360            ),
361            Predicate::like(
362                Expr::from(Customer::email),
363                Expr::value(SqlValue::String("%@example.com".to_string())),
364            ),
365            Predicate::like_escaped(
366                Expr::from(Customer::email),
367                Expr::value(SqlValue::String(r"%literal\%%".to_string())),
368                '\\',
369            ),
370        ]);
371
372        match predicate {
373            Predicate::And(parts) => assert_eq!(parts.len(), 3),
374            other => panic!("unexpected predicate shape: {other:?}"),
375        }
376    }
377
378    #[test]
379    fn select_query_captures_projection_filters_order_and_pagination() {
380        let query = SelectQuery::from_entity::<Customer>()
381            .select(vec![Expr::from(Customer::id), Expr::from(Customer::email)])
382            .filter(Predicate::eq(
383                Expr::from(Customer::active),
384                Expr::value(SqlValue::Bool(true)),
385            ))
386            .filter(Predicate::like(
387                Expr::from(Customer::email),
388                Expr::value(SqlValue::String("%@example.com".to_string())),
389            ))
390            .order_by(OrderBy::desc(Customer::created_at))
391            .paginate(Pagination::page(2, 20));
392
393        assert_eq!(query.from, TableRef::new("sales", "customers"));
394        assert!(query.joins.is_empty());
395        assert_eq!(
396            query.projection,
397            vec![
398                SelectProjection::column(Customer::id),
399                SelectProjection::column(Customer::email)
400            ]
401        );
402        assert_eq!(
403            query.order_by,
404            vec![OrderBy::new(
405                TableRef::new("sales", "customers"),
406                "created_at",
407                SortDirection::Desc,
408            )]
409        );
410        assert_eq!(query.pagination, Some(Pagination::new(20, 20)));
411        assert!(matches!(query.predicate, Some(Predicate::And(_))));
412    }
413
414    #[test]
415    fn select_query_captures_explicit_joins_without_sql_rendering() {
416        let query = SelectQuery::from_entity::<Customer>()
417            .inner_join::<Order>(Predicate::eq(
418                Expr::from(Customer::id),
419                Expr::from(Order::customer_id),
420            ))
421            .join(Join::left(
422                TableRef::new("sales", "orders"),
423                Predicate::gt(
424                    Expr::from(Order::total_cents),
425                    Expr::value(SqlValue::I64(0)),
426                ),
427            ));
428
429        assert_eq!(query.joins.len(), 2);
430        assert_eq!(query.joins[0].join_type, JoinType::Inner);
431        assert_eq!(query.joins[0].table, TableRef::new("sales", "orders"));
432        assert!(matches!(query.joins[0].on, Predicate::Eq(_, _)));
433        assert_eq!(query.joins[1].join_type, JoinType::Left);
434        assert_eq!(query.joins[1].table, TableRef::new("sales", "orders"));
435        assert!(matches!(query.joins[1].on, Predicate::Gt(_, _)));
436    }
437
438    #[test]
439    fn table_refs_capture_optional_aliases_without_sql_rendering() {
440        let table = TableRef::for_entity_as::<Customer>("root");
441        let column = ColumnRef::for_entity_column_as(Customer::email, "root");
442        let expr = Expr::column_as(Customer::id, "root");
443
444        assert_eq!(table.schema, "sales");
445        assert_eq!(table.table, "customers");
446        assert_eq!(table.alias, Some("root"));
447        assert_eq!(table.reference_name(), "root");
448        assert_eq!(table.without_alias(), TableRef::new("sales", "customers"));
449        assert_eq!(column.table, table);
450
451        match expr {
452            Expr::Column(column) => {
453                assert_eq!(column.table.alias, Some("root"));
454                assert_eq!(column.column_name, "id");
455            }
456            other => panic!("unexpected expr shape: {other:?}"),
457        }
458    }
459
460    #[test]
461    fn select_query_captures_aliased_sources_and_repeated_joins() {
462        let query = SelectQuery::from_entity_as::<Customer>("c")
463            .inner_join_as::<Order>(
464                "created_orders",
465                Predicate::eq(
466                    Expr::column_as(Customer::id, "c"),
467                    Expr::column_as(Order::customer_id, "created_orders"),
468                ),
469            )
470            .left_join_as::<Order>(
471                "completed_orders",
472                Predicate::gt(
473                    Expr::column_as(Order::total_cents, "completed_orders"),
474                    Expr::value(SqlValue::I64(0)),
475                ),
476            );
477
478        assert_eq!(query.from, TableRef::with_alias("sales", "customers", "c"));
479        assert_eq!(query.joins.len(), 2);
480        assert_eq!(
481            query.joins[0].table,
482            TableRef::with_alias("sales", "orders", "created_orders")
483        );
484        assert_eq!(
485            query.joins[1].table,
486            TableRef::with_alias("sales", "orders", "completed_orders")
487        );
488        assert_ne!(query.joins[0].table, query.joins[1].table);
489    }
490
491    #[test]
492    fn select_projection_captures_default_and_explicit_aliases() {
493        let column_projection = SelectProjection::column(Customer::email);
494        assert_eq!(column_projection.alias.as_deref(), Some("email"));
495        assert_eq!(column_projection.expr, Expr::from(Customer::email));
496
497        let expression_projection = SelectProjection::expr_as(
498            Expr::function(SqlFunction::Lower, vec![Expr::from(Customer::email)]),
499            "email_lower",
500        );
501        assert_eq!(expression_projection.alias.as_deref(), Some("email_lower"));
502
503        let owned_alias_projection =
504            SelectProjection::expr_as(Expr::from(Customer::email), "owned_email".to_string());
505        assert_eq!(owned_alias_projection.alias.as_deref(), Some("owned_email"));
506
507        let unaliased_expression = SelectProjection::expr(Expr::function(
508            SqlFunction::Lower,
509            vec![Expr::from(Customer::email)],
510        ));
511        assert_eq!(unaliased_expression.alias.as_deref(), None);
512    }
513
514    #[test]
515    fn aggregate_projection_requires_alias_without_changing_select_projection() {
516        let group_key = AggregateProjection::group_key(Order::customer_id);
517        let expression_group_key = AggregateProjection::group_key_as(
518            Expr::function(SqlFunction::Year, vec![Expr::from(Customer::created_at)]),
519            "created_year",
520        );
521        let aggregate = AggregateProjection::sum_as(Order::total_cents, "total_cents");
522
523        assert_eq!(group_key.alias, "customer_id");
524        assert_eq!(expression_group_key.alias, "created_year");
525        assert_eq!(aggregate.alias, "total_cents");
526        assert_eq!(
527            aggregate,
528            AggregateProjection::expr_as(
529                AggregateExpr::Sum(Expr::from(Order::total_cents)),
530                "total_cents"
531            )
532        );
533
534        let ordinary_projection = SelectProjection::expr(Expr::function(
535            SqlFunction::Lower,
536            vec![Expr::from(Customer::email)],
537        ));
538        assert_eq!(ordinary_projection.alias.as_deref(), None);
539    }
540
541    #[test]
542    fn aggregate_query_captures_grouping_having_and_projection_without_sql_rendering() {
543        let query = AggregateQuery::from_entity::<Order>()
544            .inner_join::<Customer>(Predicate::eq(
545                Expr::from(Order::customer_id),
546                Expr::from(Customer::id),
547            ))
548            .filter(Predicate::eq(
549                Expr::from(Customer::active),
550                Expr::value(SqlValue::Bool(true)),
551            ))
552            .group_by(vec![Expr::from(Order::customer_id)])
553            .project(vec![
554                AggregateProjection::group_key(Order::customer_id),
555                AggregateProjection::count_as("order_count"),
556                AggregateProjection::sum_as(Order::total_cents, "total_cents"),
557            ])
558            .having(AggregatePredicate::gt(
559                AggregateExpr::count_all(),
560                Expr::value(SqlValue::I64(1)),
561            ))
562            .order_by(AggregateOrderBy::desc(AggregateExpr::sum(Expr::from(
563                Order::total_cents,
564            ))))
565            .paginate(Pagination::page(1, 10));
566
567        assert_eq!(query.from, TableRef::new("sales", "orders"));
568        assert_eq!(query.joins.len(), 1);
569        assert!(query.predicate.is_some());
570        assert_eq!(query.group_by, vec![Expr::from(Order::customer_id)]);
571        assert_eq!(
572            query.projection,
573            vec![
574                AggregateProjection::group_key(Order::customer_id),
575                AggregateProjection::count_as("order_count"),
576                AggregateProjection::sum_as(Order::total_cents, "total_cents")
577            ]
578        );
579        assert!(matches!(query.having, Some(AggregatePredicate::Gt(_, _))));
580        assert_eq!(
581            query.order_by,
582            vec![AggregateOrderBy::desc(AggregateExpr::sum(Expr::from(
583                Order::total_cents
584            )))]
585        );
586        assert_eq!(query.pagination, Some(Pagination::new(0, 10)));
587        assert!(matches!(
588            Query::Aggregate(Box::new(query)),
589            Query::Aggregate(_)
590        ));
591    }
592
593    #[test]
594    fn insert_update_delete_and_count_queries_capture_operation_data() {
595        let insert = InsertQuery::for_entity::<Customer, _>(&NewCustomer {
596            email: "ana@example.com".to_string(),
597            active: true,
598        });
599        let update = UpdateQuery::for_entity::<Customer, _>(&UpdateCustomer {
600            email: Some("ana.maria@example.com".to_string()),
601        })
602        .filter(Predicate::eq(
603            Expr::from(Customer::id),
604            Expr::value(SqlValue::I64(7)),
605        ));
606        let delete = DeleteQuery::from_entity::<Customer>().filter(Predicate::eq(
607            Expr::from(Customer::id),
608            Expr::value(SqlValue::I64(7)),
609        ));
610        let count = CountQuery::from_entity::<Customer>().filter(Predicate::eq(
611            Expr::from(Customer::active),
612            Expr::value(SqlValue::Bool(true)),
613        ));
614        let exists = ExistsQuery::from_entity::<Customer>()
615            .inner_join::<Order>(Predicate::eq(
616                Expr::from(Customer::id),
617                Expr::from(Order::customer_id),
618            ))
619            .filter(Predicate::eq(
620                Expr::from(Customer::active),
621                Expr::value(SqlValue::Bool(true)),
622            ));
623
624        assert_eq!(insert.into, TableRef::new("sales", "customers"));
625        assert_eq!(insert.values.len(), 2);
626        assert_eq!(insert.entity, Some(Customer::metadata()));
627        assert_eq!(update.table, TableRef::new("sales", "customers"));
628        assert_eq!(update.changes.len(), 1);
629        assert!(update.predicate.is_some());
630        assert!(!update.allow_all_rows);
631        assert_eq!(update.entity, Some(Customer::metadata()));
632        assert!(
633            UpdateQuery::for_entity::<Customer, _>(&UpdateCustomer {
634                email: Some("ana.maria@example.com".to_string()),
635            })
636            .allow_all_rows()
637            .allow_all_rows
638        );
639        assert_eq!(delete.from, TableRef::new("sales", "customers"));
640        assert!(delete.predicate.is_some());
641        assert!(!delete.allow_all_rows);
642        assert!(
643            DeleteQuery::from_entity::<Customer>()
644                .allow_all_rows()
645                .allow_all_rows
646        );
647        assert_eq!(count.from, TableRef::new("sales", "customers"));
648        assert!(count.predicate.is_some());
649        assert_eq!(exists.from, TableRef::new("sales", "customers"));
650        assert_eq!(exists.joins.len(), 1);
651        assert!(exists.predicate.is_some());
652
653        assert!(matches!(Query::Insert(insert.clone()), Query::Insert(_)));
654        assert!(matches!(Query::Update(update.clone()), Query::Update(_)));
655        assert!(matches!(Query::Delete(delete.clone()), Query::Delete(_)));
656        assert!(matches!(Query::Count(count.clone()), Query::Count(_)));
657        assert!(matches!(Query::Exists(Box::new(exists)), Query::Exists(_)));
658    }
659
660    #[test]
661    fn compiled_query_keeps_sql_and_parameter_order() {
662        let compiled = CompiledQuery::new(
663            "SELECT [id] FROM [sales].[customers] WHERE [active] = @P1 AND [email] LIKE @P2",
664            vec![
665                SqlValue::Bool(true),
666                SqlValue::String("%@example.com".to_string()),
667            ],
668        );
669
670        assert_eq!(
671            compiled.sql,
672            "SELECT [id] FROM [sales].[customers] WHERE [active] = @P1 AND [email] LIKE @P2"
673        );
674        assert_eq!(
675            compiled.params,
676            vec![
677                SqlValue::Bool(true),
678                SqlValue::String("%@example.com".to_string()),
679            ]
680        );
681    }
682}