sql_splitter/convert/
mod.rs

1//! Convert command for translating SQL dumps between dialects.
2//!
3//! Supports conversion between MySQL, PostgreSQL, and SQLite dialects with:
4//! - Identifier quoting conversion (backticks ↔ double quotes)
5//! - String escape normalization (\' ↔ '')
6//! - Data type mapping (AUTO_INCREMENT ↔ SERIAL ↔ INTEGER PRIMARY KEY)
7//! - COPY FROM stdin → INSERT conversion
8//! - Session header conversion
9//! - Warning system for unsupported features
10
11mod copy_to_insert;
12mod types;
13mod warnings;
14
15#[allow(unused_imports)]
16pub use copy_to_insert::{
17    copy_to_inserts, parse_copy_data, parse_copy_header, CopyHeader, CopyValue,
18};
19
20use crate::parser::{Parser, SqlDialect, StatementType};
21use crate::progress::ProgressReader;
22use crate::splitter::Compression;
23use indicatif::{ProgressBar, ProgressStyle};
24use std::fs::File;
25use std::io::{BufWriter, Read, Write};
26use std::path::PathBuf;
27
28pub use types::TypeMapper;
29pub use warnings::{ConvertWarning, WarningCollector};
30
31/// Configuration for the convert command
32#[derive(Debug)]
33pub struct ConvertConfig {
34    /// Input SQL file
35    pub input: PathBuf,
36    /// Output SQL file (None for stdout)
37    pub output: Option<PathBuf>,
38    /// Source dialect (auto-detected if None)
39    pub from_dialect: Option<SqlDialect>,
40    /// Target dialect
41    pub to_dialect: SqlDialect,
42    /// Dry run mode
43    pub dry_run: bool,
44    /// Show progress
45    pub progress: bool,
46    /// Strict mode (fail on any unsupported feature)
47    pub strict: bool,
48}
49
50impl Default for ConvertConfig {
51    fn default() -> Self {
52        Self {
53            input: PathBuf::new(),
54            output: None,
55            from_dialect: None,
56            to_dialect: SqlDialect::Postgres,
57            dry_run: false,
58            progress: false,
59            strict: false,
60        }
61    }
62}
63
64/// Statistics from convert operation
65#[derive(Debug, Default)]
66pub struct ConvertStats {
67    /// Total statements processed
68    pub statements_processed: u64,
69    /// Statements converted
70    pub statements_converted: u64,
71    /// Statements passed through unchanged
72    pub statements_unchanged: u64,
73    /// Statements skipped (unsupported)
74    pub statements_skipped: u64,
75    /// Warnings generated
76    pub warnings: Vec<ConvertWarning>,
77}
78
79/// Main converter that dispatches to specific dialect converters
80pub struct Converter {
81    from: SqlDialect,
82    to: SqlDialect,
83    warnings: WarningCollector,
84    strict: bool,
85    /// Pending COPY header for data block processing
86    pending_copy_header: Option<CopyHeader>,
87}
88
89impl Converter {
90    pub fn new(from: SqlDialect, to: SqlDialect) -> Self {
91        Self {
92            from,
93            to,
94            warnings: WarningCollector::new(),
95            strict: false,
96            pending_copy_header: None,
97        }
98    }
99
100    pub fn with_strict(mut self, strict: bool) -> Self {
101        self.strict = strict;
102        self
103    }
104
105    /// Check if we have a pending COPY header (waiting for data block)
106    pub fn has_pending_copy(&self) -> bool {
107        self.pending_copy_header.is_some()
108    }
109
110    /// Process a COPY data block using the pending header
111    pub fn process_copy_data(&mut self, data: &[u8]) -> Result<Vec<Vec<u8>>, ConvertWarning> {
112        if let Some(header) = self.pending_copy_header.take() {
113            if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
114                // Convert COPY data to INSERT statements
115                let inserts = copy_to_inserts(&header, data, self.to);
116                return Ok(inserts);
117            }
118        }
119        // Pass through if same dialect or no pending header
120        Ok(vec![data.to_vec()])
121    }
122
123    /// Convert a single statement
124    pub fn convert_statement(&mut self, stmt: &[u8]) -> Result<Vec<u8>, ConvertWarning> {
125        let (stmt_type, table_name) =
126            Parser::<&[u8]>::parse_statement_with_dialect(stmt, self.from);
127
128        let table = if table_name.is_empty() {
129            None
130        } else {
131            Some(table_name.as_str())
132        };
133
134        match stmt_type {
135            StatementType::CreateTable => self.convert_create_table(stmt, table),
136            StatementType::Insert => self.convert_insert(stmt, table),
137            StatementType::CreateIndex => self.convert_create_index(stmt),
138            StatementType::AlterTable => self.convert_alter_table(stmt),
139            StatementType::DropTable => self.convert_drop_table(stmt),
140            StatementType::Copy => self.convert_copy(stmt, table),
141            StatementType::Unknown => self.convert_other(stmt),
142        }
143    }
144
145    /// Convert CREATE TABLE statement
146    fn convert_create_table(
147        &mut self,
148        stmt: &[u8],
149        table_name: Option<&str>,
150    ) -> Result<Vec<u8>, ConvertWarning> {
151        let stmt_str = String::from_utf8_lossy(stmt);
152        let mut result = stmt_str.to_string();
153
154        // Detect unsupported features BEFORE conversion (so we see original types)
155        self.detect_unsupported_features(&result, table_name)?;
156
157        // Convert identifier quoting
158        result = self.convert_identifiers(&result);
159
160        // Convert data types
161        result = self.convert_data_types(&result);
162
163        // Convert AUTO_INCREMENT
164        result = self.convert_auto_increment(&result, table_name);
165
166        // Convert PostgreSQL-specific syntax
167        if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
168            result = self.strip_postgres_casts(&result);
169            result = self.convert_nextval(&result);
170            result = self.convert_default_now(&result);
171            result = self.strip_schema_prefix(&result);
172        }
173
174        // Convert string escapes
175        result = self.convert_string_escapes(&result);
176
177        // Strip MySQL conditional comments
178        result = self.strip_conditional_comments(&result);
179
180        // Convert ENGINE clause
181        result = self.strip_engine_clause(&result);
182
183        // Convert CHARSET/COLLATE
184        result = self.strip_charset_clauses(&result);
185
186        Ok(result.into_bytes())
187    }
188
189    /// Convert INSERT statement
190    fn convert_insert(
191        &mut self,
192        stmt: &[u8],
193        _table_name: Option<&str>,
194    ) -> Result<Vec<u8>, ConvertWarning> {
195        let stmt_str = String::from_utf8_lossy(stmt);
196        let mut result = stmt_str.to_string();
197
198        // Convert identifier quoting
199        result = self.convert_identifiers(&result);
200
201        // Convert PostgreSQL-specific syntax
202        if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
203            result = self.strip_postgres_casts(&result);
204            result = self.strip_schema_prefix(&result);
205        }
206
207        // Convert string escapes (careful with data!)
208        result = self.convert_string_escapes(&result);
209
210        Ok(result.into_bytes())
211    }
212
213    /// Convert CREATE INDEX statement
214    fn convert_create_index(&mut self, stmt: &[u8]) -> Result<Vec<u8>, ConvertWarning> {
215        let stmt_str = String::from_utf8_lossy(stmt);
216        let mut result = stmt_str.to_string();
217
218        // Convert identifier quoting
219        result = self.convert_identifiers(&result);
220
221        // Convert PostgreSQL-specific syntax
222        if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
223            result = self.strip_postgres_casts(&result);
224            result = self.strip_schema_prefix(&result);
225        }
226
227        // Detect FULLTEXT/SPATIAL
228        if result.contains("FULLTEXT") || result.contains("fulltext") {
229            self.warnings.add(ConvertWarning::UnsupportedFeature {
230                feature: "FULLTEXT INDEX".to_string(),
231                suggestion: Some("Use PostgreSQL GIN index or skip".to_string()),
232            });
233            if self.strict {
234                return Err(ConvertWarning::UnsupportedFeature {
235                    feature: "FULLTEXT INDEX".to_string(),
236                    suggestion: None,
237                });
238            }
239        }
240
241        Ok(result.into_bytes())
242    }
243
244    /// Convert ALTER TABLE statement
245    fn convert_alter_table(&mut self, stmt: &[u8]) -> Result<Vec<u8>, ConvertWarning> {
246        let stmt_str = String::from_utf8_lossy(stmt);
247        let mut result = stmt_str.to_string();
248
249        result = self.convert_identifiers(&result);
250        result = self.convert_data_types(&result);
251
252        // Convert PostgreSQL-specific syntax
253        if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
254            result = self.strip_postgres_casts(&result);
255            result = self.convert_nextval(&result);
256            result = self.convert_default_now(&result);
257            result = self.strip_schema_prefix(&result);
258        }
259
260        Ok(result.into_bytes())
261    }
262
263    /// Convert DROP TABLE statement
264    fn convert_drop_table(&mut self, stmt: &[u8]) -> Result<Vec<u8>, ConvertWarning> {
265        let stmt_str = String::from_utf8_lossy(stmt);
266        let mut result = stmt_str.to_string();
267
268        result = self.convert_identifiers(&result);
269
270        // Strip PostgreSQL schema prefix
271        if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
272            result = self.strip_schema_prefix(&result);
273        }
274
275        Ok(result.into_bytes())
276    }
277
278    /// Convert COPY statement (PostgreSQL-specific)
279    ///
280    /// This handles the COPY header. The data block is processed separately
281    /// via process_copy_data() when called from the run() function.
282    fn convert_copy(
283        &mut self,
284        stmt: &[u8],
285        _table_name: Option<&str>,
286    ) -> Result<Vec<u8>, ConvertWarning> {
287        let stmt_str = String::from_utf8_lossy(stmt);
288
289        // Check if this contains "FROM stdin" (COPY header) or is data
290        let upper = stmt_str.to_uppercase();
291        if upper.contains("FROM STDIN") {
292            // This is a COPY header - parse it and store for later
293            if let Some(header) = parse_copy_header(&stmt_str) {
294                if self.from == SqlDialect::Postgres && self.to != SqlDialect::Postgres {
295                    // Store the header, will convert data block in process_copy_data
296                    self.pending_copy_header = Some(header);
297                    // Return empty - the actual INSERT will be generated from data
298                    return Ok(Vec::new());
299                }
300            }
301        }
302
303        // If same dialect or couldn't parse, pass through
304        Ok(stmt.to_vec())
305    }
306
307    /// Convert other statements (comments, session settings, etc.)
308    fn convert_other(&mut self, stmt: &[u8]) -> Result<Vec<u8>, ConvertWarning> {
309        let stmt_str = String::from_utf8_lossy(stmt);
310        let result = stmt_str.to_string();
311        let trimmed = result.trim();
312
313        // Skip MySQL session commands when converting to other dialects
314        if self.from == SqlDialect::MySql
315            && self.to != SqlDialect::MySql
316            && self.is_mysql_session_command(&result)
317        {
318            return Ok(Vec::new()); // Skip
319        }
320
321        // Skip PostgreSQL session commands and unsupported features when converting to other dialects
322        if self.from == SqlDialect::Postgres
323            && self.to != SqlDialect::Postgres
324            && self.is_postgres_session_command(&result)
325        {
326            return Ok(Vec::new()); // Skip
327        }
328        if self.from == SqlDialect::Postgres
329            && self.to != SqlDialect::Postgres
330            && self.is_postgres_only_feature(trimmed)
331        {
332            self.warnings.add(ConvertWarning::SkippedStatement {
333                reason: "PostgreSQL-only feature".to_string(),
334                statement_preview: trimmed.chars().take(60).collect(),
335            });
336            return Ok(Vec::new()); // Skip
337        }
338
339        // Skip SQLite pragmas when converting to other dialects
340        if self.from == SqlDialect::Sqlite
341            && self.to != SqlDialect::Sqlite
342            && self.is_sqlite_pragma(&result)
343        {
344            return Ok(Vec::new()); // Skip
345        }
346
347        // Strip conditional comments
348        if result.contains("/*!") {
349            let stripped = self.strip_conditional_comments(&result);
350            return Ok(stripped.into_bytes());
351        }
352
353        Ok(stmt.to_vec())
354    }
355
356    /// Check if statement is a MySQL session command
357    fn is_mysql_session_command(&self, stmt: &str) -> bool {
358        let upper = stmt.to_uppercase();
359        upper.contains("SET NAMES")
360            || upper.contains("SET CHARACTER")
361            || upper.contains("SET SQL_MODE")
362            || upper.contains("SET TIME_ZONE")
363            || upper.contains("SET FOREIGN_KEY_CHECKS")
364            || upper.contains("LOCK TABLES")
365            || upper.contains("UNLOCK TABLES")
366    }
367
368    /// Check if statement is a PostgreSQL session command or unsupported statement
369    fn is_postgres_session_command(&self, stmt: &str) -> bool {
370        let upper = stmt.to_uppercase();
371        // Session/transaction settings
372        upper.contains("SET CLIENT_ENCODING")
373            || upper.contains("SET STANDARD_CONFORMING_STRINGS")
374            || upper.contains("SET CHECK_FUNCTION_BODIES")
375            || upper.contains("SET SEARCH_PATH")
376            || upper.contains("SET DEFAULT_TABLESPACE")
377            || upper.contains("SET LOCK_TIMEOUT")
378            || upper.contains("SET IDLE_IN_TRANSACTION_SESSION_TIMEOUT")
379            || upper.contains("SET ROW_SECURITY")
380            || upper.contains("SET STATEMENT_TIMEOUT")
381            || upper.contains("SET XMLOPTION")
382            || upper.contains("SET CLIENT_MIN_MESSAGES")
383            || upper.contains("SET DEFAULT_TABLE_ACCESS_METHOD")
384            || upper.contains("SELECT PG_CATALOG")
385            // Ownership/permission statements
386            || upper.contains("OWNER TO")
387            || upper.contains("GRANT ")
388            || upper.contains("REVOKE ")
389    }
390
391    /// Check if statement is a PostgreSQL-only feature that should be skipped
392    fn is_postgres_only_feature(&self, stmt: &str) -> bool {
393        // Strip leading comments to find the actual statement
394        let stripped = self.strip_leading_sql_comments(stmt);
395        let upper = stripped.to_uppercase();
396
397        // These PostgreSQL features have no MySQL/SQLite equivalent
398        upper.starts_with("CREATE DOMAIN")
399            || upper.starts_with("CREATE TYPE")
400            || upper.starts_with("CREATE FUNCTION")
401            || upper.starts_with("CREATE PROCEDURE")
402            || upper.starts_with("CREATE AGGREGATE")
403            || upper.starts_with("CREATE OPERATOR")
404            || upper.starts_with("CREATE SEQUENCE")
405            || upper.starts_with("CREATE EXTENSION")
406            || upper.starts_with("CREATE SCHEMA")
407            || upper.starts_with("CREATE TRIGGER")
408            || upper.starts_with("ALTER DOMAIN")
409            || upper.starts_with("ALTER TYPE")
410            || upper.starts_with("ALTER FUNCTION")
411            || upper.starts_with("ALTER SEQUENCE")
412            || upper.starts_with("ALTER SCHEMA")
413            || upper.starts_with("COMMENT ON")
414    }
415
416    /// Strip leading SQL comments (-- and /* */) from a string
417    fn strip_leading_sql_comments(&self, stmt: &str) -> String {
418        let mut result = stmt.trim();
419        loop {
420            // Strip -- comments
421            if result.starts_with("--") {
422                if let Some(pos) = result.find('\n') {
423                    result = result[pos + 1..].trim();
424                    continue;
425                } else {
426                    return String::new();
427                }
428            }
429            // Strip /* */ comments
430            if result.starts_with("/*") {
431                if let Some(pos) = result.find("*/") {
432                    result = result[pos + 2..].trim();
433                    continue;
434                } else {
435                    return String::new();
436                }
437            }
438            break;
439        }
440        result.to_string()
441    }
442
443    /// Check if statement is a SQLite pragma
444    fn is_sqlite_pragma(&self, stmt: &str) -> bool {
445        let upper = stmt.to_uppercase();
446        upper.contains("PRAGMA")
447    }
448
449    /// Convert identifier quoting based on dialects
450    fn convert_identifiers(&self, stmt: &str) -> String {
451        match (self.from, self.to) {
452            (SqlDialect::MySql, SqlDialect::Postgres | SqlDialect::Sqlite) => {
453                // Backticks → double quotes
454                self.backticks_to_double_quotes(stmt)
455            }
456            (SqlDialect::Postgres | SqlDialect::Sqlite, SqlDialect::MySql) => {
457                // Double quotes → backticks
458                self.double_quotes_to_backticks(stmt)
459            }
460            _ => stmt.to_string(),
461        }
462    }
463
464    /// Convert backticks to double quotes
465    pub fn backticks_to_double_quotes(&self, stmt: &str) -> String {
466        let mut result = String::with_capacity(stmt.len());
467        let mut in_string = false;
468        let mut in_backtick = false;
469
470        for c in stmt.chars() {
471            if c == '\'' && !in_backtick {
472                in_string = !in_string;
473                result.push(c);
474            } else if c == '`' && !in_string {
475                in_backtick = !in_backtick;
476                result.push('"');
477            } else {
478                result.push(c);
479            }
480        }
481        result
482    }
483
484    /// Convert double quotes to backticks
485    pub fn double_quotes_to_backticks(&self, stmt: &str) -> String {
486        let mut result = String::with_capacity(stmt.len());
487        let mut in_string = false;
488        let mut in_dquote = false;
489        let chars = stmt.chars();
490
491        for c in chars {
492            if c == '\'' && !in_dquote {
493                in_string = !in_string;
494                result.push(c);
495            } else if c == '"' && !in_string {
496                in_dquote = !in_dquote;
497                result.push('`');
498            } else {
499                result.push(c);
500            }
501        }
502        result
503    }
504
505    /// Convert data types between dialects
506    fn convert_data_types(&self, stmt: &str) -> String {
507        TypeMapper::convert(stmt, self.from, self.to)
508    }
509
510    /// Convert AUTO_INCREMENT/SERIAL syntax
511    fn convert_auto_increment(&self, stmt: &str, _table_name: Option<&str>) -> String {
512        match (self.from, self.to) {
513            (SqlDialect::MySql, SqlDialect::Postgres) => {
514                // INT AUTO_INCREMENT → SERIAL
515                // BIGINT AUTO_INCREMENT → BIGSERIAL
516                let result = stmt.replace("BIGINT AUTO_INCREMENT", "BIGSERIAL");
517                let result = result.replace("bigint AUTO_INCREMENT", "BIGSERIAL");
518                let result = result.replace("INT AUTO_INCREMENT", "SERIAL");
519                let result = result.replace("int AUTO_INCREMENT", "SERIAL");
520                result.replace("AUTO_INCREMENT", "") // Clean up any remaining
521            }
522            (SqlDialect::MySql, SqlDialect::Sqlite) => {
523                // INT AUTO_INCREMENT PRIMARY KEY → INTEGER PRIMARY KEY
524                // The AUTOINCREMENT keyword is optional in SQLite
525                let result = stmt.replace("INT AUTO_INCREMENT", "INTEGER");
526                let result = result.replace("int AUTO_INCREMENT", "INTEGER");
527                result.replace("AUTO_INCREMENT", "")
528            }
529            (SqlDialect::Postgres, SqlDialect::MySql) => {
530                // SERIAL → INT AUTO_INCREMENT
531                // BIGSERIAL → BIGINT AUTO_INCREMENT
532                let result = stmt.replace("BIGSERIAL", "BIGINT AUTO_INCREMENT");
533                let result = result.replace("bigserial", "BIGINT AUTO_INCREMENT");
534                let result = result.replace("SMALLSERIAL", "SMALLINT AUTO_INCREMENT");
535                let result = result.replace("smallserial", "SMALLINT AUTO_INCREMENT");
536                let result = result.replace("SERIAL", "INT AUTO_INCREMENT");
537                result.replace("serial", "INT AUTO_INCREMENT")
538            }
539            (SqlDialect::Postgres, SqlDialect::Sqlite) => {
540                // SERIAL → INTEGER (SQLite auto-increments INTEGER PRIMARY KEY)
541                let result = stmt.replace("BIGSERIAL", "INTEGER");
542                let result = result.replace("bigserial", "INTEGER");
543                let result = result.replace("SMALLSERIAL", "INTEGER");
544                let result = result.replace("smallserial", "INTEGER");
545                let result = result.replace("SERIAL", "INTEGER");
546                result.replace("serial", "INTEGER")
547            }
548            (SqlDialect::Sqlite, SqlDialect::MySql) => {
549                // SQLite uses INTEGER PRIMARY KEY for auto-increment
550                // We can't easily detect this pattern, so just pass through
551                stmt.to_string()
552            }
553            (SqlDialect::Sqlite, SqlDialect::Postgres) => {
554                // SQLite uses INTEGER PRIMARY KEY for auto-increment
555                // We can't easily detect this pattern, so just pass through
556                stmt.to_string()
557            }
558            _ => stmt.to_string(),
559        }
560    }
561
562    /// Convert string escape sequences
563    fn convert_string_escapes(&self, stmt: &str) -> String {
564        match (self.from, self.to) {
565            (SqlDialect::MySql, SqlDialect::Postgres | SqlDialect::Sqlite) => {
566                // MySQL uses \' for escaping, PostgreSQL/SQLite use ''
567                self.mysql_escapes_to_standard(stmt)
568            }
569            _ => stmt.to_string(),
570        }
571    }
572
573    /// Convert MySQL backslash escapes to standard SQL double-quote escapes
574    fn mysql_escapes_to_standard(&self, stmt: &str) -> String {
575        let mut result = String::with_capacity(stmt.len());
576        let mut chars = stmt.chars().peekable();
577        let mut in_string = false;
578
579        while let Some(c) = chars.next() {
580            if c == '\'' {
581                in_string = !in_string;
582                result.push(c);
583            } else if c == '\\' && in_string {
584                // Check next character
585                if let Some(&next) = chars.peek() {
586                    match next {
587                        '\'' => {
588                            // \' → ''
589                            chars.next();
590                            result.push_str("''");
591                        }
592                        '\\' => {
593                            // \\ → keep as-is for data integrity
594                            chars.next();
595                            result.push_str("\\\\");
596                        }
597                        'n' | 'r' | 't' | '0' => {
598                            // Keep common escapes as-is
599                            result.push(c);
600                        }
601                        _ => {
602                            result.push(c);
603                        }
604                    }
605                } else {
606                    result.push(c);
607                }
608            } else {
609                result.push(c);
610            }
611        }
612        result
613    }
614
615    /// Strip MySQL conditional comments /*!40101 ... */
616    fn strip_conditional_comments(&self, stmt: &str) -> String {
617        let mut result = String::with_capacity(stmt.len());
618        let mut chars = stmt.chars().peekable();
619
620        while let Some(c) = chars.next() {
621            if c == '/' && chars.peek() == Some(&'*') {
622                chars.next(); // consume *
623                if chars.peek() == Some(&'!') {
624                    // Skip conditional comment
625                    chars.next(); // consume !
626                                  // Skip version number
627                    while chars.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
628                        chars.next();
629                    }
630                    // Skip content until */
631                    let mut depth = 1;
632                    while depth > 0 {
633                        match chars.next() {
634                            Some('*') if chars.peek() == Some(&'/') => {
635                                chars.next();
636                                depth -= 1;
637                            }
638                            Some('/') if chars.peek() == Some(&'*') => {
639                                chars.next();
640                                depth += 1;
641                            }
642                            None => break,
643                            _ => {}
644                        }
645                    }
646                } else {
647                    // Regular comment, keep it
648                    result.push('/');
649                    result.push('*');
650                }
651            } else {
652                result.push(c);
653            }
654        }
655        result
656    }
657
658    /// Strip ENGINE clause
659    fn strip_engine_clause(&self, stmt: &str) -> String {
660        if self.to == SqlDialect::MySql {
661            return stmt.to_string();
662        }
663
664        // Remove ENGINE=InnoDB, ENGINE=MyISAM, etc.
665        let re = regex::Regex::new(r"(?i)\s*ENGINE\s*=\s*\w+").unwrap();
666        re.replace_all(stmt, "").to_string()
667    }
668
669    /// Strip CHARSET/COLLATE clauses
670    fn strip_charset_clauses(&self, stmt: &str) -> String {
671        if self.to == SqlDialect::MySql {
672            return stmt.to_string();
673        }
674
675        let result = stmt.to_string();
676        let re1 = regex::Regex::new(r"(?i)\s*(DEFAULT\s+)?CHARSET\s*=\s*\w+").unwrap();
677        let result = re1.replace_all(&result, "").to_string();
678
679        let re2 = regex::Regex::new(r"(?i)\s*COLLATE\s*=?\s*\w+").unwrap();
680        re2.replace_all(&result, "").to_string()
681    }
682
683    /// Strip PostgreSQL type casts (::type and ::regclass)
684    fn strip_postgres_casts(&self, stmt: &str) -> String {
685        use once_cell::sync::Lazy;
686        use regex::Regex;
687
688        // Match ::regclass, ::text, ::integer, etc. (including complex types like character varying)
689        static RE_CAST: Lazy<Regex> = Lazy::new(|| {
690            Regex::new(r"::[a-zA-Z_][a-zA-Z0-9_]*(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)*").unwrap()
691        });
692
693        RE_CAST.replace_all(stmt, "").to_string()
694    }
695
696    /// Convert nextval('sequence') to NULL or remove (AUTO_INCREMENT handles it)
697    fn convert_nextval(&self, stmt: &str) -> String {
698        use once_cell::sync::Lazy;
699        use regex::Regex;
700
701        // Match nextval('sequence_name'::regclass) or nextval('sequence_name')
702        // Remove the DEFAULT nextval(...) entirely - AUTO_INCREMENT is already applied
703        static RE_NEXTVAL: Lazy<Regex> =
704            Lazy::new(|| Regex::new(r"(?i)\s*DEFAULT\s+nextval\s*\([^)]+\)").unwrap());
705
706        RE_NEXTVAL.replace_all(stmt, "").to_string()
707    }
708
709    /// Convert DEFAULT now() to DEFAULT CURRENT_TIMESTAMP
710    fn convert_default_now(&self, stmt: &str) -> String {
711        use once_cell::sync::Lazy;
712        use regex::Regex;
713
714        static RE_NOW: Lazy<Regex> =
715            Lazy::new(|| Regex::new(r"(?i)\bDEFAULT\s+now\s*\(\s*\)").unwrap());
716
717        RE_NOW
718            .replace_all(stmt, "DEFAULT CURRENT_TIMESTAMP")
719            .to_string()
720    }
721
722    /// Strip schema prefix from table names (e.g., public.users -> users)
723    fn strip_schema_prefix(&self, stmt: &str) -> String {
724        use once_cell::sync::Lazy;
725        use regex::Regex;
726
727        // Match schema.table patterns (with optional quotes)
728        // Handle: public.table, "public"."table", public."table"
729        static RE_SCHEMA: Lazy<Regex> =
730            Lazy::new(|| Regex::new(r#"(?i)\b(public|pg_catalog|pg_temp)\s*\.\s*"#).unwrap());
731
732        RE_SCHEMA.replace_all(stmt, "").to_string()
733    }
734
735    /// Detect unsupported features and add warnings
736    fn detect_unsupported_features(
737        &mut self,
738        stmt: &str,
739        table_name: Option<&str>,
740    ) -> Result<(), ConvertWarning> {
741        let upper = stmt.to_uppercase();
742
743        // MySQL-specific features
744        if self.from == SqlDialect::MySql {
745            // ENUM types
746            if upper.contains("ENUM(") {
747                let warning = ConvertWarning::UnsupportedFeature {
748                    feature: format!(
749                        "ENUM type{}",
750                        table_name
751                            .map(|t| format!(" in table {}", t))
752                            .unwrap_or_default()
753                    ),
754                    suggestion: Some(
755                        "Converted to VARCHAR - consider adding CHECK constraint".to_string(),
756                    ),
757                };
758                self.warnings.add(warning.clone());
759                if self.strict {
760                    return Err(warning);
761                }
762            }
763
764            // SET types (MySQL)
765            if upper.contains("SET(") {
766                let warning = ConvertWarning::UnsupportedFeature {
767                    feature: format!(
768                        "SET type{}",
769                        table_name
770                            .map(|t| format!(" in table {}", t))
771                            .unwrap_or_default()
772                    ),
773                    suggestion: Some(
774                        "Converted to VARCHAR - SET semantics not preserved".to_string(),
775                    ),
776                };
777                self.warnings.add(warning.clone());
778                if self.strict {
779                    return Err(warning);
780                }
781            }
782
783            // UNSIGNED
784            if upper.contains("UNSIGNED") {
785                self.warnings.add(ConvertWarning::UnsupportedFeature {
786                    feature: "UNSIGNED modifier".to_string(),
787                    suggestion: Some(
788                        "Removed - consider adding CHECK constraint for non-negative values"
789                            .to_string(),
790                    ),
791                });
792            }
793        }
794
795        // PostgreSQL-specific features
796        if self.from == SqlDialect::Postgres {
797            // Array types
798            if upper.contains("[]") || upper.contains("ARRAY[") {
799                let warning = ConvertWarning::UnsupportedFeature {
800                    feature: format!(
801                        "Array type{}",
802                        table_name
803                            .map(|t| format!(" in table {}", t))
804                            .unwrap_or_default()
805                    ),
806                    suggestion: Some(
807                        "Array types not supported in target dialect - consider using JSON"
808                            .to_string(),
809                    ),
810                };
811                self.warnings.add(warning.clone());
812                if self.strict {
813                    return Err(warning);
814                }
815            }
816
817            // INHERITS
818            if upper.contains("INHERITS") {
819                let warning = ConvertWarning::UnsupportedFeature {
820                    feature: "Table inheritance (INHERITS)".to_string(),
821                    suggestion: Some(
822                        "PostgreSQL table inheritance not supported in target dialect".to_string(),
823                    ),
824                };
825                self.warnings.add(warning.clone());
826                if self.strict {
827                    return Err(warning);
828                }
829            }
830
831            // PARTITION BY
832            if upper.contains("PARTITION BY") && self.to == SqlDialect::Sqlite {
833                let warning = ConvertWarning::UnsupportedFeature {
834                    feature: "Table partitioning".to_string(),
835                    suggestion: Some("Partitioning not supported in SQLite".to_string()),
836                };
837                self.warnings.add(warning.clone());
838                if self.strict {
839                    return Err(warning);
840                }
841            }
842        }
843
844        Ok(())
845    }
846
847    /// Get collected warnings
848    pub fn warnings(&self) -> &[ConvertWarning] {
849        self.warnings.warnings()
850    }
851}
852
853/// Run the convert command
854pub fn run(config: ConvertConfig) -> anyhow::Result<ConvertStats> {
855    let mut stats = ConvertStats::default();
856
857    // Detect or use specified source dialect
858    let from_dialect = if let Some(d) = config.from_dialect {
859        d
860    } else {
861        let result = crate::parser::detect_dialect_from_file(&config.input)?;
862        if config.progress {
863            eprintln!(
864                "Auto-detected source dialect: {} (confidence: {:?})",
865                result.dialect, result.confidence
866            );
867        }
868        result.dialect
869    };
870
871    // Check for same dialect
872    if from_dialect == config.to_dialect {
873        anyhow::bail!(
874            "Source and target dialects are the same ({}). No conversion needed.",
875            from_dialect
876        );
877    }
878
879    // Get file size for progress tracking
880    let file_size = std::fs::metadata(&config.input)?.len();
881
882    let progress_bar = if config.progress {
883        let pb = ProgressBar::new(file_size);
884        pb.set_style(
885            ProgressStyle::with_template(
886                "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
887            )
888            .unwrap()
889            .progress_chars("█▓▒░  ")
890            .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
891        );
892        pb.enable_steady_tick(std::time::Duration::from_millis(100));
893        pb.set_message("Converting...");
894        Some(pb)
895    } else {
896        None
897    };
898
899    // Create converter
900    let mut converter = Converter::new(from_dialect, config.to_dialect).with_strict(config.strict);
901
902    // Open input file with optional progress tracking
903    let file = File::open(&config.input)?;
904    let compression = Compression::from_path(&config.input);
905    let reader: Box<dyn Read> = if let Some(ref pb) = progress_bar {
906        let pb_clone = pb.clone();
907        let progress_reader = ProgressReader::new(file, move |bytes| {
908            pb_clone.set_position(bytes);
909        });
910        compression.wrap_reader(Box::new(progress_reader))
911    } else {
912        compression.wrap_reader(Box::new(file))
913    };
914    let mut parser = Parser::with_dialect(reader, 64 * 1024, from_dialect);
915
916    // Open output
917    let mut writer: Box<dyn Write> = if config.dry_run {
918        Box::new(std::io::sink())
919    } else {
920        match &config.output {
921            Some(path) => {
922                if let Some(parent) = path.parent() {
923                    std::fs::create_dir_all(parent)?;
924                }
925                Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
926            }
927            None => Box::new(BufWriter::new(std::io::stdout())),
928        }
929    };
930
931    // Write header
932    if !config.dry_run {
933        write_header(&mut writer, &config, from_dialect)?;
934    }
935
936    // Process statements
937    while let Some(stmt) = parser.read_statement()? {
938        stats.statements_processed += 1;
939
940        // Check if this is a COPY data block (follows a COPY header)
941        if converter.has_pending_copy() {
942            // This is a data block, convert it to INSERT statements
943            match converter.process_copy_data(&stmt) {
944                Ok(inserts) => {
945                    for insert in inserts {
946                        if !insert.is_empty() {
947                            stats.statements_converted += 1;
948                            if !config.dry_run {
949                                writer.write_all(&insert)?;
950                                writer.write_all(b"\n")?;
951                            }
952                        }
953                    }
954                }
955                Err(warning) => {
956                    stats.warnings.push(warning);
957                    stats.statements_skipped += 1;
958                }
959            }
960            continue;
961        }
962
963        match converter.convert_statement(&stmt) {
964            Ok(converted) => {
965                if converted.is_empty() {
966                    stats.statements_skipped += 1;
967                } else if converted == stmt {
968                    stats.statements_unchanged += 1;
969                    if !config.dry_run {
970                        writer.write_all(&converted)?;
971                        writer.write_all(b"\n")?;
972                    }
973                } else {
974                    stats.statements_converted += 1;
975                    if !config.dry_run {
976                        writer.write_all(&converted)?;
977                        writer.write_all(b"\n")?;
978                    }
979                }
980            }
981            Err(warning) => {
982                stats.warnings.push(warning);
983                stats.statements_skipped += 1;
984            }
985        }
986    }
987
988    // Collect warnings
989    stats.warnings.extend(converter.warnings().iter().cloned());
990
991    if let Some(pb) = progress_bar {
992        pb.finish_with_message("done");
993    }
994
995    Ok(stats)
996}
997
998/// Write output header
999fn write_header(
1000    writer: &mut dyn Write,
1001    config: &ConvertConfig,
1002    from: SqlDialect,
1003) -> std::io::Result<()> {
1004    writeln!(writer, "-- Converted by sql-splitter")?;
1005    writeln!(writer, "-- From: {} → To: {}", from, config.to_dialect)?;
1006    writeln!(writer, "-- Source: {}", config.input.display())?;
1007    writeln!(writer)?;
1008
1009    // Write dialect-specific header
1010    match config.to_dialect {
1011        SqlDialect::Postgres => {
1012            writeln!(writer, "SET client_encoding = 'UTF8';")?;
1013            writeln!(writer, "SET standard_conforming_strings = on;")?;
1014        }
1015        SqlDialect::Sqlite => {
1016            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1017        }
1018        SqlDialect::MySql => {
1019            writeln!(writer, "SET NAMES utf8mb4;")?;
1020            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1021        }
1022    }
1023    writeln!(writer)?;
1024
1025    Ok(())
1026}