1use vibesql_storage::{
29 statistics::{CostEstimator, TableIndexInfo, TableStatistics},
30 Database,
31};
32
33pub mod thresholds {
36 pub const HIGH_COST_INSERT_THRESHOLD: f64 = 0.5;
39
40 pub const SMALL_BATCH_SIZE: usize = 100;
42
43 pub const LARGE_BATCH_SIZE: usize = 1000;
45
46 pub const HIGH_DELETED_RATIO_THRESHOLD: f64 = 0.4;
48
49 pub const CHUNK_DELETE_ROW_THRESHOLD: usize = 10000;
51
52 pub const DELETE_CHUNK_SIZE: usize = 1000;
54}
55
56pub struct DmlOptimizer<'a> {
60 cost_estimator: CostEstimator,
62 index_info: Option<TableIndexInfo>,
64 table_stats: Option<TableStatistics>,
66 #[allow(dead_code)]
68 db: &'a Database,
69 table_name: &'a str,
71}
72
73impl<'a> DmlOptimizer<'a> {
74 pub fn new(db: &'a Database, table_name: &'a str) -> Self {
83 let index_info = db.get_table_index_info(table_name);
84 let table_stats = db.get_table(table_name).and_then(|t| t.get_statistics().cloned());
85
86 Self { cost_estimator: CostEstimator::default(), index_info, table_stats, db, table_name }
87 }
88
89 pub fn get_stats_with_fallback(&self) -> TableStatistics {
94 if let Some(ref stats) = self.table_stats {
95 stats.clone()
96 } else {
97 self.create_fallback_stats()
99 }
100 }
101
102 fn create_fallback_stats(&self) -> TableStatistics {
104 let row_count = self.db.get_table(self.table_name).map(|t| t.row_count()).unwrap_or(0);
106
107 TableStatistics {
109 row_count,
110 columns: std::collections::HashMap::new(),
111 last_updated: instant::SystemTime::now(),
112 is_stale: true, sample_metadata: None,
114 avg_row_bytes: None, }
116 }
117
118 pub fn optimal_insert_batch_size(&self, total_rows: usize) -> usize {
130 let index_info = match &self.index_info {
132 Some(info) => info,
133 None => return thresholds::LARGE_BATCH_SIZE.min(total_rows),
134 };
135
136 let stats = self.get_stats_with_fallback();
138 let single_row_cost = self.cost_estimator.estimate_insert(1, &stats, index_info);
139
140 if std::env::var("DML_COST_DEBUG").is_ok() {
142 eprintln!(
143 "DML_COST_DEBUG: INSERT on {} - cost_per_row={:.3}, hash_indexes={}, btree_indexes={}",
144 self.table_name,
145 single_row_cost,
146 index_info.hash_index_count,
147 index_info.btree_index_count
148 );
149 }
150
151 if single_row_cost > thresholds::HIGH_COST_INSERT_THRESHOLD {
153 thresholds::SMALL_BATCH_SIZE.min(total_rows)
154 } else {
155 thresholds::LARGE_BATCH_SIZE.min(total_rows)
156 }
157 }
158
159 pub fn should_chunk_delete(&self, rows_to_delete: usize) -> bool {
171 if rows_to_delete < thresholds::CHUNK_DELETE_ROW_THRESHOLD {
173 return false;
174 }
175
176 let index_info = match &self.index_info {
177 Some(info) => info,
178 None => return false,
179 };
180
181 let stats = self.get_stats_with_fallback();
183 let delete_cost = self.cost_estimator.estimate_delete(rows_to_delete, &stats, index_info);
184
185 if std::env::var("DML_COST_DEBUG").is_ok() {
187 eprintln!(
188 "DML_COST_DEBUG: DELETE on {} - rows={}, cost={:.3}, deleted_ratio={:.2}",
189 self.table_name, rows_to_delete, delete_cost, index_info.deleted_ratio
190 );
191 }
192
193 let high_deleted_ratio =
197 index_info.deleted_ratio > thresholds::HIGH_DELETED_RATIO_THRESHOLD;
198 let many_indexes = index_info.btree_index_count >= 3;
199
200 high_deleted_ratio || many_indexes
201 }
202
203 pub fn delete_chunk_size(&self) -> usize {
205 thresholds::DELETE_CHUNK_SIZE
206 }
207
208 pub fn should_trigger_early_compaction(&self) -> bool {
216 let index_info = match &self.index_info {
217 Some(info) => info,
218 None => return false,
219 };
220
221 index_info.deleted_ratio > thresholds::HIGH_DELETED_RATIO_THRESHOLD
224 && index_info.btree_index_count >= 2
225 }
226
227 pub fn compute_indexes_affected_ratio(
239 &self,
240 changed_columns: &std::collections::HashSet<usize>,
241 schema: &vibesql_catalog::TableSchema,
242 ) -> f64 {
243 let index_info = match &self.index_info {
244 Some(info) => info,
245 None => return 0.0,
246 };
247
248 let total_indexes = index_info.hash_index_count + index_info.btree_index_count;
249 if total_indexes == 0 {
250 return 0.0;
251 }
252
253 let mut affected_indexes = 0;
254
255 if let Some(pk_indices) = schema.get_primary_key_indices() {
257 if pk_indices.iter().any(|i| changed_columns.contains(i)) {
258 affected_indexes += 1;
259 }
260 }
261
262 for unique_cols in &schema.unique_constraints {
264 let unique_indices: Vec<usize> =
265 unique_cols.iter().filter_map(|name| schema.get_column_index(name)).collect();
266 if unique_indices.iter().any(|i| changed_columns.contains(i)) {
267 affected_indexes += 1;
268 }
269 }
270
271 let changed_column_names: Vec<String> = changed_columns
274 .iter()
275 .filter_map(|&i| schema.columns.get(i).map(|c| c.name.clone()))
276 .collect();
277
278 for col_name in &changed_column_names {
279 if self.db.has_index_on_column(self.table_name, col_name) {
280 affected_indexes += 1;
281 break; }
283 }
284
285 affected_indexes as f64 / total_indexes as f64
286 }
287
288 pub fn estimate_update_cost(&self, row_count: usize, indexes_affected_ratio: f64) -> f64 {
297 let index_info = match &self.index_info {
298 Some(info) => info,
299 None => return 0.0,
300 };
301
302 let stats = self.get_stats_with_fallback();
303 let cost = self.cost_estimator.estimate_update(
304 row_count,
305 &stats,
306 index_info,
307 indexes_affected_ratio,
308 );
309
310 if std::env::var("DML_COST_DEBUG").is_ok() {
311 eprintln!(
312 "DML_COST_DEBUG: UPDATE on {} - rows={}, affected_ratio={:.2}, cost={:.3}",
313 self.table_name, row_count, indexes_affected_ratio, cost
314 );
315 }
316
317 cost
318 }
319
320 pub fn index_info(&self) -> Option<&TableIndexInfo> {
322 self.index_info.as_ref()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use vibesql_catalog::{ColumnSchema, TableSchema};
329 use vibesql_types::DataType;
330
331 use super::*;
332
333 fn create_test_db_with_table(
334 table_name: &str,
335 with_pk: bool,
336 btree_index_count: usize,
337 ) -> Database {
338 let mut db = Database::new();
339
340 let schema = if with_pk {
341 TableSchema::with_primary_key(
342 table_name.to_string(),
343 vec![
344 ColumnSchema::new("id".to_string(), DataType::Integer, false),
345 ColumnSchema::new(
346 "name".to_string(),
347 DataType::Varchar { max_length: Some(100) },
348 false,
349 ),
350 ColumnSchema::new("value".to_string(), DataType::Integer, true),
351 ],
352 vec!["id".to_string()],
353 )
354 } else {
355 TableSchema::new(
356 table_name.to_string(),
357 vec![
358 ColumnSchema::new("id".to_string(), DataType::Integer, false),
359 ColumnSchema::new(
360 "name".to_string(),
361 DataType::Varchar { max_length: Some(100) },
362 false,
363 ),
364 ],
365 )
366 };
367 db.create_table(schema).unwrap();
368
369 for i in 0..btree_index_count {
371 db.create_index(
372 format!("idx_{}_{}", table_name, i),
373 table_name.to_string(),
374 false,
375 vec![vibesql_ast::IndexColumn::Column {
376 column_name: "name".to_string(),
377 direction: vibesql_ast::OrderDirection::Asc,
378 prefix_length: None,
379 }],
380 )
381 .unwrap();
382 }
383
384 db
385 }
386
387 #[test]
388 fn test_optimal_insert_batch_size_low_cost() {
389 let db = create_test_db_with_table("test_table", true, 0);
390 let optimizer = DmlOptimizer::new(&db, "test_table");
391
392 let batch_size = optimizer.optimal_insert_batch_size(5000);
394 assert!(
396 batch_size == thresholds::SMALL_BATCH_SIZE
397 || batch_size == thresholds::LARGE_BATCH_SIZE
398 );
399 }
400
401 #[test]
402 fn test_optimal_insert_batch_size_high_cost() {
403 let db = create_test_db_with_table("test_table", true, 5);
404 let optimizer = DmlOptimizer::new(&db, "test_table");
405
406 let batch_size = optimizer.optimal_insert_batch_size(5000);
408 assert!(batch_size <= thresholds::LARGE_BATCH_SIZE);
410 assert!(batch_size >= thresholds::SMALL_BATCH_SIZE);
411 }
412
413 #[test]
414 fn test_optimal_insert_batch_size_more_indexes_smaller_batch() {
415 let db_few = create_test_db_with_table("table_few", true, 1);
417 let db_many = create_test_db_with_table("table_many", true, 5);
418
419 let optimizer_few = DmlOptimizer::new(&db_few, "table_few");
420 let optimizer_many = DmlOptimizer::new(&db_many, "table_many");
421
422 let batch_few = optimizer_few.optimal_insert_batch_size(5000);
423 let batch_many = optimizer_many.optimal_insert_batch_size(5000);
424
425 assert!(batch_many <= batch_few);
427 }
428
429 #[test]
430 fn test_should_chunk_delete_small() {
431 let db = create_test_db_with_table("test_table", true, 0);
432 let optimizer = DmlOptimizer::new(&db, "test_table");
433
434 assert!(!optimizer.should_chunk_delete(100));
436 }
437
438 #[test]
439 fn test_compute_indexes_affected_ratio_no_indexes() {
440 let db = create_test_db_with_table("test_table", false, 0);
441 let optimizer = DmlOptimizer::new(&db, "test_table");
442
443 let schema = db.catalog.get_table("test_table").unwrap();
444 let changed_columns: std::collections::HashSet<usize> = [1].into_iter().collect();
445
446 let ratio = optimizer.compute_indexes_affected_ratio(&changed_columns, schema);
447 assert_eq!(ratio, 0.0);
448 }
449
450 #[test]
451 fn test_compute_indexes_affected_ratio_pk_affected() {
452 let db = create_test_db_with_table("test_table", true, 0);
453 let optimizer = DmlOptimizer::new(&db, "test_table");
454
455 let schema = db.catalog.get_table("test_table").unwrap();
456 let changed_columns: std::collections::HashSet<usize> = [0].into_iter().collect();
458
459 let ratio = optimizer.compute_indexes_affected_ratio(&changed_columns, schema);
460 assert!(ratio > 0.0, "PK update should affect at least one index");
461 }
462
463 #[test]
464 fn test_fallback_stats() {
465 let db = create_test_db_with_table("test_table", true, 0);
466 let optimizer = DmlOptimizer::new(&db, "test_table");
467
468 let stats = optimizer.get_stats_with_fallback();
470 assert_eq!(stats.row_count, 0);
472 }
473
474 #[test]
475 fn test_estimate_update_cost() {
476 let db = create_test_db_with_table("test_table", true, 2);
477 let optimizer = DmlOptimizer::new(&db, "test_table");
478
479 let full_cost = optimizer.estimate_update_cost(100, 1.0);
481
482 let selective_cost = optimizer.estimate_update_cost(100, 0.0);
484
485 assert!(full_cost > selective_cost, "Full update should cost more than selective");
486 }
487}