Skip to main content

sql_splitter/convert/
copy_to_insert.rs

1//! Convert PostgreSQL COPY FROM stdin statements to INSERT statements.
2//!
3//! Handles:
4//! - Tab-separated value parsing
5//! - NULL handling (\N → NULL)
6//! - Escape sequence conversion (\t, \n, \\)
7//! - Batched INSERT generation for efficiency
8
9use once_cell::sync::Lazy;
10use regex::Regex;
11
12/// Maximum rows per INSERT statement (for readability and transaction size)
13const MAX_ROWS_PER_INSERT: usize = 1000;
14
15/// Result of parsing a COPY header
16#[derive(Debug, Clone)]
17pub struct CopyHeader {
18    /// Schema name (e.g., "public")
19    pub schema: Option<String>,
20    /// Table name
21    pub table: String,
22    /// Column list (may be empty if not specified)
23    pub columns: Vec<String>,
24}
25
26/// Parse a COPY header to extract table and columns
27/// Input: "COPY schema.table (col1, col2) FROM stdin;"
28pub fn parse_copy_header(stmt: &str) -> Option<CopyHeader> {
29    // Strip comments from the beginning
30    let stmt = strip_leading_comments(stmt);
31
32    static RE_COPY: Lazy<Regex> = Lazy::new(|| {
33        // Pattern: COPY [ONLY] [schema.]table [(columns)] FROM stdin
34        // Schema and table can be quoted with double quotes
35        Regex::new(
36            r#"(?i)^\s*COPY\s+(?:ONLY\s+)?(?:"?(\w+)"?\.)?["]?(\w+)["]?\s*(?:\(([^)]+)\))?\s+FROM\s+stdin"#
37        ).unwrap()
38    });
39
40    let caps = RE_COPY.captures(&stmt)?;
41
42    let schema = caps.get(1).map(|m| m.as_str().to_string());
43    let table = caps.get(2)?.as_str().to_string();
44    let columns = caps
45        .get(3)
46        .map(|m| {
47            m.as_str()
48                .split(',')
49                .map(|c| c.trim().trim_matches('"').trim_matches('`').to_string())
50                .collect()
51        })
52        .unwrap_or_default();
53
54    Some(CopyHeader {
55        schema,
56        table,
57        columns,
58    })
59}
60
61/// Strip leading SQL comments from a string
62fn strip_leading_comments(stmt: &str) -> String {
63    let mut result = stmt.trim();
64    loop {
65        if result.starts_with("--") {
66            if let Some(pos) = result.find('\n') {
67                result = result[pos + 1..].trim();
68                continue;
69            } else {
70                return String::new();
71            }
72        }
73        if result.starts_with("/*") {
74            if let Some(pos) = result.find("*/") {
75                result = result[pos + 2..].trim();
76                continue;
77            } else {
78                return String::new();
79            }
80        }
81        break;
82    }
83    result.to_string()
84}
85
86/// Convert a COPY data block to INSERT statements
87///
88/// # Arguments
89/// * `header` - Parsed COPY header with table/column info
90/// * `data` - The data block (tab-separated rows ending with \.)
91/// * `target_dialect` - Target SQL dialect for quoting
92///
93/// # Returns
94/// Vector of INSERT statements as bytes
95pub fn copy_to_inserts(
96    header: &CopyHeader,
97    data: &[u8],
98    target_dialect: crate::parser::SqlDialect,
99) -> Vec<Vec<u8>> {
100    let mut inserts = Vec::new();
101    let rows = parse_copy_data(data);
102
103    if rows.is_empty() {
104        return inserts;
105    }
106
107    // Build INSERT prefix
108    let quote_char = match target_dialect {
109        crate::parser::SqlDialect::MySql => '`',
110        _ => '"',
111    };
112
113    let table_ref = if let Some(ref schema) = header.schema {
114        if target_dialect == crate::parser::SqlDialect::MySql {
115            // MySQL: just use table name without schema
116            format!("{}{}{}", quote_char, header.table, quote_char)
117        } else if schema == "public" || schema == "pg_catalog" {
118            // Common PostgreSQL schemas - strip for DuckDB compatibility
119            format!("{}{}{}", quote_char, header.table, quote_char)
120        } else {
121            format!(
122                "{}{}{}.{}{}{}",
123                quote_char, schema, quote_char, quote_char, header.table, quote_char
124            )
125        }
126    } else {
127        format!("{}{}{}", quote_char, header.table, quote_char)
128    };
129
130    let columns_str = if header.columns.is_empty() {
131        String::new()
132    } else {
133        let cols: Vec<String> = header
134            .columns
135            .iter()
136            .map(|c| format!("{}{}{}", quote_char, c, quote_char))
137            .collect();
138        format!(" ({})", cols.join(", "))
139    };
140
141    // Generate batched INSERTs
142    for chunk in rows.chunks(MAX_ROWS_PER_INSERT) {
143        let mut insert = format!("INSERT INTO {}{} VALUES\n", table_ref, columns_str);
144
145        for (i, row) in chunk.iter().enumerate() {
146            if i > 0 {
147                insert.push_str(",\n");
148            }
149            insert.push('(');
150
151            for (j, value) in row.iter().enumerate() {
152                if j > 0 {
153                    insert.push_str(", ");
154                }
155                insert.push_str(&format_value(value, target_dialect));
156            }
157
158            insert.push(')');
159        }
160
161        insert.push(';');
162        inserts.push(insert.into_bytes());
163    }
164
165    inserts
166}
167
168/// A parsed value from COPY data
169#[derive(Debug, Clone)]
170pub enum CopyValue {
171    Null,
172    Text(String),
173}
174
175/// Parse COPY data block into rows of values
176pub fn parse_copy_data(data: &[u8]) -> Vec<Vec<CopyValue>> {
177    let mut rows = Vec::new();
178    let mut pos = 0;
179
180    while pos < data.len() {
181        // Find end of line
182        let line_end = data[pos..]
183            .iter()
184            .position(|&b| b == b'\n')
185            .map(|p| pos + p)
186            .unwrap_or(data.len());
187
188        let line = &data[pos..line_end];
189
190        // Check for terminator
191        if line == b"\\." || line.is_empty() {
192            pos = line_end + 1;
193            continue;
194        }
195
196        // Parse the row
197        let row = parse_row(line);
198        if !row.is_empty() {
199            rows.push(row);
200        }
201
202        pos = line_end + 1;
203    }
204
205    rows
206}
207
208/// Parse a single tab-separated row
209fn parse_row(line: &[u8]) -> Vec<CopyValue> {
210    let mut values = Vec::new();
211    let mut start = 0;
212
213    for (i, &b) in line.iter().enumerate() {
214        if b == b'\t' {
215            values.push(parse_value(&line[start..i]));
216            start = i + 1;
217        }
218    }
219    // Last value
220    if start <= line.len() {
221        values.push(parse_value(&line[start..]));
222    }
223
224    values
225}
226
227/// Parse a single COPY value
228fn parse_value(value: &[u8]) -> CopyValue {
229    // Check for NULL marker
230    if value == b"\\N" {
231        return CopyValue::Null;
232    }
233
234    // Decode escape sequences
235    let decoded = decode_escapes(value);
236    CopyValue::Text(decoded)
237}
238
239/// Decode PostgreSQL COPY escape sequences
240fn decode_escapes(value: &[u8]) -> String {
241    let mut result = String::with_capacity(value.len());
242    let mut i = 0;
243
244    while i < value.len() {
245        if value[i] == b'\\' && i + 1 < value.len() {
246            let next = value[i + 1];
247            let decoded = match next {
248                b'n' => '\n',
249                b'r' => '\r',
250                b't' => '\t',
251                b'\\' => '\\',
252                b'b' => '\x08', // backspace
253                b'f' => '\x0C', // form feed
254                b'v' => '\x0B', // vertical tab
255                _ => {
256                    // Unknown escape or octal, try octal
257                    if next.is_ascii_digit() {
258                        // Try to parse octal (up to 3 digits)
259                        let mut octal_val = 0u8;
260                        let mut consumed = 0;
261                        for j in 0..3 {
262                            if i + 1 + j < value.len() {
263                                let d = value[i + 1 + j];
264                                if (b'0'..=b'7').contains(&d) {
265                                    octal_val = octal_val * 8 + (d - b'0');
266                                    consumed += 1;
267                                } else {
268                                    break;
269                                }
270                            }
271                        }
272                        if consumed > 0 {
273                            result.push(octal_val as char);
274                            i += 1 + consumed;
275                            continue;
276                        }
277                    }
278                    // Unknown escape, keep as-is
279                    result.push('\\');
280                    result.push(next as char);
281                    i += 2;
282                    continue;
283                }
284            };
285            result.push(decoded);
286            i += 2;
287        } else {
288            // Regular character - handle UTF-8 properly
289            if value[i] < 128 {
290                result.push(value[i] as char);
291                i += 1;
292            } else {
293                // Multi-byte UTF-8 sequence
294                let remaining = &value[i..];
295                if let Ok(s) = std::str::from_utf8(remaining) {
296                    if let Some(c) = s.chars().next() {
297                        result.push(c);
298                        i += c.len_utf8();
299                    } else {
300                        i += 1;
301                    }
302                } else {
303                    // Invalid UTF-8, just push the byte as replacement char
304                    result.push('\u{FFFD}');
305                    i += 1;
306                }
307            }
308        }
309    }
310
311    result
312}
313
314/// Format a value for SQL INSERT
315fn format_value(value: &CopyValue, dialect: crate::parser::SqlDialect) -> String {
316    match value {
317        CopyValue::Null => "NULL".to_string(),
318        CopyValue::Text(s) => {
319            // Escape quotes based on dialect
320            let escaped = match dialect {
321                crate::parser::SqlDialect::MySql => {
322                    // MySQL: escape with backslash
323                    s.replace('\\', "\\\\")
324                        .replace('\'', "\\'")
325                        .replace('\n', "\\n")
326                        .replace('\r', "\\r")
327                        .replace('\t', "\\t")
328                        .replace('\0', "\\0")
329                }
330                _ => {
331                    // PostgreSQL/SQLite: escape by doubling
332                    s.replace('\'', "''")
333                }
334            };
335            format!("'{}'", escaped)
336        }
337    }
338}