1use crate::schema::{
2 schema_utils::{
3 ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
4 ResultFromServer, ServerMessage,
5 },
6 CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
7 CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
8 InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
9 ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
10 ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
11 LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
12 RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
13 SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
14 UnsubscribeRequest, UnsubscribeRequestParams,
15};
16use crate::{error::SdkResult, utils::format_assertion_message};
17use async_trait::async_trait;
18use std::{sync::Arc, time::Duration};
19
20#[async_trait]
21pub trait McpClient: Sync + Send {
22 async fn start(self: Arc<Self>) -> SdkResult<()>;
23 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
24
25 async fn terminate_session(&self);
26
27 async fn shut_down(&self) -> SdkResult<()>;
28 async fn is_shut_down(&self) -> bool;
29
30 fn client_info(&self) -> &InitializeRequestParams;
31 fn server_info(&self) -> Option<InitializeResult>;
32
33 fn is_initialized(&self) -> bool {
35 self.server_info().is_some()
36 }
37
38 fn server_version(&self) -> Option<Implementation> {
41 self.server_info()
42 .map(|server_details| server_details.server_info)
43 }
44
45 fn server_capabilities(&self) -> Option<ServerCapabilities> {
48 self.server_info().map(|item| item.capabilities)
49 }
50
51 fn server_has_tools(&self) -> Option<bool> {
66 self.server_info()
67 .map(|server_details| server_details.capabilities.tools.is_some())
68 }
69
70 fn server_has_prompts(&self) -> Option<bool> {
82 self.server_info()
83 .map(|server_details| server_details.capabilities.prompts.is_some())
84 }
85
86 fn server_has_experimental(&self) -> Option<bool> {
98 self.server_info()
99 .map(|server_details| server_details.capabilities.experimental.is_some())
100 }
101
102 fn server_has_resources(&self) -> Option<bool> {
114 self.server_info()
115 .map(|server_details| server_details.capabilities.resources.is_some())
116 }
117
118 fn server_supports_logging(&self) -> Option<bool> {
130 self.server_info()
131 .map(|server_details| server_details.capabilities.logging.is_some())
132 }
133
134 fn server_supports_completion(&self) -> Option<bool> {
146 self.server_info()
147 .map(|server_details| server_details.capabilities.completions.is_some())
148 }
149
150 fn instructions(&self) -> Option<String> {
151 self.server_info()?.instructions
152 }
153
154 async fn request(
160 &self,
161 request: RequestFromClient,
162 timeout: Option<Duration>,
163 ) -> SdkResult<ResultFromServer> {
164 let response = self
165 .send(MessageFromClient::RequestFromClient(request), None, timeout)
166 .await?;
167
168 let server_message = response.ok_or_else(|| {
169 RpcError::internal_error()
170 .with_message("An empty response was received from the client.".to_string())
171 })?;
172
173 if server_message.is_error() {
174 return Err(server_message.as_error()?.error.into());
175 }
176
177 return Ok(server_message.as_response()?.result);
178 }
179
180 async fn send(
181 &self,
182 message: MessageFromClient,
183 request_id: Option<RequestId>,
184 request_timeout: Option<Duration>,
185 ) -> SdkResult<Option<ServerMessage>>;
186
187 async fn send_batch(
188 &self,
189 messages: Vec<ClientMessage>,
190 timeout: Option<Duration>,
191 ) -> SdkResult<Option<Vec<ServerMessage>>>;
192
193 async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
197 self.send(notification.into(), None, None).await?;
198 Ok(())
199 }
200
201 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
212 let ping_request = PingRequest::new(None);
213 let response = self.request(ping_request.into(), timeout).await?;
214 Ok(response.try_into()?)
215 }
216
217 async fn complete(
218 &self,
219 params: CompleteRequestParams,
220 ) -> SdkResult<crate::schema::CompleteResult> {
221 let request = CompleteRequest::new(params);
222 let response = self.request(request.into(), None).await?;
223 Ok(response.try_into()?)
224 }
225
226 async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
227 let request = SetLevelRequest::new(SetLevelRequestParams { level });
228 let response = self.request(request.into(), None).await?;
229 Ok(response.try_into()?)
230 }
231
232 async fn get_prompt(
233 &self,
234 params: GetPromptRequestParams,
235 ) -> SdkResult<crate::schema::GetPromptResult> {
236 let request = GetPromptRequest::new(params);
237 let response = self.request(request.into(), None).await?;
238 Ok(response.try_into()?)
239 }
240
241 async fn list_prompts(
242 &self,
243 params: Option<ListPromptsRequestParams>,
244 ) -> SdkResult<crate::schema::ListPromptsResult> {
245 let request = ListPromptsRequest::new(params);
246 let response = self.request(request.into(), None).await?;
247 Ok(response.try_into()?)
248 }
249
250 async fn list_resources(
251 &self,
252 params: Option<ListResourcesRequestParams>,
253 ) -> SdkResult<crate::schema::ListResourcesResult> {
254 let request =
259 ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
260 let response = self.request(request.into(), None).await?;
261 Ok(response.try_into()?)
262 }
263
264 async fn list_resource_templates(
265 &self,
266 params: Option<ListResourceTemplatesRequestParams>,
267 ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
268 let request = ListResourceTemplatesRequest::new(params);
269 let response = self.request(request.into(), None).await?;
270 Ok(response.try_into()?)
271 }
272
273 async fn read_resource(
274 &self,
275 params: ReadResourceRequestParams,
276 ) -> SdkResult<crate::schema::ReadResourceResult> {
277 let request = ReadResourceRequest::new(params);
278 let response = self.request(request.into(), None).await?;
279 Ok(response.try_into()?)
280 }
281
282 async fn subscribe_resource(
283 &self,
284 params: SubscribeRequestParams,
285 ) -> SdkResult<crate::schema::Result> {
286 let request = SubscribeRequest::new(params);
287 let response = self.request(request.into(), None).await?;
288 Ok(response.try_into()?)
289 }
290
291 async fn unsubscribe_resource(
292 &self,
293 params: UnsubscribeRequestParams,
294 ) -> SdkResult<crate::schema::Result> {
295 let request = UnsubscribeRequest::new(params);
296 let response = self.request(request.into(), None).await?;
297 Ok(response.try_into()?)
298 }
299
300 async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
301 let request = CallToolRequest::new(params);
302 let response = self.request(request.into(), None).await?;
303 Ok(response.try_into()?)
304 }
305
306 async fn list_tools(
307 &self,
308 params: Option<ListToolsRequestParams>,
309 ) -> SdkResult<crate::schema::ListToolsResult> {
310 let request = ListToolsRequest::new(params);
311 let response = self.request(request.into(), None).await?;
312 Ok(response.try_into()?)
313 }
314
315 async fn send_roots_list_changed(
316 &self,
317 params: Option<RootsListChangedNotificationParams>,
318 ) -> SdkResult<()> {
319 let notification = RootsListChangedNotification::new(params);
320 self.send_notification(notification.into()).await
321 }
322
323 fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
329 let entity = "Server";
330
331 let capabilities = self.server_capabilities().ok_or::<RpcError>(
332 RpcError::internal_error().with_message("Server is not initialized!".to_string()),
333 )?;
334
335 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
336 return Err(RpcError::internal_error()
337 .with_message(format_assertion_message(entity, "logging", request_method))
338 .into());
339 }
340
341 if [
342 GetPromptRequest::method_name(),
343 ListPromptsRequest::method_name(),
344 ]
345 .contains(request_method)
346 && capabilities.prompts.is_none()
347 {
348 return Err(RpcError::internal_error()
349 .with_message(format_assertion_message(entity, "prompts", request_method))
350 .into());
351 }
352
353 if [
354 ListResourcesRequest::method_name(),
355 ListResourceTemplatesRequest::method_name(),
356 ReadResourceRequest::method_name(),
357 SubscribeRequest::method_name(),
358 UnsubscribeRequest::method_name(),
359 ]
360 .contains(request_method)
361 && capabilities.resources.is_none()
362 {
363 return Err(RpcError::internal_error()
364 .with_message(format_assertion_message(
365 entity,
366 "resources",
367 request_method,
368 ))
369 .into());
370 }
371
372 if [
373 CallToolRequest::method_name(),
374 ListToolsRequest::method_name(),
375 ]
376 .contains(request_method)
377 && capabilities.tools.is_none()
378 {
379 return Err(RpcError::internal_error()
380 .with_message(format_assertion_message(entity, "tools", request_method))
381 .into());
382 }
383
384 Ok(())
385 }
386
387 fn assert_client_notification_capabilities(
388 &self,
389 notification_method: &String,
390 ) -> std::result::Result<(), RpcError> {
391 let entity = "Client";
392 let capabilities = &self.client_info().capabilities;
393
394 if *notification_method == RootsListChangedNotification::method_name()
395 && capabilities.roots.is_some()
396 {
397 return Err(
398 RpcError::internal_error().with_message(format_assertion_message(
399 entity,
400 "roots list changed notifications",
401 notification_method,
402 )),
403 );
404 }
405
406 Ok(())
407 }
408
409 fn assert_client_request_capabilities(
410 &self,
411 request_method: &String,
412 ) -> std::result::Result<(), RpcError> {
413 let entity = "Client";
414 let capabilities = &self.client_info().capabilities;
415
416 if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
417 {
418 return Err(
419 RpcError::internal_error().with_message(format_assertion_message(
420 entity,
421 "sampling capability",
422 request_method,
423 )),
424 );
425 }
426
427 if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
428 return Err(
429 RpcError::internal_error().with_message(format_assertion_message(
430 entity,
431 "roots capability",
432 request_method,
433 )),
434 );
435 }
436
437 Ok(())
438 }
439}