1use std::collections::HashSet;
45use std::sync::Arc;
46
47use parking_lot::RwLock;
48use turbomcp_core::context::RequestContext;
49use turbomcp_core::error::{McpError, McpResult};
50use turbomcp_core::handler::McpHandler;
51use turbomcp_types::{
52 ComponentFilter, ComponentMeta, Prompt, PromptResult, Resource, ResourceResult, Tool,
53 ToolResult,
54};
55
56type SessionVisibilityMap = Arc<dashmap::DashMap<String, HashSet<String>>>;
58
59#[derive(Debug)]
80pub struct VisibilitySessionGuard {
81 session_id: String,
82 session_enabled: SessionVisibilityMap,
83 session_disabled: SessionVisibilityMap,
84}
85
86impl VisibilitySessionGuard {
87 pub fn session_id(&self) -> &str {
89 &self.session_id
90 }
91}
92
93impl Drop for VisibilitySessionGuard {
94 fn drop(&mut self) {
95 self.session_enabled.remove(&self.session_id);
96 self.session_disabled.remove(&self.session_id);
97 }
98}
99
100#[derive(Clone)]
109pub struct VisibilityLayer<H> {
110 inner: H,
112 global_disabled: Arc<RwLock<Vec<ComponentFilter>>>,
114 session_enabled: SessionVisibilityMap,
119 session_disabled: SessionVisibilityMap,
120}
121
122impl<H: std::fmt::Debug> std::fmt::Debug for VisibilityLayer<H> {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("VisibilityLayer")
125 .field("inner", &self.inner)
126 .field("global_disabled_count", &self.global_disabled.read().len())
127 .field("session_enabled_count", &self.session_enabled.len())
128 .field("session_disabled_count", &self.session_disabled.len())
129 .finish()
130 }
131}
132
133impl<H: McpHandler> VisibilityLayer<H> {
134 pub fn new(inner: H) -> Self {
136 Self {
137 inner,
138 global_disabled: Arc::new(RwLock::new(Vec::new())),
139 session_enabled: Arc::new(dashmap::DashMap::new()),
140 session_disabled: Arc::new(dashmap::DashMap::new()),
141 }
142 }
143
144 #[must_use]
148 pub fn with_disabled(self, filter: ComponentFilter) -> Self {
149 self.global_disabled.write().push(filter);
150 self
151 }
152
153 #[must_use]
155 pub fn disable_tags<I, S>(self, tags: I) -> Self
156 where
157 I: IntoIterator<Item = S>,
158 S: Into<String>,
159 {
160 self.with_disabled(ComponentFilter::with_tags(tags))
161 }
162
163 fn is_visible(&self, meta: &ComponentMeta, session_id: Option<&str>) -> bool {
165 let global_disabled = self.global_disabled.read();
167 let globally_hidden = global_disabled.iter().any(|filter| filter.matches(meta));
168
169 if !globally_hidden {
170 if let Some(sid) = session_id
172 && let Some(disabled) = self.session_disabled.get(sid)
173 && meta.tags.iter().any(|t| disabled.contains(t))
174 {
175 return false;
176 }
177 return true;
178 }
179
180 if let Some(sid) = session_id
182 && let Some(enabled) = self.session_enabled.get(sid)
183 && meta.tags.iter().any(|t| enabled.contains(t))
184 {
185 return true;
186 }
187
188 false
189 }
190
191 pub fn enable_for_session(&self, session_id: &str, tags: &[String]) {
193 let mut entry = self
194 .session_enabled
195 .entry(session_id.to_string())
196 .or_default();
197 entry.extend(tags.iter().cloned());
198
199 if let Some(mut disabled) = self.session_disabled.get_mut(session_id) {
201 for tag in tags {
202 disabled.remove(tag);
203 }
204 }
205 }
206
207 pub fn disable_for_session(&self, session_id: &str, tags: &[String]) {
209 let mut entry = self
210 .session_disabled
211 .entry(session_id.to_string())
212 .or_default();
213 entry.extend(tags.iter().cloned());
214
215 if let Some(mut enabled) = self.session_enabled.get_mut(session_id) {
217 for tag in tags {
218 enabled.remove(tag);
219 }
220 }
221 }
222
223 pub fn clear_session(&self, session_id: &str) {
225 self.session_enabled.remove(session_id);
226 self.session_disabled.remove(session_id);
227 }
228
229 pub fn session_guard(&self, session_id: impl Into<String>) -> VisibilitySessionGuard {
246 VisibilitySessionGuard {
247 session_id: session_id.into(),
248 session_enabled: Arc::clone(&self.session_enabled),
249 session_disabled: Arc::clone(&self.session_disabled),
250 }
251 }
252
253 pub fn active_sessions_count(&self) -> usize {
257 let mut sessions = HashSet::new();
259 for entry in self.session_enabled.iter() {
260 sessions.insert(entry.key().clone());
261 }
262 for entry in self.session_disabled.iter() {
263 sessions.insert(entry.key().clone());
264 }
265 sessions.len()
266 }
267
268 pub fn inner(&self) -> &H {
270 &self.inner
271 }
272
273 pub fn inner_mut(&mut self) -> &mut H {
275 &mut self.inner
276 }
277
278 pub fn into_inner(self) -> H {
280 self.inner
281 }
282}
283
284#[allow(clippy::manual_async_fn)]
285impl<H: McpHandler> McpHandler for VisibilityLayer<H> {
286 fn server_info(&self) -> turbomcp_types::ServerInfo {
287 self.inner.server_info()
288 }
289
290 fn list_tools(&self) -> Vec<Tool> {
291 self.inner
292 .list_tools()
293 .into_iter()
294 .filter(|tool| {
295 let meta = ComponentMeta::from_meta_value(tool.meta.as_ref());
296 self.is_visible(&meta, None)
297 })
298 .collect()
299 }
300
301 fn list_resources(&self) -> Vec<Resource> {
302 self.inner
303 .list_resources()
304 .into_iter()
305 .filter(|resource| {
306 let meta = ComponentMeta::from_meta_value(resource.meta.as_ref());
307 self.is_visible(&meta, None)
308 })
309 .collect()
310 }
311
312 fn list_prompts(&self) -> Vec<Prompt> {
313 self.inner
314 .list_prompts()
315 .into_iter()
316 .filter(|prompt| {
317 let meta = ComponentMeta::from_meta_value(prompt.meta.as_ref());
318 self.is_visible(&meta, None)
319 })
320 .collect()
321 }
322
323 fn call_tool<'a>(
324 &'a self,
325 name: &'a str,
326 args: serde_json::Value,
327 ctx: &'a RequestContext,
328 ) -> impl std::future::Future<Output = McpResult<ToolResult>> + turbomcp_core::marker::MaybeSend + 'a
329 {
330 async move {
331 let tools = self.inner.list_tools();
333 let tool = tools.iter().find(|t| t.name == name);
334
335 if let Some(tool) = tool {
336 let meta = ComponentMeta::from_meta_value(tool.meta.as_ref());
337 if !self.is_visible(&meta, ctx.get_metadata("session_id")) {
338 return Err(McpError::tool_not_found(name));
339 }
340 }
341
342 self.inner.call_tool(name, args, ctx).await
343 }
344 }
345
346 fn read_resource<'a>(
347 &'a self,
348 uri: &'a str,
349 ctx: &'a RequestContext,
350 ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
351 + turbomcp_core::marker::MaybeSend
352 + 'a {
353 async move {
354 let resources = self.inner.list_resources();
356 let resource = resources.iter().find(|r| r.uri == uri);
357
358 if let Some(resource) = resource {
359 let meta = ComponentMeta::from_meta_value(resource.meta.as_ref());
360 if !self.is_visible(&meta, ctx.get_metadata("session_id")) {
361 return Err(McpError::resource_not_found(uri));
362 }
363 }
364
365 self.inner.read_resource(uri, ctx).await
366 }
367 }
368
369 fn get_prompt<'a>(
370 &'a self,
371 name: &'a str,
372 args: Option<serde_json::Value>,
373 ctx: &'a RequestContext,
374 ) -> impl std::future::Future<Output = McpResult<PromptResult>> + turbomcp_core::marker::MaybeSend + 'a
375 {
376 async move {
377 let prompts = self.inner.list_prompts();
379 let prompt = prompts.iter().find(|p| p.name == name);
380
381 if let Some(prompt) = prompt {
382 let meta = ComponentMeta::from_meta_value(prompt.meta.as_ref());
383 if !self.is_visible(&meta, ctx.get_metadata("session_id")) {
384 return Err(McpError::prompt_not_found(name));
385 }
386 }
387
388 self.inner.get_prompt(name, args, ctx).await
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[derive(Clone, Debug)]
398 struct MockHandler;
399
400 #[allow(clippy::manual_async_fn)]
401 impl McpHandler for MockHandler {
402 fn server_info(&self) -> turbomcp_types::ServerInfo {
403 turbomcp_types::ServerInfo::new("test", "1.0.0")
404 }
405
406 fn list_tools(&self) -> Vec<Tool> {
407 vec![
408 Tool {
409 name: "public_tool".to_string(),
410 description: Some("Public tool".to_string()),
411 meta: Some({
412 let mut m = std::collections::HashMap::new();
413 m.insert("tags".to_string(), serde_json::json!(["public"]));
414 m
415 }),
416 ..Default::default()
417 },
418 Tool {
419 name: "admin_tool".to_string(),
420 description: Some("Admin tool".to_string()),
421 meta: Some({
422 let mut m = std::collections::HashMap::new();
423 m.insert("tags".to_string(), serde_json::json!(["admin"]));
424 m
425 }),
426 ..Default::default()
427 },
428 ]
429 }
430
431 fn list_resources(&self) -> Vec<Resource> {
432 vec![]
433 }
434
435 fn list_prompts(&self) -> Vec<Prompt> {
436 vec![]
437 }
438
439 fn call_tool<'a>(
440 &'a self,
441 name: &'a str,
442 _args: serde_json::Value,
443 _ctx: &'a RequestContext,
444 ) -> impl std::future::Future<Output = McpResult<ToolResult>>
445 + turbomcp_core::marker::MaybeSend
446 + 'a {
447 async move { Ok(ToolResult::text(format!("Called {}", name))) }
448 }
449
450 fn read_resource<'a>(
451 &'a self,
452 _uri: &'a str,
453 _ctx: &'a RequestContext,
454 ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
455 + turbomcp_core::marker::MaybeSend
456 + 'a {
457 async move { Err(McpError::resource_not_found("not found")) }
458 }
459
460 fn get_prompt<'a>(
461 &'a self,
462 _name: &'a str,
463 _args: Option<serde_json::Value>,
464 _ctx: &'a RequestContext,
465 ) -> impl std::future::Future<Output = McpResult<PromptResult>>
466 + turbomcp_core::marker::MaybeSend
467 + 'a {
468 async move { Err(McpError::prompt_not_found("not found")) }
469 }
470 }
471
472 #[test]
473 fn test_visibility_layer_hides_admin() {
474 let layer = VisibilityLayer::new(MockHandler).disable_tags(["admin"]);
475
476 let tools = layer.list_tools();
477 assert_eq!(tools.len(), 1);
478 assert_eq!(tools[0].name, "public_tool");
479 }
480
481 #[test]
482 fn test_visibility_layer_shows_all_by_default() {
483 let layer = VisibilityLayer::new(MockHandler);
484
485 let tools = layer.list_tools();
486 assert_eq!(tools.len(), 2);
487 }
488
489 #[test]
490 fn test_session_enable_override() {
491 let layer = VisibilityLayer::new(MockHandler).disable_tags(["admin"]);
492
493 assert_eq!(layer.list_tools().len(), 1);
495
496 layer.enable_for_session("session1", &["admin".to_string()]);
498
499 assert_eq!(layer.list_tools().len(), 1);
502
503 layer.clear_session("session1");
505 }
506
507 #[test]
508 fn test_session_guard_cleanup() {
509 let layer = VisibilityLayer::new(MockHandler).disable_tags(["admin"]);
510
511 {
512 let _guard = layer.session_guard("guard-session");
513
514 layer.enable_for_session("guard-session", &["admin".to_string()]);
516 layer.disable_for_session("guard-session", &["public".to_string()]);
517
518 assert!(layer.active_sessions_count() > 0);
520 }
521
522 assert_eq!(layer.active_sessions_count(), 0);
524 }
525
526 #[test]
527 fn test_active_sessions_count() {
528 let layer = VisibilityLayer::new(MockHandler);
529
530 assert_eq!(layer.active_sessions_count(), 0);
531
532 layer.enable_for_session("session1", &["tag1".to_string()]);
533 assert_eq!(layer.active_sessions_count(), 1);
534
535 layer.disable_for_session("session2", &["tag2".to_string()]);
536 assert_eq!(layer.active_sessions_count(), 2);
537
538 layer.enable_for_session("session1", &["tag2".to_string()]);
540 assert_eq!(layer.active_sessions_count(), 2);
541
542 layer.clear_session("session1");
543 assert_eq!(layer.active_sessions_count(), 1);
544
545 layer.clear_session("session2");
546 assert_eq!(layer.active_sessions_count(), 0);
547 }
548}