1use std::borrow::Cow;
35use std::convert::Infallible;
36use std::fmt;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41
42use schemars::{JsonSchema, Schema, SchemaGenerator};
43use serde::Serialize;
44use serde::de::DeserializeOwned;
45use serde_json::Value;
46use tower::util::BoxCloneService;
47use tower_service::Service;
48
49use crate::context::RequestContext;
50use crate::error::{Error, Result};
51use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
52
53#[derive(Debug, Clone)]
62pub struct ToolRequest {
63 pub ctx: RequestContext,
65 pub args: Value,
67}
68
69impl ToolRequest {
70 pub fn new(ctx: RequestContext, args: Value) -> Self {
72 Self { ctx, args }
73 }
74}
75
76pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
82
83pub struct ToolCatchError<S> {
89 inner: S,
90}
91
92impl<S> ToolCatchError<S> {
93 pub fn new(inner: S) -> Self {
95 Self { inner }
96 }
97}
98
99impl<S: Clone> Clone for ToolCatchError<S> {
100 fn clone(&self) -> Self {
101 Self {
102 inner: self.inner.clone(),
103 }
104 }
105}
106
107impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("ToolCatchError")
110 .field("inner", &self.inner)
111 .finish()
112 }
113}
114
115impl<S> Service<ToolRequest> for ToolCatchError<S>
116where
117 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
118 S::Error: fmt::Display + Send,
119 S::Future: Send,
120{
121 type Response = CallToolResult;
122 type Error = Infallible;
123 type Future =
124 Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
127 match self.inner.poll_ready(cx) {
129 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
130 Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
131 Poll::Pending => Poll::Pending,
132 }
133 }
134
135 fn call(&mut self, req: ToolRequest) -> Self::Future {
136 let fut = self.inner.call(req);
137
138 Box::pin(async move {
139 match fut.await {
140 Ok(result) => Ok(result),
141 Err(err) => Ok(CallToolResult::error(err.to_string())),
142 }
143 })
144 }
145}
146
147#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
168pub struct NoParams;
169
170impl<'de> serde::Deserialize<'de> for NoParams {
171 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
172 where
173 D: serde::Deserializer<'de>,
174 {
175 struct NoParamsVisitor;
177
178 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
179 type Value = NoParams;
180
181 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
182 formatter.write_str("null or an object")
183 }
184
185 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
186 where
187 E: serde::de::Error,
188 {
189 Ok(NoParams)
190 }
191
192 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
193 where
194 E: serde::de::Error,
195 {
196 Ok(NoParams)
197 }
198
199 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
200 where
201 D: serde::Deserializer<'de>,
202 {
203 serde::Deserialize::deserialize(deserializer)
204 }
205
206 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
207 where
208 A: serde::de::MapAccess<'de>,
209 {
210 while map
212 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
213 .is_some()
214 {}
215 Ok(NoParams)
216 }
217 }
218
219 deserializer.deserialize_any(NoParamsVisitor)
220 }
221}
222
223impl JsonSchema for NoParams {
224 fn schema_name() -> Cow<'static, str> {
225 Cow::Borrowed("NoParams")
226 }
227
228 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
229 serde_json::json!({
230 "type": "object"
231 })
232 .try_into()
233 .expect("valid schema")
234 }
235}
236
237pub fn validate_tool_name(name: &str) -> Result<()> {
245 if name.is_empty() {
246 return Err(Error::tool("Tool name cannot be empty"));
247 }
248 if name.len() > 128 {
249 return Err(Error::tool(format!(
250 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
251 name,
252 name.len()
253 )));
254 }
255 if let Some(invalid_char) = name
256 .chars()
257 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
258 {
259 return Err(Error::tool(format!(
260 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
261 name, invalid_char
262 )));
263 }
264 Ok(())
265}
266
267pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
269
270pub trait ToolHandler: Send + Sync {
272 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
274
275 fn call_with_context(
280 &self,
281 _ctx: RequestContext,
282 args: Value,
283 ) -> BoxFuture<'_, Result<CallToolResult>> {
284 self.call(args)
285 }
286
287 fn uses_context(&self) -> bool {
289 false
290 }
291
292 fn input_schema(&self) -> Value;
294}
295
296pub(crate) struct ToolHandlerService<H> {
301 handler: Arc<H>,
302}
303
304impl<H> ToolHandlerService<H> {
305 pub(crate) fn new(handler: H) -> Self {
306 Self {
307 handler: Arc::new(handler),
308 }
309 }
310}
311
312impl<H> Clone for ToolHandlerService<H> {
313 fn clone(&self) -> Self {
314 Self {
315 handler: self.handler.clone(),
316 }
317 }
318}
319
320impl<H> Service<ToolRequest> for ToolHandlerService<H>
321where
322 H: ToolHandler + 'static,
323{
324 type Response = CallToolResult;
325 type Error = Error;
326 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
327
328 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
329 Poll::Ready(Ok(()))
330 }
331
332 fn call(&mut self, req: ToolRequest) -> Self::Future {
333 let handler = self.handler.clone();
334 Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
335 }
336}
337
338pub struct Tool {
345 pub name: String,
347 pub title: Option<String>,
349 pub description: Option<String>,
351 pub output_schema: Option<Value>,
353 pub icons: Option<Vec<ToolIcon>>,
355 pub annotations: Option<ToolAnnotations>,
357 pub(crate) service: BoxToolService,
359 pub(crate) input_schema: Value,
361}
362
363impl std::fmt::Debug for Tool {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.debug_struct("Tool")
366 .field("name", &self.name)
367 .field("title", &self.title)
368 .field("description", &self.description)
369 .field("output_schema", &self.output_schema)
370 .field("icons", &self.icons)
371 .field("annotations", &self.annotations)
372 .finish_non_exhaustive()
373 }
374}
375
376unsafe impl Send for Tool {}
379unsafe impl Sync for Tool {}
380
381impl Clone for Tool {
382 fn clone(&self) -> Self {
383 Self {
384 name: self.name.clone(),
385 title: self.title.clone(),
386 description: self.description.clone(),
387 output_schema: self.output_schema.clone(),
388 icons: self.icons.clone(),
389 annotations: self.annotations.clone(),
390 service: self.service.clone(),
391 input_schema: self.input_schema.clone(),
392 }
393 }
394}
395
396impl Tool {
397 pub fn builder(name: impl Into<String>) -> ToolBuilder {
399 ToolBuilder::new(name)
400 }
401
402 pub fn definition(&self) -> ToolDefinition {
404 ToolDefinition {
405 name: self.name.clone(),
406 title: self.title.clone(),
407 description: self.description.clone(),
408 input_schema: self.input_schema.clone(),
409 output_schema: self.output_schema.clone(),
410 icons: self.icons.clone(),
411 annotations: self.annotations.clone(),
412 }
413 }
414
415 pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
420 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
421 self.call_with_context(ctx, args)
422 }
423
424 pub fn call_with_context(
435 &self,
436 ctx: RequestContext,
437 args: Value,
438 ) -> BoxFuture<'static, CallToolResult> {
439 use tower::ServiceExt;
440 let service = self.service.clone();
441 Box::pin(async move {
442 service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
445 })
446 }
447
448 pub fn with_name_prefix(&self, prefix: &str) -> Self {
476 Self {
477 name: format!("{}.{}", prefix, self.name),
478 title: self.title.clone(),
479 description: self.description.clone(),
480 output_schema: self.output_schema.clone(),
481 icons: self.icons.clone(),
482 annotations: self.annotations.clone(),
483 service: self.service.clone(),
484 input_schema: self.input_schema.clone(),
485 }
486 }
487
488 fn from_handler<H: ToolHandler + 'static>(
490 name: String,
491 title: Option<String>,
492 description: Option<String>,
493 output_schema: Option<Value>,
494 icons: Option<Vec<ToolIcon>>,
495 annotations: Option<ToolAnnotations>,
496 handler: H,
497 ) -> Self {
498 let input_schema = handler.input_schema();
499 let handler_service = ToolHandlerService::new(handler);
500 let catch_error = ToolCatchError::new(handler_service);
501 let service = BoxCloneService::new(catch_error);
502
503 Self {
504 name,
505 title,
506 description,
507 output_schema,
508 icons,
509 annotations,
510 service,
511 input_schema,
512 }
513 }
514}
515
516pub struct ToolBuilder {
545 name: String,
546 title: Option<String>,
547 description: Option<String>,
548 output_schema: Option<Value>,
549 icons: Option<Vec<ToolIcon>>,
550 annotations: Option<ToolAnnotations>,
551}
552
553impl ToolBuilder {
554 pub fn new(name: impl Into<String>) -> Self {
555 Self {
556 name: name.into(),
557 title: None,
558 description: None,
559 output_schema: None,
560 icons: None,
561 annotations: None,
562 }
563 }
564
565 pub fn title(mut self, title: impl Into<String>) -> Self {
567 self.title = Some(title.into());
568 self
569 }
570
571 pub fn output_schema(mut self, schema: Value) -> Self {
573 self.output_schema = Some(schema);
574 self
575 }
576
577 pub fn icon(mut self, src: impl Into<String>) -> Self {
579 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
580 src: src.into(),
581 mime_type: None,
582 sizes: None,
583 });
584 self
585 }
586
587 pub fn icon_with_meta(
589 mut self,
590 src: impl Into<String>,
591 mime_type: Option<String>,
592 sizes: Option<Vec<String>>,
593 ) -> Self {
594 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
595 src: src.into(),
596 mime_type,
597 sizes,
598 });
599 self
600 }
601
602 pub fn description(mut self, description: impl Into<String>) -> Self {
604 self.description = Some(description.into());
605 self
606 }
607
608 pub fn read_only(mut self) -> Self {
610 self.annotations
611 .get_or_insert_with(ToolAnnotations::default)
612 .read_only_hint = true;
613 self
614 }
615
616 pub fn non_destructive(mut self) -> Self {
618 self.annotations
619 .get_or_insert_with(ToolAnnotations::default)
620 .destructive_hint = false;
621 self
622 }
623
624 pub fn idempotent(mut self) -> Self {
626 self.annotations
627 .get_or_insert_with(ToolAnnotations::default)
628 .idempotent_hint = true;
629 self
630 }
631
632 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
634 self.annotations = Some(annotations);
635 self
636 }
637
638 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
682 where
683 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
684 F: Fn(I) -> Fut + Send + Sync + 'static,
685 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
686 {
687 ToolBuilderWithHandler {
688 name: self.name,
689 title: self.title,
690 description: self.description,
691 output_schema: self.output_schema,
692 icons: self.icons,
693 annotations: self.annotations,
694 handler,
695 _phantom: std::marker::PhantomData,
696 }
697 }
698
699 pub fn extractor_handler<S, F, T>(
753 self,
754 state: S,
755 handler: F,
756 ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
757 where
758 S: Clone + Send + Sync + 'static,
759 F: crate::extract::ExtractorHandler<S, T> + Clone,
760 T: Send + Sync + 'static,
761 {
762 crate::extract::ToolBuilderWithExtractor {
763 name: self.name,
764 title: self.title,
765 description: self.description,
766 output_schema: self.output_schema,
767 icons: self.icons,
768 annotations: self.annotations,
769 state,
770 handler,
771 input_schema: F::input_schema(),
772 _phantom: std::marker::PhantomData,
773 }
774 }
775
776 pub fn extractor_handler_typed<S, F, T, I>(
810 self,
811 state: S,
812 handler: F,
813 ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
814 where
815 S: Clone + Send + Sync + 'static,
816 F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
817 T: Send + Sync + 'static,
818 I: schemars::JsonSchema + Send + Sync + 'static,
819 {
820 crate::extract::ToolBuilderWithTypedExtractor {
821 name: self.name,
822 title: self.title,
823 description: self.description,
824 output_schema: self.output_schema,
825 icons: self.icons,
826 annotations: self.annotations,
827 state,
828 handler,
829 _phantom: std::marker::PhantomData,
830 }
831 }
832}
833
834pub struct ToolBuilderWithHandler<I, F> {
836 name: String,
837 title: Option<String>,
838 description: Option<String>,
839 output_schema: Option<Value>,
840 icons: Option<Vec<ToolIcon>>,
841 annotations: Option<ToolAnnotations>,
842 handler: F,
843 _phantom: std::marker::PhantomData<I>,
844}
845
846impl<I, F, Fut> ToolBuilderWithHandler<I, F>
847where
848 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
849 F: Fn(I) -> Fut + Send + Sync + 'static,
850 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
851{
852 pub fn build(self) -> Result<Tool> {
856 validate_tool_name(&self.name)?;
857 Ok(Tool::from_handler(
858 self.name,
859 self.title,
860 self.description,
861 self.output_schema,
862 self.icons,
863 self.annotations,
864 TypedHandler {
865 handler: self.handler,
866 _phantom: std::marker::PhantomData,
867 },
868 ))
869 }
870
871 pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
898 ToolBuilderWithLayer {
899 name: self.name,
900 title: self.title,
901 description: self.description,
902 output_schema: self.output_schema,
903 icons: self.icons,
904 annotations: self.annotations,
905 handler: self.handler,
906 layer,
907 _phantom: std::marker::PhantomData,
908 }
909 }
910}
911
912pub struct ToolBuilderWithLayer<I, F, L> {
916 name: String,
917 title: Option<String>,
918 description: Option<String>,
919 output_schema: Option<Value>,
920 icons: Option<Vec<ToolIcon>>,
921 annotations: Option<ToolAnnotations>,
922 handler: F,
923 layer: L,
924 _phantom: std::marker::PhantomData<I>,
925}
926
927#[allow(private_bounds)]
930impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
931where
932 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
933 F: Fn(I) -> Fut + Send + Sync + 'static,
934 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
935 L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
936 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
937 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
938 <L::Service as Service<ToolRequest>>::Future: Send,
939{
940 pub fn build(self) -> Result<Tool> {
944 validate_tool_name(&self.name)?;
945
946 let input_schema = schemars::schema_for!(I);
947 let input_schema = serde_json::to_value(input_schema)
948 .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
949
950 let handler_service = ToolHandlerService::new(TypedHandler {
951 handler: self.handler,
952 _phantom: std::marker::PhantomData,
953 });
954 let layered = self.layer.layer(handler_service);
955 let catch_error = ToolCatchError::new(layered);
956 let service = BoxCloneService::new(catch_error);
957
958 Ok(Tool {
959 name: self.name,
960 title: self.title,
961 description: self.description,
962 output_schema: self.output_schema,
963 icons: self.icons,
964 annotations: self.annotations,
965 service,
966 input_schema,
967 })
968 }
969
970 pub fn layer<L2>(
975 self,
976 layer: L2,
977 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
978 ToolBuilderWithLayer {
979 name: self.name,
980 title: self.title,
981 description: self.description,
982 output_schema: self.output_schema,
983 icons: self.icons,
984 annotations: self.annotations,
985 handler: self.handler,
986 layer: tower::layer::util::Stack::new(layer, self.layer),
987 _phantom: std::marker::PhantomData,
988 }
989 }
990}
991
992struct TypedHandler<I, F> {
998 handler: F,
999 _phantom: std::marker::PhantomData<I>,
1000}
1001
1002impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1003where
1004 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1005 F: Fn(I) -> Fut + Send + Sync + 'static,
1006 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1007{
1008 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1009 Box::pin(async move {
1010 let input: I = serde_json::from_value(args)
1011 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1012 (self.handler)(input).await
1013 })
1014 }
1015
1016 fn input_schema(&self) -> Value {
1017 let schema = schemars::schema_for!(I);
1018 serde_json::to_value(schema).unwrap_or_else(|_| {
1019 serde_json::json!({
1020 "type": "object"
1021 })
1022 })
1023 }
1024}
1025
1026pub trait McpTool: Send + Sync + 'static {
1067 const NAME: &'static str;
1068 const DESCRIPTION: &'static str;
1069
1070 type Input: JsonSchema + DeserializeOwned + Send;
1071 type Output: Serialize + Send;
1072
1073 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1074
1075 fn annotations(&self) -> Option<ToolAnnotations> {
1077 None
1078 }
1079
1080 fn into_tool(self) -> Result<Tool>
1084 where
1085 Self: Sized,
1086 {
1087 validate_tool_name(Self::NAME)?;
1088 let annotations = self.annotations();
1089 let tool = Arc::new(self);
1090 Ok(Tool::from_handler(
1091 Self::NAME.to_string(),
1092 None,
1093 Some(Self::DESCRIPTION.to_string()),
1094 None,
1095 None,
1096 annotations,
1097 McpToolHandler { tool },
1098 ))
1099 }
1100}
1101
1102struct McpToolHandler<T: McpTool> {
1104 tool: Arc<T>,
1105}
1106
1107impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1108 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1109 let tool = self.tool.clone();
1110 Box::pin(async move {
1111 let input: T::Input = serde_json::from_value(args)
1112 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1113 let output = tool.call(input).await?;
1114 let value = serde_json::to_value(output)
1115 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
1116 Ok(CallToolResult::json(value))
1117 })
1118 }
1119
1120 fn input_schema(&self) -> Value {
1121 let schema = schemars::schema_for!(T::Input);
1122 serde_json::to_value(schema).unwrap_or_else(|_| {
1123 serde_json::json!({
1124 "type": "object"
1125 })
1126 })
1127 }
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132 use super::*;
1133 use crate::extract::{Context, Json, RawArgs, State};
1134 use crate::protocol::Content;
1135 use schemars::JsonSchema;
1136 use serde::Deserialize;
1137
1138 #[derive(Debug, Deserialize, JsonSchema)]
1139 struct GreetInput {
1140 name: String,
1141 }
1142
1143 #[tokio::test]
1144 async fn test_builder_tool() {
1145 let tool = ToolBuilder::new("greet")
1146 .description("Greet someone")
1147 .handler(|input: GreetInput| async move {
1148 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1149 })
1150 .build()
1151 .expect("valid tool name");
1152
1153 assert_eq!(tool.name, "greet");
1154 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1155
1156 let result = tool.call(serde_json::json!({"name": "World"})).await;
1157
1158 assert!(!result.is_error);
1159 }
1160
1161 #[tokio::test]
1162 async fn test_raw_handler() {
1163 let tool = ToolBuilder::new("echo")
1164 .description("Echo input")
1165 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1166 Ok(CallToolResult::json(args))
1167 })
1168 .build()
1169 .expect("valid tool name");
1170
1171 let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1172
1173 assert!(!result.is_error);
1174 }
1175
1176 #[test]
1177 fn test_invalid_tool_name_empty() {
1178 let result = ToolBuilder::new("")
1179 .description("Empty name")
1180 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1181 Ok(CallToolResult::json(args))
1182 })
1183 .build();
1184
1185 assert!(result.is_err());
1186 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1187 }
1188
1189 #[test]
1190 fn test_invalid_tool_name_too_long() {
1191 let long_name = "a".repeat(129);
1192 let result = ToolBuilder::new(long_name)
1193 .description("Too long")
1194 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1195 Ok(CallToolResult::json(args))
1196 })
1197 .build();
1198
1199 assert!(result.is_err());
1200 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1201 }
1202
1203 #[test]
1204 fn test_invalid_tool_name_bad_chars() {
1205 let result = ToolBuilder::new("my tool!")
1206 .description("Bad chars")
1207 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1208 Ok(CallToolResult::json(args))
1209 })
1210 .build();
1211
1212 assert!(result.is_err());
1213 assert!(
1214 result
1215 .unwrap_err()
1216 .to_string()
1217 .contains("invalid character")
1218 );
1219 }
1220
1221 #[test]
1222 fn test_valid_tool_names() {
1223 let names = [
1225 "my_tool",
1226 "my-tool",
1227 "my.tool",
1228 "MyTool123",
1229 "a",
1230 &"a".repeat(128),
1231 ];
1232 for name in names {
1233 let result = ToolBuilder::new(name)
1234 .description("Valid")
1235 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1236 Ok(CallToolResult::json(args))
1237 })
1238 .build();
1239 assert!(result.is_ok(), "Expected '{}' to be valid", name);
1240 }
1241 }
1242
1243 #[tokio::test]
1244 async fn test_context_aware_handler() {
1245 use crate::context::notification_channel;
1246 use crate::protocol::{ProgressToken, RequestId};
1247
1248 #[derive(Debug, Deserialize, JsonSchema)]
1249 struct ProcessInput {
1250 count: i32,
1251 }
1252
1253 let tool = ToolBuilder::new("process")
1254 .description("Process with context")
1255 .extractor_handler_typed::<_, _, _, ProcessInput>(
1256 (),
1257 |ctx: Context, Json(input): Json<ProcessInput>| async move {
1258 for i in 0..input.count {
1260 if ctx.is_cancelled() {
1261 return Ok(CallToolResult::error("Cancelled"));
1262 }
1263 ctx.report_progress(i as f64, Some(input.count as f64), None)
1264 .await;
1265 }
1266 Ok(CallToolResult::text(format!(
1267 "Processed {} items",
1268 input.count
1269 )))
1270 },
1271 )
1272 .build()
1273 .expect("valid tool name");
1274
1275 assert_eq!(tool.name, "process");
1276
1277 let (tx, mut rx) = notification_channel(10);
1279 let ctx = RequestContext::new(RequestId::Number(1))
1280 .with_progress_token(ProgressToken::Number(42))
1281 .with_notification_sender(tx);
1282
1283 let result = tool
1284 .call_with_context(ctx, serde_json::json!({"count": 3}))
1285 .await;
1286
1287 assert!(!result.is_error);
1288
1289 let mut progress_count = 0;
1291 while rx.try_recv().is_ok() {
1292 progress_count += 1;
1293 }
1294 assert_eq!(progress_count, 3);
1295 }
1296
1297 #[tokio::test]
1298 async fn test_context_aware_handler_cancellation() {
1299 use crate::protocol::RequestId;
1300 use std::sync::atomic::{AtomicI32, Ordering};
1301
1302 #[derive(Debug, Deserialize, JsonSchema)]
1303 struct LongRunningInput {
1304 iterations: i32,
1305 }
1306
1307 let iterations_completed = Arc::new(AtomicI32::new(0));
1308 let iterations_ref = iterations_completed.clone();
1309
1310 let tool = ToolBuilder::new("long_running")
1311 .description("Long running task")
1312 .extractor_handler_typed::<_, _, _, LongRunningInput>(
1313 (),
1314 move |ctx: Context, Json(input): Json<LongRunningInput>| {
1315 let completed = iterations_ref.clone();
1316 async move {
1317 for i in 0..input.iterations {
1318 if ctx.is_cancelled() {
1319 return Ok(CallToolResult::error("Cancelled"));
1320 }
1321 completed.fetch_add(1, Ordering::SeqCst);
1322 tokio::task::yield_now().await;
1324 if i == 2 {
1326 ctx.cancellation_token().cancel();
1327 }
1328 }
1329 Ok(CallToolResult::text("Done"))
1330 }
1331 },
1332 )
1333 .build()
1334 .expect("valid tool name");
1335
1336 let ctx = RequestContext::new(RequestId::Number(1));
1337
1338 let result = tool
1339 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1340 .await;
1341
1342 assert!(result.is_error);
1345 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1346 }
1347
1348 #[tokio::test]
1349 async fn test_tool_builder_with_enhanced_fields() {
1350 let output_schema = serde_json::json!({
1351 "type": "object",
1352 "properties": {
1353 "greeting": {"type": "string"}
1354 }
1355 });
1356
1357 let tool = ToolBuilder::new("greet")
1358 .title("Greeting Tool")
1359 .description("Greet someone")
1360 .output_schema(output_schema.clone())
1361 .icon("https://example.com/icon.png")
1362 .icon_with_meta(
1363 "https://example.com/icon-large.png",
1364 Some("image/png".to_string()),
1365 Some(vec!["96x96".to_string()]),
1366 )
1367 .handler(|input: GreetInput| async move {
1368 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1369 })
1370 .build()
1371 .expect("valid tool name");
1372
1373 assert_eq!(tool.name, "greet");
1374 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1375 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1376 assert_eq!(tool.output_schema, Some(output_schema));
1377 assert!(tool.icons.is_some());
1378 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1379
1380 let def = tool.definition();
1382 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1383 assert!(def.output_schema.is_some());
1384 assert!(def.icons.is_some());
1385 }
1386
1387 #[tokio::test]
1388 async fn test_handler_with_state() {
1389 let shared = Arc::new("shared-state".to_string());
1390
1391 let tool = ToolBuilder::new("stateful")
1392 .description("Uses shared state")
1393 .extractor_handler_typed::<_, _, _, GreetInput>(
1394 shared,
1395 |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1396 Ok(CallToolResult::text(format!(
1397 "{}: Hello, {}!",
1398 state, input.name
1399 )))
1400 },
1401 )
1402 .build()
1403 .expect("valid tool name");
1404
1405 let result = tool.call(serde_json::json!({"name": "World"})).await;
1406 assert!(!result.is_error);
1407 }
1408
1409 #[tokio::test]
1410 async fn test_handler_with_state_and_context() {
1411 use crate::protocol::RequestId;
1412
1413 let shared = Arc::new(42_i32);
1414
1415 let tool =
1416 ToolBuilder::new("stateful_ctx")
1417 .description("Uses state and context")
1418 .extractor_handler_typed::<_, _, _, GreetInput>(
1419 shared,
1420 |State(state): State<Arc<i32>>,
1421 _ctx: Context,
1422 Json(input): Json<GreetInput>| async move {
1423 Ok(CallToolResult::text(format!(
1424 "{}: Hello, {}!",
1425 state, input.name
1426 )))
1427 },
1428 )
1429 .build()
1430 .expect("valid tool name");
1431
1432 let ctx = RequestContext::new(RequestId::Number(1));
1433 let result = tool
1434 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1435 .await;
1436 assert!(!result.is_error);
1437 }
1438
1439 #[tokio::test]
1440 async fn test_handler_no_params() {
1441 let tool = ToolBuilder::new("no_params")
1442 .description("Takes no parameters")
1443 .extractor_handler_typed::<_, _, _, NoParams>((), |Json(_): Json<NoParams>| async {
1444 Ok(CallToolResult::text("no params result"))
1445 })
1446 .build()
1447 .expect("valid tool name");
1448
1449 assert_eq!(tool.name, "no_params");
1450
1451 let result = tool.call(serde_json::json!({})).await;
1453 assert!(!result.is_error);
1454
1455 let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1457 assert!(!result.is_error);
1458
1459 let schema = tool.definition().input_schema;
1461 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1462 }
1463
1464 #[tokio::test]
1465 async fn test_handler_with_state_no_params() {
1466 let shared = Arc::new("shared_value".to_string());
1467
1468 let tool = ToolBuilder::new("with_state_no_params")
1469 .description("Takes no parameters but has state")
1470 .extractor_handler_typed::<_, _, _, NoParams>(
1471 shared,
1472 |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1473 Ok(CallToolResult::text(format!("state: {}", state)))
1474 },
1475 )
1476 .build()
1477 .expect("valid tool name");
1478
1479 assert_eq!(tool.name, "with_state_no_params");
1480
1481 let result = tool.call(serde_json::json!({})).await;
1483 assert!(!result.is_error);
1484 assert_eq!(result.first_text().unwrap(), "state: shared_value");
1485
1486 let schema = tool.definition().input_schema;
1488 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1489 }
1490
1491 #[tokio::test]
1492 async fn test_handler_no_params_with_context() {
1493 let tool = ToolBuilder::new("no_params_with_context")
1494 .description("Takes no parameters but has context")
1495 .extractor_handler_typed::<_, _, _, NoParams>(
1496 (),
1497 |_ctx: Context, Json(_): Json<NoParams>| async move {
1498 Ok(CallToolResult::text("context available"))
1499 },
1500 )
1501 .build()
1502 .expect("valid tool name");
1503
1504 assert_eq!(tool.name, "no_params_with_context");
1505
1506 let result = tool.call(serde_json::json!({})).await;
1507 assert!(!result.is_error);
1508 assert_eq!(result.first_text().unwrap(), "context available");
1509 }
1510
1511 #[tokio::test]
1512 async fn test_handler_with_state_and_context_no_params() {
1513 let shared = Arc::new("shared".to_string());
1514
1515 let tool = ToolBuilder::new("state_context_no_params")
1516 .description("Has state and context, no params")
1517 .extractor_handler_typed::<_, _, _, NoParams>(
1518 shared,
1519 |State(state): State<Arc<String>>,
1520 _ctx: Context,
1521 Json(_): Json<NoParams>| async move {
1522 Ok(CallToolResult::text(format!("state: {}", state)))
1523 },
1524 )
1525 .build()
1526 .expect("valid tool name");
1527
1528 assert_eq!(tool.name, "state_context_no_params");
1529
1530 let result = tool.call(serde_json::json!({})).await;
1531 assert!(!result.is_error);
1532 assert_eq!(result.first_text().unwrap(), "state: shared");
1533 }
1534
1535 #[tokio::test]
1536 async fn test_raw_handler_with_state() {
1537 let prefix = Arc::new("prefix:".to_string());
1538
1539 let tool = ToolBuilder::new("raw_with_state")
1540 .description("Raw handler with state")
1541 .extractor_handler(
1542 prefix,
1543 |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
1544 Ok(CallToolResult::text(format!("{} {}", state, args)))
1545 },
1546 )
1547 .build()
1548 .expect("valid tool name");
1549
1550 assert_eq!(tool.name, "raw_with_state");
1551
1552 let result = tool.call(serde_json::json!({"key": "value"})).await;
1553 assert!(!result.is_error);
1554 assert!(result.first_text().unwrap().starts_with("prefix:"));
1555 }
1556
1557 #[tokio::test]
1558 async fn test_raw_handler_with_state_and_context() {
1559 let prefix = Arc::new("prefix:".to_string());
1560
1561 let tool = ToolBuilder::new("raw_state_context")
1562 .description("Raw handler with state and context")
1563 .extractor_handler(
1564 prefix,
1565 |State(state): State<Arc<String>>,
1566 _ctx: Context,
1567 RawArgs(args): RawArgs| async move {
1568 Ok(CallToolResult::text(format!("{} {}", state, args)))
1569 },
1570 )
1571 .build()
1572 .expect("valid tool name");
1573
1574 assert_eq!(tool.name, "raw_state_context");
1575
1576 let result = tool.call(serde_json::json!({"key": "value"})).await;
1577 assert!(!result.is_error);
1578 assert!(result.first_text().unwrap().starts_with("prefix:"));
1579 }
1580
1581 #[tokio::test]
1582 async fn test_tool_with_timeout_layer() {
1583 use std::time::Duration;
1584 use tower::timeout::TimeoutLayer;
1585
1586 #[derive(Debug, Deserialize, JsonSchema)]
1587 struct SlowInput {
1588 delay_ms: u64,
1589 }
1590
1591 let tool = ToolBuilder::new("slow_tool")
1593 .description("A slow tool")
1594 .handler(|input: SlowInput| async move {
1595 tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
1596 Ok(CallToolResult::text("completed"))
1597 })
1598 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1599 .build()
1600 .expect("valid tool name");
1601
1602 let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
1604 assert!(!result.is_error);
1605 assert_eq!(result.first_text().unwrap(), "completed");
1606
1607 let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
1609 assert!(result.is_error);
1610 let msg = result.first_text().unwrap().to_lowercase();
1612 assert!(
1613 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1614 "Expected timeout error, got: {}",
1615 msg
1616 );
1617 }
1618
1619 #[tokio::test]
1620 async fn test_tool_with_concurrency_limit_layer() {
1621 use std::sync::atomic::{AtomicU32, Ordering};
1622 use std::time::Duration;
1623 use tower::limit::ConcurrencyLimitLayer;
1624
1625 #[derive(Debug, Deserialize, JsonSchema)]
1626 struct WorkInput {
1627 id: u32,
1628 }
1629
1630 let max_concurrent = Arc::new(AtomicU32::new(0));
1631 let current_concurrent = Arc::new(AtomicU32::new(0));
1632 let max_ref = max_concurrent.clone();
1633 let current_ref = current_concurrent.clone();
1634
1635 let tool = ToolBuilder::new("concurrent_tool")
1637 .description("A concurrent tool")
1638 .handler(move |input: WorkInput| {
1639 let max = max_ref.clone();
1640 let current = current_ref.clone();
1641 async move {
1642 let prev = current.fetch_add(1, Ordering::SeqCst);
1644 max.fetch_max(prev + 1, Ordering::SeqCst);
1645
1646 tokio::time::sleep(Duration::from_millis(50)).await;
1648
1649 current.fetch_sub(1, Ordering::SeqCst);
1650 Ok(CallToolResult::text(format!("completed {}", input.id)))
1651 }
1652 })
1653 .layer(ConcurrencyLimitLayer::new(2))
1654 .build()
1655 .expect("valid tool name");
1656
1657 let handles: Vec<_> = (0..4)
1659 .map(|i| {
1660 let t = tool.call(serde_json::json!({"id": i}));
1661 tokio::spawn(t)
1662 })
1663 .collect();
1664
1665 for handle in handles {
1666 let result = handle.await.unwrap();
1667 assert!(!result.is_error);
1668 }
1669
1670 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
1672 }
1673
1674 #[tokio::test]
1675 async fn test_tool_with_multiple_layers() {
1676 use std::time::Duration;
1677 use tower::limit::ConcurrencyLimitLayer;
1678 use tower::timeout::TimeoutLayer;
1679
1680 #[derive(Debug, Deserialize, JsonSchema)]
1681 struct Input {
1682 value: String,
1683 }
1684
1685 let tool = ToolBuilder::new("multi_layer_tool")
1687 .description("Tool with multiple layers")
1688 .handler(|input: Input| async move {
1689 Ok(CallToolResult::text(format!("processed: {}", input.value)))
1690 })
1691 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1692 .layer(ConcurrencyLimitLayer::new(10))
1693 .build()
1694 .expect("valid tool name");
1695
1696 let result = tool.call(serde_json::json!({"value": "test"})).await;
1697 assert!(!result.is_error);
1698 assert_eq!(result.first_text().unwrap(), "processed: test");
1699 }
1700
1701 #[test]
1702 fn test_tool_catch_error_clone() {
1703 let tool = ToolBuilder::new("test")
1706 .description("test")
1707 .extractor_handler((), |RawArgs(_args): RawArgs| async {
1708 Ok(CallToolResult::text("ok"))
1709 })
1710 .build()
1711 .unwrap();
1712 let _clone = tool.call(serde_json::json!({}));
1714 }
1715
1716 #[test]
1717 fn test_tool_catch_error_debug() {
1718 #[derive(Debug, Clone)]
1722 struct DebugService;
1723
1724 impl Service<ToolRequest> for DebugService {
1725 type Response = CallToolResult;
1726 type Error = crate::error::Error;
1727 type Future = Pin<
1728 Box<
1729 dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
1730 + Send,
1731 >,
1732 >;
1733
1734 fn poll_ready(
1735 &mut self,
1736 _cx: &mut std::task::Context<'_>,
1737 ) -> Poll<std::result::Result<(), Self::Error>> {
1738 Poll::Ready(Ok(()))
1739 }
1740
1741 fn call(&mut self, _req: ToolRequest) -> Self::Future {
1742 Box::pin(async { Ok(CallToolResult::text("ok")) })
1743 }
1744 }
1745
1746 let catch_error = ToolCatchError::new(DebugService);
1747 let debug = format!("{:?}", catch_error);
1748 assert!(debug.contains("ToolCatchError"));
1749 }
1750
1751 #[test]
1752 fn test_tool_request_new() {
1753 use crate::protocol::RequestId;
1754
1755 let ctx = RequestContext::new(RequestId::Number(42));
1756 let args = serde_json::json!({"key": "value"});
1757 let req = ToolRequest::new(ctx.clone(), args.clone());
1758
1759 assert_eq!(req.args, args);
1760 }
1761
1762 #[test]
1763 fn test_no_params_schema() {
1764 let schema = schemars::schema_for!(NoParams);
1766 let schema_value = serde_json::to_value(&schema).unwrap();
1767 assert_eq!(
1768 schema_value.get("type").and_then(|v| v.as_str()),
1769 Some("object"),
1770 "NoParams should generate type: object schema"
1771 );
1772 }
1773
1774 #[test]
1775 fn test_no_params_deserialize() {
1776 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1778 assert_eq!(from_empty_object, NoParams);
1779
1780 let from_null: NoParams = serde_json::from_str("null").unwrap();
1781 assert_eq!(from_null, NoParams);
1782
1783 let from_object_with_fields: NoParams =
1785 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1786 assert_eq!(from_object_with_fields, NoParams);
1787 }
1788
1789 #[tokio::test]
1790 async fn test_no_params_type_in_handler() {
1791 let tool = ToolBuilder::new("status")
1793 .description("Get status")
1794 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1795 .build()
1796 .expect("valid tool name");
1797
1798 let schema = tool.definition().input_schema;
1800 assert_eq!(
1801 schema.get("type").and_then(|v| v.as_str()),
1802 Some("object"),
1803 "NoParams handler should produce type: object schema"
1804 );
1805
1806 let result = tool.call(serde_json::json!({})).await;
1808 assert!(!result.is_error);
1809 }
1810
1811 #[tokio::test]
1812 async fn test_tool_with_name_prefix() {
1813 #[derive(Debug, Deserialize, JsonSchema)]
1814 struct Input {
1815 value: String,
1816 }
1817
1818 let tool = ToolBuilder::new("query")
1819 .description("Query something")
1820 .title("Query Tool")
1821 .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
1822 .build()
1823 .expect("valid tool name");
1824
1825 let prefixed = tool.with_name_prefix("db");
1827
1828 assert_eq!(prefixed.name, "db.query");
1830
1831 assert_eq!(prefixed.description.as_deref(), Some("Query something"));
1833 assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
1834
1835 let result = prefixed
1837 .call(serde_json::json!({"value": "test input"}))
1838 .await;
1839 assert!(!result.is_error);
1840 match &result.content[0] {
1841 Content::Text { text, .. } => assert_eq!(text, "test input"),
1842 _ => panic!("Expected text content"),
1843 }
1844 }
1845
1846 #[tokio::test]
1847 async fn test_tool_with_name_prefix_multiple_levels() {
1848 let tool = ToolBuilder::new("action")
1849 .description("Do something")
1850 .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
1851 .build()
1852 .expect("valid tool name");
1853
1854 let prefixed = tool.with_name_prefix("level1");
1856 assert_eq!(prefixed.name, "level1.action");
1857
1858 let double_prefixed = prefixed.with_name_prefix("level0");
1859 assert_eq!(double_prefixed.name, "level0.level1.action");
1860 }
1861}