Skip to main content

typesec_agent/
tool.rs

1//! Capability-bound tool wrappers for agent and MCP-style tool execution.
2
3use std::any::Any;
4use std::collections::HashMap;
5use std::marker::PhantomData;
6use std::{future::Future, pin::Pin};
7
8use typesec_core::{Capability, Permission, Resource, typestate::Authenticated};
9
10use crate::{SecureAgent, executor::TaskError};
11
12/// Boxed future returned by protected tool handlers.
13pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>>;
14
15/// Metadata describing the authorization boundary for a protected tool.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ToolSpec {
18    /// Tool name exposed to an agent or MCP client.
19    pub name: String,
20    /// Human-readable description.
21    pub description: String,
22    /// Permission required to invoke this tool.
23    pub required_permission: &'static str,
24    /// Resource identifier the permission applies to.
25    pub resource_id: String,
26}
27
28/// A tool that cannot run unless the caller supplies a matching capability.
29pub struct ProtectedTool<P, R, F>
30where
31    P: Permission,
32    R: Resource,
33{
34    spec: ToolSpec,
35    resource: R,
36    action: F,
37    _permission: PhantomData<fn() -> P>,
38}
39
40impl<P, R, F> ProtectedTool<P, R, F>
41where
42    P: Permission,
43    R: Resource,
44{
45    /// Create a new protected tool.
46    pub fn new(
47        name: impl Into<String>,
48        description: impl Into<String>,
49        resource: R,
50        action: F,
51    ) -> Self {
52        let resource_id = resource.resource_id().to_string();
53        Self {
54            spec: ToolSpec {
55                name: name.into(),
56                description: description.into(),
57                required_permission: P::name(),
58                resource_id,
59            },
60            resource,
61            action,
62            _permission: PhantomData,
63        }
64    }
65
66    /// Return this tool's authorization metadata.
67    pub fn spec(&self) -> &ToolSpec {
68        &self.spec
69    }
70}
71
72impl<P, R, F> ProtectedTool<P, R, F>
73where
74    P: Permission,
75    R: Resource,
76    F: for<'a> Fn(&'a R) -> ToolFuture<'a>,
77{
78    /// Invoke the tool with a typed capability.
79    pub async fn invoke(
80        &self,
81        agent: &SecureAgent<Authenticated>,
82        cap: &Capability<P, R>,
83    ) -> Result<(), TaskError> {
84        if cap.subject() != agent.subject() {
85            return Err(TaskError::CapabilityMismatch(format!(
86                "capability was minted for subject '{}', not '{}'",
87                cap.subject(),
88                agent.subject()
89            )));
90        }
91        if cap.resource_id() != self.resource.resource_id() {
92            return Err(TaskError::CapabilityMismatch(format!(
93                "capability covers resource '{}', not '{}'",
94                cap.resource_id(),
95                self.resource.resource_id()
96            )));
97        }
98        cap.ensure_active()?;
99
100        tracing::info!(
101            subject = %agent.subject(),
102            permission = %Capability::<P, R>::permission_name(),
103            resource = %cap.resource_id(),
104            tool = %self.spec.name,
105            "invoking protected tool"
106        );
107        (self.action)(&self.resource).await
108    }
109}
110
111trait ErasedTool: Send + Sync {
112    fn spec(&self) -> &ToolSpec;
113
114    fn invoke_erased<'a>(
115        &'a self,
116        agent: &'a SecureAgent<Authenticated>,
117        cap: &'a (dyn Any + Send + Sync),
118    ) -> ToolFuture<'a>;
119}
120
121impl<P, R, F> ErasedTool for ProtectedTool<P, R, F>
122where
123    P: Permission + 'static,
124    R: Resource + 'static,
125    F: for<'a> Fn(&'a R) -> ToolFuture<'a> + Send + Sync + 'static,
126{
127    fn spec(&self) -> &ToolSpec {
128        &self.spec
129    }
130
131    fn invoke_erased<'a>(
132        &'a self,
133        agent: &'a SecureAgent<Authenticated>,
134        cap: &'a (dyn Any + Send + Sync),
135    ) -> ToolFuture<'a> {
136        let Some(cap) = cap.downcast_ref::<Capability<P, R>>() else {
137            return Box::pin(async move {
138                Err(TaskError::CapabilityMismatch(format!(
139                    "tool '{}' requires Capability<{}, {}>",
140                    self.spec.name,
141                    P::name(),
142                    R::resource_type()
143                )))
144            });
145        };
146
147        Box::pin(async move { self.invoke(agent, cap).await })
148    }
149}
150
151/// Registry for named capability-protected tools.
152#[derive(Default)]
153pub struct ToolRegistry {
154    tools: HashMap<String, Box<dyn ErasedTool>>,
155}
156
157impl ToolRegistry {
158    /// Create an empty registry.
159    pub fn new() -> Self {
160        Self::default()
161    }
162
163    /// Register a protected tool by its exposed name.
164    ///
165    /// Registering another tool with the same name replaces the previous one.
166    pub fn register<P, R, F>(&mut self, tool: ProtectedTool<P, R, F>)
167    where
168        P: Permission + 'static,
169        R: Resource + 'static,
170        F: for<'a> Fn(&'a R) -> ToolFuture<'a> + Send + Sync + 'static,
171    {
172        self.tools.insert(tool.spec.name.clone(), Box::new(tool));
173    }
174
175    /// Return metadata for every registered tool.
176    pub fn list_specs(&self) -> Vec<ToolSpec> {
177        self.tools
178            .values()
179            .map(|tool| tool.spec().clone())
180            .collect()
181    }
182
183    /// Return metadata for one registered tool.
184    pub fn spec(&self, name: &str) -> Option<&ToolSpec> {
185        self.tools.get(name).map(|tool| tool.spec())
186    }
187
188    /// Invoke a named tool with an erased capability.
189    pub async fn invoke(
190        &self,
191        name: &str,
192        agent: &SecureAgent<Authenticated>,
193        cap: &(dyn Any + Send + Sync),
194    ) -> Result<(), TaskError> {
195        let tool = self
196            .tools
197            .get(name)
198            .ok_or_else(|| TaskError::UnknownTool(name.to_owned()))?;
199        tool.invoke_erased(agent, cap).await
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use std::sync::Arc;
207    use typesec_core::{
208        CanExecute, CanRead, ResourceId, SubjectId,
209        policy::{PolicyEngine, PolicyResult},
210        resource::GenericResource,
211        typestate::Credentials,
212    };
213
214    struct AllowAll;
215    impl PolicyEngine for AllowAll {
216        fn check(&self, _: &SubjectId, _: &str, _: &ResourceId) -> PolicyResult {
217            PolicyResult::Allow
218        }
219    }
220
221    fn no_op_tool(_resource: &GenericResource) -> ToolFuture<'_> {
222        Box::pin(async { Ok(()) })
223    }
224
225    #[tokio::test]
226    async fn protected_tool_invokes_with_capability() {
227        let agent = SecureAgent::new(Arc::new(AllowAll))
228            .authenticate_unverified(Credentials::new("agent:test", "tok"))
229            .expect("auth ok");
230        let resource = GenericResource::new("Gmail.ListEmails", "tool");
231        let cap: Capability<CanExecute, GenericResource> =
232            agent.request_capability(&resource).await.expect("cap ok");
233        let tool = ProtectedTool::<CanExecute, _, _>::new(
234            "gmail.list",
235            "List email messages",
236            resource,
237            no_op_tool,
238        );
239
240        assert_eq!(tool.spec().required_permission, "execute");
241        tool.invoke(&agent, &cap).await.expect("tool should run");
242    }
243
244    #[tokio::test]
245    async fn protected_tool_rejects_capability_for_other_resource() {
246        let agent = SecureAgent::new(Arc::new(AllowAll))
247            .authenticate_unverified(Credentials::new("agent:test", "tok"))
248            .expect("auth ok");
249        let cap_resource = GenericResource::new("Gmail.ListEmails", "tool");
250        let tool_resource = GenericResource::new("Gmail.DeleteEmail", "tool");
251        let cap: Capability<CanExecute, GenericResource> = agent
252            .request_capability(&cap_resource)
253            .await
254            .expect("cap ok");
255        let tool = ProtectedTool::<CanExecute, _, _>::new(
256            "gmail.delete",
257            "Delete an email message",
258            tool_resource,
259            no_op_tool,
260        );
261
262        assert!(matches!(
263            tool.invoke(&agent, &cap).await,
264            Err(TaskError::CapabilityMismatch(reason)) if reason.contains("Gmail.ListEmails")
265        ));
266    }
267
268    #[tokio::test]
269    async fn registry_lists_and_invokes_registered_tools() {
270        let agent = SecureAgent::new(Arc::new(AllowAll))
271            .authenticate_unverified(Credentials::new("agent:test", "tok"))
272            .expect("auth ok");
273        let resource = GenericResource::new("Gmail.ListEmails", "tool");
274        let cap: Capability<CanExecute, GenericResource> =
275            agent.request_capability(&resource).await.expect("cap ok");
276        let mut registry = ToolRegistry::new();
277        registry.register(ProtectedTool::<CanExecute, _, _>::new(
278            "gmail.list",
279            "List email messages",
280            resource,
281            no_op_tool,
282        ));
283
284        let specs = registry.list_specs();
285        assert_eq!(specs.len(), 1);
286        assert_eq!(specs[0].name, "gmail.list");
287        assert_eq!(
288            registry
289                .spec("gmail.list")
290                .expect("registered spec")
291                .required_permission,
292            "execute"
293        );
294
295        registry
296            .invoke("gmail.list", &agent, &cap)
297            .await
298            .expect("registry invoke should run");
299    }
300
301    #[tokio::test]
302    async fn registry_rejects_wrong_capability_type_and_unknown_tool() {
303        let agent = SecureAgent::new(Arc::new(AllowAll))
304            .authenticate_unverified(Credentials::new("agent:test", "tok"))
305            .expect("auth ok");
306        let resource = GenericResource::new("Gmail.ListEmails", "tool");
307        let read_cap: Capability<CanRead, GenericResource> =
308            agent.request_capability(&resource).await.expect("cap ok");
309        let mut registry = ToolRegistry::new();
310        registry.register(ProtectedTool::<CanExecute, _, _>::new(
311            "gmail.list",
312            "List email messages",
313            resource,
314            no_op_tool,
315        ));
316
317        assert!(matches!(
318            registry.invoke("gmail.list", &agent, &read_cap).await,
319            Err(TaskError::CapabilityMismatch(reason)) if reason.contains("Capability<execute")
320        ));
321        assert!(matches!(
322            registry.invoke("missing", &agent, &read_cap).await,
323            Err(TaskError::UnknownTool(name)) if name == "missing"
324        ));
325    }
326}