Skip to main content

sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use anyhow::Context;
6use serde::Serialize;
7use std::fs::File;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10
11/// Statistics from a split operation.
12#[derive(Serialize)]
13pub struct Stats {
14    /// Total statements processed.
15    pub statements_processed: u64,
16    /// Number of unique tables found.
17    pub tables_found: usize,
18    /// Total bytes processed from input.
19    pub bytes_processed: u64,
20    /// Names of all tables found.
21    pub table_names: Vec<String>,
22}
23
24/// Configuration for the splitter.
25#[derive(Default)]
26pub struct SplitterConfig {
27    /// SQL dialect for parsing.
28    pub dialect: SqlDialect,
29    /// If true, parse without writing output files.
30    pub dry_run: bool,
31    /// If set, only process tables in this set.
32    pub table_filter: Option<AHashSet<String>>,
33    /// Optional callback for progress reporting.
34    pub progress_fn: Option<Box<dyn Fn(u64)>>,
35    /// Filter for which statement types to include.
36    pub content_filter: ContentFilter,
37}
38
39/// Compression format detected from file extension
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Compression {
42    None,
43    Gzip,
44    Bzip2,
45    Xz,
46    Zstd,
47}
48
49impl Compression {
50    /// Detect compression format from file extension
51    pub fn from_path(path: &Path) -> Self {
52        let ext = path
53            .extension()
54            .and_then(|e| e.to_str())
55            .map(|e| e.to_lowercase());
56
57        match ext.as_deref() {
58            Some("gz" | "gzip") => Compression::Gzip,
59            Some("bz2" | "bzip2") => Compression::Bzip2,
60            Some("xz" | "lzma") => Compression::Xz,
61            Some("zst" | "zstd") => Compression::Zstd,
62            _ => Compression::None,
63        }
64    }
65
66    /// Wrap a reader with the appropriate decompressor
67    pub fn wrap_reader<'a>(
68        &self,
69        reader: Box<dyn Read + 'a>,
70    ) -> std::io::Result<Box<dyn Read + 'a>> {
71        Ok(match self {
72            Compression::None => reader,
73            #[cfg(feature = "compression")]
74            Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
75            #[cfg(feature = "compression")]
76            Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
77            #[cfg(feature = "compression")]
78            Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
79            #[cfg(feature = "compression")]
80            Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader)?),
81            #[cfg(not(feature = "compression"))]
82            _ => {
83                return Err(std::io::Error::new(
84                    std::io::ErrorKind::Unsupported,
85                    "compressed input requires the `compression` feature",
86                ))
87            }
88        })
89    }
90}
91
92impl std::fmt::Display for Compression {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            Compression::None => write!(f, "none"),
96            Compression::Gzip => write!(f, "gzip"),
97            Compression::Bzip2 => write!(f, "bzip2"),
98            Compression::Xz => write!(f, "xz"),
99            Compression::Zstd => write!(f, "zstd"),
100        }
101    }
102}
103
104pub struct Splitter {
105    input_file: PathBuf,
106    output_dir: PathBuf,
107    config: SplitterConfig,
108}
109
110impl Splitter {
111    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
112        Self {
113            input_file,
114            output_dir,
115            config: SplitterConfig::default(),
116        }
117    }
118
119    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
120        self.config.dialect = dialect;
121        self
122    }
123
124    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
125        self.config.dry_run = dry_run;
126        self
127    }
128
129    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
130        if !tables.is_empty() {
131            self.config.table_filter = Some(tables.into_iter().collect());
132        }
133        self
134    }
135
136    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
137        self.config.progress_fn = Some(Box::new(f));
138        self
139    }
140
141    pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
142        self.config.content_filter = filter;
143        self
144    }
145
146    pub fn split(mut self) -> anyhow::Result<Stats> {
147        let file = File::open(&self.input_file)
148            .with_context(|| format!("Failed to open input file: {:?}", self.input_file))?;
149        let file_size = file.metadata()?.len();
150        let buffer_size = determine_buffer_size(file_size);
151        let dialect = self.config.dialect;
152        let content_filter = self.config.content_filter;
153
154        // Detect and apply decompression
155        let compression = Compression::from_path(&self.input_file);
156
157        let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
158            let progress_reader = ProgressReader::new(file, cb);
159            compression
160                .wrap_reader(Box::new(progress_reader))
161                .with_context(|| {
162                    format!(
163                        "Failed to initialize {} decompression for {:?}",
164                        compression, self.input_file
165                    )
166                })?
167        } else {
168            compression.wrap_reader(Box::new(file)).with_context(|| {
169                format!(
170                    "Failed to initialize {} decompression for {:?}",
171                    compression, self.input_file
172                )
173            })?
174        };
175
176        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
177
178        let mut writer_pool = WriterPool::new(self.output_dir.clone());
179        if !self.config.dry_run {
180            writer_pool.ensure_output_dir().with_context(|| {
181                format!("Failed to create output directory: {:?}", self.output_dir)
182            })?;
183        }
184
185        let mut tables_seen: AHashSet<String> = AHashSet::new();
186        let mut stats = Stats {
187            statements_processed: 0,
188            tables_found: 0,
189            bytes_processed: 0,
190            table_names: Vec::new(),
191        };
192
193        // Track the last COPY table for PostgreSQL COPY data blocks
194        let mut last_copy_table: Option<String> = None;
195
196        while let Some(stmt) = parser.read_statement()? {
197            let (stmt_type, mut table_name) =
198                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
199
200            // Track COPY statements for data association
201            if stmt_type == StatementType::Copy {
202                last_copy_table = Some(table_name.clone());
203            }
204
205            // Handle PostgreSQL COPY data blocks - associate with last COPY table
206            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
207                // Check if this looks like COPY data (ends with \.\n)
208                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
209                    // Safe: we just checked is_some() above
210                    if let Some(copy_table) = last_copy_table.take() {
211                        table_name = copy_table;
212                        true
213                    } else {
214                        false
215                    }
216                } else {
217                    false
218                }
219            } else {
220                false
221            };
222
223            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
224                continue;
225            }
226
227            // Apply content filter (schema-only or data-only)
228            match content_filter {
229                ContentFilter::SchemaOnly => {
230                    if !stmt_type.is_schema() {
231                        continue;
232                    }
233                }
234                ContentFilter::DataOnly => {
235                    // For data-only, include INSERT, COPY, and COPY data blocks
236                    if !stmt_type.is_data() && !is_copy_data {
237                        continue;
238                    }
239                }
240                ContentFilter::All => {}
241            }
242
243            if let Some(ref filter) = self.config.table_filter {
244                if !filter.contains(&table_name) {
245                    continue;
246                }
247            }
248
249            if !tables_seen.contains(&table_name) {
250                tables_seen.insert(table_name.clone());
251                stats.tables_found += 1;
252                stats.table_names.push(table_name.clone());
253            }
254
255            if !self.config.dry_run {
256                // For MSSQL, add semicolon if statement doesn't end with one
257                // (MSSQL uses GO as batch separator, but we need semicolons for re-parsing)
258                let write_result = if self.config.dialect == SqlDialect::Mssql {
259                    let trimmed = stmt
260                        .iter()
261                        .rev()
262                        .find(|&&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t');
263                    if trimmed != Some(&b';') {
264                        // Write statement + semicolon without cloning
265                        writer_pool.write_statement_with_suffix(&table_name, &stmt, b";")
266                    } else {
267                        writer_pool.write_statement(&table_name, &stmt)
268                    }
269                } else {
270                    writer_pool.write_statement(&table_name, &stmt)
271                };
272                write_result.with_context(|| {
273                    format!("Failed to write statement to table file: {}", table_name)
274                })?;
275            }
276
277            stats.statements_processed += 1;
278            stats.bytes_processed += stmt.len() as u64;
279        }
280
281        if !self.config.dry_run {
282            writer_pool.close_all()?;
283        }
284
285        Ok(stats)
286    }
287}