sql_splitter/parser/
mysql_insert.rs

1//! MySQL INSERT statement row parser.
2//!
3//! Parses INSERT INTO ... VALUES statements to extract individual rows
4//! and optionally extract PK/FK column values for dependency tracking.
5
6use crate::schema::{ColumnId, ColumnType, TableSchema};
7use ahash::AHashSet;
8use smallvec::SmallVec;
9
10/// Primary key value representation supporting common types
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum PkValue {
13    /// Integer value (covers most PKs)
14    Int(i64),
15    /// Big integer value
16    BigInt(i128),
17    /// Text/string value
18    Text(Box<str>),
19    /// NULL value (typically means "no dependency" for FKs)
20    Null,
21}
22
23impl PkValue {
24    /// Check if this is a NULL value
25    pub fn is_null(&self) -> bool {
26        matches!(self, PkValue::Null)
27    }
28}
29
30/// Tuple of PK values for composite primary keys
31pub type PkTuple = SmallVec<[PkValue; 2]>;
32
33/// Set of primary key values for a table
34pub type PkSet = AHashSet<PkTuple>;
35
36/// Reference to a specific foreign key in a table
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub struct FkRef {
39    /// Table containing the FK
40    pub table_id: u32,
41    /// Index of the FK within the table's foreign_keys vector
42    pub fk_index: u16,
43}
44
45/// A parsed row from an INSERT statement
46#[derive(Debug, Clone)]
47pub struct ParsedRow {
48    /// Raw bytes of the row value list: "(val1, val2, ...)"
49    pub raw: Vec<u8>,
50    /// Extracted primary key values (if table has PK and values are non-NULL)
51    pub pk: Option<PkTuple>,
52    /// Extracted foreign key values with their references
53    /// Only includes FKs where all columns are non-NULL
54    pub fk_values: Vec<(FkRef, PkTuple)>,
55}
56
57/// Parser for MySQL INSERT statements
58pub struct InsertParser<'a> {
59    stmt: &'a [u8],
60    pos: usize,
61    table_schema: Option<&'a TableSchema>,
62    /// Column order in the INSERT (maps value index -> column ID)
63    column_order: Vec<Option<ColumnId>>,
64}
65
66impl<'a> InsertParser<'a> {
67    /// Create a new parser for an INSERT statement
68    pub fn new(stmt: &'a [u8]) -> Self {
69        Self {
70            stmt,
71            pos: 0,
72            table_schema: None,
73            column_order: Vec::new(),
74        }
75    }
76
77    /// Set the table schema for PK/FK extraction
78    pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
79        self.table_schema = Some(schema);
80        self
81    }
82
83    /// Parse all rows from the INSERT statement
84    pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
85        // Find the VALUES keyword
86        let values_pos = self.find_values_keyword()?;
87        self.pos = values_pos;
88
89        // Parse column list if present
90        self.parse_column_list();
91
92        // Parse each row
93        let mut rows = Vec::new();
94        while self.pos < self.stmt.len() {
95            self.skip_whitespace();
96
97            if self.pos >= self.stmt.len() {
98                break;
99            }
100
101            if self.stmt[self.pos] == b'(' {
102                if let Some(row) = self.parse_row()? {
103                    rows.push(row);
104                }
105            } else if self.stmt[self.pos] == b',' {
106                self.pos += 1;
107            } else if self.stmt[self.pos] == b';' {
108                break;
109            } else {
110                self.pos += 1;
111            }
112        }
113
114        Ok(rows)
115    }
116
117    /// Find the VALUES keyword and return position after it
118    fn find_values_keyword(&self) -> anyhow::Result<usize> {
119        let stmt_str = String::from_utf8_lossy(self.stmt);
120        let upper = stmt_str.to_uppercase();
121
122        if let Some(pos) = upper.find("VALUES") {
123            Ok(pos + 6) // Length of "VALUES"
124        } else {
125            anyhow::bail!("INSERT statement missing VALUES keyword")
126        }
127    }
128
129    /// Parse optional column list after INSERT INTO table_name
130    fn parse_column_list(&mut self) {
131        if self.table_schema.is_none() {
132            return;
133        }
134
135        let schema = self.table_schema.unwrap();
136
137        // Look for column list between table name and VALUES
138        // We need to look backwards from current position (after VALUES)
139        let before_values = &self.stmt[..self.pos.saturating_sub(6)];
140        let stmt_str = String::from_utf8_lossy(before_values);
141
142        // Find the last (...) before VALUES
143        if let Some(close_paren) = stmt_str.rfind(')') {
144            if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
145                let col_list = &stmt_str[open_paren + 1..close_paren];
146                // Check if this looks like a column list (no VALUES, etc.)
147                if !col_list.to_uppercase().contains("SELECT") {
148                    let cols: Vec<&str> = col_list.split(',').collect();
149                    self.column_order = cols
150                        .iter()
151                        .map(|c| {
152                            let name = c.trim().trim_matches('`').trim_matches('"');
153                            schema.get_column_id(name)
154                        })
155                        .collect();
156                    return;
157                }
158            }
159        }
160
161        // No explicit column list - use natural order
162        self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
163    }
164
165    /// Parse a single row "(val1, val2, ...)"
166    fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
167        self.skip_whitespace();
168
169        if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
170            return Ok(None);
171        }
172
173        let start = self.pos;
174        self.pos += 1; // Skip '('
175
176        let mut values: Vec<ParsedValue> = Vec::new();
177        let mut depth = 1;
178
179        while self.pos < self.stmt.len() && depth > 0 {
180            self.skip_whitespace();
181
182            if self.pos >= self.stmt.len() {
183                break;
184            }
185
186            match self.stmt[self.pos] {
187                b'(' => {
188                    depth += 1;
189                    self.pos += 1;
190                }
191                b')' => {
192                    depth -= 1;
193                    self.pos += 1;
194                }
195                b',' if depth == 1 => {
196                    self.pos += 1;
197                }
198                _ if depth == 1 => {
199                    values.push(self.parse_value()?);
200                }
201                _ => {
202                    self.pos += 1;
203                }
204            }
205        }
206
207        let end = self.pos;
208        let raw = self.stmt[start..end].to_vec();
209
210        // Extract PK and FK values if we have a schema
211        let (pk, fk_values) = if let Some(schema) = self.table_schema {
212            self.extract_pk_fk(&values, schema)
213        } else {
214            (None, Vec::new())
215        };
216
217        Ok(Some(ParsedRow { raw, pk, fk_values }))
218    }
219
220    /// Parse a single value (string, number, NULL, etc.)
221    fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
222        self.skip_whitespace();
223
224        if self.pos >= self.stmt.len() {
225            return Ok(ParsedValue::Null);
226        }
227
228        let b = self.stmt[self.pos];
229
230        // NULL
231        if self.pos + 4 <= self.stmt.len() {
232            let word = &self.stmt[self.pos..self.pos + 4];
233            if word.eq_ignore_ascii_case(b"NULL") {
234                self.pos += 4;
235                return Ok(ParsedValue::Null);
236            }
237        }
238
239        // String literal
240        if b == b'\'' {
241            return self.parse_string_value();
242        }
243
244        // Hex literal (0x...)
245        if b == b'0' && self.pos + 1 < self.stmt.len() {
246            let next = self.stmt[self.pos + 1];
247            if next == b'x' || next == b'X' {
248                return self.parse_hex_value();
249            }
250        }
251
252        // Number or expression
253        self.parse_number_value()
254    }
255
256    /// Parse a string literal 'value'
257    fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
258        self.pos += 1; // Skip opening quote
259
260        let mut value = Vec::new();
261        let mut escape_next = false;
262
263        while self.pos < self.stmt.len() {
264            let b = self.stmt[self.pos];
265
266            if escape_next {
267                // Handle MySQL escape sequences
268                let escaped = match b {
269                    b'n' => b'\n',
270                    b'r' => b'\r',
271                    b't' => b'\t',
272                    b'0' => 0,
273                    _ => b, // \', \\, etc.
274                };
275                value.push(escaped);
276                escape_next = false;
277                self.pos += 1;
278            } else if b == b'\\' {
279                escape_next = true;
280                self.pos += 1;
281            } else if b == b'\'' {
282                // Check for escaped quote ''
283                if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
284                    value.push(b'\'');
285                    self.pos += 2;
286                } else {
287                    self.pos += 1; // End of string
288                    break;
289                }
290            } else {
291                value.push(b);
292                self.pos += 1;
293            }
294        }
295
296        let text = String::from_utf8_lossy(&value).into_owned();
297
298        Ok(ParsedValue::String { value: text })
299    }
300
301    /// Parse a hex literal 0xABCD...
302    fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
303        let start = self.pos;
304        self.pos += 2; // Skip 0x
305
306        while self.pos < self.stmt.len() {
307            let b = self.stmt[self.pos];
308            if b.is_ascii_hexdigit() {
309                self.pos += 1;
310            } else {
311                break;
312            }
313        }
314
315        let raw = self.stmt[start..self.pos].to_vec();
316        Ok(ParsedValue::Hex(raw))
317    }
318
319    /// Parse a number or other non-string value
320    fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
321        let start = self.pos;
322        let mut has_dot = false;
323
324        // Handle leading minus
325        if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
326            self.pos += 1;
327        }
328
329        while self.pos < self.stmt.len() {
330            let b = self.stmt[self.pos];
331            if b.is_ascii_digit() {
332                self.pos += 1;
333            } else if b == b'.' && !has_dot {
334                has_dot = true;
335                self.pos += 1;
336            } else if b == b'e' || b == b'E' {
337                // Scientific notation
338                self.pos += 1;
339                if self.pos < self.stmt.len()
340                    && (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
341                {
342                    self.pos += 1;
343                }
344            } else if b == b',' || b == b')' || b.is_ascii_whitespace() {
345                break;
346            } else {
347                // Unknown character in number, skip to next delimiter
348                while self.pos < self.stmt.len() {
349                    let c = self.stmt[self.pos];
350                    if c == b',' || c == b')' {
351                        break;
352                    }
353                    self.pos += 1;
354                }
355                break;
356            }
357        }
358
359        let raw = self.stmt[start..self.pos].to_vec();
360        let value_str = String::from_utf8_lossy(&raw);
361
362        // Try to parse as integer
363        if !has_dot {
364            if let Ok(n) = value_str.parse::<i64>() {
365                return Ok(ParsedValue::Integer(n));
366            }
367            if let Ok(n) = value_str.parse::<i128>() {
368                return Ok(ParsedValue::BigInteger(n));
369            }
370        }
371
372        // Fall back to raw value
373        Ok(ParsedValue::Other(raw))
374    }
375
376    /// Skip whitespace and newlines
377    fn skip_whitespace(&mut self) {
378        while self.pos < self.stmt.len() {
379            let b = self.stmt[self.pos];
380            if b.is_ascii_whitespace() {
381                self.pos += 1;
382            } else {
383                break;
384            }
385        }
386    }
387
388    /// Extract PK and FK values from parsed values
389    fn extract_pk_fk(
390        &self,
391        values: &[ParsedValue],
392        schema: &TableSchema,
393    ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>) {
394        let mut pk_values = PkTuple::new();
395        let mut fk_values = Vec::new();
396
397        // Build PK from columns marked as primary key
398        for (idx, col_id_opt) in self.column_order.iter().enumerate() {
399            if let Some(col_id) = col_id_opt {
400                if schema.is_pk_column(*col_id) {
401                    if let Some(value) = values.get(idx) {
402                        let pk_val = self.value_to_pk(value, schema.column(*col_id));
403                        pk_values.push(pk_val);
404                    }
405                }
406            }
407        }
408
409        // Build FK tuples
410        for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
411            if fk.referenced_table_id.is_none() {
412                continue;
413            }
414
415            let mut fk_tuple = PkTuple::new();
416            let mut all_non_null = true;
417
418            for &col_id in &fk.columns {
419                // Find the value index for this column
420                if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
421                    if let Some(value) = values.get(idx) {
422                        let pk_val = self.value_to_pk(value, schema.column(col_id));
423                        if pk_val.is_null() {
424                            all_non_null = false;
425                            break;
426                        }
427                        fk_tuple.push(pk_val);
428                    }
429                }
430            }
431
432            if all_non_null && !fk_tuple.is_empty() {
433                fk_values.push((
434                    FkRef {
435                        table_id: schema.id.0,
436                        fk_index: fk_idx as u16,
437                    },
438                    fk_tuple,
439                ));
440            }
441        }
442
443        let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
444            None
445        } else {
446            Some(pk_values)
447        };
448
449        (pk, fk_values)
450    }
451
452    /// Convert a parsed value to a PkValue
453    fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
454        match value {
455            ParsedValue::Null => PkValue::Null,
456            ParsedValue::Integer(n) => PkValue::Int(*n),
457            ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
458            ParsedValue::String { value } => {
459                // Check if this might be an integer stored as string
460                if let Some(col) = col {
461                    match col.col_type {
462                        ColumnType::Int => {
463                            if let Ok(n) = value.parse::<i64>() {
464                                return PkValue::Int(n);
465                            }
466                        }
467                        ColumnType::BigInt => {
468                            if let Ok(n) = value.parse::<i128>() {
469                                return PkValue::BigInt(n);
470                            }
471                        }
472                        _ => {}
473                    }
474                }
475                PkValue::Text(value.clone().into_boxed_str())
476            }
477            ParsedValue::Hex(raw) => {
478                PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
479            }
480            ParsedValue::Other(raw) => {
481                PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
482            }
483        }
484    }
485}
486
487/// Internal representation of a parsed value
488#[derive(Debug, Clone)]
489enum ParsedValue {
490    Null,
491    Integer(i64),
492    BigInteger(i128),
493    String { value: String },
494    Hex(Vec<u8>),
495    Other(Vec<u8>),
496}
497
498/// Parse all rows from a MySQL INSERT statement
499pub fn parse_mysql_insert_rows(
500    stmt: &[u8],
501    schema: &TableSchema,
502) -> anyhow::Result<Vec<ParsedRow>> {
503    let mut parser = InsertParser::new(stmt).with_schema(schema);
504    parser.parse_rows()
505}
506
507/// Parse rows without schema (just raw row extraction)
508pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
509    let mut parser = InsertParser::new(stmt);
510    parser.parse_rows()
511}