sql_cli/refactoring/
extraction.rs

1// CTE Extraction and Hoisting Tools
2// Analyzes SQL and suggests/performs CTE extractions
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Serialize, Deserialize)]
9pub struct ExtractionSuggestion {
10    pub expression: String,
11    pub reason: ExtractionReason,
12    pub suggested_cte_name: String,
13    pub cte_query: String,
14    pub replacement: String,
15    pub complexity_score: u32,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
19pub enum ExtractionReason {
20    ComplexCalculation,
21    RepeatedExpression,
22    WindowFunction,
23    Subquery,
24    StringManipulation,
25    CaseStatement,
26    AggregateInWhere,
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30pub struct CTEChain {
31    pub ctes: Vec<CTEDefinition>,
32    pub main_query: String,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36pub struct CTEDefinition {
37    pub name: String,
38    pub query: String,
39    pub dependencies: Vec<String>,
40    pub columns: Vec<String>,
41}
42
43/// Analyzes a query for potential CTE extractions
44pub struct ExtractionAnalyzer;
45
46impl ExtractionAnalyzer {
47    /// Analyze a SQL query for extraction opportunities
48    pub fn analyze(sql: &str) -> Vec<ExtractionSuggestion> {
49        let mut suggestions = Vec::new();
50
51        // Pattern 1: Complex calculations (multiplication, division, functions)
52        if sql.contains(" * ") || sql.contains(" / ") {
53            if let Some(expr) = Self::find_complex_calculation(sql) {
54                suggestions.push(ExtractionSuggestion {
55                    expression: expr.clone(),
56                    reason: ExtractionReason::ComplexCalculation,
57                    suggested_cte_name: "calculated".to_string(),
58                    cte_query: Self::generate_cte_for_calculation(&expr),
59                    replacement: "calculated_value".to_string(),
60                    complexity_score: 10,
61                });
62            }
63        }
64
65        // Pattern 2: CASE statements in WHERE or complex CASE
66        if sql.to_uppercase().contains("CASE WHEN") {
67            if let Some(case_expr) = Self::find_case_statement(sql) {
68                suggestions.push(ExtractionSuggestion {
69                    expression: case_expr.clone(),
70                    reason: ExtractionReason::CaseStatement,
71                    suggested_cte_name: "categorized".to_string(),
72                    cte_query: Self::generate_cte_for_case(&case_expr),
73                    replacement: "category".to_string(),
74                    complexity_score: 15,
75                });
76            }
77        }
78
79        // Pattern 3: String manipulation (SUBSTRING, CONTAINS, etc.)
80        if sql.contains("SUBSTRING") || sql.contains("CONTAINS") {
81            if let Some(str_expr) = Self::find_string_manipulation(sql) {
82                suggestions.push(ExtractionSuggestion {
83                    expression: str_expr.clone(),
84                    reason: ExtractionReason::StringManipulation,
85                    suggested_cte_name: "parsed".to_string(),
86                    cte_query: Self::generate_cte_for_string(&str_expr),
87                    replacement: "parsed_value".to_string(),
88                    complexity_score: 12,
89                });
90            }
91        }
92
93        // Pattern 4: Window functions that could be pre-computed
94        if sql.contains("OVER (") {
95            if let Some(window_expr) = Self::find_window_function(sql) {
96                suggestions.push(ExtractionSuggestion {
97                    expression: window_expr.clone(),
98                    reason: ExtractionReason::WindowFunction,
99                    suggested_cte_name: "windowed".to_string(),
100                    cte_query: Self::generate_cte_for_window(&window_expr),
101                    replacement: "window_result".to_string(),
102                    complexity_score: 20,
103                });
104            }
105        }
106
107        // Sort by complexity score (higher = more beneficial to extract)
108        suggestions.sort_by_key(|s| std::cmp::Reverse(s.complexity_score));
109
110        suggestions
111    }
112
113    fn find_complex_calculation(sql: &str) -> Option<String> {
114        // Simplified pattern matching - in real implementation would use parser
115        if sql.contains("price * quantity") {
116            return Some("price * quantity".to_string());
117        }
118        if sql.contains("amount * rate") {
119            return Some("amount * rate".to_string());
120        }
121        None
122    }
123
124    fn find_case_statement(sql: &str) -> Option<String> {
125        // Find CASE...END blocks
126        let upper = sql.to_uppercase();
127        if let Some(start) = upper.find("CASE") {
128            if let Some(end) = upper[start..].find("END") {
129                return Some(sql[start..start + end + 3].to_string());
130            }
131        }
132        None
133    }
134
135    fn find_string_manipulation(sql: &str) -> Option<String> {
136        // Find string functions
137        if sql.contains("SUBSTRING_AFTER") {
138            // Extract the full function call
139            if let Some(start) = sql.find("SUBSTRING_AFTER") {
140                if let Some(end) = Self::find_matching_paren(&sql[start..]) {
141                    return Some(sql[start..start + end + 1].to_string());
142                }
143            }
144        }
145        None
146    }
147
148    fn find_window_function(sql: &str) -> Option<String> {
149        // Find window function expressions
150        if let Some(start) = sql.find("ROW_NUMBER()") {
151            if let Some(over_pos) = sql[start..].find("OVER") {
152                if let Some(end) = Self::find_matching_paren(&sql[start + over_pos + 4..]) {
153                    return Some(sql[start..start + over_pos + 5 + end].to_string());
154                }
155            }
156        }
157        None
158    }
159
160    fn find_matching_paren(s: &str) -> Option<usize> {
161        let mut depth = 0;
162        let mut in_paren = false;
163
164        for (i, ch) in s.char_indices() {
165            match ch {
166                '(' => {
167                    depth += 1;
168                    in_paren = true;
169                }
170                ')' => {
171                    depth -= 1;
172                    if depth == 0 && in_paren {
173                        return Some(i);
174                    }
175                }
176                _ => {}
177            }
178        }
179        None
180    }
181
182    fn generate_cte_for_calculation(expr: &str) -> String {
183        format!("SELECT *, {} as calculated_value FROM source_table", expr)
184    }
185
186    fn generate_cte_for_case(expr: &str) -> String {
187        format!("SELECT *, {} as category FROM source_table", expr)
188    }
189
190    fn generate_cte_for_string(expr: &str) -> String {
191        format!("SELECT *, {} as parsed_value FROM source_table", expr)
192    }
193
194    fn generate_cte_for_window(expr: &str) -> String {
195        format!("SELECT *, {} as window_result FROM source_table", expr)
196    }
197}
198
199/// Optimizes CTE chains by analyzing dependencies and suggesting combinations
200pub struct CTEOptimizer;
201
202impl CTEOptimizer {
203    /// Analyze a CTE chain and suggest optimizations
204    pub fn optimize_chain(chain: &CTEChain) -> Vec<String> {
205        let mut suggestions = Vec::new();
206
207        // Check for CTEs that could be combined
208        for i in 0..chain.ctes.len() {
209            for j in i + 1..chain.ctes.len() {
210                if Self::can_combine(&chain.ctes[i], &chain.ctes[j]) {
211                    suggestions.push(format!(
212                        "CTEs '{}' and '{}' could be combined to reduce complexity",
213                        chain.ctes[i].name, chain.ctes[j].name
214                    ));
215                }
216            }
217        }
218
219        // Check for unused CTEs
220        let used_ctes = Self::find_used_ctes(&chain.main_query, &chain.ctes);
221        for cte in &chain.ctes {
222            if !used_ctes.contains(&cte.name) {
223                suggestions.push(format!("CTE '{}' appears to be unused", cte.name));
224            }
225        }
226
227        // Check for linear chains that could be flattened
228        if Self::is_linear_chain(&chain.ctes) {
229            suggestions.push("This linear CTE chain could potentially be flattened".to_string());
230        }
231
232        suggestions
233    }
234
235    fn can_combine(cte1: &CTEDefinition, cte2: &CTEDefinition) -> bool {
236        // Simple heuristic: if one depends on the other and doesn't add much complexity
237        cte1.dependencies.contains(&cte2.name) || cte2.dependencies.contains(&cte1.name)
238    }
239
240    fn find_used_ctes(query: &str, ctes: &[CTEDefinition]) -> HashSet<String> {
241        let mut used = HashSet::new();
242        for cte in ctes {
243            if query.contains(&cte.name) {
244                used.insert(cte.name.clone());
245            }
246        }
247        used
248    }
249
250    fn is_linear_chain(ctes: &[CTEDefinition]) -> bool {
251        // Check if each CTE depends only on the previous one
252        for i in 1..ctes.len() {
253            if ctes[i].dependencies.len() != 1 {
254                return false;
255            }
256            if !ctes[i].dependencies.contains(&ctes[i - 1].name) {
257                return false;
258            }
259        }
260        true
261    }
262}
263
264/// Generates SQL transformation suggestions
265pub fn suggest_extraction(sql: &str) -> Result<serde_json::Value> {
266    let suggestions = ExtractionAnalyzer::analyze(sql);
267
268    Ok(serde_json::json!({
269        "original": sql,
270        "suggestions": suggestions,
271        "recommendation": if !suggestions.is_empty() {
272            format!("Consider extracting {} expressions to CTEs", suggestions.len())
273        } else {
274            "No extraction opportunities found".to_string()
275        }
276    }))
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_extraction_detection() {
285        let sql = "SELECT * FROM orders WHERE price * quantity > 1000";
286        let suggestions = ExtractionAnalyzer::analyze(sql);
287
288        assert!(!suggestions.is_empty());
289        assert_eq!(
290            suggestions[0].reason as u32,
291            ExtractionReason::ComplexCalculation as u32
292        );
293    }
294
295    #[test]
296    fn test_case_extraction() {
297        let sql = "SELECT CASE WHEN age <= 20 THEN 'young' ELSE 'old' END FROM users";
298        let suggestions = ExtractionAnalyzer::analyze(sql);
299
300        assert!(suggestions
301            .iter()
302            .any(|s| matches!(s.reason, ExtractionReason::CaseStatement)));
303    }
304}