Skip to main content

wraith_runtime/
mcp_client.rs

1use std::collections::BTreeMap;
2
3use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
4use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum McpClientTransport {
8    Stdio(McpStdioTransport),
9    Sse(McpRemoteTransport),
10    Http(McpRemoteTransport),
11    WebSocket(McpRemoteTransport),
12    Sdk(McpSdkTransport),
13    ManagedProxy(McpManagedProxyTransport),
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct McpStdioTransport {
18    pub command: String,
19    pub args: Vec<String>,
20    pub env: BTreeMap<String, String>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct McpRemoteTransport {
25    pub url: String,
26    pub headers: BTreeMap<String, String>,
27    pub headers_helper: Option<String>,
28    pub auth: McpClientAuth,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct McpSdkTransport {
33    pub name: String,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct McpManagedProxyTransport {
38    pub url: String,
39    pub id: String,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum McpClientAuth {
44    None,
45    OAuth(McpOAuthConfig),
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct McpClientBootstrap {
50    pub server_name: String,
51    pub normalized_name: String,
52    pub tool_prefix: String,
53    pub signature: Option<String>,
54    pub transport: McpClientTransport,
55}
56
57impl McpClientBootstrap {
58    #[must_use]
59    pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self {
60        Self {
61            server_name: server_name.to_string(),
62            normalized_name: normalize_name_for_mcp(server_name),
63            tool_prefix: mcp_tool_prefix(server_name),
64            signature: mcp_server_signature(&config.config),
65            transport: McpClientTransport::from_config(&config.config),
66        }
67    }
68}
69
70impl McpClientTransport {
71    #[must_use]
72    pub fn from_config(config: &McpServerConfig) -> Self {
73        match config {
74            McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport {
75                command: config.command.clone(),
76                args: config.args.clone(),
77                env: config.env.clone(),
78            }),
79            McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
80                url: config.url.clone(),
81                headers: config.headers.clone(),
82                headers_helper: config.headers_helper.clone(),
83                auth: McpClientAuth::from_oauth(config.oauth.clone()),
84            }),
85            McpServerConfig::Http(config) => Self::Http(McpRemoteTransport {
86                url: config.url.clone(),
87                headers: config.headers.clone(),
88                headers_helper: config.headers_helper.clone(),
89                auth: McpClientAuth::from_oauth(config.oauth.clone()),
90            }),
91            McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport {
92                url: config.url.clone(),
93                headers: config.headers.clone(),
94                headers_helper: config.headers_helper.clone(),
95                auth: McpClientAuth::None,
96            }),
97            McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport {
98                name: config.name.clone(),
99            }),
100            McpServerConfig::ManagedProxy(config) => Self::ManagedProxy(McpManagedProxyTransport {
101                url: config.url.clone(),
102                id: config.id.clone(),
103            }),
104        }
105    }
106}
107
108impl McpClientAuth {
109    #[must_use]
110    pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
111        oauth.map_or(Self::None, Self::OAuth)
112    }
113
114    #[must_use]
115    pub const fn requires_user_auth(&self) -> bool {
116        matches!(self, Self::OAuth(_))
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::collections::BTreeMap;
123
124    use crate::config::{
125        ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
126        McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
127    };
128
129    use super::{McpClientAuth, McpClientBootstrap, McpClientTransport};
130
131    #[test]
132    fn bootstraps_stdio_servers_into_transport_targets() {
133        let config = ScopedMcpServerConfig {
134            scope: ConfigSource::User,
135            config: McpServerConfig::Stdio(McpStdioServerConfig {
136                command: "uvx".to_string(),
137                args: vec!["mcp-server".to_string()],
138                env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
139            }),
140        };
141
142        let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config);
143        assert_eq!(bootstrap.normalized_name, "stdio-server");
144        assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__");
145        assert_eq!(
146            bootstrap.signature.as_deref(),
147            Some("stdio:[uvx|mcp-server]")
148        );
149        match bootstrap.transport {
150            McpClientTransport::Stdio(transport) => {
151                assert_eq!(transport.command, "uvx");
152                assert_eq!(transport.args, vec!["mcp-server"]);
153                assert_eq!(
154                    transport.env.get("TOKEN").map(String::as_str),
155                    Some("secret")
156                );
157            }
158            other => panic!("expected stdio transport, got {other:?}"),
159        }
160    }
161
162    #[test]
163    fn bootstraps_remote_servers_with_oauth_auth() {
164        let config = ScopedMcpServerConfig {
165            scope: ConfigSource::Project,
166            config: McpServerConfig::Http(McpRemoteServerConfig {
167                url: "https://vendor.example/mcp".to_string(),
168                headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]),
169                headers_helper: Some("helper.sh".to_string()),
170                oauth: Some(McpOAuthConfig {
171                    client_id: Some("client-id".to_string()),
172                    callback_port: Some(7777),
173                    auth_server_metadata_url: Some(
174                        "https://issuer.example/.well-known/oauth-authorization-server".to_string(),
175                    ),
176                    xaa: Some(true),
177                }),
178            }),
179        };
180
181        let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config);
182        assert_eq!(bootstrap.normalized_name, "remote_server");
183        match bootstrap.transport {
184            McpClientTransport::Http(transport) => {
185                assert_eq!(transport.url, "https://vendor.example/mcp");
186                assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh"));
187                assert!(transport.auth.requires_user_auth());
188                match transport.auth {
189                    McpClientAuth::OAuth(oauth) => {
190                        assert_eq!(oauth.client_id.as_deref(), Some("client-id"));
191                    }
192                    other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"),
193                }
194            }
195            other => panic!("expected http transport, got {other:?}"),
196        }
197    }
198
199    #[test]
200    fn bootstraps_websocket_and_sdk_transports_without_oauth() {
201        let ws = ScopedMcpServerConfig {
202            scope: ConfigSource::Local,
203            config: McpServerConfig::Ws(McpWebSocketServerConfig {
204                url: "wss://vendor.example/mcp".to_string(),
205                headers: BTreeMap::new(),
206                headers_helper: None,
207            }),
208        };
209        let sdk = ScopedMcpServerConfig {
210            scope: ConfigSource::Local,
211            config: McpServerConfig::Sdk(McpSdkServerConfig {
212                name: "sdk-server".to_string(),
213            }),
214        };
215
216        let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws);
217        match ws_bootstrap.transport {
218            McpClientTransport::WebSocket(transport) => {
219                assert_eq!(transport.url, "wss://vendor.example/mcp");
220                assert!(!transport.auth.requires_user_auth());
221            }
222            other => panic!("expected websocket transport, got {other:?}"),
223        }
224
225        let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk);
226        assert_eq!(sdk_bootstrap.signature, None);
227        match sdk_bootstrap.transport {
228            McpClientTransport::Sdk(transport) => {
229                assert_eq!(transport.name, "sdk-server");
230            }
231            other => panic!("expected sdk transport, got {other:?}"),
232        }
233    }
234}