sql_splitter/sample/
config.rs

1//! YAML configuration for the sample command.
2//!
3//! Supports per-table sampling strategies and table classification.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fs;
8use std::path::Path;
9
10/// How to handle global/lookup tables
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum GlobalTableMode {
14    /// Exclude global tables
15    None,
16    /// Include lookup tables in full (default)
17    #[default]
18    Lookups,
19    /// Include all global tables in full
20    All,
21}
22
23impl std::str::FromStr for GlobalTableMode {
24    type Err = String;
25
26    fn from_str(s: &str) -> Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "none" => Ok(GlobalTableMode::None),
29            "lookups" => Ok(GlobalTableMode::Lookups),
30            "all" => Ok(GlobalTableMode::All),
31            _ => Err(format!(
32                "Unknown global mode: {}. Valid options: none, lookups, all",
33                s
34            )),
35        }
36    }
37}
38
39impl std::fmt::Display for GlobalTableMode {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            GlobalTableMode::None => write!(f, "none"),
43            GlobalTableMode::Lookups => write!(f, "lookups"),
44            GlobalTableMode::All => write!(f, "all"),
45        }
46    }
47}
48
49/// Table classification for sampling behavior
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
51#[serde(rename_all = "lowercase")]
52pub enum TableClassification {
53    /// Normal table, sample according to mode
54    #[default]
55    Normal,
56    /// Root table (has no FK dependencies or explicitly specified)
57    Root,
58    /// Global/lookup table (include fully or skip based on --include-global)
59    Lookup,
60    /// System table (skip by default: migrations, jobs, cache)
61    System,
62    /// Junction/pivot table (many-to-many)
63    Junction,
64}
65
66/// Per-table sampling configuration
67#[derive(Debug, Clone, Default, Serialize, Deserialize)]
68#[serde(default)]
69pub struct TableConfig {
70    /// Sample percentage for this table (overrides default)
71    pub percent: Option<u32>,
72    /// Fixed row count for this table (overrides default)
73    pub rows: Option<usize>,
74    /// Skip this table entirely
75    pub skip: bool,
76    /// Override table classification
77    pub classification: Option<TableClassification>,
78}
79
80/// Default sampling settings
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82#[serde(default)]
83pub struct DefaultConfig {
84    /// Default sample percentage
85    pub percent: Option<u32>,
86    /// Default row count
87    pub rows: Option<usize>,
88}
89
90/// Table classification lists
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92#[serde(default)]
93pub struct ClassificationConfig {
94    /// Tables to classify as global (include fully)
95    #[serde(default)]
96    pub global: Vec<String>,
97    /// Tables to classify as system (skip by default)
98    #[serde(default)]
99    pub system: Vec<String>,
100    /// Tables to classify as lookup (include based on --include-global)
101    #[serde(default)]
102    pub lookup: Vec<String>,
103    /// Tables to classify as root (start sampling from these)
104    #[serde(default)]
105    pub root: Vec<String>,
106}
107
108/// Complete YAML configuration for sample command
109#[derive(Debug, Clone, Default, Serialize, Deserialize)]
110#[serde(default)]
111pub struct SampleYamlConfig {
112    /// Default sampling settings
113    pub default: DefaultConfig,
114    /// Table classification lists
115    pub classification: ClassificationConfig,
116    /// Per-table settings
117    #[serde(default)]
118    pub tables: HashMap<String, TableConfig>,
119}
120
121impl SampleYamlConfig {
122    /// Load configuration from a YAML file
123    pub fn load(path: &Path) -> anyhow::Result<Self> {
124        let content = fs::read_to_string(path)?;
125        let config: SampleYamlConfig = serde_yaml::from_str(&content)?;
126        Ok(config)
127    }
128
129    /// Get configuration for a specific table
130    pub fn get_table_config(&self, table_name: &str) -> Option<&TableConfig> {
131        self.tables.get(table_name).or_else(|| {
132            // Try case-insensitive match
133            let lower = table_name.to_lowercase();
134            self.tables
135                .iter()
136                .find(|(k, _)| k.to_lowercase() == lower)
137                .map(|(_, v)| v)
138        })
139    }
140
141    /// Get classification for a table
142    pub fn get_classification(&self, table_name: &str) -> TableClassification {
143        // Check per-table override first
144        if let Some(config) = self.get_table_config(table_name) {
145            if let Some(class) = config.classification {
146                return class;
147            }
148        }
149
150        let lower = table_name.to_lowercase();
151
152        // Check classification lists
153        if self
154            .classification
155            .global
156            .iter()
157            .any(|t| t.to_lowercase() == lower)
158        {
159            return TableClassification::Lookup;
160        }
161        if self
162            .classification
163            .system
164            .iter()
165            .any(|t| t.to_lowercase() == lower)
166        {
167            return TableClassification::System;
168        }
169        if self
170            .classification
171            .lookup
172            .iter()
173            .any(|t| t.to_lowercase() == lower)
174        {
175            return TableClassification::Lookup;
176        }
177        if self
178            .classification
179            .root
180            .iter()
181            .any(|t| t.to_lowercase() == lower)
182        {
183            return TableClassification::Root;
184        }
185
186        TableClassification::Normal
187    }
188
189    /// Check if a table should be skipped
190    pub fn should_skip(&self, table_name: &str) -> bool {
191        if let Some(config) = self.get_table_config(table_name) {
192            return config.skip;
193        }
194        false
195    }
196
197    /// Get sample percent for a table (table-specific or default)
198    pub fn get_percent(&self, table_name: &str) -> Option<u32> {
199        if let Some(config) = self.get_table_config(table_name) {
200            if config.percent.is_some() {
201                return config.percent;
202            }
203        }
204        self.default.percent
205    }
206
207    /// Get sample rows for a table (table-specific or default)
208    pub fn get_rows(&self, table_name: &str) -> Option<usize> {
209        if let Some(config) = self.get_table_config(table_name) {
210            if config.rows.is_some() {
211                return config.rows;
212            }
213        }
214        self.default.rows
215    }
216}
217
218/// Default patterns for table classification (used when no config file)
219pub struct DefaultClassifier;
220
221impl DefaultClassifier {
222    /// Well-known system table patterns
223    const SYSTEM_PATTERNS: &'static [&'static str] = &[
224        "migrations",
225        "failed_jobs",
226        "job_batches",
227        "jobs",
228        "cache",
229        "cache_locks",
230        "sessions",
231        "password_reset_tokens",
232        "personal_access_tokens",
233        "telescope_entries",
234        "telescope_entries_tags",
235        "telescope_monitoring",
236        "pulse_",
237        "horizon_",
238    ];
239
240    /// Well-known lookup/global table patterns
241    const LOOKUP_PATTERNS: &'static [&'static str] = &[
242        "countries",
243        "states",
244        "provinces",
245        "cities",
246        "currencies",
247        "languages",
248        "timezones",
249        "permissions",
250        "roles",
251        "settings",
252    ];
253
254    /// Classify a table using default patterns
255    pub fn classify(table_name: &str) -> TableClassification {
256        let lower = table_name.to_lowercase();
257
258        // Check system patterns
259        for pattern in Self::SYSTEM_PATTERNS {
260            if lower.starts_with(pattern) || lower == *pattern {
261                return TableClassification::System;
262            }
263        }
264
265        // Check lookup patterns
266        for pattern in Self::LOOKUP_PATTERNS {
267            if lower == *pattern {
268                return TableClassification::Lookup;
269            }
270        }
271
272        TableClassification::Normal
273    }
274}