1use super::{should_include_table, DiffWarning};
7use glob::Pattern;
8use crate::parser::{
9 determine_buffer_size, mysql_insert, postgres_copy, Parser, SqlDialect, StatementType,
10};
11use crate::pk::{hash_pk_values, PkHash};
12use crate::progress::ProgressReader;
13use crate::schema::Schema;
14use crate::splitter::Compression;
15use ahash::AHashMap;
16use serde::Serialize;
17use std::collections::HashMap;
18use std::fs::File;
19use std::hash::{Hash, Hasher};
20use std::io::Read;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24#[derive(Debug, Clone)]
26pub struct DataDiffOptions {
27 pub max_pk_entries_global: usize,
29 pub max_pk_entries_per_table: usize,
31 pub sample_size: usize,
33 pub tables: Vec<String>,
35 pub exclude: Vec<String>,
37 pub allow_no_pk: bool,
39 pub pk_overrides: std::collections::HashMap<String, Vec<String>>,
41 pub ignore_columns: Vec<String>,
43}
44
45impl Default for DataDiffOptions {
46 fn default() -> Self {
47 Self {
48 max_pk_entries_global: 10_000_000,
49 max_pk_entries_per_table: 5_000_000,
50 sample_size: 0,
51 tables: Vec::new(),
52 exclude: Vec::new(),
53 allow_no_pk: false,
54 pk_overrides: std::collections::HashMap::new(),
55 ignore_columns: Vec::new(),
56 }
57 }
58}
59
60#[derive(Debug, Serialize)]
62pub struct DataDiff {
63 pub tables: HashMap<String, TableDataDiff>,
65}
66
67#[derive(Debug, Serialize, Clone, Default)]
69pub struct TableDataDiff {
70 pub old_row_count: u64,
72 pub new_row_count: u64,
74 pub added_count: u64,
76 pub removed_count: u64,
78 pub modified_count: u64,
80 pub truncated: bool,
82 #[serde(skip_serializing_if = "Vec::is_empty")]
84 pub sample_added_pks: Vec<String>,
85 #[serde(skip_serializing_if = "Vec::is_empty")]
87 pub sample_removed_pks: Vec<String>,
88 #[serde(skip_serializing_if = "Vec::is_empty")]
90 pub sample_modified_pks: Vec<String>,
91}
92
93struct TableState {
95 row_count: u64,
97 pk_to_digest: Option<AHashMap<PkHash, u64>>,
100 pk_strings: Option<AHashMap<PkHash, String>>,
102 truncated: bool,
104}
105
106impl TableState {
107 fn new() -> Self {
108 Self {
109 row_count: 0,
110 pk_to_digest: Some(AHashMap::new()),
111 pk_strings: None,
112 truncated: false,
113 }
114 }
115
116 fn new_with_pk_strings() -> Self {
117 Self {
118 row_count: 0,
119 pk_to_digest: Some(AHashMap::new()),
120 pk_strings: Some(AHashMap::new()),
121 truncated: false,
122 }
123 }
124}
125
126fn hash_row_digest(values: &[mysql_insert::PkValue]) -> u64 {
128 let mut hasher = ahash::AHasher::default();
129 for v in values {
130 match v {
131 mysql_insert::PkValue::Int(i) => {
132 0u8.hash(&mut hasher);
133 i.hash(&mut hasher);
134 }
135 mysql_insert::PkValue::BigInt(i) => {
136 1u8.hash(&mut hasher);
137 i.hash(&mut hasher);
138 }
139 mysql_insert::PkValue::Text(s) => {
140 2u8.hash(&mut hasher);
141 s.hash(&mut hasher);
142 }
143 mysql_insert::PkValue::Null => {
144 3u8.hash(&mut hasher);
145 }
146 }
147 }
148 hasher.finish()
149}
150
151fn format_single_pk(v: &mysql_insert::PkValue) -> String {
153 match v {
154 mysql_insert::PkValue::Int(i) => i.to_string(),
155 mysql_insert::PkValue::BigInt(i) => i.to_string(),
156 mysql_insert::PkValue::Text(s) => s.to_string(),
157 mysql_insert::PkValue::Null => "NULL".to_string(),
158 }
159}
160
161fn format_pk_value(pk: &[mysql_insert::PkValue]) -> String {
163 if pk.len() == 1 {
164 format_single_pk(&pk[0])
165 } else {
166 let parts: Vec<String> = pk.iter().map(format_single_pk).collect();
167 format!("({})", parts.join(", "))
168 }
169}
170
171fn parse_ignore_patterns(patterns: &[String]) -> Vec<Pattern> {
173 patterns
174 .iter()
175 .filter_map(|p| Pattern::new(&p.to_lowercase()).ok())
176 .collect()
177}
178
179fn should_ignore_column(table: &str, column: &str, patterns: &[Pattern]) -> bool {
181 let full_name = format!("{}.{}", table.to_lowercase(), column.to_lowercase());
182 patterns.iter().any(|p| p.matches(&full_name))
183}
184
185fn hash_row_digest_with_ignore(
187 values: &[mysql_insert::PkValue],
188 ignore_indices: &[usize],
189) -> u64 {
190 let mut hasher = ahash::AHasher::default();
191 for (i, v) in values.iter().enumerate() {
192 if ignore_indices.contains(&i) {
193 continue;
194 }
195 match v {
196 mysql_insert::PkValue::Int(val) => {
197 0u8.hash(&mut hasher);
198 val.hash(&mut hasher);
199 }
200 mysql_insert::PkValue::BigInt(val) => {
201 1u8.hash(&mut hasher);
202 val.hash(&mut hasher);
203 }
204 mysql_insert::PkValue::Text(s) => {
205 2u8.hash(&mut hasher);
206 s.hash(&mut hasher);
207 }
208 mysql_insert::PkValue::Null => {
209 3u8.hash(&mut hasher);
210 }
211 }
212 }
213 hasher.finish()
214}
215
216pub struct DataDiffer {
218 options: DataDiffOptions,
219 old_state: AHashMap<String, TableState>,
221 new_state: AHashMap<String, TableState>,
223 total_pk_entries: usize,
225 global_truncated: bool,
227 current_copy_context: Option<(String, Vec<String>)>,
229 warnings: Vec<DiffWarning>,
231 warned_tables: AHashMap<String, ()>,
233 ignore_patterns: Vec<Pattern>,
235 ignore_indices_cache: AHashMap<String, Vec<usize>>,
237}
238
239impl DataDiffer {
240 pub fn new(options: DataDiffOptions) -> Self {
242 let ignore_patterns = parse_ignore_patterns(&options.ignore_columns);
243 Self {
244 options,
245 old_state: AHashMap::new(),
246 new_state: AHashMap::new(),
247 total_pk_entries: 0,
248 global_truncated: false,
249 current_copy_context: None,
250 warnings: Vec::new(),
251 warned_tables: AHashMap::new(),
252 ignore_patterns,
253 ignore_indices_cache: AHashMap::new(),
254 }
255 }
256
257 fn get_ignore_indices(
259 &mut self,
260 table_name: &str,
261 table_schema: &crate::schema::TableSchema,
262 ) -> Vec<usize> {
263 let table_lower = table_name.to_lowercase();
264 if let Some(indices) = self.ignore_indices_cache.get(&table_lower) {
265 return indices.clone();
266 }
267
268 let pk_indices: Vec<usize> = table_schema
270 .primary_key
271 .iter()
272 .map(|col_id| col_id.0 as usize)
273 .collect();
274
275 let mut indices: Vec<usize> = Vec::new();
276 for (i, col) in table_schema.columns.iter().enumerate() {
277 if should_ignore_column(table_name, &col.name, &self.ignore_patterns) {
278 if pk_indices.contains(&i) && !self.warned_tables.contains_key(&table_lower) {
280 self.warnings.push(DiffWarning {
281 table: Some(table_name.to_string()),
282 message: format!(
283 "Ignoring primary key column '{}' may affect diff accuracy",
284 col.name
285 ),
286 });
287 }
288 indices.push(i);
289 }
290 }
291
292 self.ignore_indices_cache
293 .insert(table_lower, indices.clone());
294 indices
295 }
296
297 fn get_effective_pk_indices(
300 &self,
301 table_name: &str,
302 table_schema: &crate::schema::TableSchema,
303 ) -> (Vec<usize>, bool, Vec<String>) {
304 if let Some(override_cols) = self.options.pk_overrides.get(&table_name.to_lowercase()) {
305 let mut indices: Vec<usize> = Vec::new();
306 let mut invalid_cols: Vec<String> = Vec::new();
307
308 for col_name in override_cols {
309 if let Some(idx) = table_schema
310 .columns
311 .iter()
312 .position(|c| c.name.eq_ignore_ascii_case(col_name))
313 {
314 indices.push(idx);
315 } else {
316 invalid_cols.push(col_name.clone());
317 }
318 }
319
320 (indices, true, invalid_cols)
321 } else {
322 let indices: Vec<usize> = table_schema
323 .primary_key
324 .iter()
325 .map(|col_id| col_id.0 as usize)
326 .collect();
327 (indices, false, Vec::new())
328 }
329 }
330
331 fn extract_pk_from_values(
333 &self,
334 all_values: &[mysql_insert::PkValue],
335 pk_indices: &[usize],
336 ) -> Option<smallvec::SmallVec<[mysql_insert::PkValue; 2]>> {
337 if pk_indices.is_empty() {
338 return None;
339 }
340 let mut pk_values: smallvec::SmallVec<[mysql_insert::PkValue; 2]> =
341 smallvec::SmallVec::new();
342 for &idx in pk_indices {
343 if let Some(val) = all_values.get(idx) {
344 if val.is_null() {
345 return None;
346 }
347 pk_values.push(val.clone());
348 } else {
349 return None;
350 }
351 }
352 if pk_values.is_empty() {
353 None
354 } else {
355 Some(pk_values)
356 }
357 }
358
359 #[allow(clippy::too_many_arguments)]
361 pub fn scan_file(
362 &mut self,
363 path: &PathBuf,
364 schema: &Schema,
365 dialect: SqlDialect,
366 is_old: bool,
367 progress_fn: &Option<Arc<dyn Fn(u64, u64) + Send + Sync>>,
368 byte_offset: u64,
369 total_bytes: u64,
370 ) -> anyhow::Result<()> {
371 let file = File::open(path)?;
372 let file_size = file.metadata()?.len();
373 let buffer_size = determine_buffer_size(file_size);
374 let compression = Compression::from_path(path);
375
376 let reader: Box<dyn Read> = if let Some(ref cb) = progress_fn {
377 let cb = Arc::clone(cb);
378 let progress_reader = ProgressReader::new(file, move |bytes| {
379 cb(byte_offset + bytes, total_bytes);
380 });
381 compression.wrap_reader(Box::new(progress_reader))
382 } else {
383 compression.wrap_reader(Box::new(file))
384 };
385
386 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
387
388 self.current_copy_context = None;
390
391 while let Some(stmt) = parser.read_statement()? {
392 let (stmt_type, table_name) =
393 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
394
395 if dialect == SqlDialect::Postgres && stmt_type == StatementType::Unknown {
397 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
399 if let Some((ref copy_table, ref column_order)) =
400 self.current_copy_context.clone()
401 {
402 if should_include_table(
404 copy_table,
405 &self.options.tables,
406 &self.options.exclude,
407 ) {
408 if let Some(table_schema) = schema.get_table(copy_table) {
409 let has_pk = !table_schema.primary_key.is_empty();
410 let has_pk_override = self
411 .options
412 .pk_overrides
413 .contains_key(©_table.to_lowercase());
414 if has_pk || self.options.allow_no_pk || has_pk_override {
415 self.process_copy_data(
416 &stmt,
417 copy_table,
418 table_schema,
419 column_order.clone(),
420 is_old,
421 )?;
422 } else if !self.warned_tables.contains_key(copy_table) {
423 self.warned_tables.insert(copy_table.clone(), ());
424 self.warnings.push(DiffWarning {
425 table: Some(copy_table.clone()),
426 message: "No primary key, data comparison skipped"
427 .to_string(),
428 });
429 }
430 }
431 }
432 }
433 }
434 self.current_copy_context = None;
435 continue;
436 }
437
438 if table_name.is_empty() {
439 continue;
440 }
441
442 if !should_include_table(&table_name, &self.options.tables, &self.options.exclude) {
444 continue;
445 }
446
447 let table_schema = match schema.get_table(&table_name) {
449 Some(t) => t,
450 None => continue,
451 };
452
453 let has_pk_override = self
455 .options
456 .pk_overrides
457 .contains_key(&table_name.to_lowercase());
458 if table_schema.primary_key.is_empty() && !self.options.allow_no_pk && !has_pk_override
459 {
460 if !self.warned_tables.contains_key(&table_name) {
462 self.warned_tables.insert(table_name.clone(), ());
463 self.warnings.push(DiffWarning {
464 table: Some(table_name.clone()),
465 message: "No primary key, data comparison skipped".to_string(),
466 });
467 }
468 continue;
469 }
471
472 match stmt_type {
473 StatementType::Insert => {
474 self.process_insert_statement(&stmt, &table_name, table_schema, is_old)?;
475 }
476 StatementType::Copy => {
477 let header = String::from_utf8_lossy(&stmt);
480 let column_order = postgres_copy::parse_copy_columns(&header);
481 self.current_copy_context = Some((table_name.clone(), column_order));
482 }
483 _ => {}
484 }
485 }
486
487 Ok(())
488 }
489
490 fn process_insert_statement(
492 &mut self,
493 stmt: &[u8],
494 table_name: &str,
495 table_schema: &crate::schema::TableSchema,
496 is_old: bool,
497 ) -> anyhow::Result<()> {
498 let rows = mysql_insert::parse_mysql_insert_rows(stmt, table_schema)?;
499
500 let (pk_indices, has_override, invalid_cols) =
501 self.get_effective_pk_indices(table_name, table_schema);
502
503 let ignore_indices = self.get_ignore_indices(table_name, table_schema);
505
506 if !invalid_cols.is_empty() && !self.warned_tables.contains_key(table_name) {
508 self.warned_tables.insert(table_name.to_string(), ());
509 self.warnings.push(DiffWarning {
510 table: Some(table_name.to_string()),
511 message: format!(
512 "Primary key override column(s) not found: {}",
513 invalid_cols.join(", ")
514 ),
515 });
516 }
517
518 for row in rows {
519 let effective_pk = if has_override {
520 self.extract_pk_from_values(&row.all_values, &pk_indices)
521 } else {
522 row.pk
523 };
524 self.record_row(table_name, &effective_pk, &row.all_values, is_old, &ignore_indices);
525 }
526
527 Ok(())
528 }
529
530 fn process_copy_data(
532 &mut self,
533 data_stmt: &[u8],
534 table_name: &str,
535 table_schema: &crate::schema::TableSchema,
536 column_order: Vec<String>,
537 is_old: bool,
538 ) -> anyhow::Result<()> {
539 let data = data_stmt
542 .iter()
543 .skip_while(|&&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
544 .cloned()
545 .collect::<Vec<u8>>();
546
547 if data.is_empty() {
548 return Ok(());
549 }
550
551 let rows = postgres_copy::parse_postgres_copy_rows(&data, table_schema, column_order)?;
552
553 let (pk_indices, has_override, invalid_cols) =
554 self.get_effective_pk_indices(table_name, table_schema);
555
556 let ignore_indices = self.get_ignore_indices(table_name, table_schema);
558
559 if !invalid_cols.is_empty() && !self.warned_tables.contains_key(table_name) {
561 self.warned_tables.insert(table_name.to_string(), ());
562 self.warnings.push(DiffWarning {
563 table: Some(table_name.to_string()),
564 message: format!(
565 "Primary key override column(s) not found: {}",
566 invalid_cols.join(", ")
567 ),
568 });
569 }
570
571 for row in rows {
572 let effective_pk = if has_override {
573 self.extract_pk_from_values(&row.all_values, &pk_indices)
574 } else {
575 row.pk
576 };
577 self.record_row(table_name, &effective_pk, &row.all_values, is_old, &ignore_indices);
578 }
579
580 Ok(())
581 }
582
583 fn record_row(
585 &mut self,
586 table_name: &str,
587 pk: &Option<smallvec::SmallVec<[mysql_insert::PkValue; 2]>>,
588 all_values: &[mysql_insert::PkValue],
589 is_old: bool,
590 ignore_indices: &[usize],
591 ) {
592 if self.global_truncated {
593 let state = if is_old {
595 self.old_state
596 .entry(table_name.to_string())
597 .or_insert_with(|| {
598 let mut s = TableState::new();
599 s.pk_to_digest = None;
600 s.truncated = true;
601 s
602 })
603 } else {
604 self.new_state
605 .entry(table_name.to_string())
606 .or_insert_with(|| {
607 let mut s = TableState::new();
608 s.pk_to_digest = None;
609 s.truncated = true;
610 s
611 })
612 };
613 state.row_count += 1;
614 return;
615 }
616
617 let sample_size = self.options.sample_size;
618 let state = if is_old {
619 self.old_state
620 .entry(table_name.to_string())
621 .or_insert_with(|| {
622 if sample_size > 0 {
623 TableState::new_with_pk_strings()
624 } else {
625 TableState::new()
626 }
627 })
628 } else {
629 self.new_state
630 .entry(table_name.to_string())
631 .or_insert_with(|| {
632 if sample_size > 0 {
633 TableState::new_with_pk_strings()
634 } else {
635 TableState::new()
636 }
637 })
638 };
639
640 state.row_count += 1;
641
642 if let Some(ref map) = state.pk_to_digest {
644 if map.len() >= self.options.max_pk_entries_per_table {
645 state.pk_to_digest = None;
646 state.pk_strings = None;
647 state.truncated = true;
648 return;
649 }
650 }
651
652 if self.total_pk_entries >= self.options.max_pk_entries_global {
654 self.global_truncated = true;
655 state.pk_to_digest = None;
656 state.pk_strings = None;
657 state.truncated = true;
658 return;
659 }
660
661 if let Some(ref pk_values) = pk {
663 if let Some(ref mut map) = state.pk_to_digest {
664 let pk_hash = hash_pk_values(pk_values);
665 let row_digest = if ignore_indices.is_empty() {
666 hash_row_digest(all_values)
667 } else {
668 hash_row_digest_with_ignore(all_values, ignore_indices)
669 };
670 map.insert(pk_hash, row_digest);
671 self.total_pk_entries += 1;
672
673 if let Some(ref mut pk_str_map) = state.pk_strings {
675 pk_str_map.insert(pk_hash, format_pk_value(pk_values));
676 }
677 }
678 }
679 }
680
681 pub fn compute_diff(self) -> (DataDiff, Vec<DiffWarning>) {
683 let mut tables: HashMap<String, TableDataDiff> = HashMap::new();
684 let sample_size = self.options.sample_size;
685
686 let mut all_tables: Vec<String> = self.old_state.keys().cloned().collect();
688 for name in self.new_state.keys() {
689 if !all_tables.contains(name) {
690 all_tables.push(name.clone());
691 }
692 }
693
694 for table_name in all_tables {
695 let old_state = self.old_state.get(&table_name);
696 let new_state = self.new_state.get(&table_name);
697
698 let mut diff = TableDataDiff {
699 old_row_count: old_state.map(|s| s.row_count).unwrap_or(0),
700 new_row_count: new_state.map(|s| s.row_count).unwrap_or(0),
701 truncated: old_state.map(|s| s.truncated).unwrap_or(false)
702 || new_state.map(|s| s.truncated).unwrap_or(false)
703 || self.global_truncated,
704 ..Default::default()
705 };
706
707 let old_map = old_state.and_then(|s| s.pk_to_digest.as_ref());
709 let new_map = new_state.and_then(|s| s.pk_to_digest.as_ref());
710
711 let old_pk_strings = old_state.and_then(|s| s.pk_strings.as_ref());
713 let new_pk_strings = new_state.and_then(|s| s.pk_strings.as_ref());
714
715 match (old_map, new_map) {
716 (Some(old), Some(new)) => {
717 for pk_hash in new.keys() {
719 if !old.contains_key(pk_hash) {
720 diff.added_count += 1;
721
722 if sample_size > 0 && diff.sample_added_pks.len() < sample_size {
724 if let Some(pk_str) = new_pk_strings.and_then(|m| m.get(pk_hash)) {
725 diff.sample_added_pks.push(pk_str.clone());
726 }
727 }
728 }
729 }
730
731 for (pk_hash, old_digest) in old {
733 match new.get(pk_hash) {
734 None => {
735 diff.removed_count += 1;
736
737 if sample_size > 0 && diff.sample_removed_pks.len() < sample_size {
739 if let Some(pk_str) =
740 old_pk_strings.and_then(|m| m.get(pk_hash))
741 {
742 diff.sample_removed_pks.push(pk_str.clone());
743 }
744 }
745 }
746 Some(new_digest) => {
747 if old_digest != new_digest {
748 diff.modified_count += 1;
749
750 if sample_size > 0
752 && diff.sample_modified_pks.len() < sample_size
753 {
754 if let Some(pk_str) =
755 old_pk_strings.and_then(|m| m.get(pk_hash))
756 {
757 diff.sample_modified_pks.push(pk_str.clone());
758 }
759 }
760 }
761 }
762 }
763 }
764 }
765 _ => {
766 if diff.new_row_count > diff.old_row_count {
768 diff.added_count = diff.new_row_count - diff.old_row_count;
769 } else if diff.old_row_count > diff.new_row_count {
770 diff.removed_count = diff.old_row_count - diff.new_row_count;
771 }
772 }
773 }
774
775 if diff.old_row_count > 0
777 || diff.new_row_count > 0
778 || diff.added_count > 0
779 || diff.removed_count > 0
780 || diff.modified_count > 0
781 {
782 tables.insert(table_name, diff);
783 }
784 }
785
786 (DataDiff { tables }, self.warnings)
787 }
788}