1use std::future::Future;
77use std::marker::PhantomData;
78use std::ops::Deref;
79use std::pin::Pin;
80
81use schemars::JsonSchema;
82use serde::de::DeserializeOwned;
83use serde_json::Value;
84
85use crate::context::RequestContext;
86use crate::error::{Error, Result};
87use crate::protocol::CallToolResult;
88
89#[derive(Debug, Clone)]
99pub struct Rejection {
100 message: String,
101}
102
103impl Rejection {
104 pub fn new(message: impl Into<String>) -> Self {
106 Self {
107 message: message.into(),
108 }
109 }
110
111 pub fn message(&self) -> &str {
113 &self.message
114 }
115}
116
117impl std::fmt::Display for Rejection {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{}", self.message)
120 }
121}
122
123impl std::error::Error for Rejection {}
124
125impl From<Rejection> for Error {
126 fn from(rejection: Rejection) -> Self {
127 Error::tool(rejection.message)
128 }
129}
130
131#[derive(Debug, Clone)]
145pub struct JsonRejection {
146 message: String,
147 path: Option<String>,
149}
150
151impl JsonRejection {
152 pub fn new(message: impl Into<String>) -> Self {
154 Self {
155 message: message.into(),
156 path: None,
157 }
158 }
159
160 pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
162 Self {
163 message: message.into(),
164 path: Some(path.into()),
165 }
166 }
167
168 pub fn message(&self) -> &str {
170 &self.message
171 }
172
173 pub fn path(&self) -> Option<&str> {
175 self.path.as_deref()
176 }
177}
178
179impl std::fmt::Display for JsonRejection {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 if let Some(path) = &self.path {
182 write!(f, "Invalid input at `{}`: {}", path, self.message)
183 } else {
184 write!(f, "Invalid input: {}", self.message)
185 }
186 }
187}
188
189impl std::error::Error for JsonRejection {}
190
191impl From<JsonRejection> for Error {
192 fn from(rejection: JsonRejection) -> Self {
193 Error::tool(rejection.to_string())
194 }
195}
196
197impl From<serde_json::Error> for JsonRejection {
198 fn from(err: serde_json::Error) -> Self {
199 let path = if err.is_data() {
201 None
204 } else {
205 None
206 };
207
208 Self {
209 message: err.to_string(),
210 path,
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
229pub struct ExtensionRejection {
230 type_name: &'static str,
231}
232
233impl ExtensionRejection {
234 pub fn not_found<T>() -> Self {
236 Self {
237 type_name: std::any::type_name::<T>(),
238 }
239 }
240
241 pub fn type_name(&self) -> &'static str {
243 self.type_name
244 }
245}
246
247impl std::fmt::Display for ExtensionRejection {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 write!(
250 f,
251 "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
252 self.type_name
253 )
254 }
255}
256
257impl std::error::Error for ExtensionRejection {}
258
259impl From<ExtensionRejection> for Error {
260 fn from(rejection: ExtensionRejection) -> Self {
261 Error::tool(rejection.to_string())
262 }
263}
264
265pub trait FromToolRequest<S = ()>: Sized {
296 type Rejection: Into<Error>;
298
299 fn from_tool_request(
307 ctx: &RequestContext,
308 state: &S,
309 args: &Value,
310 ) -> std::result::Result<Self, Self::Rejection>;
311}
312
313#[derive(Debug, Clone, Copy)]
344pub struct Json<T>(pub T);
345
346impl<T> Deref for Json<T> {
347 type Target = T;
348
349 fn deref(&self) -> &Self::Target {
350 &self.0
351 }
352}
353
354impl<S, T> FromToolRequest<S> for Json<T>
355where
356 T: DeserializeOwned,
357{
358 type Rejection = JsonRejection;
359
360 fn from_tool_request(
361 _ctx: &RequestContext,
362 _state: &S,
363 args: &Value,
364 ) -> std::result::Result<Self, Self::Rejection> {
365 serde_json::from_value(args.clone())
366 .map(Json)
367 .map_err(JsonRejection::from)
368 }
369}
370
371#[derive(Debug, Clone, Copy)]
396pub struct State<T>(pub T);
397
398impl<T> Deref for State<T> {
399 type Target = T;
400
401 fn deref(&self) -> &Self::Target {
402 &self.0
403 }
404}
405
406impl<S: Clone> FromToolRequest<S> for State<S> {
407 type Rejection = Rejection;
408
409 fn from_tool_request(
410 _ctx: &RequestContext,
411 state: &S,
412 _args: &Value,
413 ) -> std::result::Result<Self, Self::Rejection> {
414 Ok(State(state.clone()))
415 }
416}
417
418#[derive(Debug, Clone)]
439pub struct Context(RequestContext);
440
441impl Context {
442 pub fn into_inner(self) -> RequestContext {
444 self.0
445 }
446}
447
448impl Deref for Context {
449 type Target = RequestContext;
450
451 fn deref(&self) -> &Self::Target {
452 &self.0
453 }
454}
455
456impl<S> FromToolRequest<S> for Context {
457 type Rejection = Rejection;
458
459 fn from_tool_request(
460 ctx: &RequestContext,
461 _state: &S,
462 _args: &Value,
463 ) -> std::result::Result<Self, Self::Rejection> {
464 Ok(Context(ctx.clone()))
465 }
466}
467
468#[derive(Debug, Clone)]
486pub struct RawArgs(pub Value);
487
488impl Deref for RawArgs {
489 type Target = Value;
490
491 fn deref(&self) -> &Self::Target {
492 &self.0
493 }
494}
495
496impl<S> FromToolRequest<S> for RawArgs {
497 type Rejection = Rejection;
498
499 fn from_tool_request(
500 _ctx: &RequestContext,
501 _state: &S,
502 args: &Value,
503 ) -> std::result::Result<Self, Self::Rejection> {
504 Ok(RawArgs(args.clone()))
505 }
506}
507
508#[derive(Debug, Clone)]
556pub struct Extension<T>(pub T);
557
558impl<T> Deref for Extension<T> {
559 type Target = T;
560
561 fn deref(&self) -> &Self::Target {
562 &self.0
563 }
564}
565
566impl<S, T> FromToolRequest<S> for Extension<T>
567where
568 T: Clone + Send + Sync + 'static,
569{
570 type Rejection = ExtensionRejection;
571
572 fn from_tool_request(
573 ctx: &RequestContext,
574 _state: &S,
575 _args: &Value,
576 ) -> std::result::Result<Self, Self::Rejection> {
577 ctx.extension::<T>()
578 .cloned()
579 .map(Extension)
580 .ok_or_else(ExtensionRejection::not_found::<T>)
581 }
582}
583
584pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
594 type Future: Future<Output = Result<CallToolResult>> + Send;
596
597 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
599
600 fn input_schema() -> Value;
604}
605
606impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
608where
609 S: Clone + Send + Sync + 'static,
610 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
611 Fut: Future<Output = Result<CallToolResult>> + Send,
612 T1: FromToolRequest<S> + Send,
613{
614 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
615
616 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
617 Box::pin(async move {
618 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
619 self(t1).await
620 })
621 }
622
623 fn input_schema() -> Value {
624 serde_json::json!({
626 "type": "object",
627 "additionalProperties": true
628 })
629 }
630}
631
632impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635 S: Clone + Send + Sync + 'static,
636 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637 Fut: Future<Output = Result<CallToolResult>> + Send,
638 T1: FromToolRequest<S> + Send,
639 T2: FromToolRequest<S> + Send,
640{
641 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644 Box::pin(async move {
645 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647 self(t1, t2).await
648 })
649 }
650
651 fn input_schema() -> Value {
652 serde_json::json!({
653 "type": "object",
654 "additionalProperties": true
655 })
656 }
657}
658
659impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
661where
662 S: Clone + Send + Sync + 'static,
663 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
664 Fut: Future<Output = Result<CallToolResult>> + Send,
665 T1: FromToolRequest<S> + Send,
666 T2: FromToolRequest<S> + Send,
667 T3: FromToolRequest<S> + Send,
668{
669 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
670
671 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
672 Box::pin(async move {
673 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
674 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
675 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
676 self(t1, t2, t3).await
677 })
678 }
679
680 fn input_schema() -> Value {
681 serde_json::json!({
682 "type": "object",
683 "additionalProperties": true
684 })
685 }
686}
687
688impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
690where
691 S: Clone + Send + Sync + 'static,
692 F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
693 Fut: Future<Output = Result<CallToolResult>> + Send,
694 T1: FromToolRequest<S> + Send,
695 T2: FromToolRequest<S> + Send,
696 T3: FromToolRequest<S> + Send,
697 T4: FromToolRequest<S> + Send,
698{
699 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
700
701 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
702 Box::pin(async move {
703 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
704 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
705 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
706 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
707 self(t1, t2, t3, t4).await
708 })
709 }
710
711 fn input_schema() -> Value {
712 serde_json::json!({
713 "type": "object",
714 "additionalProperties": true
715 })
716 }
717}
718
719impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
721where
722 S: Clone + Send + Sync + 'static,
723 F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
724 Fut: Future<Output = Result<CallToolResult>> + Send,
725 T1: FromToolRequest<S> + Send,
726 T2: FromToolRequest<S> + Send,
727 T3: FromToolRequest<S> + Send,
728 T4: FromToolRequest<S> + Send,
729 T5: FromToolRequest<S> + Send,
730{
731 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
732
733 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
734 Box::pin(async move {
735 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
736 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
737 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
738 let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
739 let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
740 self(t1, t2, t3, t4, t5).await
741 })
742 }
743
744 fn input_schema() -> Value {
745 serde_json::json!({
746 "type": "object",
747 "additionalProperties": true
748 })
749 }
750}
751
752pub trait HasSchema {
758 fn schema() -> Option<Value>;
759}
760
761impl<T: JsonSchema> HasSchema for Json<T> {
762 fn schema() -> Option<Value> {
763 let schema = schemars::schema_for!(T);
764 serde_json::to_value(schema).ok()
765 }
766}
767
768impl HasSchema for Context {
770 fn schema() -> Option<Value> {
771 None
772 }
773}
774
775impl HasSchema for RawArgs {
776 fn schema() -> Option<Value> {
777 None
778 }
779}
780
781impl<T> HasSchema for State<T> {
782 fn schema() -> Option<Value> {
783 None
784 }
785}
786
787pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
796where
797 I: JsonSchema,
798{
799 type Future: Future<Output = Result<CallToolResult>> + Send;
801
802 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
804}
805
806impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
808where
809 S: Clone + Send + Sync + 'static,
810 F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
811 Fut: Future<Output = Result<CallToolResult>> + Send,
812 T: DeserializeOwned + JsonSchema + Send,
813{
814 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
815
816 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
817 Box::pin(async move {
818 let t1 =
819 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
820 self(t1).await
821 })
822 }
823}
824
825impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
827where
828 S: Clone + Send + Sync + 'static,
829 F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
830 Fut: Future<Output = Result<CallToolResult>> + Send,
831 T1: FromToolRequest<S> + Send,
832 T: DeserializeOwned + JsonSchema + Send,
833{
834 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
835
836 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
837 Box::pin(async move {
838 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
839 let t2 =
840 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
841 self(t1, t2).await
842 })
843 }
844}
845
846impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
848where
849 S: Clone + Send + Sync + 'static,
850 F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
851 Fut: Future<Output = Result<CallToolResult>> + Send,
852 T1: FromToolRequest<S> + Send,
853 T2: FromToolRequest<S> + Send,
854 T: DeserializeOwned + JsonSchema + Send,
855{
856 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
857
858 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
859 Box::pin(async move {
860 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
861 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
862 let t3 =
863 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
864 self(t1, t2, t3).await
865 })
866 }
867}
868
869impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
871where
872 S: Clone + Send + Sync + 'static,
873 F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
874 Fut: Future<Output = Result<CallToolResult>> + Send,
875 T1: FromToolRequest<S> + Send,
876 T2: FromToolRequest<S> + Send,
877 T3: FromToolRequest<S> + Send,
878 T: DeserializeOwned + JsonSchema + Send,
879{
880 type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
881
882 fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
883 Box::pin(async move {
884 let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
885 let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
886 let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
887 let t4 =
888 Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
889 self(t1, t2, t3, t4).await
890 })
891 }
892}
893
894use crate::tool::{BoxFuture, Tool, ToolCatchError, ToolHandler, validate_tool_name};
899use tower::util::BoxCloneService;
900
901pub(crate) struct ExtractorToolHandler<S, F, T> {
903 state: S,
904 handler: F,
905 input_schema: Value,
906 _phantom: PhantomData<T>,
907}
908
909impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
910where
911 S: Clone + Send + Sync + 'static,
912 F: ExtractorHandler<S, T> + Clone,
913 T: Send + Sync + 'static,
914{
915 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
916 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
917 self.call_with_context(ctx, args)
918 }
919
920 fn call_with_context(
921 &self,
922 ctx: RequestContext,
923 args: Value,
924 ) -> BoxFuture<'_, Result<CallToolResult>> {
925 let state = self.state.clone();
926 let handler = self.handler.clone();
927 Box::pin(async move { handler.call(ctx, state, args).await })
928 }
929
930 fn uses_context(&self) -> bool {
931 true
932 }
933
934 fn input_schema(&self) -> Value {
935 self.input_schema.clone()
936 }
937}
938
939pub struct ToolBuilderWithExtractor<S, F, T> {
941 pub(crate) name: String,
942 pub(crate) title: Option<String>,
943 pub(crate) description: Option<String>,
944 pub(crate) output_schema: Option<Value>,
945 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
946 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
947 pub(crate) state: S,
948 pub(crate) handler: F,
949 pub(crate) input_schema: Value,
950 pub(crate) _phantom: PhantomData<T>,
951}
952
953impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
954where
955 S: Clone + Send + Sync + 'static,
956 F: ExtractorHandler<S, T> + Clone,
957 T: Send + Sync + 'static,
958{
959 pub fn build(self) -> Result<Tool> {
963 validate_tool_name(&self.name)?;
964
965 let handler = ExtractorToolHandler {
966 state: self.state,
967 handler: self.handler,
968 input_schema: self.input_schema.clone(),
969 _phantom: PhantomData,
970 };
971
972 let handler_service = crate::tool::ToolHandlerService::new(handler);
973 let catch_error = ToolCatchError::new(handler_service);
974 let service = BoxCloneService::new(catch_error);
975
976 Ok(Tool {
977 name: self.name,
978 title: self.title,
979 description: self.description,
980 output_schema: self.output_schema,
981 icons: self.icons,
982 annotations: self.annotations,
983 service,
984 input_schema: self.input_schema,
985 })
986 }
987}
988
989pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
991 pub(crate) name: String,
992 pub(crate) title: Option<String>,
993 pub(crate) description: Option<String>,
994 pub(crate) output_schema: Option<Value>,
995 pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
996 pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
997 pub(crate) state: S,
998 pub(crate) handler: F,
999 pub(crate) _phantom: PhantomData<(T, I)>,
1000}
1001
1002impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1003where
1004 S: Clone + Send + Sync + 'static,
1005 F: TypedExtractorHandler<S, T, I> + Clone,
1006 T: Send + Sync + 'static,
1007 I: JsonSchema + Send + Sync + 'static,
1008{
1009 pub fn build(self) -> Result<Tool> {
1013 validate_tool_name(&self.name)?;
1014
1015 let input_schema = {
1016 let schema = schemars::schema_for!(I);
1017 serde_json::to_value(schema).unwrap_or_else(|_| {
1018 serde_json::json!({
1019 "type": "object"
1020 })
1021 })
1022 };
1023
1024 let handler = TypedExtractorToolHandler {
1025 state: self.state,
1026 handler: self.handler,
1027 input_schema: input_schema.clone(),
1028 _phantom: PhantomData,
1029 };
1030
1031 let handler_service = crate::tool::ToolHandlerService::new(handler);
1032 let catch_error = ToolCatchError::new(handler_service);
1033 let service = BoxCloneService::new(catch_error);
1034
1035 Ok(Tool {
1036 name: self.name,
1037 title: self.title,
1038 description: self.description,
1039 output_schema: self.output_schema,
1040 icons: self.icons,
1041 annotations: self.annotations,
1042 service,
1043 input_schema,
1044 })
1045 }
1046}
1047
1048struct TypedExtractorToolHandler<S, F, T, I> {
1050 state: S,
1051 handler: F,
1052 input_schema: Value,
1053 _phantom: PhantomData<(T, I)>,
1054}
1055
1056impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1057where
1058 S: Clone + Send + Sync + 'static,
1059 F: TypedExtractorHandler<S, T, I> + Clone,
1060 T: Send + Sync + 'static,
1061 I: JsonSchema + Send + Sync + 'static,
1062{
1063 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1064 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1065 self.call_with_context(ctx, args)
1066 }
1067
1068 fn call_with_context(
1069 &self,
1070 ctx: RequestContext,
1071 args: Value,
1072 ) -> BoxFuture<'_, Result<CallToolResult>> {
1073 let state = self.state.clone();
1074 let handler = self.handler.clone();
1075 Box::pin(async move { handler.call(ctx, state, args).await })
1076 }
1077
1078 fn uses_context(&self) -> bool {
1079 true
1080 }
1081
1082 fn input_schema(&self) -> Value {
1083 self.input_schema.clone()
1084 }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089 use super::*;
1090 use crate::protocol::RequestId;
1091 use schemars::JsonSchema;
1092 use serde::Deserialize;
1093 use std::sync::Arc;
1094
1095 #[derive(Debug, Deserialize, JsonSchema)]
1096 struct TestInput {
1097 name: String,
1098 count: i32,
1099 }
1100
1101 #[test]
1102 fn test_json_extraction() {
1103 let args = serde_json::json!({"name": "test", "count": 42});
1104 let ctx = RequestContext::new(RequestId::Number(1));
1105
1106 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1107 assert!(result.is_ok());
1108 let Json(input) = result.unwrap();
1109 assert_eq!(input.name, "test");
1110 assert_eq!(input.count, 42);
1111 }
1112
1113 #[test]
1114 fn test_json_extraction_error() {
1115 let args = serde_json::json!({"name": "test"}); let ctx = RequestContext::new(RequestId::Number(1));
1117
1118 let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1119 assert!(result.is_err());
1120 let rejection = result.unwrap_err();
1121 assert!(rejection.message().contains("count"));
1123 }
1124
1125 #[test]
1126 fn test_state_extraction() {
1127 let args = serde_json::json!({});
1128 let ctx = RequestContext::new(RequestId::Number(1));
1129 let state = Arc::new("my-state".to_string());
1130
1131 let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1132 assert!(result.is_ok());
1133 let State(extracted) = result.unwrap();
1134 assert_eq!(*extracted, "my-state");
1135 }
1136
1137 #[test]
1138 fn test_context_extraction() {
1139 let args = serde_json::json!({});
1140 let ctx = RequestContext::new(RequestId::Number(42));
1141
1142 let result = Context::from_tool_request(&ctx, &(), &args);
1143 assert!(result.is_ok());
1144 let extracted = result.unwrap();
1145 assert_eq!(*extracted.request_id(), RequestId::Number(42));
1146 }
1147
1148 #[test]
1149 fn test_raw_args_extraction() {
1150 let args = serde_json::json!({"foo": "bar", "baz": 123});
1151 let ctx = RequestContext::new(RequestId::Number(1));
1152
1153 let result = RawArgs::from_tool_request(&ctx, &(), &args);
1154 assert!(result.is_ok());
1155 let RawArgs(extracted) = result.unwrap();
1156 assert_eq!(extracted["foo"], "bar");
1157 assert_eq!(extracted["baz"], 123);
1158 }
1159
1160 #[test]
1161 fn test_extension_extraction() {
1162 use crate::context::Extensions;
1163
1164 #[derive(Clone, Debug, PartialEq)]
1165 struct DatabasePool {
1166 url: String,
1167 }
1168
1169 let args = serde_json::json!({});
1170
1171 let mut extensions = Extensions::new();
1173 extensions.insert(Arc::new(DatabasePool {
1174 url: "postgres://localhost".to_string(),
1175 }));
1176
1177 let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1179
1180 let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1182 assert!(result.is_ok());
1183 let Extension(pool) = result.unwrap();
1184 assert_eq!(pool.url, "postgres://localhost");
1185 }
1186
1187 #[test]
1188 fn test_extension_extraction_missing() {
1189 #[derive(Clone, Debug)]
1190 struct NotPresent;
1191
1192 let args = serde_json::json!({});
1193 let ctx = RequestContext::new(RequestId::Number(1));
1194
1195 let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1197 assert!(result.is_err());
1198 let rejection = result.unwrap_err();
1199 assert!(rejection.type_name().contains("NotPresent"));
1201 }
1202
1203 #[tokio::test]
1204 async fn test_single_extractor_handler() {
1205 let handler = |Json(input): Json<TestInput>| async move {
1206 Ok(CallToolResult::text(format!(
1207 "{}: {}",
1208 input.name, input.count
1209 )))
1210 };
1211
1212 let ctx = RequestContext::new(RequestId::Number(1));
1213 let args = serde_json::json!({"name": "test", "count": 5});
1214
1215 let result: Result<CallToolResult> =
1217 ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1218 assert!(result.is_ok());
1219 }
1220
1221 #[tokio::test]
1222 async fn test_two_extractor_handler() {
1223 let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1224 Ok(CallToolResult::text(format!(
1225 "{}: {} - {}",
1226 state, input.name, input.count
1227 )))
1228 };
1229
1230 let ctx = RequestContext::new(RequestId::Number(1));
1231 let state = Arc::new("prefix".to_string());
1232 let args = serde_json::json!({"name": "test", "count": 5});
1233
1234 let result: Result<CallToolResult> = ExtractorHandler::<
1236 Arc<String>,
1237 (State<Arc<String>>, Json<TestInput>),
1238 >::call(handler, ctx, state, args)
1239 .await;
1240 assert!(result.is_ok());
1241 }
1242
1243 #[tokio::test]
1244 async fn test_three_extractor_handler() {
1245 let handler = |State(state): State<Arc<String>>,
1246 ctx: Context,
1247 Json(input): Json<TestInput>| async move {
1248 assert!(!ctx.is_cancelled());
1250 Ok(CallToolResult::text(format!(
1251 "{}: {} - {}",
1252 state, input.name, input.count
1253 )))
1254 };
1255
1256 let ctx = RequestContext::new(RequestId::Number(1));
1257 let state = Arc::new("prefix".to_string());
1258 let args = serde_json::json!({"name": "test", "count": 5});
1259
1260 let result: Result<CallToolResult> = ExtractorHandler::<
1262 Arc<String>,
1263 (State<Arc<String>>, Context, Json<TestInput>),
1264 >::call(handler, ctx, state, args)
1265 .await;
1266 assert!(result.is_ok());
1267 }
1268
1269 #[test]
1270 fn test_json_schema_generation() {
1271 let schema = Json::<TestInput>::schema();
1272 assert!(schema.is_some());
1273 let schema = schema.unwrap();
1274 assert!(schema.get("properties").is_some());
1275 }
1276
1277 #[test]
1278 fn test_rejection_into_error() {
1279 let rejection = Rejection::new("test error");
1280 let error: Error = rejection.into();
1281 assert!(error.to_string().contains("test error"));
1282 }
1283
1284 #[test]
1285 fn test_json_rejection() {
1286 let rejection = JsonRejection::new("missing field `name`");
1288 assert_eq!(rejection.message(), "missing field `name`");
1289 assert!(rejection.path().is_none());
1290 assert!(rejection.to_string().contains("Invalid input"));
1291
1292 let rejection = JsonRejection::with_path("expected string", "users[0].name");
1294 assert_eq!(rejection.message(), "expected string");
1295 assert_eq!(rejection.path(), Some("users[0].name"));
1296 assert!(rejection.to_string().contains("users[0].name"));
1297
1298 let error: Error = rejection.into();
1300 assert!(error.to_string().contains("users[0].name"));
1301 }
1302
1303 #[test]
1304 fn test_json_rejection_from_serde_error() {
1305 #[derive(Debug, serde::Deserialize)]
1307 struct TestStruct {
1308 #[allow(dead_code)]
1309 name: String,
1310 }
1311
1312 let result: std::result::Result<TestStruct, _> =
1313 serde_json::from_value(serde_json::json!({"count": 42}));
1314 assert!(result.is_err());
1315
1316 let rejection: JsonRejection = result.unwrap_err().into();
1317 assert!(rejection.message().contains("name"));
1318 }
1319
1320 #[test]
1321 fn test_extension_rejection() {
1322 let rejection = ExtensionRejection::not_found::<String>();
1324 assert!(rejection.type_name().contains("String"));
1325 assert!(rejection.to_string().contains("not found"));
1326 assert!(rejection.to_string().contains("with_state"));
1327
1328 let error: Error = rejection.into();
1330 assert!(error.to_string().contains("not found"));
1331 }
1332
1333 #[tokio::test]
1334 async fn test_tool_builder_extractor_handler() {
1335 use crate::ToolBuilder;
1336
1337 let state = Arc::new("shared-state".to_string());
1338
1339 let tool =
1340 ToolBuilder::new("test_extractor")
1341 .description("Test extractor handler")
1342 .extractor_handler(
1343 state,
1344 |State(state): State<Arc<String>>,
1345 ctx: Context,
1346 Json(input): Json<TestInput>| async move {
1347 assert!(!ctx.is_cancelled());
1348 Ok(CallToolResult::text(format!(
1349 "{}: {} - {}",
1350 state, input.name, input.count
1351 )))
1352 },
1353 )
1354 .build()
1355 .expect("valid tool name");
1356
1357 assert_eq!(tool.name, "test_extractor");
1358 assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1359
1360 let result = tool
1362 .call(serde_json::json!({"name": "test", "count": 42}))
1363 .await;
1364 assert!(!result.is_error);
1365 }
1366
1367 #[tokio::test]
1368 async fn test_tool_builder_extractor_handler_typed() {
1369 use crate::ToolBuilder;
1370
1371 let state = Arc::new("typed-state".to_string());
1372
1373 let tool = ToolBuilder::new("test_typed")
1374 .description("Test typed extractor handler")
1375 .extractor_handler_typed::<_, _, _, TestInput>(
1376 state,
1377 |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1378 Ok(CallToolResult::text(format!(
1379 "{}: {} - {}",
1380 state, input.name, input.count
1381 )))
1382 },
1383 )
1384 .build()
1385 .expect("valid tool name");
1386
1387 assert_eq!(tool.name, "test_typed");
1388
1389 let def = tool.definition();
1391 let schema = def.input_schema;
1392 assert!(schema.get("properties").is_some());
1393
1394 let result = tool
1396 .call(serde_json::json!({"name": "world", "count": 99}))
1397 .await;
1398 assert!(!result.is_error);
1399 }
1400}