sql_splitter/parser/
mysql_insert.rs

1//! MySQL INSERT statement row parser.
2//!
3//! Parses INSERT INTO ... VALUES statements to extract individual rows
4//! and optionally extract PK/FK column values for dependency tracking.
5
6use crate::schema::{ColumnId, ColumnType, TableSchema};
7use ahash::AHashSet;
8use smallvec::SmallVec;
9
10/// Primary key value representation supporting common types
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum PkValue {
13    /// Integer value (covers most PKs)
14    Int(i64),
15    /// Big integer value
16    BigInt(i128),
17    /// Text/string value
18    Text(Box<str>),
19    /// NULL value (typically means "no dependency" for FKs)
20    Null,
21}
22
23impl PkValue {
24    /// Check if this is a NULL value
25    pub fn is_null(&self) -> bool {
26        matches!(self, PkValue::Null)
27    }
28}
29
30/// Tuple of PK values for composite primary keys
31pub type PkTuple = SmallVec<[PkValue; 2]>;
32
33/// Set of primary key values for a table (stores full tuples)
34pub type PkSet = AHashSet<PkTuple>;
35
36/// Compact hash-based set of primary keys for memory efficiency.
37/// Uses 64-bit hashes instead of full values - suitable for large datasets
38/// where collision risk is acceptable (sampling, validation).
39pub type PkHashSet = AHashSet<u64>;
40
41/// Hash a PK tuple into a compact 64-bit hash for memory-efficient storage.
42/// Uses AHash for fast, high-quality hashing.
43pub fn hash_pk_tuple(pk: &PkTuple) -> u64 {
44    use std::hash::{Hash, Hasher};
45    let mut hasher = ahash::AHasher::default();
46
47    // Include arity (number of columns) in the hash
48    (pk.len() as u8).hash(&mut hasher);
49
50    for v in pk {
51        match v {
52            PkValue::Int(i) => {
53                0u8.hash(&mut hasher);
54                i.hash(&mut hasher);
55            }
56            PkValue::BigInt(i) => {
57                1u8.hash(&mut hasher);
58                i.hash(&mut hasher);
59            }
60            PkValue::Text(s) => {
61                2u8.hash(&mut hasher);
62                s.hash(&mut hasher);
63            }
64            PkValue::Null => {
65                3u8.hash(&mut hasher);
66            }
67        }
68    }
69
70    hasher.finish()
71}
72
73/// Reference to a specific foreign key in a table
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub struct FkRef {
76    /// Table containing the FK
77    pub table_id: u32,
78    /// Index of the FK within the table's foreign_keys vector
79    pub fk_index: u16,
80}
81
82/// A parsed row from an INSERT statement
83#[derive(Debug, Clone)]
84pub struct ParsedRow {
85    /// Raw bytes of the row value list: "(val1, val2, ...)"
86    pub raw: Vec<u8>,
87    /// Extracted primary key values (if table has PK and values are non-NULL)
88    pub pk: Option<PkTuple>,
89    /// Extracted foreign key values with their references
90    /// Only includes FKs where all columns are non-NULL
91    pub fk_values: Vec<(FkRef, PkTuple)>,
92}
93
94/// Parser for MySQL INSERT statements
95pub struct InsertParser<'a> {
96    stmt: &'a [u8],
97    pos: usize,
98    table_schema: Option<&'a TableSchema>,
99    /// Column order in the INSERT (maps value index -> column ID)
100    column_order: Vec<Option<ColumnId>>,
101}
102
103impl<'a> InsertParser<'a> {
104    /// Create a new parser for an INSERT statement
105    pub fn new(stmt: &'a [u8]) -> Self {
106        Self {
107            stmt,
108            pos: 0,
109            table_schema: None,
110            column_order: Vec::new(),
111        }
112    }
113
114    /// Set the table schema for PK/FK extraction
115    pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
116        self.table_schema = Some(schema);
117        self
118    }
119
120    /// Parse all rows from the INSERT statement
121    pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
122        // Find the VALUES keyword
123        let values_pos = self.find_values_keyword()?;
124        self.pos = values_pos;
125
126        // Parse column list if present
127        self.parse_column_list();
128
129        // Parse each row
130        let mut rows = Vec::new();
131        while self.pos < self.stmt.len() {
132            self.skip_whitespace();
133
134            if self.pos >= self.stmt.len() {
135                break;
136            }
137
138            if self.stmt[self.pos] == b'(' {
139                if let Some(row) = self.parse_row()? {
140                    rows.push(row);
141                }
142            } else if self.stmt[self.pos] == b',' {
143                self.pos += 1;
144            } else if self.stmt[self.pos] == b';' {
145                break;
146            } else {
147                self.pos += 1;
148            }
149        }
150
151        Ok(rows)
152    }
153
154    /// Find the VALUES keyword and return position after it
155    fn find_values_keyword(&self) -> anyhow::Result<usize> {
156        let stmt_str = String::from_utf8_lossy(self.stmt);
157        let upper = stmt_str.to_uppercase();
158
159        if let Some(pos) = upper.find("VALUES") {
160            Ok(pos + 6) // Length of "VALUES"
161        } else {
162            anyhow::bail!("INSERT statement missing VALUES keyword")
163        }
164    }
165
166    /// Parse optional column list after INSERT INTO table_name
167    fn parse_column_list(&mut self) {
168        if self.table_schema.is_none() {
169            return;
170        }
171
172        let schema = self.table_schema.unwrap();
173
174        // Look for column list between table name and VALUES
175        // We need to look backwards from current position (after VALUES)
176        let before_values = &self.stmt[..self.pos.saturating_sub(6)];
177        let stmt_str = String::from_utf8_lossy(before_values);
178
179        // Find the last (...) before VALUES
180        if let Some(close_paren) = stmt_str.rfind(')') {
181            if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
182                let col_list = &stmt_str[open_paren + 1..close_paren];
183                // Check if this looks like a column list (no VALUES, etc.)
184                if !col_list.to_uppercase().contains("SELECT") {
185                    let cols: Vec<&str> = col_list.split(',').collect();
186                    self.column_order = cols
187                        .iter()
188                        .map(|c| {
189                            let name = c.trim().trim_matches('`').trim_matches('"');
190                            schema.get_column_id(name)
191                        })
192                        .collect();
193                    return;
194                }
195            }
196        }
197
198        // No explicit column list - use natural order
199        self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
200    }
201
202    /// Parse a single row "(val1, val2, ...)"
203    fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
204        self.skip_whitespace();
205
206        if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
207            return Ok(None);
208        }
209
210        let start = self.pos;
211        self.pos += 1; // Skip '('
212
213        let mut values: Vec<ParsedValue> = Vec::new();
214        let mut depth = 1;
215
216        while self.pos < self.stmt.len() && depth > 0 {
217            self.skip_whitespace();
218
219            if self.pos >= self.stmt.len() {
220                break;
221            }
222
223            match self.stmt[self.pos] {
224                b'(' => {
225                    depth += 1;
226                    self.pos += 1;
227                }
228                b')' => {
229                    depth -= 1;
230                    self.pos += 1;
231                }
232                b',' if depth == 1 => {
233                    self.pos += 1;
234                }
235                _ if depth == 1 => {
236                    values.push(self.parse_value()?);
237                }
238                _ => {
239                    self.pos += 1;
240                }
241            }
242        }
243
244        let end = self.pos;
245        let raw = self.stmt[start..end].to_vec();
246
247        // Extract PK and FK values if we have a schema
248        let (pk, fk_values) = if let Some(schema) = self.table_schema {
249            self.extract_pk_fk(&values, schema)
250        } else {
251            (None, Vec::new())
252        };
253
254        Ok(Some(ParsedRow { raw, pk, fk_values }))
255    }
256
257    /// Parse a single value (string, number, NULL, etc.)
258    fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
259        self.skip_whitespace();
260
261        if self.pos >= self.stmt.len() {
262            return Ok(ParsedValue::Null);
263        }
264
265        let b = self.stmt[self.pos];
266
267        // NULL
268        if self.pos + 4 <= self.stmt.len() {
269            let word = &self.stmt[self.pos..self.pos + 4];
270            if word.eq_ignore_ascii_case(b"NULL") {
271                self.pos += 4;
272                return Ok(ParsedValue::Null);
273            }
274        }
275
276        // String literal
277        if b == b'\'' {
278            return self.parse_string_value();
279        }
280
281        // Hex literal (0x...)
282        if b == b'0' && self.pos + 1 < self.stmt.len() {
283            let next = self.stmt[self.pos + 1];
284            if next == b'x' || next == b'X' {
285                return self.parse_hex_value();
286            }
287        }
288
289        // Number or expression
290        self.parse_number_value()
291    }
292
293    /// Parse a string literal 'value'
294    fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
295        self.pos += 1; // Skip opening quote
296
297        let mut value = Vec::new();
298        let mut escape_next = false;
299
300        while self.pos < self.stmt.len() {
301            let b = self.stmt[self.pos];
302
303            if escape_next {
304                // Handle MySQL escape sequences
305                let escaped = match b {
306                    b'n' => b'\n',
307                    b'r' => b'\r',
308                    b't' => b'\t',
309                    b'0' => 0,
310                    _ => b, // \', \\, etc.
311                };
312                value.push(escaped);
313                escape_next = false;
314                self.pos += 1;
315            } else if b == b'\\' {
316                escape_next = true;
317                self.pos += 1;
318            } else if b == b'\'' {
319                // Check for escaped quote ''
320                if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
321                    value.push(b'\'');
322                    self.pos += 2;
323                } else {
324                    self.pos += 1; // End of string
325                    break;
326                }
327            } else {
328                value.push(b);
329                self.pos += 1;
330            }
331        }
332
333        let text = String::from_utf8_lossy(&value).into_owned();
334
335        Ok(ParsedValue::String { value: text })
336    }
337
338    /// Parse a hex literal 0xABCD...
339    fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
340        let start = self.pos;
341        self.pos += 2; // Skip 0x
342
343        while self.pos < self.stmt.len() {
344            let b = self.stmt[self.pos];
345            if b.is_ascii_hexdigit() {
346                self.pos += 1;
347            } else {
348                break;
349            }
350        }
351
352        let raw = self.stmt[start..self.pos].to_vec();
353        Ok(ParsedValue::Hex(raw))
354    }
355
356    /// Parse a number or other non-string value
357    fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
358        let start = self.pos;
359        let mut has_dot = false;
360
361        // Handle leading minus
362        if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
363            self.pos += 1;
364        }
365
366        while self.pos < self.stmt.len() {
367            let b = self.stmt[self.pos];
368            if b.is_ascii_digit() {
369                self.pos += 1;
370            } else if b == b'.' && !has_dot {
371                has_dot = true;
372                self.pos += 1;
373            } else if b == b'e' || b == b'E' {
374                // Scientific notation
375                self.pos += 1;
376                if self.pos < self.stmt.len()
377                    && (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
378                {
379                    self.pos += 1;
380                }
381            } else if b == b',' || b == b')' || b.is_ascii_whitespace() {
382                break;
383            } else {
384                // Unknown character in number, skip to next delimiter
385                while self.pos < self.stmt.len() {
386                    let c = self.stmt[self.pos];
387                    if c == b',' || c == b')' {
388                        break;
389                    }
390                    self.pos += 1;
391                }
392                break;
393            }
394        }
395
396        let raw = self.stmt[start..self.pos].to_vec();
397        let value_str = String::from_utf8_lossy(&raw);
398
399        // Try to parse as integer
400        if !has_dot {
401            if let Ok(n) = value_str.parse::<i64>() {
402                return Ok(ParsedValue::Integer(n));
403            }
404            if let Ok(n) = value_str.parse::<i128>() {
405                return Ok(ParsedValue::BigInteger(n));
406            }
407        }
408
409        // Fall back to raw value
410        Ok(ParsedValue::Other(raw))
411    }
412
413    /// Skip whitespace and newlines
414    fn skip_whitespace(&mut self) {
415        while self.pos < self.stmt.len() {
416            let b = self.stmt[self.pos];
417            if b.is_ascii_whitespace() {
418                self.pos += 1;
419            } else {
420                break;
421            }
422        }
423    }
424
425    /// Extract PK and FK values from parsed values
426    fn extract_pk_fk(
427        &self,
428        values: &[ParsedValue],
429        schema: &TableSchema,
430    ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>) {
431        let mut pk_values = PkTuple::new();
432        let mut fk_values = Vec::new();
433
434        // Build PK from columns marked as primary key
435        for (idx, col_id_opt) in self.column_order.iter().enumerate() {
436            if let Some(col_id) = col_id_opt {
437                if schema.is_pk_column(*col_id) {
438                    if let Some(value) = values.get(idx) {
439                        let pk_val = self.value_to_pk(value, schema.column(*col_id));
440                        pk_values.push(pk_val);
441                    }
442                }
443            }
444        }
445
446        // Build FK tuples
447        for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
448            if fk.referenced_table_id.is_none() {
449                continue;
450            }
451
452            let mut fk_tuple = PkTuple::new();
453            let mut all_non_null = true;
454
455            for &col_id in &fk.columns {
456                // Find the value index for this column
457                if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
458                    if let Some(value) = values.get(idx) {
459                        let pk_val = self.value_to_pk(value, schema.column(col_id));
460                        if pk_val.is_null() {
461                            all_non_null = false;
462                            break;
463                        }
464                        fk_tuple.push(pk_val);
465                    }
466                }
467            }
468
469            if all_non_null && !fk_tuple.is_empty() {
470                fk_values.push((
471                    FkRef {
472                        table_id: schema.id.0,
473                        fk_index: fk_idx as u16,
474                    },
475                    fk_tuple,
476                ));
477            }
478        }
479
480        let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
481            None
482        } else {
483            Some(pk_values)
484        };
485
486        (pk, fk_values)
487    }
488
489    /// Convert a parsed value to a PkValue
490    fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
491        match value {
492            ParsedValue::Null => PkValue::Null,
493            ParsedValue::Integer(n) => PkValue::Int(*n),
494            ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
495            ParsedValue::String { value } => {
496                // Check if this might be an integer stored as string
497                if let Some(col) = col {
498                    match col.col_type {
499                        ColumnType::Int => {
500                            if let Ok(n) = value.parse::<i64>() {
501                                return PkValue::Int(n);
502                            }
503                        }
504                        ColumnType::BigInt => {
505                            if let Ok(n) = value.parse::<i128>() {
506                                return PkValue::BigInt(n);
507                            }
508                        }
509                        _ => {}
510                    }
511                }
512                PkValue::Text(value.clone().into_boxed_str())
513            }
514            ParsedValue::Hex(raw) => {
515                PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
516            }
517            ParsedValue::Other(raw) => {
518                PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
519            }
520        }
521    }
522}
523
524/// Internal representation of a parsed value
525#[derive(Debug, Clone)]
526enum ParsedValue {
527    Null,
528    Integer(i64),
529    BigInteger(i128),
530    String { value: String },
531    Hex(Vec<u8>),
532    Other(Vec<u8>),
533}
534
535/// Parse all rows from a MySQL INSERT statement
536pub fn parse_mysql_insert_rows(
537    stmt: &[u8],
538    schema: &TableSchema,
539) -> anyhow::Result<Vec<ParsedRow>> {
540    let mut parser = InsertParser::new(stmt).with_schema(schema);
541    parser.parse_rows()
542}
543
544/// Parse rows without schema (just raw row extraction)
545pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
546    let mut parser = InsertParser::new(stmt);
547    parser.parse_rows()
548}