1#[cfg(test)]
2mod edge_case_tests;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14 #[default]
16 MySql,
17 Postgres,
19 Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24 type Err = String;
25
26 fn from_str(s: &str) -> Result<Self, Self::Err> {
27 match s.to_lowercase().as_str() {
28 "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29 "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30 "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31 _ => Err(format!("Unknown dialect: {}. Valid options: mysql, postgres, sqlite", s)),
32 }
33 }
34}
35
36impl std::fmt::Display for SqlDialect {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 SqlDialect::MySql => write!(f, "mysql"),
40 SqlDialect::Postgres => write!(f, "postgres"),
41 SqlDialect::Sqlite => write!(f, "sqlite"),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum StatementType {
48 Unknown,
49 CreateTable,
50 Insert,
51 CreateIndex,
52 AlterTable,
53 DropTable,
54 Copy,
56}
57
58static CREATE_TABLE_RE: Lazy<Regex> =
59 Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
60
61static INSERT_INTO_RE: Lazy<Regex> =
62 Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
63
64static CREATE_INDEX_RE: Lazy<Regex> =
65 Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
66
67static ALTER_TABLE_RE: Lazy<Regex> =
68 Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
69
70static DROP_TABLE_RE: Lazy<Regex> =
71 Lazy::new(|| Regex::new(r"(?i)DROP\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
72
73static COPY_RE: Lazy<Regex> = Lazy::new(|| {
75 Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap()
76});
77
78static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
84 Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
85});
86
87static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
88 Regex::new(r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
89});
90
91pub struct Parser<R: Read> {
92 reader: BufReader<R>,
93 stmt_buffer: Vec<u8>,
94 dialect: SqlDialect,
95 in_copy_data: bool,
97}
98
99impl<R: Read> Parser<R> {
100 pub fn new(reader: R, buffer_size: usize) -> Self {
101 Self::with_dialect(reader, buffer_size, SqlDialect::default())
102 }
103
104 pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
105 Self {
106 reader: BufReader::with_capacity(buffer_size, reader),
107 stmt_buffer: Vec::with_capacity(32 * 1024),
108 dialect,
109 in_copy_data: false,
110 }
111 }
112
113 pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
114 if self.in_copy_data {
116 return self.read_copy_data();
117 }
118
119 self.stmt_buffer.clear();
120
121 let mut inside_single_quote = false;
122 let mut inside_double_quote = false;
123 let mut escaped = false;
124 let mut in_line_comment = false;
125 let mut in_dollar_quote = false;
127 let mut dollar_tag: Vec<u8> = Vec::new();
128
129 loop {
130 let buf = self.reader.fill_buf()?;
131 if buf.is_empty() {
132 if self.stmt_buffer.is_empty() {
133 return Ok(None);
134 }
135 let result = std::mem::take(&mut self.stmt_buffer);
136 return Ok(Some(result));
137 }
138
139 let mut consumed = 0;
140 let mut found_terminator = false;
141
142 for (i, &b) in buf.iter().enumerate() {
143 let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
144
145 if in_line_comment {
147 if b == b'\n' {
148 in_line_comment = false;
149 }
150 continue;
151 }
152
153 if escaped {
154 escaped = false;
155 continue;
156 }
157
158 if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
160 escaped = true;
161 continue;
162 }
163
164 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
166 in_line_comment = true;
167 continue;
168 }
169
170 if self.dialect == SqlDialect::Postgres && !inside_single_quote && !inside_double_quote {
172 if b == b'$' && !in_dollar_quote {
173 if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
175 dollar_tag = buf[i + 1..i + 1 + end].to_vec();
176 in_dollar_quote = true;
177 continue;
178 }
179 } else if b == b'$' && in_dollar_quote {
180 let tag_len = dollar_tag.len();
182 if i + 1 + tag_len < buf.len()
183 && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
184 && buf.get(i + 1 + tag_len) == Some(&b'$')
185 {
186 in_dollar_quote = false;
187 dollar_tag.clear();
188 continue;
189 }
190 }
191 }
192
193 if b == b'\'' && !inside_double_quote && !in_dollar_quote {
194 inside_single_quote = !inside_single_quote;
195 } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
196 inside_double_quote = !inside_double_quote;
197 } else if b == b';' && !inside_string {
198 self.stmt_buffer.extend_from_slice(&buf[..=i]);
199 consumed = i + 1;
200 found_terminator = true;
201 break;
202 }
203 }
204
205 if found_terminator {
206 self.reader.consume(consumed);
207 let result = std::mem::take(&mut self.stmt_buffer);
208
209 if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
211 self.in_copy_data = true;
212 }
213
214 return Ok(Some(result));
215 }
216
217 self.stmt_buffer.extend_from_slice(buf);
218 let len = buf.len();
219 self.reader.consume(len);
220 }
221 }
222
223 fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
225 let stmt = strip_leading_comments_and_whitespace(stmt);
227 if stmt.len() < 4 {
228 return false;
229 }
230
231 let upper: Vec<u8> = stmt.iter().take(500).map(|b| b.to_ascii_uppercase()).collect();
233 upper.starts_with(b"COPY ") &&
234 (upper.windows(10).any(|w| w == b"FROM STDIN") ||
235 upper.windows(11).any(|w| w == b"FROM STDIN;"))
236 }
237
238 fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
240 self.stmt_buffer.clear();
241
242 loop {
243 let buf = self.reader.fill_buf()?;
245 if buf.is_empty() {
246 self.in_copy_data = false;
247 if self.stmt_buffer.is_empty() {
248 return Ok(None);
249 }
250 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
251 }
252
253 let newline_pos = buf.iter().position(|&b| b == b'\n');
255
256 if let Some(i) = newline_pos {
257 self.stmt_buffer.extend_from_slice(&buf[..=i]);
259 self.reader.consume(i + 1);
260
261 if self.ends_with_copy_terminator() {
264 self.in_copy_data = false;
265 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
266 }
267 } else {
269 let len = buf.len();
271 self.stmt_buffer.extend_from_slice(buf);
272 self.reader.consume(len);
273 }
274 }
275 }
276
277 fn ends_with_copy_terminator(&self) -> bool {
279 let data = &self.stmt_buffer;
280 if data.len() < 2 {
281 return false;
282 }
283
284 let last_newline = data[..data.len() - 1]
287 .iter()
288 .rposition(|&b| b == b'\n')
289 .map(|i| i + 1)
290 .unwrap_or(0);
291
292 let last_line = &data[last_newline..];
293
294 last_line == b"\\.\n" || last_line == b"\\.\r\n"
296 }
297
298 pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
299 Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
300 }
301
302 pub fn parse_statement_with_dialect(stmt: &[u8], dialect: SqlDialect) -> (StatementType, String) {
304 let stmt = strip_leading_comments_and_whitespace(stmt);
306
307 if stmt.len() < 4 {
308 return (StatementType::Unknown, String::new());
309 }
310
311 let upper_prefix: Vec<u8> = stmt
312 .iter()
313 .take(25)
314 .map(|b| b.to_ascii_uppercase())
315 .collect();
316
317 if upper_prefix.starts_with(b"COPY ") {
319 if let Some(caps) = COPY_RE.captures(stmt) {
320 if let Some(m) = caps.get(1) {
321 let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
322 let table_name = name.split('.').last().unwrap_or(&name).to_string();
324 return (StatementType::Copy, table_name);
325 }
326 }
327 }
328
329 if upper_prefix.starts_with(b"CREATE TABLE") {
330 if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
332 return (StatementType::CreateTable, name);
333 }
334 if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
336 if let Some(m) = caps.get(1) {
337 return (
338 StatementType::CreateTable,
339 String::from_utf8_lossy(m.as_bytes()).into_owned(),
340 );
341 }
342 }
343 if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
345 if let Some(m) = caps.get(1) {
346 return (
347 StatementType::CreateTable,
348 String::from_utf8_lossy(m.as_bytes()).into_owned(),
349 );
350 }
351 }
352 }
353
354 if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
355 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
356 return (StatementType::Insert, name);
357 }
358 if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
359 if let Some(m) = caps.get(1) {
360 return (
361 StatementType::Insert,
362 String::from_utf8_lossy(m.as_bytes()).into_owned(),
363 );
364 }
365 }
366 if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
367 if let Some(m) = caps.get(1) {
368 return (
369 StatementType::Insert,
370 String::from_utf8_lossy(m.as_bytes()).into_owned(),
371 );
372 }
373 }
374 }
375
376 if upper_prefix.starts_with(b"CREATE INDEX") {
377 if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
378 if let Some(m) = caps.get(1) {
379 return (
380 StatementType::CreateIndex,
381 String::from_utf8_lossy(m.as_bytes()).into_owned(),
382 );
383 }
384 }
385 }
386
387 if upper_prefix.starts_with(b"ALTER TABLE") {
388 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
389 return (StatementType::AlterTable, name);
390 }
391 if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
392 if let Some(m) = caps.get(1) {
393 return (
394 StatementType::AlterTable,
395 String::from_utf8_lossy(m.as_bytes()).into_owned(),
396 );
397 }
398 }
399 }
400
401 if upper_prefix.starts_with(b"DROP TABLE") {
402 if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
403 return (StatementType::DropTable, name);
404 }
405 if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
406 if let Some(m) = caps.get(1) {
407 return (
408 StatementType::DropTable,
409 String::from_utf8_lossy(m.as_bytes()).into_owned(),
410 );
411 }
412 }
413 }
414
415 (StatementType::Unknown, String::new())
416 }
417}
418
419#[inline]
420fn trim_ascii_start(data: &[u8]) -> &[u8] {
421 let start = data
422 .iter()
423 .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
424 .unwrap_or(data.len());
425 &data[start..]
426}
427
428fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
431 loop {
432 data = trim_ascii_start(data);
434
435 if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
436 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
438 data = &data[pos + 1..];
439 continue;
440 } else {
441 return &[];
443 }
444 }
445
446 break;
447 }
448
449 data
450}
451
452#[inline]
458fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
459 let mut i = offset;
460
461 while i < stmt.len() && is_whitespace(stmt[i]) {
463 i += 1;
464 }
465
466 if i >= stmt.len() {
467 return None;
468 }
469
470 let upper_check: Vec<u8> = stmt[i..].iter().take(20).map(|b| b.to_ascii_uppercase()).collect();
472 if upper_check.starts_with(b"IF NOT EXISTS") {
473 i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
475 i += 1;
476 }
477 }
478
479 let upper_check: Vec<u8> = stmt[i..].iter().take(10).map(|b| b.to_ascii_uppercase()).collect();
481 if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
482 i += 4;
483 while i < stmt.len() && is_whitespace(stmt[i]) {
484 i += 1;
485 }
486 }
487
488 if i >= stmt.len() {
489 return None;
490 }
491
492 let mut parts: Vec<String> = Vec::new();
494
495 loop {
496 let quote_char = match stmt.get(i) {
498 Some(b'`') if dialect == SqlDialect::MySql => {
499 i += 1;
500 Some(b'`')
501 }
502 Some(b'"') if dialect != SqlDialect::MySql => {
503 i += 1;
504 Some(b'"')
505 }
506 Some(b'"') => {
507 i += 1;
509 Some(b'"')
510 }
511 _ => None,
512 };
513
514 let start = i;
515
516 while i < stmt.len() {
517 let b = stmt[i];
518 if let Some(q) = quote_char {
519 if b == q {
520 let name = &stmt[start..i];
521 parts.push(String::from_utf8_lossy(name).into_owned());
522 i += 1; break;
524 }
525 } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
526 if i > start {
527 let name = &stmt[start..i];
528 parts.push(String::from_utf8_lossy(name).into_owned());
529 }
530 break;
531 }
532 i += 1;
533 }
534
535 if quote_char.is_some() && i <= start {
537 break;
538 }
539
540 while i < stmt.len() && is_whitespace(stmt[i]) {
542 i += 1;
543 }
544
545 if i < stmt.len() && stmt[i] == b'.' {
546 i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
548 i += 1;
549 }
550 } else {
552 break;
553 }
554 }
555
556 parts.pop()
558}
559
560#[inline]
561fn is_whitespace(b: u8) -> bool {
562 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
563}
564
565pub fn determine_buffer_size(file_size: u64) -> usize {
566 if file_size > 1024 * 1024 * 1024 {
567 MEDIUM_BUFFER_SIZE
568 } else {
569 SMALL_BUFFER_SIZE
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_parse_create_table() {
579 let stmt = b"CREATE TABLE users (id INT);";
580 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
581 assert_eq!(typ, StatementType::CreateTable);
582 assert_eq!(name, "users");
583 }
584
585 #[test]
586 fn test_parse_create_table_backticks() {
587 let stmt = b"CREATE TABLE `my_table` (id INT);";
588 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
589 assert_eq!(typ, StatementType::CreateTable);
590 assert_eq!(name, "my_table");
591 }
592
593 #[test]
594 fn test_parse_insert() {
595 let stmt = b"INSERT INTO posts VALUES (1, 'test');";
596 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
597 assert_eq!(typ, StatementType::Insert);
598 assert_eq!(name, "posts");
599 }
600
601 #[test]
602 fn test_parse_insert_backticks() {
603 let stmt = b"INSERT INTO `comments` VALUES (1);";
604 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
605 assert_eq!(typ, StatementType::Insert);
606 assert_eq!(name, "comments");
607 }
608
609 #[test]
610 fn test_parse_alter_table() {
611 let stmt = b"ALTER TABLE orders ADD COLUMN status INT;";
612 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
613 assert_eq!(typ, StatementType::AlterTable);
614 assert_eq!(name, "orders");
615 }
616
617 #[test]
618 fn test_parse_drop_table() {
619 let stmt = b"DROP TABLE temp_data;";
620 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
621 assert_eq!(typ, StatementType::DropTable);
622 assert_eq!(name, "temp_data");
623 }
624
625 #[test]
626 fn test_read_statement_basic() {
627 let sql = b"CREATE TABLE t1 (id INT); INSERT INTO t1 VALUES (1);";
628 let mut parser = Parser::new(&sql[..], 1024);
629
630 let stmt1 = parser.read_statement().unwrap().unwrap();
631 assert_eq!(stmt1, b"CREATE TABLE t1 (id INT);");
632
633 let stmt2 = parser.read_statement().unwrap().unwrap();
634 assert_eq!(stmt2, b" INSERT INTO t1 VALUES (1);");
635
636 let stmt3 = parser.read_statement().unwrap();
637 assert!(stmt3.is_none());
638 }
639
640 #[test]
641 fn test_read_statement_with_strings() {
642 let sql = b"INSERT INTO t1 VALUES ('hello; world');";
643 let mut parser = Parser::new(&sql[..], 1024);
644
645 let stmt = parser.read_statement().unwrap().unwrap();
646 assert_eq!(stmt, b"INSERT INTO t1 VALUES ('hello; world');");
647 }
648
649 #[test]
650 fn test_read_statement_with_escaped_quotes() {
651 let sql = b"INSERT INTO t1 VALUES ('it\\'s a test');";
652 let mut parser = Parser::new(&sql[..], 1024);
653
654 let stmt = parser.read_statement().unwrap().unwrap();
655 assert_eq!(stmt, b"INSERT INTO t1 VALUES ('it\\'s a test');");
656 }
657}
658
659#[cfg(test)]
660mod copy_tests {
661 use super::*;
662 use std::io::Cursor;
663
664 #[test]
665 fn test_copy_from_stdin_detection() {
666 let data = b"COPY public.table_001 (id, col_int, col_varchar, col_text, col_decimal, created_at) FROM stdin;\n1\t6892\tvalue_1\tLorem ipsum\n\\.\n";
667 let reader = Cursor::new(&data[..]);
668 let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
669
670 let stmt1 = parser.read_statement().unwrap().unwrap();
672 let s1 = String::from_utf8_lossy(&stmt1);
673 assert!(s1.starts_with("COPY"), "First statement should be COPY");
674 assert!(s1.contains("FROM stdin"), "Should contain FROM stdin");
675
676 let stmt2 = parser.read_statement().unwrap().unwrap();
678 let s2 = String::from_utf8_lossy(&stmt2);
679 assert!(s2.contains("1\t6892"), "Data block should contain first row");
680 assert!(s2.ends_with("\\.\n"), "Data block should end with terminator");
681 }
682
683 #[test]
684 fn test_copy_with_leading_comments() {
685 let data = b"--\n-- Data for Name: table_001\n--\n\nCOPY public.table_001 (id, name) FROM stdin;\n1\tfoo\n\\.\n";
687 let reader = Cursor::new(&data[..]);
688 let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
689
690 let stmt1 = parser.read_statement().unwrap().unwrap();
692 let (stmt_type, table_name) = Parser::<&[u8]>::parse_statement_with_dialect(&stmt1, SqlDialect::Postgres);
693 assert_eq!(stmt_type, StatementType::Copy);
694 assert_eq!(table_name, "table_001");
695
696 let stmt2 = parser.read_statement().unwrap().unwrap();
698 let s2 = String::from_utf8_lossy(&stmt2);
699 assert!(s2.ends_with("\\.\n"), "Data block should end with terminator");
700 }
701}