1use crate::clause::OrderBy;
27use crate::expr::Dialect;
28use sqlmodel_core::Value;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SetOpType {
33 Union,
35 UnionAll,
37 Intersect,
39 IntersectAll,
41 Except,
43 ExceptAll,
45}
46
47impl SetOpType {
48 pub const fn as_sql(&self) -> &'static str {
50 match self {
51 SetOpType::Union => "UNION",
52 SetOpType::UnionAll => "UNION ALL",
53 SetOpType::Intersect => "INTERSECT",
54 SetOpType::IntersectAll => "INTERSECT ALL",
55 SetOpType::Except => "EXCEPT",
56 SetOpType::ExceptAll => "EXCEPT ALL",
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct SetOperation {
64 queries: Vec<(String, Vec<Value>)>,
66 op_types: Vec<SetOpType>,
68 order_by: Vec<OrderBy>,
70 limit: Option<u64>,
72 offset: Option<u64>,
74}
75
76impl SetOperation {
77 pub fn new(query_sql: impl Into<String>, params: Vec<Value>) -> Self {
79 Self {
80 queries: vec![(query_sql.into(), params)],
81 op_types: Vec::new(),
82 order_by: Vec::new(),
83 limit: None,
84 offset: None,
85 }
86 }
87
88 pub fn union(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
90 self.add_op(SetOpType::Union, query_sql, params)
91 }
92
93 pub fn union_all(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
95 self.add_op(SetOpType::UnionAll, query_sql, params)
96 }
97
98 pub fn intersect(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
100 self.add_op(SetOpType::Intersect, query_sql, params)
101 }
102
103 pub fn intersect_all(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
105 self.add_op(SetOpType::IntersectAll, query_sql, params)
106 }
107
108 pub fn except(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
110 self.add_op(SetOpType::Except, query_sql, params)
111 }
112
113 pub fn except_all(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
115 self.add_op(SetOpType::ExceptAll, query_sql, params)
116 }
117
118 fn add_op(mut self, op: SetOpType, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
119 self.op_types.push(op);
120 self.queries.push((query_sql.into(), params));
121 self
122 }
123
124 pub fn order_by(mut self, order: OrderBy) -> Self {
126 self.order_by.push(order);
127 self
128 }
129
130 pub fn order_by_many(mut self, orders: Vec<OrderBy>) -> Self {
132 self.order_by.extend(orders);
133 self
134 }
135
136 pub fn limit(mut self, limit: u64) -> Self {
138 self.limit = Some(limit);
139 self
140 }
141
142 pub fn offset(mut self, offset: u64) -> Self {
144 self.offset = Some(offset);
145 self
146 }
147
148 pub fn build(&self) -> (String, Vec<Value>) {
150 self.build_with_dialect(Dialect::Postgres)
151 }
152
153 pub fn build_with_dialect(&self, dialect: Dialect) -> (String, Vec<Value>) {
155 let mut sql = String::new();
156 let mut params = Vec::new();
157
158 for (i, (query_sql, query_params)) in self.queries.iter().enumerate() {
160 if i > 0 {
161 let op = &self.op_types[i - 1];
163 sql.push(' ');
164 sql.push_str(op.as_sql());
165 sql.push(' ');
166 }
167
168 sql.push('(');
170 sql.push_str(query_sql);
171 sql.push(')');
172
173 params.extend(query_params.clone());
174 }
175
176 if !self.order_by.is_empty() {
178 sql.push_str(" ORDER BY ");
179 let order_strs: Vec<String> = self
180 .order_by
181 .iter()
182 .map(|o| {
183 let expr_sql = o.expr.build_with_dialect(dialect, &mut params, 0);
184 let dir = match o.direction {
185 crate::clause::OrderDirection::Asc => "ASC",
186 crate::clause::OrderDirection::Desc => "DESC",
187 };
188 let nulls = match o.nulls {
189 Some(crate::clause::NullsOrder::First) => " NULLS FIRST",
190 Some(crate::clause::NullsOrder::Last) => " NULLS LAST",
191 None => "",
192 };
193 format!("{expr_sql} {dir}{nulls}")
194 })
195 .collect();
196 sql.push_str(&order_strs.join(", "));
197 }
198
199 if let Some(limit) = self.limit {
201 sql.push_str(" LIMIT ");
202 sql.push_str(&limit.to_string());
203 }
204
205 if let Some(offset) = self.offset {
207 sql.push_str(" OFFSET ");
208 sql.push_str(&offset.to_string());
209 }
210
211 (sql, params)
212 }
213}
214
215pub fn union<I, S>(queries: I) -> Option<SetOperation>
228where
229 I: IntoIterator<Item = (S, Vec<Value>)>,
230 S: Into<String>,
231{
232 combine_queries(SetOpType::Union, queries)
233}
234
235pub fn union_all<I, S>(queries: I) -> Option<SetOperation>
249where
250 I: IntoIterator<Item = (S, Vec<Value>)>,
251 S: Into<String>,
252{
253 combine_queries(SetOpType::UnionAll, queries)
254}
255
256pub fn intersect<I, S>(queries: I) -> Option<SetOperation>
260where
261 I: IntoIterator<Item = (S, Vec<Value>)>,
262 S: Into<String>,
263{
264 combine_queries(SetOpType::Intersect, queries)
265}
266
267pub fn intersect_all<I, S>(queries: I) -> Option<SetOperation>
271where
272 I: IntoIterator<Item = (S, Vec<Value>)>,
273 S: Into<String>,
274{
275 combine_queries(SetOpType::IntersectAll, queries)
276}
277
278pub fn except<I, S>(queries: I) -> Option<SetOperation>
282where
283 I: IntoIterator<Item = (S, Vec<Value>)>,
284 S: Into<String>,
285{
286 combine_queries(SetOpType::Except, queries)
287}
288
289pub fn except_all<I, S>(queries: I) -> Option<SetOperation>
293where
294 I: IntoIterator<Item = (S, Vec<Value>)>,
295 S: Into<String>,
296{
297 combine_queries(SetOpType::ExceptAll, queries)
298}
299
300fn combine_queries<I, S>(op: SetOpType, queries: I) -> Option<SetOperation>
301where
302 I: IntoIterator<Item = (S, Vec<Value>)>,
303 S: Into<String>,
304{
305 let mut iter = queries.into_iter();
306
307 let (first_sql, first_params) = iter.next()?;
309
310 let mut result = SetOperation::new(first_sql, first_params);
311
312 for (sql, params) in iter {
314 result = result.add_op(op, sql, params);
315 }
316
317 Some(result)
318}
319
320#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::expr::Expr;
326
327 #[test]
328 fn test_union_basic() {
329 let query = SetOperation::new("SELECT * FROM users WHERE role = 'admin'", vec![])
330 .union("SELECT * FROM users WHERE role = 'manager'", vec![]);
331
332 let (sql, params) = query.build();
333 assert_eq!(
334 sql,
335 "(SELECT * FROM users WHERE role = 'admin') UNION (SELECT * FROM users WHERE role = 'manager')"
336 );
337 assert!(params.is_empty());
338 }
339
340 #[test]
341 fn test_union_all_basic() {
342 let query = SetOperation::new("SELECT id FROM table1", vec![])
343 .union_all("SELECT id FROM table2", vec![]);
344
345 let (sql, _) = query.build();
346 assert_eq!(
347 sql,
348 "(SELECT id FROM table1) UNION ALL (SELECT id FROM table2)"
349 );
350 }
351
352 #[test]
353 fn test_union_with_params() {
354 let query = SetOperation::new(
355 "SELECT * FROM users WHERE role = $1",
356 vec![Value::Text("admin".to_string())],
357 )
358 .union(
359 "SELECT * FROM users WHERE role = $2",
360 vec![Value::Text("manager".to_string())],
361 );
362
363 let (sql, params) = query.build();
364 assert_eq!(params.len(), 2);
365 assert_eq!(params[0], Value::Text("admin".to_string()));
366 assert_eq!(params[1], Value::Text("manager".to_string()));
367 assert!(sql.contains("$1"));
368 assert!(sql.contains("$2"));
369 }
370
371 #[test]
372 fn test_union_function() {
373 let query = union([
374 ("SELECT * FROM admins", vec![]),
375 ("SELECT * FROM managers", vec![]),
376 ("SELECT * FROM employees", vec![]),
377 ])
378 .expect("non-empty iterator");
379
380 let (sql, _) = query.build();
381 assert!(sql.contains("UNION"));
382 assert!(!sql.contains("UNION ALL"));
383 assert!(sql.contains("admins"));
384 assert!(sql.contains("managers"));
385 assert!(sql.contains("employees"));
386 }
387
388 #[test]
389 fn test_union_all_function() {
390 let query = union_all([
391 ("SELECT 1", vec![]),
392 ("SELECT 2", vec![]),
393 ("SELECT 3", vec![]),
394 ])
395 .expect("non-empty iterator");
396
397 let (sql, _) = query.build();
398 assert_eq!(sql.matches("UNION ALL").count(), 2);
400 }
401
402 #[test]
403 fn test_union_empty_returns_none() {
404 let empty: Vec<(&str, Vec<Value>)> = vec![];
405 assert!(union(empty).is_none());
406 }
407
408 #[test]
409 fn test_union_with_order_by() {
410 let query = SetOperation::new("SELECT name FROM users WHERE active = true", vec![])
411 .union("SELECT name FROM users WHERE premium = true", vec![])
412 .order_by(Expr::col("name").asc());
413
414 let (sql, _) = query.build();
415 assert!(sql.ends_with("ORDER BY \"name\" ASC"));
416 }
417
418 #[test]
419 fn test_union_with_limit_offset() {
420 let query = SetOperation::new("SELECT * FROM t1", vec![])
421 .union("SELECT * FROM t2", vec![])
422 .limit(10)
423 .offset(5);
424
425 let (sql, _) = query.build();
426 assert!(sql.ends_with("LIMIT 10 OFFSET 5"));
427 }
428
429 #[test]
430 fn test_intersect() {
431 let query = SetOperation::new("SELECT id FROM users WHERE active = true", vec![])
432 .intersect("SELECT id FROM users WHERE premium = true", vec![]);
433
434 let (sql, _) = query.build();
435 assert!(sql.contains("INTERSECT"));
436 assert!(!sql.contains("INTERSECT ALL"));
437 }
438
439 #[test]
440 fn test_intersect_all() {
441 let query = intersect_all([("SELECT id FROM t1", vec![]), ("SELECT id FROM t2", vec![])])
442 .expect("non-empty iterator");
443
444 let (sql, _) = query.build();
445 assert!(sql.contains("INTERSECT ALL"));
446 }
447
448 #[test]
449 fn test_except() {
450 let query = SetOperation::new("SELECT id FROM all_users", vec![])
451 .except("SELECT id FROM banned_users", vec![]);
452
453 let (sql, _) = query.build();
454 assert!(sql.contains("EXCEPT"));
455 assert!(!sql.contains("EXCEPT ALL"));
456 }
457
458 #[test]
459 fn test_except_all() {
460 let query = except_all([("SELECT id FROM t1", vec![]), ("SELECT id FROM t2", vec![])])
461 .expect("non-empty iterator");
462
463 let (sql, _) = query.build();
464 assert!(sql.contains("EXCEPT ALL"));
465 }
466
467 #[test]
468 fn test_chained_operations() {
469 let query = SetOperation::new("SELECT id FROM t1", vec![])
470 .union("SELECT id FROM t2", vec![])
471 .union_all("SELECT id FROM t3", vec![]);
472
473 let (sql, _) = query.build();
474 let union_pos = sql.find("UNION").unwrap();
476 let union_all_pos = sql.find("UNION ALL").unwrap();
477 assert!(union_pos < union_all_pos);
478 }
479
480 #[test]
481 fn test_complex_query() {
482 let query = SetOperation::new(
483 "SELECT name, email FROM users WHERE role = $1",
484 vec![Value::Text("admin".to_string())],
485 )
486 .union_all(
487 "SELECT name, email FROM users WHERE department = $2",
488 vec![Value::Text("engineering".to_string())],
489 )
490 .order_by(Expr::col("name").asc())
491 .order_by(Expr::col("email").desc())
492 .limit(100)
493 .offset(0);
494
495 let (sql, params) = query.build();
496
497 assert!(sql.contains("UNION ALL"));
498 assert!(sql.contains("ORDER BY"));
499 assert!(sql.contains("LIMIT 100"));
500 assert!(sql.contains("OFFSET 0"));
501 assert_eq!(params.len(), 2);
502 }
503
504 #[test]
505 fn test_set_op_type_sql() {
506 assert_eq!(SetOpType::Union.as_sql(), "UNION");
507 assert_eq!(SetOpType::UnionAll.as_sql(), "UNION ALL");
508 assert_eq!(SetOpType::Intersect.as_sql(), "INTERSECT");
509 assert_eq!(SetOpType::IntersectAll.as_sql(), "INTERSECT ALL");
510 assert_eq!(SetOpType::Except.as_sql(), "EXCEPT");
511 assert_eq!(SetOpType::ExceptAll.as_sql(), "EXCEPT ALL");
512 }
513}