Skip to main content

tandem_core/
plugins.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use tokio::fs;
8use tokio::sync::RwLock;
9
10use crate::permissions::PermissionAction;
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
13pub struct PluginManifest {
14    pub name: String,
15    #[serde(default = "default_true")]
16    pub enabled: bool,
17    pub system_prompt_prefix: Option<String>,
18    pub system_prompt_suffix: Option<String>,
19    #[serde(default)]
20    pub allow_tools: Vec<String>,
21    #[serde(default)]
22    pub deny_tools: Vec<String>,
23    #[serde(default)]
24    pub shell_env: HashMap<String, String>,
25    pub tool_output_suffix: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct PluginRegistry {
30    plugins: Arc<RwLock<Vec<PluginManifest>>>,
31}
32
33impl PluginRegistry {
34    pub async fn new(workspace_root: impl Into<PathBuf>) -> anyhow::Result<Self> {
35        let root: PathBuf = workspace_root.into();
36        let plugins = load_plugins(root.join(".tandem").join("plugins")).await?;
37        Ok(Self {
38            plugins: Arc::new(RwLock::new(plugins)),
39        })
40    }
41
42    pub async fn list(&self) -> Vec<PluginManifest> {
43        self.plugins.read().await.clone()
44    }
45
46    pub async fn transform_prompt(&self, prompt: String) -> String {
47        let plugins = self.plugins.read().await;
48        let mut transformed = prompt;
49        for plugin in plugins.iter().filter(|p| p.enabled) {
50            if let Some(prefix) = &plugin.system_prompt_prefix {
51                transformed = format!("{prefix}\n\n{transformed}");
52            }
53            if let Some(suffix) = &plugin.system_prompt_suffix {
54                transformed = format!("{transformed}\n\n{suffix}");
55            }
56        }
57        transformed
58    }
59
60    pub async fn permission_override(&self, tool_name: &str) -> Option<PermissionAction> {
61        let plugins = self.plugins.read().await;
62        let mut action = None;
63        for plugin in plugins.iter().filter(|p| p.enabled) {
64            if plugin.deny_tools.iter().any(|t| t == tool_name) {
65                action = Some(PermissionAction::Deny);
66            }
67            if plugin.allow_tools.iter().any(|t| t == tool_name) {
68                action = Some(PermissionAction::Allow);
69            }
70        }
71        action
72    }
73
74    pub async fn inject_tool_args(&self, tool_name: &str, mut args: Value) -> Value {
75        if tool_name != "bash" {
76            return args;
77        }
78
79        let plugins = self.plugins.read().await;
80        let mut merged_env = serde_json::Map::new();
81        for plugin in plugins.iter().filter(|p| p.enabled) {
82            for (k, v) in &plugin.shell_env {
83                merged_env.insert(k.clone(), Value::String(v.clone()));
84            }
85        }
86        if !merged_env.is_empty() {
87            args["env"] = Value::Object(merged_env);
88        }
89        args
90    }
91
92    pub async fn transform_tool_output(&self, output: String) -> String {
93        let plugins = self.plugins.read().await;
94        let mut transformed = output;
95        for plugin in plugins.iter().filter(|p| p.enabled) {
96            if let Some(suffix) = &plugin.tool_output_suffix {
97                transformed = format!("{transformed}\n{suffix}");
98            }
99        }
100        transformed
101    }
102}
103
104fn default_true() -> bool {
105    true
106}
107
108async fn load_plugins(dir: PathBuf) -> anyhow::Result<Vec<PluginManifest>> {
109    let mut out = Vec::new();
110    let mut entries = match fs::read_dir(&dir).await {
111        Ok(rd) => rd,
112        Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(out),
113        Err(err) => return Err(err.into()),
114    };
115
116    while let Some(entry) = entries.next_entry().await? {
117        let path = entry.path();
118        let Some(ext) = path.extension().and_then(|v| v.to_str()) else {
119            continue;
120        };
121        if ext != "json" {
122            continue;
123        }
124        let raw = fs::read_to_string(&path).await?;
125        if let Ok(parsed) = serde_json::from_str::<PluginManifest>(&raw) {
126            out.push(parsed);
127        }
128    }
129    out.sort_by(|a, b| a.name.cmp(&b.name));
130    Ok(out)
131}