Skip to main content

sqlmodel_query/
cte.rs

1//! Common Table Expressions (CTEs) for SQL queries.
2//!
3//! Provides support for WITH clauses including recursive CTEs.
4//!
5//! # Example
6//!
7//! ```ignore
8//! use sqlmodel_query::{Cte, CteRef, select};
9//!
10//! // Basic CTE
11//! let active_users = Cte::new("active_users")
12//!     .query(select!(User).filter(Expr::col("active").eq(true)));
13//!
14//! // Query using the CTE
15//! let query = select_from_cte(&active_users)
16//!     .columns(&["name", "email"]);
17//!
18//! // Recursive CTE for hierarchical data
19//! let hierarchy = Cte::recursive("hierarchy")
20//!     .columns(&["id", "name", "manager_id", "level"])
21//!     .initial(
22//!         select!(Employee)
23//!             .filter(Expr::col("manager_id").is_null())
24//!     )
25//!     .recursive_term(|cte| {
26//!         select!(Employee)
27//!             .join_on("hierarchy", Expr::col("manager_id").eq(cte.col("id")))
28//!     });
29//! ```
30
31use crate::expr::{Dialect, Expr};
32use sqlmodel_core::Value;
33
34/// A Common Table Expression (WITH clause).
35#[derive(Debug, Clone)]
36pub struct Cte {
37    /// Name of the CTE
38    name: String,
39    /// Column aliases (optional)
40    columns: Vec<String>,
41    /// Whether this is a recursive CTE
42    recursive: bool,
43    /// The SQL query for the CTE (pre-built)
44    query_sql: String,
45    /// Parameters for the CTE query
46    query_params: Vec<Value>,
47    /// For recursive CTEs: the UNION part
48    union_sql: Option<String>,
49    /// Parameters for the UNION part
50    union_params: Vec<Value>,
51}
52
53impl Cte {
54    /// Create a new non-recursive CTE.
55    ///
56    /// # Arguments
57    ///
58    /// * `name` - The name of the CTE (used to reference it in the main query)
59    ///
60    /// # Example
61    ///
62    /// ```ignore
63    /// let recent_orders = Cte::new("recent_orders")
64    ///     .as_select("SELECT * FROM orders WHERE created_at > NOW() - INTERVAL '7 days'");
65    /// ```
66    pub fn new(name: impl Into<String>) -> Self {
67        Self {
68            name: name.into(),
69            columns: Vec::new(),
70            recursive: false,
71            query_sql: String::new(),
72            query_params: Vec::new(),
73            union_sql: None,
74            union_params: Vec::new(),
75        }
76    }
77
78    /// Create a new recursive CTE.
79    ///
80    /// Recursive CTEs require an initial (anchor) term and a recursive term
81    /// joined with UNION ALL.
82    ///
83    /// # Example
84    ///
85    /// ```ignore
86    /// // Traverse an employee hierarchy
87    /// let hierarchy = Cte::recursive("org_chart")
88    ///     .columns(&["id", "name", "level"])
89    ///     .as_select("SELECT id, name, 0 FROM employees WHERE manager_id IS NULL")
90    ///     .union_all("SELECT e.id, e.name, h.level + 1 FROM employees e JOIN org_chart h ON e.manager_id = h.id");
91    /// ```
92    pub fn recursive(name: impl Into<String>) -> Self {
93        Self {
94            name: name.into(),
95            columns: Vec::new(),
96            recursive: true,
97            query_sql: String::new(),
98            query_params: Vec::new(),
99            union_sql: None,
100            union_params: Vec::new(),
101        }
102    }
103
104    /// Specify column aliases for the CTE.
105    ///
106    /// # Example
107    ///
108    /// ```ignore
109    /// Cte::new("totals")
110    ///     .columns(&["category", "total_amount"])
111    ///     .as_select("SELECT category, SUM(amount) FROM orders GROUP BY category");
112    /// ```
113    pub fn columns(mut self, cols: &[&str]) -> Self {
114        self.columns = cols.iter().map(|&s| s.to_string()).collect();
115        self
116    }
117
118    /// Set the CTE query from a raw SQL string.
119    ///
120    /// # Arguments
121    ///
122    /// * `sql` - The SQL query for the CTE
123    pub fn as_select(mut self, sql: impl Into<String>) -> Self {
124        self.query_sql = sql.into();
125        self
126    }
127
128    /// Set the CTE query from SQL with parameters.
129    ///
130    /// # Arguments
131    ///
132    /// * `sql` - The SQL query for the CTE
133    /// * `params` - Parameters to bind
134    pub fn as_select_with_params(mut self, sql: impl Into<String>, params: Vec<Value>) -> Self {
135        self.query_sql = sql.into();
136        self.query_params = params;
137        self
138    }
139
140    /// Add a UNION ALL clause for recursive CTEs.
141    ///
142    /// # Arguments
143    ///
144    /// * `sql` - The recursive term SQL
145    pub fn union_all(mut self, sql: impl Into<String>) -> Self {
146        self.union_sql = Some(sql.into());
147        self
148    }
149
150    /// Add a UNION ALL clause with parameters.
151    pub fn union_all_with_params(mut self, sql: impl Into<String>, params: Vec<Value>) -> Self {
152        self.union_sql = Some(sql.into());
153        self.union_params = params;
154        self
155    }
156
157    /// Get the name of this CTE.
158    pub fn name(&self) -> &str {
159        &self.name
160    }
161
162    /// Check if this is a recursive CTE.
163    pub fn is_recursive(&self) -> bool {
164        self.recursive
165    }
166
167    /// Create a reference to this CTE for use in queries.
168    ///
169    /// # Example
170    ///
171    /// ```ignore
172    /// let cte = Cte::new("active_users").as_select("...");
173    /// let cte_ref = cte.as_ref();
174    ///
175    /// // Use in expressions
176    /// let expr = cte_ref.col("name").eq("Alice");
177    /// ```
178    pub fn as_ref(&self) -> CteRef {
179        CteRef {
180            name: self.name.clone(),
181        }
182    }
183
184    /// Build the CTE definition SQL.
185    ///
186    /// Returns the SQL for use in a WITH clause and the parameters.
187    pub fn build(&self, dialect: Dialect) -> (String, Vec<Value>) {
188        let mut sql = String::new();
189        let mut params = Vec::new();
190
191        // CTE name and optional column list
192        sql.push_str(&dialect.quote_identifier(&self.name));
193
194        if !self.columns.is_empty() {
195            sql.push_str(" (");
196            let quoted_cols: Vec<_> = self
197                .columns
198                .iter()
199                .map(|c| dialect.quote_identifier(c))
200                .collect();
201            sql.push_str(&quoted_cols.join(", "));
202            sql.push(')');
203        }
204
205        sql.push_str(" AS (");
206
207        // Main query
208        sql.push_str(&self.query_sql);
209        params.extend(self.query_params.clone());
210
211        // UNION ALL for recursive CTEs
212        if let Some(union) = &self.union_sql {
213            sql.push_str(" UNION ALL ");
214            sql.push_str(union);
215            params.extend(self.union_params.clone());
216        }
217
218        sql.push(')');
219
220        (sql, params)
221    }
222}
223
224/// A reference to a CTE for use in expressions.
225#[derive(Debug, Clone)]
226pub struct CteRef {
227    name: String,
228}
229
230impl CteRef {
231    /// Create a new CTE reference.
232    pub fn new(name: impl Into<String>) -> Self {
233        Self { name: name.into() }
234    }
235
236    /// Reference a column in this CTE.
237    ///
238    /// # Example
239    ///
240    /// ```ignore
241    /// let cte_ref = CteRef::new("active_users");
242    /// let expr = cte_ref.col("email").like("%@example.com");
243    /// ```
244    pub fn col(&self, column: impl Into<String>) -> Expr {
245        Expr::qualified(&self.name, column)
246    }
247
248    /// Get the CTE name for use in FROM clauses.
249    pub fn name(&self) -> &str {
250        &self.name
251    }
252}
253
254/// A query with one or more CTEs.
255#[derive(Debug, Clone)]
256pub struct WithQuery {
257    /// List of CTEs in order of definition
258    ctes: Vec<Cte>,
259    /// The main query SQL
260    main_sql: String,
261    /// Parameters for the main query
262    main_params: Vec<Value>,
263}
264
265impl WithQuery {
266    /// Create a new query with CTEs.
267    pub fn new() -> Self {
268        Self {
269            ctes: Vec::new(),
270            main_sql: String::new(),
271            main_params: Vec::new(),
272        }
273    }
274
275    /// Add a CTE to this query.
276    ///
277    /// CTEs are added in order and can reference previously defined CTEs.
278    pub fn with_cte(mut self, cte: Cte) -> Self {
279        self.ctes.push(cte);
280        self
281    }
282
283    /// Add multiple CTEs to this query.
284    pub fn with_ctes(mut self, ctes: Vec<Cte>) -> Self {
285        self.ctes.extend(ctes);
286        self
287    }
288
289    /// Set the main query.
290    pub fn select(mut self, sql: impl Into<String>) -> Self {
291        self.main_sql = sql.into();
292        self
293    }
294
295    /// Set the main query with parameters.
296    pub fn select_with_params(mut self, sql: impl Into<String>, params: Vec<Value>) -> Self {
297        self.main_sql = sql.into();
298        self.main_params = params;
299        self
300    }
301
302    /// Build the complete SQL with WITH clause.
303    pub fn build(&self) -> (String, Vec<Value>) {
304        self.build_with_dialect(Dialect::Postgres)
305    }
306
307    /// Build the complete SQL with a specific dialect.
308    pub fn build_with_dialect(&self, dialect: Dialect) -> (String, Vec<Value>) {
309        let mut sql = String::new();
310        let mut params = Vec::new();
311
312        if !self.ctes.is_empty() {
313            // Check if any CTE is recursive
314            let has_recursive = self.ctes.iter().any(|c| c.recursive);
315
316            if has_recursive {
317                sql.push_str("WITH RECURSIVE ");
318            } else {
319                sql.push_str("WITH ");
320            }
321
322            // Build each CTE
323            let cte_sqls: Vec<String> = self
324                .ctes
325                .iter()
326                .map(|cte| {
327                    let (cte_sql, cte_params) = cte.build(dialect);
328                    params.extend(cte_params);
329                    cte_sql
330                })
331                .collect();
332
333            sql.push_str(&cte_sqls.join(", "));
334            sql.push(' ');
335        }
336
337        // Main query
338        sql.push_str(&self.main_sql);
339        params.extend(self.main_params.clone());
340
341        (sql, params)
342    }
343}
344
345impl Default for WithQuery {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351// ==================== Tests ====================
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_basic_cte() {
359        let cte = Cte::new("active_users").as_select("SELECT * FROM users WHERE active = true");
360
361        let (sql, params) = cte.build(Dialect::Postgres);
362        assert_eq!(
363            sql,
364            "\"active_users\" AS (SELECT * FROM users WHERE active = true)"
365        );
366        assert!(params.is_empty());
367    }
368
369    #[test]
370    fn test_cte_with_columns() {
371        let cte = Cte::new("user_totals")
372            .columns(&["user_id", "total"])
373            .as_select("SELECT user_id, SUM(amount) FROM orders GROUP BY user_id");
374
375        let (sql, params) = cte.build(Dialect::Postgres);
376        assert_eq!(
377            sql,
378            "\"user_totals\" (\"user_id\", \"total\") AS (SELECT user_id, SUM(amount) FROM orders GROUP BY user_id)"
379        );
380        assert!(params.is_empty());
381    }
382
383    #[test]
384    fn test_cte_with_params() {
385        let cte = Cte::new("recent_orders").as_select_with_params(
386            "SELECT * FROM orders WHERE amount > $1",
387            vec![Value::Int(100)],
388        );
389
390        let (sql, params) = cte.build(Dialect::Postgres);
391        assert_eq!(
392            sql,
393            "\"recent_orders\" AS (SELECT * FROM orders WHERE amount > $1)"
394        );
395        assert_eq!(params, vec![Value::Int(100)]);
396    }
397
398    #[test]
399    fn test_recursive_cte() {
400        let cte = Cte::recursive("hierarchy")
401            .columns(&["id", "name", "level"])
402            .as_select("SELECT id, name, 0 FROM employees WHERE manager_id IS NULL")
403            .union_all("SELECT e.id, e.name, h.level + 1 FROM employees e JOIN hierarchy h ON e.manager_id = h.id");
404
405        let (sql, _) = cte.build(Dialect::Postgres);
406        assert!(sql.contains("UNION ALL"));
407        assert!(cte.is_recursive());
408    }
409
410    #[test]
411    fn test_cte_ref_column() {
412        let cte_ref = CteRef::new("my_cte");
413        let expr = cte_ref.col("name");
414
415        let mut params = Vec::new();
416        let sql = expr.build(&mut params, 0);
417        assert_eq!(sql, "\"my_cte\".\"name\"");
418    }
419
420    #[test]
421    fn test_with_query_single_cte() {
422        let cte = Cte::new("active_users").as_select("SELECT * FROM users WHERE active = true");
423
424        let query = WithQuery::new()
425            .with_cte(cte)
426            .select("SELECT * FROM active_users");
427
428        let (sql, params) = query.build();
429        assert_eq!(
430            sql,
431            "WITH \"active_users\" AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"
432        );
433        assert!(params.is_empty());
434    }
435
436    #[test]
437    fn test_with_query_multiple_ctes() {
438        let cte1 = Cte::new("active_users").as_select("SELECT * FROM users WHERE active = true");
439
440        let cte2 = Cte::new("user_orders")
441            .as_select("SELECT u.id, COUNT(*) as order_count FROM active_users u JOIN orders o ON u.id = o.user_id GROUP BY u.id");
442
443        let query = WithQuery::new()
444            .with_cte(cte1)
445            .with_cte(cte2)
446            .select("SELECT * FROM user_orders WHERE order_count > 5");
447
448        let (sql, _) = query.build();
449        assert!(sql.starts_with("WITH "));
450        assert!(sql.contains("\"active_users\" AS"));
451        assert!(sql.contains("\"user_orders\" AS"));
452    }
453
454    #[test]
455    fn test_with_query_recursive() {
456        let cte = Cte::recursive("numbers")
457            .columns(&["n"])
458            .as_select("SELECT 1")
459            .union_all("SELECT n + 1 FROM numbers WHERE n < 10");
460
461        let query = WithQuery::new()
462            .with_cte(cte)
463            .select("SELECT * FROM numbers");
464
465        let (sql, _) = query.build();
466        assert!(sql.starts_with("WITH RECURSIVE "));
467    }
468
469    #[test]
470    fn test_cte_mysql_dialect() {
471        let cte = Cte::new("temp")
472            .columns(&["col1", "col2"])
473            .as_select("SELECT a, b FROM t");
474
475        let (sql, _) = cte.build(Dialect::Mysql);
476        assert_eq!(sql, "`temp` (`col1`, `col2`) AS (SELECT a, b FROM t)");
477    }
478
479    #[test]
480    fn test_cte_sqlite_dialect() {
481        let cte = Cte::new("temp").as_select("SELECT 1");
482
483        let (sql, _) = cte.build(Dialect::Sqlite);
484        assert_eq!(sql, "\"temp\" AS (SELECT 1)");
485    }
486
487    #[test]
488    fn test_with_query_params_aggregation() {
489        let cte = Cte::new("filtered")
490            .as_select_with_params("SELECT * FROM items WHERE price > $1", vec![Value::Int(50)]);
491
492        let query = WithQuery::new().with_cte(cte).select_with_params(
493            "SELECT * FROM filtered WHERE category = $2",
494            vec![Value::Text("electronics".to_string())],
495        );
496
497        let (sql, params) = query.build();
498        assert_eq!(params.len(), 2);
499        assert_eq!(params[0], Value::Int(50));
500        assert_eq!(params[1], Value::Text("electronics".to_string()));
501        assert!(sql.contains("$1"));
502        assert!(sql.contains("$2"));
503    }
504
505    #[test]
506    fn test_recursive_cte_hierarchy_example() {
507        // Classic organizational hierarchy example
508        let cte = Cte::recursive("org_chart")
509            .columns(&["id", "name", "manager_id", "level"])
510            .as_select("SELECT id, name, manager_id, 0 AS level FROM employees WHERE manager_id IS NULL")
511            .union_all("SELECT e.id, e.name, e.manager_id, oc.level + 1 FROM employees e INNER JOIN org_chart oc ON e.manager_id = oc.id");
512
513        let query = WithQuery::new()
514            .with_cte(cte)
515            .select("SELECT * FROM org_chart ORDER BY level, name");
516
517        let (sql, _) = query.build();
518
519        assert!(sql.starts_with("WITH RECURSIVE "));
520        assert!(sql.contains("\"org_chart\""));
521        assert!(sql.contains("UNION ALL"));
522        assert!(sql.contains("ORDER BY level, name"));
523    }
524
525    #[test]
526    fn test_cte_chained_references() {
527        // CTE that references another CTE
528        let cte1 =
529            Cte::new("base_data").as_select("SELECT id, value FROM raw_data WHERE valid = true");
530
531        let cte2 = Cte::new("aggregated")
532            .as_select("SELECT COUNT(*) as cnt, SUM(value) as total FROM base_data");
533
534        let query = WithQuery::new()
535            .with_cte(cte1)
536            .with_cte(cte2)
537            .select("SELECT * FROM aggregated");
538
539        let (sql, _) = query.build();
540
541        // Verify both CTEs are present and in order
542        let base_pos = sql.find("\"base_data\"").unwrap();
543        let agg_pos = sql.find("\"aggregated\"").unwrap();
544        assert!(
545            base_pos < agg_pos,
546            "base_data should come before aggregated"
547        );
548    }
549}