sql_splitter/convert/
copy_to_insert.rs

1//! Convert PostgreSQL COPY FROM stdin statements to INSERT statements.
2//!
3//! Handles:
4//! - Tab-separated value parsing
5//! - NULL handling (\N → NULL)
6//! - Escape sequence conversion (\t, \n, \\)
7//! - Batched INSERT generation for efficiency
8
9use once_cell::sync::Lazy;
10use regex::Regex;
11
12/// Maximum rows per INSERT statement (for readability and transaction size)
13const MAX_ROWS_PER_INSERT: usize = 100;
14
15/// Result of parsing a COPY header
16#[derive(Debug, Clone)]
17pub struct CopyHeader {
18    /// Schema name (e.g., "public")
19    pub schema: Option<String>,
20    /// Table name
21    pub table: String,
22    /// Column list (may be empty if not specified)
23    pub columns: Vec<String>,
24}
25
26/// Parse a COPY header to extract table and columns
27/// Input: "COPY schema.table (col1, col2) FROM stdin;"
28pub fn parse_copy_header(stmt: &str) -> Option<CopyHeader> {
29    // Strip comments from the beginning
30    let stmt = strip_leading_comments(stmt);
31
32    static RE_COPY: Lazy<Regex> = Lazy::new(|| {
33        // Pattern: COPY [ONLY] [schema.]table [(columns)] FROM stdin
34        // Schema and table can be quoted with double quotes
35        Regex::new(
36            r#"(?i)^\s*COPY\s+(?:ONLY\s+)?(?:"?(\w+)"?\.)?["]?(\w+)["]?\s*(?:\(([^)]+)\))?\s+FROM\s+stdin"#
37        ).unwrap()
38    });
39
40    let caps = RE_COPY.captures(&stmt)?;
41
42    let schema = caps.get(1).map(|m| m.as_str().to_string());
43    let table = caps.get(2)?.as_str().to_string();
44    let columns = caps
45        .get(3)
46        .map(|m| {
47            m.as_str()
48                .split(',')
49                .map(|c| c.trim().trim_matches('"').trim_matches('`').to_string())
50                .collect()
51        })
52        .unwrap_or_default();
53
54    Some(CopyHeader {
55        schema,
56        table,
57        columns,
58    })
59}
60
61/// Strip leading SQL comments from a string
62fn strip_leading_comments(stmt: &str) -> String {
63    let mut result = stmt.trim();
64    loop {
65        if result.starts_with("--") {
66            if let Some(pos) = result.find('\n') {
67                result = result[pos + 1..].trim();
68                continue;
69            } else {
70                return String::new();
71            }
72        }
73        if result.starts_with("/*") {
74            if let Some(pos) = result.find("*/") {
75                result = result[pos + 2..].trim();
76                continue;
77            } else {
78                return String::new();
79            }
80        }
81        break;
82    }
83    result.to_string()
84}
85
86/// Convert a COPY data block to INSERT statements
87///
88/// # Arguments
89/// * `header` - Parsed COPY header with table/column info
90/// * `data` - The data block (tab-separated rows ending with \.)
91/// * `target_dialect` - Target SQL dialect for quoting
92///
93/// # Returns
94/// Vector of INSERT statements as bytes
95pub fn copy_to_inserts(
96    header: &CopyHeader,
97    data: &[u8],
98    target_dialect: crate::parser::SqlDialect,
99) -> Vec<Vec<u8>> {
100    let mut inserts = Vec::new();
101    let rows = parse_copy_data(data);
102
103    if rows.is_empty() {
104        return inserts;
105    }
106
107    // Build INSERT prefix
108    let quote_char = match target_dialect {
109        crate::parser::SqlDialect::MySql => '`',
110        _ => '"',
111    };
112
113    let table_ref = if let Some(ref schema) = header.schema {
114        if target_dialect == crate::parser::SqlDialect::MySql {
115            // MySQL: just use table name without schema
116            format!("{}{}{}", quote_char, header.table, quote_char)
117        } else {
118            format!(
119                "{}{}{}.{}{}{}",
120                quote_char, schema, quote_char, quote_char, header.table, quote_char
121            )
122        }
123    } else {
124        format!("{}{}{}", quote_char, header.table, quote_char)
125    };
126
127    let columns_str = if header.columns.is_empty() {
128        String::new()
129    } else {
130        let cols: Vec<String> = header
131            .columns
132            .iter()
133            .map(|c| format!("{}{}{}", quote_char, c, quote_char))
134            .collect();
135        format!(" ({})", cols.join(", "))
136    };
137
138    // Generate batched INSERTs
139    for chunk in rows.chunks(MAX_ROWS_PER_INSERT) {
140        let mut insert = format!("INSERT INTO {}{} VALUES\n", table_ref, columns_str);
141
142        for (i, row) in chunk.iter().enumerate() {
143            if i > 0 {
144                insert.push_str(",\n");
145            }
146            insert.push('(');
147
148            for (j, value) in row.iter().enumerate() {
149                if j > 0 {
150                    insert.push_str(", ");
151                }
152                insert.push_str(&format_value(value, target_dialect));
153            }
154
155            insert.push(')');
156        }
157
158        insert.push(';');
159        inserts.push(insert.into_bytes());
160    }
161
162    inserts
163}
164
165/// A parsed value from COPY data
166#[derive(Debug, Clone)]
167pub enum CopyValue {
168    Null,
169    Text(String),
170}
171
172/// Parse COPY data block into rows of values
173pub fn parse_copy_data(data: &[u8]) -> Vec<Vec<CopyValue>> {
174    let mut rows = Vec::new();
175    let mut pos = 0;
176
177    while pos < data.len() {
178        // Find end of line
179        let line_end = data[pos..]
180            .iter()
181            .position(|&b| b == b'\n')
182            .map(|p| pos + p)
183            .unwrap_or(data.len());
184
185        let line = &data[pos..line_end];
186
187        // Check for terminator
188        if line == b"\\." || line.is_empty() {
189            pos = line_end + 1;
190            continue;
191        }
192
193        // Parse the row
194        let row = parse_row(line);
195        if !row.is_empty() {
196            rows.push(row);
197        }
198
199        pos = line_end + 1;
200    }
201
202    rows
203}
204
205/// Parse a single tab-separated row
206fn parse_row(line: &[u8]) -> Vec<CopyValue> {
207    let mut values = Vec::new();
208    let mut start = 0;
209
210    for (i, &b) in line.iter().enumerate() {
211        if b == b'\t' {
212            values.push(parse_value(&line[start..i]));
213            start = i + 1;
214        }
215    }
216    // Last value
217    if start <= line.len() {
218        values.push(parse_value(&line[start..]));
219    }
220
221    values
222}
223
224/// Parse a single COPY value
225fn parse_value(value: &[u8]) -> CopyValue {
226    // Check for NULL marker
227    if value == b"\\N" {
228        return CopyValue::Null;
229    }
230
231    // Decode escape sequences
232    let decoded = decode_escapes(value);
233    CopyValue::Text(decoded)
234}
235
236/// Decode PostgreSQL COPY escape sequences
237fn decode_escapes(value: &[u8]) -> String {
238    let mut result = String::with_capacity(value.len());
239    let mut i = 0;
240
241    while i < value.len() {
242        if value[i] == b'\\' && i + 1 < value.len() {
243            let next = value[i + 1];
244            let decoded = match next {
245                b'n' => '\n',
246                b'r' => '\r',
247                b't' => '\t',
248                b'\\' => '\\',
249                b'b' => '\x08', // backspace
250                b'f' => '\x0C', // form feed
251                b'v' => '\x0B', // vertical tab
252                _ => {
253                    // Unknown escape or octal, try octal
254                    if next.is_ascii_digit() {
255                        // Try to parse octal (up to 3 digits)
256                        let mut octal_val = 0u8;
257                        let mut consumed = 0;
258                        for j in 0..3 {
259                            if i + 1 + j < value.len() {
260                                let d = value[i + 1 + j];
261                                if (b'0'..=b'7').contains(&d) {
262                                    octal_val = octal_val * 8 + (d - b'0');
263                                    consumed += 1;
264                                } else {
265                                    break;
266                                }
267                            }
268                        }
269                        if consumed > 0 {
270                            result.push(octal_val as char);
271                            i += 1 + consumed;
272                            continue;
273                        }
274                    }
275                    // Unknown escape, keep as-is
276                    result.push('\\');
277                    result.push(next as char);
278                    i += 2;
279                    continue;
280                }
281            };
282            result.push(decoded);
283            i += 2;
284        } else {
285            // Regular character - handle UTF-8 properly
286            if value[i] < 128 {
287                result.push(value[i] as char);
288                i += 1;
289            } else {
290                // Multi-byte UTF-8 sequence
291                let remaining = &value[i..];
292                if let Ok(s) = std::str::from_utf8(remaining) {
293                    if let Some(c) = s.chars().next() {
294                        result.push(c);
295                        i += c.len_utf8();
296                    } else {
297                        i += 1;
298                    }
299                } else {
300                    // Invalid UTF-8, just push the byte as replacement char
301                    result.push('\u{FFFD}');
302                    i += 1;
303                }
304            }
305        }
306    }
307
308    result
309}
310
311/// Format a value for SQL INSERT
312fn format_value(value: &CopyValue, dialect: crate::parser::SqlDialect) -> String {
313    match value {
314        CopyValue::Null => "NULL".to_string(),
315        CopyValue::Text(s) => {
316            // Escape quotes based on dialect
317            let escaped = match dialect {
318                crate::parser::SqlDialect::MySql => {
319                    // MySQL: escape with backslash
320                    s.replace('\\', "\\\\")
321                        .replace('\'', "\\'")
322                        .replace('\n', "\\n")
323                        .replace('\r', "\\r")
324                        .replace('\t', "\\t")
325                        .replace('\0', "\\0")
326                }
327                _ => {
328                    // PostgreSQL/SQLite: escape by doubling
329                    s.replace('\'', "''")
330                }
331            };
332            format!("'{}'", escaped)
333        }
334    }
335}