1use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88#[derive(Debug, Clone)]
98pub struct Rejection {
99 message: String,
100}
101
102impl Rejection {
103 pub fn new(message: impl Into<String>) -> Self {
105 Self {
106 message: message.into(),
107 }
108 }
109
110 pub fn message(&self) -> &str {
112 &self.message
113 }
114}
115
116impl std::fmt::Display for Rejection {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{}", self.message)
119 }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125 fn from(rejection: Rejection) -> Self {
126 Error::tool(rejection.message)
127 }
128}
129
130#[derive(Debug, Clone)]
144pub struct JsonRejection {
145 message: String,
146 path: Option<String>,
148}
149
150impl JsonRejection {
151 pub fn new(message: impl Into<String>) -> Self {
153 Self {
154 message: message.into(),
155 path: None,
156 }
157 }
158
159 pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161 Self {
162 message: message.into(),
163 path: Some(path.into()),
164 }
165 }
166
167 pub fn message(&self) -> &str {
169 &self.message
170 }
171
172 pub fn path(&self) -> Option<&str> {
174 self.path.as_deref()
175 }
176}
177
178impl std::fmt::Display for JsonRejection {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 if let Some(path) = &self.path {
181 write!(f, "Invalid input at `{}`: {}", path, self.message)
182 } else {
183 write!(f, "Invalid input: {}", self.message)
184 }
185 }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191 fn from(rejection: JsonRejection) -> Self {
192 Error::tool(rejection.to_string())
193 }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197 fn from(err: serde_json::Error) -> Self {
198 let path = if err.is_data() {
200 None
203 } else {
204 None
205 };
206
207 Self {
208 message: err.to_string(),
209 path,
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229 type_name: &'static str,
230}
231
232impl ExtensionRejection {
233 pub fn not_found<T>() -> Self {
235 Self {
236 type_name: std::any::type_name::<T>(),
237 }
238 }
239
240 pub fn type_name(&self) -> &'static str {
242 self.type_name
243 }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(
249 f,
250 "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251 self.type_name
252 )
253 }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259 fn from(rejection: ExtensionRejection) -> Self {
260 Error::tool(rejection.to_string())
261 }
262}
263
264pub trait FromToolRequest<S = ()>: Sized {
295 type Rejection: Into<Error>;
297
298 fn from_tool_request(
306 ctx: &RequestContext,
307 state: &S,
308 args: &Value,
309 ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346 type Target = T;
347
348 fn deref(&self) -> &Self::Target {
349 &self.0
350 }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355 T: DeserializeOwned,
356{
357 type Rejection = JsonRejection;
358
359 fn from_tool_request(
360 _ctx: &RequestContext,
361 _state: &S,
362 args: &Value,
363 ) -> std::result::Result<Self, Self::Rejection> {
364 serde_json::from_value(args.clone())
365 .map(Json)
366 .map_err(JsonRejection::from)
367 }
368}
369
370#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398 type Target = T;
399
400 fn deref(&self) -> &Self::Target {
401 &self.0
402 }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406 type Rejection = Rejection;
407
408 fn from_tool_request(
409 _ctx: &RequestContext,
410 state: &S,
411 _args: &Value,
412 ) -> std::result::Result<Self, Self::Rejection> {
413 Ok(State(state.clone()))
414 }
415}
416
417#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441 pub fn into_inner(self) -> RequestContext {
443 self.0
444 }
445}
446
447impl Deref for Context {
448 type Target = RequestContext;
449
450 fn deref(&self) -> &Self::Target {
451 &self.0
452 }
453}
454
455impl<S> FromToolRequest<S> for Context {
456 type Rejection = Rejection;
457
458 fn from_tool_request(
459 ctx: &RequestContext,
460 _state: &S,
461 _args: &Value,
462 ) -> std::result::Result<Self, Self::Rejection> {
463 Ok(Context(ctx.clone()))
464 }
465}
466
467#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488 type Target = Value;
489
490 fn deref(&self) -> &Self::Target {
491 &self.0
492 }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496 type Rejection = Rejection;
497
498 fn from_tool_request(
499 _ctx: &RequestContext,
500 _state: &S,
501 args: &Value,
502 ) -> std::result::Result<Self, Self::Rejection> {
503 Ok(RawArgs(args.clone()))
504 }
505}
506
507#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557 type Target = T;
558
559 fn deref(&self) -> &Self::Target {
560 &self.0
561 }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566 T: Clone + Send + Sync + 'static,
567{
568 type Rejection = ExtensionRejection;
569
570 fn from_tool_request(
571 ctx: &RequestContext,
572 _state: &S,
573 _args: &Value,
574 ) -> std::result::Result<Self, Self::Rejection> {
575 ctx.extension::<T>()
576 .cloned()
577 .map(Extension)
578 .ok_or_else(ExtensionRejection::not_found::<T>)
579 }
580}
581
582pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592 type Future: Future<Output = Result<CallToolResult>> + Send;
594
595 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598 fn input_schema() -> Value;
602}
603
604impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607 S: Clone + Send + Sync + 'static,
608 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609 Fut: Future<Output = Result<CallToolResult>> + Send,
610 T1: FromToolRequest<S> + HasSchema + Send,
611{
612 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615 Box::pin(async move {
616 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617 self(t1).await
618 })
619 }
620
621 fn input_schema() -> Value {
622 if let Some(schema) = T1::schema() {
623 return schema;
624 }
625 serde_json::json!({
626 "type": "object",
627 "additionalProperties": true
628 })
629 }
630}
631
632impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635 S: Clone + Send + Sync + 'static,
636 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637 Fut: Future<Output = Result<CallToolResult>> + Send,
638 T1: FromToolRequest<S> + HasSchema + Send,
639 T2: FromToolRequest<S> + HasSchema + Send,
640{
641 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644 Box::pin(async move {
645 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647 self(t1, t2).await
648 })
649 }
650
651 fn input_schema() -> Value {
652 if let Some(schema) = T2::schema() {
653 return schema;
654 }
655 if let Some(schema) = T1::schema() {
656 return schema;
657 }
658 serde_json::json!({
659 "type": "object",
660 "additionalProperties": true
661 })
662 }
663}
664
665impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668 S: Clone + Send + Sync + 'static,
669 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670 Fut: Future<Output = Result<CallToolResult>> + Send,
671 T1: FromToolRequest<S> + HasSchema + Send,
672 T2: FromToolRequest<S> + HasSchema + Send,
673 T3: FromToolRequest<S> + HasSchema + Send,
674{
675 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678 Box::pin(async move {
679 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682 self(t1, t2, t3).await
683 })
684 }
685
686 fn input_schema() -> Value {
687 if let Some(schema) = T3::schema() {
688 return schema;
689 }
690 if let Some(schema) = T2::schema() {
691 return schema;
692 }
693 if let Some(schema) = T1::schema() {
694 return schema;
695 }
696 serde_json::json!({
697 "type": "object",
698 "additionalProperties": true
699 })
700 }
701}
702
703impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706 S: Clone + Send + Sync + 'static,
707 F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708 Fut: Future<Output = Result<CallToolResult>> + Send,
709 T1: FromToolRequest<S> + HasSchema + Send,
710 T2: FromToolRequest<S> + HasSchema + Send,
711 T3: FromToolRequest<S> + HasSchema + Send,
712 T4: FromToolRequest<S> + HasSchema + Send,
713{
714 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717 Box::pin(async move {
718 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722 self(t1, t2, t3, t4).await
723 })
724 }
725
726 fn input_schema() -> Value {
727 if let Some(schema) = T4::schema() {
728 return schema;
729 }
730 if let Some(schema) = T3::schema() {
731 return schema;
732 }
733 if let Some(schema) = T2::schema() {
734 return schema;
735 }
736 if let Some(schema) = T1::schema() {
737 return schema;
738 }
739 serde_json::json!({
740 "type": "object",
741 "additionalProperties": true
742 })
743 }
744}
745
746impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749 S: Clone + Send + Sync + 'static,
750 F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751 Fut: Future<Output = Result<CallToolResult>> + Send,
752 T1: FromToolRequest<S> + HasSchema + Send,
753 T2: FromToolRequest<S> + HasSchema + Send,
754 T3: FromToolRequest<S> + HasSchema + Send,
755 T4: FromToolRequest<S> + HasSchema + Send,
756 T5: FromToolRequest<S> + HasSchema + Send,
757{
758 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761 Box::pin(async move {
762 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766 let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767 self(t1, t2, t3, t4, t5).await
768 })
769 }
770
771 fn input_schema() -> Value {
772 if let Some(schema) = T5::schema() {
773 return schema;
774 }
775 if let Some(schema) = T4::schema() {
776 return schema;
777 }
778 if let Some(schema) = T3::schema() {
779 return schema;
780 }
781 if let Some(schema) = T2::schema() {
782 return schema;
783 }
784 if let Some(schema) = T1::schema() {
785 return schema;
786 }
787 serde_json::json!({
788 "type": "object",
789 "additionalProperties": true
790 })
791 }
792}
793
794pub trait HasSchema {
800 fn schema() -> Option<Value>;
802}
803
804impl<T: JsonSchema> HasSchema for Json<T> {
805 fn schema() -> Option<Value> {
806 let schema = schemars::schema_for!(T);
807 serde_json::to_value(schema)
808 .ok()
809 .map(crate::tool::ensure_object_schema)
810 }
811}
812
813impl HasSchema for Context {
815 fn schema() -> Option<Value> {
816 None
817 }
818}
819
820impl HasSchema for RawArgs {
821 fn schema() -> Option<Value> {
822 None
823 }
824}
825
826impl<T> HasSchema for State<T> {
827 fn schema() -> Option<Value> {
828 None
829 }
830}
831
832impl<T> HasSchema for Extension<T> {
833 fn schema() -> Option<Value> {
834 None
835 }
836}
837
838#[deprecated(
847 since = "0.8.0",
848 note = "Use `ExtractorHandler` instead -- `extractor_handler` auto-detects JSON schema from `Json<T>` extractors"
849)]
850pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
851where
852 I: JsonSchema,
853{
854 type Future: Future<Output = Result<CallToolResult>> + Send;
856
857 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
859}
860
861#[allow(deprecated)]
863impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
864where
865 S: Clone + Send + Sync + 'static,
866 F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
867 Fut: Future<Output = Result<CallToolResult>> + Send,
868 T: DeserializeOwned + JsonSchema + Send,
869{
870 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
871
872 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
873 Box::pin(async move {
874 let t1 =
875 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
876 self(t1).await
877 })
878 }
879}
880
881#[allow(deprecated)]
883impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
884where
885 S: Clone + Send + Sync + 'static,
886 F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
887 Fut: Future<Output = Result<CallToolResult>> + Send,
888 T1: FromToolRequest<S> + Send,
889 T: DeserializeOwned + JsonSchema + Send,
890{
891 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
892
893 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
894 Box::pin(async move {
895 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
896 let t2 =
897 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
898 self(t1, t2).await
899 })
900 }
901}
902
903#[allow(deprecated)]
905impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
906where
907 S: Clone + Send + Sync + 'static,
908 F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
909 Fut: Future<Output = Result<CallToolResult>> + Send,
910 T1: FromToolRequest<S> + Send,
911 T2: FromToolRequest<S> + Send,
912 T: DeserializeOwned + JsonSchema + Send,
913{
914 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
915
916 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
917 Box::pin(async move {
918 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
919 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
920 let t3 =
921 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
922 self(t1, t2, t3).await
923 })
924 }
925}
926
927#[allow(deprecated)]
929impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
930where
931 S: Clone + Send + Sync + 'static,
932 F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
933 Fut: Future<Output = Result<CallToolResult>> + Send,
934 T1: FromToolRequest<S> + Send,
935 T2: FromToolRequest<S> + Send,
936 T3: FromToolRequest<S> + Send,
937 T: DeserializeOwned + JsonSchema + Send,
938{
939 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
940
941 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
942 Box::pin(async move {
943 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
944 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
945 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
946 let t4 =
947 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
948 self(t1, t2, t3, t4).await
949 })
950 }
951}
952
953use crate::tool::{
958 BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
959};
960use tower::util::BoxCloneService;
961use tower_service::Service;
962
963pub(crate) struct ExtractorToolHandler<S, F, T> {
965 state: S,
966 handler: F,
967 input_schema: Value,
968 _phantom: PhantomData<T>,
969}
970
971impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
972where
973 S: Clone + Send + Sync + 'static,
974 F: ExtractorHandler<S, T> + Clone,
975 T: Send + Sync + 'static,
976{
977 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
978 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
979 self.call_with_context(ctx, args)
980 }
981
982 fn call_with_context(
983 &self,
984 ctx: RequestContext,
985 args: Value,
986 ) -> BoxFuture<'_, Result<CallToolResult>> {
987 let state = self.state.clone();
988 let handler = self.handler.clone();
989 Box::pin(async move { handler.call(ctx, state, args).await })
990 }
991
992 fn uses_context(&self) -> bool {
993 true
994 }
995
996 fn input_schema(&self) -> Value {
997 self.input_schema.clone()
998 }
999}
1000
1001#[doc(hidden)]
1003pub struct ToolBuilderWithExtractor<S, F, T> {
1004 pub(crate) name: String,
1005 pub(crate) title: Option<String>,
1006 pub(crate) description: Option<String>,
1007 pub(crate) output_schema: Option<Value>,
1008 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1009 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1010 pub(crate) task_support: crate::protocol::TaskSupportMode,
1011 pub(crate) state: S,
1012 pub(crate) handler: F,
1013 pub(crate) input_schema: Value,
1014 pub(crate) _phantom: PhantomData<T>,
1015}
1016
1017impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1018where
1019 S: Clone + Send + Sync + 'static,
1020 F: ExtractorHandler<S, T> + Clone,
1021 T: Send + Sync + 'static,
1022{
1023 pub fn build(self) -> Tool {
1025 let handler = ExtractorToolHandler {
1026 state: self.state,
1027 handler: self.handler,
1028 input_schema: self.input_schema.clone(),
1029 _phantom: PhantomData,
1030 };
1031
1032 let handler_service = ToolHandlerService::new(handler);
1033 let catch_error = ToolCatchError::new(handler_service);
1034 let service = BoxCloneService::new(catch_error);
1035
1036 Tool {
1037 name: self.name,
1038 title: self.title,
1039 description: self.description,
1040 output_schema: self.output_schema,
1041 icons: self.icons,
1042 annotations: self.annotations,
1043 task_support: self.task_support,
1044 service,
1045 input_schema: self.input_schema,
1046 }
1047 }
1048
1049 pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1085 ToolBuilderWithExtractorLayer {
1086 name: self.name,
1087 title: self.title,
1088 description: self.description,
1089 output_schema: self.output_schema,
1090 icons: self.icons,
1091 annotations: self.annotations,
1092 task_support: self.task_support,
1093 state: self.state,
1094 handler: self.handler,
1095 input_schema: self.input_schema,
1096 layer,
1097 _phantom: PhantomData,
1098 }
1099 }
1100
1101 pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1105 where
1106 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1107 {
1108 self.layer(GuardLayer::new(guard))
1109 }
1110}
1111
1112#[doc(hidden)]
1116pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1117 name: String,
1118 title: Option<String>,
1119 description: Option<String>,
1120 output_schema: Option<Value>,
1121 icons: Option<Vec<crate::protocol::ToolIcon>>,
1122 annotations: Option<crate::protocol::ToolAnnotations>,
1123 task_support: crate::protocol::TaskSupportMode,
1124 state: S,
1125 handler: F,
1126 input_schema: Value,
1127 layer: L,
1128 _phantom: PhantomData<T>,
1129}
1130
1131#[allow(private_bounds)]
1132impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1133where
1134 S: Clone + Send + Sync + 'static,
1135 F: ExtractorHandler<S, T> + Clone,
1136 T: Send + Sync + 'static,
1137 L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1138 + Clone
1139 + Send
1140 + Sync
1141 + 'static,
1142 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1143 <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1144 <L::Service as Service<ToolRequest>>::Future: Send,
1145{
1146 pub fn build(self) -> Tool {
1148 let handler = ExtractorToolHandler {
1149 state: self.state,
1150 handler: self.handler,
1151 input_schema: self.input_schema.clone(),
1152 _phantom: PhantomData,
1153 };
1154
1155 let handler_service = ToolHandlerService::new(handler);
1156 let layered = self.layer.layer(handler_service);
1157 let catch_error = ToolCatchError::new(layered);
1158 let service = BoxCloneService::new(catch_error);
1159
1160 Tool {
1161 name: self.name,
1162 title: self.title,
1163 description: self.description,
1164 output_schema: self.output_schema,
1165 icons: self.icons,
1166 annotations: self.annotations,
1167 task_support: self.task_support,
1168 service,
1169 input_schema: self.input_schema,
1170 }
1171 }
1172
1173 pub fn layer<L2>(
1178 self,
1179 layer: L2,
1180 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1181 ToolBuilderWithExtractorLayer {
1182 name: self.name,
1183 title: self.title,
1184 description: self.description,
1185 output_schema: self.output_schema,
1186 icons: self.icons,
1187 annotations: self.annotations,
1188 task_support: self.task_support,
1189 state: self.state,
1190 handler: self.handler,
1191 input_schema: self.input_schema,
1192 layer: tower::layer::util::Stack::new(layer, self.layer),
1193 _phantom: PhantomData,
1194 }
1195 }
1196
1197 pub fn guard<G>(
1201 self,
1202 guard: G,
1203 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1204 where
1205 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1206 {
1207 self.layer(GuardLayer::new(guard))
1208 }
1209}
1210
1211#[doc(hidden)]
1213#[deprecated(
1214 since = "0.8.0",
1215 note = "Use `ToolBuilderWithExtractor` via `extractor_handler` instead"
1216)]
1217pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1218 pub(crate) name: String,
1219 pub(crate) title: Option<String>,
1220 pub(crate) description: Option<String>,
1221 pub(crate) output_schema: Option<Value>,
1222 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1223 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1224 pub(crate) task_support: crate::protocol::TaskSupportMode,
1225 pub(crate) state: S,
1226 pub(crate) handler: F,
1227 pub(crate) _phantom: PhantomData<(T, I)>,
1228}
1229
1230#[allow(deprecated)]
1231impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1232where
1233 S: Clone + Send + Sync + 'static,
1234 F: TypedExtractorHandler<S, T, I> + Clone,
1235 T: Send + Sync + 'static,
1236 I: JsonSchema + Send + Sync + 'static,
1237{
1238 pub fn build(self) -> Tool {
1240 let input_schema = {
1241 let schema = schemars::schema_for!(I);
1242 let schema = serde_json::to_value(schema).unwrap_or_else(|_| {
1243 serde_json::json!({
1244 "type": "object"
1245 })
1246 });
1247 crate::tool::ensure_object_schema(schema)
1248 };
1249
1250 let handler = TypedExtractorToolHandler {
1251 state: self.state,
1252 handler: self.handler,
1253 input_schema: input_schema.clone(),
1254 _phantom: PhantomData,
1255 };
1256
1257 let handler_service = crate::tool::ToolHandlerService::new(handler);
1258 let catch_error = ToolCatchError::new(handler_service);
1259 let service = BoxCloneService::new(catch_error);
1260
1261 Tool {
1262 name: self.name,
1263 title: self.title,
1264 description: self.description,
1265 output_schema: self.output_schema,
1266 icons: self.icons,
1267 annotations: self.annotations,
1268 task_support: self.task_support,
1269 service,
1270 input_schema,
1271 }
1272 }
1273}
1274
1275struct TypedExtractorToolHandler<S, F, T, I> {
1277 state: S,
1278 handler: F,
1279 input_schema: Value,
1280 _phantom: PhantomData<(T, I)>,
1281}
1282
1283#[allow(deprecated)]
1284impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1285where
1286 S: Clone + Send + Sync + 'static,
1287 F: TypedExtractorHandler<S, T, I> + Clone,
1288 T: Send + Sync + 'static,
1289 I: JsonSchema + Send + Sync + 'static,
1290{
1291 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1292 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1293 self.call_with_context(ctx, args)
1294 }
1295
1296 fn call_with_context(
1297 &self,
1298 ctx: RequestContext,
1299 args: Value,
1300 ) -> BoxFuture<'_, Result<CallToolResult>> {
1301 let state = self.state.clone();
1302 let handler = self.handler.clone();
1303 Box::pin(async move { handler.call(ctx, state, args).await })
1304 }
1305
1306 fn uses_context(&self) -> bool {
1307 true
1308 }
1309
1310 fn input_schema(&self) -> Value {
1311 self.input_schema.clone()
1312 }
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317 use super::*;
1318 use crate::protocol::RequestId;
1319 use schemars::JsonSchema;
1320 use serde::Deserialize;
1321 use std::sync::Arc;
1322
1323 #[derive(Debug, Deserialize, JsonSchema)]
1324 struct TestInput {
1325 name: String,
1326 count: i32,
1327 }
1328
1329 #[test]
1330 fn test_json_extraction() {
1331 let args = serde_json::json!({"name": "test", "count": 42});
1332 let ctx = RequestContext::new(RequestId::Number(1));
1333
1334 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1335 assert!(result.is_ok());
1336 let Json(input) = result.unwrap();
1337 assert_eq!(input.name, "test");
1338 assert_eq!(input.count, 42);
1339 }
1340
1341 #[test]
1342 fn test_json_extraction_error() {
1343 let args = serde_json::json!({"name": "test"}); let ctx = RequestContext::new(RequestId::Number(1));
1345
1346 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1347 assert!(result.is_err());
1348 let rejection = result.unwrap_err();
1349 assert!(rejection.message().contains("count"));
1351 }
1352
1353 #[test]
1354 fn test_state_extraction() {
1355 let args = serde_json::json!({});
1356 let ctx = RequestContext::new(RequestId::Number(1));
1357 let state = Arc::new("my-state".to_string());
1358
1359 let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1360 assert!(result.is_ok());
1361 let State(extracted) = result.unwrap();
1362 assert_eq!(*extracted, "my-state");
1363 }
1364
1365 #[test]
1366 fn test_context_extraction() {
1367 let args = serde_json::json!({});
1368 let ctx = RequestContext::new(RequestId::Number(42));
1369
1370 let result = Context::from_tool_request(&ctx, &(), &args);
1371 assert!(result.is_ok());
1372 let extracted = result.unwrap();
1373 assert_eq!(*extracted.request_id(), RequestId::Number(42));
1374 }
1375
1376 #[test]
1377 fn test_raw_args_extraction() {
1378 let args = serde_json::json!({"foo": "bar", "baz": 123});
1379 let ctx = RequestContext::new(RequestId::Number(1));
1380
1381 let result = RawArgs::from_tool_request(&ctx, &(), &args);
1382 assert!(result.is_ok());
1383 let RawArgs(extracted) = result.unwrap();
1384 assert_eq!(extracted["foo"], "bar");
1385 assert_eq!(extracted["baz"], 123);
1386 }
1387
1388 #[test]
1389 fn test_extension_extraction() {
1390 use crate::context::Extensions;
1391
1392 #[derive(Clone, Debug, PartialEq)]
1393 struct DatabasePool {
1394 url: String,
1395 }
1396
1397 let args = serde_json::json!({});
1398
1399 let mut extensions = Extensions::new();
1401 extensions.insert(Arc::new(DatabasePool {
1402 url: "postgres://localhost".to_string(),
1403 }));
1404
1405 let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1407
1408 let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1410 assert!(result.is_ok());
1411 let Extension(pool) = result.unwrap();
1412 assert_eq!(pool.url, "postgres://localhost");
1413 }
1414
1415 #[test]
1416 fn test_extension_extraction_missing() {
1417 #[derive(Clone, Debug)]
1418 struct NotPresent;
1419
1420 let args = serde_json::json!({});
1421 let ctx = RequestContext::new(RequestId::Number(1));
1422
1423 let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1425 assert!(result.is_err());
1426 let rejection = result.unwrap_err();
1427 assert!(rejection.type_name().contains("NotPresent"));
1429 }
1430
1431 #[tokio::test]
1432 async fn test_single_extractor_handler() {
1433 let handler = |Json(input): Json<TestInput>| async move {
1434 Ok(CallToolResult::text(format!(
1435 "{}: {}",
1436 input.name, input.count
1437 )))
1438 };
1439
1440 let ctx = RequestContext::new(RequestId::Number(1));
1441 let args = serde_json::json!({"name": "test", "count": 5});
1442
1443 let result: Result<CallToolResult> =
1445 ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1446 assert!(result.is_ok());
1447 }
1448
1449 #[tokio::test]
1450 async fn test_two_extractor_handler() {
1451 let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1452 Ok(CallToolResult::text(format!(
1453 "{}: {} - {}",
1454 state, input.name, input.count
1455 )))
1456 };
1457
1458 let ctx = RequestContext::new(RequestId::Number(1));
1459 let state = Arc::new("prefix".to_string());
1460 let args = serde_json::json!({"name": "test", "count": 5});
1461
1462 let result: Result<CallToolResult> = ExtractorHandler::<
1464 Arc<String>,
1465 (State<Arc<String>>, Json<TestInput>),
1466 >::call(handler, ctx, state, args)
1467 .await;
1468 assert!(result.is_ok());
1469 }
1470
1471 #[tokio::test]
1472 async fn test_three_extractor_handler() {
1473 let handler = |State(state): State<Arc<String>>,
1474 ctx: Context,
1475 Json(input): Json<TestInput>| async move {
1476 assert!(!ctx.is_cancelled());
1478 Ok(CallToolResult::text(format!(
1479 "{}: {} - {}",
1480 state, input.name, input.count
1481 )))
1482 };
1483
1484 let ctx = RequestContext::new(RequestId::Number(1));
1485 let state = Arc::new("prefix".to_string());
1486 let args = serde_json::json!({"name": "test", "count": 5});
1487
1488 let result: Result<CallToolResult> = ExtractorHandler::<
1490 Arc<String>,
1491 (State<Arc<String>>, Context, Json<TestInput>),
1492 >::call(handler, ctx, state, args)
1493 .await;
1494 assert!(result.is_ok());
1495 }
1496
1497 #[test]
1498 fn test_json_schema_generation() {
1499 let schema = Json::<TestInput>::schema();
1500 assert!(schema.is_some());
1501 let schema = schema.unwrap();
1502 assert!(schema.get("properties").is_some());
1503 }
1504
1505 #[test]
1506 fn test_rejection_into_error() {
1507 let rejection = Rejection::new("test error");
1508 let error: Error = rejection.into();
1509 assert!(error.to_string().contains("test error"));
1510 }
1511
1512 #[test]
1513 fn test_json_rejection() {
1514 let rejection = JsonRejection::new("missing field `name`");
1516 assert_eq!(rejection.message(), "missing field `name`");
1517 assert!(rejection.path().is_none());
1518 assert!(rejection.to_string().contains("Invalid input"));
1519
1520 let rejection = JsonRejection::with_path("expected string", "users[0].name");
1522 assert_eq!(rejection.message(), "expected string");
1523 assert_eq!(rejection.path(), Some("users[0].name"));
1524 assert!(rejection.to_string().contains("users[0].name"));
1525
1526 let error: Error = rejection.into();
1528 assert!(error.to_string().contains("users[0].name"));
1529 }
1530
1531 #[test]
1532 fn test_json_rejection_from_serde_error() {
1533 #[derive(Debug, serde::Deserialize)]
1535 struct TestStruct {
1536 #[allow(dead_code)]
1537 name: String,
1538 }
1539
1540 let result: std::result::Result<TestStruct, _> =
1541 serde_json::from_value(serde_json::json!({"count": 42}));
1542 assert!(result.is_err());
1543
1544 let rejection: JsonRejection = result.unwrap_err().into();
1545 assert!(rejection.message().contains("name"));
1546 }
1547
1548 #[test]
1549 fn test_extension_rejection() {
1550 let rejection = ExtensionRejection::not_found::<String>();
1552 assert!(rejection.type_name().contains("String"));
1553 assert!(rejection.to_string().contains("not found"));
1554 assert!(rejection.to_string().contains("with_state"));
1555
1556 let error: Error = rejection.into();
1558 assert!(error.to_string().contains("not found"));
1559 }
1560
1561 #[tokio::test]
1562 async fn test_tool_builder_extractor_handler() {
1563 use crate::ToolBuilder;
1564
1565 let state = Arc::new("shared-state".to_string());
1566
1567 let tool =
1568 ToolBuilder::new("test_extractor")
1569 .description("Test extractor handler")
1570 .extractor_handler(
1571 state,
1572 |State(state): State<Arc<String>>,
1573 ctx: Context,
1574 Json(input): Json<TestInput>| async move {
1575 assert!(!ctx.is_cancelled());
1576 Ok(CallToolResult::text(format!(
1577 "{}: {} - {}",
1578 state, input.name, input.count
1579 )))
1580 },
1581 )
1582 .build();
1583
1584 assert_eq!(tool.name, "test_extractor");
1585 assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1586
1587 let result = tool
1589 .call(serde_json::json!({"name": "test", "count": 42}))
1590 .await;
1591 assert!(!result.is_error);
1592 }
1593
1594 #[tokio::test]
1595 #[allow(deprecated)]
1596 async fn test_tool_builder_extractor_handler_typed() {
1597 use crate::ToolBuilder;
1598
1599 let state = Arc::new("typed-state".to_string());
1600
1601 let tool = ToolBuilder::new("test_typed")
1602 .description("Test typed extractor handler")
1603 .extractor_handler_typed::<_, _, _, TestInput>(
1604 state,
1605 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1606 Ok(CallToolResult::text(format!(
1607 "{}: {} - {}",
1608 state, input.name, input.count
1609 )))
1610 },
1611 )
1612 .build();
1613
1614 assert_eq!(tool.name, "test_typed");
1615
1616 let def = tool.definition();
1618 let schema = def.input_schema;
1619 assert!(schema.get("properties").is_some());
1620
1621 let result = tool
1623 .call(serde_json::json!({"name": "world", "count": 99}))
1624 .await;
1625 assert!(!result.is_error);
1626 }
1627
1628 #[tokio::test]
1629 async fn test_extractor_handler_auto_schema() {
1630 use crate::ToolBuilder;
1631
1632 let state = Arc::new("auto-schema".to_string());
1633
1634 let tool = ToolBuilder::new("test_auto_schema")
1636 .description("Test auto schema detection")
1637 .extractor_handler(
1638 state,
1639 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1640 Ok(CallToolResult::text(format!(
1641 "{}: {} - {}",
1642 state, input.name, input.count
1643 )))
1644 },
1645 )
1646 .build();
1647
1648 let def = tool.definition();
1650 let schema = def.input_schema;
1651 assert!(
1652 schema.get("properties").is_some(),
1653 "Schema should have properties from TestInput, got: {}",
1654 schema
1655 );
1656 let props = schema.get("properties").unwrap();
1657 assert!(
1658 props.get("name").is_some(),
1659 "Schema should have 'name' property"
1660 );
1661 assert!(
1662 props.get("count").is_some(),
1663 "Schema should have 'count' property"
1664 );
1665
1666 let result = tool
1668 .call(serde_json::json!({"name": "world", "count": 99}))
1669 .await;
1670 assert!(!result.is_error);
1671 }
1672
1673 #[test]
1674 fn test_extractor_handler_no_json_fallback() {
1675 use crate::ToolBuilder;
1676
1677 let tool = ToolBuilder::new("test_no_json")
1679 .description("Test no json fallback")
1680 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1681 Ok(CallToolResult::json(args))
1682 })
1683 .build();
1684
1685 let def = tool.definition();
1686 let schema = def.input_schema;
1687 assert_eq!(
1688 schema.get("type").and_then(|v| v.as_str()),
1689 Some("object"),
1690 "Schema should be generic object"
1691 );
1692 assert_eq!(
1693 schema.get("additionalProperties").and_then(|v| v.as_bool()),
1694 Some(true),
1695 "Schema should allow additional properties"
1696 );
1697 assert!(
1699 schema.get("properties").is_none(),
1700 "Generic schema should not have specific properties"
1701 );
1702 }
1703
1704 #[tokio::test]
1705 async fn test_extractor_handler_with_layer() {
1706 use crate::ToolBuilder;
1707 use std::time::Duration;
1708 use tower::timeout::TimeoutLayer;
1709
1710 let state = Arc::new("layered".to_string());
1711
1712 let tool = ToolBuilder::new("test_extractor_layer")
1713 .description("Test extractor handler with layer")
1714 .extractor_handler(
1715 state,
1716 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1717 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1718 },
1719 )
1720 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1721 .build();
1722
1723 let result = tool
1725 .call(serde_json::json!({"name": "test", "count": 1}))
1726 .await;
1727 assert!(!result.is_error);
1728 assert_eq!(result.first_text().unwrap(), "layered: test");
1729
1730 let def = tool.definition();
1732 let schema = def.input_schema;
1733 assert!(
1734 schema.get("properties").is_some(),
1735 "Schema should have properties even with layer"
1736 );
1737 }
1738
1739 #[tokio::test]
1740 async fn test_extractor_handler_with_timeout_layer() {
1741 use crate::ToolBuilder;
1742 use std::time::Duration;
1743 use tower::timeout::TimeoutLayer;
1744
1745 let tool = ToolBuilder::new("test_extractor_timeout")
1746 .description("Test extractor handler timeout")
1747 .extractor_handler((), |Json(input): Json<TestInput>| async move {
1748 tokio::time::sleep(Duration::from_millis(200)).await;
1749 Ok(CallToolResult::text(input.name.to_string()))
1750 })
1751 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1752 .build();
1753
1754 let result = tool
1756 .call(serde_json::json!({"name": "slow", "count": 1}))
1757 .await;
1758 assert!(result.is_error);
1759 let msg = result.first_text().unwrap().to_lowercase();
1760 assert!(
1761 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1762 "Expected timeout error, got: {}",
1763 msg
1764 );
1765 }
1766
1767 #[tokio::test]
1768 async fn test_extractor_handler_with_multiple_layers() {
1769 use crate::ToolBuilder;
1770 use std::time::Duration;
1771 use tower::limit::ConcurrencyLimitLayer;
1772 use tower::timeout::TimeoutLayer;
1773
1774 let state = Arc::new("multi".to_string());
1775
1776 let tool = ToolBuilder::new("test_multi_layer")
1777 .description("Test multiple layers")
1778 .extractor_handler(
1779 state,
1780 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1781 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1782 },
1783 )
1784 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1785 .layer(ConcurrencyLimitLayer::new(10))
1786 .build();
1787
1788 let result = tool
1789 .call(serde_json::json!({"name": "test", "count": 1}))
1790 .await;
1791 assert!(!result.is_error);
1792 assert_eq!(result.first_text().unwrap(), "multi: test");
1793 }
1794}