Skip to main content

sql_splitter/shard/
config.rs

1//! YAML configuration for the shard command.
2//!
3//! Supports tenant column specification, table classification overrides,
4//! and system/lookup table patterns.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs;
9use std::path::Path;
10
11/// How to handle global/lookup tables during sharding
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum GlobalTableMode {
15    /// Exclude global tables from output
16    None,
17    /// Include lookup tables in full (default)
18    #[default]
19    Lookups,
20    /// Include all global tables in full
21    All,
22}
23
24impl std::str::FromStr for GlobalTableMode {
25    type Err = String;
26
27    fn from_str(s: &str) -> Result<Self, Self::Err> {
28        match s.to_lowercase().as_str() {
29            "none" => Ok(GlobalTableMode::None),
30            "lookups" => Ok(GlobalTableMode::Lookups),
31            "all" => Ok(GlobalTableMode::All),
32            _ => Err(format!(
33                "Unknown global mode: {}. Valid options: none, lookups, all",
34                s
35            )),
36        }
37    }
38}
39
40impl std::fmt::Display for GlobalTableMode {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            GlobalTableMode::None => write!(f, "none"),
44            GlobalTableMode::Lookups => write!(f, "lookups"),
45            GlobalTableMode::All => write!(f, "all"),
46        }
47    }
48}
49
50/// Table classification for sharding behavior
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum ShardTableClassification {
54    /// Table has the tenant column directly
55    TenantRoot,
56    /// Table is connected to tenant via FK chain
57    TenantDependent,
58    /// Junction/pivot table (many-to-many, include if any FK matches)
59    Junction,
60    /// Global/lookup table (include fully or skip based on config)
61    Lookup,
62    /// System table (skip by default: migrations, jobs, cache)
63    System,
64    /// Normal table that couldn't be classified
65    #[default]
66    Unknown,
67}
68
69impl std::fmt::Display for ShardTableClassification {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            ShardTableClassification::TenantRoot => write!(f, "tenant-root"),
73            ShardTableClassification::TenantDependent => write!(f, "tenant-dependent"),
74            ShardTableClassification::Junction => write!(f, "junction"),
75            ShardTableClassification::Lookup => write!(f, "lookup"),
76            ShardTableClassification::System => write!(f, "system"),
77            ShardTableClassification::Unknown => write!(f, "unknown"),
78        }
79    }
80}
81
82/// Per-table configuration override
83#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84#[serde(default)]
85pub struct TableOverride {
86    /// Override classification
87    pub role: Option<ShardTableClassification>,
88    /// Include this lookup/global table
89    pub include: Option<bool>,
90    /// Self-referential FK column (e.g., parent_id for hierarchical tables)
91    pub self_fk: Option<String>,
92    /// Skip this table entirely
93    pub skip: bool,
94}
95
96/// Tenant configuration section
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98#[serde(default)]
99pub struct TenantConfig {
100    /// Column name for tenant identification
101    pub column: Option<String>,
102    /// Explicit root tables (tables that have the tenant column)
103    #[serde(default)]
104    pub root_tables: Vec<String>,
105}
106
107/// Complete YAML configuration for shard command
108#[derive(Debug, Clone, Default, Serialize, Deserialize)]
109#[serde(default)]
110pub struct ShardYamlConfig {
111    /// Tenant configuration
112    pub tenant: TenantConfig,
113    /// Per-table overrides
114    #[serde(default)]
115    pub tables: HashMap<String, TableOverride>,
116    /// Global table handling
117    pub include_global: Option<GlobalTableMode>,
118}
119
120impl ShardYamlConfig {
121    /// Load configuration from a YAML file
122    pub fn load(path: &Path) -> anyhow::Result<Self> {
123        let content = fs::read_to_string(path)?;
124        let config: ShardYamlConfig = serde_yaml::from_str(&content)?;
125        Ok(config)
126    }
127
128    /// Get override for a specific table
129    pub fn get_table_override(&self, table_name: &str) -> Option<&TableOverride> {
130        self.tables.get(table_name).or_else(|| {
131            let lower = table_name.to_lowercase();
132            self.tables
133                .iter()
134                .find(|(k, _)| k.to_lowercase() == lower)
135                .map(|(_, v)| v)
136        })
137    }
138
139    /// Get classification override for a table
140    pub fn get_classification(&self, table_name: &str) -> Option<ShardTableClassification> {
141        self.get_table_override(table_name).and_then(|o| o.role)
142    }
143
144    /// Check if a table should be skipped
145    pub fn should_skip(&self, table_name: &str) -> bool {
146        self.get_table_override(table_name)
147            .map(|o| o.skip)
148            .unwrap_or(false)
149    }
150
151    /// Get self-FK column for hierarchical tables (for future self-referential closure)
152    #[allow(dead_code)]
153    pub fn get_self_fk(&self, table_name: &str) -> Option<&str> {
154        self.get_table_override(table_name)
155            .and_then(|o| o.self_fk.as_deref())
156    }
157}
158
159/// Default patterns for table classification when no config file provided
160pub struct DefaultShardClassifier;
161
162impl DefaultShardClassifier {
163    /// Well-known tenant column names (in priority order)
164    pub const TENANT_COLUMNS: &'static [&'static str] = &[
165        "company_id",
166        "tenant_id",
167        "organization_id",
168        "org_id",
169        "account_id",
170        "team_id",
171        "workspace_id",
172    ];
173
174    /// Well-known system table patterns
175    pub const SYSTEM_PATTERNS: &'static [&'static str] = &[
176        "migrations",
177        "failed_jobs",
178        "job_batches",
179        "jobs",
180        "cache",
181        "cache_locks",
182        "sessions",
183        "password_reset_tokens",
184        "personal_access_tokens",
185        "telescope_entries",
186        "telescope_entries_tags",
187        "telescope_monitoring",
188        "pulse_",
189        "horizon_",
190    ];
191
192    /// Well-known lookup/global table patterns
193    pub const LOOKUP_PATTERNS: &'static [&'static str] = &[
194        "countries",
195        "states",
196        "provinces",
197        "cities",
198        "currencies",
199        "languages",
200        "timezones",
201        "permissions",
202        "roles",
203        "settings",
204    ];
205
206    /// Check if a table name matches system table patterns
207    pub fn is_system_table(table_name: &str) -> bool {
208        let lower = table_name.to_lowercase();
209        for pattern in Self::SYSTEM_PATTERNS {
210            if lower.starts_with(pattern) || lower == *pattern {
211                return true;
212            }
213        }
214        false
215    }
216
217    /// Check if a table name matches lookup table patterns
218    pub fn is_lookup_table(table_name: &str) -> bool {
219        let lower = table_name.to_lowercase();
220        for pattern in Self::LOOKUP_PATTERNS {
221            if lower == *pattern {
222                return true;
223            }
224        }
225        false
226    }
227
228    /// Detect junction table by name pattern
229    pub fn is_junction_table_by_name(table_name: &str) -> bool {
230        let lower = table_name.to_lowercase();
231        lower.contains("_has_")
232            || lower.ends_with("_pivot")
233            || lower.ends_with("_link")
234            || lower.ends_with("_map")
235    }
236}