sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::writer::WriterPool;
3use ahash::AHashSet;
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7
8pub struct Stats {
9    pub statements_processed: u64,
10    pub tables_found: usize,
11    pub bytes_processed: u64,
12    pub table_names: Vec<String>,
13}
14
15#[derive(Default)]
16pub struct SplitterConfig {
17    pub dialect: SqlDialect,
18    pub dry_run: bool,
19    pub table_filter: Option<AHashSet<String>>,
20    pub progress_fn: Option<Box<dyn Fn(u64)>>,
21}
22
23pub struct Splitter {
24    input_file: PathBuf,
25    output_dir: PathBuf,
26    config: SplitterConfig,
27}
28
29impl Splitter {
30    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
31        Self {
32            input_file,
33            output_dir,
34            config: SplitterConfig::default(),
35        }
36    }
37
38    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
39        self.config.dialect = dialect;
40        self
41    }
42
43    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
44        self.config.dry_run = dry_run;
45        self
46    }
47
48    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
49        if !tables.is_empty() {
50            self.config.table_filter = Some(tables.into_iter().collect());
51        }
52        self
53    }
54
55    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
56        self.config.progress_fn = Some(Box::new(f));
57        self
58    }
59
60    pub fn split(self) -> anyhow::Result<Stats> {
61        let file = File::open(&self.input_file)?;
62        let file_size = file.metadata()?.len();
63        let buffer_size = determine_buffer_size(file_size);
64        let dialect = self.config.dialect;
65
66        let reader: Box<dyn Read> = if self.config.progress_fn.is_some() {
67            Box::new(ProgressReader::new(file, self.config.progress_fn.unwrap()))
68        } else {
69            Box::new(file)
70        };
71
72        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
73
74        let mut writer_pool = WriterPool::new(self.output_dir.clone());
75        if !self.config.dry_run {
76            writer_pool.ensure_output_dir()?;
77        }
78
79        let mut tables_seen: AHashSet<String> = AHashSet::new();
80        let mut stats = Stats {
81            statements_processed: 0,
82            tables_found: 0,
83            bytes_processed: 0,
84            table_names: Vec::new(),
85        };
86
87        // Track the last COPY table for PostgreSQL COPY data blocks
88        let mut last_copy_table: Option<String> = None;
89
90        while let Some(stmt) = parser.read_statement()? {
91            let (stmt_type, mut table_name) =
92                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
93
94            // Track COPY statements for data association
95            if stmt_type == StatementType::Copy {
96                last_copy_table = Some(table_name.clone());
97            }
98
99            // Handle PostgreSQL COPY data blocks - associate with last COPY table
100            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
101                // Check if this looks like COPY data (ends with \.\n)
102                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
103                    table_name = last_copy_table.take().unwrap();
104                    true
105                } else {
106                    false
107                }
108            } else {
109                false
110            };
111
112            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
113                continue;
114            }
115
116            if let Some(ref filter) = self.config.table_filter {
117                if !filter.contains(&table_name) {
118                    continue;
119                }
120            }
121
122            if !tables_seen.contains(&table_name) {
123                tables_seen.insert(table_name.clone());
124                stats.tables_found += 1;
125                stats.table_names.push(table_name.clone());
126            }
127
128            if !self.config.dry_run {
129                writer_pool.write_statement(&table_name, &stmt)?;
130            }
131
132            stats.statements_processed += 1;
133            stats.bytes_processed += stmt.len() as u64;
134        }
135
136        if !self.config.dry_run {
137            writer_pool.close_all()?;
138        }
139
140        Ok(stats)
141    }
142}
143
144struct ProgressReader<R: Read> {
145    reader: R,
146    callback: Box<dyn Fn(u64)>,
147    bytes_read: u64,
148}
149
150impl<R: Read> ProgressReader<R> {
151    fn new(reader: R, callback: Box<dyn Fn(u64)>) -> Self {
152        Self {
153            reader,
154            callback,
155            bytes_read: 0,
156        }
157    }
158}
159
160impl<R: Read> Read for ProgressReader<R> {
161    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
162        let n = self.reader.read(buf)?;
163        self.bytes_read += n as u64;
164        (self.callback)(self.bytes_read);
165        Ok(n)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use tempfile::TempDir;
173
174    #[test]
175    fn test_splitter_basic() {
176        let temp_dir = TempDir::new().unwrap();
177        let input_file = temp_dir.path().join("input.sql");
178        let output_dir = temp_dir.path().join("output");
179
180        std::fs::write(
181            &input_file,
182            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nCREATE TABLE posts (id INT);\n",
183        )
184        .unwrap();
185
186        let splitter = Splitter::new(input_file, output_dir.clone());
187        let stats = splitter.split().unwrap();
188
189        assert_eq!(stats.tables_found, 2);
190        assert_eq!(stats.statements_processed, 3);
191
192        assert!(output_dir.join("users.sql").exists());
193        assert!(output_dir.join("posts.sql").exists());
194    }
195
196    #[test]
197    fn test_splitter_dry_run() {
198        let temp_dir = TempDir::new().unwrap();
199        let input_file = temp_dir.path().join("input.sql");
200        let output_dir = temp_dir.path().join("output");
201
202        std::fs::write(&input_file, b"CREATE TABLE users (id INT);").unwrap();
203
204        let splitter = Splitter::new(input_file, output_dir.clone()).with_dry_run(true);
205        let stats = splitter.split().unwrap();
206
207        assert_eq!(stats.tables_found, 1);
208        assert!(!output_dir.exists());
209    }
210
211    #[test]
212    fn test_splitter_table_filter() {
213        let temp_dir = TempDir::new().unwrap();
214        let input_file = temp_dir.path().join("input.sql");
215        let output_dir = temp_dir.path().join("output");
216
217        std::fs::write(
218            &input_file,
219            b"CREATE TABLE users (id INT);\nCREATE TABLE posts (id INT);\nCREATE TABLE orders (id INT);",
220        )
221        .unwrap();
222
223        let splitter = Splitter::new(input_file, output_dir.clone())
224            .with_table_filter(vec!["users".to_string(), "orders".to_string()]);
225        let stats = splitter.split().unwrap();
226
227        assert_eq!(stats.tables_found, 2);
228        assert!(output_dir.join("users.sql").exists());
229        assert!(!output_dir.join("posts.sql").exists());
230        assert!(output_dir.join("orders.sql").exists());
231    }
232}