sql_splitter/splitter/
mod.rs1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use anyhow::Context;
6use serde::Serialize;
7use std::fs::File;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10
11#[derive(Serialize)]
12pub struct Stats {
13 pub statements_processed: u64,
14 pub tables_found: usize,
15 pub bytes_processed: u64,
16 pub table_names: Vec<String>,
17}
18
19#[derive(Default)]
20pub struct SplitterConfig {
21 pub dialect: SqlDialect,
22 pub dry_run: bool,
23 pub table_filter: Option<AHashSet<String>>,
24 pub progress_fn: Option<Box<dyn Fn(u64)>>,
25 pub content_filter: ContentFilter,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Compression {
31 None,
32 Gzip,
33 Bzip2,
34 Xz,
35 Zstd,
36}
37
38impl Compression {
39 pub fn from_path(path: &Path) -> Self {
41 let ext = path
42 .extension()
43 .and_then(|e| e.to_str())
44 .map(|e| e.to_lowercase());
45
46 match ext.as_deref() {
47 Some("gz" | "gzip") => Compression::Gzip,
48 Some("bz2" | "bzip2") => Compression::Bzip2,
49 Some("xz" | "lzma") => Compression::Xz,
50 Some("zst" | "zstd") => Compression::Zstd,
51 _ => Compression::None,
52 }
53 }
54
55 pub fn wrap_reader<'a>(
57 &self,
58 reader: Box<dyn Read + 'a>,
59 ) -> std::io::Result<Box<dyn Read + 'a>> {
60 Ok(match self {
61 Compression::None => reader,
62 Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
63 Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
64 Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
65 Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader)?),
66 })
67 }
68}
69
70impl std::fmt::Display for Compression {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 Compression::None => write!(f, "none"),
74 Compression::Gzip => write!(f, "gzip"),
75 Compression::Bzip2 => write!(f, "bzip2"),
76 Compression::Xz => write!(f, "xz"),
77 Compression::Zstd => write!(f, "zstd"),
78 }
79 }
80}
81
82pub struct Splitter {
83 input_file: PathBuf,
84 output_dir: PathBuf,
85 config: SplitterConfig,
86}
87
88impl Splitter {
89 pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
90 Self {
91 input_file,
92 output_dir,
93 config: SplitterConfig::default(),
94 }
95 }
96
97 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
98 self.config.dialect = dialect;
99 self
100 }
101
102 pub fn with_dry_run(mut self, dry_run: bool) -> Self {
103 self.config.dry_run = dry_run;
104 self
105 }
106
107 pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
108 if !tables.is_empty() {
109 self.config.table_filter = Some(tables.into_iter().collect());
110 }
111 self
112 }
113
114 pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
115 self.config.progress_fn = Some(Box::new(f));
116 self
117 }
118
119 pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
120 self.config.content_filter = filter;
121 self
122 }
123
124 pub fn split(mut self) -> anyhow::Result<Stats> {
125 let file = File::open(&self.input_file)
126 .with_context(|| format!("Failed to open input file: {:?}", self.input_file))?;
127 let file_size = file.metadata()?.len();
128 let buffer_size = determine_buffer_size(file_size);
129 let dialect = self.config.dialect;
130 let content_filter = self.config.content_filter;
131
132 let compression = Compression::from_path(&self.input_file);
134
135 let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
136 let progress_reader = ProgressReader::new(file, cb);
137 compression
138 .wrap_reader(Box::new(progress_reader))
139 .with_context(|| {
140 format!(
141 "Failed to initialize {} decompression for {:?}",
142 compression, self.input_file
143 )
144 })?
145 } else {
146 compression.wrap_reader(Box::new(file)).with_context(|| {
147 format!(
148 "Failed to initialize {} decompression for {:?}",
149 compression, self.input_file
150 )
151 })?
152 };
153
154 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
155
156 let mut writer_pool = WriterPool::new(self.output_dir.clone());
157 if !self.config.dry_run {
158 writer_pool.ensure_output_dir().with_context(|| {
159 format!("Failed to create output directory: {:?}", self.output_dir)
160 })?;
161 }
162
163 let mut tables_seen: AHashSet<String> = AHashSet::new();
164 let mut stats = Stats {
165 statements_processed: 0,
166 tables_found: 0,
167 bytes_processed: 0,
168 table_names: Vec::new(),
169 };
170
171 let mut last_copy_table: Option<String> = None;
173
174 while let Some(stmt) = parser.read_statement()? {
175 let (stmt_type, mut table_name) =
176 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
177
178 if stmt_type == StatementType::Copy {
180 last_copy_table = Some(table_name.clone());
181 }
182
183 let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
185 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
187 if let Some(copy_table) = last_copy_table.take() {
189 table_name = copy_table;
190 true
191 } else {
192 false
193 }
194 } else {
195 false
196 }
197 } else {
198 false
199 };
200
201 if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
202 continue;
203 }
204
205 match content_filter {
207 ContentFilter::SchemaOnly => {
208 if !stmt_type.is_schema() {
209 continue;
210 }
211 }
212 ContentFilter::DataOnly => {
213 if !stmt_type.is_data() && !is_copy_data {
215 continue;
216 }
217 }
218 ContentFilter::All => {}
219 }
220
221 if let Some(ref filter) = self.config.table_filter {
222 if !filter.contains(&table_name) {
223 continue;
224 }
225 }
226
227 if !tables_seen.contains(&table_name) {
228 tables_seen.insert(table_name.clone());
229 stats.tables_found += 1;
230 stats.table_names.push(table_name.clone());
231 }
232
233 if !self.config.dry_run {
234 let write_result = if self.config.dialect == SqlDialect::Mssql {
237 let trimmed = stmt
238 .iter()
239 .rev()
240 .find(|&&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t');
241 if trimmed != Some(&b';') {
242 writer_pool.write_statement_with_suffix(&table_name, &stmt, b";")
244 } else {
245 writer_pool.write_statement(&table_name, &stmt)
246 }
247 } else {
248 writer_pool.write_statement(&table_name, &stmt)
249 };
250 write_result.with_context(|| {
251 format!("Failed to write statement to table file: {}", table_name)
252 })?;
253 }
254
255 stats.statements_processed += 1;
256 stats.bytes_processed += stmt.len() as u64;
257 }
258
259 if !self.config.dry_run {
260 writer_pool.close_all()?;
261 }
262
263 Ok(stats)
264 }
265}