Skip to main content

sql_splitter/differ/
mod.rs

1//! Diff module for comparing two SQL dumps.
2//!
3//! This module provides:
4//! - Schema comparison (tables added/removed/modified, columns, PKs, FKs)
5//! - Data comparison (row counts: added/removed/modified)
6//! - Memory-bounded operation using PK hashing
7//! - Multiple output formats (text, json, sql)
8
9mod data;
10mod output;
11mod schema;
12
13pub use data::*;
14pub use output::*;
15pub use schema::*;
16
17use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
18use crate::progress::ProgressReader;
19use crate::schema::{Schema, SchemaBuilder};
20use crate::splitter::Compression;
21use serde::Serialize;
22use std::fs::File;
23use std::io::Read;
24use std::path::PathBuf;
25use std::sync::Arc;
26
27/// Configuration for the diff operation
28#[derive(Debug, Clone)]
29pub struct DiffConfig {
30    /// Path to the "old" SQL file
31    pub old_path: PathBuf,
32    /// Path to the "new" SQL file
33    pub new_path: PathBuf,
34    /// SQL dialect (auto-detected if None)
35    pub dialect: Option<SqlDialect>,
36    /// Only compare schema, skip data
37    pub schema_only: bool,
38    /// Only compare data, skip schema
39    pub data_only: bool,
40    /// Tables to include (if empty, include all)
41    pub tables: Vec<String>,
42    /// Tables to exclude
43    pub exclude: Vec<String>,
44    /// Output format
45    pub format: DiffOutputFormat,
46    /// Show verbose row-level details
47    pub verbose: bool,
48    /// Show progress bar
49    pub progress: bool,
50    /// Maximum PK entries to track globally
51    pub max_pk_entries: usize,
52    /// Don't skip tables without PK, use all columns as key
53    pub allow_no_pk: bool,
54    /// Ignore column order when comparing schemas
55    pub ignore_column_order: bool,
56    /// Primary key overrides: table name -> column names
57    pub pk_overrides: std::collections::HashMap<String, Vec<String>>,
58    /// Column patterns to ignore (glob format: table.column)
59    pub ignore_columns: Vec<String>,
60}
61
62impl Default for DiffConfig {
63    fn default() -> Self {
64        Self {
65            old_path: PathBuf::new(),
66            new_path: PathBuf::new(),
67            dialect: None,
68            schema_only: false,
69            data_only: false,
70            tables: Vec::new(),
71            exclude: Vec::new(),
72            format: DiffOutputFormat::Text,
73            verbose: false,
74            progress: false,
75            max_pk_entries: 10_000_000, // 10M entries ~= 160MB
76            allow_no_pk: false,
77            ignore_column_order: false,
78            pk_overrides: std::collections::HashMap::new(),
79            ignore_columns: Vec::new(),
80        }
81    }
82}
83
84/// Output format for diff results
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
86pub enum DiffOutputFormat {
87    #[default]
88    Text,
89    Json,
90    Sql,
91}
92
93impl std::str::FromStr for DiffOutputFormat {
94    type Err = String;
95
96    fn from_str(s: &str) -> Result<Self, Self::Err> {
97        match s.to_lowercase().as_str() {
98            "text" => Ok(Self::Text),
99            "json" => Ok(Self::Json),
100            "sql" => Ok(Self::Sql),
101            _ => Err(format!("Unknown format: {}. Use: text, json, sql", s)),
102        }
103    }
104}
105
106/// A warning generated during diff operation
107#[derive(Debug, Serialize, Clone)]
108pub struct DiffWarning {
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub table: Option<String>,
111    pub message: String,
112}
113
114/// Complete diff result
115#[derive(Debug, Serialize)]
116pub struct DiffResult {
117    /// Schema differences
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub schema: Option<SchemaDiff>,
120    /// Data differences
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub data: Option<DataDiff>,
123    /// Warnings generated during diff
124    #[serde(skip_serializing_if = "Vec::is_empty")]
125    pub warnings: Vec<DiffWarning>,
126    /// Summary statistics
127    pub summary: DiffSummary,
128}
129
130/// Summary of differences
131#[derive(Debug, Serialize)]
132pub struct DiffSummary {
133    /// Number of tables added
134    pub tables_added: usize,
135    /// Number of tables removed
136    pub tables_removed: usize,
137    /// Number of tables modified (schema or data)
138    pub tables_modified: usize,
139    /// Total rows added across all tables
140    pub rows_added: u64,
141    /// Total rows removed across all tables
142    pub rows_removed: u64,
143    /// Total rows modified across all tables
144    pub rows_modified: u64,
145    /// Whether any data was truncated due to memory limits
146    pub truncated: bool,
147}
148
149/// Main differ engine
150pub struct Differ {
151    config: DiffConfig,
152    dialect: SqlDialect,
153    progress_fn: Option<Arc<dyn Fn(u64, u64) + Send + Sync>>,
154}
155
156impl Differ {
157    /// Create a new differ with the given configuration
158    pub fn new(config: DiffConfig) -> Self {
159        Self {
160            dialect: config.dialect.unwrap_or(SqlDialect::MySql),
161            config,
162            progress_fn: None,
163        }
164    }
165
166    /// Set a progress callback (receives current bytes, total bytes)
167    pub fn with_progress<F>(mut self, f: F) -> Self
168    where
169        F: Fn(u64, u64) + Send + Sync + 'static,
170    {
171        self.progress_fn = Some(Arc::new(f));
172        self
173    }
174
175    /// Run the diff operation
176    pub fn diff(self) -> anyhow::Result<DiffResult> {
177        // Calculate total bytes for progress (4 passes max: 2 schema + 2 data)
178        let old_size = std::fs::metadata(&self.config.old_path)?.len();
179        let new_size = std::fs::metadata(&self.config.new_path)?.len();
180        let total_bytes = if self.config.schema_only || self.config.data_only {
181            old_size + new_size
182        } else {
183            (old_size + new_size) * 2 // Schema pass + data pass for each file
184        };
185
186        // Pass 0: Extract schemas from both files
187        let (old_schema, new_schema) = if !self.config.data_only {
188            let old = self.extract_schema(&self.config.old_path.clone(), 0, total_bytes)?;
189            let new = self.extract_schema(&self.config.new_path.clone(), old_size, total_bytes)?;
190            (Some(old), Some(new))
191        } else {
192            // Even for data-only, we need schema for PK info
193            let old = self.extract_schema(&self.config.old_path.clone(), 0, total_bytes)?;
194            let new = self.extract_schema(&self.config.new_path.clone(), old_size, total_bytes)?;
195            (Some(old), Some(new))
196        };
197
198        // Schema comparison
199        let schema_diff = if !self.config.data_only {
200            old_schema
201                .as_ref()
202                .zip(new_schema.as_ref())
203                .map(|(old, new)| compare_schemas(old, new, &self.config))
204        } else {
205            None
206        };
207
208        // Data comparison
209        let (data_diff, warnings) = if !self.config.schema_only {
210            let old_schema = old_schema
211                .as_ref()
212                .ok_or_else(|| anyhow::anyhow!("Schema required for data comparison"))?;
213            let new_schema = new_schema
214                .as_ref()
215                .ok_or_else(|| anyhow::anyhow!("Schema required for data comparison"))?;
216
217            let base_offset = if self.config.data_only {
218                0
219            } else {
220                old_size + new_size
221            };
222
223            let (data, warns) =
224                self.compare_data(old_schema, new_schema, base_offset, total_bytes)?;
225            (Some(data), warns)
226        } else {
227            (None, Vec::new())
228        };
229
230        // Build summary
231        let summary = self.build_summary(&schema_diff, &data_diff);
232
233        Ok(DiffResult {
234            schema: schema_diff,
235            data: data_diff,
236            warnings,
237            summary,
238        })
239    }
240
241    /// Extract schema from a SQL file
242    fn extract_schema(
243        &self,
244        path: &PathBuf,
245        byte_offset: u64,
246        total_bytes: u64,
247    ) -> anyhow::Result<Schema> {
248        let file = File::open(path)?;
249        let file_size = file.metadata()?.len();
250        let buffer_size = determine_buffer_size(file_size);
251        let compression = Compression::from_path(path);
252
253        let reader: Box<dyn Read> = if let Some(ref cb) = self.progress_fn {
254            let cb = Arc::clone(cb);
255            let progress_reader = ProgressReader::new(file, move |bytes| {
256                cb(byte_offset + bytes, total_bytes);
257            });
258            compression.wrap_reader(Box::new(progress_reader))?
259        } else {
260            compression.wrap_reader(Box::new(file))?
261        };
262
263        let mut parser = Parser::with_dialect(reader, buffer_size, self.dialect);
264        let mut builder = SchemaBuilder::new();
265
266        while let Some(stmt) = parser.read_statement()? {
267            let (stmt_type, _table_name) =
268                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, self.dialect);
269
270            match stmt_type {
271                StatementType::CreateTable => {
272                    if let Ok(stmt_str) = std::str::from_utf8(&stmt) {
273                        builder.parse_create_table(stmt_str);
274                    }
275                }
276                StatementType::AlterTable => {
277                    if let Ok(stmt_str) = std::str::from_utf8(&stmt) {
278                        builder.parse_alter_table(stmt_str);
279                    }
280                }
281                StatementType::CreateIndex => {
282                    if let Ok(stmt_str) = std::str::from_utf8(&stmt) {
283                        builder.parse_create_index(stmt_str);
284                    }
285                }
286                _ => {}
287            }
288        }
289
290        Ok(builder.build())
291    }
292
293    /// Compare data between two SQL files
294    fn compare_data(
295        &self,
296        old_schema: &Schema,
297        new_schema: &Schema,
298        byte_offset: u64,
299        total_bytes: u64,
300    ) -> anyhow::Result<(DataDiff, Vec<DiffWarning>)> {
301        let mut data_differ = DataDiffer::new(DataDiffOptions {
302            max_pk_entries_global: self.config.max_pk_entries,
303            max_pk_entries_per_table: self.config.max_pk_entries / 2,
304            sample_size: if self.config.verbose { 100 } else { 0 },
305            tables: self.config.tables.clone(),
306            exclude: self.config.exclude.clone(),
307            allow_no_pk: self.config.allow_no_pk,
308            pk_overrides: self.config.pk_overrides.clone(),
309            ignore_columns: self.config.ignore_columns.clone(),
310        });
311
312        let old_size = std::fs::metadata(&self.config.old_path)?.len();
313
314        // Pass 1: Scan old file
315        data_differ.scan_file(
316            &self.config.old_path,
317            old_schema,
318            self.dialect,
319            true, // is_old
320            &self.progress_fn,
321            byte_offset,
322            total_bytes,
323        )?;
324
325        // Pass 2: Scan new file
326        data_differ.scan_file(
327            &self.config.new_path,
328            new_schema,
329            self.dialect,
330            false, // is_old
331            &self.progress_fn,
332            byte_offset + old_size,
333            total_bytes,
334        )?;
335
336        Ok(data_differ.compute_diff())
337    }
338
339    /// Build summary from diff results
340    fn build_summary(
341        &self,
342        schema_diff: &Option<SchemaDiff>,
343        data_diff: &Option<DataDiff>,
344    ) -> DiffSummary {
345        let (tables_added, tables_removed, schema_modified) = schema_diff
346            .as_ref()
347            .map(|s| {
348                (
349                    s.tables_added.len(),
350                    s.tables_removed.len(),
351                    s.tables_modified.len(),
352                )
353            })
354            .unwrap_or((0, 0, 0));
355
356        let (rows_added, rows_removed, rows_modified, data_modified, truncated) = data_diff
357            .as_ref()
358            .map(|d| {
359                let mut added = 0u64;
360                let mut removed = 0u64;
361                let mut modified = 0u64;
362                let mut tables_with_changes = 0usize;
363                let mut any_truncated = false;
364
365                for table_diff in d.tables.values() {
366                    added += table_diff.added_count;
367                    removed += table_diff.removed_count;
368                    modified += table_diff.modified_count;
369                    if table_diff.added_count > 0
370                        || table_diff.removed_count > 0
371                        || table_diff.modified_count > 0
372                    {
373                        tables_with_changes += 1;
374                    }
375                    if table_diff.truncated {
376                        any_truncated = true;
377                    }
378                }
379
380                (added, removed, modified, tables_with_changes, any_truncated)
381            })
382            .unwrap_or((0, 0, 0, 0, false));
383
384        DiffSummary {
385            tables_added,
386            tables_removed,
387            tables_modified: schema_modified.max(data_modified),
388            rows_added,
389            rows_removed,
390            rows_modified,
391            truncated,
392        }
393    }
394}
395
396/// Check if a table should be included based on filter config
397pub fn should_include_table(table_name: &str, tables: &[String], exclude: &[String]) -> bool {
398    // If include list is specified, table must be in it
399    if !tables.is_empty() {
400        let name_lower = table_name.to_lowercase();
401        if !tables.iter().any(|t| t.to_lowercase() == name_lower) {
402            return false;
403        }
404    }
405
406    // If table is in exclude list, skip it
407    if !exclude.is_empty() {
408        let name_lower = table_name.to_lowercase();
409        if exclude.iter().any(|t| t.to_lowercase() == name_lower) {
410            return false;
411        }
412    }
413
414    true
415}