Skip to main content

tower_mcp/
filter.rs

1//! Session-based capability filtering.
2//!
3//! This module provides types for filtering tools, resources, and prompts
4//! based on session state. Different sessions can see different capabilities
5//! based on user identity, roles, API keys, or other session context.
6//!
7//! # Example
8//!
9//! ```rust
10//! use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Debug, Deserialize, JsonSchema)]
15//! struct Input { value: String }
16//!
17//! let public_tool = ToolBuilder::new("public")
18//!     .description("Available to everyone")
19//!     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
20//!     .build();
21//!
22//! let admin_tool = ToolBuilder::new("admin")
23//!     .description("Admin only")
24//!     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
25//!     .build();
26//!
27//! let router = McpRouter::new()
28//!     .tool(public_tool)
29//!     .tool(admin_tool)
30//!     .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
31//!         // In real code, check session.extensions() for auth claims
32//!         tool.name() != "admin"
33//!     }));
34//! ```
35
36use std::collections::HashSet;
37use std::sync::Arc;
38
39use crate::error::{Error, JsonRpcError};
40use crate::prompt::Prompt;
41use crate::resource::Resource;
42use crate::session::SessionState;
43use crate::tool::Tool;
44
45/// Trait for capabilities that can be filtered by session.
46///
47/// Implemented for [`Tool`], [`Resource`], and [`Prompt`].
48pub trait Filterable: Send + Sync {
49    /// Returns the name of this capability.
50    fn name(&self) -> &str;
51}
52
53impl Filterable for Tool {
54    fn name(&self) -> &str {
55        &self.name
56    }
57}
58
59impl Filterable for Resource {
60    fn name(&self) -> &str {
61        &self.name
62    }
63}
64
65impl Filterable for Prompt {
66    fn name(&self) -> &str {
67        &self.name
68    }
69}
70
71/// Behavior when a filtered capability is accessed directly.
72#[derive(Clone, Default)]
73#[non_exhaustive]
74pub enum DenialBehavior {
75    /// Return "method not found" error -- hides the capability entirely.
76    ///
77    /// This is the default and recommended for security. Use this in
78    /// multi-tenant scenarios where tools should not be discoverable by
79    /// unauthorized users.
80    #[default]
81    NotFound,
82    /// Return an "unauthorized" error, revealing the capability exists.
83    ///
84    /// Use this when the client should know about the capability but is
85    /// not permitted to invoke it (e.g., premium features behind an
86    /// upgrade prompt).
87    Unauthorized,
88    /// Use a custom error generator for application-specific responses.
89    ///
90    /// Use this when you need custom status codes, domain-specific error
91    /// messages, or structured error payloads.
92    Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
93}
94
95impl std::fmt::Debug for DenialBehavior {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self {
98            Self::NotFound => write!(f, "NotFound"),
99            Self::Unauthorized => write!(f, "Unauthorized"),
100            Self::Custom(_) => write!(f, "Custom(...)"),
101        }
102    }
103}
104
105impl DenialBehavior {
106    /// Create a custom denial behavior with the given error generator.
107    pub fn custom<F>(f: F) -> Self
108    where
109        F: Fn(&str) -> Error + Send + Sync + 'static,
110    {
111        Self::Custom(Arc::new(f))
112    }
113
114    /// Generate the appropriate error for a denied capability.
115    pub fn to_error(&self, name: &str) -> Error {
116        match self {
117            Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
118            Self::Unauthorized => {
119                Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
120            }
121            Self::Custom(f) => f(name),
122        }
123    }
124}
125
126/// A filter for capabilities based on session state.
127///
128/// Use this to control which tools, resources, or prompts are visible
129/// to each session.
130///
131/// # Example
132///
133/// ```rust
134/// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
135///
136/// // Filter that only shows tools starting with "public_"
137/// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
138///     tool.name().starts_with("public_")
139/// });
140///
141/// // Filter with custom denial behavior
142/// let filter_with_401 = CapabilityFilter::new(|_session, tool: &Tool| {
143///     tool.name() != "admin"
144/// }).denial_behavior(DenialBehavior::Unauthorized);
145/// ```
146pub struct CapabilityFilter<T: Filterable> {
147    #[allow(clippy::type_complexity)]
148    filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
149    denial: DenialBehavior,
150}
151
152impl<T: Filterable> Clone for CapabilityFilter<T> {
153    fn clone(&self) -> Self {
154        Self {
155            filter: Arc::clone(&self.filter),
156            denial: self.denial.clone(),
157        }
158    }
159}
160
161impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("CapabilityFilter")
164            .field("denial", &self.denial)
165            .finish_non_exhaustive()
166    }
167}
168
169impl<T: Filterable> CapabilityFilter<T> {
170    /// Create a new capability filter with the given predicate.
171    ///
172    /// The predicate receives the session state and capability, and returns
173    /// `true` if the capability should be visible to the session.
174    ///
175    /// # Example
176    ///
177    /// ```rust
178    /// use tower_mcp::{CapabilityFilter, Tool, Filterable};
179    ///
180    /// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
181    ///     // Check session extensions for auth claims
182    ///     // session.extensions().get::<UserClaims>()...
183    ///     tool.name() != "admin_only"
184    /// });
185    /// ```
186    pub fn new<F>(filter: F) -> Self
187    where
188        F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
189    {
190        Self {
191            filter: Arc::new(filter),
192            denial: DenialBehavior::default(),
193        }
194    }
195
196    /// Set the behavior when a filtered capability is accessed directly.
197    ///
198    /// Default is [`DenialBehavior::NotFound`].
199    ///
200    /// # Example
201    ///
202    /// ```rust
203    /// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
204    ///
205    /// let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "secret")
206    ///     .denial_behavior(DenialBehavior::Unauthorized);
207    /// ```
208    pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
209        self.denial = behavior;
210        self
211    }
212
213    /// Check if the given capability is visible to the session.
214    pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
215        (self.filter)(session, capability)
216    }
217
218    /// Get the error to return when access is denied.
219    pub fn denial_error(&self, name: &str) -> Error {
220        self.denial.to_error(name)
221    }
222
223    /// Create a filter that only shows capabilities whose names are in the list.
224    ///
225    /// Capabilities not in the list are hidden. This is useful for exposing
226    /// a curated subset of capabilities (e.g., from a config file or CLI flag).
227    ///
228    /// # Example
229    ///
230    /// ```rust
231    /// use tower_mcp::{CapabilityFilter, Tool};
232    ///
233    /// // Only expose these two tools
234    /// let filter = CapabilityFilter::<Tool>::allow_list(&["query", "list_tables"]);
235    /// ```
236    pub fn allow_list(names: &[&str]) -> Self
237    where
238        T: 'static,
239    {
240        let allowed: HashSet<String> = names.iter().map(|s| (*s).to_string()).collect();
241        Self::new(move |_session, cap: &T| allowed.contains(cap.name()))
242    }
243
244    /// Create a filter that hides capabilities whose names are in the list.
245    ///
246    /// All capabilities are visible except those explicitly listed. This is
247    /// useful for blocking specific dangerous or irrelevant capabilities.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use tower_mcp::{CapabilityFilter, Tool};
253    ///
254    /// // Hide these destructive tools
255    /// let filter = CapabilityFilter::<Tool>::deny_list(&["delete", "drop_table"]);
256    /// ```
257    pub fn deny_list(names: &[&str]) -> Self
258    where
259        T: 'static,
260    {
261        let denied: HashSet<String> = names.iter().map(|s| (*s).to_string()).collect();
262        Self::new(move |_session, cap: &T| !denied.contains(cap.name()))
263    }
264}
265
266impl CapabilityFilter<Tool> {
267    /// Create a filter that blocks non-read-only tools when the predicate returns `false`.
268    ///
269    /// Read-only tools (those with `read_only_hint = true`) are always allowed.
270    /// Non-read-only tools are only allowed when `is_write_allowed` returns `true`
271    /// for the current session.
272    ///
273    /// This provides annotation-based write protection without requiring
274    /// manual guards in every write tool handler.
275    ///
276    /// # Example
277    ///
278    /// ```rust
279    /// use tower_mcp::{CapabilityFilter, Tool};
280    ///
281    /// // Block all write tools unconditionally
282    /// let filter = CapabilityFilter::<Tool>::write_guard(|_session| false);
283    ///
284    /// // Allow writes based on session state
285    /// // let filter = CapabilityFilter::<Tool>::write_guard(|session| {
286    /// //     session.get::<WriteEnabled>().is_some()
287    /// // });
288    /// ```
289    pub fn write_guard<F>(is_write_allowed: F) -> Self
290    where
291        F: Fn(&SessionState) -> bool + Send + Sync + 'static,
292    {
293        Self::new(move |session, tool: &Tool| {
294            let read_only = tool.annotations.as_ref().is_some_and(|a| a.read_only_hint);
295            read_only || is_write_allowed(session)
296        })
297    }
298}
299
300/// Type alias for tool filters.
301pub type ToolFilter = CapabilityFilter<Tool>;
302
303/// Type alias for resource filters.
304pub type ResourceFilter = CapabilityFilter<Resource>;
305
306/// Type alias for prompt filters.
307pub type PromptFilter = CapabilityFilter<Prompt>;
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::CallToolResult;
313    use crate::tool::ToolBuilder;
314
315    fn make_test_tool(name: &str) -> Tool {
316        ToolBuilder::new(name)
317            .description("Test tool")
318            .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
319            .build()
320    }
321
322    #[test]
323    fn test_filter_allows() {
324        let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
325        let session = SessionState::new();
326        let allowed = make_test_tool("allowed");
327        let blocked = make_test_tool("blocked");
328
329        assert!(filter.is_visible(&session, &allowed));
330        assert!(!filter.is_visible(&session, &blocked));
331    }
332
333    #[test]
334    fn test_denial_behavior_not_found() {
335        let behavior = DenialBehavior::NotFound;
336        let error = behavior.to_error("test_tool");
337        assert!(matches!(error, Error::JsonRpc(_)));
338    }
339
340    #[test]
341    fn test_denial_behavior_unauthorized() {
342        let behavior = DenialBehavior::Unauthorized;
343        let error = behavior.to_error("test_tool");
344        match error {
345            Error::JsonRpc(e) => {
346                assert_eq!(e.code, -32007); // McpErrorCode::Forbidden
347                assert!(e.message.contains("Unauthorized"));
348            }
349            _ => panic!("Expected JsonRpc error"),
350        }
351    }
352
353    #[test]
354    fn test_denial_behavior_custom() {
355        let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
356        let error = behavior.to_error("secret_tool");
357        match error {
358            Error::Tool(e) => {
359                assert!(e.message.contains("No access to secret_tool"));
360            }
361            _ => panic!("Expected Tool error"),
362        }
363    }
364
365    #[test]
366    fn test_filter_clone() {
367        let filter = CapabilityFilter::new(|_, _: &Tool| true);
368        let cloned = filter.clone();
369        let session = SessionState::new();
370        let tool = make_test_tool("test");
371        assert!(cloned.is_visible(&session, &tool));
372    }
373
374    #[test]
375    fn test_filter_with_denial_behavior() {
376        let filter = CapabilityFilter::new(|_, _: &Tool| false)
377            .denial_behavior(DenialBehavior::Unauthorized);
378
379        let error = filter.denial_error("test");
380        match error {
381            Error::JsonRpc(e) => assert_eq!(e.code, -32007), // McpErrorCode::Forbidden
382            _ => panic!("Expected JsonRpc error"),
383        }
384    }
385
386    fn make_read_only_tool(name: &str) -> Tool {
387        ToolBuilder::new(name)
388            .description("Read-only tool")
389            .read_only()
390            .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
391            .build()
392    }
393
394    #[test]
395    fn test_write_guard_allows_read_only_when_writes_blocked() {
396        let filter = CapabilityFilter::<Tool>::write_guard(|_| false);
397        let session = SessionState::new();
398        let tool = make_read_only_tool("reader");
399
400        assert!(filter.is_visible(&session, &tool));
401    }
402
403    #[test]
404    fn test_write_guard_blocks_write_tool_when_writes_blocked() {
405        let filter = CapabilityFilter::<Tool>::write_guard(|_| false);
406        let session = SessionState::new();
407        let tool = make_test_tool("writer");
408
409        assert!(!filter.is_visible(&session, &tool));
410    }
411
412    #[test]
413    fn test_write_guard_allows_write_tool_when_writes_allowed() {
414        let filter = CapabilityFilter::<Tool>::write_guard(|_| true);
415        let session = SessionState::new();
416        let tool = make_test_tool("writer");
417
418        assert!(filter.is_visible(&session, &tool));
419    }
420
421    #[test]
422    fn test_write_guard_with_denial_behavior() {
423        let filter = CapabilityFilter::<Tool>::write_guard(|_| false)
424            .denial_behavior(DenialBehavior::Unauthorized);
425        let session = SessionState::new();
426        let tool = make_test_tool("writer");
427
428        assert!(!filter.is_visible(&session, &tool));
429        let error = filter.denial_error("writer");
430        match error {
431            Error::JsonRpc(e) => assert_eq!(e.code, -32007),
432            _ => panic!("Expected JsonRpc error"),
433        }
434    }
435
436    #[test]
437    fn test_allow_list_shows_listed_tools() {
438        let filter = CapabilityFilter::<Tool>::allow_list(&["query", "list_tables"]);
439        let session = SessionState::new();
440
441        assert!(filter.is_visible(&session, &make_test_tool("query")));
442        assert!(filter.is_visible(&session, &make_test_tool("list_tables")));
443        assert!(!filter.is_visible(&session, &make_test_tool("delete")));
444        assert!(!filter.is_visible(&session, &make_test_tool("drop_table")));
445    }
446
447    #[test]
448    fn test_allow_list_empty_blocks_all() {
449        let filter = CapabilityFilter::<Tool>::allow_list(&[]);
450        let session = SessionState::new();
451
452        assert!(!filter.is_visible(&session, &make_test_tool("anything")));
453    }
454
455    #[test]
456    fn test_deny_list_hides_listed_tools() {
457        let filter = CapabilityFilter::<Tool>::deny_list(&["delete", "drop_table"]);
458        let session = SessionState::new();
459
460        assert!(filter.is_visible(&session, &make_test_tool("query")));
461        assert!(filter.is_visible(&session, &make_test_tool("list_tables")));
462        assert!(!filter.is_visible(&session, &make_test_tool("delete")));
463        assert!(!filter.is_visible(&session, &make_test_tool("drop_table")));
464    }
465
466    #[test]
467    fn test_deny_list_empty_allows_all() {
468        let filter = CapabilityFilter::<Tool>::deny_list(&[]);
469        let session = SessionState::new();
470
471        assert!(filter.is_visible(&session, &make_test_tool("anything")));
472    }
473
474    #[test]
475    fn test_allow_list_with_denial_behavior() {
476        let filter = CapabilityFilter::<Tool>::allow_list(&["query"])
477            .denial_behavior(DenialBehavior::Unauthorized);
478        let session = SessionState::new();
479
480        assert!(!filter.is_visible(&session, &make_test_tool("delete")));
481        let error = filter.denial_error("delete");
482        match error {
483            Error::JsonRpc(e) => assert_eq!(e.code, -32007),
484            _ => panic!("Expected JsonRpc error"),
485        }
486    }
487}