sql_cli/refactoring/
mod.rs1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7pub mod banding;
8pub mod conditional_agg;
9pub mod extraction;
10
11pub use banding::{CaseCondition, CaseGenerator, CaseStyle};
12
13#[derive(Debug, Serialize, Deserialize)]
14pub struct RefactoringResult {
15 pub original: String,
16 pub transformed: String,
17 pub description: String,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub suggestions: Option<Vec<String>>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct BandingConfig {
24 pub column: String,
25 pub bands: Vec<Band>,
26 pub else_label: Option<String>,
27 pub alias: Option<String>,
28}
29
30#[derive(Debug, Serialize, Deserialize, PartialEq)]
31pub struct Band {
32 pub min: Option<f64>,
33 pub max: Option<f64>,
34 pub label: String,
35}
36
37impl BandingConfig {
38 pub fn from_string(column: &str, bands_str: &str) -> Result<Self> {
40 let mut bands = Vec::new();
41
42 for band_str in bands_str.split(',') {
43 let band_str = band_str.trim();
44
45 if band_str.ends_with('+') {
46 let min_str = band_str.trim_end_matches('+');
48 let min = min_str.parse::<f64>()?;
49 bands.push(Band {
50 min: Some(min),
51 max: None,
52 label: band_str.to_string(),
53 });
54 } else if band_str.contains('-') {
55 let parts: Vec<&str> = band_str.split('-').collect();
57 if parts.len() == 2 {
58 let min = parts[0].parse::<f64>()?;
59 let max = parts[1].parse::<f64>()?;
60 bands.push(Band {
61 min: Some(min),
62 max: Some(max),
63 label: band_str.to_string(),
64 });
65 }
66 }
67 }
68
69 Ok(BandingConfig {
70 column: column.to_string(),
71 bands,
72 else_label: None,
73 alias: Some(format!("{}_band", column)),
74 })
75 }
76
77 pub fn to_sql(&self) -> String {
79 let mut sql = String::from("CASE");
80
81 for band in &self.bands {
82 sql.push('\n');
83 if let Some(min) = band.min {
84 if let Some(max) = band.max {
85 if band.min == Some(0.0) || self.bands.iter().position(|b| b == band) == Some(0)
86 {
87 sql.push_str(&format!(
88 " WHEN {} <= {} THEN '{}'",
89 self.column, max, band.label
90 ));
91 } else {
92 sql.push_str(&format!(
93 " WHEN {} > {} AND {} <= {} THEN '{}'",
94 self.column, min, self.column, max, band.label
95 ));
96 }
97 } else {
98 sql.push_str(&format!(
99 " WHEN {} > {} THEN '{}'",
100 self.column, min, band.label
101 ));
102 }
103 } else if let Some(max) = band.max {
104 sql.push_str(&format!(
105 " WHEN {} <= {} THEN '{}'",
106 self.column, max, band.label
107 ));
108 }
109 }
110
111 if let Some(else_label) = &self.else_label {
112 sql.push_str(&format!("\n ELSE '{}'", else_label));
113 }
114
115 sql.push_str("\nEND");
116
117 if let Some(alias) = &self.alias {
118 sql.push_str(&format!(" AS {}", alias));
119 }
120
121 sql
122 }
123}
124
125pub fn generate_auto_bands(min: f64, max: f64, num_buckets: usize) -> Vec<Band> {
127 let range = max - min;
128 let bucket_size = range / num_buckets as f64;
129
130 let mut bands = Vec::new();
131
132 for i in 0..num_buckets {
133 let band_min = min + (i as f64 * bucket_size);
134 let band_max = if i == num_buckets - 1 {
135 max
136 } else {
137 min + ((i + 1) as f64 * bucket_size)
138 };
139
140 let label = if i == num_buckets - 1 {
141 format!("{:.0}+", band_min)
142 } else {
143 format!("{:.0}-{:.0}", band_min, band_max)
144 };
145
146 bands.push(Band {
147 min: if i == 0 { None } else { Some(band_min) },
148 max: if i == num_buckets - 1 {
149 None
150 } else {
151 Some(band_max)
152 },
153 label,
154 });
155 }
156
157 bands
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_banding_generation() {
166 let config = BandingConfig::from_string("age", "0-10,11-20,21-30,30+").unwrap();
167 let sql = config.to_sql();
168
169 assert!(sql.contains("WHEN age <= 10 THEN '0-10'"));
170 assert!(sql.contains("WHEN age > 30 THEN '30+'"));
171 assert!(sql.contains("AS age_band"));
172 }
173
174 #[test]
175 fn test_auto_bands() {
176 let bands = generate_auto_bands(0.0, 100.0, 4);
177 assert_eq!(bands.len(), 4);
178 assert_eq!(bands[0].label, "0-25");
179 assert_eq!(bands[3].label, "75+");
180 }
181}