sql_splitter/merger/
mod.rs

1//! Merger module for combining split SQL files back into a single dump.
2
3use crate::parser::SqlDialect;
4use std::collections::HashSet;
5use std::fs::{self, File};
6use std::io::{self, BufRead, BufReader, BufWriter, Write};
7use std::path::PathBuf;
8
9/// Statistics from merge operation
10#[derive(Debug, Default)]
11pub struct MergeStats {
12    pub tables_merged: usize,
13    pub bytes_written: u64,
14    pub table_names: Vec<String>,
15}
16
17/// Merger configuration
18#[derive(Default)]
19pub struct MergerConfig {
20    pub dialect: SqlDialect,
21    pub tables: Option<HashSet<String>>,
22    pub exclude: HashSet<String>,
23    pub add_transaction: bool,
24    pub add_header: bool,
25}
26
27/// Merger for combining split SQL files
28pub struct Merger {
29    input_dir: PathBuf,
30    output: Option<PathBuf>,
31    config: MergerConfig,
32}
33
34impl Merger {
35    pub fn new(input_dir: PathBuf, output: Option<PathBuf>) -> Self {
36        Self {
37            input_dir,
38            output,
39            config: MergerConfig::default(),
40        }
41    }
42
43    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
44        self.config.dialect = dialect;
45        self
46    }
47
48    pub fn with_tables(mut self, tables: HashSet<String>) -> Self {
49        self.config.tables = Some(tables);
50        self
51    }
52
53    pub fn with_exclude(mut self, exclude: HashSet<String>) -> Self {
54        self.config.exclude = exclude;
55        self
56    }
57
58    pub fn with_transaction(mut self, add_transaction: bool) -> Self {
59        self.config.add_transaction = add_transaction;
60        self
61    }
62
63    pub fn with_header(mut self, add_header: bool) -> Self {
64        self.config.add_header = add_header;
65        self
66    }
67
68    /// Run the merge operation
69    pub fn merge(&self) -> anyhow::Result<MergeStats> {
70        // Discover SQL files
71        let sql_files = self.discover_sql_files()?;
72        if sql_files.is_empty() {
73            anyhow::bail!(
74                "no .sql files found in directory: {}",
75                self.input_dir.display()
76            );
77        }
78
79        // Filter files
80        let filtered_files: Vec<(String, PathBuf)> = sql_files
81            .into_iter()
82            .filter(|(name, _)| {
83                let name_lower = name.to_lowercase();
84                if let Some(ref include) = self.config.tables {
85                    if !include.contains(&name_lower) {
86                        return false;
87                    }
88                }
89                !self.config.exclude.contains(&name_lower)
90            })
91            .collect();
92
93        if filtered_files.is_empty() {
94            anyhow::bail!("no tables remaining after filtering");
95        }
96
97        // Sort alphabetically
98        let mut sorted_files = filtered_files;
99        sorted_files.sort_by(|a, b| a.0.cmp(&b.0));
100
101        // Merge to output
102        if let Some(ref out_path) = self.output {
103            if let Some(parent) = out_path.parent() {
104                fs::create_dir_all(parent)?;
105            }
106            let file = File::create(out_path)?;
107            let writer = BufWriter::with_capacity(256 * 1024, file);
108            self.merge_files(sorted_files, writer)
109        } else {
110            let stdout = io::stdout();
111            let writer = BufWriter::new(stdout.lock());
112            self.merge_files(sorted_files, writer)
113        }
114    }
115
116    fn discover_sql_files(&self) -> anyhow::Result<Vec<(String, PathBuf)>> {
117        let mut files = Vec::new();
118
119        for entry in fs::read_dir(&self.input_dir)? {
120            let entry = entry?;
121            let path = entry.path();
122
123            if path.is_file() {
124                if let Some(ext) = path.extension() {
125                    if ext.eq_ignore_ascii_case("sql") {
126                        if let Some(stem) = path.file_stem() {
127                            let table_name = stem.to_string_lossy().to_string();
128                            files.push((table_name, path));
129                        }
130                    }
131                }
132            }
133        }
134
135        Ok(files)
136    }
137
138    fn merge_files<W: Write>(
139        &self,
140        files: Vec<(String, PathBuf)>,
141        mut writer: W,
142    ) -> anyhow::Result<MergeStats> {
143        let mut stats = MergeStats::default();
144
145        // Write header
146        if self.config.add_header {
147            self.write_header(&mut writer, files.len())?;
148        }
149
150        // Write transaction start
151        if self.config.add_transaction {
152            let tx_start = self.transaction_start();
153            writer.write_all(tx_start.as_bytes())?;
154            stats.bytes_written += tx_start.len() as u64;
155        }
156
157        // Merge each file
158        for (table_name, path) in &files {
159            // Write table separator
160            let separator = format!(
161                "\n-- ============================================================\n-- Table: {}\n-- ============================================================\n\n",
162                table_name
163            );
164            writer.write_all(separator.as_bytes())?;
165            stats.bytes_written += separator.len() as u64;
166
167            // Stream file content
168            let file = File::open(path)?;
169            let reader = BufReader::with_capacity(64 * 1024, file);
170
171            for line in reader.lines() {
172                let line = line?;
173                writer.write_all(line.as_bytes())?;
174                writer.write_all(b"\n")?;
175                stats.bytes_written += line.len() as u64 + 1;
176            }
177
178            stats.table_names.push(table_name.clone());
179            stats.tables_merged += 1;
180        }
181
182        // Write transaction end
183        if self.config.add_transaction {
184            let tx_end = "\nCOMMIT;\n";
185            writer.write_all(tx_end.as_bytes())?;
186            stats.bytes_written += tx_end.len() as u64;
187        }
188
189        // Write footer
190        if self.config.add_header {
191            self.write_footer(&mut writer)?;
192        }
193
194        writer.flush()?;
195
196        Ok(stats)
197    }
198
199    fn write_header<W: Write>(&self, w: &mut W, table_count: usize) -> io::Result<()> {
200        writeln!(w, "-- SQL Merge Output")?;
201        writeln!(w, "-- Generated by sql-splitter")?;
202        writeln!(w, "-- Tables: {}", table_count)?;
203        writeln!(w, "-- Dialect: {}", self.config.dialect)?;
204        writeln!(w)?;
205
206        match self.config.dialect {
207            SqlDialect::MySql => {
208                writeln!(w, "SET NAMES utf8mb4;")?;
209                writeln!(w, "SET FOREIGN_KEY_CHECKS = 0;")?;
210            }
211            SqlDialect::Postgres => {
212                writeln!(w, "SET client_encoding = 'UTF8';")?;
213            }
214            SqlDialect::Sqlite => {
215                writeln!(w, "PRAGMA foreign_keys = OFF;")?;
216            }
217            SqlDialect::Mssql => {
218                writeln!(w, "SET ANSI_NULLS ON;")?;
219                writeln!(w, "SET QUOTED_IDENTIFIER ON;")?;
220                writeln!(w, "SET NOCOUNT ON;")?;
221            }
222        }
223        writeln!(w)?;
224
225        Ok(())
226    }
227
228    fn write_footer<W: Write>(&self, w: &mut W) -> io::Result<()> {
229        writeln!(w)?;
230        match self.config.dialect {
231            SqlDialect::MySql => {
232                writeln!(w, "SET FOREIGN_KEY_CHECKS = 1;")?;
233            }
234            SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {}
235        }
236        Ok(())
237    }
238
239    fn transaction_start(&self) -> &'static str {
240        match self.config.dialect {
241            SqlDialect::MySql => "START TRANSACTION;\n\n",
242            SqlDialect::Postgres => "BEGIN;\n\n",
243            SqlDialect::Sqlite | SqlDialect::Mssql => "BEGIN TRANSACTION;\n\n",
244        }
245    }
246}