sql_splitter/shard/
mod.rs

1//! Shard command for extracting tenant-specific data from SQL dumps.
2//!
3//! The shard command extracts data belonging to a specific tenant by:
4//! - Identifying tables with the tenant column (tenant roots)
5//! - Following FK chains to include dependent data
6//! - Including junction/pivot tables where any FK matches tenant data
7//! - Optionally including global/lookup tables
8//!
9//! Supports MySQL, PostgreSQL, and SQLite dialects.
10
11mod config;
12
13pub use config::{
14    DefaultShardClassifier, GlobalTableMode, ShardTableClassification, ShardYamlConfig,
15};
16
17use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet, PkTuple, PkValue};
18use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
19use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
20use crate::schema::{SchemaBuilder, SchemaGraph, TableId, TableSchema};
21use crate::splitter::Splitter;
22use ahash::{AHashMap, AHashSet};
23use indicatif::{ProgressBar, ProgressStyle};
24use std::fs::{self, File};
25use std::io::{BufWriter, Write};
26use std::path::{Path, PathBuf};
27use tempfile::TempDir;
28
29/// Configuration for the shard command
30#[derive(Debug)]
31pub struct ShardConfig {
32    /// Input SQL file
33    pub input: PathBuf,
34    /// Output SQL file (None for stdout)
35    pub output: Option<PathBuf>,
36    /// SQL dialect
37    pub dialect: SqlDialect,
38    /// Tenant column name (auto-detected if None)
39    pub tenant_column: Option<String>,
40    /// Tenant value to extract
41    pub tenant_value: String,
42    /// Explicit root tables (tables that have the tenant column)
43    pub root_tables: Vec<String>,
44    /// How to handle global/lookup tables
45    pub include_global: GlobalTableMode,
46    /// Dry run mode (show stats only)
47    pub dry_run: bool,
48    /// Show progress
49    pub progress: bool,
50    /// YAML config file path
51    pub config_file: Option<PathBuf>,
52    /// Maximum selected rows (memory guard)
53    pub max_selected_rows: Option<usize>,
54    /// Fail if any FK integrity issues detected
55    pub strict_fk: bool,
56    /// Include schema statements in output
57    pub include_schema: bool,
58}
59
60impl Default for ShardConfig {
61    fn default() -> Self {
62        Self {
63            input: PathBuf::new(),
64            output: None,
65            dialect: SqlDialect::MySql,
66            tenant_column: None,
67            tenant_value: String::new(),
68            root_tables: Vec::new(),
69            include_global: GlobalTableMode::Lookups,
70            dry_run: false,
71            progress: false,
72            config_file: None,
73            max_selected_rows: Some(10_000_000),
74            strict_fk: false,
75            include_schema: true,
76        }
77    }
78}
79
80/// Statistics from shard operation
81#[derive(Debug, Default)]
82pub struct ShardStats {
83    /// Number of tables processed
84    pub tables_processed: usize,
85    /// Number of tables skipped
86    pub tables_skipped: usize,
87    /// Number of tables with data included
88    pub tables_with_data: usize,
89    /// Total rows selected
90    pub total_rows_selected: u64,
91    /// Total rows seen
92    pub total_rows_seen: u64,
93    /// Per-table statistics
94    pub table_stats: Vec<TableShardStats>,
95    /// Warning messages
96    pub warnings: Vec<String>,
97    /// FK orphan count (rows with missing parents)
98    pub fk_orphans_skipped: u64,
99    /// Detected tenant column
100    pub detected_tenant_column: Option<String>,
101}
102
103/// Per-table sharding statistics
104#[derive(Debug, Clone)]
105pub struct TableShardStats {
106    pub name: String,
107    pub rows_seen: u64,
108    pub rows_selected: u64,
109    pub classification: ShardTableClassification,
110}
111
112/// Runtime state for a table during sharding
113struct TableRuntime {
114    /// Table name
115    name: String,
116    /// Selected rows (raw INSERT format)
117    selected_rows: Vec<SelectedRow>,
118    /// Primary key set for FK membership checks
119    pk_set: PkSet,
120    /// Rows seen count
121    rows_seen: u64,
122    /// Whether to skip this table
123    skip: bool,
124    /// Table classification
125    classification: ShardTableClassification,
126    /// FK orphans encountered
127    fk_orphans: u64,
128    /// Column index for tenant column (if this is a tenant root)
129    tenant_column_index: Option<usize>,
130}
131
132/// Row format indicator
133#[derive(Debug, Clone, Copy, PartialEq)]
134enum RowFormat {
135    Insert,
136    Copy,
137}
138
139/// Selected row with format metadata
140struct SelectedRow {
141    raw: Vec<u8>,
142    format: RowFormat,
143}
144
145/// Combined row representation for both MySQL INSERT and PostgreSQL COPY
146enum UnifiedRow {
147    Insert(ParsedRow),
148    Copy(ParsedCopyRow),
149}
150
151impl UnifiedRow {
152    fn pk(&self) -> Option<&PkTuple> {
153        match self {
154            UnifiedRow::Insert(r) => r.pk.as_ref(),
155            UnifiedRow::Copy(r) => r.pk.as_ref(),
156        }
157    }
158
159    fn fk_values(&self) -> &[(crate::parser::mysql_insert::FkRef, PkTuple)] {
160        match self {
161            UnifiedRow::Insert(r) => &r.fk_values,
162            UnifiedRow::Copy(r) => &r.fk_values,
163        }
164    }
165
166    fn into_selected(self) -> SelectedRow {
167        match self {
168            UnifiedRow::Insert(r) => SelectedRow {
169                raw: r.raw,
170                format: RowFormat::Insert,
171            },
172            UnifiedRow::Copy(r) => SelectedRow {
173                raw: r.raw,
174                format: RowFormat::Copy,
175            },
176        }
177    }
178}
179
180/// Run the shard command
181pub fn run(config: ShardConfig) -> anyhow::Result<ShardStats> {
182    let yaml_config = if let Some(ref path) = config.config_file {
183        Some(ShardYamlConfig::load(path)?)
184    } else {
185        None
186    };
187
188    let mut stats = ShardStats::default();
189
190    let progress_bar = if config.progress {
191        let pb = ProgressBar::new_spinner();
192        pb.set_style(
193            ProgressStyle::default_spinner()
194                .template("{spinner:.green} {msg}")
195                .unwrap(),
196        );
197        Some(pb)
198    } else {
199        None
200    };
201
202    // Phase 0: Split into temp per-table files
203    let temp_dir = TempDir::new()?;
204    let tables_dir = temp_dir.path().join("tables");
205
206    if let Some(ref pb) = progress_bar {
207        pb.set_message("Splitting dump into per-table files...");
208    }
209
210    let splitter = Splitter::new(config.input.clone(), tables_dir.clone())
211        .with_dialect(config.dialect)
212        .with_content_filter(ContentFilter::All);
213
214    let split_stats = splitter.split()?;
215
216    if let Some(ref pb) = progress_bar {
217        pb.set_message(format!(
218            "Split complete: {} tables, {} statements",
219            split_stats.tables_found, split_stats.statements_processed
220        ));
221    }
222
223    // Phase 1: Build schema graph
224    if let Some(ref pb) = progress_bar {
225        pb.set_message("Building schema graph...");
226    }
227
228    let graph = build_schema_graph(&tables_dir, &config)?;
229
230    // Detect or use configured tenant column
231    let tenant_column = detect_tenant_column(&config, &yaml_config, &graph)?;
232    stats.detected_tenant_column = Some(tenant_column.clone());
233
234    if let Some(ref pb) = progress_bar {
235        pb.set_message(format!("Using tenant column: {}", tenant_column));
236    }
237
238    // Parse tenant value
239    let tenant_pk_value = parse_tenant_value(&config.tenant_value);
240
241    // Phase 2: Classify tables and build runtimes
242    let (topo_order, cyclic_tables) = graph.processing_order();
243
244    if !cyclic_tables.is_empty() {
245        let names: Vec<_> = cyclic_tables
246            .iter()
247            .filter_map(|&id| graph.table_name(id))
248            .collect();
249        stats.warnings.push(format!(
250            "{} tables have FK cycles (relaxed FK enforcement): {:?}",
251            cyclic_tables.len(),
252            names
253        ));
254    }
255
256    let cyclic_set: AHashSet<TableId> = cyclic_tables.iter().copied().collect();
257
258    // Determine tenant root tables
259    let tenant_root_ids = find_tenant_root_tables(&graph, &tenant_column, &config, &yaml_config);
260
261    // Build reachability from tenant roots
262    let reachable_from_roots = compute_reachable_tables(&graph, &tenant_root_ids);
263
264    // Initialize table runtimes with classification
265    let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
266    for table in graph.schema.iter() {
267        let classification = classify_table(
268            &table.name,
269            table.id,
270            &graph,
271            &tenant_root_ids,
272            &reachable_from_roots,
273            &yaml_config,
274        );
275
276        let tenant_column_index = if classification == ShardTableClassification::TenantRoot {
277            find_tenant_column_index(table, &tenant_column)
278        } else {
279            None
280        };
281
282        let skip = should_skip_table(&table.name, classification, &config, &yaml_config);
283
284        runtimes.insert(
285            table.id,
286            TableRuntime {
287                name: table.name.clone(),
288                selected_rows: Vec::new(),
289                pk_set: PkSet::default(),
290                rows_seen: 0,
291                skip,
292                classification,
293                fk_orphans: 0,
294                tenant_column_index,
295            },
296        );
297    }
298
299    // Phase 3: Process tables in dependency order
300    if let Some(ref pb) = progress_bar {
301        pb.set_message(format!(
302            "Processing {} tables for tenant {}...",
303            topo_order.len() + cyclic_tables.len(),
304            config.tenant_value
305        ));
306    }
307
308    let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
309    let mut total_selected: u64 = 0;
310
311    for &table_id in &all_tables {
312        let table_schema = match graph.schema.table(table_id) {
313            Some(s) => s,
314            None => continue,
315        };
316
317        let (should_skip, table_name, classification, tenant_col_idx) = {
318            let runtime = match runtimes.get(&table_id) {
319                Some(r) => r,
320                None => continue,
321            };
322            (
323                runtime.skip,
324                runtime.name.clone(),
325                runtime.classification,
326                runtime.tenant_column_index,
327            )
328        };
329
330        if should_skip {
331            stats.tables_skipped += 1;
332            continue;
333        }
334
335        // Handle lookup/system tables
336        let include_all = match classification {
337            ShardTableClassification::Lookup => match config.include_global {
338                GlobalTableMode::None => {
339                    stats.tables_skipped += 1;
340                    continue;
341                }
342                GlobalTableMode::Lookups | GlobalTableMode::All => true,
343            },
344            ShardTableClassification::System => {
345                stats.tables_skipped += 1;
346                continue;
347            }
348            ShardTableClassification::Unknown => match config.include_global {
349                GlobalTableMode::All => true,
350                _ => {
351                    stats.tables_skipped += 1;
352                    continue;
353                }
354            },
355            _ => false,
356        };
357
358        let table_file = tables_dir.join(format!("{}.sql", table_name));
359        if !table_file.exists() {
360            continue;
361        }
362
363        let file = File::open(&table_file)?;
364        let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
365
366        let mut rows_seen = 0u64;
367        let mut fk_orphans = 0u64;
368        let mut copy_columns: Vec<String> = Vec::new();
369
370        while let Some(stmt) = parser.read_statement()? {
371            let (stmt_type, _) =
372                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
373
374            match stmt_type {
375                StatementType::Insert => {
376                    let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
377
378                    for row in rows {
379                        rows_seen += 1;
380                        let unified = UnifiedRow::Insert(row);
381
382                        let should_include = if include_all {
383                            true
384                        } else {
385                            should_include_row(
386                                &unified,
387                                table_schema,
388                                classification,
389                                tenant_col_idx,
390                                &tenant_pk_value,
391                                &runtimes,
392                                &cyclic_set,
393                                &table_id,
394                            )
395                        };
396
397                        if !should_include {
398                            if classification == ShardTableClassification::TenantDependent {
399                                fk_orphans += 1;
400                            }
401                            continue;
402                        }
403
404                        // Check max_selected_rows guard
405                        if let Some(max) = config.max_selected_rows {
406                            if total_selected >= max as u64 {
407                                stats.warnings.push(format!(
408                                    "Reached max_selected_rows limit ({}) at table '{}'",
409                                    max, table_name
410                                ));
411                                break;
412                            }
413                        }
414
415                        total_selected += 1;
416
417                        let runtime = runtimes.get_mut(&table_id).unwrap();
418                        if let Some(pk) = unified.pk() {
419                            runtime.pk_set.insert(pk.clone());
420                        }
421                        runtime.selected_rows.push(unified.into_selected());
422                    }
423                }
424                StatementType::Copy => {
425                    let header = String::from_utf8_lossy(&stmt);
426                    copy_columns = parse_copy_columns(&header);
427                }
428                StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
429                    if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
430                        let rows =
431                            parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
432
433                        for row in rows {
434                            rows_seen += 1;
435                            let unified = UnifiedRow::Copy(row);
436
437                            let should_include = if include_all {
438                                true
439                            } else {
440                                should_include_row(
441                                    &unified,
442                                    table_schema,
443                                    classification,
444                                    tenant_col_idx,
445                                    &tenant_pk_value,
446                                    &runtimes,
447                                    &cyclic_set,
448                                    &table_id,
449                                )
450                            };
451
452                            if !should_include {
453                                if classification == ShardTableClassification::TenantDependent {
454                                    fk_orphans += 1;
455                                }
456                                continue;
457                            }
458
459                            if let Some(max) = config.max_selected_rows {
460                                if total_selected >= max as u64 {
461                                    break;
462                                }
463                            }
464
465                            total_selected += 1;
466
467                            let runtime = runtimes.get_mut(&table_id).unwrap();
468                            if let Some(pk) = unified.pk() {
469                                runtime.pk_set.insert(pk.clone());
470                            }
471                            runtime.selected_rows.push(unified.into_selected());
472                        }
473                    }
474                }
475                _ => {}
476            }
477        }
478
479        let runtime = runtimes.get_mut(&table_id).unwrap();
480        runtime.rows_seen = rows_seen;
481        runtime.fk_orphans = fk_orphans;
482        stats.fk_orphans_skipped += fk_orphans;
483
484        if !runtime.selected_rows.is_empty() {
485            stats.tables_with_data += 1;
486        }
487
488        stats.table_stats.push(TableShardStats {
489            name: runtime.name.clone(),
490            rows_seen: runtime.rows_seen,
491            rows_selected: runtime.selected_rows.len() as u64,
492            classification: runtime.classification,
493        });
494    }
495
496    // Calculate totals
497    for table_stat in &stats.table_stats {
498        stats.total_rows_seen += table_stat.rows_seen;
499        stats.total_rows_selected += table_stat.rows_selected;
500    }
501    stats.tables_processed = stats.table_stats.len();
502
503    if let Some(ref pb) = progress_bar {
504        pb.finish_with_message("Processing complete");
505    }
506
507    // Phase 4: Output synthesis
508    if config.dry_run {
509        return Ok(stats);
510    }
511
512    write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
513
514    Ok(stats)
515}
516
517/// Build schema graph from split table files
518fn build_schema_graph(tables_dir: &Path, config: &ShardConfig) -> anyhow::Result<SchemaGraph> {
519    let mut builder = SchemaBuilder::new();
520
521    for entry in fs::read_dir(tables_dir)? {
522        let entry = entry?;
523        let path = entry.path();
524
525        if path.extension().is_some_and(|e| e == "sql") {
526            let file = File::open(&path)?;
527            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
528
529            while let Some(stmt) = parser.read_statement()? {
530                let (stmt_type, _) =
531                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
532
533                match stmt_type {
534                    StatementType::CreateTable => {
535                        let stmt_str = String::from_utf8_lossy(&stmt);
536                        builder.parse_create_table(&stmt_str);
537                    }
538                    StatementType::AlterTable => {
539                        let stmt_str = String::from_utf8_lossy(&stmt);
540                        builder.parse_alter_table(&stmt_str);
541                    }
542                    _ => {}
543                }
544            }
545        }
546    }
547
548    Ok(SchemaGraph::from_schema(builder.build()))
549}
550
551/// Detect tenant column from config or by scanning schema
552fn detect_tenant_column(
553    config: &ShardConfig,
554    yaml_config: &Option<ShardYamlConfig>,
555    graph: &SchemaGraph,
556) -> anyhow::Result<String> {
557    // Check CLI option first
558    if let Some(ref col) = config.tenant_column {
559        return Ok(col.clone());
560    }
561
562    // Check YAML config
563    if let Some(ref yaml) = yaml_config {
564        if let Some(ref col) = yaml.tenant.column {
565            return Ok(col.clone());
566        }
567    }
568
569    // Auto-detect from schema
570    for candidate in DefaultShardClassifier::TENANT_COLUMNS {
571        let mut found_in_tables = 0;
572        for table in graph.schema.iter() {
573            if table.get_column(candidate).is_some() {
574                found_in_tables += 1;
575            }
576        }
577        if found_in_tables >= 2 {
578            return Ok(candidate.to_string());
579        }
580    }
581
582    anyhow::bail!(
583        "Could not auto-detect tenant column. Please specify with --tenant-column. \
584         Looked for: {:?}",
585        DefaultShardClassifier::TENANT_COLUMNS
586    )
587}
588
589/// Parse tenant value string into PkValue
590fn parse_tenant_value(value: &str) -> PkValue {
591    if let Ok(i) = value.parse::<i64>() {
592        PkValue::Int(i)
593    } else if let Ok(i) = value.parse::<i128>() {
594        PkValue::BigInt(i)
595    } else {
596        PkValue::Text(value.into())
597    }
598}
599
600/// Find tables that have the tenant column
601fn find_tenant_root_tables(
602    graph: &SchemaGraph,
603    tenant_column: &str,
604    config: &ShardConfig,
605    yaml_config: &Option<ShardYamlConfig>,
606) -> AHashSet<TableId> {
607    let mut roots = AHashSet::new();
608
609    // Explicit roots from config
610    let explicit_roots: AHashSet<String> = config
611        .root_tables
612        .iter()
613        .chain(
614            yaml_config
615                .as_ref()
616                .map(|y| &y.tenant.root_tables)
617                .unwrap_or(&Vec::new()),
618        )
619        .map(|s| s.to_lowercase())
620        .collect();
621
622    for table in graph.schema.iter() {
623        let lower_name = table.name.to_lowercase();
624
625        if explicit_roots.contains(&lower_name) || table.get_column(tenant_column).is_some() {
626            roots.insert(table.id);
627        }
628    }
629
630    roots
631}
632
633/// Compute tables reachable from tenant roots via FK relationships
634fn compute_reachable_tables(
635    graph: &SchemaGraph,
636    tenant_roots: &AHashSet<TableId>,
637) -> AHashSet<TableId> {
638    let mut reachable = tenant_roots.clone();
639    let mut queue: Vec<TableId> = tenant_roots.iter().copied().collect();
640
641    while let Some(table_id) = queue.pop() {
642        for &child_id in &graph.children[table_id.0 as usize] {
643            if !reachable.contains(&child_id) {
644                reachable.insert(child_id);
645                queue.push(child_id);
646            }
647        }
648    }
649
650    reachable
651}
652
653/// Classify a table for sharding
654fn classify_table(
655    table_name: &str,
656    table_id: TableId,
657    graph: &SchemaGraph,
658    tenant_roots: &AHashSet<TableId>,
659    reachable: &AHashSet<TableId>,
660    yaml_config: &Option<ShardYamlConfig>,
661) -> ShardTableClassification {
662    // Check YAML override first
663    if let Some(ref yaml) = yaml_config {
664        if let Some(class) = yaml.get_classification(table_name) {
665            return class;
666        }
667    }
668
669    // Check if it's a tenant root
670    if tenant_roots.contains(&table_id) {
671        return ShardTableClassification::TenantRoot;
672    }
673
674    // Check if reachable from tenant roots (dependent)
675    if reachable.contains(&table_id) {
676        // Check if it might be a junction table
677        if is_junction_table(table_name, table_id, graph) {
678            return ShardTableClassification::Junction;
679        }
680        return ShardTableClassification::TenantDependent;
681    }
682
683    // Check system patterns
684    if DefaultShardClassifier::is_system_table(table_name) {
685        return ShardTableClassification::System;
686    }
687
688    // Check lookup patterns
689    if DefaultShardClassifier::is_lookup_table(table_name) {
690        return ShardTableClassification::Lookup;
691    }
692
693    ShardTableClassification::Unknown
694}
695
696/// Check if a table is a junction table
697fn is_junction_table(table_name: &str, table_id: TableId, graph: &SchemaGraph) -> bool {
698    // Name-based heuristic
699    if DefaultShardClassifier::is_junction_table_by_name(table_name) {
700        return true;
701    }
702
703    // Structure-based: table with multiple FKs and few/no other columns
704    if let Some(table) = graph.schema.table(table_id) {
705        let fk_count = table.foreign_keys.len();
706        let fk_col_count: usize = table.foreign_keys.iter().map(|fk| fk.columns.len()).sum();
707        let total_cols = table.columns.len();
708
709        // Junction tables typically have mostly FK columns
710        if fk_count >= 2 && fk_col_count >= total_cols.saturating_sub(2) {
711            return true;
712        }
713    }
714
715    false
716}
717
718/// Find the index of the tenant column in a table
719fn find_tenant_column_index(table: &TableSchema, tenant_column: &str) -> Option<usize> {
720    table
721        .columns
722        .iter()
723        .position(|c| c.name.eq_ignore_ascii_case(tenant_column))
724}
725
726/// Determine if a table should be skipped
727fn should_skip_table(
728    table_name: &str,
729    classification: ShardTableClassification,
730    config: &ShardConfig,
731    yaml_config: &Option<ShardYamlConfig>,
732) -> bool {
733    // Check YAML skip override
734    if let Some(ref yaml) = yaml_config {
735        if yaml.should_skip(table_name) {
736            return true;
737        }
738    }
739
740    // System tables always skipped
741    if classification == ShardTableClassification::System {
742        return true;
743    }
744
745    // Lookup tables depend on config
746    if classification == ShardTableClassification::Lookup {
747        return config.include_global == GlobalTableMode::None;
748    }
749
750    false
751}
752
753/// Check if a row should be included in the shard
754#[allow(clippy::too_many_arguments)]
755fn should_include_row(
756    row: &UnifiedRow,
757    table_schema: &TableSchema,
758    classification: ShardTableClassification,
759    tenant_column_index: Option<usize>,
760    tenant_value: &PkValue,
761    runtimes: &AHashMap<TableId, TableRuntime>,
762    cyclic_set: &AHashSet<TableId>,
763    current_table_id: &TableId,
764) -> bool {
765    match classification {
766        ShardTableClassification::TenantRoot => {
767            // Check tenant column value
768            if let Some(idx) = tenant_column_index {
769                match row {
770                    UnifiedRow::Insert(r) => {
771                        if let Some(val) = extract_column_value(&r.raw, idx) {
772                            return &val == tenant_value;
773                        }
774                    }
775                    UnifiedRow::Copy(r) => {
776                        if let Some(val) = extract_copy_column_value(&r.raw, idx) {
777                            return &val == tenant_value;
778                        }
779                    }
780                }
781            }
782            false
783        }
784        ShardTableClassification::TenantDependent => {
785            // Check if any FK points to a selected row
786            for (fk_ref, fk_tuple) in row.fk_values() {
787                if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
788                    if let Some(parent_id) = fk.referenced_table_id {
789                        // Skip FK check for cyclic tables
790                        if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id)
791                        {
792                            continue;
793                        }
794
795                        if let Some(parent_runtime) = runtimes.get(&parent_id) {
796                            if parent_runtime.pk_set.contains(fk_tuple) {
797                                return true;
798                            }
799                        }
800                    }
801                }
802            }
803            false
804        }
805        ShardTableClassification::Junction => {
806            // Include if ANY FK points to a selected row
807            for (fk_ref, fk_tuple) in row.fk_values() {
808                if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
809                    if let Some(parent_id) = fk.referenced_table_id {
810                        if let Some(parent_runtime) = runtimes.get(&parent_id) {
811                            if parent_runtime.pk_set.contains(fk_tuple) {
812                                return true;
813                            }
814                        }
815                    }
816                }
817            }
818            false
819        }
820        _ => false,
821    }
822}
823
824/// Extract a column value from INSERT row bytes by index
825fn extract_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
826    let mut values = Vec::new();
827    let mut current_start = 0;
828    let mut in_string = false;
829    let mut escape_next = false;
830    let mut paren_depth = 0;
831
832    // Skip leading (
833    let start = raw.iter().position(|&b| b == b'(')?;
834    let raw = &raw[start + 1..];
835
836    for (i, &b) in raw.iter().enumerate() {
837        if escape_next {
838            escape_next = false;
839            continue;
840        }
841
842        if b == b'\\' && in_string {
843            escape_next = true;
844            continue;
845        }
846
847        if b == b'\'' && !escape_next {
848            in_string = !in_string;
849            continue;
850        }
851
852        if in_string {
853            continue;
854        }
855
856        match b {
857            b'(' => paren_depth += 1,
858            b')' => {
859                if paren_depth == 0 {
860                    values.push(&raw[current_start..i]);
861                    break;
862                }
863                paren_depth -= 1;
864            }
865            b',' if paren_depth == 0 => {
866                values.push(&raw[current_start..i]);
867                current_start = i + 1;
868            }
869            _ => {}
870        }
871    }
872
873    if column_index >= values.len() {
874        return None;
875    }
876
877    parse_value_bytes(values[column_index])
878}
879
880/// Extract a column value from COPY row bytes by index
881fn extract_copy_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
882    let fields: Vec<&[u8]> = raw.split(|&b| b == b'\t').collect();
883    if column_index >= fields.len() {
884        return None;
885    }
886
887    let field = fields[column_index];
888    if field == b"\\N" {
889        return Some(PkValue::Null);
890    }
891
892    parse_value_bytes(field)
893}
894
895/// Parse a byte slice into a PkValue
896fn parse_value_bytes(bytes: &[u8]) -> Option<PkValue> {
897    let trimmed = bytes
898        .iter()
899        .skip_while(|&&b| b == b' ')
900        .take_while(|&&b| b != b' ')
901        .copied()
902        .collect::<Vec<_>>();
903
904    if trimmed.is_empty() {
905        return None;
906    }
907
908    // Check for NULL
909    if trimmed.eq_ignore_ascii_case(b"null") {
910        return Some(PkValue::Null);
911    }
912
913    // Remove quotes for strings
914    let unquoted = if trimmed.first() == Some(&b'\'') && trimmed.last() == Some(&b'\'') {
915        &trimmed[1..trimmed.len() - 1]
916    } else {
917        &trimmed[..]
918    };
919
920    // Try parsing as number
921    if let Ok(s) = std::str::from_utf8(unquoted) {
922        if let Ok(i) = s.parse::<i64>() {
923            return Some(PkValue::Int(i));
924        }
925        if let Ok(i) = s.parse::<i128>() {
926            return Some(PkValue::BigInt(i));
927        }
928        return Some(PkValue::Text(s.into()));
929    }
930
931    None
932}
933
934/// Write the sharded output
935fn write_output(
936    config: &ShardConfig,
937    _graph: &SchemaGraph,
938    table_order: &[TableId],
939    runtimes: &AHashMap<TableId, TableRuntime>,
940    tables_dir: &Path,
941    stats: &ShardStats,
942) -> anyhow::Result<()> {
943    let mut writer: Box<dyn Write> = match &config.output {
944        Some(path) => {
945            if let Some(parent) = path.parent() {
946                fs::create_dir_all(parent)?;
947            }
948            Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
949        }
950        None => Box::new(BufWriter::new(std::io::stdout())),
951    };
952
953    // Write header comment
954    write_header(&mut writer, config, stats)?;
955
956    // Write dialect-specific header
957    write_dialect_header(&mut writer, config.dialect)?;
958
959    // Write schema for each table (if enabled)
960    if config.include_schema {
961        for &table_id in table_order {
962            let runtime = match runtimes.get(&table_id) {
963                Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
964                _ => continue,
965            };
966
967            let table_file = tables_dir.join(format!("{}.sql", runtime.name));
968            if !table_file.exists() {
969                continue;
970            }
971
972            let file = File::open(&table_file)?;
973            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
974
975            while let Some(stmt) = parser.read_statement()? {
976                let (stmt_type, _) =
977                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
978
979                if stmt_type.is_schema() {
980                    writer.write_all(&stmt)?;
981                    writer.write_all(b"\n")?;
982                }
983            }
984        }
985    }
986
987    // Write data for each table
988    for &table_id in table_order {
989        let runtime = match runtimes.get(&table_id) {
990            Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
991            _ => continue,
992        };
993
994        let table_name = &runtime.name;
995        let row_count = runtime.selected_rows.len();
996
997        writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
998
999        const CHUNK_SIZE: usize = 1000;
1000
1001        let quoted_name = match config.dialect {
1002            SqlDialect::MySql => format!("`{}`", table_name),
1003            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1004        };
1005
1006        for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
1007            writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1008
1009            for (i, row) in chunk.iter().enumerate() {
1010                if i > 0 {
1011                    writer.write_all(b",\n")?;
1012                }
1013
1014                let values = match row.format {
1015                    RowFormat::Insert => match config.dialect {
1016                        SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
1017                        _ => row.raw.clone(),
1018                    },
1019                    RowFormat::Copy => convert_copy_to_insert_values(&row.raw, config.dialect),
1020                };
1021                writer.write_all(&values)?;
1022            }
1023
1024            writer.write_all(b";\n")?;
1025        }
1026    }
1027
1028    // Write dialect-specific footer
1029    write_dialect_footer(&mut writer, config.dialect)?;
1030
1031    writer.flush()?;
1032
1033    Ok(())
1034}
1035
1036/// Write header comment
1037fn write_header<W: Write>(
1038    writer: &mut W,
1039    config: &ShardConfig,
1040    stats: &ShardStats,
1041) -> std::io::Result<()> {
1042    writeln!(writer, "-- Sharded from: {}", config.input.display())?;
1043    writeln!(
1044        writer,
1045        "-- Date: {}",
1046        chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1047    )?;
1048    if let Some(ref col) = stats.detected_tenant_column {
1049        writeln!(writer, "-- Tenant column: {}", col)?;
1050    }
1051    writeln!(writer, "-- Tenant value: {}", config.tenant_value)?;
1052    writeln!(writer, "-- Dialect: {}", config.dialect)?;
1053    writeln!(writer, "--")?;
1054    writeln!(writer, "-- Statistics:")?;
1055    writeln!(writer, "--   Tables processed: {}", stats.tables_processed)?;
1056    writeln!(writer, "--   Tables with data: {}", stats.tables_with_data)?;
1057    writeln!(writer, "--   Tables skipped: {}", stats.tables_skipped)?;
1058
1059    let percent = if stats.total_rows_seen > 0 {
1060        (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1061    } else {
1062        0.0
1063    };
1064    writeln!(
1065        writer,
1066        "--   Total rows: {} (from {} original, {:.1}%)",
1067        stats.total_rows_selected, stats.total_rows_seen, percent
1068    )?;
1069
1070    if stats.fk_orphans_skipped > 0 {
1071        writeln!(
1072            writer,
1073            "--   FK orphans skipped: {}",
1074            stats.fk_orphans_skipped
1075        )?;
1076    }
1077
1078    if !stats.warnings.is_empty() {
1079        writeln!(writer, "--   Warnings: {}", stats.warnings.len())?;
1080    }
1081
1082    writeln!(writer)?;
1083
1084    Ok(())
1085}
1086
1087/// Write dialect-specific header
1088fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1089    match dialect {
1090        SqlDialect::MySql => {
1091            writeln!(writer, "SET NAMES utf8mb4;")?;
1092            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1093        }
1094        SqlDialect::Postgres => {
1095            writeln!(writer, "SET client_encoding = 'UTF8';")?;
1096            writeln!(writer, "SET session_replication_role = replica;")?;
1097        }
1098        SqlDialect::Sqlite => {
1099            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1100        }
1101    }
1102    writeln!(writer)?;
1103    Ok(())
1104}
1105
1106/// Write dialect-specific footer
1107fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1108    writeln!(writer)?;
1109    match dialect {
1110        SqlDialect::MySql => {
1111            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1112        }
1113        SqlDialect::Postgres => {
1114            writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1115        }
1116        SqlDialect::Sqlite => {
1117            writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1118        }
1119    }
1120    Ok(())
1121}
1122
1123/// Convert a MySQL-style row to PostgreSQL syntax
1124fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1125    let mut result = Vec::with_capacity(row.len());
1126    let mut i = 0;
1127
1128    while i < row.len() {
1129        if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1130            result.push(b'\'');
1131            result.push(b'\'');
1132            i += 2;
1133        } else {
1134            result.push(row[i]);
1135            i += 1;
1136        }
1137    }
1138
1139    result
1140}
1141
1142/// Convert PostgreSQL COPY format to INSERT VALUES format
1143fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1144    let mut result = Vec::with_capacity(row.len() + 20);
1145    result.push(b'(');
1146
1147    let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1148
1149    for (i, field) in fields.iter().enumerate() {
1150        if i > 0 {
1151            result.extend_from_slice(b", ");
1152        }
1153
1154        if *field == b"\\N" {
1155            result.extend_from_slice(b"NULL");
1156        } else if field.is_empty() {
1157            result.extend_from_slice(b"''");
1158        } else if is_numeric(field) {
1159            result.extend_from_slice(field);
1160        } else {
1161            result.push(b'\'');
1162            for &b in *field {
1163                match b {
1164                    b'\'' => match dialect {
1165                        SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1166                        SqlDialect::Postgres | SqlDialect::Sqlite => {
1167                            result.extend_from_slice(b"''")
1168                        }
1169                    },
1170                    b'\\' if dialect == SqlDialect::MySql => {
1171                        result.extend_from_slice(b"\\\\");
1172                    }
1173                    _ => result.push(b),
1174                }
1175            }
1176            result.push(b'\'');
1177        }
1178    }
1179
1180    result.push(b')');
1181    result
1182}
1183
1184/// Check if a byte slice represents a numeric value
1185fn is_numeric(s: &[u8]) -> bool {
1186    if s.is_empty() {
1187        return false;
1188    }
1189
1190    let mut has_digit = false;
1191    let mut has_dot = false;
1192    let mut start = 0;
1193
1194    if s[0] == b'-' || s[0] == b'+' {
1195        start = 1;
1196    }
1197
1198    for &b in &s[start..] {
1199        match b {
1200            b'0'..=b'9' => has_digit = true,
1201            b'.' if !has_dot => has_dot = true,
1202            b'e' | b'E' => continue,
1203            _ => return false,
1204        }
1205    }
1206
1207    has_digit
1208}