sql_splitter/parser/
mod.rs

1pub mod mysql_insert;
2pub mod postgres_copy;
3
4// Re-export types for bulk loading
5pub use mysql_insert::{parse_insert_for_bulk, ParsedValue};
6
7use once_cell::sync::Lazy;
8use regex::bytes::Regex;
9use std::io::{BufRead, BufReader, Read};
10
11pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
12pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
13
14/// SQL dialect for parser behavior
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum SqlDialect {
17    /// MySQL/MariaDB mysqldump format (backtick quoting, backslash escapes)
18    #[default]
19    MySql,
20    /// PostgreSQL pg_dump format (double-quote identifiers, COPY FROM stdin, dollar-quoting)
21    Postgres,
22    /// SQLite .dump format (double-quote identifiers, '' escapes)
23    Sqlite,
24    /// Microsoft SQL Server / T-SQL (square bracket identifiers, GO batches, N'unicode' strings)
25    Mssql,
26}
27
28impl std::str::FromStr for SqlDialect {
29    type Err = String;
30
31    fn from_str(s: &str) -> Result<Self, Self::Err> {
32        match s.to_lowercase().as_str() {
33            "mysql" | "mariadb" => Ok(SqlDialect::MySql),
34            "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
35            "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
36            "mssql" | "sqlserver" | "sql_server" | "tsql" => Ok(SqlDialect::Mssql),
37            _ => Err(format!(
38                "Unknown dialect: {}. Valid options: mysql, postgres, sqlite, mssql",
39                s
40            )),
41        }
42    }
43}
44
45impl std::fmt::Display for SqlDialect {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            SqlDialect::MySql => write!(f, "mysql"),
49            SqlDialect::Postgres => write!(f, "postgres"),
50            SqlDialect::Sqlite => write!(f, "sqlite"),
51            SqlDialect::Mssql => write!(f, "mssql"),
52        }
53    }
54}
55
56/// Result of dialect auto-detection
57#[derive(Debug, Clone)]
58pub struct DialectDetectionResult {
59    pub dialect: SqlDialect,
60    pub confidence: DialectConfidence,
61}
62
63/// Confidence level of dialect detection
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum DialectConfidence {
66    /// High confidence - found definitive markers (e.g., "pg_dump", "MySQL dump")
67    High,
68    /// Medium confidence - found likely markers
69    Medium,
70    /// Low confidence - defaulting to MySQL
71    Low,
72}
73
74#[derive(Default)]
75struct DialectScore {
76    mysql: u32,
77    postgres: u32,
78    sqlite: u32,
79    mssql: u32,
80}
81
82/// Detect SQL dialect from file header content.
83/// Reads up to 8KB and looks for dialect-specific markers.
84pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
85    let mut score = DialectScore::default();
86
87    // High confidence markers (+10)
88    if contains_bytes(header, b"pg_dump") {
89        score.postgres += 10;
90    }
91    if contains_bytes(header, b"PostgreSQL database dump") {
92        score.postgres += 10;
93    }
94    if contains_bytes(header, b"MySQL dump") {
95        score.mysql += 10;
96    }
97    if contains_bytes(header, b"MariaDB dump") {
98        score.mysql += 10;
99    }
100    if contains_bytes(header, b"SQLite") {
101        score.sqlite += 10;
102    }
103
104    // Medium confidence markers (+5)
105    if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
106        score.postgres += 5;
107    }
108    if contains_bytes(header, b"search_path") {
109        score.postgres += 5;
110    }
111    if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
112        score.mysql += 5;
113    }
114    if contains_bytes(header, b"LOCK TABLES") {
115        score.mysql += 5;
116    }
117    if contains_bytes(header, b"PRAGMA") {
118        score.sqlite += 5;
119    }
120
121    // Low confidence markers (+2)
122    if contains_bytes(header, b"$$") {
123        score.postgres += 2;
124    }
125    if contains_bytes(header, b"CREATE EXTENSION") {
126        score.postgres += 2;
127    }
128    // BEGIN TRANSACTION is generic ANSI SQL, only slightly suggests SQLite
129    if contains_bytes(header, b"BEGIN TRANSACTION") {
130        score.sqlite += 2;
131    }
132    // Backticks suggest MySQL (could also appear in data/comments)
133    if header.contains(&b'`') {
134        score.mysql += 2;
135    }
136
137    // MSSQL/T-SQL markers
138    // High confidence markers (+20)
139    if contains_bytes(header, b"SET ANSI_NULLS") {
140        score.mssql += 20;
141    }
142    if contains_bytes(header, b"SET QUOTED_IDENTIFIER") {
143        score.mssql += 20;
144    }
145
146    // Medium confidence markers (+10-15)
147    // GO as batch separator on its own line (check for common patterns)
148    if contains_bytes(header, b"\nGO\n") || contains_bytes(header, b"\nGO\r\n") {
149        score.mssql += 15;
150    }
151    // Square bracket identifiers
152    if header.contains(&b'[') && header.contains(&b']') {
153        score.mssql += 10;
154    }
155    if contains_bytes(header, b"IDENTITY(") {
156        score.mssql += 10;
157    }
158    if contains_bytes(header, b"ON [PRIMARY]") {
159        score.mssql += 10;
160    }
161
162    // Low confidence markers (+5)
163    if contains_bytes(header, b"N'") {
164        score.mssql += 5;
165    }
166    if contains_bytes(header, b"NVARCHAR") {
167        score.mssql += 5;
168    }
169    if contains_bytes(header, b"CLUSTERED") {
170        score.mssql += 5;
171    }
172    if contains_bytes(header, b"SET NOCOUNT") {
173        score.mssql += 5;
174    }
175
176    // Determine winner and confidence
177    let max_score = score
178        .mysql
179        .max(score.postgres)
180        .max(score.sqlite)
181        .max(score.mssql);
182
183    if max_score == 0 {
184        return DialectDetectionResult {
185            dialect: SqlDialect::MySql,
186            confidence: DialectConfidence::Low,
187        };
188    }
189
190    // Find the dialect with the highest score
191    let (dialect, winning_score) = if score.mssql > score.mysql
192        && score.mssql > score.postgres
193        && score.mssql > score.sqlite
194    {
195        (SqlDialect::Mssql, score.mssql)
196    } else if score.postgres > score.mysql && score.postgres > score.sqlite {
197        (SqlDialect::Postgres, score.postgres)
198    } else if score.sqlite > score.mysql {
199        (SqlDialect::Sqlite, score.sqlite)
200    } else {
201        (SqlDialect::MySql, score.mysql)
202    };
203
204    // Determine confidence based on winning score
205    let confidence = if winning_score >= 10 {
206        DialectConfidence::High
207    } else if winning_score >= 5 {
208        DialectConfidence::Medium
209    } else {
210        DialectConfidence::Low
211    };
212
213    DialectDetectionResult {
214        dialect,
215        confidence,
216    }
217}
218
219/// Detect dialect from a file, reading first 8KB
220pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
221    use std::fs::File;
222    use std::io::Read;
223
224    let mut file = File::open(path)?;
225    let mut buf = [0u8; 8192];
226    let n = file.read(&mut buf)?;
227    Ok(detect_dialect(&buf[..n]))
228}
229
230#[inline]
231fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
232    haystack
233        .windows(needle.len())
234        .any(|window| window == needle)
235}
236
237/// Check if a line is a MSSQL GO batch separator
238/// GO must be on its own line (with optional whitespace and optional repeat count)
239/// Examples: "GO\n", "  GO  \n", "GO 100\n", "go\r\n"
240fn is_go_line(line: &[u8]) -> bool {
241    // Trim leading whitespace
242    let mut start = 0;
243    while start < line.len()
244        && (line[start] == b' ' || line[start] == b'\t' || line[start] == b'\r')
245    {
246        start += 1;
247    }
248
249    // Trim trailing whitespace and newlines
250    let mut end = line.len();
251    while end > start
252        && (line[end - 1] == b' '
253            || line[end - 1] == b'\t'
254            || line[end - 1] == b'\r'
255            || line[end - 1] == b'\n')
256    {
257        end -= 1;
258    }
259
260    let trimmed = &line[start..end];
261
262    if trimmed.len() < 2 {
263        return false;
264    }
265
266    // Check for "GO" (case-insensitive)
267    if trimmed.len() == 2 {
268        return (trimmed[0] == b'G' || trimmed[0] == b'g')
269            && (trimmed[1] == b'O' || trimmed[1] == b'o');
270    }
271
272    // Check for "GO <number>" pattern
273    if (trimmed[0] == b'G' || trimmed[0] == b'g')
274        && (trimmed[1] == b'O' || trimmed[1] == b'o')
275        && (trimmed[2] == b' ' || trimmed[2] == b'\t')
276    {
277        // Rest should be whitespace and digits
278        let rest = &trimmed[3..];
279        let rest_trimmed = rest
280            .iter()
281            .skip_while(|&&b| b == b' ' || b == b'\t')
282            .copied()
283            .collect::<Vec<_>>();
284        return rest_trimmed.is_empty() || rest_trimmed.iter().all(|&b| b.is_ascii_digit());
285    }
286
287    false
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum StatementType {
292    Unknown,
293    CreateTable,
294    Insert,
295    CreateIndex,
296    AlterTable,
297    DropTable,
298    /// PostgreSQL COPY FROM stdin
299    Copy,
300}
301
302impl StatementType {
303    /// Returns true if this is a schema-related statement (DDL)
304    pub fn is_schema(&self) -> bool {
305        matches!(
306            self,
307            StatementType::CreateTable
308                | StatementType::CreateIndex
309                | StatementType::AlterTable
310                | StatementType::DropTable
311        )
312    }
313
314    /// Returns true if this is a data-related statement (DML)
315    pub fn is_data(&self) -> bool {
316        matches!(self, StatementType::Insert | StatementType::Copy)
317    }
318}
319
320/// Content filter mode for splitting
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
322pub enum ContentFilter {
323    /// Include both schema and data statements (default)
324    #[default]
325    All,
326    /// Only schema statements (CREATE TABLE, CREATE INDEX, ALTER TABLE, DROP TABLE)
327    SchemaOnly,
328    /// Only data statements (INSERT, COPY)
329    DataOnly,
330}
331
332static CREATE_TABLE_RE: Lazy<Regex> =
333    Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
334
335static INSERT_INTO_RE: Lazy<Regex> =
336    Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
337
338static CREATE_INDEX_RE: Lazy<Regex> =
339    Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
340
341// MSSQL CREATE INDEX: extracts table from ON [schema].[table] or ON [table]
342// Matches: ON [table], ON [dbo].[table], ON [db].[dbo].[table]
343// Captures the last bracketed or unbracketed identifier before (
344static CREATE_INDEX_MSSQL_RE: Lazy<Regex> =
345    Lazy::new(|| Regex::new(r"(?i)ON\s+(?:\[?[^\[\]\s]+\]?\s*\.\s*)*\[([^\[\]]+)\]").unwrap());
346
347static ALTER_TABLE_RE: Lazy<Regex> =
348    Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
349
350static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
351    Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
352});
353
354// PostgreSQL COPY statement regex
355static COPY_RE: Lazy<Regex> =
356    Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
357
358// More flexible table name regex that handles:
359// - Backticks: `table`
360// - Double quotes: "table"
361// - Schema qualified: schema.table, `schema`.`table`, "schema"."table"
362// - IF NOT EXISTS
363static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
364    Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
365});
366
367static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
368    Regex::new(
369        r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
370    )
371    .unwrap()
372});
373
374pub struct Parser<R: Read> {
375    reader: BufReader<R>,
376    stmt_buffer: Vec<u8>,
377    dialect: SqlDialect,
378    /// For PostgreSQL: true when reading COPY data block
379    in_copy_data: bool,
380}
381
382impl<R: Read> Parser<R> {
383    #[allow(dead_code)]
384    pub fn new(reader: R, buffer_size: usize) -> Self {
385        Self::with_dialect(reader, buffer_size, SqlDialect::default())
386    }
387
388    pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
389        Self {
390            reader: BufReader::with_capacity(buffer_size, reader),
391            stmt_buffer: Vec::with_capacity(32 * 1024),
392            dialect,
393            in_copy_data: false,
394        }
395    }
396
397    pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
398        // If we're in PostgreSQL COPY data mode, read until we see the terminator
399        if self.in_copy_data {
400            return self.read_copy_data();
401        }
402
403        // For MSSQL, use line-based parsing to handle GO batch separator
404        if self.dialect == SqlDialect::Mssql {
405            return self.read_statement_mssql();
406        }
407
408        self.stmt_buffer.clear();
409
410        let mut inside_single_quote = false;
411        let mut inside_double_quote = false;
412        let mut escaped = false;
413        let mut in_line_comment = false;
414        // For PostgreSQL dollar-quoting: track the tag
415        let mut in_dollar_quote = false;
416        let mut dollar_tag: Vec<u8> = Vec::new();
417
418        loop {
419            let buf = self.reader.fill_buf()?;
420            if buf.is_empty() {
421                if self.stmt_buffer.is_empty() {
422                    return Ok(None);
423                }
424                let result = std::mem::take(&mut self.stmt_buffer);
425                return Ok(Some(result));
426            }
427
428            let mut consumed = 0;
429            let mut found_terminator = false;
430
431            for (i, &b) in buf.iter().enumerate() {
432                let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
433
434                // End of line comment on newline
435                if in_line_comment {
436                    if b == b'\n' {
437                        in_line_comment = false;
438                    }
439                    continue;
440                }
441
442                if escaped {
443                    escaped = false;
444                    continue;
445                }
446
447                // Handle backslash escapes (MySQL style)
448                if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
449                    escaped = true;
450                    continue;
451                }
452
453                // Handle line comments (-- to end of line)
454                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
455                    in_line_comment = true;
456                    continue;
457                }
458
459                // Handle dollar-quoting for PostgreSQL
460                if self.dialect == SqlDialect::Postgres
461                    && !inside_single_quote
462                    && !inside_double_quote
463                {
464                    if b == b'$' && !in_dollar_quote {
465                        // Start of dollar-quote: scan for the closing $
466                        if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
467                            let tag_bytes = &buf[i + 1..i + 1 + end];
468
469                            // Validate tag: must be empty OR identifier-like [A-Za-z_][A-Za-z0-9_]*
470                            let is_valid_tag = if tag_bytes.is_empty() {
471                                true
472                            } else {
473                                let mut iter = tag_bytes.iter();
474                                match iter.next() {
475                                    Some(&first)
476                                        if first.is_ascii_alphabetic() || first == b'_' =>
477                                    {
478                                        iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
479                                    }
480                                    _ => false,
481                                }
482                            };
483
484                            if is_valid_tag {
485                                dollar_tag = tag_bytes.to_vec();
486                                in_dollar_quote = true;
487                                continue;
488                            }
489                            // Invalid tag - treat $ as normal character
490                        }
491                    } else if b == b'$' && in_dollar_quote {
492                        // Potential end of dollar-quote
493                        let tag_len = dollar_tag.len();
494                        if i + 1 + tag_len < buf.len()
495                            && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
496                            && buf.get(i + 1 + tag_len) == Some(&b'$')
497                        {
498                            in_dollar_quote = false;
499                            dollar_tag.clear();
500                            continue;
501                        }
502                    }
503                }
504
505                if b == b'\'' && !inside_double_quote && !in_dollar_quote {
506                    inside_single_quote = !inside_single_quote;
507                } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
508                    inside_double_quote = !inside_double_quote;
509                } else if b == b';' && !inside_string {
510                    self.stmt_buffer.extend_from_slice(&buf[..=i]);
511                    consumed = i + 1;
512                    found_terminator = true;
513                    break;
514                }
515            }
516
517            if found_terminator {
518                self.reader.consume(consumed);
519                let result = std::mem::take(&mut self.stmt_buffer);
520
521                // Check if this is a PostgreSQL COPY FROM stdin statement
522                if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
523                    self.in_copy_data = true;
524                }
525
526                return Ok(Some(result));
527            }
528
529            self.stmt_buffer.extend_from_slice(buf);
530            let len = buf.len();
531            self.reader.consume(len);
532        }
533    }
534
535    /// Check if statement is a PostgreSQL COPY FROM stdin
536    fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
537        // Strip leading comments (pg_dump adds -- comments before COPY statements)
538        let stmt = strip_leading_comments_and_whitespace(stmt);
539        if stmt.len() < 4 {
540            return false;
541        }
542
543        // Take enough bytes to cover column lists - typical COPY statements are <500 bytes
544        let upper: Vec<u8> = stmt
545            .iter()
546            .take(500)
547            .map(|b| b.to_ascii_uppercase())
548            .collect();
549        upper.starts_with(b"COPY ")
550            && (upper.windows(10).any(|w| w == b"FROM STDIN")
551                || upper.windows(11).any(|w| w == b"FROM STDIN;"))
552    }
553
554    /// Read PostgreSQL COPY data block until we see the terminator line (\.)
555    fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
556        self.stmt_buffer.clear();
557
558        loop {
559            // First, fill the buffer and check if empty
560            let buf = self.reader.fill_buf()?;
561            if buf.is_empty() {
562                self.in_copy_data = false;
563                if self.stmt_buffer.is_empty() {
564                    return Ok(None);
565                }
566                return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
567            }
568
569            // Look for a newline in the buffer
570            let newline_pos = buf.iter().position(|&b| b == b'\n');
571
572            if let Some(i) = newline_pos {
573                // Include this newline
574                self.stmt_buffer.extend_from_slice(&buf[..=i]);
575                self.reader.consume(i + 1);
576
577                // Check if the line we just added ends the COPY block
578                // Looking for a line that is just "\.\n" or "\.\r\n"
579                if self.ends_with_copy_terminator() {
580                    self.in_copy_data = false;
581                    return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
582                }
583                // Continue reading - we need to process more lines
584            } else {
585                // No newline found, consume the whole buffer and continue
586                let len = buf.len();
587                self.stmt_buffer.extend_from_slice(buf);
588                self.reader.consume(len);
589            }
590        }
591    }
592
593    /// Check if buffer ends with the COPY terminator line (\.)
594    fn ends_with_copy_terminator(&self) -> bool {
595        let data = &self.stmt_buffer;
596        if data.len() < 2 {
597            return false;
598        }
599
600        // Look for a line that is just "\.\n" or "\.\r\n"
601        // We need to find the start of the last line
602        let last_newline = data[..data.len() - 1]
603            .iter()
604            .rposition(|&b| b == b'\n')
605            .map(|i| i + 1)
606            .unwrap_or(0);
607
608        let last_line = &data[last_newline..];
609
610        // Check if it's "\.\n" or "\.\r\n"
611        last_line == b"\\.\n" || last_line == b"\\.\r\n"
612    }
613
614    /// Read MSSQL statement with GO batch separator support
615    /// GO is a batch separator that appears on its own line
616    fn read_statement_mssql(&mut self) -> std::io::Result<Option<Vec<u8>>> {
617        self.stmt_buffer.clear();
618
619        let mut inside_single_quote = false;
620        let mut inside_bracket_quote = false;
621        let mut in_line_comment = false;
622        let mut line_start = 0usize; // Track where current line started in stmt_buffer
623
624        loop {
625            let buf = self.reader.fill_buf()?;
626            if buf.is_empty() {
627                if self.stmt_buffer.is_empty() {
628                    return Ok(None);
629                }
630                let result = std::mem::take(&mut self.stmt_buffer);
631                return Ok(Some(result));
632            }
633
634            let mut consumed = 0;
635            let mut found_terminator = false;
636
637            for (i, &b) in buf.iter().enumerate() {
638                let inside_string = inside_single_quote || inside_bracket_quote;
639
640                // End of line comment on newline
641                if in_line_comment {
642                    if b == b'\n' {
643                        in_line_comment = false;
644                        // Add to buffer and update line_start
645                        self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
646                        consumed = i + 1;
647                        line_start = self.stmt_buffer.len();
648                    }
649                    continue;
650                }
651
652                // Handle line comments (-- to end of line)
653                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
654                    in_line_comment = true;
655                    continue;
656                }
657
658                // Handle N'...' unicode strings - treat N as prefix, ' as quote start
659                // (The N is just a prefix, single quote handling is the same)
660
661                // Handle string quotes
662                if b == b'\'' && !inside_bracket_quote {
663                    inside_single_quote = !inside_single_quote;
664                } else if b == b'[' && !inside_single_quote {
665                    inside_bracket_quote = true;
666                } else if b == b']' && inside_bracket_quote {
667                    // Check for escaped ]]
668                    if i + 1 < buf.len() && buf[i + 1] == b']' {
669                        // Skip the escape sequence - consume one extra ]
670                        continue;
671                    }
672                    inside_bracket_quote = false;
673                } else if b == b';' && !inside_string {
674                    // Semicolon is a statement terminator in MSSQL too
675                    self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
676                    consumed = i + 1;
677                    found_terminator = true;
678                    break;
679                } else if b == b'\n' && !inside_string {
680                    // Check if the current line (from line_start to here) is just "GO"
681                    // First, add bytes up to and including the newline
682                    self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
683                    consumed = i + 1;
684
685                    // Get the line we just completed
686                    let line = &self.stmt_buffer[line_start..];
687                    if is_go_line(line) {
688                        // Remove the GO line from the buffer
689                        self.stmt_buffer.truncate(line_start);
690                        // Trim trailing whitespace from the statement
691                        while self
692                            .stmt_buffer
693                            .last()
694                            .is_some_and(|&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
695                        {
696                            self.stmt_buffer.pop();
697                        }
698                        // If we have content, return it
699                        if !self.stmt_buffer.is_empty() {
700                            self.reader.consume(consumed);
701                            let result = std::mem::take(&mut self.stmt_buffer);
702                            return Ok(Some(result));
703                        }
704                        // Otherwise, reset and continue (empty batch)
705                        line_start = 0;
706                    } else {
707                        // Update line_start to after the newline
708                        line_start = self.stmt_buffer.len();
709                    }
710                    continue;
711                }
712            }
713
714            if found_terminator {
715                self.reader.consume(consumed);
716                let result = std::mem::take(&mut self.stmt_buffer);
717                return Ok(Some(result));
718            }
719
720            // Add remaining bytes to buffer
721            if consumed < buf.len() {
722                self.stmt_buffer.extend_from_slice(&buf[consumed..]);
723            }
724            let len = buf.len();
725            self.reader.consume(len);
726        }
727    }
728
729    #[allow(dead_code)]
730    pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
731        Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
732    }
733
734    /// Parse a statement with dialect-specific handling
735    pub fn parse_statement_with_dialect(
736        stmt: &[u8],
737        dialect: SqlDialect,
738    ) -> (StatementType, String) {
739        // Strip leading comments (e.g., pg_dump adds -- comments before statements)
740        let stmt = strip_leading_comments_and_whitespace(stmt);
741
742        if stmt.len() < 4 {
743            return (StatementType::Unknown, String::new());
744        }
745
746        let upper_prefix: Vec<u8> = stmt
747            .iter()
748            .take(25)
749            .map(|b| b.to_ascii_uppercase())
750            .collect();
751
752        // PostgreSQL COPY statement
753        if upper_prefix.starts_with(b"COPY ") {
754            if let Some(caps) = COPY_RE.captures(stmt) {
755                if let Some(m) = caps.get(1) {
756                    let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
757                    // Handle schema.table - extract just the table name
758                    let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
759                    return (StatementType::Copy, table_name);
760                }
761            }
762        }
763
764        if upper_prefix.starts_with(b"CREATE TABLE") {
765            // Try fast extraction first
766            if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
767                return (StatementType::CreateTable, name);
768            }
769            // Fall back to flexible regex
770            if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
771                if let Some(m) = caps.get(1) {
772                    return (
773                        StatementType::CreateTable,
774                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
775                    );
776                }
777            }
778            // Original regex as last resort
779            if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
780                if let Some(m) = caps.get(1) {
781                    return (
782                        StatementType::CreateTable,
783                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
784                    );
785                }
786            }
787        }
788
789        if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
790            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
791                return (StatementType::Insert, name);
792            }
793            if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
794                if let Some(m) = caps.get(1) {
795                    return (
796                        StatementType::Insert,
797                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
798                    );
799                }
800            }
801            if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
802                if let Some(m) = caps.get(1) {
803                    return (
804                        StatementType::Insert,
805                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
806                    );
807                }
808            }
809        }
810
811        if upper_prefix.starts_with(b"CREATE INDEX")
812            || upper_prefix.starts_with(b"CREATE UNIQUE")
813            || upper_prefix.starts_with(b"CREATE CLUSTERED")
814            || upper_prefix.starts_with(b"CREATE NONCLUSTER")
815        {
816            // For MSSQL, try the bracket-aware regex first
817            if dialect == SqlDialect::Mssql {
818                if let Some(caps) = CREATE_INDEX_MSSQL_RE.captures(stmt) {
819                    if let Some(m) = caps.get(1) {
820                        return (
821                            StatementType::CreateIndex,
822                            String::from_utf8_lossy(m.as_bytes()).into_owned(),
823                        );
824                    }
825                }
826            }
827            // Fall back to generic regex for MySQL/PostgreSQL/SQLite
828            if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
829                if let Some(m) = caps.get(1) {
830                    return (
831                        StatementType::CreateIndex,
832                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
833                    );
834                }
835            }
836        }
837
838        if upper_prefix.starts_with(b"ALTER TABLE") {
839            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
840                return (StatementType::AlterTable, name);
841            }
842            if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
843                if let Some(m) = caps.get(1) {
844                    return (
845                        StatementType::AlterTable,
846                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
847                    );
848                }
849            }
850        }
851
852        if upper_prefix.starts_with(b"DROP TABLE") {
853            if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
854                return (StatementType::DropTable, name);
855            }
856            if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
857                if let Some(m) = caps.get(1) {
858                    return (
859                        StatementType::DropTable,
860                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
861                    );
862                }
863            }
864        }
865
866        // MSSQL BULK INSERT - treat as Insert statement type
867        if upper_prefix.starts_with(b"BULK INSERT") {
868            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
869                return (StatementType::Insert, name);
870            }
871        }
872
873        (StatementType::Unknown, String::new())
874    }
875}
876
877#[inline]
878fn trim_ascii_start(data: &[u8]) -> &[u8] {
879    let start = data
880        .iter()
881        .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
882        .unwrap_or(data.len());
883    &data[start..]
884}
885
886/// Strip leading whitespace and SQL line comments (`-- ...`) from a statement.
887/// This makes parsing robust to pg_dump-style comment blocks before statements.
888fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
889    loop {
890        // First trim leading ASCII whitespace
891        data = trim_ascii_start(data);
892
893        // Handle -- line comments
894        if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
895            // Skip until end of line
896            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
897                data = &data[pos + 1..];
898                continue;
899            } else {
900                // Comment runs to EOF, nothing left
901                return &[];
902            }
903        }
904
905        // Handle /* */ block comments (including MySQL conditional comments)
906        if data.len() >= 2 && data[0] == b'/' && data[1] == b'*' {
907            // Find the closing */
908            let mut i = 2;
909            let mut depth = 1;
910            while i < data.len() - 1 && depth > 0 {
911                if data[i] == b'*' && data[i + 1] == b'/' {
912                    depth -= 1;
913                    i += 2;
914                } else if data[i] == b'/' && data[i + 1] == b'*' {
915                    depth += 1;
916                    i += 2;
917                } else {
918                    i += 1;
919                }
920            }
921            if depth == 0 {
922                data = &data[i..];
923                continue;
924            } else {
925                // Unclosed comment runs to EOF
926                return &[];
927            }
928        }
929
930        // Handle # line comments (MySQL)
931        if !data.is_empty() && data[0] == b'#' {
932            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
933                data = &data[pos + 1..];
934                continue;
935            } else {
936                return &[];
937            }
938        }
939
940        break;
941    }
942
943    data
944}
945
946/// Extract table name with support for:
947/// - IF NOT EXISTS
948/// - ONLY (PostgreSQL)
949/// - Schema-qualified names (schema.table)
950/// - Both backtick and double-quote quoting
951#[inline]
952fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
953    let mut i = offset;
954
955    // Skip whitespace
956    while i < stmt.len() && is_whitespace(stmt[i]) {
957        i += 1;
958    }
959
960    if i >= stmt.len() {
961        return None;
962    }
963
964    // Check for IF NOT EXISTS or IF EXISTS
965    let upper_check: Vec<u8> = stmt[i..]
966        .iter()
967        .take(20)
968        .map(|b| b.to_ascii_uppercase())
969        .collect();
970    if upper_check.starts_with(b"IF NOT EXISTS") {
971        i += 13; // Skip "IF NOT EXISTS"
972        while i < stmt.len() && is_whitespace(stmt[i]) {
973            i += 1;
974        }
975    } else if upper_check.starts_with(b"IF EXISTS") {
976        i += 9; // Skip "IF EXISTS"
977        while i < stmt.len() && is_whitespace(stmt[i]) {
978            i += 1;
979        }
980    }
981
982    // Check for ONLY (PostgreSQL)
983    let upper_check: Vec<u8> = stmt[i..]
984        .iter()
985        .take(10)
986        .map(|b| b.to_ascii_uppercase())
987        .collect();
988    if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
989        i += 4;
990        while i < stmt.len() && is_whitespace(stmt[i]) {
991            i += 1;
992        }
993    }
994
995    if i >= stmt.len() {
996        return None;
997    }
998
999    // Read identifier (potentially schema-qualified)
1000    let mut parts: Vec<String> = Vec::new();
1001
1002    loop {
1003        // Determine quote character based on dialect
1004        let (quote_char, close_char) = match stmt.get(i) {
1005            Some(b'`') if dialect == SqlDialect::MySql => {
1006                i += 1;
1007                (Some(b'`'), b'`')
1008            }
1009            Some(b'"') if dialect != SqlDialect::MySql => {
1010                i += 1;
1011                (Some(b'"'), b'"')
1012            }
1013            Some(b'"') => {
1014                // Allow double quotes for MySQL too (though less common)
1015                i += 1;
1016                (Some(b'"'), b'"')
1017            }
1018            Some(b'[') if dialect == SqlDialect::Mssql => {
1019                // MSSQL square bracket quoting
1020                i += 1;
1021                (Some(b'['), b']')
1022            }
1023            _ => (None, 0),
1024        };
1025
1026        let start = i;
1027
1028        while i < stmt.len() {
1029            let b = stmt[i];
1030            if quote_char.is_some() {
1031                if b == close_char {
1032                    // For MSSQL, check for escaped ]]
1033                    if dialect == SqlDialect::Mssql
1034                        && close_char == b']'
1035                        && i + 1 < stmt.len()
1036                        && stmt[i + 1] == b']'
1037                    {
1038                        // Escaped bracket, skip both
1039                        i += 2;
1040                        continue;
1041                    }
1042                    let name = &stmt[start..i];
1043                    // For MSSQL, unescape ]] to ]
1044                    let name_str = if dialect == SqlDialect::Mssql {
1045                        String::from_utf8_lossy(name).replace("]]", "]")
1046                    } else {
1047                        String::from_utf8_lossy(name).into_owned()
1048                    };
1049                    parts.push(name_str);
1050                    i += 1; // Skip closing quote
1051                    break;
1052                }
1053            } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
1054                if i > start {
1055                    let name = &stmt[start..i];
1056                    parts.push(String::from_utf8_lossy(name).into_owned());
1057                }
1058                break;
1059            }
1060            i += 1;
1061        }
1062
1063        // If at end of quoted name without finding close quote, bail
1064        if quote_char.is_some() && i <= start {
1065            break;
1066        }
1067
1068        // Check for schema separator (.)
1069        while i < stmt.len() && is_whitespace(stmt[i]) {
1070            i += 1;
1071        }
1072
1073        if i < stmt.len() && stmt[i] == b'.' {
1074            i += 1; // Skip the dot
1075            while i < stmt.len() && is_whitespace(stmt[i]) {
1076                i += 1;
1077            }
1078            // Continue to read the next identifier (table name)
1079        } else {
1080            break;
1081        }
1082    }
1083
1084    // Return the last part (table name), not the schema
1085    parts.pop()
1086}
1087
1088#[inline]
1089fn is_whitespace(b: u8) -> bool {
1090    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
1091}
1092
1093pub fn determine_buffer_size(file_size: u64) -> usize {
1094    if file_size > 1024 * 1024 * 1024 {
1095        MEDIUM_BUFFER_SIZE
1096    } else {
1097        SMALL_BUFFER_SIZE
1098    }
1099}