1pub mod mysql_insert;
2pub mod postgres_copy;
3
4pub use mysql_insert::{parse_insert_for_bulk, ParsedValue};
6
7use once_cell::sync::Lazy;
8use regex::bytes::Regex;
9use std::io::{BufRead, BufReader, Read};
10
11pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
12pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum SqlDialect {
17 #[default]
19 MySql,
20 Postgres,
22 Sqlite,
24 Mssql,
26}
27
28impl std::str::FromStr for SqlDialect {
29 type Err = String;
30
31 fn from_str(s: &str) -> Result<Self, Self::Err> {
32 match s.to_lowercase().as_str() {
33 "mysql" | "mariadb" => Ok(SqlDialect::MySql),
34 "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
35 "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
36 "mssql" | "sqlserver" | "sql_server" | "tsql" => Ok(SqlDialect::Mssql),
37 _ => Err(format!(
38 "Unknown dialect: {}. Valid options: mysql, postgres, sqlite, mssql",
39 s
40 )),
41 }
42 }
43}
44
45impl std::fmt::Display for SqlDialect {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 SqlDialect::MySql => write!(f, "mysql"),
49 SqlDialect::Postgres => write!(f, "postgres"),
50 SqlDialect::Sqlite => write!(f, "sqlite"),
51 SqlDialect::Mssql => write!(f, "mssql"),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct DialectDetectionResult {
59 pub dialect: SqlDialect,
60 pub confidence: DialectConfidence,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum DialectConfidence {
66 High,
68 Medium,
70 Low,
72}
73
74#[derive(Default)]
75struct DialectScore {
76 mysql: u32,
77 postgres: u32,
78 sqlite: u32,
79 mssql: u32,
80}
81
82pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
85 let mut score = DialectScore::default();
86
87 if contains_bytes(header, b"pg_dump") {
89 score.postgres += 10;
90 }
91 if contains_bytes(header, b"PostgreSQL database dump") {
92 score.postgres += 10;
93 }
94 if contains_bytes(header, b"MySQL dump") {
95 score.mysql += 10;
96 }
97 if contains_bytes(header, b"MariaDB dump") {
98 score.mysql += 10;
99 }
100 if contains_bytes(header, b"SQLite") {
101 score.sqlite += 10;
102 }
103
104 if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
106 score.postgres += 5;
107 }
108 if contains_bytes(header, b"search_path") {
109 score.postgres += 5;
110 }
111 if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
112 score.mysql += 5;
113 }
114 if contains_bytes(header, b"LOCK TABLES") {
115 score.mysql += 5;
116 }
117 if contains_bytes(header, b"PRAGMA") {
118 score.sqlite += 5;
119 }
120
121 if contains_bytes(header, b"$$") {
123 score.postgres += 2;
124 }
125 if contains_bytes(header, b"CREATE EXTENSION") {
126 score.postgres += 2;
127 }
128 if contains_bytes(header, b"BEGIN TRANSACTION") {
130 score.sqlite += 2;
131 }
132 if header.contains(&b'`') {
134 score.mysql += 2;
135 }
136
137 if contains_bytes(header, b"SET ANSI_NULLS") {
140 score.mssql += 20;
141 }
142 if contains_bytes(header, b"SET QUOTED_IDENTIFIER") {
143 score.mssql += 20;
144 }
145
146 if contains_bytes(header, b"\nGO\n") || contains_bytes(header, b"\nGO\r\n") {
149 score.mssql += 15;
150 }
151 if header.contains(&b'[') && header.contains(&b']') {
153 score.mssql += 10;
154 }
155 if contains_bytes(header, b"IDENTITY(") {
156 score.mssql += 10;
157 }
158 if contains_bytes(header, b"ON [PRIMARY]") {
159 score.mssql += 10;
160 }
161
162 if contains_bytes(header, b"N'") {
164 score.mssql += 5;
165 }
166 if contains_bytes(header, b"NVARCHAR") {
167 score.mssql += 5;
168 }
169 if contains_bytes(header, b"CLUSTERED") {
170 score.mssql += 5;
171 }
172 if contains_bytes(header, b"SET NOCOUNT") {
173 score.mssql += 5;
174 }
175
176 let max_score = score
178 .mysql
179 .max(score.postgres)
180 .max(score.sqlite)
181 .max(score.mssql);
182
183 if max_score == 0 {
184 return DialectDetectionResult {
185 dialect: SqlDialect::MySql,
186 confidence: DialectConfidence::Low,
187 };
188 }
189
190 let (dialect, winning_score) = if score.mssql > score.mysql
192 && score.mssql > score.postgres
193 && score.mssql > score.sqlite
194 {
195 (SqlDialect::Mssql, score.mssql)
196 } else if score.postgres > score.mysql && score.postgres > score.sqlite {
197 (SqlDialect::Postgres, score.postgres)
198 } else if score.sqlite > score.mysql {
199 (SqlDialect::Sqlite, score.sqlite)
200 } else {
201 (SqlDialect::MySql, score.mysql)
202 };
203
204 let confidence = if winning_score >= 10 {
206 DialectConfidence::High
207 } else if winning_score >= 5 {
208 DialectConfidence::Medium
209 } else {
210 DialectConfidence::Low
211 };
212
213 DialectDetectionResult {
214 dialect,
215 confidence,
216 }
217}
218
219pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
221 use std::fs::File;
222 use std::io::Read;
223
224 let mut file = File::open(path)?;
225 let mut buf = [0u8; 8192];
226 let n = file.read(&mut buf)?;
227 Ok(detect_dialect(&buf[..n]))
228}
229
230#[inline]
231fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
232 haystack
233 .windows(needle.len())
234 .any(|window| window == needle)
235}
236
237fn is_go_line(line: &[u8]) -> bool {
241 let mut start = 0;
243 while start < line.len()
244 && (line[start] == b' ' || line[start] == b'\t' || line[start] == b'\r')
245 {
246 start += 1;
247 }
248
249 let mut end = line.len();
251 while end > start
252 && (line[end - 1] == b' '
253 || line[end - 1] == b'\t'
254 || line[end - 1] == b'\r'
255 || line[end - 1] == b'\n')
256 {
257 end -= 1;
258 }
259
260 let trimmed = &line[start..end];
261
262 if trimmed.len() < 2 {
263 return false;
264 }
265
266 if trimmed.len() == 2 {
268 return (trimmed[0] == b'G' || trimmed[0] == b'g')
269 && (trimmed[1] == b'O' || trimmed[1] == b'o');
270 }
271
272 if (trimmed[0] == b'G' || trimmed[0] == b'g')
274 && (trimmed[1] == b'O' || trimmed[1] == b'o')
275 && (trimmed[2] == b' ' || trimmed[2] == b'\t')
276 {
277 let rest = &trimmed[3..];
279 let rest_trimmed = rest
280 .iter()
281 .skip_while(|&&b| b == b' ' || b == b'\t')
282 .copied()
283 .collect::<Vec<_>>();
284 return rest_trimmed.is_empty() || rest_trimmed.iter().all(|&b| b.is_ascii_digit());
285 }
286
287 false
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum StatementType {
292 Unknown,
293 CreateTable,
294 Insert,
295 CreateIndex,
296 AlterTable,
297 DropTable,
298 Copy,
300}
301
302impl StatementType {
303 pub fn is_schema(&self) -> bool {
305 matches!(
306 self,
307 StatementType::CreateTable
308 | StatementType::CreateIndex
309 | StatementType::AlterTable
310 | StatementType::DropTable
311 )
312 }
313
314 pub fn is_data(&self) -> bool {
316 matches!(self, StatementType::Insert | StatementType::Copy)
317 }
318}
319
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
322pub enum ContentFilter {
323 #[default]
325 All,
326 SchemaOnly,
328 DataOnly,
330}
331
332static CREATE_TABLE_RE: Lazy<Regex> =
333 Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
334
335static INSERT_INTO_RE: Lazy<Regex> =
336 Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
337
338static CREATE_INDEX_RE: Lazy<Regex> =
339 Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
340
341static CREATE_INDEX_MSSQL_RE: Lazy<Regex> =
345 Lazy::new(|| Regex::new(r"(?i)ON\s+(?:\[?[^\[\]\s]+\]?\s*\.\s*)*\[([^\[\]]+)\]").unwrap());
346
347static ALTER_TABLE_RE: Lazy<Regex> =
348 Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
349
350static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
351 Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
352});
353
354static COPY_RE: Lazy<Regex> =
356 Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
357
358static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
364 Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
365});
366
367static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
368 Regex::new(
369 r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
370 )
371 .unwrap()
372});
373
374pub struct Parser<R: Read> {
375 reader: BufReader<R>,
376 stmt_buffer: Vec<u8>,
377 dialect: SqlDialect,
378 in_copy_data: bool,
380}
381
382impl<R: Read> Parser<R> {
383 #[allow(dead_code)]
384 pub fn new(reader: R, buffer_size: usize) -> Self {
385 Self::with_dialect(reader, buffer_size, SqlDialect::default())
386 }
387
388 pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
389 Self {
390 reader: BufReader::with_capacity(buffer_size, reader),
391 stmt_buffer: Vec::with_capacity(32 * 1024),
392 dialect,
393 in_copy_data: false,
394 }
395 }
396
397 pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
398 if self.in_copy_data {
400 return self.read_copy_data();
401 }
402
403 if self.dialect == SqlDialect::Mssql {
405 return self.read_statement_mssql();
406 }
407
408 self.stmt_buffer.clear();
409
410 let mut inside_single_quote = false;
411 let mut inside_double_quote = false;
412 let mut escaped = false;
413 let mut in_line_comment = false;
414 let mut in_dollar_quote = false;
416 let mut dollar_tag: Vec<u8> = Vec::new();
417
418 loop {
419 let buf = self.reader.fill_buf()?;
420 if buf.is_empty() {
421 if self.stmt_buffer.is_empty() {
422 return Ok(None);
423 }
424 let result = std::mem::take(&mut self.stmt_buffer);
425 return Ok(Some(result));
426 }
427
428 let mut consumed = 0;
429 let mut found_terminator = false;
430
431 for (i, &b) in buf.iter().enumerate() {
432 let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
433
434 if in_line_comment {
436 if b == b'\n' {
437 in_line_comment = false;
438 }
439 continue;
440 }
441
442 if escaped {
443 escaped = false;
444 continue;
445 }
446
447 if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
449 escaped = true;
450 continue;
451 }
452
453 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
455 in_line_comment = true;
456 continue;
457 }
458
459 if self.dialect == SqlDialect::Postgres
461 && !inside_single_quote
462 && !inside_double_quote
463 {
464 if b == b'$' && !in_dollar_quote {
465 if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
467 let tag_bytes = &buf[i + 1..i + 1 + end];
468
469 let is_valid_tag = if tag_bytes.is_empty() {
471 true
472 } else {
473 let mut iter = tag_bytes.iter();
474 match iter.next() {
475 Some(&first)
476 if first.is_ascii_alphabetic() || first == b'_' =>
477 {
478 iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
479 }
480 _ => false,
481 }
482 };
483
484 if is_valid_tag {
485 dollar_tag = tag_bytes.to_vec();
486 in_dollar_quote = true;
487 continue;
488 }
489 }
491 } else if b == b'$' && in_dollar_quote {
492 let tag_len = dollar_tag.len();
494 if i + 1 + tag_len < buf.len()
495 && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
496 && buf.get(i + 1 + tag_len) == Some(&b'$')
497 {
498 in_dollar_quote = false;
499 dollar_tag.clear();
500 continue;
501 }
502 }
503 }
504
505 if b == b'\'' && !inside_double_quote && !in_dollar_quote {
506 inside_single_quote = !inside_single_quote;
507 } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
508 inside_double_quote = !inside_double_quote;
509 } else if b == b';' && !inside_string {
510 self.stmt_buffer.extend_from_slice(&buf[..=i]);
511 consumed = i + 1;
512 found_terminator = true;
513 break;
514 }
515 }
516
517 if found_terminator {
518 self.reader.consume(consumed);
519 let result = std::mem::take(&mut self.stmt_buffer);
520
521 if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
523 self.in_copy_data = true;
524 }
525
526 return Ok(Some(result));
527 }
528
529 self.stmt_buffer.extend_from_slice(buf);
530 let len = buf.len();
531 self.reader.consume(len);
532 }
533 }
534
535 fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
537 let stmt = strip_leading_comments_and_whitespace(stmt);
539 if stmt.len() < 4 {
540 return false;
541 }
542
543 let upper: Vec<u8> = stmt
545 .iter()
546 .take(500)
547 .map(|b| b.to_ascii_uppercase())
548 .collect();
549 upper.starts_with(b"COPY ")
550 && (upper.windows(10).any(|w| w == b"FROM STDIN")
551 || upper.windows(11).any(|w| w == b"FROM STDIN;"))
552 }
553
554 fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
556 self.stmt_buffer.clear();
557
558 loop {
559 let buf = self.reader.fill_buf()?;
561 if buf.is_empty() {
562 self.in_copy_data = false;
563 if self.stmt_buffer.is_empty() {
564 return Ok(None);
565 }
566 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
567 }
568
569 let newline_pos = buf.iter().position(|&b| b == b'\n');
571
572 if let Some(i) = newline_pos {
573 self.stmt_buffer.extend_from_slice(&buf[..=i]);
575 self.reader.consume(i + 1);
576
577 if self.ends_with_copy_terminator() {
580 self.in_copy_data = false;
581 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
582 }
583 } else {
585 let len = buf.len();
587 self.stmt_buffer.extend_from_slice(buf);
588 self.reader.consume(len);
589 }
590 }
591 }
592
593 fn ends_with_copy_terminator(&self) -> bool {
595 let data = &self.stmt_buffer;
596 if data.len() < 2 {
597 return false;
598 }
599
600 let last_newline = data[..data.len() - 1]
603 .iter()
604 .rposition(|&b| b == b'\n')
605 .map(|i| i + 1)
606 .unwrap_or(0);
607
608 let last_line = &data[last_newline..];
609
610 last_line == b"\\.\n" || last_line == b"\\.\r\n"
612 }
613
614 fn read_statement_mssql(&mut self) -> std::io::Result<Option<Vec<u8>>> {
617 self.stmt_buffer.clear();
618
619 let mut inside_single_quote = false;
620 let mut inside_bracket_quote = false;
621 let mut in_line_comment = false;
622 let mut line_start = 0usize; loop {
625 let buf = self.reader.fill_buf()?;
626 if buf.is_empty() {
627 if self.stmt_buffer.is_empty() {
628 return Ok(None);
629 }
630 let result = std::mem::take(&mut self.stmt_buffer);
631 return Ok(Some(result));
632 }
633
634 let mut consumed = 0;
635 let mut found_terminator = false;
636
637 for (i, &b) in buf.iter().enumerate() {
638 let inside_string = inside_single_quote || inside_bracket_quote;
639
640 if in_line_comment {
642 if b == b'\n' {
643 in_line_comment = false;
644 self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
646 consumed = i + 1;
647 line_start = self.stmt_buffer.len();
648 }
649 continue;
650 }
651
652 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
654 in_line_comment = true;
655 continue;
656 }
657
658 if b == b'\'' && !inside_bracket_quote {
663 inside_single_quote = !inside_single_quote;
664 } else if b == b'[' && !inside_single_quote {
665 inside_bracket_quote = true;
666 } else if b == b']' && inside_bracket_quote {
667 if i + 1 < buf.len() && buf[i + 1] == b']' {
669 continue;
671 }
672 inside_bracket_quote = false;
673 } else if b == b';' && !inside_string {
674 self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
676 consumed = i + 1;
677 found_terminator = true;
678 break;
679 } else if b == b'\n' && !inside_string {
680 self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
683 consumed = i + 1;
684
685 let line = &self.stmt_buffer[line_start..];
687 if is_go_line(line) {
688 self.stmt_buffer.truncate(line_start);
690 while self
692 .stmt_buffer
693 .last()
694 .is_some_and(|&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
695 {
696 self.stmt_buffer.pop();
697 }
698 if !self.stmt_buffer.is_empty() {
700 self.reader.consume(consumed);
701 let result = std::mem::take(&mut self.stmt_buffer);
702 return Ok(Some(result));
703 }
704 line_start = 0;
706 } else {
707 line_start = self.stmt_buffer.len();
709 }
710 continue;
711 }
712 }
713
714 if found_terminator {
715 self.reader.consume(consumed);
716 let result = std::mem::take(&mut self.stmt_buffer);
717 return Ok(Some(result));
718 }
719
720 if consumed < buf.len() {
722 self.stmt_buffer.extend_from_slice(&buf[consumed..]);
723 }
724 let len = buf.len();
725 self.reader.consume(len);
726 }
727 }
728
729 #[allow(dead_code)]
730 pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
731 Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
732 }
733
734 pub fn parse_statement_with_dialect(
736 stmt: &[u8],
737 dialect: SqlDialect,
738 ) -> (StatementType, String) {
739 let stmt = strip_leading_comments_and_whitespace(stmt);
741
742 if stmt.len() < 4 {
743 return (StatementType::Unknown, String::new());
744 }
745
746 let upper_prefix: Vec<u8> = stmt
747 .iter()
748 .take(25)
749 .map(|b| b.to_ascii_uppercase())
750 .collect();
751
752 if upper_prefix.starts_with(b"COPY ") {
754 if let Some(caps) = COPY_RE.captures(stmt) {
755 if let Some(m) = caps.get(1) {
756 let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
757 let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
759 return (StatementType::Copy, table_name);
760 }
761 }
762 }
763
764 if upper_prefix.starts_with(b"CREATE TABLE") {
765 if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
767 return (StatementType::CreateTable, name);
768 }
769 if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
771 if let Some(m) = caps.get(1) {
772 return (
773 StatementType::CreateTable,
774 String::from_utf8_lossy(m.as_bytes()).into_owned(),
775 );
776 }
777 }
778 if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
780 if let Some(m) = caps.get(1) {
781 return (
782 StatementType::CreateTable,
783 String::from_utf8_lossy(m.as_bytes()).into_owned(),
784 );
785 }
786 }
787 }
788
789 if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
790 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
791 return (StatementType::Insert, name);
792 }
793 if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
794 if let Some(m) = caps.get(1) {
795 return (
796 StatementType::Insert,
797 String::from_utf8_lossy(m.as_bytes()).into_owned(),
798 );
799 }
800 }
801 if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
802 if let Some(m) = caps.get(1) {
803 return (
804 StatementType::Insert,
805 String::from_utf8_lossy(m.as_bytes()).into_owned(),
806 );
807 }
808 }
809 }
810
811 if upper_prefix.starts_with(b"CREATE INDEX")
812 || upper_prefix.starts_with(b"CREATE UNIQUE")
813 || upper_prefix.starts_with(b"CREATE CLUSTERED")
814 || upper_prefix.starts_with(b"CREATE NONCLUSTER")
815 {
816 if dialect == SqlDialect::Mssql {
818 if let Some(caps) = CREATE_INDEX_MSSQL_RE.captures(stmt) {
819 if let Some(m) = caps.get(1) {
820 return (
821 StatementType::CreateIndex,
822 String::from_utf8_lossy(m.as_bytes()).into_owned(),
823 );
824 }
825 }
826 }
827 if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
829 if let Some(m) = caps.get(1) {
830 return (
831 StatementType::CreateIndex,
832 String::from_utf8_lossy(m.as_bytes()).into_owned(),
833 );
834 }
835 }
836 }
837
838 if upper_prefix.starts_with(b"ALTER TABLE") {
839 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
840 return (StatementType::AlterTable, name);
841 }
842 if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
843 if let Some(m) = caps.get(1) {
844 return (
845 StatementType::AlterTable,
846 String::from_utf8_lossy(m.as_bytes()).into_owned(),
847 );
848 }
849 }
850 }
851
852 if upper_prefix.starts_with(b"DROP TABLE") {
853 if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
854 return (StatementType::DropTable, name);
855 }
856 if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
857 if let Some(m) = caps.get(1) {
858 return (
859 StatementType::DropTable,
860 String::from_utf8_lossy(m.as_bytes()).into_owned(),
861 );
862 }
863 }
864 }
865
866 if upper_prefix.starts_with(b"BULK INSERT") {
868 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
869 return (StatementType::Insert, name);
870 }
871 }
872
873 (StatementType::Unknown, String::new())
874 }
875}
876
877#[inline]
878fn trim_ascii_start(data: &[u8]) -> &[u8] {
879 let start = data
880 .iter()
881 .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
882 .unwrap_or(data.len());
883 &data[start..]
884}
885
886fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
889 loop {
890 data = trim_ascii_start(data);
892
893 if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
895 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
897 data = &data[pos + 1..];
898 continue;
899 } else {
900 return &[];
902 }
903 }
904
905 if data.len() >= 2 && data[0] == b'/' && data[1] == b'*' {
907 let mut i = 2;
909 let mut depth = 1;
910 while i < data.len() - 1 && depth > 0 {
911 if data[i] == b'*' && data[i + 1] == b'/' {
912 depth -= 1;
913 i += 2;
914 } else if data[i] == b'/' && data[i + 1] == b'*' {
915 depth += 1;
916 i += 2;
917 } else {
918 i += 1;
919 }
920 }
921 if depth == 0 {
922 data = &data[i..];
923 continue;
924 } else {
925 return &[];
927 }
928 }
929
930 if !data.is_empty() && data[0] == b'#' {
932 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
933 data = &data[pos + 1..];
934 continue;
935 } else {
936 return &[];
937 }
938 }
939
940 break;
941 }
942
943 data
944}
945
946#[inline]
952fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
953 let mut i = offset;
954
955 while i < stmt.len() && is_whitespace(stmt[i]) {
957 i += 1;
958 }
959
960 if i >= stmt.len() {
961 return None;
962 }
963
964 let upper_check: Vec<u8> = stmt[i..]
966 .iter()
967 .take(20)
968 .map(|b| b.to_ascii_uppercase())
969 .collect();
970 if upper_check.starts_with(b"IF NOT EXISTS") {
971 i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
973 i += 1;
974 }
975 } else if upper_check.starts_with(b"IF EXISTS") {
976 i += 9; while i < stmt.len() && is_whitespace(stmt[i]) {
978 i += 1;
979 }
980 }
981
982 let upper_check: Vec<u8> = stmt[i..]
984 .iter()
985 .take(10)
986 .map(|b| b.to_ascii_uppercase())
987 .collect();
988 if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
989 i += 4;
990 while i < stmt.len() && is_whitespace(stmt[i]) {
991 i += 1;
992 }
993 }
994
995 if i >= stmt.len() {
996 return None;
997 }
998
999 let mut parts: Vec<String> = Vec::new();
1001
1002 loop {
1003 let (quote_char, close_char) = match stmt.get(i) {
1005 Some(b'`') if dialect == SqlDialect::MySql => {
1006 i += 1;
1007 (Some(b'`'), b'`')
1008 }
1009 Some(b'"') if dialect != SqlDialect::MySql => {
1010 i += 1;
1011 (Some(b'"'), b'"')
1012 }
1013 Some(b'"') => {
1014 i += 1;
1016 (Some(b'"'), b'"')
1017 }
1018 Some(b'[') if dialect == SqlDialect::Mssql => {
1019 i += 1;
1021 (Some(b'['), b']')
1022 }
1023 _ => (None, 0),
1024 };
1025
1026 let start = i;
1027
1028 while i < stmt.len() {
1029 let b = stmt[i];
1030 if quote_char.is_some() {
1031 if b == close_char {
1032 if dialect == SqlDialect::Mssql
1034 && close_char == b']'
1035 && i + 1 < stmt.len()
1036 && stmt[i + 1] == b']'
1037 {
1038 i += 2;
1040 continue;
1041 }
1042 let name = &stmt[start..i];
1043 let name_str = if dialect == SqlDialect::Mssql {
1045 String::from_utf8_lossy(name).replace("]]", "]")
1046 } else {
1047 String::from_utf8_lossy(name).into_owned()
1048 };
1049 parts.push(name_str);
1050 i += 1; break;
1052 }
1053 } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
1054 if i > start {
1055 let name = &stmt[start..i];
1056 parts.push(String::from_utf8_lossy(name).into_owned());
1057 }
1058 break;
1059 }
1060 i += 1;
1061 }
1062
1063 if quote_char.is_some() && i <= start {
1065 break;
1066 }
1067
1068 while i < stmt.len() && is_whitespace(stmt[i]) {
1070 i += 1;
1071 }
1072
1073 if i < stmt.len() && stmt[i] == b'.' {
1074 i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
1076 i += 1;
1077 }
1078 } else {
1080 break;
1081 }
1082 }
1083
1084 parts.pop()
1086}
1087
1088#[inline]
1089fn is_whitespace(b: u8) -> bool {
1090 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
1091}
1092
1093pub fn determine_buffer_size(file_size: u64) -> usize {
1094 if file_size > 1024 * 1024 * 1024 {
1095 MEDIUM_BUFFER_SIZE
1096 } else {
1097 SMALL_BUFFER_SIZE
1098 }
1099}