sql_splitter/sample/
mod.rs

1//! Sample command for creating reduced datasets from SQL dumps.
2//!
3//! The sample command creates reduced datasets while optionally preserving
4//! foreign key integrity through dependency-aware FK chain resolution.
5//!
6//! Supports MySQL, PostgreSQL, and SQLite dialects.
7
8mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28/// Sampling mode
29#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31    /// Sample N% of rows from each table
32    Percent(u32),
33    /// Sample up to N rows from each table
34    Rows(usize),
35}
36
37/// Configuration for the sample command
38#[derive(Debug)]
39pub struct SampleConfig {
40    /// Input SQL file
41    pub input: PathBuf,
42    /// Output SQL file (None for stdout)
43    pub output: Option<PathBuf>,
44    /// SQL dialect
45    pub dialect: SqlDialect,
46    /// Sampling mode
47    pub mode: SampleMode,
48    /// Preserve foreign key relationships
49    pub preserve_relations: bool,
50    /// Only sample these tables (None = all)
51    pub tables_filter: Option<Vec<String>>,
52    /// Exclude these tables
53    pub exclude: Vec<String>,
54    /// Root tables for sampling (start from these)
55    pub root_tables: Vec<String>,
56    /// How to handle global/lookup tables
57    pub include_global: GlobalTableMode,
58    /// Random seed for reproducibility
59    pub seed: u64,
60    /// Dry run mode (show stats only)
61    pub dry_run: bool,
62    /// Show progress
63    pub progress: bool,
64    /// YAML config file path
65    pub config_file: Option<PathBuf>,
66    /// Maximum total rows to sample (explosion guard)
67    pub max_total_rows: Option<usize>,
68    /// Fail if any FK integrity issues detected
69    pub strict_fk: bool,
70    /// Include schema statements in output
71    pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75    fn default() -> Self {
76        Self {
77            input: PathBuf::new(),
78            output: None,
79            dialect: SqlDialect::MySql,
80            mode: SampleMode::Percent(10),
81            preserve_relations: false,
82            tables_filter: None,
83            exclude: Vec::new(),
84            root_tables: Vec::new(),
85            include_global: GlobalTableMode::Lookups,
86            seed: rand::random(),
87            dry_run: false,
88            progress: false,
89            config_file: None,
90            max_total_rows: None,
91            strict_fk: false,
92            include_schema: true,
93        }
94    }
95}
96
97/// Statistics from sample operation
98#[derive(Debug, Default)]
99pub struct SampleStats {
100    /// Number of tables sampled
101    pub tables_sampled: usize,
102    /// Number of tables skipped
103    pub tables_skipped: usize,
104    /// Total rows selected
105    pub total_rows_selected: u64,
106    /// Total rows seen
107    pub total_rows_seen: u64,
108    /// Per-table statistics
109    pub table_stats: Vec<TableSampleStats>,
110    /// Warning messages
111    pub warnings: Vec<String>,
112    /// FK orphan count (rows rejected due to missing parents)
113    pub fk_orphans_rejected: u64,
114}
115
116/// Per-table sampling statistics
117#[derive(Debug, Clone)]
118pub struct TableSampleStats {
119    pub name: String,
120    pub rows_seen: u64,
121    pub rows_selected: u64,
122    pub classification: TableClassification,
123}
124
125/// Runtime state for a table during sampling
126struct TableRuntime {
127    /// Table name
128    name: String,
129    /// Selected rows with format metadata
130    selected_rows: Vec<SelectedRow>,
131    /// Primary key set for FK membership checks
132    pk_set: PkSet,
133    /// Rows seen count
134    rows_seen: u64,
135    /// Whether to skip this table
136    skip: bool,
137    /// Table classification
138    classification: TableClassification,
139    /// FK orphans rejected for this table
140    fk_orphans: u64,
141}
142
143/// Combined row representation for both MySQL INSERT and PostgreSQL COPY
144enum UnifiedRow {
145    Insert(ParsedRow),
146    Copy(ParsedCopyRow),
147}
148
149/// Row format indicator for output
150#[derive(Debug, Clone, Copy, PartialEq)]
151enum RowFormat {
152    Insert,
153    Copy,
154}
155
156/// Selected row with format metadata
157struct SelectedRow {
158    raw: Vec<u8>,
159    format: RowFormat,
160}
161
162impl UnifiedRow {
163    fn pk(&self) -> Option<&smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>> {
164        match self {
165            UnifiedRow::Insert(r) => r.pk.as_ref(),
166            UnifiedRow::Copy(r) => r.pk.as_ref(),
167        }
168    }
169
170    fn fk_values(
171        &self,
172    ) -> &[(
173        crate::parser::mysql_insert::FkRef,
174        smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
175    )] {
176        match self {
177            UnifiedRow::Insert(r) => &r.fk_values,
178            UnifiedRow::Copy(r) => &r.fk_values,
179        }
180    }
181
182    fn into_selected(self) -> SelectedRow {
183        match self {
184            UnifiedRow::Insert(r) => SelectedRow {
185                raw: r.raw,
186                format: RowFormat::Insert,
187            },
188            UnifiedRow::Copy(r) => SelectedRow {
189                raw: r.raw,
190                format: RowFormat::Copy,
191            },
192        }
193    }
194}
195
196/// Run the sample command
197pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
198    // Load YAML config if provided
199    let yaml_config = if let Some(ref path) = config.config_file {
200        Some(SampleYamlConfig::load(path)?)
201    } else {
202        None
203    };
204
205    let mut rng = StdRng::seed_from_u64(config.seed);
206    let mut stats = SampleStats::default();
207
208    // Progress bar setup
209    let progress_bar = if config.progress {
210        let pb = ProgressBar::new_spinner();
211        pb.set_style(
212            ProgressStyle::default_spinner()
213                .template("{spinner:.green} {msg}")
214                .unwrap(),
215        );
216        Some(pb)
217    } else {
218        None
219    };
220
221    // Phase 0: Split into temp per-table files
222    let temp_dir = TempDir::new()?;
223    let tables_dir = temp_dir.path().join("tables");
224
225    if let Some(ref pb) = progress_bar {
226        pb.set_message("Splitting dump into per-table files...");
227    }
228
229    let splitter = Splitter::new(config.input.clone(), tables_dir.clone())
230        .with_dialect(config.dialect)
231        .with_content_filter(ContentFilter::All);
232
233    let split_stats = splitter.split()?;
234
235    if let Some(ref pb) = progress_bar {
236        pb.set_message(format!(
237            "Split complete: {} tables, {} statements",
238            split_stats.tables_found, split_stats.statements_processed
239        ));
240    }
241
242    // Phase 1: Build schema graph
243    if let Some(ref pb) = progress_bar {
244        pb.set_message("Building schema graph...");
245    }
246
247    let graph = build_schema_graph(&tables_dir, &config)?;
248
249    let (topo_order, cyclic_tables) = graph.processing_order();
250
251    if !cyclic_tables.is_empty() {
252        let names: Vec<_> = cyclic_tables
253            .iter()
254            .filter_map(|&id| graph.table_name(id))
255            .collect();
256        let msg = format!(
257            "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
258            cyclic_tables.len(),
259            names
260        );
261        if config.progress {
262            eprintln!("{}", msg);
263        }
264        stats.warnings.push(msg);
265    }
266
267    // Build set of cyclic table IDs for quick lookup
268    let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
269
270    // Determine root tables
271    let explicit_roots: ahash::AHashSet<String> = config
272        .root_tables
273        .iter()
274        .map(|s| s.to_lowercase())
275        .collect();
276
277    // Initialize table runtimes with classification
278    let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
279    for table in graph.schema.iter() {
280        let classification =
281            determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
282        let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
283
284        runtimes.insert(
285            table.id,
286            TableRuntime {
287                name: table.name.clone(),
288                selected_rows: Vec::new(),
289                pk_set: PkSet::default(),
290                rows_seen: 0,
291                skip,
292                classification,
293                fk_orphans: 0,
294            },
295        );
296    }
297
298    // Phase 2: Process tables in dependency order
299    if let Some(ref pb) = progress_bar {
300        pb.set_message(format!(
301            "Sampling {} tables in dependency order...",
302            topo_order.len()
303        ));
304    }
305
306    // Process acyclic tables first, then cyclic tables
307    let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
308
309    let mut total_selected: u64 = 0;
310
311    for table_id in &all_tables {
312        let table_schema = match graph.schema.table(*table_id) {
313            Some(s) => s,
314            None => continue,
315        };
316
317        // Check if we should skip this table
318        let (should_skip, table_name, classification) = {
319            let runtime = match runtimes.get(table_id) {
320                Some(r) => r,
321                None => continue,
322            };
323            (runtime.skip, runtime.name.clone(), runtime.classification)
324        };
325
326        if should_skip {
327            stats.tables_skipped += 1;
328            continue;
329        }
330
331        // Handle lookup/global tables specially
332        let sample_mode = match classification {
333            TableClassification::Lookup => {
334                match config.include_global {
335                    GlobalTableMode::None => {
336                        stats.tables_skipped += 1;
337                        continue;
338                    }
339                    GlobalTableMode::Lookups | GlobalTableMode::All => {
340                        // Include all rows
341                        SampleMode::Percent(100)
342                    }
343                }
344            }
345            TableClassification::System => {
346                stats.tables_skipped += 1;
347                continue;
348            }
349            _ => get_table_sample_mode(&table_name, &config, &yaml_config),
350        };
351
352        let table_file = tables_dir.join(format!("{}.sql", table_name));
353        if !table_file.exists() {
354            continue;
355        }
356
357        // Parse statements from this table file
358        let file = File::open(&table_file)?;
359        let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
360
361        // Collect rows based on sampling mode
362        let mut candidates: Vec<UnifiedRow> = Vec::new();
363        let mut rows_seen = 0u64;
364        let mut fk_orphans = 0u64;
365
366        // For PostgreSQL COPY, track the current column order
367        let mut copy_columns: Vec<String> = Vec::new();
368
369        while let Some(stmt) = parser.read_statement()? {
370            let (stmt_type, _) =
371                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
372
373            match stmt_type {
374                StatementType::Insert => {
375                    let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
376
377                    for row in rows {
378                        rows_seen += 1;
379                        let unified = UnifiedRow::Insert(row);
380
381                        if config.preserve_relations {
382                            let (passes, orphan) = check_unified_fk_membership(
383                                &unified,
384                                table_schema,
385                                &runtimes,
386                                &cyclic_set,
387                                table_id,
388                            );
389                            if !passes {
390                                fk_orphans += 1;
391                                if orphan && config.strict_fk {
392                                    anyhow::bail!(
393                                        "FK integrity violation in table '{}': row references missing parent",
394                                        table_name
395                                    );
396                                }
397                                continue;
398                            }
399                        }
400
401                        candidates.push(unified);
402                    }
403                }
404                StatementType::Copy => {
405                    // Extract column order from COPY header
406                    let header = String::from_utf8_lossy(&stmt);
407                    copy_columns = parse_copy_columns(&header);
408                }
409                StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
410                    // This might be COPY data
411                    if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
412                        let rows =
413                            parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
414
415                        for row in rows {
416                            rows_seen += 1;
417                            let unified = UnifiedRow::Copy(row);
418
419                            if config.preserve_relations {
420                                let (passes, orphan) = check_unified_fk_membership(
421                                    &unified,
422                                    table_schema,
423                                    &runtimes,
424                                    &cyclic_set,
425                                    table_id,
426                                );
427                                if !passes {
428                                    fk_orphans += 1;
429                                    if orphan && config.strict_fk {
430                                        anyhow::bail!(
431                                            "FK integrity violation in table '{}': row references missing parent",
432                                            table_name
433                                        );
434                                    }
435                                    continue;
436                                }
437                            }
438
439                            candidates.push(unified);
440                        }
441                    }
442                }
443                _ => {}
444            }
445        }
446
447        // Check max_total_rows guard
448        if let Some(max) = config.max_total_rows {
449            if total_selected + candidates.len() as u64 > max as u64 {
450                let msg = format!(
451                    "Warning: Reached max_total_rows limit ({}) at table '{}'",
452                    max, table_name
453                );
454                stats.warnings.push(msg);
455                break;
456            }
457        }
458
459        // Apply sampling to candidates
460        let selected = sample_rows(&candidates, sample_mode, &mut rng);
461
462        // Update total count
463        total_selected += selected.len() as u64;
464
465        // Store selected rows and update PK set
466        let runtime = runtimes.get_mut(table_id).unwrap();
467        runtime.rows_seen = rows_seen;
468        runtime.fk_orphans = fk_orphans;
469
470        for row in selected {
471            if let Some(pk) = row.pk() {
472                runtime.pk_set.insert(pk.clone());
473            }
474            runtime.selected_rows.push(row.into_selected());
475        }
476
477        stats.fk_orphans_rejected += fk_orphans;
478
479        stats.table_stats.push(TableSampleStats {
480            name: runtime.name.clone(),
481            rows_seen: runtime.rows_seen,
482            rows_selected: runtime.selected_rows.len() as u64,
483            classification: runtime.classification,
484        });
485    }
486
487    // Calculate totals
488    for table_stats in &stats.table_stats {
489        stats.total_rows_seen += table_stats.rows_seen;
490        stats.total_rows_selected += table_stats.rows_selected;
491    }
492    stats.tables_sampled = stats.table_stats.len();
493
494    if let Some(ref pb) = progress_bar {
495        pb.finish_with_message("Sampling complete");
496    }
497
498    // Phase 3: Output synthesis
499    if config.dry_run {
500        return Ok(stats);
501    }
502
503    if config.progress {
504        eprintln!("Writing output...");
505    }
506
507    write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
508
509    Ok(stats)
510}
511
512/// Build schema graph from split table files
513fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
514    let mut builder = SchemaBuilder::new();
515
516    for entry in fs::read_dir(tables_dir)? {
517        let entry = entry?;
518        let path = entry.path();
519
520        if path.extension().map(|e| e == "sql").unwrap_or(false) {
521            let file = File::open(&path)?;
522            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
523
524            while let Some(stmt) = parser.read_statement()? {
525                let stmt_str = String::from_utf8_lossy(&stmt);
526                let (stmt_type, _) =
527                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
528
529                match stmt_type {
530                    StatementType::CreateTable => {
531                        builder.parse_create_table(&stmt_str);
532                    }
533                    StatementType::AlterTable => {
534                        builder.parse_alter_table(&stmt_str);
535                    }
536                    _ => {}
537                }
538            }
539        }
540    }
541
542    Ok(SchemaGraph::from_schema(builder.build()))
543}
544
545/// Determine table classification
546fn determine_classification(
547    name: &str,
548    graph: &SchemaGraph,
549    table_id: TableId,
550    yaml_config: &Option<SampleYamlConfig>,
551    explicit_roots: &ahash::AHashSet<String>,
552) -> TableClassification {
553    // Check explicit roots first
554    if explicit_roots.contains(&name.to_lowercase()) {
555        return TableClassification::Root;
556    }
557
558    // Check YAML config
559    if let Some(ref config) = yaml_config {
560        let class = config.get_classification(name);
561        if class != TableClassification::Normal {
562            return class;
563        }
564    }
565
566    // Check if it's a graph root (no parents)
567    if graph.parents[table_id.0 as usize].is_empty() {
568        return TableClassification::Root;
569    }
570
571    // Use default classifier
572    DefaultClassifier::classify(name)
573}
574
575/// Check if a table should be skipped
576fn should_skip_table(
577    name: &str,
578    config: &SampleConfig,
579    yaml_config: &Option<SampleYamlConfig>,
580    classification: TableClassification,
581) -> bool {
582    let name_lower = name.to_lowercase();
583
584    // Check exclude list
585    if config
586        .exclude
587        .iter()
588        .any(|e| e.to_lowercase() == name_lower)
589    {
590        return true;
591    }
592
593    // Check YAML skip
594    if let Some(ref yc) = yaml_config {
595        if yc.should_skip(name) {
596            return true;
597        }
598    }
599
600    // Check include filter
601    if let Some(ref filter) = config.tables_filter {
602        if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
603            return true;
604        }
605    }
606
607    // Skip system tables by default
608    if classification == TableClassification::System {
609        return true;
610    }
611
612    false
613}
614
615/// Get sample mode for a specific table
616fn get_table_sample_mode(
617    name: &str,
618    config: &SampleConfig,
619    yaml_config: &Option<SampleYamlConfig>,
620) -> SampleMode {
621    // Check YAML config first
622    if let Some(ref yc) = yaml_config {
623        if let Some(rows) = yc.get_rows(name) {
624            return SampleMode::Rows(rows);
625        }
626        if let Some(percent) = yc.get_percent(name) {
627            return SampleMode::Percent(percent);
628        }
629    }
630
631    // Fall back to global config
632    config.mode
633}
634
635/// Check FK membership for a unified row (works with both INSERT and COPY rows)
636fn check_unified_fk_membership(
637    row: &UnifiedRow,
638    table_schema: &crate::schema::TableSchema,
639    runtimes: &AHashMap<TableId, TableRuntime>,
640    cyclic_set: &ahash::AHashSet<TableId>,
641    current_table_id: &TableId,
642) -> (bool, bool) {
643    let mut passes = true;
644    let mut is_orphan = false;
645
646    for (fk_ref, fk_tuple) in row.fk_values() {
647        if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
648            if let Some(parent_id) = fk.referenced_table_id {
649                // Skip FK check for cyclic tables
650                if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
651                    continue;
652                }
653
654                // Check if parent row exists in parent's pk_set
655                if let Some(parent_runtime) = runtimes.get(&parent_id) {
656                    if !parent_runtime.pk_set.contains(fk_tuple) {
657                        passes = false;
658                        is_orphan = true;
659                        break;
660                    }
661                }
662            }
663        }
664    }
665
666    (passes, is_orphan)
667}
668
669/// Sample rows according to sampling mode
670fn sample_rows(candidates: &[UnifiedRow], mode: SampleMode, rng: &mut StdRng) -> Vec<UnifiedRow> {
671    match mode {
672        SampleMode::Percent(p) => {
673            // Bernoulli sampling
674            let prob = p as f64 / 100.0;
675            candidates
676                .iter()
677                .filter(|_| rng.gen_bool(prob.min(1.0)))
678                .map(|r| match r {
679                    UnifiedRow::Insert(row) => UnifiedRow::Insert(row.clone()),
680                    UnifiedRow::Copy(row) => UnifiedRow::Copy(row.clone()),
681                })
682                .collect()
683        }
684        SampleMode::Rows(n) => {
685            // Reservoir sampling
686            let mut reservoir = Reservoir::new(n, StdRng::from_rng(rng).unwrap());
687            for row in candidates {
688                let cloned = match row {
689                    UnifiedRow::Insert(r) => UnifiedRow::Insert(r.clone()),
690                    UnifiedRow::Copy(r) => UnifiedRow::Copy(r.clone()),
691                };
692                reservoir.consider(cloned);
693            }
694            reservoir.into_items()
695        }
696    }
697}
698
699/// Write sampled output
700fn write_output(
701    config: &SampleConfig,
702    _graph: &SchemaGraph,
703    table_order: &[TableId],
704    runtimes: &AHashMap<TableId, TableRuntime>,
705    tables_dir: &Path,
706    stats: &SampleStats,
707) -> anyhow::Result<()> {
708    let mut writer: Box<dyn Write> = match &config.output {
709        Some(path) => {
710            if let Some(parent) = path.parent() {
711                fs::create_dir_all(parent)?;
712            }
713            Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
714        }
715        None => Box::new(BufWriter::new(std::io::stdout())),
716    };
717
718    // Write header comment
719    write_header(&mut writer, config, stats)?;
720
721    // Write dialect-specific header
722    write_dialect_header(&mut writer, config.dialect)?;
723
724    // Write schema for each table (if enabled)
725    if config.include_schema {
726        for &table_id in table_order {
727            let runtime = match runtimes.get(&table_id) {
728                Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
729                _ => continue,
730            };
731
732            let table_file = tables_dir.join(format!("{}.sql", runtime.name));
733            if !table_file.exists() {
734                continue;
735            }
736
737            // Copy schema statements from table file
738            let file = File::open(&table_file)?;
739            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
740
741            while let Some(stmt) = parser.read_statement()? {
742                let (stmt_type, _) =
743                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
744
745                if stmt_type.is_schema() {
746                    writer.write_all(&stmt)?;
747                    writer.write_all(b"\n")?;
748                }
749            }
750        }
751    }
752
753    // Write data for each table
754    for &table_id in table_order {
755        let runtime = match runtimes.get(&table_id) {
756            Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
757            _ => continue,
758        };
759
760        let table_name = &runtime.name;
761        let row_count = runtime.selected_rows.len();
762
763        writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
764
765        // Write INSERTs in chunks (compact multi-row format)
766        const CHUNK_SIZE: usize = 1000;
767
768        // Get the table name quoting based on dialect
769        let quoted_name = match config.dialect {
770            SqlDialect::MySql => format!("`{}`", table_name),
771            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
772        };
773
774        for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
775            writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
776
777            for (i, row) in chunk.iter().enumerate() {
778                if i > 0 {
779                    writer.write_all(b",\n")?;
780                }
781
782                // Convert row to INSERT VALUES format based on original format
783                let values = match row.format {
784                    RowFormat::Insert => {
785                        // Already in INSERT format, but may need dialect conversion
786                        match config.dialect {
787                            SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
788                            _ => row.raw.clone(),
789                        }
790                    }
791                    RowFormat::Copy => {
792                        // Convert COPY format to INSERT VALUES
793                        convert_copy_to_insert_values(&row.raw, config.dialect)
794                    }
795                };
796                writer.write_all(&values)?;
797            }
798
799            writer.write_all(b";\n")?;
800        }
801    }
802
803    // Write dialect-specific footer
804    write_dialect_footer(&mut writer, config.dialect)?;
805
806    writer.flush()?;
807
808    Ok(())
809}
810
811/// Write header comment
812fn write_header<W: Write>(
813    writer: &mut W,
814    config: &SampleConfig,
815    stats: &SampleStats,
816) -> std::io::Result<()> {
817    writeln!(writer, "-- Sampled from: {}", config.input.display())?;
818    writeln!(
819        writer,
820        "-- Date: {}",
821        chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
822    )?;
823    writeln!(
824        writer,
825        "-- Mode: {:?}{}",
826        config.mode,
827        if config.preserve_relations {
828            ", preserve-relations"
829        } else {
830            ""
831        }
832    )?;
833    writeln!(writer, "-- Seed: {}", config.seed)?;
834    writeln!(writer, "-- Dialect: {}", config.dialect)?;
835    writeln!(writer, "--")?;
836    writeln!(writer, "-- Statistics:")?;
837    writeln!(writer, "--   Tables sampled: {}", stats.tables_sampled)?;
838    writeln!(writer, "--   Tables skipped: {}", stats.tables_skipped)?;
839
840    let percent = if stats.total_rows_seen > 0 {
841        (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
842    } else {
843        0.0
844    };
845    writeln!(
846        writer,
847        "--   Total rows: {} (from {} original, {:.1}%)",
848        stats.total_rows_selected, stats.total_rows_seen, percent
849    )?;
850
851    if stats.fk_orphans_rejected > 0 {
852        writeln!(
853            writer,
854            "--   FK orphans rejected: {}",
855            stats.fk_orphans_rejected
856        )?;
857    }
858
859    if !stats.warnings.is_empty() {
860        writeln!(writer, "--   Warnings: {}", stats.warnings.len())?;
861    }
862
863    writeln!(writer)?;
864
865    Ok(())
866}
867
868/// Write dialect-specific header
869fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
870    match dialect {
871        SqlDialect::MySql => {
872            writeln!(writer, "SET NAMES utf8mb4;")?;
873            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
874        }
875        SqlDialect::Postgres => {
876            writeln!(writer, "SET client_encoding = 'UTF8';")?;
877            writeln!(writer, "SET session_replication_role = replica;")?;
878        }
879        SqlDialect::Sqlite => {
880            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
881        }
882    }
883    writeln!(writer)?;
884    Ok(())
885}
886
887/// Write dialect-specific footer
888fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
889    writeln!(writer)?;
890    match dialect {
891        SqlDialect::MySql => {
892            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
893        }
894        SqlDialect::Postgres => {
895            writeln!(writer, "SET session_replication_role = DEFAULT;")?;
896        }
897        SqlDialect::Sqlite => {
898            writeln!(writer, "PRAGMA foreign_keys = ON;")?;
899        }
900    }
901    Ok(())
902}
903
904/// Convert a MySQL-style row to PostgreSQL syntax
905fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
906    // Simple conversion: just replace escaped quotes
907    // A full implementation would handle more edge cases
908    let mut result = Vec::with_capacity(row.len());
909    let mut i = 0;
910
911    while i < row.len() {
912        if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
913            // MySQL: \' -> PostgreSQL: ''
914            result.push(b'\'');
915            result.push(b'\'');
916            i += 2;
917        } else {
918            result.push(row[i]);
919            i += 1;
920        }
921    }
922
923    result
924}
925
926/// Convert PostgreSQL COPY format (tab-separated) to INSERT VALUES format
927fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
928    let mut result = Vec::with_capacity(row.len() + 20);
929    result.push(b'(');
930
931    let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
932
933    for (i, field) in fields.iter().enumerate() {
934        if i > 0 {
935            result.extend_from_slice(b", ");
936        }
937
938        // Check for NULL marker
939        if *field == b"\\N" {
940            result.extend_from_slice(b"NULL");
941        } else if field.is_empty() {
942            // Empty string
943            match dialect {
944                SqlDialect::MySql => result.extend_from_slice(b"''"),
945                SqlDialect::Postgres | SqlDialect::Sqlite => result.extend_from_slice(b"''"),
946            }
947        } else if is_numeric(field) {
948            // Numeric value - no quotes needed
949            result.extend_from_slice(field);
950        } else {
951            // String value - needs quoting
952            result.push(b'\'');
953            for &b in *field {
954                match b {
955                    b'\'' => {
956                        // Escape single quote
957                        match dialect {
958                            SqlDialect::MySql => result.extend_from_slice(b"\\'"),
959                            SqlDialect::Postgres | SqlDialect::Sqlite => {
960                                result.extend_from_slice(b"''")
961                            }
962                        }
963                    }
964                    b'\\' if dialect == SqlDialect::MySql => {
965                        // Escape backslash in MySQL
966                        result.extend_from_slice(b"\\\\");
967                    }
968                    _ => result.push(b),
969                }
970            }
971            result.push(b'\'');
972        }
973    }
974
975    result.push(b')');
976    result
977}
978
979/// Check if a byte slice represents a numeric value
980fn is_numeric(s: &[u8]) -> bool {
981    if s.is_empty() {
982        return false;
983    }
984
985    let mut has_digit = false;
986    let mut has_dot = false;
987    let mut start = 0;
988
989    // Handle leading sign
990    if s[0] == b'-' || s[0] == b'+' {
991        start = 1;
992    }
993
994    for &b in &s[start..] {
995        match b {
996            b'0'..=b'9' => has_digit = true,
997            b'.' if !has_dot => has_dot = true,
998            b'e' | b'E' => {
999                // Scientific notation - just check rest is digits
1000                continue;
1001            }
1002            _ => return false,
1003        }
1004    }
1005
1006    has_digit
1007}