1use std::any::Any;
4use std::collections::HashMap;
5use std::marker::PhantomData;
6use std::{future::Future, pin::Pin};
7
8use typesec_core::{Capability, Permission, Resource, typestate::Authenticated};
9
10use crate::{SecureAgent, executor::TaskError};
11
12pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>>;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ToolSpec {
18 pub name: String,
20 pub description: String,
22 pub required_permission: &'static str,
24 pub resource_id: String,
26}
27
28pub struct ProtectedTool<P, R, F>
30where
31 P: Permission,
32 R: Resource,
33{
34 spec: ToolSpec,
35 resource: R,
36 action: F,
37 _permission: PhantomData<fn() -> P>,
38}
39
40impl<P, R, F> ProtectedTool<P, R, F>
41where
42 P: Permission,
43 R: Resource,
44{
45 pub fn new(
47 name: impl Into<String>,
48 description: impl Into<String>,
49 resource: R,
50 action: F,
51 ) -> Self {
52 let resource_id = resource.resource_id().to_string();
53 Self {
54 spec: ToolSpec {
55 name: name.into(),
56 description: description.into(),
57 required_permission: P::name(),
58 resource_id,
59 },
60 resource,
61 action,
62 _permission: PhantomData,
63 }
64 }
65
66 pub fn spec(&self) -> &ToolSpec {
68 &self.spec
69 }
70}
71
72impl<P, R, F> ProtectedTool<P, R, F>
73where
74 P: Permission,
75 R: Resource,
76 F: for<'a> Fn(&'a R) -> ToolFuture<'a>,
77{
78 pub async fn invoke(
80 &self,
81 agent: &SecureAgent<Authenticated>,
82 cap: &Capability<P, R>,
83 ) -> Result<(), TaskError> {
84 if cap.subject() != agent.subject() {
85 return Err(TaskError::CapabilityMismatch(format!(
86 "capability was minted for subject '{}', not '{}'",
87 cap.subject(),
88 agent.subject()
89 )));
90 }
91 if cap.resource_id() != self.resource.resource_id() {
92 return Err(TaskError::CapabilityMismatch(format!(
93 "capability covers resource '{}', not '{}'",
94 cap.resource_id(),
95 self.resource.resource_id()
96 )));
97 }
98 cap.ensure_active()?;
99
100 tracing::info!(
101 subject = %agent.subject(),
102 permission = %Capability::<P, R>::permission_name(),
103 resource = %cap.resource_id(),
104 tool = %self.spec.name,
105 "invoking protected tool"
106 );
107 (self.action)(&self.resource).await
108 }
109}
110
111trait ErasedTool: Send + Sync {
112 fn spec(&self) -> &ToolSpec;
113
114 fn invoke_erased<'a>(
115 &'a self,
116 agent: &'a SecureAgent<Authenticated>,
117 cap: &'a (dyn Any + Send + Sync),
118 ) -> ToolFuture<'a>;
119}
120
121impl<P, R, F> ErasedTool for ProtectedTool<P, R, F>
122where
123 P: Permission + 'static,
124 R: Resource + 'static,
125 F: for<'a> Fn(&'a R) -> ToolFuture<'a> + Send + Sync + 'static,
126{
127 fn spec(&self) -> &ToolSpec {
128 &self.spec
129 }
130
131 fn invoke_erased<'a>(
132 &'a self,
133 agent: &'a SecureAgent<Authenticated>,
134 cap: &'a (dyn Any + Send + Sync),
135 ) -> ToolFuture<'a> {
136 let Some(cap) = cap.downcast_ref::<Capability<P, R>>() else {
137 return Box::pin(async move {
138 Err(TaskError::CapabilityMismatch(format!(
139 "tool '{}' requires Capability<{}, {}>",
140 self.spec.name,
141 P::name(),
142 R::resource_type()
143 )))
144 });
145 };
146
147 Box::pin(async move { self.invoke(agent, cap).await })
148 }
149}
150
151#[derive(Default)]
153pub struct ToolRegistry {
154 tools: HashMap<String, Box<dyn ErasedTool>>,
155}
156
157impl ToolRegistry {
158 pub fn new() -> Self {
160 Self::default()
161 }
162
163 pub fn register<P, R, F>(&mut self, tool: ProtectedTool<P, R, F>)
167 where
168 P: Permission + 'static,
169 R: Resource + 'static,
170 F: for<'a> Fn(&'a R) -> ToolFuture<'a> + Send + Sync + 'static,
171 {
172 self.tools.insert(tool.spec.name.clone(), Box::new(tool));
173 }
174
175 pub fn list_specs(&self) -> Vec<ToolSpec> {
177 self.tools
178 .values()
179 .map(|tool| tool.spec().clone())
180 .collect()
181 }
182
183 pub fn spec(&self, name: &str) -> Option<&ToolSpec> {
185 self.tools.get(name).map(|tool| tool.spec())
186 }
187
188 pub async fn invoke(
190 &self,
191 name: &str,
192 agent: &SecureAgent<Authenticated>,
193 cap: &(dyn Any + Send + Sync),
194 ) -> Result<(), TaskError> {
195 let tool = self
196 .tools
197 .get(name)
198 .ok_or_else(|| TaskError::UnknownTool(name.to_owned()))?;
199 tool.invoke_erased(agent, cap).await
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use std::sync::Arc;
207 use typesec_core::{
208 CanExecute, CanRead, ResourceId, SubjectId,
209 policy::{PolicyEngine, PolicyResult},
210 resource::GenericResource,
211 typestate::Credentials,
212 };
213
214 struct AllowAll;
215 impl PolicyEngine for AllowAll {
216 fn check(&self, _: &SubjectId, _: &str, _: &ResourceId) -> PolicyResult {
217 PolicyResult::Allow
218 }
219 }
220
221 fn no_op_tool(_resource: &GenericResource) -> ToolFuture<'_> {
222 Box::pin(async { Ok(()) })
223 }
224
225 #[tokio::test]
226 async fn protected_tool_invokes_with_capability() {
227 let agent = SecureAgent::new(Arc::new(AllowAll))
228 .authenticate_unverified(Credentials::new("agent:test", "tok"))
229 .expect("auth ok");
230 let resource = GenericResource::new("Gmail.ListEmails", "tool");
231 let cap: Capability<CanExecute, GenericResource> =
232 agent.request_capability(&resource).await.expect("cap ok");
233 let tool = ProtectedTool::<CanExecute, _, _>::new(
234 "gmail.list",
235 "List email messages",
236 resource,
237 no_op_tool,
238 );
239
240 assert_eq!(tool.spec().required_permission, "execute");
241 tool.invoke(&agent, &cap).await.expect("tool should run");
242 }
243
244 #[tokio::test]
245 async fn protected_tool_rejects_capability_for_other_resource() {
246 let agent = SecureAgent::new(Arc::new(AllowAll))
247 .authenticate_unverified(Credentials::new("agent:test", "tok"))
248 .expect("auth ok");
249 let cap_resource = GenericResource::new("Gmail.ListEmails", "tool");
250 let tool_resource = GenericResource::new("Gmail.DeleteEmail", "tool");
251 let cap: Capability<CanExecute, GenericResource> = agent
252 .request_capability(&cap_resource)
253 .await
254 .expect("cap ok");
255 let tool = ProtectedTool::<CanExecute, _, _>::new(
256 "gmail.delete",
257 "Delete an email message",
258 tool_resource,
259 no_op_tool,
260 );
261
262 assert!(matches!(
263 tool.invoke(&agent, &cap).await,
264 Err(TaskError::CapabilityMismatch(reason)) if reason.contains("Gmail.ListEmails")
265 ));
266 }
267
268 #[tokio::test]
269 async fn registry_lists_and_invokes_registered_tools() {
270 let agent = SecureAgent::new(Arc::new(AllowAll))
271 .authenticate_unverified(Credentials::new("agent:test", "tok"))
272 .expect("auth ok");
273 let resource = GenericResource::new("Gmail.ListEmails", "tool");
274 let cap: Capability<CanExecute, GenericResource> =
275 agent.request_capability(&resource).await.expect("cap ok");
276 let mut registry = ToolRegistry::new();
277 registry.register(ProtectedTool::<CanExecute, _, _>::new(
278 "gmail.list",
279 "List email messages",
280 resource,
281 no_op_tool,
282 ));
283
284 let specs = registry.list_specs();
285 assert_eq!(specs.len(), 1);
286 assert_eq!(specs[0].name, "gmail.list");
287 assert_eq!(
288 registry
289 .spec("gmail.list")
290 .expect("registered spec")
291 .required_permission,
292 "execute"
293 );
294
295 registry
296 .invoke("gmail.list", &agent, &cap)
297 .await
298 .expect("registry invoke should run");
299 }
300
301 #[tokio::test]
302 async fn registry_rejects_wrong_capability_type_and_unknown_tool() {
303 let agent = SecureAgent::new(Arc::new(AllowAll))
304 .authenticate_unverified(Credentials::new("agent:test", "tok"))
305 .expect("auth ok");
306 let resource = GenericResource::new("Gmail.ListEmails", "tool");
307 let read_cap: Capability<CanRead, GenericResource> =
308 agent.request_capability(&resource).await.expect("cap ok");
309 let mut registry = ToolRegistry::new();
310 registry.register(ProtectedTool::<CanExecute, _, _>::new(
311 "gmail.list",
312 "List email messages",
313 resource,
314 no_op_tool,
315 ));
316
317 assert!(matches!(
318 registry.invoke("gmail.list", &agent, &read_cap).await,
319 Err(TaskError::CapabilityMismatch(reason)) if reason.contains("Capability<execute")
320 ));
321 assert!(matches!(
322 registry.invoke("missing", &agent, &read_cap).await,
323 Err(TaskError::UnknownTool(name)) if name == "missing"
324 ));
325 }
326}