1use crate::quoting::{
2 quote_column_ref, quote_identifier, quote_table_ref, quote_table_reference, quote_table_source,
3};
4use sql_orm_core::{ColumnValue, OrmError, SqlValue};
5use sql_orm_query::{
6 AggregateExpr, AggregateOrderBy, AggregatePredicate, AggregateProjection, AggregateQuery,
7 BinaryOp, CompiledQuery, CountQuery, DeleteQuery, ExistsQuery, Expr, InsertQuery, Join,
8 JoinType, OrderBy, Pagination, Predicate, Query, SelectProjection, SelectQuery, SortDirection,
9 TableRef, UnaryOp, UpdateQuery,
10};
11use std::collections::BTreeSet;
12
13#[derive(Debug, Default)]
14struct ParameterBuilder {
15 params: Vec<SqlValue>,
16}
17
18impl ParameterBuilder {
19 fn push(&mut self, value: SqlValue) -> String {
20 self.params.push(value);
21 format!("@P{}", self.params.len())
22 }
23
24 fn finish(self, sql: String) -> CompiledQuery {
25 CompiledQuery::new(sql, self.params)
26 }
27}
28
29impl crate::SqlServerCompiler {
30 pub fn compile_query(query: &Query) -> Result<CompiledQuery, OrmError> {
31 match query {
32 Query::Select(query) => Self::compile_select(query),
33 Query::Aggregate(query) => Self::compile_aggregate(query),
34 Query::Exists(query) => Self::compile_exists(query),
35 Query::Insert(query) => Self::compile_insert(query),
36 Query::Update(query) => Self::compile_update(query),
37 Query::Delete(query) => Self::compile_delete(query),
38 Query::Count(query) => Self::compile_count(query),
39 }
40 }
41
42 pub fn compile_select(query: &SelectQuery) -> Result<CompiledQuery, OrmError> {
43 let mut parameters = ParameterBuilder::default();
44 let projection = compile_projection(&query.projection, &mut parameters)?;
45 let mut sql = format!(
46 "SELECT {projection} FROM {}",
47 quote_table_source(&query.from)?
48 );
49 sql.push_str(&compile_joins(&query.from, &query.joins, &mut parameters)?);
50
51 if let Some(predicate) = &query.predicate {
52 let predicate = compile_predicate(predicate, &mut parameters)?;
53 sql.push_str(" WHERE ");
54 sql.push_str(&predicate);
55 }
56
57 if !query.order_by.is_empty() {
58 sql.push_str(" ORDER BY ");
59 sql.push_str(&compile_order_by(&query.order_by)?);
60 }
61
62 if let Some(pagination) = query.pagination {
63 if query.order_by.is_empty() {
64 return Err(OrmError::new(
65 "SQL Server pagination requires ORDER BY before OFFSET/FETCH",
66 ));
67 }
68
69 sql.push(' ');
70 sql.push_str(&compile_pagination(pagination, &mut parameters));
71 }
72
73 Ok(parameters.finish(sql))
74 }
75
76 pub fn compile_insert(query: &InsertQuery) -> Result<CompiledQuery, OrmError> {
77 if query.values.is_empty() {
78 return Err(OrmError::new(
79 "SQL Server insert compilation requires at least one value",
80 ));
81 }
82
83 let mut parameters = ParameterBuilder::default();
84 let (columns, values) = compile_column_values(&query.values, &mut parameters)?;
85 let sql = format!(
86 "INSERT INTO {} ({columns}) OUTPUT INSERTED.* VALUES ({values})",
87 quote_table_ref(&query.into)?,
88 );
89
90 Ok(parameters.finish(sql))
91 }
92
93 pub fn compile_update(query: &UpdateQuery) -> Result<CompiledQuery, OrmError> {
94 if query.changes.is_empty() {
95 return Err(OrmError::new(
96 "SQL Server update compilation requires at least one change",
97 ));
98 }
99
100 let mut parameters = ParameterBuilder::default();
101 let assignments = compile_assignments(&query.changes, &mut parameters)?;
102 let mut sql = format!(
103 "UPDATE {} SET {assignments} OUTPUT INSERTED.*",
104 quote_table_ref(&query.table)?,
105 );
106
107 if let Some(predicate) = &query.predicate {
108 let predicate = compile_predicate(predicate, &mut parameters)?;
109 sql.push_str(" WHERE ");
110 sql.push_str(&predicate);
111 }
112
113 Ok(parameters.finish(sql))
114 }
115
116 pub fn compile_delete(query: &DeleteQuery) -> Result<CompiledQuery, OrmError> {
117 let mut parameters = ParameterBuilder::default();
118 let mut sql = format!("DELETE FROM {}", quote_table_ref(&query.from)?);
119
120 if let Some(predicate) = &query.predicate {
121 let predicate = compile_predicate(predicate, &mut parameters)?;
122 sql.push_str(" WHERE ");
123 sql.push_str(&predicate);
124 }
125
126 Ok(parameters.finish(sql))
127 }
128
129 pub fn compile_count(query: &CountQuery) -> Result<CompiledQuery, OrmError> {
130 let mut parameters = ParameterBuilder::default();
131 let mut sql = format!(
132 "SELECT COUNT(*) AS {} FROM {}",
133 quote_identifier("count")?,
134 quote_table_source(&query.from)?,
135 );
136
137 if let Some(predicate) = &query.predicate {
138 let predicate = compile_predicate(predicate, &mut parameters)?;
139 sql.push_str(" WHERE ");
140 sql.push_str(&predicate);
141 }
142
143 Ok(parameters.finish(sql))
144 }
145
146 pub fn compile_exists(query: &ExistsQuery) -> Result<CompiledQuery, OrmError> {
147 let mut parameters = ParameterBuilder::default();
148 let mut subquery = format!("SELECT 1 FROM {}", quote_table_source(&query.from)?);
149 subquery.push_str(&compile_joins(&query.from, &query.joins, &mut parameters)?);
150
151 if let Some(predicate) = &query.predicate {
152 let predicate = compile_predicate(predicate, &mut parameters)?;
153 subquery.push_str(" WHERE ");
154 subquery.push_str(&predicate);
155 }
156
157 let sql = format!(
158 "SELECT CASE WHEN EXISTS ({subquery}) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END AS {}",
159 quote_identifier("exists")?
160 );
161
162 Ok(parameters.finish(sql))
163 }
164
165 pub fn compile_aggregate(query: &AggregateQuery) -> Result<CompiledQuery, OrmError> {
166 validate_aggregate_query(query)?;
167
168 let mut parameters = ParameterBuilder::default();
169 let projection =
170 compile_aggregate_projection(&query.projection, &query.group_by, &mut parameters)?;
171 let mut sql = format!(
172 "SELECT {projection} FROM {}",
173 quote_table_source(&query.from)?
174 );
175 sql.push_str(&compile_joins(&query.from, &query.joins, &mut parameters)?);
176
177 if let Some(predicate) = &query.predicate {
178 let predicate = compile_predicate(predicate, &mut parameters)?;
179 sql.push_str(" WHERE ");
180 sql.push_str(&predicate);
181 }
182
183 if !query.group_by.is_empty() {
184 sql.push_str(" GROUP BY ");
185 sql.push_str(&compile_group_by(&query.group_by, &mut parameters)?);
186 }
187
188 if let Some(having) = &query.having {
189 let having = compile_aggregate_predicate(having, &query.group_by, &mut parameters)?;
190 sql.push_str(" HAVING ");
191 sql.push_str(&having);
192 }
193
194 if !query.order_by.is_empty() {
195 sql.push_str(" ORDER BY ");
196 sql.push_str(&compile_aggregate_order_by(
197 &query.order_by,
198 &query.group_by,
199 &mut parameters,
200 )?);
201 }
202
203 if let Some(pagination) = query.pagination {
204 if query.order_by.is_empty() {
205 return Err(OrmError::new(
206 "SQL Server aggregate pagination requires ORDER BY before OFFSET/FETCH",
207 ));
208 }
209
210 sql.push(' ');
211 sql.push_str(&compile_pagination(pagination, &mut parameters));
212 }
213
214 Ok(parameters.finish(sql))
215 }
216}
217
218fn validate_aggregate_query(query: &AggregateQuery) -> Result<(), OrmError> {
219 if query.projection.is_empty() {
220 return Err(OrmError::new(
221 "SQL Server aggregate query compilation requires at least one projection",
222 ));
223 }
224
225 validate_aggregate_projection(&query.projection, &query.group_by)?;
226
227 if let Some(having) = &query.having {
228 validate_aggregate_predicate(having, &query.group_by)?;
229 }
230
231 for order in &query.order_by {
232 validate_aggregate_expr(&order.expr, &query.group_by)?;
233 }
234
235 Ok(())
236}
237
238fn compile_joins(
239 from: &TableRef,
240 joins: &[Join],
241 parameters: &mut ParameterBuilder,
242) -> Result<String, OrmError> {
243 let mut compiled = String::new();
244 let mut seen_tables = vec![*from];
245
246 for join in joins {
247 if seen_tables.contains(&join.table) {
248 return Err(OrmError::new(
249 "SQL Server join compilation requires aliases for repeated table sources",
250 ));
251 }
252
253 seen_tables.push(join.table);
254 compiled.push(' ');
255 compiled.push_str(match join.join_type {
256 JoinType::Inner => "INNER JOIN ",
257 JoinType::Left => "LEFT JOIN ",
258 });
259 compiled.push_str("e_table_source(&join.table)?);
260 compiled.push_str(" ON ");
261 compiled.push_str(&compile_predicate(&join.on, parameters)?);
262 }
263
264 Ok(compiled)
265}
266
267fn compile_projection(
268 projection: &[SelectProjection],
269 parameters: &mut ParameterBuilder,
270) -> Result<String, OrmError> {
271 if projection.is_empty() {
272 return Ok("*".to_string());
273 }
274
275 let mut aliases = BTreeSet::new();
276 let parts = projection
277 .iter()
278 .map(|projection| {
279 let alias = projection.alias.ok_or_else(|| {
280 OrmError::new("SQL Server projection expressions require an explicit alias")
281 })?;
282 if alias.trim().is_empty() {
283 return Err(OrmError::new("SQL Server projection alias cannot be empty"));
284 }
285 if !aliases.insert(alias) {
286 return Err(OrmError::new(format!(
287 "SQL Server projection alias `{alias}` is duplicated"
288 )));
289 }
290
291 Ok(format!(
292 "{} AS {}",
293 compile_expr(&projection.expr, parameters)?,
294 quote_identifier(alias)?
295 ))
296 })
297 .collect::<Result<Vec<_>, _>>()?;
298 Ok(parts.join(", "))
299}
300
301fn compile_aggregate_projection(
302 projection: &[AggregateProjection],
303 group_by: &[Expr],
304 parameters: &mut ParameterBuilder,
305) -> Result<String, OrmError> {
306 let mut aliases = BTreeSet::new();
307 let parts = projection
308 .iter()
309 .map(|projection| {
310 if projection.alias.trim().is_empty() {
311 return Err(OrmError::new(
312 "SQL Server aggregate projection alias cannot be empty",
313 ));
314 }
315 if !aliases.insert(projection.alias) {
316 return Err(OrmError::new(format!(
317 "SQL Server aggregate projection alias `{}` is duplicated",
318 projection.alias
319 )));
320 }
321
322 Ok(format!(
323 "{} AS {}",
324 compile_aggregate_expr(&projection.expr, group_by, parameters)?,
325 quote_identifier(projection.alias)?
326 ))
327 })
328 .collect::<Result<Vec<_>, OrmError>>()?;
329 Ok(parts.join(", "))
330}
331
332fn validate_aggregate_projection(
333 projection: &[AggregateProjection],
334 group_by: &[Expr],
335) -> Result<(), OrmError> {
336 let mut aliases = BTreeSet::new();
337
338 for projection in projection {
339 if projection.alias.trim().is_empty() {
340 return Err(OrmError::new(
341 "SQL Server aggregate projection alias cannot be empty",
342 ));
343 }
344 if !aliases.insert(projection.alias) {
345 return Err(OrmError::new(format!(
346 "SQL Server aggregate projection alias `{}` is duplicated",
347 projection.alias
348 )));
349 }
350
351 validate_aggregate_expr(&projection.expr, group_by)?;
352 }
353
354 Ok(())
355}
356
357fn validate_aggregate_expr(expr: &AggregateExpr, group_by: &[Expr]) -> Result<(), OrmError> {
358 match expr {
359 AggregateExpr::GroupKey(expr) => validate_group_key(expr, group_by),
360 AggregateExpr::CountAll
361 | AggregateExpr::Count(_)
362 | AggregateExpr::Sum(_)
363 | AggregateExpr::Avg(_)
364 | AggregateExpr::Min(_)
365 | AggregateExpr::Max(_) => Ok(()),
366 }
367}
368
369fn validate_aggregate_predicate(
370 predicate: &AggregatePredicate,
371 group_by: &[Expr],
372) -> Result<(), OrmError> {
373 match predicate {
374 AggregatePredicate::Eq(left, right)
375 | AggregatePredicate::Ne(left, right)
376 | AggregatePredicate::Gt(left, right)
377 | AggregatePredicate::Gte(left, right)
378 | AggregatePredicate::Lt(left, right)
379 | AggregatePredicate::Lte(left, right) => {
380 validate_aggregate_expr(left, group_by)?;
381 validate_non_aggregate_expr_in_grouped_context(right, group_by)
382 }
383 AggregatePredicate::And(predicates) | AggregatePredicate::Or(predicates) => {
384 if predicates.is_empty() {
385 return Err(OrmError::new(
386 "aggregate logical predicate compilation requires at least one child predicate",
387 ));
388 }
389
390 for predicate in predicates {
391 validate_aggregate_predicate(predicate, group_by)?;
392 }
393 Ok(())
394 }
395 AggregatePredicate::Not(predicate) => validate_aggregate_predicate(predicate, group_by),
396 }
397}
398
399fn validate_non_aggregate_expr_in_grouped_context(
400 expr: &Expr,
401 group_by: &[Expr],
402) -> Result<(), OrmError> {
403 match expr {
404 Expr::Column(_) => validate_group_key(expr, group_by),
405 Expr::Value(_) => Ok(()),
406 Expr::Binary { left, right, .. } => {
407 validate_non_aggregate_expr_in_grouped_context(left, group_by)?;
408 validate_non_aggregate_expr_in_grouped_context(right, group_by)
409 }
410 Expr::Unary { expr, .. } => validate_non_aggregate_expr_in_grouped_context(expr, group_by),
411 Expr::Function { args, .. } => {
412 for arg in args {
413 validate_non_aggregate_expr_in_grouped_context(arg, group_by)?;
414 }
415 Ok(())
416 }
417 }
418}
419
420fn compile_group_by(
421 group_by: &[Expr],
422 parameters: &mut ParameterBuilder,
423) -> Result<String, OrmError> {
424 let parts = group_by
425 .iter()
426 .map(|expr| compile_expr(expr, parameters))
427 .collect::<Result<Vec<_>, _>>()?;
428 Ok(parts.join(", "))
429}
430
431fn compile_aggregate_expr(
432 expr: &AggregateExpr,
433 group_by: &[Expr],
434 parameters: &mut ParameterBuilder,
435) -> Result<String, OrmError> {
436 match expr {
437 AggregateExpr::GroupKey(expr) => {
438 validate_group_key(expr, group_by)?;
439 compile_expr(expr, parameters)
440 }
441 AggregateExpr::CountAll => Ok("COUNT(*)".to_string()),
442 AggregateExpr::Count(expr) => Ok(format!("COUNT({})", compile_expr(expr, parameters)?)),
443 AggregateExpr::Sum(expr) => Ok(format!("SUM({})", compile_expr(expr, parameters)?)),
444 AggregateExpr::Avg(expr) => Ok(format!("AVG({})", compile_expr(expr, parameters)?)),
445 AggregateExpr::Min(expr) => Ok(format!("MIN({})", compile_expr(expr, parameters)?)),
446 AggregateExpr::Max(expr) => Ok(format!("MAX({})", compile_expr(expr, parameters)?)),
447 }
448}
449
450fn validate_group_key(expr: &Expr, group_by: &[Expr]) -> Result<(), OrmError> {
451 if group_by.iter().any(|group_key| group_key == expr) {
452 return Ok(());
453 }
454
455 Err(OrmError::new(
456 "SQL Server aggregate group key projection must appear in GROUP BY",
457 ))
458}
459
460fn compile_expr(expr: &Expr, parameters: &mut ParameterBuilder) -> Result<String, OrmError> {
461 match expr {
462 Expr::Column(column) => quote_column_ref(column),
463 Expr::Value(value) => Ok(parameters.push(value.clone())),
464 Expr::Binary { left, op, right } => Ok(format!(
465 "({} {} {})",
466 compile_expr(left, parameters)?,
467 compile_binary_op(*op),
468 compile_expr(right, parameters)?,
469 )),
470 Expr::Unary { op, expr } => Ok(format!(
471 "({} {})",
472 compile_unary_op(*op),
473 compile_expr(expr, parameters)?,
474 )),
475 Expr::Function { name, args } => {
476 if name.trim().is_empty() {
477 return Err(OrmError::new("SQL function name cannot be empty"));
478 }
479
480 let args = args
481 .iter()
482 .map(|arg| compile_expr(arg, parameters))
483 .collect::<Result<Vec<_>, _>>()?;
484
485 Ok(format!("{name}({})", args.join(", ")))
486 }
487 }
488}
489
490fn compile_predicate(
491 predicate: &Predicate,
492 parameters: &mut ParameterBuilder,
493) -> Result<String, OrmError> {
494 match predicate {
495 Predicate::Eq(left, right) => compile_comparison(left, "=", right, parameters),
496 Predicate::Ne(left, right) => compile_comparison(left, "<>", right, parameters),
497 Predicate::Gt(left, right) => compile_comparison(left, ">", right, parameters),
498 Predicate::Gte(left, right) => compile_comparison(left, ">=", right, parameters),
499 Predicate::Lt(left, right) => compile_comparison(left, "<", right, parameters),
500 Predicate::Lte(left, right) => compile_comparison(left, "<=", right, parameters),
501 Predicate::Like(left, right) => compile_comparison(left, "LIKE", right, parameters),
502 Predicate::IsNull(expr) => Ok(format!("({} IS NULL)", compile_expr(expr, parameters)?)),
503 Predicate::IsNotNull(expr) => {
504 Ok(format!("({} IS NOT NULL)", compile_expr(expr, parameters)?))
505 }
506 Predicate::And(predicates) => compile_logical("AND", predicates, parameters),
507 Predicate::Or(predicates) => compile_logical("OR", predicates, parameters),
508 Predicate::Not(predicate) => Ok(format!(
509 "(NOT {})",
510 compile_predicate(predicate, parameters)?
511 )),
512 }
513}
514
515fn compile_aggregate_predicate(
516 predicate: &AggregatePredicate,
517 group_by: &[Expr],
518 parameters: &mut ParameterBuilder,
519) -> Result<String, OrmError> {
520 match predicate {
521 AggregatePredicate::Eq(left, right) => {
522 compile_aggregate_comparison(left, "=", right, group_by, parameters)
523 }
524 AggregatePredicate::Ne(left, right) => {
525 compile_aggregate_comparison(left, "<>", right, group_by, parameters)
526 }
527 AggregatePredicate::Gt(left, right) => {
528 compile_aggregate_comparison(left, ">", right, group_by, parameters)
529 }
530 AggregatePredicate::Gte(left, right) => {
531 compile_aggregate_comparison(left, ">=", right, group_by, parameters)
532 }
533 AggregatePredicate::Lt(left, right) => {
534 compile_aggregate_comparison(left, "<", right, group_by, parameters)
535 }
536 AggregatePredicate::Lte(left, right) => {
537 compile_aggregate_comparison(left, "<=", right, group_by, parameters)
538 }
539 AggregatePredicate::And(predicates) => {
540 compile_aggregate_logical("AND", predicates, group_by, parameters)
541 }
542 AggregatePredicate::Or(predicates) => {
543 compile_aggregate_logical("OR", predicates, group_by, parameters)
544 }
545 AggregatePredicate::Not(predicate) => Ok(format!(
546 "(NOT {})",
547 compile_aggregate_predicate(predicate, group_by, parameters)?
548 )),
549 }
550}
551
552fn compile_aggregate_comparison(
553 left: &AggregateExpr,
554 operator: &str,
555 right: &Expr,
556 group_by: &[Expr],
557 parameters: &mut ParameterBuilder,
558) -> Result<String, OrmError> {
559 Ok(format!(
560 "({} {operator} {})",
561 compile_aggregate_expr(left, group_by, parameters)?,
562 compile_expr(right, parameters)?,
563 ))
564}
565
566fn compile_aggregate_logical(
567 operator: &str,
568 predicates: &[AggregatePredicate],
569 group_by: &[Expr],
570 parameters: &mut ParameterBuilder,
571) -> Result<String, OrmError> {
572 if predicates.is_empty() {
573 return Err(OrmError::new(
574 "aggregate logical predicate compilation requires at least one child predicate",
575 ));
576 }
577
578 let compiled = predicates
579 .iter()
580 .map(|predicate| compile_aggregate_predicate(predicate, group_by, parameters))
581 .collect::<Result<Vec<_>, _>>()?;
582
583 Ok(format!("({})", compiled.join(&format!(" {operator} "))))
584}
585
586fn compile_comparison(
587 left: &Expr,
588 operator: &str,
589 right: &Expr,
590 parameters: &mut ParameterBuilder,
591) -> Result<String, OrmError> {
592 Ok(format!(
593 "({} {operator} {})",
594 compile_expr(left, parameters)?,
595 compile_expr(right, parameters)?,
596 ))
597}
598
599fn compile_logical(
600 operator: &str,
601 predicates: &[Predicate],
602 parameters: &mut ParameterBuilder,
603) -> Result<String, OrmError> {
604 if predicates.is_empty() {
605 return Err(OrmError::new(
606 "logical predicate compilation requires at least one child predicate",
607 ));
608 }
609
610 let compiled = predicates
611 .iter()
612 .map(|predicate| compile_predicate(predicate, parameters))
613 .collect::<Result<Vec<_>, _>>()?;
614
615 Ok(format!("({})", compiled.join(&format!(" {operator} "))))
616}
617
618fn compile_order_by(order_by: &[OrderBy]) -> Result<String, OrmError> {
619 let parts = order_by
620 .iter()
621 .map(|order| {
622 Ok(format!(
623 "{}.{} {}",
624 quote_table_reference(&order.table)?,
625 quote_identifier(order.column_name)?,
626 match order.direction {
627 SortDirection::Asc => "ASC",
628 SortDirection::Desc => "DESC",
629 },
630 ))
631 })
632 .collect::<Result<Vec<_>, OrmError>>()?;
633
634 Ok(parts.join(", "))
635}
636
637fn compile_aggregate_order_by(
638 order_by: &[AggregateOrderBy],
639 group_by: &[Expr],
640 parameters: &mut ParameterBuilder,
641) -> Result<String, OrmError> {
642 let parts = order_by
643 .iter()
644 .map(|order| {
645 Ok(format!(
646 "{} {}",
647 compile_aggregate_expr(&order.expr, group_by, parameters)?,
648 match order.direction {
649 SortDirection::Asc => "ASC",
650 SortDirection::Desc => "DESC",
651 },
652 ))
653 })
654 .collect::<Result<Vec<_>, OrmError>>()?;
655
656 Ok(parts.join(", "))
657}
658
659fn compile_pagination(pagination: Pagination, parameters: &mut ParameterBuilder) -> String {
660 let offset = parameters.push(SqlValue::I64(pagination.offset as i64));
661 let limit = parameters.push(SqlValue::I64(pagination.limit as i64));
662
663 format!("OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY")
664}
665
666fn compile_column_values(
667 values: &[ColumnValue],
668 parameters: &mut ParameterBuilder,
669) -> Result<(String, String), OrmError> {
670 let mut columns = Vec::with_capacity(values.len());
671 let mut placeholders = Vec::with_capacity(values.len());
672
673 for value in values {
674 columns.push(quote_identifier(value.column_name)?);
675 placeholders.push(parameters.push(value.value.clone()));
676 }
677
678 Ok((columns.join(", "), placeholders.join(", ")))
679}
680
681fn compile_assignments(
682 changes: &[ColumnValue],
683 parameters: &mut ParameterBuilder,
684) -> Result<String, OrmError> {
685 let assignments = changes
686 .iter()
687 .map(|change| {
688 Ok(format!(
689 "{} = {}",
690 quote_identifier(change.column_name)?,
691 parameters.push(change.value.clone()),
692 ))
693 })
694 .collect::<Result<Vec<_>, OrmError>>()?;
695
696 Ok(assignments.join(", "))
697}
698
699fn compile_binary_op(op: BinaryOp) -> &'static str {
700 match op {
701 BinaryOp::Add => "+",
702 BinaryOp::Subtract => "-",
703 BinaryOp::Multiply => "*",
704 BinaryOp::Divide => "/",
705 }
706}
707
708fn compile_unary_op(op: UnaryOp) -> &'static str {
709 match op {
710 UnaryOp::Negate => "-",
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::super::SqlServerCompiler;
717 use sql_orm_core::{
718 Changeset, ColumnMetadata, ColumnValue, Entity, EntityColumn, EntityMetadata,
719 IdentityMetadata, Insertable, PrimaryKeyMetadata, SqlServerType, SqlValue,
720 };
721 use sql_orm_query::{
722 AggregateExpr, AggregateOrderBy, AggregatePredicate, AggregateProjection, AggregateQuery,
723 BinaryOp, CountQuery, DeleteQuery, ExistsQuery, Expr, InsertQuery, OrderBy, Pagination,
724 Predicate, Query, SelectProjection, SelectQuery, TableRef, UnaryOp, UpdateQuery,
725 };
726
727 #[allow(dead_code)]
728 struct Customer;
729
730 #[allow(dead_code)]
731 struct Order;
732
733 static CUSTOMER_COLUMNS: [ColumnMetadata; 4] = [
734 ColumnMetadata {
735 rust_field: "id",
736 column_name: "id",
737 renamed_from: None,
738 sql_type: SqlServerType::BigInt,
739 nullable: false,
740 primary_key: true,
741 identity: Some(IdentityMetadata::new(1, 1)),
742 default_sql: None,
743 computed_sql: None,
744 rowversion: false,
745 insertable: false,
746 updatable: false,
747 max_length: None,
748 precision: None,
749 scale: None,
750 },
751 ColumnMetadata {
752 rust_field: "email",
753 column_name: "email",
754 renamed_from: None,
755 sql_type: SqlServerType::NVarChar,
756 nullable: false,
757 primary_key: false,
758 identity: None,
759 default_sql: None,
760 computed_sql: None,
761 rowversion: false,
762 insertable: true,
763 updatable: true,
764 max_length: Some(160),
765 precision: None,
766 scale: None,
767 },
768 ColumnMetadata {
769 rust_field: "active",
770 column_name: "active",
771 renamed_from: None,
772 sql_type: SqlServerType::Bit,
773 nullable: false,
774 primary_key: false,
775 identity: None,
776 default_sql: Some("1"),
777 computed_sql: None,
778 rowversion: false,
779 insertable: true,
780 updatable: true,
781 max_length: None,
782 precision: None,
783 scale: None,
784 },
785 ColumnMetadata {
786 rust_field: "created_at",
787 column_name: "created_at",
788 renamed_from: None,
789 sql_type: SqlServerType::DateTime2,
790 nullable: false,
791 primary_key: false,
792 identity: None,
793 default_sql: Some("SYSUTCDATETIME()"),
794 computed_sql: None,
795 rowversion: false,
796 insertable: true,
797 updatable: true,
798 max_length: None,
799 precision: None,
800 scale: None,
801 },
802 ];
803
804 static CUSTOMER_METADATA: EntityMetadata = EntityMetadata {
805 rust_name: "Customer",
806 schema: "sales",
807 table: "customers",
808 renamed_from: None,
809 columns: &CUSTOMER_COLUMNS,
810 primary_key: PrimaryKeyMetadata::new(Some("pk_customers"), &["id"]),
811 indexes: &[],
812 foreign_keys: &[],
813 navigations: &[],
814 };
815
816 impl Entity for Customer {
817 fn metadata() -> &'static EntityMetadata {
818 &CUSTOMER_METADATA
819 }
820 }
821
822 static ORDER_COLUMNS: [ColumnMetadata; 3] = [
823 ColumnMetadata {
824 rust_field: "id",
825 column_name: "id",
826 renamed_from: None,
827 sql_type: SqlServerType::BigInt,
828 nullable: false,
829 primary_key: true,
830 identity: Some(IdentityMetadata::new(1, 1)),
831 default_sql: None,
832 computed_sql: None,
833 rowversion: false,
834 insertable: false,
835 updatable: false,
836 max_length: None,
837 precision: None,
838 scale: None,
839 },
840 ColumnMetadata {
841 rust_field: "customer_id",
842 column_name: "customer_id",
843 renamed_from: None,
844 sql_type: SqlServerType::BigInt,
845 nullable: false,
846 primary_key: false,
847 identity: None,
848 default_sql: None,
849 computed_sql: None,
850 rowversion: false,
851 insertable: true,
852 updatable: true,
853 max_length: None,
854 precision: None,
855 scale: None,
856 },
857 ColumnMetadata {
858 rust_field: "total_cents",
859 column_name: "total_cents",
860 renamed_from: None,
861 sql_type: SqlServerType::BigInt,
862 nullable: false,
863 primary_key: false,
864 identity: None,
865 default_sql: None,
866 computed_sql: None,
867 rowversion: false,
868 insertable: true,
869 updatable: true,
870 max_length: None,
871 precision: None,
872 scale: None,
873 },
874 ];
875
876 static ORDER_METADATA: EntityMetadata = EntityMetadata {
877 rust_name: "Order",
878 schema: "sales",
879 table: "orders",
880 renamed_from: None,
881 columns: &ORDER_COLUMNS,
882 primary_key: PrimaryKeyMetadata::new(Some("pk_orders"), &["id"]),
883 indexes: &[],
884 foreign_keys: &[],
885 navigations: &[],
886 };
887
888 impl Entity for Order {
889 fn metadata() -> &'static EntityMetadata {
890 &ORDER_METADATA
891 }
892 }
893
894 #[allow(non_upper_case_globals)]
895 impl Customer {
896 const id: EntityColumn<Customer> = EntityColumn::new("id", "id");
897 const email: EntityColumn<Customer> = EntityColumn::new("email", "email");
898 const active: EntityColumn<Customer> = EntityColumn::new("active", "active");
899 const created_at: EntityColumn<Customer> = EntityColumn::new("created_at", "created_at");
900 }
901
902 #[allow(non_upper_case_globals)]
903 impl Order {
904 const customer_id: EntityColumn<Order> = EntityColumn::new("customer_id", "customer_id");
905 const total_cents: EntityColumn<Order> = EntityColumn::new("total_cents", "total_cents");
906 }
907
908 struct NewCustomer {
909 email: String,
910 active: bool,
911 }
912
913 impl Insertable<Customer> for NewCustomer {
914 fn values(&self) -> Vec<ColumnValue> {
915 vec![
916 ColumnValue::new("email", SqlValue::String(self.email.clone())),
917 ColumnValue::new("active", SqlValue::Bool(self.active)),
918 ]
919 }
920 }
921
922 struct UpdateCustomer {
923 email: Option<String>,
924 active: Option<bool>,
925 }
926
927 impl Changeset<Customer> for UpdateCustomer {
928 fn changes(&self) -> Vec<ColumnValue> {
929 let mut changes = Vec::new();
930
931 if let Some(email) = &self.email {
932 changes.push(ColumnValue::new("email", SqlValue::String(email.clone())));
933 }
934
935 if let Some(active) = self.active {
936 changes.push(ColumnValue::new("active", SqlValue::Bool(active)));
937 }
938
939 changes
940 }
941 }
942
943 #[test]
944 fn compiles_select_with_predicates_order_and_pagination() {
945 let query = SelectQuery::from_entity::<Customer>()
946 .select(vec![Expr::from(Customer::id), Expr::from(Customer::email)])
947 .filter(Predicate::eq(
948 Expr::from(Customer::active),
949 Expr::value(SqlValue::Bool(true)),
950 ))
951 .filter(Predicate::like(
952 Expr::from(Customer::email),
953 Expr::value(SqlValue::String("%@example.com".to_string())),
954 ))
955 .order_by(OrderBy::desc(Customer::created_at))
956 .paginate(Pagination::page(2, 20));
957
958 let compiled = SqlServerCompiler::compile_select(&query).unwrap();
959
960 assert_eq!(
961 compiled.sql,
962 "SELECT [sales].[customers].[id] AS [id], [sales].[customers].[email] AS [email] FROM [sales].[customers] WHERE (([sales].[customers].[active] = @P1) AND ([sales].[customers].[email] LIKE @P2)) ORDER BY [sales].[customers].[created_at] DESC OFFSET @P3 ROWS FETCH NEXT @P4 ROWS ONLY"
963 );
964 assert_eq!(
965 compiled.params,
966 vec![
967 SqlValue::Bool(true),
968 SqlValue::String("%@example.com".to_string()),
969 SqlValue::I64(20),
970 SqlValue::I64(20),
971 ]
972 );
973 }
974
975 #[test]
976 fn compiles_select_without_projection_as_star() {
977 let compiled =
978 SqlServerCompiler::compile_select(&SelectQuery::from_entity::<Customer>()).unwrap();
979
980 assert_eq!(compiled.sql, "SELECT * FROM [sales].[customers]");
981 assert!(compiled.params.is_empty());
982 }
983
984 #[test]
985 fn rejects_pagination_without_order_by() {
986 let error = SqlServerCompiler::compile_select(
987 &SelectQuery::from_entity::<Customer>().paginate(Pagination::page(1, 10)),
988 )
989 .unwrap_err();
990
991 assert_eq!(
992 error.message(),
993 "SQL Server pagination requires ORDER BY before OFFSET/FETCH"
994 );
995 }
996
997 #[test]
998 fn compiles_explicit_joins_to_sql() {
999 let query = SelectQuery::from_entity::<Customer>()
1000 .select(vec![
1001 Expr::from(Customer::email),
1002 Expr::from(Order::total_cents),
1003 ])
1004 .inner_join::<Order>(Predicate::eq(
1005 Expr::from(Customer::id),
1006 Expr::from(Order::customer_id),
1007 ))
1008 .filter(Predicate::gt(
1009 Expr::from(Order::total_cents),
1010 Expr::value(SqlValue::I64(1000)),
1011 ))
1012 .order_by(OrderBy::desc(Order::total_cents))
1013 .paginate(Pagination::page(1, 10));
1014
1015 let compiled = SqlServerCompiler::compile_select(&query).unwrap();
1016
1017 assert_eq!(
1018 compiled.sql,
1019 "SELECT [sales].[customers].[email] AS [email], [sales].[orders].[total_cents] AS [total_cents] FROM [sales].[customers] INNER JOIN [sales].[orders] ON ([sales].[customers].[id] = [sales].[orders].[customer_id]) WHERE ([sales].[orders].[total_cents] > @P1) ORDER BY [sales].[orders].[total_cents] DESC OFFSET @P2 ROWS FETCH NEXT @P3 ROWS ONLY"
1020 );
1021 assert_eq!(
1022 compiled.params,
1023 vec![SqlValue::I64(1000), SqlValue::I64(0), SqlValue::I64(10)]
1024 );
1025 }
1026
1027 #[test]
1028 fn rejects_duplicate_unaliased_joined_tables() {
1029 let error = SqlServerCompiler::compile_select(
1030 &SelectQuery::from_entity::<Customer>().inner_join::<Customer>(Predicate::eq(
1031 Expr::from(Customer::id),
1032 Expr::from(Customer::id),
1033 )),
1034 )
1035 .unwrap_err();
1036
1037 assert_eq!(
1038 error.message(),
1039 "SQL Server join compilation requires aliases for repeated table sources"
1040 );
1041 }
1042
1043 #[test]
1044 fn compiles_aliased_selects_with_repeated_joined_tables() {
1045 let query = SelectQuery::from_entity_as::<Customer>("c")
1046 .select(vec![
1047 Expr::column_as(Customer::email, "c"),
1048 Expr::column_as(Order::total_cents, "created_orders"),
1049 ])
1050 .inner_join_as::<Order>(
1051 "created_orders",
1052 Predicate::eq(
1053 Expr::column_as(Customer::id, "c"),
1054 Expr::column_as(Order::customer_id, "created_orders"),
1055 ),
1056 )
1057 .left_join_as::<Order>(
1058 "completed_orders",
1059 Predicate::gte(
1060 Expr::column_as(Order::total_cents, "completed_orders"),
1061 Expr::value(SqlValue::I64(5000)),
1062 ),
1063 )
1064 .filter(Predicate::gt(
1065 Expr::column_as(Order::total_cents, "created_orders"),
1066 Expr::value(SqlValue::I64(1000)),
1067 ))
1068 .order_by(OrderBy::new(
1069 TableRef::for_entity_as::<Order>("completed_orders"),
1070 "total_cents",
1071 sql_orm_query::SortDirection::Desc,
1072 ))
1073 .paginate(Pagination::page(1, 10));
1074
1075 let compiled = SqlServerCompiler::compile_select(&query).unwrap();
1076
1077 assert_eq!(
1078 compiled.sql,
1079 "SELECT [c].[email] AS [email], [created_orders].[total_cents] AS [total_cents] FROM [sales].[customers] AS [c] INNER JOIN [sales].[orders] AS [created_orders] ON ([c].[id] = [created_orders].[customer_id]) LEFT JOIN [sales].[orders] AS [completed_orders] ON ([completed_orders].[total_cents] >= @P1) WHERE ([created_orders].[total_cents] > @P2) ORDER BY [completed_orders].[total_cents] DESC OFFSET @P3 ROWS FETCH NEXT @P4 ROWS ONLY"
1080 );
1081 assert_eq!(
1082 compiled.params,
1083 vec![
1084 SqlValue::I64(5000),
1085 SqlValue::I64(1000),
1086 SqlValue::I64(0),
1087 SqlValue::I64(10),
1088 ]
1089 );
1090 }
1091
1092 #[test]
1093 fn compiles_aliased_count_query() {
1094 let query = CountQuery::from_entity_as::<Customer>("c").filter(Predicate::eq(
1095 Expr::column_as(Customer::active, "c"),
1096 Expr::value(SqlValue::Bool(true)),
1097 ));
1098
1099 let compiled = SqlServerCompiler::compile_count(&query).unwrap();
1100
1101 assert_eq!(
1102 compiled.sql,
1103 "SELECT COUNT(*) AS [count] FROM [sales].[customers] AS [c] WHERE ([c].[active] = @P1)"
1104 );
1105 assert_eq!(compiled.params, vec![SqlValue::Bool(true)]);
1106 }
1107
1108 #[test]
1109 fn rejects_empty_table_aliases() {
1110 let error = SqlServerCompiler::compile_select(
1111 &SelectQuery::from_entity_as::<Customer>("").inner_join_as::<Order>(
1112 "o",
1113 Predicate::eq(
1114 Expr::column_as(Customer::id, ""),
1115 Expr::column_as(Order::customer_id, "o"),
1116 ),
1117 ),
1118 )
1119 .unwrap_err();
1120
1121 assert_eq!(error.message(), "SQL Server identifier cannot be empty");
1122 }
1123
1124 #[test]
1125 fn compiles_insert_with_output_inserted_and_parameter_order() {
1126 let query = InsertQuery::for_entity::<Customer, _>(&NewCustomer {
1127 email: "ana@example.com".to_string(),
1128 active: true,
1129 });
1130
1131 let compiled = SqlServerCompiler::compile_insert(&query).unwrap();
1132
1133 assert_eq!(
1134 compiled.sql,
1135 "INSERT INTO [sales].[customers] ([email], [active]) OUTPUT INSERTED.* VALUES (@P1, @P2)"
1136 );
1137 assert_eq!(
1138 compiled.params,
1139 vec![
1140 SqlValue::String("ana@example.com".to_string()),
1141 SqlValue::Bool(true),
1142 ]
1143 );
1144 }
1145
1146 #[test]
1147 fn compiles_update_with_output_inserted_and_where_clause() {
1148 let query = UpdateQuery::for_entity::<Customer, _>(&UpdateCustomer {
1149 email: Some("ana.maria@example.com".to_string()),
1150 active: Some(false),
1151 })
1152 .filter(Predicate::eq(
1153 Expr::from(Customer::id),
1154 Expr::value(SqlValue::I64(7)),
1155 ));
1156
1157 let compiled = SqlServerCompiler::compile_update(&query).unwrap();
1158
1159 assert_eq!(
1160 compiled.sql,
1161 "UPDATE [sales].[customers] SET [email] = @P1, [active] = @P2 OUTPUT INSERTED.* WHERE ([sales].[customers].[id] = @P3)"
1162 );
1163 assert_eq!(
1164 compiled.params,
1165 vec![
1166 SqlValue::String("ana.maria@example.com".to_string()),
1167 SqlValue::Bool(false),
1168 SqlValue::I64(7),
1169 ]
1170 );
1171 }
1172
1173 #[test]
1174 fn compiles_delete_and_count_queries() {
1175 let delete = DeleteQuery::from_entity::<Customer>().filter(Predicate::eq(
1176 Expr::from(Customer::id),
1177 Expr::value(SqlValue::I64(7)),
1178 ));
1179 let count = CountQuery::from_entity::<Customer>().filter(Predicate::eq(
1180 Expr::from(Customer::active),
1181 Expr::value(SqlValue::Bool(true)),
1182 ));
1183
1184 let compiled_delete = SqlServerCompiler::compile_delete(&delete).unwrap();
1185 let compiled_count = SqlServerCompiler::compile_count(&count).unwrap();
1186
1187 assert_eq!(
1188 compiled_delete.sql,
1189 "DELETE FROM [sales].[customers] WHERE ([sales].[customers].[id] = @P1)"
1190 );
1191 assert_eq!(compiled_delete.params, vec![SqlValue::I64(7)]);
1192 assert_eq!(
1193 compiled_count.sql,
1194 "SELECT COUNT(*) AS [count] FROM [sales].[customers] WHERE ([sales].[customers].[active] = @P1)"
1195 );
1196 assert_eq!(compiled_count.params, vec![SqlValue::Bool(true)]);
1197 }
1198
1199 #[test]
1200 fn compiles_exists_query_with_join_and_parameter_order() {
1201 let query = ExistsQuery::from_entity::<Customer>()
1202 .inner_join::<Order>(Predicate::eq(
1203 Expr::from(Customer::id),
1204 Expr::from(Order::customer_id),
1205 ))
1206 .filter(Predicate::eq(
1207 Expr::from(Customer::active),
1208 Expr::value(SqlValue::Bool(true)),
1209 ))
1210 .filter(Predicate::gt(
1211 Expr::from(Order::total_cents),
1212 Expr::value(SqlValue::I64(1000)),
1213 ));
1214
1215 let compiled = SqlServerCompiler::compile_exists(&query).unwrap();
1216
1217 assert_eq!(
1218 compiled.sql,
1219 "SELECT CASE WHEN EXISTS (SELECT 1 FROM [sales].[customers] INNER JOIN [sales].[orders] ON ([sales].[customers].[id] = [sales].[orders].[customer_id]) WHERE (([sales].[customers].[active] = @P1) AND ([sales].[orders].[total_cents] > @P2))) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END AS [exists]"
1220 );
1221 assert_eq!(
1222 compiled.params,
1223 vec![SqlValue::Bool(true), SqlValue::I64(1000)]
1224 );
1225 }
1226
1227 #[test]
1228 fn compiles_query_enum_through_single_entry_point() {
1229 let query = Query::Count(CountQuery::from_entity::<Customer>().filter(Predicate::eq(
1230 Expr::from(Customer::active),
1231 Expr::value(SqlValue::Bool(true)),
1232 )));
1233
1234 let compiled = SqlServerCompiler::compile_query(&query).unwrap();
1235
1236 assert_eq!(
1237 compiled.sql,
1238 "SELECT COUNT(*) AS [count] FROM [sales].[customers] WHERE ([sales].[customers].[active] = @P1)"
1239 );
1240 assert_eq!(compiled.params, vec![SqlValue::Bool(true)]);
1241
1242 let exists_query = Query::Exists(Box::new(ExistsQuery::from_entity::<Customer>().filter(
1243 Predicate::eq(
1244 Expr::from(Customer::active),
1245 Expr::value(SqlValue::Bool(true)),
1246 ),
1247 )));
1248 let compiled_exists = SqlServerCompiler::compile_query(&exists_query).unwrap();
1249 assert_eq!(
1250 compiled_exists.sql,
1251 "SELECT CASE WHEN EXISTS (SELECT 1 FROM [sales].[customers] WHERE ([sales].[customers].[active] = @P1)) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END AS [exists]"
1252 );
1253 assert_eq!(compiled_exists.params, vec![SqlValue::Bool(true)]);
1254 }
1255
1256 #[test]
1257 fn compiles_aggregate_query_through_single_entry_point() {
1258 let query = Query::Aggregate(Box::new(
1259 AggregateQuery::from_entity::<Order>()
1260 .project(vec![AggregateProjection::count_as("order_count")])
1261 .filter(Predicate::gt(
1262 Expr::from(Order::total_cents),
1263 Expr::value(SqlValue::I64(1000)),
1264 )),
1265 ));
1266
1267 let compiled = SqlServerCompiler::compile_query(&query).unwrap();
1268
1269 assert_eq!(
1270 compiled.sql,
1271 "SELECT COUNT(*) AS [order_count] FROM [sales].[orders] WHERE ([sales].[orders].[total_cents] > @P1)"
1272 );
1273 assert_eq!(compiled.params, vec![SqlValue::I64(1000)]);
1274 }
1275
1276 #[test]
1277 fn compiles_grouped_aggregate_query_with_having_and_parameter_order() {
1278 let query = AggregateQuery::from_entity::<Order>()
1279 .inner_join::<Customer>(Predicate::eq(
1280 Expr::from(Order::customer_id),
1281 Expr::from(Customer::id),
1282 ))
1283 .filter(Predicate::eq(
1284 Expr::from(Customer::active),
1285 Expr::value(SqlValue::Bool(true)),
1286 ))
1287 .group_by(vec![Expr::from(Order::customer_id)])
1288 .project(vec![
1289 AggregateProjection::group_key(Order::customer_id),
1290 AggregateProjection::count_as("order_count"),
1291 AggregateProjection::sum_as(Order::total_cents, "total_cents"),
1292 AggregateProjection::avg_as(Order::total_cents, "average_cents"),
1293 AggregateProjection::min_as(Order::total_cents, "min_cents"),
1294 AggregateProjection::max_as(Order::total_cents, "max_cents"),
1295 ])
1296 .having(AggregatePredicate::gt(
1297 AggregateExpr::count_all(),
1298 Expr::value(SqlValue::I64(1)),
1299 ))
1300 .order_by(AggregateOrderBy::desc(AggregateExpr::sum(Expr::from(
1301 Order::total_cents,
1302 ))))
1303 .paginate(Pagination::page(1, 10));
1304
1305 let compiled = SqlServerCompiler::compile_aggregate(&query).unwrap();
1306
1307 assert_eq!(
1308 compiled.sql,
1309 "SELECT [sales].[orders].[customer_id] AS [customer_id], COUNT(*) AS [order_count], SUM([sales].[orders].[total_cents]) AS [total_cents], AVG([sales].[orders].[total_cents]) AS [average_cents], MIN([sales].[orders].[total_cents]) AS [min_cents], MAX([sales].[orders].[total_cents]) AS [max_cents] FROM [sales].[orders] INNER JOIN [sales].[customers] ON ([sales].[orders].[customer_id] = [sales].[customers].[id]) WHERE ([sales].[customers].[active] = @P1) GROUP BY [sales].[orders].[customer_id] HAVING (COUNT(*) > @P2) ORDER BY SUM([sales].[orders].[total_cents]) DESC OFFSET @P3 ROWS FETCH NEXT @P4 ROWS ONLY"
1310 );
1311 assert_eq!(
1312 compiled.params,
1313 vec![
1314 SqlValue::Bool(true),
1315 SqlValue::I64(1),
1316 SqlValue::I64(0),
1317 SqlValue::I64(10),
1318 ]
1319 );
1320 }
1321
1322 #[test]
1323 fn rejects_invalid_aggregate_queries() {
1324 let empty_projection_error =
1325 SqlServerCompiler::compile_aggregate(&AggregateQuery::from_entity::<Order>())
1326 .unwrap_err();
1327 assert_eq!(
1328 empty_projection_error.message(),
1329 "SQL Server aggregate query compilation requires at least one projection"
1330 );
1331
1332 let duplicate_alias_error = SqlServerCompiler::compile_aggregate(
1333 &AggregateQuery::from_entity::<Order>().project(vec![
1334 AggregateProjection::count_as("value"),
1335 AggregateProjection::sum_as(Order::total_cents, "value"),
1336 ]),
1337 )
1338 .unwrap_err();
1339 assert_eq!(
1340 duplicate_alias_error.message(),
1341 "SQL Server aggregate projection alias `value` is duplicated"
1342 );
1343
1344 let missing_group_key_error = SqlServerCompiler::compile_aggregate(
1345 &AggregateQuery::from_entity::<Order>()
1346 .project(vec![AggregateProjection::group_key(Order::customer_id)]),
1347 )
1348 .unwrap_err();
1349 assert_eq!(
1350 missing_group_key_error.message(),
1351 "SQL Server aggregate group key projection must appear in GROUP BY"
1352 );
1353
1354 let empty_alias_error = SqlServerCompiler::compile_aggregate(
1355 &AggregateQuery::from_entity::<Order>().project(vec![AggregateProjection::expr_as(
1356 AggregateExpr::count_all(),
1357 " ",
1358 )]),
1359 )
1360 .unwrap_err();
1361 assert_eq!(
1362 empty_alias_error.message(),
1363 "SQL Server aggregate projection alias cannot be empty"
1364 );
1365
1366 let ungrouped_having_column_error = SqlServerCompiler::compile_aggregate(
1367 &AggregateQuery::from_entity::<Order>()
1368 .group_by(vec![Expr::from(Order::customer_id)])
1369 .project(vec![
1370 AggregateProjection::group_key(Order::customer_id),
1371 AggregateProjection::count_as("order_count"),
1372 ])
1373 .having(AggregatePredicate::gt(
1374 AggregateExpr::count_all(),
1375 Expr::from(Order::total_cents),
1376 )),
1377 )
1378 .unwrap_err();
1379 assert_eq!(
1380 ungrouped_having_column_error.message(),
1381 "SQL Server aggregate group key projection must appear in GROUP BY"
1382 );
1383
1384 let ungrouped_order_key_error = SqlServerCompiler::compile_aggregate(
1385 &AggregateQuery::from_entity::<Order>()
1386 .group_by(vec![Expr::from(Order::customer_id)])
1387 .project(vec![
1388 AggregateProjection::group_key(Order::customer_id),
1389 AggregateProjection::count_as("order_count"),
1390 ])
1391 .order_by(AggregateOrderBy::asc(AggregateExpr::group_key(
1392 Order::total_cents,
1393 ))),
1394 )
1395 .unwrap_err();
1396 assert_eq!(
1397 ungrouped_order_key_error.message(),
1398 "SQL Server aggregate group key projection must appear in GROUP BY"
1399 );
1400 }
1401
1402 #[test]
1403 fn compiles_functions_null_checks_and_unary_binary_exprs() {
1404 let query = SelectQuery {
1405 from: TableRef::new("sales", "customers"),
1406 joins: vec![],
1407 projection: vec![SelectProjection::expr_as(
1408 Expr::function(
1409 "LOWER",
1410 vec![Expr::binary(
1411 Expr::from(Customer::email),
1412 BinaryOp::Add,
1413 Expr::value(SqlValue::String("@example.com".to_string())),
1414 )],
1415 ),
1416 "email_lower",
1417 )],
1418 predicate: Some(Predicate::and(vec![
1419 Predicate::is_not_null(Expr::from(Customer::email)),
1420 Predicate::negate(Predicate::is_null(Expr::unary(
1421 UnaryOp::Negate,
1422 Expr::value(SqlValue::I64(1)),
1423 ))),
1424 ])),
1425 order_by: vec![],
1426 pagination: None,
1427 };
1428
1429 let compiled = SqlServerCompiler::compile_select(&query).unwrap();
1430
1431 assert_eq!(
1432 compiled.sql,
1433 "SELECT LOWER(([sales].[customers].[email] + @P1)) AS [email_lower] FROM [sales].[customers] WHERE (([sales].[customers].[email] IS NOT NULL) AND (NOT ((- @P2) IS NULL)))"
1434 );
1435 assert_eq!(
1436 compiled.params,
1437 vec![
1438 SqlValue::String("@example.com".to_string()),
1439 SqlValue::I64(1),
1440 ]
1441 );
1442 }
1443
1444 #[test]
1445 fn rejects_projection_expression_without_alias() {
1446 let error = SqlServerCompiler::compile_select(
1447 &SelectQuery::from_entity::<Customer>().select(vec![SelectProjection::expr(
1448 Expr::function("LOWER", vec![Expr::from(Customer::email)]),
1449 )]),
1450 )
1451 .unwrap_err();
1452
1453 assert_eq!(
1454 error.message(),
1455 "SQL Server projection expressions require an explicit alias"
1456 );
1457 }
1458
1459 #[test]
1460 fn rejects_empty_or_duplicate_projection_aliases() {
1461 let empty_alias_error =
1462 SqlServerCompiler::compile_select(&SelectQuery::from_entity::<Customer>().select(
1463 vec![SelectProjection::expr_as(Expr::from(Customer::email), "")],
1464 ))
1465 .unwrap_err();
1466
1467 assert_eq!(
1468 empty_alias_error.message(),
1469 "SQL Server projection alias cannot be empty"
1470 );
1471
1472 let duplicate_alias_error = SqlServerCompiler::compile_select(
1473 &SelectQuery::from_entity::<Customer>().select(vec![
1474 SelectProjection::column(Customer::id),
1475 SelectProjection::expr_as(Expr::from(Customer::email), "id"),
1476 ]),
1477 )
1478 .unwrap_err();
1479
1480 assert_eq!(
1481 duplicate_alias_error.message(),
1482 "SQL Server projection alias `id` is duplicated"
1483 );
1484 }
1485
1486 #[test]
1487 fn rejects_empty_updates_and_empty_logical_predicates() {
1488 let empty_update = UpdateQuery::for_entity::<Customer, _>(&UpdateCustomer {
1489 email: None,
1490 active: None,
1491 });
1492 let update_error = SqlServerCompiler::compile_update(&empty_update).unwrap_err();
1493
1494 assert_eq!(
1495 update_error.message(),
1496 "SQL Server update compilation requires at least one change"
1497 );
1498
1499 let predicate_error = SqlServerCompiler::compile_select(
1500 &SelectQuery::from_entity::<Customer>().filter(Predicate::and(vec![])),
1501 )
1502 .unwrap_err();
1503
1504 assert_eq!(
1505 predicate_error.message(),
1506 "logical predicate compilation requires at least one child predicate"
1507 );
1508 }
1509}