1use crate::benchmarks::data_generator::BenchmarkDataGenerator;
2use crate::benchmarks::metrics::{BenchmarkResult, MetricsCollector};
3use crate::benchmarks::query_suite::{BenchmarkQuery, QuerySuite};
4use crate::data::datatable::DataTable;
5use crate::data::query_engine::QueryEngine;
6use crate::sql::recursive_parser::Parser;
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::Write;
10use std::sync::Arc;
11
12pub struct BenchmarkRunner {
13 results: Vec<BenchmarkResult>,
14 tables: HashMap<String, DataTable>,
15}
16
17impl BenchmarkRunner {
18 pub fn new() -> Self {
19 BenchmarkRunner {
20 results: Vec::new(),
21 tables: HashMap::new(),
22 }
23 }
24
25 pub fn prepare_benchmark_data(&mut self, sizes: &[usize]) {
26 println!("=== Preparing Benchmark Data ===");
27 for &size in sizes {
28 println!("Generating tables with {} rows...", size);
29 let tables = BenchmarkDataGenerator::generate_all_benchmark_tables(size);
30
31 for (name, table) in tables {
32 let key = format!("{}_{}", name, size);
33 println!(
34 " - Generated {} table: {} rows, {} columns",
35 key,
36 table.rows.len(),
37 table.columns.len()
38 );
39 self.tables.insert(key, table);
40 }
41 }
42 println!("Total tables generated: {}\n", self.tables.len());
43 }
44
45 pub fn run_query_benchmark(
46 &mut self,
47 query: &BenchmarkQuery,
48 table: &DataTable,
49 table_name: &str,
50 row_count: usize,
51 ) -> BenchmarkResult {
52 let mut result = BenchmarkResult::new(
53 query.name.clone(),
54 query.category.as_str().to_string(),
55 table_name.to_string(),
56 row_count,
57 );
58
59 let mut collector = MetricsCollector::new();
60 collector.start_total();
61
62 collector.start_phase();
63 let parse_result = Parser::new(&query.sql).parse();
64 collector.end_parse_phase();
65
66 match parse_result {
67 Ok(statement) => {
68 collector.start_phase();
69
70 let engine = QueryEngine::new();
71 let table_arc = Arc::new(table.clone());
72
73 collector.start_phase();
74 match engine.execute_statement(table_arc, statement) {
75 Ok(result_view) => {
76 collector.end_execute_phase();
77 collector.set_rows(table.rows.len(), result_view.row_count());
78 }
79 Err(e) => {
80 result.error = Some(format!("Execution error: {}", e));
81 }
82 }
83 }
84 Err(e) => {
85 result.error = Some(format!("Parse error: {}", e));
86 }
87 }
88
89 collector.end_total();
90 result.metrics = collector.get_metrics();
91 result
92 }
93
94 pub fn run_progressive_benchmarks(&mut self, increment: usize, max_rows: usize) {
95 println!("=== Running Progressive Benchmarks ===");
96 println!("Increment: {} rows, Max: {} rows\n", increment, max_rows);
97
98 let sizes: Vec<usize> = (1..=max_rows / increment).map(|i| i * increment).collect();
99
100 self.prepare_benchmark_data(&sizes);
101
102 let queries = QuerySuite::get_progressive_queries();
103
104 for size in &sizes {
105 if let Some(query_sqls) = queries.get(size) {
106 println!("Running benchmarks for {} rows:", size);
107
108 let table_key = format!("mixed_{}", size);
109 if let Some(table) = self.tables.get(&table_key).cloned() {
110 for (i, sql) in query_sqls.iter().enumerate() {
111 let query = BenchmarkQuery::new(
112 format!("prog_query_{}", i),
113 crate::benchmarks::query_suite::QueryCategory::BasicOperations,
114 sql.clone(),
115 format!("Progressive query {}", i),
116 "mixed",
117 );
118
119 let result = self.run_query_benchmark(&query, &table, &table_key, *size);
120
121 println!(
122 " - {}: {:.2}ms, {} rows/sec",
123 query.name,
124 result.metrics.total_time.as_secs_f64() * 1000.0,
125 result.metrics.rows_per_second as u64
126 );
127
128 self.results.push(result);
129 }
130 }
131 println!();
132 }
133 }
134 }
135
136 pub fn run_comprehensive_benchmarks(&mut self, sizes: &[usize]) {
137 println!("=== Running Comprehensive Benchmarks ===");
138 self.prepare_benchmark_data(sizes);
139
140 let all_queries = QuerySuite::get_all_queries();
141 let total_benchmarks = sizes.len() * all_queries.len();
142 let mut completed = 0;
143
144 for &size in sizes {
145 println!("\n--- Testing with {} rows ---", size);
146
147 for query in &all_queries {
148 let table_key = format!("{}_{}", query.table_type, size);
149
150 let table = if query.table_type == "all" {
151 self.tables.get(&format!("mixed_{}", size)).cloned()
152 } else {
153 self.tables.get(&table_key).cloned()
154 };
155
156 if let Some(table) = table {
157 let result = self.run_query_benchmark(&query, &table, &query.table_type, size);
158
159 completed += 1;
160 let progress = (completed as f64 / total_benchmarks as f64) * 100.0;
161
162 if result.error.is_none() {
163 println!(
164 "[{:.1}%] {} ({}): {:.2}ms",
165 progress,
166 query.name,
167 query.category.as_str(),
168 result.metrics.total_time.as_secs_f64() * 1000.0
169 );
170 } else {
171 println!(
172 "[{:.1}%] {} ({}): ERROR - {}",
173 progress,
174 query.name,
175 query.category.as_str(),
176 result.error.as_ref().unwrap()
177 );
178 }
179
180 self.results.push(result);
181 }
182 }
183 }
184 }
185
186 pub fn run_category_benchmarks(
187 &mut self,
188 category: crate::benchmarks::query_suite::QueryCategory,
189 sizes: &[usize],
190 ) {
191 println!("=== Running {} Benchmarks ===", category.as_str());
192 self.prepare_benchmark_data(sizes);
193
194 let queries: Vec<_> = QuerySuite::get_all_queries()
195 .into_iter()
196 .filter(|q| q.category == category)
197 .collect();
198
199 for &size in sizes {
200 println!(
201 "\nTesting {} queries with {} rows:",
202 category.as_str(),
203 size
204 );
205
206 for query in &queries {
207 let table_key = format!("{}_{}", query.table_type, size);
208 let table = self
209 .tables
210 .get(&table_key)
211 .cloned()
212 .or_else(|| self.tables.get(&format!("mixed_{}", size)).cloned());
213
214 if let Some(table) = table {
215 let result = self.run_query_benchmark(&query, &table, &query.table_type, size);
216
217 println!(
218 " {}: {:.2}ms ({})",
219 query.name,
220 result.metrics.total_time.as_secs_f64() * 1000.0,
221 if result.error.is_none() {
222 "OK"
223 } else {
224 "FAILED"
225 }
226 );
227
228 self.results.push(result);
229 }
230 }
231 }
232 }
233
234 pub fn generate_report(&self) -> String {
235 let mut report = String::new();
236
237 report.push_str("# SQL CLI Benchmark Report\n\n");
238 report.push_str(&format!("Generated: {}\n", chrono::Local::now()));
239 report.push_str(&format!("Total benchmarks run: {}\n\n", self.results.len()));
240
241 report.push_str("## Summary Statistics\n\n");
242
243 let successful: Vec<_> = self.results.iter().filter(|r| r.error.is_none()).collect();
244
245 if !successful.is_empty() {
246 let avg_time: f64 = successful
247 .iter()
248 .map(|r| r.metrics.total_time.as_secs_f64())
249 .sum::<f64>()
250 / successful.len() as f64;
251
252 let avg_throughput: f64 = successful
253 .iter()
254 .map(|r| r.metrics.rows_per_second)
255 .sum::<f64>()
256 / successful.len() as f64;
257
258 report.push_str(&format!("- Successful: {}\n", successful.len()));
259 report.push_str(&format!(
260 "- Failed: {}\n",
261 self.results.len() - successful.len()
262 ));
263 report.push_str(&format!(
264 "- Average execution time: {:.2}ms\n",
265 avg_time * 1000.0
266 ));
267 report.push_str(&format!(
268 "- Average throughput: {:.0} rows/sec\n\n",
269 avg_throughput
270 ));
271 }
272
273 report.push_str("## Results by Category\n\n");
274
275 let categories = vec!["basic", "aggregation", "sorting", "window", "complex"];
276
277 for category in categories {
278 let category_results: Vec<_> = self
279 .results
280 .iter()
281 .filter(|r| r.query_category == category && r.error.is_none())
282 .collect();
283
284 if !category_results.is_empty() {
285 report.push_str(&format!(
286 "### {} Operations\n\n",
287 category.chars().next().unwrap().to_uppercase().to_string() + &category[1..]
288 ));
289
290 for result in category_results {
291 report.push_str(&format!(
292 "- {} ({} rows): {:.2}ms, {:.0} rows/sec\n",
293 result.query_name,
294 result.row_count,
295 result.metrics.total_time.as_secs_f64() * 1000.0,
296 result.metrics.rows_per_second
297 ));
298 }
299 report.push_str("\n");
300 }
301 }
302
303 report.push_str("## Performance by Data Size\n\n");
304
305 let mut size_map: HashMap<usize, Vec<&BenchmarkResult>> = HashMap::new();
306 for result in &self.results {
307 if result.error.is_none() {
308 size_map
309 .entry(result.row_count)
310 .or_insert_with(Vec::new)
311 .push(result);
312 }
313 }
314
315 let mut sizes: Vec<_> = size_map.keys().cloned().collect();
316 sizes.sort();
317
318 for size in sizes {
319 if let Some(results) = size_map.get(&size) {
320 let avg_time: f64 = results
321 .iter()
322 .map(|r| r.metrics.total_time.as_secs_f64())
323 .sum::<f64>()
324 / results.len() as f64;
325
326 report.push_str(&format!(
327 "- {} rows: avg {:.2}ms ({} queries)\n",
328 size,
329 avg_time * 1000.0,
330 results.len()
331 ));
332 }
333 }
334
335 report
336 }
337
338 pub fn save_results_csv(&self, filename: &str) -> Result<(), String> {
339 let mut file =
340 File::create(filename).map_err(|e| format!("Failed to create CSV file: {}", e))?;
341
342 writeln!(file, "query_name,table,row_count,parse_ms,plan_ms,execute_ms,total_ms,rows_processed,rows_returned,rows_per_sec,status")
343 .map_err(|e| format!("Failed to write CSV header: {}", e))?;
344
345 for result in &self.results {
346 writeln!(file, "{}", result.to_csv_row())
347 .map_err(|e| format!("Failed to write CSV row: {}", e))?;
348 }
349
350 Ok(())
351 }
352
353 pub fn print_summary(&self) {
354 println!("\n=== Benchmark Summary ===");
355
356 let successful = self.results.iter().filter(|r| r.error.is_none()).count();
357 let failed = self.results.len() - successful;
358
359 println!(
360 "Total: {}, Successful: {}, Failed: {}",
361 self.results.len(),
362 successful,
363 failed
364 );
365
366 if successful > 0 {
367 let total_time: f64 = self
368 .results
369 .iter()
370 .filter(|r| r.error.is_none())
371 .map(|r| r.metrics.total_time.as_secs_f64())
372 .sum();
373
374 println!("Total benchmark time: {:.2}s", total_time);
375
376 let fastest = self
377 .results
378 .iter()
379 .filter(|r| r.error.is_none())
380 .min_by_key(|r| r.metrics.total_time)
381 .unwrap();
382
383 let slowest = self
384 .results
385 .iter()
386 .filter(|r| r.error.is_none())
387 .max_by_key(|r| r.metrics.total_time)
388 .unwrap();
389
390 println!(
391 "\nFastest query: {} ({:.2}ms)",
392 fastest.query_name,
393 fastest.metrics.total_time.as_secs_f64() * 1000.0
394 );
395
396 println!(
397 "Slowest query: {} ({:.2}ms)",
398 slowest.query_name,
399 slowest.metrics.total_time.as_secs_f64() * 1000.0
400 );
401 }
402
403 if failed > 0 {
404 println!("\nFailed queries:");
405 for result in self.results.iter().filter(|r| r.error.is_some()) {
406 println!(
407 " - {}: {}",
408 result.query_name,
409 result.error.as_ref().unwrap()
410 );
411 }
412 }
413 }
414}