Skip to main content

sql_splitter/parser/
mod.rs

1pub mod mysql_insert;
2pub mod postgres_copy;
3
4// Re-export types for bulk loading
5pub use mysql_insert::{parse_insert_for_bulk, ParsedValue};
6
7use once_cell::sync::Lazy;
8use regex::bytes::Regex;
9use std::io::{BufRead, BufReader, Read};
10
11pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
12pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
13
14/// SQL dialect for parser behavior
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum SqlDialect {
17    /// MySQL/MariaDB mysqldump format (backtick quoting, backslash escapes)
18    #[default]
19    MySql,
20    /// PostgreSQL pg_dump format (double-quote identifiers, COPY FROM stdin, dollar-quoting)
21    Postgres,
22    /// SQLite .dump format (double-quote identifiers, '' escapes)
23    Sqlite,
24    /// Microsoft SQL Server / T-SQL (square bracket identifiers, GO batches, N'unicode' strings)
25    Mssql,
26}
27
28impl std::str::FromStr for SqlDialect {
29    type Err = String;
30
31    fn from_str(s: &str) -> Result<Self, Self::Err> {
32        match s.to_lowercase().as_str() {
33            "mysql" | "mariadb" => Ok(SqlDialect::MySql),
34            "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
35            "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
36            "mssql" | "sqlserver" | "sql_server" | "tsql" => Ok(SqlDialect::Mssql),
37            _ => Err(format!(
38                "Unknown dialect: {}. Valid options: mysql, postgres, sqlite, mssql",
39                s
40            )),
41        }
42    }
43}
44
45impl std::fmt::Display for SqlDialect {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            SqlDialect::MySql => write!(f, "mysql"),
49            SqlDialect::Postgres => write!(f, "postgres"),
50            SqlDialect::Sqlite => write!(f, "sqlite"),
51            SqlDialect::Mssql => write!(f, "mssql"),
52        }
53    }
54}
55
56/// Result of dialect auto-detection
57#[derive(Debug, Clone)]
58pub struct DialectDetectionResult {
59    pub dialect: SqlDialect,
60    pub confidence: DialectConfidence,
61}
62
63/// Confidence level of dialect detection
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum DialectConfidence {
66    /// High confidence - found definitive markers (e.g., "pg_dump", "MySQL dump")
67    High,
68    /// Medium confidence - found likely markers
69    Medium,
70    /// Low confidence - defaulting to MySQL
71    Low,
72}
73
74#[derive(Default)]
75struct DialectScore {
76    mysql: u32,
77    postgres: u32,
78    sqlite: u32,
79    mssql: u32,
80}
81
82/// Detect SQL dialect from file header content.
83/// Reads up to 8KB and looks for dialect-specific markers.
84pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
85    let mut score = DialectScore::default();
86
87    // High confidence markers (+10)
88    if contains_bytes(header, b"pg_dump") {
89        score.postgres += 10;
90    }
91    if contains_bytes(header, b"PostgreSQL database dump") {
92        score.postgres += 10;
93    }
94    if contains_bytes(header, b"MySQL dump") {
95        score.mysql += 10;
96    }
97    if contains_bytes(header, b"MariaDB dump") {
98        score.mysql += 10;
99    }
100    if contains_bytes(header, b"SQLite") {
101        score.sqlite += 10;
102    }
103
104    // Medium confidence markers (+5)
105    if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
106        score.postgres += 5;
107    }
108    if contains_bytes(header, b"search_path") {
109        score.postgres += 5;
110    }
111    if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
112        score.mysql += 5;
113    }
114    if contains_bytes(header, b"LOCK TABLES") {
115        score.mysql += 5;
116    }
117    if contains_bytes(header, b"PRAGMA") {
118        score.sqlite += 5;
119    }
120
121    // Low confidence markers (+2)
122    if contains_bytes(header, b"$$") {
123        score.postgres += 2;
124    }
125    if contains_bytes(header, b"CREATE EXTENSION") {
126        score.postgres += 2;
127    }
128    // BEGIN TRANSACTION is generic ANSI SQL, only slightly suggests SQLite
129    if contains_bytes(header, b"BEGIN TRANSACTION") {
130        score.sqlite += 2;
131    }
132    // Backticks suggest MySQL (could also appear in data/comments)
133    if header.contains(&b'`') {
134        score.mysql += 2;
135    }
136
137    // MSSQL/T-SQL markers
138    // High confidence markers (+20)
139    if contains_bytes(header, b"SET ANSI_NULLS") {
140        score.mssql += 20;
141    }
142    if contains_bytes(header, b"SET QUOTED_IDENTIFIER") {
143        score.mssql += 20;
144    }
145
146    // Medium confidence markers (+10-15)
147    // GO as batch separator on its own line (check for common patterns)
148    if contains_bytes(header, b"\nGO\n") || contains_bytes(header, b"\nGO\r\n") {
149        score.mssql += 15;
150    }
151    // Square bracket identifiers
152    if header.contains(&b'[') && header.contains(&b']') {
153        score.mssql += 10;
154    }
155    if contains_bytes(header, b"IDENTITY(") {
156        score.mssql += 10;
157    }
158    if contains_bytes(header, b"ON [PRIMARY]") {
159        score.mssql += 10;
160    }
161
162    // Low confidence markers (+5)
163    if contains_bytes(header, b"N'") {
164        score.mssql += 5;
165    }
166    if contains_bytes(header, b"NVARCHAR") {
167        score.mssql += 5;
168    }
169    if contains_bytes(header, b"CLUSTERED") {
170        score.mssql += 5;
171    }
172    if contains_bytes(header, b"SET NOCOUNT") {
173        score.mssql += 5;
174    }
175
176    // Determine winner and confidence
177    let max_score = score
178        .mysql
179        .max(score.postgres)
180        .max(score.sqlite)
181        .max(score.mssql);
182
183    if max_score == 0 {
184        return DialectDetectionResult {
185            dialect: SqlDialect::MySql,
186            confidence: DialectConfidence::Low,
187        };
188    }
189
190    // Find the dialect with the highest score
191    let (dialect, winning_score) = if score.mssql > score.mysql
192        && score.mssql > score.postgres
193        && score.mssql > score.sqlite
194    {
195        (SqlDialect::Mssql, score.mssql)
196    } else if score.postgres > score.mysql && score.postgres > score.sqlite {
197        (SqlDialect::Postgres, score.postgres)
198    } else if score.sqlite > score.mysql {
199        (SqlDialect::Sqlite, score.sqlite)
200    } else {
201        (SqlDialect::MySql, score.mysql)
202    };
203
204    // Determine confidence based on winning score
205    let confidence = if winning_score >= 10 {
206        DialectConfidence::High
207    } else if winning_score >= 5 {
208        DialectConfidence::Medium
209    } else {
210        DialectConfidence::Low
211    };
212
213    DialectDetectionResult {
214        dialect,
215        confidence,
216    }
217}
218
219/// Detect dialect from a file, reading first 8KB
220pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
221    use std::fs::File;
222    use std::io::Read;
223
224    let mut file = File::open(path)?;
225    let mut buf = [0u8; 8192];
226    let n = file.read(&mut buf)?;
227    Ok(detect_dialect(&buf[..n]))
228}
229
230#[inline]
231fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
232    haystack
233        .windows(needle.len())
234        .any(|window| window == needle)
235}
236
237/// Check if a line is a MSSQL GO batch separator
238/// GO must be on its own line (with optional whitespace and optional repeat count)
239/// Examples: "GO\n", "  GO  \n", "GO 100\n", "go\r\n"
240fn is_go_line(line: &[u8]) -> bool {
241    // Trim leading whitespace
242    let mut start = 0;
243    while start < line.len()
244        && (line[start] == b' ' || line[start] == b'\t' || line[start] == b'\r')
245    {
246        start += 1;
247    }
248
249    // Trim trailing whitespace and newlines
250    let mut end = line.len();
251    while end > start
252        && (line[end - 1] == b' '
253            || line[end - 1] == b'\t'
254            || line[end - 1] == b'\r'
255            || line[end - 1] == b'\n')
256    {
257        end -= 1;
258    }
259
260    let trimmed = &line[start..end];
261
262    if trimmed.len() < 2 {
263        return false;
264    }
265
266    // Check for "GO" (case-insensitive)
267    if trimmed.len() == 2 {
268        return (trimmed[0] == b'G' || trimmed[0] == b'g')
269            && (trimmed[1] == b'O' || trimmed[1] == b'o');
270    }
271
272    // Check for "GO <number>" pattern
273    if (trimmed[0] == b'G' || trimmed[0] == b'g')
274        && (trimmed[1] == b'O' || trimmed[1] == b'o')
275        && (trimmed[2] == b' ' || trimmed[2] == b'\t')
276    {
277        // Rest should be whitespace and digits
278        let rest = &trimmed[3..];
279        let rest_trimmed = rest
280            .iter()
281            .skip_while(|&&b| b == b' ' || b == b'\t')
282            .copied()
283            .collect::<Vec<_>>();
284        return rest_trimmed.is_empty() || rest_trimmed.iter().all(|&b| b.is_ascii_digit());
285    }
286
287    false
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum StatementType {
292    Unknown,
293    CreateTable,
294    Insert,
295    CreateIndex,
296    AlterTable,
297    DropTable,
298    /// PostgreSQL COPY FROM stdin
299    Copy,
300}
301
302impl StatementType {
303    /// Returns true if this is a schema-related statement (DDL)
304    pub fn is_schema(&self) -> bool {
305        matches!(
306            self,
307            StatementType::CreateTable
308                | StatementType::CreateIndex
309                | StatementType::AlterTable
310                | StatementType::DropTable
311        )
312    }
313
314    /// Returns true if this is a data-related statement (DML)
315    pub fn is_data(&self) -> bool {
316        matches!(self, StatementType::Insert | StatementType::Copy)
317    }
318}
319
320/// Content filter mode for splitting
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
322pub enum ContentFilter {
323    /// Include both schema and data statements (default)
324    #[default]
325    All,
326    /// Only schema statements (CREATE TABLE, CREATE INDEX, ALTER TABLE, DROP TABLE)
327    SchemaOnly,
328    /// Only data statements (INSERT, COPY)
329    DataOnly,
330}
331
332static CREATE_TABLE_RE: Lazy<Regex> =
333    Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
334
335static INSERT_INTO_RE: Lazy<Regex> =
336    Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
337
338static CREATE_INDEX_RE: Lazy<Regex> =
339    Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
340
341// MSSQL CREATE INDEX: extracts table from ON [schema].[table] or ON [table]
342// Matches: ON [table], ON [dbo].[table], ON [db].[dbo].[table]
343// Captures the last bracketed or unbracketed identifier before (
344static CREATE_INDEX_MSSQL_RE: Lazy<Regex> =
345    Lazy::new(|| Regex::new(r"(?i)ON\s+(?:\[?[^\[\]\s]+\]?\s*\.\s*)*\[([^\[\]]+)\]").unwrap());
346
347static ALTER_TABLE_RE: Lazy<Regex> =
348    Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
349
350static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
351    Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
352});
353
354// PostgreSQL COPY statement regex
355static COPY_RE: Lazy<Regex> =
356    Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
357
358// More flexible table name regex that handles:
359// - Backticks: `table`
360// - Double quotes: "table"
361// - Schema qualified: schema.table, `schema`.`table`, "schema"."table"
362// - IF NOT EXISTS
363static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
364    Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
365});
366
367static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
368    Regex::new(
369        r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
370    )
371    .unwrap()
372});
373
374pub struct Parser<R: Read> {
375    reader: BufReader<R>,
376    stmt_buffer: Vec<u8>,
377    dialect: SqlDialect,
378    /// For PostgreSQL: true when reading COPY data block
379    in_copy_data: bool,
380}
381
382impl<R: Read> Parser<R> {
383    #[allow(dead_code)]
384    pub fn new(reader: R, buffer_size: usize) -> Self {
385        Self::with_dialect(reader, buffer_size, SqlDialect::default())
386    }
387
388    pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
389        Self {
390            reader: BufReader::with_capacity(buffer_size, reader),
391            stmt_buffer: Vec::with_capacity(32 * 1024),
392            dialect,
393            in_copy_data: false,
394        }
395    }
396
397    pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
398        // If we're in PostgreSQL COPY data mode, read until we see the terminator
399        if self.in_copy_data {
400            return self.read_copy_data();
401        }
402
403        // For MSSQL, use line-based parsing to handle GO batch separator
404        if self.dialect == SqlDialect::Mssql {
405            return self.read_statement_mssql();
406        }
407
408        self.stmt_buffer.clear();
409
410        let mut inside_single_quote = false;
411        let mut inside_double_quote = false;
412        let mut escaped = false;
413        let mut in_line_comment = false;
414        // For PostgreSQL dollar-quoting: track the tag
415        let mut in_dollar_quote = false;
416        let mut dollar_tag: Vec<u8> = Vec::new();
417
418        loop {
419            let buf = self.reader.fill_buf()?;
420            if buf.is_empty() {
421                if self.stmt_buffer.is_empty() {
422                    return Ok(None);
423                }
424                let result = std::mem::take(&mut self.stmt_buffer);
425                return Ok(Some(result));
426            }
427
428            let mut consumed = 0;
429            let mut found_terminator = false;
430
431            for (i, &b) in buf.iter().enumerate() {
432                let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
433
434                // End of line comment on newline
435                if in_line_comment {
436                    if b == b'\n' {
437                        in_line_comment = false;
438                    }
439                    continue;
440                }
441
442                if escaped {
443                    escaped = false;
444                    continue;
445                }
446
447                // Handle backslash escapes (MySQL style)
448                if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
449                    escaped = true;
450                    continue;
451                }
452
453                // Handle line comments (-- to end of line)
454                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
455                    in_line_comment = true;
456                    continue;
457                }
458
459                // Handle dollar-quoting for PostgreSQL
460                if self.dialect == SqlDialect::Postgres
461                    && !inside_single_quote
462                    && !inside_double_quote
463                {
464                    if b == b'$' && !in_dollar_quote {
465                        // Start of dollar-quote: scan for the closing $
466                        if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
467                            let tag_bytes = &buf[i + 1..i + 1 + end];
468
469                            // Validate tag: must be empty OR identifier-like [A-Za-z_][A-Za-z0-9_]*
470                            let is_valid_tag = if tag_bytes.is_empty() {
471                                true
472                            } else {
473                                let mut iter = tag_bytes.iter();
474                                match iter.next() {
475                                    Some(&first)
476                                        if first.is_ascii_alphabetic() || first == b'_' =>
477                                    {
478                                        iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
479                                    }
480                                    _ => false,
481                                }
482                            };
483
484                            if is_valid_tag {
485                                dollar_tag = tag_bytes.to_vec();
486                                in_dollar_quote = true;
487                                continue;
488                            }
489                            // Invalid tag - treat $ as normal character
490                        }
491                    } else if b == b'$' && in_dollar_quote {
492                        // Potential end of dollar-quote
493                        let tag_len = dollar_tag.len();
494                        if i + 1 + tag_len < buf.len()
495                            && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
496                            && buf.get(i + 1 + tag_len) == Some(&b'$')
497                        {
498                            in_dollar_quote = false;
499                            dollar_tag.clear();
500                            continue;
501                        }
502                    }
503                }
504
505                if b == b'\'' && !inside_double_quote && !in_dollar_quote {
506                    inside_single_quote = !inside_single_quote;
507                } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
508                    inside_double_quote = !inside_double_quote;
509                } else if b == b';' && !inside_string {
510                    self.stmt_buffer.extend_from_slice(&buf[..=i]);
511                    consumed = i + 1;
512                    found_terminator = true;
513                    break;
514                }
515            }
516
517            if found_terminator {
518                self.reader.consume(consumed);
519                let result = std::mem::take(&mut self.stmt_buffer);
520
521                // Check if this is a PostgreSQL COPY FROM stdin statement
522                if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
523                    self.in_copy_data = true;
524                }
525
526                return Ok(Some(result));
527            }
528
529            self.stmt_buffer.extend_from_slice(buf);
530            let len = buf.len();
531            self.reader.consume(len);
532        }
533    }
534
535    /// Check if statement is a PostgreSQL COPY FROM stdin
536    fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
537        // Strip leading comments (pg_dump adds -- comments before COPY statements)
538        let stmt = strip_leading_comments_and_whitespace(stmt);
539        if stmt.len() < 15 {
540            // Minimum: "COPY x FROM STDIN" = 17 chars
541            return false;
542        }
543
544        // Check prefix with stack-allocated buffer (avoid heap allocation)
545        let mut prefix = [0u8; 5];
546        for (i, &b) in stmt.iter().take(5).enumerate() {
547            prefix[i] = b.to_ascii_uppercase();
548        }
549        if &prefix != b"COPY " {
550            return false;
551        }
552
553        // Search for "FROM STDIN" case-insensitively without allocating
554        // Look within first 500 bytes (typical COPY statements are shorter)
555        let search_len = stmt.len().min(500);
556        for i in 0..search_len.saturating_sub(10) {
557            if stmt[i..i + 10]
558                .iter()
559                .zip(b"FROM STDIN".iter())
560                .all(|(&a, &b)| a.to_ascii_uppercase() == b)
561            {
562                return true;
563            }
564        }
565        false
566    }
567
568    /// Read PostgreSQL COPY data block until we see the terminator line (\.)
569    fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
570        self.stmt_buffer.clear();
571
572        loop {
573            // First, fill the buffer and check if empty
574            let buf = self.reader.fill_buf()?;
575            if buf.is_empty() {
576                self.in_copy_data = false;
577                if self.stmt_buffer.is_empty() {
578                    return Ok(None);
579                }
580                return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
581            }
582
583            // Look for a newline in the buffer
584            let newline_pos = buf.iter().position(|&b| b == b'\n');
585
586            if let Some(i) = newline_pos {
587                // Include this newline
588                self.stmt_buffer.extend_from_slice(&buf[..=i]);
589                self.reader.consume(i + 1);
590
591                // Check if the line we just added ends the COPY block
592                // Looking for a line that is just "\.\n" or "\.\r\n"
593                if self.ends_with_copy_terminator() {
594                    self.in_copy_data = false;
595                    return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
596                }
597                // Continue reading - we need to process more lines
598            } else {
599                // No newline found, consume the whole buffer and continue
600                let len = buf.len();
601                self.stmt_buffer.extend_from_slice(buf);
602                self.reader.consume(len);
603            }
604        }
605    }
606
607    /// Check if buffer ends with the COPY terminator line (\.)
608    fn ends_with_copy_terminator(&self) -> bool {
609        let data = &self.stmt_buffer;
610        if data.len() < 2 {
611            return false;
612        }
613
614        // Look for a line that is just "\.\n" or "\.\r\n"
615        // We need to find the start of the last line
616        let last_newline = data[..data.len() - 1]
617            .iter()
618            .rposition(|&b| b == b'\n')
619            .map(|i| i + 1)
620            .unwrap_or(0);
621
622        let last_line = &data[last_newline..];
623
624        // Check if it's "\.\n" or "\.\r\n"
625        last_line == b"\\.\n" || last_line == b"\\.\r\n"
626    }
627
628    /// Read MSSQL statement with GO batch separator support
629    /// GO is a batch separator that appears on its own line
630    fn read_statement_mssql(&mut self) -> std::io::Result<Option<Vec<u8>>> {
631        self.stmt_buffer.clear();
632
633        let mut inside_single_quote = false;
634        let mut inside_bracket_quote = false;
635        let mut in_line_comment = false;
636        let mut line_start = 0usize; // Track where current line started in stmt_buffer
637
638        loop {
639            let buf = self.reader.fill_buf()?;
640            if buf.is_empty() {
641                if self.stmt_buffer.is_empty() {
642                    return Ok(None);
643                }
644                let result = std::mem::take(&mut self.stmt_buffer);
645                return Ok(Some(result));
646            }
647
648            let mut consumed = 0;
649            let mut found_terminator = false;
650
651            for (i, &b) in buf.iter().enumerate() {
652                let inside_string = inside_single_quote || inside_bracket_quote;
653
654                // End of line comment on newline
655                if in_line_comment {
656                    if b == b'\n' {
657                        in_line_comment = false;
658                        // Add to buffer and update line_start
659                        self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
660                        consumed = i + 1;
661                        line_start = self.stmt_buffer.len();
662                    }
663                    continue;
664                }
665
666                // Handle line comments (-- to end of line)
667                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
668                    in_line_comment = true;
669                    continue;
670                }
671
672                // Handle N'...' unicode strings - treat N as prefix, ' as quote start
673                // (The N is just a prefix, single quote handling is the same)
674
675                // Handle string quotes
676                if b == b'\'' && !inside_bracket_quote {
677                    inside_single_quote = !inside_single_quote;
678                } else if b == b'[' && !inside_single_quote {
679                    inside_bracket_quote = true;
680                } else if b == b']' && inside_bracket_quote {
681                    // Check for escaped ]]
682                    if i + 1 < buf.len() && buf[i + 1] == b']' {
683                        // Skip the escape sequence - consume one extra ]
684                        continue;
685                    }
686                    inside_bracket_quote = false;
687                } else if b == b';' && !inside_string {
688                    // Semicolon is a statement terminator in MSSQL too
689                    self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
690                    consumed = i + 1;
691                    found_terminator = true;
692                    break;
693                } else if b == b'\n' && !inside_string {
694                    // Check if the current line (from line_start to here) is just "GO"
695                    // First, add bytes up to and including the newline
696                    self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
697                    consumed = i + 1;
698
699                    // Get the line we just completed
700                    let line = &self.stmt_buffer[line_start..];
701                    if is_go_line(line) {
702                        // Remove the GO line from the buffer
703                        self.stmt_buffer.truncate(line_start);
704                        // Trim trailing whitespace from the statement
705                        while self
706                            .stmt_buffer
707                            .last()
708                            .is_some_and(|&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
709                        {
710                            self.stmt_buffer.pop();
711                        }
712                        // If we have content, return it
713                        if !self.stmt_buffer.is_empty() {
714                            self.reader.consume(consumed);
715                            let result = std::mem::take(&mut self.stmt_buffer);
716                            return Ok(Some(result));
717                        }
718                        // Otherwise, reset and continue (empty batch)
719                        line_start = 0;
720                    } else {
721                        // Update line_start to after the newline
722                        line_start = self.stmt_buffer.len();
723                    }
724                    continue;
725                }
726            }
727
728            if found_terminator {
729                self.reader.consume(consumed);
730                let result = std::mem::take(&mut self.stmt_buffer);
731                return Ok(Some(result));
732            }
733
734            // Add remaining bytes to buffer
735            if consumed < buf.len() {
736                self.stmt_buffer.extend_from_slice(&buf[consumed..]);
737            }
738            let len = buf.len();
739            self.reader.consume(len);
740        }
741    }
742
743    #[allow(dead_code)]
744    pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
745        Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
746    }
747
748    /// Parse a statement with dialect-specific handling
749    pub fn parse_statement_with_dialect(
750        stmt: &[u8],
751        dialect: SqlDialect,
752    ) -> (StatementType, String) {
753        // Strip leading comments (e.g., pg_dump adds -- comments before statements)
754        let stmt = strip_leading_comments_and_whitespace(stmt);
755
756        if stmt.len() < 4 {
757            return (StatementType::Unknown, String::new());
758        }
759
760        // Use stack-allocated buffer to avoid heap allocation in hot path
761        let mut upper_prefix = [0u8; 25];
762        let prefix_len = stmt.len().min(25);
763        for (i, &b) in stmt.iter().take(prefix_len).enumerate() {
764            upper_prefix[i] = b.to_ascii_uppercase();
765        }
766        let upper_prefix = &upper_prefix[..prefix_len];
767
768        // PostgreSQL COPY statement
769        if upper_prefix.starts_with(b"COPY ") {
770            if let Some(caps) = COPY_RE.captures(stmt) {
771                if let Some(m) = caps.get(1) {
772                    let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
773                    // Handle schema.table - extract just the table name
774                    let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
775                    return (StatementType::Copy, table_name);
776                }
777            }
778        }
779
780        if upper_prefix.starts_with(b"CREATE TABLE") {
781            // Try fast extraction first
782            if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
783                return (StatementType::CreateTable, name);
784            }
785            // Fall back to flexible regex
786            if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
787                if let Some(m) = caps.get(1) {
788                    return (
789                        StatementType::CreateTable,
790                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
791                    );
792                }
793            }
794            // Original regex as last resort
795            if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
796                if let Some(m) = caps.get(1) {
797                    return (
798                        StatementType::CreateTable,
799                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
800                    );
801                }
802            }
803        }
804
805        if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
806            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
807                return (StatementType::Insert, name);
808            }
809            if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
810                if let Some(m) = caps.get(1) {
811                    return (
812                        StatementType::Insert,
813                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
814                    );
815                }
816            }
817            if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
818                if let Some(m) = caps.get(1) {
819                    return (
820                        StatementType::Insert,
821                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
822                    );
823                }
824            }
825        }
826
827        if upper_prefix.starts_with(b"CREATE INDEX")
828            || upper_prefix.starts_with(b"CREATE UNIQUE")
829            || upper_prefix.starts_with(b"CREATE CLUSTERED")
830            || upper_prefix.starts_with(b"CREATE NONCLUSTER")
831        {
832            // For MSSQL, try the bracket-aware regex first
833            if dialect == SqlDialect::Mssql {
834                if let Some(caps) = CREATE_INDEX_MSSQL_RE.captures(stmt) {
835                    if let Some(m) = caps.get(1) {
836                        return (
837                            StatementType::CreateIndex,
838                            String::from_utf8_lossy(m.as_bytes()).into_owned(),
839                        );
840                    }
841                }
842            }
843            // Fall back to generic regex for MySQL/PostgreSQL/SQLite
844            if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
845                if let Some(m) = caps.get(1) {
846                    return (
847                        StatementType::CreateIndex,
848                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
849                    );
850                }
851            }
852        }
853
854        if upper_prefix.starts_with(b"ALTER TABLE") {
855            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
856                return (StatementType::AlterTable, name);
857            }
858            if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
859                if let Some(m) = caps.get(1) {
860                    return (
861                        StatementType::AlterTable,
862                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
863                    );
864                }
865            }
866        }
867
868        if upper_prefix.starts_with(b"DROP TABLE") {
869            if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
870                return (StatementType::DropTable, name);
871            }
872            if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
873                if let Some(m) = caps.get(1) {
874                    return (
875                        StatementType::DropTable,
876                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
877                    );
878                }
879            }
880        }
881
882        // MSSQL BULK INSERT - treat as Insert statement type
883        if upper_prefix.starts_with(b"BULK INSERT") {
884            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
885                return (StatementType::Insert, name);
886            }
887        }
888
889        (StatementType::Unknown, String::new())
890    }
891}
892
893#[inline]
894fn trim_ascii_start(data: &[u8]) -> &[u8] {
895    let start = data
896        .iter()
897        .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
898        .unwrap_or(data.len());
899    &data[start..]
900}
901
902/// Strip leading whitespace and SQL line comments (`-- ...`) from a statement.
903/// This makes parsing robust to pg_dump-style comment blocks before statements.
904fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
905    loop {
906        // First trim leading ASCII whitespace
907        data = trim_ascii_start(data);
908
909        // Handle -- line comments
910        if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
911            // Skip until end of line
912            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
913                data = &data[pos + 1..];
914                continue;
915            } else {
916                // Comment runs to EOF, nothing left
917                return &[];
918            }
919        }
920
921        // Handle /* */ block comments (including MySQL conditional comments)
922        if data.len() >= 2 && data[0] == b'/' && data[1] == b'*' {
923            // Find the closing */
924            let mut i = 2;
925            let mut depth = 1;
926            while i < data.len() - 1 && depth > 0 {
927                if data[i] == b'*' && data[i + 1] == b'/' {
928                    depth -= 1;
929                    i += 2;
930                } else if data[i] == b'/' && data[i + 1] == b'*' {
931                    depth += 1;
932                    i += 2;
933                } else {
934                    i += 1;
935                }
936            }
937            if depth == 0 {
938                data = &data[i..];
939                continue;
940            } else {
941                // Unclosed comment runs to EOF
942                return &[];
943            }
944        }
945
946        // Handle # line comments (MySQL)
947        if !data.is_empty() && data[0] == b'#' {
948            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
949                data = &data[pos + 1..];
950                continue;
951            } else {
952                return &[];
953            }
954        }
955
956        break;
957    }
958
959    data
960}
961
962/// Extract table name with support for:
963/// - IF NOT EXISTS
964/// - ONLY (PostgreSQL)
965/// - Schema-qualified names (schema.table)
966/// - Both backtick and double-quote quoting
967#[inline]
968fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
969    let mut i = offset;
970
971    // Skip whitespace
972    while i < stmt.len() && is_whitespace(stmt[i]) {
973        i += 1;
974    }
975
976    if i >= stmt.len() {
977        return None;
978    }
979
980    // Check for IF NOT EXISTS or IF EXISTS (stack-allocated to avoid heap allocation)
981    let mut upper_check = [0u8; 20];
982    let check_len = (stmt.len() - i).min(20);
983    for (idx, &b) in stmt[i..].iter().take(check_len).enumerate() {
984        upper_check[idx] = b.to_ascii_uppercase();
985    }
986    let upper_slice = &upper_check[..check_len];
987    if upper_slice.starts_with(b"IF NOT EXISTS") {
988        i += 13; // Skip "IF NOT EXISTS"
989        while i < stmt.len() && is_whitespace(stmt[i]) {
990            i += 1;
991        }
992    } else if upper_slice.starts_with(b"IF EXISTS") {
993        i += 9; // Skip "IF EXISTS"
994        while i < stmt.len() && is_whitespace(stmt[i]) {
995            i += 1;
996        }
997    }
998
999    // Check for ONLY (PostgreSQL) - reuse first 10 bytes or re-check if position changed
1000    let only_check = if i < stmt.len() {
1001        let mut buf = [0u8; 10];
1002        let len = (stmt.len() - i).min(10);
1003        for (idx, &b) in stmt[i..].iter().take(len).enumerate() {
1004            buf[idx] = b.to_ascii_uppercase();
1005        }
1006        (buf, len)
1007    } else {
1008        ([0u8; 10], 0)
1009    };
1010    let only_slice = &only_check.0[..only_check.1];
1011    if only_slice.starts_with(b"ONLY ") || only_slice.starts_with(b"ONLY\t") {
1012        i += 4;
1013        while i < stmt.len() && is_whitespace(stmt[i]) {
1014            i += 1;
1015        }
1016    }
1017
1018    if i >= stmt.len() {
1019        return None;
1020    }
1021
1022    // Read identifier (potentially schema-qualified)
1023    let mut parts: Vec<String> = Vec::new();
1024
1025    loop {
1026        // Determine quote character based on dialect
1027        let (quote_char, close_char) = match stmt.get(i) {
1028            Some(b'`') if dialect == SqlDialect::MySql => {
1029                i += 1;
1030                (Some(b'`'), b'`')
1031            }
1032            Some(b'"') if dialect != SqlDialect::MySql => {
1033                i += 1;
1034                (Some(b'"'), b'"')
1035            }
1036            Some(b'"') => {
1037                // Allow double quotes for MySQL too (though less common)
1038                i += 1;
1039                (Some(b'"'), b'"')
1040            }
1041            Some(b'[') if dialect == SqlDialect::Mssql => {
1042                // MSSQL square bracket quoting
1043                i += 1;
1044                (Some(b'['), b']')
1045            }
1046            _ => (None, 0),
1047        };
1048
1049        let start = i;
1050
1051        while i < stmt.len() {
1052            let b = stmt[i];
1053            if quote_char.is_some() {
1054                if b == close_char {
1055                    // For MSSQL, check for escaped ]]
1056                    if dialect == SqlDialect::Mssql
1057                        && close_char == b']'
1058                        && i + 1 < stmt.len()
1059                        && stmt[i + 1] == b']'
1060                    {
1061                        // Escaped bracket, skip both
1062                        i += 2;
1063                        continue;
1064                    }
1065                    let name = &stmt[start..i];
1066                    // For MSSQL, unescape ]] to ]
1067                    let name_str = if dialect == SqlDialect::Mssql {
1068                        String::from_utf8_lossy(name).replace("]]", "]")
1069                    } else {
1070                        String::from_utf8_lossy(name).into_owned()
1071                    };
1072                    parts.push(name_str);
1073                    i += 1; // Skip closing quote
1074                    break;
1075                }
1076            } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
1077                if i > start {
1078                    let name = &stmt[start..i];
1079                    parts.push(String::from_utf8_lossy(name).into_owned());
1080                }
1081                break;
1082            }
1083            i += 1;
1084        }
1085
1086        // If at end of quoted name without finding close quote, bail
1087        if quote_char.is_some() && i <= start {
1088            break;
1089        }
1090
1091        // Check for schema separator (.)
1092        while i < stmt.len() && is_whitespace(stmt[i]) {
1093            i += 1;
1094        }
1095
1096        if i < stmt.len() && stmt[i] == b'.' {
1097            i += 1; // Skip the dot
1098            while i < stmt.len() && is_whitespace(stmt[i]) {
1099                i += 1;
1100            }
1101            // Continue to read the next identifier (table name)
1102        } else {
1103            break;
1104        }
1105    }
1106
1107    // Return the last part (table name), not the schema
1108    parts.pop()
1109}
1110
1111#[inline]
1112fn is_whitespace(b: u8) -> bool {
1113    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
1114}
1115
1116pub fn determine_buffer_size(file_size: u64) -> usize {
1117    if file_size > 1024 * 1024 * 1024 {
1118        MEDIUM_BUFFER_SIZE
1119    } else {
1120        SMALL_BUFFER_SIZE
1121    }
1122}