1use std::path::{Path, PathBuf};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SourceConfig {
8 #[serde(default = "default_true")]
9 pub enabled: bool,
10 #[serde(default = "default_timeout")]
11 pub timeout: f64,
12 #[serde(default = "default_max_retries")]
13 pub max_retries: u32,
14 #[serde(default)]
15 pub api_key: String,
16}
17
18impl Default for SourceConfig {
19 fn default() -> Self {
20 Self {
21 enabled: true,
22 timeout: 30.0,
23 max_retries: 3,
24 api_key: String::new(),
25 }
26 }
27}
28
29fn default_true() -> bool {
30 true
31}
32
33fn default_timeout() -> f64 {
34 30.0
35}
36
37fn default_max_retries() -> u32 {
38 3
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct EpoConfig {
44 #[serde(default = "default_true")]
45 pub enabled: bool,
46 #[serde(default = "default_timeout")]
47 pub timeout: f64,
48 #[serde(default = "default_max_retries")]
49 pub max_retries: u32,
50 #[serde(default)]
51 pub consumer_key: String,
52 #[serde(default)]
53 pub consumer_secret: String,
54}
55
56impl Default for EpoConfig {
57 fn default() -> Self {
58 Self {
59 enabled: true,
60 timeout: 30.0,
61 max_retries: 3,
62 consumer_key: String::new(),
63 consumer_secret: String::new(),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ChatConfig {
71 #[serde(default = "default_model")]
72 pub model: String,
73 #[serde(default = "default_max_tokens")]
74 pub max_tokens: u32,
75 #[serde(default = "default_scoring_concurrency")]
76 pub scoring_concurrency: u32,
77}
78
79impl Default for ChatConfig {
80 fn default() -> Self {
81 Self {
82 model: "claude-sonnet-4-6".to_string(),
83 max_tokens: 4096,
84 scoring_concurrency: 5,
85 }
86 }
87}
88
89fn default_model() -> String {
90 "claude-sonnet-4-6".to_string()
91}
92
93fn default_max_tokens() -> u32 {
94 4096
95}
96
97fn default_scoring_concurrency() -> u32 {
98 5
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct UiConfig {
104 #[serde(default = "default_true")]
108 pub show_institutional_hint: bool,
109}
110
111impl Default for UiConfig {
112 fn default() -> Self {
113 Self {
114 show_institutional_hint: true,
115 }
116 }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct Config {
122 pub db_path: PathBuf,
123 #[serde(default = "default_sources")]
124 pub default_sources: Vec<String>,
125 #[serde(default)]
126 pub pubmed: SourceConfig,
127 #[serde(default)]
128 pub arxiv: SourceConfig,
129 #[serde(default)]
130 pub openalex: SourceConfig,
131 #[serde(default)]
132 pub inspire: SourceConfig,
133 #[serde(default)]
134 pub patentsview: SourceConfig,
135 #[serde(default)]
136 pub lens: SourceConfig,
137 #[serde(default)]
138 pub epo: EpoConfig,
139 #[serde(default)]
140 pub chat: ChatConfig,
141 #[serde(default)]
142 pub ui: UiConfig,
143}
144
145fn default_sources() -> Vec<String> {
146 vec![
147 "pubmed".into(),
148 "arxiv".into(),
149 "openalex".into(),
150 "inspire".into(),
151 ]
152}
153
154impl Config {
155 pub fn papers_dir(&self) -> PathBuf {
157 self.db_path
158 .parent()
159 .unwrap_or_else(|| Path::new("."))
160 .join("papers")
161 }
162}
163
164impl Default for Config {
165 fn default() -> Self {
166 let workspace = find_workspace_root().unwrap_or_else(|| std::env::current_dir().unwrap());
167 Self {
168 db_path: default_db_path(&workspace),
169 default_sources: default_sources(),
170 pubmed: SourceConfig::default(),
171 arxiv: SourceConfig::default(),
172 openalex: SourceConfig::default(),
173 inspire: SourceConfig::default(),
174 patentsview: SourceConfig::default(),
175 lens: SourceConfig::default(),
176 epo: EpoConfig::default(),
177 chat: ChatConfig::default(),
178 ui: UiConfig::default(),
179 }
180 }
181}
182
183fn find_workspace_root() -> Option<PathBuf> {
185 let output = std::process::Command::new("git")
186 .args(["rev-parse", "--show-toplevel"])
187 .output()
188 .ok()?;
189 if output.status.success() {
190 let path = String::from_utf8(output.stdout).ok()?;
191 Some(PathBuf::from(path.trim()))
192 } else {
193 None
194 }
195}
196
197fn default_db_path(workspace: &Path) -> PathBuf {
201 if let Ok(db) = std::env::var("SCITADEL_DB") {
202 let expanded = if db.starts_with('~') {
203 if let Ok(home) = std::env::var("HOME") {
204 db.replacen('~', &home, 1)
205 } else {
206 db
207 }
208 } else {
209 db
210 };
211 return PathBuf::from(expanded);
212 }
213 workspace.join(".scitadel").join("scitadel.db")
214}
215
216pub fn load_config() -> Config {
220 use crate::credentials::resolve;
221
222 let workspace = find_workspace_root().unwrap_or_else(|| std::env::current_dir().unwrap());
223 let db_path = default_db_path(&workspace);
224
225 let config_path = workspace.join(".scitadel").join("config.toml");
227 let mut config: Config = std::fs::read_to_string(&config_path)
228 .ok()
229 .and_then(|contents| toml::from_str(&contents).ok())
230 .unwrap_or_default();
231
232 config.db_path = db_path;
233
234 config.pubmed.api_key = resolve(
236 "pubmed.api_key",
237 "SCITADEL_PUBMED_API_KEY",
238 &config.pubmed.api_key,
239 )
240 .unwrap_or_default();
241
242 config.openalex.api_key = resolve(
243 "openalex.email",
244 "SCITADEL_OPENALEX_EMAIL",
245 &config.openalex.api_key,
246 )
247 .unwrap_or_default();
248
249 config.patentsview.api_key = resolve(
250 "patentsview.api_key",
251 "SCITADEL_PATENTSVIEW_KEY",
252 &config.patentsview.api_key,
253 )
254 .unwrap_or_default();
255
256 config.lens.api_key = resolve(
257 "lens.api_token",
258 "SCITADEL_LENS_TOKEN",
259 &config.lens.api_key,
260 )
261 .unwrap_or_default();
262
263 config.epo.consumer_key = resolve(
264 "epo.consumer_key",
265 "SCITADEL_EPO_KEY",
266 &config.epo.consumer_key,
267 )
268 .unwrap_or_default();
269
270 config.epo.consumer_secret = resolve(
271 "epo.consumer_secret",
272 "SCITADEL_EPO_SECRET",
273 &config.epo.consumer_secret,
274 )
275 .unwrap_or_default();
276
277 if let Ok(model) = std::env::var("SCITADEL_CHAT_MODEL") {
279 config.chat.model = model;
280 }
281 if let Ok(tokens) = std::env::var("SCITADEL_CHAT_MAX_TOKENS")
282 && let Ok(v) = tokens.parse()
283 {
284 config.chat.max_tokens = v;
285 }
286 if let Ok(conc) = std::env::var("SCITADEL_SCORING_CONCURRENCY")
287 && let Ok(v) = conc.parse()
288 {
289 config.chat.scoring_concurrency = v;
290 }
291
292 config
293}
294
295pub fn load_config_from(path: &Path) -> Result<Config, crate::error::CoreError> {
297 let contents = std::fs::read_to_string(path)?;
298 toml::from_str(&contents).map_err(|e| crate::error::CoreError::Config(e.to_string()))
299}