1use std::borrow::Cow;
34use std::convert::Infallible;
35use std::fmt;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use schemars::{JsonSchema, Schema, SchemaGenerator};
42use serde::Serialize;
43use serde::de::DeserializeOwned;
44use serde_json::Value;
45use tower::util::BoxCloneService;
46use tower_service::Service;
47
48use crate::context::RequestContext;
49use crate::error::{Error, Result, ResultExt};
50use crate::protocol::{
51 CallToolResult, TaskSupportMode, ToolAnnotations, ToolDefinition, ToolExecution, ToolIcon,
52};
53
54#[derive(Debug, Clone)]
63pub struct ToolRequest {
64 pub ctx: RequestContext,
66 pub args: Value,
68}
69
70impl ToolRequest {
71 pub fn new(ctx: RequestContext, args: Value) -> Self {
73 Self { ctx, args }
74 }
75}
76
77pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
83
84pub struct ToolCatchError<S> {
90 inner: S,
91}
92
93impl<S> ToolCatchError<S> {
94 pub fn new(inner: S) -> Self {
96 Self { inner }
97 }
98}
99
100impl<S: Clone> Clone for ToolCatchError<S> {
101 fn clone(&self) -> Self {
102 Self {
103 inner: self.inner.clone(),
104 }
105 }
106}
107
108impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("ToolCatchError")
111 .field("inner", &self.inner)
112 .finish()
113 }
114}
115
116impl<S> Service<ToolRequest> for ToolCatchError<S>
117where
118 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
119 S::Error: fmt::Display + Send,
120 S::Future: Send,
121{
122 type Response = CallToolResult;
123 type Error = Infallible;
124 type Future =
125 Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
126
127 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
128 match self.inner.poll_ready(cx) {
130 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
131 Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
132 Poll::Pending => Poll::Pending,
133 }
134 }
135
136 fn call(&mut self, req: ToolRequest) -> Self::Future {
137 let fut = self.inner.call(req);
138
139 Box::pin(async move {
140 match fut.await {
141 Ok(result) => Ok(result),
142 Err(err) => Ok(CallToolResult::error(err.to_string())),
143 }
144 })
145 }
146}
147
148#[derive(Clone)]
179pub struct GuardLayer<G> {
180 guard: G,
181}
182
183impl<G> GuardLayer<G> {
184 pub fn new(guard: G) -> Self {
189 Self { guard }
190 }
191}
192
193impl<G, S> tower::Layer<S> for GuardLayer<G>
194where
195 G: Clone,
196{
197 type Service = GuardService<G, S>;
198
199 fn layer(&self, inner: S) -> Self::Service {
200 GuardService {
201 guard: self.guard.clone(),
202 inner,
203 }
204 }
205}
206
207#[derive(Clone)]
211pub struct GuardService<G, S> {
212 guard: G,
213 inner: S,
214}
215
216impl<G, S> Service<ToolRequest> for GuardService<G, S>
217where
218 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
219 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
220 S::Error: Into<Error> + Send,
221 S::Future: Send,
222{
223 type Response = CallToolResult;
224 type Error = Error;
225 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
226
227 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
228 self.inner.poll_ready(cx).map_err(Into::into)
229 }
230
231 fn call(&mut self, req: ToolRequest) -> Self::Future {
232 match (self.guard)(&req) {
233 Ok(()) => {
234 let fut = self.inner.call(req);
235 Box::pin(async move { fut.await.map_err(Into::into) })
236 }
237 Err(msg) => Box::pin(async move { Err(Error::tool(msg)) }),
238 }
239 }
240}
241
242#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
262pub struct NoParams;
263
264impl<'de> serde::Deserialize<'de> for NoParams {
265 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
266 where
267 D: serde::Deserializer<'de>,
268 {
269 struct NoParamsVisitor;
271
272 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
273 type Value = NoParams;
274
275 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
276 formatter.write_str("null or an object")
277 }
278
279 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
280 where
281 E: serde::de::Error,
282 {
283 Ok(NoParams)
284 }
285
286 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
287 where
288 E: serde::de::Error,
289 {
290 Ok(NoParams)
291 }
292
293 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
294 where
295 D: serde::Deserializer<'de>,
296 {
297 serde::Deserialize::deserialize(deserializer)
298 }
299
300 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
301 where
302 A: serde::de::MapAccess<'de>,
303 {
304 while map
306 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
307 .is_some()
308 {}
309 Ok(NoParams)
310 }
311 }
312
313 deserializer.deserialize_any(NoParamsVisitor)
314 }
315}
316
317impl JsonSchema for NoParams {
318 fn schema_name() -> Cow<'static, str> {
319 Cow::Borrowed("NoParams")
320 }
321
322 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
323 serde_json::json!({
324 "type": "object"
325 })
326 .try_into()
327 .expect("valid schema")
328 }
329}
330
331pub fn validate_tool_name(name: &str) -> Result<()> {
339 if name.is_empty() {
340 return Err(Error::tool("Tool name cannot be empty"));
341 }
342 if name.len() > 128 {
343 return Err(Error::tool(format!(
344 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
345 name,
346 name.len()
347 )));
348 }
349 if let Some(invalid_char) = name
350 .chars()
351 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
352 {
353 return Err(Error::tool(format!(
354 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
355 name, invalid_char
356 )));
357 }
358 Ok(())
359}
360
361pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
363
364pub trait ToolHandler: Send + Sync {
366 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
368
369 fn call_with_context(
374 &self,
375 _ctx: RequestContext,
376 args: Value,
377 ) -> BoxFuture<'_, Result<CallToolResult>> {
378 self.call(args)
379 }
380
381 fn uses_context(&self) -> bool {
383 false
384 }
385
386 fn input_schema(&self) -> Value;
388}
389
390pub(crate) struct ToolHandlerService<H> {
395 handler: Arc<H>,
396}
397
398impl<H> ToolHandlerService<H> {
399 pub(crate) fn new(handler: H) -> Self {
400 Self {
401 handler: Arc::new(handler),
402 }
403 }
404}
405
406impl<H> Clone for ToolHandlerService<H> {
407 fn clone(&self) -> Self {
408 Self {
409 handler: self.handler.clone(),
410 }
411 }
412}
413
414impl<H> Service<ToolRequest> for ToolHandlerService<H>
415where
416 H: ToolHandler + 'static,
417{
418 type Response = CallToolResult;
419 type Error = Error;
420 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
421
422 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
423 Poll::Ready(Ok(()))
424 }
425
426 fn call(&mut self, req: ToolRequest) -> Self::Future {
427 let handler = self.handler.clone();
428 Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
429 }
430}
431
432pub struct Tool {
439 pub name: String,
441 pub title: Option<String>,
443 pub description: Option<String>,
445 pub output_schema: Option<Value>,
447 pub icons: Option<Vec<ToolIcon>>,
449 pub annotations: Option<ToolAnnotations>,
451 pub task_support: TaskSupportMode,
453 pub(crate) service: BoxToolService,
455 pub(crate) input_schema: Value,
457}
458
459impl std::fmt::Debug for Tool {
460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461 f.debug_struct("Tool")
462 .field("name", &self.name)
463 .field("title", &self.title)
464 .field("description", &self.description)
465 .field("output_schema", &self.output_schema)
466 .field("icons", &self.icons)
467 .field("annotations", &self.annotations)
468 .field("task_support", &self.task_support)
469 .finish_non_exhaustive()
470 }
471}
472
473unsafe impl Send for Tool {}
476unsafe impl Sync for Tool {}
477
478impl Clone for Tool {
479 fn clone(&self) -> Self {
480 Self {
481 name: self.name.clone(),
482 title: self.title.clone(),
483 description: self.description.clone(),
484 output_schema: self.output_schema.clone(),
485 icons: self.icons.clone(),
486 annotations: self.annotations.clone(),
487 task_support: self.task_support,
488 service: self.service.clone(),
489 input_schema: self.input_schema.clone(),
490 }
491 }
492}
493
494impl Tool {
495 pub fn builder(name: impl Into<String>) -> ToolBuilder {
497 ToolBuilder::new(name)
498 }
499
500 pub fn definition(&self) -> ToolDefinition {
502 let execution = match self.task_support {
503 TaskSupportMode::Forbidden => None,
504 mode => Some(ToolExecution {
505 task_support: Some(mode),
506 }),
507 };
508 ToolDefinition {
509 name: self.name.clone(),
510 title: self.title.clone(),
511 description: self.description.clone(),
512 input_schema: self.input_schema.clone(),
513 output_schema: self.output_schema.clone(),
514 icons: self.icons.clone(),
515 annotations: self.annotations.clone(),
516 execution,
517 meta: None,
518 }
519 }
520
521 pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
526 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
527 self.call_with_context(ctx, args)
528 }
529
530 pub fn call_with_context(
541 &self,
542 ctx: RequestContext,
543 args: Value,
544 ) -> BoxFuture<'static, CallToolResult> {
545 use tower::ServiceExt;
546 let service = self.service.clone();
547 Box::pin(async move {
548 service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
551 })
552 }
553
554 pub fn with_guard<G>(self, guard: G) -> Self
583 where
584 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
585 {
586 let guarded = GuardService {
587 guard,
588 inner: self.service,
589 };
590 let caught = ToolCatchError::new(guarded);
591 Tool {
592 service: BoxCloneService::new(caught),
593 ..self
594 }
595 }
596
597 pub fn with_name_prefix(&self, prefix: &str) -> Self {
624 Self {
625 name: format!("{}.{}", prefix, self.name),
626 title: self.title.clone(),
627 description: self.description.clone(),
628 output_schema: self.output_schema.clone(),
629 icons: self.icons.clone(),
630 annotations: self.annotations.clone(),
631 task_support: self.task_support,
632 service: self.service.clone(),
633 input_schema: self.input_schema.clone(),
634 }
635 }
636
637 #[allow(clippy::too_many_arguments)]
639 fn from_handler<H: ToolHandler + 'static>(
640 name: String,
641 title: Option<String>,
642 description: Option<String>,
643 output_schema: Option<Value>,
644 icons: Option<Vec<ToolIcon>>,
645 annotations: Option<ToolAnnotations>,
646 task_support: TaskSupportMode,
647 handler: H,
648 ) -> Self {
649 let input_schema = handler.input_schema();
650 let handler_service = ToolHandlerService::new(handler);
651 let catch_error = ToolCatchError::new(handler_service);
652 let service = BoxCloneService::new(catch_error);
653
654 Self {
655 name,
656 title,
657 description,
658 output_schema,
659 icons,
660 annotations,
661 task_support,
662 service,
663 input_schema,
664 }
665 }
666}
667
668pub struct ToolBuilder {
696 name: String,
697 title: Option<String>,
698 description: Option<String>,
699 output_schema: Option<Value>,
700 icons: Option<Vec<ToolIcon>>,
701 annotations: Option<ToolAnnotations>,
702 task_support: TaskSupportMode,
703}
704
705impl ToolBuilder {
706 pub fn new(name: impl Into<String>) -> Self {
718 let name = name.into();
719 if let Err(e) = validate_tool_name(&name) {
720 panic!("{e}");
721 }
722 Self {
723 name,
724 title: None,
725 description: None,
726 output_schema: None,
727 icons: None,
728 annotations: None,
729 task_support: TaskSupportMode::default(),
730 }
731 }
732
733 pub fn try_new(name: impl Into<String>) -> Result<Self> {
739 let name = name.into();
740 validate_tool_name(&name)?;
741 Ok(Self {
742 name,
743 title: None,
744 description: None,
745 output_schema: None,
746 icons: None,
747 annotations: None,
748 task_support: TaskSupportMode::default(),
749 })
750 }
751
752 pub fn title(mut self, title: impl Into<String>) -> Self {
754 self.title = Some(title.into());
755 self
756 }
757
758 pub fn output_schema(mut self, schema: Value) -> Self {
760 self.output_schema = Some(schema);
761 self
762 }
763
764 pub fn icon(mut self, src: impl Into<String>) -> Self {
766 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
767 src: src.into(),
768 mime_type: None,
769 sizes: None,
770 theme: None,
771 });
772 self
773 }
774
775 pub fn icon_with_meta(
777 mut self,
778 src: impl Into<String>,
779 mime_type: Option<String>,
780 sizes: Option<Vec<String>>,
781 ) -> Self {
782 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
783 src: src.into(),
784 mime_type,
785 sizes,
786 theme: None,
787 });
788 self
789 }
790
791 pub fn description(mut self, description: impl Into<String>) -> Self {
793 self.description = Some(description.into());
794 self
795 }
796
797 pub fn read_only(mut self) -> Self {
799 self.annotations
800 .get_or_insert_with(ToolAnnotations::default)
801 .read_only_hint = true;
802 self
803 }
804
805 pub fn non_destructive(mut self) -> Self {
807 self.annotations
808 .get_or_insert_with(ToolAnnotations::default)
809 .destructive_hint = false;
810 self
811 }
812
813 pub fn destructive(mut self) -> Self {
815 self.annotations
816 .get_or_insert_with(ToolAnnotations::default)
817 .destructive_hint = true;
818 self
819 }
820
821 pub fn idempotent(mut self) -> Self {
823 self.annotations
824 .get_or_insert_with(ToolAnnotations::default)
825 .idempotent_hint = true;
826 self
827 }
828
829 pub fn read_only_safe(mut self) -> Self {
835 let ann = self
836 .annotations
837 .get_or_insert_with(ToolAnnotations::default);
838 ann.read_only_hint = true;
839 ann.idempotent_hint = true;
840 ann.destructive_hint = false;
841 self
842 }
843
844 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
846 self.annotations = Some(annotations);
847 self
848 }
849
850 pub fn task_support(mut self, mode: TaskSupportMode) -> Self {
852 self.task_support = mode;
853 self
854 }
855
856 pub fn no_params_handler<F, Fut>(self, handler: F) -> ToolBuilderWithNoParamsHandler<F>
874 where
875 F: Fn() -> Fut + Send + Sync + 'static,
876 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
877 {
878 ToolBuilderWithNoParamsHandler {
879 name: self.name,
880 title: self.title,
881 description: self.description,
882 output_schema: self.output_schema,
883 icons: self.icons,
884 annotations: self.annotations,
885 task_support: self.task_support,
886 handler,
887 }
888 }
889
890 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
933 where
934 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
935 F: Fn(I) -> Fut + Send + Sync + 'static,
936 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
937 {
938 ToolBuilderWithHandler {
939 name: self.name,
940 title: self.title,
941 description: self.description,
942 output_schema: self.output_schema,
943 icons: self.icons,
944 annotations: self.annotations,
945 task_support: self.task_support,
946 handler,
947 _phantom: std::marker::PhantomData,
948 }
949 }
950
951 pub fn extractor_handler<S, F, T>(
1045 self,
1046 state: S,
1047 handler: F,
1048 ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
1049 where
1050 S: Clone + Send + Sync + 'static,
1051 F: crate::extract::ExtractorHandler<S, T> + Clone,
1052 T: Send + Sync + 'static,
1053 {
1054 crate::extract::ToolBuilderWithExtractor {
1055 name: self.name,
1056 title: self.title,
1057 description: self.description,
1058 output_schema: self.output_schema,
1059 icons: self.icons,
1060 annotations: self.annotations,
1061 task_support: self.task_support,
1062 state,
1063 handler,
1064 input_schema: F::input_schema(),
1065 _phantom: std::marker::PhantomData,
1066 }
1067 }
1068
1069 pub fn extractor_handler_typed<S, F, T, I>(
1106 self,
1107 state: S,
1108 handler: F,
1109 ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
1110 where
1111 S: Clone + Send + Sync + 'static,
1112 F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
1113 T: Send + Sync + 'static,
1114 I: schemars::JsonSchema + Send + Sync + 'static,
1115 {
1116 crate::extract::ToolBuilderWithTypedExtractor {
1117 name: self.name,
1118 title: self.title,
1119 description: self.description,
1120 output_schema: self.output_schema,
1121 icons: self.icons,
1122 annotations: self.annotations,
1123 task_support: self.task_support,
1124 state,
1125 handler,
1126 _phantom: std::marker::PhantomData,
1127 }
1128 }
1129}
1130
1131struct NoParamsTypedHandler<F> {
1135 handler: F,
1136}
1137
1138impl<F, Fut> ToolHandler for NoParamsTypedHandler<F>
1139where
1140 F: Fn() -> Fut + Send + Sync + 'static,
1141 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1142{
1143 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1144 Box::pin(async move { (self.handler)().await })
1145 }
1146
1147 fn input_schema(&self) -> Value {
1148 serde_json::json!({ "type": "object" })
1149 }
1150}
1151
1152pub struct ToolBuilderWithHandler<I, F> {
1154 name: String,
1155 title: Option<String>,
1156 description: Option<String>,
1157 output_schema: Option<Value>,
1158 icons: Option<Vec<ToolIcon>>,
1159 annotations: Option<ToolAnnotations>,
1160 task_support: TaskSupportMode,
1161 handler: F,
1162 _phantom: std::marker::PhantomData<I>,
1163}
1164
1165pub struct ToolBuilderWithNoParamsHandler<F> {
1169 name: String,
1170 title: Option<String>,
1171 description: Option<String>,
1172 output_schema: Option<Value>,
1173 icons: Option<Vec<ToolIcon>>,
1174 annotations: Option<ToolAnnotations>,
1175 task_support: TaskSupportMode,
1176 handler: F,
1177}
1178
1179impl<F, Fut> ToolBuilderWithNoParamsHandler<F>
1180where
1181 F: Fn() -> Fut + Send + Sync + 'static,
1182 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1183{
1184 pub fn build(self) -> Tool {
1186 Tool::from_handler(
1187 self.name,
1188 self.title,
1189 self.description,
1190 self.output_schema,
1191 self.icons,
1192 self.annotations,
1193 self.task_support,
1194 NoParamsTypedHandler {
1195 handler: self.handler,
1196 },
1197 )
1198 }
1199
1200 pub fn layer<L>(self, layer: L) -> ToolBuilderWithNoParamsHandlerLayer<F, L> {
1204 ToolBuilderWithNoParamsHandlerLayer {
1205 name: self.name,
1206 title: self.title,
1207 description: self.description,
1208 output_schema: self.output_schema,
1209 icons: self.icons,
1210 annotations: self.annotations,
1211 task_support: self.task_support,
1212 handler: self.handler,
1213 layer,
1214 }
1215 }
1216
1217 pub fn guard<G>(self, guard: G) -> ToolBuilderWithNoParamsHandlerLayer<F, GuardLayer<G>>
1221 where
1222 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1223 {
1224 self.layer(GuardLayer::new(guard))
1225 }
1226}
1227
1228pub struct ToolBuilderWithNoParamsHandlerLayer<F, L> {
1230 name: String,
1231 title: Option<String>,
1232 description: Option<String>,
1233 output_schema: Option<Value>,
1234 icons: Option<Vec<ToolIcon>>,
1235 annotations: Option<ToolAnnotations>,
1236 task_support: TaskSupportMode,
1237 handler: F,
1238 layer: L,
1239}
1240
1241#[allow(private_bounds)]
1242impl<F, Fut, L> ToolBuilderWithNoParamsHandlerLayer<F, L>
1243where
1244 F: Fn() -> Fut + Send + Sync + 'static,
1245 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1246 L: tower::Layer<ToolHandlerService<NoParamsTypedHandler<F>>> + Clone + Send + Sync + 'static,
1247 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1248 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1249 <L::Service as Service<ToolRequest>>::Future: Send,
1250{
1251 pub fn build(self) -> Tool {
1253 let input_schema = serde_json::json!({ "type": "object" });
1254
1255 let handler_service = ToolHandlerService::new(NoParamsTypedHandler {
1256 handler: self.handler,
1257 });
1258 let layered = self.layer.layer(handler_service);
1259 let catch_error = ToolCatchError::new(layered);
1260 let service = BoxCloneService::new(catch_error);
1261
1262 Tool {
1263 name: self.name,
1264 title: self.title,
1265 description: self.description,
1266 output_schema: self.output_schema,
1267 icons: self.icons,
1268 annotations: self.annotations,
1269 task_support: self.task_support,
1270 service,
1271 input_schema,
1272 }
1273 }
1274
1275 pub fn layer<L2>(
1277 self,
1278 layer: L2,
1279 ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<L2, L>> {
1280 ToolBuilderWithNoParamsHandlerLayer {
1281 name: self.name,
1282 title: self.title,
1283 description: self.description,
1284 output_schema: self.output_schema,
1285 icons: self.icons,
1286 annotations: self.annotations,
1287 task_support: self.task_support,
1288 handler: self.handler,
1289 layer: tower::layer::util::Stack::new(layer, self.layer),
1290 }
1291 }
1292
1293 pub fn guard<G>(
1297 self,
1298 guard: G,
1299 ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<GuardLayer<G>, L>>
1300 where
1301 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1302 {
1303 self.layer(GuardLayer::new(guard))
1304 }
1305}
1306
1307impl<I, F, Fut> ToolBuilderWithHandler<I, F>
1308where
1309 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1310 F: Fn(I) -> Fut + Send + Sync + 'static,
1311 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1312{
1313 pub fn build(self) -> Tool {
1315 Tool::from_handler(
1316 self.name,
1317 self.title,
1318 self.description,
1319 self.output_schema,
1320 self.icons,
1321 self.annotations,
1322 self.task_support,
1323 TypedHandler {
1324 handler: self.handler,
1325 _phantom: std::marker::PhantomData,
1326 },
1327 )
1328 }
1329
1330 pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
1356 ToolBuilderWithLayer {
1357 name: self.name,
1358 title: self.title,
1359 description: self.description,
1360 output_schema: self.output_schema,
1361 icons: self.icons,
1362 annotations: self.annotations,
1363 task_support: self.task_support,
1364 handler: self.handler,
1365 layer,
1366 _phantom: std::marker::PhantomData,
1367 }
1368 }
1369
1370 pub fn guard<G>(self, guard: G) -> ToolBuilderWithLayer<I, F, GuardLayer<G>>
1377 where
1378 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1379 {
1380 self.layer(GuardLayer::new(guard))
1381 }
1382}
1383
1384pub struct ToolBuilderWithLayer<I, F, L> {
1388 name: String,
1389 title: Option<String>,
1390 description: Option<String>,
1391 output_schema: Option<Value>,
1392 icons: Option<Vec<ToolIcon>>,
1393 annotations: Option<ToolAnnotations>,
1394 task_support: TaskSupportMode,
1395 handler: F,
1396 layer: L,
1397 _phantom: std::marker::PhantomData<I>,
1398}
1399
1400#[allow(private_bounds)]
1403impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
1404where
1405 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1406 F: Fn(I) -> Fut + Send + Sync + 'static,
1407 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1408 L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
1409 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1410 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1411 <L::Service as Service<ToolRequest>>::Future: Send,
1412{
1413 pub fn build(self) -> Tool {
1415 let input_schema = schemars::schema_for!(I);
1416 let input_schema = serde_json::to_value(input_schema)
1417 .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
1418
1419 let handler_service = ToolHandlerService::new(TypedHandler {
1420 handler: self.handler,
1421 _phantom: std::marker::PhantomData,
1422 });
1423 let layered = self.layer.layer(handler_service);
1424 let catch_error = ToolCatchError::new(layered);
1425 let service = BoxCloneService::new(catch_error);
1426
1427 Tool {
1428 name: self.name,
1429 title: self.title,
1430 description: self.description,
1431 output_schema: self.output_schema,
1432 icons: self.icons,
1433 annotations: self.annotations,
1434 task_support: self.task_support,
1435 service,
1436 input_schema,
1437 }
1438 }
1439
1440 pub fn layer<L2>(
1445 self,
1446 layer: L2,
1447 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
1448 ToolBuilderWithLayer {
1449 name: self.name,
1450 title: self.title,
1451 description: self.description,
1452 output_schema: self.output_schema,
1453 icons: self.icons,
1454 annotations: self.annotations,
1455 task_support: self.task_support,
1456 handler: self.handler,
1457 layer: tower::layer::util::Stack::new(layer, self.layer),
1458 _phantom: std::marker::PhantomData,
1459 }
1460 }
1461
1462 pub fn guard<G>(
1466 self,
1467 guard: G,
1468 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<GuardLayer<G>, L>>
1469 where
1470 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1471 {
1472 self.layer(GuardLayer::new(guard))
1473 }
1474}
1475
1476struct TypedHandler<I, F> {
1482 handler: F,
1483 _phantom: std::marker::PhantomData<I>,
1484}
1485
1486impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1487where
1488 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1489 F: Fn(I) -> Fut + Send + Sync + 'static,
1490 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1491{
1492 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1493 Box::pin(async move {
1494 let input: I = serde_json::from_value(args).tool_context("Invalid input")?;
1495 (self.handler)(input).await
1496 })
1497 }
1498
1499 fn input_schema(&self) -> Value {
1500 let schema = schemars::schema_for!(I);
1501 serde_json::to_value(schema).unwrap_or_else(|_| {
1502 serde_json::json!({
1503 "type": "object"
1504 })
1505 })
1506 }
1507}
1508
1509pub trait McpTool: Send + Sync + 'static {
1550 const NAME: &'static str;
1551 const DESCRIPTION: &'static str;
1552
1553 type Input: JsonSchema + DeserializeOwned + Send;
1554 type Output: Serialize + Send;
1555
1556 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1557
1558 fn annotations(&self) -> Option<ToolAnnotations> {
1560 None
1561 }
1562
1563 fn into_tool(self) -> Tool
1571 where
1572 Self: Sized,
1573 {
1574 if let Err(e) = validate_tool_name(Self::NAME) {
1575 panic!("{e}");
1576 }
1577 let annotations = self.annotations();
1578 let tool = Arc::new(self);
1579 Tool::from_handler(
1580 Self::NAME.to_string(),
1581 None,
1582 Some(Self::DESCRIPTION.to_string()),
1583 None,
1584 None,
1585 annotations,
1586 TaskSupportMode::default(),
1587 McpToolHandler { tool },
1588 )
1589 }
1590}
1591
1592struct McpToolHandler<T: McpTool> {
1594 tool: Arc<T>,
1595}
1596
1597impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1598 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1599 let tool = self.tool.clone();
1600 Box::pin(async move {
1601 let input: T::Input = serde_json::from_value(args).tool_context("Invalid input")?;
1602 let output = tool.call(input).await?;
1603 let value = serde_json::to_value(output).tool_context("Failed to serialize output")?;
1604 Ok(CallToolResult::json(value))
1605 })
1606 }
1607
1608 fn input_schema(&self) -> Value {
1609 let schema = schemars::schema_for!(T::Input);
1610 serde_json::to_value(schema).unwrap_or_else(|_| {
1611 serde_json::json!({
1612 "type": "object"
1613 })
1614 })
1615 }
1616}
1617
1618#[cfg(test)]
1619mod tests {
1620 use super::*;
1621 use crate::extract::{Context, Json, RawArgs, State};
1622 use crate::protocol::Content;
1623 use schemars::JsonSchema;
1624 use serde::Deserialize;
1625
1626 #[derive(Debug, Deserialize, JsonSchema)]
1627 struct GreetInput {
1628 name: String,
1629 }
1630
1631 #[tokio::test]
1632 async fn test_builder_tool() {
1633 let tool = ToolBuilder::new("greet")
1634 .description("Greet someone")
1635 .handler(|input: GreetInput| async move {
1636 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1637 })
1638 .build();
1639
1640 assert_eq!(tool.name, "greet");
1641 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1642
1643 let result = tool.call(serde_json::json!({"name": "World"})).await;
1644
1645 assert!(!result.is_error);
1646 }
1647
1648 #[tokio::test]
1649 async fn test_raw_handler() {
1650 let tool = ToolBuilder::new("echo")
1651 .description("Echo input")
1652 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1653 Ok(CallToolResult::json(args))
1654 })
1655 .build();
1656
1657 let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1658
1659 assert!(!result.is_error);
1660 }
1661
1662 #[test]
1663 fn test_invalid_tool_name_empty() {
1664 let err = ToolBuilder::try_new("").err().expect("should fail");
1665 assert!(err.to_string().contains("cannot be empty"));
1666 }
1667
1668 #[test]
1669 fn test_invalid_tool_name_too_long() {
1670 let long_name = "a".repeat(129);
1671 let err = ToolBuilder::try_new(long_name).err().expect("should fail");
1672 assert!(err.to_string().contains("exceeds maximum"));
1673 }
1674
1675 #[test]
1676 fn test_invalid_tool_name_bad_chars() {
1677 let err = ToolBuilder::try_new("my tool!").err().expect("should fail");
1678 assert!(err.to_string().contains("invalid character"));
1679 }
1680
1681 #[test]
1682 #[should_panic(expected = "cannot be empty")]
1683 fn test_new_panics_on_empty_name() {
1684 ToolBuilder::new("");
1685 }
1686
1687 #[test]
1688 #[should_panic(expected = "exceeds maximum")]
1689 fn test_new_panics_on_too_long_name() {
1690 ToolBuilder::new("a".repeat(129));
1691 }
1692
1693 #[test]
1694 #[should_panic(expected = "invalid character")]
1695 fn test_new_panics_on_invalid_chars() {
1696 ToolBuilder::new("my tool!");
1697 }
1698
1699 #[test]
1700 fn test_valid_tool_names() {
1701 let names = [
1703 "my_tool",
1704 "my-tool",
1705 "my.tool",
1706 "MyTool123",
1707 "a",
1708 &"a".repeat(128),
1709 ];
1710 for name in names {
1711 assert!(
1712 ToolBuilder::try_new(name).is_ok(),
1713 "Expected '{}' to be valid",
1714 name
1715 );
1716 }
1717 }
1718
1719 #[tokio::test]
1720 async fn test_context_aware_handler() {
1721 use crate::context::notification_channel;
1722 use crate::protocol::{ProgressToken, RequestId};
1723
1724 #[derive(Debug, Deserialize, JsonSchema)]
1725 struct ProcessInput {
1726 count: i32,
1727 }
1728
1729 let tool = ToolBuilder::new("process")
1730 .description("Process with context")
1731 .extractor_handler(
1732 (),
1733 |ctx: Context, Json(input): Json<ProcessInput>| async move {
1734 for i in 0..input.count {
1736 if ctx.is_cancelled() {
1737 return Ok(CallToolResult::error("Cancelled"));
1738 }
1739 ctx.report_progress(i as f64, Some(input.count as f64), None)
1740 .await;
1741 }
1742 Ok(CallToolResult::text(format!(
1743 "Processed {} items",
1744 input.count
1745 )))
1746 },
1747 )
1748 .build();
1749
1750 assert_eq!(tool.name, "process");
1751
1752 let (tx, mut rx) = notification_channel(10);
1754 let ctx = RequestContext::new(RequestId::Number(1))
1755 .with_progress_token(ProgressToken::Number(42))
1756 .with_notification_sender(tx);
1757
1758 let result = tool
1759 .call_with_context(ctx, serde_json::json!({"count": 3}))
1760 .await;
1761
1762 assert!(!result.is_error);
1763
1764 let mut progress_count = 0;
1766 while rx.try_recv().is_ok() {
1767 progress_count += 1;
1768 }
1769 assert_eq!(progress_count, 3);
1770 }
1771
1772 #[tokio::test]
1773 async fn test_context_aware_handler_cancellation() {
1774 use crate::protocol::RequestId;
1775 use std::sync::atomic::{AtomicI32, Ordering};
1776
1777 #[derive(Debug, Deserialize, JsonSchema)]
1778 struct LongRunningInput {
1779 iterations: i32,
1780 }
1781
1782 let iterations_completed = Arc::new(AtomicI32::new(0));
1783 let iterations_ref = iterations_completed.clone();
1784
1785 let tool = ToolBuilder::new("long_running")
1786 .description("Long running task")
1787 .extractor_handler(
1788 (),
1789 move |ctx: Context, Json(input): Json<LongRunningInput>| {
1790 let completed = iterations_ref.clone();
1791 async move {
1792 for i in 0..input.iterations {
1793 if ctx.is_cancelled() {
1794 return Ok(CallToolResult::error("Cancelled"));
1795 }
1796 completed.fetch_add(1, Ordering::SeqCst);
1797 tokio::task::yield_now().await;
1799 if i == 2 {
1801 ctx.cancellation_token().cancel();
1802 }
1803 }
1804 Ok(CallToolResult::text("Done"))
1805 }
1806 },
1807 )
1808 .build();
1809
1810 let ctx = RequestContext::new(RequestId::Number(1));
1811
1812 let result = tool
1813 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1814 .await;
1815
1816 assert!(result.is_error);
1819 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1820 }
1821
1822 #[tokio::test]
1823 async fn test_tool_builder_with_enhanced_fields() {
1824 let output_schema = serde_json::json!({
1825 "type": "object",
1826 "properties": {
1827 "greeting": {"type": "string"}
1828 }
1829 });
1830
1831 let tool = ToolBuilder::new("greet")
1832 .title("Greeting Tool")
1833 .description("Greet someone")
1834 .output_schema(output_schema.clone())
1835 .icon("https://example.com/icon.png")
1836 .icon_with_meta(
1837 "https://example.com/icon-large.png",
1838 Some("image/png".to_string()),
1839 Some(vec!["96x96".to_string()]),
1840 )
1841 .handler(|input: GreetInput| async move {
1842 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1843 })
1844 .build();
1845
1846 assert_eq!(tool.name, "greet");
1847 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1848 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1849 assert_eq!(tool.output_schema, Some(output_schema));
1850 assert!(tool.icons.is_some());
1851 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1852
1853 let def = tool.definition();
1855 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1856 assert!(def.output_schema.is_some());
1857 assert!(def.icons.is_some());
1858 }
1859
1860 #[tokio::test]
1861 async fn test_handler_with_state() {
1862 let shared = Arc::new("shared-state".to_string());
1863
1864 let tool = ToolBuilder::new("stateful")
1865 .description("Uses shared state")
1866 .extractor_handler(
1867 shared,
1868 |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1869 Ok(CallToolResult::text(format!(
1870 "{}: Hello, {}!",
1871 state, input.name
1872 )))
1873 },
1874 )
1875 .build();
1876
1877 let result = tool.call(serde_json::json!({"name": "World"})).await;
1878 assert!(!result.is_error);
1879 }
1880
1881 #[tokio::test]
1882 async fn test_handler_with_state_and_context() {
1883 use crate::protocol::RequestId;
1884
1885 let shared = Arc::new(42_i32);
1886
1887 let tool =
1888 ToolBuilder::new("stateful_ctx")
1889 .description("Uses state and context")
1890 .extractor_handler(
1891 shared,
1892 |State(state): State<Arc<i32>>,
1893 _ctx: Context,
1894 Json(input): Json<GreetInput>| async move {
1895 Ok(CallToolResult::text(format!(
1896 "{}: Hello, {}!",
1897 state, input.name
1898 )))
1899 },
1900 )
1901 .build();
1902
1903 let ctx = RequestContext::new(RequestId::Number(1));
1904 let result = tool
1905 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1906 .await;
1907 assert!(!result.is_error);
1908 }
1909
1910 #[tokio::test]
1911 async fn test_handler_no_params() {
1912 let tool = ToolBuilder::new("no_params")
1913 .description("Takes no parameters")
1914 .extractor_handler((), |Json(_): Json<NoParams>| async {
1915 Ok(CallToolResult::text("no params result"))
1916 })
1917 .build();
1918
1919 assert_eq!(tool.name, "no_params");
1920
1921 let result = tool.call(serde_json::json!({})).await;
1923 assert!(!result.is_error);
1924
1925 let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1927 assert!(!result.is_error);
1928
1929 let schema = tool.definition().input_schema;
1931 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1932 }
1933
1934 #[tokio::test]
1935 async fn test_handler_with_state_no_params() {
1936 let shared = Arc::new("shared_value".to_string());
1937
1938 let tool = ToolBuilder::new("with_state_no_params")
1939 .description("Takes no parameters but has state")
1940 .extractor_handler(
1941 shared,
1942 |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1943 Ok(CallToolResult::text(format!("state: {}", state)))
1944 },
1945 )
1946 .build();
1947
1948 assert_eq!(tool.name, "with_state_no_params");
1949
1950 let result = tool.call(serde_json::json!({})).await;
1952 assert!(!result.is_error);
1953 assert_eq!(result.first_text().unwrap(), "state: shared_value");
1954
1955 let schema = tool.definition().input_schema;
1957 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1958 }
1959
1960 #[tokio::test]
1961 async fn test_handler_no_params_with_context() {
1962 let tool = ToolBuilder::new("no_params_with_context")
1963 .description("Takes no parameters but has context")
1964 .extractor_handler((), |_ctx: Context, Json(_): Json<NoParams>| async move {
1965 Ok(CallToolResult::text("context available"))
1966 })
1967 .build();
1968
1969 assert_eq!(tool.name, "no_params_with_context");
1970
1971 let result = tool.call(serde_json::json!({})).await;
1972 assert!(!result.is_error);
1973 assert_eq!(result.first_text().unwrap(), "context available");
1974 }
1975
1976 #[tokio::test]
1977 async fn test_handler_with_state_and_context_no_params() {
1978 let shared = Arc::new("shared".to_string());
1979
1980 let tool = ToolBuilder::new("state_context_no_params")
1981 .description("Has state and context, no params")
1982 .extractor_handler(
1983 shared,
1984 |State(state): State<Arc<String>>,
1985 _ctx: Context,
1986 Json(_): Json<NoParams>| async move {
1987 Ok(CallToolResult::text(format!("state: {}", state)))
1988 },
1989 )
1990 .build();
1991
1992 assert_eq!(tool.name, "state_context_no_params");
1993
1994 let result = tool.call(serde_json::json!({})).await;
1995 assert!(!result.is_error);
1996 assert_eq!(result.first_text().unwrap(), "state: shared");
1997 }
1998
1999 #[tokio::test]
2000 async fn test_raw_handler_with_state() {
2001 let prefix = Arc::new("prefix:".to_string());
2002
2003 let tool = ToolBuilder::new("raw_with_state")
2004 .description("Raw handler with state")
2005 .extractor_handler(
2006 prefix,
2007 |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
2008 Ok(CallToolResult::text(format!("{} {}", state, args)))
2009 },
2010 )
2011 .build();
2012
2013 assert_eq!(tool.name, "raw_with_state");
2014
2015 let result = tool.call(serde_json::json!({"key": "value"})).await;
2016 assert!(!result.is_error);
2017 assert!(result.first_text().unwrap().starts_with("prefix:"));
2018 }
2019
2020 #[tokio::test]
2021 async fn test_raw_handler_with_state_and_context() {
2022 let prefix = Arc::new("prefix:".to_string());
2023
2024 let tool = ToolBuilder::new("raw_state_context")
2025 .description("Raw handler with state and context")
2026 .extractor_handler(
2027 prefix,
2028 |State(state): State<Arc<String>>,
2029 _ctx: Context,
2030 RawArgs(args): RawArgs| async move {
2031 Ok(CallToolResult::text(format!("{} {}", state, args)))
2032 },
2033 )
2034 .build();
2035
2036 assert_eq!(tool.name, "raw_state_context");
2037
2038 let result = tool.call(serde_json::json!({"key": "value"})).await;
2039 assert!(!result.is_error);
2040 assert!(result.first_text().unwrap().starts_with("prefix:"));
2041 }
2042
2043 #[tokio::test]
2044 async fn test_tool_with_timeout_layer() {
2045 use std::time::Duration;
2046 use tower::timeout::TimeoutLayer;
2047
2048 #[derive(Debug, Deserialize, JsonSchema)]
2049 struct SlowInput {
2050 delay_ms: u64,
2051 }
2052
2053 let tool = ToolBuilder::new("slow_tool")
2055 .description("A slow tool")
2056 .handler(|input: SlowInput| async move {
2057 tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
2058 Ok(CallToolResult::text("completed"))
2059 })
2060 .layer(TimeoutLayer::new(Duration::from_millis(50)))
2061 .build();
2062
2063 let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
2065 assert!(!result.is_error);
2066 assert_eq!(result.first_text().unwrap(), "completed");
2067
2068 let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
2070 assert!(result.is_error);
2071 let msg = result.first_text().unwrap().to_lowercase();
2073 assert!(
2074 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2075 "Expected timeout error, got: {}",
2076 msg
2077 );
2078 }
2079
2080 #[tokio::test]
2081 async fn test_tool_with_concurrency_limit_layer() {
2082 use std::sync::atomic::{AtomicU32, Ordering};
2083 use std::time::Duration;
2084 use tower::limit::ConcurrencyLimitLayer;
2085
2086 #[derive(Debug, Deserialize, JsonSchema)]
2087 struct WorkInput {
2088 id: u32,
2089 }
2090
2091 let max_concurrent = Arc::new(AtomicU32::new(0));
2092 let current_concurrent = Arc::new(AtomicU32::new(0));
2093 let max_ref = max_concurrent.clone();
2094 let current_ref = current_concurrent.clone();
2095
2096 let tool = ToolBuilder::new("concurrent_tool")
2098 .description("A concurrent tool")
2099 .handler(move |input: WorkInput| {
2100 let max = max_ref.clone();
2101 let current = current_ref.clone();
2102 async move {
2103 let prev = current.fetch_add(1, Ordering::SeqCst);
2105 max.fetch_max(prev + 1, Ordering::SeqCst);
2106
2107 tokio::time::sleep(Duration::from_millis(50)).await;
2109
2110 current.fetch_sub(1, Ordering::SeqCst);
2111 Ok(CallToolResult::text(format!("completed {}", input.id)))
2112 }
2113 })
2114 .layer(ConcurrencyLimitLayer::new(2))
2115 .build();
2116
2117 let handles: Vec<_> = (0..4)
2119 .map(|i| {
2120 let t = tool.call(serde_json::json!({"id": i}));
2121 tokio::spawn(t)
2122 })
2123 .collect();
2124
2125 for handle in handles {
2126 let result = handle.await.unwrap();
2127 assert!(!result.is_error);
2128 }
2129
2130 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
2132 }
2133
2134 #[tokio::test]
2135 async fn test_tool_with_multiple_layers() {
2136 use std::time::Duration;
2137 use tower::limit::ConcurrencyLimitLayer;
2138 use tower::timeout::TimeoutLayer;
2139
2140 #[derive(Debug, Deserialize, JsonSchema)]
2141 struct Input {
2142 value: String,
2143 }
2144
2145 let tool = ToolBuilder::new("multi_layer_tool")
2147 .description("Tool with multiple layers")
2148 .handler(|input: Input| async move {
2149 Ok(CallToolResult::text(format!("processed: {}", input.value)))
2150 })
2151 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2152 .layer(ConcurrencyLimitLayer::new(10))
2153 .build();
2154
2155 let result = tool.call(serde_json::json!({"value": "test"})).await;
2156 assert!(!result.is_error);
2157 assert_eq!(result.first_text().unwrap(), "processed: test");
2158 }
2159
2160 #[test]
2161 fn test_tool_catch_error_clone() {
2162 let tool = ToolBuilder::new("test")
2165 .description("test")
2166 .extractor_handler((), |RawArgs(_args): RawArgs| async {
2167 Ok(CallToolResult::text("ok"))
2168 })
2169 .build();
2170 let _clone = tool.call(serde_json::json!({}));
2172 }
2173
2174 #[test]
2175 fn test_tool_catch_error_debug() {
2176 #[derive(Debug, Clone)]
2180 struct DebugService;
2181
2182 impl Service<ToolRequest> for DebugService {
2183 type Response = CallToolResult;
2184 type Error = crate::error::Error;
2185 type Future = Pin<
2186 Box<
2187 dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
2188 + Send,
2189 >,
2190 >;
2191
2192 fn poll_ready(
2193 &mut self,
2194 _cx: &mut std::task::Context<'_>,
2195 ) -> Poll<std::result::Result<(), Self::Error>> {
2196 Poll::Ready(Ok(()))
2197 }
2198
2199 fn call(&mut self, _req: ToolRequest) -> Self::Future {
2200 Box::pin(async { Ok(CallToolResult::text("ok")) })
2201 }
2202 }
2203
2204 let catch_error = ToolCatchError::new(DebugService);
2205 let debug = format!("{:?}", catch_error);
2206 assert!(debug.contains("ToolCatchError"));
2207 }
2208
2209 #[test]
2210 fn test_tool_request_new() {
2211 use crate::protocol::RequestId;
2212
2213 let ctx = RequestContext::new(RequestId::Number(42));
2214 let args = serde_json::json!({"key": "value"});
2215 let req = ToolRequest::new(ctx.clone(), args.clone());
2216
2217 assert_eq!(req.args, args);
2218 }
2219
2220 #[test]
2221 fn test_no_params_schema() {
2222 let schema = schemars::schema_for!(NoParams);
2224 let schema_value = serde_json::to_value(&schema).unwrap();
2225 assert_eq!(
2226 schema_value.get("type").and_then(|v| v.as_str()),
2227 Some("object"),
2228 "NoParams should generate type: object schema"
2229 );
2230 }
2231
2232 #[test]
2233 fn test_no_params_deserialize() {
2234 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
2236 assert_eq!(from_empty_object, NoParams);
2237
2238 let from_null: NoParams = serde_json::from_str("null").unwrap();
2239 assert_eq!(from_null, NoParams);
2240
2241 let from_object_with_fields: NoParams =
2243 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
2244 assert_eq!(from_object_with_fields, NoParams);
2245 }
2246
2247 #[tokio::test]
2248 async fn test_no_params_type_in_handler() {
2249 let tool = ToolBuilder::new("status")
2251 .description("Get status")
2252 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
2253 .build();
2254
2255 let schema = tool.definition().input_schema;
2257 assert_eq!(
2258 schema.get("type").and_then(|v| v.as_str()),
2259 Some("object"),
2260 "NoParams handler should produce type: object schema"
2261 );
2262
2263 let result = tool.call(serde_json::json!({})).await;
2265 assert!(!result.is_error);
2266 }
2267
2268 #[tokio::test]
2269 async fn test_tool_with_name_prefix() {
2270 #[derive(Debug, Deserialize, JsonSchema)]
2271 struct Input {
2272 value: String,
2273 }
2274
2275 let tool = ToolBuilder::new("query")
2276 .description("Query something")
2277 .title("Query Tool")
2278 .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
2279 .build();
2280
2281 let prefixed = tool.with_name_prefix("db");
2283
2284 assert_eq!(prefixed.name, "db.query");
2286
2287 assert_eq!(prefixed.description.as_deref(), Some("Query something"));
2289 assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
2290
2291 let result = prefixed
2293 .call(serde_json::json!({"value": "test input"}))
2294 .await;
2295 assert!(!result.is_error);
2296 match &result.content[0] {
2297 Content::Text { text, .. } => assert_eq!(text, "test input"),
2298 _ => panic!("Expected text content"),
2299 }
2300 }
2301
2302 #[tokio::test]
2303 async fn test_tool_with_name_prefix_multiple_levels() {
2304 let tool = ToolBuilder::new("action")
2305 .description("Do something")
2306 .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
2307 .build();
2308
2309 let prefixed = tool.with_name_prefix("level1");
2311 assert_eq!(prefixed.name, "level1.action");
2312
2313 let double_prefixed = prefixed.with_name_prefix("level0");
2314 assert_eq!(double_prefixed.name, "level0.level1.action");
2315 }
2316
2317 #[tokio::test]
2322 async fn test_no_params_handler_basic() {
2323 let tool = ToolBuilder::new("get_status")
2324 .description("Get current status")
2325 .no_params_handler(|| async { Ok(CallToolResult::text("OK")) })
2326 .build();
2327
2328 assert_eq!(tool.name, "get_status");
2329 assert_eq!(tool.description.as_deref(), Some("Get current status"));
2330
2331 let result = tool.call(serde_json::json!({})).await;
2333 assert!(!result.is_error);
2334 assert_eq!(result.first_text().unwrap(), "OK");
2335
2336 let result = tool.call(serde_json::json!(null)).await;
2338 assert!(!result.is_error);
2339
2340 let schema = tool.definition().input_schema;
2342 assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
2343 }
2344
2345 #[tokio::test]
2346 async fn test_no_params_handler_with_captured_state() {
2347 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
2348 let counter_ref = counter.clone();
2349
2350 let tool = ToolBuilder::new("increment")
2351 .description("Increment counter")
2352 .no_params_handler(move || {
2353 let c = counter_ref.clone();
2354 async move {
2355 let prev = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2356 Ok(CallToolResult::text(format!("Incremented from {}", prev)))
2357 }
2358 })
2359 .build();
2360
2361 let _ = tool.call(serde_json::json!({})).await;
2363 let _ = tool.call(serde_json::json!({})).await;
2364 let result = tool.call(serde_json::json!({})).await;
2365
2366 assert!(!result.is_error);
2367 assert_eq!(result.first_text().unwrap(), "Incremented from 2");
2368 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
2369 }
2370
2371 #[tokio::test]
2372 async fn test_no_params_handler_with_layer() {
2373 use std::time::Duration;
2374 use tower::timeout::TimeoutLayer;
2375
2376 let tool = ToolBuilder::new("slow_status")
2377 .description("Slow status check")
2378 .no_params_handler(|| async {
2379 tokio::time::sleep(Duration::from_millis(10)).await;
2380 Ok(CallToolResult::text("done"))
2381 })
2382 .layer(TimeoutLayer::new(Duration::from_secs(1)))
2383 .build();
2384
2385 let result = tool.call(serde_json::json!({})).await;
2386 assert!(!result.is_error);
2387 assert_eq!(result.first_text().unwrap(), "done");
2388 }
2389
2390 #[tokio::test]
2391 async fn test_no_params_handler_timeout() {
2392 use std::time::Duration;
2393 use tower::timeout::TimeoutLayer;
2394
2395 let tool = ToolBuilder::new("very_slow_status")
2396 .description("Very slow status check")
2397 .no_params_handler(|| async {
2398 tokio::time::sleep(Duration::from_millis(200)).await;
2399 Ok(CallToolResult::text("done"))
2400 })
2401 .layer(TimeoutLayer::new(Duration::from_millis(50)))
2402 .build();
2403
2404 let result = tool.call(serde_json::json!({})).await;
2405 assert!(result.is_error);
2406 let msg = result.first_text().unwrap().to_lowercase();
2407 assert!(
2408 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2409 "Expected timeout error, got: {}",
2410 msg
2411 );
2412 }
2413
2414 #[tokio::test]
2415 async fn test_no_params_handler_with_multiple_layers() {
2416 use std::time::Duration;
2417 use tower::limit::ConcurrencyLimitLayer;
2418 use tower::timeout::TimeoutLayer;
2419
2420 let tool = ToolBuilder::new("multi_layer_status")
2421 .description("Status with multiple layers")
2422 .no_params_handler(|| async { Ok(CallToolResult::text("status ok")) })
2423 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2424 .layer(ConcurrencyLimitLayer::new(10))
2425 .build();
2426
2427 let result = tool.call(serde_json::json!({})).await;
2428 assert!(!result.is_error);
2429 assert_eq!(result.first_text().unwrap(), "status ok");
2430 }
2431
2432 #[tokio::test]
2437 async fn test_guard_allows_request() {
2438 #[derive(Debug, Deserialize, JsonSchema)]
2439 #[allow(dead_code)]
2440 struct DeleteInput {
2441 id: String,
2442 confirm: bool,
2443 }
2444
2445 let tool = ToolBuilder::new("delete")
2446 .description("Delete a record")
2447 .handler(|input: DeleteInput| async move {
2448 Ok(CallToolResult::text(format!("deleted {}", input.id)))
2449 })
2450 .guard(|req: &ToolRequest| {
2451 let confirm = req
2452 .args
2453 .get("confirm")
2454 .and_then(|v| v.as_bool())
2455 .unwrap_or(false);
2456 if !confirm {
2457 return Err("Must set confirm=true to delete".to_string());
2458 }
2459 Ok(())
2460 })
2461 .build();
2462
2463 let result = tool
2464 .call(serde_json::json!({"id": "abc", "confirm": true}))
2465 .await;
2466 assert!(!result.is_error);
2467 assert_eq!(result.first_text().unwrap(), "deleted abc");
2468 }
2469
2470 #[tokio::test]
2471 async fn test_guard_rejects_request() {
2472 #[derive(Debug, Deserialize, JsonSchema)]
2473 #[allow(dead_code)]
2474 struct DeleteInput2 {
2475 id: String,
2476 confirm: bool,
2477 }
2478
2479 let tool = ToolBuilder::new("delete2")
2480 .description("Delete a record")
2481 .handler(|input: DeleteInput2| async move {
2482 Ok(CallToolResult::text(format!("deleted {}", input.id)))
2483 })
2484 .guard(|req: &ToolRequest| {
2485 let confirm = req
2486 .args
2487 .get("confirm")
2488 .and_then(|v| v.as_bool())
2489 .unwrap_or(false);
2490 if !confirm {
2491 return Err("Must set confirm=true to delete".to_string());
2492 }
2493 Ok(())
2494 })
2495 .build();
2496
2497 let result = tool
2498 .call(serde_json::json!({"id": "abc", "confirm": false}))
2499 .await;
2500 assert!(result.is_error);
2501 assert!(
2502 result
2503 .first_text()
2504 .unwrap()
2505 .contains("Must set confirm=true")
2506 );
2507 }
2508
2509 #[tokio::test]
2510 async fn test_guard_with_layer() {
2511 use std::time::Duration;
2512 use tower::timeout::TimeoutLayer;
2513
2514 let tool = ToolBuilder::new("guarded_timeout")
2515 .description("Guarded with timeout")
2516 .handler(|input: GreetInput| async move {
2517 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
2518 })
2519 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2520 .guard(|_req: &ToolRequest| Ok(()))
2521 .build();
2522
2523 let result = tool.call(serde_json::json!({"name": "World"})).await;
2524 assert!(!result.is_error);
2525 assert_eq!(result.first_text().unwrap(), "Hello, World!");
2526 }
2527
2528 #[tokio::test]
2529 async fn test_guard_on_no_params_handler() {
2530 let allowed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true));
2531 let allowed_clone = allowed.clone();
2532
2533 let tool = ToolBuilder::new("status")
2534 .description("Get status")
2535 .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2536 .guard(move |_req: &ToolRequest| {
2537 if allowed_clone.load(std::sync::atomic::Ordering::Relaxed) {
2538 Ok(())
2539 } else {
2540 Err("Access denied".to_string())
2541 }
2542 })
2543 .build();
2544
2545 let result = tool.call(serde_json::json!({})).await;
2547 assert!(!result.is_error);
2548 assert_eq!(result.first_text().unwrap(), "ok");
2549
2550 allowed.store(false, std::sync::atomic::Ordering::Relaxed);
2552 let result = tool.call(serde_json::json!({})).await;
2553 assert!(result.is_error);
2554 assert!(result.first_text().unwrap().contains("Access denied"));
2555 }
2556
2557 #[tokio::test]
2558 async fn test_guard_on_no_params_handler_with_layer() {
2559 use std::time::Duration;
2560 use tower::timeout::TimeoutLayer;
2561
2562 let tool = ToolBuilder::new("status_layered")
2563 .description("Get status with layers")
2564 .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2565 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2566 .guard(|_req: &ToolRequest| Ok(()))
2567 .build();
2568
2569 let result = tool.call(serde_json::json!({})).await;
2570 assert!(!result.is_error);
2571 assert_eq!(result.first_text().unwrap(), "ok");
2572 }
2573
2574 #[tokio::test]
2575 async fn test_guard_on_extractor_handler() {
2576 use std::sync::Arc;
2577
2578 #[derive(Clone)]
2579 struct AppState {
2580 prefix: String,
2581 }
2582
2583 #[derive(Debug, Deserialize, JsonSchema)]
2584 struct QueryInput {
2585 query: String,
2586 }
2587
2588 let state = Arc::new(AppState {
2589 prefix: "db".to_string(),
2590 });
2591
2592 let tool = ToolBuilder::new("search")
2593 .description("Search")
2594 .extractor_handler(
2595 state,
2596 |State(app): State<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
2597 Ok(CallToolResult::text(format!(
2598 "{}: {}",
2599 app.prefix, input.query
2600 )))
2601 },
2602 )
2603 .guard(|req: &ToolRequest| {
2604 let query = req.args.get("query").and_then(|v| v.as_str()).unwrap_or("");
2605 if query.is_empty() {
2606 return Err("Query cannot be empty".to_string());
2607 }
2608 Ok(())
2609 })
2610 .build();
2611
2612 let result = tool.call(serde_json::json!({"query": "hello"})).await;
2614 assert!(!result.is_error);
2615 assert_eq!(result.first_text().unwrap(), "db: hello");
2616
2617 let result = tool.call(serde_json::json!({"query": ""})).await;
2619 assert!(result.is_error);
2620 assert!(
2621 result
2622 .first_text()
2623 .unwrap()
2624 .contains("Query cannot be empty")
2625 );
2626 }
2627
2628 #[tokio::test]
2629 async fn test_guard_on_extractor_handler_with_layer() {
2630 use std::sync::Arc;
2631 use std::time::Duration;
2632 use tower::timeout::TimeoutLayer;
2633
2634 #[derive(Clone)]
2635 struct AppState2 {
2636 prefix: String,
2637 }
2638
2639 #[derive(Debug, Deserialize, JsonSchema)]
2640 struct QueryInput2 {
2641 query: String,
2642 }
2643
2644 let state = Arc::new(AppState2 {
2645 prefix: "db".to_string(),
2646 });
2647
2648 let tool = ToolBuilder::new("search2")
2649 .description("Search with layer and guard")
2650 .extractor_handler(
2651 state,
2652 |State(app): State<Arc<AppState2>>, Json(input): Json<QueryInput2>| async move {
2653 Ok(CallToolResult::text(format!(
2654 "{}: {}",
2655 app.prefix, input.query
2656 )))
2657 },
2658 )
2659 .layer(TimeoutLayer::new(Duration::from_secs(5)))
2660 .guard(|_req: &ToolRequest| Ok(()))
2661 .build();
2662
2663 let result = tool.call(serde_json::json!({"query": "hello"})).await;
2664 assert!(!result.is_error);
2665 assert_eq!(result.first_text().unwrap(), "db: hello");
2666 }
2667
2668 #[tokio::test]
2669 async fn test_tool_with_guard_post_build() {
2670 let tool = ToolBuilder::new("admin_action")
2671 .description("Admin action")
2672 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2673 .build();
2674
2675 let guarded = tool.with_guard(|req: &ToolRequest| {
2677 let name = req.args.get("name").and_then(|v| v.as_str()).unwrap_or("");
2678 if name == "admin" {
2679 Ok(())
2680 } else {
2681 Err("Only admin allowed".to_string())
2682 }
2683 });
2684
2685 let result = guarded.call(serde_json::json!({"name": "admin"})).await;
2687 assert!(!result.is_error);
2688
2689 let result = guarded.call(serde_json::json!({"name": "user"})).await;
2691 assert!(result.is_error);
2692 assert!(result.first_text().unwrap().contains("Only admin allowed"));
2693 }
2694
2695 #[tokio::test]
2696 async fn test_with_guard_preserves_tool_metadata() {
2697 let tool = ToolBuilder::new("my_tool")
2698 .description("A tool")
2699 .title("My Tool")
2700 .read_only()
2701 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2702 .build();
2703
2704 let guarded = tool.with_guard(|_req: &ToolRequest| Ok(()));
2705
2706 assert_eq!(guarded.name, "my_tool");
2707 assert_eq!(guarded.description.as_deref(), Some("A tool"));
2708 assert_eq!(guarded.title.as_deref(), Some("My Tool"));
2709 assert!(guarded.annotations.is_some());
2710 }
2711
2712 #[tokio::test]
2713 async fn test_guard_group_pattern() {
2714 let require_auth = |req: &ToolRequest| {
2716 let token = req
2717 .args
2718 .get("_token")
2719 .and_then(|v| v.as_str())
2720 .unwrap_or("");
2721 if token == "valid" {
2722 Ok(())
2723 } else {
2724 Err("Authentication required".to_string())
2725 }
2726 };
2727
2728 let tool1 = ToolBuilder::new("action1")
2729 .description("Action 1")
2730 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action1")) })
2731 .build();
2732 let tool2 = ToolBuilder::new("action2")
2733 .description("Action 2")
2734 .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action2")) })
2735 .build();
2736
2737 let guarded1 = tool1.with_guard(require_auth);
2739 let guarded2 = tool2.with_guard(require_auth);
2740
2741 let r1 = guarded1
2743 .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2744 .await;
2745 let r2 = guarded2
2746 .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2747 .await;
2748 assert!(r1.is_error);
2749 assert!(r2.is_error);
2750
2751 let r1 = guarded1
2753 .call(serde_json::json!({"name": "test", "_token": "valid"}))
2754 .await;
2755 let r2 = guarded2
2756 .call(serde_json::json!({"name": "test", "_token": "valid"}))
2757 .await;
2758 assert!(!r1.is_error);
2759 assert!(!r2.is_error);
2760 }
2761}