1use crate::schema::{
2 schema_utils::{
3 ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
4 ResultFromServer, ServerMessage,
5 },
6 CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
7 CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
8 InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
9 ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
10 ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
11 LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
12 RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
13 SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
14 UnsubscribeRequest, UnsubscribeRequestParams,
15};
16use crate::{error::SdkResult, utils::format_assertion_message};
17use async_trait::async_trait;
18use std::{sync::Arc, time::Duration};
19
20#[async_trait]
21pub trait McpClient: Sync + Send {
22 async fn start(self: Arc<Self>) -> SdkResult<()>;
23 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
24
25 async fn terminate_session(&self);
26
27 async fn shut_down(&self) -> SdkResult<()>;
28 async fn is_shut_down(&self) -> bool;
29
30 fn client_info(&self) -> &InitializeRequestParams;
31 fn server_info(&self) -> Option<InitializeResult>;
32
33 fn is_initialized(&self) -> bool {
35 self.server_info().is_some()
36 }
37
38 fn server_version(&self) -> Option<Implementation> {
41 self.server_info()
42 .map(|server_details| server_details.server_info)
43 }
44
45 fn server_capabilities(&self) -> Option<ServerCapabilities> {
48 self.server_info().map(|item| item.capabilities)
49 }
50
51 fn server_has_tools(&self) -> Option<bool> {
66 self.server_info()
67 .map(|server_details| server_details.capabilities.tools.is_some())
68 }
69
70 fn server_has_prompts(&self) -> Option<bool> {
82 self.server_info()
83 .map(|server_details| server_details.capabilities.prompts.is_some())
84 }
85
86 fn server_has_experimental(&self) -> Option<bool> {
98 self.server_info()
99 .map(|server_details| server_details.capabilities.experimental.is_some())
100 }
101
102 fn server_has_resources(&self) -> Option<bool> {
114 self.server_info()
115 .map(|server_details| server_details.capabilities.resources.is_some())
116 }
117
118 fn server_supports_logging(&self) -> Option<bool> {
130 self.server_info()
131 .map(|server_details| server_details.capabilities.logging.is_some())
132 }
133
134 fn instructions(&self) -> Option<String> {
135 self.server_info()?.instructions
136 }
137
138 async fn request(
144 &self,
145 request: RequestFromClient,
146 timeout: Option<Duration>,
147 ) -> SdkResult<ResultFromServer> {
148 let response = self
149 .send(MessageFromClient::RequestFromClient(request), None, timeout)
150 .await?;
151
152 let server_message = response.ok_or_else(|| {
153 RpcError::internal_error()
154 .with_message("An empty response was received from the client.".to_string())
155 })?;
156
157 if server_message.is_error() {
158 return Err(server_message.as_error()?.error.into());
159 }
160
161 return Ok(server_message.as_response()?.result);
162 }
163
164 async fn send(
165 &self,
166 message: MessageFromClient,
167 request_id: Option<RequestId>,
168 request_timeout: Option<Duration>,
169 ) -> SdkResult<Option<ServerMessage>>;
170
171 async fn send_batch(
172 &self,
173 messages: Vec<ClientMessage>,
174 timeout: Option<Duration>,
175 ) -> SdkResult<Option<Vec<ServerMessage>>>;
176
177 async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
181 self.send(notification.into(), None, None).await?;
182 Ok(())
183 }
184
185 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
196 let ping_request = PingRequest::new(None);
197 let response = self.request(ping_request.into(), timeout).await?;
198 Ok(response.try_into()?)
199 }
200
201 async fn complete(
202 &self,
203 params: CompleteRequestParams,
204 ) -> SdkResult<crate::schema::CompleteResult> {
205 let request = CompleteRequest::new(params);
206 let response = self.request(request.into(), None).await?;
207 Ok(response.try_into()?)
208 }
209
210 async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
211 let request = SetLevelRequest::new(SetLevelRequestParams { level });
212 let response = self.request(request.into(), None).await?;
213 Ok(response.try_into()?)
214 }
215
216 async fn get_prompt(
217 &self,
218 params: GetPromptRequestParams,
219 ) -> SdkResult<crate::schema::GetPromptResult> {
220 let request = GetPromptRequest::new(params);
221 let response = self.request(request.into(), None).await?;
222 Ok(response.try_into()?)
223 }
224
225 async fn list_prompts(
226 &self,
227 params: Option<ListPromptsRequestParams>,
228 ) -> SdkResult<crate::schema::ListPromptsResult> {
229 let request = ListPromptsRequest::new(params);
230 let response = self.request(request.into(), None).await?;
231 Ok(response.try_into()?)
232 }
233
234 async fn list_resources(
235 &self,
236 params: Option<ListResourcesRequestParams>,
237 ) -> SdkResult<crate::schema::ListResourcesResult> {
238 let request =
243 ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
244 let response = self.request(request.into(), None).await?;
245 Ok(response.try_into()?)
246 }
247
248 async fn list_resource_templates(
249 &self,
250 params: Option<ListResourceTemplatesRequestParams>,
251 ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
252 let request = ListResourceTemplatesRequest::new(params);
253 let response = self.request(request.into(), None).await?;
254 Ok(response.try_into()?)
255 }
256
257 async fn read_resource(
258 &self,
259 params: ReadResourceRequestParams,
260 ) -> SdkResult<crate::schema::ReadResourceResult> {
261 let request = ReadResourceRequest::new(params);
262 let response = self.request(request.into(), None).await?;
263 Ok(response.try_into()?)
264 }
265
266 async fn subscribe_resource(
267 &self,
268 params: SubscribeRequestParams,
269 ) -> SdkResult<crate::schema::Result> {
270 let request = SubscribeRequest::new(params);
271 let response = self.request(request.into(), None).await?;
272 Ok(response.try_into()?)
273 }
274
275 async fn unsubscribe_resource(
276 &self,
277 params: UnsubscribeRequestParams,
278 ) -> SdkResult<crate::schema::Result> {
279 let request = UnsubscribeRequest::new(params);
280 let response = self.request(request.into(), None).await?;
281 Ok(response.try_into()?)
282 }
283
284 async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
285 let request = CallToolRequest::new(params);
286 let response = self.request(request.into(), None).await?;
287 Ok(response.try_into()?)
288 }
289
290 async fn list_tools(
291 &self,
292 params: Option<ListToolsRequestParams>,
293 ) -> SdkResult<crate::schema::ListToolsResult> {
294 let request = ListToolsRequest::new(params);
295 let response = self.request(request.into(), None).await?;
296 Ok(response.try_into()?)
297 }
298
299 async fn send_roots_list_changed(
300 &self,
301 params: Option<RootsListChangedNotificationParams>,
302 ) -> SdkResult<()> {
303 let notification = RootsListChangedNotification::new(params);
304 self.send_notification(notification.into()).await
305 }
306
307 fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
313 let entity = "Server";
314
315 let capabilities = self.server_capabilities().ok_or::<RpcError>(
316 RpcError::internal_error().with_message("Server is not initialized!".to_string()),
317 )?;
318
319 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
320 return Err(RpcError::internal_error()
321 .with_message(format_assertion_message(entity, "logging", request_method))
322 .into());
323 }
324
325 if [
326 GetPromptRequest::method_name(),
327 ListPromptsRequest::method_name(),
328 ]
329 .contains(request_method)
330 && capabilities.prompts.is_none()
331 {
332 return Err(RpcError::internal_error()
333 .with_message(format_assertion_message(entity, "prompts", request_method))
334 .into());
335 }
336
337 if [
338 ListResourcesRequest::method_name(),
339 ListResourceTemplatesRequest::method_name(),
340 ReadResourceRequest::method_name(),
341 SubscribeRequest::method_name(),
342 UnsubscribeRequest::method_name(),
343 ]
344 .contains(request_method)
345 && capabilities.resources.is_none()
346 {
347 return Err(RpcError::internal_error()
348 .with_message(format_assertion_message(
349 entity,
350 "resources",
351 request_method,
352 ))
353 .into());
354 }
355
356 if [
357 CallToolRequest::method_name(),
358 ListToolsRequest::method_name(),
359 ]
360 .contains(request_method)
361 && capabilities.tools.is_none()
362 {
363 return Err(RpcError::internal_error()
364 .with_message(format_assertion_message(entity, "tools", request_method))
365 .into());
366 }
367
368 Ok(())
369 }
370
371 fn assert_client_notification_capabilities(
372 &self,
373 notification_method: &String,
374 ) -> std::result::Result<(), RpcError> {
375 let entity = "Client";
376 let capabilities = &self.client_info().capabilities;
377
378 if *notification_method == RootsListChangedNotification::method_name()
379 && capabilities.roots.is_some()
380 {
381 return Err(
382 RpcError::internal_error().with_message(format_assertion_message(
383 entity,
384 "roots list changed notifications",
385 notification_method,
386 )),
387 );
388 }
389
390 Ok(())
391 }
392
393 fn assert_client_request_capabilities(
394 &self,
395 request_method: &String,
396 ) -> std::result::Result<(), RpcError> {
397 let entity = "Client";
398 let capabilities = &self.client_info().capabilities;
399
400 if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
401 {
402 return Err(
403 RpcError::internal_error().with_message(format_assertion_message(
404 entity,
405 "sampling capability",
406 request_method,
407 )),
408 );
409 }
410
411 if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
412 return Err(
413 RpcError::internal_error().with_message(format_assertion_message(
414 entity,
415 "roots capability",
416 request_method,
417 )),
418 );
419 }
420
421 Ok(())
422 }
423}