1pub mod mysql_insert;
2pub mod postgres_copy;
3
4pub use mysql_insert::{parse_insert_for_bulk, ParsedValue};
6
7use once_cell::sync::Lazy;
8use regex::bytes::Regex;
9use std::io::{BufRead, BufReader, Read};
10
11pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
12pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum SqlDialect {
17 #[default]
19 MySql,
20 Postgres,
22 Sqlite,
24 Mssql,
26}
27
28impl std::str::FromStr for SqlDialect {
29 type Err = String;
30
31 fn from_str(s: &str) -> Result<Self, Self::Err> {
32 match s.to_lowercase().as_str() {
33 "mysql" | "mariadb" => Ok(SqlDialect::MySql),
34 "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
35 "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
36 "mssql" | "sqlserver" | "sql_server" | "tsql" => Ok(SqlDialect::Mssql),
37 _ => Err(format!(
38 "Unknown dialect: {}. Valid options: mysql, postgres, sqlite, mssql",
39 s
40 )),
41 }
42 }
43}
44
45impl std::fmt::Display for SqlDialect {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 SqlDialect::MySql => write!(f, "mysql"),
49 SqlDialect::Postgres => write!(f, "postgres"),
50 SqlDialect::Sqlite => write!(f, "sqlite"),
51 SqlDialect::Mssql => write!(f, "mssql"),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct DialectDetectionResult {
59 pub dialect: SqlDialect,
60 pub confidence: DialectConfidence,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum DialectConfidence {
66 High,
68 Medium,
70 Low,
72}
73
74#[derive(Default)]
75struct DialectScore {
76 mysql: u32,
77 postgres: u32,
78 sqlite: u32,
79 mssql: u32,
80}
81
82pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
85 let mut score = DialectScore::default();
86
87 if contains_bytes(header, b"pg_dump") {
89 score.postgres += 10;
90 }
91 if contains_bytes(header, b"PostgreSQL database dump") {
92 score.postgres += 10;
93 }
94 if contains_bytes(header, b"MySQL dump") {
95 score.mysql += 10;
96 }
97 if contains_bytes(header, b"MariaDB dump") {
98 score.mysql += 10;
99 }
100 if contains_bytes(header, b"SQLite") {
101 score.sqlite += 10;
102 }
103
104 if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
106 score.postgres += 5;
107 }
108 if contains_bytes(header, b"search_path") {
109 score.postgres += 5;
110 }
111 if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
112 score.mysql += 5;
113 }
114 if contains_bytes(header, b"LOCK TABLES") {
115 score.mysql += 5;
116 }
117 if contains_bytes(header, b"PRAGMA") {
118 score.sqlite += 5;
119 }
120
121 if contains_bytes(header, b"$$") {
123 score.postgres += 2;
124 }
125 if contains_bytes(header, b"CREATE EXTENSION") {
126 score.postgres += 2;
127 }
128 if contains_bytes(header, b"BEGIN TRANSACTION") {
130 score.sqlite += 2;
131 }
132 if header.contains(&b'`') {
134 score.mysql += 2;
135 }
136
137 if contains_bytes(header, b"SET ANSI_NULLS") {
140 score.mssql += 20;
141 }
142 if contains_bytes(header, b"SET QUOTED_IDENTIFIER") {
143 score.mssql += 20;
144 }
145
146 if contains_bytes(header, b"\nGO\n") || contains_bytes(header, b"\nGO\r\n") {
149 score.mssql += 15;
150 }
151 if header.contains(&b'[') && header.contains(&b']') {
153 score.mssql += 10;
154 }
155 if contains_bytes(header, b"IDENTITY(") {
156 score.mssql += 10;
157 }
158 if contains_bytes(header, b"ON [PRIMARY]") {
159 score.mssql += 10;
160 }
161
162 if contains_bytes(header, b"N'") {
164 score.mssql += 5;
165 }
166 if contains_bytes(header, b"NVARCHAR") {
167 score.mssql += 5;
168 }
169 if contains_bytes(header, b"CLUSTERED") {
170 score.mssql += 5;
171 }
172 if contains_bytes(header, b"SET NOCOUNT") {
173 score.mssql += 5;
174 }
175
176 let max_score = score
178 .mysql
179 .max(score.postgres)
180 .max(score.sqlite)
181 .max(score.mssql);
182
183 if max_score == 0 {
184 return DialectDetectionResult {
185 dialect: SqlDialect::MySql,
186 confidence: DialectConfidence::Low,
187 };
188 }
189
190 let (dialect, winning_score) = if score.mssql > score.mysql
192 && score.mssql > score.postgres
193 && score.mssql > score.sqlite
194 {
195 (SqlDialect::Mssql, score.mssql)
196 } else if score.postgres > score.mysql && score.postgres > score.sqlite {
197 (SqlDialect::Postgres, score.postgres)
198 } else if score.sqlite > score.mysql {
199 (SqlDialect::Sqlite, score.sqlite)
200 } else {
201 (SqlDialect::MySql, score.mysql)
202 };
203
204 let confidence = if winning_score >= 10 {
206 DialectConfidence::High
207 } else if winning_score >= 5 {
208 DialectConfidence::Medium
209 } else {
210 DialectConfidence::Low
211 };
212
213 DialectDetectionResult {
214 dialect,
215 confidence,
216 }
217}
218
219pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
221 use std::fs::File;
222 use std::io::Read;
223
224 let mut file = File::open(path)?;
225 let mut buf = [0u8; 8192];
226 let n = file.read(&mut buf)?;
227 Ok(detect_dialect(&buf[..n]))
228}
229
230#[inline]
231fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
232 haystack
233 .windows(needle.len())
234 .any(|window| window == needle)
235}
236
237fn is_go_line(line: &[u8]) -> bool {
241 let mut start = 0;
243 while start < line.len()
244 && (line[start] == b' ' || line[start] == b'\t' || line[start] == b'\r')
245 {
246 start += 1;
247 }
248
249 let mut end = line.len();
251 while end > start
252 && (line[end - 1] == b' '
253 || line[end - 1] == b'\t'
254 || line[end - 1] == b'\r'
255 || line[end - 1] == b'\n')
256 {
257 end -= 1;
258 }
259
260 let trimmed = &line[start..end];
261
262 if trimmed.len() < 2 {
263 return false;
264 }
265
266 if trimmed.len() == 2 {
268 return (trimmed[0] == b'G' || trimmed[0] == b'g')
269 && (trimmed[1] == b'O' || trimmed[1] == b'o');
270 }
271
272 if (trimmed[0] == b'G' || trimmed[0] == b'g')
274 && (trimmed[1] == b'O' || trimmed[1] == b'o')
275 && (trimmed[2] == b' ' || trimmed[2] == b'\t')
276 {
277 let rest = &trimmed[3..];
279 let rest_trimmed = rest
280 .iter()
281 .skip_while(|&&b| b == b' ' || b == b'\t')
282 .copied()
283 .collect::<Vec<_>>();
284 return rest_trimmed.is_empty() || rest_trimmed.iter().all(|&b| b.is_ascii_digit());
285 }
286
287 false
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum StatementType {
292 Unknown,
293 CreateTable,
294 Insert,
295 CreateIndex,
296 AlterTable,
297 DropTable,
298 Copy,
300}
301
302impl StatementType {
303 pub fn is_schema(&self) -> bool {
305 matches!(
306 self,
307 StatementType::CreateTable
308 | StatementType::CreateIndex
309 | StatementType::AlterTable
310 | StatementType::DropTable
311 )
312 }
313
314 pub fn is_data(&self) -> bool {
316 matches!(self, StatementType::Insert | StatementType::Copy)
317 }
318}
319
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
322pub enum ContentFilter {
323 #[default]
325 All,
326 SchemaOnly,
328 DataOnly,
330}
331
332static CREATE_TABLE_RE: Lazy<Regex> =
333 Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
334
335static INSERT_INTO_RE: Lazy<Regex> =
336 Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
337
338static CREATE_INDEX_RE: Lazy<Regex> =
339 Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
340
341static CREATE_INDEX_MSSQL_RE: Lazy<Regex> =
345 Lazy::new(|| Regex::new(r"(?i)ON\s+(?:\[?[^\[\]\s]+\]?\s*\.\s*)*\[([^\[\]]+)\]").unwrap());
346
347static ALTER_TABLE_RE: Lazy<Regex> =
348 Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
349
350static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
351 Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
352});
353
354static COPY_RE: Lazy<Regex> =
356 Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
357
358static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
364 Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
365});
366
367static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
368 Regex::new(
369 r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
370 )
371 .unwrap()
372});
373
374pub struct Parser<R: Read> {
375 reader: BufReader<R>,
376 stmt_buffer: Vec<u8>,
377 dialect: SqlDialect,
378 in_copy_data: bool,
380}
381
382impl<R: Read> Parser<R> {
383 #[allow(dead_code)]
384 pub fn new(reader: R, buffer_size: usize) -> Self {
385 Self::with_dialect(reader, buffer_size, SqlDialect::default())
386 }
387
388 pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
389 Self {
390 reader: BufReader::with_capacity(buffer_size, reader),
391 stmt_buffer: Vec::with_capacity(32 * 1024),
392 dialect,
393 in_copy_data: false,
394 }
395 }
396
397 pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
398 if self.in_copy_data {
400 return self.read_copy_data();
401 }
402
403 if self.dialect == SqlDialect::Mssql {
405 return self.read_statement_mssql();
406 }
407
408 self.stmt_buffer.clear();
409
410 let mut inside_single_quote = false;
411 let mut inside_double_quote = false;
412 let mut escaped = false;
413 let mut in_line_comment = false;
414 let mut in_dollar_quote = false;
416 let mut dollar_tag: Vec<u8> = Vec::new();
417
418 loop {
419 let buf = self.reader.fill_buf()?;
420 if buf.is_empty() {
421 if self.stmt_buffer.is_empty() {
422 return Ok(None);
423 }
424 let result = std::mem::take(&mut self.stmt_buffer);
425 return Ok(Some(result));
426 }
427
428 let mut consumed = 0;
429 let mut found_terminator = false;
430
431 for (i, &b) in buf.iter().enumerate() {
432 let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
433
434 if in_line_comment {
436 if b == b'\n' {
437 in_line_comment = false;
438 }
439 continue;
440 }
441
442 if escaped {
443 escaped = false;
444 continue;
445 }
446
447 if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
449 escaped = true;
450 continue;
451 }
452
453 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
455 in_line_comment = true;
456 continue;
457 }
458
459 if self.dialect == SqlDialect::Postgres
461 && !inside_single_quote
462 && !inside_double_quote
463 {
464 if b == b'$' && !in_dollar_quote {
465 if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
467 let tag_bytes = &buf[i + 1..i + 1 + end];
468
469 let is_valid_tag = if tag_bytes.is_empty() {
471 true
472 } else {
473 let mut iter = tag_bytes.iter();
474 match iter.next() {
475 Some(&first)
476 if first.is_ascii_alphabetic() || first == b'_' =>
477 {
478 iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
479 }
480 _ => false,
481 }
482 };
483
484 if is_valid_tag {
485 dollar_tag = tag_bytes.to_vec();
486 in_dollar_quote = true;
487 continue;
488 }
489 }
491 } else if b == b'$' && in_dollar_quote {
492 let tag_len = dollar_tag.len();
494 if i + 1 + tag_len < buf.len()
495 && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
496 && buf.get(i + 1 + tag_len) == Some(&b'$')
497 {
498 in_dollar_quote = false;
499 dollar_tag.clear();
500 continue;
501 }
502 }
503 }
504
505 if b == b'\'' && !inside_double_quote && !in_dollar_quote {
506 inside_single_quote = !inside_single_quote;
507 } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
508 inside_double_quote = !inside_double_quote;
509 } else if b == b';' && !inside_string {
510 self.stmt_buffer.extend_from_slice(&buf[..=i]);
511 consumed = i + 1;
512 found_terminator = true;
513 break;
514 }
515 }
516
517 if found_terminator {
518 self.reader.consume(consumed);
519 let result = std::mem::take(&mut self.stmt_buffer);
520
521 if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
523 self.in_copy_data = true;
524 }
525
526 return Ok(Some(result));
527 }
528
529 self.stmt_buffer.extend_from_slice(buf);
530 let len = buf.len();
531 self.reader.consume(len);
532 }
533 }
534
535 fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
537 let stmt = strip_leading_comments_and_whitespace(stmt);
539 if stmt.len() < 15 {
540 return false;
542 }
543
544 let mut prefix = [0u8; 5];
546 for (i, &b) in stmt.iter().take(5).enumerate() {
547 prefix[i] = b.to_ascii_uppercase();
548 }
549 if &prefix != b"COPY " {
550 return false;
551 }
552
553 let search_len = stmt.len().min(500);
556 for i in 0..search_len.saturating_sub(10) {
557 if stmt[i..i + 10]
558 .iter()
559 .zip(b"FROM STDIN".iter())
560 .all(|(&a, &b)| a.to_ascii_uppercase() == b)
561 {
562 return true;
563 }
564 }
565 false
566 }
567
568 fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
570 self.stmt_buffer.clear();
571
572 loop {
573 let buf = self.reader.fill_buf()?;
575 if buf.is_empty() {
576 self.in_copy_data = false;
577 if self.stmt_buffer.is_empty() {
578 return Ok(None);
579 }
580 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
581 }
582
583 let newline_pos = buf.iter().position(|&b| b == b'\n');
585
586 if let Some(i) = newline_pos {
587 self.stmt_buffer.extend_from_slice(&buf[..=i]);
589 self.reader.consume(i + 1);
590
591 if self.ends_with_copy_terminator() {
594 self.in_copy_data = false;
595 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
596 }
597 } else {
599 let len = buf.len();
601 self.stmt_buffer.extend_from_slice(buf);
602 self.reader.consume(len);
603 }
604 }
605 }
606
607 fn ends_with_copy_terminator(&self) -> bool {
609 let data = &self.stmt_buffer;
610 if data.len() < 2 {
611 return false;
612 }
613
614 let last_newline = data[..data.len() - 1]
617 .iter()
618 .rposition(|&b| b == b'\n')
619 .map(|i| i + 1)
620 .unwrap_or(0);
621
622 let last_line = &data[last_newline..];
623
624 last_line == b"\\.\n" || last_line == b"\\.\r\n"
626 }
627
628 fn read_statement_mssql(&mut self) -> std::io::Result<Option<Vec<u8>>> {
631 self.stmt_buffer.clear();
632
633 let mut inside_single_quote = false;
634 let mut inside_bracket_quote = false;
635 let mut in_line_comment = false;
636 let mut line_start = 0usize; loop {
639 let buf = self.reader.fill_buf()?;
640 if buf.is_empty() {
641 if self.stmt_buffer.is_empty() {
642 return Ok(None);
643 }
644 let result = std::mem::take(&mut self.stmt_buffer);
645 return Ok(Some(result));
646 }
647
648 let mut consumed = 0;
649 let mut found_terminator = false;
650
651 for (i, &b) in buf.iter().enumerate() {
652 let inside_string = inside_single_quote || inside_bracket_quote;
653
654 if in_line_comment {
656 if b == b'\n' {
657 in_line_comment = false;
658 self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
660 consumed = i + 1;
661 line_start = self.stmt_buffer.len();
662 }
663 continue;
664 }
665
666 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
668 in_line_comment = true;
669 continue;
670 }
671
672 if b == b'\'' && !inside_bracket_quote {
677 inside_single_quote = !inside_single_quote;
678 } else if b == b'[' && !inside_single_quote {
679 inside_bracket_quote = true;
680 } else if b == b']' && inside_bracket_quote {
681 if i + 1 < buf.len() && buf[i + 1] == b']' {
683 continue;
685 }
686 inside_bracket_quote = false;
687 } else if b == b';' && !inside_string {
688 self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
690 consumed = i + 1;
691 found_terminator = true;
692 break;
693 } else if b == b'\n' && !inside_string {
694 self.stmt_buffer.extend_from_slice(&buf[consumed..=i]);
697 consumed = i + 1;
698
699 let line = &self.stmt_buffer[line_start..];
701 if is_go_line(line) {
702 self.stmt_buffer.truncate(line_start);
704 while self
706 .stmt_buffer
707 .last()
708 .is_some_and(|&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
709 {
710 self.stmt_buffer.pop();
711 }
712 if !self.stmt_buffer.is_empty() {
714 self.reader.consume(consumed);
715 let result = std::mem::take(&mut self.stmt_buffer);
716 return Ok(Some(result));
717 }
718 line_start = 0;
720 } else {
721 line_start = self.stmt_buffer.len();
723 }
724 continue;
725 }
726 }
727
728 if found_terminator {
729 self.reader.consume(consumed);
730 let result = std::mem::take(&mut self.stmt_buffer);
731 return Ok(Some(result));
732 }
733
734 if consumed < buf.len() {
736 self.stmt_buffer.extend_from_slice(&buf[consumed..]);
737 }
738 let len = buf.len();
739 self.reader.consume(len);
740 }
741 }
742
743 #[allow(dead_code)]
744 pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
745 Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
746 }
747
748 pub fn parse_statement_with_dialect(
750 stmt: &[u8],
751 dialect: SqlDialect,
752 ) -> (StatementType, String) {
753 let stmt = strip_leading_comments_and_whitespace(stmt);
755
756 if stmt.len() < 4 {
757 return (StatementType::Unknown, String::new());
758 }
759
760 let mut upper_prefix = [0u8; 25];
762 let prefix_len = stmt.len().min(25);
763 for (i, &b) in stmt.iter().take(prefix_len).enumerate() {
764 upper_prefix[i] = b.to_ascii_uppercase();
765 }
766 let upper_prefix = &upper_prefix[..prefix_len];
767
768 if upper_prefix.starts_with(b"COPY ") {
770 if let Some(caps) = COPY_RE.captures(stmt) {
771 if let Some(m) = caps.get(1) {
772 let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
773 let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
775 return (StatementType::Copy, table_name);
776 }
777 }
778 }
779
780 if upper_prefix.starts_with(b"CREATE TABLE") {
781 if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
783 return (StatementType::CreateTable, name);
784 }
785 if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
787 if let Some(m) = caps.get(1) {
788 return (
789 StatementType::CreateTable,
790 String::from_utf8_lossy(m.as_bytes()).into_owned(),
791 );
792 }
793 }
794 if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
796 if let Some(m) = caps.get(1) {
797 return (
798 StatementType::CreateTable,
799 String::from_utf8_lossy(m.as_bytes()).into_owned(),
800 );
801 }
802 }
803 }
804
805 if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
806 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
807 return (StatementType::Insert, name);
808 }
809 if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
810 if let Some(m) = caps.get(1) {
811 return (
812 StatementType::Insert,
813 String::from_utf8_lossy(m.as_bytes()).into_owned(),
814 );
815 }
816 }
817 if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
818 if let Some(m) = caps.get(1) {
819 return (
820 StatementType::Insert,
821 String::from_utf8_lossy(m.as_bytes()).into_owned(),
822 );
823 }
824 }
825 }
826
827 if upper_prefix.starts_with(b"CREATE INDEX")
828 || upper_prefix.starts_with(b"CREATE UNIQUE")
829 || upper_prefix.starts_with(b"CREATE CLUSTERED")
830 || upper_prefix.starts_with(b"CREATE NONCLUSTER")
831 {
832 if dialect == SqlDialect::Mssql {
834 if let Some(caps) = CREATE_INDEX_MSSQL_RE.captures(stmt) {
835 if let Some(m) = caps.get(1) {
836 return (
837 StatementType::CreateIndex,
838 String::from_utf8_lossy(m.as_bytes()).into_owned(),
839 );
840 }
841 }
842 }
843 if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
845 if let Some(m) = caps.get(1) {
846 return (
847 StatementType::CreateIndex,
848 String::from_utf8_lossy(m.as_bytes()).into_owned(),
849 );
850 }
851 }
852 }
853
854 if upper_prefix.starts_with(b"ALTER TABLE") {
855 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
856 return (StatementType::AlterTable, name);
857 }
858 if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
859 if let Some(m) = caps.get(1) {
860 return (
861 StatementType::AlterTable,
862 String::from_utf8_lossy(m.as_bytes()).into_owned(),
863 );
864 }
865 }
866 }
867
868 if upper_prefix.starts_with(b"DROP TABLE") {
869 if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
870 return (StatementType::DropTable, name);
871 }
872 if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
873 if let Some(m) = caps.get(1) {
874 return (
875 StatementType::DropTable,
876 String::from_utf8_lossy(m.as_bytes()).into_owned(),
877 );
878 }
879 }
880 }
881
882 if upper_prefix.starts_with(b"BULK INSERT") {
884 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
885 return (StatementType::Insert, name);
886 }
887 }
888
889 (StatementType::Unknown, String::new())
890 }
891}
892
893#[inline]
894fn trim_ascii_start(data: &[u8]) -> &[u8] {
895 let start = data
896 .iter()
897 .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
898 .unwrap_or(data.len());
899 &data[start..]
900}
901
902fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
905 loop {
906 data = trim_ascii_start(data);
908
909 if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
911 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
913 data = &data[pos + 1..];
914 continue;
915 } else {
916 return &[];
918 }
919 }
920
921 if data.len() >= 2 && data[0] == b'/' && data[1] == b'*' {
923 let mut i = 2;
925 let mut depth = 1;
926 while i < data.len() - 1 && depth > 0 {
927 if data[i] == b'*' && data[i + 1] == b'/' {
928 depth -= 1;
929 i += 2;
930 } else if data[i] == b'/' && data[i + 1] == b'*' {
931 depth += 1;
932 i += 2;
933 } else {
934 i += 1;
935 }
936 }
937 if depth == 0 {
938 data = &data[i..];
939 continue;
940 } else {
941 return &[];
943 }
944 }
945
946 if !data.is_empty() && data[0] == b'#' {
948 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
949 data = &data[pos + 1..];
950 continue;
951 } else {
952 return &[];
953 }
954 }
955
956 break;
957 }
958
959 data
960}
961
962#[inline]
968fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
969 let mut i = offset;
970
971 while i < stmt.len() && is_whitespace(stmt[i]) {
973 i += 1;
974 }
975
976 if i >= stmt.len() {
977 return None;
978 }
979
980 let mut upper_check = [0u8; 20];
982 let check_len = (stmt.len() - i).min(20);
983 for (idx, &b) in stmt[i..].iter().take(check_len).enumerate() {
984 upper_check[idx] = b.to_ascii_uppercase();
985 }
986 let upper_slice = &upper_check[..check_len];
987 if upper_slice.starts_with(b"IF NOT EXISTS") {
988 i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
990 i += 1;
991 }
992 } else if upper_slice.starts_with(b"IF EXISTS") {
993 i += 9; while i < stmt.len() && is_whitespace(stmt[i]) {
995 i += 1;
996 }
997 }
998
999 let only_check = if i < stmt.len() {
1001 let mut buf = [0u8; 10];
1002 let len = (stmt.len() - i).min(10);
1003 for (idx, &b) in stmt[i..].iter().take(len).enumerate() {
1004 buf[idx] = b.to_ascii_uppercase();
1005 }
1006 (buf, len)
1007 } else {
1008 ([0u8; 10], 0)
1009 };
1010 let only_slice = &only_check.0[..only_check.1];
1011 if only_slice.starts_with(b"ONLY ") || only_slice.starts_with(b"ONLY\t") {
1012 i += 4;
1013 while i < stmt.len() && is_whitespace(stmt[i]) {
1014 i += 1;
1015 }
1016 }
1017
1018 if i >= stmt.len() {
1019 return None;
1020 }
1021
1022 let mut parts: Vec<String> = Vec::new();
1024
1025 loop {
1026 let (quote_char, close_char) = match stmt.get(i) {
1028 Some(b'`') if dialect == SqlDialect::MySql => {
1029 i += 1;
1030 (Some(b'`'), b'`')
1031 }
1032 Some(b'"') if dialect != SqlDialect::MySql => {
1033 i += 1;
1034 (Some(b'"'), b'"')
1035 }
1036 Some(b'"') => {
1037 i += 1;
1039 (Some(b'"'), b'"')
1040 }
1041 Some(b'[') if dialect == SqlDialect::Mssql => {
1042 i += 1;
1044 (Some(b'['), b']')
1045 }
1046 _ => (None, 0),
1047 };
1048
1049 let start = i;
1050
1051 while i < stmt.len() {
1052 let b = stmt[i];
1053 if quote_char.is_some() {
1054 if b == close_char {
1055 if dialect == SqlDialect::Mssql
1057 && close_char == b']'
1058 && i + 1 < stmt.len()
1059 && stmt[i + 1] == b']'
1060 {
1061 i += 2;
1063 continue;
1064 }
1065 let name = &stmt[start..i];
1066 let name_str = if dialect == SqlDialect::Mssql {
1068 String::from_utf8_lossy(name).replace("]]", "]")
1069 } else {
1070 String::from_utf8_lossy(name).into_owned()
1071 };
1072 parts.push(name_str);
1073 i += 1; break;
1075 }
1076 } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
1077 if i > start {
1078 let name = &stmt[start..i];
1079 parts.push(String::from_utf8_lossy(name).into_owned());
1080 }
1081 break;
1082 }
1083 i += 1;
1084 }
1085
1086 if quote_char.is_some() && i <= start {
1088 break;
1089 }
1090
1091 while i < stmt.len() && is_whitespace(stmt[i]) {
1093 i += 1;
1094 }
1095
1096 if i < stmt.len() && stmt[i] == b'.' {
1097 i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
1099 i += 1;
1100 }
1101 } else {
1103 break;
1104 }
1105 }
1106
1107 parts.pop()
1109}
1110
1111#[inline]
1112fn is_whitespace(b: u8) -> bool {
1113 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
1114}
1115
1116pub fn determine_buffer_size(file_size: u64) -> usize {
1117 if file_size > 1024 * 1024 * 1024 {
1118 MEDIUM_BUFFER_SIZE
1119 } else {
1120 SMALL_BUFFER_SIZE
1121 }
1122}