sql_splitter/parser/
mod.rs

1#[cfg(test)]
2mod edge_case_tests;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11/// SQL dialect for parser behavior
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14    /// MySQL/MariaDB mysqldump format (backtick quoting, backslash escapes)
15    #[default]
16    MySql,
17    /// PostgreSQL pg_dump format (double-quote identifiers, COPY FROM stdin, dollar-quoting)
18    Postgres,
19    /// SQLite .dump format (double-quote identifiers, '' escapes)
20    Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24    type Err = String;
25
26    fn from_str(s: &str) -> Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29            "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30            "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31            _ => Err(format!(
32                "Unknown dialect: {}. Valid options: mysql, postgres, sqlite",
33                s
34            )),
35        }
36    }
37}
38
39impl std::fmt::Display for SqlDialect {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            SqlDialect::MySql => write!(f, "mysql"),
43            SqlDialect::Postgres => write!(f, "postgres"),
44            SqlDialect::Sqlite => write!(f, "sqlite"),
45        }
46    }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum StatementType {
51    Unknown,
52    CreateTable,
53    Insert,
54    CreateIndex,
55    AlterTable,
56    DropTable,
57    /// PostgreSQL COPY FROM stdin
58    Copy,
59}
60
61static CREATE_TABLE_RE: Lazy<Regex> =
62    Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
63
64static INSERT_INTO_RE: Lazy<Regex> =
65    Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
66
67static CREATE_INDEX_RE: Lazy<Regex> =
68    Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
69
70static ALTER_TABLE_RE: Lazy<Regex> =
71    Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
72
73static DROP_TABLE_RE: Lazy<Regex> =
74    Lazy::new(|| Regex::new(r"(?i)DROP\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
75
76// PostgreSQL COPY statement regex
77static COPY_RE: Lazy<Regex> =
78    Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
79
80// More flexible table name regex that handles:
81// - Backticks: `table`
82// - Double quotes: "table"
83// - Schema qualified: schema.table, `schema`.`table`, "schema"."table"
84// - IF NOT EXISTS
85static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
86    Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
87});
88
89static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
90    Regex::new(
91        r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
92    )
93    .unwrap()
94});
95
96pub struct Parser<R: Read> {
97    reader: BufReader<R>,
98    stmt_buffer: Vec<u8>,
99    dialect: SqlDialect,
100    /// For PostgreSQL: true when reading COPY data block
101    in_copy_data: bool,
102}
103
104impl<R: Read> Parser<R> {
105    pub fn new(reader: R, buffer_size: usize) -> Self {
106        Self::with_dialect(reader, buffer_size, SqlDialect::default())
107    }
108
109    pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
110        Self {
111            reader: BufReader::with_capacity(buffer_size, reader),
112            stmt_buffer: Vec::with_capacity(32 * 1024),
113            dialect,
114            in_copy_data: false,
115        }
116    }
117
118    pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
119        // If we're in PostgreSQL COPY data mode, read until we see the terminator
120        if self.in_copy_data {
121            return self.read_copy_data();
122        }
123
124        self.stmt_buffer.clear();
125
126        let mut inside_single_quote = false;
127        let mut inside_double_quote = false;
128        let mut escaped = false;
129        let mut in_line_comment = false;
130        // For PostgreSQL dollar-quoting: track the tag
131        let mut in_dollar_quote = false;
132        let mut dollar_tag: Vec<u8> = Vec::new();
133
134        loop {
135            let buf = self.reader.fill_buf()?;
136            if buf.is_empty() {
137                if self.stmt_buffer.is_empty() {
138                    return Ok(None);
139                }
140                let result = std::mem::take(&mut self.stmt_buffer);
141                return Ok(Some(result));
142            }
143
144            let mut consumed = 0;
145            let mut found_terminator = false;
146
147            for (i, &b) in buf.iter().enumerate() {
148                let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
149
150                // End of line comment on newline
151                if in_line_comment {
152                    if b == b'\n' {
153                        in_line_comment = false;
154                    }
155                    continue;
156                }
157
158                if escaped {
159                    escaped = false;
160                    continue;
161                }
162
163                // Handle backslash escapes (MySQL style)
164                if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
165                    escaped = true;
166                    continue;
167                }
168
169                // Handle line comments (-- to end of line)
170                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
171                    in_line_comment = true;
172                    continue;
173                }
174
175                // Handle dollar-quoting for PostgreSQL
176                if self.dialect == SqlDialect::Postgres
177                    && !inside_single_quote
178                    && !inside_double_quote
179                {
180                    if b == b'$' && !in_dollar_quote {
181                        // Start of dollar-quote: scan for the closing $
182                        if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
183                            dollar_tag = buf[i + 1..i + 1 + end].to_vec();
184                            in_dollar_quote = true;
185                            continue;
186                        }
187                    } else if b == b'$' && in_dollar_quote {
188                        // Potential end of dollar-quote
189                        let tag_len = dollar_tag.len();
190                        if i + 1 + tag_len < buf.len()
191                            && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
192                            && buf.get(i + 1 + tag_len) == Some(&b'$')
193                        {
194                            in_dollar_quote = false;
195                            dollar_tag.clear();
196                            continue;
197                        }
198                    }
199                }
200
201                if b == b'\'' && !inside_double_quote && !in_dollar_quote {
202                    inside_single_quote = !inside_single_quote;
203                } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
204                    inside_double_quote = !inside_double_quote;
205                } else if b == b';' && !inside_string {
206                    self.stmt_buffer.extend_from_slice(&buf[..=i]);
207                    consumed = i + 1;
208                    found_terminator = true;
209                    break;
210                }
211            }
212
213            if found_terminator {
214                self.reader.consume(consumed);
215                let result = std::mem::take(&mut self.stmt_buffer);
216
217                // Check if this is a PostgreSQL COPY FROM stdin statement
218                if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
219                    self.in_copy_data = true;
220                }
221
222                return Ok(Some(result));
223            }
224
225            self.stmt_buffer.extend_from_slice(buf);
226            let len = buf.len();
227            self.reader.consume(len);
228        }
229    }
230
231    /// Check if statement is a PostgreSQL COPY FROM stdin
232    fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
233        // Strip leading comments (pg_dump adds -- comments before COPY statements)
234        let stmt = strip_leading_comments_and_whitespace(stmt);
235        if stmt.len() < 4 {
236            return false;
237        }
238
239        // Take enough bytes to cover column lists - typical COPY statements are <500 bytes
240        let upper: Vec<u8> = stmt
241            .iter()
242            .take(500)
243            .map(|b| b.to_ascii_uppercase())
244            .collect();
245        upper.starts_with(b"COPY ")
246            && (upper.windows(10).any(|w| w == b"FROM STDIN")
247                || upper.windows(11).any(|w| w == b"FROM STDIN;"))
248    }
249
250    /// Read PostgreSQL COPY data block until we see the terminator line (\.)
251    fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
252        self.stmt_buffer.clear();
253
254        loop {
255            // First, fill the buffer and check if empty
256            let buf = self.reader.fill_buf()?;
257            if buf.is_empty() {
258                self.in_copy_data = false;
259                if self.stmt_buffer.is_empty() {
260                    return Ok(None);
261                }
262                return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
263            }
264
265            // Look for a newline in the buffer
266            let newline_pos = buf.iter().position(|&b| b == b'\n');
267
268            if let Some(i) = newline_pos {
269                // Include this newline
270                self.stmt_buffer.extend_from_slice(&buf[..=i]);
271                self.reader.consume(i + 1);
272
273                // Check if the line we just added ends the COPY block
274                // Looking for a line that is just "\.\n" or "\.\r\n"
275                if self.ends_with_copy_terminator() {
276                    self.in_copy_data = false;
277                    return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
278                }
279                // Continue reading - we need to process more lines
280            } else {
281                // No newline found, consume the whole buffer and continue
282                let len = buf.len();
283                self.stmt_buffer.extend_from_slice(buf);
284                self.reader.consume(len);
285            }
286        }
287    }
288
289    /// Check if buffer ends with the COPY terminator line (\.)
290    fn ends_with_copy_terminator(&self) -> bool {
291        let data = &self.stmt_buffer;
292        if data.len() < 2 {
293            return false;
294        }
295
296        // Look for a line that is just "\.\n" or "\.\r\n"
297        // We need to find the start of the last line
298        let last_newline = data[..data.len() - 1]
299            .iter()
300            .rposition(|&b| b == b'\n')
301            .map(|i| i + 1)
302            .unwrap_or(0);
303
304        let last_line = &data[last_newline..];
305
306        // Check if it's "\.\n" or "\.\r\n"
307        last_line == b"\\.\n" || last_line == b"\\.\r\n"
308    }
309
310    pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
311        Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
312    }
313
314    /// Parse a statement with dialect-specific handling
315    pub fn parse_statement_with_dialect(
316        stmt: &[u8],
317        dialect: SqlDialect,
318    ) -> (StatementType, String) {
319        // Strip leading comments (e.g., pg_dump adds -- comments before statements)
320        let stmt = strip_leading_comments_and_whitespace(stmt);
321
322        if stmt.len() < 4 {
323            return (StatementType::Unknown, String::new());
324        }
325
326        let upper_prefix: Vec<u8> = stmt
327            .iter()
328            .take(25)
329            .map(|b| b.to_ascii_uppercase())
330            .collect();
331
332        // PostgreSQL COPY statement
333        if upper_prefix.starts_with(b"COPY ") {
334            if let Some(caps) = COPY_RE.captures(stmt) {
335                if let Some(m) = caps.get(1) {
336                    let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
337                    // Handle schema.table - extract just the table name
338                    let table_name = name.split('.').last().unwrap_or(&name).to_string();
339                    return (StatementType::Copy, table_name);
340                }
341            }
342        }
343
344        if upper_prefix.starts_with(b"CREATE TABLE") {
345            // Try fast extraction first
346            if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
347                return (StatementType::CreateTable, name);
348            }
349            // Fall back to flexible regex
350            if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
351                if let Some(m) = caps.get(1) {
352                    return (
353                        StatementType::CreateTable,
354                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
355                    );
356                }
357            }
358            // Original regex as last resort
359            if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
360                if let Some(m) = caps.get(1) {
361                    return (
362                        StatementType::CreateTable,
363                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
364                    );
365                }
366            }
367        }
368
369        if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
370            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
371                return (StatementType::Insert, name);
372            }
373            if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
374                if let Some(m) = caps.get(1) {
375                    return (
376                        StatementType::Insert,
377                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
378                    );
379                }
380            }
381            if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
382                if let Some(m) = caps.get(1) {
383                    return (
384                        StatementType::Insert,
385                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
386                    );
387                }
388            }
389        }
390
391        if upper_prefix.starts_with(b"CREATE INDEX") {
392            if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
393                if let Some(m) = caps.get(1) {
394                    return (
395                        StatementType::CreateIndex,
396                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
397                    );
398                }
399            }
400        }
401
402        if upper_prefix.starts_with(b"ALTER TABLE") {
403            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
404                return (StatementType::AlterTable, name);
405            }
406            if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
407                if let Some(m) = caps.get(1) {
408                    return (
409                        StatementType::AlterTable,
410                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
411                    );
412                }
413            }
414        }
415
416        if upper_prefix.starts_with(b"DROP TABLE") {
417            if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
418                return (StatementType::DropTable, name);
419            }
420            if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
421                if let Some(m) = caps.get(1) {
422                    return (
423                        StatementType::DropTable,
424                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
425                    );
426                }
427            }
428        }
429
430        (StatementType::Unknown, String::new())
431    }
432}
433
434#[inline]
435fn trim_ascii_start(data: &[u8]) -> &[u8] {
436    let start = data
437        .iter()
438        .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
439        .unwrap_or(data.len());
440    &data[start..]
441}
442
443/// Strip leading whitespace and SQL line comments (`-- ...`) from a statement.
444/// This makes parsing robust to pg_dump-style comment blocks before statements.
445fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
446    loop {
447        // First trim leading ASCII whitespace
448        data = trim_ascii_start(data);
449
450        if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
451            // Skip until end of line
452            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
453                data = &data[pos + 1..];
454                continue;
455            } else {
456                // Comment runs to EOF, nothing left
457                return &[];
458            }
459        }
460
461        break;
462    }
463
464    data
465}
466
467/// Extract table name with support for:
468/// - IF NOT EXISTS
469/// - ONLY (PostgreSQL)
470/// - Schema-qualified names (schema.table)
471/// - Both backtick and double-quote quoting
472#[inline]
473fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
474    let mut i = offset;
475
476    // Skip whitespace
477    while i < stmt.len() && is_whitespace(stmt[i]) {
478        i += 1;
479    }
480
481    if i >= stmt.len() {
482        return None;
483    }
484
485    // Check for IF NOT EXISTS
486    let upper_check: Vec<u8> = stmt[i..]
487        .iter()
488        .take(20)
489        .map(|b| b.to_ascii_uppercase())
490        .collect();
491    if upper_check.starts_with(b"IF NOT EXISTS") {
492        i += 13; // Skip "IF NOT EXISTS"
493        while i < stmt.len() && is_whitespace(stmt[i]) {
494            i += 1;
495        }
496    }
497
498    // Check for ONLY (PostgreSQL)
499    let upper_check: Vec<u8> = stmt[i..]
500        .iter()
501        .take(10)
502        .map(|b| b.to_ascii_uppercase())
503        .collect();
504    if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
505        i += 4;
506        while i < stmt.len() && is_whitespace(stmt[i]) {
507            i += 1;
508        }
509    }
510
511    if i >= stmt.len() {
512        return None;
513    }
514
515    // Read identifier (potentially schema-qualified)
516    let mut parts: Vec<String> = Vec::new();
517
518    loop {
519        // Determine quote character
520        let quote_char = match stmt.get(i) {
521            Some(b'`') if dialect == SqlDialect::MySql => {
522                i += 1;
523                Some(b'`')
524            }
525            Some(b'"') if dialect != SqlDialect::MySql => {
526                i += 1;
527                Some(b'"')
528            }
529            Some(b'"') => {
530                // Allow double quotes for MySQL too (though less common)
531                i += 1;
532                Some(b'"')
533            }
534            _ => None,
535        };
536
537        let start = i;
538
539        while i < stmt.len() {
540            let b = stmt[i];
541            if let Some(q) = quote_char {
542                if b == q {
543                    let name = &stmt[start..i];
544                    parts.push(String::from_utf8_lossy(name).into_owned());
545                    i += 1; // Skip closing quote
546                    break;
547                }
548            } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
549                if i > start {
550                    let name = &stmt[start..i];
551                    parts.push(String::from_utf8_lossy(name).into_owned());
552                }
553                break;
554            }
555            i += 1;
556        }
557
558        // If at end of quoted name without finding close quote, bail
559        if quote_char.is_some() && i <= start {
560            break;
561        }
562
563        // Check for schema separator (.)
564        while i < stmt.len() && is_whitespace(stmt[i]) {
565            i += 1;
566        }
567
568        if i < stmt.len() && stmt[i] == b'.' {
569            i += 1; // Skip the dot
570            while i < stmt.len() && is_whitespace(stmt[i]) {
571                i += 1;
572            }
573            // Continue to read the next identifier (table name)
574        } else {
575            break;
576        }
577    }
578
579    // Return the last part (table name), not the schema
580    parts.pop()
581}
582
583#[inline]
584fn is_whitespace(b: u8) -> bool {
585    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
586}
587
588pub fn determine_buffer_size(file_size: u64) -> usize {
589    if file_size > 1024 * 1024 * 1024 {
590        MEDIUM_BUFFER_SIZE
591    } else {
592        SMALL_BUFFER_SIZE
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[test]
601    fn test_parse_create_table() {
602        let stmt = b"CREATE TABLE users (id INT);";
603        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
604        assert_eq!(typ, StatementType::CreateTable);
605        assert_eq!(name, "users");
606    }
607
608    #[test]
609    fn test_parse_create_table_backticks() {
610        let stmt = b"CREATE TABLE `my_table` (id INT);";
611        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
612        assert_eq!(typ, StatementType::CreateTable);
613        assert_eq!(name, "my_table");
614    }
615
616    #[test]
617    fn test_parse_insert() {
618        let stmt = b"INSERT INTO posts VALUES (1, 'test');";
619        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
620        assert_eq!(typ, StatementType::Insert);
621        assert_eq!(name, "posts");
622    }
623
624    #[test]
625    fn test_parse_insert_backticks() {
626        let stmt = b"INSERT INTO `comments` VALUES (1);";
627        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
628        assert_eq!(typ, StatementType::Insert);
629        assert_eq!(name, "comments");
630    }
631
632    #[test]
633    fn test_parse_alter_table() {
634        let stmt = b"ALTER TABLE orders ADD COLUMN status INT;";
635        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
636        assert_eq!(typ, StatementType::AlterTable);
637        assert_eq!(name, "orders");
638    }
639
640    #[test]
641    fn test_parse_drop_table() {
642        let stmt = b"DROP TABLE temp_data;";
643        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
644        assert_eq!(typ, StatementType::DropTable);
645        assert_eq!(name, "temp_data");
646    }
647
648    #[test]
649    fn test_read_statement_basic() {
650        let sql = b"CREATE TABLE t1 (id INT); INSERT INTO t1 VALUES (1);";
651        let mut parser = Parser::new(&sql[..], 1024);
652
653        let stmt1 = parser.read_statement().unwrap().unwrap();
654        assert_eq!(stmt1, b"CREATE TABLE t1 (id INT);");
655
656        let stmt2 = parser.read_statement().unwrap().unwrap();
657        assert_eq!(stmt2, b" INSERT INTO t1 VALUES (1);");
658
659        let stmt3 = parser.read_statement().unwrap();
660        assert!(stmt3.is_none());
661    }
662
663    #[test]
664    fn test_read_statement_with_strings() {
665        let sql = b"INSERT INTO t1 VALUES ('hello; world');";
666        let mut parser = Parser::new(&sql[..], 1024);
667
668        let stmt = parser.read_statement().unwrap().unwrap();
669        assert_eq!(stmt, b"INSERT INTO t1 VALUES ('hello; world');");
670    }
671
672    #[test]
673    fn test_read_statement_with_escaped_quotes() {
674        let sql = b"INSERT INTO t1 VALUES ('it\\'s a test');";
675        let mut parser = Parser::new(&sql[..], 1024);
676
677        let stmt = parser.read_statement().unwrap().unwrap();
678        assert_eq!(stmt, b"INSERT INTO t1 VALUES ('it\\'s a test');");
679    }
680}
681
682#[cfg(test)]
683mod copy_tests {
684    use super::*;
685    use std::io::Cursor;
686
687    #[test]
688    fn test_copy_from_stdin_detection() {
689        let data = b"COPY public.table_001 (id, col_int, col_varchar, col_text, col_decimal, created_at) FROM stdin;\n1\t6892\tvalue_1\tLorem ipsum\n\\.\n";
690        let reader = Cursor::new(&data[..]);
691        let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
692
693        // First statement should be the COPY header
694        let stmt1 = parser.read_statement().unwrap().unwrap();
695        let s1 = String::from_utf8_lossy(&stmt1);
696        assert!(s1.starts_with("COPY"), "First statement should be COPY");
697        assert!(s1.contains("FROM stdin"), "Should contain FROM stdin");
698
699        // Second statement should be the data block
700        let stmt2 = parser.read_statement().unwrap().unwrap();
701        let s2 = String::from_utf8_lossy(&stmt2);
702        assert!(
703            s2.contains("1\t6892"),
704            "Data block should contain first row"
705        );
706        assert!(
707            s2.ends_with("\\.\n"),
708            "Data block should end with terminator"
709        );
710    }
711
712    #[test]
713    fn test_copy_with_leading_comments() {
714        // pg_dump adds -- comments before COPY statements
715        let data = b"--\n-- Data for Name: table_001\n--\n\nCOPY public.table_001 (id, name) FROM stdin;\n1\tfoo\n\\.\n";
716        let reader = Cursor::new(&data[..]);
717        let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
718
719        // First statement should be the COPY header (with leading comments)
720        let stmt1 = parser.read_statement().unwrap().unwrap();
721        let (stmt_type, table_name) =
722            Parser::<&[u8]>::parse_statement_with_dialect(&stmt1, SqlDialect::Postgres);
723        assert_eq!(stmt_type, StatementType::Copy);
724        assert_eq!(table_name, "table_001");
725
726        // Second statement should be the data block
727        let stmt2 = parser.read_statement().unwrap().unwrap();
728        let s2 = String::from_utf8_lossy(&stmt2);
729        assert!(
730            s2.ends_with("\\.\n"),
731            "Data block should end with terminator"
732        );
733    }
734}