sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::writer::WriterPool;
3use ahash::AHashSet;
4use std::fs::File;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7
8pub struct Stats {
9    pub statements_processed: u64,
10    pub tables_found: usize,
11    pub bytes_processed: u64,
12    pub table_names: Vec<String>,
13}
14
15#[derive(Default)]
16pub struct SplitterConfig {
17    pub dialect: SqlDialect,
18    pub dry_run: bool,
19    pub table_filter: Option<AHashSet<String>>,
20    pub progress_fn: Option<Box<dyn Fn(u64)>>,
21    pub content_filter: ContentFilter,
22}
23
24/// Compression format detected from file extension
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum Compression {
27    None,
28    Gzip,
29    Bzip2,
30    Xz,
31    Zstd,
32}
33
34impl Compression {
35    /// Detect compression format from file extension
36    pub fn from_path(path: &Path) -> Self {
37        let ext = path
38            .extension()
39            .and_then(|e| e.to_str())
40            .map(|e| e.to_lowercase());
41
42        match ext.as_deref() {
43            Some("gz" | "gzip") => Compression::Gzip,
44            Some("bz2" | "bzip2") => Compression::Bzip2,
45            Some("xz" | "lzma") => Compression::Xz,
46            Some("zst" | "zstd") => Compression::Zstd,
47            _ => Compression::None,
48        }
49    }
50
51    /// Wrap a reader with the appropriate decompressor
52    pub fn wrap_reader<'a>(&self, reader: Box<dyn Read + 'a>) -> Box<dyn Read + 'a> {
53        match self {
54            Compression::None => reader,
55            Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
56            Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
57            Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
58            Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader).unwrap()),
59        }
60    }
61}
62
63impl std::fmt::Display for Compression {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Compression::None => write!(f, "none"),
67            Compression::Gzip => write!(f, "gzip"),
68            Compression::Bzip2 => write!(f, "bzip2"),
69            Compression::Xz => write!(f, "xz"),
70            Compression::Zstd => write!(f, "zstd"),
71        }
72    }
73}
74
75pub struct Splitter {
76    input_file: PathBuf,
77    output_dir: PathBuf,
78    config: SplitterConfig,
79}
80
81impl Splitter {
82    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
83        Self {
84            input_file,
85            output_dir,
86            config: SplitterConfig::default(),
87        }
88    }
89
90    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
91        self.config.dialect = dialect;
92        self
93    }
94
95    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
96        self.config.dry_run = dry_run;
97        self
98    }
99
100    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
101        if !tables.is_empty() {
102            self.config.table_filter = Some(tables.into_iter().collect());
103        }
104        self
105    }
106
107    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
108        self.config.progress_fn = Some(Box::new(f));
109        self
110    }
111
112    pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
113        self.config.content_filter = filter;
114        self
115    }
116
117    pub fn split(self) -> anyhow::Result<Stats> {
118        let file = File::open(&self.input_file)?;
119        let file_size = file.metadata()?.len();
120        let buffer_size = determine_buffer_size(file_size);
121        let dialect = self.config.dialect;
122        let content_filter = self.config.content_filter;
123
124        // Detect and apply decompression
125        let compression = Compression::from_path(&self.input_file);
126
127        let reader: Box<dyn Read> = if self.config.progress_fn.is_some() {
128            let progress_reader = ProgressReader::new(file, self.config.progress_fn.unwrap());
129            compression.wrap_reader(Box::new(progress_reader))
130        } else {
131            compression.wrap_reader(Box::new(file))
132        };
133
134        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
135
136        let mut writer_pool = WriterPool::new(self.output_dir.clone());
137        if !self.config.dry_run {
138            writer_pool.ensure_output_dir()?;
139        }
140
141        let mut tables_seen: AHashSet<String> = AHashSet::new();
142        let mut stats = Stats {
143            statements_processed: 0,
144            tables_found: 0,
145            bytes_processed: 0,
146            table_names: Vec::new(),
147        };
148
149        // Track the last COPY table for PostgreSQL COPY data blocks
150        let mut last_copy_table: Option<String> = None;
151
152        while let Some(stmt) = parser.read_statement()? {
153            let (stmt_type, mut table_name) =
154                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
155
156            // Track COPY statements for data association
157            if stmt_type == StatementType::Copy {
158                last_copy_table = Some(table_name.clone());
159            }
160
161            // Handle PostgreSQL COPY data blocks - associate with last COPY table
162            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
163                // Check if this looks like COPY data (ends with \.\n)
164                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
165                    table_name = last_copy_table.take().unwrap();
166                    true
167                } else {
168                    false
169                }
170            } else {
171                false
172            };
173
174            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
175                continue;
176            }
177
178            // Apply content filter (schema-only or data-only)
179            match content_filter {
180                ContentFilter::SchemaOnly => {
181                    if !stmt_type.is_schema() {
182                        continue;
183                    }
184                }
185                ContentFilter::DataOnly => {
186                    // For data-only, include INSERT, COPY, and COPY data blocks
187                    if !stmt_type.is_data() && !is_copy_data {
188                        continue;
189                    }
190                }
191                ContentFilter::All => {}
192            }
193
194            if let Some(ref filter) = self.config.table_filter {
195                if !filter.contains(&table_name) {
196                    continue;
197                }
198            }
199
200            if !tables_seen.contains(&table_name) {
201                tables_seen.insert(table_name.clone());
202                stats.tables_found += 1;
203                stats.table_names.push(table_name.clone());
204            }
205
206            if !self.config.dry_run {
207                writer_pool.write_statement(&table_name, &stmt)?;
208            }
209
210            stats.statements_processed += 1;
211            stats.bytes_processed += stmt.len() as u64;
212        }
213
214        if !self.config.dry_run {
215            writer_pool.close_all()?;
216        }
217
218        Ok(stats)
219    }
220}
221
222struct ProgressReader<R: Read> {
223    reader: R,
224    callback: Box<dyn Fn(u64)>,
225    bytes_read: u64,
226}
227
228impl<R: Read> ProgressReader<R> {
229    fn new(reader: R, callback: Box<dyn Fn(u64)>) -> Self {
230        Self {
231            reader,
232            callback,
233            bytes_read: 0,
234        }
235    }
236}
237
238impl<R: Read> Read for ProgressReader<R> {
239    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
240        let n = self.reader.read(buf)?;
241        self.bytes_read += n as u64;
242        (self.callback)(self.bytes_read);
243        Ok(n)
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use tempfile::TempDir;
251
252    #[test]
253    fn test_splitter_basic() {
254        let temp_dir = TempDir::new().unwrap();
255        let input_file = temp_dir.path().join("input.sql");
256        let output_dir = temp_dir.path().join("output");
257
258        std::fs::write(
259            &input_file,
260            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nCREATE TABLE posts (id INT);\n",
261        )
262        .unwrap();
263
264        let splitter = Splitter::new(input_file, output_dir.clone());
265        let stats = splitter.split().unwrap();
266
267        assert_eq!(stats.tables_found, 2);
268        assert_eq!(stats.statements_processed, 3);
269
270        assert!(output_dir.join("users.sql").exists());
271        assert!(output_dir.join("posts.sql").exists());
272    }
273
274    #[test]
275    fn test_splitter_dry_run() {
276        let temp_dir = TempDir::new().unwrap();
277        let input_file = temp_dir.path().join("input.sql");
278        let output_dir = temp_dir.path().join("output");
279
280        std::fs::write(&input_file, b"CREATE TABLE users (id INT);").unwrap();
281
282        let splitter = Splitter::new(input_file, output_dir.clone()).with_dry_run(true);
283        let stats = splitter.split().unwrap();
284
285        assert_eq!(stats.tables_found, 1);
286        assert!(!output_dir.exists());
287    }
288
289    #[test]
290    fn test_splitter_table_filter() {
291        let temp_dir = TempDir::new().unwrap();
292        let input_file = temp_dir.path().join("input.sql");
293        let output_dir = temp_dir.path().join("output");
294
295        std::fs::write(
296            &input_file,
297            b"CREATE TABLE users (id INT);\nCREATE TABLE posts (id INT);\nCREATE TABLE orders (id INT);",
298        )
299        .unwrap();
300
301        let splitter = Splitter::new(input_file, output_dir.clone())
302            .with_table_filter(vec!["users".to_string(), "orders".to_string()]);
303        let stats = splitter.split().unwrap();
304
305        assert_eq!(stats.tables_found, 2);
306        assert!(output_dir.join("users.sql").exists());
307        assert!(!output_dir.join("posts.sql").exists());
308        assert!(output_dir.join("orders.sql").exists());
309    }
310
311    #[test]
312    fn test_splitter_schema_only() {
313        use crate::parser::ContentFilter;
314
315        let temp_dir = TempDir::new().unwrap();
316        let input_file = temp_dir.path().join("input.sql");
317        let output_dir = temp_dir.path().join("output");
318
319        std::fs::write(
320            &input_file,
321            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);",
322        )
323        .unwrap();
324
325        let splitter = Splitter::new(input_file, output_dir.clone())
326            .with_content_filter(ContentFilter::SchemaOnly);
327        let stats = splitter.split().unwrap();
328
329        assert_eq!(stats.tables_found, 1);
330        assert_eq!(stats.statements_processed, 1); // Only CREATE TABLE
331
332        let content = std::fs::read_to_string(output_dir.join("users.sql")).unwrap();
333        assert!(content.contains("CREATE TABLE"));
334        assert!(!content.contains("INSERT"));
335    }
336
337    #[test]
338    fn test_splitter_data_only() {
339        use crate::parser::ContentFilter;
340
341        let temp_dir = TempDir::new().unwrap();
342        let input_file = temp_dir.path().join("input.sql");
343        let output_dir = temp_dir.path().join("output");
344
345        std::fs::write(
346            &input_file,
347            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);",
348        )
349        .unwrap();
350
351        let splitter = Splitter::new(input_file, output_dir.clone())
352            .with_content_filter(ContentFilter::DataOnly);
353        let stats = splitter.split().unwrap();
354
355        assert_eq!(stats.tables_found, 1);
356        assert_eq!(stats.statements_processed, 2); // Only INSERTs
357
358        let content = std::fs::read_to_string(output_dir.join("users.sql")).unwrap();
359        assert!(!content.contains("CREATE TABLE"));
360        assert!(content.contains("INSERT"));
361    }
362
363    #[test]
364    fn test_splitter_gzip_compressed() {
365        use flate2::write::GzEncoder;
366        use flate2::Compression as GzCompression;
367        use std::io::Write;
368
369        let temp_dir = TempDir::new().unwrap();
370        let input_file = temp_dir.path().join("input.sql.gz");
371        let output_dir = temp_dir.path().join("output");
372
373        // Create gzipped SQL file
374        let file = std::fs::File::create(&input_file).unwrap();
375        let mut encoder = GzEncoder::new(file, GzCompression::default());
376        encoder
377            .write_all(b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);")
378            .unwrap();
379        encoder.finish().unwrap();
380
381        let splitter = Splitter::new(input_file, output_dir.clone());
382        let stats = splitter.split().unwrap();
383
384        assert_eq!(stats.tables_found, 1);
385        assert_eq!(stats.statements_processed, 2);
386        assert!(output_dir.join("users.sql").exists());
387    }
388
389    #[test]
390    fn test_compression_detection() {
391        assert_eq!(
392            Compression::from_path(std::path::Path::new("file.sql")),
393            Compression::None
394        );
395        assert_eq!(
396            Compression::from_path(std::path::Path::new("file.sql.gz")),
397            Compression::Gzip
398        );
399        assert_eq!(
400            Compression::from_path(std::path::Path::new("file.sql.bz2")),
401            Compression::Bzip2
402        );
403        assert_eq!(
404            Compression::from_path(std::path::Path::new("file.sql.xz")),
405            Compression::Xz
406        );
407        assert_eq!(
408            Compression::from_path(std::path::Path::new("file.sql.zst")),
409            Compression::Zstd
410        );
411    }
412}