Skip to main content

schema_risk/
config.rs

1//! Configuration loader for `schema-risk.yml` / `schema-risk.yaml`.
2//!
3//! Loads per-project configuration from `schema-risk.yml` in the current
4//! directory (or a path supplied via `--config`). Falls back gracefully to
5//! built-in defaults when the file is absent.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11// ─────────────────────────────────────────────
12// Top-level config struct
13// ─────────────────────────────────────────────
14
15/// Root configuration loaded from `schema-risk.yml`.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(default)]
18pub struct Config {
19    pub version: u32,
20    pub database: DatabaseConfig,
21    pub migrations: MigrationsConfig,
22    pub thresholds: Thresholds,
23    pub rules: RulesConfig,
24    pub scan: ScanConfig,
25    pub guard: GuardConfig,
26    pub output: OutputConfig,
27}
28
29impl Default for Config {
30    fn default() -> Self {
31        Self {
32            version: 2,
33            database: DatabaseConfig::default(),
34            migrations: MigrationsConfig::default(),
35            thresholds: Thresholds::default(),
36            rules: RulesConfig::default(),
37            scan: ScanConfig::default(),
38            guard: GuardConfig::default(),
39            output: OutputConfig::default(),
40        }
41    }
42}
43
44// ─────────────────────────────────────────────
45// Database config
46// ─────────────────────────────────────────────
47
48/// Database connection configuration.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(default)]
51pub struct DatabaseConfig {
52    /// Primary environment variable to read database URL from.
53    pub url_env: String,
54    /// Fallback environment variables to check (in order).
55    pub fallback_envs: Vec<String>,
56    /// Whether to automatically load `.env` file.
57    pub load_dotenv: bool,
58    /// Direct database URL (not recommended — use env vars instead).
59    pub url: Option<String>,
60}
61
62impl Default for DatabaseConfig {
63    fn default() -> Self {
64        Self {
65            url_env: "DATABASE_URL".to_string(),
66            fallback_envs: vec!["DB_URL".to_string(), "POSTGRES_URL".to_string()],
67            load_dotenv: true,
68            url: None,
69        }
70    }
71}
72
73// ─────────────────────────────────────────────
74// Migrations config
75// ─────────────────────────────────────────────
76
77/// Migration file discovery configuration.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79#[serde(default)]
80pub struct MigrationsConfig {
81    /// Custom paths to scan for migrations (empty = auto-detect only).
82    pub paths: Vec<String>,
83    /// Enable automatic migration directory discovery.
84    pub auto_discover: bool,
85    /// Glob patterns for migration files.
86    pub patterns: Vec<String>,
87}
88
89impl Default for MigrationsConfig {
90    fn default() -> Self {
91        Self {
92            paths: vec![],
93            auto_discover: true,
94            patterns: vec![
95                "prisma/migrations/**/migration.sql".to_string(),
96                "db/migrate/**/*.sql".to_string(),
97                "migrations/**/*.sql".to_string(),
98                "alembic/versions/**/*.sql".to_string(),
99                "drizzle/**/*.sql".to_string(),
100                "supabase/migrations/**/*.sql".to_string(),
101                "flyway/sql/**/*.sql".to_string(),
102                "src/migrations/**/*.sql".to_string(),
103                "database/migrations/**/*.sql".to_string(),
104            ],
105        }
106    }
107}
108
109// ─────────────────────────────────────────────
110// Thresholds
111// ─────────────────────────────────────────────
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(default)]
115pub struct Thresholds {
116    /// Exit non-zero if any migration reaches this risk level.
117    pub fail_on: String,
118    /// Show guard prompt starting at this risk level.
119    pub guard_on: String,
120}
121
122impl Default for Thresholds {
123    fn default() -> Self {
124        Self {
125            fail_on: "high".to_string(),
126            guard_on: "medium".to_string(),
127        }
128    }
129}
130
131// ─────────────────────────────────────────────
132// Rules config
133// ─────────────────────────────────────────────
134
135#[derive(Debug, Clone, Serialize, Deserialize, Default)]
136#[serde(default)]
137pub struct RulesConfig {
138    /// Rule IDs to disable (e.g. ["R03", "R07"])
139    pub disabled: Vec<String>,
140    /// Per-table risk overrides.
141    pub table_overrides: HashMap<String, TableOverride>,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, Default)]
145#[serde(default)]
146pub struct TableOverride {
147    /// Allow higher risk level on this table.
148    pub max_risk: Option<String>,
149    /// Skip risk analysis entirely for this table.
150    pub ignored: bool,
151}
152
153// ─────────────────────────────────────────────
154// Scan config
155// ─────────────────────────────────────────────
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158#[serde(default)]
159pub struct ScanConfig {
160    pub root_dir: String,
161    pub extensions: Vec<String>,
162    pub exclude: Vec<String>,
163    /// Skip columns/tables with fewer than 4 characters (avoids false positives).
164    pub skip_short_identifiers: bool,
165}
166
167impl Default for ScanConfig {
168    fn default() -> Self {
169        Self {
170            root_dir: ".".to_string(),
171            extensions: vec![
172                "rs".to_string(),
173                "py".to_string(),
174                "go".to_string(),
175                "ts".to_string(),
176                "js".to_string(),
177                "rb".to_string(),
178                "java".to_string(),
179                "kt".to_string(),
180                "cs".to_string(),
181                "php".to_string(),
182            ],
183            exclude: vec![
184                "target/".to_string(),
185                "node_modules/".to_string(),
186                "vendor/".to_string(),
187                ".git/".to_string(),
188                "dist/".to_string(),
189                "build/".to_string(),
190            ],
191            skip_short_identifiers: true,
192        }
193    }
194}
195
196// ─────────────────────────────────────────────
197// Guard config
198// ─────────────────────────────────────────────
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(default)]
202pub struct GuardConfig {
203    /// Require full phrase "yes I am sure" for Critical operations.
204    pub require_typed_confirmation: bool,
205    /// Path to write the audit log JSON.
206    pub audit_log: String,
207    /// Always exit 4 when actor is detected as an AI agent.
208    pub block_agents: bool,
209    /// Exit 4 for CI pipelines (default: false - just print warning).
210    pub block_ci: bool,
211}
212
213impl Default for GuardConfig {
214    fn default() -> Self {
215        Self {
216            require_typed_confirmation: true,
217            audit_log: ".schemarisk-audit.json".to_string(),
218            block_agents: true,
219            block_ci: false,
220        }
221    }
222}
223
224// ─────────────────────────────────────────────
225// Output config
226// ─────────────────────────────────────────────
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229#[serde(default)]
230pub struct OutputConfig {
231    pub format: String,
232    pub color: bool,
233    pub show_recommendations: bool,
234    pub show_impact: bool,
235}
236
237impl Default for OutputConfig {
238    fn default() -> Self {
239        Self {
240            format: "terminal".to_string(),
241            color: true,
242            show_recommendations: true,
243            show_impact: true,
244        }
245    }
246}
247
248// ─────────────────────────────────────────────
249// Loader
250// ─────────────────────────────────────────────
251
252/// Load configuration from a file path, falling back to defaults if absent.
253///
254/// Searches in order:
255/// 1. The path supplied via `--config <PATH>`
256/// 2. `./schema-risk.yml`
257/// 3. `./schema-risk.yaml`
258/// 4. Built-in defaults
259pub fn load(path: Option<&str>) -> Config {
260    let candidates: Vec<&str> = if let Some(p) = path {
261        vec![p]
262    } else {
263        vec!["schema-risk.yml", "schema-risk.yaml"]
264    };
265
266    for candidate in &candidates {
267        if let Some(config) = try_load(Path::new(candidate)) {
268            return config;
269        }
270    }
271
272    Config::default()
273}
274
275fn try_load(path: &Path) -> Option<Config> {
276    if !path.exists() {
277        return None;
278    }
279    let contents = std::fs::read_to_string(path).ok()?;
280    match serde_yaml::from_str::<Config>(&contents) {
281        Ok(c) => Some(c),
282        Err(e) => {
283            eprintln!("warning: Failed to parse {}: {e}", path.display());
284            None
285        }
286    }
287}
288
289/// Return the canonical YAML template written by `schemarisk init`.
290pub fn default_yaml_template() -> &'static str {
291    r#"# schema-risk.yml — per-project SchemaRisk configuration
292version: 2
293
294# Database connection settings
295database:
296  url_env: DATABASE_URL           # primary env var for DB URL
297  fallback_envs:                  # fallback env vars to check (in order)
298    - DB_URL
299    - POSTGRES_URL
300  load_dotenv: true               # auto-load .env file
301  # url: postgres://...           # direct URL (not recommended - use env vars)
302
303# Migration file discovery
304migrations:
305  paths: []                       # custom paths (empty = auto-detect)
306  auto_discover: true             # scan for common migration patterns
307  patterns:                       # glob patterns for SQL migration files
308    - "prisma/migrations/**/migration.sql"
309    - "db/migrate/**/*.sql"
310    - "migrations/**/*.sql"
311    - "alembic/versions/**/*.sql"
312    - "drizzle/**/*.sql"
313    - "supabase/migrations/**/*.sql"
314
315thresholds:
316  fail_on: high                   # low | medium | high | critical
317  guard_on: medium                # operations at this level trigger guard prompts
318
319rules:
320  disabled: []                    # e.g. [R03, R07]
321  table_overrides:
322    audit_log:
323      max_risk: critical          # allow higher risk on append-only tables
324    sessions:
325      ignored: true               # skip risk analysis entirely
326
327scan:
328  root_dir: "."                   # directory to scan for code impact
329  extensions: [rs, py, go, ts, js, rb, java, kt, cs, php]
330  exclude: [target/, node_modules/, vendor/, .git/, dist/, build/]
331  skip_short_identifiers: true    # skip columns < 4 chars (avoids false positives)
332
333guard:
334  require_typed_confirmation: true  # "yes I am sure" for Critical ops
335  audit_log: ".schemarisk-audit.json"
336  block_agents: true              # always block AI agents
337  block_ci: false                 # set true to block CI pipelines
338
339output:
340  format: terminal                # terminal | json | markdown | sarif
341  color: true
342  show_recommendations: true
343  show_impact: true
344"#
345}
346
347// ─────────────────────────────────────────────
348// Tests
349// ─────────────────────────────────────────────
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn default_config_has_sensible_values() {
357        let cfg = Config::default();
358        assert_eq!(cfg.thresholds.fail_on, "high");
359        assert_eq!(cfg.thresholds.guard_on, "medium");
360        assert!(cfg.guard.block_agents);
361        assert!(!cfg.guard.block_ci);
362        assert!(cfg.scan.skip_short_identifiers);
363        assert_eq!(cfg.database.url_env, "DATABASE_URL");
364        assert!(cfg.database.load_dotenv);
365        assert!(cfg.migrations.auto_discover);
366    }
367
368    #[test]
369    fn yaml_template_parses_correctly() {
370        let cfg: Config =
371            serde_yaml::from_str(default_yaml_template()).expect("template should be valid YAML");
372        assert_eq!(cfg.version, 2);
373        assert_eq!(cfg.thresholds.fail_on, "high");
374        assert_eq!(cfg.database.url_env, "DATABASE_URL");
375        assert!(cfg.migrations.auto_discover);
376    }
377
378    #[test]
379    fn database_config_defaults() {
380        let db = DatabaseConfig::default();
381        assert_eq!(db.url_env, "DATABASE_URL");
382        assert_eq!(db.fallback_envs, vec!["DB_URL", "POSTGRES_URL"]);
383        assert!(db.load_dotenv);
384        assert!(db.url.is_none());
385    }
386
387    #[test]
388    fn migrations_config_defaults() {
389        let mig = MigrationsConfig::default();
390        assert!(mig.paths.is_empty());
391        assert!(mig.auto_discover);
392        assert!(!mig.patterns.is_empty());
393        assert!(mig.patterns.iter().any(|p| p.contains("prisma")));
394    }
395}