1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::writer::WriterPool;
3use ahash::AHashSet;
4use std::fs::File;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7
8pub struct Stats {
9 pub statements_processed: u64,
10 pub tables_found: usize,
11 pub bytes_processed: u64,
12 pub table_names: Vec<String>,
13}
14
15#[derive(Default)]
16pub struct SplitterConfig {
17 pub dialect: SqlDialect,
18 pub dry_run: bool,
19 pub table_filter: Option<AHashSet<String>>,
20 pub progress_fn: Option<Box<dyn Fn(u64)>>,
21 pub content_filter: ContentFilter,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum Compression {
27 None,
28 Gzip,
29 Bzip2,
30 Xz,
31 Zstd,
32}
33
34impl Compression {
35 pub fn from_path(path: &Path) -> Self {
37 let ext = path
38 .extension()
39 .and_then(|e| e.to_str())
40 .map(|e| e.to_lowercase());
41
42 match ext.as_deref() {
43 Some("gz" | "gzip") => Compression::Gzip,
44 Some("bz2" | "bzip2") => Compression::Bzip2,
45 Some("xz" | "lzma") => Compression::Xz,
46 Some("zst" | "zstd") => Compression::Zstd,
47 _ => Compression::None,
48 }
49 }
50
51 pub fn wrap_reader<'a>(&self, reader: Box<dyn Read + 'a>) -> Box<dyn Read + 'a> {
53 match self {
54 Compression::None => reader,
55 Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
56 Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
57 Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
58 Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader).unwrap()),
59 }
60 }
61}
62
63impl std::fmt::Display for Compression {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 Compression::None => write!(f, "none"),
67 Compression::Gzip => write!(f, "gzip"),
68 Compression::Bzip2 => write!(f, "bzip2"),
69 Compression::Xz => write!(f, "xz"),
70 Compression::Zstd => write!(f, "zstd"),
71 }
72 }
73}
74
75pub struct Splitter {
76 input_file: PathBuf,
77 output_dir: PathBuf,
78 config: SplitterConfig,
79}
80
81impl Splitter {
82 pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
83 Self {
84 input_file,
85 output_dir,
86 config: SplitterConfig::default(),
87 }
88 }
89
90 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
91 self.config.dialect = dialect;
92 self
93 }
94
95 pub fn with_dry_run(mut self, dry_run: bool) -> Self {
96 self.config.dry_run = dry_run;
97 self
98 }
99
100 pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
101 if !tables.is_empty() {
102 self.config.table_filter = Some(tables.into_iter().collect());
103 }
104 self
105 }
106
107 pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
108 self.config.progress_fn = Some(Box::new(f));
109 self
110 }
111
112 pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
113 self.config.content_filter = filter;
114 self
115 }
116
117 pub fn split(self) -> anyhow::Result<Stats> {
118 let file = File::open(&self.input_file)?;
119 let file_size = file.metadata()?.len();
120 let buffer_size = determine_buffer_size(file_size);
121 let dialect = self.config.dialect;
122 let content_filter = self.config.content_filter;
123
124 let compression = Compression::from_path(&self.input_file);
126
127 let reader: Box<dyn Read> = if self.config.progress_fn.is_some() {
128 let progress_reader = ProgressReader::new(file, self.config.progress_fn.unwrap());
129 compression.wrap_reader(Box::new(progress_reader))
130 } else {
131 compression.wrap_reader(Box::new(file))
132 };
133
134 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
135
136 let mut writer_pool = WriterPool::new(self.output_dir.clone());
137 if !self.config.dry_run {
138 writer_pool.ensure_output_dir()?;
139 }
140
141 let mut tables_seen: AHashSet<String> = AHashSet::new();
142 let mut stats = Stats {
143 statements_processed: 0,
144 tables_found: 0,
145 bytes_processed: 0,
146 table_names: Vec::new(),
147 };
148
149 let mut last_copy_table: Option<String> = None;
151
152 while let Some(stmt) = parser.read_statement()? {
153 let (stmt_type, mut table_name) =
154 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
155
156 if stmt_type == StatementType::Copy {
158 last_copy_table = Some(table_name.clone());
159 }
160
161 let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
163 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
165 table_name = last_copy_table.take().unwrap();
166 true
167 } else {
168 false
169 }
170 } else {
171 false
172 };
173
174 if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
175 continue;
176 }
177
178 match content_filter {
180 ContentFilter::SchemaOnly => {
181 if !stmt_type.is_schema() {
182 continue;
183 }
184 }
185 ContentFilter::DataOnly => {
186 if !stmt_type.is_data() && !is_copy_data {
188 continue;
189 }
190 }
191 ContentFilter::All => {}
192 }
193
194 if let Some(ref filter) = self.config.table_filter {
195 if !filter.contains(&table_name) {
196 continue;
197 }
198 }
199
200 if !tables_seen.contains(&table_name) {
201 tables_seen.insert(table_name.clone());
202 stats.tables_found += 1;
203 stats.table_names.push(table_name.clone());
204 }
205
206 if !self.config.dry_run {
207 writer_pool.write_statement(&table_name, &stmt)?;
208 }
209
210 stats.statements_processed += 1;
211 stats.bytes_processed += stmt.len() as u64;
212 }
213
214 if !self.config.dry_run {
215 writer_pool.close_all()?;
216 }
217
218 Ok(stats)
219 }
220}
221
222struct ProgressReader<R: Read> {
223 reader: R,
224 callback: Box<dyn Fn(u64)>,
225 bytes_read: u64,
226}
227
228impl<R: Read> ProgressReader<R> {
229 fn new(reader: R, callback: Box<dyn Fn(u64)>) -> Self {
230 Self {
231 reader,
232 callback,
233 bytes_read: 0,
234 }
235 }
236}
237
238impl<R: Read> Read for ProgressReader<R> {
239 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
240 let n = self.reader.read(buf)?;
241 self.bytes_read += n as u64;
242 (self.callback)(self.bytes_read);
243 Ok(n)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use tempfile::TempDir;
251
252 #[test]
253 fn test_splitter_basic() {
254 let temp_dir = TempDir::new().unwrap();
255 let input_file = temp_dir.path().join("input.sql");
256 let output_dir = temp_dir.path().join("output");
257
258 std::fs::write(
259 &input_file,
260 b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nCREATE TABLE posts (id INT);\n",
261 )
262 .unwrap();
263
264 let splitter = Splitter::new(input_file, output_dir.clone());
265 let stats = splitter.split().unwrap();
266
267 assert_eq!(stats.tables_found, 2);
268 assert_eq!(stats.statements_processed, 3);
269
270 assert!(output_dir.join("users.sql").exists());
271 assert!(output_dir.join("posts.sql").exists());
272 }
273
274 #[test]
275 fn test_splitter_dry_run() {
276 let temp_dir = TempDir::new().unwrap();
277 let input_file = temp_dir.path().join("input.sql");
278 let output_dir = temp_dir.path().join("output");
279
280 std::fs::write(&input_file, b"CREATE TABLE users (id INT);").unwrap();
281
282 let splitter = Splitter::new(input_file, output_dir.clone()).with_dry_run(true);
283 let stats = splitter.split().unwrap();
284
285 assert_eq!(stats.tables_found, 1);
286 assert!(!output_dir.exists());
287 }
288
289 #[test]
290 fn test_splitter_table_filter() {
291 let temp_dir = TempDir::new().unwrap();
292 let input_file = temp_dir.path().join("input.sql");
293 let output_dir = temp_dir.path().join("output");
294
295 std::fs::write(
296 &input_file,
297 b"CREATE TABLE users (id INT);\nCREATE TABLE posts (id INT);\nCREATE TABLE orders (id INT);",
298 )
299 .unwrap();
300
301 let splitter = Splitter::new(input_file, output_dir.clone())
302 .with_table_filter(vec!["users".to_string(), "orders".to_string()]);
303 let stats = splitter.split().unwrap();
304
305 assert_eq!(stats.tables_found, 2);
306 assert!(output_dir.join("users.sql").exists());
307 assert!(!output_dir.join("posts.sql").exists());
308 assert!(output_dir.join("orders.sql").exists());
309 }
310
311 #[test]
312 fn test_splitter_schema_only() {
313 use crate::parser::ContentFilter;
314
315 let temp_dir = TempDir::new().unwrap();
316 let input_file = temp_dir.path().join("input.sql");
317 let output_dir = temp_dir.path().join("output");
318
319 std::fs::write(
320 &input_file,
321 b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);",
322 )
323 .unwrap();
324
325 let splitter = Splitter::new(input_file, output_dir.clone())
326 .with_content_filter(ContentFilter::SchemaOnly);
327 let stats = splitter.split().unwrap();
328
329 assert_eq!(stats.tables_found, 1);
330 assert_eq!(stats.statements_processed, 1); let content = std::fs::read_to_string(output_dir.join("users.sql")).unwrap();
333 assert!(content.contains("CREATE TABLE"));
334 assert!(!content.contains("INSERT"));
335 }
336
337 #[test]
338 fn test_splitter_data_only() {
339 use crate::parser::ContentFilter;
340
341 let temp_dir = TempDir::new().unwrap();
342 let input_file = temp_dir.path().join("input.sql");
343 let output_dir = temp_dir.path().join("output");
344
345 std::fs::write(
346 &input_file,
347 b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);",
348 )
349 .unwrap();
350
351 let splitter = Splitter::new(input_file, output_dir.clone())
352 .with_content_filter(ContentFilter::DataOnly);
353 let stats = splitter.split().unwrap();
354
355 assert_eq!(stats.tables_found, 1);
356 assert_eq!(stats.statements_processed, 2); let content = std::fs::read_to_string(output_dir.join("users.sql")).unwrap();
359 assert!(!content.contains("CREATE TABLE"));
360 assert!(content.contains("INSERT"));
361 }
362
363 #[test]
364 fn test_splitter_gzip_compressed() {
365 use flate2::write::GzEncoder;
366 use flate2::Compression as GzCompression;
367 use std::io::Write;
368
369 let temp_dir = TempDir::new().unwrap();
370 let input_file = temp_dir.path().join("input.sql.gz");
371 let output_dir = temp_dir.path().join("output");
372
373 let file = std::fs::File::create(&input_file).unwrap();
375 let mut encoder = GzEncoder::new(file, GzCompression::default());
376 encoder
377 .write_all(b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);")
378 .unwrap();
379 encoder.finish().unwrap();
380
381 let splitter = Splitter::new(input_file, output_dir.clone());
382 let stats = splitter.split().unwrap();
383
384 assert_eq!(stats.tables_found, 1);
385 assert_eq!(stats.statements_processed, 2);
386 assert!(output_dir.join("users.sql").exists());
387 }
388
389 #[test]
390 fn test_compression_detection() {
391 assert_eq!(
392 Compression::from_path(std::path::Path::new("file.sql")),
393 Compression::None
394 );
395 assert_eq!(
396 Compression::from_path(std::path::Path::new("file.sql.gz")),
397 Compression::Gzip
398 );
399 assert_eq!(
400 Compression::from_path(std::path::Path::new("file.sql.bz2")),
401 Compression::Bzip2
402 );
403 assert_eq!(
404 Compression::from_path(std::path::Path::new("file.sql.xz")),
405 Compression::Xz
406 );
407 assert_eq!(
408 Compression::from_path(std::path::Path::new("file.sql.zst")),
409 Compression::Zstd
410 );
411 }
412}