sql_splitter/parser/
mod.rs

1#[cfg(test)]
2mod edge_case_tests;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11/// SQL dialect for parser behavior
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14    /// MySQL/MariaDB mysqldump format (backtick quoting, backslash escapes)
15    #[default]
16    MySql,
17    /// PostgreSQL pg_dump format (double-quote identifiers, COPY FROM stdin, dollar-quoting)
18    Postgres,
19    /// SQLite .dump format (double-quote identifiers, '' escapes)
20    Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24    type Err = String;
25
26    fn from_str(s: &str) -> Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29            "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30            "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31            _ => Err(format!("Unknown dialect: {}. Valid options: mysql, postgres, sqlite", s)),
32        }
33    }
34}
35
36impl std::fmt::Display for SqlDialect {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            SqlDialect::MySql => write!(f, "mysql"),
40            SqlDialect::Postgres => write!(f, "postgres"),
41            SqlDialect::Sqlite => write!(f, "sqlite"),
42        }
43    }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum StatementType {
48    Unknown,
49    CreateTable,
50    Insert,
51    CreateIndex,
52    AlterTable,
53    DropTable,
54    /// PostgreSQL COPY FROM stdin
55    Copy,
56}
57
58static CREATE_TABLE_RE: Lazy<Regex> =
59    Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
60
61static INSERT_INTO_RE: Lazy<Regex> =
62    Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
63
64static CREATE_INDEX_RE: Lazy<Regex> =
65    Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
66
67static ALTER_TABLE_RE: Lazy<Regex> =
68    Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
69
70static DROP_TABLE_RE: Lazy<Regex> =
71    Lazy::new(|| Regex::new(r"(?i)DROP\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
72
73// PostgreSQL COPY statement regex
74static COPY_RE: Lazy<Regex> = Lazy::new(|| {
75    Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap()
76});
77
78// More flexible table name regex that handles:
79// - Backticks: `table`
80// - Double quotes: "table"
81// - Schema qualified: schema.table, `schema`.`table`, "schema"."table"
82// - IF NOT EXISTS
83static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
84    Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
85});
86
87static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
88    Regex::new(r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
89});
90
91pub struct Parser<R: Read> {
92    reader: BufReader<R>,
93    stmt_buffer: Vec<u8>,
94    dialect: SqlDialect,
95    /// For PostgreSQL: true when reading COPY data block
96    in_copy_data: bool,
97}
98
99impl<R: Read> Parser<R> {
100    pub fn new(reader: R, buffer_size: usize) -> Self {
101        Self::with_dialect(reader, buffer_size, SqlDialect::default())
102    }
103
104    pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
105        Self {
106            reader: BufReader::with_capacity(buffer_size, reader),
107            stmt_buffer: Vec::with_capacity(32 * 1024),
108            dialect,
109            in_copy_data: false,
110        }
111    }
112
113    pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
114        // If we're in PostgreSQL COPY data mode, read until we see the terminator
115        if self.in_copy_data {
116            return self.read_copy_data();
117        }
118
119        self.stmt_buffer.clear();
120
121        let mut inside_single_quote = false;
122        let mut inside_double_quote = false;
123        let mut escaped = false;
124        let mut in_line_comment = false;
125        // For PostgreSQL dollar-quoting: track the tag
126        let mut in_dollar_quote = false;
127        let mut dollar_tag: Vec<u8> = Vec::new();
128
129        loop {
130            let buf = self.reader.fill_buf()?;
131            if buf.is_empty() {
132                if self.stmt_buffer.is_empty() {
133                    return Ok(None);
134                }
135                let result = std::mem::take(&mut self.stmt_buffer);
136                return Ok(Some(result));
137            }
138
139            let mut consumed = 0;
140            let mut found_terminator = false;
141
142            for (i, &b) in buf.iter().enumerate() {
143                let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
144
145                // End of line comment on newline
146                if in_line_comment {
147                    if b == b'\n' {
148                        in_line_comment = false;
149                    }
150                    continue;
151                }
152
153                if escaped {
154                    escaped = false;
155                    continue;
156                }
157
158                // Handle backslash escapes (MySQL style)
159                if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
160                    escaped = true;
161                    continue;
162                }
163
164                // Handle line comments (-- to end of line)
165                if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
166                    in_line_comment = true;
167                    continue;
168                }
169
170                // Handle dollar-quoting for PostgreSQL
171                if self.dialect == SqlDialect::Postgres && !inside_single_quote && !inside_double_quote {
172                    if b == b'$' && !in_dollar_quote {
173                        // Start of dollar-quote: scan for the closing $
174                        if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
175                            dollar_tag = buf[i + 1..i + 1 + end].to_vec();
176                            in_dollar_quote = true;
177                            continue;
178                        }
179                    } else if b == b'$' && in_dollar_quote {
180                        // Potential end of dollar-quote
181                        let tag_len = dollar_tag.len();
182                        if i + 1 + tag_len < buf.len() 
183                            && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
184                            && buf.get(i + 1 + tag_len) == Some(&b'$')
185                        {
186                            in_dollar_quote = false;
187                            dollar_tag.clear();
188                            continue;
189                        }
190                    }
191                }
192
193                if b == b'\'' && !inside_double_quote && !in_dollar_quote {
194                    inside_single_quote = !inside_single_quote;
195                } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
196                    inside_double_quote = !inside_double_quote;
197                } else if b == b';' && !inside_string {
198                    self.stmt_buffer.extend_from_slice(&buf[..=i]);
199                    consumed = i + 1;
200                    found_terminator = true;
201                    break;
202                }
203            }
204
205            if found_terminator {
206                self.reader.consume(consumed);
207                let result = std::mem::take(&mut self.stmt_buffer);
208                
209                // Check if this is a PostgreSQL COPY FROM stdin statement
210                if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
211                    self.in_copy_data = true;
212                }
213                
214                return Ok(Some(result));
215            }
216
217            self.stmt_buffer.extend_from_slice(buf);
218            let len = buf.len();
219            self.reader.consume(len);
220        }
221    }
222
223    /// Check if statement is a PostgreSQL COPY FROM stdin
224    fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
225        // Strip leading comments (pg_dump adds -- comments before COPY statements)
226        let stmt = strip_leading_comments_and_whitespace(stmt);
227        if stmt.len() < 4 {
228            return false;
229        }
230        
231        // Take enough bytes to cover column lists - typical COPY statements are <500 bytes
232        let upper: Vec<u8> = stmt.iter().take(500).map(|b| b.to_ascii_uppercase()).collect();
233        upper.starts_with(b"COPY ") && 
234            (upper.windows(10).any(|w| w == b"FROM STDIN") ||
235             upper.windows(11).any(|w| w == b"FROM STDIN;"))
236    }
237
238    /// Read PostgreSQL COPY data block until we see the terminator line (\.)
239    fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
240        self.stmt_buffer.clear();
241        
242        loop {
243            // First, fill the buffer and check if empty
244            let buf = self.reader.fill_buf()?;
245            if buf.is_empty() {
246                self.in_copy_data = false;
247                if self.stmt_buffer.is_empty() {
248                    return Ok(None);
249                }
250                return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
251            }
252
253            // Look for a newline in the buffer
254            let newline_pos = buf.iter().position(|&b| b == b'\n');
255            
256            if let Some(i) = newline_pos {
257                // Include this newline
258                self.stmt_buffer.extend_from_slice(&buf[..=i]);
259                self.reader.consume(i + 1);
260                
261                // Check if the line we just added ends the COPY block
262                // Looking for a line that is just "\.\n" or "\.\r\n"
263                if self.ends_with_copy_terminator() {
264                    self.in_copy_data = false;
265                    return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
266                }
267                // Continue reading - we need to process more lines
268            } else {
269                // No newline found, consume the whole buffer and continue
270                let len = buf.len();
271                self.stmt_buffer.extend_from_slice(buf);
272                self.reader.consume(len);
273            }
274        }
275    }
276
277    /// Check if buffer ends with the COPY terminator line (\.)
278    fn ends_with_copy_terminator(&self) -> bool {
279        let data = &self.stmt_buffer;
280        if data.len() < 2 {
281            return false;
282        }
283        
284        // Look for a line that is just "\.\n" or "\.\r\n"
285        // We need to find the start of the last line
286        let last_newline = data[..data.len() - 1]
287            .iter()
288            .rposition(|&b| b == b'\n')
289            .map(|i| i + 1)
290            .unwrap_or(0);
291        
292        let last_line = &data[last_newline..];
293        
294        // Check if it's "\.\n" or "\.\r\n"
295        last_line == b"\\.\n" || last_line == b"\\.\r\n"
296    }
297
298    pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
299        Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
300    }
301
302    /// Parse a statement with dialect-specific handling
303    pub fn parse_statement_with_dialect(stmt: &[u8], dialect: SqlDialect) -> (StatementType, String) {
304        // Strip leading comments (e.g., pg_dump adds -- comments before statements)
305        let stmt = strip_leading_comments_and_whitespace(stmt);
306
307        if stmt.len() < 4 {
308            return (StatementType::Unknown, String::new());
309        }
310
311        let upper_prefix: Vec<u8> = stmt
312            .iter()
313            .take(25)
314            .map(|b| b.to_ascii_uppercase())
315            .collect();
316
317        // PostgreSQL COPY statement
318        if upper_prefix.starts_with(b"COPY ") {
319            if let Some(caps) = COPY_RE.captures(stmt) {
320                if let Some(m) = caps.get(1) {
321                    let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
322                    // Handle schema.table - extract just the table name
323                    let table_name = name.split('.').last().unwrap_or(&name).to_string();
324                    return (StatementType::Copy, table_name);
325                }
326            }
327        }
328
329        if upper_prefix.starts_with(b"CREATE TABLE") {
330            // Try fast extraction first
331            if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
332                return (StatementType::CreateTable, name);
333            }
334            // Fall back to flexible regex
335            if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
336                if let Some(m) = caps.get(1) {
337                    return (
338                        StatementType::CreateTable,
339                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
340                    );
341                }
342            }
343            // Original regex as last resort
344            if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
345                if let Some(m) = caps.get(1) {
346                    return (
347                        StatementType::CreateTable,
348                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
349                    );
350                }
351            }
352        }
353
354        if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
355            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
356                return (StatementType::Insert, name);
357            }
358            if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
359                if let Some(m) = caps.get(1) {
360                    return (
361                        StatementType::Insert,
362                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
363                    );
364                }
365            }
366            if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
367                if let Some(m) = caps.get(1) {
368                    return (
369                        StatementType::Insert,
370                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
371                    );
372                }
373            }
374        }
375
376        if upper_prefix.starts_with(b"CREATE INDEX") {
377            if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
378                if let Some(m) = caps.get(1) {
379                    return (
380                        StatementType::CreateIndex,
381                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
382                    );
383                }
384            }
385        }
386
387        if upper_prefix.starts_with(b"ALTER TABLE") {
388            if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
389                return (StatementType::AlterTable, name);
390            }
391            if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
392                if let Some(m) = caps.get(1) {
393                    return (
394                        StatementType::AlterTable,
395                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
396                    );
397                }
398            }
399        }
400
401        if upper_prefix.starts_with(b"DROP TABLE") {
402            if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
403                return (StatementType::DropTable, name);
404            }
405            if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
406                if let Some(m) = caps.get(1) {
407                    return (
408                        StatementType::DropTable,
409                        String::from_utf8_lossy(m.as_bytes()).into_owned(),
410                    );
411                }
412            }
413        }
414
415        (StatementType::Unknown, String::new())
416    }
417}
418
419#[inline]
420fn trim_ascii_start(data: &[u8]) -> &[u8] {
421    let start = data
422        .iter()
423        .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
424        .unwrap_or(data.len());
425    &data[start..]
426}
427
428/// Strip leading whitespace and SQL line comments (`-- ...`) from a statement.
429/// This makes parsing robust to pg_dump-style comment blocks before statements.
430fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
431    loop {
432        // First trim leading ASCII whitespace
433        data = trim_ascii_start(data);
434
435        if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
436            // Skip until end of line
437            if let Some(pos) = data.iter().position(|&b| b == b'\n') {
438                data = &data[pos + 1..];
439                continue;
440            } else {
441                // Comment runs to EOF, nothing left
442                return &[];
443            }
444        }
445
446        break;
447    }
448
449    data
450}
451
452/// Extract table name with support for:
453/// - IF NOT EXISTS
454/// - ONLY (PostgreSQL)
455/// - Schema-qualified names (schema.table)
456/// - Both backtick and double-quote quoting
457#[inline]
458fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
459    let mut i = offset;
460
461    // Skip whitespace
462    while i < stmt.len() && is_whitespace(stmt[i]) {
463        i += 1;
464    }
465
466    if i >= stmt.len() {
467        return None;
468    }
469
470    // Check for IF NOT EXISTS
471    let upper_check: Vec<u8> = stmt[i..].iter().take(20).map(|b| b.to_ascii_uppercase()).collect();
472    if upper_check.starts_with(b"IF NOT EXISTS") {
473        i += 13; // Skip "IF NOT EXISTS"
474        while i < stmt.len() && is_whitespace(stmt[i]) {
475            i += 1;
476        }
477    }
478
479    // Check for ONLY (PostgreSQL)
480    let upper_check: Vec<u8> = stmt[i..].iter().take(10).map(|b| b.to_ascii_uppercase()).collect();
481    if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
482        i += 4;
483        while i < stmt.len() && is_whitespace(stmt[i]) {
484            i += 1;
485        }
486    }
487
488    if i >= stmt.len() {
489        return None;
490    }
491
492    // Read identifier (potentially schema-qualified)
493    let mut parts: Vec<String> = Vec::new();
494    
495    loop {
496        // Determine quote character
497        let quote_char = match stmt.get(i) {
498            Some(b'`') if dialect == SqlDialect::MySql => {
499                i += 1;
500                Some(b'`')
501            }
502            Some(b'"') if dialect != SqlDialect::MySql => {
503                i += 1;
504                Some(b'"')
505            }
506            Some(b'"') => {
507                // Allow double quotes for MySQL too (though less common)
508                i += 1;
509                Some(b'"')
510            }
511            _ => None,
512        };
513
514        let start = i;
515
516        while i < stmt.len() {
517            let b = stmt[i];
518            if let Some(q) = quote_char {
519                if b == q {
520                    let name = &stmt[start..i];
521                    parts.push(String::from_utf8_lossy(name).into_owned());
522                    i += 1; // Skip closing quote
523                    break;
524                }
525            } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
526                if i > start {
527                    let name = &stmt[start..i];
528                    parts.push(String::from_utf8_lossy(name).into_owned());
529                }
530                break;
531            }
532            i += 1;
533        }
534
535        // If at end of quoted name without finding close quote, bail
536        if quote_char.is_some() && i <= start {
537            break;
538        }
539
540        // Check for schema separator (.)
541        while i < stmt.len() && is_whitespace(stmt[i]) {
542            i += 1;
543        }
544        
545        if i < stmt.len() && stmt[i] == b'.' {
546            i += 1; // Skip the dot
547            while i < stmt.len() && is_whitespace(stmt[i]) {
548                i += 1;
549            }
550            // Continue to read the next identifier (table name)
551        } else {
552            break;
553        }
554    }
555
556    // Return the last part (table name), not the schema
557    parts.pop()
558}
559
560#[inline]
561fn is_whitespace(b: u8) -> bool {
562    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
563}
564
565pub fn determine_buffer_size(file_size: u64) -> usize {
566    if file_size > 1024 * 1024 * 1024 {
567        MEDIUM_BUFFER_SIZE
568    } else {
569        SMALL_BUFFER_SIZE
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_parse_create_table() {
579        let stmt = b"CREATE TABLE users (id INT);";
580        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
581        assert_eq!(typ, StatementType::CreateTable);
582        assert_eq!(name, "users");
583    }
584
585    #[test]
586    fn test_parse_create_table_backticks() {
587        let stmt = b"CREATE TABLE `my_table` (id INT);";
588        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
589        assert_eq!(typ, StatementType::CreateTable);
590        assert_eq!(name, "my_table");
591    }
592
593    #[test]
594    fn test_parse_insert() {
595        let stmt = b"INSERT INTO posts VALUES (1, 'test');";
596        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
597        assert_eq!(typ, StatementType::Insert);
598        assert_eq!(name, "posts");
599    }
600
601    #[test]
602    fn test_parse_insert_backticks() {
603        let stmt = b"INSERT INTO `comments` VALUES (1);";
604        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
605        assert_eq!(typ, StatementType::Insert);
606        assert_eq!(name, "comments");
607    }
608
609    #[test]
610    fn test_parse_alter_table() {
611        let stmt = b"ALTER TABLE orders ADD COLUMN status INT;";
612        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
613        assert_eq!(typ, StatementType::AlterTable);
614        assert_eq!(name, "orders");
615    }
616
617    #[test]
618    fn test_parse_drop_table() {
619        let stmt = b"DROP TABLE temp_data;";
620        let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
621        assert_eq!(typ, StatementType::DropTable);
622        assert_eq!(name, "temp_data");
623    }
624
625    #[test]
626    fn test_read_statement_basic() {
627        let sql = b"CREATE TABLE t1 (id INT); INSERT INTO t1 VALUES (1);";
628        let mut parser = Parser::new(&sql[..], 1024);
629
630        let stmt1 = parser.read_statement().unwrap().unwrap();
631        assert_eq!(stmt1, b"CREATE TABLE t1 (id INT);");
632
633        let stmt2 = parser.read_statement().unwrap().unwrap();
634        assert_eq!(stmt2, b" INSERT INTO t1 VALUES (1);");
635
636        let stmt3 = parser.read_statement().unwrap();
637        assert!(stmt3.is_none());
638    }
639
640    #[test]
641    fn test_read_statement_with_strings() {
642        let sql = b"INSERT INTO t1 VALUES ('hello; world');";
643        let mut parser = Parser::new(&sql[..], 1024);
644
645        let stmt = parser.read_statement().unwrap().unwrap();
646        assert_eq!(stmt, b"INSERT INTO t1 VALUES ('hello; world');");
647    }
648
649    #[test]
650    fn test_read_statement_with_escaped_quotes() {
651        let sql = b"INSERT INTO t1 VALUES ('it\\'s a test');";
652        let mut parser = Parser::new(&sql[..], 1024);
653
654        let stmt = parser.read_statement().unwrap().unwrap();
655        assert_eq!(stmt, b"INSERT INTO t1 VALUES ('it\\'s a test');");
656    }
657}
658
659#[cfg(test)]
660mod copy_tests {
661    use super::*;
662    use std::io::Cursor;
663    
664    #[test]
665    fn test_copy_from_stdin_detection() {
666        let data = b"COPY public.table_001 (id, col_int, col_varchar, col_text, col_decimal, created_at) FROM stdin;\n1\t6892\tvalue_1\tLorem ipsum\n\\.\n";
667        let reader = Cursor::new(&data[..]);
668        let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
669        
670        // First statement should be the COPY header
671        let stmt1 = parser.read_statement().unwrap().unwrap();
672        let s1 = String::from_utf8_lossy(&stmt1);
673        assert!(s1.starts_with("COPY"), "First statement should be COPY");
674        assert!(s1.contains("FROM stdin"), "Should contain FROM stdin");
675        
676        // Second statement should be the data block
677        let stmt2 = parser.read_statement().unwrap().unwrap();
678        let s2 = String::from_utf8_lossy(&stmt2);
679        assert!(s2.contains("1\t6892"), "Data block should contain first row");
680        assert!(s2.ends_with("\\.\n"), "Data block should end with terminator");
681    }
682    
683    #[test]
684    fn test_copy_with_leading_comments() {
685        // pg_dump adds -- comments before COPY statements
686        let data = b"--\n-- Data for Name: table_001\n--\n\nCOPY public.table_001 (id, name) FROM stdin;\n1\tfoo\n\\.\n";
687        let reader = Cursor::new(&data[..]);
688        let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
689        
690        // First statement should be the COPY header (with leading comments)
691        let stmt1 = parser.read_statement().unwrap().unwrap();
692        let (stmt_type, table_name) = Parser::<&[u8]>::parse_statement_with_dialect(&stmt1, SqlDialect::Postgres);
693        assert_eq!(stmt_type, StatementType::Copy);
694        assert_eq!(table_name, "table_001");
695        
696        // Second statement should be the data block
697        let stmt2 = parser.read_statement().unwrap().unwrap();
698        let s2 = String::from_utf8_lossy(&stmt2);
699        assert!(s2.ends_with("\\.\n"), "Data block should end with terminator");
700    }
701}