Skip to main content

sql_splitter/duckdb/
batch.rs

1//! Batch manager for DuckDB Appender-based bulk loading.
2//!
3//! This module provides efficient batched insertion of rows into DuckDB
4//! using the Appender API instead of individual INSERT statement execution.
5
6use crate::parser::ParsedValue;
7use ahash::AHashMap;
8use anyhow::Result;
9use duckdb::Connection;
10
11use super::ImportStats;
12
13/// Maximum rows to accumulate per batch before flushing
14pub const MAX_ROWS_PER_BATCH: usize = 10_000;
15
16/// A batch of rows for a single table
17#[derive(Debug)]
18pub struct InsertBatch {
19    /// Target table name
20    pub table: String,
21    /// Column list if explicitly specified
22    pub columns: Option<Vec<String>>,
23    /// Accumulated rows (each row is a Vec of ParsedValue)
24    pub rows: Vec<Vec<ParsedValue>>,
25    /// Original SQL statements for fallback execution
26    pub statements: Vec<String>,
27    /// Number of rows contributed by each statement
28    pub rows_per_statement: Vec<usize>,
29}
30
31impl InsertBatch {
32    /// Create a new batch for a table
33    pub fn new(table: String, columns: Option<Vec<String>>) -> Self {
34        Self {
35            table,
36            columns,
37            rows: Vec::new(),
38            statements: Vec::new(),
39            rows_per_statement: Vec::new(),
40        }
41    }
42
43    /// Total number of rows in batch
44    pub fn row_count(&self) -> usize {
45        self.rows.len()
46    }
47
48    /// Clear the batch
49    pub fn clear(&mut self) {
50        self.rows.clear();
51        self.statements.clear();
52        self.rows_per_statement.clear();
53    }
54}
55
56/// Batch key: (table_name, column_layout)
57/// Using Option<Vec<String>> for columns allows distinguishing between
58/// different column orderings for the same table.
59type BatchKey = (String, Option<Vec<String>>);
60
61/// Manages batched INSERT operations for multiple tables
62pub struct BatchManager {
63    /// Active batches keyed by (table, columns)
64    batches: AHashMap<BatchKey, InsertBatch>,
65    /// Maximum rows per batch
66    max_rows_per_batch: usize,
67}
68
69impl BatchManager {
70    /// Create a new batch manager
71    pub fn new(max_rows_per_batch: usize) -> Self {
72        Self {
73            batches: AHashMap::new(),
74            max_rows_per_batch,
75        }
76    }
77
78    /// Queue rows for insertion, returning a batch if it's ready to flush
79    pub fn queue_insert(
80        &mut self,
81        table: &str,
82        columns: Option<Vec<String>>,
83        rows: Vec<Vec<ParsedValue>>,
84        original_sql: String,
85    ) -> Option<InsertBatch> {
86        let row_count = rows.len();
87        let key = (table.to_string(), columns.clone());
88
89        let batch = self
90            .batches
91            .entry(key)
92            .or_insert_with(|| InsertBatch::new(table.to_string(), columns));
93
94        batch.rows.extend(rows);
95        batch.statements.push(original_sql);
96        batch.rows_per_statement.push(row_count);
97
98        // Check if we need to flush
99        if batch.rows.len() >= self.max_rows_per_batch {
100            // Take the batch out and return it
101            let key = (table.to_string(), batch.columns.clone());
102            self.batches.remove(&key)
103        } else {
104            None
105        }
106    }
107
108    /// Get any batches that are ready to flush
109    pub fn get_ready_batches(&mut self) -> Vec<InsertBatch> {
110        let mut ready = Vec::new();
111        let mut to_remove = Vec::new();
112
113        for (key, batch) in &self.batches {
114            if batch.rows.len() >= self.max_rows_per_batch {
115                to_remove.push(key.clone());
116            }
117        }
118
119        for key in to_remove {
120            if let Some(batch) = self.batches.remove(&key) {
121                ready.push(batch);
122            }
123        }
124
125        ready
126    }
127
128    /// Flush all remaining batches
129    pub fn drain_all(&mut self) -> Vec<InsertBatch> {
130        self.batches.drain().map(|(_, batch)| batch).collect()
131    }
132
133    /// Check if there are any pending batches
134    pub fn has_pending(&self) -> bool {
135        !self.batches.is_empty()
136    }
137}
138
139/// Format a ParsedValue for SQL insertion
140fn format_value_for_sql(value: &ParsedValue) -> String {
141    match value {
142        ParsedValue::Null => "NULL".to_string(),
143        ParsedValue::Integer(n) => n.to_string(),
144        ParsedValue::BigInteger(n) => n.to_string(),
145        ParsedValue::String { value } => {
146            // Escape single quotes by doubling them (SQL standard)
147            let escaped = value.replace('\'', "''");
148            format!("'{}'", escaped)
149        }
150        ParsedValue::Hex(bytes) => {
151            // Convert to hex string for DuckDB
152            let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
153            format!("x'{}'", hex)
154        }
155        ParsedValue::Other(raw) => {
156            let s = String::from_utf8_lossy(raw);
157            // Try to parse as float
158            if s.parse::<f64>().is_ok() {
159                s.to_string()
160            } else {
161                // Treat as text
162                let escaped = s.replace('\'', "''");
163                format!("'{}'", escaped)
164            }
165        }
166    }
167}
168
169/// Generate a batched INSERT statement from parsed values
170fn generate_batch_insert(
171    table: &str,
172    columns: &Option<Vec<String>>,
173    rows: &[Vec<ParsedValue>],
174) -> String {
175    if rows.is_empty() {
176        return String::new();
177    }
178
179    let mut sql = format!("INSERT INTO \"{}\"", table);
180
181    // Add column list if specified
182    if let Some(cols) = columns {
183        sql.push_str(" (");
184        for (i, col) in cols.iter().enumerate() {
185            if i > 0 {
186                sql.push_str(", ");
187            }
188            sql.push('"');
189            sql.push_str(col);
190            sql.push('"');
191        }
192        sql.push(')');
193    }
194
195    sql.push_str(" VALUES\n");
196
197    for (i, row) in rows.iter().enumerate() {
198        if i > 0 {
199            sql.push_str(",\n");
200        }
201        sql.push('(');
202        for (j, value) in row.iter().enumerate() {
203            if j > 0 {
204                sql.push_str(", ");
205            }
206            sql.push_str(&format_value_for_sql(value));
207        }
208        sql.push(')');
209    }
210    sql.push(';');
211
212    sql
213}
214
215/// Flush a batch using DuckDB's Appender API with transactional fallback
216pub fn flush_batch(
217    conn: &Connection,
218    batch: &mut InsertBatch,
219    stats: &mut ImportStats,
220    failed_tables: &mut std::collections::HashSet<String>,
221) -> Result<()> {
222    if batch.rows.is_empty() {
223        return Ok(());
224    }
225
226    // Skip tables we know don't exist
227    if failed_tables.contains(&batch.table) {
228        batch.clear();
229        return Ok(());
230    }
231
232    // Try the fast path with batched INSERT
233    match try_batch_insert(conn, batch, stats) {
234        Ok(true) => {
235            // Success via batched INSERT
236            batch.clear();
237            Ok(())
238        }
239        Ok(false) => {
240            // Table doesn't exist or other non-recoverable error
241            failed_tables.insert(batch.table.clone());
242            batch.clear();
243            Ok(())
244        }
245        Err(_) => {
246            // Batched INSERT failed (constraint violation, type mismatch, etc.)
247            // Fall back to per-statement execution
248            fallback_execute(conn, batch, stats)?;
249            batch.clear();
250            Ok(())
251        }
252    }
253}
254
255/// Try to insert using batched SQL execution, returns Ok(true) on success,
256/// Ok(false) if table doesn't exist, Err on constraint/type errors
257fn try_batch_insert(
258    conn: &Connection,
259    batch: &InsertBatch,
260    stats: &mut ImportStats,
261) -> Result<bool> {
262    // Generate a single batched INSERT statement
263    let batch_sql = generate_batch_insert(&batch.table, &batch.columns, &batch.rows);
264    if batch_sql.is_empty() {
265        return Ok(true);
266    }
267
268    // Execute the batched INSERT (within the loader's transaction context)
269    match conn.execute(&batch_sql, []) {
270        Ok(_) => {
271            stats.insert_statements += batch.statements.len();
272            stats.rows_inserted += batch.rows.len() as u64;
273            Ok(true)
274        }
275        Err(e) => {
276            let err_str = e.to_string();
277            // Check if it's a "table not found" error
278            if err_str.contains("does not exist") || err_str.contains("not found") {
279                return Ok(false);
280            }
281            Err(e.into())
282        }
283    }
284}
285
286/// Fallback: execute original SQL statements one by one
287fn fallback_execute(conn: &Connection, batch: &InsertBatch, stats: &mut ImportStats) -> Result<()> {
288    for stmt in &batch.statements {
289        match conn.execute(stmt, []) {
290            Ok(_) => {
291                stats.insert_statements += 1;
292                stats.rows_inserted += count_insert_rows(stmt);
293            }
294            Err(e) => {
295                if stats.warnings.len() < 100 {
296                    stats.warnings.push(format!(
297                        "Failed INSERT for {} in fallback: {}",
298                        batch.table, e
299                    ));
300                }
301                stats.statements_skipped += 1;
302            }
303        }
304    }
305    Ok(())
306}
307
308/// Count rows in an INSERT statement (simple heuristic)
309fn count_insert_rows(sql: &str) -> u64 {
310    if let Some(values_pos) = sql.to_uppercase().find("VALUES") {
311        let after_values = &sql[values_pos + 6..];
312        let mut count = 0u64;
313        let mut depth: i32 = 0;
314        let mut in_string = false;
315        let mut prev_char = ' ';
316
317        for c in after_values.chars() {
318            if c == '\'' && prev_char != '\\' {
319                in_string = !in_string;
320            }
321            if !in_string {
322                if c == '(' {
323                    if depth == 0 {
324                        count += 1;
325                    }
326                    depth += 1;
327                } else if c == ')' {
328                    depth = depth.saturating_sub(1);
329                }
330            }
331            prev_char = c;
332        }
333        count
334    } else {
335        1
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_batch_manager_queue() {
345        let mut mgr = BatchManager::new(100);
346
347        let rows = vec![vec![
348            ParsedValue::Integer(1),
349            ParsedValue::String {
350                value: "test".to_string(),
351            },
352        ]];
353
354        let result = mgr.queue_insert(
355            "users",
356            None,
357            rows,
358            "INSERT INTO users VALUES (1, 'test')".to_string(),
359        );
360        assert!(result.is_none()); // Not ready yet
361        assert!(mgr.has_pending());
362    }
363
364    #[test]
365    fn test_batch_manager_flush_threshold() {
366        let mut mgr = BatchManager::new(2);
367
368        let rows1 = vec![vec![ParsedValue::Integer(1)]];
369        let rows2 = vec![vec![ParsedValue::Integer(2)], vec![ParsedValue::Integer(3)]];
370
371        mgr.queue_insert("test", None, rows1, "SQL1".to_string());
372        let result = mgr.queue_insert("test", None, rows2, "SQL2".to_string());
373
374        assert!(result.is_some());
375        let batch = result.unwrap();
376        assert_eq!(batch.row_count(), 3);
377    }
378
379    #[test]
380    fn test_count_insert_rows() {
381        assert_eq!(count_insert_rows("INSERT INTO t VALUES (1)"), 1);
382        assert_eq!(count_insert_rows("INSERT INTO t VALUES (1), (2), (3)"), 3);
383        assert_eq!(
384            count_insert_rows("INSERT INTO t VALUES (1, 'a(b)'), (2, 'c')"),
385            2
386        );
387    }
388
389    #[test]
390    fn test_generate_batch_insert_with_columns() {
391        let rows = vec![
392            vec![
393                ParsedValue::String {
394                    value: "alice".to_string(),
395                },
396                ParsedValue::Integer(1),
397            ],
398            vec![
399                ParsedValue::String {
400                    value: "bob".to_string(),
401                },
402                ParsedValue::Integer(2),
403            ],
404        ];
405        let columns = Some(vec!["name".to_string(), "id".to_string()]);
406        let sql = generate_batch_insert("users", &columns, &rows);
407        assert!(sql.contains("INSERT INTO \"users\" (\"name\", \"id\") VALUES"));
408        assert!(sql.contains("'alice'"));
409        assert!(sql.contains("'bob'"));
410    }
411
412    #[test]
413    fn test_generate_batch_insert_without_columns() {
414        let rows = vec![vec![
415            ParsedValue::Integer(1),
416            ParsedValue::String {
417                value: "test".to_string(),
418            },
419        ]];
420        let sql = generate_batch_insert("test", &None, &rows);
421        assert_eq!(sql, "INSERT INTO \"test\" VALUES\n(1, 'test');");
422    }
423}