sql_splitter/parser/
mysql_insert.rs1use crate::schema::{ColumnId, ColumnType, TableSchema};
7use ahash::AHashSet;
8use smallvec::SmallVec;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum PkValue {
13 Int(i64),
15 BigInt(i128),
17 Text(Box<str>),
19 Null,
21}
22
23impl PkValue {
24 pub fn is_null(&self) -> bool {
26 matches!(self, PkValue::Null)
27 }
28}
29
30pub type PkTuple = SmallVec<[PkValue; 2]>;
32
33pub type PkSet = AHashSet<PkTuple>;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub struct FkRef {
39 pub table_id: u32,
41 pub fk_index: u16,
43}
44
45#[derive(Debug, Clone)]
47pub struct ParsedRow {
48 pub raw: Vec<u8>,
50 pub pk: Option<PkTuple>,
52 pub fk_values: Vec<(FkRef, PkTuple)>,
55}
56
57pub struct InsertParser<'a> {
59 stmt: &'a [u8],
60 pos: usize,
61 table_schema: Option<&'a TableSchema>,
62 column_order: Vec<Option<ColumnId>>,
64}
65
66impl<'a> InsertParser<'a> {
67 pub fn new(stmt: &'a [u8]) -> Self {
69 Self {
70 stmt,
71 pos: 0,
72 table_schema: None,
73 column_order: Vec::new(),
74 }
75 }
76
77 pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
79 self.table_schema = Some(schema);
80 self
81 }
82
83 pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
85 let values_pos = self.find_values_keyword()?;
87 self.pos = values_pos;
88
89 self.parse_column_list();
91
92 let mut rows = Vec::new();
94 while self.pos < self.stmt.len() {
95 self.skip_whitespace();
96
97 if self.pos >= self.stmt.len() {
98 break;
99 }
100
101 if self.stmt[self.pos] == b'(' {
102 if let Some(row) = self.parse_row()? {
103 rows.push(row);
104 }
105 } else if self.stmt[self.pos] == b',' {
106 self.pos += 1;
107 } else if self.stmt[self.pos] == b';' {
108 break;
109 } else {
110 self.pos += 1;
111 }
112 }
113
114 Ok(rows)
115 }
116
117 fn find_values_keyword(&self) -> anyhow::Result<usize> {
119 let stmt_str = String::from_utf8_lossy(self.stmt);
120 let upper = stmt_str.to_uppercase();
121
122 if let Some(pos) = upper.find("VALUES") {
123 Ok(pos + 6) } else {
125 anyhow::bail!("INSERT statement missing VALUES keyword")
126 }
127 }
128
129 fn parse_column_list(&mut self) {
131 if self.table_schema.is_none() {
132 return;
133 }
134
135 let schema = self.table_schema.unwrap();
136
137 let before_values = &self.stmt[..self.pos.saturating_sub(6)];
140 let stmt_str = String::from_utf8_lossy(before_values);
141
142 if let Some(close_paren) = stmt_str.rfind(')') {
144 if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
145 let col_list = &stmt_str[open_paren + 1..close_paren];
146 if !col_list.to_uppercase().contains("SELECT") {
148 let cols: Vec<&str> = col_list.split(',').collect();
149 self.column_order = cols
150 .iter()
151 .map(|c| {
152 let name = c.trim().trim_matches('`').trim_matches('"');
153 schema.get_column_id(name)
154 })
155 .collect();
156 return;
157 }
158 }
159 }
160
161 self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
163 }
164
165 fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
167 self.skip_whitespace();
168
169 if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
170 return Ok(None);
171 }
172
173 let start = self.pos;
174 self.pos += 1; let mut values: Vec<ParsedValue> = Vec::new();
177 let mut depth = 1;
178
179 while self.pos < self.stmt.len() && depth > 0 {
180 self.skip_whitespace();
181
182 if self.pos >= self.stmt.len() {
183 break;
184 }
185
186 match self.stmt[self.pos] {
187 b'(' => {
188 depth += 1;
189 self.pos += 1;
190 }
191 b')' => {
192 depth -= 1;
193 self.pos += 1;
194 }
195 b',' if depth == 1 => {
196 self.pos += 1;
197 }
198 _ if depth == 1 => {
199 values.push(self.parse_value()?);
200 }
201 _ => {
202 self.pos += 1;
203 }
204 }
205 }
206
207 let end = self.pos;
208 let raw = self.stmt[start..end].to_vec();
209
210 let (pk, fk_values) = if let Some(schema) = self.table_schema {
212 self.extract_pk_fk(&values, schema)
213 } else {
214 (None, Vec::new())
215 };
216
217 Ok(Some(ParsedRow { raw, pk, fk_values }))
218 }
219
220 fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
222 self.skip_whitespace();
223
224 if self.pos >= self.stmt.len() {
225 return Ok(ParsedValue::Null);
226 }
227
228 let b = self.stmt[self.pos];
229
230 if self.pos + 4 <= self.stmt.len() {
232 let word = &self.stmt[self.pos..self.pos + 4];
233 if word.eq_ignore_ascii_case(b"NULL") {
234 self.pos += 4;
235 return Ok(ParsedValue::Null);
236 }
237 }
238
239 if b == b'\'' {
241 return self.parse_string_value();
242 }
243
244 if b == b'0' && self.pos + 1 < self.stmt.len() {
246 let next = self.stmt[self.pos + 1];
247 if next == b'x' || next == b'X' {
248 return self.parse_hex_value();
249 }
250 }
251
252 self.parse_number_value()
254 }
255
256 fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
258 self.pos += 1; let mut value = Vec::new();
261 let mut escape_next = false;
262
263 while self.pos < self.stmt.len() {
264 let b = self.stmt[self.pos];
265
266 if escape_next {
267 let escaped = match b {
269 b'n' => b'\n',
270 b'r' => b'\r',
271 b't' => b'\t',
272 b'0' => 0,
273 _ => b, };
275 value.push(escaped);
276 escape_next = false;
277 self.pos += 1;
278 } else if b == b'\\' {
279 escape_next = true;
280 self.pos += 1;
281 } else if b == b'\'' {
282 if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
284 value.push(b'\'');
285 self.pos += 2;
286 } else {
287 self.pos += 1; break;
289 }
290 } else {
291 value.push(b);
292 self.pos += 1;
293 }
294 }
295
296 let text = String::from_utf8_lossy(&value).into_owned();
297
298 Ok(ParsedValue::String { value: text })
299 }
300
301 fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
303 let start = self.pos;
304 self.pos += 2; while self.pos < self.stmt.len() {
307 let b = self.stmt[self.pos];
308 if b.is_ascii_hexdigit() {
309 self.pos += 1;
310 } else {
311 break;
312 }
313 }
314
315 let raw = self.stmt[start..self.pos].to_vec();
316 Ok(ParsedValue::Hex(raw))
317 }
318
319 fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
321 let start = self.pos;
322 let mut has_dot = false;
323
324 if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
326 self.pos += 1;
327 }
328
329 while self.pos < self.stmt.len() {
330 let b = self.stmt[self.pos];
331 if b.is_ascii_digit() {
332 self.pos += 1;
333 } else if b == b'.' && !has_dot {
334 has_dot = true;
335 self.pos += 1;
336 } else if b == b'e' || b == b'E' {
337 self.pos += 1;
339 if self.pos < self.stmt.len()
340 && (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
341 {
342 self.pos += 1;
343 }
344 } else if b == b',' || b == b')' || b.is_ascii_whitespace() {
345 break;
346 } else {
347 while self.pos < self.stmt.len() {
349 let c = self.stmt[self.pos];
350 if c == b',' || c == b')' {
351 break;
352 }
353 self.pos += 1;
354 }
355 break;
356 }
357 }
358
359 let raw = self.stmt[start..self.pos].to_vec();
360 let value_str = String::from_utf8_lossy(&raw);
361
362 if !has_dot {
364 if let Ok(n) = value_str.parse::<i64>() {
365 return Ok(ParsedValue::Integer(n));
366 }
367 if let Ok(n) = value_str.parse::<i128>() {
368 return Ok(ParsedValue::BigInteger(n));
369 }
370 }
371
372 Ok(ParsedValue::Other(raw))
374 }
375
376 fn skip_whitespace(&mut self) {
378 while self.pos < self.stmt.len() {
379 let b = self.stmt[self.pos];
380 if b.is_ascii_whitespace() {
381 self.pos += 1;
382 } else {
383 break;
384 }
385 }
386 }
387
388 fn extract_pk_fk(
390 &self,
391 values: &[ParsedValue],
392 schema: &TableSchema,
393 ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>) {
394 let mut pk_values = PkTuple::new();
395 let mut fk_values = Vec::new();
396
397 for (idx, col_id_opt) in self.column_order.iter().enumerate() {
399 if let Some(col_id) = col_id_opt {
400 if schema.is_pk_column(*col_id) {
401 if let Some(value) = values.get(idx) {
402 let pk_val = self.value_to_pk(value, schema.column(*col_id));
403 pk_values.push(pk_val);
404 }
405 }
406 }
407 }
408
409 for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
411 if fk.referenced_table_id.is_none() {
412 continue;
413 }
414
415 let mut fk_tuple = PkTuple::new();
416 let mut all_non_null = true;
417
418 for &col_id in &fk.columns {
419 if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
421 if let Some(value) = values.get(idx) {
422 let pk_val = self.value_to_pk(value, schema.column(col_id));
423 if pk_val.is_null() {
424 all_non_null = false;
425 break;
426 }
427 fk_tuple.push(pk_val);
428 }
429 }
430 }
431
432 if all_non_null && !fk_tuple.is_empty() {
433 fk_values.push((
434 FkRef {
435 table_id: schema.id.0,
436 fk_index: fk_idx as u16,
437 },
438 fk_tuple,
439 ));
440 }
441 }
442
443 let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
444 None
445 } else {
446 Some(pk_values)
447 };
448
449 (pk, fk_values)
450 }
451
452 fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
454 match value {
455 ParsedValue::Null => PkValue::Null,
456 ParsedValue::Integer(n) => PkValue::Int(*n),
457 ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
458 ParsedValue::String { value } => {
459 if let Some(col) = col {
461 match col.col_type {
462 ColumnType::Int => {
463 if let Ok(n) = value.parse::<i64>() {
464 return PkValue::Int(n);
465 }
466 }
467 ColumnType::BigInt => {
468 if let Ok(n) = value.parse::<i128>() {
469 return PkValue::BigInt(n);
470 }
471 }
472 _ => {}
473 }
474 }
475 PkValue::Text(value.clone().into_boxed_str())
476 }
477 ParsedValue::Hex(raw) => {
478 PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
479 }
480 ParsedValue::Other(raw) => {
481 PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
482 }
483 }
484 }
485}
486
487#[derive(Debug, Clone)]
489enum ParsedValue {
490 Null,
491 Integer(i64),
492 BigInteger(i128),
493 String { value: String },
494 Hex(Vec<u8>),
495 Other(Vec<u8>),
496}
497
498pub fn parse_mysql_insert_rows(
500 stmt: &[u8],
501 schema: &TableSchema,
502) -> anyhow::Result<Vec<ParsedRow>> {
503 let mut parser = InsertParser::new(stmt).with_schema(schema);
504 parser.parse_rows()
505}
506
507pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
509 let mut parser = InsertParser::new(stmt);
510 parser.parse_rows()
511}