sql_splitter/convert/
copy_to_insert.rs1use once_cell::sync::Lazy;
10use regex::Regex;
11
12const MAX_ROWS_PER_INSERT: usize = 1000;
14
15#[derive(Debug, Clone)]
17pub struct CopyHeader {
18 pub schema: Option<String>,
20 pub table: String,
22 pub columns: Vec<String>,
24}
25
26pub fn parse_copy_header(stmt: &str) -> Option<CopyHeader> {
29 let stmt = strip_leading_comments(stmt);
31
32 static RE_COPY: Lazy<Regex> = Lazy::new(|| {
33 Regex::new(
36 r#"(?i)^\s*COPY\s+(?:ONLY\s+)?(?:"?(\w+)"?\.)?["]?(\w+)["]?\s*(?:\(([^)]+)\))?\s+FROM\s+stdin"#
37 ).unwrap()
38 });
39
40 let caps = RE_COPY.captures(&stmt)?;
41
42 let schema = caps.get(1).map(|m| m.as_str().to_string());
43 let table = caps.get(2)?.as_str().to_string();
44 let columns = caps
45 .get(3)
46 .map(|m| {
47 m.as_str()
48 .split(',')
49 .map(|c| c.trim().trim_matches('"').trim_matches('`').to_string())
50 .collect()
51 })
52 .unwrap_or_default();
53
54 Some(CopyHeader {
55 schema,
56 table,
57 columns,
58 })
59}
60
61fn strip_leading_comments(stmt: &str) -> String {
63 let mut result = stmt.trim();
64 loop {
65 if result.starts_with("--") {
66 if let Some(pos) = result.find('\n') {
67 result = result[pos + 1..].trim();
68 continue;
69 } else {
70 return String::new();
71 }
72 }
73 if result.starts_with("/*") {
74 if let Some(pos) = result.find("*/") {
75 result = result[pos + 2..].trim();
76 continue;
77 } else {
78 return String::new();
79 }
80 }
81 break;
82 }
83 result.to_string()
84}
85
86pub fn copy_to_inserts(
96 header: &CopyHeader,
97 data: &[u8],
98 target_dialect: crate::parser::SqlDialect,
99) -> Vec<Vec<u8>> {
100 let mut inserts = Vec::new();
101 let rows = parse_copy_data(data);
102
103 if rows.is_empty() {
104 return inserts;
105 }
106
107 let quote_char = match target_dialect {
109 crate::parser::SqlDialect::MySql => '`',
110 _ => '"',
111 };
112
113 let table_ref = if let Some(ref schema) = header.schema {
114 if target_dialect == crate::parser::SqlDialect::MySql {
115 format!("{}{}{}", quote_char, header.table, quote_char)
117 } else if schema == "public" || schema == "pg_catalog" {
118 format!("{}{}{}", quote_char, header.table, quote_char)
120 } else {
121 format!(
122 "{}{}{}.{}{}{}",
123 quote_char, schema, quote_char, quote_char, header.table, quote_char
124 )
125 }
126 } else {
127 format!("{}{}{}", quote_char, header.table, quote_char)
128 };
129
130 let columns_str = if header.columns.is_empty() {
131 String::new()
132 } else {
133 let cols: Vec<String> = header
134 .columns
135 .iter()
136 .map(|c| format!("{}{}{}", quote_char, c, quote_char))
137 .collect();
138 format!(" ({})", cols.join(", "))
139 };
140
141 for chunk in rows.chunks(MAX_ROWS_PER_INSERT) {
143 let mut insert = format!("INSERT INTO {}{} VALUES\n", table_ref, columns_str);
144
145 for (i, row) in chunk.iter().enumerate() {
146 if i > 0 {
147 insert.push_str(",\n");
148 }
149 insert.push('(');
150
151 for (j, value) in row.iter().enumerate() {
152 if j > 0 {
153 insert.push_str(", ");
154 }
155 insert.push_str(&format_value(value, target_dialect));
156 }
157
158 insert.push(')');
159 }
160
161 insert.push(';');
162 inserts.push(insert.into_bytes());
163 }
164
165 inserts
166}
167
168#[derive(Debug, Clone)]
170pub enum CopyValue {
171 Null,
172 Text(String),
173}
174
175pub fn parse_copy_data(data: &[u8]) -> Vec<Vec<CopyValue>> {
177 let mut rows = Vec::new();
178 let mut pos = 0;
179
180 while pos < data.len() {
181 let line_end = data[pos..]
183 .iter()
184 .position(|&b| b == b'\n')
185 .map(|p| pos + p)
186 .unwrap_or(data.len());
187
188 let line = &data[pos..line_end];
189
190 if line == b"\\." || line.is_empty() {
192 pos = line_end + 1;
193 continue;
194 }
195
196 let row = parse_row(line);
198 if !row.is_empty() {
199 rows.push(row);
200 }
201
202 pos = line_end + 1;
203 }
204
205 rows
206}
207
208fn parse_row(line: &[u8]) -> Vec<CopyValue> {
210 let mut values = Vec::new();
211 let mut start = 0;
212
213 for (i, &b) in line.iter().enumerate() {
214 if b == b'\t' {
215 values.push(parse_value(&line[start..i]));
216 start = i + 1;
217 }
218 }
219 if start <= line.len() {
221 values.push(parse_value(&line[start..]));
222 }
223
224 values
225}
226
227fn parse_value(value: &[u8]) -> CopyValue {
229 if value == b"\\N" {
231 return CopyValue::Null;
232 }
233
234 let decoded = decode_escapes(value);
236 CopyValue::Text(decoded)
237}
238
239fn decode_escapes(value: &[u8]) -> String {
241 let mut result = String::with_capacity(value.len());
242 let mut i = 0;
243
244 while i < value.len() {
245 if value[i] == b'\\' && i + 1 < value.len() {
246 let next = value[i + 1];
247 let decoded = match next {
248 b'n' => '\n',
249 b'r' => '\r',
250 b't' => '\t',
251 b'\\' => '\\',
252 b'b' => '\x08', b'f' => '\x0C', b'v' => '\x0B', _ => {
256 if next.is_ascii_digit() {
258 let mut octal_val = 0u8;
260 let mut consumed = 0;
261 for j in 0..3 {
262 if i + 1 + j < value.len() {
263 let d = value[i + 1 + j];
264 if (b'0'..=b'7').contains(&d) {
265 octal_val = octal_val * 8 + (d - b'0');
266 consumed += 1;
267 } else {
268 break;
269 }
270 }
271 }
272 if consumed > 0 {
273 result.push(octal_val as char);
274 i += 1 + consumed;
275 continue;
276 }
277 }
278 result.push('\\');
280 result.push(next as char);
281 i += 2;
282 continue;
283 }
284 };
285 result.push(decoded);
286 i += 2;
287 } else {
288 if value[i] < 128 {
290 result.push(value[i] as char);
291 i += 1;
292 } else {
293 let remaining = &value[i..];
295 if let Ok(s) = std::str::from_utf8(remaining) {
296 if let Some(c) = s.chars().next() {
297 result.push(c);
298 i += c.len_utf8();
299 } else {
300 i += 1;
301 }
302 } else {
303 result.push('\u{FFFD}');
305 i += 1;
306 }
307 }
308 }
309 }
310
311 result
312}
313
314fn format_value(value: &CopyValue, dialect: crate::parser::SqlDialect) -> String {
316 match value {
317 CopyValue::Null => "NULL".to_string(),
318 CopyValue::Text(s) => {
319 let escaped = match dialect {
321 crate::parser::SqlDialect::MySql => {
322 s.replace('\\', "\\\\")
324 .replace('\'', "\\'")
325 .replace('\n', "\\n")
326 .replace('\r', "\\r")
327 .replace('\t', "\\t")
328 .replace('\0', "\\0")
329 }
330 _ => {
331 s.replace('\'', "''")
333 }
334 };
335 format!("'{}'", escaped)
336 }
337 }
338}