sql_splitter/splitter/
mod.rs1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::writer::WriterPool;
3use ahash::AHashSet;
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7
8pub struct Stats {
9 pub statements_processed: u64,
10 pub tables_found: usize,
11 pub bytes_processed: u64,
12 pub table_names: Vec<String>,
13}
14
15#[derive(Default)]
16pub struct SplitterConfig {
17 pub dialect: SqlDialect,
18 pub dry_run: bool,
19 pub table_filter: Option<AHashSet<String>>,
20 pub progress_fn: Option<Box<dyn Fn(u64)>>,
21}
22
23pub struct Splitter {
24 input_file: PathBuf,
25 output_dir: PathBuf,
26 config: SplitterConfig,
27}
28
29impl Splitter {
30 pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
31 Self {
32 input_file,
33 output_dir,
34 config: SplitterConfig::default(),
35 }
36 }
37
38 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
39 self.config.dialect = dialect;
40 self
41 }
42
43 pub fn with_dry_run(mut self, dry_run: bool) -> Self {
44 self.config.dry_run = dry_run;
45 self
46 }
47
48 pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
49 if !tables.is_empty() {
50 self.config.table_filter = Some(tables.into_iter().collect());
51 }
52 self
53 }
54
55 pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
56 self.config.progress_fn = Some(Box::new(f));
57 self
58 }
59
60 pub fn split(self) -> anyhow::Result<Stats> {
61 let file = File::open(&self.input_file)?;
62 let file_size = file.metadata()?.len();
63 let buffer_size = determine_buffer_size(file_size);
64 let dialect = self.config.dialect;
65
66 let reader: Box<dyn Read> = if self.config.progress_fn.is_some() {
67 Box::new(ProgressReader::new(file, self.config.progress_fn.unwrap()))
68 } else {
69 Box::new(file)
70 };
71
72 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
73
74 let mut writer_pool = WriterPool::new(self.output_dir.clone());
75 if !self.config.dry_run {
76 writer_pool.ensure_output_dir()?;
77 }
78
79 let mut tables_seen: AHashSet<String> = AHashSet::new();
80 let mut stats = Stats {
81 statements_processed: 0,
82 tables_found: 0,
83 bytes_processed: 0,
84 table_names: Vec::new(),
85 };
86
87 let mut last_copy_table: Option<String> = None;
89
90 while let Some(stmt) = parser.read_statement()? {
91 let (stmt_type, mut table_name) =
92 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
93
94 if stmt_type == StatementType::Copy {
96 last_copy_table = Some(table_name.clone());
97 }
98
99 let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
101 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
103 table_name = last_copy_table.take().unwrap();
104 true
105 } else {
106 false
107 }
108 } else {
109 false
110 };
111
112 if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
113 continue;
114 }
115
116 if let Some(ref filter) = self.config.table_filter {
117 if !filter.contains(&table_name) {
118 continue;
119 }
120 }
121
122 if !tables_seen.contains(&table_name) {
123 tables_seen.insert(table_name.clone());
124 stats.tables_found += 1;
125 stats.table_names.push(table_name.clone());
126 }
127
128 if !self.config.dry_run {
129 writer_pool.write_statement(&table_name, &stmt)?;
130 }
131
132 stats.statements_processed += 1;
133 stats.bytes_processed += stmt.len() as u64;
134 }
135
136 if !self.config.dry_run {
137 writer_pool.close_all()?;
138 }
139
140 Ok(stats)
141 }
142}
143
144struct ProgressReader<R: Read> {
145 reader: R,
146 callback: Box<dyn Fn(u64)>,
147 bytes_read: u64,
148}
149
150impl<R: Read> ProgressReader<R> {
151 fn new(reader: R, callback: Box<dyn Fn(u64)>) -> Self {
152 Self {
153 reader,
154 callback,
155 bytes_read: 0,
156 }
157 }
158}
159
160impl<R: Read> Read for ProgressReader<R> {
161 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
162 let n = self.reader.read(buf)?;
163 self.bytes_read += n as u64;
164 (self.callback)(self.bytes_read);
165 Ok(n)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use tempfile::TempDir;
173
174 #[test]
175 fn test_splitter_basic() {
176 let temp_dir = TempDir::new().unwrap();
177 let input_file = temp_dir.path().join("input.sql");
178 let output_dir = temp_dir.path().join("output");
179
180 std::fs::write(
181 &input_file,
182 b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nCREATE TABLE posts (id INT);\n",
183 )
184 .unwrap();
185
186 let splitter = Splitter::new(input_file, output_dir.clone());
187 let stats = splitter.split().unwrap();
188
189 assert_eq!(stats.tables_found, 2);
190 assert_eq!(stats.statements_processed, 3);
191
192 assert!(output_dir.join("users.sql").exists());
193 assert!(output_dir.join("posts.sql").exists());
194 }
195
196 #[test]
197 fn test_splitter_dry_run() {
198 let temp_dir = TempDir::new().unwrap();
199 let input_file = temp_dir.path().join("input.sql");
200 let output_dir = temp_dir.path().join("output");
201
202 std::fs::write(&input_file, b"CREATE TABLE users (id INT);").unwrap();
203
204 let splitter = Splitter::new(input_file, output_dir.clone()).with_dry_run(true);
205 let stats = splitter.split().unwrap();
206
207 assert_eq!(stats.tables_found, 1);
208 assert!(!output_dir.exists());
209 }
210
211 #[test]
212 fn test_splitter_table_filter() {
213 let temp_dir = TempDir::new().unwrap();
214 let input_file = temp_dir.path().join("input.sql");
215 let output_dir = temp_dir.path().join("output");
216
217 std::fs::write(
218 &input_file,
219 b"CREATE TABLE users (id INT);\nCREATE TABLE posts (id INT);\nCREATE TABLE orders (id INT);",
220 )
221 .unwrap();
222
223 let splitter = Splitter::new(input_file, output_dir.clone())
224 .with_table_filter(vec!["users".to_string(), "orders".to_string()]);
225 let stats = splitter.split().unwrap();
226
227 assert_eq!(stats.tables_found, 2);
228 assert!(output_dir.join("users.sql").exists());
229 assert!(!output_dir.join("posts.sql").exists());
230 assert!(output_dir.join("orders.sql").exists());
231 }
232}