tower_mcp/filter.rs
1//! Session-based capability filtering.
2//!
3//! This module provides types for filtering tools, resources, and prompts
4//! based on session state. Different sessions can see different capabilities
5//! based on user identity, roles, API keys, or other session context.
6//!
7//! # Example
8//!
9//! ```rust
10//! use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Debug, Deserialize, JsonSchema)]
15//! struct Input { value: String }
16//!
17//! let public_tool = ToolBuilder::new("public")
18//! .description("Available to everyone")
19//! .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
20//! .build();
21//!
22//! let admin_tool = ToolBuilder::new("admin")
23//! .description("Admin only")
24//! .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
25//! .build();
26//!
27//! let router = McpRouter::new()
28//! .tool(public_tool)
29//! .tool(admin_tool)
30//! .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
31//! // In real code, check session.extensions() for auth claims
32//! tool.name() != "admin"
33//! }));
34//! ```
35
36use std::sync::Arc;
37
38use crate::error::{Error, JsonRpcError};
39use crate::prompt::Prompt;
40use crate::resource::Resource;
41use crate::session::SessionState;
42use crate::tool::Tool;
43
44/// Trait for capabilities that can be filtered by session.
45///
46/// Implemented for [`Tool`], [`Resource`], and [`Prompt`].
47pub trait Filterable: Send + Sync {
48 /// Returns the name of this capability.
49 fn name(&self) -> &str;
50}
51
52impl Filterable for Tool {
53 fn name(&self) -> &str {
54 &self.name
55 }
56}
57
58impl Filterable for Resource {
59 fn name(&self) -> &str {
60 &self.name
61 }
62}
63
64impl Filterable for Prompt {
65 fn name(&self) -> &str {
66 &self.name
67 }
68}
69
70/// Behavior when a filtered capability is accessed directly.
71#[derive(Clone, Default)]
72pub enum DenialBehavior {
73 /// Return "method not found" error -- hides the capability entirely.
74 ///
75 /// This is the default and recommended for security. Use this in
76 /// multi-tenant scenarios where tools should not be discoverable by
77 /// unauthorized users.
78 #[default]
79 NotFound,
80 /// Return an "unauthorized" error, revealing the capability exists.
81 ///
82 /// Use this when the client should know about the capability but is
83 /// not permitted to invoke it (e.g., premium features behind an
84 /// upgrade prompt).
85 Unauthorized,
86 /// Use a custom error generator for application-specific responses.
87 ///
88 /// Use this when you need custom status codes, domain-specific error
89 /// messages, or structured error payloads.
90 Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
91}
92
93impl std::fmt::Debug for DenialBehavior {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Self::NotFound => write!(f, "NotFound"),
97 Self::Unauthorized => write!(f, "Unauthorized"),
98 Self::Custom(_) => write!(f, "Custom(...)"),
99 }
100 }
101}
102
103impl DenialBehavior {
104 /// Create a custom denial behavior with the given error generator.
105 pub fn custom<F>(f: F) -> Self
106 where
107 F: Fn(&str) -> Error + Send + Sync + 'static,
108 {
109 Self::Custom(Arc::new(f))
110 }
111
112 /// Generate the appropriate error for a denied capability.
113 pub fn to_error(&self, name: &str) -> Error {
114 match self {
115 Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
116 Self::Unauthorized => {
117 Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
118 }
119 Self::Custom(f) => f(name),
120 }
121 }
122}
123
124/// A filter for capabilities based on session state.
125///
126/// Use this to control which tools, resources, or prompts are visible
127/// to each session.
128///
129/// # Example
130///
131/// ```rust
132/// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
133///
134/// // Filter that only shows tools starting with "public_"
135/// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
136/// tool.name().starts_with("public_")
137/// });
138///
139/// // Filter with custom denial behavior
140/// let filter_with_401 = CapabilityFilter::new(|_session, tool: &Tool| {
141/// tool.name() != "admin"
142/// }).denial_behavior(DenialBehavior::Unauthorized);
143/// ```
144pub struct CapabilityFilter<T: Filterable> {
145 #[allow(clippy::type_complexity)]
146 filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
147 denial: DenialBehavior,
148}
149
150impl<T: Filterable> Clone for CapabilityFilter<T> {
151 fn clone(&self) -> Self {
152 Self {
153 filter: Arc::clone(&self.filter),
154 denial: self.denial.clone(),
155 }
156 }
157}
158
159impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct("CapabilityFilter")
162 .field("denial", &self.denial)
163 .finish_non_exhaustive()
164 }
165}
166
167impl<T: Filterable> CapabilityFilter<T> {
168 /// Create a new capability filter with the given predicate.
169 ///
170 /// The predicate receives the session state and capability, and returns
171 /// `true` if the capability should be visible to the session.
172 ///
173 /// # Example
174 ///
175 /// ```rust
176 /// use tower_mcp::{CapabilityFilter, Tool, Filterable};
177 ///
178 /// let filter = CapabilityFilter::new(|_session, tool: &Tool| {
179 /// // Check session extensions for auth claims
180 /// // session.extensions().get::<UserClaims>()...
181 /// tool.name() != "admin_only"
182 /// });
183 /// ```
184 pub fn new<F>(filter: F) -> Self
185 where
186 F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
187 {
188 Self {
189 filter: Arc::new(filter),
190 denial: DenialBehavior::default(),
191 }
192 }
193
194 /// Set the behavior when a filtered capability is accessed directly.
195 ///
196 /// Default is [`DenialBehavior::NotFound`].
197 ///
198 /// # Example
199 ///
200 /// ```rust
201 /// use tower_mcp::{CapabilityFilter, DenialBehavior, Tool, Filterable};
202 ///
203 /// let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "secret")
204 /// .denial_behavior(DenialBehavior::Unauthorized);
205 /// ```
206 pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
207 self.denial = behavior;
208 self
209 }
210
211 /// Check if the given capability is visible to the session.
212 pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
213 (self.filter)(session, capability)
214 }
215
216 /// Get the error to return when access is denied.
217 pub fn denial_error(&self, name: &str) -> Error {
218 self.denial.to_error(name)
219 }
220}
221
222/// Type alias for tool filters.
223pub type ToolFilter = CapabilityFilter<Tool>;
224
225/// Type alias for resource filters.
226pub type ResourceFilter = CapabilityFilter<Resource>;
227
228/// Type alias for prompt filters.
229pub type PromptFilter = CapabilityFilter<Prompt>;
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::CallToolResult;
235 use crate::tool::ToolBuilder;
236
237 fn make_test_tool(name: &str) -> Tool {
238 ToolBuilder::new(name)
239 .description("Test tool")
240 .handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
241 .build()
242 }
243
244 #[test]
245 fn test_filter_allows() {
246 let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
247 let session = SessionState::new();
248 let allowed = make_test_tool("allowed");
249 let blocked = make_test_tool("blocked");
250
251 assert!(filter.is_visible(&session, &allowed));
252 assert!(!filter.is_visible(&session, &blocked));
253 }
254
255 #[test]
256 fn test_denial_behavior_not_found() {
257 let behavior = DenialBehavior::NotFound;
258 let error = behavior.to_error("test_tool");
259 assert!(matches!(error, Error::JsonRpc(_)));
260 }
261
262 #[test]
263 fn test_denial_behavior_unauthorized() {
264 let behavior = DenialBehavior::Unauthorized;
265 let error = behavior.to_error("test_tool");
266 match error {
267 Error::JsonRpc(e) => {
268 assert_eq!(e.code, -32007); // McpErrorCode::Forbidden
269 assert!(e.message.contains("Unauthorized"));
270 }
271 _ => panic!("Expected JsonRpc error"),
272 }
273 }
274
275 #[test]
276 fn test_denial_behavior_custom() {
277 let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
278 let error = behavior.to_error("secret_tool");
279 match error {
280 Error::Tool(e) => {
281 assert!(e.message.contains("No access to secret_tool"));
282 }
283 _ => panic!("Expected Tool error"),
284 }
285 }
286
287 #[test]
288 fn test_filter_clone() {
289 let filter = CapabilityFilter::new(|_, _: &Tool| true);
290 let cloned = filter.clone();
291 let session = SessionState::new();
292 let tool = make_test_tool("test");
293 assert!(cloned.is_visible(&session, &tool));
294 }
295
296 #[test]
297 fn test_filter_with_denial_behavior() {
298 let filter = CapabilityFilter::new(|_, _: &Tool| false)
299 .denial_behavior(DenialBehavior::Unauthorized);
300
301 let error = filter.denial_error("test");
302 match error {
303 Error::JsonRpc(e) => assert_eq!(e.code, -32007), // McpErrorCode::Forbidden
304 _ => panic!("Expected JsonRpc error"),
305 }
306 }
307}