1use std::borrow::Cow;
34use std::convert::Infallible;
35use std::fmt;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use schemars::{JsonSchema, Schema, SchemaGenerator};
42use serde::Serialize;
43use serde::de::DeserializeOwned;
44use serde_json::Value;
45use tower::util::BoxCloneService;
46use tower_service::Service;
47
48use crate::context::RequestContext;
49use crate::error::{Error, Result, ResultExt};
50use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
51
52#[derive(Debug, Clone)]
61pub struct ToolRequest {
62 pub ctx: RequestContext,
64 pub args: Value,
66}
67
68impl ToolRequest {
69 pub fn new(ctx: RequestContext, args: Value) -> Self {
71 Self { ctx, args }
72 }
73}
74
75pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
81
82pub struct ToolCatchError<S> {
88 inner: S,
89}
90
91impl<S> ToolCatchError<S> {
92 pub fn new(inner: S) -> Self {
94 Self { inner }
95 }
96}
97
98impl<S: Clone> Clone for ToolCatchError<S> {
99 fn clone(&self) -> Self {
100 Self {
101 inner: self.inner.clone(),
102 }
103 }
104}
105
106impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("ToolCatchError")
109 .field("inner", &self.inner)
110 .finish()
111 }
112}
113
114impl<S> Service<ToolRequest> for ToolCatchError<S>
115where
116 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
117 S::Error: fmt::Display + Send,
118 S::Future: Send,
119{
120 type Response = CallToolResult;
121 type Error = Infallible;
122 type Future =
123 Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
124
125 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
126 match self.inner.poll_ready(cx) {
128 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
129 Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
130 Poll::Pending => Poll::Pending,
131 }
132 }
133
134 fn call(&mut self, req: ToolRequest) -> Self::Future {
135 let fut = self.inner.call(req);
136
137 Box::pin(async move {
138 match fut.await {
139 Ok(result) => Ok(result),
140 Err(err) => Ok(CallToolResult::error(err.to_string())),
141 }
142 })
143 }
144}
145
146#[derive(Clone)]
177pub struct GuardLayer<G> {
178 guard: G,
179}
180
181impl<G> GuardLayer<G> {
182 pub fn new(guard: G) -> Self {
187 Self { guard }
188 }
189}
190
191impl<G, S> tower::Layer<S> for GuardLayer<G>
192where
193 G: Clone,
194{
195 type Service = GuardService<G, S>;
196
197 fn layer(&self, inner: S) -> Self::Service {
198 GuardService {
199 guard: self.guard.clone(),
200 inner,
201 }
202 }
203}
204
205#[derive(Clone)]
209pub struct GuardService<G, S> {
210 guard: G,
211 inner: S,
212}
213
214impl<G, S> Service<ToolRequest> for GuardService<G, S>
215where
216 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
217 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
218 S::Error: Into<Error> + Send,
219 S::Future: Send,
220{
221 type Response = CallToolResult;
222 type Error = Error;
223 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
224
225 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
226 self.inner.poll_ready(cx).map_err(Into::into)
227 }
228
229 fn call(&mut self, req: ToolRequest) -> Self::Future {
230 match (self.guard)(&req) {
231 Ok(()) => {
232 let fut = self.inner.call(req);
233 Box::pin(async move { fut.await.map_err(Into::into) })
234 }
235 Err(msg) => Box::pin(async move { Err(Error::tool(msg)) }),
236 }
237 }
238}
239
240#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
260pub struct NoParams;
261
262impl<'de> serde::Deserialize<'de> for NoParams {
263 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
264 where
265 D: serde::Deserializer<'de>,
266 {
267 struct NoParamsVisitor;
269
270 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
271 type Value = NoParams;
272
273 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
274 formatter.write_str("null or an object")
275 }
276
277 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
278 where
279 E: serde::de::Error,
280 {
281 Ok(NoParams)
282 }
283
284 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
285 where
286 E: serde::de::Error,
287 {
288 Ok(NoParams)
289 }
290
291 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
292 where
293 D: serde::Deserializer<'de>,
294 {
295 serde::Deserialize::deserialize(deserializer)
296 }
297
298 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
299 where
300 A: serde::de::MapAccess<'de>,
301 {
302 while map
304 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
305 .is_some()
306 {}
307 Ok(NoParams)
308 }
309 }
310
311 deserializer.deserialize_any(NoParamsVisitor)
312 }
313}
314
315impl JsonSchema for NoParams {
316 fn schema_name() -> Cow<'static, str> {
317 Cow::Borrowed("NoParams")
318 }
319
320 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
321 serde_json::json!({
322 "type": "object"
323 })
324 .try_into()
325 .expect("valid schema")
326 }
327}
328
329pub fn validate_tool_name(name: &str) -> Result<()> {
337 if name.is_empty() {
338 return Err(Error::tool("Tool name cannot be empty"));
339 }
340 if name.len() > 128 {
341 return Err(Error::tool(format!(
342 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
343 name,
344 name.len()
345 )));
346 }
347 if let Some(invalid_char) = name
348 .chars()
349 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
350 {
351 return Err(Error::tool(format!(
352 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
353 name, invalid_char
354 )));
355 }
356 Ok(())
357}
358
359pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
361
362pub trait ToolHandler: Send + Sync {
364 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
366
367 fn call_with_context(
372 &self,
373 _ctx: RequestContext,
374 args: Value,
375 ) -> BoxFuture<'_, Result<CallToolResult>> {
376 self.call(args)
377 }
378
379 fn uses_context(&self) -> bool {
381 false
382 }
383
384 fn input_schema(&self) -> Value;
386}
387
388pub(crate) struct ToolHandlerService<H> {
393 handler: Arc<H>,
394}
395
396impl<H> ToolHandlerService<H> {
397 pub(crate) fn new(handler: H) -> Self {
398 Self {
399 handler: Arc::new(handler),
400 }
401 }
402}
403
404impl<H> Clone for ToolHandlerService<H> {
405 fn clone(&self) -> Self {
406 Self {
407 handler: self.handler.clone(),
408 }
409 }
410}
411
412impl<H> Service<ToolRequest> for ToolHandlerService<H>
413where
414 H: ToolHandler + 'static,
415{
416 type Response = CallToolResult;
417 type Error = Error;
418 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
419
420 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
421 Poll::Ready(Ok(()))
422 }
423
424 fn call(&mut self, req: ToolRequest) -> Self::Future {
425 let handler = self.handler.clone();
426 Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
427 }
428}
429
430pub struct Tool {
437 pub name: String,
439 pub title: Option<String>,
441 pub description: Option<String>,
443 pub output_schema: Option<Value>,
445 pub icons: Option<Vec<ToolIcon>>,
447 pub annotations: Option<ToolAnnotations>,
449 pub(crate) service: BoxToolService,
451 pub(crate) input_schema: Value,
453}
454
455impl std::fmt::Debug for Tool {
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 f.debug_struct("Tool")
458 .field("name", &self.name)
459 .field("title", &self.title)
460 .field("description", &self.description)
461 .field("output_schema", &self.output_schema)
462 .field("icons", &self.icons)
463 .field("annotations", &self.annotations)
464 .finish_non_exhaustive()
465 }
466}
467
468unsafe impl Send for Tool {}
471unsafe impl Sync for Tool {}
472
473impl Clone for Tool {
474 fn clone(&self) -> Self {
475 Self {
476 name: self.name.clone(),
477 title: self.title.clone(),
478 description: self.description.clone(),
479 output_schema: self.output_schema.clone(),
480 icons: self.icons.clone(),
481 annotations: self.annotations.clone(),
482 service: self.service.clone(),
483 input_schema: self.input_schema.clone(),
484 }
485 }
486}
487
488impl Tool {
489 pub fn builder(name: impl Into<String>) -> ToolBuilder {
491 ToolBuilder::new(name)
492 }
493
494 pub fn definition(&self) -> ToolDefinition {
496 ToolDefinition {
497 name: self.name.clone(),
498 title: self.title.clone(),
499 description: self.description.clone(),
500 input_schema: self.input_schema.clone(),
501 output_schema: self.output_schema.clone(),
502 icons: self.icons.clone(),
503 annotations: self.annotations.clone(),
504 }
505 }
506
507 pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
512 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
513 self.call_with_context(ctx, args)
514 }
515
516 pub fn call_with_context(
527 &self,
528 ctx: RequestContext,
529 args: Value,
530 ) -> BoxFuture<'static, CallToolResult> {
531 use tower::ServiceExt;
532 let service = self.service.clone();
533 Box::pin(async move {
534 service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
537 })
538 }
539
540 pub fn with_guard<G>(self, guard: G) -> Self
569 where
570 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
571 {
572 let guarded = GuardService {
573 guard,
574 inner: self.service,
575 };
576 let caught = ToolCatchError::new(guarded);
577 Tool {
578 service: BoxCloneService::new(caught),
579 ..self
580 }
581 }
582
583 pub fn with_name_prefix(&self, prefix: &str) -> Self {
610 Self {
611 name: format!("{}.{}", prefix, self.name),
612 title: self.title.clone(),
613 description: self.description.clone(),
614 output_schema: self.output_schema.clone(),
615 icons: self.icons.clone(),
616 annotations: self.annotations.clone(),
617 service: self.service.clone(),
618 input_schema: self.input_schema.clone(),
619 }
620 }
621
622 fn from_handler<H: ToolHandler + 'static>(
624 name: String,
625 title: Option<String>,
626 description: Option<String>,
627 output_schema: Option<Value>,
628 icons: Option<Vec<ToolIcon>>,
629 annotations: Option<ToolAnnotations>,
630 handler: H,
631 ) -> Self {
632 let input_schema = handler.input_schema();
633 let handler_service = ToolHandlerService::new(handler);
634 let catch_error = ToolCatchError::new(handler_service);
635 let service = BoxCloneService::new(catch_error);
636
637 Self {
638 name,
639 title,
640 description,
641 output_schema,
642 icons,
643 annotations,
644 service,
645 input_schema,
646 }
647 }
648}
649
650pub struct ToolBuilder {
678 name: String,
679 title: Option<String>,
680 description: Option<String>,
681 output_schema: Option<Value>,
682 icons: Option<Vec<ToolIcon>>,
683 annotations: Option<ToolAnnotations>,
684}
685
686impl ToolBuilder {
687 pub fn new(name: impl Into<String>) -> Self {
699 let name = name.into();
700 if let Err(e) = validate_tool_name(&name) {
701 panic!("{e}");
702 }
703 Self {
704 name,
705 title: None,
706 description: None,
707 output_schema: None,
708 icons: None,
709 annotations: None,
710 }
711 }
712
713 pub fn try_new(name: impl Into<String>) -> Result<Self> {
719 let name = name.into();
720 validate_tool_name(&name)?;
721 Ok(Self {
722 name,
723 title: None,
724 description: None,
725 output_schema: None,
726 icons: None,
727 annotations: None,
728 })
729 }
730
731 pub fn title(mut self, title: impl Into<String>) -> Self {
733 self.title = Some(title.into());
734 self
735 }
736
737 pub fn output_schema(mut self, schema: Value) -> Self {
739 self.output_schema = Some(schema);
740 self
741 }
742
743 pub fn icon(mut self, src: impl Into<String>) -> Self {
745 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
746 src: src.into(),
747 mime_type: None,
748 sizes: None,
749 });
750 self
751 }
752
753 pub fn icon_with_meta(
755 mut self,
756 src: impl Into<String>,
757 mime_type: Option<String>,
758 sizes: Option<Vec<String>>,
759 ) -> Self {
760 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
761 src: src.into(),
762 mime_type,
763 sizes,
764 });
765 self
766 }
767
768 pub fn description(mut self, description: impl Into<String>) -> Self {
770 self.description = Some(description.into());
771 self
772 }
773
774 pub fn read_only(mut self) -> Self {
776 self.annotations
777 .get_or_insert_with(ToolAnnotations::default)
778 .read_only_hint = true;
779 self
780 }
781
782 pub fn non_destructive(mut self) -> Self {
784 self.annotations
785 .get_or_insert_with(ToolAnnotations::default)
786 .destructive_hint = false;
787 self
788 }
789
790 pub fn idempotent(mut self) -> Self {
792 self.annotations
793 .get_or_insert_with(ToolAnnotations::default)
794 .idempotent_hint = true;
795 self
796 }
797
798 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
800 self.annotations = Some(annotations);
801 self
802 }
803
804 pub fn no_params_handler<F, Fut>(self, handler: F) -> ToolBuilderWithNoParamsHandler<F>
822 where
823 F: Fn() -> Fut + Send + Sync + 'static,
824 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
825 {
826 ToolBuilderWithNoParamsHandler {
827 name: self.name,
828 title: self.title,
829 description: self.description,
830 output_schema: self.output_schema,
831 icons: self.icons,
832 annotations: self.annotations,
833 handler,
834 }
835 }
836
837 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
880 where
881 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
882 F: Fn(I) -> Fut + Send + Sync + 'static,
883 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
884 {
885 ToolBuilderWithHandler {
886 name: self.name,
887 title: self.title,
888 description: self.description,
889 output_schema: self.output_schema,
890 icons: self.icons,
891 annotations: self.annotations,
892 handler,
893 _phantom: std::marker::PhantomData,
894 }
895 }
896
897 pub fn extractor_handler<S, F, T>(
991 self,
992 state: S,
993 handler: F,
994 ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
995 where
996 S: Clone + Send + Sync + 'static,
997 F: crate::extract::ExtractorHandler<S, T> + Clone,
998 T: Send + Sync + 'static,
999 {
1000 crate::extract::ToolBuilderWithExtractor {
1001 name: self.name,
1002 title: self.title,
1003 description: self.description,
1004 output_schema: self.output_schema,
1005 icons: self.icons,
1006 annotations: self.annotations,
1007 state,
1008 handler,
1009 input_schema: F::input_schema(),
1010 _phantom: std::marker::PhantomData,
1011 }
1012 }
1013
1014 pub fn extractor_handler_typed<S, F, T, I>(
1051 self,
1052 state: S,
1053 handler: F,
1054 ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
1055 where
1056 S: Clone + Send + Sync + 'static,
1057 F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
1058 T: Send + Sync + 'static,
1059 I: schemars::JsonSchema + Send + Sync + 'static,
1060 {
1061 crate::extract::ToolBuilderWithTypedExtractor {
1062 name: self.name,
1063 title: self.title,
1064 description: self.description,
1065 output_schema: self.output_schema,
1066 icons: self.icons,
1067 annotations: self.annotations,
1068 state,
1069 handler,
1070 _phantom: std::marker::PhantomData,
1071 }
1072 }
1073}
1074
1075struct NoParamsTypedHandler<F> {
1079 handler: F,
1080}
1081
1082impl<F, Fut> ToolHandler for NoParamsTypedHandler<F>
1083where
1084 F: Fn() -> Fut + Send + Sync + 'static,
1085 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1086{
1087 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1088 Box::pin(async move { (self.handler)().await })
1089 }
1090
1091 fn input_schema(&self) -> Value {
1092 serde_json::json!({ "type": "object" })
1093 }
1094}
1095
1096pub struct ToolBuilderWithHandler<I, F> {
1098 name: String,
1099 title: Option<String>,
1100 description: Option<String>,
1101 output_schema: Option<Value>,
1102 icons: Option<Vec<ToolIcon>>,
1103 annotations: Option<ToolAnnotations>,
1104 handler: F,
1105 _phantom: std::marker::PhantomData<I>,
1106}
1107
1108pub struct ToolBuilderWithNoParamsHandler<F> {
1112 name: String,
1113 title: Option<String>,
1114 description: Option<String>,
1115 output_schema: Option<Value>,
1116 icons: Option<Vec<ToolIcon>>,
1117 annotations: Option<ToolAnnotations>,
1118 handler: F,
1119}
1120
1121impl<F, Fut> ToolBuilderWithNoParamsHandler<F>
1122where
1123 F: Fn() -> Fut + Send + Sync + 'static,
1124 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1125{
1126 pub fn build(self) -> Tool {
1128 Tool::from_handler(
1129 self.name,
1130 self.title,
1131 self.description,
1132 self.output_schema,
1133 self.icons,
1134 self.annotations,
1135 NoParamsTypedHandler {
1136 handler: self.handler,
1137 },
1138 )
1139 }
1140
1141 pub fn layer<L>(self, layer: L) -> ToolBuilderWithNoParamsHandlerLayer<F, L> {
1145 ToolBuilderWithNoParamsHandlerLayer {
1146 name: self.name,
1147 title: self.title,
1148 description: self.description,
1149 output_schema: self.output_schema,
1150 icons: self.icons,
1151 annotations: self.annotations,
1152 handler: self.handler,
1153 layer,
1154 }
1155 }
1156
1157 pub fn guard<G>(self, guard: G) -> ToolBuilderWithNoParamsHandlerLayer<F, GuardLayer<G>>
1161 where
1162 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1163 {
1164 self.layer(GuardLayer::new(guard))
1165 }
1166}
1167
1168pub struct ToolBuilderWithNoParamsHandlerLayer<F, L> {
1170 name: String,
1171 title: Option<String>,
1172 description: Option<String>,
1173 output_schema: Option<Value>,
1174 icons: Option<Vec<ToolIcon>>,
1175 annotations: Option<ToolAnnotations>,
1176 handler: F,
1177 layer: L,
1178}
1179
1180#[allow(private_bounds)]
1181impl<F, Fut, L> ToolBuilderWithNoParamsHandlerLayer<F, L>
1182where
1183 F: Fn() -> Fut + Send + Sync + 'static,
1184 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1185 L: tower::Layer<ToolHandlerService<NoParamsTypedHandler<F>>> + Clone + Send + Sync + 'static,
1186 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1187 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1188 <L::Service as Service<ToolRequest>>::Future: Send,
1189{
1190 pub fn build(self) -> Tool {
1192 let input_schema = serde_json::json!({ "type": "object" });
1193
1194 let handler_service = ToolHandlerService::new(NoParamsTypedHandler {
1195 handler: self.handler,
1196 });
1197 let layered = self.layer.layer(handler_service);
1198 let catch_error = ToolCatchError::new(layered);
1199 let service = BoxCloneService::new(catch_error);
1200
1201 Tool {
1202 name: self.name,
1203 title: self.title,
1204 description: self.description,
1205 output_schema: self.output_schema,
1206 icons: self.icons,
1207 annotations: self.annotations,
1208 service,
1209 input_schema,
1210 }
1211 }
1212
1213 pub fn layer<L2>(
1215 self,
1216 layer: L2,
1217 ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<L2, L>> {
1218 ToolBuilderWithNoParamsHandlerLayer {
1219 name: self.name,
1220 title: self.title,
1221 description: self.description,
1222 output_schema: self.output_schema,
1223 icons: self.icons,
1224 annotations: self.annotations,
1225 handler: self.handler,
1226 layer: tower::layer::util::Stack::new(layer, self.layer),
1227 }
1228 }
1229
1230 pub fn guard<G>(
1234 self,
1235 guard: G,
1236 ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<GuardLayer<G>, L>>
1237 where
1238 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1239 {
1240 self.layer(GuardLayer::new(guard))
1241 }
1242}
1243
1244impl<I, F, Fut> ToolBuilderWithHandler<I, F>
1245where
1246 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1247 F: Fn(I) -> Fut + Send + Sync + 'static,
1248 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1249{
1250 pub fn build(self) -> Tool {
1252 Tool::from_handler(
1253 self.name,
1254 self.title,
1255 self.description,
1256 self.output_schema,
1257 self.icons,
1258 self.annotations,
1259 TypedHandler {
1260 handler: self.handler,
1261 _phantom: std::marker::PhantomData,
1262 },
1263 )
1264 }
1265
1266 pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
1292 ToolBuilderWithLayer {
1293 name: self.name,
1294 title: self.title,
1295 description: self.description,
1296 output_schema: self.output_schema,
1297 icons: self.icons,
1298 annotations: self.annotations,
1299 handler: self.handler,
1300 layer,
1301 _phantom: std::marker::PhantomData,
1302 }
1303 }
1304
1305 pub fn guard<G>(self, guard: G) -> ToolBuilderWithLayer<I, F, GuardLayer<G>>
1312 where
1313 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1314 {
1315 self.layer(GuardLayer::new(guard))
1316 }
1317}
1318
1319pub struct ToolBuilderWithLayer<I, F, L> {
1323 name: String,
1324 title: Option<String>,
1325 description: Option<String>,
1326 output_schema: Option<Value>,
1327 icons: Option<Vec<ToolIcon>>,
1328 annotations: Option<ToolAnnotations>,
1329 handler: F,
1330 layer: L,
1331 _phantom: std::marker::PhantomData<I>,
1332}
1333
1334#[allow(private_bounds)]
1337impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
1338where
1339 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1340 F: Fn(I) -> Fut + Send + Sync + 'static,
1341 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1342 L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
1343 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1344 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1345 <L::Service as Service<ToolRequest>>::Future: Send,
1346{
1347 pub fn build(self) -> Tool {
1349 let input_schema = schemars::schema_for!(I);
1350 let input_schema = serde_json::to_value(input_schema)
1351 .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
1352
1353 let handler_service = ToolHandlerService::new(TypedHandler {
1354 handler: self.handler,
1355 _phantom: std::marker::PhantomData,
1356 });
1357 let layered = self.layer.layer(handler_service);
1358 let catch_error = ToolCatchError::new(layered);
1359 let service = BoxCloneService::new(catch_error);
1360
1361 Tool {
1362 name: self.name,
1363 title: self.title,
1364 description: self.description,
1365 output_schema: self.output_schema,
1366 icons: self.icons,
1367 annotations: self.annotations,
1368 service,
1369 input_schema,
1370 }
1371 }
1372
1373 pub fn layer<L2>(
1378 self,
1379 layer: L2,
1380 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
1381 ToolBuilderWithLayer {
1382 name: self.name,
1383 title: self.title,
1384 description: self.description,
1385 output_schema: self.output_schema,
1386 icons: self.icons,
1387 annotations: self.annotations,
1388 handler: self.handler,
1389 layer: tower::layer::util::Stack::new(layer, self.layer),
1390 _phantom: std::marker::PhantomData,
1391 }
1392 }
1393
1394 pub fn guard<G>(
1398 self,
1399 guard: G,
1400 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<GuardLayer<G>, L>>
1401 where
1402 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1403 {
1404 self.layer(GuardLayer::new(guard))
1405 }
1406}
1407
1408struct TypedHandler<I, F> {
1414 handler: F,
1415 _phantom: std::marker::PhantomData<I>,
1416}
1417
1418impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1419where
1420 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1421 F: Fn(I) -> Fut + Send + Sync + 'static,
1422 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1423{
1424 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1425 Box::pin(async move {
1426 let input: I = serde_json::from_value(args).tool_context("Invalid input")?;
1427 (self.handler)(input).await
1428 })
1429 }
1430
1431 fn input_schema(&self) -> Value {
1432 let schema = schemars::schema_for!(I);
1433 serde_json::to_value(schema).unwrap_or_else(|_| {
1434 serde_json::json!({
1435 "type": "object"
1436 })
1437 })
1438 }
1439}
1440
1441pub trait McpTool: Send + Sync + 'static {
1482 const NAME: &'static str;
1483 const DESCRIPTION: &'static str;
1484
1485 type Input: JsonSchema + DeserializeOwned + Send;
1486 type Output: Serialize + Send;
1487
1488 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1489
1490 fn annotations(&self) -> Option<ToolAnnotations> {
1492 None
1493 }
1494
1495 fn into_tool(self) -> Tool
1503 where
1504 Self: Sized,
1505 {
1506 if let Err(e) = validate_tool_name(Self::NAME) {
1507 panic!("{e}");
1508 }
1509 let annotations = self.annotations();
1510 let tool = Arc::new(self);
1511 Tool::from_handler(
1512 Self::NAME.to_string(),
1513 None,
1514 Some(Self::DESCRIPTION.to_string()),
1515 None,
1516 None,
1517 annotations,
1518 McpToolHandler { tool },
1519 )
1520 }
1521}
1522
1523struct McpToolHandler<T: McpTool> {
1525 tool: Arc<T>,
1526}
1527
1528impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1529 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1530 let tool = self.tool.clone();
1531 Box::pin(async move {
1532 let input: T::Input = serde_json::from_value(args).tool_context("Invalid input")?;
1533 let output = tool.call(input).await?;
1534 let value = serde_json::to_value(output).tool_context("Failed to serialize output")?;
1535 Ok(CallToolResult::json(value))
1536 })
1537 }
1538
1539 fn input_schema(&self) -> Value {
1540 let schema = schemars::schema_for!(T::Input);
1541 serde_json::to_value(schema).unwrap_or_else(|_| {
1542 serde_json::json!({
1543 "type": "object"
1544 })
1545 })
1546 }
1547}
1548
1549#[cfg(test)]
1550mod tests {
1551 use super::*;
1552 use crate::extract::{Context, Json, RawArgs, State};
1553 use crate::protocol::Content;
1554 use schemars::JsonSchema;
1555 use serde::Deserialize;
1556
1557 #[derive(Debug, Deserialize, JsonSchema)]
1558 struct GreetInput {
1559 name: String,
1560 }
1561
1562 #[tokio::test]
1563 async fn test_builder_tool() {
1564 let tool = ToolBuilder::new("greet")
1565 .description("Greet someone")
1566 .handler(|input: GreetInput| async move {
1567 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1568 })
1569 .build();
1570
1571 assert_eq!(tool.name, "greet");
1572 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1573
1574 let result = tool.call(serde_json::json!({"name": "World"})).await;
1575
1576 assert!(!result.is_error);
1577 }
1578
1579 #[tokio::test]
1580 async fn test_raw_handler() {
1581 let tool = ToolBuilder::new("echo")
1582 .description("Echo input")
1583 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1584 Ok(CallToolResult::json(args))
1585 })
1586 .build();
1587
1588 let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1589
1590 assert!(!result.is_error);
1591 }
1592
1593 #[test]
1594 fn test_invalid_tool_name_empty() {
1595 let err = ToolBuilder::try_new("").err().expect("should fail");
1596 assert!(err.to_string().contains("cannot be empty"));
1597 }
1598
1599 #[test]
1600 fn test_invalid_tool_name_too_long() {
1601 let long_name = "a".repeat(129);
1602 let err = ToolBuilder::try_new(long_name).err().expect("should fail");
1603 assert!(err.to_string().contains("exceeds maximum"));
1604 }
1605
1606 #[test]
1607 fn test_invalid_tool_name_bad_chars() {
1608 let err = ToolBuilder::try_new("my tool!").err().expect("should fail");
1609 assert!(err.to_string().contains("invalid character"));
1610 }
1611
1612 #[test]
1613 #[should_panic(expected = "cannot be empty")]
1614 fn test_new_panics_on_empty_name() {
1615 ToolBuilder::new("");
1616 }
1617
1618 #[test]
1619 #[should_panic(expected = "exceeds maximum")]
1620 fn test_new_panics_on_too_long_name() {
1621 ToolBuilder::new("a".repeat(129));
1622 }
1623
1624 #[test]
1625 #[should_panic(expected = "invalid character")]
1626 fn test_new_panics_on_invalid_chars() {
1627 ToolBuilder::new("my tool!");
1628 }
1629
1630 #[test]
1631 fn test_valid_tool_names() {
1632 let names = [
1634 "my_tool",
1635 "my-tool",
1636 "my.tool",
1637 "MyTool123",
1638 "a",
1639 &"a".repeat(128),
1640 ];
1641 for name in names {
1642 assert!(
1643 ToolBuilder::try_new(name).is_ok(),
1644 "Expected '{}' to be valid",
1645 name
1646 );
1647 }
1648 }
1649
1650 #[tokio::test]
1651 async fn test_context_aware_handler() {
1652 use crate::context::notification_channel;
1653 use crate::protocol::{ProgressToken, RequestId};
1654
1655 #[derive(Debug, Deserialize, JsonSchema)]
1656 struct ProcessInput {
1657 count: i32,
1658 }
1659
1660 let tool = ToolBuilder::new("process")
1661 .description("Process with context")
1662 .extractor_handler(
1663 (),
1664 |ctx: Context, Json(input): Json<ProcessInput>| async move {
1665 for i in 0..input.count {
1667 if ctx.is_cancelled() {
1668 return Ok(CallToolResult::error("Cancelled"));
1669 }
1670 ctx.report_progress(i as f64, Some(input.count as f64), None)
1671 .await;
1672 }
1673 Ok(CallToolResult::text(format!(
1674 "Processed {} items",
1675 input.count
1676 )))
1677 },
1678 )
1679 .build();
1680
1681 assert_eq!(tool.name, "process");
1682
1683 let (tx, mut rx) = notification_channel(10);
1685 let ctx = RequestContext::new(RequestId::Number(1))
1686 .with_progress_token(ProgressToken::Number(42))
1687 .with_notification_sender(tx);
1688
1689 let result = tool
1690 .call_with_context(ctx, serde_json::json!({"count": 3}))
1691 .await;
1692
1693 assert!(!result.is_error);
1694
1695 let mut progress_count = 0;
1697 while rx.try_recv().is_ok() {
1698 progress_count += 1;
1699 }
1700 assert_eq!(progress_count, 3);
1701 }
1702
1703 #[tokio::test]
1704 async fn test_context_aware_handler_cancellation() {
1705 use crate::protocol::RequestId;
1706 use std::sync::atomic::{AtomicI32, Ordering};
1707
1708 #[derive(Debug, Deserialize, JsonSchema)]
1709 struct LongRunningInput {
1710 iterations: i32,
1711 }
1712
1713 let iterations_completed = Arc::new(AtomicI32::new(0));
1714 let iterations_ref = iterations_completed.clone();
1715
1716 let tool = ToolBuilder::new("long_running")
1717 .description("Long running task")
1718 .extractor_handler(
1719 (),
1720 move |ctx: Context, Json(input): Json<LongRunningInput>| {
1721 let completed = iterations_ref.clone();
1722 async move {
1723 for i in 0..input.iterations {
1724 if ctx.is_cancelled() {
1725 return Ok(CallToolResult::error("Cancelled"));
1726 }
1727 completed.fetch_add(1, Ordering::SeqCst);
1728 tokio::task::yield_now().await;
1730 if i == 2 {
1732 ctx.cancellation_token().cancel();
1733 }
1734 }
1735 Ok(CallToolResult::text("Done"))
1736 }
1737 },
1738 )
1739 .build();
1740
1741 let ctx = RequestContext::new(RequestId::Number(1));
1742
1743 let result = tool
1744 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1745 .await;
1746
1747 assert!(result.is_error);
1750 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1751 }
1752
1753 #[tokio::test]
1754 async fn test_tool_builder_with_enhanced_fields() {
1755 let output_schema = serde_json::json!({
1756 "type": "object",
1757 "properties": {
1758 "greeting": {"type": "string"}
1759 }
1760 });
1761
1762 let tool = ToolBuilder::new("greet")
1763 .title("Greeting Tool")
1764 .description("Greet someone")
1765 .output_schema(output_schema.clone())
1766 .icon("https://example.com/icon.png")
1767 .icon_with_meta(
1768 "https://example.com/icon-large.png",
1769 Some("image/png".to_string()),
1770 Some(vec!["96x96".to_string()]),
1771 )
1772 .handler(|input: GreetInput| async move {
1773 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1774 })
1775 .build();
1776
1777 assert_eq!(tool.name, "greet");
1778 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1779 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1780 assert_eq!(tool.output_schema, Some(output_schema));
1781 assert!(tool.icons.is_some());
1782 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1783
1784 let def = tool.definition();
1786 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1787 assert!(def.output_schema.is_some());
1788 assert!(def.icons.is_some());
1789 }
1790
1791 #[tokio::test]
1792 async fn test_handler_with_state() {
1793 let shared = Arc::new("shared-state".to_string());
1794
1795 let tool = ToolBuilder::new("stateful")
1796 .description("Uses shared state")
1797 .extractor_handler(
1798 shared,
1799 |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1800 Ok(CallToolResult::text(format!(
1801 "{}: Hello, {}!",
1802 state, input.name
1803 )))
1804 },
1805 )
1806 .build();
1807
1808 let result = tool.call(serde_json::json!({"name": "World"})).await;
1809 assert!(!result.is_error);
1810 }
1811
1812 #[tokio::test]
1813 async fn test_handler_with_state_and_context() {
1814 use crate::protocol::RequestId;
1815
1816 let shared = Arc::new(42_i32);
1817
1818 let tool =
1819 ToolBuilder::new("stateful_ctx")
1820 .description("Uses state and context")
1821 .extractor_handler(
1822 shared,
1823 |State(state): State<Arc<i32>>,
1824 _ctx: Context,
1825 Json(input): Json<GreetInput>| async move {
1826 Ok(CallToolResult::text(format!(
1827 "{}: Hello, {}!",
1828 state, input.name
1829 )))
1830 },
1831 )
1832 .build();
1833
1834 let ctx = RequestContext::new(RequestId::Number(1));
1835 let result = tool
1836 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1837 .await;
1838 assert!(!result.is_error);
1839 }
1840
1841 #[tokio::test]
1842 async fn test_handler_no_params() {
1843 let tool = ToolBuilder::new("no_params")
1844 .description("Takes no parameters")
1845 .extractor_handler((), |Json(_): Json<NoParams>| async {
1846 Ok(CallToolResult::text("no params result"))
1847 })
1848 .build();
1849
1850 assert_eq!(tool.name, "no_params");
1851
1852 let result = tool.call(serde_json::json!({})).await;
1854 assert!(!result.is_error);
1855
1856 let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1858 assert!(!result.is_error);
1859
1860 let schema = tool.definition().input_schema;
1862 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1863 }
1864
1865 #[tokio::test]
1866 async fn test_handler_with_state_no_params() {
1867 let shared = Arc::new("shared_value".to_string());
1868
1869 let tool = ToolBuilder::new("with_state_no_params")
1870 .description("Takes no parameters but has state")
1871 .extractor_handler(
1872 shared,
1873 |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1874 Ok(CallToolResult::text(format!("state: {}", state)))
1875 },
1876 )
1877 .build();
1878
1879 assert_eq!(tool.name, "with_state_no_params");
1880
1881 let result = tool.call(serde_json::json!({})).await;
1883 assert!(!result.is_error);
1884 assert_eq!(result.first_text().unwrap(), "state: shared_value");
1885
1886 let schema = tool.definition().input_schema;
1888 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1889 }
1890
1891 #[tokio::test]
1892 async fn test_handler_no_params_with_context() {
1893 let tool = ToolBuilder::new("no_params_with_context")
1894 .description("Takes no parameters but has context")
1895 .extractor_handler((), |_ctx: Context, Json(_): Json<NoParams>| async move {
1896 Ok(CallToolResult::text("context available"))
1897 })
1898 .build();
1899
1900 assert_eq!(tool.name, "no_params_with_context");
1901
1902 let result = tool.call(serde_json::json!({})).await;
1903 assert!(!result.is_error);
1904 assert_eq!(result.first_text().unwrap(), "context available");
1905 }
1906
1907 #[tokio::test]
1908 async fn test_handler_with_state_and_context_no_params() {
1909 let shared = Arc::new("shared".to_string());
1910
1911 let tool = ToolBuilder::new("state_context_no_params")
1912 .description("Has state and context, no params")
1913 .extractor_handler(
1914 shared,
1915 |State(state): State<Arc<String>>,
1916 _ctx: Context,
1917 Json(_): Json<NoParams>| async move {
1918 Ok(CallToolResult::text(format!("state: {}", state)))
1919 },
1920 )
1921 .build();
1922
1923 assert_eq!(tool.name, "state_context_no_params");
1924
1925 let result = tool.call(serde_json::json!({})).await;
1926 assert!(!result.is_error);
1927 assert_eq!(result.first_text().unwrap(), "state: shared");
1928 }
1929
1930 #[tokio::test]
1931 async fn test_raw_handler_with_state() {
1932 let prefix = Arc::new("prefix:".to_string());
1933
1934 let tool = ToolBuilder::new("raw_with_state")
1935 .description("Raw handler with state")
1936 .extractor_handler(
1937 prefix,
1938 |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
1939 Ok(CallToolResult::text(format!("{} {}", state, args)))
1940 },
1941 )
1942 .build();
1943
1944 assert_eq!(tool.name, "raw_with_state");
1945
1946 let result = tool.call(serde_json::json!({"key": "value"})).await;
1947 assert!(!result.is_error);
1948 assert!(result.first_text().unwrap().starts_with("prefix:"));
1949 }
1950
1951 #[tokio::test]
1952 async fn test_raw_handler_with_state_and_context() {
1953 let prefix = Arc::new("prefix:".to_string());
1954
1955 let tool = ToolBuilder::new("raw_state_context")
1956 .description("Raw handler with state and context")
1957 .extractor_handler(
1958 prefix,
1959 |State(state): State<Arc<String>>,
1960 _ctx: Context,
1961 RawArgs(args): RawArgs| async move {
1962 Ok(CallToolResult::text(format!("{} {}", state, args)))
1963 },
1964 )
1965 .build();
1966
1967 assert_eq!(tool.name, "raw_state_context");
1968
1969 let result = tool.call(serde_json::json!({"key": "value"})).await;
1970 assert!(!result.is_error);
1971 assert!(result.first_text().unwrap().starts_with("prefix:"));
1972 }
1973
1974 #[tokio::test]
1975 async fn test_tool_with_timeout_layer() {
1976 use std::time::Duration;
1977 use tower::timeout::TimeoutLayer;
1978
1979 #[derive(Debug, Deserialize, JsonSchema)]
1980 struct SlowInput {
1981 delay_ms: u64,
1982 }
1983
1984 let tool = ToolBuilder::new("slow_tool")
1986 .description("A slow tool")
1987 .handler(|input: SlowInput| async move {
1988 tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
1989 Ok(CallToolResult::text("completed"))
1990 })
1991 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1992 .build();
1993
1994 let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
1996 assert!(!result.is_error);
1997 assert_eq!(result.first_text().unwrap(), "completed");
1998
1999 let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
2001 assert!(result.is_error);
2002 let msg = result.first_text().unwrap().to_lowercase();
2004 assert!(
2005 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2006 "Expected timeout error, got: {}",
2007 msg
2008 );
2009 }
2010
2011 #[tokio::test]
2012 async fn test_tool_with_concurrency_limit_layer() {
2013 use std::sync::atomic::{AtomicU32, Ordering};
2014 use std::time::Duration;
2015 use tower::limit::ConcurrencyLimitLayer;
2016
2017 #[derive(Debug, Deserialize, JsonSchema)]
2018 struct WorkInput {
2019 id: u32,
2020 }
2021
2022 let max_concurrent = Arc::new(AtomicU32::new(0));
2023 let current_concurrent = Arc::new(AtomicU32::new(0));
2024 let max_ref = max_concurrent.clone();
2025 let current_ref = current_concurrent.clone();
2026
2027 let tool = ToolBuilder::new("concurrent_tool")
2029 .description("A concurrent tool")
2030 .handler(move |input: WorkInput| {
2031 let max = max_ref.clone();
2032 let current = current_ref.clone();
2033 async move {
2034 let prev = current.fetch_add(1, Ordering::SeqCst);
2036 max.fetch_max(prev + 1, Ordering::SeqCst);
2037
2038 tokio::time::sleep(Duration::from_millis(50)).await;
2040
2041 current.fetch_sub(1, Ordering::SeqCst);
2042 Ok(CallToolResult::text(format!("completed {}", input.id)))
2043 }
2044 })
2045 .layer(ConcurrencyLimitLayer::new(2))
2046 .build();
2047
2048 let handles: Vec<_> = (0..4)
2050 .map(|i| {
2051 let t = tool.call(serde_json::json!({"id": i}));
2052 tokio::spawn(t)
2053 })
2054 .collect();
2055
2056 for handle in handles {
2057 let result = handle.await.unwrap();
2058 assert!(!result.is_error);
2059 }
2060
2061 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
2063 }
2064
2065 #[tokio::test]
2066 async fn test_tool_with_multiple_layers() {
2067 use std::time::Duration;
2068 use tower::limit::ConcurrencyLimitLayer;
2069 use tower::timeout::TimeoutLayer;
2070
2071 #[derive(Debug, Deserialize, JsonSchema)]
2072 struct Input {
2073 value: String,
2074 }
2075
2076 let tool = ToolBuilder::new("multi_layer_tool")
2078 .description("Tool with multiple layers")
2079 .handler(|input: Input| async move {
2080 Ok(CallToolResult::text(format!("processed: {}", input.value)))
2081 })
2082 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2083 .layer(ConcurrencyLimitLayer::new(10))
2084 .build();
2085
2086 let result = tool.call(serde_json::json!({"value": "test"})).await;
2087 assert!(!result.is_error);
2088 assert_eq!(result.first_text().unwrap(), "processed: test");
2089 }
2090
2091 #[test]
2092 fn test_tool_catch_error_clone() {
2093 let tool = ToolBuilder::new("test")
2096 .description("test")
2097 .extractor_handler((), |RawArgs(_args): RawArgs| async {
2098 Ok(CallToolResult::text("ok"))
2099 })
2100 .build();
2101 let _clone = tool.call(serde_json::json!({}));
2103 }
2104
2105 #[test]
2106 fn test_tool_catch_error_debug() {
2107 #[derive(Debug, Clone)]
2111 struct DebugService;
2112
2113 impl Service<ToolRequest> for DebugService {
2114 type Response = CallToolResult;
2115 type Error = crate::error::Error;
2116 type Future = Pin<
2117 Box<
2118 dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
2119 + Send,
2120 >,
2121 >;
2122
2123 fn poll_ready(
2124 &mut self,
2125 _cx: &mut std::task::Context<'_>,
2126 ) -> Poll<std::result::Result<(), Self::Error>> {
2127 Poll::Ready(Ok(()))
2128 }
2129
2130 fn call(&mut self, _req: ToolRequest) -> Self::Future {
2131 Box::pin(async { Ok(CallToolResult::text("ok")) })
2132 }
2133 }
2134
2135 let catch_error = ToolCatchError::new(DebugService);
2136 let debug = format!("{:?}", catch_error);
2137 assert!(debug.contains("ToolCatchError"));
2138 }
2139
2140 #[test]
2141 fn test_tool_request_new() {
2142 use crate::protocol::RequestId;
2143
2144 let ctx = RequestContext::new(RequestId::Number(42));
2145 let args = serde_json::json!({"key": "value"});
2146 let req = ToolRequest::new(ctx.clone(), args.clone());
2147
2148 assert_eq!(req.args, args);
2149 }
2150
2151 #[test]
2152 fn test_no_params_schema() {
2153 let schema = schemars::schema_for!(NoParams);
2155 let schema_value = serde_json::to_value(&schema).unwrap();
2156 assert_eq!(
2157 schema_value.get("type").and_then(|v| v.as_str()),
2158 Some("object"),
2159 "NoParams should generate type: object schema"
2160 );
2161 }
2162
2163 #[test]
2164 fn test_no_params_deserialize() {
2165 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
2167 assert_eq!(from_empty_object, NoParams);
2168
2169 let from_null: NoParams = serde_json::from_str("null").unwrap();
2170 assert_eq!(from_null, NoParams);
2171
2172 let from_object_with_fields: NoParams =
2174 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
2175 assert_eq!(from_object_with_fields, NoParams);
2176 }
2177
2178 #[tokio::test]
2179 async fn test_no_params_type_in_handler() {
2180 let tool = ToolBuilder::new("status")
2182 .description("Get status")
2183 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
2184 .build();
2185
2186 let schema = tool.definition().input_schema;
2188 assert_eq!(
2189 schema.get("type").and_then(|v| v.as_str()),
2190 Some("object"),
2191 "NoParams handler should produce type: object schema"
2192 );
2193
2194 let result = tool.call(serde_json::json!({})).await;
2196 assert!(!result.is_error);
2197 }
2198
2199 #[tokio::test]
2200 async fn test_tool_with_name_prefix() {
2201 #[derive(Debug, Deserialize, JsonSchema)]
2202 struct Input {
2203 value: String,
2204 }
2205
2206 let tool = ToolBuilder::new("query")
2207 .description("Query something")
2208 .title("Query Tool")
2209 .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
2210 .build();
2211
2212 let prefixed = tool.with_name_prefix("db");
2214
2215 assert_eq!(prefixed.name, "db.query");
2217
2218 assert_eq!(prefixed.description.as_deref(), Some("Query something"));
2220 assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
2221
2222 let result = prefixed
2224 .call(serde_json::json!({"value": "test input"}))
2225 .await;
2226 assert!(!result.is_error);
2227 match &result.content[0] {
2228 Content::Text { text, .. } => assert_eq!(text, "test input"),
2229 _ => panic!("Expected text content"),
2230 }
2231 }
2232
2233 #[tokio::test]
2234 async fn test_tool_with_name_prefix_multiple_levels() {
2235 let tool = ToolBuilder::new("action")
2236 .description("Do something")
2237 .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
2238 .build();
2239
2240 let prefixed = tool.with_name_prefix("level1");
2242 assert_eq!(prefixed.name, "level1.action");
2243
2244 let double_prefixed = prefixed.with_name_prefix("level0");
2245 assert_eq!(double_prefixed.name, "level0.level1.action");
2246 }
2247
2248 #[tokio::test]
2253 async fn test_no_params_handler_basic() {
2254 let tool = ToolBuilder::new("get_status")
2255 .description("Get current status")
2256 .no_params_handler(|| async { Ok(CallToolResult::text("OK")) })
2257 .build();
2258
2259 assert_eq!(tool.name, "get_status");
2260 assert_eq!(tool.description.as_deref(), Some("Get current status"));
2261
2262 let result = tool.call(serde_json::json!({})).await;
2264 assert!(!result.is_error);
2265 assert_eq!(result.first_text().unwrap(), "OK");
2266
2267 let result = tool.call(serde_json::json!(null)).await;
2269 assert!(!result.is_error);
2270
2271 let schema = tool.definition().input_schema;
2273 assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
2274 }
2275
2276 #[tokio::test]
2277 async fn test_no_params_handler_with_captured_state() {
2278 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
2279 let counter_ref = counter.clone();
2280
2281 let tool = ToolBuilder::new("increment")
2282 .description("Increment counter")
2283 .no_params_handler(move || {
2284 let c = counter_ref.clone();
2285 async move {
2286 let prev = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2287 Ok(CallToolResult::text(format!("Incremented from {}", prev)))
2288 }
2289 })
2290 .build();
2291
2292 let _ = tool.call(serde_json::json!({})).await;
2294 let _ = tool.call(serde_json::json!({})).await;
2295 let result = tool.call(serde_json::json!({})).await;
2296
2297 assert!(!result.is_error);
2298 assert_eq!(result.first_text().unwrap(), "Incremented from 2");
2299 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
2300 }
2301
2302 #[tokio::test]
2303 async fn test_no_params_handler_with_layer() {
2304 use std::time::Duration;
2305 use tower::timeout::TimeoutLayer;
2306
2307 let tool = ToolBuilder::new("slow_status")
2308 .description("Slow status check")
2309 .no_params_handler(|| async {
2310 tokio::time::sleep(Duration::from_millis(10)).await;
2311 Ok(CallToolResult::text("done"))
2312 })
2313 .layer(TimeoutLayer::new(Duration::from_secs(1)))
2314 .build();
2315
2316 let result = tool.call(serde_json::json!({})).await;
2317 assert!(!result.is_error);
2318 assert_eq!(result.first_text().unwrap(), "done");
2319 }
2320
2321 #[tokio::test]
2322 async fn test_no_params_handler_timeout() {
2323 use std::time::Duration;
2324 use tower::timeout::TimeoutLayer;
2325
2326 let tool = ToolBuilder::new("very_slow_status")
2327 .description("Very slow status check")
2328 .no_params_handler(|| async {
2329 tokio::time::sleep(Duration::from_millis(200)).await;
2330 Ok(CallToolResult::text("done"))
2331 })
2332 .layer(TimeoutLayer::new(Duration::from_millis(50)))
2333 .build();
2334
2335 let result = tool.call(serde_json::json!({})).await;
2336 assert!(result.is_error);
2337 let msg = result.first_text().unwrap().to_lowercase();
2338 assert!(
2339 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2340 "Expected timeout error, got: {}",
2341 msg
2342 );
2343 }
2344
2345 #[tokio::test]
2346 async fn test_no_params_handler_with_multiple_layers() {
2347 use std::time::Duration;
2348 use tower::limit::ConcurrencyLimitLayer;
2349 use tower::timeout::TimeoutLayer;
2350
2351 let tool = ToolBuilder::new("multi_layer_status")
2352 .description("Status with multiple layers")
2353 .no_params_handler(|| async { Ok(CallToolResult::text("status ok")) })
2354 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2355 .layer(ConcurrencyLimitLayer::new(10))
2356 .build();
2357
2358 let result = tool.call(serde_json::json!({})).await;
2359 assert!(!result.is_error);
2360 assert_eq!(result.first_text().unwrap(), "status ok");
2361 }
2362
2363 #[tokio::test]
2368 async fn test_guard_allows_request() {
2369 #[derive(Debug, Deserialize, JsonSchema)]
2370 #[allow(dead_code)]
2371 struct DeleteInput {
2372 id: String,
2373 confirm: bool,
2374 }
2375
2376 let tool = ToolBuilder::new("delete")
2377 .description("Delete a record")
2378 .handler(|input: DeleteInput| async move {
2379 Ok(CallToolResult::text(format!("deleted {}", input.id)))
2380 })
2381 .guard(|req: &ToolRequest| {
2382 let confirm = req
2383 .args
2384 .get("confirm")
2385 .and_then(|v| v.as_bool())
2386 .unwrap_or(false);
2387 if !confirm {
2388 return Err("Must set confirm=true to delete".to_string());
2389 }
2390 Ok(())
2391 })
2392 .build();
2393
2394 let result = tool
2395 .call(serde_json::json!({"id": "abc", "confirm": true}))
2396 .await;
2397 assert!(!result.is_error);
2398 assert_eq!(result.first_text().unwrap(), "deleted abc");
2399 }
2400
2401 #[tokio::test]
2402 async fn test_guard_rejects_request() {
2403 #[derive(Debug, Deserialize, JsonSchema)]
2404 #[allow(dead_code)]
2405 struct DeleteInput2 {
2406 id: String,
2407 confirm: bool,
2408 }
2409
2410 let tool = ToolBuilder::new("delete2")
2411 .description("Delete a record")
2412 .handler(|input: DeleteInput2| async move {
2413 Ok(CallToolResult::text(format!("deleted {}", input.id)))
2414 })
2415 .guard(|req: &ToolRequest| {
2416 let confirm = req
2417 .args
2418 .get("confirm")
2419 .and_then(|v| v.as_bool())
2420 .unwrap_or(false);
2421 if !confirm {
2422 return Err("Must set confirm=true to delete".to_string());
2423 }
2424 Ok(())
2425 })
2426 .build();
2427
2428 let result = tool
2429 .call(serde_json::json!({"id": "abc", "confirm": false}))
2430 .await;
2431 assert!(result.is_error);
2432 assert!(
2433 result
2434 .first_text()
2435 .unwrap()
2436 .contains("Must set confirm=true")
2437 );
2438 }
2439
2440 #[tokio::test]
2441 async fn test_guard_with_layer() {
2442 use std::time::Duration;
2443 use tower::timeout::TimeoutLayer;
2444
2445 let tool = ToolBuilder::new("guarded_timeout")
2446 .description("Guarded with timeout")
2447 .handler(|input: GreetInput| async move {
2448 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
2449 })
2450 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2451 .guard(|_req: &ToolRequest| Ok(()))
2452 .build();
2453
2454 let result = tool.call(serde_json::json!({"name": "World"})).await;
2455 assert!(!result.is_error);
2456 assert_eq!(result.first_text().unwrap(), "Hello, World!");
2457 }
2458
2459 #[tokio::test]
2460 async fn test_guard_on_no_params_handler() {
2461 let allowed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true));
2462 let allowed_clone = allowed.clone();
2463
2464 let tool = ToolBuilder::new("status")
2465 .description("Get status")
2466 .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2467 .guard(move |_req: &ToolRequest| {
2468 if allowed_clone.load(std::sync::atomic::Ordering::Relaxed) {
2469 Ok(())
2470 } else {
2471 Err("Access denied".to_string())
2472 }
2473 })
2474 .build();
2475
2476 let result = tool.call(serde_json::json!({})).await;
2478 assert!(!result.is_error);
2479 assert_eq!(result.first_text().unwrap(), "ok");
2480
2481 allowed.store(false, std::sync::atomic::Ordering::Relaxed);
2483 let result = tool.call(serde_json::json!({})).await;
2484 assert!(result.is_error);
2485 assert!(result.first_text().unwrap().contains("Access denied"));
2486 }
2487
2488 #[tokio::test]
2489 async fn test_guard_on_no_params_handler_with_layer() {
2490 use std::time::Duration;
2491 use tower::timeout::TimeoutLayer;
2492
2493 let tool = ToolBuilder::new("status_layered")
2494 .description("Get status with layers")
2495 .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2496 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2497 .guard(|_req: &ToolRequest| Ok(()))
2498 .build();
2499
2500 let result = tool.call(serde_json::json!({})).await;
2501 assert!(!result.is_error);
2502 assert_eq!(result.first_text().unwrap(), "ok");
2503 }
2504
2505 #[tokio::test]
2506 async fn test_guard_on_extractor_handler() {
2507 use std::sync::Arc;
2508
2509 #[derive(Clone)]
2510 struct AppState {
2511 prefix: String,
2512 }
2513
2514 #[derive(Debug, Deserialize, JsonSchema)]
2515 struct QueryInput {
2516 query: String,
2517 }
2518
2519 let state = Arc::new(AppState {
2520 prefix: "db".to_string(),
2521 });
2522
2523 let tool = ToolBuilder::new("search")
2524 .description("Search")
2525 .extractor_handler(
2526 state,
2527 |State(app): State<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
2528 Ok(CallToolResult::text(format!(
2529 "{}: {}",
2530 app.prefix, input.query
2531 )))
2532 },
2533 )
2534 .guard(|req: &ToolRequest| {
2535 let query = req.args.get("query").and_then(|v| v.as_str()).unwrap_or("");
2536 if query.is_empty() {
2537 return Err("Query cannot be empty".to_string());
2538 }
2539 Ok(())
2540 })
2541 .build();
2542
2543 let result = tool.call(serde_json::json!({"query": "hello"})).await;
2545 assert!(!result.is_error);
2546 assert_eq!(result.first_text().unwrap(), "db: hello");
2547
2548 let result = tool.call(serde_json::json!({"query": ""})).await;
2550 assert!(result.is_error);
2551 assert!(
2552 result
2553 .first_text()
2554 .unwrap()
2555 .contains("Query cannot be empty")
2556 );
2557 }
2558
2559 #[tokio::test]
2560 async fn test_guard_on_extractor_handler_with_layer() {
2561 use std::sync::Arc;
2562 use std::time::Duration;
2563 use tower::timeout::TimeoutLayer;
2564
2565 #[derive(Clone)]
2566 struct AppState2 {
2567 prefix: String,
2568 }
2569
2570 #[derive(Debug, Deserialize, JsonSchema)]
2571 struct QueryInput2 {
2572 query: String,
2573 }
2574
2575 let state = Arc::new(AppState2 {
2576 prefix: "db".to_string(),
2577 });
2578
2579 let tool = ToolBuilder::new("search2")
2580 .description("Search with layer and guard")
2581 .extractor_handler(
2582 state,
2583 |State(app): State<Arc<AppState2>>, Json(input): Json<QueryInput2>| async move {
2584 Ok(CallToolResult::text(format!(
2585 "{}: {}",
2586 app.prefix, input.query
2587 )))
2588 },
2589 )
2590 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2591 .guard(|_req: &ToolRequest| Ok(()))
2592 .build();
2593
2594 let result = tool.call(serde_json::json!({"query": "hello"})).await;
2595 assert!(!result.is_error);
2596 assert_eq!(result.first_text().unwrap(), "db: hello");
2597 }
2598
2599 #[tokio::test]
2600 async fn test_tool_with_guard_post_build() {
2601 let tool = ToolBuilder::new("admin_action")
2602 .description("Admin action")
2603 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2604 .build();
2605
2606 let guarded = tool.with_guard(|req: &ToolRequest| {
2608 let name = req.args.get("name").and_then(|v| v.as_str()).unwrap_or("");
2609 if name == "admin" {
2610 Ok(())
2611 } else {
2612 Err("Only admin allowed".to_string())
2613 }
2614 });
2615
2616 let result = guarded.call(serde_json::json!({"name": "admin"})).await;
2618 assert!(!result.is_error);
2619
2620 let result = guarded.call(serde_json::json!({"name": "user"})).await;
2622 assert!(result.is_error);
2623 assert!(result.first_text().unwrap().contains("Only admin allowed"));
2624 }
2625
2626 #[tokio::test]
2627 async fn test_with_guard_preserves_tool_metadata() {
2628 let tool = ToolBuilder::new("my_tool")
2629 .description("A tool")
2630 .title("My Tool")
2631 .read_only()
2632 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2633 .build();
2634
2635 let guarded = tool.with_guard(|_req: &ToolRequest| Ok(()));
2636
2637 assert_eq!(guarded.name, "my_tool");
2638 assert_eq!(guarded.description.as_deref(), Some("A tool"));
2639 assert_eq!(guarded.title.as_deref(), Some("My Tool"));
2640 assert!(guarded.annotations.is_some());
2641 }
2642
2643 #[tokio::test]
2644 async fn test_guard_group_pattern() {
2645 let require_auth = |req: &ToolRequest| {
2647 let token = req
2648 .args
2649 .get("_token")
2650 .and_then(|v| v.as_str())
2651 .unwrap_or("");
2652 if token == "valid" {
2653 Ok(())
2654 } else {
2655 Err("Authentication required".to_string())
2656 }
2657 };
2658
2659 let tool1 = ToolBuilder::new("action1")
2660 .description("Action 1")
2661 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action1")) })
2662 .build();
2663 let tool2 = ToolBuilder::new("action2")
2664 .description("Action 2")
2665 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action2")) })
2666 .build();
2667
2668 let guarded1 = tool1.with_guard(require_auth);
2670 let guarded2 = tool2.with_guard(require_auth);
2671
2672 let r1 = guarded1
2674 .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2675 .await;
2676 let r2 = guarded2
2677 .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2678 .await;
2679 assert!(r1.is_error);
2680 assert!(r2.is_error);
2681
2682 let r1 = guarded1
2684 .call(serde_json::json!({"name": "test", "_token": "valid"}))
2685 .await;
2686 let r2 = guarded2
2687 .call(serde_json::json!({"name": "test", "_token": "valid"}))
2688 .await;
2689 assert!(!r1.is_error);
2690 assert!(!r2.is_error);
2691 }
2692}