1mod config;
12
13pub use config::{
14 DefaultShardClassifier, GlobalTableMode, ShardTableClassification, ShardYamlConfig,
15};
16
17use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet, PkTuple, PkValue};
18use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
19use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
20use crate::schema::{SchemaBuilder, SchemaGraph, TableId, TableSchema};
21use crate::splitter::Splitter;
22use ahash::{AHashMap, AHashSet};
23use indicatif::{ProgressBar, ProgressStyle};
24use std::fs::{self, File};
25use std::io::{BufWriter, Write};
26use std::path::{Path, PathBuf};
27use tempfile::TempDir;
28
29#[derive(Debug)]
31pub struct ShardConfig {
32 pub input: PathBuf,
34 pub output: Option<PathBuf>,
36 pub dialect: SqlDialect,
38 pub tenant_column: Option<String>,
40 pub tenant_value: String,
42 pub root_tables: Vec<String>,
44 pub include_global: GlobalTableMode,
46 pub dry_run: bool,
48 pub progress: bool,
50 pub config_file: Option<PathBuf>,
52 pub max_selected_rows: Option<usize>,
54 pub strict_fk: bool,
56 pub include_schema: bool,
58}
59
60impl Default for ShardConfig {
61 fn default() -> Self {
62 Self {
63 input: PathBuf::new(),
64 output: None,
65 dialect: SqlDialect::MySql,
66 tenant_column: None,
67 tenant_value: String::new(),
68 root_tables: Vec::new(),
69 include_global: GlobalTableMode::Lookups,
70 dry_run: false,
71 progress: false,
72 config_file: None,
73 max_selected_rows: Some(10_000_000),
74 strict_fk: false,
75 include_schema: true,
76 }
77 }
78}
79
80#[derive(Debug, Default)]
82pub struct ShardStats {
83 pub tables_processed: usize,
85 pub tables_skipped: usize,
87 pub tables_with_data: usize,
89 pub total_rows_selected: u64,
91 pub total_rows_seen: u64,
93 pub table_stats: Vec<TableShardStats>,
95 pub warnings: Vec<String>,
97 pub fk_orphans_skipped: u64,
99 pub detected_tenant_column: Option<String>,
101}
102
103#[derive(Debug, Clone)]
105pub struct TableShardStats {
106 pub name: String,
107 pub rows_seen: u64,
108 pub rows_selected: u64,
109 pub classification: ShardTableClassification,
110}
111
112struct TableRuntime {
114 name: String,
116 selected_rows: Vec<SelectedRow>,
118 pk_set: PkSet,
120 rows_seen: u64,
122 skip: bool,
124 classification: ShardTableClassification,
126 fk_orphans: u64,
128 tenant_column_index: Option<usize>,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq)]
134enum RowFormat {
135 Insert,
136 Copy,
137}
138
139struct SelectedRow {
141 raw: Vec<u8>,
142 format: RowFormat,
143}
144
145enum UnifiedRow {
147 Insert(ParsedRow),
148 Copy(ParsedCopyRow),
149}
150
151impl UnifiedRow {
152 fn pk(&self) -> Option<&PkTuple> {
153 match self {
154 UnifiedRow::Insert(r) => r.pk.as_ref(),
155 UnifiedRow::Copy(r) => r.pk.as_ref(),
156 }
157 }
158
159 fn fk_values(&self) -> &[(crate::parser::mysql_insert::FkRef, PkTuple)] {
160 match self {
161 UnifiedRow::Insert(r) => &r.fk_values,
162 UnifiedRow::Copy(r) => &r.fk_values,
163 }
164 }
165
166 fn into_selected(self) -> SelectedRow {
167 match self {
168 UnifiedRow::Insert(r) => SelectedRow {
169 raw: r.raw,
170 format: RowFormat::Insert,
171 },
172 UnifiedRow::Copy(r) => SelectedRow {
173 raw: r.raw,
174 format: RowFormat::Copy,
175 },
176 }
177 }
178}
179
180pub fn run(config: ShardConfig) -> anyhow::Result<ShardStats> {
182 let yaml_config = if let Some(ref path) = config.config_file {
183 Some(ShardYamlConfig::load(path)?)
184 } else {
185 None
186 };
187
188 let mut stats = ShardStats::default();
189
190 let progress_bar = if config.progress {
191 let pb = ProgressBar::new_spinner();
192 pb.set_style(
193 ProgressStyle::default_spinner()
194 .template("{spinner:.green} {msg}")
195 .unwrap(),
196 );
197 Some(pb)
198 } else {
199 None
200 };
201
202 let temp_dir = TempDir::new()?;
204 let tables_dir = temp_dir.path().join("tables");
205
206 if let Some(ref pb) = progress_bar {
207 pb.set_message("Splitting dump into per-table files...");
208 }
209
210 let splitter = Splitter::new(config.input.clone(), tables_dir.clone())
211 .with_dialect(config.dialect)
212 .with_content_filter(ContentFilter::All);
213
214 let split_stats = splitter.split()?;
215
216 if let Some(ref pb) = progress_bar {
217 pb.set_message(format!(
218 "Split complete: {} tables, {} statements",
219 split_stats.tables_found, split_stats.statements_processed
220 ));
221 }
222
223 if let Some(ref pb) = progress_bar {
225 pb.set_message("Building schema graph...");
226 }
227
228 let graph = build_schema_graph(&tables_dir, &config)?;
229
230 let tenant_column = detect_tenant_column(&config, &yaml_config, &graph)?;
232 stats.detected_tenant_column = Some(tenant_column.clone());
233
234 if let Some(ref pb) = progress_bar {
235 pb.set_message(format!("Using tenant column: {}", tenant_column));
236 }
237
238 let tenant_pk_value = parse_tenant_value(&config.tenant_value);
240
241 let (topo_order, cyclic_tables) = graph.processing_order();
243
244 if !cyclic_tables.is_empty() {
245 let names: Vec<_> = cyclic_tables
246 .iter()
247 .filter_map(|&id| graph.table_name(id))
248 .collect();
249 stats.warnings.push(format!(
250 "{} tables have FK cycles (relaxed FK enforcement): {:?}",
251 cyclic_tables.len(),
252 names
253 ));
254 }
255
256 let cyclic_set: AHashSet<TableId> = cyclic_tables.iter().copied().collect();
257
258 let tenant_root_ids = find_tenant_root_tables(&graph, &tenant_column, &config, &yaml_config);
260
261 let reachable_from_roots = compute_reachable_tables(&graph, &tenant_root_ids);
263
264 let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
266 for table in graph.schema.iter() {
267 let classification = classify_table(
268 &table.name,
269 table.id,
270 &graph,
271 &tenant_root_ids,
272 &reachable_from_roots,
273 &yaml_config,
274 );
275
276 let tenant_column_index = if classification == ShardTableClassification::TenantRoot {
277 find_tenant_column_index(table, &tenant_column)
278 } else {
279 None
280 };
281
282 let skip = should_skip_table(&table.name, classification, &config, &yaml_config);
283
284 runtimes.insert(
285 table.id,
286 TableRuntime {
287 name: table.name.clone(),
288 selected_rows: Vec::new(),
289 pk_set: PkSet::default(),
290 rows_seen: 0,
291 skip,
292 classification,
293 fk_orphans: 0,
294 tenant_column_index,
295 },
296 );
297 }
298
299 if let Some(ref pb) = progress_bar {
301 pb.set_message(format!(
302 "Processing {} tables for tenant {}...",
303 topo_order.len() + cyclic_tables.len(),
304 config.tenant_value
305 ));
306 }
307
308 let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
309 let mut total_selected: u64 = 0;
310
311 for &table_id in &all_tables {
312 let table_schema = match graph.schema.table(table_id) {
313 Some(s) => s,
314 None => continue,
315 };
316
317 let (should_skip, table_name, classification, tenant_col_idx) = {
318 let runtime = match runtimes.get(&table_id) {
319 Some(r) => r,
320 None => continue,
321 };
322 (
323 runtime.skip,
324 runtime.name.clone(),
325 runtime.classification,
326 runtime.tenant_column_index,
327 )
328 };
329
330 if should_skip {
331 stats.tables_skipped += 1;
332 continue;
333 }
334
335 let include_all = match classification {
337 ShardTableClassification::Lookup => match config.include_global {
338 GlobalTableMode::None => {
339 stats.tables_skipped += 1;
340 continue;
341 }
342 GlobalTableMode::Lookups | GlobalTableMode::All => true,
343 },
344 ShardTableClassification::System => {
345 stats.tables_skipped += 1;
346 continue;
347 }
348 ShardTableClassification::Unknown => match config.include_global {
349 GlobalTableMode::All => true,
350 _ => {
351 stats.tables_skipped += 1;
352 continue;
353 }
354 },
355 _ => false,
356 };
357
358 let table_file = tables_dir.join(format!("{}.sql", table_name));
359 if !table_file.exists() {
360 continue;
361 }
362
363 let file = File::open(&table_file)?;
364 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
365
366 let mut rows_seen = 0u64;
367 let mut fk_orphans = 0u64;
368 let mut copy_columns: Vec<String> = Vec::new();
369
370 while let Some(stmt) = parser.read_statement()? {
371 let (stmt_type, _) =
372 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
373
374 match stmt_type {
375 StatementType::Insert => {
376 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
377
378 for row in rows {
379 rows_seen += 1;
380 let unified = UnifiedRow::Insert(row);
381
382 let should_include = if include_all {
383 true
384 } else {
385 should_include_row(
386 &unified,
387 table_schema,
388 classification,
389 tenant_col_idx,
390 &tenant_pk_value,
391 &runtimes,
392 &cyclic_set,
393 &table_id,
394 )
395 };
396
397 if !should_include {
398 if classification == ShardTableClassification::TenantDependent {
399 fk_orphans += 1;
400 }
401 continue;
402 }
403
404 if let Some(max) = config.max_selected_rows {
406 if total_selected >= max as u64 {
407 stats.warnings.push(format!(
408 "Reached max_selected_rows limit ({}) at table '{}'",
409 max, table_name
410 ));
411 break;
412 }
413 }
414
415 total_selected += 1;
416
417 let runtime = runtimes.get_mut(&table_id).unwrap();
418 if let Some(pk) = unified.pk() {
419 runtime.pk_set.insert(pk.clone());
420 }
421 runtime.selected_rows.push(unified.into_selected());
422 }
423 }
424 StatementType::Copy => {
425 let header = String::from_utf8_lossy(&stmt);
426 copy_columns = parse_copy_columns(&header);
427 }
428 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
429 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
430 let rows =
431 parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
432
433 for row in rows {
434 rows_seen += 1;
435 let unified = UnifiedRow::Copy(row);
436
437 let should_include = if include_all {
438 true
439 } else {
440 should_include_row(
441 &unified,
442 table_schema,
443 classification,
444 tenant_col_idx,
445 &tenant_pk_value,
446 &runtimes,
447 &cyclic_set,
448 &table_id,
449 )
450 };
451
452 if !should_include {
453 if classification == ShardTableClassification::TenantDependent {
454 fk_orphans += 1;
455 }
456 continue;
457 }
458
459 if let Some(max) = config.max_selected_rows {
460 if total_selected >= max as u64 {
461 break;
462 }
463 }
464
465 total_selected += 1;
466
467 let runtime = runtimes.get_mut(&table_id).unwrap();
468 if let Some(pk) = unified.pk() {
469 runtime.pk_set.insert(pk.clone());
470 }
471 runtime.selected_rows.push(unified.into_selected());
472 }
473 }
474 }
475 _ => {}
476 }
477 }
478
479 let runtime = runtimes.get_mut(&table_id).unwrap();
480 runtime.rows_seen = rows_seen;
481 runtime.fk_orphans = fk_orphans;
482 stats.fk_orphans_skipped += fk_orphans;
483
484 if !runtime.selected_rows.is_empty() {
485 stats.tables_with_data += 1;
486 }
487
488 stats.table_stats.push(TableShardStats {
489 name: runtime.name.clone(),
490 rows_seen: runtime.rows_seen,
491 rows_selected: runtime.selected_rows.len() as u64,
492 classification: runtime.classification,
493 });
494 }
495
496 for table_stat in &stats.table_stats {
498 stats.total_rows_seen += table_stat.rows_seen;
499 stats.total_rows_selected += table_stat.rows_selected;
500 }
501 stats.tables_processed = stats.table_stats.len();
502
503 if let Some(ref pb) = progress_bar {
504 pb.finish_with_message("Processing complete");
505 }
506
507 if config.dry_run {
509 return Ok(stats);
510 }
511
512 write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
513
514 Ok(stats)
515}
516
517fn build_schema_graph(tables_dir: &Path, config: &ShardConfig) -> anyhow::Result<SchemaGraph> {
519 let mut builder = SchemaBuilder::new();
520
521 for entry in fs::read_dir(tables_dir)? {
522 let entry = entry?;
523 let path = entry.path();
524
525 if path.extension().is_some_and(|e| e == "sql") {
526 let file = File::open(&path)?;
527 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
528
529 while let Some(stmt) = parser.read_statement()? {
530 let (stmt_type, _) =
531 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
532
533 match stmt_type {
534 StatementType::CreateTable => {
535 let stmt_str = String::from_utf8_lossy(&stmt);
536 builder.parse_create_table(&stmt_str);
537 }
538 StatementType::AlterTable => {
539 let stmt_str = String::from_utf8_lossy(&stmt);
540 builder.parse_alter_table(&stmt_str);
541 }
542 _ => {}
543 }
544 }
545 }
546 }
547
548 Ok(SchemaGraph::from_schema(builder.build()))
549}
550
551fn detect_tenant_column(
553 config: &ShardConfig,
554 yaml_config: &Option<ShardYamlConfig>,
555 graph: &SchemaGraph,
556) -> anyhow::Result<String> {
557 if let Some(ref col) = config.tenant_column {
559 return Ok(col.clone());
560 }
561
562 if let Some(ref yaml) = yaml_config {
564 if let Some(ref col) = yaml.tenant.column {
565 return Ok(col.clone());
566 }
567 }
568
569 for candidate in DefaultShardClassifier::TENANT_COLUMNS {
571 let mut found_in_tables = 0;
572 for table in graph.schema.iter() {
573 if table.get_column(candidate).is_some() {
574 found_in_tables += 1;
575 }
576 }
577 if found_in_tables >= 2 {
578 return Ok(candidate.to_string());
579 }
580 }
581
582 anyhow::bail!(
583 "Could not auto-detect tenant column. Please specify with --tenant-column. \
584 Looked for: {:?}",
585 DefaultShardClassifier::TENANT_COLUMNS
586 )
587}
588
589fn parse_tenant_value(value: &str) -> PkValue {
591 if let Ok(i) = value.parse::<i64>() {
592 PkValue::Int(i)
593 } else if let Ok(i) = value.parse::<i128>() {
594 PkValue::BigInt(i)
595 } else {
596 PkValue::Text(value.into())
597 }
598}
599
600fn find_tenant_root_tables(
602 graph: &SchemaGraph,
603 tenant_column: &str,
604 config: &ShardConfig,
605 yaml_config: &Option<ShardYamlConfig>,
606) -> AHashSet<TableId> {
607 let mut roots = AHashSet::new();
608
609 let explicit_roots: AHashSet<String> = config
611 .root_tables
612 .iter()
613 .chain(
614 yaml_config
615 .as_ref()
616 .map(|y| &y.tenant.root_tables)
617 .unwrap_or(&Vec::new()),
618 )
619 .map(|s| s.to_lowercase())
620 .collect();
621
622 for table in graph.schema.iter() {
623 let lower_name = table.name.to_lowercase();
624
625 if explicit_roots.contains(&lower_name) || table.get_column(tenant_column).is_some() {
626 roots.insert(table.id);
627 }
628 }
629
630 roots
631}
632
633fn compute_reachable_tables(
635 graph: &SchemaGraph,
636 tenant_roots: &AHashSet<TableId>,
637) -> AHashSet<TableId> {
638 let mut reachable = tenant_roots.clone();
639 let mut queue: Vec<TableId> = tenant_roots.iter().copied().collect();
640
641 while let Some(table_id) = queue.pop() {
642 for &child_id in &graph.children[table_id.0 as usize] {
643 if !reachable.contains(&child_id) {
644 reachable.insert(child_id);
645 queue.push(child_id);
646 }
647 }
648 }
649
650 reachable
651}
652
653fn classify_table(
655 table_name: &str,
656 table_id: TableId,
657 graph: &SchemaGraph,
658 tenant_roots: &AHashSet<TableId>,
659 reachable: &AHashSet<TableId>,
660 yaml_config: &Option<ShardYamlConfig>,
661) -> ShardTableClassification {
662 if let Some(ref yaml) = yaml_config {
664 if let Some(class) = yaml.get_classification(table_name) {
665 return class;
666 }
667 }
668
669 if tenant_roots.contains(&table_id) {
671 return ShardTableClassification::TenantRoot;
672 }
673
674 if reachable.contains(&table_id) {
676 if is_junction_table(table_name, table_id, graph) {
678 return ShardTableClassification::Junction;
679 }
680 return ShardTableClassification::TenantDependent;
681 }
682
683 if DefaultShardClassifier::is_system_table(table_name) {
685 return ShardTableClassification::System;
686 }
687
688 if DefaultShardClassifier::is_lookup_table(table_name) {
690 return ShardTableClassification::Lookup;
691 }
692
693 ShardTableClassification::Unknown
694}
695
696fn is_junction_table(table_name: &str, table_id: TableId, graph: &SchemaGraph) -> bool {
698 if DefaultShardClassifier::is_junction_table_by_name(table_name) {
700 return true;
701 }
702
703 if let Some(table) = graph.schema.table(table_id) {
705 let fk_count = table.foreign_keys.len();
706 let fk_col_count: usize = table.foreign_keys.iter().map(|fk| fk.columns.len()).sum();
707 let total_cols = table.columns.len();
708
709 if fk_count >= 2 && fk_col_count >= total_cols.saturating_sub(2) {
711 return true;
712 }
713 }
714
715 false
716}
717
718fn find_tenant_column_index(table: &TableSchema, tenant_column: &str) -> Option<usize> {
720 table
721 .columns
722 .iter()
723 .position(|c| c.name.eq_ignore_ascii_case(tenant_column))
724}
725
726fn should_skip_table(
728 table_name: &str,
729 classification: ShardTableClassification,
730 config: &ShardConfig,
731 yaml_config: &Option<ShardYamlConfig>,
732) -> bool {
733 if let Some(ref yaml) = yaml_config {
735 if yaml.should_skip(table_name) {
736 return true;
737 }
738 }
739
740 if classification == ShardTableClassification::System {
742 return true;
743 }
744
745 if classification == ShardTableClassification::Lookup {
747 return config.include_global == GlobalTableMode::None;
748 }
749
750 false
751}
752
753#[allow(clippy::too_many_arguments)]
755fn should_include_row(
756 row: &UnifiedRow,
757 table_schema: &TableSchema,
758 classification: ShardTableClassification,
759 tenant_column_index: Option<usize>,
760 tenant_value: &PkValue,
761 runtimes: &AHashMap<TableId, TableRuntime>,
762 cyclic_set: &AHashSet<TableId>,
763 current_table_id: &TableId,
764) -> bool {
765 match classification {
766 ShardTableClassification::TenantRoot => {
767 if let Some(idx) = tenant_column_index {
769 match row {
770 UnifiedRow::Insert(r) => {
771 if let Some(val) = extract_column_value(&r.raw, idx) {
772 return &val == tenant_value;
773 }
774 }
775 UnifiedRow::Copy(r) => {
776 if let Some(val) = extract_copy_column_value(&r.raw, idx) {
777 return &val == tenant_value;
778 }
779 }
780 }
781 }
782 false
783 }
784 ShardTableClassification::TenantDependent => {
785 for (fk_ref, fk_tuple) in row.fk_values() {
787 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
788 if let Some(parent_id) = fk.referenced_table_id {
789 if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id)
791 {
792 continue;
793 }
794
795 if let Some(parent_runtime) = runtimes.get(&parent_id) {
796 if parent_runtime.pk_set.contains(fk_tuple) {
797 return true;
798 }
799 }
800 }
801 }
802 }
803 false
804 }
805 ShardTableClassification::Junction => {
806 for (fk_ref, fk_tuple) in row.fk_values() {
808 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
809 if let Some(parent_id) = fk.referenced_table_id {
810 if let Some(parent_runtime) = runtimes.get(&parent_id) {
811 if parent_runtime.pk_set.contains(fk_tuple) {
812 return true;
813 }
814 }
815 }
816 }
817 }
818 false
819 }
820 _ => false,
821 }
822}
823
824fn extract_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
826 let mut values = Vec::new();
827 let mut current_start = 0;
828 let mut in_string = false;
829 let mut escape_next = false;
830 let mut paren_depth = 0;
831
832 let start = raw.iter().position(|&b| b == b'(')?;
834 let raw = &raw[start + 1..];
835
836 for (i, &b) in raw.iter().enumerate() {
837 if escape_next {
838 escape_next = false;
839 continue;
840 }
841
842 if b == b'\\' && in_string {
843 escape_next = true;
844 continue;
845 }
846
847 if b == b'\'' && !escape_next {
848 in_string = !in_string;
849 continue;
850 }
851
852 if in_string {
853 continue;
854 }
855
856 match b {
857 b'(' => paren_depth += 1,
858 b')' => {
859 if paren_depth == 0 {
860 values.push(&raw[current_start..i]);
861 break;
862 }
863 paren_depth -= 1;
864 }
865 b',' if paren_depth == 0 => {
866 values.push(&raw[current_start..i]);
867 current_start = i + 1;
868 }
869 _ => {}
870 }
871 }
872
873 if column_index >= values.len() {
874 return None;
875 }
876
877 parse_value_bytes(values[column_index])
878}
879
880fn extract_copy_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
882 let fields: Vec<&[u8]> = raw.split(|&b| b == b'\t').collect();
883 if column_index >= fields.len() {
884 return None;
885 }
886
887 let field = fields[column_index];
888 if field == b"\\N" {
889 return Some(PkValue::Null);
890 }
891
892 parse_value_bytes(field)
893}
894
895fn parse_value_bytes(bytes: &[u8]) -> Option<PkValue> {
897 let trimmed = bytes
898 .iter()
899 .skip_while(|&&b| b == b' ')
900 .take_while(|&&b| b != b' ')
901 .copied()
902 .collect::<Vec<_>>();
903
904 if trimmed.is_empty() {
905 return None;
906 }
907
908 if trimmed.eq_ignore_ascii_case(b"null") {
910 return Some(PkValue::Null);
911 }
912
913 let unquoted = if trimmed.first() == Some(&b'\'') && trimmed.last() == Some(&b'\'') {
915 &trimmed[1..trimmed.len() - 1]
916 } else {
917 &trimmed[..]
918 };
919
920 if let Ok(s) = std::str::from_utf8(unquoted) {
922 if let Ok(i) = s.parse::<i64>() {
923 return Some(PkValue::Int(i));
924 }
925 if let Ok(i) = s.parse::<i128>() {
926 return Some(PkValue::BigInt(i));
927 }
928 return Some(PkValue::Text(s.into()));
929 }
930
931 None
932}
933
934fn write_output(
936 config: &ShardConfig,
937 _graph: &SchemaGraph,
938 table_order: &[TableId],
939 runtimes: &AHashMap<TableId, TableRuntime>,
940 tables_dir: &Path,
941 stats: &ShardStats,
942) -> anyhow::Result<()> {
943 let mut writer: Box<dyn Write> = match &config.output {
944 Some(path) => {
945 if let Some(parent) = path.parent() {
946 fs::create_dir_all(parent)?;
947 }
948 Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
949 }
950 None => Box::new(BufWriter::new(std::io::stdout())),
951 };
952
953 write_header(&mut writer, config, stats)?;
955
956 write_dialect_header(&mut writer, config.dialect)?;
958
959 if config.include_schema {
961 for &table_id in table_order {
962 let runtime = match runtimes.get(&table_id) {
963 Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
964 _ => continue,
965 };
966
967 let table_file = tables_dir.join(format!("{}.sql", runtime.name));
968 if !table_file.exists() {
969 continue;
970 }
971
972 let file = File::open(&table_file)?;
973 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
974
975 while let Some(stmt) = parser.read_statement()? {
976 let (stmt_type, _) =
977 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
978
979 if stmt_type.is_schema() {
980 writer.write_all(&stmt)?;
981 writer.write_all(b"\n")?;
982 }
983 }
984 }
985 }
986
987 for &table_id in table_order {
989 let runtime = match runtimes.get(&table_id) {
990 Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
991 _ => continue,
992 };
993
994 let table_name = &runtime.name;
995 let row_count = runtime.selected_rows.len();
996
997 writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
998
999 const CHUNK_SIZE: usize = 1000;
1000
1001 let quoted_name = match config.dialect {
1002 SqlDialect::MySql => format!("`{}`", table_name),
1003 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1004 };
1005
1006 for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
1007 writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1008
1009 for (i, row) in chunk.iter().enumerate() {
1010 if i > 0 {
1011 writer.write_all(b",\n")?;
1012 }
1013
1014 let values = match row.format {
1015 RowFormat::Insert => match config.dialect {
1016 SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
1017 _ => row.raw.clone(),
1018 },
1019 RowFormat::Copy => convert_copy_to_insert_values(&row.raw, config.dialect),
1020 };
1021 writer.write_all(&values)?;
1022 }
1023
1024 writer.write_all(b";\n")?;
1025 }
1026 }
1027
1028 write_dialect_footer(&mut writer, config.dialect)?;
1030
1031 writer.flush()?;
1032
1033 Ok(())
1034}
1035
1036fn write_header<W: Write>(
1038 writer: &mut W,
1039 config: &ShardConfig,
1040 stats: &ShardStats,
1041) -> std::io::Result<()> {
1042 writeln!(writer, "-- Sharded from: {}", config.input.display())?;
1043 writeln!(
1044 writer,
1045 "-- Date: {}",
1046 chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1047 )?;
1048 if let Some(ref col) = stats.detected_tenant_column {
1049 writeln!(writer, "-- Tenant column: {}", col)?;
1050 }
1051 writeln!(writer, "-- Tenant value: {}", config.tenant_value)?;
1052 writeln!(writer, "-- Dialect: {}", config.dialect)?;
1053 writeln!(writer, "--")?;
1054 writeln!(writer, "-- Statistics:")?;
1055 writeln!(writer, "-- Tables processed: {}", stats.tables_processed)?;
1056 writeln!(writer, "-- Tables with data: {}", stats.tables_with_data)?;
1057 writeln!(writer, "-- Tables skipped: {}", stats.tables_skipped)?;
1058
1059 let percent = if stats.total_rows_seen > 0 {
1060 (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1061 } else {
1062 0.0
1063 };
1064 writeln!(
1065 writer,
1066 "-- Total rows: {} (from {} original, {:.1}%)",
1067 stats.total_rows_selected, stats.total_rows_seen, percent
1068 )?;
1069
1070 if stats.fk_orphans_skipped > 0 {
1071 writeln!(
1072 writer,
1073 "-- FK orphans skipped: {}",
1074 stats.fk_orphans_skipped
1075 )?;
1076 }
1077
1078 if !stats.warnings.is_empty() {
1079 writeln!(writer, "-- Warnings: {}", stats.warnings.len())?;
1080 }
1081
1082 writeln!(writer)?;
1083
1084 Ok(())
1085}
1086
1087fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1089 match dialect {
1090 SqlDialect::MySql => {
1091 writeln!(writer, "SET NAMES utf8mb4;")?;
1092 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1093 }
1094 SqlDialect::Postgres => {
1095 writeln!(writer, "SET client_encoding = 'UTF8';")?;
1096 writeln!(writer, "SET session_replication_role = replica;")?;
1097 }
1098 SqlDialect::Sqlite => {
1099 writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1100 }
1101 }
1102 writeln!(writer)?;
1103 Ok(())
1104}
1105
1106fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1108 writeln!(writer)?;
1109 match dialect {
1110 SqlDialect::MySql => {
1111 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1112 }
1113 SqlDialect::Postgres => {
1114 writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1115 }
1116 SqlDialect::Sqlite => {
1117 writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1118 }
1119 }
1120 Ok(())
1121}
1122
1123fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1125 let mut result = Vec::with_capacity(row.len());
1126 let mut i = 0;
1127
1128 while i < row.len() {
1129 if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1130 result.push(b'\'');
1131 result.push(b'\'');
1132 i += 2;
1133 } else {
1134 result.push(row[i]);
1135 i += 1;
1136 }
1137 }
1138
1139 result
1140}
1141
1142fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1144 let mut result = Vec::with_capacity(row.len() + 20);
1145 result.push(b'(');
1146
1147 let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1148
1149 for (i, field) in fields.iter().enumerate() {
1150 if i > 0 {
1151 result.extend_from_slice(b", ");
1152 }
1153
1154 if *field == b"\\N" {
1155 result.extend_from_slice(b"NULL");
1156 } else if field.is_empty() {
1157 result.extend_from_slice(b"''");
1158 } else if is_numeric(field) {
1159 result.extend_from_slice(field);
1160 } else {
1161 result.push(b'\'');
1162 for &b in *field {
1163 match b {
1164 b'\'' => match dialect {
1165 SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1166 SqlDialect::Postgres | SqlDialect::Sqlite => {
1167 result.extend_from_slice(b"''")
1168 }
1169 },
1170 b'\\' if dialect == SqlDialect::MySql => {
1171 result.extend_from_slice(b"\\\\");
1172 }
1173 _ => result.push(b),
1174 }
1175 }
1176 result.push(b'\'');
1177 }
1178 }
1179
1180 result.push(b')');
1181 result
1182}
1183
1184fn is_numeric(s: &[u8]) -> bool {
1186 if s.is_empty() {
1187 return false;
1188 }
1189
1190 let mut has_digit = false;
1191 let mut has_dot = false;
1192 let mut start = 0;
1193
1194 if s[0] == b'-' || s[0] == b'+' {
1195 start = 1;
1196 }
1197
1198 for &b in &s[start..] {
1199 match b {
1200 b'0'..=b'9' => has_digit = true,
1201 b'.' if !has_dot => has_dot = true,
1202 b'e' | b'E' => continue,
1203 _ => return false,
1204 }
1205 }
1206
1207 has_digit
1208}