1#[cfg(test)]
2mod edge_case_tests;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14 #[default]
16 MySql,
17 Postgres,
19 Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24 type Err = String;
25
26 fn from_str(s: &str) -> Result<Self, Self::Err> {
27 match s.to_lowercase().as_str() {
28 "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29 "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30 "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31 _ => Err(format!(
32 "Unknown dialect: {}. Valid options: mysql, postgres, sqlite",
33 s
34 )),
35 }
36 }
37}
38
39impl std::fmt::Display for SqlDialect {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 SqlDialect::MySql => write!(f, "mysql"),
43 SqlDialect::Postgres => write!(f, "postgres"),
44 SqlDialect::Sqlite => write!(f, "sqlite"),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct DialectDetectionResult {
52 pub dialect: SqlDialect,
53 pub confidence: DialectConfidence,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DialectConfidence {
59 High,
61 Medium,
63 Low,
65}
66
67#[derive(Default)]
68struct DialectScore {
69 mysql: u32,
70 postgres: u32,
71 sqlite: u32,
72}
73
74pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
77 let mut score = DialectScore::default();
78
79 if contains_bytes(header, b"pg_dump") {
81 score.postgres += 10;
82 }
83 if contains_bytes(header, b"PostgreSQL database dump") {
84 score.postgres += 10;
85 }
86 if contains_bytes(header, b"MySQL dump") {
87 score.mysql += 10;
88 }
89 if contains_bytes(header, b"MariaDB dump") {
90 score.mysql += 10;
91 }
92 if contains_bytes(header, b"SQLite") {
93 score.sqlite += 10;
94 }
95
96 if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
98 score.postgres += 5;
99 }
100 if contains_bytes(header, b"search_path") {
101 score.postgres += 5;
102 }
103 if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
104 score.mysql += 5;
105 }
106 if contains_bytes(header, b"LOCK TABLES") {
107 score.mysql += 5;
108 }
109 if contains_bytes(header, b"PRAGMA") {
110 score.sqlite += 5;
111 }
112
113 if contains_bytes(header, b"$$") {
115 score.postgres += 2;
116 }
117 if contains_bytes(header, b"CREATE EXTENSION") {
118 score.postgres += 2;
119 }
120 if contains_bytes(header, b"BEGIN TRANSACTION") {
122 score.sqlite += 2;
123 }
124 if header.contains(&b'`') {
126 score.mysql += 2;
127 }
128
129 let max_score = score.mysql.max(score.postgres).max(score.sqlite);
131
132 if max_score == 0 {
133 return DialectDetectionResult {
134 dialect: SqlDialect::MySql,
135 confidence: DialectConfidence::Low,
136 };
137 }
138
139 let (dialect, confidence) = if score.postgres > score.mysql && score.postgres > score.sqlite {
140 let conf = if score.postgres >= 10 {
141 DialectConfidence::High
142 } else if score.postgres >= 5 {
143 DialectConfidence::Medium
144 } else {
145 DialectConfidence::Low
146 };
147 (SqlDialect::Postgres, conf)
148 } else if score.sqlite > score.mysql {
149 let conf = if score.sqlite >= 10 {
150 DialectConfidence::High
151 } else if score.sqlite >= 5 {
152 DialectConfidence::Medium
153 } else {
154 DialectConfidence::Low
155 };
156 (SqlDialect::Sqlite, conf)
157 } else {
158 let conf = if score.mysql >= 10 {
159 DialectConfidence::High
160 } else if score.mysql >= 5 {
161 DialectConfidence::Medium
162 } else {
163 DialectConfidence::Low
164 };
165 (SqlDialect::MySql, conf)
166 };
167
168 DialectDetectionResult {
169 dialect,
170 confidence,
171 }
172}
173
174pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
176 use std::fs::File;
177 use std::io::Read;
178
179 let mut file = File::open(path)?;
180 let mut buf = [0u8; 8192];
181 let n = file.read(&mut buf)?;
182 Ok(detect_dialect(&buf[..n]))
183}
184
185#[inline]
186fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
187 haystack
188 .windows(needle.len())
189 .any(|window| window == needle)
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
193pub enum StatementType {
194 Unknown,
195 CreateTable,
196 Insert,
197 CreateIndex,
198 AlterTable,
199 DropTable,
200 Copy,
202}
203
204impl StatementType {
205 pub fn is_schema(&self) -> bool {
207 matches!(
208 self,
209 StatementType::CreateTable
210 | StatementType::CreateIndex
211 | StatementType::AlterTable
212 | StatementType::DropTable
213 )
214 }
215
216 pub fn is_data(&self) -> bool {
218 matches!(self, StatementType::Insert | StatementType::Copy)
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
224pub enum ContentFilter {
225 #[default]
227 All,
228 SchemaOnly,
230 DataOnly,
232}
233
234static CREATE_TABLE_RE: Lazy<Regex> =
235 Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
236
237static INSERT_INTO_RE: Lazy<Regex> =
238 Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
239
240static CREATE_INDEX_RE: Lazy<Regex> =
241 Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
242
243static ALTER_TABLE_RE: Lazy<Regex> =
244 Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
245
246static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
247 Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
248});
249
250static COPY_RE: Lazy<Regex> =
252 Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
253
254static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
260 Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
261});
262
263static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
264 Regex::new(
265 r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
266 )
267 .unwrap()
268});
269
270pub struct Parser<R: Read> {
271 reader: BufReader<R>,
272 stmt_buffer: Vec<u8>,
273 dialect: SqlDialect,
274 in_copy_data: bool,
276}
277
278impl<R: Read> Parser<R> {
279 #[allow(dead_code)]
280 pub fn new(reader: R, buffer_size: usize) -> Self {
281 Self::with_dialect(reader, buffer_size, SqlDialect::default())
282 }
283
284 pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
285 Self {
286 reader: BufReader::with_capacity(buffer_size, reader),
287 stmt_buffer: Vec::with_capacity(32 * 1024),
288 dialect,
289 in_copy_data: false,
290 }
291 }
292
293 pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
294 if self.in_copy_data {
296 return self.read_copy_data();
297 }
298
299 self.stmt_buffer.clear();
300
301 let mut inside_single_quote = false;
302 let mut inside_double_quote = false;
303 let mut escaped = false;
304 let mut in_line_comment = false;
305 let mut in_dollar_quote = false;
307 let mut dollar_tag: Vec<u8> = Vec::new();
308
309 loop {
310 let buf = self.reader.fill_buf()?;
311 if buf.is_empty() {
312 if self.stmt_buffer.is_empty() {
313 return Ok(None);
314 }
315 let result = std::mem::take(&mut self.stmt_buffer);
316 return Ok(Some(result));
317 }
318
319 let mut consumed = 0;
320 let mut found_terminator = false;
321
322 for (i, &b) in buf.iter().enumerate() {
323 let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
324
325 if in_line_comment {
327 if b == b'\n' {
328 in_line_comment = false;
329 }
330 continue;
331 }
332
333 if escaped {
334 escaped = false;
335 continue;
336 }
337
338 if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
340 escaped = true;
341 continue;
342 }
343
344 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
346 in_line_comment = true;
347 continue;
348 }
349
350 if self.dialect == SqlDialect::Postgres
352 && !inside_single_quote
353 && !inside_double_quote
354 {
355 if b == b'$' && !in_dollar_quote {
356 if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
358 let tag_bytes = &buf[i + 1..i + 1 + end];
359
360 let is_valid_tag = if tag_bytes.is_empty() {
362 true
363 } else {
364 let mut iter = tag_bytes.iter();
365 match iter.next() {
366 Some(&first)
367 if first.is_ascii_alphabetic() || first == b'_' =>
368 {
369 iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
370 }
371 _ => false,
372 }
373 };
374
375 if is_valid_tag {
376 dollar_tag = tag_bytes.to_vec();
377 in_dollar_quote = true;
378 continue;
379 }
380 }
382 } else if b == b'$' && in_dollar_quote {
383 let tag_len = dollar_tag.len();
385 if i + 1 + tag_len < buf.len()
386 && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
387 && buf.get(i + 1 + tag_len) == Some(&b'$')
388 {
389 in_dollar_quote = false;
390 dollar_tag.clear();
391 continue;
392 }
393 }
394 }
395
396 if b == b'\'' && !inside_double_quote && !in_dollar_quote {
397 inside_single_quote = !inside_single_quote;
398 } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
399 inside_double_quote = !inside_double_quote;
400 } else if b == b';' && !inside_string {
401 self.stmt_buffer.extend_from_slice(&buf[..=i]);
402 consumed = i + 1;
403 found_terminator = true;
404 break;
405 }
406 }
407
408 if found_terminator {
409 self.reader.consume(consumed);
410 let result = std::mem::take(&mut self.stmt_buffer);
411
412 if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
414 self.in_copy_data = true;
415 }
416
417 return Ok(Some(result));
418 }
419
420 self.stmt_buffer.extend_from_slice(buf);
421 let len = buf.len();
422 self.reader.consume(len);
423 }
424 }
425
426 fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
428 let stmt = strip_leading_comments_and_whitespace(stmt);
430 if stmt.len() < 4 {
431 return false;
432 }
433
434 let upper: Vec<u8> = stmt
436 .iter()
437 .take(500)
438 .map(|b| b.to_ascii_uppercase())
439 .collect();
440 upper.starts_with(b"COPY ")
441 && (upper.windows(10).any(|w| w == b"FROM STDIN")
442 || upper.windows(11).any(|w| w == b"FROM STDIN;"))
443 }
444
445 fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
447 self.stmt_buffer.clear();
448
449 loop {
450 let buf = self.reader.fill_buf()?;
452 if buf.is_empty() {
453 self.in_copy_data = false;
454 if self.stmt_buffer.is_empty() {
455 return Ok(None);
456 }
457 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
458 }
459
460 let newline_pos = buf.iter().position(|&b| b == b'\n');
462
463 if let Some(i) = newline_pos {
464 self.stmt_buffer.extend_from_slice(&buf[..=i]);
466 self.reader.consume(i + 1);
467
468 if self.ends_with_copy_terminator() {
471 self.in_copy_data = false;
472 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
473 }
474 } else {
476 let len = buf.len();
478 self.stmt_buffer.extend_from_slice(buf);
479 self.reader.consume(len);
480 }
481 }
482 }
483
484 fn ends_with_copy_terminator(&self) -> bool {
486 let data = &self.stmt_buffer;
487 if data.len() < 2 {
488 return false;
489 }
490
491 let last_newline = data[..data.len() - 1]
494 .iter()
495 .rposition(|&b| b == b'\n')
496 .map(|i| i + 1)
497 .unwrap_or(0);
498
499 let last_line = &data[last_newline..];
500
501 last_line == b"\\.\n" || last_line == b"\\.\r\n"
503 }
504
505 #[allow(dead_code)]
506 pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
507 Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
508 }
509
510 pub fn parse_statement_with_dialect(
512 stmt: &[u8],
513 dialect: SqlDialect,
514 ) -> (StatementType, String) {
515 let stmt = strip_leading_comments_and_whitespace(stmt);
517
518 if stmt.len() < 4 {
519 return (StatementType::Unknown, String::new());
520 }
521
522 let upper_prefix: Vec<u8> = stmt
523 .iter()
524 .take(25)
525 .map(|b| b.to_ascii_uppercase())
526 .collect();
527
528 if upper_prefix.starts_with(b"COPY ") {
530 if let Some(caps) = COPY_RE.captures(stmt) {
531 if let Some(m) = caps.get(1) {
532 let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
533 let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
535 return (StatementType::Copy, table_name);
536 }
537 }
538 }
539
540 if upper_prefix.starts_with(b"CREATE TABLE") {
541 if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
543 return (StatementType::CreateTable, name);
544 }
545 if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
547 if let Some(m) = caps.get(1) {
548 return (
549 StatementType::CreateTable,
550 String::from_utf8_lossy(m.as_bytes()).into_owned(),
551 );
552 }
553 }
554 if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
556 if let Some(m) = caps.get(1) {
557 return (
558 StatementType::CreateTable,
559 String::from_utf8_lossy(m.as_bytes()).into_owned(),
560 );
561 }
562 }
563 }
564
565 if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
566 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
567 return (StatementType::Insert, name);
568 }
569 if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
570 if let Some(m) = caps.get(1) {
571 return (
572 StatementType::Insert,
573 String::from_utf8_lossy(m.as_bytes()).into_owned(),
574 );
575 }
576 }
577 if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
578 if let Some(m) = caps.get(1) {
579 return (
580 StatementType::Insert,
581 String::from_utf8_lossy(m.as_bytes()).into_owned(),
582 );
583 }
584 }
585 }
586
587 if upper_prefix.starts_with(b"CREATE INDEX") {
588 if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
589 if let Some(m) = caps.get(1) {
590 return (
591 StatementType::CreateIndex,
592 String::from_utf8_lossy(m.as_bytes()).into_owned(),
593 );
594 }
595 }
596 }
597
598 if upper_prefix.starts_with(b"ALTER TABLE") {
599 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
600 return (StatementType::AlterTable, name);
601 }
602 if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
603 if let Some(m) = caps.get(1) {
604 return (
605 StatementType::AlterTable,
606 String::from_utf8_lossy(m.as_bytes()).into_owned(),
607 );
608 }
609 }
610 }
611
612 if upper_prefix.starts_with(b"DROP TABLE") {
613 if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
614 return (StatementType::DropTable, name);
615 }
616 if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
617 if let Some(m) = caps.get(1) {
618 return (
619 StatementType::DropTable,
620 String::from_utf8_lossy(m.as_bytes()).into_owned(),
621 );
622 }
623 }
624 }
625
626 (StatementType::Unknown, String::new())
627 }
628}
629
630#[inline]
631fn trim_ascii_start(data: &[u8]) -> &[u8] {
632 let start = data
633 .iter()
634 .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
635 .unwrap_or(data.len());
636 &data[start..]
637}
638
639fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
642 loop {
643 data = trim_ascii_start(data);
645
646 if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
647 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
649 data = &data[pos + 1..];
650 continue;
651 } else {
652 return &[];
654 }
655 }
656
657 break;
658 }
659
660 data
661}
662
663#[inline]
669fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
670 let mut i = offset;
671
672 while i < stmt.len() && is_whitespace(stmt[i]) {
674 i += 1;
675 }
676
677 if i >= stmt.len() {
678 return None;
679 }
680
681 let upper_check: Vec<u8> = stmt[i..]
683 .iter()
684 .take(20)
685 .map(|b| b.to_ascii_uppercase())
686 .collect();
687 if upper_check.starts_with(b"IF NOT EXISTS") {
688 i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
690 i += 1;
691 }
692 } else if upper_check.starts_with(b"IF EXISTS") {
693 i += 9; while i < stmt.len() && is_whitespace(stmt[i]) {
695 i += 1;
696 }
697 }
698
699 let upper_check: Vec<u8> = stmt[i..]
701 .iter()
702 .take(10)
703 .map(|b| b.to_ascii_uppercase())
704 .collect();
705 if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
706 i += 4;
707 while i < stmt.len() && is_whitespace(stmt[i]) {
708 i += 1;
709 }
710 }
711
712 if i >= stmt.len() {
713 return None;
714 }
715
716 let mut parts: Vec<String> = Vec::new();
718
719 loop {
720 let quote_char = match stmt.get(i) {
722 Some(b'`') if dialect == SqlDialect::MySql => {
723 i += 1;
724 Some(b'`')
725 }
726 Some(b'"') if dialect != SqlDialect::MySql => {
727 i += 1;
728 Some(b'"')
729 }
730 Some(b'"') => {
731 i += 1;
733 Some(b'"')
734 }
735 _ => None,
736 };
737
738 let start = i;
739
740 while i < stmt.len() {
741 let b = stmt[i];
742 if let Some(q) = quote_char {
743 if b == q {
744 let name = &stmt[start..i];
745 parts.push(String::from_utf8_lossy(name).into_owned());
746 i += 1; break;
748 }
749 } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
750 if i > start {
751 let name = &stmt[start..i];
752 parts.push(String::from_utf8_lossy(name).into_owned());
753 }
754 break;
755 }
756 i += 1;
757 }
758
759 if quote_char.is_some() && i <= start {
761 break;
762 }
763
764 while i < stmt.len() && is_whitespace(stmt[i]) {
766 i += 1;
767 }
768
769 if i < stmt.len() && stmt[i] == b'.' {
770 i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
772 i += 1;
773 }
774 } else {
776 break;
777 }
778 }
779
780 parts.pop()
782}
783
784#[inline]
785fn is_whitespace(b: u8) -> bool {
786 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
787}
788
789pub fn determine_buffer_size(file_size: u64) -> usize {
790 if file_size > 1024 * 1024 * 1024 {
791 MEDIUM_BUFFER_SIZE
792 } else {
793 SMALL_BUFFER_SIZE
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_parse_create_table() {
803 let stmt = b"CREATE TABLE users (id INT);";
804 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
805 assert_eq!(typ, StatementType::CreateTable);
806 assert_eq!(name, "users");
807 }
808
809 #[test]
810 fn test_parse_create_table_backticks() {
811 let stmt = b"CREATE TABLE `my_table` (id INT);";
812 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
813 assert_eq!(typ, StatementType::CreateTable);
814 assert_eq!(name, "my_table");
815 }
816
817 #[test]
818 fn test_parse_insert() {
819 let stmt = b"INSERT INTO posts VALUES (1, 'test');";
820 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
821 assert_eq!(typ, StatementType::Insert);
822 assert_eq!(name, "posts");
823 }
824
825 #[test]
826 fn test_parse_insert_backticks() {
827 let stmt = b"INSERT INTO `comments` VALUES (1);";
828 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
829 assert_eq!(typ, StatementType::Insert);
830 assert_eq!(name, "comments");
831 }
832
833 #[test]
834 fn test_parse_alter_table() {
835 let stmt = b"ALTER TABLE orders ADD COLUMN status INT;";
836 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
837 assert_eq!(typ, StatementType::AlterTable);
838 assert_eq!(name, "orders");
839 }
840
841 #[test]
842 fn test_parse_drop_table() {
843 let stmt = b"DROP TABLE temp_data;";
844 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
845 assert_eq!(typ, StatementType::DropTable);
846 assert_eq!(name, "temp_data");
847 }
848
849 #[test]
850 fn test_read_statement_basic() {
851 let sql = b"CREATE TABLE t1 (id INT); INSERT INTO t1 VALUES (1);";
852 let mut parser = Parser::new(&sql[..], 1024);
853
854 let stmt1 = parser.read_statement().unwrap().unwrap();
855 assert_eq!(stmt1, b"CREATE TABLE t1 (id INT);");
856
857 let stmt2 = parser.read_statement().unwrap().unwrap();
858 assert_eq!(stmt2, b" INSERT INTO t1 VALUES (1);");
859
860 let stmt3 = parser.read_statement().unwrap();
861 assert!(stmt3.is_none());
862 }
863
864 #[test]
865 fn test_read_statement_with_strings() {
866 let sql = b"INSERT INTO t1 VALUES ('hello; world');";
867 let mut parser = Parser::new(&sql[..], 1024);
868
869 let stmt = parser.read_statement().unwrap().unwrap();
870 assert_eq!(stmt, b"INSERT INTO t1 VALUES ('hello; world');");
871 }
872
873 #[test]
874 fn test_read_statement_with_escaped_quotes() {
875 let sql = b"INSERT INTO t1 VALUES ('it\\'s a test');";
876 let mut parser = Parser::new(&sql[..], 1024);
877
878 let stmt = parser.read_statement().unwrap().unwrap();
879 assert_eq!(stmt, b"INSERT INTO t1 VALUES ('it\\'s a test');");
880 }
881}
882
883#[cfg(test)]
884mod copy_tests {
885 use super::*;
886 use std::io::Cursor;
887
888 #[test]
889 fn test_copy_from_stdin_detection() {
890 let data = b"COPY public.table_001 (id, col_int, col_varchar, col_text, col_decimal, created_at) FROM stdin;\n1\t6892\tvalue_1\tLorem ipsum\n\\.\n";
891 let reader = Cursor::new(&data[..]);
892 let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
893
894 let stmt1 = parser.read_statement().unwrap().unwrap();
896 let s1 = String::from_utf8_lossy(&stmt1);
897 assert!(s1.starts_with("COPY"), "First statement should be COPY");
898 assert!(s1.contains("FROM stdin"), "Should contain FROM stdin");
899
900 let stmt2 = parser.read_statement().unwrap().unwrap();
902 let s2 = String::from_utf8_lossy(&stmt2);
903 assert!(
904 s2.contains("1\t6892"),
905 "Data block should contain first row"
906 );
907 assert!(
908 s2.ends_with("\\.\n"),
909 "Data block should end with terminator"
910 );
911 }
912
913 #[test]
914 fn test_copy_with_leading_comments() {
915 let data = b"--\n-- Data for Name: table_001\n--\n\nCOPY public.table_001 (id, name) FROM stdin;\n1\tfoo\n\\.\n";
917 let reader = Cursor::new(&data[..]);
918 let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
919
920 let stmt1 = parser.read_statement().unwrap().unwrap();
922 let (stmt_type, table_name) =
923 Parser::<&[u8]>::parse_statement_with_dialect(&stmt1, SqlDialect::Postgres);
924 assert_eq!(stmt_type, StatementType::Copy);
925 assert_eq!(table_name, "table_001");
926
927 let stmt2 = parser.read_statement().unwrap().unwrap();
929 let s2 = String::from_utf8_lossy(&stmt2);
930 assert!(
931 s2.ends_with("\\.\n"),
932 "Data block should end with terminator"
933 );
934 }
935}
936
937#[cfg(test)]
938mod dialect_detection_tests {
939 use super::*;
940
941 #[test]
942 fn test_detect_mysql_dump_header() {
943 let header = b"-- MySQL dump 10.13 Distrib 8.0.32, for Linux (x86_64)
944--
945-- Host: localhost Database: mydb
946-- ------------------------------------------------------
947-- Server version 8.0.32
948
949/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
950";
951 let result = detect_dialect(header);
952 assert_eq!(result.dialect, SqlDialect::MySql);
953 assert_eq!(result.confidence, DialectConfidence::High);
954 }
955
956 #[test]
957 fn test_detect_mariadb_dump_header() {
958 let header = b"-- MariaDB dump 10.19 Distrib 10.11.2-MariaDB
959--
960-- Host: localhost Database: test
961";
962 let result = detect_dialect(header);
963 assert_eq!(result.dialect, SqlDialect::MySql);
964 assert_eq!(result.confidence, DialectConfidence::High);
965 }
966
967 #[test]
968 fn test_detect_postgres_pgdump_header() {
969 let header = b"--
970-- PostgreSQL database dump
971--
972
973-- Dumped from database version 15.2
974-- Dumped by pg_dump version 15.2
975
976SET statement_timeout = 0;
977SET search_path = public, pg_catalog;
978";
979 let result = detect_dialect(header);
980 assert_eq!(result.dialect, SqlDialect::Postgres);
981 assert_eq!(result.confidence, DialectConfidence::High);
982 }
983
984 #[test]
985 fn test_detect_postgres_copy_statement() {
986 let header = b"COPY public.users (id, name, email) FROM stdin;
9871\tAlice\talice@example.com
9882\tBob\tbob@example.com
989\\.
990";
991 let result = detect_dialect(header);
992 assert_eq!(result.dialect, SqlDialect::Postgres);
993 assert_eq!(result.confidence, DialectConfidence::Medium);
994 }
995
996 #[test]
997 fn test_detect_postgres_dollar_quoting() {
998 let header = b"CREATE OR REPLACE FUNCTION test() RETURNS void AS $$
999BEGIN
1000 RAISE NOTICE 'Hello';
1001END;
1002$$ LANGUAGE plpgsql;
1003";
1004 let result = detect_dialect(header);
1005 assert_eq!(result.dialect, SqlDialect::Postgres);
1006 }
1007
1008 #[test]
1009 fn test_detect_sqlite_dump_header() {
1010 let header = b"-- SQLite database dump
1012PRAGMA foreign_keys=OFF;
1013BEGIN TRANSACTION;
1014CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT);
1015INSERT INTO users VALUES(1,'Alice');
1016COMMIT;
1017";
1018 let result = detect_dialect(header);
1019 assert_eq!(result.dialect, SqlDialect::Sqlite);
1020 assert_eq!(result.confidence, DialectConfidence::High);
1022 }
1023
1024 #[test]
1025 fn test_detect_sqlite_pragma_only() {
1026 let header = b"PRAGMA foreign_keys=OFF;
1027CREATE TABLE test (id INT);
1028";
1029 let result = detect_dialect(header);
1030 assert_eq!(result.dialect, SqlDialect::Sqlite);
1031 assert_eq!(result.confidence, DialectConfidence::Medium);
1032 }
1033
1034 #[test]
1035 fn test_detect_mysql_backticks() {
1036 let header = b"CREATE TABLE `users` (
1037 `id` int NOT NULL AUTO_INCREMENT,
1038 `name` varchar(255) DEFAULT NULL,
1039 PRIMARY KEY (`id`)
1040);
1041";
1042 let result = detect_dialect(header);
1043 assert_eq!(result.dialect, SqlDialect::MySql);
1044 }
1045
1046 #[test]
1047 fn test_detect_mysql_conditional_comments() {
1048 let header = b"/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
1049/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
1050/*!50503 SET NAMES utf8mb4 */;
1051";
1052 let result = detect_dialect(header);
1053 assert_eq!(result.dialect, SqlDialect::MySql);
1054 assert_eq!(result.confidence, DialectConfidence::Medium);
1055 }
1056
1057 #[test]
1058 fn test_detect_mysql_lock_tables() {
1059 let header = b"LOCK TABLES `users` WRITE;
1060INSERT INTO `users` VALUES (1,'test');
1061UNLOCK TABLES;
1062";
1063 let result = detect_dialect(header);
1064 assert_eq!(result.dialect, SqlDialect::MySql);
1065 assert_eq!(result.confidence, DialectConfidence::Medium);
1066 }
1067
1068 #[test]
1069 fn test_detect_empty_defaults_to_mysql() {
1070 let header = b"";
1071 let result = detect_dialect(header);
1072 assert_eq!(result.dialect, SqlDialect::MySql);
1073 assert_eq!(result.confidence, DialectConfidence::Low);
1074 }
1075
1076 #[test]
1077 fn test_detect_generic_sql_defaults_to_mysql() {
1078 let header = b"CREATE TABLE users (id INT, name VARCHAR(100));
1079INSERT INTO users VALUES (1, 'Alice');
1080";
1081 let result = detect_dialect(header);
1082 assert_eq!(result.dialect, SqlDialect::MySql);
1083 assert_eq!(result.confidence, DialectConfidence::Low);
1084 }
1085
1086 #[test]
1087 fn test_detect_postgres_create_extension() {
1088 let header = b"CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";
1089CREATE TABLE users (id uuid DEFAULT uuid_generate_v4());
1090";
1091 let result = detect_dialect(header);
1092 assert_eq!(result.dialect, SqlDialect::Postgres);
1093 }
1094
1095 #[test]
1096 fn test_detect_sqlite_comment() {
1097 let header = b"-- SQLite database dump
1098-- Created by sqlite3
1099
1100CREATE TABLE test (id INTEGER);
1101";
1102 let result = detect_dialect(header);
1103 assert_eq!(result.dialect, SqlDialect::Sqlite);
1104 assert_eq!(result.confidence, DialectConfidence::High);
1105 }
1106
1107 #[test]
1108 fn test_scoring_postgres_beats_mysql_backticks() {
1109 let header = b"--
1111-- PostgreSQL database dump
1112--
1113-- Dumped by pg_dump version 15.2
1114
1115INSERT INTO notes VALUES (1, 'Use `code` for inline code');
1116";
1117 let result = detect_dialect(header);
1118 assert_eq!(result.dialect, SqlDialect::Postgres);
1119 assert_eq!(result.confidence, DialectConfidence::High);
1120 }
1121
1122 #[test]
1123 fn test_begin_transaction_alone_is_low_confidence() {
1124 let header = b"BEGIN TRANSACTION;
1126CREATE TABLE t (id INTEGER);
1127COMMIT;
1128";
1129 let result = detect_dialect(header);
1130 assert_eq!(result.dialect, SqlDialect::Sqlite);
1132 assert_eq!(result.confidence, DialectConfidence::Low);
1133 }
1134
1135 #[test]
1136 fn test_backticks_only_is_low_confidence() {
1137 let header = b"CREATE TABLE `users` (id INT);
1139INSERT INTO `users` VALUES (1);
1140";
1141 let result = detect_dialect(header);
1142 assert_eq!(result.dialect, SqlDialect::MySql);
1143 assert_eq!(result.confidence, DialectConfidence::Low);
1144 }
1145
1146 #[test]
1147 fn test_conflicting_markers_postgres_wins() {
1148 let header = b"-- PostgreSQL database dump
1150SET search_path = public;
1151INSERT INTO notes VALUES (1, 'Use `backticks` for code');
1152";
1153 let result = detect_dialect(header);
1154 assert_eq!(result.dialect, SqlDialect::Postgres);
1155 assert_eq!(result.confidence, DialectConfidence::High);
1157 }
1158}