1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs;
9use std::path::Path;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum GlobalTableMode {
15 None,
17 #[default]
19 Lookups,
20 All,
22}
23
24impl std::str::FromStr for GlobalTableMode {
25 type Err = String;
26
27 fn from_str(s: &str) -> Result<Self, Self::Err> {
28 match s.to_lowercase().as_str() {
29 "none" => Ok(GlobalTableMode::None),
30 "lookups" => Ok(GlobalTableMode::Lookups),
31 "all" => Ok(GlobalTableMode::All),
32 _ => Err(format!(
33 "Unknown global mode: {}. Valid options: none, lookups, all",
34 s
35 )),
36 }
37 }
38}
39
40impl std::fmt::Display for GlobalTableMode {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 GlobalTableMode::None => write!(f, "none"),
44 GlobalTableMode::Lookups => write!(f, "lookups"),
45 GlobalTableMode::All => write!(f, "all"),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum ShardTableClassification {
54 TenantRoot,
56 TenantDependent,
58 Junction,
60 Lookup,
62 System,
64 #[default]
66 Unknown,
67}
68
69impl std::fmt::Display for ShardTableClassification {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 ShardTableClassification::TenantRoot => write!(f, "tenant-root"),
73 ShardTableClassification::TenantDependent => write!(f, "tenant-dependent"),
74 ShardTableClassification::Junction => write!(f, "junction"),
75 ShardTableClassification::Lookup => write!(f, "lookup"),
76 ShardTableClassification::System => write!(f, "system"),
77 ShardTableClassification::Unknown => write!(f, "unknown"),
78 }
79 }
80}
81
82#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84#[serde(default)]
85pub struct TableOverride {
86 pub role: Option<ShardTableClassification>,
88 pub include: Option<bool>,
90 pub self_fk: Option<String>,
92 pub skip: bool,
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98#[serde(default)]
99pub struct TenantConfig {
100 pub column: Option<String>,
102 #[serde(default)]
104 pub root_tables: Vec<String>,
105}
106
107#[derive(Debug, Clone, Default, Serialize, Deserialize)]
109#[serde(default)]
110pub struct ShardYamlConfig {
111 pub tenant: TenantConfig,
113 #[serde(default)]
115 pub tables: HashMap<String, TableOverride>,
116 pub include_global: Option<GlobalTableMode>,
118}
119
120impl ShardYamlConfig {
121 pub fn load(path: &Path) -> anyhow::Result<Self> {
123 let content = fs::read_to_string(path)?;
124 let config: ShardYamlConfig = serde_yaml::from_str(&content)?;
125 Ok(config)
126 }
127
128 pub fn get_table_override(&self, table_name: &str) -> Option<&TableOverride> {
130 self.tables.get(table_name).or_else(|| {
131 let lower = table_name.to_lowercase();
132 self.tables
133 .iter()
134 .find(|(k, _)| k.to_lowercase() == lower)
135 .map(|(_, v)| v)
136 })
137 }
138
139 pub fn get_classification(&self, table_name: &str) -> Option<ShardTableClassification> {
141 self.get_table_override(table_name).and_then(|o| o.role)
142 }
143
144 pub fn should_skip(&self, table_name: &str) -> bool {
146 self.get_table_override(table_name)
147 .map(|o| o.skip)
148 .unwrap_or(false)
149 }
150
151 #[allow(dead_code)]
153 pub fn get_self_fk(&self, table_name: &str) -> Option<&str> {
154 self.get_table_override(table_name)
155 .and_then(|o| o.self_fk.as_deref())
156 }
157}
158
159pub struct DefaultShardClassifier;
161
162impl DefaultShardClassifier {
163 pub const TENANT_COLUMNS: &'static [&'static str] = &[
165 "company_id",
166 "tenant_id",
167 "organization_id",
168 "org_id",
169 "account_id",
170 "team_id",
171 "workspace_id",
172 ];
173
174 pub const SYSTEM_PATTERNS: &'static [&'static str] = &[
176 "migrations",
177 "failed_jobs",
178 "job_batches",
179 "jobs",
180 "cache",
181 "cache_locks",
182 "sessions",
183 "password_reset_tokens",
184 "personal_access_tokens",
185 "telescope_entries",
186 "telescope_entries_tags",
187 "telescope_monitoring",
188 "pulse_",
189 "horizon_",
190 ];
191
192 pub const LOOKUP_PATTERNS: &'static [&'static str] = &[
194 "countries",
195 "states",
196 "provinces",
197 "cities",
198 "currencies",
199 "languages",
200 "timezones",
201 "permissions",
202 "roles",
203 "settings",
204 ];
205
206 pub fn is_system_table(table_name: &str) -> bool {
208 let lower = table_name.to_lowercase();
209 for pattern in Self::SYSTEM_PATTERNS {
210 if lower.starts_with(pattern) || lower == *pattern {
211 return true;
212 }
213 }
214 false
215 }
216
217 pub fn is_lookup_table(table_name: &str) -> bool {
219 let lower = table_name.to_lowercase();
220 for pattern in Self::LOOKUP_PATTERNS {
221 if lower == *pattern {
222 return true;
223 }
224 }
225 false
226 }
227
228 pub fn is_junction_table_by_name(table_name: &str) -> bool {
230 let lower = table_name.to_lowercase();
231 lower.contains("_has_")
232 || lower.ends_with("_pivot")
233 || lower.ends_with("_link")
234 || lower.ends_with("_map")
235 }
236}