1use std::{sync::Arc, time::Duration};
2
3use crate::schema::{
4 schema_utils::{
5 self, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
6 ResultFromServer, ServerMessage,
7 },
8 CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
9 CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
10 InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
11 ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
12 ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
13 LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams,
14 RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
15 SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
16 UnsubscribeRequest, UnsubscribeRequestParams,
17};
18use async_trait::async_trait;
19use rust_mcp_transport::{McpDispatch, MessageDispatcher};
20
21use crate::{error::SdkResult, utils::format_assertion_message};
22
23#[async_trait]
24pub trait McpClient: Sync + Send {
25 async fn start(self: Arc<Self>) -> SdkResult<()>;
26 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
27
28 async fn shut_down(&self) -> SdkResult<()>;
29 async fn is_shut_down(&self) -> bool;
30
31 async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>
32 where
33 MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>;
34
35 fn client_info(&self) -> &InitializeRequestParams;
36 fn server_info(&self) -> Option<InitializeResult>;
37
38 #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
39 fn get_client_info(&self) -> &InitializeRequestParams {
40 self.client_info()
41 }
42
43 #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")]
44 fn get_server_info(&self) -> Option<InitializeResult> {
45 self.server_info()
46 }
47
48 fn is_initialized(&self) -> bool {
50 self.server_info().is_some()
51 }
52
53 fn server_version(&self) -> Option<Implementation> {
56 self.server_info()
57 .map(|server_details| server_details.server_info)
58 }
59
60 #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")]
61 fn get_server_version(&self) -> Option<Implementation> {
62 self.server_info()
63 .map(|server_details| server_details.server_info)
64 }
65
66 fn server_capabilities(&self) -> Option<ServerCapabilities> {
69 self.server_info().map(|item| item.capabilities)
70 }
71
72 #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")]
73 fn get_server_capabilities(&self) -> Option<ServerCapabilities> {
74 self.server_info().map(|item| item.capabilities)
75 }
76
77 fn server_has_tools(&self) -> Option<bool> {
92 self.server_info()
93 .map(|server_details| server_details.capabilities.tools.is_some())
94 }
95
96 fn server_has_prompts(&self) -> Option<bool> {
108 self.server_info()
109 .map(|server_details| server_details.capabilities.prompts.is_some())
110 }
111
112 fn server_has_experimental(&self) -> Option<bool> {
124 self.server_info()
125 .map(|server_details| server_details.capabilities.experimental.is_some())
126 }
127
128 fn server_has_resources(&self) -> Option<bool> {
140 self.server_info()
141 .map(|server_details| server_details.capabilities.resources.is_some())
142 }
143
144 fn server_supports_logging(&self) -> Option<bool> {
156 self.server_info()
157 .map(|server_details| server_details.capabilities.logging.is_some())
158 }
159 #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")]
160 fn get_instructions(&self) -> Option<String> {
161 self.server_info()?.instructions
162 }
163
164 fn instructions(&self) -> Option<String> {
165 self.server_info()?.instructions
166 }
167
168 async fn request(
174 &self,
175 request: RequestFromClient,
176 timeout: Option<Duration>,
177 ) -> SdkResult<ResultFromServer> {
178 let sender = self.sender().await.read().await;
179 let sender = sender
180 .as_ref()
181 .ok_or(schema_utils::SdkError::connection_closed())?;
182
183 let response = sender
185 .send(MessageFromClient::RequestFromClient(request), None, timeout)
186 .await?;
187
188 let server_message = response.ok_or_else(|| {
189 RpcError::internal_error()
190 .with_message("An empty response was received from the server.".to_string())
191 })?;
192
193 if server_message.is_error() {
194 return Err(server_message.as_error()?.error.into());
195 }
196
197 return Ok(server_message.as_response()?.result);
198 }
199
200 async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
204 let sender = self.sender().await.read().await;
205 let sender = sender
206 .as_ref()
207 .ok_or(schema_utils::SdkError::connection_closed())?;
208 sender
209 .send(
210 MessageFromClient::NotificationFromClient(notification),
211 None,
212 None,
213 )
214 .await?;
215 Ok(())
216 }
217
218 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
229 let ping_request = PingRequest::new(None);
230 let response = self.request(ping_request.into(), timeout).await?;
231 Ok(response.try_into()?)
232 }
233
234 async fn complete(
235 &self,
236 params: CompleteRequestParams,
237 ) -> SdkResult<crate::schema::CompleteResult> {
238 let request = CompleteRequest::new(params);
239 let response = self.request(request.into(), None).await?;
240 Ok(response.try_into()?)
241 }
242
243 async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
244 let request = SetLevelRequest::new(SetLevelRequestParams { level });
245 let response = self.request(request.into(), None).await?;
246 Ok(response.try_into()?)
247 }
248
249 async fn get_prompt(
250 &self,
251 params: GetPromptRequestParams,
252 ) -> SdkResult<crate::schema::GetPromptResult> {
253 let request = GetPromptRequest::new(params);
254 let response = self.request(request.into(), None).await?;
255 Ok(response.try_into()?)
256 }
257
258 async fn list_prompts(
259 &self,
260 params: Option<ListPromptsRequestParams>,
261 ) -> SdkResult<crate::schema::ListPromptsResult> {
262 let request = ListPromptsRequest::new(params);
263 let response = self.request(request.into(), None).await?;
264 Ok(response.try_into()?)
265 }
266
267 async fn list_resources(
268 &self,
269 params: Option<ListResourcesRequestParams>,
270 ) -> SdkResult<crate::schema::ListResourcesResult> {
271 let request =
276 ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
277 let response = self.request(request.into(), None).await?;
278 Ok(response.try_into()?)
279 }
280
281 async fn list_resource_templates(
282 &self,
283 params: Option<ListResourceTemplatesRequestParams>,
284 ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
285 let request = ListResourceTemplatesRequest::new(params);
286 let response = self.request(request.into(), None).await?;
287 Ok(response.try_into()?)
288 }
289
290 async fn read_resource(
291 &self,
292 params: ReadResourceRequestParams,
293 ) -> SdkResult<crate::schema::ReadResourceResult> {
294 let request = ReadResourceRequest::new(params);
295 let response = self.request(request.into(), None).await?;
296 Ok(response.try_into()?)
297 }
298
299 async fn subscribe_resource(
300 &self,
301 params: SubscribeRequestParams,
302 ) -> SdkResult<crate::schema::Result> {
303 let request = SubscribeRequest::new(params);
304 let response = self.request(request.into(), None).await?;
305 Ok(response.try_into()?)
306 }
307
308 async fn unsubscribe_resource(
309 &self,
310 params: UnsubscribeRequestParams,
311 ) -> SdkResult<crate::schema::Result> {
312 let request = UnsubscribeRequest::new(params);
313 let response = self.request(request.into(), None).await?;
314 Ok(response.try_into()?)
315 }
316
317 async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
318 let request = CallToolRequest::new(params);
319 let response = self.request(request.into(), None).await?;
320 Ok(response.try_into()?)
321 }
322
323 async fn list_tools(
324 &self,
325 params: Option<ListToolsRequestParams>,
326 ) -> SdkResult<crate::schema::ListToolsResult> {
327 let request = ListToolsRequest::new(params);
328 let response = self.request(request.into(), None).await?;
329 Ok(response.try_into()?)
330 }
331
332 async fn send_roots_list_changed(
333 &self,
334 params: Option<RootsListChangedNotificationParams>,
335 ) -> SdkResult<()> {
336 let notification = RootsListChangedNotification::new(params);
337 self.send_notification(notification.into()).await
338 }
339
340 fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
346 let entity = "Server";
347
348 let capabilities = self.server_capabilities().ok_or::<RpcError>(
349 RpcError::internal_error().with_message("Server is not initialized!".to_string()),
350 )?;
351
352 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
353 return Err(RpcError::internal_error()
354 .with_message(format_assertion_message(entity, "logging", request_method))
355 .into());
356 }
357
358 if [
359 GetPromptRequest::method_name(),
360 ListPromptsRequest::method_name(),
361 ]
362 .contains(request_method)
363 && capabilities.prompts.is_none()
364 {
365 return Err(RpcError::internal_error()
366 .with_message(format_assertion_message(entity, "prompts", request_method))
367 .into());
368 }
369
370 if [
371 ListResourcesRequest::method_name(),
372 ListResourceTemplatesRequest::method_name(),
373 ReadResourceRequest::method_name(),
374 SubscribeRequest::method_name(),
375 UnsubscribeRequest::method_name(),
376 ]
377 .contains(request_method)
378 && capabilities.resources.is_none()
379 {
380 return Err(RpcError::internal_error()
381 .with_message(format_assertion_message(
382 entity,
383 "resources",
384 request_method,
385 ))
386 .into());
387 }
388
389 if [
390 CallToolRequest::method_name(),
391 ListToolsRequest::method_name(),
392 ]
393 .contains(request_method)
394 && capabilities.tools.is_none()
395 {
396 return Err(RpcError::internal_error()
397 .with_message(format_assertion_message(entity, "tools", request_method))
398 .into());
399 }
400
401 Ok(())
402 }
403
404 fn assert_client_notification_capabilities(
405 &self,
406 notification_method: &String,
407 ) -> std::result::Result<(), RpcError> {
408 let entity = "Client";
409 let capabilities = &self.client_info().capabilities;
410
411 if *notification_method == RootsListChangedNotification::method_name()
412 && capabilities.roots.is_some()
413 {
414 return Err(
415 RpcError::internal_error().with_message(format_assertion_message(
416 entity,
417 "roots list changed notifications",
418 notification_method,
419 )),
420 );
421 }
422
423 Ok(())
424 }
425
426 fn assert_client_request_capabilities(
427 &self,
428 request_method: &String,
429 ) -> std::result::Result<(), RpcError> {
430 let entity = "Client";
431 let capabilities = &self.client_info().capabilities;
432
433 if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
434 {
435 return Err(
436 RpcError::internal_error().with_message(format_assertion_message(
437 entity,
438 "sampling capability",
439 request_method,
440 )),
441 );
442 }
443
444 if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
445 return Err(
446 RpcError::internal_error().with_message(format_assertion_message(
447 entity,
448 "roots capability",
449 request_method,
450 )),
451 );
452 }
453
454 Ok(())
455 }
456}