sql_cli/cli/
refactoring.rs1use std::io;
2
3use crate::data::csv_datasource;
4use crate::refactoring::banding::CaseGenerator;
5
6pub fn handle_banding_generation(args: &[String]) -> io::Result<()> {
8 let column_pos = args.iter().position(|arg| arg == "--column");
10 let bands_pos = args.iter().position(|arg| arg == "--bands");
11
12 let column = column_pos.and_then(|pos| args.get(pos + 1));
13 let bands = bands_pos.and_then(|pos| args.get(pos + 1));
14
15 match (column, bands) {
16 (Some(col), Some(bands_spec)) => {
17 let case_sql = generate_banding_case(col, bands_spec);
19 println!("{}", case_sql);
20 Ok(())
21 }
22 _ => {
23 eprintln!("Error: --generate-bands requires --column <name> and --bands <spec>");
24 eprintln!("Example: --generate-bands --column age --bands \"0-24,25-49,50-74,75+\"");
25 std::process::exit(1);
26 }
27 }
28}
29
30pub fn handle_case_generation(args: &[String]) -> io::Result<()> {
32 let case_pos = args
34 .iter()
35 .position(|arg| arg == "--generate-case")
36 .unwrap();
37 let file_path = args.get(case_pos + 1);
38
39 let column_pos = args.iter().position(|arg| arg == "--column");
41 let column = column_pos.and_then(|pos| args.get(pos + 1));
42
43 let style_pos = args.iter().position(|arg| arg == "--style");
44 let style = style_pos
45 .and_then(|pos| args.get(pos + 1))
46 .map(|s| s.as_str())
47 .unwrap_or("values");
48
49 let labels_pos = args.iter().position(|arg| arg == "--labels");
50 let labels = labels_pos.and_then(|pos| args.get(pos + 1)).map(|l| {
51 l.split(',')
52 .map(|s| s.trim().to_string())
53 .collect::<Vec<_>>()
54 });
55
56 match (file_path, column) {
57 (Some(path), Some(col)) => {
58 let datasource = match csv_datasource::CsvDataSource::load_from_file(path, "data") {
60 Ok(ds) => ds,
61 Err(e) => {
62 eprintln!("Error loading file {}: {}", path, e);
63 std::process::exit(1);
64 }
65 };
66
67 let datatable = datasource.to_datatable();
68
69 let col_index = datatable
71 .columns
72 .iter()
73 .position(|c| c.name.eq_ignore_ascii_case(col))
74 .ok_or_else(|| {
75 io::Error::new(
76 io::ErrorKind::NotFound,
77 format!("Column '{}' not found", col),
78 )
79 })?;
80
81 match style {
83 "values" => {
84 let mut distinct_values = std::collections::BTreeSet::new();
86 for row in &datatable.rows {
87 if let Some(value) = row.get(col_index) {
88 if !value.is_null() {
89 distinct_values.insert(value.to_string());
90 }
91 }
92 }
93
94 let value_mappings: Vec<(String, String)> = if let Some(ref labels) = labels {
96 distinct_values
97 .into_iter()
98 .zip(labels.iter())
99 .map(|(v, l)| (v, l.clone()))
100 .collect()
101 } else {
102 distinct_values
103 .into_iter()
104 .map(|v| {
105 let label = v.replace('_', " ").replace('-', " ");
106 let label = label
107 .split_whitespace()
108 .map(|word| {
109 let mut chars = word.chars();
110 match chars.next() {
111 None => String::new(),
112 Some(first) => {
113 first.to_uppercase().collect::<String>()
114 + chars.as_str()
115 }
116 }
117 })
118 .collect::<Vec<_>>()
119 .join(" ");
120 (v, label)
121 })
122 .collect()
123 };
124
125 let generator = CaseGenerator::from_values(col, value_mappings);
126 println!("{}", generator.to_sql());
127 }
128 "ranges" => {
129 let mut min_val = f64::MAX;
131 let mut max_val = f64::MIN;
132 let mut count = 0;
133
134 for row in &datatable.rows {
135 if let Some(value) = row.get(col_index) {
136 if let Ok(num) = value.to_string().parse::<f64>() {
137 min_val = min_val.min(num);
138 max_val = max_val.max(num);
139 count += 1;
140 }
141 }
142 }
143
144 if count == 0 {
145 eprintln!("No numeric values found in column '{}'", col);
146 std::process::exit(1);
147 }
148
149 let range = max_val - min_val;
151 let bands_spec = if range <= 100.0 {
152 let num_bands = ((range / 10.0).ceil() as usize).min(10);
154 let mut bands = Vec::new();
155 for i in 0..num_bands {
156 let start = min_val + (i as f64 * 10.0);
157 let end = (min_val + ((i + 1) as f64 * 10.0)).min(max_val);
158 if i == num_bands - 1 {
159 bands.push(format!("{:.0}+", start));
160 } else {
161 bands.push(format!("{:.0}-{:.0}", start, end));
162 }
163 }
164 bands.join(",")
165 } else {
166 let step = range / 5.0;
168 let mut bands = Vec::new();
169 for i in 0..5 {
170 let start = min_val + (i as f64 * step);
171 let end = min_val + ((i + 1) as f64 * step);
172 if i == 4 {
173 bands.push(format!("{:.0}+", start));
174 } else {
175 bands.push(format!("{:.0}-{:.0}", start, end));
176 }
177 }
178 bands.join(",")
179 };
180
181 let generator = CaseGenerator::from_ranges(col, &bands_spec, labels).unwrap();
182 println!("{}", generator.to_sql());
183 }
184 _ => {
185 eprintln!("Unknown style: {}. Use 'values' or 'ranges'", style);
186 std::process::exit(1);
187 }
188 }
189
190 Ok(())
191 }
192 _ => {
193 eprintln!("Error: --generate-case requires a file path and --column <name>");
194 eprintln!("Example: --generate-case data.csv --column ocean_proximity --style values");
195 std::process::exit(1);
196 }
197 }
198}
199
200pub fn handle_case_range_generation(args: &[String]) -> io::Result<()> {
202 let column_pos = args.iter().position(|arg| arg == "--column");
204 let column = column_pos.and_then(|pos| args.get(pos + 1));
205
206 let min_pos = args.iter().position(|arg| arg == "--min");
207 let min_val = min_pos
208 .and_then(|pos| args.get(pos + 1))
209 .and_then(|s| s.parse::<f64>().ok())
210 .unwrap_or(0.0);
211
212 let max_pos = args.iter().position(|arg| arg == "--max");
213 let max_val = max_pos
214 .and_then(|pos| args.get(pos + 1))
215 .and_then(|s| s.parse::<f64>().ok())
216 .unwrap_or(100.0);
217
218 let bands_pos = args.iter().position(|arg| arg == "--bands");
219 let num_bands = bands_pos
220 .and_then(|pos| args.get(pos + 1))
221 .and_then(|s| s.parse::<usize>().ok())
222 .unwrap_or(5);
223
224 let labels_pos = args.iter().position(|arg| arg == "--labels");
225 let labels = labels_pos.and_then(|pos| args.get(pos + 1)).map(|l| {
226 l.split(',')
227 .map(|s| s.trim().to_string())
228 .collect::<Vec<_>>()
229 });
230
231 match column {
232 Some(col) => {
233 let width = (max_val - min_val) / num_bands as f64;
235 let mut bands_spec = Vec::new();
236
237 for i in 0..num_bands {
238 let start = min_val + (i as f64 * width);
239 let end = if i == num_bands - 1 {
240 max_val
241 } else {
242 min_val + ((i + 1) as f64 * width)
243 };
244
245 if i == num_bands - 1 {
246 bands_spec.push(format!("{:.0}+", start));
247 } else {
248 bands_spec.push(format!("{:.0}-{:.0}", start, end));
249 }
250 }
251
252 let bands_str = bands_spec.join(",");
253 let generator = CaseGenerator::from_ranges(col, &bands_str, labels).unwrap();
254 println!("{}", generator.to_sql());
255
256 Ok(())
257 }
258 _ => {
259 eprintln!("Error: --generate-case-range requires --column <name>");
260 eprintln!("Example: --generate-case-range --column value --min 0 --max 100 --bands 5");
261 std::process::exit(1);
262 }
263 }
264}
265
266fn generate_banding_case(column: &str, bands_spec: &str) -> String {
268 let mut sql = String::from("CASE");
269 let bands: Vec<&str> = bands_spec.split(',').map(|s| s.trim()).collect();
270
271 for (i, band) in bands.iter().enumerate() {
272 sql.push('\n');
273
274 if band.ends_with('+') {
275 let min = band.trim_end_matches('+').trim();
277 sql.push_str(&format!(" WHEN {} >= {} THEN '{}'", column, min, band));
278 } else if band.contains('-') {
279 let parts: Vec<&str> = band.split('-').map(|s| s.trim()).collect();
281 if parts.len() == 2 {
282 let min = parts[0];
283 let max = parts[1];
284
285 if i == 0 {
286 sql.push_str(&format!(" WHEN {} <= {} THEN '{}'", column, max, band));
288 } else {
289 sql.push_str(&format!(
291 " WHEN {} BETWEEN {} AND {} THEN '{}'",
292 column, min, max, band
293 ));
294 }
295 }
296 }
297 }
298
299 sql.push_str(&format!("\nEND AS {}_band", column));
300 sql
301}