sql_cli/cli/
refactoring.rs

1use std::io;
2
3use crate::data::csv_datasource;
4use crate::refactoring::banding::CaseGenerator;
5
6/// Handle --generate-bands flag to generate banding CASE statements
7pub fn handle_banding_generation(args: &[String]) -> io::Result<()> {
8    // Parse column and bands arguments
9    let column_pos = args.iter().position(|arg| arg == "--column");
10    let bands_pos = args.iter().position(|arg| arg == "--bands");
11
12    let column = column_pos.and_then(|pos| args.get(pos + 1));
13    let bands = bands_pos.and_then(|pos| args.get(pos + 1));
14
15    match (column, bands) {
16        (Some(col), Some(bands_spec)) => {
17            // Generate the CASE statement
18            let case_sql = generate_banding_case(col, bands_spec);
19            println!("{}", case_sql);
20            Ok(())
21        }
22        _ => {
23            eprintln!("Error: --generate-bands requires --column <name> and --bands <spec>");
24            eprintln!("Example: --generate-bands --column age --bands \"0-24,25-49,50-74,75+\"");
25            std::process::exit(1);
26        }
27    }
28}
29
30/// Handle --generate-case flag to generate CASE statements from data analysis
31pub fn handle_case_generation(args: &[String]) -> io::Result<()> {
32    // Find file argument position
33    let case_pos = args
34        .iter()
35        .position(|arg| arg == "--generate-case")
36        .unwrap();
37    let file_path = args.get(case_pos + 1);
38
39    // Parse other arguments
40    let column_pos = args.iter().position(|arg| arg == "--column");
41    let column = column_pos.and_then(|pos| args.get(pos + 1));
42
43    let style_pos = args.iter().position(|arg| arg == "--style");
44    let style = style_pos
45        .and_then(|pos| args.get(pos + 1))
46        .map(|s| s.as_str())
47        .unwrap_or("values");
48
49    let labels_pos = args.iter().position(|arg| arg == "--labels");
50    let labels = labels_pos.and_then(|pos| args.get(pos + 1)).map(|l| {
51        l.split(',')
52            .map(|s| s.trim().to_string())
53            .collect::<Vec<_>>()
54    });
55
56    match (file_path, column) {
57        (Some(path), Some(col)) => {
58            // Load the data file
59            let datasource = match csv_datasource::CsvDataSource::load_from_file(path, "data") {
60                Ok(ds) => ds,
61                Err(e) => {
62                    eprintln!("Error loading file {}: {}", path, e);
63                    std::process::exit(1);
64                }
65            };
66
67            let datatable = datasource.to_datatable();
68
69            // Find the column
70            let col_index = datatable
71                .columns
72                .iter()
73                .position(|c| c.name.eq_ignore_ascii_case(col))
74                .ok_or_else(|| {
75                    io::Error::new(
76                        io::ErrorKind::NotFound,
77                        format!("Column '{}' not found", col),
78                    )
79                })?;
80
81            // Get distinct values or analyze range
82            match style {
83                "values" => {
84                    // Get distinct values
85                    let mut distinct_values = std::collections::BTreeSet::new();
86                    for row in &datatable.rows {
87                        if let Some(value) = row.get(col_index) {
88                            if !value.is_null() {
89                                distinct_values.insert(value.to_string());
90                            }
91                        }
92                    }
93
94                    // Generate value mappings
95                    let value_mappings: Vec<(String, String)> = if let Some(ref labels) = labels {
96                        distinct_values
97                            .into_iter()
98                            .zip(labels.iter())
99                            .map(|(v, l)| (v, l.clone()))
100                            .collect()
101                    } else {
102                        distinct_values
103                            .into_iter()
104                            .map(|v| {
105                                let label = v.replace('_', " ").replace('-', " ");
106                                let label = label
107                                    .split_whitespace()
108                                    .map(|word| {
109                                        let mut chars = word.chars();
110                                        match chars.next() {
111                                            None => String::new(),
112                                            Some(first) => {
113                                                first.to_uppercase().collect::<String>()
114                                                    + chars.as_str()
115                                            }
116                                        }
117                                    })
118                                    .collect::<Vec<_>>()
119                                    .join(" ");
120                                (v, label)
121                            })
122                            .collect()
123                    };
124
125                    let generator = CaseGenerator::from_values(col, value_mappings);
126                    println!("{}", generator.to_sql());
127                }
128                "ranges" => {
129                    // Analyze numeric range
130                    let mut min_val = f64::MAX;
131                    let mut max_val = f64::MIN;
132                    let mut count = 0;
133
134                    for row in &datatable.rows {
135                        if let Some(value) = row.get(col_index) {
136                            if let Ok(num) = value.to_string().parse::<f64>() {
137                                min_val = min_val.min(num);
138                                max_val = max_val.max(num);
139                                count += 1;
140                            }
141                        }
142                    }
143
144                    if count == 0 {
145                        eprintln!("No numeric values found in column '{}'", col);
146                        std::process::exit(1);
147                    }
148
149                    // Generate smart ranges based on data distribution
150                    let range = max_val - min_val;
151                    let bands_spec = if range <= 100.0 {
152                        // Small range - use 10-unit bands
153                        let num_bands = ((range / 10.0).ceil() as usize).min(10);
154                        let mut bands = Vec::new();
155                        for i in 0..num_bands {
156                            let start = min_val + (i as f64 * 10.0);
157                            let end = (min_val + ((i + 1) as f64 * 10.0)).min(max_val);
158                            if i == num_bands - 1 {
159                                bands.push(format!("{:.0}+", start));
160                            } else {
161                                bands.push(format!("{:.0}-{:.0}", start, end));
162                            }
163                        }
164                        bands.join(",")
165                    } else {
166                        // Large range - use quartiles or quintiles
167                        let step = range / 5.0;
168                        let mut bands = Vec::new();
169                        for i in 0..5 {
170                            let start = min_val + (i as f64 * step);
171                            let end = min_val + ((i + 1) as f64 * step);
172                            if i == 4 {
173                                bands.push(format!("{:.0}+", start));
174                            } else {
175                                bands.push(format!("{:.0}-{:.0}", start, end));
176                            }
177                        }
178                        bands.join(",")
179                    };
180
181                    let generator = CaseGenerator::from_ranges(col, &bands_spec, labels).unwrap();
182                    println!("{}", generator.to_sql());
183                }
184                _ => {
185                    eprintln!("Unknown style: {}. Use 'values' or 'ranges'", style);
186                    std::process::exit(1);
187                }
188            }
189
190            Ok(())
191        }
192        _ => {
193            eprintln!("Error: --generate-case requires a file path and --column <name>");
194            eprintln!("Example: --generate-case data.csv --column ocean_proximity --style values");
195            std::process::exit(1);
196        }
197    }
198}
199
200/// Handle --generate-case-range flag to generate CASE statements for numeric ranges
201pub fn handle_case_range_generation(args: &[String]) -> io::Result<()> {
202    // Parse arguments
203    let column_pos = args.iter().position(|arg| arg == "--column");
204    let column = column_pos.and_then(|pos| args.get(pos + 1));
205
206    let min_pos = args.iter().position(|arg| arg == "--min");
207    let min_val = min_pos
208        .and_then(|pos| args.get(pos + 1))
209        .and_then(|s| s.parse::<f64>().ok())
210        .unwrap_or(0.0);
211
212    let max_pos = args.iter().position(|arg| arg == "--max");
213    let max_val = max_pos
214        .and_then(|pos| args.get(pos + 1))
215        .and_then(|s| s.parse::<f64>().ok())
216        .unwrap_or(100.0);
217
218    let bands_pos = args.iter().position(|arg| arg == "--bands");
219    let num_bands = bands_pos
220        .and_then(|pos| args.get(pos + 1))
221        .and_then(|s| s.parse::<usize>().ok())
222        .unwrap_or(5);
223
224    let labels_pos = args.iter().position(|arg| arg == "--labels");
225    let labels = labels_pos.and_then(|pos| args.get(pos + 1)).map(|l| {
226        l.split(',')
227            .map(|s| s.trim().to_string())
228            .collect::<Vec<_>>()
229    });
230
231    match column {
232        Some(col) => {
233            // Generate equal-width bands
234            let width = (max_val - min_val) / num_bands as f64;
235            let mut bands_spec = Vec::new();
236
237            for i in 0..num_bands {
238                let start = min_val + (i as f64 * width);
239                let end = if i == num_bands - 1 {
240                    max_val
241                } else {
242                    min_val + ((i + 1) as f64 * width)
243                };
244
245                if i == num_bands - 1 {
246                    bands_spec.push(format!("{:.0}+", start));
247                } else {
248                    bands_spec.push(format!("{:.0}-{:.0}", start, end));
249                }
250            }
251
252            let bands_str = bands_spec.join(",");
253            let generator = CaseGenerator::from_ranges(col, &bands_str, labels).unwrap();
254            println!("{}", generator.to_sql());
255
256            Ok(())
257        }
258        _ => {
259            eprintln!("Error: --generate-case-range requires --column <name>");
260            eprintln!("Example: --generate-case-range --column value --min 0 --max 100 --bands 5");
261            std::process::exit(1);
262        }
263    }
264}
265
266/// Generate a CASE statement from column name and bands specification
267fn generate_banding_case(column: &str, bands_spec: &str) -> String {
268    let mut sql = String::from("CASE");
269    let bands: Vec<&str> = bands_spec.split(',').map(|s| s.trim()).collect();
270
271    for (i, band) in bands.iter().enumerate() {
272        sql.push('\n');
273
274        if band.ends_with('+') {
275            // Handle "75+" format - everything above the minimum
276            let min = band.trim_end_matches('+').trim();
277            sql.push_str(&format!("    WHEN {} >= {} THEN '{}'", column, min, band));
278        } else if band.contains('-') {
279            // Handle range format like "25-49"
280            let parts: Vec<&str> = band.split('-').map(|s| s.trim()).collect();
281            if parts.len() == 2 {
282                let min = parts[0];
283                let max = parts[1];
284
285                if i == 0 {
286                    // First band, from start up to max
287                    sql.push_str(&format!("    WHEN {} <= {} THEN '{}'", column, max, band));
288                } else {
289                    // Subsequent bands, check both boundaries
290                    sql.push_str(&format!(
291                        "    WHEN {} BETWEEN {} AND {} THEN '{}'",
292                        column, min, max, band
293                    ));
294                }
295            }
296        }
297    }
298
299    sql.push_str(&format!("\nEND AS {}_band", column));
300    sql
301}