1use crate::parser::ParsedValue;
7use ahash::AHashMap;
8use anyhow::Result;
9use duckdb::Connection;
10
11use super::ImportStats;
12
13pub const MAX_ROWS_PER_BATCH: usize = 10_000;
15
16#[derive(Debug)]
18pub struct InsertBatch {
19 pub table: String,
21 pub columns: Option<Vec<String>>,
23 pub rows: Vec<Vec<ParsedValue>>,
25 pub statements: Vec<String>,
27 pub rows_per_statement: Vec<usize>,
29}
30
31impl InsertBatch {
32 pub fn new(table: String, columns: Option<Vec<String>>) -> Self {
34 Self {
35 table,
36 columns,
37 rows: Vec::new(),
38 statements: Vec::new(),
39 rows_per_statement: Vec::new(),
40 }
41 }
42
43 pub fn row_count(&self) -> usize {
45 self.rows.len()
46 }
47
48 pub fn clear(&mut self) {
50 self.rows.clear();
51 self.statements.clear();
52 self.rows_per_statement.clear();
53 }
54}
55
56type BatchKey = (String, Option<Vec<String>>);
60
61pub struct BatchManager {
63 batches: AHashMap<BatchKey, InsertBatch>,
65 max_rows_per_batch: usize,
67}
68
69impl BatchManager {
70 pub fn new(max_rows_per_batch: usize) -> Self {
72 Self {
73 batches: AHashMap::new(),
74 max_rows_per_batch,
75 }
76 }
77
78 pub fn queue_insert(
80 &mut self,
81 table: &str,
82 columns: Option<Vec<String>>,
83 rows: Vec<Vec<ParsedValue>>,
84 original_sql: String,
85 ) -> Option<InsertBatch> {
86 let row_count = rows.len();
87 let key = (table.to_string(), columns.clone());
88
89 let batch = self
90 .batches
91 .entry(key)
92 .or_insert_with(|| InsertBatch::new(table.to_string(), columns));
93
94 batch.rows.extend(rows);
95 batch.statements.push(original_sql);
96 batch.rows_per_statement.push(row_count);
97
98 if batch.rows.len() >= self.max_rows_per_batch {
100 let key = (table.to_string(), batch.columns.clone());
102 self.batches.remove(&key)
103 } else {
104 None
105 }
106 }
107
108 pub fn get_ready_batches(&mut self) -> Vec<InsertBatch> {
110 let mut ready = Vec::new();
111 let mut to_remove = Vec::new();
112
113 for (key, batch) in &self.batches {
114 if batch.rows.len() >= self.max_rows_per_batch {
115 to_remove.push(key.clone());
116 }
117 }
118
119 for key in to_remove {
120 if let Some(batch) = self.batches.remove(&key) {
121 ready.push(batch);
122 }
123 }
124
125 ready
126 }
127
128 pub fn drain_all(&mut self) -> Vec<InsertBatch> {
130 self.batches.drain().map(|(_, batch)| batch).collect()
131 }
132
133 pub fn has_pending(&self) -> bool {
135 !self.batches.is_empty()
136 }
137}
138
139fn format_value_for_sql(value: &ParsedValue) -> String {
141 match value {
142 ParsedValue::Null => "NULL".to_string(),
143 ParsedValue::Integer(n) => n.to_string(),
144 ParsedValue::BigInteger(n) => n.to_string(),
145 ParsedValue::String { value } => {
146 let escaped = value.replace('\'', "''");
148 format!("'{}'", escaped)
149 }
150 ParsedValue::Hex(bytes) => {
151 let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
153 format!("x'{}'", hex)
154 }
155 ParsedValue::Other(raw) => {
156 let s = String::from_utf8_lossy(raw);
157 if s.parse::<f64>().is_ok() {
159 s.to_string()
160 } else {
161 let escaped = s.replace('\'', "''");
163 format!("'{}'", escaped)
164 }
165 }
166 }
167}
168
169fn generate_batch_insert(
171 table: &str,
172 columns: &Option<Vec<String>>,
173 rows: &[Vec<ParsedValue>],
174) -> String {
175 if rows.is_empty() {
176 return String::new();
177 }
178
179 let mut sql = format!("INSERT INTO \"{}\"", table);
180
181 if let Some(cols) = columns {
183 sql.push_str(" (");
184 for (i, col) in cols.iter().enumerate() {
185 if i > 0 {
186 sql.push_str(", ");
187 }
188 sql.push('"');
189 sql.push_str(col);
190 sql.push('"');
191 }
192 sql.push(')');
193 }
194
195 sql.push_str(" VALUES\n");
196
197 for (i, row) in rows.iter().enumerate() {
198 if i > 0 {
199 sql.push_str(",\n");
200 }
201 sql.push('(');
202 for (j, value) in row.iter().enumerate() {
203 if j > 0 {
204 sql.push_str(", ");
205 }
206 sql.push_str(&format_value_for_sql(value));
207 }
208 sql.push(')');
209 }
210 sql.push(';');
211
212 sql
213}
214
215pub fn flush_batch(
217 conn: &Connection,
218 batch: &mut InsertBatch,
219 stats: &mut ImportStats,
220 failed_tables: &mut std::collections::HashSet<String>,
221) -> Result<()> {
222 if batch.rows.is_empty() {
223 return Ok(());
224 }
225
226 if failed_tables.contains(&batch.table) {
228 batch.clear();
229 return Ok(());
230 }
231
232 match try_batch_insert(conn, batch, stats) {
234 Ok(true) => {
235 batch.clear();
237 Ok(())
238 }
239 Ok(false) => {
240 failed_tables.insert(batch.table.clone());
242 batch.clear();
243 Ok(())
244 }
245 Err(_) => {
246 fallback_execute(conn, batch, stats)?;
249 batch.clear();
250 Ok(())
251 }
252 }
253}
254
255fn try_batch_insert(
258 conn: &Connection,
259 batch: &InsertBatch,
260 stats: &mut ImportStats,
261) -> Result<bool> {
262 let batch_sql = generate_batch_insert(&batch.table, &batch.columns, &batch.rows);
264 if batch_sql.is_empty() {
265 return Ok(true);
266 }
267
268 match conn.execute(&batch_sql, []) {
270 Ok(_) => {
271 stats.insert_statements += batch.statements.len();
272 stats.rows_inserted += batch.rows.len() as u64;
273 Ok(true)
274 }
275 Err(e) => {
276 let err_str = e.to_string();
277 if err_str.contains("does not exist") || err_str.contains("not found") {
279 return Ok(false);
280 }
281 Err(e.into())
282 }
283 }
284}
285
286fn fallback_execute(conn: &Connection, batch: &InsertBatch, stats: &mut ImportStats) -> Result<()> {
288 for stmt in &batch.statements {
289 match conn.execute(stmt, []) {
290 Ok(_) => {
291 stats.insert_statements += 1;
292 stats.rows_inserted += count_insert_rows(stmt);
293 }
294 Err(e) => {
295 if stats.warnings.len() < 100 {
296 stats.warnings.push(format!(
297 "Failed INSERT for {} in fallback: {}",
298 batch.table, e
299 ));
300 }
301 stats.statements_skipped += 1;
302 }
303 }
304 }
305 Ok(())
306}
307
308fn count_insert_rows(sql: &str) -> u64 {
310 if let Some(values_pos) = sql.to_uppercase().find("VALUES") {
311 let after_values = &sql[values_pos + 6..];
312 let mut count = 0u64;
313 let mut depth: i32 = 0;
314 let mut in_string = false;
315 let mut prev_char = ' ';
316
317 for c in after_values.chars() {
318 if c == '\'' && prev_char != '\\' {
319 in_string = !in_string;
320 }
321 if !in_string {
322 if c == '(' {
323 if depth == 0 {
324 count += 1;
325 }
326 depth += 1;
327 } else if c == ')' {
328 depth = depth.saturating_sub(1);
329 }
330 }
331 prev_char = c;
332 }
333 count
334 } else {
335 1
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_batch_manager_queue() {
345 let mut mgr = BatchManager::new(100);
346
347 let rows = vec![vec![
348 ParsedValue::Integer(1),
349 ParsedValue::String {
350 value: "test".to_string(),
351 },
352 ]];
353
354 let result = mgr.queue_insert(
355 "users",
356 None,
357 rows,
358 "INSERT INTO users VALUES (1, 'test')".to_string(),
359 );
360 assert!(result.is_none()); assert!(mgr.has_pending());
362 }
363
364 #[test]
365 fn test_batch_manager_flush_threshold() {
366 let mut mgr = BatchManager::new(2);
367
368 let rows1 = vec![vec![ParsedValue::Integer(1)]];
369 let rows2 = vec![vec![ParsedValue::Integer(2)], vec![ParsedValue::Integer(3)]];
370
371 mgr.queue_insert("test", None, rows1, "SQL1".to_string());
372 let result = mgr.queue_insert("test", None, rows2, "SQL2".to_string());
373
374 assert!(result.is_some());
375 let batch = result.unwrap();
376 assert_eq!(batch.row_count(), 3);
377 }
378
379 #[test]
380 fn test_count_insert_rows() {
381 assert_eq!(count_insert_rows("INSERT INTO t VALUES (1)"), 1);
382 assert_eq!(count_insert_rows("INSERT INTO t VALUES (1), (2), (3)"), 3);
383 assert_eq!(
384 count_insert_rows("INSERT INTO t VALUES (1, 'a(b)'), (2, 'c')"),
385 2
386 );
387 }
388
389 #[test]
390 fn test_generate_batch_insert_with_columns() {
391 let rows = vec![
392 vec![
393 ParsedValue::String {
394 value: "alice".to_string(),
395 },
396 ParsedValue::Integer(1),
397 ],
398 vec![
399 ParsedValue::String {
400 value: "bob".to_string(),
401 },
402 ParsedValue::Integer(2),
403 ],
404 ];
405 let columns = Some(vec!["name".to_string(), "id".to_string()]);
406 let sql = generate_batch_insert("users", &columns, &rows);
407 assert!(sql.contains("INSERT INTO \"users\" (\"name\", \"id\") VALUES"));
408 assert!(sql.contains("'alice'"));
409 assert!(sql.contains("'bob'"));
410 }
411
412 #[test]
413 fn test_generate_batch_insert_without_columns() {
414 let rows = vec![vec![
415 ParsedValue::Integer(1),
416 ParsedValue::String {
417 value: "test".to_string(),
418 },
419 ]];
420 let sql = generate_batch_insert("test", &None, &rows);
421 assert_eq!(sql, "INSERT INTO \"test\" VALUES\n(1, 'test');");
422 }
423}