1use std::borrow::Cow;
35use std::convert::Infallible;
36use std::fmt;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41
42use schemars::{JsonSchema, Schema, SchemaGenerator};
43use serde::Serialize;
44use serde::de::DeserializeOwned;
45use serde_json::Value;
46use tower::util::BoxCloneService;
47use tower_service::Service;
48
49use crate::context::RequestContext;
50use crate::error::{Error, Result};
51use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
52
53#[derive(Debug, Clone)]
62pub struct ToolRequest {
63 pub ctx: RequestContext,
65 pub args: Value,
67}
68
69impl ToolRequest {
70 pub fn new(ctx: RequestContext, args: Value) -> Self {
72 Self { ctx, args }
73 }
74}
75
76pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
82
83pub struct ToolCatchError<S> {
89 inner: S,
90}
91
92impl<S> ToolCatchError<S> {
93 pub fn new(inner: S) -> Self {
95 Self { inner }
96 }
97}
98
99impl<S: Clone> Clone for ToolCatchError<S> {
100 fn clone(&self) -> Self {
101 Self {
102 inner: self.inner.clone(),
103 }
104 }
105}
106
107impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("ToolCatchError")
110 .field("inner", &self.inner)
111 .finish()
112 }
113}
114
115impl<S> Service<ToolRequest> for ToolCatchError<S>
116where
117 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
118 S::Error: fmt::Display + Send,
119 S::Future: Send,
120{
121 type Response = CallToolResult;
122 type Error = Infallible;
123 type Future =
124 Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
127 match self.inner.poll_ready(cx) {
129 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
130 Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
131 Poll::Pending => Poll::Pending,
132 }
133 }
134
135 fn call(&mut self, req: ToolRequest) -> Self::Future {
136 let fut = self.inner.call(req);
137
138 Box::pin(async move {
139 match fut.await {
140 Ok(result) => Ok(result),
141 Err(err) => Ok(CallToolResult::error(err.to_string())),
142 }
143 })
144 }
145}
146
147#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
168pub struct NoParams;
169
170impl<'de> serde::Deserialize<'de> for NoParams {
171 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
172 where
173 D: serde::Deserializer<'de>,
174 {
175 struct NoParamsVisitor;
177
178 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
179 type Value = NoParams;
180
181 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
182 formatter.write_str("null or an object")
183 }
184
185 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
186 where
187 E: serde::de::Error,
188 {
189 Ok(NoParams)
190 }
191
192 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
193 where
194 E: serde::de::Error,
195 {
196 Ok(NoParams)
197 }
198
199 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
200 where
201 D: serde::Deserializer<'de>,
202 {
203 serde::Deserialize::deserialize(deserializer)
204 }
205
206 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
207 where
208 A: serde::de::MapAccess<'de>,
209 {
210 while map
212 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
213 .is_some()
214 {}
215 Ok(NoParams)
216 }
217 }
218
219 deserializer.deserialize_any(NoParamsVisitor)
220 }
221}
222
223impl JsonSchema for NoParams {
224 fn schema_name() -> Cow<'static, str> {
225 Cow::Borrowed("NoParams")
226 }
227
228 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
229 serde_json::json!({
230 "type": "object"
231 })
232 .try_into()
233 .expect("valid schema")
234 }
235}
236
237pub fn validate_tool_name(name: &str) -> Result<()> {
245 if name.is_empty() {
246 return Err(Error::tool("Tool name cannot be empty"));
247 }
248 if name.len() > 128 {
249 return Err(Error::tool(format!(
250 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
251 name,
252 name.len()
253 )));
254 }
255 if let Some(invalid_char) = name
256 .chars()
257 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
258 {
259 return Err(Error::tool(format!(
260 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
261 name, invalid_char
262 )));
263 }
264 Ok(())
265}
266
267pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
269
270pub trait ToolHandler: Send + Sync {
272 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
274
275 fn call_with_context(
280 &self,
281 _ctx: RequestContext,
282 args: Value,
283 ) -> BoxFuture<'_, Result<CallToolResult>> {
284 self.call(args)
285 }
286
287 fn uses_context(&self) -> bool {
289 false
290 }
291
292 fn input_schema(&self) -> Value;
294}
295
296pub(crate) struct ToolHandlerService<H> {
301 handler: Arc<H>,
302}
303
304impl<H> ToolHandlerService<H> {
305 pub(crate) fn new(handler: H) -> Self {
306 Self {
307 handler: Arc::new(handler),
308 }
309 }
310}
311
312impl<H> Clone for ToolHandlerService<H> {
313 fn clone(&self) -> Self {
314 Self {
315 handler: self.handler.clone(),
316 }
317 }
318}
319
320impl<H> Service<ToolRequest> for ToolHandlerService<H>
321where
322 H: ToolHandler + 'static,
323{
324 type Response = CallToolResult;
325 type Error = Error;
326 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
327
328 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
329 Poll::Ready(Ok(()))
330 }
331
332 fn call(&mut self, req: ToolRequest) -> Self::Future {
333 let handler = self.handler.clone();
334 Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
335 }
336}
337
338pub struct Tool {
345 pub name: String,
347 pub title: Option<String>,
349 pub description: Option<String>,
351 pub output_schema: Option<Value>,
353 pub icons: Option<Vec<ToolIcon>>,
355 pub annotations: Option<ToolAnnotations>,
357 pub(crate) service: BoxToolService,
359 pub(crate) input_schema: Value,
361}
362
363impl std::fmt::Debug for Tool {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.debug_struct("Tool")
366 .field("name", &self.name)
367 .field("title", &self.title)
368 .field("description", &self.description)
369 .field("output_schema", &self.output_schema)
370 .field("icons", &self.icons)
371 .field("annotations", &self.annotations)
372 .finish_non_exhaustive()
373 }
374}
375
376unsafe impl Send for Tool {}
379unsafe impl Sync for Tool {}
380
381impl Clone for Tool {
382 fn clone(&self) -> Self {
383 Self {
384 name: self.name.clone(),
385 title: self.title.clone(),
386 description: self.description.clone(),
387 output_schema: self.output_schema.clone(),
388 icons: self.icons.clone(),
389 annotations: self.annotations.clone(),
390 service: self.service.clone(),
391 input_schema: self.input_schema.clone(),
392 }
393 }
394}
395
396impl Tool {
397 pub fn builder(name: impl Into<String>) -> ToolBuilder {
399 ToolBuilder::new(name)
400 }
401
402 pub fn definition(&self) -> ToolDefinition {
404 ToolDefinition {
405 name: self.name.clone(),
406 title: self.title.clone(),
407 description: self.description.clone(),
408 input_schema: self.input_schema.clone(),
409 output_schema: self.output_schema.clone(),
410 icons: self.icons.clone(),
411 annotations: self.annotations.clone(),
412 }
413 }
414
415 pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
420 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
421 self.call_with_context(ctx, args)
422 }
423
424 pub fn call_with_context(
435 &self,
436 ctx: RequestContext,
437 args: Value,
438 ) -> BoxFuture<'static, CallToolResult> {
439 use tower::ServiceExt;
440 let service = self.service.clone();
441 Box::pin(async move {
442 service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
445 })
446 }
447
448 pub fn with_name_prefix(&self, prefix: &str) -> Self {
476 Self {
477 name: format!("{}.{}", prefix, self.name),
478 title: self.title.clone(),
479 description: self.description.clone(),
480 output_schema: self.output_schema.clone(),
481 icons: self.icons.clone(),
482 annotations: self.annotations.clone(),
483 service: self.service.clone(),
484 input_schema: self.input_schema.clone(),
485 }
486 }
487
488 fn from_handler<H: ToolHandler + 'static>(
490 name: String,
491 title: Option<String>,
492 description: Option<String>,
493 output_schema: Option<Value>,
494 icons: Option<Vec<ToolIcon>>,
495 annotations: Option<ToolAnnotations>,
496 handler: H,
497 ) -> Self {
498 let input_schema = handler.input_schema();
499 let handler_service = ToolHandlerService::new(handler);
500 let catch_error = ToolCatchError::new(handler_service);
501 let service = BoxCloneService::new(catch_error);
502
503 Self {
504 name,
505 title,
506 description,
507 output_schema,
508 icons,
509 annotations,
510 service,
511 input_schema,
512 }
513 }
514}
515
516pub struct ToolBuilder {
545 name: String,
546 title: Option<String>,
547 description: Option<String>,
548 output_schema: Option<Value>,
549 icons: Option<Vec<ToolIcon>>,
550 annotations: Option<ToolAnnotations>,
551}
552
553impl ToolBuilder {
554 pub fn new(name: impl Into<String>) -> Self {
555 Self {
556 name: name.into(),
557 title: None,
558 description: None,
559 output_schema: None,
560 icons: None,
561 annotations: None,
562 }
563 }
564
565 pub fn title(mut self, title: impl Into<String>) -> Self {
567 self.title = Some(title.into());
568 self
569 }
570
571 pub fn output_schema(mut self, schema: Value) -> Self {
573 self.output_schema = Some(schema);
574 self
575 }
576
577 pub fn icon(mut self, src: impl Into<String>) -> Self {
579 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
580 src: src.into(),
581 mime_type: None,
582 sizes: None,
583 });
584 self
585 }
586
587 pub fn icon_with_meta(
589 mut self,
590 src: impl Into<String>,
591 mime_type: Option<String>,
592 sizes: Option<Vec<String>>,
593 ) -> Self {
594 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
595 src: src.into(),
596 mime_type,
597 sizes,
598 });
599 self
600 }
601
602 pub fn description(mut self, description: impl Into<String>) -> Self {
604 self.description = Some(description.into());
605 self
606 }
607
608 pub fn read_only(mut self) -> Self {
610 self.annotations
611 .get_or_insert_with(ToolAnnotations::default)
612 .read_only_hint = true;
613 self
614 }
615
616 pub fn non_destructive(mut self) -> Self {
618 self.annotations
619 .get_or_insert_with(ToolAnnotations::default)
620 .destructive_hint = false;
621 self
622 }
623
624 pub fn idempotent(mut self) -> Self {
626 self.annotations
627 .get_or_insert_with(ToolAnnotations::default)
628 .idempotent_hint = true;
629 self
630 }
631
632 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
634 self.annotations = Some(annotations);
635 self
636 }
637
638 pub fn no_params_handler<F, Fut>(self, handler: F) -> ToolBuilderWithNoParamsHandler<F>
657 where
658 F: Fn() -> Fut + Send + Sync + 'static,
659 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
660 {
661 ToolBuilderWithNoParamsHandler {
662 name: self.name,
663 title: self.title,
664 description: self.description,
665 output_schema: self.output_schema,
666 icons: self.icons,
667 annotations: self.annotations,
668 handler,
669 }
670 }
671
672 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
716 where
717 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
718 F: Fn(I) -> Fut + Send + Sync + 'static,
719 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
720 {
721 ToolBuilderWithHandler {
722 name: self.name,
723 title: self.title,
724 description: self.description,
725 output_schema: self.output_schema,
726 icons: self.icons,
727 annotations: self.annotations,
728 handler,
729 _phantom: std::marker::PhantomData,
730 }
731 }
732
733 pub fn extractor_handler<S, F, T>(
787 self,
788 state: S,
789 handler: F,
790 ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
791 where
792 S: Clone + Send + Sync + 'static,
793 F: crate::extract::ExtractorHandler<S, T> + Clone,
794 T: Send + Sync + 'static,
795 {
796 crate::extract::ToolBuilderWithExtractor {
797 name: self.name,
798 title: self.title,
799 description: self.description,
800 output_schema: self.output_schema,
801 icons: self.icons,
802 annotations: self.annotations,
803 state,
804 handler,
805 input_schema: F::input_schema(),
806 _phantom: std::marker::PhantomData,
807 }
808 }
809
810 pub fn extractor_handler_typed<S, F, T, I>(
844 self,
845 state: S,
846 handler: F,
847 ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
848 where
849 S: Clone + Send + Sync + 'static,
850 F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
851 T: Send + Sync + 'static,
852 I: schemars::JsonSchema + Send + Sync + 'static,
853 {
854 crate::extract::ToolBuilderWithTypedExtractor {
855 name: self.name,
856 title: self.title,
857 description: self.description,
858 output_schema: self.output_schema,
859 icons: self.icons,
860 annotations: self.annotations,
861 state,
862 handler,
863 _phantom: std::marker::PhantomData,
864 }
865 }
866}
867
868struct NoParamsTypedHandler<F> {
872 handler: F,
873}
874
875impl<F, Fut> ToolHandler for NoParamsTypedHandler<F>
876where
877 F: Fn() -> Fut + Send + Sync + 'static,
878 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
879{
880 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
881 Box::pin(async move { (self.handler)().await })
882 }
883
884 fn input_schema(&self) -> Value {
885 serde_json::json!({ "type": "object" })
886 }
887}
888
889pub struct ToolBuilderWithHandler<I, F> {
891 name: String,
892 title: Option<String>,
893 description: Option<String>,
894 output_schema: Option<Value>,
895 icons: Option<Vec<ToolIcon>>,
896 annotations: Option<ToolAnnotations>,
897 handler: F,
898 _phantom: std::marker::PhantomData<I>,
899}
900
901pub struct ToolBuilderWithNoParamsHandler<F> {
905 name: String,
906 title: Option<String>,
907 description: Option<String>,
908 output_schema: Option<Value>,
909 icons: Option<Vec<ToolIcon>>,
910 annotations: Option<ToolAnnotations>,
911 handler: F,
912}
913
914impl<F, Fut> ToolBuilderWithNoParamsHandler<F>
915where
916 F: Fn() -> Fut + Send + Sync + 'static,
917 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
918{
919 pub fn build(self) -> Result<Tool> {
923 validate_tool_name(&self.name)?;
924 Ok(Tool::from_handler(
925 self.name,
926 self.title,
927 self.description,
928 self.output_schema,
929 self.icons,
930 self.annotations,
931 NoParamsTypedHandler {
932 handler: self.handler,
933 },
934 ))
935 }
936
937 pub fn layer<L>(self, layer: L) -> ToolBuilderWithNoParamsHandlerLayer<F, L> {
941 ToolBuilderWithNoParamsHandlerLayer {
942 name: self.name,
943 title: self.title,
944 description: self.description,
945 output_schema: self.output_schema,
946 icons: self.icons,
947 annotations: self.annotations,
948 handler: self.handler,
949 layer,
950 }
951 }
952}
953
954pub struct ToolBuilderWithNoParamsHandlerLayer<F, L> {
956 name: String,
957 title: Option<String>,
958 description: Option<String>,
959 output_schema: Option<Value>,
960 icons: Option<Vec<ToolIcon>>,
961 annotations: Option<ToolAnnotations>,
962 handler: F,
963 layer: L,
964}
965
966#[allow(private_bounds)]
967impl<F, Fut, L> ToolBuilderWithNoParamsHandlerLayer<F, L>
968where
969 F: Fn() -> Fut + Send + Sync + 'static,
970 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
971 L: tower::Layer<ToolHandlerService<NoParamsTypedHandler<F>>> + Clone + Send + Sync + 'static,
972 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
973 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
974 <L::Service as Service<ToolRequest>>::Future: Send,
975{
976 pub fn build(self) -> Result<Tool> {
980 validate_tool_name(&self.name)?;
981
982 let input_schema = serde_json::json!({ "type": "object" });
983
984 let handler_service = ToolHandlerService::new(NoParamsTypedHandler {
985 handler: self.handler,
986 });
987 let layered = self.layer.layer(handler_service);
988 let catch_error = ToolCatchError::new(layered);
989 let service = BoxCloneService::new(catch_error);
990
991 Ok(Tool {
992 name: self.name,
993 title: self.title,
994 description: self.description,
995 output_schema: self.output_schema,
996 icons: self.icons,
997 annotations: self.annotations,
998 service,
999 input_schema,
1000 })
1001 }
1002
1003 pub fn layer<L2>(
1005 self,
1006 layer: L2,
1007 ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<L2, L>> {
1008 ToolBuilderWithNoParamsHandlerLayer {
1009 name: self.name,
1010 title: self.title,
1011 description: self.description,
1012 output_schema: self.output_schema,
1013 icons: self.icons,
1014 annotations: self.annotations,
1015 handler: self.handler,
1016 layer: tower::layer::util::Stack::new(layer, self.layer),
1017 }
1018 }
1019}
1020
1021impl<I, F, Fut> ToolBuilderWithHandler<I, F>
1022where
1023 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1024 F: Fn(I) -> Fut + Send + Sync + 'static,
1025 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1026{
1027 pub fn build(self) -> Result<Tool> {
1031 validate_tool_name(&self.name)?;
1032 Ok(Tool::from_handler(
1033 self.name,
1034 self.title,
1035 self.description,
1036 self.output_schema,
1037 self.icons,
1038 self.annotations,
1039 TypedHandler {
1040 handler: self.handler,
1041 _phantom: std::marker::PhantomData,
1042 },
1043 ))
1044 }
1045
1046 pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
1073 ToolBuilderWithLayer {
1074 name: self.name,
1075 title: self.title,
1076 description: self.description,
1077 output_schema: self.output_schema,
1078 icons: self.icons,
1079 annotations: self.annotations,
1080 handler: self.handler,
1081 layer,
1082 _phantom: std::marker::PhantomData,
1083 }
1084 }
1085}
1086
1087pub struct ToolBuilderWithLayer<I, F, L> {
1091 name: String,
1092 title: Option<String>,
1093 description: Option<String>,
1094 output_schema: Option<Value>,
1095 icons: Option<Vec<ToolIcon>>,
1096 annotations: Option<ToolAnnotations>,
1097 handler: F,
1098 layer: L,
1099 _phantom: std::marker::PhantomData<I>,
1100}
1101
1102#[allow(private_bounds)]
1105impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
1106where
1107 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1108 F: Fn(I) -> Fut + Send + Sync + 'static,
1109 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1110 L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
1111 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1112 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1113 <L::Service as Service<ToolRequest>>::Future: Send,
1114{
1115 pub fn build(self) -> Result<Tool> {
1119 validate_tool_name(&self.name)?;
1120
1121 let input_schema = schemars::schema_for!(I);
1122 let input_schema = serde_json::to_value(input_schema)
1123 .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
1124
1125 let handler_service = ToolHandlerService::new(TypedHandler {
1126 handler: self.handler,
1127 _phantom: std::marker::PhantomData,
1128 });
1129 let layered = self.layer.layer(handler_service);
1130 let catch_error = ToolCatchError::new(layered);
1131 let service = BoxCloneService::new(catch_error);
1132
1133 Ok(Tool {
1134 name: self.name,
1135 title: self.title,
1136 description: self.description,
1137 output_schema: self.output_schema,
1138 icons: self.icons,
1139 annotations: self.annotations,
1140 service,
1141 input_schema,
1142 })
1143 }
1144
1145 pub fn layer<L2>(
1150 self,
1151 layer: L2,
1152 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
1153 ToolBuilderWithLayer {
1154 name: self.name,
1155 title: self.title,
1156 description: self.description,
1157 output_schema: self.output_schema,
1158 icons: self.icons,
1159 annotations: self.annotations,
1160 handler: self.handler,
1161 layer: tower::layer::util::Stack::new(layer, self.layer),
1162 _phantom: std::marker::PhantomData,
1163 }
1164 }
1165}
1166
1167struct TypedHandler<I, F> {
1173 handler: F,
1174 _phantom: std::marker::PhantomData<I>,
1175}
1176
1177impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1178where
1179 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1180 F: Fn(I) -> Fut + Send + Sync + 'static,
1181 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1182{
1183 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1184 Box::pin(async move {
1185 let input: I = serde_json::from_value(args)
1186 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1187 (self.handler)(input).await
1188 })
1189 }
1190
1191 fn input_schema(&self) -> Value {
1192 let schema = schemars::schema_for!(I);
1193 serde_json::to_value(schema).unwrap_or_else(|_| {
1194 serde_json::json!({
1195 "type": "object"
1196 })
1197 })
1198 }
1199}
1200
1201pub trait McpTool: Send + Sync + 'static {
1242 const NAME: &'static str;
1243 const DESCRIPTION: &'static str;
1244
1245 type Input: JsonSchema + DeserializeOwned + Send;
1246 type Output: Serialize + Send;
1247
1248 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1249
1250 fn annotations(&self) -> Option<ToolAnnotations> {
1252 None
1253 }
1254
1255 fn into_tool(self) -> Result<Tool>
1259 where
1260 Self: Sized,
1261 {
1262 validate_tool_name(Self::NAME)?;
1263 let annotations = self.annotations();
1264 let tool = Arc::new(self);
1265 Ok(Tool::from_handler(
1266 Self::NAME.to_string(),
1267 None,
1268 Some(Self::DESCRIPTION.to_string()),
1269 None,
1270 None,
1271 annotations,
1272 McpToolHandler { tool },
1273 ))
1274 }
1275}
1276
1277struct McpToolHandler<T: McpTool> {
1279 tool: Arc<T>,
1280}
1281
1282impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1283 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1284 let tool = self.tool.clone();
1285 Box::pin(async move {
1286 let input: T::Input = serde_json::from_value(args)
1287 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1288 let output = tool.call(input).await?;
1289 let value = serde_json::to_value(output)
1290 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
1291 Ok(CallToolResult::json(value))
1292 })
1293 }
1294
1295 fn input_schema(&self) -> Value {
1296 let schema = schemars::schema_for!(T::Input);
1297 serde_json::to_value(schema).unwrap_or_else(|_| {
1298 serde_json::json!({
1299 "type": "object"
1300 })
1301 })
1302 }
1303}
1304
1305#[cfg(test)]
1306mod tests {
1307 use super::*;
1308 use crate::extract::{Context, Json, RawArgs, State};
1309 use crate::protocol::Content;
1310 use schemars::JsonSchema;
1311 use serde::Deserialize;
1312
1313 #[derive(Debug, Deserialize, JsonSchema)]
1314 struct GreetInput {
1315 name: String,
1316 }
1317
1318 #[tokio::test]
1319 async fn test_builder_tool() {
1320 let tool = ToolBuilder::new("greet")
1321 .description("Greet someone")
1322 .handler(|input: GreetInput| async move {
1323 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1324 })
1325 .build()
1326 .expect("valid tool name");
1327
1328 assert_eq!(tool.name, "greet");
1329 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1330
1331 let result = tool.call(serde_json::json!({"name": "World"})).await;
1332
1333 assert!(!result.is_error);
1334 }
1335
1336 #[tokio::test]
1337 async fn test_raw_handler() {
1338 let tool = ToolBuilder::new("echo")
1339 .description("Echo input")
1340 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1341 Ok(CallToolResult::json(args))
1342 })
1343 .build()
1344 .expect("valid tool name");
1345
1346 let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1347
1348 assert!(!result.is_error);
1349 }
1350
1351 #[test]
1352 fn test_invalid_tool_name_empty() {
1353 let result = ToolBuilder::new("")
1354 .description("Empty name")
1355 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1356 Ok(CallToolResult::json(args))
1357 })
1358 .build();
1359
1360 assert!(result.is_err());
1361 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1362 }
1363
1364 #[test]
1365 fn test_invalid_tool_name_too_long() {
1366 let long_name = "a".repeat(129);
1367 let result = ToolBuilder::new(long_name)
1368 .description("Too long")
1369 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1370 Ok(CallToolResult::json(args))
1371 })
1372 .build();
1373
1374 assert!(result.is_err());
1375 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1376 }
1377
1378 #[test]
1379 fn test_invalid_tool_name_bad_chars() {
1380 let result = ToolBuilder::new("my tool!")
1381 .description("Bad chars")
1382 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1383 Ok(CallToolResult::json(args))
1384 })
1385 .build();
1386
1387 assert!(result.is_err());
1388 assert!(
1389 result
1390 .unwrap_err()
1391 .to_string()
1392 .contains("invalid character")
1393 );
1394 }
1395
1396 #[test]
1397 fn test_valid_tool_names() {
1398 let names = [
1400 "my_tool",
1401 "my-tool",
1402 "my.tool",
1403 "MyTool123",
1404 "a",
1405 &"a".repeat(128),
1406 ];
1407 for name in names {
1408 let result = ToolBuilder::new(name)
1409 .description("Valid")
1410 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1411 Ok(CallToolResult::json(args))
1412 })
1413 .build();
1414 assert!(result.is_ok(), "Expected '{}' to be valid", name);
1415 }
1416 }
1417
1418 #[tokio::test]
1419 async fn test_context_aware_handler() {
1420 use crate::context::notification_channel;
1421 use crate::protocol::{ProgressToken, RequestId};
1422
1423 #[derive(Debug, Deserialize, JsonSchema)]
1424 struct ProcessInput {
1425 count: i32,
1426 }
1427
1428 let tool = ToolBuilder::new("process")
1429 .description("Process with context")
1430 .extractor_handler_typed::<_, _, _, ProcessInput>(
1431 (),
1432 |ctx: Context, Json(input): Json<ProcessInput>| async move {
1433 for i in 0..input.count {
1435 if ctx.is_cancelled() {
1436 return Ok(CallToolResult::error("Cancelled"));
1437 }
1438 ctx.report_progress(i as f64, Some(input.count as f64), None)
1439 .await;
1440 }
1441 Ok(CallToolResult::text(format!(
1442 "Processed {} items",
1443 input.count
1444 )))
1445 },
1446 )
1447 .build()
1448 .expect("valid tool name");
1449
1450 assert_eq!(tool.name, "process");
1451
1452 let (tx, mut rx) = notification_channel(10);
1454 let ctx = RequestContext::new(RequestId::Number(1))
1455 .with_progress_token(ProgressToken::Number(42))
1456 .with_notification_sender(tx);
1457
1458 let result = tool
1459 .call_with_context(ctx, serde_json::json!({"count": 3}))
1460 .await;
1461
1462 assert!(!result.is_error);
1463
1464 let mut progress_count = 0;
1466 while rx.try_recv().is_ok() {
1467 progress_count += 1;
1468 }
1469 assert_eq!(progress_count, 3);
1470 }
1471
1472 #[tokio::test]
1473 async fn test_context_aware_handler_cancellation() {
1474 use crate::protocol::RequestId;
1475 use std::sync::atomic::{AtomicI32, Ordering};
1476
1477 #[derive(Debug, Deserialize, JsonSchema)]
1478 struct LongRunningInput {
1479 iterations: i32,
1480 }
1481
1482 let iterations_completed = Arc::new(AtomicI32::new(0));
1483 let iterations_ref = iterations_completed.clone();
1484
1485 let tool = ToolBuilder::new("long_running")
1486 .description("Long running task")
1487 .extractor_handler_typed::<_, _, _, LongRunningInput>(
1488 (),
1489 move |ctx: Context, Json(input): Json<LongRunningInput>| {
1490 let completed = iterations_ref.clone();
1491 async move {
1492 for i in 0..input.iterations {
1493 if ctx.is_cancelled() {
1494 return Ok(CallToolResult::error("Cancelled"));
1495 }
1496 completed.fetch_add(1, Ordering::SeqCst);
1497 tokio::task::yield_now().await;
1499 if i == 2 {
1501 ctx.cancellation_token().cancel();
1502 }
1503 }
1504 Ok(CallToolResult::text("Done"))
1505 }
1506 },
1507 )
1508 .build()
1509 .expect("valid tool name");
1510
1511 let ctx = RequestContext::new(RequestId::Number(1));
1512
1513 let result = tool
1514 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1515 .await;
1516
1517 assert!(result.is_error);
1520 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1521 }
1522
1523 #[tokio::test]
1524 async fn test_tool_builder_with_enhanced_fields() {
1525 let output_schema = serde_json::json!({
1526 "type": "object",
1527 "properties": {
1528 "greeting": {"type": "string"}
1529 }
1530 });
1531
1532 let tool = ToolBuilder::new("greet")
1533 .title("Greeting Tool")
1534 .description("Greet someone")
1535 .output_schema(output_schema.clone())
1536 .icon("https://example.com/icon.png")
1537 .icon_with_meta(
1538 "https://example.com/icon-large.png",
1539 Some("image/png".to_string()),
1540 Some(vec!["96x96".to_string()]),
1541 )
1542 .handler(|input: GreetInput| async move {
1543 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1544 })
1545 .build()
1546 .expect("valid tool name");
1547
1548 assert_eq!(tool.name, "greet");
1549 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1550 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1551 assert_eq!(tool.output_schema, Some(output_schema));
1552 assert!(tool.icons.is_some());
1553 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1554
1555 let def = tool.definition();
1557 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1558 assert!(def.output_schema.is_some());
1559 assert!(def.icons.is_some());
1560 }
1561
1562 #[tokio::test]
1563 async fn test_handler_with_state() {
1564 let shared = Arc::new("shared-state".to_string());
1565
1566 let tool = ToolBuilder::new("stateful")
1567 .description("Uses shared state")
1568 .extractor_handler_typed::<_, _, _, GreetInput>(
1569 shared,
1570 |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1571 Ok(CallToolResult::text(format!(
1572 "{}: Hello, {}!",
1573 state, input.name
1574 )))
1575 },
1576 )
1577 .build()
1578 .expect("valid tool name");
1579
1580 let result = tool.call(serde_json::json!({"name": "World"})).await;
1581 assert!(!result.is_error);
1582 }
1583
1584 #[tokio::test]
1585 async fn test_handler_with_state_and_context() {
1586 use crate::protocol::RequestId;
1587
1588 let shared = Arc::new(42_i32);
1589
1590 let tool =
1591 ToolBuilder::new("stateful_ctx")
1592 .description("Uses state and context")
1593 .extractor_handler_typed::<_, _, _, GreetInput>(
1594 shared,
1595 |State(state): State<Arc<i32>>,
1596 _ctx: Context,
1597 Json(input): Json<GreetInput>| async move {
1598 Ok(CallToolResult::text(format!(
1599 "{}: Hello, {}!",
1600 state, input.name
1601 )))
1602 },
1603 )
1604 .build()
1605 .expect("valid tool name");
1606
1607 let ctx = RequestContext::new(RequestId::Number(1));
1608 let result = tool
1609 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1610 .await;
1611 assert!(!result.is_error);
1612 }
1613
1614 #[tokio::test]
1615 async fn test_handler_no_params() {
1616 let tool = ToolBuilder::new("no_params")
1617 .description("Takes no parameters")
1618 .extractor_handler_typed::<_, _, _, NoParams>((), |Json(_): Json<NoParams>| async {
1619 Ok(CallToolResult::text("no params result"))
1620 })
1621 .build()
1622 .expect("valid tool name");
1623
1624 assert_eq!(tool.name, "no_params");
1625
1626 let result = tool.call(serde_json::json!({})).await;
1628 assert!(!result.is_error);
1629
1630 let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1632 assert!(!result.is_error);
1633
1634 let schema = tool.definition().input_schema;
1636 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1637 }
1638
1639 #[tokio::test]
1640 async fn test_handler_with_state_no_params() {
1641 let shared = Arc::new("shared_value".to_string());
1642
1643 let tool = ToolBuilder::new("with_state_no_params")
1644 .description("Takes no parameters but has state")
1645 .extractor_handler_typed::<_, _, _, NoParams>(
1646 shared,
1647 |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1648 Ok(CallToolResult::text(format!("state: {}", state)))
1649 },
1650 )
1651 .build()
1652 .expect("valid tool name");
1653
1654 assert_eq!(tool.name, "with_state_no_params");
1655
1656 let result = tool.call(serde_json::json!({})).await;
1658 assert!(!result.is_error);
1659 assert_eq!(result.first_text().unwrap(), "state: shared_value");
1660
1661 let schema = tool.definition().input_schema;
1663 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1664 }
1665
1666 #[tokio::test]
1667 async fn test_handler_no_params_with_context() {
1668 let tool = ToolBuilder::new("no_params_with_context")
1669 .description("Takes no parameters but has context")
1670 .extractor_handler_typed::<_, _, _, NoParams>(
1671 (),
1672 |_ctx: Context, Json(_): Json<NoParams>| async move {
1673 Ok(CallToolResult::text("context available"))
1674 },
1675 )
1676 .build()
1677 .expect("valid tool name");
1678
1679 assert_eq!(tool.name, "no_params_with_context");
1680
1681 let result = tool.call(serde_json::json!({})).await;
1682 assert!(!result.is_error);
1683 assert_eq!(result.first_text().unwrap(), "context available");
1684 }
1685
1686 #[tokio::test]
1687 async fn test_handler_with_state_and_context_no_params() {
1688 let shared = Arc::new("shared".to_string());
1689
1690 let tool = ToolBuilder::new("state_context_no_params")
1691 .description("Has state and context, no params")
1692 .extractor_handler_typed::<_, _, _, NoParams>(
1693 shared,
1694 |State(state): State<Arc<String>>,
1695 _ctx: Context,
1696 Json(_): Json<NoParams>| async move {
1697 Ok(CallToolResult::text(format!("state: {}", state)))
1698 },
1699 )
1700 .build()
1701 .expect("valid tool name");
1702
1703 assert_eq!(tool.name, "state_context_no_params");
1704
1705 let result = tool.call(serde_json::json!({})).await;
1706 assert!(!result.is_error);
1707 assert_eq!(result.first_text().unwrap(), "state: shared");
1708 }
1709
1710 #[tokio::test]
1711 async fn test_raw_handler_with_state() {
1712 let prefix = Arc::new("prefix:".to_string());
1713
1714 let tool = ToolBuilder::new("raw_with_state")
1715 .description("Raw handler with state")
1716 .extractor_handler(
1717 prefix,
1718 |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
1719 Ok(CallToolResult::text(format!("{} {}", state, args)))
1720 },
1721 )
1722 .build()
1723 .expect("valid tool name");
1724
1725 assert_eq!(tool.name, "raw_with_state");
1726
1727 let result = tool.call(serde_json::json!({"key": "value"})).await;
1728 assert!(!result.is_error);
1729 assert!(result.first_text().unwrap().starts_with("prefix:"));
1730 }
1731
1732 #[tokio::test]
1733 async fn test_raw_handler_with_state_and_context() {
1734 let prefix = Arc::new("prefix:".to_string());
1735
1736 let tool = ToolBuilder::new("raw_state_context")
1737 .description("Raw handler with state and context")
1738 .extractor_handler(
1739 prefix,
1740 |State(state): State<Arc<String>>,
1741 _ctx: Context,
1742 RawArgs(args): RawArgs| async move {
1743 Ok(CallToolResult::text(format!("{} {}", state, args)))
1744 },
1745 )
1746 .build()
1747 .expect("valid tool name");
1748
1749 assert_eq!(tool.name, "raw_state_context");
1750
1751 let result = tool.call(serde_json::json!({"key": "value"})).await;
1752 assert!(!result.is_error);
1753 assert!(result.first_text().unwrap().starts_with("prefix:"));
1754 }
1755
1756 #[tokio::test]
1757 async fn test_tool_with_timeout_layer() {
1758 use std::time::Duration;
1759 use tower::timeout::TimeoutLayer;
1760
1761 #[derive(Debug, Deserialize, JsonSchema)]
1762 struct SlowInput {
1763 delay_ms: u64,
1764 }
1765
1766 let tool = ToolBuilder::new("slow_tool")
1768 .description("A slow tool")
1769 .handler(|input: SlowInput| async move {
1770 tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
1771 Ok(CallToolResult::text("completed"))
1772 })
1773 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1774 .build()
1775 .expect("valid tool name");
1776
1777 let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
1779 assert!(!result.is_error);
1780 assert_eq!(result.first_text().unwrap(), "completed");
1781
1782 let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
1784 assert!(result.is_error);
1785 let msg = result.first_text().unwrap().to_lowercase();
1787 assert!(
1788 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1789 "Expected timeout error, got: {}",
1790 msg
1791 );
1792 }
1793
1794 #[tokio::test]
1795 async fn test_tool_with_concurrency_limit_layer() {
1796 use std::sync::atomic::{AtomicU32, Ordering};
1797 use std::time::Duration;
1798 use tower::limit::ConcurrencyLimitLayer;
1799
1800 #[derive(Debug, Deserialize, JsonSchema)]
1801 struct WorkInput {
1802 id: u32,
1803 }
1804
1805 let max_concurrent = Arc::new(AtomicU32::new(0));
1806 let current_concurrent = Arc::new(AtomicU32::new(0));
1807 let max_ref = max_concurrent.clone();
1808 let current_ref = current_concurrent.clone();
1809
1810 let tool = ToolBuilder::new("concurrent_tool")
1812 .description("A concurrent tool")
1813 .handler(move |input: WorkInput| {
1814 let max = max_ref.clone();
1815 let current = current_ref.clone();
1816 async move {
1817 let prev = current.fetch_add(1, Ordering::SeqCst);
1819 max.fetch_max(prev + 1, Ordering::SeqCst);
1820
1821 tokio::time::sleep(Duration::from_millis(50)).await;
1823
1824 current.fetch_sub(1, Ordering::SeqCst);
1825 Ok(CallToolResult::text(format!("completed {}", input.id)))
1826 }
1827 })
1828 .layer(ConcurrencyLimitLayer::new(2))
1829 .build()
1830 .expect("valid tool name");
1831
1832 let handles: Vec<_> = (0..4)
1834 .map(|i| {
1835 let t = tool.call(serde_json::json!({"id": i}));
1836 tokio::spawn(t)
1837 })
1838 .collect();
1839
1840 for handle in handles {
1841 let result = handle.await.unwrap();
1842 assert!(!result.is_error);
1843 }
1844
1845 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
1847 }
1848
1849 #[tokio::test]
1850 async fn test_tool_with_multiple_layers() {
1851 use std::time::Duration;
1852 use tower::limit::ConcurrencyLimitLayer;
1853 use tower::timeout::TimeoutLayer;
1854
1855 #[derive(Debug, Deserialize, JsonSchema)]
1856 struct Input {
1857 value: String,
1858 }
1859
1860 let tool = ToolBuilder::new("multi_layer_tool")
1862 .description("Tool with multiple layers")
1863 .handler(|input: Input| async move {
1864 Ok(CallToolResult::text(format!("processed: {}", input.value)))
1865 })
1866 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1867 .layer(ConcurrencyLimitLayer::new(10))
1868 .build()
1869 .expect("valid tool name");
1870
1871 let result = tool.call(serde_json::json!({"value": "test"})).await;
1872 assert!(!result.is_error);
1873 assert_eq!(result.first_text().unwrap(), "processed: test");
1874 }
1875
1876 #[test]
1877 fn test_tool_catch_error_clone() {
1878 let tool = ToolBuilder::new("test")
1881 .description("test")
1882 .extractor_handler((), |RawArgs(_args): RawArgs| async {
1883 Ok(CallToolResult::text("ok"))
1884 })
1885 .build()
1886 .unwrap();
1887 let _clone = tool.call(serde_json::json!({}));
1889 }
1890
1891 #[test]
1892 fn test_tool_catch_error_debug() {
1893 #[derive(Debug, Clone)]
1897 struct DebugService;
1898
1899 impl Service<ToolRequest> for DebugService {
1900 type Response = CallToolResult;
1901 type Error = crate::error::Error;
1902 type Future = Pin<
1903 Box<
1904 dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
1905 + Send,
1906 >,
1907 >;
1908
1909 fn poll_ready(
1910 &mut self,
1911 _cx: &mut std::task::Context<'_>,
1912 ) -> Poll<std::result::Result<(), Self::Error>> {
1913 Poll::Ready(Ok(()))
1914 }
1915
1916 fn call(&mut self, _req: ToolRequest) -> Self::Future {
1917 Box::pin(async { Ok(CallToolResult::text("ok")) })
1918 }
1919 }
1920
1921 let catch_error = ToolCatchError::new(DebugService);
1922 let debug = format!("{:?}", catch_error);
1923 assert!(debug.contains("ToolCatchError"));
1924 }
1925
1926 #[test]
1927 fn test_tool_request_new() {
1928 use crate::protocol::RequestId;
1929
1930 let ctx = RequestContext::new(RequestId::Number(42));
1931 let args = serde_json::json!({"key": "value"});
1932 let req = ToolRequest::new(ctx.clone(), args.clone());
1933
1934 assert_eq!(req.args, args);
1935 }
1936
1937 #[test]
1938 fn test_no_params_schema() {
1939 let schema = schemars::schema_for!(NoParams);
1941 let schema_value = serde_json::to_value(&schema).unwrap();
1942 assert_eq!(
1943 schema_value.get("type").and_then(|v| v.as_str()),
1944 Some("object"),
1945 "NoParams should generate type: object schema"
1946 );
1947 }
1948
1949 #[test]
1950 fn test_no_params_deserialize() {
1951 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1953 assert_eq!(from_empty_object, NoParams);
1954
1955 let from_null: NoParams = serde_json::from_str("null").unwrap();
1956 assert_eq!(from_null, NoParams);
1957
1958 let from_object_with_fields: NoParams =
1960 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1961 assert_eq!(from_object_with_fields, NoParams);
1962 }
1963
1964 #[tokio::test]
1965 async fn test_no_params_type_in_handler() {
1966 let tool = ToolBuilder::new("status")
1968 .description("Get status")
1969 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1970 .build()
1971 .expect("valid tool name");
1972
1973 let schema = tool.definition().input_schema;
1975 assert_eq!(
1976 schema.get("type").and_then(|v| v.as_str()),
1977 Some("object"),
1978 "NoParams handler should produce type: object schema"
1979 );
1980
1981 let result = tool.call(serde_json::json!({})).await;
1983 assert!(!result.is_error);
1984 }
1985
1986 #[tokio::test]
1987 async fn test_tool_with_name_prefix() {
1988 #[derive(Debug, Deserialize, JsonSchema)]
1989 struct Input {
1990 value: String,
1991 }
1992
1993 let tool = ToolBuilder::new("query")
1994 .description("Query something")
1995 .title("Query Tool")
1996 .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
1997 .build()
1998 .expect("valid tool name");
1999
2000 let prefixed = tool.with_name_prefix("db");
2002
2003 assert_eq!(prefixed.name, "db.query");
2005
2006 assert_eq!(prefixed.description.as_deref(), Some("Query something"));
2008 assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
2009
2010 let result = prefixed
2012 .call(serde_json::json!({"value": "test input"}))
2013 .await;
2014 assert!(!result.is_error);
2015 match &result.content[0] {
2016 Content::Text { text, .. } => assert_eq!(text, "test input"),
2017 _ => panic!("Expected text content"),
2018 }
2019 }
2020
2021 #[tokio::test]
2022 async fn test_tool_with_name_prefix_multiple_levels() {
2023 let tool = ToolBuilder::new("action")
2024 .description("Do something")
2025 .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
2026 .build()
2027 .expect("valid tool name");
2028
2029 let prefixed = tool.with_name_prefix("level1");
2031 assert_eq!(prefixed.name, "level1.action");
2032
2033 let double_prefixed = prefixed.with_name_prefix("level0");
2034 assert_eq!(double_prefixed.name, "level0.level1.action");
2035 }
2036
2037 #[tokio::test]
2042 async fn test_no_params_handler_basic() {
2043 let tool = ToolBuilder::new("get_status")
2044 .description("Get current status")
2045 .no_params_handler(|| async { Ok(CallToolResult::text("OK")) })
2046 .build()
2047 .expect("valid tool name");
2048
2049 assert_eq!(tool.name, "get_status");
2050 assert_eq!(tool.description.as_deref(), Some("Get current status"));
2051
2052 let result = tool.call(serde_json::json!({})).await;
2054 assert!(!result.is_error);
2055 assert_eq!(result.first_text().unwrap(), "OK");
2056
2057 let result = tool.call(serde_json::json!(null)).await;
2059 assert!(!result.is_error);
2060
2061 let schema = tool.definition().input_schema;
2063 assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
2064 }
2065
2066 #[tokio::test]
2067 async fn test_no_params_handler_with_captured_state() {
2068 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
2069 let counter_ref = counter.clone();
2070
2071 let tool = ToolBuilder::new("increment")
2072 .description("Increment counter")
2073 .no_params_handler(move || {
2074 let c = counter_ref.clone();
2075 async move {
2076 let prev = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2077 Ok(CallToolResult::text(format!("Incremented from {}", prev)))
2078 }
2079 })
2080 .build()
2081 .expect("valid tool name");
2082
2083 let _ = tool.call(serde_json::json!({})).await;
2085 let _ = tool.call(serde_json::json!({})).await;
2086 let result = tool.call(serde_json::json!({})).await;
2087
2088 assert!(!result.is_error);
2089 assert_eq!(result.first_text().unwrap(), "Incremented from 2");
2090 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
2091 }
2092
2093 #[tokio::test]
2094 async fn test_no_params_handler_with_layer() {
2095 use std::time::Duration;
2096 use tower::timeout::TimeoutLayer;
2097
2098 let tool = ToolBuilder::new("slow_status")
2099 .description("Slow status check")
2100 .no_params_handler(|| async {
2101 tokio::time::sleep(Duration::from_millis(10)).await;
2102 Ok(CallToolResult::text("done"))
2103 })
2104 .layer(TimeoutLayer::new(Duration::from_secs(1)))
2105 .build()
2106 .expect("valid tool name");
2107
2108 let result = tool.call(serde_json::json!({})).await;
2109 assert!(!result.is_error);
2110 assert_eq!(result.first_text().unwrap(), "done");
2111 }
2112
2113 #[tokio::test]
2114 async fn test_no_params_handler_timeout() {
2115 use std::time::Duration;
2116 use tower::timeout::TimeoutLayer;
2117
2118 let tool = ToolBuilder::new("very_slow_status")
2119 .description("Very slow status check")
2120 .no_params_handler(|| async {
2121 tokio::time::sleep(Duration::from_millis(200)).await;
2122 Ok(CallToolResult::text("done"))
2123 })
2124 .layer(TimeoutLayer::new(Duration::from_millis(50)))
2125 .build()
2126 .expect("valid tool name");
2127
2128 let result = tool.call(serde_json::json!({})).await;
2129 assert!(result.is_error);
2130 let msg = result.first_text().unwrap().to_lowercase();
2131 assert!(
2132 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2133 "Expected timeout error, got: {}",
2134 msg
2135 );
2136 }
2137
2138 #[tokio::test]
2139 async fn test_no_params_handler_with_multiple_layers() {
2140 use std::time::Duration;
2141 use tower::limit::ConcurrencyLimitLayer;
2142 use tower::timeout::TimeoutLayer;
2143
2144 let tool = ToolBuilder::new("multi_layer_status")
2145 .description("Status with multiple layers")
2146 .no_params_handler(|| async { Ok(CallToolResult::text("status ok")) })
2147 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2148 .layer(ConcurrencyLimitLayer::new(10))
2149 .build()
2150 .expect("valid tool name");
2151
2152 let result = tool.call(serde_json::json!({})).await;
2153 assert!(!result.is_error);
2154 assert_eq!(result.first_text().unwrap(), "status ok");
2155 }
2156}