sql_cli/refactoring/
extraction.rs1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Serialize, Deserialize)]
9pub struct ExtractionSuggestion {
10 pub expression: String,
11 pub reason: ExtractionReason,
12 pub suggested_cte_name: String,
13 pub cte_query: String,
14 pub replacement: String,
15 pub complexity_score: u32,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
19pub enum ExtractionReason {
20 ComplexCalculation,
21 RepeatedExpression,
22 WindowFunction,
23 Subquery,
24 StringManipulation,
25 CaseStatement,
26 AggregateInWhere,
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30pub struct CTEChain {
31 pub ctes: Vec<CTEDefinition>,
32 pub main_query: String,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36pub struct CTEDefinition {
37 pub name: String,
38 pub query: String,
39 pub dependencies: Vec<String>,
40 pub columns: Vec<String>,
41}
42
43pub struct ExtractionAnalyzer;
45
46impl ExtractionAnalyzer {
47 pub fn analyze(sql: &str) -> Vec<ExtractionSuggestion> {
49 let mut suggestions = Vec::new();
50
51 if sql.contains(" * ") || sql.contains(" / ") {
53 if let Some(expr) = Self::find_complex_calculation(sql) {
54 suggestions.push(ExtractionSuggestion {
55 expression: expr.clone(),
56 reason: ExtractionReason::ComplexCalculation,
57 suggested_cte_name: "calculated".to_string(),
58 cte_query: Self::generate_cte_for_calculation(&expr),
59 replacement: "calculated_value".to_string(),
60 complexity_score: 10,
61 });
62 }
63 }
64
65 if sql.to_uppercase().contains("CASE WHEN") {
67 if let Some(case_expr) = Self::find_case_statement(sql) {
68 suggestions.push(ExtractionSuggestion {
69 expression: case_expr.clone(),
70 reason: ExtractionReason::CaseStatement,
71 suggested_cte_name: "categorized".to_string(),
72 cte_query: Self::generate_cte_for_case(&case_expr),
73 replacement: "category".to_string(),
74 complexity_score: 15,
75 });
76 }
77 }
78
79 if sql.contains("SUBSTRING") || sql.contains("CONTAINS") {
81 if let Some(str_expr) = Self::find_string_manipulation(sql) {
82 suggestions.push(ExtractionSuggestion {
83 expression: str_expr.clone(),
84 reason: ExtractionReason::StringManipulation,
85 suggested_cte_name: "parsed".to_string(),
86 cte_query: Self::generate_cte_for_string(&str_expr),
87 replacement: "parsed_value".to_string(),
88 complexity_score: 12,
89 });
90 }
91 }
92
93 if sql.contains("OVER (") {
95 if let Some(window_expr) = Self::find_window_function(sql) {
96 suggestions.push(ExtractionSuggestion {
97 expression: window_expr.clone(),
98 reason: ExtractionReason::WindowFunction,
99 suggested_cte_name: "windowed".to_string(),
100 cte_query: Self::generate_cte_for_window(&window_expr),
101 replacement: "window_result".to_string(),
102 complexity_score: 20,
103 });
104 }
105 }
106
107 suggestions.sort_by_key(|s| std::cmp::Reverse(s.complexity_score));
109
110 suggestions
111 }
112
113 fn find_complex_calculation(sql: &str) -> Option<String> {
114 if sql.contains("price * quantity") {
116 return Some("price * quantity".to_string());
117 }
118 if sql.contains("amount * rate") {
119 return Some("amount * rate".to_string());
120 }
121 None
122 }
123
124 fn find_case_statement(sql: &str) -> Option<String> {
125 let upper = sql.to_uppercase();
127 if let Some(start) = upper.find("CASE") {
128 if let Some(end) = upper[start..].find("END") {
129 return Some(sql[start..start + end + 3].to_string());
130 }
131 }
132 None
133 }
134
135 fn find_string_manipulation(sql: &str) -> Option<String> {
136 if sql.contains("SUBSTRING_AFTER") {
138 if let Some(start) = sql.find("SUBSTRING_AFTER") {
140 if let Some(end) = Self::find_matching_paren(&sql[start..]) {
141 return Some(sql[start..start + end + 1].to_string());
142 }
143 }
144 }
145 None
146 }
147
148 fn find_window_function(sql: &str) -> Option<String> {
149 if let Some(start) = sql.find("ROW_NUMBER()") {
151 if let Some(over_pos) = sql[start..].find("OVER") {
152 if let Some(end) = Self::find_matching_paren(&sql[start + over_pos + 4..]) {
153 return Some(sql[start..start + over_pos + 5 + end].to_string());
154 }
155 }
156 }
157 None
158 }
159
160 fn find_matching_paren(s: &str) -> Option<usize> {
161 let mut depth = 0;
162 let mut in_paren = false;
163
164 for (i, ch) in s.char_indices() {
165 match ch {
166 '(' => {
167 depth += 1;
168 in_paren = true;
169 }
170 ')' => {
171 depth -= 1;
172 if depth == 0 && in_paren {
173 return Some(i);
174 }
175 }
176 _ => {}
177 }
178 }
179 None
180 }
181
182 fn generate_cte_for_calculation(expr: &str) -> String {
183 format!("SELECT *, {} as calculated_value FROM source_table", expr)
184 }
185
186 fn generate_cte_for_case(expr: &str) -> String {
187 format!("SELECT *, {} as category FROM source_table", expr)
188 }
189
190 fn generate_cte_for_string(expr: &str) -> String {
191 format!("SELECT *, {} as parsed_value FROM source_table", expr)
192 }
193
194 fn generate_cte_for_window(expr: &str) -> String {
195 format!("SELECT *, {} as window_result FROM source_table", expr)
196 }
197}
198
199pub struct CTEOptimizer;
201
202impl CTEOptimizer {
203 pub fn optimize_chain(chain: &CTEChain) -> Vec<String> {
205 let mut suggestions = Vec::new();
206
207 for i in 0..chain.ctes.len() {
209 for j in i + 1..chain.ctes.len() {
210 if Self::can_combine(&chain.ctes[i], &chain.ctes[j]) {
211 suggestions.push(format!(
212 "CTEs '{}' and '{}' could be combined to reduce complexity",
213 chain.ctes[i].name, chain.ctes[j].name
214 ));
215 }
216 }
217 }
218
219 let used_ctes = Self::find_used_ctes(&chain.main_query, &chain.ctes);
221 for cte in &chain.ctes {
222 if !used_ctes.contains(&cte.name) {
223 suggestions.push(format!("CTE '{}' appears to be unused", cte.name));
224 }
225 }
226
227 if Self::is_linear_chain(&chain.ctes) {
229 suggestions.push("This linear CTE chain could potentially be flattened".to_string());
230 }
231
232 suggestions
233 }
234
235 fn can_combine(cte1: &CTEDefinition, cte2: &CTEDefinition) -> bool {
236 cte1.dependencies.contains(&cte2.name) || cte2.dependencies.contains(&cte1.name)
238 }
239
240 fn find_used_ctes(query: &str, ctes: &[CTEDefinition]) -> HashSet<String> {
241 let mut used = HashSet::new();
242 for cte in ctes {
243 if query.contains(&cte.name) {
244 used.insert(cte.name.clone());
245 }
246 }
247 used
248 }
249
250 fn is_linear_chain(ctes: &[CTEDefinition]) -> bool {
251 for i in 1..ctes.len() {
253 if ctes[i].dependencies.len() != 1 {
254 return false;
255 }
256 if !ctes[i].dependencies.contains(&ctes[i - 1].name) {
257 return false;
258 }
259 }
260 true
261 }
262}
263
264pub fn suggest_extraction(sql: &str) -> Result<serde_json::Value> {
266 let suggestions = ExtractionAnalyzer::analyze(sql);
267
268 Ok(serde_json::json!({
269 "original": sql,
270 "suggestions": suggestions,
271 "recommendation": if !suggestions.is_empty() {
272 format!("Consider extracting {} expressions to CTEs", suggestions.len())
273 } else {
274 "No extraction opportunities found".to_string()
275 }
276 }))
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_extraction_detection() {
285 let sql = "SELECT * FROM orders WHERE price * quantity > 1000";
286 let suggestions = ExtractionAnalyzer::analyze(sql);
287
288 assert!(!suggestions.is_empty());
289 assert_eq!(
290 suggestions[0].reason as u32,
291 ExtractionReason::ComplexCalculation as u32
292 );
293 }
294
295 #[test]
296 fn test_case_extraction() {
297 let sql = "SELECT CASE WHEN age <= 20 THEN 'young' ELSE 'old' END FROM users";
298 let suggestions = ExtractionAnalyzer::analyze(sql);
299
300 assert!(suggestions
301 .iter()
302 .any(|s| matches!(s.reason, ExtractionReason::CaseStatement)));
303 }
304}