sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use serde::Serialize;
6use std::fs::File;
7use std::io::Read;
8use std::path::{Path, PathBuf};
9
10#[derive(Serialize)]
11pub struct Stats {
12    pub statements_processed: u64,
13    pub tables_found: usize,
14    pub bytes_processed: u64,
15    pub table_names: Vec<String>,
16}
17
18#[derive(Default)]
19pub struct SplitterConfig {
20    pub dialect: SqlDialect,
21    pub dry_run: bool,
22    pub table_filter: Option<AHashSet<String>>,
23    pub progress_fn: Option<Box<dyn Fn(u64)>>,
24    pub content_filter: ContentFilter,
25}
26
27/// Compression format detected from file extension
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Compression {
30    None,
31    Gzip,
32    Bzip2,
33    Xz,
34    Zstd,
35}
36
37impl Compression {
38    /// Detect compression format from file extension
39    pub fn from_path(path: &Path) -> Self {
40        let ext = path
41            .extension()
42            .and_then(|e| e.to_str())
43            .map(|e| e.to_lowercase());
44
45        match ext.as_deref() {
46            Some("gz" | "gzip") => Compression::Gzip,
47            Some("bz2" | "bzip2") => Compression::Bzip2,
48            Some("xz" | "lzma") => Compression::Xz,
49            Some("zst" | "zstd") => Compression::Zstd,
50            _ => Compression::None,
51        }
52    }
53
54    /// Wrap a reader with the appropriate decompressor
55    pub fn wrap_reader<'a>(&self, reader: Box<dyn Read + 'a>) -> Box<dyn Read + 'a> {
56        match self {
57            Compression::None => reader,
58            Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
59            Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
60            Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
61            Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader).unwrap()),
62        }
63    }
64}
65
66impl std::fmt::Display for Compression {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            Compression::None => write!(f, "none"),
70            Compression::Gzip => write!(f, "gzip"),
71            Compression::Bzip2 => write!(f, "bzip2"),
72            Compression::Xz => write!(f, "xz"),
73            Compression::Zstd => write!(f, "zstd"),
74        }
75    }
76}
77
78pub struct Splitter {
79    input_file: PathBuf,
80    output_dir: PathBuf,
81    config: SplitterConfig,
82}
83
84impl Splitter {
85    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
86        Self {
87            input_file,
88            output_dir,
89            config: SplitterConfig::default(),
90        }
91    }
92
93    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
94        self.config.dialect = dialect;
95        self
96    }
97
98    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
99        self.config.dry_run = dry_run;
100        self
101    }
102
103    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
104        if !tables.is_empty() {
105            self.config.table_filter = Some(tables.into_iter().collect());
106        }
107        self
108    }
109
110    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
111        self.config.progress_fn = Some(Box::new(f));
112        self
113    }
114
115    pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
116        self.config.content_filter = filter;
117        self
118    }
119
120    pub fn split(mut self) -> anyhow::Result<Stats> {
121        let file = File::open(&self.input_file)?;
122        let file_size = file.metadata()?.len();
123        let buffer_size = determine_buffer_size(file_size);
124        let dialect = self.config.dialect;
125        let content_filter = self.config.content_filter;
126
127        // Detect and apply decompression
128        let compression = Compression::from_path(&self.input_file);
129
130        let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
131            let progress_reader = ProgressReader::new(file, cb);
132            compression.wrap_reader(Box::new(progress_reader))
133        } else {
134            compression.wrap_reader(Box::new(file))
135        };
136
137        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
138
139        let mut writer_pool = WriterPool::new(self.output_dir.clone());
140        if !self.config.dry_run {
141            writer_pool.ensure_output_dir()?;
142        }
143
144        let mut tables_seen: AHashSet<String> = AHashSet::new();
145        let mut stats = Stats {
146            statements_processed: 0,
147            tables_found: 0,
148            bytes_processed: 0,
149            table_names: Vec::new(),
150        };
151
152        // Track the last COPY table for PostgreSQL COPY data blocks
153        let mut last_copy_table: Option<String> = None;
154
155        while let Some(stmt) = parser.read_statement()? {
156            let (stmt_type, mut table_name) =
157                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
158
159            // Track COPY statements for data association
160            if stmt_type == StatementType::Copy {
161                last_copy_table = Some(table_name.clone());
162            }
163
164            // Handle PostgreSQL COPY data blocks - associate with last COPY table
165            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
166                // Check if this looks like COPY data (ends with \.\n)
167                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
168                    table_name = last_copy_table.take().unwrap();
169                    true
170                } else {
171                    false
172                }
173            } else {
174                false
175            };
176
177            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
178                continue;
179            }
180
181            // Apply content filter (schema-only or data-only)
182            match content_filter {
183                ContentFilter::SchemaOnly => {
184                    if !stmt_type.is_schema() {
185                        continue;
186                    }
187                }
188                ContentFilter::DataOnly => {
189                    // For data-only, include INSERT, COPY, and COPY data blocks
190                    if !stmt_type.is_data() && !is_copy_data {
191                        continue;
192                    }
193                }
194                ContentFilter::All => {}
195            }
196
197            if let Some(ref filter) = self.config.table_filter {
198                if !filter.contains(&table_name) {
199                    continue;
200                }
201            }
202
203            if !tables_seen.contains(&table_name) {
204                tables_seen.insert(table_name.clone());
205                stats.tables_found += 1;
206                stats.table_names.push(table_name.clone());
207            }
208
209            if !self.config.dry_run {
210                writer_pool.write_statement(&table_name, &stmt)?;
211            }
212
213            stats.statements_processed += 1;
214            stats.bytes_processed += stmt.len() as u64;
215        }
216
217        if !self.config.dry_run {
218            writer_pool.close_all()?;
219        }
220
221        Ok(stats)
222    }
223}