1mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31 Percent(u32),
33 Rows(usize),
35}
36
37#[derive(Debug)]
39pub struct SampleConfig {
40 pub input: PathBuf,
42 pub output: Option<PathBuf>,
44 pub dialect: SqlDialect,
46 pub mode: SampleMode,
48 pub preserve_relations: bool,
50 pub tables_filter: Option<Vec<String>>,
52 pub exclude: Vec<String>,
54 pub root_tables: Vec<String>,
56 pub include_global: GlobalTableMode,
58 pub seed: u64,
60 pub dry_run: bool,
62 pub progress: bool,
64 pub config_file: Option<PathBuf>,
66 pub max_total_rows: Option<usize>,
68 pub strict_fk: bool,
70 pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75 fn default() -> Self {
76 Self {
77 input: PathBuf::new(),
78 output: None,
79 dialect: SqlDialect::MySql,
80 mode: SampleMode::Percent(10),
81 preserve_relations: false,
82 tables_filter: None,
83 exclude: Vec::new(),
84 root_tables: Vec::new(),
85 include_global: GlobalTableMode::Lookups,
86 seed: rand::random(),
87 dry_run: false,
88 progress: false,
89 config_file: None,
90 max_total_rows: None,
91 strict_fk: false,
92 include_schema: true,
93 }
94 }
95}
96
97#[derive(Debug, Default)]
99pub struct SampleStats {
100 pub tables_sampled: usize,
102 pub tables_skipped: usize,
104 pub total_rows_selected: u64,
106 pub total_rows_seen: u64,
108 pub table_stats: Vec<TableSampleStats>,
110 pub warnings: Vec<String>,
112 pub fk_orphans_rejected: u64,
114}
115
116#[derive(Debug, Clone)]
118pub struct TableSampleStats {
119 pub name: String,
120 pub rows_seen: u64,
121 pub rows_selected: u64,
122 pub classification: TableClassification,
123}
124
125struct TableRuntime {
127 name: String,
129 selected_rows: Vec<SelectedRow>,
131 pk_set: PkSet,
133 rows_seen: u64,
135 skip: bool,
137 classification: TableClassification,
139 fk_orphans: u64,
141}
142
143enum UnifiedRow {
145 Insert(ParsedRow),
146 Copy(ParsedCopyRow),
147}
148
149#[derive(Debug, Clone, Copy, PartialEq)]
151enum RowFormat {
152 Insert,
153 Copy,
154}
155
156struct SelectedRow {
158 raw: Vec<u8>,
159 format: RowFormat,
160}
161
162impl UnifiedRow {
163 fn pk(&self) -> Option<&smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>> {
164 match self {
165 UnifiedRow::Insert(r) => r.pk.as_ref(),
166 UnifiedRow::Copy(r) => r.pk.as_ref(),
167 }
168 }
169
170 fn fk_values(
171 &self,
172 ) -> &[(
173 crate::parser::mysql_insert::FkRef,
174 smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
175 )] {
176 match self {
177 UnifiedRow::Insert(r) => &r.fk_values,
178 UnifiedRow::Copy(r) => &r.fk_values,
179 }
180 }
181
182 fn into_selected(self) -> SelectedRow {
183 match self {
184 UnifiedRow::Insert(r) => SelectedRow {
185 raw: r.raw,
186 format: RowFormat::Insert,
187 },
188 UnifiedRow::Copy(r) => SelectedRow {
189 raw: r.raw,
190 format: RowFormat::Copy,
191 },
192 }
193 }
194}
195
196pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
198 let yaml_config = if let Some(ref path) = config.config_file {
200 Some(SampleYamlConfig::load(path)?)
201 } else {
202 None
203 };
204
205 let mut rng = StdRng::seed_from_u64(config.seed);
206 let mut stats = SampleStats::default();
207
208 let progress_bar = if config.progress {
210 let pb = ProgressBar::new_spinner();
211 pb.set_style(
212 ProgressStyle::default_spinner()
213 .template("{spinner:.green} {msg}")
214 .unwrap(),
215 );
216 Some(pb)
217 } else {
218 None
219 };
220
221 let temp_dir = TempDir::new()?;
223 let tables_dir = temp_dir.path().join("tables");
224
225 if let Some(ref pb) = progress_bar {
226 pb.set_message("Splitting dump into per-table files...");
227 }
228
229 let splitter = Splitter::new(config.input.clone(), tables_dir.clone())
230 .with_dialect(config.dialect)
231 .with_content_filter(ContentFilter::All);
232
233 let split_stats = splitter.split()?;
234
235 if let Some(ref pb) = progress_bar {
236 pb.set_message(format!(
237 "Split complete: {} tables, {} statements",
238 split_stats.tables_found, split_stats.statements_processed
239 ));
240 }
241
242 if let Some(ref pb) = progress_bar {
244 pb.set_message("Building schema graph...");
245 }
246
247 let graph = build_schema_graph(&tables_dir, &config)?;
248
249 let (topo_order, cyclic_tables) = graph.processing_order();
250
251 if !cyclic_tables.is_empty() {
252 let names: Vec<_> = cyclic_tables
253 .iter()
254 .filter_map(|&id| graph.table_name(id))
255 .collect();
256 let msg = format!(
257 "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
258 cyclic_tables.len(),
259 names
260 );
261 if config.progress {
262 eprintln!("{}", msg);
263 }
264 stats.warnings.push(msg);
265 }
266
267 let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
269
270 let explicit_roots: ahash::AHashSet<String> = config
272 .root_tables
273 .iter()
274 .map(|s| s.to_lowercase())
275 .collect();
276
277 let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
279 for table in graph.schema.iter() {
280 let classification =
281 determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
282 let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
283
284 runtimes.insert(
285 table.id,
286 TableRuntime {
287 name: table.name.clone(),
288 selected_rows: Vec::new(),
289 pk_set: PkSet::default(),
290 rows_seen: 0,
291 skip,
292 classification,
293 fk_orphans: 0,
294 },
295 );
296 }
297
298 if let Some(ref pb) = progress_bar {
300 pb.set_message(format!(
301 "Sampling {} tables in dependency order...",
302 topo_order.len()
303 ));
304 }
305
306 let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
308
309 let mut total_selected: u64 = 0;
310
311 for table_id in &all_tables {
312 let table_schema = match graph.schema.table(*table_id) {
313 Some(s) => s,
314 None => continue,
315 };
316
317 let (should_skip, table_name, classification) = {
319 let runtime = match runtimes.get(table_id) {
320 Some(r) => r,
321 None => continue,
322 };
323 (runtime.skip, runtime.name.clone(), runtime.classification)
324 };
325
326 if should_skip {
327 stats.tables_skipped += 1;
328 continue;
329 }
330
331 let sample_mode = match classification {
333 TableClassification::Lookup => {
334 match config.include_global {
335 GlobalTableMode::None => {
336 stats.tables_skipped += 1;
337 continue;
338 }
339 GlobalTableMode::Lookups | GlobalTableMode::All => {
340 SampleMode::Percent(100)
342 }
343 }
344 }
345 TableClassification::System => {
346 stats.tables_skipped += 1;
347 continue;
348 }
349 _ => get_table_sample_mode(&table_name, &config, &yaml_config),
350 };
351
352 let table_file = tables_dir.join(format!("{}.sql", table_name));
353 if !table_file.exists() {
354 continue;
355 }
356
357 let file = File::open(&table_file)?;
359 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
360
361 let mut candidates: Vec<UnifiedRow> = Vec::new();
363 let mut rows_seen = 0u64;
364 let mut fk_orphans = 0u64;
365
366 let mut copy_columns: Vec<String> = Vec::new();
368
369 while let Some(stmt) = parser.read_statement()? {
370 let (stmt_type, _) =
371 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
372
373 match stmt_type {
374 StatementType::Insert => {
375 let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
376
377 for row in rows {
378 rows_seen += 1;
379 let unified = UnifiedRow::Insert(row);
380
381 if config.preserve_relations {
382 let (passes, orphan) = check_unified_fk_membership(
383 &unified,
384 table_schema,
385 &runtimes,
386 &cyclic_set,
387 table_id,
388 );
389 if !passes {
390 fk_orphans += 1;
391 if orphan && config.strict_fk {
392 anyhow::bail!(
393 "FK integrity violation in table '{}': row references missing parent",
394 table_name
395 );
396 }
397 continue;
398 }
399 }
400
401 candidates.push(unified);
402 }
403 }
404 StatementType::Copy => {
405 let header = String::from_utf8_lossy(&stmt);
407 copy_columns = parse_copy_columns(&header);
408 }
409 StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
410 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
412 let rows =
413 parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
414
415 for row in rows {
416 rows_seen += 1;
417 let unified = UnifiedRow::Copy(row);
418
419 if config.preserve_relations {
420 let (passes, orphan) = check_unified_fk_membership(
421 &unified,
422 table_schema,
423 &runtimes,
424 &cyclic_set,
425 table_id,
426 );
427 if !passes {
428 fk_orphans += 1;
429 if orphan && config.strict_fk {
430 anyhow::bail!(
431 "FK integrity violation in table '{}': row references missing parent",
432 table_name
433 );
434 }
435 continue;
436 }
437 }
438
439 candidates.push(unified);
440 }
441 }
442 }
443 _ => {}
444 }
445 }
446
447 if let Some(max) = config.max_total_rows {
449 if total_selected + candidates.len() as u64 > max as u64 {
450 let msg = format!(
451 "Warning: Reached max_total_rows limit ({}) at table '{}'",
452 max, table_name
453 );
454 stats.warnings.push(msg);
455 break;
456 }
457 }
458
459 let selected = sample_rows(&candidates, sample_mode, &mut rng);
461
462 total_selected += selected.len() as u64;
464
465 let runtime = runtimes.get_mut(table_id).unwrap();
467 runtime.rows_seen = rows_seen;
468 runtime.fk_orphans = fk_orphans;
469
470 for row in selected {
471 if let Some(pk) = row.pk() {
472 runtime.pk_set.insert(pk.clone());
473 }
474 runtime.selected_rows.push(row.into_selected());
475 }
476
477 stats.fk_orphans_rejected += fk_orphans;
478
479 stats.table_stats.push(TableSampleStats {
480 name: runtime.name.clone(),
481 rows_seen: runtime.rows_seen,
482 rows_selected: runtime.selected_rows.len() as u64,
483 classification: runtime.classification,
484 });
485 }
486
487 for table_stats in &stats.table_stats {
489 stats.total_rows_seen += table_stats.rows_seen;
490 stats.total_rows_selected += table_stats.rows_selected;
491 }
492 stats.tables_sampled = stats.table_stats.len();
493
494 if let Some(ref pb) = progress_bar {
495 pb.finish_with_message("Sampling complete");
496 }
497
498 if config.dry_run {
500 return Ok(stats);
501 }
502
503 if config.progress {
504 eprintln!("Writing output...");
505 }
506
507 write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
508
509 Ok(stats)
510}
511
512fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
514 let mut builder = SchemaBuilder::new();
515
516 for entry in fs::read_dir(tables_dir)? {
517 let entry = entry?;
518 let path = entry.path();
519
520 if path.extension().map(|e| e == "sql").unwrap_or(false) {
521 let file = File::open(&path)?;
522 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
523
524 while let Some(stmt) = parser.read_statement()? {
525 let stmt_str = String::from_utf8_lossy(&stmt);
526 let (stmt_type, _) =
527 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
528
529 match stmt_type {
530 StatementType::CreateTable => {
531 builder.parse_create_table(&stmt_str);
532 }
533 StatementType::AlterTable => {
534 builder.parse_alter_table(&stmt_str);
535 }
536 _ => {}
537 }
538 }
539 }
540 }
541
542 Ok(SchemaGraph::from_schema(builder.build()))
543}
544
545fn determine_classification(
547 name: &str,
548 graph: &SchemaGraph,
549 table_id: TableId,
550 yaml_config: &Option<SampleYamlConfig>,
551 explicit_roots: &ahash::AHashSet<String>,
552) -> TableClassification {
553 if explicit_roots.contains(&name.to_lowercase()) {
555 return TableClassification::Root;
556 }
557
558 if let Some(ref config) = yaml_config {
560 let class = config.get_classification(name);
561 if class != TableClassification::Normal {
562 return class;
563 }
564 }
565
566 if graph.parents[table_id.0 as usize].is_empty() {
568 return TableClassification::Root;
569 }
570
571 DefaultClassifier::classify(name)
573}
574
575fn should_skip_table(
577 name: &str,
578 config: &SampleConfig,
579 yaml_config: &Option<SampleYamlConfig>,
580 classification: TableClassification,
581) -> bool {
582 let name_lower = name.to_lowercase();
583
584 if config
586 .exclude
587 .iter()
588 .any(|e| e.to_lowercase() == name_lower)
589 {
590 return true;
591 }
592
593 if let Some(ref yc) = yaml_config {
595 if yc.should_skip(name) {
596 return true;
597 }
598 }
599
600 if let Some(ref filter) = config.tables_filter {
602 if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
603 return true;
604 }
605 }
606
607 if classification == TableClassification::System {
609 return true;
610 }
611
612 false
613}
614
615fn get_table_sample_mode(
617 name: &str,
618 config: &SampleConfig,
619 yaml_config: &Option<SampleYamlConfig>,
620) -> SampleMode {
621 if let Some(ref yc) = yaml_config {
623 if let Some(rows) = yc.get_rows(name) {
624 return SampleMode::Rows(rows);
625 }
626 if let Some(percent) = yc.get_percent(name) {
627 return SampleMode::Percent(percent);
628 }
629 }
630
631 config.mode
633}
634
635fn check_unified_fk_membership(
637 row: &UnifiedRow,
638 table_schema: &crate::schema::TableSchema,
639 runtimes: &AHashMap<TableId, TableRuntime>,
640 cyclic_set: &ahash::AHashSet<TableId>,
641 current_table_id: &TableId,
642) -> (bool, bool) {
643 let mut passes = true;
644 let mut is_orphan = false;
645
646 for (fk_ref, fk_tuple) in row.fk_values() {
647 if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
648 if let Some(parent_id) = fk.referenced_table_id {
649 if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
651 continue;
652 }
653
654 if let Some(parent_runtime) = runtimes.get(&parent_id) {
656 if !parent_runtime.pk_set.contains(fk_tuple) {
657 passes = false;
658 is_orphan = true;
659 break;
660 }
661 }
662 }
663 }
664 }
665
666 (passes, is_orphan)
667}
668
669fn sample_rows(candidates: &[UnifiedRow], mode: SampleMode, rng: &mut StdRng) -> Vec<UnifiedRow> {
671 match mode {
672 SampleMode::Percent(p) => {
673 let prob = p as f64 / 100.0;
675 candidates
676 .iter()
677 .filter(|_| rng.gen_bool(prob.min(1.0)))
678 .map(|r| match r {
679 UnifiedRow::Insert(row) => UnifiedRow::Insert(row.clone()),
680 UnifiedRow::Copy(row) => UnifiedRow::Copy(row.clone()),
681 })
682 .collect()
683 }
684 SampleMode::Rows(n) => {
685 let mut reservoir = Reservoir::new(n, StdRng::from_rng(rng).unwrap());
687 for row in candidates {
688 let cloned = match row {
689 UnifiedRow::Insert(r) => UnifiedRow::Insert(r.clone()),
690 UnifiedRow::Copy(r) => UnifiedRow::Copy(r.clone()),
691 };
692 reservoir.consider(cloned);
693 }
694 reservoir.into_items()
695 }
696 }
697}
698
699fn write_output(
701 config: &SampleConfig,
702 _graph: &SchemaGraph,
703 table_order: &[TableId],
704 runtimes: &AHashMap<TableId, TableRuntime>,
705 tables_dir: &Path,
706 stats: &SampleStats,
707) -> anyhow::Result<()> {
708 let mut writer: Box<dyn Write> = match &config.output {
709 Some(path) => {
710 if let Some(parent) = path.parent() {
711 fs::create_dir_all(parent)?;
712 }
713 Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
714 }
715 None => Box::new(BufWriter::new(std::io::stdout())),
716 };
717
718 write_header(&mut writer, config, stats)?;
720
721 write_dialect_header(&mut writer, config.dialect)?;
723
724 if config.include_schema {
726 for &table_id in table_order {
727 let runtime = match runtimes.get(&table_id) {
728 Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
729 _ => continue,
730 };
731
732 let table_file = tables_dir.join(format!("{}.sql", runtime.name));
733 if !table_file.exists() {
734 continue;
735 }
736
737 let file = File::open(&table_file)?;
739 let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
740
741 while let Some(stmt) = parser.read_statement()? {
742 let (stmt_type, _) =
743 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
744
745 if stmt_type.is_schema() {
746 writer.write_all(&stmt)?;
747 writer.write_all(b"\n")?;
748 }
749 }
750 }
751 }
752
753 for &table_id in table_order {
755 let runtime = match runtimes.get(&table_id) {
756 Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
757 _ => continue,
758 };
759
760 let table_name = &runtime.name;
761 let row_count = runtime.selected_rows.len();
762
763 writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
764
765 const CHUNK_SIZE: usize = 1000;
767
768 let quoted_name = match config.dialect {
770 SqlDialect::MySql => format!("`{}`", table_name),
771 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
772 };
773
774 for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
775 writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
776
777 for (i, row) in chunk.iter().enumerate() {
778 if i > 0 {
779 writer.write_all(b",\n")?;
780 }
781
782 let values = match row.format {
784 RowFormat::Insert => {
785 match config.dialect {
787 SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
788 _ => row.raw.clone(),
789 }
790 }
791 RowFormat::Copy => {
792 convert_copy_to_insert_values(&row.raw, config.dialect)
794 }
795 };
796 writer.write_all(&values)?;
797 }
798
799 writer.write_all(b";\n")?;
800 }
801 }
802
803 write_dialect_footer(&mut writer, config.dialect)?;
805
806 writer.flush()?;
807
808 Ok(())
809}
810
811fn write_header<W: Write>(
813 writer: &mut W,
814 config: &SampleConfig,
815 stats: &SampleStats,
816) -> std::io::Result<()> {
817 writeln!(writer, "-- Sampled from: {}", config.input.display())?;
818 writeln!(
819 writer,
820 "-- Date: {}",
821 chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
822 )?;
823 writeln!(
824 writer,
825 "-- Mode: {:?}{}",
826 config.mode,
827 if config.preserve_relations {
828 ", preserve-relations"
829 } else {
830 ""
831 }
832 )?;
833 writeln!(writer, "-- Seed: {}", config.seed)?;
834 writeln!(writer, "-- Dialect: {}", config.dialect)?;
835 writeln!(writer, "--")?;
836 writeln!(writer, "-- Statistics:")?;
837 writeln!(writer, "-- Tables sampled: {}", stats.tables_sampled)?;
838 writeln!(writer, "-- Tables skipped: {}", stats.tables_skipped)?;
839
840 let percent = if stats.total_rows_seen > 0 {
841 (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
842 } else {
843 0.0
844 };
845 writeln!(
846 writer,
847 "-- Total rows: {} (from {} original, {:.1}%)",
848 stats.total_rows_selected, stats.total_rows_seen, percent
849 )?;
850
851 if stats.fk_orphans_rejected > 0 {
852 writeln!(
853 writer,
854 "-- FK orphans rejected: {}",
855 stats.fk_orphans_rejected
856 )?;
857 }
858
859 if !stats.warnings.is_empty() {
860 writeln!(writer, "-- Warnings: {}", stats.warnings.len())?;
861 }
862
863 writeln!(writer)?;
864
865 Ok(())
866}
867
868fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
870 match dialect {
871 SqlDialect::MySql => {
872 writeln!(writer, "SET NAMES utf8mb4;")?;
873 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
874 }
875 SqlDialect::Postgres => {
876 writeln!(writer, "SET client_encoding = 'UTF8';")?;
877 writeln!(writer, "SET session_replication_role = replica;")?;
878 }
879 SqlDialect::Sqlite => {
880 writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
881 }
882 }
883 writeln!(writer)?;
884 Ok(())
885}
886
887fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
889 writeln!(writer)?;
890 match dialect {
891 SqlDialect::MySql => {
892 writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
893 }
894 SqlDialect::Postgres => {
895 writeln!(writer, "SET session_replication_role = DEFAULT;")?;
896 }
897 SqlDialect::Sqlite => {
898 writeln!(writer, "PRAGMA foreign_keys = ON;")?;
899 }
900 }
901 Ok(())
902}
903
904fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
906 let mut result = Vec::with_capacity(row.len());
909 let mut i = 0;
910
911 while i < row.len() {
912 if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
913 result.push(b'\'');
915 result.push(b'\'');
916 i += 2;
917 } else {
918 result.push(row[i]);
919 i += 1;
920 }
921 }
922
923 result
924}
925
926fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
928 let mut result = Vec::with_capacity(row.len() + 20);
929 result.push(b'(');
930
931 let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
932
933 for (i, field) in fields.iter().enumerate() {
934 if i > 0 {
935 result.extend_from_slice(b", ");
936 }
937
938 if *field == b"\\N" {
940 result.extend_from_slice(b"NULL");
941 } else if field.is_empty() {
942 match dialect {
944 SqlDialect::MySql => result.extend_from_slice(b"''"),
945 SqlDialect::Postgres | SqlDialect::Sqlite => result.extend_from_slice(b"''"),
946 }
947 } else if is_numeric(field) {
948 result.extend_from_slice(field);
950 } else {
951 result.push(b'\'');
953 for &b in *field {
954 match b {
955 b'\'' => {
956 match dialect {
958 SqlDialect::MySql => result.extend_from_slice(b"\\'"),
959 SqlDialect::Postgres | SqlDialect::Sqlite => {
960 result.extend_from_slice(b"''")
961 }
962 }
963 }
964 b'\\' if dialect == SqlDialect::MySql => {
965 result.extend_from_slice(b"\\\\");
967 }
968 _ => result.push(b),
969 }
970 }
971 result.push(b'\'');
972 }
973 }
974
975 result.push(b')');
976 result
977}
978
979fn is_numeric(s: &[u8]) -> bool {
981 if s.is_empty() {
982 return false;
983 }
984
985 let mut has_digit = false;
986 let mut has_dot = false;
987 let mut start = 0;
988
989 if s[0] == b'-' || s[0] == b'+' {
991 start = 1;
992 }
993
994 for &b in &s[start..] {
995 match b {
996 b'0'..=b'9' => has_digit = true,
997 b'.' if !has_dot => has_dot = true,
998 b'e' | b'E' => {
999 continue;
1001 }
1002 _ => return false,
1003 }
1004 }
1005
1006 has_digit
1007}