1use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88#[derive(Debug, Clone)]
98pub struct Rejection {
99 message: String,
100}
101
102impl Rejection {
103 pub fn new(message: impl Into<String>) -> Self {
105 Self {
106 message: message.into(),
107 }
108 }
109
110 pub fn message(&self) -> &str {
112 &self.message
113 }
114}
115
116impl std::fmt::Display for Rejection {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{}", self.message)
119 }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125 fn from(rejection: Rejection) -> Self {
126 Error::tool(rejection.message)
127 }
128}
129
130#[derive(Debug, Clone)]
144pub struct JsonRejection {
145 message: String,
146 path: Option<String>,
148}
149
150impl JsonRejection {
151 pub fn new(message: impl Into<String>) -> Self {
153 Self {
154 message: message.into(),
155 path: None,
156 }
157 }
158
159 pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161 Self {
162 message: message.into(),
163 path: Some(path.into()),
164 }
165 }
166
167 pub fn message(&self) -> &str {
169 &self.message
170 }
171
172 pub fn path(&self) -> Option<&str> {
174 self.path.as_deref()
175 }
176}
177
178impl std::fmt::Display for JsonRejection {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 if let Some(path) = &self.path {
181 write!(f, "Invalid input at `{}`: {}", path, self.message)
182 } else {
183 write!(f, "Invalid input: {}", self.message)
184 }
185 }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191 fn from(rejection: JsonRejection) -> Self {
192 Error::tool(rejection.to_string())
193 }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197 fn from(err: serde_json::Error) -> Self {
198 let path = if err.is_data() {
200 None
203 } else {
204 None
205 };
206
207 Self {
208 message: err.to_string(),
209 path,
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229 type_name: &'static str,
230}
231
232impl ExtensionRejection {
233 pub fn not_found<T>() -> Self {
235 Self {
236 type_name: std::any::type_name::<T>(),
237 }
238 }
239
240 pub fn type_name(&self) -> &'static str {
242 self.type_name
243 }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(
249 f,
250 "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251 self.type_name
252 )
253 }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259 fn from(rejection: ExtensionRejection) -> Self {
260 Error::tool(rejection.to_string())
261 }
262}
263
264pub trait FromToolRequest<S = ()>: Sized {
295 type Rejection: Into<Error>;
297
298 fn from_tool_request(
306 ctx: &RequestContext,
307 state: &S,
308 args: &Value,
309 ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346 type Target = T;
347
348 fn deref(&self) -> &Self::Target {
349 &self.0
350 }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355 T: DeserializeOwned,
356{
357 type Rejection = JsonRejection;
358
359 fn from_tool_request(
360 _ctx: &RequestContext,
361 _state: &S,
362 args: &Value,
363 ) -> std::result::Result<Self, Self::Rejection> {
364 serde_json::from_value(args.clone())
365 .map(Json)
366 .map_err(JsonRejection::from)
367 }
368}
369
370#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398 type Target = T;
399
400 fn deref(&self) -> &Self::Target {
401 &self.0
402 }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406 type Rejection = Rejection;
407
408 fn from_tool_request(
409 _ctx: &RequestContext,
410 state: &S,
411 _args: &Value,
412 ) -> std::result::Result<Self, Self::Rejection> {
413 Ok(State(state.clone()))
414 }
415}
416
417#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441 pub fn into_inner(self) -> RequestContext {
443 self.0
444 }
445}
446
447impl Deref for Context {
448 type Target = RequestContext;
449
450 fn deref(&self) -> &Self::Target {
451 &self.0
452 }
453}
454
455impl<S> FromToolRequest<S> for Context {
456 type Rejection = Rejection;
457
458 fn from_tool_request(
459 ctx: &RequestContext,
460 _state: &S,
461 _args: &Value,
462 ) -> std::result::Result<Self, Self::Rejection> {
463 Ok(Context(ctx.clone()))
464 }
465}
466
467#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488 type Target = Value;
489
490 fn deref(&self) -> &Self::Target {
491 &self.0
492 }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496 type Rejection = Rejection;
497
498 fn from_tool_request(
499 _ctx: &RequestContext,
500 _state: &S,
501 args: &Value,
502 ) -> std::result::Result<Self, Self::Rejection> {
503 Ok(RawArgs(args.clone()))
504 }
505}
506
507#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557 type Target = T;
558
559 fn deref(&self) -> &Self::Target {
560 &self.0
561 }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566 T: Clone + Send + Sync + 'static,
567{
568 type Rejection = ExtensionRejection;
569
570 fn from_tool_request(
571 ctx: &RequestContext,
572 _state: &S,
573 _args: &Value,
574 ) -> std::result::Result<Self, Self::Rejection> {
575 ctx.extension::<T>()
576 .cloned()
577 .map(Extension)
578 .ok_or_else(ExtensionRejection::not_found::<T>)
579 }
580}
581
582pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592 type Future: Future<Output = Result<CallToolResult>> + Send;
594
595 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598 fn input_schema() -> Value;
602}
603
604impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607 S: Clone + Send + Sync + 'static,
608 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609 Fut: Future<Output = Result<CallToolResult>> + Send,
610 T1: FromToolRequest<S> + HasSchema + Send,
611{
612 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615 Box::pin(async move {
616 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617 self(t1).await
618 })
619 }
620
621 fn input_schema() -> Value {
622 if let Some(schema) = T1::schema() {
623 return schema;
624 }
625 serde_json::json!({
626 "type": "object",
627 "additionalProperties": true
628 })
629 }
630}
631
632impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635 S: Clone + Send + Sync + 'static,
636 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637 Fut: Future<Output = Result<CallToolResult>> + Send,
638 T1: FromToolRequest<S> + HasSchema + Send,
639 T2: FromToolRequest<S> + HasSchema + Send,
640{
641 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644 Box::pin(async move {
645 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647 self(t1, t2).await
648 })
649 }
650
651 fn input_schema() -> Value {
652 if let Some(schema) = T2::schema() {
653 return schema;
654 }
655 if let Some(schema) = T1::schema() {
656 return schema;
657 }
658 serde_json::json!({
659 "type": "object",
660 "additionalProperties": true
661 })
662 }
663}
664
665impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668 S: Clone + Send + Sync + 'static,
669 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670 Fut: Future<Output = Result<CallToolResult>> + Send,
671 T1: FromToolRequest<S> + HasSchema + Send,
672 T2: FromToolRequest<S> + HasSchema + Send,
673 T3: FromToolRequest<S> + HasSchema + Send,
674{
675 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678 Box::pin(async move {
679 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682 self(t1, t2, t3).await
683 })
684 }
685
686 fn input_schema() -> Value {
687 if let Some(schema) = T3::schema() {
688 return schema;
689 }
690 if let Some(schema) = T2::schema() {
691 return schema;
692 }
693 if let Some(schema) = T1::schema() {
694 return schema;
695 }
696 serde_json::json!({
697 "type": "object",
698 "additionalProperties": true
699 })
700 }
701}
702
703impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706 S: Clone + Send + Sync + 'static,
707 F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708 Fut: Future<Output = Result<CallToolResult>> + Send,
709 T1: FromToolRequest<S> + HasSchema + Send,
710 T2: FromToolRequest<S> + HasSchema + Send,
711 T3: FromToolRequest<S> + HasSchema + Send,
712 T4: FromToolRequest<S> + HasSchema + Send,
713{
714 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717 Box::pin(async move {
718 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722 self(t1, t2, t3, t4).await
723 })
724 }
725
726 fn input_schema() -> Value {
727 if let Some(schema) = T4::schema() {
728 return schema;
729 }
730 if let Some(schema) = T3::schema() {
731 return schema;
732 }
733 if let Some(schema) = T2::schema() {
734 return schema;
735 }
736 if let Some(schema) = T1::schema() {
737 return schema;
738 }
739 serde_json::json!({
740 "type": "object",
741 "additionalProperties": true
742 })
743 }
744}
745
746impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749 S: Clone + Send + Sync + 'static,
750 F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751 Fut: Future<Output = Result<CallToolResult>> + Send,
752 T1: FromToolRequest<S> + HasSchema + Send,
753 T2: FromToolRequest<S> + HasSchema + Send,
754 T3: FromToolRequest<S> + HasSchema + Send,
755 T4: FromToolRequest<S> + HasSchema + Send,
756 T5: FromToolRequest<S> + HasSchema + Send,
757{
758 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761 Box::pin(async move {
762 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766 let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767 self(t1, t2, t3, t4, t5).await
768 })
769 }
770
771 fn input_schema() -> Value {
772 if let Some(schema) = T5::schema() {
773 return schema;
774 }
775 if let Some(schema) = T4::schema() {
776 return schema;
777 }
778 if let Some(schema) = T3::schema() {
779 return schema;
780 }
781 if let Some(schema) = T2::schema() {
782 return schema;
783 }
784 if let Some(schema) = T1::schema() {
785 return schema;
786 }
787 serde_json::json!({
788 "type": "object",
789 "additionalProperties": true
790 })
791 }
792}
793
794pub trait HasSchema {
800 fn schema() -> Option<Value>;
801}
802
803impl<T: JsonSchema> HasSchema for Json<T> {
804 fn schema() -> Option<Value> {
805 let schema = schemars::schema_for!(T);
806 serde_json::to_value(schema).ok()
807 }
808}
809
810impl HasSchema for Context {
812 fn schema() -> Option<Value> {
813 None
814 }
815}
816
817impl HasSchema for RawArgs {
818 fn schema() -> Option<Value> {
819 None
820 }
821}
822
823impl<T> HasSchema for State<T> {
824 fn schema() -> Option<Value> {
825 None
826 }
827}
828
829impl<T> HasSchema for Extension<T> {
830 fn schema() -> Option<Value> {
831 None
832 }
833}
834
835pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
844where
845 I: JsonSchema,
846{
847 type Future: Future<Output = Result<CallToolResult>> + Send;
849
850 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
852}
853
854impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
856where
857 S: Clone + Send + Sync + 'static,
858 F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
859 Fut: Future<Output = Result<CallToolResult>> + Send,
860 T: DeserializeOwned + JsonSchema + Send,
861{
862 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
863
864 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
865 Box::pin(async move {
866 let t1 =
867 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
868 self(t1).await
869 })
870 }
871}
872
873impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
875where
876 S: Clone + Send + Sync + 'static,
877 F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
878 Fut: Future<Output = Result<CallToolResult>> + Send,
879 T1: FromToolRequest<S> + Send,
880 T: DeserializeOwned + JsonSchema + Send,
881{
882 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
883
884 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
885 Box::pin(async move {
886 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
887 let t2 =
888 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
889 self(t1, t2).await
890 })
891 }
892}
893
894impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
896where
897 S: Clone + Send + Sync + 'static,
898 F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
899 Fut: Future<Output = Result<CallToolResult>> + Send,
900 T1: FromToolRequest<S> + Send,
901 T2: FromToolRequest<S> + Send,
902 T: DeserializeOwned + JsonSchema + Send,
903{
904 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
905
906 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
907 Box::pin(async move {
908 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
909 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
910 let t3 =
911 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
912 self(t1, t2, t3).await
913 })
914 }
915}
916
917impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
919where
920 S: Clone + Send + Sync + 'static,
921 F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
922 Fut: Future<Output = Result<CallToolResult>> + Send,
923 T1: FromToolRequest<S> + Send,
924 T2: FromToolRequest<S> + Send,
925 T3: FromToolRequest<S> + Send,
926 T: DeserializeOwned + JsonSchema + Send,
927{
928 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
929
930 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
931 Box::pin(async move {
932 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
933 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
934 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
935 let t4 =
936 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
937 self(t1, t2, t3, t4).await
938 })
939 }
940}
941
942use crate::tool::{
947 BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
948};
949use tower::util::BoxCloneService;
950use tower_service::Service;
951
952pub(crate) struct ExtractorToolHandler<S, F, T> {
954 state: S,
955 handler: F,
956 input_schema: Value,
957 _phantom: PhantomData<T>,
958}
959
960impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
961where
962 S: Clone + Send + Sync + 'static,
963 F: ExtractorHandler<S, T> + Clone,
964 T: Send + Sync + 'static,
965{
966 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
967 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
968 self.call_with_context(ctx, args)
969 }
970
971 fn call_with_context(
972 &self,
973 ctx: RequestContext,
974 args: Value,
975 ) -> BoxFuture<'_, Result<CallToolResult>> {
976 let state = self.state.clone();
977 let handler = self.handler.clone();
978 Box::pin(async move { handler.call(ctx, state, args).await })
979 }
980
981 fn uses_context(&self) -> bool {
982 true
983 }
984
985 fn input_schema(&self) -> Value {
986 self.input_schema.clone()
987 }
988}
989
990pub struct ToolBuilderWithExtractor<S, F, T> {
992 pub(crate) name: String,
993 pub(crate) title: Option<String>,
994 pub(crate) description: Option<String>,
995 pub(crate) output_schema: Option<Value>,
996 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
997 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
998 pub(crate) state: S,
999 pub(crate) handler: F,
1000 pub(crate) input_schema: Value,
1001 pub(crate) _phantom: PhantomData<T>,
1002}
1003
1004impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1005where
1006 S: Clone + Send + Sync + 'static,
1007 F: ExtractorHandler<S, T> + Clone,
1008 T: Send + Sync + 'static,
1009{
1010 pub fn build(self) -> Tool {
1012 let handler = ExtractorToolHandler {
1013 state: self.state,
1014 handler: self.handler,
1015 input_schema: self.input_schema.clone(),
1016 _phantom: PhantomData,
1017 };
1018
1019 let handler_service = ToolHandlerService::new(handler);
1020 let catch_error = ToolCatchError::new(handler_service);
1021 let service = BoxCloneService::new(catch_error);
1022
1023 Tool {
1024 name: self.name,
1025 title: self.title,
1026 description: self.description,
1027 output_schema: self.output_schema,
1028 icons: self.icons,
1029 annotations: self.annotations,
1030 service,
1031 input_schema: self.input_schema,
1032 }
1033 }
1034
1035 pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1071 ToolBuilderWithExtractorLayer {
1072 name: self.name,
1073 title: self.title,
1074 description: self.description,
1075 output_schema: self.output_schema,
1076 icons: self.icons,
1077 annotations: self.annotations,
1078 state: self.state,
1079 handler: self.handler,
1080 input_schema: self.input_schema,
1081 layer,
1082 _phantom: PhantomData,
1083 }
1084 }
1085
1086 pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1090 where
1091 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1092 {
1093 self.layer(GuardLayer::new(guard))
1094 }
1095}
1096
1097pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1101 name: String,
1102 title: Option<String>,
1103 description: Option<String>,
1104 output_schema: Option<Value>,
1105 icons: Option<Vec<crate::protocol::ToolIcon>>,
1106 annotations: Option<crate::protocol::ToolAnnotations>,
1107 state: S,
1108 handler: F,
1109 input_schema: Value,
1110 layer: L,
1111 _phantom: PhantomData<T>,
1112}
1113
1114#[allow(private_bounds)]
1115impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1116where
1117 S: Clone + Send + Sync + 'static,
1118 F: ExtractorHandler<S, T> + Clone,
1119 T: Send + Sync + 'static,
1120 L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1121 + Clone
1122 + Send
1123 + Sync
1124 + 'static,
1125 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1126 <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1127 <L::Service as Service<ToolRequest>>::Future: Send,
1128{
1129 pub fn build(self) -> Tool {
1131 let handler = ExtractorToolHandler {
1132 state: self.state,
1133 handler: self.handler,
1134 input_schema: self.input_schema.clone(),
1135 _phantom: PhantomData,
1136 };
1137
1138 let handler_service = ToolHandlerService::new(handler);
1139 let layered = self.layer.layer(handler_service);
1140 let catch_error = ToolCatchError::new(layered);
1141 let service = BoxCloneService::new(catch_error);
1142
1143 Tool {
1144 name: self.name,
1145 title: self.title,
1146 description: self.description,
1147 output_schema: self.output_schema,
1148 icons: self.icons,
1149 annotations: self.annotations,
1150 service,
1151 input_schema: self.input_schema,
1152 }
1153 }
1154
1155 pub fn layer<L2>(
1160 self,
1161 layer: L2,
1162 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1163 ToolBuilderWithExtractorLayer {
1164 name: self.name,
1165 title: self.title,
1166 description: self.description,
1167 output_schema: self.output_schema,
1168 icons: self.icons,
1169 annotations: self.annotations,
1170 state: self.state,
1171 handler: self.handler,
1172 input_schema: self.input_schema,
1173 layer: tower::layer::util::Stack::new(layer, self.layer),
1174 _phantom: PhantomData,
1175 }
1176 }
1177
1178 pub fn guard<G>(
1182 self,
1183 guard: G,
1184 ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1185 where
1186 G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1187 {
1188 self.layer(GuardLayer::new(guard))
1189 }
1190}
1191
1192pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1194 pub(crate) name: String,
1195 pub(crate) title: Option<String>,
1196 pub(crate) description: Option<String>,
1197 pub(crate) output_schema: Option<Value>,
1198 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1199 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1200 pub(crate) state: S,
1201 pub(crate) handler: F,
1202 pub(crate) _phantom: PhantomData<(T, I)>,
1203}
1204
1205impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1206where
1207 S: Clone + Send + Sync + 'static,
1208 F: TypedExtractorHandler<S, T, I> + Clone,
1209 T: Send + Sync + 'static,
1210 I: JsonSchema + Send + Sync + 'static,
1211{
1212 pub fn build(self) -> Tool {
1214 let input_schema = {
1215 let schema = schemars::schema_for!(I);
1216 serde_json::to_value(schema).unwrap_or_else(|_| {
1217 serde_json::json!({
1218 "type": "object"
1219 })
1220 })
1221 };
1222
1223 let handler = TypedExtractorToolHandler {
1224 state: self.state,
1225 handler: self.handler,
1226 input_schema: input_schema.clone(),
1227 _phantom: PhantomData,
1228 };
1229
1230 let handler_service = crate::tool::ToolHandlerService::new(handler);
1231 let catch_error = ToolCatchError::new(handler_service);
1232 let service = BoxCloneService::new(catch_error);
1233
1234 Tool {
1235 name: self.name,
1236 title: self.title,
1237 description: self.description,
1238 output_schema: self.output_schema,
1239 icons: self.icons,
1240 annotations: self.annotations,
1241 service,
1242 input_schema,
1243 }
1244 }
1245}
1246
1247struct TypedExtractorToolHandler<S, F, T, I> {
1249 state: S,
1250 handler: F,
1251 input_schema: Value,
1252 _phantom: PhantomData<(T, I)>,
1253}
1254
1255impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1256where
1257 S: Clone + Send + Sync + 'static,
1258 F: TypedExtractorHandler<S, T, I> + Clone,
1259 T: Send + Sync + 'static,
1260 I: JsonSchema + Send + Sync + 'static,
1261{
1262 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1263 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1264 self.call_with_context(ctx, args)
1265 }
1266
1267 fn call_with_context(
1268 &self,
1269 ctx: RequestContext,
1270 args: Value,
1271 ) -> BoxFuture<'_, Result<CallToolResult>> {
1272 let state = self.state.clone();
1273 let handler = self.handler.clone();
1274 Box::pin(async move { handler.call(ctx, state, args).await })
1275 }
1276
1277 fn uses_context(&self) -> bool {
1278 true
1279 }
1280
1281 fn input_schema(&self) -> Value {
1282 self.input_schema.clone()
1283 }
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288 use super::*;
1289 use crate::protocol::RequestId;
1290 use schemars::JsonSchema;
1291 use serde::Deserialize;
1292 use std::sync::Arc;
1293
1294 #[derive(Debug, Deserialize, JsonSchema)]
1295 struct TestInput {
1296 name: String,
1297 count: i32,
1298 }
1299
1300 #[test]
1301 fn test_json_extraction() {
1302 let args = serde_json::json!({"name": "test", "count": 42});
1303 let ctx = RequestContext::new(RequestId::Number(1));
1304
1305 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1306 assert!(result.is_ok());
1307 let Json(input) = result.unwrap();
1308 assert_eq!(input.name, "test");
1309 assert_eq!(input.count, 42);
1310 }
1311
1312 #[test]
1313 fn test_json_extraction_error() {
1314 let args = serde_json::json!({"name": "test"}); let ctx = RequestContext::new(RequestId::Number(1));
1316
1317 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1318 assert!(result.is_err());
1319 let rejection = result.unwrap_err();
1320 assert!(rejection.message().contains("count"));
1322 }
1323
1324 #[test]
1325 fn test_state_extraction() {
1326 let args = serde_json::json!({});
1327 let ctx = RequestContext::new(RequestId::Number(1));
1328 let state = Arc::new("my-state".to_string());
1329
1330 let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1331 assert!(result.is_ok());
1332 let State(extracted) = result.unwrap();
1333 assert_eq!(*extracted, "my-state");
1334 }
1335
1336 #[test]
1337 fn test_context_extraction() {
1338 let args = serde_json::json!({});
1339 let ctx = RequestContext::new(RequestId::Number(42));
1340
1341 let result = Context::from_tool_request(&ctx, &(), &args);
1342 assert!(result.is_ok());
1343 let extracted = result.unwrap();
1344 assert_eq!(*extracted.request_id(), RequestId::Number(42));
1345 }
1346
1347 #[test]
1348 fn test_raw_args_extraction() {
1349 let args = serde_json::json!({"foo": "bar", "baz": 123});
1350 let ctx = RequestContext::new(RequestId::Number(1));
1351
1352 let result = RawArgs::from_tool_request(&ctx, &(), &args);
1353 assert!(result.is_ok());
1354 let RawArgs(extracted) = result.unwrap();
1355 assert_eq!(extracted["foo"], "bar");
1356 assert_eq!(extracted["baz"], 123);
1357 }
1358
1359 #[test]
1360 fn test_extension_extraction() {
1361 use crate::context::Extensions;
1362
1363 #[derive(Clone, Debug, PartialEq)]
1364 struct DatabasePool {
1365 url: String,
1366 }
1367
1368 let args = serde_json::json!({});
1369
1370 let mut extensions = Extensions::new();
1372 extensions.insert(Arc::new(DatabasePool {
1373 url: "postgres://localhost".to_string(),
1374 }));
1375
1376 let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1378
1379 let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1381 assert!(result.is_ok());
1382 let Extension(pool) = result.unwrap();
1383 assert_eq!(pool.url, "postgres://localhost");
1384 }
1385
1386 #[test]
1387 fn test_extension_extraction_missing() {
1388 #[derive(Clone, Debug)]
1389 struct NotPresent;
1390
1391 let args = serde_json::json!({});
1392 let ctx = RequestContext::new(RequestId::Number(1));
1393
1394 let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1396 assert!(result.is_err());
1397 let rejection = result.unwrap_err();
1398 assert!(rejection.type_name().contains("NotPresent"));
1400 }
1401
1402 #[tokio::test]
1403 async fn test_single_extractor_handler() {
1404 let handler = |Json(input): Json<TestInput>| async move {
1405 Ok(CallToolResult::text(format!(
1406 "{}: {}",
1407 input.name, input.count
1408 )))
1409 };
1410
1411 let ctx = RequestContext::new(RequestId::Number(1));
1412 let args = serde_json::json!({"name": "test", "count": 5});
1413
1414 let result: Result<CallToolResult> =
1416 ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1417 assert!(result.is_ok());
1418 }
1419
1420 #[tokio::test]
1421 async fn test_two_extractor_handler() {
1422 let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1423 Ok(CallToolResult::text(format!(
1424 "{}: {} - {}",
1425 state, input.name, input.count
1426 )))
1427 };
1428
1429 let ctx = RequestContext::new(RequestId::Number(1));
1430 let state = Arc::new("prefix".to_string());
1431 let args = serde_json::json!({"name": "test", "count": 5});
1432
1433 let result: Result<CallToolResult> = ExtractorHandler::<
1435 Arc<String>,
1436 (State<Arc<String>>, Json<TestInput>),
1437 >::call(handler, ctx, state, args)
1438 .await;
1439 assert!(result.is_ok());
1440 }
1441
1442 #[tokio::test]
1443 async fn test_three_extractor_handler() {
1444 let handler = |State(state): State<Arc<String>>,
1445 ctx: Context,
1446 Json(input): Json<TestInput>| async move {
1447 assert!(!ctx.is_cancelled());
1449 Ok(CallToolResult::text(format!(
1450 "{}: {} - {}",
1451 state, input.name, input.count
1452 )))
1453 };
1454
1455 let ctx = RequestContext::new(RequestId::Number(1));
1456 let state = Arc::new("prefix".to_string());
1457 let args = serde_json::json!({"name": "test", "count": 5});
1458
1459 let result: Result<CallToolResult> = ExtractorHandler::<
1461 Arc<String>,
1462 (State<Arc<String>>, Context, Json<TestInput>),
1463 >::call(handler, ctx, state, args)
1464 .await;
1465 assert!(result.is_ok());
1466 }
1467
1468 #[test]
1469 fn test_json_schema_generation() {
1470 let schema = Json::<TestInput>::schema();
1471 assert!(schema.is_some());
1472 let schema = schema.unwrap();
1473 assert!(schema.get("properties").is_some());
1474 }
1475
1476 #[test]
1477 fn test_rejection_into_error() {
1478 let rejection = Rejection::new("test error");
1479 let error: Error = rejection.into();
1480 assert!(error.to_string().contains("test error"));
1481 }
1482
1483 #[test]
1484 fn test_json_rejection() {
1485 let rejection = JsonRejection::new("missing field `name`");
1487 assert_eq!(rejection.message(), "missing field `name`");
1488 assert!(rejection.path().is_none());
1489 assert!(rejection.to_string().contains("Invalid input"));
1490
1491 let rejection = JsonRejection::with_path("expected string", "users[0].name");
1493 assert_eq!(rejection.message(), "expected string");
1494 assert_eq!(rejection.path(), Some("users[0].name"));
1495 assert!(rejection.to_string().contains("users[0].name"));
1496
1497 let error: Error = rejection.into();
1499 assert!(error.to_string().contains("users[0].name"));
1500 }
1501
1502 #[test]
1503 fn test_json_rejection_from_serde_error() {
1504 #[derive(Debug, serde::Deserialize)]
1506 struct TestStruct {
1507 #[allow(dead_code)]
1508 name: String,
1509 }
1510
1511 let result: std::result::Result<TestStruct, _> =
1512 serde_json::from_value(serde_json::json!({"count": 42}));
1513 assert!(result.is_err());
1514
1515 let rejection: JsonRejection = result.unwrap_err().into();
1516 assert!(rejection.message().contains("name"));
1517 }
1518
1519 #[test]
1520 fn test_extension_rejection() {
1521 let rejection = ExtensionRejection::not_found::<String>();
1523 assert!(rejection.type_name().contains("String"));
1524 assert!(rejection.to_string().contains("not found"));
1525 assert!(rejection.to_string().contains("with_state"));
1526
1527 let error: Error = rejection.into();
1529 assert!(error.to_string().contains("not found"));
1530 }
1531
1532 #[tokio::test]
1533 async fn test_tool_builder_extractor_handler() {
1534 use crate::ToolBuilder;
1535
1536 let state = Arc::new("shared-state".to_string());
1537
1538 let tool =
1539 ToolBuilder::new("test_extractor")
1540 .description("Test extractor handler")
1541 .extractor_handler(
1542 state,
1543 |State(state): State<Arc<String>>,
1544 ctx: Context,
1545 Json(input): Json<TestInput>| async move {
1546 assert!(!ctx.is_cancelled());
1547 Ok(CallToolResult::text(format!(
1548 "{}: {} - {}",
1549 state, input.name, input.count
1550 )))
1551 },
1552 )
1553 .build();
1554
1555 assert_eq!(tool.name, "test_extractor");
1556 assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1557
1558 let result = tool
1560 .call(serde_json::json!({"name": "test", "count": 42}))
1561 .await;
1562 assert!(!result.is_error);
1563 }
1564
1565 #[tokio::test]
1566 async fn test_tool_builder_extractor_handler_typed() {
1567 use crate::ToolBuilder;
1568
1569 let state = Arc::new("typed-state".to_string());
1570
1571 let tool = ToolBuilder::new("test_typed")
1572 .description("Test typed extractor handler")
1573 .extractor_handler_typed::<_, _, _, TestInput>(
1574 state,
1575 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1576 Ok(CallToolResult::text(format!(
1577 "{}: {} - {}",
1578 state, input.name, input.count
1579 )))
1580 },
1581 )
1582 .build();
1583
1584 assert_eq!(tool.name, "test_typed");
1585
1586 let def = tool.definition();
1588 let schema = def.input_schema;
1589 assert!(schema.get("properties").is_some());
1590
1591 let result = tool
1593 .call(serde_json::json!({"name": "world", "count": 99}))
1594 .await;
1595 assert!(!result.is_error);
1596 }
1597
1598 #[tokio::test]
1599 async fn test_extractor_handler_auto_schema() {
1600 use crate::ToolBuilder;
1601
1602 let state = Arc::new("auto-schema".to_string());
1603
1604 let tool = ToolBuilder::new("test_auto_schema")
1606 .description("Test auto schema detection")
1607 .extractor_handler(
1608 state,
1609 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1610 Ok(CallToolResult::text(format!(
1611 "{}: {} - {}",
1612 state, input.name, input.count
1613 )))
1614 },
1615 )
1616 .build();
1617
1618 let def = tool.definition();
1620 let schema = def.input_schema;
1621 assert!(
1622 schema.get("properties").is_some(),
1623 "Schema should have properties from TestInput, got: {}",
1624 schema
1625 );
1626 let props = schema.get("properties").unwrap();
1627 assert!(
1628 props.get("name").is_some(),
1629 "Schema should have 'name' property"
1630 );
1631 assert!(
1632 props.get("count").is_some(),
1633 "Schema should have 'count' property"
1634 );
1635
1636 let result = tool
1638 .call(serde_json::json!({"name": "world", "count": 99}))
1639 .await;
1640 assert!(!result.is_error);
1641 }
1642
1643 #[test]
1644 fn test_extractor_handler_no_json_fallback() {
1645 use crate::ToolBuilder;
1646
1647 let tool = ToolBuilder::new("test_no_json")
1649 .description("Test no json fallback")
1650 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1651 Ok(CallToolResult::json(args))
1652 })
1653 .build();
1654
1655 let def = tool.definition();
1656 let schema = def.input_schema;
1657 assert_eq!(
1658 schema.get("type").and_then(|v| v.as_str()),
1659 Some("object"),
1660 "Schema should be generic object"
1661 );
1662 assert_eq!(
1663 schema.get("additionalProperties").and_then(|v| v.as_bool()),
1664 Some(true),
1665 "Schema should allow additional properties"
1666 );
1667 assert!(
1669 schema.get("properties").is_none(),
1670 "Generic schema should not have specific properties"
1671 );
1672 }
1673
1674 #[tokio::test]
1675 async fn test_extractor_handler_with_layer() {
1676 use crate::ToolBuilder;
1677 use std::time::Duration;
1678 use tower::timeout::TimeoutLayer;
1679
1680 let state = Arc::new("layered".to_string());
1681
1682 let tool = ToolBuilder::new("test_extractor_layer")
1683 .description("Test extractor handler with layer")
1684 .extractor_handler(
1685 state,
1686 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1687 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1688 },
1689 )
1690 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1691 .build();
1692
1693 let result = tool
1695 .call(serde_json::json!({"name": "test", "count": 1}))
1696 .await;
1697 assert!(!result.is_error);
1698 assert_eq!(result.first_text().unwrap(), "layered: test");
1699
1700 let def = tool.definition();
1702 let schema = def.input_schema;
1703 assert!(
1704 schema.get("properties").is_some(),
1705 "Schema should have properties even with layer"
1706 );
1707 }
1708
1709 #[tokio::test]
1710 async fn test_extractor_handler_with_timeout_layer() {
1711 use crate::ToolBuilder;
1712 use std::time::Duration;
1713 use tower::timeout::TimeoutLayer;
1714
1715 let tool = ToolBuilder::new("test_extractor_timeout")
1716 .description("Test extractor handler timeout")
1717 .extractor_handler((), |Json(input): Json<TestInput>| async move {
1718 tokio::time::sleep(Duration::from_millis(200)).await;
1719 Ok(CallToolResult::text(input.name.to_string()))
1720 })
1721 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1722 .build();
1723
1724 let result = tool
1726 .call(serde_json::json!({"name": "slow", "count": 1}))
1727 .await;
1728 assert!(result.is_error);
1729 let msg = result.first_text().unwrap().to_lowercase();
1730 assert!(
1731 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1732 "Expected timeout error, got: {}",
1733 msg
1734 );
1735 }
1736
1737 #[tokio::test]
1738 async fn test_extractor_handler_with_multiple_layers() {
1739 use crate::ToolBuilder;
1740 use std::time::Duration;
1741 use tower::limit::ConcurrencyLimitLayer;
1742 use tower::timeout::TimeoutLayer;
1743
1744 let state = Arc::new("multi".to_string());
1745
1746 let tool = ToolBuilder::new("test_multi_layer")
1747 .description("Test multiple layers")
1748 .extractor_handler(
1749 state,
1750 |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1751 Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1752 },
1753 )
1754 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1755 .layer(ConcurrencyLimitLayer::new(10))
1756 .build();
1757
1758 let result = tool
1759 .call(serde_json::json!({"name": "test", "count": 1}))
1760 .await;
1761 assert!(!result.is_error);
1762 assert_eq!(result.first_text().unwrap(), "multi: test");
1763 }
1764}