Skip to main content

sqlmodel_query/
set_ops.rs

1//! Set operations for combining query results.
2//!
3//! Provides UNION, UNION ALL, INTERSECT, INTERSECT ALL, EXCEPT, and EXCEPT ALL
4//! operations for combining multiple SELECT queries.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use sqlmodel_query::{select, union, union_all, SetOperation};
10//!
11//! // UNION - removes duplicates
12//! let admins = select!(User).filter(Expr::col("role").eq("admin"));
13//! let managers = select!(User).filter(Expr::col("role").eq("manager"));
14//! let query = admins.union(managers);
15//!
16//! // UNION ALL - keeps duplicates
17//! let query = union_all([query1, query2, query3]);
18//!
19//! // With ORDER BY on final result
20//! let query = select!(User)
21//!     .filter(Expr::col("active").eq(true))
22//!     .union(select!(User).filter(Expr::col("premium").eq(true)))
23//!     .order_by(Expr::col("name").asc());
24//! ```
25
26use crate::clause::OrderBy;
27use crate::expr::Dialect;
28use sqlmodel_core::Value;
29
30/// Type of set operation.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SetOpType {
33    /// UNION - combines results, removes duplicates
34    Union,
35    /// UNION ALL - combines results, keeps duplicates
36    UnionAll,
37    /// INTERSECT - returns common rows, removes duplicates
38    Intersect,
39    /// INTERSECT ALL - returns common rows, keeps duplicates
40    IntersectAll,
41    /// EXCEPT - returns rows in first query not in second, removes duplicates
42    Except,
43    /// EXCEPT ALL - returns rows in first query not in second, keeps duplicates
44    ExceptAll,
45}
46
47impl SetOpType {
48    /// Get the SQL keyword for this set operation.
49    pub const fn as_sql(&self) -> &'static str {
50        match self {
51            SetOpType::Union => "UNION",
52            SetOpType::UnionAll => "UNION ALL",
53            SetOpType::Intersect => "INTERSECT",
54            SetOpType::IntersectAll => "INTERSECT ALL",
55            SetOpType::Except => "EXCEPT",
56            SetOpType::ExceptAll => "EXCEPT ALL",
57        }
58    }
59}
60
61/// A set operation combining multiple queries.
62#[derive(Debug, Clone)]
63pub struct SetOperation {
64    /// The queries to combine (in order)
65    queries: Vec<(String, Vec<Value>)>,
66    /// The type of set operation between consecutive queries
67    op_types: Vec<SetOpType>,
68    /// Optional ORDER BY on the final result
69    order_by: Vec<OrderBy>,
70    /// Optional LIMIT on the final result
71    limit: Option<u64>,
72    /// Optional OFFSET on the final result
73    offset: Option<u64>,
74}
75
76impl SetOperation {
77    /// Create a new set operation from a single query.
78    pub fn new(query_sql: impl Into<String>, params: Vec<Value>) -> Self {
79        Self {
80            queries: vec![(query_sql.into(), params)],
81            op_types: Vec::new(),
82            order_by: Vec::new(),
83            limit: None,
84            offset: None,
85        }
86    }
87
88    /// Add a UNION operation with another query.
89    pub fn union(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
90        self.add_op(SetOpType::Union, query_sql, params)
91    }
92
93    /// Add a UNION ALL operation with another query.
94    pub fn union_all(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
95        self.add_op(SetOpType::UnionAll, query_sql, params)
96    }
97
98    /// Add an INTERSECT operation with another query.
99    pub fn intersect(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
100        self.add_op(SetOpType::Intersect, query_sql, params)
101    }
102
103    /// Add an INTERSECT ALL operation with another query.
104    pub fn intersect_all(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
105        self.add_op(SetOpType::IntersectAll, query_sql, params)
106    }
107
108    /// Add an EXCEPT operation with another query.
109    pub fn except(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
110        self.add_op(SetOpType::Except, query_sql, params)
111    }
112
113    /// Add an EXCEPT ALL operation with another query.
114    pub fn except_all(self, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
115        self.add_op(SetOpType::ExceptAll, query_sql, params)
116    }
117
118    fn add_op(mut self, op: SetOpType, query_sql: impl Into<String>, params: Vec<Value>) -> Self {
119        self.op_types.push(op);
120        self.queries.push((query_sql.into(), params));
121        self
122    }
123
124    /// Add ORDER BY to the final result.
125    pub fn order_by(mut self, order: OrderBy) -> Self {
126        self.order_by.push(order);
127        self
128    }
129
130    /// Add multiple ORDER BY clauses.
131    pub fn order_by_many(mut self, orders: Vec<OrderBy>) -> Self {
132        self.order_by.extend(orders);
133        self
134    }
135
136    /// Set LIMIT on the final result.
137    pub fn limit(mut self, limit: u64) -> Self {
138        self.limit = Some(limit);
139        self
140    }
141
142    /// Set OFFSET on the final result.
143    pub fn offset(mut self, offset: u64) -> Self {
144        self.offset = Some(offset);
145        self
146    }
147
148    /// Build the SQL query with default dialect (PostgreSQL).
149    pub fn build(&self) -> (String, Vec<Value>) {
150        self.build_with_dialect(Dialect::Postgres)
151    }
152
153    /// Build the SQL query with a specific dialect.
154    pub fn build_with_dialect(&self, dialect: Dialect) -> (String, Vec<Value>) {
155        let mut sql = String::new();
156        let mut params = Vec::new();
157
158        // Build each query with set operations between them
159        for (i, (query_sql, query_params)) in self.queries.iter().enumerate() {
160            if i > 0 {
161                // Add the set operation before this query
162                let op = &self.op_types[i - 1];
163                sql.push(' ');
164                sql.push_str(op.as_sql());
165                sql.push(' ');
166            }
167
168            // Wrap each query in parentheses for clarity
169            sql.push('(');
170            sql.push_str(query_sql);
171            sql.push(')');
172
173            params.extend(query_params.clone());
174        }
175
176        // ORDER BY on final result
177        if !self.order_by.is_empty() {
178            sql.push_str(" ORDER BY ");
179            let order_strs: Vec<String> = self
180                .order_by
181                .iter()
182                .map(|o| {
183                    let expr_sql = o.expr.build_with_dialect(dialect, &mut params, 0);
184                    let dir = match o.direction {
185                        crate::clause::OrderDirection::Asc => "ASC",
186                        crate::clause::OrderDirection::Desc => "DESC",
187                    };
188                    let nulls = match o.nulls {
189                        Some(crate::clause::NullsOrder::First) => " NULLS FIRST",
190                        Some(crate::clause::NullsOrder::Last) => " NULLS LAST",
191                        None => "",
192                    };
193                    format!("{expr_sql} {dir}{nulls}")
194                })
195                .collect();
196            sql.push_str(&order_strs.join(", "));
197        }
198
199        // LIMIT
200        if let Some(limit) = self.limit {
201            sql.push_str(" LIMIT ");
202            sql.push_str(&limit.to_string());
203        }
204
205        // OFFSET
206        if let Some(offset) = self.offset {
207            sql.push_str(" OFFSET ");
208            sql.push_str(&offset.to_string());
209        }
210
211        (sql, params)
212    }
213}
214
215/// Create a UNION of multiple queries.
216///
217/// Returns `None` if the iterator is empty.
218///
219/// # Example
220///
221/// ```ignore
222/// let query = union([
223///     ("SELECT * FROM users WHERE role = 'admin'", vec![]),
224///     ("SELECT * FROM users WHERE role = 'manager'", vec![]),
225/// ]).expect("at least one query required");
226/// ```
227pub fn union<I, S>(queries: I) -> Option<SetOperation>
228where
229    I: IntoIterator<Item = (S, Vec<Value>)>,
230    S: Into<String>,
231{
232    combine_queries(SetOpType::Union, queries)
233}
234
235/// Create a UNION ALL of multiple queries.
236///
237/// Returns `None` if the iterator is empty.
238///
239/// # Example
240///
241/// ```ignore
242/// let query = union_all([
243///     ("SELECT id FROM table1", vec![]),
244///     ("SELECT id FROM table2", vec![]),
245///     ("SELECT id FROM table3", vec![]),
246/// ]).expect("at least one query required");
247/// ```
248pub fn union_all<I, S>(queries: I) -> Option<SetOperation>
249where
250    I: IntoIterator<Item = (S, Vec<Value>)>,
251    S: Into<String>,
252{
253    combine_queries(SetOpType::UnionAll, queries)
254}
255
256/// Create an INTERSECT of multiple queries.
257///
258/// Returns `None` if the iterator is empty.
259pub fn intersect<I, S>(queries: I) -> Option<SetOperation>
260where
261    I: IntoIterator<Item = (S, Vec<Value>)>,
262    S: Into<String>,
263{
264    combine_queries(SetOpType::Intersect, queries)
265}
266
267/// Create an INTERSECT ALL of multiple queries.
268///
269/// Returns `None` if the iterator is empty.
270pub fn intersect_all<I, S>(queries: I) -> Option<SetOperation>
271where
272    I: IntoIterator<Item = (S, Vec<Value>)>,
273    S: Into<String>,
274{
275    combine_queries(SetOpType::IntersectAll, queries)
276}
277
278/// Create an EXCEPT of multiple queries.
279///
280/// Returns `None` if the iterator is empty.
281pub fn except<I, S>(queries: I) -> Option<SetOperation>
282where
283    I: IntoIterator<Item = (S, Vec<Value>)>,
284    S: Into<String>,
285{
286    combine_queries(SetOpType::Except, queries)
287}
288
289/// Create an EXCEPT ALL of multiple queries.
290///
291/// Returns `None` if the iterator is empty.
292pub fn except_all<I, S>(queries: I) -> Option<SetOperation>
293where
294    I: IntoIterator<Item = (S, Vec<Value>)>,
295    S: Into<String>,
296{
297    combine_queries(SetOpType::ExceptAll, queries)
298}
299
300fn combine_queries<I, S>(op: SetOpType, queries: I) -> Option<SetOperation>
301where
302    I: IntoIterator<Item = (S, Vec<Value>)>,
303    S: Into<String>,
304{
305    let mut iter = queries.into_iter();
306
307    // Get the first query, return None if empty
308    let (first_sql, first_params) = iter.next()?;
309
310    let mut result = SetOperation::new(first_sql, first_params);
311
312    // Add remaining queries with the set operation
313    for (sql, params) in iter {
314        result = result.add_op(op, sql, params);
315    }
316
317    Some(result)
318}
319
320// ==================== Tests ====================
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::expr::Expr;
326
327    #[test]
328    fn test_union_basic() {
329        let query = SetOperation::new("SELECT * FROM users WHERE role = 'admin'", vec![])
330            .union("SELECT * FROM users WHERE role = 'manager'", vec![]);
331
332        let (sql, params) = query.build();
333        assert_eq!(
334            sql,
335            "(SELECT * FROM users WHERE role = 'admin') UNION (SELECT * FROM users WHERE role = 'manager')"
336        );
337        assert!(params.is_empty());
338    }
339
340    #[test]
341    fn test_union_all_basic() {
342        let query = SetOperation::new("SELECT id FROM table1", vec![])
343            .union_all("SELECT id FROM table2", vec![]);
344
345        let (sql, _) = query.build();
346        assert_eq!(
347            sql,
348            "(SELECT id FROM table1) UNION ALL (SELECT id FROM table2)"
349        );
350    }
351
352    #[test]
353    fn test_union_with_params() {
354        let query = SetOperation::new(
355            "SELECT * FROM users WHERE role = $1",
356            vec![Value::Text("admin".to_string())],
357        )
358        .union(
359            "SELECT * FROM users WHERE role = $2",
360            vec![Value::Text("manager".to_string())],
361        );
362
363        let (sql, params) = query.build();
364        assert_eq!(params.len(), 2);
365        assert_eq!(params[0], Value::Text("admin".to_string()));
366        assert_eq!(params[1], Value::Text("manager".to_string()));
367        assert!(sql.contains("$1"));
368        assert!(sql.contains("$2"));
369    }
370
371    #[test]
372    fn test_union_function() {
373        let query = union([
374            ("SELECT * FROM admins", vec![]),
375            ("SELECT * FROM managers", vec![]),
376            ("SELECT * FROM employees", vec![]),
377        ])
378        .expect("non-empty iterator");
379
380        let (sql, _) = query.build();
381        assert!(sql.contains("UNION"));
382        assert!(!sql.contains("UNION ALL"));
383        assert!(sql.contains("admins"));
384        assert!(sql.contains("managers"));
385        assert!(sql.contains("employees"));
386    }
387
388    #[test]
389    fn test_union_all_function() {
390        let query = union_all([
391            ("SELECT 1", vec![]),
392            ("SELECT 2", vec![]),
393            ("SELECT 3", vec![]),
394        ])
395        .expect("non-empty iterator");
396
397        let (sql, _) = query.build();
398        // Should have two UNION ALL operations
399        assert_eq!(sql.matches("UNION ALL").count(), 2);
400    }
401
402    #[test]
403    fn test_union_empty_returns_none() {
404        let empty: Vec<(&str, Vec<Value>)> = vec![];
405        assert!(union(empty).is_none());
406    }
407
408    #[test]
409    fn test_union_with_order_by() {
410        let query = SetOperation::new("SELECT name FROM users WHERE active = true", vec![])
411            .union("SELECT name FROM users WHERE premium = true", vec![])
412            .order_by(Expr::col("name").asc());
413
414        let (sql, _) = query.build();
415        assert!(sql.ends_with("ORDER BY \"name\" ASC"));
416    }
417
418    #[test]
419    fn test_union_with_limit_offset() {
420        let query = SetOperation::new("SELECT * FROM t1", vec![])
421            .union("SELECT * FROM t2", vec![])
422            .limit(10)
423            .offset(5);
424
425        let (sql, _) = query.build();
426        assert!(sql.ends_with("LIMIT 10 OFFSET 5"));
427    }
428
429    #[test]
430    fn test_intersect() {
431        let query = SetOperation::new("SELECT id FROM users WHERE active = true", vec![])
432            .intersect("SELECT id FROM users WHERE premium = true", vec![]);
433
434        let (sql, _) = query.build();
435        assert!(sql.contains("INTERSECT"));
436        assert!(!sql.contains("INTERSECT ALL"));
437    }
438
439    #[test]
440    fn test_intersect_all() {
441        let query = intersect_all([("SELECT id FROM t1", vec![]), ("SELECT id FROM t2", vec![])])
442            .expect("non-empty iterator");
443
444        let (sql, _) = query.build();
445        assert!(sql.contains("INTERSECT ALL"));
446    }
447
448    #[test]
449    fn test_except() {
450        let query = SetOperation::new("SELECT id FROM all_users", vec![])
451            .except("SELECT id FROM banned_users", vec![]);
452
453        let (sql, _) = query.build();
454        assert!(sql.contains("EXCEPT"));
455        assert!(!sql.contains("EXCEPT ALL"));
456    }
457
458    #[test]
459    fn test_except_all() {
460        let query = except_all([("SELECT id FROM t1", vec![]), ("SELECT id FROM t2", vec![])])
461            .expect("non-empty iterator");
462
463        let (sql, _) = query.build();
464        assert!(sql.contains("EXCEPT ALL"));
465    }
466
467    #[test]
468    fn test_chained_operations() {
469        let query = SetOperation::new("SELECT id FROM t1", vec![])
470            .union("SELECT id FROM t2", vec![])
471            .union_all("SELECT id FROM t3", vec![]);
472
473        let (sql, _) = query.build();
474        // First should be UNION, second should be UNION ALL
475        let union_pos = sql.find("UNION").unwrap();
476        let union_all_pos = sql.find("UNION ALL").unwrap();
477        assert!(union_pos < union_all_pos);
478    }
479
480    #[test]
481    fn test_complex_query() {
482        let query = SetOperation::new(
483            "SELECT name, email FROM users WHERE role = $1",
484            vec![Value::Text("admin".to_string())],
485        )
486        .union_all(
487            "SELECT name, email FROM users WHERE department = $2",
488            vec![Value::Text("engineering".to_string())],
489        )
490        .order_by(Expr::col("name").asc())
491        .order_by(Expr::col("email").desc())
492        .limit(100)
493        .offset(0);
494
495        let (sql, params) = query.build();
496
497        assert!(sql.contains("UNION ALL"));
498        assert!(sql.contains("ORDER BY"));
499        assert!(sql.contains("LIMIT 100"));
500        assert!(sql.contains("OFFSET 0"));
501        assert_eq!(params.len(), 2);
502    }
503
504    #[test]
505    fn test_set_op_type_sql() {
506        assert_eq!(SetOpType::Union.as_sql(), "UNION");
507        assert_eq!(SetOpType::UnionAll.as_sql(), "UNION ALL");
508        assert_eq!(SetOpType::Intersect.as_sql(), "INTERSECT");
509        assert_eq!(SetOpType::IntersectAll.as_sql(), "INTERSECT ALL");
510        assert_eq!(SetOpType::Except.as_sql(), "EXCEPT");
511        assert_eq!(SetOpType::ExceptAll.as_sql(), "EXCEPT ALL");
512    }
513}