sql_splitter/parser/
mysql_insert.rs

1//! MySQL INSERT statement row parser.
2//!
3//! Parses INSERT INTO ... VALUES statements to extract individual rows
4//! and optionally extract PK/FK column values for dependency tracking.
5
6use crate::schema::{ColumnId, ColumnType, TableSchema};
7use ahash::AHashSet;
8use smallvec::SmallVec;
9
10/// Primary key value representation supporting common types
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum PkValue {
13    /// Integer value (covers most PKs)
14    Int(i64),
15    /// Big integer value
16    BigInt(i128),
17    /// Text/string value
18    Text(Box<str>),
19    /// NULL value (typically means "no dependency" for FKs)
20    Null,
21}
22
23impl PkValue {
24    /// Check if this is a NULL value
25    pub fn is_null(&self) -> bool {
26        matches!(self, PkValue::Null)
27    }
28}
29
30/// Tuple of PK values for composite primary keys
31pub type PkTuple = SmallVec<[PkValue; 2]>;
32
33/// Set of primary key values for a table (stores full tuples)
34pub type PkSet = AHashSet<PkTuple>;
35
36/// Compact hash-based set of primary keys for memory efficiency.
37/// Uses 64-bit hashes instead of full values - suitable for large datasets
38/// where collision risk is acceptable (sampling, validation).
39pub type PkHashSet = AHashSet<u64>;
40
41/// Hash a PK tuple into a compact 64-bit hash for memory-efficient storage.
42/// Uses AHash for fast, high-quality hashing.
43pub fn hash_pk_tuple(pk: &PkTuple) -> u64 {
44    use std::hash::{Hash, Hasher};
45    let mut hasher = ahash::AHasher::default();
46
47    // Include arity (number of columns) in the hash
48    (pk.len() as u8).hash(&mut hasher);
49
50    for v in pk {
51        match v {
52            PkValue::Int(i) => {
53                0u8.hash(&mut hasher);
54                i.hash(&mut hasher);
55            }
56            PkValue::BigInt(i) => {
57                1u8.hash(&mut hasher);
58                i.hash(&mut hasher);
59            }
60            PkValue::Text(s) => {
61                2u8.hash(&mut hasher);
62                s.hash(&mut hasher);
63            }
64            PkValue::Null => {
65                3u8.hash(&mut hasher);
66            }
67        }
68    }
69
70    hasher.finish()
71}
72
73/// Reference to a specific foreign key in a table
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub struct FkRef {
76    /// Table containing the FK
77    pub table_id: u32,
78    /// Index of the FK within the table's foreign_keys vector
79    pub fk_index: u16,
80}
81
82/// A parsed row from an INSERT statement
83#[derive(Debug, Clone)]
84pub struct ParsedRow {
85    /// Raw bytes of the row value list: "(val1, val2, ...)"
86    pub raw: Vec<u8>,
87    /// Extracted primary key values (if table has PK and values are non-NULL)
88    pub pk: Option<PkTuple>,
89    /// Extracted foreign key values with their references
90    /// Only includes FKs where all columns are non-NULL
91    pub fk_values: Vec<(FkRef, PkTuple)>,
92    /// All column values (for data diff comparison)
93    pub all_values: Vec<PkValue>,
94}
95
96/// Parser for MySQL INSERT statements
97pub struct InsertParser<'a> {
98    stmt: &'a [u8],
99    pos: usize,
100    table_schema: Option<&'a TableSchema>,
101    /// Column order in the INSERT (maps value index -> column ID)
102    column_order: Vec<Option<ColumnId>>,
103}
104
105impl<'a> InsertParser<'a> {
106    /// Create a new parser for an INSERT statement
107    pub fn new(stmt: &'a [u8]) -> Self {
108        Self {
109            stmt,
110            pos: 0,
111            table_schema: None,
112            column_order: Vec::new(),
113        }
114    }
115
116    /// Set the table schema for PK/FK extraction
117    pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
118        self.table_schema = Some(schema);
119        self
120    }
121
122    /// Parse all rows from the INSERT statement
123    pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
124        // Find the VALUES keyword
125        let values_pos = self.find_values_keyword()?;
126        self.pos = values_pos;
127
128        // Parse column list if present
129        self.parse_column_list();
130
131        // Parse each row
132        let mut rows = Vec::new();
133        while self.pos < self.stmt.len() {
134            self.skip_whitespace();
135
136            if self.pos >= self.stmt.len() {
137                break;
138            }
139
140            if self.stmt[self.pos] == b'(' {
141                if let Some(row) = self.parse_row()? {
142                    rows.push(row);
143                }
144            } else if self.stmt[self.pos] == b',' {
145                self.pos += 1;
146            } else if self.stmt[self.pos] == b';' {
147                break;
148            } else {
149                self.pos += 1;
150            }
151        }
152
153        Ok(rows)
154    }
155
156    /// Find the VALUES keyword and return position after it
157    fn find_values_keyword(&self) -> anyhow::Result<usize> {
158        let stmt_str = String::from_utf8_lossy(self.stmt);
159        let upper = stmt_str.to_uppercase();
160
161        if let Some(pos) = upper.find("VALUES") {
162            Ok(pos + 6) // Length of "VALUES"
163        } else {
164            anyhow::bail!("INSERT statement missing VALUES keyword")
165        }
166    }
167
168    /// Parse optional column list after INSERT INTO table_name
169    fn parse_column_list(&mut self) {
170        if self.table_schema.is_none() {
171            return;
172        }
173
174        let schema = self.table_schema.unwrap();
175
176        // Look for column list between table name and VALUES
177        // We need to look backwards from current position (after VALUES)
178        let before_values = &self.stmt[..self.pos.saturating_sub(6)];
179        let stmt_str = String::from_utf8_lossy(before_values);
180
181        // Find the last (...) before VALUES
182        if let Some(close_paren) = stmt_str.rfind(')') {
183            if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
184                let col_list = &stmt_str[open_paren + 1..close_paren];
185                // Check if this looks like a column list (no VALUES, etc.)
186                if !col_list.to_uppercase().contains("SELECT") {
187                    let cols: Vec<&str> = col_list.split(',').collect();
188                    self.column_order = cols
189                        .iter()
190                        .map(|c| {
191                            let name = c.trim().trim_matches('`').trim_matches('"');
192                            schema.get_column_id(name)
193                        })
194                        .collect();
195                    return;
196                }
197            }
198        }
199
200        // No explicit column list - use natural order
201        self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
202    }
203
204    /// Parse a single row "(val1, val2, ...)"
205    fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
206        self.skip_whitespace();
207
208        if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
209            return Ok(None);
210        }
211
212        let start = self.pos;
213        self.pos += 1; // Skip '('
214
215        let mut values: Vec<ParsedValue> = Vec::new();
216        let mut depth = 1;
217
218        while self.pos < self.stmt.len() && depth > 0 {
219            self.skip_whitespace();
220
221            if self.pos >= self.stmt.len() {
222                break;
223            }
224
225            match self.stmt[self.pos] {
226                b'(' => {
227                    depth += 1;
228                    self.pos += 1;
229                }
230                b')' => {
231                    depth -= 1;
232                    self.pos += 1;
233                }
234                b',' if depth == 1 => {
235                    self.pos += 1;
236                }
237                _ if depth == 1 => {
238                    values.push(self.parse_value()?);
239                }
240                _ => {
241                    self.pos += 1;
242                }
243            }
244        }
245
246        let end = self.pos;
247        let raw = self.stmt[start..end].to_vec();
248
249        // Extract PK, FK, and all values if we have a schema
250        let (pk, fk_values, all_values) = if let Some(schema) = self.table_schema {
251            self.extract_pk_fk(&values, schema)
252        } else {
253            (None, Vec::new(), Vec::new())
254        };
255
256        Ok(Some(ParsedRow {
257            raw,
258            pk,
259            fk_values,
260            all_values,
261        }))
262    }
263
264    /// Parse a single value (string, number, NULL, etc.)
265    fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
266        self.skip_whitespace();
267
268        if self.pos >= self.stmt.len() {
269            return Ok(ParsedValue::Null);
270        }
271
272        let b = self.stmt[self.pos];
273
274        // NULL
275        if self.pos + 4 <= self.stmt.len() {
276            let word = &self.stmt[self.pos..self.pos + 4];
277            if word.eq_ignore_ascii_case(b"NULL") {
278                self.pos += 4;
279                return Ok(ParsedValue::Null);
280            }
281        }
282
283        // String literal
284        if b == b'\'' {
285            return self.parse_string_value();
286        }
287
288        // Hex literal (0x...)
289        if b == b'0' && self.pos + 1 < self.stmt.len() {
290            let next = self.stmt[self.pos + 1];
291            if next == b'x' || next == b'X' {
292                return self.parse_hex_value();
293            }
294        }
295
296        // Number or expression
297        self.parse_number_value()
298    }
299
300    /// Parse a string literal 'value'
301    fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
302        self.pos += 1; // Skip opening quote
303
304        let mut value = Vec::new();
305        let mut escape_next = false;
306
307        while self.pos < self.stmt.len() {
308            let b = self.stmt[self.pos];
309
310            if escape_next {
311                // Handle MySQL escape sequences
312                let escaped = match b {
313                    b'n' => b'\n',
314                    b'r' => b'\r',
315                    b't' => b'\t',
316                    b'0' => 0,
317                    _ => b, // \', \\, etc.
318                };
319                value.push(escaped);
320                escape_next = false;
321                self.pos += 1;
322            } else if b == b'\\' {
323                escape_next = true;
324                self.pos += 1;
325            } else if b == b'\'' {
326                // Check for escaped quote ''
327                if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
328                    value.push(b'\'');
329                    self.pos += 2;
330                } else {
331                    self.pos += 1; // End of string
332                    break;
333                }
334            } else {
335                value.push(b);
336                self.pos += 1;
337            }
338        }
339
340        let text = String::from_utf8_lossy(&value).into_owned();
341
342        Ok(ParsedValue::String { value: text })
343    }
344
345    /// Parse a hex literal 0xABCD...
346    fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
347        let start = self.pos;
348        self.pos += 2; // Skip 0x
349
350        while self.pos < self.stmt.len() {
351            let b = self.stmt[self.pos];
352            if b.is_ascii_hexdigit() {
353                self.pos += 1;
354            } else {
355                break;
356            }
357        }
358
359        let raw = self.stmt[start..self.pos].to_vec();
360        Ok(ParsedValue::Hex(raw))
361    }
362
363    /// Parse a number or other non-string value
364    fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
365        let start = self.pos;
366        let mut has_dot = false;
367
368        // Handle leading minus
369        if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
370            self.pos += 1;
371        }
372
373        while self.pos < self.stmt.len() {
374            let b = self.stmt[self.pos];
375            if b.is_ascii_digit() {
376                self.pos += 1;
377            } else if b == b'.' && !has_dot {
378                has_dot = true;
379                self.pos += 1;
380            } else if b == b'e' || b == b'E' {
381                // Scientific notation
382                self.pos += 1;
383                if self.pos < self.stmt.len()
384                    && (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
385                {
386                    self.pos += 1;
387                }
388            } else if b == b',' || b == b')' || b.is_ascii_whitespace() {
389                break;
390            } else {
391                // Unknown character in number, skip to next delimiter
392                while self.pos < self.stmt.len() {
393                    let c = self.stmt[self.pos];
394                    if c == b',' || c == b')' {
395                        break;
396                    }
397                    self.pos += 1;
398                }
399                break;
400            }
401        }
402
403        let raw = self.stmt[start..self.pos].to_vec();
404        let value_str = String::from_utf8_lossy(&raw);
405
406        // Try to parse as integer
407        if !has_dot {
408            if let Ok(n) = value_str.parse::<i64>() {
409                return Ok(ParsedValue::Integer(n));
410            }
411            if let Ok(n) = value_str.parse::<i128>() {
412                return Ok(ParsedValue::BigInteger(n));
413            }
414        }
415
416        // Fall back to raw value
417        Ok(ParsedValue::Other(raw))
418    }
419
420    /// Skip whitespace and newlines
421    fn skip_whitespace(&mut self) {
422        while self.pos < self.stmt.len() {
423            let b = self.stmt[self.pos];
424            if b.is_ascii_whitespace() {
425                self.pos += 1;
426            } else {
427                break;
428            }
429        }
430    }
431
432    /// Extract PK, FK, and all values from parsed values
433    fn extract_pk_fk(
434        &self,
435        values: &[ParsedValue],
436        schema: &TableSchema,
437    ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>, Vec<PkValue>) {
438        let mut pk_values = PkTuple::new();
439        let mut fk_values = Vec::new();
440
441        // Build all_values: convert each value to PkValue
442        let all_values: Vec<PkValue> = values
443            .iter()
444            .enumerate()
445            .map(|(idx, v)| {
446                let col = self
447                    .column_order
448                    .get(idx)
449                    .and_then(|c| *c)
450                    .and_then(|id| schema.column(id));
451                self.value_to_pk(v, col)
452            })
453            .collect();
454
455        // Build PK from columns marked as primary key
456        for (idx, col_id_opt) in self.column_order.iter().enumerate() {
457            if let Some(col_id) = col_id_opt {
458                if schema.is_pk_column(*col_id) {
459                    if let Some(value) = values.get(idx) {
460                        let pk_val = self.value_to_pk(value, schema.column(*col_id));
461                        pk_values.push(pk_val);
462                    }
463                }
464            }
465        }
466
467        // Build FK tuples
468        for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
469            if fk.referenced_table_id.is_none() {
470                continue;
471            }
472
473            let mut fk_tuple = PkTuple::new();
474            let mut all_non_null = true;
475
476            for &col_id in &fk.columns {
477                // Find the value index for this column
478                if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
479                    if let Some(value) = values.get(idx) {
480                        let pk_val = self.value_to_pk(value, schema.column(col_id));
481                        if pk_val.is_null() {
482                            all_non_null = false;
483                            break;
484                        }
485                        fk_tuple.push(pk_val);
486                    }
487                }
488            }
489
490            if all_non_null && !fk_tuple.is_empty() {
491                fk_values.push((
492                    FkRef {
493                        table_id: schema.id.0,
494                        fk_index: fk_idx as u16,
495                    },
496                    fk_tuple,
497                ));
498            }
499        }
500
501        let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
502            None
503        } else {
504            Some(pk_values)
505        };
506
507        (pk, fk_values, all_values)
508    }
509
510    /// Convert a parsed value to a PkValue
511    fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
512        match value {
513            ParsedValue::Null => PkValue::Null,
514            ParsedValue::Integer(n) => PkValue::Int(*n),
515            ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
516            ParsedValue::String { value } => {
517                // Check if this might be an integer stored as string
518                if let Some(col) = col {
519                    match col.col_type {
520                        ColumnType::Int => {
521                            if let Ok(n) = value.parse::<i64>() {
522                                return PkValue::Int(n);
523                            }
524                        }
525                        ColumnType::BigInt => {
526                            if let Ok(n) = value.parse::<i128>() {
527                                return PkValue::BigInt(n);
528                            }
529                        }
530                        _ => {}
531                    }
532                }
533                PkValue::Text(value.clone().into_boxed_str())
534            }
535            ParsedValue::Hex(raw) => {
536                PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
537            }
538            ParsedValue::Other(raw) => {
539                PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
540            }
541        }
542    }
543}
544
545/// Internal representation of a parsed value
546#[derive(Debug, Clone)]
547enum ParsedValue {
548    Null,
549    Integer(i64),
550    BigInteger(i128),
551    String { value: String },
552    Hex(Vec<u8>),
553    Other(Vec<u8>),
554}
555
556/// Parse all rows from a MySQL INSERT statement
557pub fn parse_mysql_insert_rows(
558    stmt: &[u8],
559    schema: &TableSchema,
560) -> anyhow::Result<Vec<ParsedRow>> {
561    let mut parser = InsertParser::new(stmt).with_schema(schema);
562    parser.parse_rows()
563}
564
565/// Parse rows without schema (just raw row extraction)
566pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
567    let mut parser = InsertParser::new(stmt);
568    parser.parse_rows()
569}