Skip to main content

the_code_graph_cli/
config.rs

1use domain::error::{CodeGraphError, Result};
2use serde::Deserialize;
3use std::path::Path;
4
5#[derive(Debug, Clone, Default, Deserialize)]
6pub struct CodeGraphConfig {
7    pub index: Option<IndexConfig>,
8    pub search: Option<SearchConfig>,
9    pub watch: Option<WatchConfig>,
10    pub flows: Option<FlowsConfig>,
11    pub risk: Option<RiskCliConfig>,
12    pub communities: Option<CommunitiesConfig>,
13    pub embeddings: Option<EmbeddingsCliConfig>,
14    #[serde(rename = "dead-code")]
15    pub dead_code: Option<DeadCodeCliConfig>,
16}
17
18#[derive(Debug, Clone, Default, Deserialize)]
19pub struct IndexConfig {
20    pub exclude: Option<Vec<String>>,
21}
22
23#[derive(Debug, Clone, Default, Deserialize)]
24pub struct SearchConfig {
25    pub max_results: Option<usize>,
26    pub rrf_k: Option<usize>,
27    pub kind_boost: Option<bool>,
28}
29
30#[derive(Debug, Clone, Default, Deserialize)]
31pub struct EmbeddingsCliConfig {
32    pub enabled: Option<bool>,
33    pub model: Option<String>,
34    pub batch_size: Option<usize>,
35}
36
37#[derive(Debug, Clone, Default, Deserialize)]
38pub struct WatchConfig {
39    pub debounce_ms: Option<u64>,
40}
41
42#[derive(Debug, Clone, Default, Deserialize)]
43pub struct FlowsConfig {
44    pub extra_entry_points: Option<Vec<String>>,
45    pub excluded_entry_points: Option<Vec<String>>,
46}
47
48#[derive(Debug, Clone, Default, Deserialize)]
49pub struct RiskCliConfig {
50    pub weight_criticality: Option<f64>,
51    pub weight_coupling: Option<f64>,
52    pub weight_test_gap: Option<f64>,
53    pub weight_sensitivity: Option<f64>,
54    pub extra_security_patterns: Option<Vec<String>>,
55    pub excluded_security_patterns: Option<Vec<String>>,
56}
57
58#[derive(Debug, Clone, Default, Deserialize)]
59pub struct CommunitiesConfig {
60    pub resolution: Option<f64>,
61    pub min_community_size: Option<usize>,
62    pub seed: Option<u64>,
63}
64
65#[derive(Debug, Clone, Default, Deserialize)]
66pub struct DeadCodeCliConfig {
67    pub exclude_patterns: Option<Vec<String>>,
68    pub entry_point_patterns: Option<Vec<String>>,
69    pub migration_patterns: Option<Vec<String>>,
70}
71
72pub fn load_config(project_root: &Path) -> Result<CodeGraphConfig> {
73    let config_path = project_root.join(".code-graph").join("config.toml");
74    if !config_path.exists() {
75        return Ok(CodeGraphConfig::default());
76    }
77    let content =
78        std::fs::read_to_string(&config_path).map_err(|e| CodeGraphError::FileSystem {
79            path: config_path.clone(),
80            source: e,
81        })?;
82    toml::from_str(&content).map_err(|e| CodeGraphError::Other(format!("invalid config: {e}")))
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn missing_config_returns_defaults() {
91        let tmp = tempfile::tempdir().unwrap();
92        let config = load_config(tmp.path()).unwrap();
93        assert!(config.index.is_none());
94        assert!(config.search.is_none());
95    }
96
97    #[test]
98    fn valid_config_parses() {
99        let tmp = tempfile::tempdir().unwrap();
100        let dir = tmp.path().join(".code-graph");
101        std::fs::create_dir_all(&dir).unwrap();
102        std::fs::write(
103            dir.join("config.toml"),
104            r#"
105[index]
106exclude = ["target", "node_modules"]
107
108[search]
109max_results = 50
110"#,
111        )
112        .unwrap();
113        let config = load_config(tmp.path()).unwrap();
114        let index = config.index.unwrap();
115        assert_eq!(index.exclude.unwrap(), vec!["target", "node_modules"]);
116        let search = config.search.unwrap();
117        assert_eq!(search.max_results.unwrap(), 50);
118    }
119
120    #[test]
121    fn flows_config_parses() {
122        let tmp = tempfile::tempdir().unwrap();
123        let dir = tmp.path().join(".code-graph");
124        std::fs::create_dir_all(&dir).unwrap();
125        std::fs::write(
126            dir.join("config.toml"),
127            r#"
128[flows]
129extra_entry_points = ["src/custom.rs::handler"]
130excluded_entry_points = ["src/test_helper.rs::setup"]
131"#,
132        )
133        .unwrap();
134        let config = load_config(tmp.path()).unwrap();
135        let flows = config.flows.unwrap();
136        assert_eq!(
137            flows.extra_entry_points.unwrap(),
138            vec!["src/custom.rs::handler"]
139        );
140        assert_eq!(
141            flows.excluded_entry_points.unwrap(),
142            vec!["src/test_helper.rs::setup"]
143        );
144    }
145
146    #[test]
147    fn communities_config_parses() {
148        let tmp = tempfile::tempdir().unwrap();
149        let dir = tmp.path().join(".code-graph");
150        std::fs::create_dir_all(&dir).unwrap();
151        std::fs::write(
152            dir.join("config.toml"),
153            r#"
154[communities]
155resolution = 1.5
156min_community_size = 3
157seed = 42
158"#,
159        )
160        .unwrap();
161        let config = load_config(tmp.path()).unwrap();
162        let cc = config.communities.unwrap();
163        assert!((cc.resolution.unwrap() - 1.5).abs() < f64::EPSILON);
164        assert_eq!(cc.min_community_size.unwrap(), 3);
165        assert_eq!(cc.seed.unwrap(), 42);
166    }
167
168    #[test]
169    fn embeddings_config_parses() {
170        let tmp = tempfile::tempdir().unwrap();
171        let dir = tmp.path().join(".code-graph");
172        std::fs::create_dir_all(&dir).unwrap();
173        std::fs::write(
174            dir.join("config.toml"),
175            r#"
176[embeddings]
177enabled = true
178model = "all-MiniLM-L6-v2"
179batch_size = 32
180
181[search]
182rrf_k = 60
183kind_boost = true
184"#,
185        )
186        .unwrap();
187        let config = load_config(tmp.path()).unwrap();
188        let emb = config.embeddings.unwrap();
189        assert_eq!(emb.enabled.unwrap(), true);
190        assert_eq!(emb.model.unwrap(), "all-MiniLM-L6-v2");
191        assert_eq!(emb.batch_size.unwrap(), 32);
192        let search = config.search.unwrap();
193        assert_eq!(search.rrf_k.unwrap(), 60);
194        assert_eq!(search.kind_boost.unwrap(), true);
195    }
196
197    #[test]
198    fn dead_code_config_parses() {
199        let tmp = tempfile::tempdir().unwrap();
200        let dir = tmp.path().join(".code-graph");
201        std::fs::create_dir_all(&dir).unwrap();
202        std::fs::write(
203            dir.join("config.toml"),
204            r#"
205[dead-code]
206exclude_patterns = ["**/generated/**", "**/proto/**"]
207migration_patterns = ["**/migrations/**"]
208entry_point_patterns = ["*_handler", "*_endpoint"]
209"#,
210        )
211        .unwrap();
212        let config = load_config(tmp.path()).unwrap();
213        let dc = config.dead_code.unwrap();
214        assert_eq!(
215            dc.exclude_patterns.unwrap(),
216            vec!["**/generated/**", "**/proto/**"]
217        );
218        assert_eq!(dc.migration_patterns.unwrap(), vec!["**/migrations/**"]);
219        assert_eq!(
220            dc.entry_point_patterns.unwrap(),
221            vec!["*_handler", "*_endpoint"]
222        );
223    }
224
225    #[test]
226    fn invalid_toml_returns_error() {
227        let tmp = tempfile::tempdir().unwrap();
228        let dir = tmp.path().join(".code-graph");
229        std::fs::create_dir_all(&dir).unwrap();
230        std::fs::write(dir.join("config.toml"), "not valid toml {{{{").unwrap();
231        assert!(load_config(tmp.path()).is_err());
232    }
233
234    #[test]
235    fn risk_config_parses() {
236        let tmp = tempfile::tempdir().unwrap();
237        let dir = tmp.path().join(".code-graph");
238        std::fs::create_dir_all(&dir).unwrap();
239        std::fs::write(
240            dir.join("config.toml"),
241            r#"
242[risk]
243weight_criticality = 0.40
244weight_coupling = 0.20
245weight_test_gap = 0.20
246weight_sensitivity = 0.20
247extra_security_patterns = ["unsafe", "inject"]
248excluded_security_patterns = ["hash"]
249"#,
250        )
251        .unwrap();
252        let config = load_config(tmp.path()).unwrap();
253        let risk = config.risk.unwrap();
254        assert!((risk.weight_criticality.unwrap() - 0.40).abs() < f64::EPSILON);
255        assert!((risk.weight_coupling.unwrap() - 0.20).abs() < f64::EPSILON);
256        assert_eq!(
257            risk.extra_security_patterns.unwrap(),
258            vec!["unsafe", "inject"]
259        );
260        assert_eq!(risk.excluded_security_patterns.unwrap(), vec!["hash"]);
261    }
262}