1use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88#[derive(Debug, Clone)]
98pub struct Rejection {
99 message: String,
100}
101
102impl Rejection {
103 pub fn new(message: impl Into<String>) -> Self {
105 Self {
106 message: message.into(),
107 }
108 }
109
110 pub fn message(&self) -> &str {
112 &self.message
113 }
114}
115
116impl std::fmt::Display for Rejection {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{}", self.message)
119 }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125 fn from(rejection: Rejection) -> Self {
126 Error::tool(rejection.message)
127 }
128}
129
130#[derive(Debug, Clone)]
144pub struct JsonRejection {
145 message: String,
146 path: Option<String>,
148}
149
150impl JsonRejection {
151 pub fn new(message: impl Into<String>) -> Self {
153 Self {
154 message: message.into(),
155 path: None,
156 }
157 }
158
159 pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161 Self {
162 message: message.into(),
163 path: Some(path.into()),
164 }
165 }
166
167 pub fn message(&self) -> &str {
169 &self.message
170 }
171
172 pub fn path(&self) -> Option<&str> {
174 self.path.as_deref()
175 }
176}
177
178impl std::fmt::Display for JsonRejection {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 if let Some(path) = &self.path {
181 write!(f, "Invalid input at `{}`: {}", path, self.message)
182 } else {
183 write!(f, "Invalid input: {}", self.message)
184 }
185 }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191 fn from(rejection: JsonRejection) -> Self {
192 Error::tool(rejection.to_string())
193 }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197 fn from(err: serde_json::Error) -> Self {
198 let path = if err.is_data() {
200 None
203 } else {
204 None
205 };
206
207 Self {
208 message: err.to_string(),
209 path,
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229 type_name: &'static str,
230}
231
232impl ExtensionRejection {
233 pub fn not_found<T>() -> Self {
235 Self {
236 type_name: std::any::type_name::<T>(),
237 }
238 }
239
240 pub fn type_name(&self) -> &'static str {
242 self.type_name
243 }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(
249 f,
250 "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251 self.type_name
252 )
253 }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259 fn from(rejection: ExtensionRejection) -> Self {
260 Error::tool(rejection.to_string())
261 }
262}
263
264pub trait FromToolRequest<S = ()>: Sized {
295 type Rejection: Into<Error>;
297
298 fn from_tool_request(
306 ctx: &RequestContext,
307 state: &S,
308 args: &Value,
309 ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346 type Target = T;
347
348 fn deref(&self) -> &Self::Target {
349 &self.0
350 }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355 T: DeserializeOwned,
356{
357 type Rejection = JsonRejection;
358
359 fn from_tool_request(
360 _ctx: &RequestContext,
361 _state: &S,
362 args: &Value,
363 ) -> std::result::Result<Self, Self::Rejection> {
364 serde_json::from_value(args.clone())
365 .map(Json)
366 .map_err(JsonRejection::from)
367 }
368}
369
370#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398 type Target = T;
399
400 fn deref(&self) -> &Self::Target {
401 &self.0
402 }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406 type Rejection = Rejection;
407
408 fn from_tool_request(
409 _ctx: &RequestContext,
410 state: &S,
411 _args: &Value,
412 ) -> std::result::Result<Self, Self::Rejection> {
413 Ok(State(state.clone()))
414 }
415}
416
417#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441 pub fn into_inner(self) -> RequestContext {
443 self.0
444 }
445}
446
447impl Deref for Context {
448 type Target = RequestContext;
449
450 fn deref(&self) -> &Self::Target {
451 &self.0
452 }
453}
454
455impl<S> FromToolRequest<S> for Context {
456 type Rejection = Rejection;
457
458 fn from_tool_request(
459 ctx: &RequestContext,
460 _state: &S,
461 _args: &Value,
462 ) -> std::result::Result<Self, Self::Rejection> {
463 Ok(Context(ctx.clone()))
464 }
465}
466
467#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488 type Target = Value;
489
490 fn deref(&self) -> &Self::Target {
491 &self.0
492 }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496 type Rejection = Rejection;
497
498 fn from_tool_request(
499 _ctx: &RequestContext,
500 _state: &S,
501 args: &Value,
502 ) -> std::result::Result<Self, Self::Rejection> {
503 Ok(RawArgs(args.clone()))
504 }
505}
506
507#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557 type Target = T;
558
559 fn deref(&self) -> &Self::Target {
560 &self.0
561 }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566 T: Clone + Send + Sync + 'static,
567{
568 type Rejection = ExtensionRejection;
569
570 fn from_tool_request(
571 ctx: &RequestContext,
572 _state: &S,
573 _args: &Value,
574 ) -> std::result::Result<Self, Self::Rejection> {
575 ctx.extension::<T>()
576 .cloned()
577 .map(Extension)
578 .ok_or_else(ExtensionRejection::not_found::<T>)
579 }
580}
581
582pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592 type Future: Future<Output = Result<CallToolResult>> + Send;
594
595 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598 fn input_schema() -> Value;
602}
603
604impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607 S: Clone + Send + Sync + 'static,
608 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609 Fut: Future<Output = Result<CallToolResult>> + Send,
610 T1: FromToolRequest<S> + HasSchema + Send,
611{
612 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615 Box::pin(async move {
616 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617 self(t1).await
618 })
619 }
620
621 fn input_schema() -> Value {
622 if let Some(schema) = T1::schema() {
623 return schema;
624 }
625 serde_json::json!({
626 "type": "object",
627 "additionalProperties": true
628 })
629 }
630}
631
632impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635 S: Clone + Send + Sync + 'static,
636 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637 Fut: Future<Output = Result<CallToolResult>> + Send,
638 T1: FromToolRequest<S> + HasSchema + Send,
639 T2: FromToolRequest<S> + HasSchema + Send,
640{
641 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644 Box::pin(async move {
645 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647 self(t1, t2).await
648 })
649 }
650
651 fn input_schema() -> Value {
652 if let Some(schema) = T2::schema() {
653 return schema;
654 }
655 if let Some(schema) = T1::schema() {
656 return schema;
657 }
658 serde_json::json!({
659 "type": "object",
660 "additionalProperties": true
661 })
662 }
663}
664
665impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668 S: Clone + Send + Sync + 'static,
669 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670 Fut: Future<Output = Result<CallToolResult>> + Send,
671 T1: FromToolRequest<S> + HasSchema + Send,
672 T2: FromToolRequest<S> + HasSchema + Send,
673 T3: FromToolRequest<S> + HasSchema + Send,
674{
675 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678 Box::pin(async move {
679 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682 self(t1, t2, t3).await
683 })
684 }
685
686 fn input_schema() -> Value {
687 if let Some(schema) = T3::schema() {
688 return schema;
689 }
690 if let Some(schema) = T2::schema() {
691 return schema;
692 }
693 if let Some(schema) = T1::schema() {
694 return schema;
695 }
696 serde_json::json!({
697 "type": "object",
698 "additionalProperties": true
699 })
700 }
701}
702
703impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706 S: Clone + Send + Sync + 'static,
707 F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708 Fut: Future<Output = Result<CallToolResult>> + Send,
709 T1: FromToolRequest<S> + HasSchema + Send,
710 T2: FromToolRequest<S> + HasSchema + Send,
711 T3: FromToolRequest<S> + HasSchema + Send,
712 T4: FromToolRequest<S> + HasSchema + Send,
713{
714 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717 Box::pin(async move {
718 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722 self(t1, t2, t3, t4).await
723 })
724 }
725
726 fn input_schema() -> Value {
727 if let Some(schema) = T4::schema() {
728 return schema;
729 }
730 if let Some(schema) = T3::schema() {
731 return schema;
732 }
733 if let Some(schema) = T2::schema() {
734 return schema;
735 }
736 if let Some(schema) = T1::schema() {
737 return schema;
738 }
739 serde_json::json!({
740 "type": "object",
741 "additionalProperties": true
742 })
743 }
744}
745
746impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749 S: Clone + Send + Sync + 'static,
750 F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751 Fut: Future<Output = Result<CallToolResult>> + Send,
752 T1: FromToolRequest<S> + HasSchema + Send,
753 T2: FromToolRequest<S> + HasSchema + Send,
754 T3: FromToolRequest<S> + HasSchema + Send,
755 T4: FromToolRequest<S> + HasSchema + Send,
756 T5: FromToolRequest<S> + HasSchema + Send,
757{
758 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761 Box::pin(async move {
762 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766 let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767 self(t1, t2, t3, t4, t5).await
768 })
769 }
770
771 fn input_schema() -> Value {
772 if let Some(schema) = T5::schema() {
773 return schema;
774 }
775 if let Some(schema) = T4::schema() {
776 return schema;
777 }
778 if let Some(schema) = T3::schema() {
779 return schema;
780 }
781 if let Some(schema) = T2::schema() {
782 return schema;
783 }
784 if let Some(schema) = T1::schema() {
785 return schema;
786 }
787 serde_json::json!({
788 "type": "object",
789 "additionalProperties": true
790 })
791 }
792}
793
794pub trait HasSchema {
800 fn schema() -> Option<Value>;
802}
803
804impl<T: JsonSchema> HasSchema for Json<T> {
805 fn schema() -> Option<Value> {
806 let schema = schemars::schema_for!(T);
807 serde_json::to_value(schema)
808 .ok()
809 .map(crate::tool::ensure_object_schema)
810 }
811}
812
813impl HasSchema for Context {
815 fn schema() -> Option<Value> {
816 None
817 }
818}
819
820impl HasSchema for RawArgs {
821 fn schema() -> Option<Value> {
822 None
823 }
824}
825
826impl<T> HasSchema for State<T> {
827 fn schema() -> Option<Value> {
828 None
829 }
830}
831
832impl<T> HasSchema for Extension<T> {
833 fn schema() -> Option<Value> {
834 None
835 }
836}
837
838#[deprecated(
847 since = "0.8.0",
848 note = "Use `ExtractorHandler` instead -- `extractor_handler` auto-detects JSON schema from `Json<T>` extractors"
849)]
850pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
851where
852 I: JsonSchema,
853{
854 type Future: Future<Output = Result<CallToolResult>> + Send;
856
857 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
859}
860
861#[allow(deprecated)]
863impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
864where
865 S: Clone + Send + Sync + 'static,
866 F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
867 Fut: Future<Output = Result<CallToolResult>> + Send,
868 T: DeserializeOwned + JsonSchema + Send,
869{
870 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
871
872 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
873 Box::pin(async move {
874 let t1 =
875 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
876 self(t1).await
877 })
878 }
879}
880
881#[allow(deprecated)]
883impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
884where
885 S: Clone + Send + Sync + 'static,
886 F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
887 Fut: Future<Output = Result<CallToolResult>> + Send,
888 T1: FromToolRequest<S> + Send,
889 T: DeserializeOwned + JsonSchema + Send,
890{
891 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
892
893 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
894 Box::pin(async move {
895 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
896 let t2 =
897 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
898 self(t1, t2).await
899 })
900 }
901}
902
903#[allow(deprecated)]
905impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
906where
907 S: Clone + Send + Sync + 'static,
908 F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
909 Fut: Future<Output = Result<CallToolResult>> + Send,
910 T1: FromToolRequest<S> + Send,
911 T2: FromToolRequest<S> + Send,
912 T: DeserializeOwned + JsonSchema + Send,
913{
914 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
915
916 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
917 Box::pin(async move {
918 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
919 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
920 let t3 =
921 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
922 self(t1, t2, t3).await
923 })
924 }
925}
926
927#[allow(deprecated)]
929impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
930where
931 S: Clone + Send + Sync + 'static,
932 F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
933 Fut: Future<Output = Result<CallToolResult>> + Send,
934 T1: FromToolRequest<S> + Send,
935 T2: FromToolRequest<S> + Send,
936 T3: FromToolRequest<S> + Send,
937 T: DeserializeOwned + JsonSchema + Send,
938{
939 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
940
941 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
942 Box::pin(async move {
943 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
944 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
945 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
946 let t4 =
947 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
948 self(t1, t2, t3, t4).await
949 })
950 }
951}
952
953use crate::tool::{
958 BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
959};
960use tower::util::BoxCloneService;
961use tower_service::Service;
962
963pub(crate) struct ExtractorToolHandler<S, F, T> {
965 state: S,
966 handler: F,
967 input_schema: Value,
968 _phantom: PhantomData<T>,
969}
970
971impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
972where
973 S: Clone + Send + Sync + 'static,
974 F: ExtractorHandler<S, T> + Clone,
975 T: Send + Sync + 'static,
976{
977 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
978 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
979 self.call_with_context(ctx, args)
980 }
981
982 fn call_with_context(
983 &self,
984 ctx: RequestContext,
985 args: Value,
986 ) -> BoxFuture<'_, Result<CallToolResult>> {
987 let state = self.state.clone();
988 let handler = self.handler.clone();
989 Box::pin(async move { handler.call(ctx, state, args).await })
990 }
991
992 fn uses_context(&self) -> bool {
993 true
994 }
995
996 fn input_schema(&self) -> Value {
997 self.input_schema.clone()
998 }
999}
1000
1001#[doc(hidden)]
1003pub struct ToolBuilderWithExtractor<S, F, T> {
1004 pub(crate) name: String,
1005 pub(crate) title: Option<String>,
1006 pub(crate) description: Option<String>,
1007 pub(crate) output_schema: Option<Value>,
1008 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1009 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1010 pub(crate) task_support: crate::protocol::TaskSupportMode,
1011 pub(crate) state: S,
1012 pub(crate) handler: F,
1013 pub(crate) input_schema: Value,
1014 pub(crate) _phantom: PhantomData<T>,
1015}
1016
1017impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1018where
1019 S: Clone + Send + Sync + 'static,
1020 F: ExtractorHandler<S, T> + Clone,
1021 T: Send + Sync + 'static,
1022{
1023 pub fn build(self) -> Tool {
1025 let handler = ExtractorToolHandler {
1026 state: self.state,
1027 handler: self.handler,
1028 input_schema: self.input_schema.clone(),
1029 _phantom: PhantomData,
1030 };
1031
1032 let handler_service = ToolHandlerService::new(handler);
1033 let catch_error = ToolCatchError::new(handler_service);
1034 let service = BoxCloneService::new(catch_error);
1035
1036 Tool {
1037 name: self.name,
1038 title: self.title,
1039 description: self.description,
1040 output_schema: self.output_schema,
1041 icons: self.icons,
1042 annotations: self.annotations,
1043 task_support: self.task_support,
1044 service,
1045 input_schema: self.input_schema,
1046 }
1047 }
1048
1049 pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1085 ToolBuilderWithExtractorLayer {
1086 name: self.name,
1087 title: self.title,
1088 description: self.description,
1089 output_schema: self.output_schema,
1090 icons: self.icons,
1091 annotations: self.annotations,
1092 task_support: self.task_support,
1093 state: self.state,
1094 handler: self.handler,
1095 input_schema: self.input_schema,
1096 layer,
1097 _phantom: PhantomData,
1098 }
1099 }
1100
1101 pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1105 where
1106 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1107 {
1108 self.layer(GuardLayer::new(guard))
1109 }
1110}
1111
1112#[doc(hidden)]
1116pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1117 name: String,
1118 title: Option<String>,
1119 description: Option<String>,
1120 output_schema: Option<Value>,
1121 icons: Option<Vec<crate::protocol::ToolIcon>>,
1122 annotations: Option<crate::protocol::ToolAnnotations>,
1123 task_support: crate::protocol::TaskSupportMode,
1124 state: S,
1125 handler: F,
1126 input_schema: Value,
1127 layer: L,
1128 _phantom: PhantomData<T>,
1129}
1130
1131#[allow(private_bounds)]
1132impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1133where
1134 S: Clone + Send + Sync + 'static,
1135 F: ExtractorHandler<S, T> + Clone,
1136 T: Send + Sync + 'static,
1137 L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1138 + Clone
1139 + Send
1140 + Sync
1141 + 'static,
1142 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1143 <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1144 <L::Service as Service<ToolRequest>>::Future: Send,
1145{
1146 pub fn build(self) -> Tool {
1148 let handler = ExtractorToolHandler {
1149 state: self.state,
1150 handler: self.handler,
1151 input_schema: self.input_schema.clone(),
1152 _phantom: PhantomData,
1153 };
1154
1155 let handler_service = ToolHandlerService::new(handler);
1156 let layered = self.layer.layer(handler_service);
1157 let catch_error = ToolCatchError::new(layered);
1158 let service = BoxCloneService::new(catch_error);
1159
1160 Tool {
1161 name: self.name,
1162 title: self.title,
1163 description: self.description,
1164 output_schema: self.output_schema,
1165 icons: self.icons,
1166 annotations: self.annotations,
1167 task_support: self.task_support,
1168 service,
1169 input_schema: self.input_schema,
1170 }
1171 }
1172
1173 pub fn layer<L2>(
1178 self,
1179 layer: L2,
1180 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1181 ToolBuilderWithExtractorLayer {
1182 name: self.name,
1183 title: self.title,
1184 description: self.description,
1185 output_schema: self.output_schema,
1186 icons: self.icons,
1187 annotations: self.annotations,
1188 task_support: self.task_support,
1189 state: self.state,
1190 handler: self.handler,
1191 input_schema: self.input_schema,
1192 layer: tower::layer::util::Stack::new(layer, self.layer),
1193 _phantom: PhantomData,
1194 }
1195 }
1196
1197 pub fn guard<G>(
1201 self,
1202 guard: G,
1203 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1204 where
1205 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1206 {
1207 self.layer(GuardLayer::new(guard))
1208 }
1209}
1210
1211#[doc(hidden)]
1213#[deprecated(
1214 since = "0.8.0",
1215 note = "Use `ToolBuilderWithExtractor` via `extractor_handler` instead"
1216)]
1217pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1218 pub(crate) name: String,
1219 pub(crate) title: Option<String>,
1220 pub(crate) description: Option<String>,
1221 pub(crate) output_schema: Option<Value>,
1222 pub(crate) input_schema_override: Option<Value>,
1223 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1224 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1225 pub(crate) task_support: crate::protocol::TaskSupportMode,
1226 pub(crate) state: S,
1227 pub(crate) handler: F,
1228 pub(crate) _phantom: PhantomData<(T, I)>,
1229}
1230
1231#[allow(deprecated)]
1232impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1233where
1234 S: Clone + Send + Sync + 'static,
1235 F: TypedExtractorHandler<S, T, I> + Clone,
1236 T: Send + Sync + 'static,
1237 I: JsonSchema + Send + Sync + 'static,
1238{
1239 pub fn build(self) -> Tool {
1241 let input_schema = {
1242 let schema = self.input_schema_override.unwrap_or_else(|| {
1243 let schema = schemars::schema_for!(I);
1244 serde_json::to_value(schema).unwrap_or_else(|_| {
1245 serde_json::json!({
1246 "type": "object"
1247 })
1248 })
1249 });
1250 crate::tool::ensure_object_schema(schema)
1251 };
1252
1253 let handler = TypedExtractorToolHandler {
1254 state: self.state,
1255 handler: self.handler,
1256 input_schema: input_schema.clone(),
1257 _phantom: PhantomData,
1258 };
1259
1260 let handler_service = crate::tool::ToolHandlerService::new(handler);
1261 let catch_error = ToolCatchError::new(handler_service);
1262 let service = BoxCloneService::new(catch_error);
1263
1264 Tool {
1265 name: self.name,
1266 title: self.title,
1267 description: self.description,
1268 output_schema: self.output_schema,
1269 icons: self.icons,
1270 annotations: self.annotations,
1271 task_support: self.task_support,
1272 service,
1273 input_schema,
1274 }
1275 }
1276}
1277
1278struct TypedExtractorToolHandler<S, F, T, I> {
1280 state: S,
1281 handler: F,
1282 input_schema: Value,
1283 _phantom: PhantomData<(T, I)>,
1284}
1285
1286#[allow(deprecated)]
1287impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1288where
1289 S: Clone + Send + Sync + 'static,
1290 F: TypedExtractorHandler<S, T, I> + Clone,
1291 T: Send + Sync + 'static,
1292 I: JsonSchema + Send + Sync + 'static,
1293{
1294 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1295 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1296 self.call_with_context(ctx, args)
1297 }
1298
1299 fn call_with_context(
1300 &self,
1301 ctx: RequestContext,
1302 args: Value,
1303 ) -> BoxFuture<'_, Result<CallToolResult>> {
1304 let state = self.state.clone();
1305 let handler = self.handler.clone();
1306 Box::pin(async move { handler.call(ctx, state, args).await })
1307 }
1308
1309 fn uses_context(&self) -> bool {
1310 true
1311 }
1312
1313 fn input_schema(&self) -> Value {
1314 self.input_schema.clone()
1315 }
1316}
1317
1318#[cfg(test)]
1319mod tests {
1320 use super::*;
1321 use crate::protocol::RequestId;
1322 use schemars::JsonSchema;
1323 use serde::Deserialize;
1324 use std::sync::Arc;
1325
1326 #[derive(Debug, Deserialize, JsonSchema)]
1327 struct TestInput {
1328 name: String,
1329 count: i32,
1330 }
1331
1332 #[test]
1333 fn test_json_extraction() {
1334 let args = serde_json::json!({"name": "test", "count": 42});
1335 let ctx = RequestContext::new(RequestId::Number(1));
1336
1337 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1338 assert!(result.is_ok());
1339 let Json(input) = result.unwrap();
1340 assert_eq!(input.name, "test");
1341 assert_eq!(input.count, 42);
1342 }
1343
1344 #[test]
1345 fn test_json_extraction_error() {
1346 let args = serde_json::json!({"name": "test"}); let ctx = RequestContext::new(RequestId::Number(1));
1348
1349 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1350 assert!(result.is_err());
1351 let rejection = result.unwrap_err();
1352 assert!(rejection.message().contains("count"));
1354 }
1355
1356 #[test]
1357 fn test_state_extraction() {
1358 let args = serde_json::json!({});
1359 let ctx = RequestContext::new(RequestId::Number(1));
1360 let state = Arc::new("my-state".to_string());
1361
1362 let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1363 assert!(result.is_ok());
1364 let State(extracted) = result.unwrap();
1365 assert_eq!(*extracted, "my-state");
1366 }
1367
1368 #[test]
1369 fn test_context_extraction() {
1370 let args = serde_json::json!({});
1371 let ctx = RequestContext::new(RequestId::Number(42));
1372
1373 let result = Context::from_tool_request(&ctx, &(), &args);
1374 assert!(result.is_ok());
1375 let extracted = result.unwrap();
1376 assert_eq!(*extracted.request_id(), RequestId::Number(42));
1377 }
1378
1379 #[test]
1380 fn test_raw_args_extraction() {
1381 let args = serde_json::json!({"foo": "bar", "baz": 123});
1382 let ctx = RequestContext::new(RequestId::Number(1));
1383
1384 let result = RawArgs::from_tool_request(&ctx, &(), &args);
1385 assert!(result.is_ok());
1386 let RawArgs(extracted) = result.unwrap();
1387 assert_eq!(extracted["foo"], "bar");
1388 assert_eq!(extracted["baz"], 123);
1389 }
1390
1391 #[test]
1392 fn test_extension_extraction() {
1393 use crate::context::Extensions;
1394
1395 #[derive(Clone, Debug, PartialEq)]
1396 struct DatabasePool {
1397 url: String,
1398 }
1399
1400 let args = serde_json::json!({});
1401
1402 let mut extensions = Extensions::new();
1404 extensions.insert(Arc::new(DatabasePool {
1405 url: "postgres://localhost".to_string(),
1406 }));
1407
1408 let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1410
1411 let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1413 assert!(result.is_ok());
1414 let Extension(pool) = result.unwrap();
1415 assert_eq!(pool.url, "postgres://localhost");
1416 }
1417
1418 #[test]
1419 fn test_extension_extraction_missing() {
1420 #[derive(Clone, Debug)]
1421 struct NotPresent;
1422
1423 let args = serde_json::json!({});
1424 let ctx = RequestContext::new(RequestId::Number(1));
1425
1426 let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1428 assert!(result.is_err());
1429 let rejection = result.unwrap_err();
1430 assert!(rejection.type_name().contains("NotPresent"));
1432 }
1433
1434 #[tokio::test]
1435 async fn test_single_extractor_handler() {
1436 let handler = |Json(input): Json<TestInput>| async move {
1437 Ok(CallToolResult::text(format!(
1438 "{}: {}",
1439 input.name, input.count
1440 )))
1441 };
1442
1443 let ctx = RequestContext::new(RequestId::Number(1));
1444 let args = serde_json::json!({"name": "test", "count": 5});
1445
1446 let result: Result<CallToolResult> =
1448 ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1449 assert!(result.is_ok());
1450 }
1451
1452 #[tokio::test]
1453 async fn test_two_extractor_handler() {
1454 let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1455 Ok(CallToolResult::text(format!(
1456 "{}: {} - {}",
1457 state, input.name, input.count
1458 )))
1459 };
1460
1461 let ctx = RequestContext::new(RequestId::Number(1));
1462 let state = Arc::new("prefix".to_string());
1463 let args = serde_json::json!({"name": "test", "count": 5});
1464
1465 let result: Result<CallToolResult> = ExtractorHandler::<
1467 Arc<String>,
1468 (State<Arc<String>>, Json<TestInput>),
1469 >::call(handler, ctx, state, args)
1470 .await;
1471 assert!(result.is_ok());
1472 }
1473
1474 #[tokio::test]
1475 async fn test_three_extractor_handler() {
1476 let handler = |State(state): State<Arc<String>>,
1477 ctx: Context,
1478 Json(input): Json<TestInput>| async move {
1479 assert!(!ctx.is_cancelled());
1481 Ok(CallToolResult::text(format!(
1482 "{}: {} - {}",
1483 state, input.name, input.count
1484 )))
1485 };
1486
1487 let ctx = RequestContext::new(RequestId::Number(1));
1488 let state = Arc::new("prefix".to_string());
1489 let args = serde_json::json!({"name": "test", "count": 5});
1490
1491 let result: Result<CallToolResult> = ExtractorHandler::<
1493 Arc<String>,
1494 (State<Arc<String>>, Context, Json<TestInput>),
1495 >::call(handler, ctx, state, args)
1496 .await;
1497 assert!(result.is_ok());
1498 }
1499
1500 #[test]
1501 fn test_json_schema_generation() {
1502 let schema = Json::<TestInput>::schema();
1503 assert!(schema.is_some());
1504 let schema = schema.unwrap();
1505 assert!(schema.get("properties").is_some());
1506 }
1507
1508 #[test]
1509 fn test_rejection_into_error() {
1510 let rejection = Rejection::new("test error");
1511 let error: Error = rejection.into();
1512 assert!(error.to_string().contains("test error"));
1513 }
1514
1515 #[test]
1516 fn test_json_rejection() {
1517 let rejection = JsonRejection::new("missing field `name`");
1519 assert_eq!(rejection.message(), "missing field `name`");
1520 assert!(rejection.path().is_none());
1521 assert!(rejection.to_string().contains("Invalid input"));
1522
1523 let rejection = JsonRejection::with_path("expected string", "users[0].name");
1525 assert_eq!(rejection.message(), "expected string");
1526 assert_eq!(rejection.path(), Some("users[0].name"));
1527 assert!(rejection.to_string().contains("users[0].name"));
1528
1529 let error: Error = rejection.into();
1531 assert!(error.to_string().contains("users[0].name"));
1532 }
1533
1534 #[test]
1535 fn test_json_rejection_from_serde_error() {
1536 #[derive(Debug, serde::Deserialize)]
1538 struct TestStruct {
1539 #[allow(dead_code)]
1540 name: String,
1541 }
1542
1543 let result: std::result::Result<TestStruct, _> =
1544 serde_json::from_value(serde_json::json!({"count": 42}));
1545 assert!(result.is_err());
1546
1547 let rejection: JsonRejection = result.unwrap_err().into();
1548 assert!(rejection.message().contains("name"));
1549 }
1550
1551 #[test]
1552 fn test_extension_rejection() {
1553 let rejection = ExtensionRejection::not_found::<String>();
1555 assert!(rejection.type_name().contains("String"));
1556 assert!(rejection.to_string().contains("not found"));
1557 assert!(rejection.to_string().contains("with_state"));
1558
1559 let error: Error = rejection.into();
1561 assert!(error.to_string().contains("not found"));
1562 }
1563
1564 #[tokio::test]
1565 async fn test_tool_builder_extractor_handler() {
1566 use crate::ToolBuilder;
1567
1568 let state = Arc::new("shared-state".to_string());
1569
1570 let tool =
1571 ToolBuilder::new("test_extractor")
1572 .description("Test extractor handler")
1573 .extractor_handler(
1574 state,
1575 |State(state): State<Arc<String>>,
1576 ctx: Context,
1577 Json(input): Json<TestInput>| async move {
1578 assert!(!ctx.is_cancelled());
1579 Ok(CallToolResult::text(format!(
1580 "{}: {} - {}",
1581 state, input.name, input.count
1582 )))
1583 },
1584 )
1585 .build();
1586
1587 assert_eq!(tool.name, "test_extractor");
1588 assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1589
1590 let result = tool
1592 .call(serde_json::json!({"name": "test", "count": 42}))
1593 .await;
1594 assert!(!result.is_error);
1595 }
1596
1597 #[tokio::test]
1598 #[allow(deprecated)]
1599 async fn test_tool_builder_extractor_handler_typed() {
1600 use crate::ToolBuilder;
1601
1602 let state = Arc::new("typed-state".to_string());
1603
1604 let tool = ToolBuilder::new("test_typed")
1605 .description("Test typed extractor handler")
1606 .extractor_handler_typed::<_, _, _, TestInput>(
1607 state,
1608 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1609 Ok(CallToolResult::text(format!(
1610 "{}: {} - {}",
1611 state, input.name, input.count
1612 )))
1613 },
1614 )
1615 .build();
1616
1617 assert_eq!(tool.name, "test_typed");
1618
1619 let def = tool.definition();
1621 let schema = def.input_schema;
1622 assert!(schema.get("properties").is_some());
1623
1624 let result = tool
1626 .call(serde_json::json!({"name": "world", "count": 99}))
1627 .await;
1628 assert!(!result.is_error);
1629 }
1630
1631 #[tokio::test]
1632 async fn test_extractor_handler_auto_schema() {
1633 use crate::ToolBuilder;
1634
1635 let state = Arc::new("auto-schema".to_string());
1636
1637 let tool = ToolBuilder::new("test_auto_schema")
1639 .description("Test auto schema detection")
1640 .extractor_handler(
1641 state,
1642 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1643 Ok(CallToolResult::text(format!(
1644 "{}: {} - {}",
1645 state, input.name, input.count
1646 )))
1647 },
1648 )
1649 .build();
1650
1651 let def = tool.definition();
1653 let schema = def.input_schema;
1654 assert!(
1655 schema.get("properties").is_some(),
1656 "Schema should have properties from TestInput, got: {}",
1657 schema
1658 );
1659 let props = schema.get("properties").unwrap();
1660 assert!(
1661 props.get("name").is_some(),
1662 "Schema should have 'name' property"
1663 );
1664 assert!(
1665 props.get("count").is_some(),
1666 "Schema should have 'count' property"
1667 );
1668
1669 let result = tool
1671 .call(serde_json::json!({"name": "world", "count": 99}))
1672 .await;
1673 assert!(!result.is_error);
1674 }
1675
1676 #[test]
1677 fn test_extractor_handler_no_json_fallback() {
1678 use crate::ToolBuilder;
1679
1680 let tool = ToolBuilder::new("test_no_json")
1682 .description("Test no json fallback")
1683 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1684 Ok(CallToolResult::json(args))
1685 })
1686 .build();
1687
1688 let def = tool.definition();
1689 let schema = def.input_schema;
1690 assert_eq!(
1691 schema.get("type").and_then(|v| v.as_str()),
1692 Some("object"),
1693 "Schema should be generic object"
1694 );
1695 assert_eq!(
1696 schema.get("additionalProperties").and_then(|v| v.as_bool()),
1697 Some(true),
1698 "Schema should allow additional properties"
1699 );
1700 assert!(
1702 schema.get("properties").is_none(),
1703 "Generic schema should not have specific properties"
1704 );
1705 }
1706
1707 #[tokio::test]
1708 async fn test_extractor_handler_with_layer() {
1709 use crate::ToolBuilder;
1710 use std::time::Duration;
1711 use tower::timeout::TimeoutLayer;
1712
1713 let state = Arc::new("layered".to_string());
1714
1715 let tool = ToolBuilder::new("test_extractor_layer")
1716 .description("Test extractor handler with layer")
1717 .extractor_handler(
1718 state,
1719 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1720 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1721 },
1722 )
1723 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1724 .build();
1725
1726 let result = tool
1728 .call(serde_json::json!({"name": "test", "count": 1}))
1729 .await;
1730 assert!(!result.is_error);
1731 assert_eq!(result.first_text().unwrap(), "layered: test");
1732
1733 let def = tool.definition();
1735 let schema = def.input_schema;
1736 assert!(
1737 schema.get("properties").is_some(),
1738 "Schema should have properties even with layer"
1739 );
1740 }
1741
1742 #[tokio::test]
1743 async fn test_extractor_handler_with_timeout_layer() {
1744 use crate::ToolBuilder;
1745 use std::time::Duration;
1746 use tower::timeout::TimeoutLayer;
1747
1748 let tool = ToolBuilder::new("test_extractor_timeout")
1749 .description("Test extractor handler timeout")
1750 .extractor_handler((), |Json(input): Json<TestInput>| async move {
1751 tokio::time::sleep(Duration::from_millis(200)).await;
1752 Ok(CallToolResult::text(input.name.to_string()))
1753 })
1754 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1755 .build();
1756
1757 let result = tool
1759 .call(serde_json::json!({"name": "slow", "count": 1}))
1760 .await;
1761 assert!(result.is_error);
1762 let msg = result.first_text().unwrap().to_lowercase();
1763 assert!(
1764 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1765 "Expected timeout error, got: {}",
1766 msg
1767 );
1768 }
1769
1770 #[tokio::test]
1771 async fn test_extractor_handler_with_multiple_layers() {
1772 use crate::ToolBuilder;
1773 use std::time::Duration;
1774 use tower::limit::ConcurrencyLimitLayer;
1775 use tower::timeout::TimeoutLayer;
1776
1777 let state = Arc::new("multi".to_string());
1778
1779 let tool = ToolBuilder::new("test_multi_layer")
1780 .description("Test multiple layers")
1781 .extractor_handler(
1782 state,
1783 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1784 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1785 },
1786 )
1787 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1788 .layer(ConcurrencyLimitLayer::new(10))
1789 .build();
1790
1791 let result = tool
1792 .call(serde_json::json!({"name": "test", "count": 1}))
1793 .await;
1794 assert!(!result.is_error);
1795 assert_eq!(result.first_text().unwrap(), "multi: test");
1796 }
1797}