1use crate::schema::{ColumnId, ColumnType, TableSchema};
7use ahash::AHashSet;
8use smallvec::SmallVec;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum PkValue {
13 Int(i64),
15 BigInt(i128),
17 Text(Box<str>),
19 Null,
21}
22
23impl PkValue {
24 pub fn is_null(&self) -> bool {
26 matches!(self, PkValue::Null)
27 }
28}
29
30pub type PkTuple = SmallVec<[PkValue; 2]>;
32
33pub type PkSet = AHashSet<PkTuple>;
35
36pub type PkHashSet = AHashSet<u64>;
40
41pub fn hash_pk_tuple(pk: &PkTuple) -> u64 {
44 use std::hash::{Hash, Hasher};
45 let mut hasher = ahash::AHasher::default();
46
47 (pk.len() as u8).hash(&mut hasher);
49
50 for v in pk {
51 match v {
52 PkValue::Int(i) => {
53 0u8.hash(&mut hasher);
54 i.hash(&mut hasher);
55 }
56 PkValue::BigInt(i) => {
57 1u8.hash(&mut hasher);
58 i.hash(&mut hasher);
59 }
60 PkValue::Text(s) => {
61 2u8.hash(&mut hasher);
62 s.hash(&mut hasher);
63 }
64 PkValue::Null => {
65 3u8.hash(&mut hasher);
66 }
67 }
68 }
69
70 hasher.finish()
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub struct FkRef {
76 pub table_id: u32,
78 pub fk_index: u16,
80}
81
82#[derive(Debug, Clone)]
84pub struct ParsedRow {
85 pub raw: Vec<u8>,
87 pub values: Vec<ParsedValue>,
89 pub pk: Option<PkTuple>,
91 pub fk_values: Vec<(FkRef, PkTuple)>,
94 pub all_values: Vec<PkValue>,
96 pub column_map: Vec<Option<usize>>,
99}
100
101impl ParsedRow {
102 pub fn get_column_value(&self, schema_col_index: usize) -> Option<&PkValue> {
104 self.column_map
105 .get(schema_col_index)
106 .and_then(|v| *v)
107 .and_then(|val_idx| self.all_values.get(val_idx))
108 }
109}
110
111pub struct InsertParser<'a> {
113 stmt: &'a [u8],
114 pos: usize,
115 table_schema: Option<&'a TableSchema>,
116 column_order: Vec<Option<ColumnId>>,
118}
119
120impl<'a> InsertParser<'a> {
121 pub fn new(stmt: &'a [u8]) -> Self {
123 Self {
124 stmt,
125 pos: 0,
126 table_schema: None,
127 column_order: Vec::new(),
128 }
129 }
130
131 pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
133 self.table_schema = Some(schema);
134 self
135 }
136
137 pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
139 let values_pos = self.find_values_keyword()?;
141 self.pos = values_pos;
142
143 self.parse_column_list();
145
146 let mut rows = Vec::new();
148 while self.pos < self.stmt.len() {
149 self.skip_whitespace();
150
151 if self.pos >= self.stmt.len() {
152 break;
153 }
154
155 if self.stmt[self.pos] == b'(' {
156 if let Some(row) = self.parse_row()? {
157 rows.push(row);
158 }
159 } else if self.stmt[self.pos] == b',' {
160 self.pos += 1;
161 } else if self.stmt[self.pos] == b';' {
162 break;
163 } else {
164 self.pos += 1;
165 }
166 }
167
168 Ok(rows)
169 }
170
171 fn find_values_keyword(&self) -> anyhow::Result<usize> {
173 let stmt_str = String::from_utf8_lossy(self.stmt);
174 let upper = stmt_str.to_uppercase();
175
176 if let Some(pos) = upper.find("VALUES") {
177 Ok(pos + 6) } else {
179 anyhow::bail!("INSERT statement missing VALUES keyword")
180 }
181 }
182
183 fn parse_column_list(&mut self) {
185 if self.table_schema.is_none() {
186 return;
187 }
188
189 let schema = self.table_schema.unwrap();
190
191 let before_values = &self.stmt[..self.pos.saturating_sub(6)];
194 let stmt_str = String::from_utf8_lossy(before_values);
195
196 if let Some(close_paren) = stmt_str.rfind(')') {
198 if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
199 let col_list = &stmt_str[open_paren + 1..close_paren];
200 if !col_list.to_uppercase().contains("SELECT") {
202 let cols: Vec<&str> = col_list.split(',').collect();
203 self.column_order = cols
204 .iter()
205 .map(|c| {
206 let name = c
208 .trim()
209 .trim_matches('`')
210 .trim_matches('"')
211 .trim_matches('[')
212 .trim_matches(']');
213 schema.get_column_id(name)
214 })
215 .collect();
216 return;
217 }
218 }
219 }
220
221 self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
223 }
224
225 fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
227 self.skip_whitespace();
228
229 if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
230 return Ok(None);
231 }
232
233 let start = self.pos;
234 self.pos += 1; let mut values: Vec<ParsedValue> = Vec::new();
237 let mut depth = 1;
238
239 while self.pos < self.stmt.len() && depth > 0 {
240 self.skip_whitespace();
241
242 if self.pos >= self.stmt.len() {
243 break;
244 }
245
246 match self.stmt[self.pos] {
247 b'(' => {
248 depth += 1;
249 self.pos += 1;
250 }
251 b')' => {
252 depth -= 1;
253 self.pos += 1;
254 }
255 b',' if depth == 1 => {
256 self.pos += 1;
257 }
258 _ if depth == 1 => {
259 values.push(self.parse_value()?);
260 }
261 _ => {
262 self.pos += 1;
263 }
264 }
265 }
266
267 let end = self.pos;
268 let raw = self.stmt[start..end].to_vec();
269
270 let (pk, fk_values, all_values, column_map) = if let Some(schema) = self.table_schema {
272 let (pk, fk_values, all_values) = self.extract_pk_fk(&values, schema);
273 let column_map = self.build_column_map(schema);
274 (pk, fk_values, all_values, column_map)
275 } else {
276 (None, Vec::new(), Vec::new(), Vec::new())
277 };
278
279 Ok(Some(ParsedRow {
280 raw,
281 values,
282 pk,
283 fk_values,
284 all_values,
285 column_map,
286 }))
287 }
288
289 fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
291 self.skip_whitespace();
292
293 if self.pos >= self.stmt.len() {
294 return Ok(ParsedValue::Null);
295 }
296
297 let b = self.stmt[self.pos];
298
299 if self.pos + 4 <= self.stmt.len() {
301 let word = &self.stmt[self.pos..self.pos + 4];
302 if word.eq_ignore_ascii_case(b"NULL") {
303 self.pos += 4;
304 return Ok(ParsedValue::Null);
305 }
306 }
307
308 if b == b'\'' {
310 return self.parse_string_value();
311 }
312
313 if (b == b'N' || b == b'n')
315 && self.pos + 1 < self.stmt.len()
316 && self.stmt[self.pos + 1] == b'\''
317 {
318 self.pos += 1; return self.parse_string_value();
320 }
321
322 if b == b'0' && self.pos + 1 < self.stmt.len() {
324 let next = self.stmt[self.pos + 1];
325 if next == b'x' || next == b'X' {
326 return self.parse_hex_value();
327 }
328 }
329
330 self.parse_number_value()
332 }
333
334 fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
336 self.pos += 1; let mut value = Vec::new();
339 let mut escape_next = false;
340
341 while self.pos < self.stmt.len() {
342 let b = self.stmt[self.pos];
343
344 if escape_next {
345 let escaped = match b {
347 b'n' => b'\n',
348 b'r' => b'\r',
349 b't' => b'\t',
350 b'0' => 0,
351 _ => b, };
353 value.push(escaped);
354 escape_next = false;
355 self.pos += 1;
356 } else if b == b'\\' {
357 escape_next = true;
358 self.pos += 1;
359 } else if b == b'\'' {
360 if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
362 value.push(b'\'');
363 self.pos += 2;
364 } else {
365 self.pos += 1; break;
367 }
368 } else {
369 value.push(b);
370 self.pos += 1;
371 }
372 }
373
374 let text = String::from_utf8_lossy(&value).into_owned();
375
376 Ok(ParsedValue::String { value: text })
377 }
378
379 fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
381 let start = self.pos;
382 self.pos += 2; while self.pos < self.stmt.len() {
385 let b = self.stmt[self.pos];
386 if b.is_ascii_hexdigit() {
387 self.pos += 1;
388 } else {
389 break;
390 }
391 }
392
393 let raw = self.stmt[start..self.pos].to_vec();
394 Ok(ParsedValue::Hex(raw))
395 }
396
397 fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
399 let start = self.pos;
400 let mut has_dot = false;
401
402 if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
404 self.pos += 1;
405 }
406
407 while self.pos < self.stmt.len() {
408 let b = self.stmt[self.pos];
409 if b.is_ascii_digit() {
410 self.pos += 1;
411 } else if b == b'.' && !has_dot {
412 has_dot = true;
413 self.pos += 1;
414 } else if b == b'e' || b == b'E' {
415 self.pos += 1;
417 if self.pos < self.stmt.len()
418 && (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
419 {
420 self.pos += 1;
421 }
422 } else if b == b',' || b == b')' || b.is_ascii_whitespace() {
423 break;
424 } else {
425 while self.pos < self.stmt.len() {
427 let c = self.stmt[self.pos];
428 if c == b',' || c == b')' {
429 break;
430 }
431 self.pos += 1;
432 }
433 break;
434 }
435 }
436
437 let raw = self.stmt[start..self.pos].to_vec();
438 let value_str = String::from_utf8_lossy(&raw);
439
440 if !has_dot {
442 if let Ok(n) = value_str.parse::<i64>() {
443 return Ok(ParsedValue::Integer(n));
444 }
445 if let Ok(n) = value_str.parse::<i128>() {
446 return Ok(ParsedValue::BigInteger(n));
447 }
448 }
449
450 Ok(ParsedValue::Other(raw))
452 }
453
454 fn skip_whitespace(&mut self) {
456 while self.pos < self.stmt.len() {
457 let b = self.stmt[self.pos];
458 if b.is_ascii_whitespace() {
459 self.pos += 1;
460 } else {
461 break;
462 }
463 }
464 }
465
466 fn extract_pk_fk(
468 &self,
469 values: &[ParsedValue],
470 schema: &TableSchema,
471 ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>, Vec<PkValue>) {
472 let mut pk_values = PkTuple::new();
473 let mut fk_values = Vec::new();
474
475 let all_values: Vec<PkValue> = values
477 .iter()
478 .enumerate()
479 .map(|(idx, v)| {
480 let col = self
481 .column_order
482 .get(idx)
483 .and_then(|c| *c)
484 .and_then(|id| schema.column(id));
485 self.value_to_pk(v, col)
486 })
487 .collect();
488
489 for (idx, col_id_opt) in self.column_order.iter().enumerate() {
491 if let Some(col_id) = col_id_opt {
492 if schema.is_pk_column(*col_id) {
493 if let Some(value) = values.get(idx) {
494 let pk_val = self.value_to_pk(value, schema.column(*col_id));
495 pk_values.push(pk_val);
496 }
497 }
498 }
499 }
500
501 for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
503 if fk.referenced_table_id.is_none() {
504 continue;
505 }
506
507 let mut fk_tuple = PkTuple::new();
508 let mut all_non_null = true;
509
510 for &col_id in &fk.columns {
511 if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
513 if let Some(value) = values.get(idx) {
514 let pk_val = self.value_to_pk(value, schema.column(col_id));
515 if pk_val.is_null() {
516 all_non_null = false;
517 break;
518 }
519 fk_tuple.push(pk_val);
520 }
521 }
522 }
523
524 if all_non_null && !fk_tuple.is_empty() {
525 fk_values.push((
526 FkRef {
527 table_id: schema.id.0,
528 fk_index: fk_idx as u16,
529 },
530 fk_tuple,
531 ));
532 }
533 }
534
535 let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
536 None
537 } else {
538 Some(pk_values)
539 };
540
541 (pk, fk_values, all_values)
542 }
543
544 fn build_column_map(&self, schema: &TableSchema) -> Vec<Option<usize>> {
547 let mut map = vec![None; schema.columns.len()];
549
550 for (val_idx, col_id_opt) in self.column_order.iter().enumerate() {
551 if let Some(col_id) = col_id_opt {
552 let ordinal = col_id.0 as usize;
553 if ordinal < map.len() {
554 map[ordinal] = Some(val_idx);
555 }
556 }
557 }
558
559 map
560 }
561
562 fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
564 match value {
565 ParsedValue::Null => PkValue::Null,
566 ParsedValue::Integer(n) => PkValue::Int(*n),
567 ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
568 ParsedValue::String { value } => {
569 if let Some(col) = col {
571 match col.col_type {
572 ColumnType::Int => {
573 if let Ok(n) = value.parse::<i64>() {
574 return PkValue::Int(n);
575 }
576 }
577 ColumnType::BigInt => {
578 if let Ok(n) = value.parse::<i128>() {
579 return PkValue::BigInt(n);
580 }
581 }
582 _ => {}
583 }
584 }
585 PkValue::Text(value.clone().into_boxed_str())
586 }
587 ParsedValue::Hex(raw) => {
588 PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
589 }
590 ParsedValue::Other(raw) => {
591 PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
592 }
593 }
594 }
595}
596
597#[derive(Debug, Clone)]
601pub enum ParsedValue {
602 Null,
604 Integer(i64),
606 BigInteger(i128),
608 String { value: String },
610 Hex(Vec<u8>),
612 Other(Vec<u8>),
614}
615
616pub fn parse_mysql_insert_rows(
618 stmt: &[u8],
619 schema: &TableSchema,
620) -> anyhow::Result<Vec<ParsedRow>> {
621 let mut parser = InsertParser::new(stmt).with_schema(schema);
622 parser.parse_rows()
623}
624
625pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
627 let mut parser = InsertParser::new(stmt);
628 parser.parse_rows()
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_parse_insert_for_bulk_simple() {
637 let sql = b"INSERT INTO users VALUES (1, 'Alice')";
638 let result = parse_insert_for_bulk(sql).unwrap();
639 assert_eq!(result.table, "users");
640 assert!(result.columns.is_none());
641 assert_eq!(result.rows.len(), 1);
642 }
643
644 #[test]
645 fn test_parse_insert_for_bulk_with_columns() {
646 let sql = b"INSERT INTO users (name, id) VALUES ('Alice', 1)";
647 let result = parse_insert_for_bulk(sql).unwrap();
648 assert_eq!(result.table, "users");
649 assert_eq!(
650 result.columns,
651 Some(vec!["name".to_string(), "id".to_string()])
652 );
653 assert_eq!(result.rows.len(), 1);
654 }
655
656 #[test]
657 fn test_parse_insert_for_bulk_mssql() {
658 let sql =
659 b"INSERT INTO [dbo].[users] ([email], [name]) VALUES (N'alice@example.com', N'Alice')";
660 let result = parse_insert_for_bulk(sql).unwrap();
661 assert_eq!(result.table, "users");
662 assert_eq!(
663 result.columns,
664 Some(vec!["email".to_string(), "name".to_string()])
665 );
666 assert_eq!(result.rows.len(), 1);
667 }
668
669 #[test]
670 fn test_parse_insert_for_bulk_mysql() {
671 let sql = b"INSERT INTO `users` (`id`, `name`) VALUES (1, 'Bob')";
672 let result = parse_insert_for_bulk(sql).unwrap();
673 assert_eq!(result.table, "users");
674 assert_eq!(
675 result.columns,
676 Some(vec!["id".to_string(), "name".to_string()])
677 );
678 assert_eq!(result.rows.len(), 1);
679 }
680}
681
682#[derive(Debug, Clone)]
684pub struct InsertValues {
685 pub table: String,
687 pub columns: Option<Vec<String>>,
689 pub rows: Vec<Vec<ParsedValue>>,
691}
692
693pub fn parse_insert_for_bulk(stmt: &[u8]) -> anyhow::Result<InsertValues> {
699 let stmt_str = String::from_utf8_lossy(stmt);
700 let upper = stmt_str.to_uppercase();
701
702 let table = extract_insert_table_name(&stmt_str, &upper)?;
704
705 let columns = extract_column_list(&stmt_str, &upper);
707
708 let mut parser = InsertParser::new(stmt);
710 let parsed_rows = parser.parse_rows()?;
711
712 let rows = parsed_rows.into_iter().map(|r| r.values).collect();
713
714 Ok(InsertValues {
715 table,
716 columns,
717 rows,
718 })
719}
720
721fn extract_insert_table_name(stmt: &str, upper: &str) -> anyhow::Result<String> {
723 let start_pos = if let Some(pos) = upper.find("INSERT INTO") {
725 pos + 11 } else if let Some(pos) = upper.find("INSERT") {
727 pos + 6 } else {
729 anyhow::bail!("Not an INSERT statement");
730 };
731
732 let remaining = stmt[start_pos..].trim_start();
734
735 let table_ref = extract_table_reference(remaining)?;
737
738 if let Some(dot_pos) = table_ref.rfind('.') {
740 let table_part = &table_ref[dot_pos + 1..];
741 Ok(strip_identifier_quotes(table_part))
742 } else {
743 Ok(strip_identifier_quotes(&table_ref))
744 }
745}
746
747fn extract_table_reference(s: &str) -> anyhow::Result<String> {
749 let s = s.trim();
750
751 if s.is_empty() {
752 anyhow::bail!("Empty table reference");
753 }
754
755 let mut result = String::new();
756 let mut chars = s.chars().peekable();
757
758 while let Some(&c) = chars.peek() {
759 match c {
760 '[' => {
761 chars.next();
763 result.push('[');
764 while let Some(&inner) = chars.peek() {
765 chars.next();
766 result.push(inner);
767 if inner == ']' {
768 break;
769 }
770 }
771 }
772 '`' => {
773 chars.next();
775 result.push('`');
776 while let Some(&inner) = chars.peek() {
777 chars.next();
778 result.push(inner);
779 if inner == '`' {
780 break;
781 }
782 }
783 }
784 '"' => {
785 chars.next();
787 result.push('"');
788 while let Some(&inner) = chars.peek() {
789 chars.next();
790 result.push(inner);
791 if inner == '"' {
792 break;
793 }
794 }
795 }
796 '.' => {
797 chars.next();
799 result.push('.');
800 }
801 c if c.is_whitespace() || c == '(' || c == ',' => {
802 break;
804 }
805 _ => {
806 chars.next();
808 result.push(c);
809 }
810 }
811 }
812
813 if result.is_empty() {
814 anyhow::bail!("Empty table reference");
815 }
816
817 Ok(result)
818}
819
820fn strip_identifier_quotes(s: &str) -> String {
822 s.trim_matches('`')
823 .trim_matches('"')
824 .trim_matches('[')
825 .trim_matches(']')
826 .to_string()
827}
828
829fn extract_column_list(stmt: &str, upper: &str) -> Option<Vec<String>> {
831 let values_pos = upper.find("VALUES")?;
833 let before_values = &stmt[..values_pos];
834
835 let close_paren = before_values.rfind(')')?;
837 let open_paren = before_values[..close_paren].rfind('(')?;
838
839 let col_list = &before_values[open_paren + 1..close_paren];
840
841 let upper_cols = col_list.to_uppercase();
843 if col_list.trim().is_empty() || upper_cols.contains("SELECT") || upper_cols.contains("VALUES")
844 {
845 return None;
846 }
847
848 let columns: Vec<String> = col_list
849 .split(',')
850 .map(|c| {
851 c.trim()
852 .trim_matches('`')
853 .trim_matches('"')
854 .trim_matches('[')
855 .trim_matches(']')
856 .to_string()
857 })
858 .collect();
859
860 if columns.is_empty() {
861 None
862 } else {
863 Some(columns)
864 }
865}