Skip to main content

sql_splitter/shard/
mod.rs

1//! Shard command for extracting tenant-specific data from SQL dumps.
2//!
3//! The shard command extracts data belonging to a specific tenant by:
4//! - Identifying tables with the tenant column (tenant roots)
5//! - Following FK chains to include dependent data
6//! - Including junction/pivot tables where any FK matches tenant data
7//! - Optionally including global/lookup tables
8//!
9//! Supports MySQL, PostgreSQL, and SQLite dialects.
10
11mod config;
12
13pub use config::{
14    DefaultShardClassifier, GlobalTableMode, ShardTableClassification, ShardYamlConfig,
15};
16
17use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet, PkTuple, PkValue};
18use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
19use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
20use crate::schema::{SchemaBuilder, SchemaGraph, TableId, TableSchema};
21use crate::splitter::Splitter;
22use ahash::{AHashMap, AHashSet};
23use indicatif::{ProgressBar, ProgressStyle};
24use std::fs::{self, File};
25use std::io::{BufWriter, Write};
26use std::path::{Path, PathBuf};
27use tempfile::TempDir;
28
29/// Configuration for the shard command
30#[derive(Debug)]
31pub struct ShardConfig {
32    /// Input SQL file
33    pub input: PathBuf,
34    /// Output SQL file (None for stdout)
35    pub output: Option<PathBuf>,
36    /// SQL dialect
37    pub dialect: SqlDialect,
38    /// Tenant column name (auto-detected if None)
39    pub tenant_column: Option<String>,
40    /// Tenant value to extract
41    pub tenant_value: String,
42    /// Explicit root tables (tables that have the tenant column)
43    pub root_tables: Vec<String>,
44    /// How to handle global/lookup tables
45    pub include_global: GlobalTableMode,
46    /// Dry run mode (show stats only)
47    pub dry_run: bool,
48    /// Show progress
49    pub progress: bool,
50    /// YAML config file path
51    pub config_file: Option<PathBuf>,
52    /// Maximum selected rows (memory guard)
53    pub max_selected_rows: Option<usize>,
54    /// Fail if any FK integrity issues detected
55    pub strict_fk: bool,
56    /// Include schema statements in output
57    pub include_schema: bool,
58}
59
60impl Default for ShardConfig {
61    fn default() -> Self {
62        Self {
63            input: PathBuf::new(),
64            output: None,
65            dialect: SqlDialect::MySql,
66            tenant_column: None,
67            tenant_value: String::new(),
68            root_tables: Vec::new(),
69            include_global: GlobalTableMode::Lookups,
70            dry_run: false,
71            progress: false,
72            config_file: None,
73            max_selected_rows: Some(10_000_000),
74            strict_fk: false,
75            include_schema: true,
76        }
77    }
78}
79
80/// Statistics from shard operation
81#[derive(Debug, Default, serde::Serialize)]
82pub struct ShardStats {
83    /// Number of tables processed
84    pub tables_processed: usize,
85    /// Number of tables skipped
86    pub tables_skipped: usize,
87    /// Number of tables with data included
88    pub tables_with_data: usize,
89    /// Total rows selected
90    pub total_rows_selected: u64,
91    /// Total rows seen
92    pub total_rows_seen: u64,
93    /// Per-table statistics
94    pub table_stats: Vec<TableShardStats>,
95    /// Warning messages
96    pub warnings: Vec<String>,
97    /// FK orphan count (rows with missing parents)
98    pub fk_orphans_skipped: u64,
99    /// Detected tenant column
100    pub detected_tenant_column: Option<String>,
101}
102
103/// Per-table sharding statistics
104#[derive(Debug, Clone, serde::Serialize)]
105pub struct TableShardStats {
106    pub name: String,
107    pub rows_seen: u64,
108    pub rows_selected: u64,
109    pub classification: ShardTableClassification,
110}
111
112/// Runtime state for a table during sharding
113struct TableRuntime {
114    /// Table name
115    name: String,
116    /// Selected rows (raw INSERT format)
117    selected_rows: Vec<SelectedRow>,
118    /// Primary key set for FK membership checks
119    pk_set: PkSet,
120    /// Rows seen count
121    rows_seen: u64,
122    /// Whether to skip this table
123    skip: bool,
124    /// Table classification
125    classification: ShardTableClassification,
126    /// FK orphans encountered
127    fk_orphans: u64,
128    /// Column index for tenant column (if this is a tenant root)
129    tenant_column_index: Option<usize>,
130}
131
132/// Row format indicator
133#[derive(Debug, Clone, Copy, PartialEq)]
134enum RowFormat {
135    Insert,
136    Copy,
137}
138
139/// Selected row with format metadata
140struct SelectedRow {
141    raw: Vec<u8>,
142    format: RowFormat,
143}
144
145/// Combined row representation for both MySQL INSERT and PostgreSQL COPY
146enum UnifiedRow {
147    Insert(ParsedRow),
148    Copy(ParsedCopyRow),
149}
150
151impl UnifiedRow {
152    fn pk(&self) -> Option<&PkTuple> {
153        match self {
154            UnifiedRow::Insert(r) => r.pk.as_ref(),
155            UnifiedRow::Copy(r) => r.pk.as_ref(),
156        }
157    }
158
159    fn fk_values(&self) -> &[(crate::parser::mysql_insert::FkRef, PkTuple)] {
160        match self {
161            UnifiedRow::Insert(r) => &r.fk_values,
162            UnifiedRow::Copy(r) => &r.fk_values,
163        }
164    }
165
166    fn into_selected(self) -> SelectedRow {
167        match self {
168            UnifiedRow::Insert(r) => SelectedRow {
169                raw: r.raw,
170                format: RowFormat::Insert,
171            },
172            UnifiedRow::Copy(r) => SelectedRow {
173                raw: r.raw,
174                format: RowFormat::Copy,
175            },
176        }
177    }
178}
179
180/// Run the shard command
181pub fn run(config: ShardConfig) -> anyhow::Result<ShardStats> {
182    let yaml_config = if let Some(ref path) = config.config_file {
183        Some(ShardYamlConfig::load(path)?)
184    } else {
185        None
186    };
187
188    let mut stats = ShardStats::default();
189
190    // Get file size for progress tracking
191    let file_size = std::fs::metadata(&config.input)?.len();
192
193    // Progress bar setup - byte-based for the split phase
194    let progress_bar = if config.progress {
195        let pb = ProgressBar::new(file_size);
196        pb.set_style(
197            ProgressStyle::with_template(
198                "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
199            )
200            .unwrap()
201            .progress_chars("█▓▒░  ")
202            .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
203        );
204        pb.enable_steady_tick(std::time::Duration::from_millis(100));
205        pb.set_message("Splitting dump...");
206        Some(pb)
207    } else {
208        None
209    };
210
211    // Phase 0: Split into temp per-table files
212    let temp_dir = TempDir::new()?;
213    let tables_dir = temp_dir.path().join("tables");
214
215    let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
216        .with_dialect(config.dialect)
217        .with_content_filter(ContentFilter::All);
218
219    if let Some(ref pb) = progress_bar {
220        let pb_clone = pb.clone();
221        splitter = splitter.with_progress(move |bytes| {
222            pb_clone.set_position(bytes);
223        });
224    }
225
226    let split_stats = splitter.split()?;
227
228    // Finish byte-based progress, switch to milestone messages
229    if let Some(ref pb) = progress_bar {
230        pb.finish_and_clear();
231    }
232
233    if config.progress {
234        eprintln!(
235            "Split complete: {} tables, {} statements",
236            split_stats.tables_found, split_stats.statements_processed
237        );
238    }
239
240    // Phase 1: Build schema graph
241    if config.progress {
242        eprintln!("Building schema graph...");
243    }
244
245    let graph = build_schema_graph(&tables_dir, &config)?;
246
247    // Detect or use configured tenant column
248    let tenant_column = detect_tenant_column(&config, &yaml_config, &graph)?;
249    stats.detected_tenant_column = Some(tenant_column.clone());
250
251    if config.progress {
252        eprintln!("Using tenant column: {}", tenant_column);
253    }
254
255    // Parse tenant value
256    let tenant_pk_value = parse_tenant_value(&config.tenant_value);
257
258    // Phase 2: Classify tables and build runtimes
259    let (topo_order, cyclic_tables) = graph.processing_order();
260
261    if !cyclic_tables.is_empty() {
262        let names: Vec<_> = cyclic_tables
263            .iter()
264            .filter_map(|&id| graph.table_name(id))
265            .collect();
266        stats.warnings.push(format!(
267            "{} tables have FK cycles (relaxed FK enforcement): {:?}",
268            cyclic_tables.len(),
269            names
270        ));
271    }
272
273    let cyclic_set: AHashSet<TableId> = cyclic_tables.iter().copied().collect();
274
275    // Determine tenant root tables
276    let tenant_root_ids = find_tenant_root_tables(&graph, &tenant_column, &config, &yaml_config);
277
278    // Build reachability from tenant roots
279    let reachable_from_roots = compute_reachable_tables(&graph, &tenant_root_ids);
280
281    // Initialize table runtimes with classification
282    let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
283    for table in graph.schema.iter() {
284        let classification = classify_table(
285            &table.name,
286            table.id,
287            &graph,
288            &tenant_root_ids,
289            &reachable_from_roots,
290            &yaml_config,
291        );
292
293        let tenant_column_index = if classification == ShardTableClassification::TenantRoot {
294            find_tenant_column_index(table, &tenant_column)
295        } else {
296            None
297        };
298
299        let skip = should_skip_table(&table.name, classification, &config, &yaml_config);
300
301        runtimes.insert(
302            table.id,
303            TableRuntime {
304                name: table.name.clone(),
305                selected_rows: Vec::new(),
306                pk_set: PkSet::default(),
307                rows_seen: 0,
308                skip,
309                classification,
310                fk_orphans: 0,
311                tenant_column_index,
312            },
313        );
314    }
315
316    // Phase 3: Process tables in dependency order
317    if config.progress {
318        eprintln!(
319            "Processing {} tables for tenant {}...",
320            topo_order.len() + cyclic_tables.len(),
321            config.tenant_value
322        );
323    }
324
325    let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
326    let mut total_selected: u64 = 0;
327
328    for &table_id in &all_tables {
329        let table_schema = match graph.schema.table(table_id) {
330            Some(s) => s,
331            None => continue,
332        };
333
334        let (should_skip, table_name, classification, tenant_col_idx) = {
335            let runtime = match runtimes.get(&table_id) {
336                Some(r) => r,
337                None => continue,
338            };
339            (
340                runtime.skip,
341                runtime.name.clone(),
342                runtime.classification,
343                runtime.tenant_column_index,
344            )
345        };
346
347        if should_skip {
348            stats.tables_skipped += 1;
349            continue;
350        }
351
352        // Handle lookup/system tables
353        let include_all = match classification {
354            ShardTableClassification::Lookup => match config.include_global {
355                GlobalTableMode::None => {
356                    stats.tables_skipped += 1;
357                    continue;
358                }
359                GlobalTableMode::Lookups | GlobalTableMode::All => true,
360            },
361            ShardTableClassification::System => {
362                stats.tables_skipped += 1;
363                continue;
364            }
365            ShardTableClassification::Unknown => match config.include_global {
366                GlobalTableMode::All => true,
367                _ => {
368                    stats.tables_skipped += 1;
369                    continue;
370                }
371            },
372            _ => false,
373        };
374
375        let table_file = tables_dir.join(format!("{}.sql", table_name));
376        if !table_file.exists() {
377            continue;
378        }
379
380        let file = File::open(&table_file)?;
381        let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
382
383        let mut rows_seen = 0u64;
384        let mut fk_orphans = 0u64;
385        let mut copy_columns: Vec<String> = Vec::new();
386
387        while let Some(stmt) = parser.read_statement()? {
388            let (stmt_type, _) =
389                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
390
391            match stmt_type {
392                StatementType::Insert => {
393                    let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
394
395                    for row in rows {
396                        rows_seen += 1;
397                        let unified = UnifiedRow::Insert(row);
398
399                        let should_include = if include_all {
400                            true
401                        } else {
402                            should_include_row(
403                                &unified,
404                                table_schema,
405                                classification,
406                                tenant_col_idx,
407                                &tenant_pk_value,
408                                &runtimes,
409                                &cyclic_set,
410                                &table_id,
411                            )
412                        };
413
414                        if !should_include {
415                            if classification == ShardTableClassification::TenantDependent {
416                                fk_orphans += 1;
417                            }
418                            continue;
419                        }
420
421                        // Check max_selected_rows guard
422                        if let Some(max) = config.max_selected_rows {
423                            if total_selected >= max as u64 {
424                                stats.warnings.push(format!(
425                                    "Reached max_selected_rows limit ({}) at table '{}'",
426                                    max, table_name
427                                ));
428                                break;
429                            }
430                        }
431
432                        total_selected += 1;
433
434                        let runtime = runtimes.get_mut(&table_id).unwrap();
435                        if let Some(pk) = unified.pk() {
436                            runtime.pk_set.insert(pk.clone());
437                        }
438                        runtime.selected_rows.push(unified.into_selected());
439                    }
440                }
441                StatementType::Copy => {
442                    let header = String::from_utf8_lossy(&stmt);
443                    copy_columns = parse_copy_columns(&header);
444                }
445                StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
446                    if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
447                        let rows =
448                            parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
449
450                        for row in rows {
451                            rows_seen += 1;
452                            let unified = UnifiedRow::Copy(row);
453
454                            let should_include = if include_all {
455                                true
456                            } else {
457                                should_include_row(
458                                    &unified,
459                                    table_schema,
460                                    classification,
461                                    tenant_col_idx,
462                                    &tenant_pk_value,
463                                    &runtimes,
464                                    &cyclic_set,
465                                    &table_id,
466                                )
467                            };
468
469                            if !should_include {
470                                if classification == ShardTableClassification::TenantDependent {
471                                    fk_orphans += 1;
472                                }
473                                continue;
474                            }
475
476                            if let Some(max) = config.max_selected_rows {
477                                if total_selected >= max as u64 {
478                                    break;
479                                }
480                            }
481
482                            total_selected += 1;
483
484                            let runtime = runtimes.get_mut(&table_id).unwrap();
485                            if let Some(pk) = unified.pk() {
486                                runtime.pk_set.insert(pk.clone());
487                            }
488                            runtime.selected_rows.push(unified.into_selected());
489                        }
490                    }
491                }
492                _ => {}
493            }
494        }
495
496        let runtime = runtimes.get_mut(&table_id).unwrap();
497        runtime.rows_seen = rows_seen;
498        runtime.fk_orphans = fk_orphans;
499        stats.fk_orphans_skipped += fk_orphans;
500
501        if !runtime.selected_rows.is_empty() {
502            stats.tables_with_data += 1;
503        }
504
505        stats.table_stats.push(TableShardStats {
506            name: runtime.name.clone(),
507            rows_seen: runtime.rows_seen,
508            rows_selected: runtime.selected_rows.len() as u64,
509            classification: runtime.classification,
510        });
511    }
512
513    // Calculate totals
514    for table_stat in &stats.table_stats {
515        stats.total_rows_seen += table_stat.rows_seen;
516        stats.total_rows_selected += table_stat.rows_selected;
517    }
518    stats.tables_processed = stats.table_stats.len();
519
520    if config.progress {
521        eprintln!("Processing complete");
522    }
523
524    // Phase 4: Output synthesis
525    if config.dry_run {
526        return Ok(stats);
527    }
528
529    write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
530
531    Ok(stats)
532}
533
534/// Build schema graph from split table files
535fn build_schema_graph(tables_dir: &Path, config: &ShardConfig) -> anyhow::Result<SchemaGraph> {
536    let mut builder = SchemaBuilder::new();
537
538    for entry in fs::read_dir(tables_dir)? {
539        let entry = entry?;
540        let path = entry.path();
541
542        if path.extension().is_some_and(|e| e == "sql") {
543            let file = File::open(&path)?;
544            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
545
546            while let Some(stmt) = parser.read_statement()? {
547                let (stmt_type, _) =
548                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
549
550                match stmt_type {
551                    StatementType::CreateTable => {
552                        let stmt_str = String::from_utf8_lossy(&stmt);
553                        builder.parse_create_table(&stmt_str);
554                    }
555                    StatementType::AlterTable => {
556                        let stmt_str = String::from_utf8_lossy(&stmt);
557                        builder.parse_alter_table(&stmt_str);
558                    }
559                    _ => {}
560                }
561            }
562        }
563    }
564
565    Ok(SchemaGraph::from_schema(builder.build()))
566}
567
568/// Detect tenant column from config or by scanning schema
569fn detect_tenant_column(
570    config: &ShardConfig,
571    yaml_config: &Option<ShardYamlConfig>,
572    graph: &SchemaGraph,
573) -> anyhow::Result<String> {
574    // Check CLI option first
575    if let Some(ref col) = config.tenant_column {
576        return Ok(col.clone());
577    }
578
579    // Check YAML config
580    if let Some(ref yaml) = yaml_config {
581        if let Some(ref col) = yaml.tenant.column {
582            return Ok(col.clone());
583        }
584    }
585
586    // Auto-detect from schema
587    for candidate in DefaultShardClassifier::TENANT_COLUMNS {
588        let mut found_in_tables = 0;
589        for table in graph.schema.iter() {
590            if table.get_column(candidate).is_some() {
591                found_in_tables += 1;
592            }
593        }
594        if found_in_tables >= 2 {
595            return Ok(candidate.to_string());
596        }
597    }
598
599    anyhow::bail!(
600        "Could not auto-detect tenant column. Please specify with --tenant-column. \
601         Looked for: {:?}",
602        DefaultShardClassifier::TENANT_COLUMNS
603    )
604}
605
606/// Parse tenant value string into PkValue
607fn parse_tenant_value(value: &str) -> PkValue {
608    if let Ok(i) = value.parse::<i64>() {
609        PkValue::Int(i)
610    } else if let Ok(i) = value.parse::<i128>() {
611        PkValue::BigInt(i)
612    } else {
613        PkValue::Text(value.into())
614    }
615}
616
617/// Find tables that have the tenant column
618fn find_tenant_root_tables(
619    graph: &SchemaGraph,
620    tenant_column: &str,
621    config: &ShardConfig,
622    yaml_config: &Option<ShardYamlConfig>,
623) -> AHashSet<TableId> {
624    let mut roots = AHashSet::new();
625
626    // Explicit roots from config
627    let explicit_roots: AHashSet<String> = config
628        .root_tables
629        .iter()
630        .chain(
631            yaml_config
632                .as_ref()
633                .map(|y| &y.tenant.root_tables)
634                .unwrap_or(&Vec::new()),
635        )
636        .map(|s| s.to_lowercase())
637        .collect();
638
639    for table in graph.schema.iter() {
640        let lower_name = table.name.to_lowercase();
641
642        if explicit_roots.contains(&lower_name) || table.get_column(tenant_column).is_some() {
643            roots.insert(table.id);
644        }
645    }
646
647    roots
648}
649
650/// Compute tables reachable from tenant roots via FK relationships
651fn compute_reachable_tables(
652    graph: &SchemaGraph,
653    tenant_roots: &AHashSet<TableId>,
654) -> AHashSet<TableId> {
655    let mut reachable = tenant_roots.clone();
656    let mut queue: Vec<TableId> = tenant_roots.iter().copied().collect();
657
658    while let Some(table_id) = queue.pop() {
659        for &child_id in &graph.children[table_id.0 as usize] {
660            if !reachable.contains(&child_id) {
661                reachable.insert(child_id);
662                queue.push(child_id);
663            }
664        }
665    }
666
667    reachable
668}
669
670/// Classify a table for sharding
671fn classify_table(
672    table_name: &str,
673    table_id: TableId,
674    graph: &SchemaGraph,
675    tenant_roots: &AHashSet<TableId>,
676    reachable: &AHashSet<TableId>,
677    yaml_config: &Option<ShardYamlConfig>,
678) -> ShardTableClassification {
679    // Check YAML override first
680    if let Some(ref yaml) = yaml_config {
681        if let Some(class) = yaml.get_classification(table_name) {
682            return class;
683        }
684    }
685
686    // Check if it's a tenant root
687    if tenant_roots.contains(&table_id) {
688        return ShardTableClassification::TenantRoot;
689    }
690
691    // Check if reachable from tenant roots (dependent)
692    if reachable.contains(&table_id) {
693        // Check if it might be a junction table
694        if is_junction_table(table_name, table_id, graph) {
695            return ShardTableClassification::Junction;
696        }
697        return ShardTableClassification::TenantDependent;
698    }
699
700    // Check system patterns
701    if DefaultShardClassifier::is_system_table(table_name) {
702        return ShardTableClassification::System;
703    }
704
705    // Check lookup patterns
706    if DefaultShardClassifier::is_lookup_table(table_name) {
707        return ShardTableClassification::Lookup;
708    }
709
710    ShardTableClassification::Unknown
711}
712
713/// Check if a table is a junction table
714fn is_junction_table(table_name: &str, table_id: TableId, graph: &SchemaGraph) -> bool {
715    // Name-based heuristic
716    if DefaultShardClassifier::is_junction_table_by_name(table_name) {
717        return true;
718    }
719
720    // Structure-based: table with multiple FKs and few/no other columns
721    if let Some(table) = graph.schema.table(table_id) {
722        let fk_count = table.foreign_keys.len();
723        let fk_col_count: usize = table.foreign_keys.iter().map(|fk| fk.columns.len()).sum();
724        let total_cols = table.columns.len();
725
726        // Junction tables typically have mostly FK columns
727        if fk_count >= 2 && fk_col_count >= total_cols.saturating_sub(2) {
728            return true;
729        }
730    }
731
732    false
733}
734
735/// Find the index of the tenant column in a table
736fn find_tenant_column_index(table: &TableSchema, tenant_column: &str) -> Option<usize> {
737    table
738        .columns
739        .iter()
740        .position(|c| c.name.eq_ignore_ascii_case(tenant_column))
741}
742
743/// Determine if a table should be skipped
744fn should_skip_table(
745    table_name: &str,
746    classification: ShardTableClassification,
747    config: &ShardConfig,
748    yaml_config: &Option<ShardYamlConfig>,
749) -> bool {
750    // Check YAML skip override
751    if let Some(ref yaml) = yaml_config {
752        if yaml.should_skip(table_name) {
753            return true;
754        }
755    }
756
757    // System tables always skipped
758    if classification == ShardTableClassification::System {
759        return true;
760    }
761
762    // Lookup tables depend on config
763    if classification == ShardTableClassification::Lookup {
764        return config.include_global == GlobalTableMode::None;
765    }
766
767    false
768}
769
770/// Check if a row should be included in the shard
771#[allow(clippy::too_many_arguments)]
772fn should_include_row(
773    row: &UnifiedRow,
774    table_schema: &TableSchema,
775    classification: ShardTableClassification,
776    tenant_column_index: Option<usize>,
777    tenant_value: &PkValue,
778    runtimes: &AHashMap<TableId, TableRuntime>,
779    cyclic_set: &AHashSet<TableId>,
780    current_table_id: &TableId,
781) -> bool {
782    match classification {
783        ShardTableClassification::TenantRoot => {
784            // Check tenant column value using column_map for correct mapping
785            if let Some(idx) = tenant_column_index {
786                match row {
787                    UnifiedRow::Insert(r) => {
788                        if let Some(val) = r.get_column_value(idx) {
789                            return val == tenant_value;
790                        }
791                    }
792                    UnifiedRow::Copy(r) => {
793                        if let Some(val) = r.get_column_value(idx) {
794                            return val == tenant_value;
795                        }
796                    }
797                }
798            }
799            false
800        }
801        ShardTableClassification::TenantDependent => {
802            // Check if any FK points to a selected row
803            for (fk_ref, fk_tuple) in row.fk_values() {
804                if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
805                    if let Some(parent_id) = fk.referenced_table_id {
806                        // Skip FK check for cyclic tables
807                        if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id)
808                        {
809                            continue;
810                        }
811
812                        if let Some(parent_runtime) = runtimes.get(&parent_id) {
813                            if parent_runtime.pk_set.contains(fk_tuple) {
814                                return true;
815                            }
816                        }
817                    }
818                }
819            }
820            false
821        }
822        ShardTableClassification::Junction => {
823            // Include if ANY FK points to a selected row
824            for (fk_ref, fk_tuple) in row.fk_values() {
825                if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
826                    if let Some(parent_id) = fk.referenced_table_id {
827                        if let Some(parent_runtime) = runtimes.get(&parent_id) {
828                            if parent_runtime.pk_set.contains(fk_tuple) {
829                                return true;
830                            }
831                        }
832                    }
833                }
834            }
835            false
836        }
837        _ => false,
838    }
839}
840
841/// Extract a column value from INSERT row bytes by index
842fn extract_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
843    let mut values = Vec::new();
844    let mut current_start = 0;
845    let mut in_string = false;
846    let mut escape_next = false;
847    let mut paren_depth = 0;
848
849    // Skip leading (
850    let start = raw.iter().position(|&b| b == b'(')?;
851    let raw = &raw[start + 1..];
852
853    for (i, &b) in raw.iter().enumerate() {
854        if escape_next {
855            escape_next = false;
856            continue;
857        }
858
859        if b == b'\\' && in_string {
860            escape_next = true;
861            continue;
862        }
863
864        if b == b'\'' && !escape_next {
865            in_string = !in_string;
866            continue;
867        }
868
869        if in_string {
870            continue;
871        }
872
873        match b {
874            b'(' => paren_depth += 1,
875            b')' => {
876                if paren_depth == 0 {
877                    values.push(&raw[current_start..i]);
878                    break;
879                }
880                paren_depth -= 1;
881            }
882            b',' if paren_depth == 0 => {
883                values.push(&raw[current_start..i]);
884                current_start = i + 1;
885            }
886            _ => {}
887        }
888    }
889
890    if column_index >= values.len() {
891        return None;
892    }
893
894    parse_value_bytes(values[column_index])
895}
896
897/// Extract a column value from COPY row bytes by index
898fn extract_copy_column_value(raw: &[u8], column_index: usize) -> Option<PkValue> {
899    let fields: Vec<&[u8]> = raw.split(|&b| b == b'\t').collect();
900    if column_index >= fields.len() {
901        return None;
902    }
903
904    let field = fields[column_index];
905    if field == b"\\N" {
906        return Some(PkValue::Null);
907    }
908
909    parse_value_bytes(field)
910}
911
912/// Parse a byte slice into a PkValue
913fn parse_value_bytes(bytes: &[u8]) -> Option<PkValue> {
914    let trimmed = bytes
915        .iter()
916        .skip_while(|&&b| b == b' ')
917        .take_while(|&&b| b != b' ')
918        .copied()
919        .collect::<Vec<_>>();
920
921    if trimmed.is_empty() {
922        return None;
923    }
924
925    // Check for NULL
926    if trimmed.eq_ignore_ascii_case(b"null") {
927        return Some(PkValue::Null);
928    }
929
930    // Remove quotes for strings
931    let unquoted = if trimmed.first() == Some(&b'\'') && trimmed.last() == Some(&b'\'') {
932        &trimmed[1..trimmed.len() - 1]
933    } else {
934        &trimmed[..]
935    };
936
937    // Try parsing as number
938    if let Ok(s) = std::str::from_utf8(unquoted) {
939        if let Ok(i) = s.parse::<i64>() {
940            return Some(PkValue::Int(i));
941        }
942        if let Ok(i) = s.parse::<i128>() {
943            return Some(PkValue::BigInt(i));
944        }
945        return Some(PkValue::Text(s.into()));
946    }
947
948    None
949}
950
951/// Write the sharded output
952fn write_output(
953    config: &ShardConfig,
954    _graph: &SchemaGraph,
955    table_order: &[TableId],
956    runtimes: &AHashMap<TableId, TableRuntime>,
957    tables_dir: &Path,
958    stats: &ShardStats,
959) -> anyhow::Result<()> {
960    let mut writer: Box<dyn Write> = match &config.output {
961        Some(path) => {
962            if let Some(parent) = path.parent() {
963                fs::create_dir_all(parent)?;
964            }
965            Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
966        }
967        None => Box::new(BufWriter::new(std::io::stdout())),
968    };
969
970    // Write header comment
971    write_header(&mut writer, config, stats)?;
972
973    // Write dialect-specific header
974    write_dialect_header(&mut writer, config.dialect)?;
975
976    // Write schema for each table (if enabled)
977    if config.include_schema {
978        for &table_id in table_order {
979            let runtime = match runtimes.get(&table_id) {
980                Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
981                _ => continue,
982            };
983
984            let table_file = tables_dir.join(format!("{}.sql", runtime.name));
985            if !table_file.exists() {
986                continue;
987            }
988
989            let file = File::open(&table_file)?;
990            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
991
992            while let Some(stmt) = parser.read_statement()? {
993                let (stmt_type, _) =
994                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
995
996                if stmt_type.is_schema() {
997                    writer.write_all(&stmt)?;
998                    writer.write_all(b"\n")?;
999                }
1000            }
1001        }
1002    }
1003
1004    // Write data for each table
1005    for &table_id in table_order {
1006        let runtime = match runtimes.get(&table_id) {
1007            Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
1008            _ => continue,
1009        };
1010
1011        let table_name = &runtime.name;
1012        let row_count = runtime.selected_rows.len();
1013
1014        writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
1015
1016        const CHUNK_SIZE: usize = 1000;
1017
1018        let quoted_name = match config.dialect {
1019            SqlDialect::MySql => format!("`{}`", table_name),
1020            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
1021            SqlDialect::Mssql => format!("[{}]", table_name),
1022        };
1023
1024        for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
1025            writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
1026
1027            for (i, row) in chunk.iter().enumerate() {
1028                if i > 0 {
1029                    writer.write_all(b",\n")?;
1030                }
1031
1032                let values = match row.format {
1033                    RowFormat::Insert => match config.dialect {
1034                        SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
1035                        _ => row.raw.clone(),
1036                    },
1037                    RowFormat::Copy => convert_copy_to_insert_values(&row.raw, config.dialect),
1038                };
1039                writer.write_all(&values)?;
1040            }
1041
1042            writer.write_all(b";\n")?;
1043        }
1044    }
1045
1046    // Write dialect-specific footer
1047    write_dialect_footer(&mut writer, config.dialect)?;
1048
1049    writer.flush()?;
1050
1051    Ok(())
1052}
1053
1054/// Write header comment
1055fn write_header<W: Write>(
1056    writer: &mut W,
1057    config: &ShardConfig,
1058    stats: &ShardStats,
1059) -> std::io::Result<()> {
1060    writeln!(writer, "-- Sharded from: {}", config.input.display())?;
1061    writeln!(
1062        writer,
1063        "-- Date: {}",
1064        chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
1065    )?;
1066    if let Some(ref col) = stats.detected_tenant_column {
1067        writeln!(writer, "-- Tenant column: {}", col)?;
1068    }
1069    writeln!(writer, "-- Tenant value: {}", config.tenant_value)?;
1070    writeln!(writer, "-- Dialect: {}", config.dialect)?;
1071    writeln!(writer, "--")?;
1072    writeln!(writer, "-- Statistics:")?;
1073    writeln!(writer, "--   Tables processed: {}", stats.tables_processed)?;
1074    writeln!(writer, "--   Tables with data: {}", stats.tables_with_data)?;
1075    writeln!(writer, "--   Tables skipped: {}", stats.tables_skipped)?;
1076
1077    let percent = if stats.total_rows_seen > 0 {
1078        (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
1079    } else {
1080        0.0
1081    };
1082    writeln!(
1083        writer,
1084        "--   Total rows: {} (from {} original, {:.1}%)",
1085        stats.total_rows_selected, stats.total_rows_seen, percent
1086    )?;
1087
1088    if stats.fk_orphans_skipped > 0 {
1089        writeln!(
1090            writer,
1091            "--   FK orphans skipped: {}",
1092            stats.fk_orphans_skipped
1093        )?;
1094    }
1095
1096    if !stats.warnings.is_empty() {
1097        writeln!(writer, "--   Warnings: {}", stats.warnings.len())?;
1098    }
1099
1100    writeln!(writer)?;
1101
1102    Ok(())
1103}
1104
1105/// Write dialect-specific header
1106fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1107    match dialect {
1108        SqlDialect::MySql => {
1109            writeln!(writer, "SET NAMES utf8mb4;")?;
1110            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
1111        }
1112        SqlDialect::Postgres => {
1113            writeln!(writer, "SET client_encoding = 'UTF8';")?;
1114            writeln!(writer, "SET session_replication_role = replica;")?;
1115        }
1116        SqlDialect::Sqlite => {
1117            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
1118        }
1119        SqlDialect::Mssql => {
1120            writeln!(writer, "SET ANSI_NULLS ON;")?;
1121            writeln!(writer, "SET QUOTED_IDENTIFIER ON;")?;
1122            writeln!(writer, "SET NOCOUNT ON;")?;
1123        }
1124    }
1125    writeln!(writer)?;
1126    Ok(())
1127}
1128
1129/// Write dialect-specific footer
1130fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
1131    writeln!(writer)?;
1132    match dialect {
1133        SqlDialect::MySql => {
1134            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
1135        }
1136        SqlDialect::Postgres => {
1137            writeln!(writer, "SET session_replication_role = DEFAULT;")?;
1138        }
1139        SqlDialect::Sqlite => {
1140            writeln!(writer, "PRAGMA foreign_keys = ON;")?;
1141        }
1142        SqlDialect::Mssql => {
1143            // No footer needed
1144        }
1145    }
1146    Ok(())
1147}
1148
1149/// Convert a MySQL-style row to PostgreSQL syntax
1150fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
1151    let mut result = Vec::with_capacity(row.len());
1152    let mut i = 0;
1153
1154    while i < row.len() {
1155        if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
1156            result.push(b'\'');
1157            result.push(b'\'');
1158            i += 2;
1159        } else {
1160            result.push(row[i]);
1161            i += 1;
1162        }
1163    }
1164
1165    result
1166}
1167
1168/// Convert PostgreSQL COPY format to INSERT VALUES format
1169fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
1170    let mut result = Vec::with_capacity(row.len() + 20);
1171    result.push(b'(');
1172
1173    let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
1174
1175    for (i, field) in fields.iter().enumerate() {
1176        if i > 0 {
1177            result.extend_from_slice(b", ");
1178        }
1179
1180        if *field == b"\\N" {
1181            result.extend_from_slice(b"NULL");
1182        } else if field.is_empty() {
1183            result.extend_from_slice(b"''");
1184        } else if is_numeric(field) {
1185            result.extend_from_slice(field);
1186        } else {
1187            result.push(b'\'');
1188            for &b in *field {
1189                match b {
1190                    b'\'' => match dialect {
1191                        SqlDialect::MySql => result.extend_from_slice(b"\\'"),
1192                        SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {
1193                            result.extend_from_slice(b"''")
1194                        }
1195                    },
1196                    b'\\' if dialect == SqlDialect::MySql => {
1197                        result.extend_from_slice(b"\\\\");
1198                    }
1199                    _ => result.push(b),
1200                }
1201            }
1202            result.push(b'\'');
1203        }
1204    }
1205
1206    result.push(b')');
1207    result
1208}
1209
1210/// Check if a byte slice represents a numeric value
1211fn is_numeric(s: &[u8]) -> bool {
1212    if s.is_empty() {
1213        return false;
1214    }
1215
1216    let mut has_digit = false;
1217    let mut has_dot = false;
1218    let mut start = 0;
1219
1220    if s[0] == b'-' || s[0] == b'+' {
1221        start = 1;
1222    }
1223
1224    for &b in &s[start..] {
1225        match b {
1226            b'0'..=b'9' => has_digit = true,
1227            b'.' if !has_dot => has_dot = true,
1228            b'e' | b'E' => continue,
1229            _ => return false,
1230        }
1231    }
1232
1233    has_digit
1234}