1pub mod mysql_insert;
2pub mod postgres_copy;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14 #[default]
16 MySql,
17 Postgres,
19 Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24 type Err = String;
25
26 fn from_str(s: &str) -> Result<Self, Self::Err> {
27 match s.to_lowercase().as_str() {
28 "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29 "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30 "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31 _ => Err(format!(
32 "Unknown dialect: {}. Valid options: mysql, postgres, sqlite",
33 s
34 )),
35 }
36 }
37}
38
39impl std::fmt::Display for SqlDialect {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 SqlDialect::MySql => write!(f, "mysql"),
43 SqlDialect::Postgres => write!(f, "postgres"),
44 SqlDialect::Sqlite => write!(f, "sqlite"),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct DialectDetectionResult {
52 pub dialect: SqlDialect,
53 pub confidence: DialectConfidence,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DialectConfidence {
59 High,
61 Medium,
63 Low,
65}
66
67#[derive(Default)]
68struct DialectScore {
69 mysql: u32,
70 postgres: u32,
71 sqlite: u32,
72}
73
74pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
77 let mut score = DialectScore::default();
78
79 if contains_bytes(header, b"pg_dump") {
81 score.postgres += 10;
82 }
83 if contains_bytes(header, b"PostgreSQL database dump") {
84 score.postgres += 10;
85 }
86 if contains_bytes(header, b"MySQL dump") {
87 score.mysql += 10;
88 }
89 if contains_bytes(header, b"MariaDB dump") {
90 score.mysql += 10;
91 }
92 if contains_bytes(header, b"SQLite") {
93 score.sqlite += 10;
94 }
95
96 if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
98 score.postgres += 5;
99 }
100 if contains_bytes(header, b"search_path") {
101 score.postgres += 5;
102 }
103 if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
104 score.mysql += 5;
105 }
106 if contains_bytes(header, b"LOCK TABLES") {
107 score.mysql += 5;
108 }
109 if contains_bytes(header, b"PRAGMA") {
110 score.sqlite += 5;
111 }
112
113 if contains_bytes(header, b"$$") {
115 score.postgres += 2;
116 }
117 if contains_bytes(header, b"CREATE EXTENSION") {
118 score.postgres += 2;
119 }
120 if contains_bytes(header, b"BEGIN TRANSACTION") {
122 score.sqlite += 2;
123 }
124 if header.contains(&b'`') {
126 score.mysql += 2;
127 }
128
129 let max_score = score.mysql.max(score.postgres).max(score.sqlite);
131
132 if max_score == 0 {
133 return DialectDetectionResult {
134 dialect: SqlDialect::MySql,
135 confidence: DialectConfidence::Low,
136 };
137 }
138
139 let (dialect, confidence) = if score.postgres > score.mysql && score.postgres > score.sqlite {
140 let conf = if score.postgres >= 10 {
141 DialectConfidence::High
142 } else if score.postgres >= 5 {
143 DialectConfidence::Medium
144 } else {
145 DialectConfidence::Low
146 };
147 (SqlDialect::Postgres, conf)
148 } else if score.sqlite > score.mysql {
149 let conf = if score.sqlite >= 10 {
150 DialectConfidence::High
151 } else if score.sqlite >= 5 {
152 DialectConfidence::Medium
153 } else {
154 DialectConfidence::Low
155 };
156 (SqlDialect::Sqlite, conf)
157 } else {
158 let conf = if score.mysql >= 10 {
159 DialectConfidence::High
160 } else if score.mysql >= 5 {
161 DialectConfidence::Medium
162 } else {
163 DialectConfidence::Low
164 };
165 (SqlDialect::MySql, conf)
166 };
167
168 DialectDetectionResult {
169 dialect,
170 confidence,
171 }
172}
173
174pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
176 use std::fs::File;
177 use std::io::Read;
178
179 let mut file = File::open(path)?;
180 let mut buf = [0u8; 8192];
181 let n = file.read(&mut buf)?;
182 Ok(detect_dialect(&buf[..n]))
183}
184
185#[inline]
186fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
187 haystack
188 .windows(needle.len())
189 .any(|window| window == needle)
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
193pub enum StatementType {
194 Unknown,
195 CreateTable,
196 Insert,
197 CreateIndex,
198 AlterTable,
199 DropTable,
200 Copy,
202}
203
204impl StatementType {
205 pub fn is_schema(&self) -> bool {
207 matches!(
208 self,
209 StatementType::CreateTable
210 | StatementType::CreateIndex
211 | StatementType::AlterTable
212 | StatementType::DropTable
213 )
214 }
215
216 pub fn is_data(&self) -> bool {
218 matches!(self, StatementType::Insert | StatementType::Copy)
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
224pub enum ContentFilter {
225 #[default]
227 All,
228 SchemaOnly,
230 DataOnly,
232}
233
234static CREATE_TABLE_RE: Lazy<Regex> =
235 Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
236
237static INSERT_INTO_RE: Lazy<Regex> =
238 Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
239
240static CREATE_INDEX_RE: Lazy<Regex> =
241 Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
242
243static ALTER_TABLE_RE: Lazy<Regex> =
244 Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
245
246static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
247 Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
248});
249
250static COPY_RE: Lazy<Regex> =
252 Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
253
254static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
260 Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
261});
262
263static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
264 Regex::new(
265 r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
266 )
267 .unwrap()
268});
269
270pub struct Parser<R: Read> {
271 reader: BufReader<R>,
272 stmt_buffer: Vec<u8>,
273 dialect: SqlDialect,
274 in_copy_data: bool,
276}
277
278impl<R: Read> Parser<R> {
279 #[allow(dead_code)]
280 pub fn new(reader: R, buffer_size: usize) -> Self {
281 Self::with_dialect(reader, buffer_size, SqlDialect::default())
282 }
283
284 pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
285 Self {
286 reader: BufReader::with_capacity(buffer_size, reader),
287 stmt_buffer: Vec::with_capacity(32 * 1024),
288 dialect,
289 in_copy_data: false,
290 }
291 }
292
293 pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
294 if self.in_copy_data {
296 return self.read_copy_data();
297 }
298
299 self.stmt_buffer.clear();
300
301 let mut inside_single_quote = false;
302 let mut inside_double_quote = false;
303 let mut escaped = false;
304 let mut in_line_comment = false;
305 let mut in_dollar_quote = false;
307 let mut dollar_tag: Vec<u8> = Vec::new();
308
309 loop {
310 let buf = self.reader.fill_buf()?;
311 if buf.is_empty() {
312 if self.stmt_buffer.is_empty() {
313 return Ok(None);
314 }
315 let result = std::mem::take(&mut self.stmt_buffer);
316 return Ok(Some(result));
317 }
318
319 let mut consumed = 0;
320 let mut found_terminator = false;
321
322 for (i, &b) in buf.iter().enumerate() {
323 let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
324
325 if in_line_comment {
327 if b == b'\n' {
328 in_line_comment = false;
329 }
330 continue;
331 }
332
333 if escaped {
334 escaped = false;
335 continue;
336 }
337
338 if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
340 escaped = true;
341 continue;
342 }
343
344 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
346 in_line_comment = true;
347 continue;
348 }
349
350 if self.dialect == SqlDialect::Postgres
352 && !inside_single_quote
353 && !inside_double_quote
354 {
355 if b == b'$' && !in_dollar_quote {
356 if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
358 let tag_bytes = &buf[i + 1..i + 1 + end];
359
360 let is_valid_tag = if tag_bytes.is_empty() {
362 true
363 } else {
364 let mut iter = tag_bytes.iter();
365 match iter.next() {
366 Some(&first)
367 if first.is_ascii_alphabetic() || first == b'_' =>
368 {
369 iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
370 }
371 _ => false,
372 }
373 };
374
375 if is_valid_tag {
376 dollar_tag = tag_bytes.to_vec();
377 in_dollar_quote = true;
378 continue;
379 }
380 }
382 } else if b == b'$' && in_dollar_quote {
383 let tag_len = dollar_tag.len();
385 if i + 1 + tag_len < buf.len()
386 && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
387 && buf.get(i + 1 + tag_len) == Some(&b'$')
388 {
389 in_dollar_quote = false;
390 dollar_tag.clear();
391 continue;
392 }
393 }
394 }
395
396 if b == b'\'' && !inside_double_quote && !in_dollar_quote {
397 inside_single_quote = !inside_single_quote;
398 } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
399 inside_double_quote = !inside_double_quote;
400 } else if b == b';' && !inside_string {
401 self.stmt_buffer.extend_from_slice(&buf[..=i]);
402 consumed = i + 1;
403 found_terminator = true;
404 break;
405 }
406 }
407
408 if found_terminator {
409 self.reader.consume(consumed);
410 let result = std::mem::take(&mut self.stmt_buffer);
411
412 if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
414 self.in_copy_data = true;
415 }
416
417 return Ok(Some(result));
418 }
419
420 self.stmt_buffer.extend_from_slice(buf);
421 let len = buf.len();
422 self.reader.consume(len);
423 }
424 }
425
426 fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
428 let stmt = strip_leading_comments_and_whitespace(stmt);
430 if stmt.len() < 4 {
431 return false;
432 }
433
434 let upper: Vec<u8> = stmt
436 .iter()
437 .take(500)
438 .map(|b| b.to_ascii_uppercase())
439 .collect();
440 upper.starts_with(b"COPY ")
441 && (upper.windows(10).any(|w| w == b"FROM STDIN")
442 || upper.windows(11).any(|w| w == b"FROM STDIN;"))
443 }
444
445 fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
447 self.stmt_buffer.clear();
448
449 loop {
450 let buf = self.reader.fill_buf()?;
452 if buf.is_empty() {
453 self.in_copy_data = false;
454 if self.stmt_buffer.is_empty() {
455 return Ok(None);
456 }
457 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
458 }
459
460 let newline_pos = buf.iter().position(|&b| b == b'\n');
462
463 if let Some(i) = newline_pos {
464 self.stmt_buffer.extend_from_slice(&buf[..=i]);
466 self.reader.consume(i + 1);
467
468 if self.ends_with_copy_terminator() {
471 self.in_copy_data = false;
472 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
473 }
474 } else {
476 let len = buf.len();
478 self.stmt_buffer.extend_from_slice(buf);
479 self.reader.consume(len);
480 }
481 }
482 }
483
484 fn ends_with_copy_terminator(&self) -> bool {
486 let data = &self.stmt_buffer;
487 if data.len() < 2 {
488 return false;
489 }
490
491 let last_newline = data[..data.len() - 1]
494 .iter()
495 .rposition(|&b| b == b'\n')
496 .map(|i| i + 1)
497 .unwrap_or(0);
498
499 let last_line = &data[last_newline..];
500
501 last_line == b"\\.\n" || last_line == b"\\.\r\n"
503 }
504
505 #[allow(dead_code)]
506 pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
507 Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
508 }
509
510 pub fn parse_statement_with_dialect(
512 stmt: &[u8],
513 dialect: SqlDialect,
514 ) -> (StatementType, String) {
515 let stmt = strip_leading_comments_and_whitespace(stmt);
517
518 if stmt.len() < 4 {
519 return (StatementType::Unknown, String::new());
520 }
521
522 let upper_prefix: Vec<u8> = stmt
523 .iter()
524 .take(25)
525 .map(|b| b.to_ascii_uppercase())
526 .collect();
527
528 if upper_prefix.starts_with(b"COPY ") {
530 if let Some(caps) = COPY_RE.captures(stmt) {
531 if let Some(m) = caps.get(1) {
532 let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
533 let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
535 return (StatementType::Copy, table_name);
536 }
537 }
538 }
539
540 if upper_prefix.starts_with(b"CREATE TABLE") {
541 if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
543 return (StatementType::CreateTable, name);
544 }
545 if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
547 if let Some(m) = caps.get(1) {
548 return (
549 StatementType::CreateTable,
550 String::from_utf8_lossy(m.as_bytes()).into_owned(),
551 );
552 }
553 }
554 if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
556 if let Some(m) = caps.get(1) {
557 return (
558 StatementType::CreateTable,
559 String::from_utf8_lossy(m.as_bytes()).into_owned(),
560 );
561 }
562 }
563 }
564
565 if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
566 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
567 return (StatementType::Insert, name);
568 }
569 if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
570 if let Some(m) = caps.get(1) {
571 return (
572 StatementType::Insert,
573 String::from_utf8_lossy(m.as_bytes()).into_owned(),
574 );
575 }
576 }
577 if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
578 if let Some(m) = caps.get(1) {
579 return (
580 StatementType::Insert,
581 String::from_utf8_lossy(m.as_bytes()).into_owned(),
582 );
583 }
584 }
585 }
586
587 if upper_prefix.starts_with(b"CREATE INDEX") {
588 if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
589 if let Some(m) = caps.get(1) {
590 return (
591 StatementType::CreateIndex,
592 String::from_utf8_lossy(m.as_bytes()).into_owned(),
593 );
594 }
595 }
596 }
597
598 if upper_prefix.starts_with(b"ALTER TABLE") {
599 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
600 return (StatementType::AlterTable, name);
601 }
602 if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
603 if let Some(m) = caps.get(1) {
604 return (
605 StatementType::AlterTable,
606 String::from_utf8_lossy(m.as_bytes()).into_owned(),
607 );
608 }
609 }
610 }
611
612 if upper_prefix.starts_with(b"DROP TABLE") {
613 if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
614 return (StatementType::DropTable, name);
615 }
616 if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
617 if let Some(m) = caps.get(1) {
618 return (
619 StatementType::DropTable,
620 String::from_utf8_lossy(m.as_bytes()).into_owned(),
621 );
622 }
623 }
624 }
625
626 (StatementType::Unknown, String::new())
627 }
628}
629
630#[inline]
631fn trim_ascii_start(data: &[u8]) -> &[u8] {
632 let start = data
633 .iter()
634 .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
635 .unwrap_or(data.len());
636 &data[start..]
637}
638
639fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
642 loop {
643 data = trim_ascii_start(data);
645
646 if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
648 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
650 data = &data[pos + 1..];
651 continue;
652 } else {
653 return &[];
655 }
656 }
657
658 if data.len() >= 2 && data[0] == b'/' && data[1] == b'*' {
660 let mut i = 2;
662 let mut depth = 1;
663 while i < data.len() - 1 && depth > 0 {
664 if data[i] == b'*' && data[i + 1] == b'/' {
665 depth -= 1;
666 i += 2;
667 } else if data[i] == b'/' && data[i + 1] == b'*' {
668 depth += 1;
669 i += 2;
670 } else {
671 i += 1;
672 }
673 }
674 if depth == 0 {
675 data = &data[i..];
676 continue;
677 } else {
678 return &[];
680 }
681 }
682
683 if !data.is_empty() && data[0] == b'#' {
685 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
686 data = &data[pos + 1..];
687 continue;
688 } else {
689 return &[];
690 }
691 }
692
693 break;
694 }
695
696 data
697}
698
699#[inline]
705fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
706 let mut i = offset;
707
708 while i < stmt.len() && is_whitespace(stmt[i]) {
710 i += 1;
711 }
712
713 if i >= stmt.len() {
714 return None;
715 }
716
717 let upper_check: Vec<u8> = stmt[i..]
719 .iter()
720 .take(20)
721 .map(|b| b.to_ascii_uppercase())
722 .collect();
723 if upper_check.starts_with(b"IF NOT EXISTS") {
724 i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
726 i += 1;
727 }
728 } else if upper_check.starts_with(b"IF EXISTS") {
729 i += 9; while i < stmt.len() && is_whitespace(stmt[i]) {
731 i += 1;
732 }
733 }
734
735 let upper_check: Vec<u8> = stmt[i..]
737 .iter()
738 .take(10)
739 .map(|b| b.to_ascii_uppercase())
740 .collect();
741 if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
742 i += 4;
743 while i < stmt.len() && is_whitespace(stmt[i]) {
744 i += 1;
745 }
746 }
747
748 if i >= stmt.len() {
749 return None;
750 }
751
752 let mut parts: Vec<String> = Vec::new();
754
755 loop {
756 let quote_char = match stmt.get(i) {
758 Some(b'`') if dialect == SqlDialect::MySql => {
759 i += 1;
760 Some(b'`')
761 }
762 Some(b'"') if dialect != SqlDialect::MySql => {
763 i += 1;
764 Some(b'"')
765 }
766 Some(b'"') => {
767 i += 1;
769 Some(b'"')
770 }
771 _ => None,
772 };
773
774 let start = i;
775
776 while i < stmt.len() {
777 let b = stmt[i];
778 if let Some(q) = quote_char {
779 if b == q {
780 let name = &stmt[start..i];
781 parts.push(String::from_utf8_lossy(name).into_owned());
782 i += 1; break;
784 }
785 } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
786 if i > start {
787 let name = &stmt[start..i];
788 parts.push(String::from_utf8_lossy(name).into_owned());
789 }
790 break;
791 }
792 i += 1;
793 }
794
795 if quote_char.is_some() && i <= start {
797 break;
798 }
799
800 while i < stmt.len() && is_whitespace(stmt[i]) {
802 i += 1;
803 }
804
805 if i < stmt.len() && stmt[i] == b'.' {
806 i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
808 i += 1;
809 }
810 } else {
812 break;
813 }
814 }
815
816 parts.pop()
818}
819
820#[inline]
821fn is_whitespace(b: u8) -> bool {
822 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
823}
824
825pub fn determine_buffer_size(file_size: u64) -> usize {
826 if file_size > 1024 * 1024 * 1024 {
827 MEDIUM_BUFFER_SIZE
828 } else {
829 SMALL_BUFFER_SIZE
830 }
831}