1use std::sync::Arc;
39
40use crate::error::{Error, JsonRpcError};
41use crate::prompt::Prompt;
42use crate::resource::Resource;
43use crate::session::SessionState;
44use crate::tool::Tool;
45
46pub trait Filterable: Send + Sync {
50 fn name(&self) -> &str;
52}
53
54impl Filterable for Tool {
55 fn name(&self) -> &str {
56 &self.name
57 }
58}
59
60impl Filterable for Resource {
61 fn name(&self) -> &str {
62 &self.name
63 }
64}
65
66impl Filterable for Prompt {
67 fn name(&self) -> &str {
68 &self.name
69 }
70}
71
72#[derive(Clone, Default)]
74pub enum DenialBehavior {
75 #[default]
78 NotFound,
79 Unauthorized,
81 Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
83}
84
85impl std::fmt::Debug for DenialBehavior {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 Self::NotFound => write!(f, "NotFound"),
89 Self::Unauthorized => write!(f, "Unauthorized"),
90 Self::Custom(_) => write!(f, "Custom(...)"),
91 }
92 }
93}
94
95impl DenialBehavior {
96 pub fn custom<F>(f: F) -> Self
98 where
99 F: Fn(&str) -> Error + Send + Sync + 'static,
100 {
101 Self::Custom(Arc::new(f))
102 }
103
104 pub fn to_error(&self, name: &str) -> Error {
106 match self {
107 Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
108 Self::Unauthorized => {
109 Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
110 }
111 Self::Custom(f) => f(name),
112 }
113 }
114}
115
116pub struct CapabilityFilter<T: Filterable> {
137 #[allow(clippy::type_complexity)]
138 filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
139 denial: DenialBehavior,
140}
141
142impl<T: Filterable> Clone for CapabilityFilter<T> {
143 fn clone(&self) -> Self {
144 Self {
145 filter: Arc::clone(&self.filter),
146 denial: self.denial.clone(),
147 }
148 }
149}
150
151impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("CapabilityFilter")
154 .field("denial", &self.denial)
155 .finish_non_exhaustive()
156 }
157}
158
159impl<T: Filterable> CapabilityFilter<T> {
160 pub fn new<F>(filter: F) -> Self
177 where
178 F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
179 {
180 Self {
181 filter: Arc::new(filter),
182 denial: DenialBehavior::default(),
183 }
184 }
185
186 pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
199 self.denial = behavior;
200 self
201 }
202
203 pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
205 (self.filter)(session, capability)
206 }
207
208 pub fn denial_error(&self, name: &str) -> Error {
210 self.denial.to_error(name)
211 }
212}
213
214pub type ToolFilter = CapabilityFilter<Tool>;
216
217pub type ResourceFilter = CapabilityFilter<Resource>;
219
220pub type PromptFilter = CapabilityFilter<Prompt>;
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::CallToolResult;
227 use crate::tool::ToolBuilder;
228
229 fn make_test_tool(name: &str) -> Tool {
230 ToolBuilder::new(name)
231 .description("Test tool")
232 .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
233 .build()
234 .unwrap()
235 }
236
237 #[test]
238 fn test_filter_allows() {
239 let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
240 let session = SessionState::new();
241 let allowed = make_test_tool("allowed");
242 let blocked = make_test_tool("blocked");
243
244 assert!(filter.is_visible(&session, &allowed));
245 assert!(!filter.is_visible(&session, &blocked));
246 }
247
248 #[test]
249 fn test_denial_behavior_not_found() {
250 let behavior = DenialBehavior::NotFound;
251 let error = behavior.to_error("test_tool");
252 assert!(matches!(error, Error::JsonRpc(_)));
253 }
254
255 #[test]
256 fn test_denial_behavior_unauthorized() {
257 let behavior = DenialBehavior::Unauthorized;
258 let error = behavior.to_error("test_tool");
259 match error {
260 Error::JsonRpc(e) => {
261 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
263 }
264 _ => panic!("Expected JsonRpc error"),
265 }
266 }
267
268 #[test]
269 fn test_denial_behavior_custom() {
270 let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
271 let error = behavior.to_error("secret_tool");
272 match error {
273 Error::Tool(e) => {
274 assert!(e.message.contains("No access to secret_tool"));
275 }
276 _ => panic!("Expected Tool error"),
277 }
278 }
279
280 #[test]
281 fn test_filter_clone() {
282 let filter = CapabilityFilter::new(|_, _: &Tool| true);
283 let cloned = filter.clone();
284 let session = SessionState::new();
285 let tool = make_test_tool("test");
286 assert!(cloned.is_visible(&session, &tool));
287 }
288
289 #[test]
290 fn test_filter_with_denial_behavior() {
291 let filter = CapabilityFilter::new(|_, _: &Tool| false)
292 .denial_behavior(DenialBehavior::Unauthorized);
293
294 let error = filter.denial_error("test");
295 match error {
296 Error::JsonRpc(e) => assert_eq!(e.code, -32007), _ => panic!("Expected JsonRpc error"),
298 }
299 }
300}