1use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88#[derive(Debug, Clone)]
98pub struct Rejection {
99 message: String,
100}
101
102impl Rejection {
103 pub fn new(message: impl Into<String>) -> Self {
105 Self {
106 message: message.into(),
107 }
108 }
109
110 pub fn message(&self) -> &str {
112 &self.message
113 }
114}
115
116impl std::fmt::Display for Rejection {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{}", self.message)
119 }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125 fn from(rejection: Rejection) -> Self {
126 Error::tool(rejection.message)
127 }
128}
129
130#[derive(Debug, Clone)]
144pub struct JsonRejection {
145 message: String,
146 path: Option<String>,
148}
149
150impl JsonRejection {
151 pub fn new(message: impl Into<String>) -> Self {
153 Self {
154 message: message.into(),
155 path: None,
156 }
157 }
158
159 pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161 Self {
162 message: message.into(),
163 path: Some(path.into()),
164 }
165 }
166
167 pub fn message(&self) -> &str {
169 &self.message
170 }
171
172 pub fn path(&self) -> Option<&str> {
174 self.path.as_deref()
175 }
176}
177
178impl std::fmt::Display for JsonRejection {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 if let Some(path) = &self.path {
181 write!(f, "Invalid input at `{}`: {}", path, self.message)
182 } else {
183 write!(f, "Invalid input: {}", self.message)
184 }
185 }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191 fn from(rejection: JsonRejection) -> Self {
192 Error::tool(rejection.to_string())
193 }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197 fn from(err: serde_json::Error) -> Self {
198 let path = if err.is_data() {
200 None
203 } else {
204 None
205 };
206
207 Self {
208 message: err.to_string(),
209 path,
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229 type_name: &'static str,
230}
231
232impl ExtensionRejection {
233 pub fn not_found<T>() -> Self {
235 Self {
236 type_name: std::any::type_name::<T>(),
237 }
238 }
239
240 pub fn type_name(&self) -> &'static str {
242 self.type_name
243 }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(
249 f,
250 "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251 self.type_name
252 )
253 }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259 fn from(rejection: ExtensionRejection) -> Self {
260 Error::tool(rejection.to_string())
261 }
262}
263
264pub trait FromToolRequest<S = ()>: Sized {
295 type Rejection: Into<Error>;
297
298 fn from_tool_request(
306 ctx: &RequestContext,
307 state: &S,
308 args: &Value,
309 ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346 type Target = T;
347
348 fn deref(&self) -> &Self::Target {
349 &self.0
350 }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355 T: DeserializeOwned,
356{
357 type Rejection = JsonRejection;
358
359 fn from_tool_request(
360 _ctx: &RequestContext,
361 _state: &S,
362 args: &Value,
363 ) -> std::result::Result<Self, Self::Rejection> {
364 serde_json::from_value(args.clone())
365 .map(Json)
366 .map_err(JsonRejection::from)
367 }
368}
369
370#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398 type Target = T;
399
400 fn deref(&self) -> &Self::Target {
401 &self.0
402 }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406 type Rejection = Rejection;
407
408 fn from_tool_request(
409 _ctx: &RequestContext,
410 state: &S,
411 _args: &Value,
412 ) -> std::result::Result<Self, Self::Rejection> {
413 Ok(State(state.clone()))
414 }
415}
416
417#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441 pub fn into_inner(self) -> RequestContext {
443 self.0
444 }
445}
446
447impl Deref for Context {
448 type Target = RequestContext;
449
450 fn deref(&self) -> &Self::Target {
451 &self.0
452 }
453}
454
455impl<S> FromToolRequest<S> for Context {
456 type Rejection = Rejection;
457
458 fn from_tool_request(
459 ctx: &RequestContext,
460 _state: &S,
461 _args: &Value,
462 ) -> std::result::Result<Self, Self::Rejection> {
463 Ok(Context(ctx.clone()))
464 }
465}
466
467#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488 type Target = Value;
489
490 fn deref(&self) -> &Self::Target {
491 &self.0
492 }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496 type Rejection = Rejection;
497
498 fn from_tool_request(
499 _ctx: &RequestContext,
500 _state: &S,
501 args: &Value,
502 ) -> std::result::Result<Self, Self::Rejection> {
503 Ok(RawArgs(args.clone()))
504 }
505}
506
507#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557 type Target = T;
558
559 fn deref(&self) -> &Self::Target {
560 &self.0
561 }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566 T: Clone + Send + Sync + 'static,
567{
568 type Rejection = ExtensionRejection;
569
570 fn from_tool_request(
571 ctx: &RequestContext,
572 _state: &S,
573 _args: &Value,
574 ) -> std::result::Result<Self, Self::Rejection> {
575 ctx.extension::<T>()
576 .cloned()
577 .map(Extension)
578 .ok_or_else(ExtensionRejection::not_found::<T>)
579 }
580}
581
582pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592 type Future: Future<Output = Result<CallToolResult>> + Send;
594
595 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598 fn input_schema() -> Value;
602}
603
604impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607 S: Clone + Send + Sync + 'static,
608 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609 Fut: Future<Output = Result<CallToolResult>> + Send,
610 T1: FromToolRequest<S> + HasSchema + Send,
611{
612 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615 Box::pin(async move {
616 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617 self(t1).await
618 })
619 }
620
621 fn input_schema() -> Value {
622 if let Some(schema) = T1::schema() {
623 return schema;
624 }
625 serde_json::json!({
626 "type": "object",
627 "additionalProperties": true
628 })
629 }
630}
631
632impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635 S: Clone + Send + Sync + 'static,
636 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637 Fut: Future<Output = Result<CallToolResult>> + Send,
638 T1: FromToolRequest<S> + HasSchema + Send,
639 T2: FromToolRequest<S> + HasSchema + Send,
640{
641 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644 Box::pin(async move {
645 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647 self(t1, t2).await
648 })
649 }
650
651 fn input_schema() -> Value {
652 if let Some(schema) = T2::schema() {
653 return schema;
654 }
655 if let Some(schema) = T1::schema() {
656 return schema;
657 }
658 serde_json::json!({
659 "type": "object",
660 "additionalProperties": true
661 })
662 }
663}
664
665impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668 S: Clone + Send + Sync + 'static,
669 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670 Fut: Future<Output = Result<CallToolResult>> + Send,
671 T1: FromToolRequest<S> + HasSchema + Send,
672 T2: FromToolRequest<S> + HasSchema + Send,
673 T3: FromToolRequest<S> + HasSchema + Send,
674{
675 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678 Box::pin(async move {
679 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682 self(t1, t2, t3).await
683 })
684 }
685
686 fn input_schema() -> Value {
687 if let Some(schema) = T3::schema() {
688 return schema;
689 }
690 if let Some(schema) = T2::schema() {
691 return schema;
692 }
693 if let Some(schema) = T1::schema() {
694 return schema;
695 }
696 serde_json::json!({
697 "type": "object",
698 "additionalProperties": true
699 })
700 }
701}
702
703impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706 S: Clone + Send + Sync + 'static,
707 F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708 Fut: Future<Output = Result<CallToolResult>> + Send,
709 T1: FromToolRequest<S> + HasSchema + Send,
710 T2: FromToolRequest<S> + HasSchema + Send,
711 T3: FromToolRequest<S> + HasSchema + Send,
712 T4: FromToolRequest<S> + HasSchema + Send,
713{
714 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717 Box::pin(async move {
718 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722 self(t1, t2, t3, t4).await
723 })
724 }
725
726 fn input_schema() -> Value {
727 if let Some(schema) = T4::schema() {
728 return schema;
729 }
730 if let Some(schema) = T3::schema() {
731 return schema;
732 }
733 if let Some(schema) = T2::schema() {
734 return schema;
735 }
736 if let Some(schema) = T1::schema() {
737 return schema;
738 }
739 serde_json::json!({
740 "type": "object",
741 "additionalProperties": true
742 })
743 }
744}
745
746impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749 S: Clone + Send + Sync + 'static,
750 F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751 Fut: Future<Output = Result<CallToolResult>> + Send,
752 T1: FromToolRequest<S> + HasSchema + Send,
753 T2: FromToolRequest<S> + HasSchema + Send,
754 T3: FromToolRequest<S> + HasSchema + Send,
755 T4: FromToolRequest<S> + HasSchema + Send,
756 T5: FromToolRequest<S> + HasSchema + Send,
757{
758 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761 Box::pin(async move {
762 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766 let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767 self(t1, t2, t3, t4, t5).await
768 })
769 }
770
771 fn input_schema() -> Value {
772 if let Some(schema) = T5::schema() {
773 return schema;
774 }
775 if let Some(schema) = T4::schema() {
776 return schema;
777 }
778 if let Some(schema) = T3::schema() {
779 return schema;
780 }
781 if let Some(schema) = T2::schema() {
782 return schema;
783 }
784 if let Some(schema) = T1::schema() {
785 return schema;
786 }
787 serde_json::json!({
788 "type": "object",
789 "additionalProperties": true
790 })
791 }
792}
793
794pub trait HasSchema {
800 fn schema() -> Option<Value>;
801}
802
803impl<T: JsonSchema> HasSchema for Json<T> {
804 fn schema() -> Option<Value> {
805 let schema = schemars::schema_for!(T);
806 serde_json::to_value(schema).ok()
807 }
808}
809
810impl HasSchema for Context {
812 fn schema() -> Option<Value> {
813 None
814 }
815}
816
817impl HasSchema for RawArgs {
818 fn schema() -> Option<Value> {
819 None
820 }
821}
822
823impl<T> HasSchema for State<T> {
824 fn schema() -> Option<Value> {
825 None
826 }
827}
828
829impl<T> HasSchema for Extension<T> {
830 fn schema() -> Option<Value> {
831 None
832 }
833}
834
835pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
844where
845 I: JsonSchema,
846{
847 type Future: Future<Output = Result<CallToolResult>> + Send;
849
850 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
852}
853
854impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
856where
857 S: Clone + Send + Sync + 'static,
858 F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
859 Fut: Future<Output = Result<CallToolResult>> + Send,
860 T: DeserializeOwned + JsonSchema + Send,
861{
862 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
863
864 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
865 Box::pin(async move {
866 let t1 =
867 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
868 self(t1).await
869 })
870 }
871}
872
873impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
875where
876 S: Clone + Send + Sync + 'static,
877 F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
878 Fut: Future<Output = Result<CallToolResult>> + Send,
879 T1: FromToolRequest<S> + Send,
880 T: DeserializeOwned + JsonSchema + Send,
881{
882 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
883
884 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
885 Box::pin(async move {
886 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
887 let t2 =
888 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
889 self(t1, t2).await
890 })
891 }
892}
893
894impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
896where
897 S: Clone + Send + Sync + 'static,
898 F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
899 Fut: Future<Output = Result<CallToolResult>> + Send,
900 T1: FromToolRequest<S> + Send,
901 T2: FromToolRequest<S> + Send,
902 T: DeserializeOwned + JsonSchema + Send,
903{
904 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
905
906 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
907 Box::pin(async move {
908 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
909 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
910 let t3 =
911 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
912 self(t1, t2, t3).await
913 })
914 }
915}
916
917impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
919where
920 S: Clone + Send + Sync + 'static,
921 F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
922 Fut: Future<Output = Result<CallToolResult>> + Send,
923 T1: FromToolRequest<S> + Send,
924 T2: FromToolRequest<S> + Send,
925 T3: FromToolRequest<S> + Send,
926 T: DeserializeOwned + JsonSchema + Send,
927{
928 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
929
930 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
931 Box::pin(async move {
932 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
933 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
934 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
935 let t4 =
936 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
937 self(t1, t2, t3, t4).await
938 })
939 }
940}
941
942use crate::tool::{
947 BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
948};
949use tower::util::BoxCloneService;
950use tower_service::Service;
951
952pub(crate) struct ExtractorToolHandler<S, F, T> {
954 state: S,
955 handler: F,
956 input_schema: Value,
957 _phantom: PhantomData<T>,
958}
959
960impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
961where
962 S: Clone + Send + Sync + 'static,
963 F: ExtractorHandler<S, T> + Clone,
964 T: Send + Sync + 'static,
965{
966 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
967 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
968 self.call_with_context(ctx, args)
969 }
970
971 fn call_with_context(
972 &self,
973 ctx: RequestContext,
974 args: Value,
975 ) -> BoxFuture<'_, Result<CallToolResult>> {
976 let state = self.state.clone();
977 let handler = self.handler.clone();
978 Box::pin(async move { handler.call(ctx, state, args).await })
979 }
980
981 fn uses_context(&self) -> bool {
982 true
983 }
984
985 fn input_schema(&self) -> Value {
986 self.input_schema.clone()
987 }
988}
989
990pub struct ToolBuilderWithExtractor<S, F, T> {
992 pub(crate) name: String,
993 pub(crate) title: Option<String>,
994 pub(crate) description: Option<String>,
995 pub(crate) output_schema: Option<Value>,
996 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
997 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
998 pub(crate) task_support: crate::protocol::TaskSupportMode,
999 pub(crate) state: S,
1000 pub(crate) handler: F,
1001 pub(crate) input_schema: Value,
1002 pub(crate) _phantom: PhantomData<T>,
1003}
1004
1005impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1006where
1007 S: Clone + Send + Sync + 'static,
1008 F: ExtractorHandler<S, T> + Clone,
1009 T: Send + Sync + 'static,
1010{
1011 pub fn build(self) -> Tool {
1013 let handler = ExtractorToolHandler {
1014 state: self.state,
1015 handler: self.handler,
1016 input_schema: self.input_schema.clone(),
1017 _phantom: PhantomData,
1018 };
1019
1020 let handler_service = ToolHandlerService::new(handler);
1021 let catch_error = ToolCatchError::new(handler_service);
1022 let service = BoxCloneService::new(catch_error);
1023
1024 Tool {
1025 name: self.name,
1026 title: self.title,
1027 description: self.description,
1028 output_schema: self.output_schema,
1029 icons: self.icons,
1030 annotations: self.annotations,
1031 task_support: self.task_support,
1032 service,
1033 input_schema: self.input_schema,
1034 }
1035 }
1036
1037 pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1073 ToolBuilderWithExtractorLayer {
1074 name: self.name,
1075 title: self.title,
1076 description: self.description,
1077 output_schema: self.output_schema,
1078 icons: self.icons,
1079 annotations: self.annotations,
1080 task_support: self.task_support,
1081 state: self.state,
1082 handler: self.handler,
1083 input_schema: self.input_schema,
1084 layer,
1085 _phantom: PhantomData,
1086 }
1087 }
1088
1089 pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1093 where
1094 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1095 {
1096 self.layer(GuardLayer::new(guard))
1097 }
1098}
1099
1100pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1104 name: String,
1105 title: Option<String>,
1106 description: Option<String>,
1107 output_schema: Option<Value>,
1108 icons: Option<Vec<crate::protocol::ToolIcon>>,
1109 annotations: Option<crate::protocol::ToolAnnotations>,
1110 task_support: crate::protocol::TaskSupportMode,
1111 state: S,
1112 handler: F,
1113 input_schema: Value,
1114 layer: L,
1115 _phantom: PhantomData<T>,
1116}
1117
1118#[allow(private_bounds)]
1119impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1120where
1121 S: Clone + Send + Sync + 'static,
1122 F: ExtractorHandler<S, T> + Clone,
1123 T: Send + Sync + 'static,
1124 L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1125 + Clone
1126 + Send
1127 + Sync
1128 + 'static,
1129 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1130 <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1131 <L::Service as Service<ToolRequest>>::Future: Send,
1132{
1133 pub fn build(self) -> Tool {
1135 let handler = ExtractorToolHandler {
1136 state: self.state,
1137 handler: self.handler,
1138 input_schema: self.input_schema.clone(),
1139 _phantom: PhantomData,
1140 };
1141
1142 let handler_service = ToolHandlerService::new(handler);
1143 let layered = self.layer.layer(handler_service);
1144 let catch_error = ToolCatchError::new(layered);
1145 let service = BoxCloneService::new(catch_error);
1146
1147 Tool {
1148 name: self.name,
1149 title: self.title,
1150 description: self.description,
1151 output_schema: self.output_schema,
1152 icons: self.icons,
1153 annotations: self.annotations,
1154 task_support: self.task_support,
1155 service,
1156 input_schema: self.input_schema,
1157 }
1158 }
1159
1160 pub fn layer<L2>(
1165 self,
1166 layer: L2,
1167 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1168 ToolBuilderWithExtractorLayer {
1169 name: self.name,
1170 title: self.title,
1171 description: self.description,
1172 output_schema: self.output_schema,
1173 icons: self.icons,
1174 annotations: self.annotations,
1175 task_support: self.task_support,
1176 state: self.state,
1177 handler: self.handler,
1178 input_schema: self.input_schema,
1179 layer: tower::layer::util::Stack::new(layer, self.layer),
1180 _phantom: PhantomData,
1181 }
1182 }
1183
1184 pub fn guard<G>(
1188 self,
1189 guard: G,
1190 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1191 where
1192 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1193 {
1194 self.layer(GuardLayer::new(guard))
1195 }
1196}
1197
1198pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1200 pub(crate) name: String,
1201 pub(crate) title: Option<String>,
1202 pub(crate) description: Option<String>,
1203 pub(crate) output_schema: Option<Value>,
1204 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1205 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1206 pub(crate) task_support: crate::protocol::TaskSupportMode,
1207 pub(crate) state: S,
1208 pub(crate) handler: F,
1209 pub(crate) _phantom: PhantomData<(T, I)>,
1210}
1211
1212impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1213where
1214 S: Clone + Send + Sync + 'static,
1215 F: TypedExtractorHandler<S, T, I> + Clone,
1216 T: Send + Sync + 'static,
1217 I: JsonSchema + Send + Sync + 'static,
1218{
1219 pub fn build(self) -> Tool {
1221 let input_schema = {
1222 let schema = schemars::schema_for!(I);
1223 serde_json::to_value(schema).unwrap_or_else(|_| {
1224 serde_json::json!({
1225 "type": "object"
1226 })
1227 })
1228 };
1229
1230 let handler = TypedExtractorToolHandler {
1231 state: self.state,
1232 handler: self.handler,
1233 input_schema: input_schema.clone(),
1234 _phantom: PhantomData,
1235 };
1236
1237 let handler_service = crate::tool::ToolHandlerService::new(handler);
1238 let catch_error = ToolCatchError::new(handler_service);
1239 let service = BoxCloneService::new(catch_error);
1240
1241 Tool {
1242 name: self.name,
1243 title: self.title,
1244 description: self.description,
1245 output_schema: self.output_schema,
1246 icons: self.icons,
1247 annotations: self.annotations,
1248 task_support: self.task_support,
1249 service,
1250 input_schema,
1251 }
1252 }
1253}
1254
1255struct TypedExtractorToolHandler<S, F, T, I> {
1257 state: S,
1258 handler: F,
1259 input_schema: Value,
1260 _phantom: PhantomData<(T, I)>,
1261}
1262
1263impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1264where
1265 S: Clone + Send + Sync + 'static,
1266 F: TypedExtractorHandler<S, T, I> + Clone,
1267 T: Send + Sync + 'static,
1268 I: JsonSchema + Send + Sync + 'static,
1269{
1270 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1271 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1272 self.call_with_context(ctx, args)
1273 }
1274
1275 fn call_with_context(
1276 &self,
1277 ctx: RequestContext,
1278 args: Value,
1279 ) -> BoxFuture<'_, Result<CallToolResult>> {
1280 let state = self.state.clone();
1281 let handler = self.handler.clone();
1282 Box::pin(async move { handler.call(ctx, state, args).await })
1283 }
1284
1285 fn uses_context(&self) -> bool {
1286 true
1287 }
1288
1289 fn input_schema(&self) -> Value {
1290 self.input_schema.clone()
1291 }
1292}
1293
1294#[cfg(test)]
1295mod tests {
1296 use super::*;
1297 use crate::protocol::RequestId;
1298 use schemars::JsonSchema;
1299 use serde::Deserialize;
1300 use std::sync::Arc;
1301
1302 #[derive(Debug, Deserialize, JsonSchema)]
1303 struct TestInput {
1304 name: String,
1305 count: i32,
1306 }
1307
1308 #[test]
1309 fn test_json_extraction() {
1310 let args = serde_json::json!({"name": "test", "count": 42});
1311 let ctx = RequestContext::new(RequestId::Number(1));
1312
1313 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1314 assert!(result.is_ok());
1315 let Json(input) = result.unwrap();
1316 assert_eq!(input.name, "test");
1317 assert_eq!(input.count, 42);
1318 }
1319
1320 #[test]
1321 fn test_json_extraction_error() {
1322 let args = serde_json::json!({"name": "test"}); let ctx = RequestContext::new(RequestId::Number(1));
1324
1325 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1326 assert!(result.is_err());
1327 let rejection = result.unwrap_err();
1328 assert!(rejection.message().contains("count"));
1330 }
1331
1332 #[test]
1333 fn test_state_extraction() {
1334 let args = serde_json::json!({});
1335 let ctx = RequestContext::new(RequestId::Number(1));
1336 let state = Arc::new("my-state".to_string());
1337
1338 let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1339 assert!(result.is_ok());
1340 let State(extracted) = result.unwrap();
1341 assert_eq!(*extracted, "my-state");
1342 }
1343
1344 #[test]
1345 fn test_context_extraction() {
1346 let args = serde_json::json!({});
1347 let ctx = RequestContext::new(RequestId::Number(42));
1348
1349 let result = Context::from_tool_request(&ctx, &(), &args);
1350 assert!(result.is_ok());
1351 let extracted = result.unwrap();
1352 assert_eq!(*extracted.request_id(), RequestId::Number(42));
1353 }
1354
1355 #[test]
1356 fn test_raw_args_extraction() {
1357 let args = serde_json::json!({"foo": "bar", "baz": 123});
1358 let ctx = RequestContext::new(RequestId::Number(1));
1359
1360 let result = RawArgs::from_tool_request(&ctx, &(), &args);
1361 assert!(result.is_ok());
1362 let RawArgs(extracted) = result.unwrap();
1363 assert_eq!(extracted["foo"], "bar");
1364 assert_eq!(extracted["baz"], 123);
1365 }
1366
1367 #[test]
1368 fn test_extension_extraction() {
1369 use crate::context::Extensions;
1370
1371 #[derive(Clone, Debug, PartialEq)]
1372 struct DatabasePool {
1373 url: String,
1374 }
1375
1376 let args = serde_json::json!({});
1377
1378 let mut extensions = Extensions::new();
1380 extensions.insert(Arc::new(DatabasePool {
1381 url: "postgres://localhost".to_string(),
1382 }));
1383
1384 let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1386
1387 let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1389 assert!(result.is_ok());
1390 let Extension(pool) = result.unwrap();
1391 assert_eq!(pool.url, "postgres://localhost");
1392 }
1393
1394 #[test]
1395 fn test_extension_extraction_missing() {
1396 #[derive(Clone, Debug)]
1397 struct NotPresent;
1398
1399 let args = serde_json::json!({});
1400 let ctx = RequestContext::new(RequestId::Number(1));
1401
1402 let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1404 assert!(result.is_err());
1405 let rejection = result.unwrap_err();
1406 assert!(rejection.type_name().contains("NotPresent"));
1408 }
1409
1410 #[tokio::test]
1411 async fn test_single_extractor_handler() {
1412 let handler = |Json(input): Json<TestInput>| async move {
1413 Ok(CallToolResult::text(format!(
1414 "{}: {}",
1415 input.name, input.count
1416 )))
1417 };
1418
1419 let ctx = RequestContext::new(RequestId::Number(1));
1420 let args = serde_json::json!({"name": "test", "count": 5});
1421
1422 let result: Result<CallToolResult> =
1424 ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1425 assert!(result.is_ok());
1426 }
1427
1428 #[tokio::test]
1429 async fn test_two_extractor_handler() {
1430 let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1431 Ok(CallToolResult::text(format!(
1432 "{}: {} - {}",
1433 state, input.name, input.count
1434 )))
1435 };
1436
1437 let ctx = RequestContext::new(RequestId::Number(1));
1438 let state = Arc::new("prefix".to_string());
1439 let args = serde_json::json!({"name": "test", "count": 5});
1440
1441 let result: Result<CallToolResult> = ExtractorHandler::<
1443 Arc<String>,
1444 (State<Arc<String>>, Json<TestInput>),
1445 >::call(handler, ctx, state, args)
1446 .await;
1447 assert!(result.is_ok());
1448 }
1449
1450 #[tokio::test]
1451 async fn test_three_extractor_handler() {
1452 let handler = |State(state): State<Arc<String>>,
1453 ctx: Context,
1454 Json(input): Json<TestInput>| async move {
1455 assert!(!ctx.is_cancelled());
1457 Ok(CallToolResult::text(format!(
1458 "{}: {} - {}",
1459 state, input.name, input.count
1460 )))
1461 };
1462
1463 let ctx = RequestContext::new(RequestId::Number(1));
1464 let state = Arc::new("prefix".to_string());
1465 let args = serde_json::json!({"name": "test", "count": 5});
1466
1467 let result: Result<CallToolResult> = ExtractorHandler::<
1469 Arc<String>,
1470 (State<Arc<String>>, Context, Json<TestInput>),
1471 >::call(handler, ctx, state, args)
1472 .await;
1473 assert!(result.is_ok());
1474 }
1475
1476 #[test]
1477 fn test_json_schema_generation() {
1478 let schema = Json::<TestInput>::schema();
1479 assert!(schema.is_some());
1480 let schema = schema.unwrap();
1481 assert!(schema.get("properties").is_some());
1482 }
1483
1484 #[test]
1485 fn test_rejection_into_error() {
1486 let rejection = Rejection::new("test error");
1487 let error: Error = rejection.into();
1488 assert!(error.to_string().contains("test error"));
1489 }
1490
1491 #[test]
1492 fn test_json_rejection() {
1493 let rejection = JsonRejection::new("missing field `name`");
1495 assert_eq!(rejection.message(), "missing field `name`");
1496 assert!(rejection.path().is_none());
1497 assert!(rejection.to_string().contains("Invalid input"));
1498
1499 let rejection = JsonRejection::with_path("expected string", "users[0].name");
1501 assert_eq!(rejection.message(), "expected string");
1502 assert_eq!(rejection.path(), Some("users[0].name"));
1503 assert!(rejection.to_string().contains("users[0].name"));
1504
1505 let error: Error = rejection.into();
1507 assert!(error.to_string().contains("users[0].name"));
1508 }
1509
1510 #[test]
1511 fn test_json_rejection_from_serde_error() {
1512 #[derive(Debug, serde::Deserialize)]
1514 struct TestStruct {
1515 #[allow(dead_code)]
1516 name: String,
1517 }
1518
1519 let result: std::result::Result<TestStruct, _> =
1520 serde_json::from_value(serde_json::json!({"count": 42}));
1521 assert!(result.is_err());
1522
1523 let rejection: JsonRejection = result.unwrap_err().into();
1524 assert!(rejection.message().contains("name"));
1525 }
1526
1527 #[test]
1528 fn test_extension_rejection() {
1529 let rejection = ExtensionRejection::not_found::<String>();
1531 assert!(rejection.type_name().contains("String"));
1532 assert!(rejection.to_string().contains("not found"));
1533 assert!(rejection.to_string().contains("with_state"));
1534
1535 let error: Error = rejection.into();
1537 assert!(error.to_string().contains("not found"));
1538 }
1539
1540 #[tokio::test]
1541 async fn test_tool_builder_extractor_handler() {
1542 use crate::ToolBuilder;
1543
1544 let state = Arc::new("shared-state".to_string());
1545
1546 let tool =
1547 ToolBuilder::new("test_extractor")
1548 .description("Test extractor handler")
1549 .extractor_handler(
1550 state,
1551 |State(state): State<Arc<String>>,
1552 ctx: Context,
1553 Json(input): Json<TestInput>| async move {
1554 assert!(!ctx.is_cancelled());
1555 Ok(CallToolResult::text(format!(
1556 "{}: {} - {}",
1557 state, input.name, input.count
1558 )))
1559 },
1560 )
1561 .build();
1562
1563 assert_eq!(tool.name, "test_extractor");
1564 assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1565
1566 let result = tool
1568 .call(serde_json::json!({"name": "test", "count": 42}))
1569 .await;
1570 assert!(!result.is_error);
1571 }
1572
1573 #[tokio::test]
1574 async fn test_tool_builder_extractor_handler_typed() {
1575 use crate::ToolBuilder;
1576
1577 let state = Arc::new("typed-state".to_string());
1578
1579 let tool = ToolBuilder::new("test_typed")
1580 .description("Test typed extractor handler")
1581 .extractor_handler_typed::<_, _, _, TestInput>(
1582 state,
1583 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1584 Ok(CallToolResult::text(format!(
1585 "{}: {} - {}",
1586 state, input.name, input.count
1587 )))
1588 },
1589 )
1590 .build();
1591
1592 assert_eq!(tool.name, "test_typed");
1593
1594 let def = tool.definition();
1596 let schema = def.input_schema;
1597 assert!(schema.get("properties").is_some());
1598
1599 let result = tool
1601 .call(serde_json::json!({"name": "world", "count": 99}))
1602 .await;
1603 assert!(!result.is_error);
1604 }
1605
1606 #[tokio::test]
1607 async fn test_extractor_handler_auto_schema() {
1608 use crate::ToolBuilder;
1609
1610 let state = Arc::new("auto-schema".to_string());
1611
1612 let tool = ToolBuilder::new("test_auto_schema")
1614 .description("Test auto schema detection")
1615 .extractor_handler(
1616 state,
1617 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1618 Ok(CallToolResult::text(format!(
1619 "{}: {} - {}",
1620 state, input.name, input.count
1621 )))
1622 },
1623 )
1624 .build();
1625
1626 let def = tool.definition();
1628 let schema = def.input_schema;
1629 assert!(
1630 schema.get("properties").is_some(),
1631 "Schema should have properties from TestInput, got: {}",
1632 schema
1633 );
1634 let props = schema.get("properties").unwrap();
1635 assert!(
1636 props.get("name").is_some(),
1637 "Schema should have 'name' property"
1638 );
1639 assert!(
1640 props.get("count").is_some(),
1641 "Schema should have 'count' property"
1642 );
1643
1644 let result = tool
1646 .call(serde_json::json!({"name": "world", "count": 99}))
1647 .await;
1648 assert!(!result.is_error);
1649 }
1650
1651 #[test]
1652 fn test_extractor_handler_no_json_fallback() {
1653 use crate::ToolBuilder;
1654
1655 let tool = ToolBuilder::new("test_no_json")
1657 .description("Test no json fallback")
1658 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1659 Ok(CallToolResult::json(args))
1660 })
1661 .build();
1662
1663 let def = tool.definition();
1664 let schema = def.input_schema;
1665 assert_eq!(
1666 schema.get("type").and_then(|v| v.as_str()),
1667 Some("object"),
1668 "Schema should be generic object"
1669 );
1670 assert_eq!(
1671 schema.get("additionalProperties").and_then(|v| v.as_bool()),
1672 Some(true),
1673 "Schema should allow additional properties"
1674 );
1675 assert!(
1677 schema.get("properties").is_none(),
1678 "Generic schema should not have specific properties"
1679 );
1680 }
1681
1682 #[tokio::test]
1683 async fn test_extractor_handler_with_layer() {
1684 use crate::ToolBuilder;
1685 use std::time::Duration;
1686 use tower::timeout::TimeoutLayer;
1687
1688 let state = Arc::new("layered".to_string());
1689
1690 let tool = ToolBuilder::new("test_extractor_layer")
1691 .description("Test extractor handler with layer")
1692 .extractor_handler(
1693 state,
1694 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1695 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1696 },
1697 )
1698 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1699 .build();
1700
1701 let result = tool
1703 .call(serde_json::json!({"name": "test", "count": 1}))
1704 .await;
1705 assert!(!result.is_error);
1706 assert_eq!(result.first_text().unwrap(), "layered: test");
1707
1708 let def = tool.definition();
1710 let schema = def.input_schema;
1711 assert!(
1712 schema.get("properties").is_some(),
1713 "Schema should have properties even with layer"
1714 );
1715 }
1716
1717 #[tokio::test]
1718 async fn test_extractor_handler_with_timeout_layer() {
1719 use crate::ToolBuilder;
1720 use std::time::Duration;
1721 use tower::timeout::TimeoutLayer;
1722
1723 let tool = ToolBuilder::new("test_extractor_timeout")
1724 .description("Test extractor handler timeout")
1725 .extractor_handler((), |Json(input): Json<TestInput>| async move {
1726 tokio::time::sleep(Duration::from_millis(200)).await;
1727 Ok(CallToolResult::text(input.name.to_string()))
1728 })
1729 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1730 .build();
1731
1732 let result = tool
1734 .call(serde_json::json!({"name": "slow", "count": 1}))
1735 .await;
1736 assert!(result.is_error);
1737 let msg = result.first_text().unwrap().to_lowercase();
1738 assert!(
1739 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1740 "Expected timeout error, got: {}",
1741 msg
1742 );
1743 }
1744
1745 #[tokio::test]
1746 async fn test_extractor_handler_with_multiple_layers() {
1747 use crate::ToolBuilder;
1748 use std::time::Duration;
1749 use tower::limit::ConcurrencyLimitLayer;
1750 use tower::timeout::TimeoutLayer;
1751
1752 let state = Arc::new("multi".to_string());
1753
1754 let tool = ToolBuilder::new("test_multi_layer")
1755 .description("Test multiple layers")
1756 .extractor_handler(
1757 state,
1758 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1759 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1760 },
1761 )
1762 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1763 .layer(ConcurrencyLimitLayer::new(10))
1764 .build();
1765
1766 let result = tool
1767 .call(serde_json::json!({"name": "test", "count": 1}))
1768 .await;
1769 assert!(!result.is_error);
1770 assert_eq!(result.first_text().unwrap(), "multi: test");
1771 }
1772}