Skip to main content

synwire_mcp_adapters/
session.rs

1//! RAII session guard for a single MCP server connection.
2//!
3//! [`McpClientSession`] wraps a [`McpTransport`], caches the tool list, and
4//! disconnects cleanly when dropped.
5
6use std::sync::Arc;
7
8use synwire_core::agents::error::AgentError;
9use synwire_core::mcp::traits::{
10    McpConnectionState, McpServerStatus, McpToolDescriptor, McpTransport,
11};
12
13// ---------------------------------------------------------------------------
14// McpClientSession
15// ---------------------------------------------------------------------------
16
17/// A guard-based session for a single MCP server connection.
18///
19/// The session connects on creation and disconnects when dropped. Tool
20/// descriptors are cached after the first successful [`list_tools`] call
21/// to avoid redundant round-trips.
22///
23/// [`list_tools`]: McpTransport::list_tools
24pub struct McpClientSession {
25    /// Server name (used for logging).
26    name: String,
27    /// The underlying transport.
28    transport: Arc<dyn McpTransport>,
29    /// Cached tool descriptors (populated by [`populate_tool_cache`]).
30    tool_cache: Vec<McpToolDescriptor>,
31}
32
33impl std::fmt::Debug for McpClientSession {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("McpClientSession")
36            .field("name", &self.name)
37            .field("tool_cache_len", &self.tool_cache.len())
38            .finish_non_exhaustive()
39    }
40}
41
42impl McpClientSession {
43    /// Connects to the server and returns a new session.
44    ///
45    /// # Errors
46    ///
47    /// Returns [`AgentError`] if the transport fails to connect.
48    pub async fn connect(
49        name: impl Into<String>,
50        transport: Arc<dyn McpTransport>,
51    ) -> Result<Self, AgentError> {
52        transport.connect().await?;
53        let name = name.into();
54        tracing::debug!(server = %name, "McpClientSession connected");
55        Ok(Self {
56            name,
57            transport,
58            tool_cache: Vec::new(),
59        })
60    }
61
62    /// Returns the server name.
63    #[must_use]
64    pub fn name(&self) -> &str {
65        &self.name
66    }
67
68    /// Returns the current connection status from the transport.
69    pub async fn status(&self) -> McpServerStatus {
70        self.transport.status().await
71    }
72
73    /// Returns `true` if the session is currently connected.
74    pub async fn is_connected(&self) -> bool {
75        self.transport.status().await.state == McpConnectionState::Connected
76    }
77
78    /// Returns a reference to the cached tool descriptors.
79    ///
80    /// The cache is empty until [`populate_tool_cache`](Self::populate_tool_cache)
81    /// has been called.
82    #[must_use]
83    pub fn cached_tools(&self) -> &[McpToolDescriptor] {
84        &self.tool_cache
85    }
86
87    /// Fetches the tool list from the server and caches the results.
88    ///
89    /// Overwrites any previously cached descriptors.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`AgentError`] if the transport call fails.
94    pub async fn populate_tool_cache(&mut self) -> Result<(), AgentError> {
95        self.tool_cache = self.transport.list_tools().await?;
96        tracing::debug!(
97            server = %self.name,
98            tools = self.tool_cache.len(),
99            "Tool cache populated"
100        );
101        Ok(())
102    }
103
104    /// Returns a reference to the underlying transport.
105    #[must_use]
106    pub fn transport(&self) -> &Arc<dyn McpTransport> {
107        &self.transport
108    }
109}
110
111impl Drop for McpClientSession {
112    fn drop(&mut self) {
113        // Spawn a best-effort disconnect on the tokio runtime.
114        // If no runtime is available (e.g. during test teardown), the error is
115        // silently ignored because we cannot await here.
116        let transport = Arc::clone(&self.transport);
117        let name = self.name.clone();
118        let _handle = tokio::task::spawn(async move {
119            if let Err(e) = transport.disconnect().await {
120                tracing::warn!(server = %name, error = %e, "Error during McpClientSession drop disconnect");
121            }
122        });
123    }
124}