1mod config;
12
13pub use config::{
14 DefaultShardClassifier, GlobalTableMode, ShardTableClassification, ShardYamlConfig,
15};
16
17use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet, PkTuple, PkValue};
18use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
19use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
20use crate::schema::{SchemaBuilder, SchemaGraph, TableId, TableSchema};
21use crate::splitter::Splitter;
22use ahash::{AHashMap, AHashSet};
23use indicatif::{ProgressBar, ProgressStyle};
24use std::fs::{self, File};
25use std::io::{BufWriter, Write};
26use std::path::{Path, PathBuf};
27use tempfile::TempDir;
28
29#[derive(Debug)]
31pub struct ShardConfig {
32 pub input: PathBuf,
34 pub output: Option<PathBuf>,
36 pub dialect: SqlDialect,
38 pub tenant_column: Option<String>,
40 pub tenant_value: String,
42 pub root_tables: Vec<String>,
44 pub include_global: GlobalTableMode,
46 pub dry_run: bool,
48 pub progress: bool,
50 pub config_file: Option<PathBuf>,
52 pub max_selected_rows: Option<usize>,
54 pub strict_fk: bool,
56 pub include_schema: bool,
58}
59
60impl Default for ShardConfig {
61 fn default() -> Self {
62 Self {
63 input: PathBuf::new(),
64 output: None,
65 dialect: SqlDialect::MySql,
66 tenant_column: None,
67 tenant_value: String::new(),
68 root_tables: Vec::new(),
69 include_global: GlobalTableMode::Lookups,
70 dry_run: false,
71 progress: false,
72 config_file: None,
73 max_selected_rows: Some(10_000_000),
74 strict_fk: false,
75 include_schema: true,
76 }
77 }
78}
79
80#[derive(Debug, Default, serde::Serialize)]
82pub struct ShardStats {
83 pub tables_processed: usize,
85 pub tables_skipped: usize,
87 pub tables_with_data: usize,
89 pub total_rows_selected: u64,
91 pub total_rows_seen: u64,
93 pub table_stats: Vec<TableShardStats>,
95 pub warnings: Vec<String>,
97 pub fk_orphans_skipped: u64,
99 pub detected_tenant_column: Option<String>,
101}
102
103#[derive(Debug, Clone, serde::Serialize)]
105pub struct TableShardStats {
106 pub name: String,
107 pub rows_seen: u64,
108 pub rows_selected: u64,
109 pub classification: ShardTableClassification,
110}
111
112struct TableRuntime {
114 name: String,
116 selected_rows: Vec<SelectedRow>,
118 pk_set: PkSet,
120 rows_seen: u64,
122 skip: bool,
124 classification: ShardTableClassification,
126 fk_orphans: u64,
128 tenant_column_index: Option<usize>,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq)]
134enum RowFormat {
135 Insert,
136 Copy,
137}
138
139struct SelectedRow {
141 raw: Vec<u8>,
142 format: RowFormat,
143}
144
145enum UnifiedRow {
147 Insert(ParsedRow),
148 Copy(ParsedCopyRow),
149}
150
151impl UnifiedRow {
152 fn pk(&self) -> Option<&PkTuple> {
153 match self {
154 UnifiedRow::Insert(r) => r.pk.as_ref(),
155 UnifiedRow::Copy(r) => r.pk.as_ref(),
156 }
157 }
158
159 fn fk_values(&self) -> &[(crate::parser::mysql_insert::FkRef, PkTuple)] {
160 match self {
161 UnifiedRow::Insert(r) => &r.fk_values,
162 UnifiedRow::Copy(r) => &r.fk_values,
163 }
164 }
165
166 fn into_selected(self) -> SelectedRow {
167 match self {
168 UnifiedRow::Insert(r) => SelectedRow {
169 raw: r.raw,
170 format: RowFormat::Insert,
171 },
172 UnifiedRow::Copy(r) => SelectedRow {
173 raw: r.raw,
174 format: RowFormat::Copy,
175 },
176 }
177 }
178}
179
180pub fn run(config: ShardConfig) -> anyhow::Result<ShardStats> {
182 let yaml_config = if let Some(ref path) = config.config_file {
183 Some(ShardYamlConfig::load(path)?)
184 } else {
185 None
186 };
187
188 let mut stats = ShardStats::default();
189
190 let file_size = std::fs::metadata(&config.input)?.len();
192
193 let progress_bar = if config.progress {
195 let pb = ProgressBar::new(file_size);
196 pb.set_style(
197 ProgressStyle::with_template(
198 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
199 )
200 .unwrap()
201 .progress_chars("█▓▒░ ")
202 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
203 );
204 pb.enable_steady_tick(std::time::Duration::from_millis(100));
205 pb.set_message("Splitting dump...");
206 Some(pb)
207 } else {
208 None
209 };
210
211 let temp_dir = TempDir::new()?;
213 let tables_dir = temp_dir.path().join("tables");
214
215 let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
216 .with_dialect(config.dialect)
217 .with_content_filter(ContentFilter::All);
218
219 if let Some(ref pb) = progress_bar {
220 let pb_clone = pb.clone();
221 splitter = splitter.with_progress(move |bytes| {
222 pb_clone.set_position(bytes);
223 });
224 }
225
226 let split_stats = splitter.split()?;
227
228 if let Some(ref pb) = progress_bar {
230 pb.finish_and_clear();
231 }
232
233 if config.progress {
234 eprintln!(
235 "Split complete: {} tables, {} statements",
236 split_stats.tables_found, split_stats.statements_processed
237 );
238 }
239
240 if config.progress {
242 eprintln!("Building schema graph...");
243 }
244
245 let graph = build_schema_graph(&tables_dir, &config)?;
246
247 let tenant_column = detect_tenant_column(&config, &yaml_config, &graph)?;
249 stats.detected_tenant_column = Some(tenant_column.clone());
250
251 if config.progress {
252 eprintln!("Using tenant column: {}", tenant_column);
253 }
254
255 let tenant_pk_value = parse_tenant_value(&config.tenant_value);
257
258 let (topo_order, cyclic_tables) = graph.processing_order();
260
261 if !cyclic_tables.is_empty() {
262 let names: Vec<_> = cyclic_tables
263 .iter()
264 .filter_map(|&id| graph.table_name(id))
265 .collect();
266 stats.warnings.push(format!(
267 "{} tables have FK cycles (relaxed FK enforcement): {:?}",
268 cyclic_tables.len(),
269 names
270 ));
271 }
272
273 let cyclic_set: AHashSet<TableId> = cyclic_tables.iter().copied().collect();
274
275 let tenant_root_ids = find_tenant_root_tables(&graph, &tenant_column, &config, &yaml_config);
277
278 let reachable_from_roots = compute_reachable_tables(&graph, &tenant_root_ids);
280
281 let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
283 for table in graph.schema.iter() {
284 let classification = classify_table(
285 &table.name,
286 table.id,
287 &graph,
288 &tenant_root_ids,
289 &reachable_from_roots,
290 &yaml_config,
291 );
292
293 let tenant_column_index = if classification == ShardTableClassification::TenantRoot {
294 find_tenant_column_index(table, &tenant_column)
295 } else {
296 None
297 };
298
299 let skip = should_skip_table(&table.name, classification, &config, &yaml_config);
300
301 runtimes.insert(
302 table.id,
303 TableRuntime {
304 name: table.name.clone(),
305 selected_rows: Vec::new(),
306 pk_set: PkSet::default(),
307 rows_seen: 0,
308 skip,
309 classification,
310 fk_orphans: 0,
311 tenant_column_index,
312 },
313 );
314 }
315
316 if config.progress {
318 eprintln!(
319 "Processing {} tables for tenant {}...",
320 topo_order.len() + cyclic_tables.len(),
321 config.tenant_value
322 );
323 }
324
325 let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
326 let mut total_selected: u64 = 0;
327
328 for &table_id in &all_tables {
329 let table_schema = match graph.schema.table(table_id) {
330 Some(s) => s,
331 None => continue,
332 };
333
334 let (should_skip, table_name, classification, tenant_col_idx) = {
335 let runtime = match runtimes.get(&table_id) {
336 Some(r) => r,
337 None => continue,
338 };
339 (
340 runtime.skip,
341 runtime.name.clone(),
342 runtime.classification,
343 runtime.tenant_column_index,
344 )
345 };
346
347 if should_skip {
348 stats.tables_skipped += 1;
349 continue;
350 }
351
352 let include_all = match classification {
354 ShardTableClassification::Lookup => match config.include_global {
355 GlobalTableMode::None => {
356 stats.tables_skipped += 1;
357 continue;
358 }
359 GlobalTableMode::Lookups | GlobalTableMode::All => true,
360 },
361 ShardTableClassification::System => {
362 stats.tables_skipped += 1;
363 continue;
364 }
365 ShardTableClassification::Unknown => match config.include_global {
366 GlobalTableMode::All => true,
367 _ => {
368 stats.tables_skipped += 1;
369 continue;
370 }
371 },
372 _ => false,
373 };
374
375 let table_file = tables_dir.join(format!("{}.sql", table_name));
376 if !table_file.exists() {
377 continue;
378 }
379
380 let file = File::open(&table_file)?;
381 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
382
383 let mut rows_seen = 0u64;
384 let mut fk_orphans = 0u64;
385 let mut copy_columns: Vec<String> = Vec::new();
386
387 while let Some(stmt) = parser.read_statement()? {
388 let (stmt_type, _) =
389 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
390
391 match stmt_type {
392 StatementType::Insert => {
393 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
394
395 for row in rows {
396 rows_seen += 1;
397 let unified = UnifiedRow::Insert(row);
398
399 let should_include = if include_all {
400 true
401 } else {
402 should_include_row(
403 &unified,
404 table_schema,
405 classification,
406 tenant_col_idx,
407 &tenant_pk_value,
408 &runtimes,
409 &cyclic_set,
410 &table_id,
411 )
412 };
413
414 if !should_include {
415 if classification == ShardTableClassification::TenantDependent {
416 fk_orphans += 1;
417 }
418 continue;
419 }
420
421 if let Some(max) = config.max_selected_rows {
423 if total_selected >= max as u64 {
424 stats.warnings.push(format!(
425 "Reached max_selected_rows limit ({}) at table '{}'",
426 max, table_name
427 ));
428 break;
429 }
430 }
431
432 total_selected += 1;
433
434 let runtime = runtimes.get_mut(&table_id).unwrap();
435 if let Some(pk) = unified.pk() {
436 runtime.pk_set.insert(pk.clone());
437 }
438 runtime.selected_rows.push(unified.into_selected());
439 }
440 }
441 StatementType::Copy => {
442 let header = String::from_utf8_lossy(&stmt);
443 copy_columns = parse_copy_columns(&header);
444 }
445 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
446 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
447 let rows =
448 parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
449
450 for row in rows {
451 rows_seen += 1;
452 let unified = UnifiedRow::Copy(row);
453
454 let should_include = if include_all {
455 true
456 } else {
457 should_include_row(
458 &unified,
459 table_schema,
460 classification,
461 tenant_col_idx,
462 &tenant_pk_value,
463 &runtimes,
464 &cyclic_set,
465 &table_id,
466 )
467 };
468
469 if !should_include {
470 if classification == ShardTableClassification::TenantDependent {
471 fk_orphans += 1;
472 }
473 continue;
474 }
475
476 if let Some(max) = config.max_selected_rows {
477 if total_selected >= max as u64 {
478 break;
479 }
480 }
481
482 total_selected += 1;
483
484 let runtime = runtimes.get_mut(&table_id).unwrap();
485 if let Some(pk) = unified.pk() {
486 runtime.pk_set.insert(pk.clone());
487 }
488 runtime.selected_rows.push(unified.into_selected());
489 }
490 }
491 }
492 _ => {}
493 }
494 }
495
496 let runtime = runtimes.get_mut(&table_id).unwrap();
497 runtime.rows_seen = rows_seen;
498 runtime.fk_orphans = fk_orphans;
499 stats.fk_orphans_skipped += fk_orphans;
500
501 if !runtime.selected_rows.is_empty() {
502 stats.tables_with_data += 1;
503 }
504
505 stats.table_stats.push(TableShardStats {
506 name: runtime.name.clone(),
507 rows_seen: runtime.rows_seen,
508 rows_selected: runtime.selected_rows.len() as u64,
509 classification: runtime.classification,
510 });
511 }
512
513 for table_stat in &stats.table_stats {
515 stats.total_rows_seen += table_stat.rows_seen;
516 stats.total_rows_selected += table_stat.rows_selected;
517 }
518 stats.tables_processed = stats.table_stats.len();
519
520 if config.progress {
521 eprintln!("Processing complete");
522 }
523
524 if config.dry_run {
526 return Ok(stats);
527 }
528
529 write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
530
531 Ok(stats)
532}
533
534fn build_schema_graph(tables_dir: &Path, config: &ShardConfig) -> anyhow::Result<SchemaGraph> {
536 let mut builder = SchemaBuilder::new();
537
538 for entry in fs::read_dir(tables_dir)? {
539 let entry = entry?;
540 let path = entry.path();
541
542 if path.extension().is_some_and(|e| e == "sql") {
543 let file = File::open(&path)?;
544 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
545
546 while let Some(stmt) = parser.read_statement()? {
547 let (stmt_type, _) =
548 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
549
550 match stmt_type {
551 StatementType::CreateTable => {
552 let stmt_str = String::from_utf8_lossy(&stmt);
553 builder.parse_create_table(&stmt_str);
554 }
555 StatementType::AlterTable => {
556 let stmt_str = String::from_utf8_lossy(&stmt);
557 builder.parse_alter_table(&stmt_str);
558 }
559 _ => {}
560 }
561 }
562 }
563 }
564
565 Ok(SchemaGraph::from_schema(builder.build()))
566}
567
568fn detect_tenant_column(
570 config: &ShardConfig,
571 yaml_config: &Option<ShardYamlConfig>,
572 graph: &SchemaGraph,
573) -> anyhow::Result<String> {
574 if let Some(ref col) = config.tenant_column {
576 return Ok(col.clone());
577 }
578
579 if let Some(ref yaml) = yaml_config {
581 if let Some(ref col) = yaml.tenant.column {
582 return Ok(col.clone());
583 }
584 }
585
586 for candidate in DefaultShardClassifier::TENANT_COLUMNS {
588 let mut found_in_tables = 0;
589 for table in graph.schema.iter() {
590 if table.get_column(candidate).is_some() {
591 found_in_tables += 1;
592 }
593 }
594 if found_in_tables >= 2 {
595 return Ok(candidate.to_string());
596 }
597 }
598
599 anyhow::bail!(
600 "Could not auto-detect tenant column. Please specify with --tenant-column. \
601 Looked for: {:?}",
602 DefaultShardClassifier::TENANT_COLUMNS
603 )
604}
605
606fn parse_tenant_value(value: &str) -> PkValue {
608 if let Ok(i) = value.parse::<i64>() {
609 PkValue::Int(i)
610 } else if let Ok(i) = value.parse::<i128>() {
611 PkValue::BigInt(i)
612 } else {
613 PkValue::Text(value.into())
614 }
615}
616
617fn find_tenant_root_tables(
619 graph: &SchemaGraph,
620 tenant_column: &str,
621 config: &ShardConfig,
622 yaml_config: &Option<ShardYamlConfig>,
623) -> AHashSet<TableId> {
624 let mut roots = AHashSet::new();
625
626 let explicit_roots: AHashSet<String> = config
628 .root_tables
629 .iter()
630 .chain(
631 yaml_config
632 .as_ref()
633 .map(|y| &y.tenant.root_tables)
634 .unwrap_or(&Vec::new()),
635 )
636 .map(|s| s.to_lowercase())
637 .collect();
638
639 for table in graph.schema.iter() {
640 let lower_name = table.name.to_lowercase();
641
642 if explicit_roots.contains(&lower_name) || table.get_column(tenant_column).is_some() {
643 roots.insert(table.id);
644 }
645 }
646
647 roots
648}
649
650fn compute_reachable_tables(
652 graph: &SchemaGraph,
653 tenant_roots: &AHashSet<TableId>,
654) -> AHashSet<TableId> {
655 let mut reachable = tenant_roots.clone();
656 let mut queue: Vec<TableId> = tenant_roots.iter().copied().collect();
657
658 while let Some(table_id) = queue.pop() {
659 for &child_id in &graph.children[table_id.0 as usize] {
660 if !reachable.contains(&child_id) {
661 reachable.insert(child_id);
662 queue.push(child_id);
663 }
664 }
665 }
666
667 reachable
668}
669
670fn classify_table(
672 table_name: &str,
673 table_id: TableId,
674 graph: &SchemaGraph,
675 tenant_roots: &AHashSet<TableId>,
676 reachable: &AHashSet<TableId>,
677 yaml_config: &Option<ShardYamlConfig>,
678) -> ShardTableClassification {
679 if let Some(ref yaml) = yaml_config {
681 if let Some(class) = yaml.get_classification(table_name) {
682 return class;
683 }
684 }
685
686 if tenant_roots.contains(&table_id) {
688 return ShardTableClassification::TenantRoot;
689 }
690
691 if reachable.contains(&table_id) {
693 if is_junction_table(table_name, table_id, graph) {
695 return ShardTableClassification::Junction;
696 }
697 return ShardTableClassification::TenantDependent;
698 }
699
700 if DefaultShardClassifier::is_system_table(table_name) {
702 return ShardTableClassification::System;
703 }
704
705 if DefaultShardClassifier::is_lookup_table(table_name) {
707 return ShardTableClassification::Lookup;
708 }
709
710 ShardTableClassification::Unknown
711}
712
713fn is_junction_table(table_name: &str, table_id: TableId, graph: &SchemaGraph) -> bool {
715 if DefaultShardClassifier::is_junction_table_by_name(table_name) {
717 return true;
718 }
719
720 if let Some(table) = graph.schema.table(table_id) {
722 let fk_count = table.foreign_keys.len();
723 let fk_col_count: usize = table.foreign_keys.iter().map(|fk| fk.columns.len()).sum();
724 let total_cols = table.columns.len();
725
726 if fk_count >= 2 && fk_col_count >= total_cols.saturating_sub(2) {
728 return true;
729 }
730 }
731
732 false
733}
734
735fn find_tenant_column_index(table: &TableSchema, tenant_column: &str) -> Option<usize> {
737 table
738 .columns
739 .iter()
740 .position(|c| c.name.eq_ignore_ascii_case(tenant_column))
741}
742
743fn should_skip_table(
745 table_name: &str,
746 classification: ShardTableClassification,
747 config: &ShardConfig,
748 yaml_config: &Option<ShardYamlConfig>,
749) -> bool {
750 if let Some(ref yaml) = yaml_config {
752 if yaml.should_skip(table_name) {
753 return true;
754 }
755 }
756
757 if classification == ShardTableClassification::System {
759 return true;
760 }
761
762 if classification == ShardTableClassification::Lookup {
764 return config.include_global == GlobalTableMode::None;
765 }
766
767 false
768}
769
770#[allow(clippy::too_many_arguments)]
772fn should_include_row(
773 row: &UnifiedRow,
774 table_schema: &TableSchema,
775 classification: ShardTableClassification,
776 tenant_column_index: Option<usize>,
777 tenant_value: &PkValue,
778 runtimes: &AHashMap<TableId, TableRuntime>,
779 cyclic_set: &AHashSet<TableId>,
780 current_table_id: &TableId,
781) -> bool {
782 match classification {
783 ShardTableClassification::TenantRoot => {
784 if let Some(idx) = tenant_column_index {
786 match row {
787 UnifiedRow::Insert(r) => {
788 if let Some(val) = r.get_column_value(idx) {
789 return val == tenant_value;
790 }
791 }
792 UnifiedRow::Copy(r) => {
793 if let Some(val) = r.get_column_value(idx) {
794 return val == tenant_value;
795 }
796 }
797 }
798 }
799 false
800 }
801 ShardTableClassification::TenantDependent => {
802 for (fk_ref, fk_tuple) in row.fk_values() {
804 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
805 if let Some(parent_id) = fk.referenced_table_id {
806 if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id)
808 {
809 continue;
810 }
811
812 if let Some(parent_runtime) = runtimes.get(&parent_id) {
813 if parent_runtime.pk_set.contains(fk_tuple) {
814 return true;
815 }
816 }
817 }
818 }
819 }
820 false
821 }
822 ShardTableClassification::Junction => {
823 for (fk_ref, fk_tuple) in row.fk_values() {
825 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
826 if let Some(parent_id) = fk.referenced_table_id {
827 if let Some(parent_runtime) = runtimes.get(&parent_id) {
828 if parent_runtime.pk_set.contains(fk_tuple) {
829 return true;
830 }
831 }
832 }
833 }
834 }
835 false
836 }
837 _ => false,
838 }
839}
840
841fn extract_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
843 let mut values = Vec::new();
844 let mut current_start = 0;
845 let mut in_string = false;
846 let mut escape_next = false;
847 let mut paren_depth = 0;
848
849 let start = raw.iter().position(|&b| b == b'(')?;
851 let raw = &raw[start + 1..];
852
853 for (i, &b) in raw.iter().enumerate() {
854 if escape_next {
855 escape_next = false;
856 continue;
857 }
858
859 if b == b'\\' && in_string {
860 escape_next = true;
861 continue;
862 }
863
864 if b == b'\'' && !escape_next {
865 in_string = !in_string;
866 continue;
867 }
868
869 if in_string {
870 continue;
871 }
872
873 match b {
874 b'(' => paren_depth += 1,
875 b')' => {
876 if paren_depth == 0 {
877 values.push(&raw[current_start..i]);
878 break;
879 }
880 paren_depth -= 1;
881 }
882 b',' if paren_depth == 0 => {
883 values.push(&raw[current_start..i]);
884 current_start = i + 1;
885 }
886 _ => {}
887 }
888 }
889
890 if column_index >= values.len() {
891 return None;
892 }
893
894 parse_value_bytes(values[column_index])
895}
896
897fn extract_copy_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
899 let fields: Vec<&[u8]> = raw.split(|&b| b == b'\t').collect();
900 if column_index >= fields.len() {
901 return None;
902 }
903
904 let field = fields[column_index];
905 if field == b"\\N" {
906 return Some(PkValue::Null);
907 }
908
909 parse_value_bytes(field)
910}
911
912fn parse_value_bytes(bytes: &[u8]) -> Option<PkValue> {
914 let trimmed = bytes
915 .iter()
916 .skip_while(|&&b| b == b' ')
917 .take_while(|&&b| b != b' ')
918 .copied()
919 .collect::<Vec<_>>();
920
921 if trimmed.is_empty() {
922 return None;
923 }
924
925 if trimmed.eq_ignore_ascii_case(b"null") {
927 return Some(PkValue::Null);
928 }
929
930 let unquoted = if trimmed.first() == Some(&b'\'') && trimmed.last() == Some(&b'\'') {
932 &trimmed[1..trimmed.len() - 1]
933 } else {
934 &trimmed[..]
935 };
936
937 if let Ok(s) = std::str::from_utf8(unquoted) {
939 if let Ok(i) = s.parse::<i64>() {
940 return Some(PkValue::Int(i));
941 }
942 if let Ok(i) = s.parse::<i128>() {
943 return Some(PkValue::BigInt(i));
944 }
945 return Some(PkValue::Text(s.into()));
946 }
947
948 None
949}
950
951fn write_output(
953 config: &ShardConfig,
954 _graph: &SchemaGraph,
955 table_order: &[TableId],
956 runtimes: &AHashMap<TableId, TableRuntime>,
957 tables_dir: &Path,
958 stats: &ShardStats,
959) -> anyhow::Result<()> {
960 let mut writer: Box<dyn Write> = match &config.output {
961 Some(path) => {
962 if let Some(parent) = path.parent() {
963 fs::create_dir_all(parent)?;
964 }
965 Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
966 }
967 None => Box::new(BufWriter::new(std::io::stdout())),
968 };
969
970 write_header(&mut writer, config, stats)?;
972
973 write_dialect_header(&mut writer, config.dialect)?;
975
976 if config.include_schema {
978 for &table_id in table_order {
979 let runtime = match runtimes.get(&table_id) {
980 Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
981 _ => continue,
982 };
983
984 let table_file = tables_dir.join(format!("{}.sql", runtime.name));
985 if !table_file.exists() {
986 continue;
987 }
988
989 let file = File::open(&table_file)?;
990 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
991
992 while let Some(stmt) = parser.read_statement()? {
993 let (stmt_type, _) =
994 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
995
996 if stmt_type.is_schema() {
997 writer.write_all(&stmt)?;
998 writer.write_all(b"\n")?;
999 }
1000 }
1001 }
1002 }
1003
1004 for &table_id in table_order {
1006 let runtime = match runtimes.get(&table_id) {
1007 Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
1008 _ => continue,
1009 };
1010
1011 let table_name = &runtime.name;
1012 let row_count = runtime.selected_rows.len();
1013
1014 writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1015
1016 const CHUNK_SIZE: usize = 1000;
1017
1018 let quoted_name = match config.dialect {
1019 SqlDialect::MySql => format!("`{}`", table_name),
1020 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1021 SqlDialect::Mssql => format!("[{}]", table_name),
1022 };
1023
1024 for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
1025 writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1026
1027 for (i, row) in chunk.iter().enumerate() {
1028 if i > 0 {
1029 writer.write_all(b",\n")?;
1030 }
1031
1032 let values = match row.format {
1033 RowFormat::Insert => match config.dialect {
1034 SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
1035 _ => row.raw.clone(),
1036 },
1037 RowFormat::Copy => convert_copy_to_insert_values(&row.raw, config.dialect),
1038 };
1039 writer.write_all(&values)?;
1040 }
1041
1042 writer.write_all(b";\n")?;
1043 }
1044 }
1045
1046 write_dialect_footer(&mut writer, config.dialect)?;
1048
1049 writer.flush()?;
1050
1051 Ok(())
1052}
1053
1054fn write_header<W: Write>(
1056 writer: &mut W,
1057 config: &ShardConfig,
1058 stats: &ShardStats,
1059) -> std::io::Result<()> {
1060 writeln!(writer, "-- Sharded from: {}", config.input.display())?;
1061 writeln!(
1062 writer,
1063 "-- Date: {}",
1064 chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1065 )?;
1066 if let Some(ref col) = stats.detected_tenant_column {
1067 writeln!(writer, "-- Tenant column: {}", col)?;
1068 }
1069 writeln!(writer, "-- Tenant value: {}", config.tenant_value)?;
1070 writeln!(writer, "-- Dialect: {}", config.dialect)?;
1071 writeln!(writer, "--")?;
1072 writeln!(writer, "-- Statistics:")?;
1073 writeln!(writer, "-- Tables processed: {}", stats.tables_processed)?;
1074 writeln!(writer, "-- Tables with data: {}", stats.tables_with_data)?;
1075 writeln!(writer, "-- Tables skipped: {}", stats.tables_skipped)?;
1076
1077 let percent = if stats.total_rows_seen > 0 {
1078 (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1079 } else {
1080 0.0
1081 };
1082 writeln!(
1083 writer,
1084 "-- Total rows: {} (from {} original, {:.1}%)",
1085 stats.total_rows_selected, stats.total_rows_seen, percent
1086 )?;
1087
1088 if stats.fk_orphans_skipped > 0 {
1089 writeln!(
1090 writer,
1091 "-- FK orphans skipped: {}",
1092 stats.fk_orphans_skipped
1093 )?;
1094 }
1095
1096 if !stats.warnings.is_empty() {
1097 writeln!(writer, "-- Warnings: {}", stats.warnings.len())?;
1098 }
1099
1100 writeln!(writer)?;
1101
1102 Ok(())
1103}
1104
1105fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1107 match dialect {
1108 SqlDialect::MySql => {
1109 writeln!(writer, "SET NAMES utf8mb4;")?;
1110 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1111 }
1112 SqlDialect::Postgres => {
1113 writeln!(writer, "SET client_encoding = 'UTF8';")?;
1114 writeln!(writer, "SET session_replication_role = replica;")?;
1115 }
1116 SqlDialect::Sqlite => {
1117 writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1118 }
1119 SqlDialect::Mssql => {
1120 writeln!(writer, "SET ANSI_NULLS ON;")?;
1121 writeln!(writer, "SET QUOTED_IDENTIFIER ON;")?;
1122 writeln!(writer, "SET NOCOUNT ON;")?;
1123 }
1124 }
1125 writeln!(writer)?;
1126 Ok(())
1127}
1128
1129fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1131 writeln!(writer)?;
1132 match dialect {
1133 SqlDialect::MySql => {
1134 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1135 }
1136 SqlDialect::Postgres => {
1137 writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1138 }
1139 SqlDialect::Sqlite => {
1140 writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1141 }
1142 SqlDialect::Mssql => {
1143 }
1145 }
1146 Ok(())
1147}
1148
1149fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1151 let mut result = Vec::with_capacity(row.len());
1152 let mut i = 0;
1153
1154 while i < row.len() {
1155 if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1156 result.push(b'\'');
1157 result.push(b'\'');
1158 i += 2;
1159 } else {
1160 result.push(row[i]);
1161 i += 1;
1162 }
1163 }
1164
1165 result
1166}
1167
1168fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1170 let mut result = Vec::with_capacity(row.len() + 20);
1171 result.push(b'(');
1172
1173 let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1174
1175 for (i, field) in fields.iter().enumerate() {
1176 if i > 0 {
1177 result.extend_from_slice(b", ");
1178 }
1179
1180 if *field == b"\\N" {
1181 result.extend_from_slice(b"NULL");
1182 } else if field.is_empty() {
1183 result.extend_from_slice(b"''");
1184 } else if is_numeric(field) {
1185 result.extend_from_slice(field);
1186 } else {
1187 result.push(b'\'');
1188 for &b in *field {
1189 match b {
1190 b'\'' => match dialect {
1191 SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1192 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1193 result.extend_from_slice(b"''")
1194 }
1195 },
1196 b'\\' if dialect == SqlDialect::MySql => {
1197 result.extend_from_slice(b"\\\\");
1198 }
1199 _ => result.push(b),
1200 }
1201 }
1202 result.push(b'\'');
1203 }
1204 }
1205
1206 result.push(b')');
1207 result
1208}
1209
1210fn is_numeric(s: &[u8]) -> bool {
1212 if s.is_empty() {
1213 return false;
1214 }
1215
1216 let mut has_digit = false;
1217 let mut has_dot = false;
1218 let mut start = 0;
1219
1220 if s[0] == b'-' || s[0] == b'+' {
1221 start = 1;
1222 }
1223
1224 for &b in &s[start..] {
1225 match b {
1226 b'0'..=b'9' => has_digit = true,
1227 b'.' if !has_dot => has_dot = true,
1228 b'e' | b'E' => continue,
1229 _ => return false,
1230 }
1231 }
1232
1233 has_digit
1234}