1use std::any::Any;
4use std::collections::HashMap;
5use std::marker::PhantomData;
6use std::{future::Future, pin::Pin};
7
8use typesec_core::{Capability, Permission, Resource, typestate::Authenticated};
9
10use crate::{SecureAgent, executor::TaskError};
11
12pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>>;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ToolSpec {
18 pub name: String,
20 pub description: String,
22 pub required_permission: &'static str,
24 pub resource_id: String,
26}
27
28pub struct ProtectedTool<P, R, F>
30where
31 P: Permission,
32 R: Resource,
33{
34 spec: ToolSpec,
35 resource: R,
36 action: F,
37 _permission: PhantomData<fn() -> P>,
38}
39
40impl<P, R, F> ProtectedTool<P, R, F>
41where
42 P: Permission,
43 R: Resource,
44{
45 pub fn new(
47 name: impl Into<String>,
48 description: impl Into<String>,
49 resource: R,
50 action: F,
51 ) -> Self {
52 let resource_id = resource.resource_id().to_string();
53 Self {
54 spec: ToolSpec {
55 name: name.into(),
56 description: description.into(),
57 required_permission: P::name(),
58 resource_id,
59 },
60 resource,
61 action,
62 _permission: PhantomData,
63 }
64 }
65
66 pub fn spec(&self) -> &ToolSpec {
68 &self.spec
69 }
70}
71
72impl<P, R, F> ProtectedTool<P, R, F>
73where
74 P: Permission,
75 R: Resource,
76 F: for<'a> Fn(&'a R) -> ToolFuture<'a>,
77{
78 pub async fn invoke(
80 &self,
81 agent: &SecureAgent<Authenticated>,
82 cap: &Capability<P, R>,
83 ) -> Result<(), TaskError> {
84 if cap.subject() != agent.subject() {
85 return Err(TaskError::CapabilityMismatch(format!(
86 "capability was minted for subject '{}', not '{}'",
87 cap.subject(),
88 agent.subject()
89 )));
90 }
91 if cap.resource_id() != self.resource.resource_id() {
92 return Err(TaskError::CapabilityMismatch(format!(
93 "capability covers resource '{}', not '{}'",
94 cap.resource_id(),
95 self.resource.resource_id()
96 )));
97 }
98 cap.ensure_active()?;
99
100 tracing::info!(
101 subject = %agent.subject(),
102 permission = %Capability::<P, R>::permission_name(),
103 resource = %cap.resource_id(),
104 tool = %self.spec.name,
105 "invoking protected tool"
106 );
107 (self.action)(&self.resource).await
108 }
109}
110
111trait ErasedTool: Send + Sync {
112 fn spec(&self) -> &ToolSpec;
113
114 fn invoke_erased<'a>(
115 &'a self,
116 agent: &'a SecureAgent<Authenticated>,
117 cap: &'a (dyn Any + Send + Sync),
118 ) -> ToolFuture<'a>;
119}
120
121impl<P, R, F> ErasedTool for ProtectedTool<P, R, F>
122where
123 P: Permission + 'static,
124 R: Resource + 'static,
125 F: for<'a> Fn(&'a R) -> ToolFuture<'a> + Send + Sync + 'static,
126{
127 fn spec(&self) -> &ToolSpec {
128 &self.spec
129 }
130
131 fn invoke_erased<'a>(
132 &'a self,
133 agent: &'a SecureAgent<Authenticated>,
134 cap: &'a (dyn Any + Send + Sync),
135 ) -> ToolFuture<'a> {
136 let Some(cap) = cap.downcast_ref::<Capability<P, R>>() else {
137 return Box::pin(async move {
138 Err(TaskError::CapabilityMismatch(format!(
139 "tool '{}' requires Capability<{}, {}>",
140 self.spec.name,
141 P::name(),
142 R::resource_type()
143 )))
144 });
145 };
146
147 Box::pin(async move { self.invoke(agent, cap).await })
148 }
149}
150
151#[derive(Default)]
153pub struct ToolRegistry {
154 tools: HashMap<String, Box<dyn ErasedTool>>,
155}
156
157impl ToolRegistry {
158 pub fn new() -> Self {
160 Self::default()
161 }
162
163 pub fn register<P, R, F>(&mut self, tool: ProtectedTool<P, R, F>)
167 where
168 P: Permission + 'static,
169 R: Resource + 'static,
170 F: for<'a> Fn(&'a R) -> ToolFuture<'a> + Send + Sync + 'static,
171 {
172 self.tools.insert(tool.spec.name.clone(), Box::new(tool));
173 }
174
175 pub fn list_specs(&self) -> Vec<ToolSpec> {
177 self.tools
178 .values()
179 .map(|tool| tool.spec().clone())
180 .collect()
181 }
182
183 pub fn spec(&self, name: &str) -> Option<&ToolSpec> {
185 self.tools.get(name).map(|tool| tool.spec())
186 }
187
188 pub async fn invoke(
190 &self,
191 name: &str,
192 agent: &SecureAgent<Authenticated>,
193 cap: &(dyn Any + Send + Sync),
194 ) -> Result<(), TaskError> {
195 let tool = self
196 .tools
197 .get(name)
198 .ok_or_else(|| TaskError::UnknownTool(name.to_owned()))?;
199 tool.invoke_erased(agent, cap).await
200 }
201}
202
203#[cfg(test)]
204mod tests;