sql_splitter/parser/
mod.rs

1#[cfg(test)]
2mod edge_case_tests;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11/// SQL dialect for parser behavior
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14    /// MySQL/MariaDB mysqldump format (backtick quoting, backslash escapes)
15    #[default]
16    MySql,
17    /// PostgreSQL pg_dump format (double-quote identifiers, COPY FROM stdin, dollar-quoting)
18    Postgres,
19    /// SQLite .dump format (double-quote identifiers, '' escapes)
20    Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24    type Err = String;
25
26    fn from_str(s: &str) -> Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29            "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30            "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31            _ => Err(format!(
32                "Unknown dialect: {}. Valid options: mysql, postgres, sqlite",
33                s
34            )),
35        }
36    }
37}
38
39impl std::fmt::Display for SqlDialect {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            SqlDialect::MySql => write!(f, "mysql"),
43            SqlDialect::Postgres => write!(f, "postgres"),
44            SqlDialect::Sqlite => write!(f, "sqlite"),
45        }
46    }
47}
48
49/// Result of dialect auto-detection
50#[derive(Debug, Clone)]
51pub struct DialectDetectionResult {
52    pub dialect: SqlDialect,
53    pub confidence: DialectConfidence,
54}
55
56/// Confidence level of dialect detection
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DialectConfidence {
59    /// High confidence - found definitive markers (e.g., "pg_dump", "MySQL dump")
60    High,
61    /// Medium confidence - found likely markers
62    Medium,
63    /// Low confidence - defaulting to MySQL
64    Low,
65}
66
67#[derive(Default)]
68struct DialectScore {
69    mysql: u32,
70    postgres: u32,
71    sqlite: u32,
72}
73
74/// Detect SQL dialect from file header content.
75/// Reads up to 8KB and looks for dialect-specific markers.
76pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
77    let mut score = DialectScore::default();
78
79    // High confidence markers (+10)
80    if contains_bytes(header, b"pg_dump") {
81        score.postgres += 10;
82    }
83    if contains_bytes(header, b"PostgreSQL database dump") {
84        score.postgres += 10;
85    }
86    if contains_bytes(header, b"MySQL dump") {
87        score.mysql += 10;
88    }
89    if contains_bytes(header, b"MariaDB dump") {
90        score.mysql += 10;
91    }
92    if contains_bytes(header, b"SQLite") {
93        score.sqlite += 10;
94    }
95
96    // Medium confidence markers (+5)
97    if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
98        score.postgres += 5;
99    }
100    if contains_bytes(header, b"search_path") {
101        score.postgres += 5;
102    }
103    if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
104        score.mysql += 5;
105    }
106    if contains_bytes(header, b"LOCK TABLES") {
107        score.mysql += 5;
108    }
109    if contains_bytes(header, b"PRAGMA") {
110        score.sqlite += 5;
111    }
112
113    // Low confidence markers (+2)
114    if contains_bytes(header, b"$$") {
115        score.postgres += 2;
116    }
117    if contains_bytes(header, b"CREATE EXTENSION") {
118        score.postgres += 2;
119    }
120    // BEGIN TRANSACTION is generic ANSI SQL, only slightly suggests SQLite
121    if contains_bytes(header, b"BEGIN TRANSACTION") {
122        score.sqlite += 2;
123    }
124    // Backticks suggest MySQL (could also appear in data/comments)
125    if header.contains(&b'`') {
126        score.mysql += 2;
127    }
128
129    // Determine winner and confidence
130    let max_score = score.mysql.max(score.postgres).max(score.sqlite);
131
132    if max_score == 0 {
133        return DialectDetectionResult {
134            dialect: SqlDialect::MySql,
135            confidence: DialectConfidence::Low,
136        };
137    }
138
139    let (dialect, confidence) = if score.postgres > score.mysql && score.postgres > score.sqlite {
140        let conf = if score.postgres >= 10 {
141            DialectConfidence::High
142        } else if score.postgres >= 5 {
143            DialectConfidence::Medium
144        } else {
145            DialectConfidence::Low
146        };
147        (SqlDialect::Postgres, conf)
148    } else if score.sqlite > score.mysql {
149        let conf = if score.sqlite >= 10 {
150            DialectConfidence::High
151        } else if score.sqlite >= 5 {
152            DialectConfidence::Medium
153        } else {
154            DialectConfidence::Low
155        };
156        (SqlDialect::Sqlite, conf)
157    } else {
158        let conf = if score.mysql >= 10 {
159            DialectConfidence::High
160        } else if score.mysql >= 5 {
161            DialectConfidence::Medium
162        } else {
163            DialectConfidence::Low
164        };
165        (SqlDialect::MySql, conf)
166    };
167
168    DialectDetectionResult {
169        dialect,
170        confidence,
171    }
172}
173
174/// Detect dialect from a file, reading first 8KB
175pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
176    use std::fs::File;
177    use std::io::Read;
178
179    let mut file = File::open(path)?;
180    let mut buf = [0u8; 8192];
181    let n = file.read(&mut buf)?;
182    Ok(detect_dialect(&buf[..n]))
183}
184
185#[inline]
186fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
187    haystack
188        .windows(needle.len())
189        .any(|window| window == needle)
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
193pub enum StatementType {
194    Unknown,
195    CreateTable,
196    Insert,
197    CreateIndex,
198    AlterTable,
199    DropTable,
200    /// PostgreSQL COPY FROM stdin
201    Copy,
202}
203
204impl StatementType {
205    /// Returns true if this is a schema-related statement (DDL)
206    pub fn is_schema(&self) -> bool {
207        matches!(
208            self,
209            StatementType::CreateTable
210                | StatementType::CreateIndex
211                | StatementType::AlterTable
212                | StatementType::DropTable
213        )
214    }
215
216    /// Returns true if this is a data-related statement (DML)
217    pub fn is_data(&self) -> bool {
218        matches!(self, StatementType::Insert | StatementType::Copy)
219    }
220}
221
222/// Content filter mode for splitting
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
224pub enum ContentFilter {
225    /// Include both schema and data statements (default)
226    #[default]
227    All,
228    /// Only schema statements (CREATE TABLE, CREATE INDEX, ALTER TABLE, DROP TABLE)
229    SchemaOnly,
230    /// Only data statements (INSERT, COPY)
231    DataOnly,
232}
233
234static CREATE_TABLE_RE: Lazy<Regex> =
235    Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
236
237static INSERT_INTO_RE: Lazy<Regex> =
238    Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
239
240static CREATE_INDEX_RE: Lazy<Regex> =
241    Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
242
243static ALTER_TABLE_RE: Lazy<Regex> =
244    Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
245
246static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
247    Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
248});
249
250// PostgreSQL COPY statement regex
251static COPY_RE: Lazy<Regex> =
252    Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
253
254// More flexible table name regex that handles:
255// - Backticks: `table`
256// - Double quotes: "table"
257// - Schema qualified: schema.table, `schema`.`table`, "schema"."table"
258// - IF NOT EXISTS
259static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
260    Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
261});
262
263static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
264    Regex::new(
265        r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
266    )
267    .unwrap()
268});
269
270pub struct Parser<R: Read> {
271    reader: BufReader<R>,
272    stmt_buffer: Vec<u8>,
273    dialect: SqlDialect,
274    /// For PostgreSQL: true when reading COPY data block
275    in_copy_data: bool,
276}
277
278impl<R: Read> Parser<R> {
279    #[allow(dead_code)]
280    pub fn new(reader: R, buffer_size: usize) -> Self {
281        Self::with_dialect(reader, buffer_size, SqlDialect::default())
282    }
283
284    pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
285        Self {
286            reader: BufReader::with_capacity(buffer_size, reader),
287            stmt_buffer: Vec::with_capacity(32 * 1024),
288            dialect,
289            in_copy_data: false,
290        }
291    }
292
293    pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
294        // If we're in PostgreSQL COPY data mode, read until we see the terminator
295        if self.in_copy_data {
296            return self.read_copy_data();
297        }
298
299        self.stmt_buffer.clear();
300
301        let mut inside_single_quote = false;
302        let mut inside_double_quote = false;
303        let mut escaped = false;
304        let mut in_line_comment = false;
305        // For PostgreSQL dollar-quoting: track the tag
306        let mut in_dollar_quote = false;
307        let mut dollar_tag: Vec<u8> = Vec::new();
308
309        loop {
310            let buf = self.reader.fill_buf()?;
311            if buf.is_empty() {
312                if self.stmt_buffer.is_empty() {
313                    return Ok(None);
314                }
315                let result = std::mem::take(&mut self.stmt_buffer);
316                return Ok(Some(result));
317            }
318
319            let mut consumed = 0;
320            let mut found_terminator = false;
321
322            for (i, &b) in buf.iter().enumerate() {
323                let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
324
325                // End of line comment on newline
326                if in_line_comment {
327                    if b == b'\n' {
328                        in_line_comment = false;
329                    }
330                    continue;
331                }
332
333                if escaped {
334                    escaped = false;
335                    continue;
336                }
337
338                // Handle backslash escapes (MySQL style)
339                if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
340                    escaped = true;
341                    continue;
342                }
343
344                // Handle line comments (-- to end of line)
345                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
346                    in_line_comment = true;
347                    continue;
348                }
349
350                // Handle dollar-quoting for PostgreSQL
351                if self.dialect == SqlDialect::Postgres
352                    && !inside_single_quote
353                    && !inside_double_quote
354                {
355                    if b == b'$' && !in_dollar_quote {
356                        // Start of dollar-quote: scan for the closing $
357                        if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
358                            let tag_bytes = &buf[i + 1..i + 1 + end];
359
360                            // Validate tag: must be empty OR identifier-like [A-Za-z_][A-Za-z0-9_]*
361                            let is_valid_tag = if tag_bytes.is_empty() {
362                                true
363                            } else {
364                                let mut iter = tag_bytes.iter();
365                                match iter.next() {
366                                    Some(&first)
367                                        if first.is_ascii_alphabetic() || first == b'_' =>
368                                    {
369                                        iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
370                                    }
371                                    _ => false,
372                                }
373                            };
374
375                            if is_valid_tag {
376                                dollar_tag = tag_bytes.to_vec();
377                                in_dollar_quote = true;
378                                continue;
379                            }
380                            // Invalid tag - treat $ as normal character
381                        }
382                    } else if b == b'$' && in_dollar_quote {
383                        // Potential end of dollar-quote
384                        let tag_len = dollar_tag.len();
385                        if i + 1 + tag_len < buf.len()
386                            && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
387                            && buf.get(i + 1 + tag_len) == Some(&b'$')
388                        {
389                            in_dollar_quote = false;
390                            dollar_tag.clear();
391                            continue;
392                        }
393                    }
394                }
395
396                if b == b'\'' && !inside_double_quote && !in_dollar_quote {
397                    inside_single_quote = !inside_single_quote;
398                } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
399                    inside_double_quote = !inside_double_quote;
400                } else if b == b';' && !inside_string {
401                    self.stmt_buffer.extend_from_slice(&buf[..=i]);
402                    consumed = i + 1;
403                    found_terminator = true;
404                    break;
405                }
406            }
407
408            if found_terminator {
409                self.reader.consume(consumed);
410                let result = std::mem::take(&mut self.stmt_buffer);
411
412                // Check if this is a PostgreSQL COPY FROM stdin statement
413                if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
414                    self.in_copy_data = true;
415                }
416
417                return Ok(Some(result));
418            }
419
420            self.stmt_buffer.extend_from_slice(buf);
421            let len = buf.len();
422            self.reader.consume(len);
423        }
424    }
425
426    /// Check if statement is a PostgreSQL COPY FROM stdin
427    fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
428        // Strip leading comments (pg_dump adds -- comments before COPY statements)
429        let stmt = strip_leading_comments_and_whitespace(stmt);
430        if stmt.len() < 4 {
431            return false;
432        }
433
434        // Take enough bytes to cover column lists - typical COPY statements are <500 bytes
435        let upper: Vec<u8> = stmt
436            .iter()
437            .take(500)
438            .map(|b| b.to_ascii_uppercase())
439            .collect();
440        upper.starts_with(b"COPY ")
441            && (upper.windows(10).any(|w| w == b"FROM STDIN")
442                || upper.windows(11).any(|w| w == b"FROM STDIN;"))
443    }
444
445    /// Read PostgreSQL COPY data block until we see the terminator line (\.)
446    fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
447        self.stmt_buffer.clear();
448
449        loop {
450            // First, fill the buffer and check if empty
451            let buf = self.reader.fill_buf()?;
452            if buf.is_empty() {
453                self.in_copy_data = false;
454                if self.stmt_buffer.is_empty() {
455                    return Ok(None);
456                }
457                return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
458            }
459
460            // Look for a newline in the buffer
461            let newline_pos = buf.iter().position(|&b| b == b'\n');
462
463            if let Some(i) = newline_pos {
464                // Include this newline
465                self.stmt_buffer.extend_from_slice(&buf[..=i]);
466                self.reader.consume(i + 1);
467
468                // Check if the line we just added ends the COPY block
469                // Looking for a line that is just "\.\n" or "\.\r\n"
470                if self.ends_with_copy_terminator() {
471                    self.in_copy_data = false;
472                    return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
473                }
474                // Continue reading - we need to process more lines
475            } else {
476                // No newline found, consume the whole buffer and continue
477                let len = buf.len();
478                self.stmt_buffer.extend_from_slice(buf);
479                self.reader.consume(len);
480            }
481        }
482    }
483
484    /// Check if buffer ends with the COPY terminator line (\.)
485    fn ends_with_copy_terminator(&self) -> bool {
486        let data = &self.stmt_buffer;
487        if data.len() < 2 {
488            return false;
489        }
490
491        // Look for a line that is just "\.\n" or "\.\r\n"
492        // We need to find the start of the last line
493        let last_newline = data[..data.len() - 1]
494            .iter()
495            .rposition(|&b| b == b'\n')
496            .map(|i| i + 1)
497            .unwrap_or(0);
498
499        let last_line = &data[last_newline..];
500
501        // Check if it's "\.\n" or "\.\r\n"
502        last_line == b"\\.\n" || last_line == b"\\.\r\n"
503    }
504
505    #[allow(dead_code)]
506    pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
507        Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
508    }
509
510    /// Parse a statement with dialect-specific handling
511    pub fn parse_statement_with_dialect(
512        stmt: &[u8],
513        dialect: SqlDialect,
514    ) -> (StatementType, String) {
515        // Strip leading comments (e.g., pg_dump adds -- comments before statements)
516        let stmt = strip_leading_comments_and_whitespace(stmt);
517
518        if stmt.len() < 4 {
519            return (StatementType::Unknown, String::new());
520        }
521
522        let upper_prefix: Vec<u8> = stmt
523            .iter()
524            .take(25)
525            .map(|b| b.to_ascii_uppercase())
526            .collect();
527
528        // PostgreSQL COPY statement
529        if upper_prefix.starts_with(b"COPY ") {
530            if let Some(caps) = COPY_RE.captures(stmt) {
531                if let Some(m) = caps.get(1) {
532                    let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
533                    // Handle schema.table - extract just the table name
534                    let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
535                    return (StatementType::Copy, table_name);
536                }
537            }
538        }
539
540        if upper_prefix.starts_with(b"CREATE TABLE") {
541            // Try fast extraction first
542            if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
543                return (StatementType::CreateTable, name);
544            }
545            // Fall back to flexible regex
546            if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
547                if let Some(m) = caps.get(1) {
548                    return (
549                        StatementType::CreateTable,
550                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
551                    );
552                }
553            }
554            // Original regex as last resort
555            if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
556                if let Some(m) = caps.get(1) {
557                    return (
558                        StatementType::CreateTable,
559                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
560                    );
561                }
562            }
563        }
564
565        if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
566            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
567                return (StatementType::Insert, name);
568            }
569            if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
570                if let Some(m) = caps.get(1) {
571                    return (
572                        StatementType::Insert,
573                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
574                    );
575                }
576            }
577            if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
578                if let Some(m) = caps.get(1) {
579                    return (
580                        StatementType::Insert,
581                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
582                    );
583                }
584            }
585        }
586
587        if upper_prefix.starts_with(b"CREATE INDEX") {
588            if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
589                if let Some(m) = caps.get(1) {
590                    return (
591                        StatementType::CreateIndex,
592                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
593                    );
594                }
595            }
596        }
597
598        if upper_prefix.starts_with(b"ALTER TABLE") {
599            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
600                return (StatementType::AlterTable, name);
601            }
602            if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
603                if let Some(m) = caps.get(1) {
604                    return (
605                        StatementType::AlterTable,
606                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
607                    );
608                }
609            }
610        }
611
612        if upper_prefix.starts_with(b"DROP TABLE") {
613            if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
614                return (StatementType::DropTable, name);
615            }
616            if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
617                if let Some(m) = caps.get(1) {
618                    return (
619                        StatementType::DropTable,
620                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
621                    );
622                }
623            }
624        }
625
626        (StatementType::Unknown, String::new())
627    }
628}
629
630#[inline]
631fn trim_ascii_start(data: &[u8]) -> &[u8] {
632    let start = data
633        .iter()
634        .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
635        .unwrap_or(data.len());
636    &data[start..]
637}
638
639/// Strip leading whitespace and SQL line comments (`-- ...`) from a statement.
640/// This makes parsing robust to pg_dump-style comment blocks before statements.
641fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
642    loop {
643        // First trim leading ASCII whitespace
644        data = trim_ascii_start(data);
645
646        if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
647            // Skip until end of line
648            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
649                data = &data[pos + 1..];
650                continue;
651            } else {
652                // Comment runs to EOF, nothing left
653                return &[];
654            }
655        }
656
657        break;
658    }
659
660    data
661}
662
663/// Extract table name with support for:
664/// - IF NOT EXISTS
665/// - ONLY (PostgreSQL)
666/// - Schema-qualified names (schema.table)
667/// - Both backtick and double-quote quoting
668#[inline]
669fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
670    let mut i = offset;
671
672    // Skip whitespace
673    while i < stmt.len() && is_whitespace(stmt[i]) {
674        i += 1;
675    }
676
677    if i >= stmt.len() {
678        return None;
679    }
680
681    // Check for IF NOT EXISTS or IF EXISTS
682    let upper_check: Vec<u8> = stmt[i..]
683        .iter()
684        .take(20)
685        .map(|b| b.to_ascii_uppercase())
686        .collect();
687    if upper_check.starts_with(b"IF NOT EXISTS") {
688        i += 13; // Skip "IF NOT EXISTS"
689        while i < stmt.len() && is_whitespace(stmt[i]) {
690            i += 1;
691        }
692    } else if upper_check.starts_with(b"IF EXISTS") {
693        i += 9; // Skip "IF EXISTS"
694        while i < stmt.len() && is_whitespace(stmt[i]) {
695            i += 1;
696        }
697    }
698
699    // Check for ONLY (PostgreSQL)
700    let upper_check: Vec<u8> = stmt[i..]
701        .iter()
702        .take(10)
703        .map(|b| b.to_ascii_uppercase())
704        .collect();
705    if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
706        i += 4;
707        while i < stmt.len() && is_whitespace(stmt[i]) {
708            i += 1;
709        }
710    }
711
712    if i >= stmt.len() {
713        return None;
714    }
715
716    // Read identifier (potentially schema-qualified)
717    let mut parts: Vec<String> = Vec::new();
718
719    loop {
720        // Determine quote character
721        let quote_char = match stmt.get(i) {
722            Some(b'`') if dialect == SqlDialect::MySql => {
723                i += 1;
724                Some(b'`')
725            }
726            Some(b'"') if dialect != SqlDialect::MySql => {
727                i += 1;
728                Some(b'"')
729            }
730            Some(b'"') => {
731                // Allow double quotes for MySQL too (though less common)
732                i += 1;
733                Some(b'"')
734            }
735            _ => None,
736        };
737
738        let start = i;
739
740        while i < stmt.len() {
741            let b = stmt[i];
742            if let Some(q) = quote_char {
743                if b == q {
744                    let name = &stmt[start..i];
745                    parts.push(String::from_utf8_lossy(name).into_owned());
746                    i += 1; // Skip closing quote
747                    break;
748                }
749            } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
750                if i > start {
751                    let name = &stmt[start..i];
752                    parts.push(String::from_utf8_lossy(name).into_owned());
753                }
754                break;
755            }
756            i += 1;
757        }
758
759        // If at end of quoted name without finding close quote, bail
760        if quote_char.is_some() && i <= start {
761            break;
762        }
763
764        // Check for schema separator (.)
765        while i < stmt.len() && is_whitespace(stmt[i]) {
766            i += 1;
767        }
768
769        if i < stmt.len() && stmt[i] == b'.' {
770            i += 1; // Skip the dot
771            while i < stmt.len() && is_whitespace(stmt[i]) {
772                i += 1;
773            }
774            // Continue to read the next identifier (table name)
775        } else {
776            break;
777        }
778    }
779
780    // Return the last part (table name), not the schema
781    parts.pop()
782}
783
784#[inline]
785fn is_whitespace(b: u8) -> bool {
786    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
787}
788
789pub fn determine_buffer_size(file_size: u64) -> usize {
790    if file_size > 1024 * 1024 * 1024 {
791        MEDIUM_BUFFER_SIZE
792    } else {
793        SMALL_BUFFER_SIZE
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    #[test]
802    fn test_parse_create_table() {
803        let stmt = b"CREATE TABLE users (id INT);";
804        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
805        assert_eq!(typ, StatementType::CreateTable);
806        assert_eq!(name, "users");
807    }
808
809    #[test]
810    fn test_parse_create_table_backticks() {
811        let stmt = b"CREATE TABLE `my_table` (id INT);";
812        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
813        assert_eq!(typ, StatementType::CreateTable);
814        assert_eq!(name, "my_table");
815    }
816
817    #[test]
818    fn test_parse_insert() {
819        let stmt = b"INSERT INTO posts VALUES (1, 'test');";
820        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
821        assert_eq!(typ, StatementType::Insert);
822        assert_eq!(name, "posts");
823    }
824
825    #[test]
826    fn test_parse_insert_backticks() {
827        let stmt = b"INSERT INTO `comments` VALUES (1);";
828        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
829        assert_eq!(typ, StatementType::Insert);
830        assert_eq!(name, "comments");
831    }
832
833    #[test]
834    fn test_parse_alter_table() {
835        let stmt = b"ALTER TABLE orders ADD COLUMN status INT;";
836        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
837        assert_eq!(typ, StatementType::AlterTable);
838        assert_eq!(name, "orders");
839    }
840
841    #[test]
842    fn test_parse_drop_table() {
843        let stmt = b"DROP TABLE temp_data;";
844        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
845        assert_eq!(typ, StatementType::DropTable);
846        assert_eq!(name, "temp_data");
847    }
848
849    #[test]
850    fn test_read_statement_basic() {
851        let sql = b"CREATE TABLE t1 (id INT); INSERT INTO t1 VALUES (1);";
852        let mut parser = Parser::new(&sql[..], 1024);
853
854        let stmt1 = parser.read_statement().unwrap().unwrap();
855        assert_eq!(stmt1, b"CREATE TABLE t1 (id INT);");
856
857        let stmt2 = parser.read_statement().unwrap().unwrap();
858        assert_eq!(stmt2, b" INSERT INTO t1 VALUES (1);");
859
860        let stmt3 = parser.read_statement().unwrap();
861        assert!(stmt3.is_none());
862    }
863
864    #[test]
865    fn test_read_statement_with_strings() {
866        let sql = b"INSERT INTO t1 VALUES ('hello; world');";
867        let mut parser = Parser::new(&sql[..], 1024);
868
869        let stmt = parser.read_statement().unwrap().unwrap();
870        assert_eq!(stmt, b"INSERT INTO t1 VALUES ('hello; world');");
871    }
872
873    #[test]
874    fn test_read_statement_with_escaped_quotes() {
875        let sql = b"INSERT INTO t1 VALUES ('it\\'s a test');";
876        let mut parser = Parser::new(&sql[..], 1024);
877
878        let stmt = parser.read_statement().unwrap().unwrap();
879        assert_eq!(stmt, b"INSERT INTO t1 VALUES ('it\\'s a test');");
880    }
881}
882
883#[cfg(test)]
884mod copy_tests {
885    use super::*;
886    use std::io::Cursor;
887
888    #[test]
889    fn test_copy_from_stdin_detection() {
890        let data = b"COPY public.table_001 (id, col_int, col_varchar, col_text, col_decimal, created_at) FROM stdin;\n1\t6892\tvalue_1\tLorem ipsum\n\\.\n";
891        let reader = Cursor::new(&data[..]);
892        let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
893
894        // First statement should be the COPY header
895        let stmt1 = parser.read_statement().unwrap().unwrap();
896        let s1 = String::from_utf8_lossy(&stmt1);
897        assert!(s1.starts_with("COPY"), "First statement should be COPY");
898        assert!(s1.contains("FROM stdin"), "Should contain FROM stdin");
899
900        // Second statement should be the data block
901        let stmt2 = parser.read_statement().unwrap().unwrap();
902        let s2 = String::from_utf8_lossy(&stmt2);
903        assert!(
904            s2.contains("1\t6892"),
905            "Data block should contain first row"
906        );
907        assert!(
908            s2.ends_with("\\.\n"),
909            "Data block should end with terminator"
910        );
911    }
912
913    #[test]
914    fn test_copy_with_leading_comments() {
915        // pg_dump adds -- comments before COPY statements
916        let data = b"--\n-- Data for Name: table_001\n--\n\nCOPY public.table_001 (id, name) FROM stdin;\n1\tfoo\n\\.\n";
917        let reader = Cursor::new(&data[..]);
918        let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
919
920        // First statement should be the COPY header (with leading comments)
921        let stmt1 = parser.read_statement().unwrap().unwrap();
922        let (stmt_type, table_name) =
923            Parser::<&[u8]>::parse_statement_with_dialect(&stmt1, SqlDialect::Postgres);
924        assert_eq!(stmt_type, StatementType::Copy);
925        assert_eq!(table_name, "table_001");
926
927        // Second statement should be the data block
928        let stmt2 = parser.read_statement().unwrap().unwrap();
929        let s2 = String::from_utf8_lossy(&stmt2);
930        assert!(
931            s2.ends_with("\\.\n"),
932            "Data block should end with terminator"
933        );
934    }
935}
936
937#[cfg(test)]
938mod dialect_detection_tests {
939    use super::*;
940
941    #[test]
942    fn test_detect_mysql_dump_header() {
943        let header = b"-- MySQL dump 10.13  Distrib 8.0.32, for Linux (x86_64)
944--
945-- Host: localhost    Database: mydb
946-- ------------------------------------------------------
947-- Server version	8.0.32
948
949/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
950";
951        let result = detect_dialect(header);
952        assert_eq!(result.dialect, SqlDialect::MySql);
953        assert_eq!(result.confidence, DialectConfidence::High);
954    }
955
956    #[test]
957    fn test_detect_mariadb_dump_header() {
958        let header = b"-- MariaDB dump 10.19  Distrib 10.11.2-MariaDB
959--
960-- Host: localhost    Database: test
961";
962        let result = detect_dialect(header);
963        assert_eq!(result.dialect, SqlDialect::MySql);
964        assert_eq!(result.confidence, DialectConfidence::High);
965    }
966
967    #[test]
968    fn test_detect_postgres_pgdump_header() {
969        let header = b"--
970-- PostgreSQL database dump
971--
972
973-- Dumped from database version 15.2
974-- Dumped by pg_dump version 15.2
975
976SET statement_timeout = 0;
977SET search_path = public, pg_catalog;
978";
979        let result = detect_dialect(header);
980        assert_eq!(result.dialect, SqlDialect::Postgres);
981        assert_eq!(result.confidence, DialectConfidence::High);
982    }
983
984    #[test]
985    fn test_detect_postgres_copy_statement() {
986        let header = b"COPY public.users (id, name, email) FROM stdin;
9871\tAlice\talice@example.com
9882\tBob\tbob@example.com
989\\.
990";
991        let result = detect_dialect(header);
992        assert_eq!(result.dialect, SqlDialect::Postgres);
993        assert_eq!(result.confidence, DialectConfidence::Medium);
994    }
995
996    #[test]
997    fn test_detect_postgres_dollar_quoting() {
998        let header = b"CREATE OR REPLACE FUNCTION test() RETURNS void AS $$
999BEGIN
1000    RAISE NOTICE 'Hello';
1001END;
1002$$ LANGUAGE plpgsql;
1003";
1004        let result = detect_dialect(header);
1005        assert_eq!(result.dialect, SqlDialect::Postgres);
1006    }
1007
1008    #[test]
1009    fn test_detect_sqlite_dump_header() {
1010        // Real sqlite3 .dump output has a comment at the top
1011        let header = b"-- SQLite database dump
1012PRAGMA foreign_keys=OFF;
1013BEGIN TRANSACTION;
1014CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT);
1015INSERT INTO users VALUES(1,'Alice');
1016COMMIT;
1017";
1018        let result = detect_dialect(header);
1019        assert_eq!(result.dialect, SqlDialect::Sqlite);
1020        // SQLite (+10) + PRAGMA (+5) + BEGIN TRANSACTION (+2) = High
1021        assert_eq!(result.confidence, DialectConfidence::High);
1022    }
1023
1024    #[test]
1025    fn test_detect_sqlite_pragma_only() {
1026        let header = b"PRAGMA foreign_keys=OFF;
1027CREATE TABLE test (id INT);
1028";
1029        let result = detect_dialect(header);
1030        assert_eq!(result.dialect, SqlDialect::Sqlite);
1031        assert_eq!(result.confidence, DialectConfidence::Medium);
1032    }
1033
1034    #[test]
1035    fn test_detect_mysql_backticks() {
1036        let header = b"CREATE TABLE `users` (
1037  `id` int NOT NULL AUTO_INCREMENT,
1038  `name` varchar(255) DEFAULT NULL,
1039  PRIMARY KEY (`id`)
1040);
1041";
1042        let result = detect_dialect(header);
1043        assert_eq!(result.dialect, SqlDialect::MySql);
1044    }
1045
1046    #[test]
1047    fn test_detect_mysql_conditional_comments() {
1048        let header = b"/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
1049/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
1050/*!50503 SET NAMES utf8mb4 */;
1051";
1052        let result = detect_dialect(header);
1053        assert_eq!(result.dialect, SqlDialect::MySql);
1054        assert_eq!(result.confidence, DialectConfidence::Medium);
1055    }
1056
1057    #[test]
1058    fn test_detect_mysql_lock_tables() {
1059        let header = b"LOCK TABLES `users` WRITE;
1060INSERT INTO `users` VALUES (1,'test');
1061UNLOCK TABLES;
1062";
1063        let result = detect_dialect(header);
1064        assert_eq!(result.dialect, SqlDialect::MySql);
1065        assert_eq!(result.confidence, DialectConfidence::Medium);
1066    }
1067
1068    #[test]
1069    fn test_detect_empty_defaults_to_mysql() {
1070        let header = b"";
1071        let result = detect_dialect(header);
1072        assert_eq!(result.dialect, SqlDialect::MySql);
1073        assert_eq!(result.confidence, DialectConfidence::Low);
1074    }
1075
1076    #[test]
1077    fn test_detect_generic_sql_defaults_to_mysql() {
1078        let header = b"CREATE TABLE users (id INT, name VARCHAR(100));
1079INSERT INTO users VALUES (1, 'Alice');
1080";
1081        let result = detect_dialect(header);
1082        assert_eq!(result.dialect, SqlDialect::MySql);
1083        assert_eq!(result.confidence, DialectConfidence::Low);
1084    }
1085
1086    #[test]
1087    fn test_detect_postgres_create_extension() {
1088        let header = b"CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";
1089CREATE TABLE users (id uuid DEFAULT uuid_generate_v4());
1090";
1091        let result = detect_dialect(header);
1092        assert_eq!(result.dialect, SqlDialect::Postgres);
1093    }
1094
1095    #[test]
1096    fn test_detect_sqlite_comment() {
1097        let header = b"-- SQLite database dump
1098-- Created by sqlite3
1099
1100CREATE TABLE test (id INTEGER);
1101";
1102        let result = detect_dialect(header);
1103        assert_eq!(result.dialect, SqlDialect::Sqlite);
1104        assert_eq!(result.confidence, DialectConfidence::High);
1105    }
1106
1107    #[test]
1108    fn test_scoring_postgres_beats_mysql_backticks() {
1109        // pg_dump header with some backticks in data shouldn't confuse it
1110        let header = b"--
1111-- PostgreSQL database dump
1112--
1113-- Dumped by pg_dump version 15.2
1114
1115INSERT INTO notes VALUES (1, 'Use `code` for inline code');
1116";
1117        let result = detect_dialect(header);
1118        assert_eq!(result.dialect, SqlDialect::Postgres);
1119        assert_eq!(result.confidence, DialectConfidence::High);
1120    }
1121
1122    #[test]
1123    fn test_begin_transaction_alone_is_low_confidence() {
1124        // BEGIN TRANSACTION is generic ANSI SQL, not definitive for SQLite
1125        let header = b"BEGIN TRANSACTION;
1126CREATE TABLE t (id INTEGER);
1127COMMIT;
1128";
1129        let result = detect_dialect(header);
1130        // Should detect SQLite but with low confidence since only generic markers
1131        assert_eq!(result.dialect, SqlDialect::Sqlite);
1132        assert_eq!(result.confidence, DialectConfidence::Low);
1133    }
1134
1135    #[test]
1136    fn test_backticks_only_is_low_confidence() {
1137        // Backticks alone shouldn't give high confidence MySQL
1138        let header = b"CREATE TABLE `users` (id INT);
1139INSERT INTO `users` VALUES (1);
1140";
1141        let result = detect_dialect(header);
1142        assert_eq!(result.dialect, SqlDialect::MySql);
1143        assert_eq!(result.confidence, DialectConfidence::Low);
1144    }
1145
1146    #[test]
1147    fn test_conflicting_markers_postgres_wins() {
1148        // PostgreSQL dump header should beat MySQL-style backticks in data
1149        let header = b"-- PostgreSQL database dump
1150SET search_path = public;
1151INSERT INTO notes VALUES (1, 'Use `backticks` for code');
1152";
1153        let result = detect_dialect(header);
1154        assert_eq!(result.dialect, SqlDialect::Postgres);
1155        // High confidence because we have strong Postgres markers
1156        assert_eq!(result.confidence, DialectConfidence::High);
1157    }
1158}