tensorlogic_cli/
config.rs1use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::fs;
12use std::path::{Path, PathBuf};
13
14use crate::macros::MacroDef;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(default)]
19pub struct Config {
20 pub strategy: String,
22
23 pub domains: HashMap<String, usize>,
25
26 pub output_format: String,
28
29 pub validate: bool,
31
32 pub debug: bool,
34
35 pub colored: bool,
37
38 pub repl: ReplConfig,
40
41 pub watch: WatchConfig,
43
44 pub cache: CacheConfig,
46
47 pub macros: Vec<MacroDef>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(default)]
54pub struct CacheConfig {
55 pub enabled: bool,
57
58 pub max_entries: usize,
60
61 pub disk_cache_enabled: bool,
63
64 pub disk_cache_max_size_mb: usize,
66
67 pub disk_cache_dir: Option<PathBuf>,
69}
70
71impl Default for CacheConfig {
72 fn default() -> Self {
73 Self {
74 enabled: true,
75 max_entries: 100,
76 disk_cache_enabled: true,
77 disk_cache_max_size_mb: 500,
78 disk_cache_dir: None,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(default)]
85pub struct ReplConfig {
86 pub prompt: String,
88
89 pub history_file: String,
91
92 pub max_history: usize,
94
95 pub auto_save: bool,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100#[serde(default)]
101pub struct WatchConfig {
102 pub debounce_ms: u64,
104
105 pub clear_screen: bool,
107
108 pub show_timestamps: bool,
110}
111
112impl Default for Config {
113 fn default() -> Self {
114 let mut domains = HashMap::new();
115 domains.insert("D".to_string(), 100);
116
117 Self {
118 strategy: "soft_differentiable".to_string(),
119 domains,
120 output_format: "graph".to_string(),
121 validate: false,
122 debug: false,
123 colored: true,
124 repl: ReplConfig::default(),
125 watch: WatchConfig::default(),
126 cache: CacheConfig::default(),
127 macros: Vec::new(),
128 }
129 }
130}
131
132impl Default for ReplConfig {
133 fn default() -> Self {
134 Self {
135 prompt: "tensorlogic> ".to_string(),
136 history_file: ".tensorlogic_history".to_string(),
137 max_history: 1000,
138 auto_save: true,
139 }
140 }
141}
142
143impl Default for WatchConfig {
144 fn default() -> Self {
145 Self {
146 debounce_ms: 500,
147 clear_screen: true,
148 show_timestamps: true,
149 }
150 }
151}
152
153impl Config {
154 pub fn load(path: &Path) -> Result<Self> {
156 let content = fs::read_to_string(path)
157 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
158
159 toml::from_str(&content)
160 .with_context(|| format!("Failed to parse config file: {}", path.display()))
161 }
162
163 pub fn save(&self, path: &Path) -> Result<()> {
165 let content = toml::to_string_pretty(self).context("Failed to serialize configuration")?;
166
167 fs::write(path, content)
168 .with_context(|| format!("Failed to write config file: {}", path.display()))
169 }
170
171 pub fn load_default() -> Self {
178 if let Ok(path) = std::env::var("TENSORLOGIC_CONFIG") {
180 if let Ok(config) = Self::load(Path::new(&path)) {
181 return config;
182 }
183 }
184
185 let current_config = PathBuf::from(".tensorlogicrc");
187 if current_config.exists() {
188 if let Ok(config) = Self::load(¤t_config) {
189 return config;
190 }
191 }
192
193 if let Some(home) = dirs::home_dir() {
195 let home_config = home.join(".tensorlogicrc");
196 if home_config.exists() {
197 if let Ok(config) = Self::load(&home_config) {
198 return config;
199 }
200 }
201 }
202
203 Self::default()
205 }
206
207 pub fn config_path() -> PathBuf {
209 let current = PathBuf::from(".tensorlogicrc");
211 if current.exists() {
212 return current;
213 }
214
215 if let Some(home) = dirs::home_dir() {
217 home.join(".tensorlogicrc")
218 } else {
219 current
220 }
221 }
222
223 pub fn create_default() -> Result<PathBuf> {
225 let config = Self::default();
226 let path = Self::config_path();
227 config.save(&path)?;
228 Ok(path)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_default_config() {
238 let config = Config::default();
239 assert_eq!(config.strategy, "soft_differentiable");
240 assert!(config.domains.contains_key("D"));
241 assert_eq!(config.output_format, "graph");
242 }
243
244 #[test]
245 fn test_serialize_deserialize() {
246 let config = Config::default();
247 let toml_str = toml::to_string(&config).unwrap();
248 let deserialized: Config = toml::from_str(&toml_str).unwrap();
249 assert_eq!(config.strategy, deserialized.strategy);
250 }
251}