Skip to main content

tower_mcp/
filter.rs

1//! Session-based capability filtering.
2//!
3//! This module provides types for filtering tools, resources, and prompts
4//! based on session state. Different sessions can see different capabilities
5//! based on user identity, roles, API keys, or other session context.
6//!
7//! # Example
8//!
9//! ```rust
10//! use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Debug, Deserialize, JsonSchema)]
15//! struct Input { value: String }
16//!
17//! let public_tool = ToolBuilder::new("public")
18//!     .description("Available to everyone")
19//!     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
20//!     .build();
21//!
22//! let admin_tool = ToolBuilder::new("admin")
23//!     .description("Admin only")
24//!     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
25//!     .build();
26//!
27//! let router = McpRouter::new()
28//!     .tool(public_tool)
29//!     .tool(admin_tool)
30//!     .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
31//!         // In real code, check session.extensions() for auth claims
32//!         tool.name() != "admin"
33//!     }));
34//! ```
35
36use std::sync::Arc;
37
38use crate::error::{Error, JsonRpcError};
39use crate::prompt::Prompt;
40use crate::resource::Resource;
41use crate::session::SessionState;
42use crate::tool::Tool;
43
44/// Trait for capabilities that can be filtered by session.
45///
46/// Implemented for [`Tool`], [`Resource`], and [`Prompt`].
47pub trait Filterable: Send + Sync {
48    /// Returns the name of this capability.
49    fn name(&self) -> &str;
50}
51
52impl Filterable for Tool {
53    fn name(&self) -> &str {
54        &self.name
55    }
56}
57
58impl Filterable for Resource {
59    fn name(&self) -> &str {
60        &self.name
61    }
62}
63
64impl Filterable for Prompt {
65    fn name(&self) -> &str {
66        &self.name
67    }
68}
69
70/// Behavior when a filtered capability is accessed directly.
71#[derive(Clone, Default)]
72pub enum DenialBehavior {
73    /// Return "method not found" error -- hides the capability entirely.
74    ///
75    /// This is the default and recommended for security. Use this in
76    /// multi-tenant scenarios where tools should not be discoverable by
77    /// unauthorized users.
78    #[default]
79    NotFound,
80    /// Return an "unauthorized" error, revealing the capability exists.
81    ///
82    /// Use this when the client should know about the capability but is
83    /// not permitted to invoke it (e.g., premium features behind an
84    /// upgrade prompt).
85    Unauthorized,
86    /// Use a custom error generator for application-specific responses.
87    ///
88    /// Use this when you need custom status codes, domain-specific error
89    /// messages, or structured error payloads.
90    Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
91}
92
93impl std::fmt::Debug for DenialBehavior {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Self::NotFound => write!(f, "NotFound"),
97            Self::Unauthorized => write!(f, "Unauthorized"),
98            Self::Custom(_) => write!(f, "Custom(...)"),
99        }
100    }
101}
102
103impl DenialBehavior {
104    /// Create a custom denial behavior with the given error generator.
105    pub fn custom<F>(f: F) -> Self
106    where
107        F: Fn(&str) -> Error + Send + Sync + 'static,
108    {
109        Self::Custom(Arc::new(f))
110    }
111
112    /// Generate the appropriate error for a denied capability.
113    pub fn to_error(&self, name: &str) -> Error {
114        match self {
115            Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
116            Self::Unauthorized => {
117                Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
118            }
119            Self::Custom(f) => f(name),
120        }
121    }
122}
123
124/// A filter for capabilities based on session state.
125///
126/// Use this to control which tools, resources, or prompts are visible
127/// to each session.
128///
129/// # Example
130///
131/// ```rust
132/// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
133///
134/// // Filter that only shows tools starting with "public_"
135/// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
136///     tool.name().starts_with("public_")
137/// });
138///
139/// // Filter with custom denial behavior
140/// let filter_with_401 = CapabilityFilter::new(|_session, tool: &Tool| {
141///     tool.name() != "admin"
142/// }).denial_behavior(DenialBehavior::Unauthorized);
143/// ```
144pub struct CapabilityFilter<T: Filterable> {
145    #[allow(clippy::type_complexity)]
146    filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
147    denial: DenialBehavior,
148}
149
150impl<T: Filterable> Clone for CapabilityFilter<T> {
151    fn clone(&self) -> Self {
152        Self {
153            filter: Arc::clone(&self.filter),
154            denial: self.denial.clone(),
155        }
156    }
157}
158
159impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        f.debug_struct("CapabilityFilter")
162            .field("denial", &self.denial)
163            .finish_non_exhaustive()
164    }
165}
166
167impl<T: Filterable> CapabilityFilter<T> {
168    /// Create a new capability filter with the given predicate.
169    ///
170    /// The predicate receives the session state and capability, and returns
171    /// `true` if the capability should be visible to the session.
172    ///
173    /// # Example
174    ///
175    /// ```rust
176    /// use tower_mcp::{CapabilityFilter, Tool, Filterable};
177    ///
178    /// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
179    ///     // Check session extensions for auth claims
180    ///     // session.extensions().get::<UserClaims>()...
181    ///     tool.name() != "admin_only"
182    /// });
183    /// ```
184    pub fn new<F>(filter: F) -> Self
185    where
186        F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
187    {
188        Self {
189            filter: Arc::new(filter),
190            denial: DenialBehavior::default(),
191        }
192    }
193
194    /// Set the behavior when a filtered capability is accessed directly.
195    ///
196    /// Default is [`DenialBehavior::NotFound`].
197    ///
198    /// # Example
199    ///
200    /// ```rust
201    /// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
202    ///
203    /// let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "secret")
204    ///     .denial_behavior(DenialBehavior::Unauthorized);
205    /// ```
206    pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
207        self.denial = behavior;
208        self
209    }
210
211    /// Check if the given capability is visible to the session.
212    pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
213        (self.filter)(session, capability)
214    }
215
216    /// Get the error to return when access is denied.
217    pub fn denial_error(&self, name: &str) -> Error {
218        self.denial.to_error(name)
219    }
220}
221
222/// Type alias for tool filters.
223pub type ToolFilter = CapabilityFilter<Tool>;
224
225/// Type alias for resource filters.
226pub type ResourceFilter = CapabilityFilter<Resource>;
227
228/// Type alias for prompt filters.
229pub type PromptFilter = CapabilityFilter<Prompt>;
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::CallToolResult;
235    use crate::tool::ToolBuilder;
236
237    fn make_test_tool(name: &str) -> Tool {
238        ToolBuilder::new(name)
239            .description("Test tool")
240            .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
241            .build()
242    }
243
244    #[test]
245    fn test_filter_allows() {
246        let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
247        let session = SessionState::new();
248        let allowed = make_test_tool("allowed");
249        let blocked = make_test_tool("blocked");
250
251        assert!(filter.is_visible(&session, &allowed));
252        assert!(!filter.is_visible(&session, &blocked));
253    }
254
255    #[test]
256    fn test_denial_behavior_not_found() {
257        let behavior = DenialBehavior::NotFound;
258        let error = behavior.to_error("test_tool");
259        assert!(matches!(error, Error::JsonRpc(_)));
260    }
261
262    #[test]
263    fn test_denial_behavior_unauthorized() {
264        let behavior = DenialBehavior::Unauthorized;
265        let error = behavior.to_error("test_tool");
266        match error {
267            Error::JsonRpc(e) => {
268                assert_eq!(e.code, -32007); // McpErrorCode::Forbidden
269                assert!(e.message.contains("Unauthorized"));
270            }
271            _ => panic!("Expected JsonRpc error"),
272        }
273    }
274
275    #[test]
276    fn test_denial_behavior_custom() {
277        let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
278        let error = behavior.to_error("secret_tool");
279        match error {
280            Error::Tool(e) => {
281                assert!(e.message.contains("No access to secret_tool"));
282            }
283            _ => panic!("Expected Tool error"),
284        }
285    }
286
287    #[test]
288    fn test_filter_clone() {
289        let filter = CapabilityFilter::new(|_, _: &Tool| true);
290        let cloned = filter.clone();
291        let session = SessionState::new();
292        let tool = make_test_tool("test");
293        assert!(cloned.is_visible(&session, &tool));
294    }
295
296    #[test]
297    fn test_filter_with_denial_behavior() {
298        let filter = CapabilityFilter::new(|_, _: &Tool| false)
299            .denial_behavior(DenialBehavior::Unauthorized);
300
301        let error = filter.denial_error("test");
302        match error {
303            Error::JsonRpc(e) => assert_eq!(e.code, -32007), // McpErrorCode::Forbidden
304            _ => panic!("Expected JsonRpc error"),
305        }
306    }
307}