sql_splitter/convert/
copy_to_insert.rs1use once_cell::sync::Lazy;
10use regex::Regex;
11
12const MAX_ROWS_PER_INSERT: usize = 100;
14
15#[derive(Debug, Clone)]
17pub struct CopyHeader {
18 pub schema: Option<String>,
20 pub table: String,
22 pub columns: Vec<String>,
24}
25
26pub fn parse_copy_header(stmt: &str) -> Option<CopyHeader> {
29 let stmt = strip_leading_comments(stmt);
31
32 static RE_COPY: Lazy<Regex> = Lazy::new(|| {
33 Regex::new(
36 r#"(?i)^\s*COPY\s+(?:ONLY\s+)?(?:"?(\w+)"?\.)?["]?(\w+)["]?\s*(?:\(([^)]+)\))?\s+FROM\s+stdin"#
37 ).unwrap()
38 });
39
40 let caps = RE_COPY.captures(&stmt)?;
41
42 let schema = caps.get(1).map(|m| m.as_str().to_string());
43 let table = caps.get(2)?.as_str().to_string();
44 let columns = caps
45 .get(3)
46 .map(|m| {
47 m.as_str()
48 .split(',')
49 .map(|c| c.trim().trim_matches('"').trim_matches('`').to_string())
50 .collect()
51 })
52 .unwrap_or_default();
53
54 Some(CopyHeader {
55 schema,
56 table,
57 columns,
58 })
59}
60
61fn strip_leading_comments(stmt: &str) -> String {
63 let mut result = stmt.trim();
64 loop {
65 if result.starts_with("--") {
66 if let Some(pos) = result.find('\n') {
67 result = result[pos + 1..].trim();
68 continue;
69 } else {
70 return String::new();
71 }
72 }
73 if result.starts_with("/*") {
74 if let Some(pos) = result.find("*/") {
75 result = result[pos + 2..].trim();
76 continue;
77 } else {
78 return String::new();
79 }
80 }
81 break;
82 }
83 result.to_string()
84}
85
86pub fn copy_to_inserts(
96 header: &CopyHeader,
97 data: &[u8],
98 target_dialect: crate::parser::SqlDialect,
99) -> Vec<Vec<u8>> {
100 let mut inserts = Vec::new();
101 let rows = parse_copy_data(data);
102
103 if rows.is_empty() {
104 return inserts;
105 }
106
107 let quote_char = match target_dialect {
109 crate::parser::SqlDialect::MySql => '`',
110 _ => '"',
111 };
112
113 let table_ref = if let Some(ref schema) = header.schema {
114 if target_dialect == crate::parser::SqlDialect::MySql {
115 format!("{}{}{}", quote_char, header.table, quote_char)
117 } else {
118 format!(
119 "{}{}{}.{}{}{}",
120 quote_char, schema, quote_char, quote_char, header.table, quote_char
121 )
122 }
123 } else {
124 format!("{}{}{}", quote_char, header.table, quote_char)
125 };
126
127 let columns_str = if header.columns.is_empty() {
128 String::new()
129 } else {
130 let cols: Vec<String> = header
131 .columns
132 .iter()
133 .map(|c| format!("{}{}{}", quote_char, c, quote_char))
134 .collect();
135 format!(" ({})", cols.join(", "))
136 };
137
138 for chunk in rows.chunks(MAX_ROWS_PER_INSERT) {
140 let mut insert = format!("INSERT INTO {}{} VALUES\n", table_ref, columns_str);
141
142 for (i, row) in chunk.iter().enumerate() {
143 if i > 0 {
144 insert.push_str(",\n");
145 }
146 insert.push('(');
147
148 for (j, value) in row.iter().enumerate() {
149 if j > 0 {
150 insert.push_str(", ");
151 }
152 insert.push_str(&format_value(value, target_dialect));
153 }
154
155 insert.push(')');
156 }
157
158 insert.push(';');
159 inserts.push(insert.into_bytes());
160 }
161
162 inserts
163}
164
165#[derive(Debug, Clone)]
167pub enum CopyValue {
168 Null,
169 Text(String),
170}
171
172pub fn parse_copy_data(data: &[u8]) -> Vec<Vec<CopyValue>> {
174 let mut rows = Vec::new();
175 let mut pos = 0;
176
177 while pos < data.len() {
178 let line_end = data[pos..]
180 .iter()
181 .position(|&b| b == b'\n')
182 .map(|p| pos + p)
183 .unwrap_or(data.len());
184
185 let line = &data[pos..line_end];
186
187 if line == b"\\." || line.is_empty() {
189 pos = line_end + 1;
190 continue;
191 }
192
193 let row = parse_row(line);
195 if !row.is_empty() {
196 rows.push(row);
197 }
198
199 pos = line_end + 1;
200 }
201
202 rows
203}
204
205fn parse_row(line: &[u8]) -> Vec<CopyValue> {
207 let mut values = Vec::new();
208 let mut start = 0;
209
210 for (i, &b) in line.iter().enumerate() {
211 if b == b'\t' {
212 values.push(parse_value(&line[start..i]));
213 start = i + 1;
214 }
215 }
216 if start <= line.len() {
218 values.push(parse_value(&line[start..]));
219 }
220
221 values
222}
223
224fn parse_value(value: &[u8]) -> CopyValue {
226 if value == b"\\N" {
228 return CopyValue::Null;
229 }
230
231 let decoded = decode_escapes(value);
233 CopyValue::Text(decoded)
234}
235
236fn decode_escapes(value: &[u8]) -> String {
238 let mut result = String::with_capacity(value.len());
239 let mut i = 0;
240
241 while i < value.len() {
242 if value[i] == b'\\' && i + 1 < value.len() {
243 let next = value[i + 1];
244 let decoded = match next {
245 b'n' => '\n',
246 b'r' => '\r',
247 b't' => '\t',
248 b'\\' => '\\',
249 b'b' => '\x08', b'f' => '\x0C', b'v' => '\x0B', _ => {
253 if next.is_ascii_digit() {
255 let mut octal_val = 0u8;
257 let mut consumed = 0;
258 for j in 0..3 {
259 if i + 1 + j < value.len() {
260 let d = value[i + 1 + j];
261 if (b'0'..=b'7').contains(&d) {
262 octal_val = octal_val * 8 + (d - b'0');
263 consumed += 1;
264 } else {
265 break;
266 }
267 }
268 }
269 if consumed > 0 {
270 result.push(octal_val as char);
271 i += 1 + consumed;
272 continue;
273 }
274 }
275 result.push('\\');
277 result.push(next as char);
278 i += 2;
279 continue;
280 }
281 };
282 result.push(decoded);
283 i += 2;
284 } else {
285 if value[i] < 128 {
287 result.push(value[i] as char);
288 i += 1;
289 } else {
290 let remaining = &value[i..];
292 if let Ok(s) = std::str::from_utf8(remaining) {
293 if let Some(c) = s.chars().next() {
294 result.push(c);
295 i += c.len_utf8();
296 } else {
297 i += 1;
298 }
299 } else {
300 result.push('\u{FFFD}');
302 i += 1;
303 }
304 }
305 }
306 }
307
308 result
309}
310
311fn format_value(value: &CopyValue, dialect: crate::parser::SqlDialect) -> String {
313 match value {
314 CopyValue::Null => "NULL".to_string(),
315 CopyValue::Text(s) => {
316 let escaped = match dialect {
318 crate::parser::SqlDialect::MySql => {
319 s.replace('\\', "\\\\")
321 .replace('\'', "\\'")
322 .replace('\n', "\\n")
323 .replace('\r', "\\r")
324 .replace('\t', "\\t")
325 .replace('\0', "\\0")
326 }
327 _ => {
328 s.replace('\'', "''")
330 }
331 };
332 format!("'{}'", escaped)
333 }
334 }
335}