sql_splitter/splitter/
mod.rs1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use anyhow::Context;
6use serde::Serialize;
7use std::fs::File;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10
11#[derive(Serialize)]
13pub struct Stats {
14 pub statements_processed: u64,
16 pub tables_found: usize,
18 pub bytes_processed: u64,
20 pub table_names: Vec<String>,
22}
23
24#[derive(Default)]
26pub struct SplitterConfig {
27 pub dialect: SqlDialect,
29 pub dry_run: bool,
31 pub table_filter: Option<AHashSet<String>>,
33 pub progress_fn: Option<Box<dyn Fn(u64)>>,
35 pub content_filter: ContentFilter,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Compression {
42 None,
43 Gzip,
44 Bzip2,
45 Xz,
46 Zstd,
47}
48
49impl Compression {
50 pub fn from_path(path: &Path) -> Self {
52 let ext = path
53 .extension()
54 .and_then(|e| e.to_str())
55 .map(|e| e.to_lowercase());
56
57 match ext.as_deref() {
58 Some("gz" | "gzip") => Compression::Gzip,
59 Some("bz2" | "bzip2") => Compression::Bzip2,
60 Some("xz" | "lzma") => Compression::Xz,
61 Some("zst" | "zstd") => Compression::Zstd,
62 _ => Compression::None,
63 }
64 }
65
66 pub fn wrap_reader<'a>(
68 &self,
69 reader: Box<dyn Read + 'a>,
70 ) -> std::io::Result<Box<dyn Read + 'a>> {
71 Ok(match self {
72 Compression::None => reader,
73 #[cfg(feature = "compression")]
74 Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
75 #[cfg(feature = "compression")]
76 Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
77 #[cfg(feature = "compression")]
78 Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
79 #[cfg(feature = "compression")]
80 Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader)?),
81 #[cfg(not(feature = "compression"))]
82 _ => {
83 return Err(std::io::Error::new(
84 std::io::ErrorKind::Unsupported,
85 "compressed input requires the `compression` feature",
86 ))
87 }
88 })
89 }
90}
91
92impl std::fmt::Display for Compression {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 Compression::None => write!(f, "none"),
96 Compression::Gzip => write!(f, "gzip"),
97 Compression::Bzip2 => write!(f, "bzip2"),
98 Compression::Xz => write!(f, "xz"),
99 Compression::Zstd => write!(f, "zstd"),
100 }
101 }
102}
103
104pub struct Splitter {
105 input_file: PathBuf,
106 output_dir: PathBuf,
107 config: SplitterConfig,
108}
109
110impl Splitter {
111 pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
112 Self {
113 input_file,
114 output_dir,
115 config: SplitterConfig::default(),
116 }
117 }
118
119 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
120 self.config.dialect = dialect;
121 self
122 }
123
124 pub fn with_dry_run(mut self, dry_run: bool) -> Self {
125 self.config.dry_run = dry_run;
126 self
127 }
128
129 pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
130 if !tables.is_empty() {
131 self.config.table_filter = Some(tables.into_iter().collect());
132 }
133 self
134 }
135
136 pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
137 self.config.progress_fn = Some(Box::new(f));
138 self
139 }
140
141 pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
142 self.config.content_filter = filter;
143 self
144 }
145
146 pub fn split(mut self) -> anyhow::Result<Stats> {
147 let file = File::open(&self.input_file)
148 .with_context(|| format!("Failed to open input file: {:?}", self.input_file))?;
149 let file_size = file.metadata()?.len();
150 let buffer_size = determine_buffer_size(file_size);
151 let dialect = self.config.dialect;
152 let content_filter = self.config.content_filter;
153
154 let compression = Compression::from_path(&self.input_file);
156
157 let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
158 let progress_reader = ProgressReader::new(file, cb);
159 compression
160 .wrap_reader(Box::new(progress_reader))
161 .with_context(|| {
162 format!(
163 "Failed to initialize {} decompression for {:?}",
164 compression, self.input_file
165 )
166 })?
167 } else {
168 compression.wrap_reader(Box::new(file)).with_context(|| {
169 format!(
170 "Failed to initialize {} decompression for {:?}",
171 compression, self.input_file
172 )
173 })?
174 };
175
176 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
177
178 let mut writer_pool = WriterPool::new(self.output_dir.clone());
179 if !self.config.dry_run {
180 writer_pool.ensure_output_dir().with_context(|| {
181 format!("Failed to create output directory: {:?}", self.output_dir)
182 })?;
183 }
184
185 let mut tables_seen: AHashSet<String> = AHashSet::new();
186 let mut stats = Stats {
187 statements_processed: 0,
188 tables_found: 0,
189 bytes_processed: 0,
190 table_names: Vec::new(),
191 };
192
193 let mut last_copy_table: Option<String> = None;
195
196 while let Some(stmt) = parser.read_statement()? {
197 let (stmt_type, mut table_name) =
198 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
199
200 if stmt_type == StatementType::Copy {
202 last_copy_table = Some(table_name.clone());
203 }
204
205 let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
207 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
209 if let Some(copy_table) = last_copy_table.take() {
211 table_name = copy_table;
212 true
213 } else {
214 false
215 }
216 } else {
217 false
218 }
219 } else {
220 false
221 };
222
223 if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
224 continue;
225 }
226
227 match content_filter {
229 ContentFilter::SchemaOnly => {
230 if !stmt_type.is_schema() {
231 continue;
232 }
233 }
234 ContentFilter::DataOnly => {
235 if !stmt_type.is_data() && !is_copy_data {
237 continue;
238 }
239 }
240 ContentFilter::All => {}
241 }
242
243 if let Some(ref filter) = self.config.table_filter {
244 if !filter.contains(&table_name) {
245 continue;
246 }
247 }
248
249 if !tables_seen.contains(&table_name) {
250 tables_seen.insert(table_name.clone());
251 stats.tables_found += 1;
252 stats.table_names.push(table_name.clone());
253 }
254
255 if !self.config.dry_run {
256 let write_result = if self.config.dialect == SqlDialect::Mssql {
259 let trimmed = stmt
260 .iter()
261 .rev()
262 .find(|&&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t');
263 if trimmed != Some(&b';') {
264 writer_pool.write_statement_with_suffix(&table_name, &stmt, b";")
266 } else {
267 writer_pool.write_statement(&table_name, &stmt)
268 }
269 } else {
270 writer_pool.write_statement(&table_name, &stmt)
271 };
272 write_result.with_context(|| {
273 format!("Failed to write statement to table file: {}", table_name)
274 })?;
275 }
276
277 stats.statements_processed += 1;
278 stats.bytes_processed += stmt.len() as u64;
279 }
280
281 if !self.config.dry_run {
282 writer_pool.close_all()?;
283 }
284
285 Ok(stats)
286 }
287}