Skip to main content

tower_mcp/
filter.rs

1//! Session-based capability filtering.
2//!
3//! This module provides types for filtering tools, resources, and prompts
4//! based on session state. Different sessions can see different capabilities
5//! based on user identity, roles, API keys, or other session context.
6//!
7//! # Example
8//!
9//! ```rust
10//! use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Debug, Deserialize, JsonSchema)]
15//! struct Input { value: String }
16//!
17//! let public_tool = ToolBuilder::new("public")
18//!     .description("Available to everyone")
19//!     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
20//!     .build()
21//!     .unwrap();
22//!
23//! let admin_tool = ToolBuilder::new("admin")
24//!     .description("Admin only")
25//!     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
26//!     .build()
27//!     .unwrap();
28//!
29//! let router = McpRouter::new()
30//!     .tool(public_tool)
31//!     .tool(admin_tool)
32//!     .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
33//!         // In real code, check session.extensions() for auth claims
34//!         tool.name() != "admin"
35//!     }));
36//! ```
37
38use std::sync::Arc;
39
40use crate::error::{Error, JsonRpcError};
41use crate::prompt::Prompt;
42use crate::resource::Resource;
43use crate::session::SessionState;
44use crate::tool::Tool;
45
46/// Trait for capabilities that can be filtered by session.
47///
48/// Implemented for [`Tool`], [`Resource`], and [`Prompt`].
49pub trait Filterable: Send + Sync {
50    /// Returns the name of this capability.
51    fn name(&self) -> &str;
52}
53
54impl Filterable for Tool {
55    fn name(&self) -> &str {
56        &self.name
57    }
58}
59
60impl Filterable for Resource {
61    fn name(&self) -> &str {
62        &self.name
63    }
64}
65
66impl Filterable for Prompt {
67    fn name(&self) -> &str {
68        &self.name
69    }
70}
71
72/// Behavior when a filtered capability is accessed directly.
73#[derive(Clone, Default)]
74pub enum DenialBehavior {
75    /// Return "method not found" error - don't reveal the capability exists.
76    /// This is the default and recommended for security.
77    #[default]
78    NotFound,
79    /// Return an "unauthorized" error, revealing the capability exists.
80    Unauthorized,
81    /// Use a custom error generator.
82    Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
83}
84
85impl std::fmt::Debug for DenialBehavior {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            Self::NotFound => write!(f, "NotFound"),
89            Self::Unauthorized => write!(f, "Unauthorized"),
90            Self::Custom(_) => write!(f, "Custom(...)"),
91        }
92    }
93}
94
95impl DenialBehavior {
96    /// Create a custom denial behavior with the given error generator.
97    pub fn custom<F>(f: F) -> Self
98    where
99        F: Fn(&str) -> Error + Send + Sync + 'static,
100    {
101        Self::Custom(Arc::new(f))
102    }
103
104    /// Generate the appropriate error for a denied capability.
105    pub fn to_error(&self, name: &str) -> Error {
106        match self {
107            Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
108            Self::Unauthorized => {
109                Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
110            }
111            Self::Custom(f) => f(name),
112        }
113    }
114}
115
116/// A filter for capabilities based on session state.
117///
118/// Use this to control which tools, resources, or prompts are visible
119/// to each session.
120///
121/// # Example
122///
123/// ```rust
124/// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
125///
126/// // Filter that only shows tools starting with "public_"
127/// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
128///     tool.name().starts_with("public_")
129/// });
130///
131/// // Filter with custom denial behavior
132/// let filter_with_401 = CapabilityFilter::new(|_session, tool: &Tool| {
133///     tool.name() != "admin"
134/// }).denial_behavior(DenialBehavior::Unauthorized);
135/// ```
136pub struct CapabilityFilter<T: Filterable> {
137    #[allow(clippy::type_complexity)]
138    filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
139    denial: DenialBehavior,
140}
141
142impl<T: Filterable> Clone for CapabilityFilter<T> {
143    fn clone(&self) -> Self {
144        Self {
145            filter: Arc::clone(&self.filter),
146            denial: self.denial.clone(),
147        }
148    }
149}
150
151impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.debug_struct("CapabilityFilter")
154            .field("denial", &self.denial)
155            .finish_non_exhaustive()
156    }
157}
158
159impl<T: Filterable> CapabilityFilter<T> {
160    /// Create a new capability filter with the given predicate.
161    ///
162    /// The predicate receives the session state and capability, and returns
163    /// `true` if the capability should be visible to the session.
164    ///
165    /// # Example
166    ///
167    /// ```rust
168    /// use tower_mcp::{CapabilityFilter, Tool, Filterable};
169    ///
170    /// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
171    ///     // Check session extensions for auth claims
172    ///     // session.extensions().get::<UserClaims>()...
173    ///     tool.name() != "admin_only"
174    /// });
175    /// ```
176    pub fn new<F>(filter: F) -> Self
177    where
178        F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
179    {
180        Self {
181            filter: Arc::new(filter),
182            denial: DenialBehavior::default(),
183        }
184    }
185
186    /// Set the behavior when a filtered capability is accessed directly.
187    ///
188    /// Default is [`DenialBehavior::NotFound`].
189    ///
190    /// # Example
191    ///
192    /// ```rust
193    /// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
194    ///
195    /// let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "secret")
196    ///     .denial_behavior(DenialBehavior::Unauthorized);
197    /// ```
198    pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
199        self.denial = behavior;
200        self
201    }
202
203    /// Check if the given capability is visible to the session.
204    pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
205        (self.filter)(session, capability)
206    }
207
208    /// Get the error to return when access is denied.
209    pub fn denial_error(&self, name: &str) -> Error {
210        self.denial.to_error(name)
211    }
212}
213
214/// Type alias for tool filters.
215pub type ToolFilter = CapabilityFilter<Tool>;
216
217/// Type alias for resource filters.
218pub type ResourceFilter = CapabilityFilter<Resource>;
219
220/// Type alias for prompt filters.
221pub type PromptFilter = CapabilityFilter<Prompt>;
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::CallToolResult;
227    use crate::tool::ToolBuilder;
228
229    fn make_test_tool(name: &str) -> Tool {
230        ToolBuilder::new(name)
231            .description("Test tool")
232            .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
233            .build()
234            .unwrap()
235    }
236
237    #[test]
238    fn test_filter_allows() {
239        let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
240        let session = SessionState::new();
241        let allowed = make_test_tool("allowed");
242        let blocked = make_test_tool("blocked");
243
244        assert!(filter.is_visible(&session, &allowed));
245        assert!(!filter.is_visible(&session, &blocked));
246    }
247
248    #[test]
249    fn test_denial_behavior_not_found() {
250        let behavior = DenialBehavior::NotFound;
251        let error = behavior.to_error("test_tool");
252        assert!(matches!(error, Error::JsonRpc(_)));
253    }
254
255    #[test]
256    fn test_denial_behavior_unauthorized() {
257        let behavior = DenialBehavior::Unauthorized;
258        let error = behavior.to_error("test_tool");
259        match error {
260            Error::JsonRpc(e) => {
261                assert_eq!(e.code, -32007); // McpErrorCode::Forbidden
262                assert!(e.message.contains("Unauthorized"));
263            }
264            _ => panic!("Expected JsonRpc error"),
265        }
266    }
267
268    #[test]
269    fn test_denial_behavior_custom() {
270        let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
271        let error = behavior.to_error("secret_tool");
272        match error {
273            Error::Tool(e) => {
274                assert!(e.message.contains("No access to secret_tool"));
275            }
276            _ => panic!("Expected Tool error"),
277        }
278    }
279
280    #[test]
281    fn test_filter_clone() {
282        let filter = CapabilityFilter::new(|_, _: &Tool| true);
283        let cloned = filter.clone();
284        let session = SessionState::new();
285        let tool = make_test_tool("test");
286        assert!(cloned.is_visible(&session, &tool));
287    }
288
289    #[test]
290    fn test_filter_with_denial_behavior() {
291        let filter = CapabilityFilter::new(|_, _: &Tool| false)
292            .denial_behavior(DenialBehavior::Unauthorized);
293
294        let error = filter.denial_error("test");
295        match error {
296            Error::JsonRpc(e) => assert_eq!(e.code, -32007), // McpErrorCode::Forbidden
297            _ => panic!("Expected JsonRpc error"),
298        }
299    }
300}