Skip to main content

turbomcp_server/
visibility.rs

1//! Progressive disclosure through component visibility control.
2//!
3//! This module provides the ability to dynamically show/hide tools, resources,
4//! and prompts based on tags. This enables patterns like:
5//!
6//! - Hiding admin tools until explicitly unlocked
7//! - Progressive disclosure of advanced features
8//! - Role-based component visibility
9//!
10//! # Memory Management
11//!
12//! Session visibility overrides are stored in a per-layer map keyed by session ID.
13//! **IMPORTANT**: You must ensure cleanup happens when sessions end to prevent
14//! memory leaks. Use one of these approaches:
15//!
16//! 1. **Recommended**: Use [`VisibilitySessionGuard`] which automatically cleans up on drop
17//! 2. **Manual**: Call [`VisibilityLayer::clear_session`] when a session disconnects
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use turbomcp_server::visibility::{VisibilityLayer, VisibilitySessionGuard};
23//! use turbomcp_types::component::ComponentFilter;
24//!
25//! // Create a visibility layer that hides admin tools by default
26//! let layer = VisibilityLayer::new(server)
27//!     .with_disabled(ComponentFilter::with_tags(["admin"]));
28//!
29//! // Tools, resources, and prompts tagged with "admin" won't appear
30//! // until explicitly enabled via the RequestContext
31//!
32//! async fn handle_session(layer: &VisibilityLayer<MyHandler>, session_id: &str) {
33//!     // Guard ensures cleanup when it goes out of scope
34//!     let _guard = layer.session_guard(session_id);
35//!
36//!     // Enable admin tools for this session
37//!     layer.enable_for_session(session_id, &["admin".to_string()]);
38//!
39//!     // ... handle requests ...
40//!
41//! } // Guard dropped here, session state automatically cleaned up
42//! ```
43
44use std::collections::HashSet;
45use std::sync::Arc;
46
47use parking_lot::RwLock;
48use turbomcp_core::context::RequestContext;
49use turbomcp_core::error::{McpError, McpResult};
50use turbomcp_core::handler::McpHandler;
51use turbomcp_types::{
52    ComponentFilter, ComponentMeta, Prompt, PromptResult, Resource, ResourceResult, Tool,
53    ToolResult,
54};
55
56/// Type alias for session visibility maps to reduce complexity.
57type SessionVisibilityMap = Arc<dashmap::DashMap<String, HashSet<String>>>;
58
59/// RAII guard that automatically cleans up session visibility state when dropped.
60///
61/// This is the recommended way to manage session visibility lifetime. Create a guard
62/// at the start of a session and let it clean up automatically when the session ends.
63///
64/// # Example
65///
66/// ```rust,ignore
67/// use turbomcp_server::visibility::VisibilityLayer;
68///
69/// async fn handle_connection<H: McpHandler>(layer: &VisibilityLayer<H>, session_id: &str) {
70///     let _guard = layer.session_guard(session_id);
71///
72///     // Enable admin tools for this session
73///     layer.enable_for_session(session_id, &["admin".to_string()]);
74///
75///     // ... handle requests ...
76///
77/// } // State automatically cleaned up here
78/// ```
79#[derive(Debug)]
80pub struct VisibilitySessionGuard {
81    session_id: String,
82    session_enabled: SessionVisibilityMap,
83    session_disabled: SessionVisibilityMap,
84}
85
86impl VisibilitySessionGuard {
87    /// Get the session ID this guard is managing.
88    pub fn session_id(&self) -> &str {
89        &self.session_id
90    }
91}
92
93impl Drop for VisibilitySessionGuard {
94    fn drop(&mut self) {
95        self.session_enabled.remove(&self.session_id);
96        self.session_disabled.remove(&self.session_id);
97    }
98}
99
100/// A visibility layer that wraps an `McpHandler` and filters components.
101///
102/// This allows per-session control over which tools, resources, and prompts
103/// are visible to clients through the `list_*` methods.
104///
105/// **Warning**: Session overrides stored in this layer must be manually cleaned up
106/// via [`clear_session`](Self::clear_session) or by using a [`VisibilitySessionGuard`]
107/// to prevent unbounded memory growth.
108#[derive(Clone)]
109pub struct VisibilityLayer<H> {
110    /// The wrapped handler
111    inner: H,
112    /// Globally disabled component filters
113    global_disabled: Arc<RwLock<Vec<ComponentFilter>>>,
114    /// Session-specific visibility overrides (keyed by session_id)
115    ///
116    /// **Warning**: Entries must be manually cleaned up via [`clear_session`](Self::clear_session)
117    /// or [`session_guard`](Self::session_guard) to prevent unbounded memory growth.
118    session_enabled: SessionVisibilityMap,
119    session_disabled: SessionVisibilityMap,
120}
121
122impl<H: std::fmt::Debug> std::fmt::Debug for VisibilityLayer<H> {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("VisibilityLayer")
125            .field("inner", &self.inner)
126            .field("global_disabled_count", &self.global_disabled.read().len())
127            .field("session_enabled_count", &self.session_enabled.len())
128            .field("session_disabled_count", &self.session_disabled.len())
129            .finish()
130    }
131}
132
133impl<H: McpHandler> VisibilityLayer<H> {
134    /// Create a new visibility layer wrapping the given handler.
135    pub fn new(inner: H) -> Self {
136        Self {
137            inner,
138            global_disabled: Arc::new(RwLock::new(Vec::new())),
139            session_enabled: Arc::new(dashmap::DashMap::new()),
140            session_disabled: Arc::new(dashmap::DashMap::new()),
141        }
142    }
143
144    /// Disable components matching the filter globally.
145    ///
146    /// This affects all sessions unless explicitly enabled per-session.
147    #[must_use]
148    pub fn with_disabled(self, filter: ComponentFilter) -> Self {
149        self.global_disabled.write().push(filter);
150        self
151    }
152
153    /// Disable components with the given tags globally.
154    #[must_use]
155    pub fn disable_tags<I, S>(self, tags: I) -> Self
156    where
157        I: IntoIterator<Item = S>,
158        S: Into<String>,
159    {
160        self.with_disabled(ComponentFilter::with_tags(tags))
161    }
162
163    /// Check if a component is visible given its metadata and session.
164    fn is_visible(&self, meta: &ComponentMeta, session_id: Option<&str>) -> bool {
165        // Check global disabled filters
166        let global_disabled = self.global_disabled.read();
167        let globally_hidden = global_disabled.iter().any(|filter| filter.matches(meta));
168
169        if !globally_hidden {
170            // Not globally hidden - check if session explicitly disabled it
171            if let Some(sid) = session_id
172                && let Some(disabled) = self.session_disabled.get(sid)
173                && meta.tags.iter().any(|t| disabled.contains(t))
174            {
175                return false;
176            }
177            return true;
178        }
179
180        // Globally hidden - check if session explicitly enabled it
181        if let Some(sid) = session_id
182            && let Some(enabled) = self.session_enabled.get(sid)
183            && meta.tags.iter().any(|t| enabled.contains(t))
184        {
185            return true;
186        }
187
188        false
189    }
190
191    /// Enable components with the given tags for a specific session.
192    pub fn enable_for_session(&self, session_id: &str, tags: &[String]) {
193        let mut entry = self
194            .session_enabled
195            .entry(session_id.to_string())
196            .or_default();
197        entry.extend(tags.iter().cloned());
198
199        // Remove from disabled if present
200        if let Some(mut disabled) = self.session_disabled.get_mut(session_id) {
201            for tag in tags {
202                disabled.remove(tag);
203            }
204        }
205    }
206
207    /// Disable components with the given tags for a specific session.
208    pub fn disable_for_session(&self, session_id: &str, tags: &[String]) {
209        let mut entry = self
210            .session_disabled
211            .entry(session_id.to_string())
212            .or_default();
213        entry.extend(tags.iter().cloned());
214
215        // Remove from enabled if present
216        if let Some(mut enabled) = self.session_enabled.get_mut(session_id) {
217            for tag in tags {
218                enabled.remove(tag);
219            }
220        }
221    }
222
223    /// Clear all session-specific overrides.
224    pub fn clear_session(&self, session_id: &str) {
225        self.session_enabled.remove(session_id);
226        self.session_disabled.remove(session_id);
227    }
228
229    /// Create an RAII guard that automatically cleans up session state on drop.
230    ///
231    /// This is the recommended way to manage session visibility lifetime.
232    ///
233    /// # Example
234    ///
235    /// ```rust,ignore
236    /// async fn handle_connection<H: McpHandler>(layer: &VisibilityLayer<H>, session_id: &str) {
237    ///     let _guard = layer.session_guard(session_id);
238    ///
239    ///     layer.enable_for_session(session_id, &["admin".to_string()]);
240    ///
241    ///     // ... handle requests ...
242    ///
243    /// } // State automatically cleaned up here
244    /// ```
245    pub fn session_guard(&self, session_id: impl Into<String>) -> VisibilitySessionGuard {
246        VisibilitySessionGuard {
247            session_id: session_id.into(),
248            session_enabled: Arc::clone(&self.session_enabled),
249            session_disabled: Arc::clone(&self.session_disabled),
250        }
251    }
252
253    /// Get the number of active sessions with visibility overrides.
254    ///
255    /// This is useful for monitoring memory usage.
256    pub fn active_sessions_count(&self) -> usize {
257        // Count unique session IDs across both maps
258        let mut sessions = HashSet::new();
259        for entry in self.session_enabled.iter() {
260            sessions.insert(entry.key().clone());
261        }
262        for entry in self.session_disabled.iter() {
263            sessions.insert(entry.key().clone());
264        }
265        sessions.len()
266    }
267
268    /// Get a reference to the inner handler.
269    pub fn inner(&self) -> &H {
270        &self.inner
271    }
272
273    /// Get a mutable reference to the inner handler.
274    pub fn inner_mut(&mut self) -> &mut H {
275        &mut self.inner
276    }
277
278    /// Unwrap the layer and return the inner handler.
279    pub fn into_inner(self) -> H {
280        self.inner
281    }
282}
283
284#[allow(clippy::manual_async_fn)]
285impl<H: McpHandler> McpHandler for VisibilityLayer<H> {
286    fn server_info(&self) -> turbomcp_types::ServerInfo {
287        self.inner.server_info()
288    }
289
290    fn list_tools(&self) -> Vec<Tool> {
291        self.inner
292            .list_tools()
293            .into_iter()
294            .filter(|tool| {
295                let meta = ComponentMeta::from_meta_value(tool.meta.as_ref());
296                self.is_visible(&meta, None)
297            })
298            .collect()
299    }
300
301    fn list_resources(&self) -> Vec<Resource> {
302        self.inner
303            .list_resources()
304            .into_iter()
305            .filter(|resource| {
306                let meta = ComponentMeta::from_meta_value(resource.meta.as_ref());
307                self.is_visible(&meta, None)
308            })
309            .collect()
310    }
311
312    fn list_prompts(&self) -> Vec<Prompt> {
313        self.inner
314            .list_prompts()
315            .into_iter()
316            .filter(|prompt| {
317                let meta = ComponentMeta::from_meta_value(prompt.meta.as_ref());
318                self.is_visible(&meta, None)
319            })
320            .collect()
321    }
322
323    fn call_tool<'a>(
324        &'a self,
325        name: &'a str,
326        args: serde_json::Value,
327        ctx: &'a RequestContext,
328    ) -> impl std::future::Future<Output = McpResult<ToolResult>> + turbomcp_core::marker::MaybeSend + 'a
329    {
330        async move {
331            // Check if tool is visible for this session
332            let tools = self.inner.list_tools();
333            let tool = tools.iter().find(|t| t.name == name);
334
335            if let Some(tool) = tool {
336                let meta = ComponentMeta::from_meta_value(tool.meta.as_ref());
337                if !self.is_visible(&meta, ctx.get_metadata("session_id")) {
338                    return Err(McpError::tool_not_found(name));
339                }
340            }
341
342            self.inner.call_tool(name, args, ctx).await
343        }
344    }
345
346    fn read_resource<'a>(
347        &'a self,
348        uri: &'a str,
349        ctx: &'a RequestContext,
350    ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
351    + turbomcp_core::marker::MaybeSend
352    + 'a {
353        async move {
354            // Check if resource is visible for this session
355            let resources = self.inner.list_resources();
356            let resource = resources.iter().find(|r| r.uri == uri);
357
358            if let Some(resource) = resource {
359                let meta = ComponentMeta::from_meta_value(resource.meta.as_ref());
360                if !self.is_visible(&meta, ctx.get_metadata("session_id")) {
361                    return Err(McpError::resource_not_found(uri));
362                }
363            }
364
365            self.inner.read_resource(uri, ctx).await
366        }
367    }
368
369    fn get_prompt<'a>(
370        &'a self,
371        name: &'a str,
372        args: Option<serde_json::Value>,
373        ctx: &'a RequestContext,
374    ) -> impl std::future::Future<Output = McpResult<PromptResult>> + turbomcp_core::marker::MaybeSend + 'a
375    {
376        async move {
377            // Check if prompt is visible for this session
378            let prompts = self.inner.list_prompts();
379            let prompt = prompts.iter().find(|p| p.name == name);
380
381            if let Some(prompt) = prompt {
382                let meta = ComponentMeta::from_meta_value(prompt.meta.as_ref());
383                if !self.is_visible(&meta, ctx.get_metadata("session_id")) {
384                    return Err(McpError::prompt_not_found(name));
385                }
386            }
387
388            self.inner.get_prompt(name, args, ctx).await
389        }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[derive(Clone, Debug)]
398    struct MockHandler;
399
400    #[allow(clippy::manual_async_fn)]
401    impl McpHandler for MockHandler {
402        fn server_info(&self) -> turbomcp_types::ServerInfo {
403            turbomcp_types::ServerInfo::new("test", "1.0.0")
404        }
405
406        fn list_tools(&self) -> Vec<Tool> {
407            vec![
408                Tool {
409                    name: "public_tool".to_string(),
410                    description: Some("Public tool".to_string()),
411                    meta: Some({
412                        let mut m = std::collections::HashMap::new();
413                        m.insert("tags".to_string(), serde_json::json!(["public"]));
414                        m
415                    }),
416                    ..Default::default()
417                },
418                Tool {
419                    name: "admin_tool".to_string(),
420                    description: Some("Admin tool".to_string()),
421                    meta: Some({
422                        let mut m = std::collections::HashMap::new();
423                        m.insert("tags".to_string(), serde_json::json!(["admin"]));
424                        m
425                    }),
426                    ..Default::default()
427                },
428            ]
429        }
430
431        fn list_resources(&self) -> Vec<Resource> {
432            vec![]
433        }
434
435        fn list_prompts(&self) -> Vec<Prompt> {
436            vec![]
437        }
438
439        fn call_tool<'a>(
440            &'a self,
441            name: &'a str,
442            _args: serde_json::Value,
443            _ctx: &'a RequestContext,
444        ) -> impl std::future::Future<Output = McpResult<ToolResult>>
445        + turbomcp_core::marker::MaybeSend
446        + 'a {
447            async move { Ok(ToolResult::text(format!("Called {}", name))) }
448        }
449
450        fn read_resource<'a>(
451            &'a self,
452            _uri: &'a str,
453            _ctx: &'a RequestContext,
454        ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
455        + turbomcp_core::marker::MaybeSend
456        + 'a {
457            async move { Err(McpError::resource_not_found("not found")) }
458        }
459
460        fn get_prompt<'a>(
461            &'a self,
462            _name: &'a str,
463            _args: Option<serde_json::Value>,
464            _ctx: &'a RequestContext,
465        ) -> impl std::future::Future<Output = McpResult<PromptResult>>
466        + turbomcp_core::marker::MaybeSend
467        + 'a {
468            async move { Err(McpError::prompt_not_found("not found")) }
469        }
470    }
471
472    #[test]
473    fn test_visibility_layer_hides_admin() {
474        let layer = VisibilityLayer::new(MockHandler).disable_tags(["admin"]);
475
476        let tools = layer.list_tools();
477        assert_eq!(tools.len(), 1);
478        assert_eq!(tools[0].name, "public_tool");
479    }
480
481    #[test]
482    fn test_visibility_layer_shows_all_by_default() {
483        let layer = VisibilityLayer::new(MockHandler);
484
485        let tools = layer.list_tools();
486        assert_eq!(tools.len(), 2);
487    }
488
489    #[test]
490    fn test_session_enable_override() {
491        let layer = VisibilityLayer::new(MockHandler).disable_tags(["admin"]);
492
493        // Initially hidden
494        assert_eq!(layer.list_tools().len(), 1);
495
496        // Enable for session
497        layer.enable_for_session("session1", &["admin".to_string()]);
498
499        // Still hidden in list_tools (doesn't take session context)
500        // but call_tool would work with session context
501        assert_eq!(layer.list_tools().len(), 1);
502
503        // Cleanup
504        layer.clear_session("session1");
505    }
506
507    #[test]
508    fn test_session_guard_cleanup() {
509        let layer = VisibilityLayer::new(MockHandler).disable_tags(["admin"]);
510
511        {
512            let _guard = layer.session_guard("guard-session");
513
514            // Enable admin for this session
515            layer.enable_for_session("guard-session", &["admin".to_string()]);
516            layer.disable_for_session("guard-session", &["public".to_string()]);
517
518            // Session state exists
519            assert!(layer.active_sessions_count() > 0);
520        }
521
522        // After guard drops, session state should be cleaned up
523        assert_eq!(layer.active_sessions_count(), 0);
524    }
525
526    #[test]
527    fn test_active_sessions_count() {
528        let layer = VisibilityLayer::new(MockHandler);
529
530        assert_eq!(layer.active_sessions_count(), 0);
531
532        layer.enable_for_session("session1", &["tag1".to_string()]);
533        assert_eq!(layer.active_sessions_count(), 1);
534
535        layer.disable_for_session("session2", &["tag2".to_string()]);
536        assert_eq!(layer.active_sessions_count(), 2);
537
538        // Same session, different tag - should not increase count
539        layer.enable_for_session("session1", &["tag2".to_string()]);
540        assert_eq!(layer.active_sessions_count(), 2);
541
542        layer.clear_session("session1");
543        assert_eq!(layer.active_sessions_count(), 1);
544
545        layer.clear_session("session2");
546        assert_eq!(layer.active_sessions_count(), 0);
547    }
548}