1use super::{should_include_table, DiffWarning};
7use crate::parser::{
8 determine_buffer_size, mysql_insert, postgres_copy, Parser, SqlDialect, StatementType,
9};
10use crate::pk::{hash_pk_values, PkHash};
11use crate::progress::ProgressReader;
12use crate::schema::Schema;
13use crate::splitter::Compression;
14use ahash::AHashMap;
15use glob::Pattern;
16use serde::Serialize;
17use std::collections::HashMap;
18use std::fs::File;
19use std::hash::{Hash, Hasher};
20use std::io::Read;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24#[derive(Debug, Clone)]
26pub struct DataDiffOptions {
27 pub max_pk_entries_global: usize,
29 pub max_pk_entries_per_table: usize,
31 pub sample_size: usize,
33 pub tables: Vec<String>,
35 pub exclude: Vec<String>,
37 pub allow_no_pk: bool,
39 pub pk_overrides: std::collections::HashMap<String, Vec<String>>,
41 pub ignore_columns: Vec<String>,
43}
44
45impl Default for DataDiffOptions {
46 fn default() -> Self {
47 Self {
48 max_pk_entries_global: 10_000_000,
49 max_pk_entries_per_table: 5_000_000,
50 sample_size: 0,
51 tables: Vec::new(),
52 exclude: Vec::new(),
53 allow_no_pk: false,
54 pk_overrides: std::collections::HashMap::new(),
55 ignore_columns: Vec::new(),
56 }
57 }
58}
59
60#[derive(Debug, Serialize)]
62pub struct DataDiff {
63 pub tables: HashMap<String, TableDataDiff>,
65}
66
67#[derive(Debug, Serialize, Clone, Default)]
69pub struct TableDataDiff {
70 pub old_row_count: u64,
72 pub new_row_count: u64,
74 pub added_count: u64,
76 pub removed_count: u64,
78 pub modified_count: u64,
80 pub truncated: bool,
82 #[serde(skip_serializing_if = "Vec::is_empty")]
84 pub sample_added_pks: Vec<String>,
85 #[serde(skip_serializing_if = "Vec::is_empty")]
87 pub sample_removed_pks: Vec<String>,
88 #[serde(skip_serializing_if = "Vec::is_empty")]
90 pub sample_modified_pks: Vec<String>,
91}
92
93struct TableState {
95 row_count: u64,
97 pk_to_digest: Option<AHashMap<PkHash, u64>>,
100 pk_strings: Option<AHashMap<PkHash, String>>,
102 truncated: bool,
104}
105
106impl TableState {
107 fn new() -> Self {
108 Self {
109 row_count: 0,
110 pk_to_digest: Some(AHashMap::new()),
111 pk_strings: None,
112 truncated: false,
113 }
114 }
115
116 fn new_with_pk_strings() -> Self {
117 Self {
118 row_count: 0,
119 pk_to_digest: Some(AHashMap::new()),
120 pk_strings: Some(AHashMap::new()),
121 truncated: false,
122 }
123 }
124}
125
126fn hash_row_digest(values: &[mysql_insert::PkValue]) -> u64 {
128 let mut hasher = ahash::AHasher::default();
129 for v in values {
130 match v {
131 mysql_insert::PkValue::Int(i) => {
132 0u8.hash(&mut hasher);
133 i.hash(&mut hasher);
134 }
135 mysql_insert::PkValue::BigInt(i) => {
136 1u8.hash(&mut hasher);
137 i.hash(&mut hasher);
138 }
139 mysql_insert::PkValue::Text(s) => {
140 2u8.hash(&mut hasher);
141 s.hash(&mut hasher);
142 }
143 mysql_insert::PkValue::Null => {
144 3u8.hash(&mut hasher);
145 }
146 }
147 }
148 hasher.finish()
149}
150
151fn format_single_pk(v: &mysql_insert::PkValue) -> String {
153 match v {
154 mysql_insert::PkValue::Int(i) => i.to_string(),
155 mysql_insert::PkValue::BigInt(i) => i.to_string(),
156 mysql_insert::PkValue::Text(s) => s.to_string(),
157 mysql_insert::PkValue::Null => "NULL".to_string(),
158 }
159}
160
161fn format_pk_value(pk: &[mysql_insert::PkValue]) -> String {
163 if pk.len() == 1 {
164 format_single_pk(&pk[0])
165 } else {
166 let parts: Vec<String> = pk.iter().map(format_single_pk).collect();
167 format!("({})", parts.join(", "))
168 }
169}
170
171fn parse_ignore_patterns(patterns: &[String]) -> Vec<Pattern> {
173 patterns
174 .iter()
175 .filter_map(|p| Pattern::new(&p.to_lowercase()).ok())
176 .collect()
177}
178
179fn should_ignore_column(table: &str, column: &str, patterns: &[Pattern]) -> bool {
181 let full_name = format!("{}.{}", table.to_lowercase(), column.to_lowercase());
182 patterns.iter().any(|p| p.matches(&full_name))
183}
184
185fn hash_row_digest_with_ignore(values: &[mysql_insert::PkValue], ignore_indices: &[usize]) -> u64 {
187 let mut hasher = ahash::AHasher::default();
188 for (i, v) in values.iter().enumerate() {
189 if ignore_indices.contains(&i) {
190 continue;
191 }
192 match v {
193 mysql_insert::PkValue::Int(val) => {
194 0u8.hash(&mut hasher);
195 val.hash(&mut hasher);
196 }
197 mysql_insert::PkValue::BigInt(val) => {
198 1u8.hash(&mut hasher);
199 val.hash(&mut hasher);
200 }
201 mysql_insert::PkValue::Text(s) => {
202 2u8.hash(&mut hasher);
203 s.hash(&mut hasher);
204 }
205 mysql_insert::PkValue::Null => {
206 3u8.hash(&mut hasher);
207 }
208 }
209 }
210 hasher.finish()
211}
212
213pub struct DataDiffer {
215 options: DataDiffOptions,
216 old_state: AHashMap<String, TableState>,
218 new_state: AHashMap<String, TableState>,
220 total_pk_entries: usize,
222 global_truncated: bool,
224 current_copy_context: Option<(String, Vec<String>)>,
226 warnings: Vec<DiffWarning>,
228 warned_tables: AHashMap<String, ()>,
230 ignore_patterns: Vec<Pattern>,
232 ignore_indices_cache: AHashMap<String, Vec<usize>>,
234}
235
236impl DataDiffer {
237 pub fn new(options: DataDiffOptions) -> Self {
239 let ignore_patterns = parse_ignore_patterns(&options.ignore_columns);
240 Self {
241 options,
242 old_state: AHashMap::new(),
243 new_state: AHashMap::new(),
244 total_pk_entries: 0,
245 global_truncated: false,
246 current_copy_context: None,
247 warnings: Vec::new(),
248 warned_tables: AHashMap::new(),
249 ignore_patterns,
250 ignore_indices_cache: AHashMap::new(),
251 }
252 }
253
254 fn get_ignore_indices(
256 &mut self,
257 table_name: &str,
258 table_schema: &crate::schema::TableSchema,
259 ) -> Vec<usize> {
260 let table_lower = table_name.to_lowercase();
261 if let Some(indices) = self.ignore_indices_cache.get(&table_lower) {
262 return indices.clone();
263 }
264
265 let pk_indices: Vec<usize> = table_schema
267 .primary_key
268 .iter()
269 .map(|col_id| col_id.0 as usize)
270 .collect();
271
272 let mut indices: Vec<usize> = Vec::new();
273 for (i, col) in table_schema.columns.iter().enumerate() {
274 if should_ignore_column(table_name, &col.name, &self.ignore_patterns) {
275 if pk_indices.contains(&i) && !self.warned_tables.contains_key(&table_lower) {
277 self.warnings.push(DiffWarning {
278 table: Some(table_name.to_string()),
279 message: format!(
280 "Ignoring primary key column '{}' may affect diff accuracy",
281 col.name
282 ),
283 });
284 }
285 indices.push(i);
286 }
287 }
288
289 self.ignore_indices_cache
290 .insert(table_lower, indices.clone());
291 indices
292 }
293
294 fn get_effective_pk_indices(
297 &self,
298 table_name: &str,
299 table_schema: &crate::schema::TableSchema,
300 ) -> (Vec<usize>, bool, Vec<String>) {
301 if let Some(override_cols) = self.options.pk_overrides.get(&table_name.to_lowercase()) {
302 let mut indices: Vec<usize> = Vec::new();
303 let mut invalid_cols: Vec<String> = Vec::new();
304
305 for col_name in override_cols {
306 if let Some(idx) = table_schema
307 .columns
308 .iter()
309 .position(|c| c.name.eq_ignore_ascii_case(col_name))
310 {
311 indices.push(idx);
312 } else {
313 invalid_cols.push(col_name.clone());
314 }
315 }
316
317 (indices, true, invalid_cols)
318 } else {
319 let indices: Vec<usize> = table_schema
320 .primary_key
321 .iter()
322 .map(|col_id| col_id.0 as usize)
323 .collect();
324 (indices, false, Vec::new())
325 }
326 }
327
328 fn extract_pk_from_values(
330 &self,
331 all_values: &[mysql_insert::PkValue],
332 pk_indices: &[usize],
333 ) -> Option<smallvec::SmallVec<[mysql_insert::PkValue; 2]>> {
334 if pk_indices.is_empty() {
335 return None;
336 }
337 let mut pk_values: smallvec::SmallVec<[mysql_insert::PkValue; 2]> =
338 smallvec::SmallVec::new();
339 for &idx in pk_indices {
340 if let Some(val) = all_values.get(idx) {
341 if val.is_null() {
342 return None;
343 }
344 pk_values.push(val.clone());
345 } else {
346 return None;
347 }
348 }
349 if pk_values.is_empty() {
350 None
351 } else {
352 Some(pk_values)
353 }
354 }
355
356 #[allow(clippy::too_many_arguments)]
358 pub fn scan_file(
359 &mut self,
360 path: &PathBuf,
361 schema: &Schema,
362 dialect: SqlDialect,
363 is_old: bool,
364 progress_fn: &Option<Arc<dyn Fn(u64, u64) + Send + Sync>>,
365 byte_offset: u64,
366 total_bytes: u64,
367 ) -> anyhow::Result<()> {
368 let file = File::open(path)?;
369 let file_size = file.metadata()?.len();
370 let buffer_size = determine_buffer_size(file_size);
371 let compression = Compression::from_path(path);
372
373 let reader: Box<dyn Read> = if let Some(ref cb) = progress_fn {
374 let cb = Arc::clone(cb);
375 let progress_reader = ProgressReader::new(file, move |bytes| {
376 cb(byte_offset + bytes, total_bytes);
377 });
378 compression.wrap_reader(Box::new(progress_reader))?
379 } else {
380 compression.wrap_reader(Box::new(file))?
381 };
382
383 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
384
385 self.current_copy_context = None;
387
388 while let Some(stmt) = parser.read_statement()? {
389 let (stmt_type, table_name) =
390 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
391
392 if dialect == SqlDialect::Postgres && stmt_type == StatementType::Unknown {
394 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
396 if let Some((ref copy_table, ref column_order)) =
397 self.current_copy_context.clone()
398 {
399 if should_include_table(
401 copy_table,
402 &self.options.tables,
403 &self.options.exclude,
404 ) {
405 if let Some(table_schema) = schema.get_table(copy_table) {
406 let has_pk = !table_schema.primary_key.is_empty();
407 let has_pk_override = self
408 .options
409 .pk_overrides
410 .contains_key(©_table.to_lowercase());
411 if has_pk || self.options.allow_no_pk || has_pk_override {
412 self.process_copy_data(
413 &stmt,
414 copy_table,
415 table_schema,
416 column_order.clone(),
417 is_old,
418 )?;
419 } else if !self.warned_tables.contains_key(copy_table) {
420 self.warned_tables.insert(copy_table.clone(), ());
421 self.warnings.push(DiffWarning {
422 table: Some(copy_table.clone()),
423 message: "No primary key, data comparison skipped"
424 .to_string(),
425 });
426 }
427 }
428 }
429 }
430 }
431 self.current_copy_context = None;
432 continue;
433 }
434
435 if table_name.is_empty() {
436 continue;
437 }
438
439 if !should_include_table(&table_name, &self.options.tables, &self.options.exclude) {
441 continue;
442 }
443
444 let table_schema = match schema.get_table(&table_name) {
446 Some(t) => t,
447 None => continue,
448 };
449
450 let has_pk_override = self
452 .options
453 .pk_overrides
454 .contains_key(&table_name.to_lowercase());
455 if table_schema.primary_key.is_empty() && !self.options.allow_no_pk && !has_pk_override
456 {
457 if !self.warned_tables.contains_key(&table_name) {
459 self.warned_tables.insert(table_name.clone(), ());
460 self.warnings.push(DiffWarning {
461 table: Some(table_name.clone()),
462 message: "No primary key, data comparison skipped".to_string(),
463 });
464 }
465 continue;
466 }
468
469 match stmt_type {
470 StatementType::Insert => {
471 self.process_insert_statement(&stmt, &table_name, table_schema, is_old)?;
472 }
473 StatementType::Copy => {
474 let header = String::from_utf8_lossy(&stmt);
477 let column_order = postgres_copy::parse_copy_columns(&header);
478 self.current_copy_context = Some((table_name.clone(), column_order));
479 }
480 _ => {}
481 }
482 }
483
484 Ok(())
485 }
486
487 fn process_insert_statement(
489 &mut self,
490 stmt: &[u8],
491 table_name: &str,
492 table_schema: &crate::schema::TableSchema,
493 is_old: bool,
494 ) -> anyhow::Result<()> {
495 let rows = mysql_insert::parse_mysql_insert_rows(stmt, table_schema)?;
496
497 let (pk_indices, has_override, invalid_cols) =
498 self.get_effective_pk_indices(table_name, table_schema);
499
500 let ignore_indices = self.get_ignore_indices(table_name, table_schema);
502
503 if !invalid_cols.is_empty() && !self.warned_tables.contains_key(table_name) {
505 self.warned_tables.insert(table_name.to_string(), ());
506 self.warnings.push(DiffWarning {
507 table: Some(table_name.to_string()),
508 message: format!(
509 "Primary key override column(s) not found: {}",
510 invalid_cols.join(", ")
511 ),
512 });
513 }
514
515 for row in rows {
516 let effective_pk = if has_override {
517 self.extract_pk_from_values(&row.all_values, &pk_indices)
518 } else {
519 row.pk
520 };
521 self.record_row(
522 table_name,
523 &effective_pk,
524 &row.all_values,
525 is_old,
526 &ignore_indices,
527 );
528 }
529
530 Ok(())
531 }
532
533 fn process_copy_data(
535 &mut self,
536 data_stmt: &[u8],
537 table_name: &str,
538 table_schema: &crate::schema::TableSchema,
539 column_order: Vec<String>,
540 is_old: bool,
541 ) -> anyhow::Result<()> {
542 let data = data_stmt
545 .iter()
546 .skip_while(|&&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
547 .cloned()
548 .collect::<Vec<u8>>();
549
550 if data.is_empty() {
551 return Ok(());
552 }
553
554 let rows = postgres_copy::parse_postgres_copy_rows(&data, table_schema, column_order)?;
555
556 let (pk_indices, has_override, invalid_cols) =
557 self.get_effective_pk_indices(table_name, table_schema);
558
559 let ignore_indices = self.get_ignore_indices(table_name, table_schema);
561
562 if !invalid_cols.is_empty() && !self.warned_tables.contains_key(table_name) {
564 self.warned_tables.insert(table_name.to_string(), ());
565 self.warnings.push(DiffWarning {
566 table: Some(table_name.to_string()),
567 message: format!(
568 "Primary key override column(s) not found: {}",
569 invalid_cols.join(", ")
570 ),
571 });
572 }
573
574 for row in rows {
575 let effective_pk = if has_override {
576 self.extract_pk_from_values(&row.all_values, &pk_indices)
577 } else {
578 row.pk
579 };
580 self.record_row(
581 table_name,
582 &effective_pk,
583 &row.all_values,
584 is_old,
585 &ignore_indices,
586 );
587 }
588
589 Ok(())
590 }
591
592 fn record_row(
594 &mut self,
595 table_name: &str,
596 pk: &Option<smallvec::SmallVec<[mysql_insert::PkValue; 2]>>,
597 all_values: &[mysql_insert::PkValue],
598 is_old: bool,
599 ignore_indices: &[usize],
600 ) {
601 if self.global_truncated {
602 let state = if is_old {
604 self.old_state
605 .entry(table_name.to_string())
606 .or_insert_with(|| {
607 let mut s = TableState::new();
608 s.pk_to_digest = None;
609 s.truncated = true;
610 s
611 })
612 } else {
613 self.new_state
614 .entry(table_name.to_string())
615 .or_insert_with(|| {
616 let mut s = TableState::new();
617 s.pk_to_digest = None;
618 s.truncated = true;
619 s
620 })
621 };
622 state.row_count += 1;
623 return;
624 }
625
626 let sample_size = self.options.sample_size;
627 let state = if is_old {
628 self.old_state
629 .entry(table_name.to_string())
630 .or_insert_with(|| {
631 if sample_size > 0 {
632 TableState::new_with_pk_strings()
633 } else {
634 TableState::new()
635 }
636 })
637 } else {
638 self.new_state
639 .entry(table_name.to_string())
640 .or_insert_with(|| {
641 if sample_size > 0 {
642 TableState::new_with_pk_strings()
643 } else {
644 TableState::new()
645 }
646 })
647 };
648
649 state.row_count += 1;
650
651 if let Some(ref map) = state.pk_to_digest {
653 if map.len() >= self.options.max_pk_entries_per_table {
654 state.pk_to_digest = None;
655 state.pk_strings = None;
656 state.truncated = true;
657 return;
658 }
659 }
660
661 if self.total_pk_entries >= self.options.max_pk_entries_global {
663 self.global_truncated = true;
664 state.pk_to_digest = None;
665 state.pk_strings = None;
666 state.truncated = true;
667 return;
668 }
669
670 if let Some(ref pk_values) = pk {
672 if let Some(ref mut map) = state.pk_to_digest {
673 let pk_hash = hash_pk_values(pk_values);
674 let row_digest = if ignore_indices.is_empty() {
675 hash_row_digest(all_values)
676 } else {
677 hash_row_digest_with_ignore(all_values, ignore_indices)
678 };
679 map.insert(pk_hash, row_digest);
680 self.total_pk_entries += 1;
681
682 if let Some(ref mut pk_str_map) = state.pk_strings {
684 pk_str_map.insert(pk_hash, format_pk_value(pk_values));
685 }
686 }
687 }
688 }
689
690 pub fn compute_diff(self) -> (DataDiff, Vec<DiffWarning>) {
692 let mut tables: HashMap<String, TableDataDiff> = HashMap::new();
693 let sample_size = self.options.sample_size;
694
695 let mut all_tables: Vec<String> = self.old_state.keys().cloned().collect();
697 for name in self.new_state.keys() {
698 if !all_tables.contains(name) {
699 all_tables.push(name.clone());
700 }
701 }
702
703 for table_name in all_tables {
704 let old_state = self.old_state.get(&table_name);
705 let new_state = self.new_state.get(&table_name);
706
707 let mut diff = TableDataDiff {
708 old_row_count: old_state.map(|s| s.row_count).unwrap_or(0),
709 new_row_count: new_state.map(|s| s.row_count).unwrap_or(0),
710 truncated: old_state.map(|s| s.truncated).unwrap_or(false)
711 || new_state.map(|s| s.truncated).unwrap_or(false)
712 || self.global_truncated,
713 ..Default::default()
714 };
715
716 let old_map = old_state.and_then(|s| s.pk_to_digest.as_ref());
718 let new_map = new_state.and_then(|s| s.pk_to_digest.as_ref());
719
720 let old_pk_strings = old_state.and_then(|s| s.pk_strings.as_ref());
722 let new_pk_strings = new_state.and_then(|s| s.pk_strings.as_ref());
723
724 match (old_map, new_map) {
725 (Some(old), Some(new)) => {
726 for pk_hash in new.keys() {
728 if !old.contains_key(pk_hash) {
729 diff.added_count += 1;
730
731 if sample_size > 0 && diff.sample_added_pks.len() < sample_size {
733 if let Some(pk_str) = new_pk_strings.and_then(|m| m.get(pk_hash)) {
734 diff.sample_added_pks.push(pk_str.clone());
735 }
736 }
737 }
738 }
739
740 for (pk_hash, old_digest) in old {
742 match new.get(pk_hash) {
743 None => {
744 diff.removed_count += 1;
745
746 if sample_size > 0 && diff.sample_removed_pks.len() < sample_size {
748 if let Some(pk_str) =
749 old_pk_strings.and_then(|m| m.get(pk_hash))
750 {
751 diff.sample_removed_pks.push(pk_str.clone());
752 }
753 }
754 }
755 Some(new_digest) => {
756 if old_digest != new_digest {
757 diff.modified_count += 1;
758
759 if sample_size > 0
761 && diff.sample_modified_pks.len() < sample_size
762 {
763 if let Some(pk_str) =
764 old_pk_strings.and_then(|m| m.get(pk_hash))
765 {
766 diff.sample_modified_pks.push(pk_str.clone());
767 }
768 }
769 }
770 }
771 }
772 }
773 }
774 _ => {
775 if diff.new_row_count > diff.old_row_count {
777 diff.added_count = diff.new_row_count - diff.old_row_count;
778 } else if diff.old_row_count > diff.new_row_count {
779 diff.removed_count = diff.old_row_count - diff.new_row_count;
780 }
781 }
782 }
783
784 if diff.old_row_count > 0
786 || diff.new_row_count > 0
787 || diff.added_count > 0
788 || diff.removed_count > 0
789 || diff.modified_count > 0
790 {
791 tables.insert(table_name, diff);
792 }
793 }
794
795 (DataDiff { tables }, self.warnings)
796 }
797}