1use std::collections::HashSet;
37use std::sync::Arc;
38
39use crate::error::{Error, JsonRpcError};
40use crate::prompt::Prompt;
41use crate::resource::Resource;
42use crate::session::SessionState;
43use crate::tool::Tool;
44
45pub trait Filterable: Send + Sync {
49 fn name(&self) -> &str;
51}
52
53impl Filterable for Tool {
54 fn name(&self) -> &str {
55 &self.name
56 }
57}
58
59impl Filterable for Resource {
60 fn name(&self) -> &str {
61 &self.name
62 }
63}
64
65impl Filterable for Prompt {
66 fn name(&self) -> &str {
67 &self.name
68 }
69}
70
71#[derive(Clone, Default)]
73#[non_exhaustive]
74pub enum DenialBehavior {
75 #[default]
81 NotFound,
82 Unauthorized,
88 Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
93}
94
95impl std::fmt::Debug for DenialBehavior {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 Self::NotFound => write!(f, "NotFound"),
99 Self::Unauthorized => write!(f, "Unauthorized"),
100 Self::Custom(_) => write!(f, "Custom(...)"),
101 }
102 }
103}
104
105impl DenialBehavior {
106 pub fn custom<F>(f: F) -> Self
108 where
109 F: Fn(&str) -> Error + Send + Sync + 'static,
110 {
111 Self::Custom(Arc::new(f))
112 }
113
114 pub fn to_error(&self, name: &str) -> Error {
116 match self {
117 Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
118 Self::Unauthorized => {
119 Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
120 }
121 Self::Custom(f) => f(name),
122 }
123 }
124}
125
126pub struct CapabilityFilter<T: Filterable> {
147 #[allow(clippy::type_complexity)]
148 filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
149 denial: DenialBehavior,
150}
151
152impl<T: Filterable> Clone for CapabilityFilter<T> {
153 fn clone(&self) -> Self {
154 Self {
155 filter: Arc::clone(&self.filter),
156 denial: self.denial.clone(),
157 }
158 }
159}
160
161impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("CapabilityFilter")
164 .field("denial", &self.denial)
165 .finish_non_exhaustive()
166 }
167}
168
169impl<T: Filterable> CapabilityFilter<T> {
170 pub fn new<F>(filter: F) -> Self
187 where
188 F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
189 {
190 Self {
191 filter: Arc::new(filter),
192 denial: DenialBehavior::default(),
193 }
194 }
195
196 pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
209 self.denial = behavior;
210 self
211 }
212
213 pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
215 (self.filter)(session, capability)
216 }
217
218 pub fn denial_error(&self, name: &str) -> Error {
220 self.denial.to_error(name)
221 }
222
223 pub fn allow_list(names: &[&str]) -> Self
237 where
238 T: 'static,
239 {
240 let allowed: HashSet<String> = names.iter().map(|s| (*s).to_string()).collect();
241 Self::new(move |_session, cap: &T| allowed.contains(cap.name()))
242 }
243
244 pub fn deny_list(names: &[&str]) -> Self
258 where
259 T: 'static,
260 {
261 let denied: HashSet<String> = names.iter().map(|s| (*s).to_string()).collect();
262 Self::new(move |_session, cap: &T| !denied.contains(cap.name()))
263 }
264}
265
266impl CapabilityFilter<Tool> {
267 pub fn write_guard<F>(is_write_allowed: F) -> Self
290 where
291 F: Fn(&SessionState) -> bool + Send + Sync + 'static,
292 {
293 Self::new(move |session, tool: &Tool| {
294 let read_only = tool.annotations.as_ref().is_some_and(|a| a.read_only_hint);
295 read_only || is_write_allowed(session)
296 })
297 }
298}
299
300pub type ToolFilter = CapabilityFilter<Tool>;
302
303pub type ResourceFilter = CapabilityFilter<Resource>;
305
306pub type PromptFilter = CapabilityFilter<Prompt>;
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::CallToolResult;
313 use crate::tool::ToolBuilder;
314
315 fn make_test_tool(name: &str) -> Tool {
316 ToolBuilder::new(name)
317 .description("Test tool")
318 .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
319 .build()
320 }
321
322 #[test]
323 fn test_filter_allows() {
324 let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
325 let session = SessionState::new();
326 let allowed = make_test_tool("allowed");
327 let blocked = make_test_tool("blocked");
328
329 assert!(filter.is_visible(&session, &allowed));
330 assert!(!filter.is_visible(&session, &blocked));
331 }
332
333 #[test]
334 fn test_denial_behavior_not_found() {
335 let behavior = DenialBehavior::NotFound;
336 let error = behavior.to_error("test_tool");
337 assert!(matches!(error, Error::JsonRpc(_)));
338 }
339
340 #[test]
341 fn test_denial_behavior_unauthorized() {
342 let behavior = DenialBehavior::Unauthorized;
343 let error = behavior.to_error("test_tool");
344 match error {
345 Error::JsonRpc(e) => {
346 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
348 }
349 _ => panic!("Expected JsonRpc error"),
350 }
351 }
352
353 #[test]
354 fn test_denial_behavior_custom() {
355 let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
356 let error = behavior.to_error("secret_tool");
357 match error {
358 Error::Tool(e) => {
359 assert!(e.message.contains("No access to secret_tool"));
360 }
361 _ => panic!("Expected Tool error"),
362 }
363 }
364
365 #[test]
366 fn test_filter_clone() {
367 let filter = CapabilityFilter::new(|_, _: &Tool| true);
368 let cloned = filter.clone();
369 let session = SessionState::new();
370 let tool = make_test_tool("test");
371 assert!(cloned.is_visible(&session, &tool));
372 }
373
374 #[test]
375 fn test_filter_with_denial_behavior() {
376 let filter = CapabilityFilter::new(|_, _: &Tool| false)
377 .denial_behavior(DenialBehavior::Unauthorized);
378
379 let error = filter.denial_error("test");
380 match error {
381 Error::JsonRpc(e) => assert_eq!(e.code, -32007), _ => panic!("Expected JsonRpc error"),
383 }
384 }
385
386 fn make_read_only_tool(name: &str) -> Tool {
387 ToolBuilder::new(name)
388 .description("Read-only tool")
389 .read_only()
390 .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
391 .build()
392 }
393
394 #[test]
395 fn test_write_guard_allows_read_only_when_writes_blocked() {
396 let filter = CapabilityFilter::<Tool>::write_guard(|_| false);
397 let session = SessionState::new();
398 let tool = make_read_only_tool("reader");
399
400 assert!(filter.is_visible(&session, &tool));
401 }
402
403 #[test]
404 fn test_write_guard_blocks_write_tool_when_writes_blocked() {
405 let filter = CapabilityFilter::<Tool>::write_guard(|_| false);
406 let session = SessionState::new();
407 let tool = make_test_tool("writer");
408
409 assert!(!filter.is_visible(&session, &tool));
410 }
411
412 #[test]
413 fn test_write_guard_allows_write_tool_when_writes_allowed() {
414 let filter = CapabilityFilter::<Tool>::write_guard(|_| true);
415 let session = SessionState::new();
416 let tool = make_test_tool("writer");
417
418 assert!(filter.is_visible(&session, &tool));
419 }
420
421 #[test]
422 fn test_write_guard_with_denial_behavior() {
423 let filter = CapabilityFilter::<Tool>::write_guard(|_| false)
424 .denial_behavior(DenialBehavior::Unauthorized);
425 let session = SessionState::new();
426 let tool = make_test_tool("writer");
427
428 assert!(!filter.is_visible(&session, &tool));
429 let error = filter.denial_error("writer");
430 match error {
431 Error::JsonRpc(e) => assert_eq!(e.code, -32007),
432 _ => panic!("Expected JsonRpc error"),
433 }
434 }
435
436 #[test]
437 fn test_allow_list_shows_listed_tools() {
438 let filter = CapabilityFilter::<Tool>::allow_list(&["query", "list_tables"]);
439 let session = SessionState::new();
440
441 assert!(filter.is_visible(&session, &make_test_tool("query")));
442 assert!(filter.is_visible(&session, &make_test_tool("list_tables")));
443 assert!(!filter.is_visible(&session, &make_test_tool("delete")));
444 assert!(!filter.is_visible(&session, &make_test_tool("drop_table")));
445 }
446
447 #[test]
448 fn test_allow_list_empty_blocks_all() {
449 let filter = CapabilityFilter::<Tool>::allow_list(&[]);
450 let session = SessionState::new();
451
452 assert!(!filter.is_visible(&session, &make_test_tool("anything")));
453 }
454
455 #[test]
456 fn test_deny_list_hides_listed_tools() {
457 let filter = CapabilityFilter::<Tool>::deny_list(&["delete", "drop_table"]);
458 let session = SessionState::new();
459
460 assert!(filter.is_visible(&session, &make_test_tool("query")));
461 assert!(filter.is_visible(&session, &make_test_tool("list_tables")));
462 assert!(!filter.is_visible(&session, &make_test_tool("delete")));
463 assert!(!filter.is_visible(&session, &make_test_tool("drop_table")));
464 }
465
466 #[test]
467 fn test_deny_list_empty_allows_all() {
468 let filter = CapabilityFilter::<Tool>::deny_list(&[]);
469 let session = SessionState::new();
470
471 assert!(filter.is_visible(&session, &make_test_tool("anything")));
472 }
473
474 #[test]
475 fn test_allow_list_with_denial_behavior() {
476 let filter = CapabilityFilter::<Tool>::allow_list(&["query"])
477 .denial_behavior(DenialBehavior::Unauthorized);
478 let session = SessionState::new();
479
480 assert!(!filter.is_visible(&session, &make_test_tool("delete")));
481 let error = filter.denial_error("delete");
482 match error {
483 Error::JsonRpc(e) => assert_eq!(e.code, -32007),
484 _ => panic!("Expected JsonRpc error"),
485 }
486 }
487}