sql_cli/benchmarks/
runner.rs

1use crate::benchmarks::data_generator::BenchmarkDataGenerator;
2use crate::benchmarks::metrics::{BenchmarkResult, MetricsCollector};
3use crate::benchmarks::query_suite::{BenchmarkQuery, QuerySuite};
4use crate::data::datatable::DataTable;
5use crate::data::query_engine::QueryEngine;
6use crate::sql::recursive_parser::Parser;
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::Write;
10use std::sync::Arc;
11
12pub struct BenchmarkRunner {
13    results: Vec<BenchmarkResult>,
14    tables: HashMap<String, DataTable>,
15}
16
17impl BenchmarkRunner {
18    pub fn new() -> Self {
19        BenchmarkRunner {
20            results: Vec::new(),
21            tables: HashMap::new(),
22        }
23    }
24
25    pub fn prepare_benchmark_data(&mut self, sizes: &[usize]) {
26        println!("=== Preparing Benchmark Data ===");
27        for &size in sizes {
28            println!("Generating tables with {} rows...", size);
29            let tables = BenchmarkDataGenerator::generate_all_benchmark_tables(size);
30
31            for (name, table) in tables {
32                let key = format!("{}_{}", name, size);
33                println!(
34                    "  - Generated {} table: {} rows, {} columns",
35                    key,
36                    table.rows.len(),
37                    table.columns.len()
38                );
39                self.tables.insert(key, table);
40            }
41        }
42        println!("Total tables generated: {}\n", self.tables.len());
43    }
44
45    pub fn run_query_benchmark(
46        &mut self,
47        query: &BenchmarkQuery,
48        table: &DataTable,
49        table_name: &str,
50        row_count: usize,
51    ) -> BenchmarkResult {
52        let mut result = BenchmarkResult::new(
53            query.name.clone(),
54            query.category.as_str().to_string(),
55            table_name.to_string(),
56            row_count,
57        );
58
59        let mut collector = MetricsCollector::new();
60        collector.start_total();
61
62        collector.start_phase();
63        let parse_result = Parser::new(&query.sql).parse();
64        collector.end_parse_phase();
65
66        match parse_result {
67            Ok(statement) => {
68                collector.start_phase();
69
70                let engine = QueryEngine::new();
71                let table_arc = Arc::new(table.clone());
72
73                collector.start_phase();
74                match engine.execute_statement(table_arc, statement) {
75                    Ok(result_view) => {
76                        collector.end_execute_phase();
77                        collector.set_rows(table.rows.len(), result_view.row_count());
78                    }
79                    Err(e) => {
80                        result.error = Some(format!("Execution error: {}", e));
81                    }
82                }
83            }
84            Err(e) => {
85                result.error = Some(format!("Parse error: {}", e));
86            }
87        }
88
89        collector.end_total();
90        result.metrics = collector.get_metrics();
91        result
92    }
93
94    pub fn run_progressive_benchmarks(&mut self, increment: usize, max_rows: usize) {
95        println!("=== Running Progressive Benchmarks ===");
96        println!("Increment: {} rows, Max: {} rows\n", increment, max_rows);
97
98        let sizes: Vec<usize> = (1..=max_rows / increment).map(|i| i * increment).collect();
99
100        self.prepare_benchmark_data(&sizes);
101
102        let queries = QuerySuite::get_progressive_queries();
103
104        for size in &sizes {
105            if let Some(query_sqls) = queries.get(size) {
106                println!("Running benchmarks for {} rows:", size);
107
108                let table_key = format!("mixed_{}", size);
109                if let Some(table) = self.tables.get(&table_key).cloned() {
110                    for (i, sql) in query_sqls.iter().enumerate() {
111                        let query = BenchmarkQuery::new(
112                            format!("prog_query_{}", i),
113                            crate::benchmarks::query_suite::QueryCategory::BasicOperations,
114                            sql.clone(),
115                            format!("Progressive query {}", i),
116                            "mixed",
117                        );
118
119                        let result = self.run_query_benchmark(&query, &table, &table_key, *size);
120
121                        println!(
122                            "  - {}: {:.2}ms, {} rows/sec",
123                            query.name,
124                            result.metrics.total_time.as_secs_f64() * 1000.0,
125                            result.metrics.rows_per_second as u64
126                        );
127
128                        self.results.push(result);
129                    }
130                }
131                println!();
132            }
133        }
134    }
135
136    pub fn run_comprehensive_benchmarks(&mut self, sizes: &[usize]) {
137        println!("=== Running Comprehensive Benchmarks ===");
138        self.prepare_benchmark_data(sizes);
139
140        let all_queries = QuerySuite::get_all_queries();
141        let total_benchmarks = sizes.len() * all_queries.len();
142        let mut completed = 0;
143
144        for &size in sizes {
145            println!("\n--- Testing with {} rows ---", size);
146
147            for query in &all_queries {
148                let table_key = format!("{}_{}", query.table_type, size);
149
150                let table = if query.table_type == "all" {
151                    self.tables.get(&format!("mixed_{}", size)).cloned()
152                } else {
153                    self.tables.get(&table_key).cloned()
154                };
155
156                if let Some(table) = table {
157                    let result = self.run_query_benchmark(&query, &table, &query.table_type, size);
158
159                    completed += 1;
160                    let progress = (completed as f64 / total_benchmarks as f64) * 100.0;
161
162                    if result.error.is_none() {
163                        println!(
164                            "[{:.1}%] {} ({}): {:.2}ms",
165                            progress,
166                            query.name,
167                            query.category.as_str(),
168                            result.metrics.total_time.as_secs_f64() * 1000.0
169                        );
170                    } else {
171                        println!(
172                            "[{:.1}%] {} ({}): ERROR - {}",
173                            progress,
174                            query.name,
175                            query.category.as_str(),
176                            result.error.as_ref().unwrap()
177                        );
178                    }
179
180                    self.results.push(result);
181                }
182            }
183        }
184    }
185
186    pub fn run_category_benchmarks(
187        &mut self,
188        category: crate::benchmarks::query_suite::QueryCategory,
189        sizes: &[usize],
190    ) {
191        println!("=== Running {} Benchmarks ===", category.as_str());
192        self.prepare_benchmark_data(sizes);
193
194        let queries: Vec<_> = QuerySuite::get_all_queries()
195            .into_iter()
196            .filter(|q| q.category == category)
197            .collect();
198
199        for &size in sizes {
200            println!(
201                "\nTesting {} queries with {} rows:",
202                category.as_str(),
203                size
204            );
205
206            for query in &queries {
207                let table_key = format!("{}_{}", query.table_type, size);
208                let table = self
209                    .tables
210                    .get(&table_key)
211                    .cloned()
212                    .or_else(|| self.tables.get(&format!("mixed_{}", size)).cloned());
213
214                if let Some(table) = table {
215                    let result = self.run_query_benchmark(&query, &table, &query.table_type, size);
216
217                    println!(
218                        "  {}: {:.2}ms ({})",
219                        query.name,
220                        result.metrics.total_time.as_secs_f64() * 1000.0,
221                        if result.error.is_none() {
222                            "OK"
223                        } else {
224                            "FAILED"
225                        }
226                    );
227
228                    self.results.push(result);
229                }
230            }
231        }
232    }
233
234    pub fn generate_report(&self) -> String {
235        let mut report = String::new();
236
237        report.push_str("# SQL CLI Benchmark Report\n\n");
238        report.push_str(&format!("Generated: {}\n", chrono::Local::now()));
239        report.push_str(&format!("Total benchmarks run: {}\n\n", self.results.len()));
240
241        report.push_str("## Summary Statistics\n\n");
242
243        let successful: Vec<_> = self.results.iter().filter(|r| r.error.is_none()).collect();
244
245        if !successful.is_empty() {
246            let avg_time: f64 = successful
247                .iter()
248                .map(|r| r.metrics.total_time.as_secs_f64())
249                .sum::<f64>()
250                / successful.len() as f64;
251
252            let avg_throughput: f64 = successful
253                .iter()
254                .map(|r| r.metrics.rows_per_second)
255                .sum::<f64>()
256                / successful.len() as f64;
257
258            report.push_str(&format!("- Successful: {}\n", successful.len()));
259            report.push_str(&format!(
260                "- Failed: {}\n",
261                self.results.len() - successful.len()
262            ));
263            report.push_str(&format!(
264                "- Average execution time: {:.2}ms\n",
265                avg_time * 1000.0
266            ));
267            report.push_str(&format!(
268                "- Average throughput: {:.0} rows/sec\n\n",
269                avg_throughput
270            ));
271        }
272
273        report.push_str("## Results by Category\n\n");
274
275        let categories = vec!["basic", "aggregation", "sorting", "window", "complex"];
276
277        for category in categories {
278            let category_results: Vec<_> = self
279                .results
280                .iter()
281                .filter(|r| r.query_category == category && r.error.is_none())
282                .collect();
283
284            if !category_results.is_empty() {
285                report.push_str(&format!(
286                    "### {} Operations\n\n",
287                    category.chars().next().unwrap().to_uppercase().to_string() + &category[1..]
288                ));
289
290                for result in category_results {
291                    report.push_str(&format!(
292                        "- {} ({} rows): {:.2}ms, {:.0} rows/sec\n",
293                        result.query_name,
294                        result.row_count,
295                        result.metrics.total_time.as_secs_f64() * 1000.0,
296                        result.metrics.rows_per_second
297                    ));
298                }
299                report.push_str("\n");
300            }
301        }
302
303        report.push_str("## Performance by Data Size\n\n");
304
305        let mut size_map: HashMap<usize, Vec<&BenchmarkResult>> = HashMap::new();
306        for result in &self.results {
307            if result.error.is_none() {
308                size_map
309                    .entry(result.row_count)
310                    .or_insert_with(Vec::new)
311                    .push(result);
312            }
313        }
314
315        let mut sizes: Vec<_> = size_map.keys().cloned().collect();
316        sizes.sort();
317
318        for size in sizes {
319            if let Some(results) = size_map.get(&size) {
320                let avg_time: f64 = results
321                    .iter()
322                    .map(|r| r.metrics.total_time.as_secs_f64())
323                    .sum::<f64>()
324                    / results.len() as f64;
325
326                report.push_str(&format!(
327                    "- {} rows: avg {:.2}ms ({} queries)\n",
328                    size,
329                    avg_time * 1000.0,
330                    results.len()
331                ));
332            }
333        }
334
335        report
336    }
337
338    pub fn save_results_csv(&self, filename: &str) -> Result<(), String> {
339        let mut file =
340            File::create(filename).map_err(|e| format!("Failed to create CSV file: {}", e))?;
341
342        writeln!(file, "query_name,table,row_count,parse_ms,plan_ms,execute_ms,total_ms,rows_processed,rows_returned,rows_per_sec,status")
343            .map_err(|e| format!("Failed to write CSV header: {}", e))?;
344
345        for result in &self.results {
346            writeln!(file, "{}", result.to_csv_row())
347                .map_err(|e| format!("Failed to write CSV row: {}", e))?;
348        }
349
350        Ok(())
351    }
352
353    pub fn print_summary(&self) {
354        println!("\n=== Benchmark Summary ===");
355
356        let successful = self.results.iter().filter(|r| r.error.is_none()).count();
357        let failed = self.results.len() - successful;
358
359        println!(
360            "Total: {}, Successful: {}, Failed: {}",
361            self.results.len(),
362            successful,
363            failed
364        );
365
366        if successful > 0 {
367            let total_time: f64 = self
368                .results
369                .iter()
370                .filter(|r| r.error.is_none())
371                .map(|r| r.metrics.total_time.as_secs_f64())
372                .sum();
373
374            println!("Total benchmark time: {:.2}s", total_time);
375
376            let fastest = self
377                .results
378                .iter()
379                .filter(|r| r.error.is_none())
380                .min_by_key(|r| r.metrics.total_time)
381                .unwrap();
382
383            let slowest = self
384                .results
385                .iter()
386                .filter(|r| r.error.is_none())
387                .max_by_key(|r| r.metrics.total_time)
388                .unwrap();
389
390            println!(
391                "\nFastest query: {} ({:.2}ms)",
392                fastest.query_name,
393                fastest.metrics.total_time.as_secs_f64() * 1000.0
394            );
395
396            println!(
397                "Slowest query: {} ({:.2}ms)",
398                slowest.query_name,
399                slowest.metrics.total_time.as_secs_f64() * 1000.0
400            );
401        }
402
403        if failed > 0 {
404            println!("\nFailed queries:");
405            for result in self.results.iter().filter(|r| r.error.is_some()) {
406                println!(
407                    "  - {}: {}",
408                    result.query_name,
409                    result.error.as_ref().unwrap()
410                );
411            }
412        }
413    }
414}