sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use std::fs::File;
6use std::io::Read;
7use std::path::{Path, PathBuf};
8
9pub struct Stats {
10    pub statements_processed: u64,
11    pub tables_found: usize,
12    pub bytes_processed: u64,
13    pub table_names: Vec<String>,
14}
15
16#[derive(Default)]
17pub struct SplitterConfig {
18    pub dialect: SqlDialect,
19    pub dry_run: bool,
20    pub table_filter: Option<AHashSet<String>>,
21    pub progress_fn: Option<Box<dyn Fn(u64)>>,
22    pub content_filter: ContentFilter,
23}
24
25/// Compression format detected from file extension
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum Compression {
28    None,
29    Gzip,
30    Bzip2,
31    Xz,
32    Zstd,
33}
34
35impl Compression {
36    /// Detect compression format from file extension
37    pub fn from_path(path: &Path) -> Self {
38        let ext = path
39            .extension()
40            .and_then(|e| e.to_str())
41            .map(|e| e.to_lowercase());
42
43        match ext.as_deref() {
44            Some("gz" | "gzip") => Compression::Gzip,
45            Some("bz2" | "bzip2") => Compression::Bzip2,
46            Some("xz" | "lzma") => Compression::Xz,
47            Some("zst" | "zstd") => Compression::Zstd,
48            _ => Compression::None,
49        }
50    }
51
52    /// Wrap a reader with the appropriate decompressor
53    pub fn wrap_reader<'a>(&self, reader: Box<dyn Read + 'a>) -> Box<dyn Read + 'a> {
54        match self {
55            Compression::None => reader,
56            Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
57            Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
58            Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
59            Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader).unwrap()),
60        }
61    }
62}
63
64impl std::fmt::Display for Compression {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            Compression::None => write!(f, "none"),
68            Compression::Gzip => write!(f, "gzip"),
69            Compression::Bzip2 => write!(f, "bzip2"),
70            Compression::Xz => write!(f, "xz"),
71            Compression::Zstd => write!(f, "zstd"),
72        }
73    }
74}
75
76pub struct Splitter {
77    input_file: PathBuf,
78    output_dir: PathBuf,
79    config: SplitterConfig,
80}
81
82impl Splitter {
83    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
84        Self {
85            input_file,
86            output_dir,
87            config: SplitterConfig::default(),
88        }
89    }
90
91    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
92        self.config.dialect = dialect;
93        self
94    }
95
96    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
97        self.config.dry_run = dry_run;
98        self
99    }
100
101    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
102        if !tables.is_empty() {
103            self.config.table_filter = Some(tables.into_iter().collect());
104        }
105        self
106    }
107
108    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
109        self.config.progress_fn = Some(Box::new(f));
110        self
111    }
112
113    pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
114        self.config.content_filter = filter;
115        self
116    }
117
118    pub fn split(mut self) -> anyhow::Result<Stats> {
119        let file = File::open(&self.input_file)?;
120        let file_size = file.metadata()?.len();
121        let buffer_size = determine_buffer_size(file_size);
122        let dialect = self.config.dialect;
123        let content_filter = self.config.content_filter;
124
125        // Detect and apply decompression
126        let compression = Compression::from_path(&self.input_file);
127
128        let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
129            let progress_reader = ProgressReader::new(file, cb);
130            compression.wrap_reader(Box::new(progress_reader))
131        } else {
132            compression.wrap_reader(Box::new(file))
133        };
134
135        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
136
137        let mut writer_pool = WriterPool::new(self.output_dir.clone());
138        if !self.config.dry_run {
139            writer_pool.ensure_output_dir()?;
140        }
141
142        let mut tables_seen: AHashSet<String> = AHashSet::new();
143        let mut stats = Stats {
144            statements_processed: 0,
145            tables_found: 0,
146            bytes_processed: 0,
147            table_names: Vec::new(),
148        };
149
150        // Track the last COPY table for PostgreSQL COPY data blocks
151        let mut last_copy_table: Option<String> = None;
152
153        while let Some(stmt) = parser.read_statement()? {
154            let (stmt_type, mut table_name) =
155                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
156
157            // Track COPY statements for data association
158            if stmt_type == StatementType::Copy {
159                last_copy_table = Some(table_name.clone());
160            }
161
162            // Handle PostgreSQL COPY data blocks - associate with last COPY table
163            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
164                // Check if this looks like COPY data (ends with \.\n)
165                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
166                    table_name = last_copy_table.take().unwrap();
167                    true
168                } else {
169                    false
170                }
171            } else {
172                false
173            };
174
175            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
176                continue;
177            }
178
179            // Apply content filter (schema-only or data-only)
180            match content_filter {
181                ContentFilter::SchemaOnly => {
182                    if !stmt_type.is_schema() {
183                        continue;
184                    }
185                }
186                ContentFilter::DataOnly => {
187                    // For data-only, include INSERT, COPY, and COPY data blocks
188                    if !stmt_type.is_data() && !is_copy_data {
189                        continue;
190                    }
191                }
192                ContentFilter::All => {}
193            }
194
195            if let Some(ref filter) = self.config.table_filter {
196                if !filter.contains(&table_name) {
197                    continue;
198                }
199            }
200
201            if !tables_seen.contains(&table_name) {
202                tables_seen.insert(table_name.clone());
203                stats.tables_found += 1;
204                stats.table_names.push(table_name.clone());
205            }
206
207            if !self.config.dry_run {
208                writer_pool.write_statement(&table_name, &stmt)?;
209            }
210
211            stats.statements_processed += 1;
212            stats.bytes_processed += stmt.len() as u64;
213        }
214
215        if !self.config.dry_run {
216            writer_pool.close_all()?;
217        }
218
219        Ok(stats)
220    }
221}