sql_cli/refactoring/
mod.rs

1// SQL Refactoring Tools for IDE Integration
2// Provides programmatic transformations of SQL queries
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7pub mod banding;
8pub mod conditional_agg;
9pub mod extraction;
10
11pub use banding::{CaseCondition, CaseGenerator, CaseStyle};
12
13#[derive(Debug, Serialize, Deserialize)]
14pub struct RefactoringResult {
15    pub original: String,
16    pub transformed: String,
17    pub description: String,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub suggestions: Option<Vec<String>>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct BandingConfig {
24    pub column: String,
25    pub bands: Vec<Band>,
26    pub else_label: Option<String>,
27    pub alias: Option<String>,
28}
29
30#[derive(Debug, Serialize, Deserialize, PartialEq)]
31pub struct Band {
32    pub min: Option<f64>,
33    pub max: Option<f64>,
34    pub label: String,
35}
36
37impl BandingConfig {
38    /// Parse bands from string format like "0-10,11-20,21-30,30+"
39    pub fn from_string(column: &str, bands_str: &str) -> Result<Self> {
40        let mut bands = Vec::new();
41
42        for band_str in bands_str.split(',') {
43            let band_str = band_str.trim();
44
45            if band_str.ends_with('+') {
46                // Handle "30+" format
47                let min_str = band_str.trim_end_matches('+');
48                let min = min_str.parse::<f64>()?;
49                bands.push(Band {
50                    min: Some(min),
51                    max: None,
52                    label: band_str.to_string(),
53                });
54            } else if band_str.contains('-') {
55                // Handle "0-10" format
56                let parts: Vec<&str> = band_str.split('-').collect();
57                if parts.len() == 2 {
58                    let min = parts[0].parse::<f64>()?;
59                    let max = parts[1].parse::<f64>()?;
60                    bands.push(Band {
61                        min: Some(min),
62                        max: Some(max),
63                        label: band_str.to_string(),
64                    });
65                }
66            }
67        }
68
69        Ok(BandingConfig {
70            column: column.to_string(),
71            bands,
72            else_label: None,
73            alias: Some(format!("{}_band", column)),
74        })
75    }
76
77    /// Generate SQL CASE statement for banding
78    pub fn to_sql(&self) -> String {
79        let mut sql = String::from("CASE");
80
81        for band in &self.bands {
82            sql.push('\n');
83            if let Some(min) = band.min {
84                if let Some(max) = band.max {
85                    if band.min == Some(0.0) || self.bands.iter().position(|b| b == band) == Some(0)
86                    {
87                        sql.push_str(&format!(
88                            "    WHEN {} <= {} THEN '{}'",
89                            self.column, max, band.label
90                        ));
91                    } else {
92                        sql.push_str(&format!(
93                            "    WHEN {} > {} AND {} <= {} THEN '{}'",
94                            self.column, min, self.column, max, band.label
95                        ));
96                    }
97                } else {
98                    sql.push_str(&format!(
99                        "    WHEN {} > {} THEN '{}'",
100                        self.column, min, band.label
101                    ));
102                }
103            } else if let Some(max) = band.max {
104                sql.push_str(&format!(
105                    "    WHEN {} <= {} THEN '{}'",
106                    self.column, max, band.label
107                ));
108            }
109        }
110
111        if let Some(else_label) = &self.else_label {
112            sql.push_str(&format!("\n    ELSE '{}'", else_label));
113        }
114
115        sql.push_str("\nEND");
116
117        if let Some(alias) = &self.alias {
118            sql.push_str(&format!(" AS {}", alias));
119        }
120
121        sql
122    }
123}
124
125/// Generate automatic bands for a numeric range
126pub fn generate_auto_bands(min: f64, max: f64, num_buckets: usize) -> Vec<Band> {
127    let range = max - min;
128    let bucket_size = range / num_buckets as f64;
129
130    let mut bands = Vec::new();
131
132    for i in 0..num_buckets {
133        let band_min = min + (i as f64 * bucket_size);
134        let band_max = if i == num_buckets - 1 {
135            max
136        } else {
137            min + ((i + 1) as f64 * bucket_size)
138        };
139
140        let label = if i == num_buckets - 1 {
141            format!("{:.0}+", band_min)
142        } else {
143            format!("{:.0}-{:.0}", band_min, band_max)
144        };
145
146        bands.push(Band {
147            min: if i == 0 { None } else { Some(band_min) },
148            max: if i == num_buckets - 1 {
149                None
150            } else {
151                Some(band_max)
152            },
153            label,
154        });
155    }
156
157    bands
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_banding_generation() {
166        let config = BandingConfig::from_string("age", "0-10,11-20,21-30,30+").unwrap();
167        let sql = config.to_sql();
168
169        assert!(sql.contains("WHEN age <= 10 THEN '0-10'"));
170        assert!(sql.contains("WHEN age > 30 THEN '30+'"));
171        assert!(sql.contains("AS age_band"));
172    }
173
174    #[test]
175    fn test_auto_bands() {
176        let bands = generate_auto_bands(0.0, 100.0, 4);
177        assert_eq!(bands.len(), 4);
178        assert_eq!(bands[0].label, "0-25");
179        assert_eq!(bands[3].label, "75+");
180    }
181}