1use std::{sync::Arc, time::Duration};
2
3use crate::schema::{
4 schema_utils::{
5 self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient,
6 NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages,
7 },
8 CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
9 CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
10 InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
11 ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
12 ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
13 LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
14 RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
15 SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
16 UnsubscribeRequest, UnsubscribeRequestParams,
17};
18use crate::{error::SdkResult, utils::format_assertion_message};
19use async_trait::async_trait;
20use rust_mcp_transport::{McpDispatch, MessageDispatcher};
21
22#[async_trait]
23pub trait McpClient: Sync + Send {
24 async fn start(self: Arc<Self>) -> SdkResult<()>;
25 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
26
27 async fn shut_down(&self) -> SdkResult<()>;
28 async fn is_shut_down(&self) -> bool;
29
30 fn sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>>
31 where
32 MessageDispatcher<ServerMessage>:
33 McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>;
34
35 fn client_info(&self) -> &InitializeRequestParams;
36 fn server_info(&self) -> Option<InitializeResult>;
37
38 fn is_initialized(&self) -> bool {
40 self.server_info().is_some()
41 }
42
43 fn server_version(&self) -> Option<Implementation> {
46 self.server_info()
47 .map(|server_details| server_details.server_info)
48 }
49
50 fn server_capabilities(&self) -> Option<ServerCapabilities> {
53 self.server_info().map(|item| item.capabilities)
54 }
55
56 fn server_has_tools(&self) -> Option<bool> {
71 self.server_info()
72 .map(|server_details| server_details.capabilities.tools.is_some())
73 }
74
75 fn server_has_prompts(&self) -> Option<bool> {
87 self.server_info()
88 .map(|server_details| server_details.capabilities.prompts.is_some())
89 }
90
91 fn server_has_experimental(&self) -> Option<bool> {
103 self.server_info()
104 .map(|server_details| server_details.capabilities.experimental.is_some())
105 }
106
107 fn server_has_resources(&self) -> Option<bool> {
119 self.server_info()
120 .map(|server_details| server_details.capabilities.resources.is_some())
121 }
122
123 fn server_supports_logging(&self) -> Option<bool> {
135 self.server_info()
136 .map(|server_details| server_details.capabilities.logging.is_some())
137 }
138
139 fn instructions(&self) -> Option<String> {
140 self.server_info()?.instructions
141 }
142
143 async fn request(
149 &self,
150 request: RequestFromClient,
151 timeout: Option<Duration>,
152 ) -> SdkResult<ResultFromServer> {
153 let response = self
154 .send(MessageFromClient::RequestFromClient(request), None, timeout)
155 .await?;
156
157 let server_message = response.ok_or_else(|| {
158 RpcError::internal_error()
159 .with_message("An empty response was received from the client.".to_string())
160 })?;
161
162 if server_message.is_error() {
163 return Err(server_message.as_error()?.error.into());
164 }
165
166 return Ok(server_message.as_response()?.result);
167 }
168
169 async fn send(
170 &self,
171 message: MessageFromClient,
172 request_id: Option<RequestId>,
173 timeout: Option<Duration>,
174 ) -> SdkResult<Option<ServerMessage>>;
175
176 async fn send_batch(
177 &self,
178 messages: Vec<ClientMessage>,
179 timeout: Option<Duration>,
180 ) -> SdkResult<Option<Vec<ServerMessage>>> {
181 let sender = self.sender();
182 let sender = sender.read().await;
183 let sender = sender
184 .as_ref()
185 .ok_or(schema_utils::SdkError::connection_closed())?;
186
187 let response = sender
188 .send_message(ClientMessages::Batch(messages), timeout)
189 .await?;
190
191 match response {
192 Some(res) => {
193 let server_results = res.as_batch()?;
194 Ok(Some(server_results))
195 }
196 None => Ok(None),
197 }
198 }
199
200 async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
204 let sender = self.sender();
205 let sender = sender.read().await;
206 let sender = sender
207 .as_ref()
208 .ok_or(schema_utils::SdkError::connection_closed())?;
209
210 let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?;
211
212 sender
213 .send_message(ClientMessages::Single(mcp_message), None)
214 .await?;
215 Ok(())
216 }
217
218 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
229 let ping_request = PingRequest::new(None);
230 let response = self.request(ping_request.into(), timeout).await?;
231 Ok(response.try_into()?)
232 }
233
234 async fn complete(
235 &self,
236 params: CompleteRequestParams,
237 ) -> SdkResult<crate::schema::CompleteResult> {
238 let request = CompleteRequest::new(params);
239 let response = self.request(request.into(), None).await?;
240 Ok(response.try_into()?)
241 }
242
243 async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
244 let request = SetLevelRequest::new(SetLevelRequestParams { level });
245 let response = self.request(request.into(), None).await?;
246 Ok(response.try_into()?)
247 }
248
249 async fn get_prompt(
250 &self,
251 params: GetPromptRequestParams,
252 ) -> SdkResult<crate::schema::GetPromptResult> {
253 let request = GetPromptRequest::new(params);
254 let response = self.request(request.into(), None).await?;
255 Ok(response.try_into()?)
256 }
257
258 async fn list_prompts(
259 &self,
260 params: Option<ListPromptsRequestParams>,
261 ) -> SdkResult<crate::schema::ListPromptsResult> {
262 let request = ListPromptsRequest::new(params);
263 let response = self.request(request.into(), None).await?;
264 Ok(response.try_into()?)
265 }
266
267 async fn list_resources(
268 &self,
269 params: Option<ListResourcesRequestParams>,
270 ) -> SdkResult<crate::schema::ListResourcesResult> {
271 let request =
276 ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
277 let response = self.request(request.into(), None).await?;
278 Ok(response.try_into()?)
279 }
280
281 async fn list_resource_templates(
282 &self,
283 params: Option<ListResourceTemplatesRequestParams>,
284 ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
285 let request = ListResourceTemplatesRequest::new(params);
286 let response = self.request(request.into(), None).await?;
287 Ok(response.try_into()?)
288 }
289
290 async fn read_resource(
291 &self,
292 params: ReadResourceRequestParams,
293 ) -> SdkResult<crate::schema::ReadResourceResult> {
294 let request = ReadResourceRequest::new(params);
295 let response = self.request(request.into(), None).await?;
296 Ok(response.try_into()?)
297 }
298
299 async fn subscribe_resource(
300 &self,
301 params: SubscribeRequestParams,
302 ) -> SdkResult<crate::schema::Result> {
303 let request = SubscribeRequest::new(params);
304 let response = self.request(request.into(), None).await?;
305 Ok(response.try_into()?)
306 }
307
308 async fn unsubscribe_resource(
309 &self,
310 params: UnsubscribeRequestParams,
311 ) -> SdkResult<crate::schema::Result> {
312 let request = UnsubscribeRequest::new(params);
313 let response = self.request(request.into(), None).await?;
314 Ok(response.try_into()?)
315 }
316
317 async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
318 let request = CallToolRequest::new(params);
319 let response = self.request(request.into(), None).await?;
320 Ok(response.try_into()?)
321 }
322
323 async fn list_tools(
324 &self,
325 params: Option<ListToolsRequestParams>,
326 ) -> SdkResult<crate::schema::ListToolsResult> {
327 let request = ListToolsRequest::new(params);
328 let response = self.request(request.into(), None).await?;
329 Ok(response.try_into()?)
330 }
331
332 async fn send_roots_list_changed(
333 &self,
334 params: Option<RootsListChangedNotificationParams>,
335 ) -> SdkResult<()> {
336 let notification = RootsListChangedNotification::new(params);
337 self.send_notification(notification.into()).await
338 }
339
340 fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
346 let entity = "Server";
347
348 let capabilities = self.server_capabilities().ok_or::<RpcError>(
349 RpcError::internal_error().with_message("Server is not initialized!".to_string()),
350 )?;
351
352 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
353 return Err(RpcError::internal_error()
354 .with_message(format_assertion_message(entity, "logging", request_method))
355 .into());
356 }
357
358 if [
359 GetPromptRequest::method_name(),
360 ListPromptsRequest::method_name(),
361 ]
362 .contains(request_method)
363 && capabilities.prompts.is_none()
364 {
365 return Err(RpcError::internal_error()
366 .with_message(format_assertion_message(entity, "prompts", request_method))
367 .into());
368 }
369
370 if [
371 ListResourcesRequest::method_name(),
372 ListResourceTemplatesRequest::method_name(),
373 ReadResourceRequest::method_name(),
374 SubscribeRequest::method_name(),
375 UnsubscribeRequest::method_name(),
376 ]
377 .contains(request_method)
378 && capabilities.resources.is_none()
379 {
380 return Err(RpcError::internal_error()
381 .with_message(format_assertion_message(
382 entity,
383 "resources",
384 request_method,
385 ))
386 .into());
387 }
388
389 if [
390 CallToolRequest::method_name(),
391 ListToolsRequest::method_name(),
392 ]
393 .contains(request_method)
394 && capabilities.tools.is_none()
395 {
396 return Err(RpcError::internal_error()
397 .with_message(format_assertion_message(entity, "tools", request_method))
398 .into());
399 }
400
401 Ok(())
402 }
403
404 fn assert_client_notification_capabilities(
405 &self,
406 notification_method: &String,
407 ) -> std::result::Result<(), RpcError> {
408 let entity = "Client";
409 let capabilities = &self.client_info().capabilities;
410
411 if *notification_method == RootsListChangedNotification::method_name()
412 && capabilities.roots.is_some()
413 {
414 return Err(
415 RpcError::internal_error().with_message(format_assertion_message(
416 entity,
417 "roots list changed notifications",
418 notification_method,
419 )),
420 );
421 }
422
423 Ok(())
424 }
425
426 fn assert_client_request_capabilities(
427 &self,
428 request_method: &String,
429 ) -> std::result::Result<(), RpcError> {
430 let entity = "Client";
431 let capabilities = &self.client_info().capabilities;
432
433 if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
434 {
435 return Err(
436 RpcError::internal_error().with_message(format_assertion_message(
437 entity,
438 "sampling capability",
439 request_method,
440 )),
441 );
442 }
443
444 if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
445 return Err(
446 RpcError::internal_error().with_message(format_assertion_message(
447 entity,
448 "roots capability",
449 request_method,
450 )),
451 );
452 }
453
454 Ok(())
455 }
456}