Skip to main content

sql_splitter/merger/
mod.rs

1//! Merger module for combining split SQL files back into a single dump.
2
3use crate::parser::SqlDialect;
4use std::collections::HashSet;
5use std::fs::{self, File};
6use std::io::{self, BufRead, BufReader, BufWriter, Write};
7use std::path::PathBuf;
8
9/// Statistics from merge operation
10#[derive(Debug, Default)]
11pub struct MergeStats {
12    /// Number of tables merged.
13    pub tables_merged: usize,
14    /// Total bytes written to output.
15    pub bytes_written: u64,
16    /// Names of tables that were merged.
17    pub table_names: Vec<String>,
18}
19
20/// Merger configuration
21#[derive(Default)]
22pub struct MergerConfig {
23    pub dialect: SqlDialect,
24    pub tables: Option<HashSet<String>>,
25    pub exclude: HashSet<String>,
26    pub add_transaction: bool,
27    pub add_header: bool,
28}
29
30/// Merger for combining split SQL files
31pub struct Merger {
32    input_dir: PathBuf,
33    output: Option<PathBuf>,
34    config: MergerConfig,
35}
36
37impl Merger {
38    pub fn new(input_dir: PathBuf, output: Option<PathBuf>) -> Self {
39        Self {
40            input_dir,
41            output,
42            config: MergerConfig::default(),
43        }
44    }
45
46    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
47        self.config.dialect = dialect;
48        self
49    }
50
51    pub fn with_tables(mut self, tables: HashSet<String>) -> Self {
52        self.config.tables = Some(tables);
53        self
54    }
55
56    pub fn with_exclude(mut self, exclude: HashSet<String>) -> Self {
57        self.config.exclude = exclude;
58        self
59    }
60
61    pub fn with_transaction(mut self, add_transaction: bool) -> Self {
62        self.config.add_transaction = add_transaction;
63        self
64    }
65
66    pub fn with_header(mut self, add_header: bool) -> Self {
67        self.config.add_header = add_header;
68        self
69    }
70
71    /// Run the merge operation
72    pub fn merge(&self) -> anyhow::Result<MergeStats> {
73        // Discover SQL files
74        let sql_files = self.discover_sql_files()?;
75        if sql_files.is_empty() {
76            anyhow::bail!(
77                "no .sql files found in directory: {}",
78                self.input_dir.display()
79            );
80        }
81
82        // Filter files
83        let filtered_files: Vec<(String, PathBuf)> = sql_files
84            .into_iter()
85            .filter(|(name, _)| {
86                let name_lower = name.to_lowercase();
87                if let Some(ref include) = self.config.tables {
88                    if !include.contains(&name_lower) {
89                        return false;
90                    }
91                }
92                !self.config.exclude.contains(&name_lower)
93            })
94            .collect();
95
96        if filtered_files.is_empty() {
97            anyhow::bail!("no tables remaining after filtering");
98        }
99
100        // Sort alphabetically
101        let mut sorted_files = filtered_files;
102        sorted_files.sort_by(|a, b| a.0.cmp(&b.0));
103
104        // Merge to output
105        if let Some(ref out_path) = self.output {
106            if let Some(parent) = out_path.parent() {
107                fs::create_dir_all(parent)?;
108            }
109            let file = File::create(out_path)?;
110            let writer = BufWriter::with_capacity(256 * 1024, file);
111            self.merge_files(sorted_files, writer)
112        } else {
113            let stdout = io::stdout();
114            let writer = BufWriter::new(stdout.lock());
115            self.merge_files(sorted_files, writer)
116        }
117    }
118
119    fn discover_sql_files(&self) -> anyhow::Result<Vec<(String, PathBuf)>> {
120        let mut files = Vec::new();
121
122        for entry in fs::read_dir(&self.input_dir)? {
123            let entry = entry?;
124            let path = entry.path();
125
126            if path.is_file() {
127                if let Some(ext) = path.extension() {
128                    if ext.eq_ignore_ascii_case("sql") {
129                        if let Some(stem) = path.file_stem() {
130                            let table_name = stem.to_string_lossy().to_string();
131                            files.push((table_name, path));
132                        }
133                    }
134                }
135            }
136        }
137
138        Ok(files)
139    }
140
141    fn merge_files<W: Write>(
142        &self,
143        files: Vec<(String, PathBuf)>,
144        mut writer: W,
145    ) -> anyhow::Result<MergeStats> {
146        let mut stats = MergeStats::default();
147
148        // Write header
149        if self.config.add_header {
150            self.write_header(&mut writer, files.len())?;
151        }
152
153        // Write transaction start
154        if self.config.add_transaction {
155            let tx_start = self.transaction_start();
156            writer.write_all(tx_start.as_bytes())?;
157            stats.bytes_written += tx_start.len() as u64;
158        }
159
160        // Merge each file
161        for (table_name, path) in &files {
162            // Write table separator
163            let separator = format!(
164                "\n-- ============================================================\n-- Table: {}\n-- ============================================================\n\n",
165                table_name
166            );
167            writer.write_all(separator.as_bytes())?;
168            stats.bytes_written += separator.len() as u64;
169
170            // Stream file content
171            let file = File::open(path)?;
172            let reader = BufReader::with_capacity(64 * 1024, file);
173
174            for line in reader.lines() {
175                let line = line?;
176                writer.write_all(line.as_bytes())?;
177                writer.write_all(b"\n")?;
178                stats.bytes_written += line.len() as u64 + 1;
179            }
180
181            stats.table_names.push(table_name.clone());
182            stats.tables_merged += 1;
183        }
184
185        // Write transaction end
186        if self.config.add_transaction {
187            let tx_end = "\nCOMMIT;\n";
188            writer.write_all(tx_end.as_bytes())?;
189            stats.bytes_written += tx_end.len() as u64;
190        }
191
192        // Write footer
193        if self.config.add_header {
194            self.write_footer(&mut writer)?;
195        }
196
197        writer.flush()?;
198
199        Ok(stats)
200    }
201
202    fn write_header<W: Write>(&self, w: &mut W, table_count: usize) -> io::Result<()> {
203        writeln!(w, "-- SQL Merge Output")?;
204        writeln!(w, "-- Generated by sql-splitter")?;
205        writeln!(w, "-- Tables: {}", table_count)?;
206        writeln!(w, "-- Dialect: {}", self.config.dialect)?;
207        writeln!(w)?;
208
209        match self.config.dialect {
210            SqlDialect::MySql => {
211                writeln!(w, "SET NAMES utf8mb4;")?;
212                writeln!(w, "SET FOREIGN_KEY_CHECKS = 0;")?;
213            }
214            SqlDialect::Postgres => {
215                writeln!(w, "SET client_encoding = 'UTF8';")?;
216            }
217            SqlDialect::Sqlite => {
218                writeln!(w, "PRAGMA foreign_keys = OFF;")?;
219            }
220            SqlDialect::Mssql => {
221                writeln!(w, "SET ANSI_NULLS ON;")?;
222                writeln!(w, "SET QUOTED_IDENTIFIER ON;")?;
223                writeln!(w, "SET NOCOUNT ON;")?;
224            }
225        }
226        writeln!(w)?;
227
228        Ok(())
229    }
230
231    fn write_footer<W: Write>(&self, w: &mut W) -> io::Result<()> {
232        writeln!(w)?;
233        match self.config.dialect {
234            SqlDialect::MySql => {
235                writeln!(w, "SET FOREIGN_KEY_CHECKS = 1;")?;
236            }
237            SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {}
238        }
239        Ok(())
240    }
241
242    fn transaction_start(&self) -> &'static str {
243        match self.config.dialect {
244            SqlDialect::MySql => "START TRANSACTION;\n\n",
245            SqlDialect::Postgres => "BEGIN;\n\n",
246            SqlDialect::Sqlite | SqlDialect::Mssql => "BEGIN TRANSACTION;\n\n",
247        }
248    }
249}