sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::writer::WriterPool;
3use ahash::AHashSet;
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7
8pub struct Stats {
9    pub statements_processed: u64,
10    pub tables_found: usize,
11    pub bytes_processed: u64,
12    pub table_names: Vec<String>,
13}
14
15#[derive(Default)]
16pub struct SplitterConfig {
17    pub dialect: SqlDialect,
18    pub dry_run: bool,
19    pub table_filter: Option<AHashSet<String>>,
20    pub progress_fn: Option<Box<dyn Fn(u64)>>,
21}
22
23pub struct Splitter {
24    input_file: PathBuf,
25    output_dir: PathBuf,
26    config: SplitterConfig,
27}
28
29impl Splitter {
30    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
31        Self {
32            input_file,
33            output_dir,
34            config: SplitterConfig::default(),
35        }
36    }
37
38    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
39        self.config.dialect = dialect;
40        self
41    }
42
43    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
44        self.config.dry_run = dry_run;
45        self
46    }
47
48    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
49        if !tables.is_empty() {
50            self.config.table_filter = Some(tables.into_iter().collect());
51        }
52        self
53    }
54
55    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
56        self.config.progress_fn = Some(Box::new(f));
57        self
58    }
59
60    pub fn split(self) -> anyhow::Result<Stats> {
61        let file = File::open(&self.input_file)?;
62        let file_size = file.metadata()?.len();
63        let buffer_size = determine_buffer_size(file_size);
64        let dialect = self.config.dialect;
65
66        let reader: Box<dyn Read> = if self.config.progress_fn.is_some() {
67            Box::new(ProgressReader::new(file, self.config.progress_fn.unwrap()))
68        } else {
69            Box::new(file)
70        };
71
72        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
73
74        let mut writer_pool = WriterPool::new(self.output_dir.clone());
75        if !self.config.dry_run {
76            writer_pool.ensure_output_dir()?;
77        }
78
79        let mut tables_seen: AHashSet<String> = AHashSet::new();
80        let mut stats = Stats {
81            statements_processed: 0,
82            tables_found: 0,
83            bytes_processed: 0,
84            table_names: Vec::new(),
85        };
86        
87        // Track the last COPY table for PostgreSQL COPY data blocks
88        let mut last_copy_table: Option<String> = None;
89
90        while let Some(stmt) = parser.read_statement()? {
91            let (stmt_type, mut table_name) = Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
92
93            // Track COPY statements for data association
94            if stmt_type == StatementType::Copy {
95                last_copy_table = Some(table_name.clone());
96            }
97
98            // Handle PostgreSQL COPY data blocks - associate with last COPY table
99            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
100                // Check if this looks like COPY data (ends with \.\n)
101                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
102                    table_name = last_copy_table.take().unwrap();
103                    true
104                } else {
105                    false
106                }
107            } else {
108                false
109            };
110
111            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
112                continue;
113            }
114
115            if let Some(ref filter) = self.config.table_filter {
116                if !filter.contains(&table_name) {
117                    continue;
118                }
119            }
120
121            if !tables_seen.contains(&table_name) {
122                tables_seen.insert(table_name.clone());
123                stats.tables_found += 1;
124                stats.table_names.push(table_name.clone());
125            }
126
127            if !self.config.dry_run {
128                writer_pool.write_statement(&table_name, &stmt)?;
129            }
130
131            stats.statements_processed += 1;
132            stats.bytes_processed += stmt.len() as u64;
133        }
134
135        if !self.config.dry_run {
136            writer_pool.close_all()?;
137        }
138
139        Ok(stats)
140    }
141}
142
143struct ProgressReader<R: Read> {
144    reader: R,
145    callback: Box<dyn Fn(u64)>,
146    bytes_read: u64,
147}
148
149impl<R: Read> ProgressReader<R> {
150    fn new(reader: R, callback: Box<dyn Fn(u64)>) -> Self {
151        Self {
152            reader,
153            callback,
154            bytes_read: 0,
155        }
156    }
157}
158
159impl<R: Read> Read for ProgressReader<R> {
160    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
161        let n = self.reader.read(buf)?;
162        self.bytes_read += n as u64;
163        (self.callback)(self.bytes_read);
164        Ok(n)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use tempfile::TempDir;
172
173    #[test]
174    fn test_splitter_basic() {
175        let temp_dir = TempDir::new().unwrap();
176        let input_file = temp_dir.path().join("input.sql");
177        let output_dir = temp_dir.path().join("output");
178
179        std::fs::write(
180            &input_file,
181            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nCREATE TABLE posts (id INT);\n",
182        )
183        .unwrap();
184
185        let splitter = Splitter::new(input_file, output_dir.clone());
186        let stats = splitter.split().unwrap();
187
188        assert_eq!(stats.tables_found, 2);
189        assert_eq!(stats.statements_processed, 3);
190
191        assert!(output_dir.join("users.sql").exists());
192        assert!(output_dir.join("posts.sql").exists());
193    }
194
195    #[test]
196    fn test_splitter_dry_run() {
197        let temp_dir = TempDir::new().unwrap();
198        let input_file = temp_dir.path().join("input.sql");
199        let output_dir = temp_dir.path().join("output");
200
201        std::fs::write(&input_file, b"CREATE TABLE users (id INT);").unwrap();
202
203        let splitter = Splitter::new(input_file, output_dir.clone()).with_dry_run(true);
204        let stats = splitter.split().unwrap();
205
206        assert_eq!(stats.tables_found, 1);
207        assert!(!output_dir.exists());
208    }
209
210    #[test]
211    fn test_splitter_table_filter() {
212        let temp_dir = TempDir::new().unwrap();
213        let input_file = temp_dir.path().join("input.sql");
214        let output_dir = temp_dir.path().join("output");
215
216        std::fs::write(
217            &input_file,
218            b"CREATE TABLE users (id INT);\nCREATE TABLE posts (id INT);\nCREATE TABLE orders (id INT);",
219        )
220        .unwrap();
221
222        let splitter = Splitter::new(input_file, output_dir.clone())
223            .with_table_filter(vec!["users".to_string(), "orders".to_string()]);
224        let stats = splitter.split().unwrap();
225
226        assert_eq!(stats.tables_found, 2);
227        assert!(output_dir.join("users.sql").exists());
228        assert!(!output_dir.join("posts.sql").exists());
229        assert!(output_dir.join("orders.sql").exists());
230    }
231}