sql_splitter/splitter/
mod.rs1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use anyhow::Context;
6use serde::Serialize;
7use std::fs::File;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10
11#[derive(Serialize)]
13pub struct Stats {
14 pub statements_processed: u64,
16 pub tables_found: usize,
18 pub bytes_processed: u64,
20 pub table_names: Vec<String>,
22}
23
24#[derive(Default)]
26pub struct SplitterConfig {
27 pub dialect: SqlDialect,
29 pub dry_run: bool,
31 pub table_filter: Option<AHashSet<String>>,
33 pub progress_fn: Option<Box<dyn Fn(u64)>>,
35 pub content_filter: ContentFilter,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Compression {
42 None,
43 Gzip,
44 Bzip2,
45 Xz,
46 Zstd,
47}
48
49impl Compression {
50 pub fn from_path(path: &Path) -> Self {
52 let ext = path
53 .extension()
54 .and_then(|e| e.to_str())
55 .map(|e| e.to_lowercase());
56
57 match ext.as_deref() {
58 Some("gz" | "gzip") => Compression::Gzip,
59 Some("bz2" | "bzip2") => Compression::Bzip2,
60 Some("xz" | "lzma") => Compression::Xz,
61 Some("zst" | "zstd") => Compression::Zstd,
62 _ => Compression::None,
63 }
64 }
65
66 pub fn wrap_reader<'a>(
68 &self,
69 reader: Box<dyn Read + 'a>,
70 ) -> std::io::Result<Box<dyn Read + 'a>> {
71 Ok(match self {
72 Compression::None => reader,
73 Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
74 Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
75 Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
76 Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader)?),
77 })
78 }
79}
80
81impl std::fmt::Display for Compression {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Compression::None => write!(f, "none"),
85 Compression::Gzip => write!(f, "gzip"),
86 Compression::Bzip2 => write!(f, "bzip2"),
87 Compression::Xz => write!(f, "xz"),
88 Compression::Zstd => write!(f, "zstd"),
89 }
90 }
91}
92
93pub struct Splitter {
94 input_file: PathBuf,
95 output_dir: PathBuf,
96 config: SplitterConfig,
97}
98
99impl Splitter {
100 pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
101 Self {
102 input_file,
103 output_dir,
104 config: SplitterConfig::default(),
105 }
106 }
107
108 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
109 self.config.dialect = dialect;
110 self
111 }
112
113 pub fn with_dry_run(mut self, dry_run: bool) -> Self {
114 self.config.dry_run = dry_run;
115 self
116 }
117
118 pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
119 if !tables.is_empty() {
120 self.config.table_filter = Some(tables.into_iter().collect());
121 }
122 self
123 }
124
125 pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
126 self.config.progress_fn = Some(Box::new(f));
127 self
128 }
129
130 pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
131 self.config.content_filter = filter;
132 self
133 }
134
135 pub fn split(mut self) -> anyhow::Result<Stats> {
136 let file = File::open(&self.input_file)
137 .with_context(|| format!("Failed to open input file: {:?}", self.input_file))?;
138 let file_size = file.metadata()?.len();
139 let buffer_size = determine_buffer_size(file_size);
140 let dialect = self.config.dialect;
141 let content_filter = self.config.content_filter;
142
143 let compression = Compression::from_path(&self.input_file);
145
146 let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
147 let progress_reader = ProgressReader::new(file, cb);
148 compression
149 .wrap_reader(Box::new(progress_reader))
150 .with_context(|| {
151 format!(
152 "Failed to initialize {} decompression for {:?}",
153 compression, self.input_file
154 )
155 })?
156 } else {
157 compression.wrap_reader(Box::new(file)).with_context(|| {
158 format!(
159 "Failed to initialize {} decompression for {:?}",
160 compression, self.input_file
161 )
162 })?
163 };
164
165 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
166
167 let mut writer_pool = WriterPool::new(self.output_dir.clone());
168 if !self.config.dry_run {
169 writer_pool.ensure_output_dir().with_context(|| {
170 format!("Failed to create output directory: {:?}", self.output_dir)
171 })?;
172 }
173
174 let mut tables_seen: AHashSet<String> = AHashSet::new();
175 let mut stats = Stats {
176 statements_processed: 0,
177 tables_found: 0,
178 bytes_processed: 0,
179 table_names: Vec::new(),
180 };
181
182 let mut last_copy_table: Option<String> = None;
184
185 while let Some(stmt) = parser.read_statement()? {
186 let (stmt_type, mut table_name) =
187 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
188
189 if stmt_type == StatementType::Copy {
191 last_copy_table = Some(table_name.clone());
192 }
193
194 let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
196 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
198 if let Some(copy_table) = last_copy_table.take() {
200 table_name = copy_table;
201 true
202 } else {
203 false
204 }
205 } else {
206 false
207 }
208 } else {
209 false
210 };
211
212 if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
213 continue;
214 }
215
216 match content_filter {
218 ContentFilter::SchemaOnly => {
219 if !stmt_type.is_schema() {
220 continue;
221 }
222 }
223 ContentFilter::DataOnly => {
224 if !stmt_type.is_data() && !is_copy_data {
226 continue;
227 }
228 }
229 ContentFilter::All => {}
230 }
231
232 if let Some(ref filter) = self.config.table_filter {
233 if !filter.contains(&table_name) {
234 continue;
235 }
236 }
237
238 if !tables_seen.contains(&table_name) {
239 tables_seen.insert(table_name.clone());
240 stats.tables_found += 1;
241 stats.table_names.push(table_name.clone());
242 }
243
244 if !self.config.dry_run {
245 let write_result = if self.config.dialect == SqlDialect::Mssql {
248 let trimmed = stmt
249 .iter()
250 .rev()
251 .find(|&&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t');
252 if trimmed != Some(&b';') {
253 writer_pool.write_statement_with_suffix(&table_name, &stmt, b";")
255 } else {
256 writer_pool.write_statement(&table_name, &stmt)
257 }
258 } else {
259 writer_pool.write_statement(&table_name, &stmt)
260 };
261 write_result.with_context(|| {
262 format!("Failed to write statement to table file: {}", table_name)
263 })?;
264 }
265
266 stats.statements_processed += 1;
267 stats.bytes_processed += stmt.len() as u64;
268 }
269
270 if !self.config.dry_run {
271 writer_pool.close_all()?;
272 }
273
274 Ok(stats)
275 }
276}