Skip to main content

tandem_core/
config.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Map, Value};
7use tokio::fs;
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ProviderConfig {
12    pub api_key: Option<String>,
13    pub url: Option<String>,
14    pub default_model: Option<String>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct AppConfig {
19    #[serde(default)]
20    pub providers: HashMap<String, ProviderConfig>,
21    pub default_provider: Option<String>,
22}
23
24#[derive(Debug, Clone, Default)]
25struct ConfigLayers {
26    global: Value,
27    project: Value,
28    managed: Value,
29    env: Value,
30    runtime: Value,
31    cli: Value,
32}
33
34#[derive(Clone)]
35pub struct ConfigStore {
36    project_path: PathBuf,
37    global_path: PathBuf,
38    managed_path: PathBuf,
39    layers: Arc<RwLock<ConfigLayers>>,
40}
41
42impl ConfigStore {
43    pub async fn new(path: impl AsRef<Path>, cli_overrides: Option<Value>) -> anyhow::Result<Self> {
44        let project_path = path.as_ref().to_path_buf();
45        if let Some(parent) = project_path.parent() {
46            fs::create_dir_all(parent).await?;
47        }
48        let managed_path = project_path
49            .parent()
50            .unwrap_or_else(|| Path::new("."))
51            .join("managed_config.json");
52        let global_path = resolve_global_config_path().await?;
53
54        let mut global = read_json_file(&global_path)
55            .await
56            .unwrap_or_else(|_| empty_object());
57        let mut project = read_json_file(&project_path)
58            .await
59            .unwrap_or_else(|_| empty_object());
60        let mut managed = read_json_file(&managed_path)
61            .await
62            .unwrap_or_else(|_| empty_object());
63
64        scrub_persisted_secrets(&mut global, Some(&global_path)).await?;
65        scrub_persisted_secrets(&mut project, Some(&project_path)).await?;
66        scrub_persisted_secrets(&mut managed, Some(&managed_path)).await?;
67
68        let layers = ConfigLayers {
69            global,
70            project,
71            managed,
72            env: env_layer(),
73            runtime: empty_object(),
74            cli: cli_overrides.unwrap_or_else(empty_object),
75        };
76
77        let store = Self {
78            project_path,
79            global_path,
80            managed_path,
81            layers: Arc::new(RwLock::new(layers)),
82        };
83        store.save_project().await?;
84        store.save_global().await?;
85        Ok(store)
86    }
87
88    pub async fn get(&self) -> AppConfig {
89        let merged = self.get_effective_value().await;
90        serde_json::from_value(merged).unwrap_or_default()
91    }
92
93    pub async fn get_effective_value(&self) -> Value {
94        let layers = self.layers.read().await.clone();
95        let mut merged = empty_object();
96        deep_merge(&mut merged, &layers.global);
97        deep_merge(&mut merged, &layers.project);
98        deep_merge(&mut merged, &layers.managed);
99        deep_merge(&mut merged, &layers.env);
100        deep_merge(&mut merged, &layers.runtime);
101        deep_merge(&mut merged, &layers.cli);
102        merged
103    }
104
105    pub async fn get_project_value(&self) -> Value {
106        self.layers.read().await.project.clone()
107    }
108
109    pub async fn get_global_value(&self) -> Value {
110        self.layers.read().await.global.clone()
111    }
112
113    pub async fn get_layers_value(&self) -> Value {
114        let layers = self.layers.read().await;
115        json!({
116            "global": layers.global,
117            "project": layers.project,
118            "managed": layers.managed,
119            "env": layers.env,
120            "runtime": layers.runtime,
121            "cli": layers.cli
122        })
123    }
124
125    pub async fn set(&self, config: AppConfig) -> anyhow::Result<()> {
126        let value = serde_json::to_value(config)?;
127        self.set_project_value(value).await
128    }
129
130    pub async fn patch_project(&self, patch: Value) -> anyhow::Result<Value> {
131        {
132            let mut layers = self.layers.write().await;
133            deep_merge(&mut layers.project, &patch);
134        }
135        self.save_project().await?;
136        Ok(self.get_effective_value().await)
137    }
138
139    pub async fn patch_global(&self, patch: Value) -> anyhow::Result<Value> {
140        {
141            let mut layers = self.layers.write().await;
142            deep_merge(&mut layers.global, &patch);
143        }
144        self.save_global().await?;
145        Ok(self.get_effective_value().await)
146    }
147
148    pub async fn patch_runtime(&self, patch: Value) -> anyhow::Result<Value> {
149        {
150            let mut layers = self.layers.write().await;
151            deep_merge(&mut layers.runtime, &patch);
152        }
153        Ok(self.get_effective_value().await)
154    }
155
156    pub async fn replace_project_value(&self, value: Value) -> anyhow::Result<Value> {
157        self.set_project_value(value).await?;
158        Ok(self.get_effective_value().await)
159    }
160
161    pub async fn delete_runtime_provider_key(&self, provider_id: &str) -> anyhow::Result<Value> {
162        let provider = provider_id.trim().to_string();
163        {
164            let mut layers = self.layers.write().await;
165            let Some(root) = layers.runtime.as_object_mut() else {
166                return Ok(self.get_effective_value().await);
167            };
168            let Some(providers) = root.get_mut("providers").and_then(|v| v.as_object_mut()) else {
169                return Ok(self.get_effective_value().await);
170            };
171            let existing_key = providers
172                .keys()
173                .find(|k| k.eq_ignore_ascii_case(&provider))
174                .cloned();
175            let Some(existing_key) = existing_key else {
176                return Ok(self.get_effective_value().await);
177            };
178            let Some(cfg) = providers
179                .get_mut(&existing_key)
180                .and_then(|v| v.as_object_mut())
181            else {
182                return Ok(self.get_effective_value().await);
183            };
184            cfg.remove("api_key");
185            cfg.remove("apiKey");
186            if cfg.is_empty() {
187                providers.remove(&existing_key);
188            }
189        }
190        Ok(self.get_effective_value().await)
191    }
192
193    async fn set_project_value(&self, value: Value) -> anyhow::Result<()> {
194        self.layers.write().await.project = value;
195        self.save_project().await
196    }
197
198    async fn save_project(&self) -> anyhow::Result<()> {
199        let snapshot = self.layers.read().await.project.clone();
200        write_json_file(&self.project_path, &snapshot).await
201    }
202
203    async fn save_global(&self) -> anyhow::Result<()> {
204        let snapshot = self.layers.read().await.global.clone();
205        write_json_file(&self.global_path, &snapshot).await
206    }
207
208    #[allow(dead_code)]
209    async fn save_managed(&self) -> anyhow::Result<()> {
210        let snapshot = self.layers.read().await.managed.clone();
211        write_json_file(&self.managed_path, &snapshot).await
212    }
213}
214
215fn empty_object() -> Value {
216    Value::Object(Map::new())
217}
218
219async fn write_json_file(path: &Path, value: &Value) -> anyhow::Result<()> {
220    if let Some(parent) = path.parent() {
221        fs::create_dir_all(parent).await?;
222    }
223    let mut to_write = value.clone();
224    if !is_legacy_opencode_path(path) {
225        strip_persisted_secrets(&mut to_write);
226    }
227    let raw = serde_json::to_string_pretty(&to_write)?;
228    fs::write(path, raw).await?;
229    Ok(())
230}
231
232fn strip_persisted_secrets(value: &mut Value) {
233    if let Value::Object(root) = value {
234        if let Some(channels) = root.get_mut("channels").and_then(|v| v.as_object_mut()) {
235            for channel in ["telegram", "discord", "slack"] {
236                if let Some(cfg) = channels.get_mut(channel).and_then(|v| v.as_object_mut()) {
237                    cfg.remove("bot_token");
238                    cfg.remove("botToken");
239                }
240            }
241        }
242
243        let Some(providers) = root.get_mut("providers").and_then(|v| v.as_object_mut()) else {
244            return;
245        };
246        for (provider_id, provider_cfg) in providers.iter_mut() {
247            let Value::Object(cfg) = provider_cfg else {
248                continue;
249            };
250            if !cfg.contains_key("api_key") && !cfg.contains_key("apiKey") {
251                continue;
252            }
253            if provider_has_runtime_secret(provider_id) {
254                cfg.remove("api_key");
255                cfg.remove("apiKey");
256            }
257        }
258    }
259}
260
261async fn scrub_persisted_secrets(value: &mut Value, path: Option<&Path>) -> anyhow::Result<()> {
262    if let Some(target) = path {
263        if is_legacy_opencode_path(target) {
264            return Ok(());
265        }
266    }
267    let before = value.clone();
268    strip_persisted_secrets(value);
269    if *value != before {
270        if let Some(target) = path {
271            write_json_file(target, value).await?;
272        }
273    }
274    Ok(())
275}
276
277fn is_legacy_opencode_path(path: &Path) -> bool {
278    path.to_string_lossy()
279        .to_ascii_lowercase()
280        .contains("opencode")
281}
282
283fn provider_has_runtime_secret(provider_id: &str) -> bool {
284    provider_env_candidates(provider_id).into_iter().any(|key| {
285        std::env::var(&key)
286            .map(|v| !v.trim().is_empty())
287            .unwrap_or(false)
288    })
289}
290
291fn provider_env_candidates(provider_id: &str) -> Vec<String> {
292    let normalized = provider_id
293        .chars()
294        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
295        .collect::<String>()
296        .to_ascii_uppercase();
297
298    let mut out = vec![format!("{}_API_KEY", normalized)];
299
300    match provider_id.to_ascii_lowercase().as_str() {
301        "openai" => out.push("OPENAI_API_KEY".to_string()),
302        "openrouter" => out.push("OPENROUTER_API_KEY".to_string()),
303        "anthropic" => out.push("ANTHROPIC_API_KEY".to_string()),
304        "groq" => out.push("GROQ_API_KEY".to_string()),
305        "mistral" => out.push("MISTRAL_API_KEY".to_string()),
306        "together" => out.push("TOGETHER_API_KEY".to_string()),
307        "azure" => out.push("AZURE_OPENAI_API_KEY".to_string()),
308        "vertex" => out.push("VERTEX_API_KEY".to_string()),
309        "bedrock" => out.push("BEDROCK_API_KEY".to_string()),
310        "copilot" => out.push("GITHUB_TOKEN".to_string()),
311        "cohere" => out.push("COHERE_API_KEY".to_string()),
312        "zen" | "opencode_zen" | "opencodezen" => out.push("OPENCODE_ZEN_API_KEY".to_string()),
313        _ => {}
314    }
315
316    out.sort();
317    out.dedup();
318    out
319}
320
321async fn read_json_file(path: &Path) -> anyhow::Result<Value> {
322    if !path.exists() {
323        return Ok(empty_object());
324    }
325    let raw = fs::read_to_string(path).await?;
326    Ok(serde_json::from_str::<Value>(&raw).unwrap_or_else(|_| empty_object()))
327}
328
329async fn resolve_global_config_path() -> anyhow::Result<PathBuf> {
330    if let Ok(path) = std::env::var("TANDEM_GLOBAL_CONFIG") {
331        let path = PathBuf::from(path);
332        if let Some(parent) = path.parent() {
333            fs::create_dir_all(parent).await?;
334        }
335        return Ok(path);
336    }
337    if let Some(config_dir) = dirs::config_dir() {
338        let path = config_dir.join("tandem").join("config.json");
339        if let Some(parent) = path.parent() {
340            fs::create_dir_all(parent).await?;
341        }
342        return Ok(path);
343    }
344    Ok(PathBuf::from(".tandem/global_config.json"))
345}
346
347fn env_layer() -> Value {
348    let mut root = empty_object();
349
350    if let Ok(enabled) = std::env::var("TANDEM_WEB_UI") {
351        if let Some(v) = parse_bool_like(&enabled) {
352            deep_merge(&mut root, &json!({ "web_ui": { "enabled": v } }));
353        }
354    }
355    if let Ok(prefix) = std::env::var("TANDEM_WEB_UI_PREFIX") {
356        if !prefix.trim().is_empty() {
357            deep_merge(&mut root, &json!({ "web_ui": { "path_prefix": prefix } }));
358        }
359    }
360    if let Ok(token) = std::env::var("TANDEM_TELEGRAM_BOT_TOKEN") {
361        if !token.trim().is_empty() {
362            let allowed = std::env::var("TANDEM_TELEGRAM_ALLOWED_USERS")
363                .map(|s| parse_csv(&s))
364                .unwrap_or_else(|_| vec!["*".to_string()]);
365            let mention_only = std::env::var("TANDEM_TELEGRAM_MENTION_ONLY")
366                .ok()
367                .and_then(|v| parse_bool_like(&v))
368                .unwrap_or(false);
369            deep_merge(
370                &mut root,
371                &json!({
372                    "channels": {
373                        "telegram": {
374                            "bot_token": token,
375                            "allowed_users": allowed,
376                            "mention_only": mention_only
377                        }
378                    }
379                }),
380            );
381        }
382    }
383    if let Ok(token) = std::env::var("TANDEM_DISCORD_BOT_TOKEN") {
384        if !token.trim().is_empty() {
385            let allowed = std::env::var("TANDEM_DISCORD_ALLOWED_USERS")
386                .map(|s| parse_csv(&s))
387                .unwrap_or_else(|_| vec!["*".to_string()]);
388            let mention_only = std::env::var("TANDEM_DISCORD_MENTION_ONLY")
389                .ok()
390                .and_then(|v| parse_bool_like(&v))
391                .unwrap_or(true);
392            let guild_id = std::env::var("TANDEM_DISCORD_GUILD_ID").ok();
393            deep_merge(
394                &mut root,
395                &json!({
396                    "channels": {
397                        "discord": {
398                            "bot_token": token,
399                            "guild_id": guild_id,
400                            "allowed_users": allowed,
401                            "mention_only": mention_only
402                        }
403                    }
404                }),
405            );
406        }
407    }
408    if let Ok(token) = std::env::var("TANDEM_SLACK_BOT_TOKEN") {
409        if !token.trim().is_empty() {
410            if let Ok(channel_id) = std::env::var("TANDEM_SLACK_CHANNEL_ID") {
411                if !channel_id.trim().is_empty() {
412                    let allowed = std::env::var("TANDEM_SLACK_ALLOWED_USERS")
413                        .map(|s| parse_csv(&s))
414                        .unwrap_or_else(|_| vec!["*".to_string()]);
415                    deep_merge(
416                        &mut root,
417                        &json!({
418                            "channels": {
419                                "slack": {
420                                    "bot_token": token,
421                                    "channel_id": channel_id,
422                                    "allowed_users": allowed
423                                }
424                            }
425                        }),
426                    );
427                }
428            }
429        }
430    }
431
432    if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
433        deep_merge(
434            &mut root,
435            &json!({
436                "providers": {
437                    "openai": {
438                        "api_key": api_key,
439                        "url": "https://api.openai.com/v1",
440                        "default_model": "gpt-4o-mini"
441                    }
442                }
443            }),
444        );
445    }
446    add_openai_env(
447        &mut root,
448        "openrouter",
449        "OPENROUTER_API_KEY",
450        "https://openrouter.ai/api/v1",
451        "openai/gpt-4o-mini",
452    );
453    add_openai_env(
454        &mut root,
455        "groq",
456        "GROQ_API_KEY",
457        "https://api.groq.com/openai/v1",
458        "llama-3.1-8b-instant",
459    );
460    add_openai_env(
461        &mut root,
462        "mistral",
463        "MISTRAL_API_KEY",
464        "https://api.mistral.ai/v1",
465        "mistral-small-latest",
466    );
467    add_openai_env(
468        &mut root,
469        "together",
470        "TOGETHER_API_KEY",
471        "https://api.together.xyz/v1",
472        "meta-llama/Llama-3.1-8B-Instruct-Turbo",
473    );
474    add_openai_env(
475        &mut root,
476        "azure",
477        "AZURE_OPENAI_API_KEY",
478        "https://example.openai.azure.com/openai/deployments/default",
479        "gpt-4o-mini",
480    );
481    add_openai_env(
482        &mut root,
483        "vertex",
484        "VERTEX_API_KEY",
485        "https://aiplatform.googleapis.com/v1",
486        "gemini-1.5-flash",
487    );
488    add_openai_env(
489        &mut root,
490        "bedrock",
491        "BEDROCK_API_KEY",
492        "https://bedrock-runtime.us-east-1.amazonaws.com",
493        "anthropic.claude-3-5-sonnet-20240620-v1:0",
494    );
495    add_openai_env(
496        &mut root,
497        "copilot",
498        "GITHUB_TOKEN",
499        "https://api.githubcopilot.com",
500        "gpt-4o-mini",
501    );
502    add_openai_env(
503        &mut root,
504        "cohere",
505        "COHERE_API_KEY",
506        "https://api.cohere.com/v2",
507        "command-r-plus",
508    );
509    if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
510        deep_merge(
511            &mut root,
512            &json!({
513                "providers": {
514                    "anthropic": {
515                        "api_key": api_key,
516                        "url": "https://api.anthropic.com/v1",
517                        "default_model": "claude-3-5-sonnet-latest"
518                    }
519                }
520            }),
521        );
522    }
523    if let Ok(ollama_url) = std::env::var("OLLAMA_URL") {
524        deep_merge(
525            &mut root,
526            &json!({
527                "providers": {
528                    "ollama": {
529                        "url": ollama_url,
530                        "default_model": "llama3.1:8b"
531                    }
532                }
533            }),
534        );
535    } else if std::net::TcpStream::connect("127.0.0.1:11434").is_ok() {
536        deep_merge(
537            &mut root,
538            &json!({
539                "providers": {
540                    "ollama": {
541                        "url": "http://127.0.0.1:11434/v1",
542                        "default_model": "llama3.1:8b"
543                    }
544                }
545            }),
546        );
547    }
548
549    root
550}
551
552fn parse_bool_like(raw: &str) -> Option<bool> {
553    match raw.trim().to_ascii_lowercase().as_str() {
554        "1" | "true" | "yes" | "on" => Some(true),
555        "0" | "false" | "no" | "off" => Some(false),
556        _ => None,
557    }
558}
559
560fn parse_csv(raw: &str) -> Vec<String> {
561    if raw.trim() == "*" {
562        return vec!["*".to_string()];
563    }
564    raw.split(',')
565        .map(|s| s.trim().to_string())
566        .filter(|s| !s.is_empty())
567        .collect()
568}
569
570fn add_openai_env(root: &mut Value, provider: &str, key_env: &str, default_url: &str, model: &str) {
571    if let Ok(api_key) = std::env::var(key_env) {
572        deep_merge(
573            root,
574            &json!({
575                "providers": {
576                    provider: {
577                        "api_key": api_key,
578                        "url": default_url,
579                        "default_model": model
580                    }
581                }
582            }),
583        );
584    }
585}
586
587fn deep_merge(base: &mut Value, overlay: &Value) {
588    if overlay.is_null() {
589        return;
590    }
591    match (base, overlay) {
592        (Value::Object(base_map), Value::Object(overlay_map)) => {
593            for (key, value) in overlay_map {
594                if value.is_null() {
595                    continue;
596                }
597                match base_map.get_mut(key) {
598                    Some(existing) => deep_merge(existing, value),
599                    None => {
600                        base_map.insert(key.clone(), value.clone());
601                    }
602                }
603            }
604        }
605        (base_value, overlay_value) => {
606            *base_value = overlay_value.clone();
607        }
608    }
609}
610
611impl From<ProviderConfig> for tandem_providers::ProviderConfig {
612    fn from(value: ProviderConfig) -> Self {
613        Self {
614            api_key: value.api_key,
615            url: value.url,
616            default_model: value.default_model,
617        }
618    }
619}
620
621impl From<AppConfig> for tandem_providers::AppConfig {
622    fn from(value: AppConfig) -> Self {
623        Self {
624            providers: value
625                .providers
626                .into_iter()
627                .map(|(k, v)| (k, v.into()))
628                .collect(),
629            default_provider: value.default_provider,
630        }
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use std::time::{SystemTime, UNIX_EPOCH};
638
639    fn unique_temp_file(name: &str) -> PathBuf {
640        let mut path = std::env::temp_dir();
641        let ts = SystemTime::now()
642            .duration_since(UNIX_EPOCH)
643            .map(|d| d.as_nanos())
644            .unwrap_or(0);
645        path.push(format!("tandem-core-config-{name}-{ts}.json"));
646        path
647    }
648
649    #[test]
650    fn strip_persisted_secrets_removes_channel_bot_tokens() {
651        let mut value = json!({
652            "channels": {
653                "telegram": {
654                    "bot_token": "tg-secret",
655                    "allowed_users": ["*"]
656                },
657                "discord": {
658                    "botToken": "dc-secret",
659                    "allowed_users": ["*"],
660                    "mention_only": true
661                },
662                "slack": {
663                    "bot_token": "sl-secret",
664                    "channel_id": "C123"
665                }
666            },
667            "providers": {}
668        });
669
670        strip_persisted_secrets(&mut value);
671
672        assert!(value
673            .get("channels")
674            .and_then(|v| v.get("telegram"))
675            .and_then(Value::as_object)
676            .is_some_and(|obj| !obj.contains_key("bot_token")));
677        assert!(value
678            .get("channels")
679            .and_then(|v| v.get("discord"))
680            .and_then(Value::as_object)
681            .is_some_and(|obj| !obj.contains_key("botToken")));
682        assert!(value
683            .get("channels")
684            .and_then(|v| v.get("slack"))
685            .and_then(Value::as_object)
686            .is_some_and(|obj| !obj.contains_key("bot_token")));
687    }
688
689    #[tokio::test]
690    async fn scrub_persisted_secrets_rewrites_channel_tokens_on_disk() {
691        let path = unique_temp_file("scrub");
692        let original = json!({
693            "channels": {
694                "telegram": {
695                    "bot_token": "tg-secret",
696                    "allowed_users": ["@alice"]
697                }
698            },
699            "providers": {}
700        });
701        let raw = serde_json::to_string_pretty(&original).expect("serialize");
702        fs::write(&path, raw).await.expect("write");
703
704        let mut loaded =
705            serde_json::from_str::<Value>(&fs::read_to_string(&path).await.expect("read before"))
706                .expect("parse");
707
708        scrub_persisted_secrets(&mut loaded, Some(&path))
709            .await
710            .expect("scrub");
711
712        let persisted =
713            serde_json::from_str::<Value>(&fs::read_to_string(&path).await.expect("read after"))
714                .expect("parse persisted");
715        assert!(persisted
716            .get("channels")
717            .and_then(|v| v.get("telegram"))
718            .and_then(Value::as_object)
719            .is_some_and(|obj| !obj.contains_key("bot_token")));
720
721        let _ = fs::remove_file(&path).await;
722    }
723}