sql_splitter/parser/
mysql_insert.rs1use crate::schema::{ColumnId, ColumnType, TableSchema};
7use ahash::AHashSet;
8use smallvec::SmallVec;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum PkValue {
13 Int(i64),
15 BigInt(i128),
17 Text(Box<str>),
19 Null,
21}
22
23impl PkValue {
24 pub fn is_null(&self) -> bool {
26 matches!(self, PkValue::Null)
27 }
28}
29
30pub type PkTuple = SmallVec<[PkValue; 2]>;
32
33pub type PkSet = AHashSet<PkTuple>;
35
36pub type PkHashSet = AHashSet<u64>;
40
41pub fn hash_pk_tuple(pk: &PkTuple) -> u64 {
44 use std::hash::{Hash, Hasher};
45 let mut hasher = ahash::AHasher::default();
46
47 (pk.len() as u8).hash(&mut hasher);
49
50 for v in pk {
51 match v {
52 PkValue::Int(i) => {
53 0u8.hash(&mut hasher);
54 i.hash(&mut hasher);
55 }
56 PkValue::BigInt(i) => {
57 1u8.hash(&mut hasher);
58 i.hash(&mut hasher);
59 }
60 PkValue::Text(s) => {
61 2u8.hash(&mut hasher);
62 s.hash(&mut hasher);
63 }
64 PkValue::Null => {
65 3u8.hash(&mut hasher);
66 }
67 }
68 }
69
70 hasher.finish()
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub struct FkRef {
76 pub table_id: u32,
78 pub fk_index: u16,
80}
81
82#[derive(Debug, Clone)]
84pub struct ParsedRow {
85 pub raw: Vec<u8>,
87 pub pk: Option<PkTuple>,
89 pub fk_values: Vec<(FkRef, PkTuple)>,
92 pub all_values: Vec<PkValue>,
94}
95
96pub struct InsertParser<'a> {
98 stmt: &'a [u8],
99 pos: usize,
100 table_schema: Option<&'a TableSchema>,
101 column_order: Vec<Option<ColumnId>>,
103}
104
105impl<'a> InsertParser<'a> {
106 pub fn new(stmt: &'a [u8]) -> Self {
108 Self {
109 stmt,
110 pos: 0,
111 table_schema: None,
112 column_order: Vec::new(),
113 }
114 }
115
116 pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
118 self.table_schema = Some(schema);
119 self
120 }
121
122 pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
124 let values_pos = self.find_values_keyword()?;
126 self.pos = values_pos;
127
128 self.parse_column_list();
130
131 let mut rows = Vec::new();
133 while self.pos < self.stmt.len() {
134 self.skip_whitespace();
135
136 if self.pos >= self.stmt.len() {
137 break;
138 }
139
140 if self.stmt[self.pos] == b'(' {
141 if let Some(row) = self.parse_row()? {
142 rows.push(row);
143 }
144 } else if self.stmt[self.pos] == b',' {
145 self.pos += 1;
146 } else if self.stmt[self.pos] == b';' {
147 break;
148 } else {
149 self.pos += 1;
150 }
151 }
152
153 Ok(rows)
154 }
155
156 fn find_values_keyword(&self) -> anyhow::Result<usize> {
158 let stmt_str = String::from_utf8_lossy(self.stmt);
159 let upper = stmt_str.to_uppercase();
160
161 if let Some(pos) = upper.find("VALUES") {
162 Ok(pos + 6) } else {
164 anyhow::bail!("INSERT statement missing VALUES keyword")
165 }
166 }
167
168 fn parse_column_list(&mut self) {
170 if self.table_schema.is_none() {
171 return;
172 }
173
174 let schema = self.table_schema.unwrap();
175
176 let before_values = &self.stmt[..self.pos.saturating_sub(6)];
179 let stmt_str = String::from_utf8_lossy(before_values);
180
181 if let Some(close_paren) = stmt_str.rfind(')') {
183 if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
184 let col_list = &stmt_str[open_paren + 1..close_paren];
185 if !col_list.to_uppercase().contains("SELECT") {
187 let cols: Vec<&str> = col_list.split(',').collect();
188 self.column_order = cols
189 .iter()
190 .map(|c| {
191 let name = c.trim().trim_matches('`').trim_matches('"');
192 schema.get_column_id(name)
193 })
194 .collect();
195 return;
196 }
197 }
198 }
199
200 self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
202 }
203
204 fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
206 self.skip_whitespace();
207
208 if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
209 return Ok(None);
210 }
211
212 let start = self.pos;
213 self.pos += 1; let mut values: Vec<ParsedValue> = Vec::new();
216 let mut depth = 1;
217
218 while self.pos < self.stmt.len() && depth > 0 {
219 self.skip_whitespace();
220
221 if self.pos >= self.stmt.len() {
222 break;
223 }
224
225 match self.stmt[self.pos] {
226 b'(' => {
227 depth += 1;
228 self.pos += 1;
229 }
230 b')' => {
231 depth -= 1;
232 self.pos += 1;
233 }
234 b',' if depth == 1 => {
235 self.pos += 1;
236 }
237 _ if depth == 1 => {
238 values.push(self.parse_value()?);
239 }
240 _ => {
241 self.pos += 1;
242 }
243 }
244 }
245
246 let end = self.pos;
247 let raw = self.stmt[start..end].to_vec();
248
249 let (pk, fk_values, all_values) = if let Some(schema) = self.table_schema {
251 self.extract_pk_fk(&values, schema)
252 } else {
253 (None, Vec::new(), Vec::new())
254 };
255
256 Ok(Some(ParsedRow {
257 raw,
258 pk,
259 fk_values,
260 all_values,
261 }))
262 }
263
264 fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
266 self.skip_whitespace();
267
268 if self.pos >= self.stmt.len() {
269 return Ok(ParsedValue::Null);
270 }
271
272 let b = self.stmt[self.pos];
273
274 if self.pos + 4 <= self.stmt.len() {
276 let word = &self.stmt[self.pos..self.pos + 4];
277 if word.eq_ignore_ascii_case(b"NULL") {
278 self.pos += 4;
279 return Ok(ParsedValue::Null);
280 }
281 }
282
283 if b == b'\'' {
285 return self.parse_string_value();
286 }
287
288 if b == b'0' && self.pos + 1 < self.stmt.len() {
290 let next = self.stmt[self.pos + 1];
291 if next == b'x' || next == b'X' {
292 return self.parse_hex_value();
293 }
294 }
295
296 self.parse_number_value()
298 }
299
300 fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
302 self.pos += 1; let mut value = Vec::new();
305 let mut escape_next = false;
306
307 while self.pos < self.stmt.len() {
308 let b = self.stmt[self.pos];
309
310 if escape_next {
311 let escaped = match b {
313 b'n' => b'\n',
314 b'r' => b'\r',
315 b't' => b'\t',
316 b'0' => 0,
317 _ => b, };
319 value.push(escaped);
320 escape_next = false;
321 self.pos += 1;
322 } else if b == b'\\' {
323 escape_next = true;
324 self.pos += 1;
325 } else if b == b'\'' {
326 if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
328 value.push(b'\'');
329 self.pos += 2;
330 } else {
331 self.pos += 1; break;
333 }
334 } else {
335 value.push(b);
336 self.pos += 1;
337 }
338 }
339
340 let text = String::from_utf8_lossy(&value).into_owned();
341
342 Ok(ParsedValue::String { value: text })
343 }
344
345 fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
347 let start = self.pos;
348 self.pos += 2; while self.pos < self.stmt.len() {
351 let b = self.stmt[self.pos];
352 if b.is_ascii_hexdigit() {
353 self.pos += 1;
354 } else {
355 break;
356 }
357 }
358
359 let raw = self.stmt[start..self.pos].to_vec();
360 Ok(ParsedValue::Hex(raw))
361 }
362
363 fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
365 let start = self.pos;
366 let mut has_dot = false;
367
368 if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
370 self.pos += 1;
371 }
372
373 while self.pos < self.stmt.len() {
374 let b = self.stmt[self.pos];
375 if b.is_ascii_digit() {
376 self.pos += 1;
377 } else if b == b'.' && !has_dot {
378 has_dot = true;
379 self.pos += 1;
380 } else if b == b'e' || b == b'E' {
381 self.pos += 1;
383 if self.pos < self.stmt.len()
384 && (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
385 {
386 self.pos += 1;
387 }
388 } else if b == b',' || b == b')' || b.is_ascii_whitespace() {
389 break;
390 } else {
391 while self.pos < self.stmt.len() {
393 let c = self.stmt[self.pos];
394 if c == b',' || c == b')' {
395 break;
396 }
397 self.pos += 1;
398 }
399 break;
400 }
401 }
402
403 let raw = self.stmt[start..self.pos].to_vec();
404 let value_str = String::from_utf8_lossy(&raw);
405
406 if !has_dot {
408 if let Ok(n) = value_str.parse::<i64>() {
409 return Ok(ParsedValue::Integer(n));
410 }
411 if let Ok(n) = value_str.parse::<i128>() {
412 return Ok(ParsedValue::BigInteger(n));
413 }
414 }
415
416 Ok(ParsedValue::Other(raw))
418 }
419
420 fn skip_whitespace(&mut self) {
422 while self.pos < self.stmt.len() {
423 let b = self.stmt[self.pos];
424 if b.is_ascii_whitespace() {
425 self.pos += 1;
426 } else {
427 break;
428 }
429 }
430 }
431
432 fn extract_pk_fk(
434 &self,
435 values: &[ParsedValue],
436 schema: &TableSchema,
437 ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>, Vec<PkValue>) {
438 let mut pk_values = PkTuple::new();
439 let mut fk_values = Vec::new();
440
441 let all_values: Vec<PkValue> = values
443 .iter()
444 .enumerate()
445 .map(|(idx, v)| {
446 let col = self
447 .column_order
448 .get(idx)
449 .and_then(|c| *c)
450 .and_then(|id| schema.column(id));
451 self.value_to_pk(v, col)
452 })
453 .collect();
454
455 for (idx, col_id_opt) in self.column_order.iter().enumerate() {
457 if let Some(col_id) = col_id_opt {
458 if schema.is_pk_column(*col_id) {
459 if let Some(value) = values.get(idx) {
460 let pk_val = self.value_to_pk(value, schema.column(*col_id));
461 pk_values.push(pk_val);
462 }
463 }
464 }
465 }
466
467 for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
469 if fk.referenced_table_id.is_none() {
470 continue;
471 }
472
473 let mut fk_tuple = PkTuple::new();
474 let mut all_non_null = true;
475
476 for &col_id in &fk.columns {
477 if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
479 if let Some(value) = values.get(idx) {
480 let pk_val = self.value_to_pk(value, schema.column(col_id));
481 if pk_val.is_null() {
482 all_non_null = false;
483 break;
484 }
485 fk_tuple.push(pk_val);
486 }
487 }
488 }
489
490 if all_non_null && !fk_tuple.is_empty() {
491 fk_values.push((
492 FkRef {
493 table_id: schema.id.0,
494 fk_index: fk_idx as u16,
495 },
496 fk_tuple,
497 ));
498 }
499 }
500
501 let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
502 None
503 } else {
504 Some(pk_values)
505 };
506
507 (pk, fk_values, all_values)
508 }
509
510 fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
512 match value {
513 ParsedValue::Null => PkValue::Null,
514 ParsedValue::Integer(n) => PkValue::Int(*n),
515 ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
516 ParsedValue::String { value } => {
517 if let Some(col) = col {
519 match col.col_type {
520 ColumnType::Int => {
521 if let Ok(n) = value.parse::<i64>() {
522 return PkValue::Int(n);
523 }
524 }
525 ColumnType::BigInt => {
526 if let Ok(n) = value.parse::<i128>() {
527 return PkValue::BigInt(n);
528 }
529 }
530 _ => {}
531 }
532 }
533 PkValue::Text(value.clone().into_boxed_str())
534 }
535 ParsedValue::Hex(raw) => {
536 PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
537 }
538 ParsedValue::Other(raw) => {
539 PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
540 }
541 }
542 }
543}
544
545#[derive(Debug, Clone)]
547enum ParsedValue {
548 Null,
549 Integer(i64),
550 BigInteger(i128),
551 String { value: String },
552 Hex(Vec<u8>),
553 Other(Vec<u8>),
554}
555
556pub fn parse_mysql_insert_rows(
558 stmt: &[u8],
559 schema: &TableSchema,
560) -> anyhow::Result<Vec<ParsedRow>> {
561 let mut parser = InsertParser::new(stmt).with_schema(schema);
562 parser.parse_rows()
563}
564
565pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
567 let mut parser = InsertParser::new(stmt);
568 parser.parse_rows()
569}