Skip to main content

tower_mcp/
extract.rs

1//! Extractor pattern for tool handlers
2//!
3//! This module provides an axum-inspired extractor pattern that makes state and context
4//! injection more declarative, reducing the combinatorial explosion of handler variants.
5//!
6//! # Overview
7//!
8//! Extractors implement [`FromToolRequest`], which extracts data from the tool request
9//! (context, state, and arguments). Multiple extractors can be combined in handler
10//! function parameters.
11//!
12//! # Built-in Extractors
13//!
14//! - [`Json<T>`] - Extract typed input from args (deserializes JSON)
15//! - [`State<T>`] - Extract shared state from per-tool state (cloned for each request)
16//! - [`Extension<T>`] - Extract data from router extensions (via `router.with_state()`)
17//! - [`Context`] - Extract the [`RequestContext`] for progress, cancellation, etc.
18//! - [`RawArgs`] - Extract raw `serde_json::Value` arguments
19//!
20//! ## State vs Extension
21//!
22//! - Use **`State<T>`** when state is passed directly to `extractor_handler()` (per-tool state)
23//! - Use **`Extension<T>`** when state is set via `McpRouter::with_state()` (router-level state)
24//!
25//! # Example
26//!
27//! ```rust
28//! use std::sync::Arc;
29//! use tower_mcp::{ToolBuilder, CallToolResult};
30//! use tower_mcp::extract::{Json, State, Context};
31//! use schemars::JsonSchema;
32//! use serde::Deserialize;
33//!
34//! #[derive(Clone)]
35//! struct AppState {
36//!     db_url: String,
37//! }
38//!
39//! #[derive(Debug, Deserialize, JsonSchema)]
40//! struct QueryInput {
41//!     query: String,
42//! }
43//!
44//! let state = Arc::new(AppState { db_url: "postgres://...".to_string() });
45//!
46//! let tool = ToolBuilder::new("search")
47//!     .description("Search the database")
48//!     .extractor_handler(state, |
49//!         State(db): State<Arc<AppState>>,
50//!         ctx: Context,
51//!         Json(input): Json<QueryInput>,
52//!     | async move {
53//!         // Check cancellation
54//!         if ctx.is_cancelled() {
55//!             return Ok(CallToolResult::error("Cancelled"));
56//!         }
57//!         // Report progress
58//!         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
59//!         // Use state
60//!         Ok(CallToolResult::text(format!("Searched {} with query: {}", db.db_url, input.query)))
61//!     })
62//!     .build();
63//! ```
64//!
65//! # Extractor Order
66//!
67//! The order of extractors in the function signature doesn't matter. Each extractor
68//! independently extracts its data from the request.
69//!
70//! # Error Handling
71//!
72//! If an extractor fails (e.g., JSON deserialization fails), the handler returns
73//! a `CallToolResult::error()` with the rejection message.
74
75use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88// =============================================================================
89// Rejection Types
90// =============================================================================
91
92/// A simple rejection with a message string.
93///
94/// This is a general-purpose rejection type for custom extractors.
95/// For more specific error information, use the typed rejection types
96/// like [`JsonRejection`] or [`ExtensionRejection`].
97#[derive(Debug, Clone)]
98pub struct Rejection {
99    message: String,
100}
101
102impl Rejection {
103    /// Create a new rejection with the given message.
104    pub fn new(message: impl Into<String>) -> Self {
105        Self {
106            message: message.into(),
107        }
108    }
109
110    /// Get the rejection message.
111    pub fn message(&self) -> &str {
112        &self.message
113    }
114}
115
116impl std::fmt::Display for Rejection {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}", self.message)
119    }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125    fn from(rejection: Rejection) -> Self {
126        Error::tool(rejection.message)
127    }
128}
129
130/// Rejection returned when JSON deserialization fails.
131///
132/// This rejection provides structured information about the deserialization
133/// error, including the path to the failing field when available.
134///
135/// # Example
136///
137/// ```rust
138/// use tower_mcp::extract::JsonRejection;
139///
140/// let rejection = JsonRejection::new("missing field `name`");
141/// assert!(rejection.message().contains("name"));
142/// ```
143#[derive(Debug, Clone)]
144pub struct JsonRejection {
145    message: String,
146    /// The serde error path, if available (e.g., "users[0].name")
147    path: Option<String>,
148}
149
150impl JsonRejection {
151    /// Create a new JSON rejection from a serde error.
152    pub fn new(message: impl Into<String>) -> Self {
153        Self {
154            message: message.into(),
155            path: None,
156        }
157    }
158
159    /// Create a JSON rejection with a path to the failing field.
160    pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161        Self {
162            message: message.into(),
163            path: Some(path.into()),
164        }
165    }
166
167    /// Get the error message.
168    pub fn message(&self) -> &str {
169        &self.message
170    }
171
172    /// Get the path to the failing field, if available.
173    pub fn path(&self) -> Option<&str> {
174        self.path.as_deref()
175    }
176}
177
178impl std::fmt::Display for JsonRejection {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        if let Some(path) = &self.path {
181            write!(f, "Invalid input at `{}`: {}", path, self.message)
182        } else {
183            write!(f, "Invalid input: {}", self.message)
184        }
185    }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191    fn from(rejection: JsonRejection) -> Self {
192        Error::tool(rejection.to_string())
193    }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197    fn from(err: serde_json::Error) -> Self {
198        // Try to extract path information from serde error
199        let path = if err.is_data() {
200            // serde_json provides line/column but not field path in the error itself
201            // The path is embedded in the message for some error types
202            None
203        } else {
204            None
205        };
206
207        Self {
208            message: err.to_string(),
209            path,
210        }
211    }
212}
213
214/// Rejection returned when an extension is not found.
215///
216/// This rejection is returned by the [`Extension`] extractor when the
217/// requested type is not present in the router's extensions.
218///
219/// # Example
220///
221/// ```rust
222/// use tower_mcp::extract::ExtensionRejection;
223///
224/// let rejection = ExtensionRejection::not_found::<String>();
225/// assert!(rejection.type_name().contains("String"));
226/// ```
227#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229    type_name: &'static str,
230}
231
232impl ExtensionRejection {
233    /// Create a rejection for a missing extension type.
234    pub fn not_found<T>() -> Self {
235        Self {
236            type_name: std::any::type_name::<T>(),
237        }
238    }
239
240    /// Get the type name of the missing extension.
241    pub fn type_name(&self) -> &'static str {
242        self.type_name
243    }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(
249            f,
250            "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251            self.type_name
252        )
253    }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259    fn from(rejection: ExtensionRejection) -> Self {
260        Error::tool(rejection.to_string())
261    }
262}
263
264/// Trait for extracting data from a tool request.
265///
266/// Implement this trait to create custom extractors that can be used
267/// in `extractor_handler` functions.
268///
269/// # Type Parameters
270///
271/// - `S` - The state type. Defaults to `()` for extractors that don't need state.
272///
273/// # Example
274///
275/// ```rust
276/// use tower_mcp::extract::{FromToolRequest, Rejection};
277/// use tower_mcp::RequestContext;
278/// use serde_json::Value;
279///
280/// struct RequestId(String);
281///
282/// impl<S> FromToolRequest<S> for RequestId {
283///     type Rejection = Rejection;
284///
285///     fn from_tool_request(
286///         ctx: &RequestContext,
287///         _state: &S,
288///         _args: &Value,
289///     ) -> Result<Self, Self::Rejection> {
290///         Ok(RequestId(format!("{:?}", ctx.request_id())))
291///     }
292/// }
293/// ```
294pub trait FromToolRequest<S = ()>: Sized {
295    /// The rejection type returned when extraction fails.
296    type Rejection: Into<Error>;
297
298    /// Extract this type from the tool request.
299    ///
300    /// # Arguments
301    ///
302    /// * `ctx` - The request context with progress, cancellation, etc.
303    /// * `state` - The shared state passed to the handler
304    /// * `args` - The raw JSON arguments to the tool
305    fn from_tool_request(
306        ctx: &RequestContext,
307        state: &S,
308        args: &Value,
309    ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312// =============================================================================
313// Built-in Extractors
314// =============================================================================
315
316/// Extract and deserialize JSON arguments into a typed struct.
317///
318/// This extractor deserializes the tool's JSON arguments into type `T`.
319/// The type must implement [`serde::de::DeserializeOwned`] and [`schemars::JsonSchema`].
320///
321/// # Example
322///
323/// ```rust
324/// use tower_mcp::extract::Json;
325/// use schemars::JsonSchema;
326/// use serde::Deserialize;
327///
328/// #[derive(Debug, Deserialize, JsonSchema)]
329/// struct MyInput {
330///     name: String,
331///     count: i32,
332/// }
333///
334/// // In an extractor handler:
335/// // |Json(input): Json<MyInput>| async move { ... }
336/// ```
337///
338/// # Rejection
339///
340/// Returns a [`JsonRejection`] if deserialization fails. The rejection contains
341/// the error message and potentially the path to the failing field.
342#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346    type Target = T;
347
348    fn deref(&self) -> &Self::Target {
349        &self.0
350    }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355    T: DeserializeOwned,
356{
357    type Rejection = JsonRejection;
358
359    fn from_tool_request(
360        _ctx: &RequestContext,
361        _state: &S,
362        args: &Value,
363    ) -> std::result::Result<Self, Self::Rejection> {
364        serde_json::from_value(args.clone())
365            .map(Json)
366            .map_err(JsonRejection::from)
367    }
368}
369
370/// Extract shared state.
371///
372/// This extractor clones the state passed to `extractor_handler` and provides
373/// it to the handler. The state type must match the type passed to the builder.
374///
375/// # Example
376///
377/// ```rust
378/// use std::sync::Arc;
379/// use tower_mcp::extract::State;
380///
381/// #[derive(Clone)]
382/// struct AppState {
383///     db_url: String,
384/// }
385///
386/// // In an extractor handler:
387/// // |State(state): State<Arc<AppState>>| async move { ... }
388/// ```
389///
390/// # Note
391///
392/// For expensive-to-clone types, wrap them in `Arc` before passing to
393/// `extractor_handler`.
394#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398    type Target = T;
399
400    fn deref(&self) -> &Self::Target {
401        &self.0
402    }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406    type Rejection = Rejection;
407
408    fn from_tool_request(
409        _ctx: &RequestContext,
410        state: &S,
411        _args: &Value,
412    ) -> std::result::Result<Self, Self::Rejection> {
413        Ok(State(state.clone()))
414    }
415}
416
417/// Extract the request context.
418///
419/// This extractor provides access to the [`RequestContext`], which contains:
420/// - Progress reporting via `report_progress()`
421/// - Cancellation checking via `is_cancelled()`
422/// - Sampling capabilities via `sample()`
423/// - Elicitation capabilities via `elicit_form()` and `elicit_url()`
424/// - Log sending via `send_log()`
425///
426/// # Example
427///
428/// ```rust
429/// use tower_mcp::extract::Context;
430///
431/// // In an extractor handler:
432/// // |ctx: Context| async move {
433/// //     ctx.report_progress(0.5, Some(1.0), Some("Working...")).await;
434/// //     // ...
435/// // }
436/// ```
437#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441    /// Get the inner RequestContext
442    pub fn into_inner(self) -> RequestContext {
443        self.0
444    }
445}
446
447impl Deref for Context {
448    type Target = RequestContext;
449
450    fn deref(&self) -> &Self::Target {
451        &self.0
452    }
453}
454
455impl<S> FromToolRequest<S> for Context {
456    type Rejection = Rejection;
457
458    fn from_tool_request(
459        ctx: &RequestContext,
460        _state: &S,
461        _args: &Value,
462    ) -> std::result::Result<Self, Self::Rejection> {
463        Ok(Context(ctx.clone()))
464    }
465}
466
467/// Extract raw JSON arguments.
468///
469/// This extractor provides the raw `serde_json::Value` arguments without
470/// any deserialization. Useful when you need full control over argument
471/// parsing or when the schema is dynamic.
472///
473/// # Example
474///
475/// ```rust
476/// use tower_mcp::extract::RawArgs;
477///
478/// // In an extractor handler:
479/// // |RawArgs(args): RawArgs| async move {
480/// //     // args is serde_json::Value
481/// //     if let Some(name) = args.get("name") { ... }
482/// // }
483/// ```
484#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488    type Target = Value;
489
490    fn deref(&self) -> &Self::Target {
491        &self.0
492    }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496    type Rejection = Rejection;
497
498    fn from_tool_request(
499        _ctx: &RequestContext,
500        _state: &S,
501        args: &Value,
502    ) -> std::result::Result<Self, Self::Rejection> {
503        Ok(RawArgs(args.clone()))
504    }
505}
506
507/// Extract typed data from router extensions.
508///
509/// This extractor retrieves data that was added to the router via
510/// [`crate::McpRouter::with_state()`] or [`crate::McpRouter::with_extension()`], or
511/// inserted by middleware into the request context's extensions.
512///
513/// # Example
514///
515/// ```rust
516/// use std::sync::Arc;
517/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
518/// use tower_mcp::extract::{Extension, Json};
519/// use schemars::JsonSchema;
520/// use serde::Deserialize;
521///
522/// #[derive(Clone)]
523/// struct DatabasePool {
524///     url: String,
525/// }
526///
527/// #[derive(Deserialize, JsonSchema)]
528/// struct QueryInput {
529///     sql: String,
530/// }
531///
532/// let pool = Arc::new(DatabasePool { url: "postgres://...".into() });
533///
534/// let tool = ToolBuilder::new("query")
535///     .description("Run a query")
536///     .extractor_handler(
537///         (),
538///         |Extension(db): Extension<Arc<DatabasePool>>, Json(input): Json<QueryInput>| async move {
539///             Ok(CallToolResult::text(format!("Query on {}: {}", db.url, input.sql)))
540///         },
541///     )
542///     .build();
543///
544/// let router = McpRouter::new()
545///     .with_state(pool)
546///     .tool(tool);
547/// ```
548///
549/// # Rejection
550///
551/// Returns an [`ExtensionRejection`] if the requested type is not found in the extensions.
552/// The rejection contains the type name of the missing extension.
553#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557    type Target = T;
558
559    fn deref(&self) -> &Self::Target {
560        &self.0
561    }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566    T: Clone + Send + Sync + 'static,
567{
568    type Rejection = ExtensionRejection;
569
570    fn from_tool_request(
571        ctx: &RequestContext,
572        _state: &S,
573        _args: &Value,
574    ) -> std::result::Result<Self, Self::Rejection> {
575        ctx.extension::<T>()
576            .cloned()
577            .map(Extension)
578            .ok_or_else(ExtensionRejection::not_found::<T>)
579    }
580}
581
582// =============================================================================
583// Handler Trait
584// =============================================================================
585
586/// A handler that uses extractors.
587///
588/// This trait is implemented for functions that take extractors as arguments.
589/// You don't need to implement this trait directly; it's automatically
590/// implemented for compatible async functions.
591pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592    /// The future returned by the handler.
593    type Future: Future<Output = Result<CallToolResult>> + Send;
594
595    /// Call the handler with extracted values.
596    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598    /// Get the input schema for this handler.
599    ///
600    /// Returns `None` if no `Json<T>` extractor is used.
601    fn input_schema() -> Value;
602}
603
604// Implementation for single extractor
605impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607    S: Clone + Send + Sync + 'static,
608    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609    Fut: Future<Output = Result<CallToolResult>> + Send,
610    T1: FromToolRequest<S> + HasSchema + Send,
611{
612    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615        Box::pin(async move {
616            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617            self(t1).await
618        })
619    }
620
621    fn input_schema() -> Value {
622        if let Some(schema) = T1::schema() {
623            return schema;
624        }
625        serde_json::json!({
626            "type": "object",
627            "additionalProperties": true
628        })
629    }
630}
631
632// Implementation for two extractors
633impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635    S: Clone + Send + Sync + 'static,
636    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637    Fut: Future<Output = Result<CallToolResult>> + Send,
638    T1: FromToolRequest<S> + HasSchema + Send,
639    T2: FromToolRequest<S> + HasSchema + Send,
640{
641    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644        Box::pin(async move {
645            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647            self(t1, t2).await
648        })
649    }
650
651    fn input_schema() -> Value {
652        if let Some(schema) = T2::schema() {
653            return schema;
654        }
655        if let Some(schema) = T1::schema() {
656            return schema;
657        }
658        serde_json::json!({
659            "type": "object",
660            "additionalProperties": true
661        })
662    }
663}
664
665// Implementation for three extractors
666impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668    S: Clone + Send + Sync + 'static,
669    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670    Fut: Future<Output = Result<CallToolResult>> + Send,
671    T1: FromToolRequest<S> + HasSchema + Send,
672    T2: FromToolRequest<S> + HasSchema + Send,
673    T3: FromToolRequest<S> + HasSchema + Send,
674{
675    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678        Box::pin(async move {
679            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682            self(t1, t2, t3).await
683        })
684    }
685
686    fn input_schema() -> Value {
687        if let Some(schema) = T3::schema() {
688            return schema;
689        }
690        if let Some(schema) = T2::schema() {
691            return schema;
692        }
693        if let Some(schema) = T1::schema() {
694            return schema;
695        }
696        serde_json::json!({
697            "type": "object",
698            "additionalProperties": true
699        })
700    }
701}
702
703// Implementation for four extractors
704impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706    S: Clone + Send + Sync + 'static,
707    F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708    Fut: Future<Output = Result<CallToolResult>> + Send,
709    T1: FromToolRequest<S> + HasSchema + Send,
710    T2: FromToolRequest<S> + HasSchema + Send,
711    T3: FromToolRequest<S> + HasSchema + Send,
712    T4: FromToolRequest<S> + HasSchema + Send,
713{
714    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717        Box::pin(async move {
718            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722            self(t1, t2, t3, t4).await
723        })
724    }
725
726    fn input_schema() -> Value {
727        if let Some(schema) = T4::schema() {
728            return schema;
729        }
730        if let Some(schema) = T3::schema() {
731            return schema;
732        }
733        if let Some(schema) = T2::schema() {
734            return schema;
735        }
736        if let Some(schema) = T1::schema() {
737            return schema;
738        }
739        serde_json::json!({
740            "type": "object",
741            "additionalProperties": true
742        })
743    }
744}
745
746// Implementation for five extractors
747impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749    S: Clone + Send + Sync + 'static,
750    F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751    Fut: Future<Output = Result<CallToolResult>> + Send,
752    T1: FromToolRequest<S> + HasSchema + Send,
753    T2: FromToolRequest<S> + HasSchema + Send,
754    T3: FromToolRequest<S> + HasSchema + Send,
755    T4: FromToolRequest<S> + HasSchema + Send,
756    T5: FromToolRequest<S> + HasSchema + Send,
757{
758    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761        Box::pin(async move {
762            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766            let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767            self(t1, t2, t3, t4, t5).await
768        })
769    }
770
771    fn input_schema() -> Value {
772        if let Some(schema) = T5::schema() {
773            return schema;
774        }
775        if let Some(schema) = T4::schema() {
776            return schema;
777        }
778        if let Some(schema) = T3::schema() {
779            return schema;
780        }
781        if let Some(schema) = T2::schema() {
782            return schema;
783        }
784        if let Some(schema) = T1::schema() {
785            return schema;
786        }
787        serde_json::json!({
788            "type": "object",
789            "additionalProperties": true
790        })
791    }
792}
793
794// =============================================================================
795// Schema Extraction Helper
796// =============================================================================
797
798/// Helper trait to get schema from `Json<T>` extractor
799pub trait HasSchema {
800    fn schema() -> Option<Value>;
801}
802
803impl<T: JsonSchema> HasSchema for Json<T> {
804    fn schema() -> Option<Value> {
805        let schema = schemars::schema_for!(T);
806        serde_json::to_value(schema).ok()
807    }
808}
809
810// Default impl for non-Json extractors
811impl HasSchema for Context {
812    fn schema() -> Option<Value> {
813        None
814    }
815}
816
817impl HasSchema for RawArgs {
818    fn schema() -> Option<Value> {
819        None
820    }
821}
822
823impl<T> HasSchema for State<T> {
824    fn schema() -> Option<Value> {
825        None
826    }
827}
828
829impl<T> HasSchema for Extension<T> {
830    fn schema() -> Option<Value> {
831        None
832    }
833}
834
835// =============================================================================
836// Typed Extractor Handler
837// =============================================================================
838
839/// A handler that uses extractors with typed JSON input.
840///
841/// This trait is similar to [`ExtractorHandler`] but provides proper JSON
842/// schema generation for the input type when `Json<T>` is used.
843pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
844where
845    I: JsonSchema,
846{
847    /// The future returned by the handler.
848    type Future: Future<Output = Result<CallToolResult>> + Send;
849
850    /// Call the handler with extracted values.
851    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
852}
853
854// Single extractor with Json<T>
855impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
856where
857    S: Clone + Send + Sync + 'static,
858    F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
859    Fut: Future<Output = Result<CallToolResult>> + Send,
860    T: DeserializeOwned + JsonSchema + Send,
861{
862    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
863
864    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
865        Box::pin(async move {
866            let t1 =
867                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
868            self(t1).await
869        })
870    }
871}
872
873// Two extractors ending with Json<T>
874impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
875where
876    S: Clone + Send + Sync + 'static,
877    F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
878    Fut: Future<Output = Result<CallToolResult>> + Send,
879    T1: FromToolRequest<S> + Send,
880    T: DeserializeOwned + JsonSchema + Send,
881{
882    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
883
884    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
885        Box::pin(async move {
886            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
887            let t2 =
888                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
889            self(t1, t2).await
890        })
891    }
892}
893
894// Three extractors ending with Json<T>
895impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
896where
897    S: Clone + Send + Sync + 'static,
898    F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
899    Fut: Future<Output = Result<CallToolResult>> + Send,
900    T1: FromToolRequest<S> + Send,
901    T2: FromToolRequest<S> + Send,
902    T: DeserializeOwned + JsonSchema + Send,
903{
904    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
905
906    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
907        Box::pin(async move {
908            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
909            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
910            let t3 =
911                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
912            self(t1, t2, t3).await
913        })
914    }
915}
916
917// Four extractors ending with Json<T>
918impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
919where
920    S: Clone + Send + Sync + 'static,
921    F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
922    Fut: Future<Output = Result<CallToolResult>> + Send,
923    T1: FromToolRequest<S> + Send,
924    T2: FromToolRequest<S> + Send,
925    T3: FromToolRequest<S> + Send,
926    T: DeserializeOwned + JsonSchema + Send,
927{
928    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
929
930    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
931        Box::pin(async move {
932            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
933            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
934            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
935            let t4 =
936                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
937            self(t1, t2, t3, t4).await
938        })
939    }
940}
941
942// =============================================================================
943// ToolBuilder Extensions
944// =============================================================================
945
946use crate::tool::{
947    BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
948};
949use tower::util::BoxCloneService;
950use tower_service::Service;
951
952/// Internal handler wrapper for extractor-based handlers
953pub(crate) struct ExtractorToolHandler<S, F, T> {
954    state: S,
955    handler: F,
956    input_schema: Value,
957    _phantom: PhantomData<T>,
958}
959
960impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
961where
962    S: Clone + Send + Sync + 'static,
963    F: ExtractorHandler<S, T> + Clone,
964    T: Send + Sync + 'static,
965{
966    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
967        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
968        self.call_with_context(ctx, args)
969    }
970
971    fn call_with_context(
972        &self,
973        ctx: RequestContext,
974        args: Value,
975    ) -> BoxFuture<'_, Result<CallToolResult>> {
976        let state = self.state.clone();
977        let handler = self.handler.clone();
978        Box::pin(async move { handler.call(ctx, state, args).await })
979    }
980
981    fn uses_context(&self) -> bool {
982        true
983    }
984
985    fn input_schema(&self) -> Value {
986        self.input_schema.clone()
987    }
988}
989
990/// Builder state for extractor-based handlers
991pub struct ToolBuilderWithExtractor<S, F, T> {
992    pub(crate) name: String,
993    pub(crate) title: Option<String>,
994    pub(crate) description: Option<String>,
995    pub(crate) output_schema: Option<Value>,
996    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
997    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
998    pub(crate) state: S,
999    pub(crate) handler: F,
1000    pub(crate) input_schema: Value,
1001    pub(crate) _phantom: PhantomData<T>,
1002}
1003
1004impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1005where
1006    S: Clone + Send + Sync + 'static,
1007    F: ExtractorHandler<S, T> + Clone,
1008    T: Send + Sync + 'static,
1009{
1010    /// Build the tool.
1011    pub fn build(self) -> Tool {
1012        let handler = ExtractorToolHandler {
1013            state: self.state,
1014            handler: self.handler,
1015            input_schema: self.input_schema.clone(),
1016            _phantom: PhantomData,
1017        };
1018
1019        let handler_service = ToolHandlerService::new(handler);
1020        let catch_error = ToolCatchError::new(handler_service);
1021        let service = BoxCloneService::new(catch_error);
1022
1023        Tool {
1024            name: self.name,
1025            title: self.title,
1026            description: self.description,
1027            output_schema: self.output_schema,
1028            icons: self.icons,
1029            annotations: self.annotations,
1030            service,
1031            input_schema: self.input_schema,
1032        }
1033    }
1034
1035    /// Apply a Tower layer (middleware) to this tool.
1036    ///
1037    /// The layer wraps the tool's handler service, enabling functionality like
1038    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1039    ///
1040    /// # Example
1041    ///
1042    /// ```rust
1043    /// use std::sync::Arc;
1044    /// use std::time::Duration;
1045    /// use tower::timeout::TimeoutLayer;
1046    /// use tower_mcp::{ToolBuilder, CallToolResult};
1047    /// use tower_mcp::extract::{Json, State};
1048    /// use schemars::JsonSchema;
1049    /// use serde::Deserialize;
1050    ///
1051    /// #[derive(Clone)]
1052    /// struct AppState { prefix: String }
1053    ///
1054    /// #[derive(Debug, Deserialize, JsonSchema)]
1055    /// struct QueryInput { query: String }
1056    ///
1057    /// let state = Arc::new(AppState { prefix: "db".to_string() });
1058    ///
1059    /// let tool = ToolBuilder::new("search")
1060    ///     .description("Search with timeout")
1061    ///     .extractor_handler(state, |
1062    ///         State(app): State<Arc<AppState>>,
1063    ///         Json(input): Json<QueryInput>,
1064    ///     | async move {
1065    ///         Ok(CallToolResult::text(format!("{}: {}", app.prefix, input.query)))
1066    ///     })
1067    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1068    ///     .build();
1069    /// ```
1070    pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1071        ToolBuilderWithExtractorLayer {
1072            name: self.name,
1073            title: self.title,
1074            description: self.description,
1075            output_schema: self.output_schema,
1076            icons: self.icons,
1077            annotations: self.annotations,
1078            state: self.state,
1079            handler: self.handler,
1080            input_schema: self.input_schema,
1081            layer,
1082            _phantom: PhantomData,
1083        }
1084    }
1085
1086    /// Apply a guard to this tool.
1087    ///
1088    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1089    pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1090    where
1091        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1092    {
1093        self.layer(GuardLayer::new(guard))
1094    }
1095}
1096
1097/// Builder state after a layer has been applied to an extractor handler.
1098///
1099/// This builder allows chaining additional layers and building the final tool.
1100pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1101    name: String,
1102    title: Option<String>,
1103    description: Option<String>,
1104    output_schema: Option<Value>,
1105    icons: Option<Vec<crate::protocol::ToolIcon>>,
1106    annotations: Option<crate::protocol::ToolAnnotations>,
1107    state: S,
1108    handler: F,
1109    input_schema: Value,
1110    layer: L,
1111    _phantom: PhantomData<T>,
1112}
1113
1114#[allow(private_bounds)]
1115impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1116where
1117    S: Clone + Send + Sync + 'static,
1118    F: ExtractorHandler<S, T> + Clone,
1119    T: Send + Sync + 'static,
1120    L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1121        + Clone
1122        + Send
1123        + Sync
1124        + 'static,
1125    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1126    <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1127    <L::Service as Service<ToolRequest>>::Future: Send,
1128{
1129    /// Build the tool with the applied layer(s).
1130    pub fn build(self) -> Tool {
1131        let handler = ExtractorToolHandler {
1132            state: self.state,
1133            handler: self.handler,
1134            input_schema: self.input_schema.clone(),
1135            _phantom: PhantomData,
1136        };
1137
1138        let handler_service = ToolHandlerService::new(handler);
1139        let layered = self.layer.layer(handler_service);
1140        let catch_error = ToolCatchError::new(layered);
1141        let service = BoxCloneService::new(catch_error);
1142
1143        Tool {
1144            name: self.name,
1145            title: self.title,
1146            description: self.description,
1147            output_schema: self.output_schema,
1148            icons: self.icons,
1149            annotations: self.annotations,
1150            service,
1151            input_schema: self.input_schema,
1152        }
1153    }
1154
1155    /// Apply an additional Tower layer (middleware).
1156    ///
1157    /// Layers are applied in order, with earlier layers wrapping later ones.
1158    /// This means the first layer added is the outermost middleware.
1159    pub fn layer<L2>(
1160        self,
1161        layer: L2,
1162    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1163        ToolBuilderWithExtractorLayer {
1164            name: self.name,
1165            title: self.title,
1166            description: self.description,
1167            output_schema: self.output_schema,
1168            icons: self.icons,
1169            annotations: self.annotations,
1170            state: self.state,
1171            handler: self.handler,
1172            input_schema: self.input_schema,
1173            layer: tower::layer::util::Stack::new(layer, self.layer),
1174            _phantom: PhantomData,
1175        }
1176    }
1177
1178    /// Apply a guard to this tool.
1179    ///
1180    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1181    pub fn guard<G>(
1182        self,
1183        guard: G,
1184    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1185    where
1186        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1187    {
1188        self.layer(GuardLayer::new(guard))
1189    }
1190}
1191
1192/// Builder state for extractor-based handlers with typed JSON input
1193pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1194    pub(crate) name: String,
1195    pub(crate) title: Option<String>,
1196    pub(crate) description: Option<String>,
1197    pub(crate) output_schema: Option<Value>,
1198    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1199    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1200    pub(crate) state: S,
1201    pub(crate) handler: F,
1202    pub(crate) _phantom: PhantomData<(T, I)>,
1203}
1204
1205impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1206where
1207    S: Clone + Send + Sync + 'static,
1208    F: TypedExtractorHandler<S, T, I> + Clone,
1209    T: Send + Sync + 'static,
1210    I: JsonSchema + Send + Sync + 'static,
1211{
1212    /// Build the tool.
1213    pub fn build(self) -> Tool {
1214        let input_schema = {
1215            let schema = schemars::schema_for!(I);
1216            serde_json::to_value(schema).unwrap_or_else(|_| {
1217                serde_json::json!({
1218                    "type": "object"
1219                })
1220            })
1221        };
1222
1223        let handler = TypedExtractorToolHandler {
1224            state: self.state,
1225            handler: self.handler,
1226            input_schema: input_schema.clone(),
1227            _phantom: PhantomData,
1228        };
1229
1230        let handler_service = crate::tool::ToolHandlerService::new(handler);
1231        let catch_error = ToolCatchError::new(handler_service);
1232        let service = BoxCloneService::new(catch_error);
1233
1234        Tool {
1235            name: self.name,
1236            title: self.title,
1237            description: self.description,
1238            output_schema: self.output_schema,
1239            icons: self.icons,
1240            annotations: self.annotations,
1241            service,
1242            input_schema,
1243        }
1244    }
1245}
1246
1247/// Internal handler wrapper for typed extractor-based handlers
1248struct TypedExtractorToolHandler<S, F, T, I> {
1249    state: S,
1250    handler: F,
1251    input_schema: Value,
1252    _phantom: PhantomData<(T, I)>,
1253}
1254
1255impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1256where
1257    S: Clone + Send + Sync + 'static,
1258    F: TypedExtractorHandler<S, T, I> + Clone,
1259    T: Send + Sync + 'static,
1260    I: JsonSchema + Send + Sync + 'static,
1261{
1262    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1263        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1264        self.call_with_context(ctx, args)
1265    }
1266
1267    fn call_with_context(
1268        &self,
1269        ctx: RequestContext,
1270        args: Value,
1271    ) -> BoxFuture<'_, Result<CallToolResult>> {
1272        let state = self.state.clone();
1273        let handler = self.handler.clone();
1274        Box::pin(async move { handler.call(ctx, state, args).await })
1275    }
1276
1277    fn uses_context(&self) -> bool {
1278        true
1279    }
1280
1281    fn input_schema(&self) -> Value {
1282        self.input_schema.clone()
1283    }
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288    use super::*;
1289    use crate::protocol::RequestId;
1290    use schemars::JsonSchema;
1291    use serde::Deserialize;
1292    use std::sync::Arc;
1293
1294    #[derive(Debug, Deserialize, JsonSchema)]
1295    struct TestInput {
1296        name: String,
1297        count: i32,
1298    }
1299
1300    #[test]
1301    fn test_json_extraction() {
1302        let args = serde_json::json!({"name": "test", "count": 42});
1303        let ctx = RequestContext::new(RequestId::Number(1));
1304
1305        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1306        assert!(result.is_ok());
1307        let Json(input) = result.unwrap();
1308        assert_eq!(input.name, "test");
1309        assert_eq!(input.count, 42);
1310    }
1311
1312    #[test]
1313    fn test_json_extraction_error() {
1314        let args = serde_json::json!({"name": "test"}); // missing count
1315        let ctx = RequestContext::new(RequestId::Number(1));
1316
1317        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1318        assert!(result.is_err());
1319        let rejection = result.unwrap_err();
1320        // JsonRejection contains the serde error message
1321        assert!(rejection.message().contains("count"));
1322    }
1323
1324    #[test]
1325    fn test_state_extraction() {
1326        let args = serde_json::json!({});
1327        let ctx = RequestContext::new(RequestId::Number(1));
1328        let state = Arc::new("my-state".to_string());
1329
1330        let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1331        assert!(result.is_ok());
1332        let State(extracted) = result.unwrap();
1333        assert_eq!(*extracted, "my-state");
1334    }
1335
1336    #[test]
1337    fn test_context_extraction() {
1338        let args = serde_json::json!({});
1339        let ctx = RequestContext::new(RequestId::Number(42));
1340
1341        let result = Context::from_tool_request(&ctx, &(), &args);
1342        assert!(result.is_ok());
1343        let extracted = result.unwrap();
1344        assert_eq!(*extracted.request_id(), RequestId::Number(42));
1345    }
1346
1347    #[test]
1348    fn test_raw_args_extraction() {
1349        let args = serde_json::json!({"foo": "bar", "baz": 123});
1350        let ctx = RequestContext::new(RequestId::Number(1));
1351
1352        let result = RawArgs::from_tool_request(&ctx, &(), &args);
1353        assert!(result.is_ok());
1354        let RawArgs(extracted) = result.unwrap();
1355        assert_eq!(extracted["foo"], "bar");
1356        assert_eq!(extracted["baz"], 123);
1357    }
1358
1359    #[test]
1360    fn test_extension_extraction() {
1361        use crate::context::Extensions;
1362
1363        #[derive(Clone, Debug, PartialEq)]
1364        struct DatabasePool {
1365            url: String,
1366        }
1367
1368        let args = serde_json::json!({});
1369
1370        // Create extensions with a value
1371        let mut extensions = Extensions::new();
1372        extensions.insert(Arc::new(DatabasePool {
1373            url: "postgres://localhost".to_string(),
1374        }));
1375
1376        // Create context with extensions
1377        let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1378
1379        // Extract the extension
1380        let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1381        assert!(result.is_ok());
1382        let Extension(pool) = result.unwrap();
1383        assert_eq!(pool.url, "postgres://localhost");
1384    }
1385
1386    #[test]
1387    fn test_extension_extraction_missing() {
1388        #[derive(Clone, Debug)]
1389        struct NotPresent;
1390
1391        let args = serde_json::json!({});
1392        let ctx = RequestContext::new(RequestId::Number(1));
1393
1394        // Try to extract something that's not in extensions
1395        let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1396        assert!(result.is_err());
1397        let rejection = result.unwrap_err();
1398        // ExtensionRejection contains the type name
1399        assert!(rejection.type_name().contains("NotPresent"));
1400    }
1401
1402    #[tokio::test]
1403    async fn test_single_extractor_handler() {
1404        let handler = |Json(input): Json<TestInput>| async move {
1405            Ok(CallToolResult::text(format!(
1406                "{}: {}",
1407                input.name, input.count
1408            )))
1409        };
1410
1411        let ctx = RequestContext::new(RequestId::Number(1));
1412        let args = serde_json::json!({"name": "test", "count": 5});
1413
1414        // Use explicit trait to avoid ambiguity
1415        let result: Result<CallToolResult> =
1416            ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1417        assert!(result.is_ok());
1418    }
1419
1420    #[tokio::test]
1421    async fn test_two_extractor_handler() {
1422        let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1423            Ok(CallToolResult::text(format!(
1424                "{}: {} - {}",
1425                state, input.name, input.count
1426            )))
1427        };
1428
1429        let ctx = RequestContext::new(RequestId::Number(1));
1430        let state = Arc::new("prefix".to_string());
1431        let args = serde_json::json!({"name": "test", "count": 5});
1432
1433        // Use explicit trait to avoid ambiguity
1434        let result: Result<CallToolResult> = ExtractorHandler::<
1435            Arc<String>,
1436            (State<Arc<String>>, Json<TestInput>),
1437        >::call(handler, ctx, state, args)
1438        .await;
1439        assert!(result.is_ok());
1440    }
1441
1442    #[tokio::test]
1443    async fn test_three_extractor_handler() {
1444        let handler = |State(state): State<Arc<String>>,
1445                       ctx: Context,
1446                       Json(input): Json<TestInput>| async move {
1447            // Verify we can access all extractors
1448            assert!(!ctx.is_cancelled());
1449            Ok(CallToolResult::text(format!(
1450                "{}: {} - {}",
1451                state, input.name, input.count
1452            )))
1453        };
1454
1455        let ctx = RequestContext::new(RequestId::Number(1));
1456        let state = Arc::new("prefix".to_string());
1457        let args = serde_json::json!({"name": "test", "count": 5});
1458
1459        // Use explicit trait to avoid ambiguity
1460        let result: Result<CallToolResult> = ExtractorHandler::<
1461            Arc<String>,
1462            (State<Arc<String>>, Context, Json<TestInput>),
1463        >::call(handler, ctx, state, args)
1464        .await;
1465        assert!(result.is_ok());
1466    }
1467
1468    #[test]
1469    fn test_json_schema_generation() {
1470        let schema = Json::<TestInput>::schema();
1471        assert!(schema.is_some());
1472        let schema = schema.unwrap();
1473        assert!(schema.get("properties").is_some());
1474    }
1475
1476    #[test]
1477    fn test_rejection_into_error() {
1478        let rejection = Rejection::new("test error");
1479        let error: Error = rejection.into();
1480        assert!(error.to_string().contains("test error"));
1481    }
1482
1483    #[test]
1484    fn test_json_rejection() {
1485        // Test basic JsonRejection
1486        let rejection = JsonRejection::new("missing field `name`");
1487        assert_eq!(rejection.message(), "missing field `name`");
1488        assert!(rejection.path().is_none());
1489        assert!(rejection.to_string().contains("Invalid input"));
1490
1491        // Test JsonRejection with path
1492        let rejection = JsonRejection::with_path("expected string", "users[0].name");
1493        assert_eq!(rejection.message(), "expected string");
1494        assert_eq!(rejection.path(), Some("users[0].name"));
1495        assert!(rejection.to_string().contains("users[0].name"));
1496
1497        // Test conversion to Error
1498        let error: Error = rejection.into();
1499        assert!(error.to_string().contains("users[0].name"));
1500    }
1501
1502    #[test]
1503    fn test_json_rejection_from_serde_error() {
1504        // Create a real serde error by deserializing invalid JSON
1505        #[derive(Debug, serde::Deserialize)]
1506        struct TestStruct {
1507            #[allow(dead_code)]
1508            name: String,
1509        }
1510
1511        let result: std::result::Result<TestStruct, _> =
1512            serde_json::from_value(serde_json::json!({"count": 42}));
1513        assert!(result.is_err());
1514
1515        let rejection: JsonRejection = result.unwrap_err().into();
1516        assert!(rejection.message().contains("name"));
1517    }
1518
1519    #[test]
1520    fn test_extension_rejection() {
1521        // Test ExtensionRejection
1522        let rejection = ExtensionRejection::not_found::<String>();
1523        assert!(rejection.type_name().contains("String"));
1524        assert!(rejection.to_string().contains("not found"));
1525        assert!(rejection.to_string().contains("with_state"));
1526
1527        // Test conversion to Error
1528        let error: Error = rejection.into();
1529        assert!(error.to_string().contains("not found"));
1530    }
1531
1532    #[tokio::test]
1533    async fn test_tool_builder_extractor_handler() {
1534        use crate::ToolBuilder;
1535
1536        let state = Arc::new("shared-state".to_string());
1537
1538        let tool =
1539            ToolBuilder::new("test_extractor")
1540                .description("Test extractor handler")
1541                .extractor_handler(
1542                    state,
1543                    |State(state): State<Arc<String>>,
1544                     ctx: Context,
1545                     Json(input): Json<TestInput>| async move {
1546                        assert!(!ctx.is_cancelled());
1547                        Ok(CallToolResult::text(format!(
1548                            "{}: {} - {}",
1549                            state, input.name, input.count
1550                        )))
1551                    },
1552                )
1553                .build();
1554
1555        assert_eq!(tool.name, "test_extractor");
1556        assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1557
1558        // Test calling the tool
1559        let result = tool
1560            .call(serde_json::json!({"name": "test", "count": 42}))
1561            .await;
1562        assert!(!result.is_error);
1563    }
1564
1565    #[tokio::test]
1566    async fn test_tool_builder_extractor_handler_typed() {
1567        use crate::ToolBuilder;
1568
1569        let state = Arc::new("typed-state".to_string());
1570
1571        let tool = ToolBuilder::new("test_typed")
1572            .description("Test typed extractor handler")
1573            .extractor_handler_typed::<_, _, _, TestInput>(
1574                state,
1575                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1576                    Ok(CallToolResult::text(format!(
1577                        "{}: {} - {}",
1578                        state, input.name, input.count
1579                    )))
1580                },
1581            )
1582            .build();
1583
1584        assert_eq!(tool.name, "test_typed");
1585
1586        // Verify schema is properly generated from TestInput
1587        let def = tool.definition();
1588        let schema = def.input_schema;
1589        assert!(schema.get("properties").is_some());
1590
1591        // Test calling the tool
1592        let result = tool
1593            .call(serde_json::json!({"name": "world", "count": 99}))
1594            .await;
1595        assert!(!result.is_error);
1596    }
1597
1598    #[tokio::test]
1599    async fn test_extractor_handler_auto_schema() {
1600        use crate::ToolBuilder;
1601
1602        let state = Arc::new("auto-schema".to_string());
1603
1604        // extractor_handler (not _typed) should auto-detect Json<TestInput> schema
1605        let tool = ToolBuilder::new("test_auto_schema")
1606            .description("Test auto schema detection")
1607            .extractor_handler(
1608                state,
1609                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1610                    Ok(CallToolResult::text(format!(
1611                        "{}: {} - {}",
1612                        state, input.name, input.count
1613                    )))
1614                },
1615            )
1616            .build();
1617
1618        // Verify schema is properly generated from TestInput (not generic object)
1619        let def = tool.definition();
1620        let schema = def.input_schema;
1621        assert!(
1622            schema.get("properties").is_some(),
1623            "Schema should have properties from TestInput, got: {}",
1624            schema
1625        );
1626        let props = schema.get("properties").unwrap();
1627        assert!(
1628            props.get("name").is_some(),
1629            "Schema should have 'name' property"
1630        );
1631        assert!(
1632            props.get("count").is_some(),
1633            "Schema should have 'count' property"
1634        );
1635
1636        // Test calling the tool
1637        let result = tool
1638            .call(serde_json::json!({"name": "world", "count": 99}))
1639            .await;
1640        assert!(!result.is_error);
1641    }
1642
1643    #[test]
1644    fn test_extractor_handler_no_json_fallback() {
1645        use crate::ToolBuilder;
1646
1647        // extractor_handler without Json<T> should fall back to generic schema
1648        let tool = ToolBuilder::new("test_no_json")
1649            .description("Test no json fallback")
1650            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1651                Ok(CallToolResult::json(args))
1652            })
1653            .build();
1654
1655        let def = tool.definition();
1656        let schema = def.input_schema;
1657        assert_eq!(
1658            schema.get("type").and_then(|v| v.as_str()),
1659            Some("object"),
1660            "Schema should be generic object"
1661        );
1662        assert_eq!(
1663            schema.get("additionalProperties").and_then(|v| v.as_bool()),
1664            Some(true),
1665            "Schema should allow additional properties"
1666        );
1667        // Should NOT have specific properties
1668        assert!(
1669            schema.get("properties").is_none(),
1670            "Generic schema should not have specific properties"
1671        );
1672    }
1673
1674    #[tokio::test]
1675    async fn test_extractor_handler_with_layer() {
1676        use crate::ToolBuilder;
1677        use std::time::Duration;
1678        use tower::timeout::TimeoutLayer;
1679
1680        let state = Arc::new("layered".to_string());
1681
1682        let tool = ToolBuilder::new("test_extractor_layer")
1683            .description("Test extractor handler with layer")
1684            .extractor_handler(
1685                state,
1686                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1687                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1688                },
1689            )
1690            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1691            .build();
1692
1693        // Verify the tool works
1694        let result = tool
1695            .call(serde_json::json!({"name": "test", "count": 1}))
1696            .await;
1697        assert!(!result.is_error);
1698        assert_eq!(result.first_text().unwrap(), "layered: test");
1699
1700        // Verify schema is still properly generated
1701        let def = tool.definition();
1702        let schema = def.input_schema;
1703        assert!(
1704            schema.get("properties").is_some(),
1705            "Schema should have properties even with layer"
1706        );
1707    }
1708
1709    #[tokio::test]
1710    async fn test_extractor_handler_with_timeout_layer() {
1711        use crate::ToolBuilder;
1712        use std::time::Duration;
1713        use tower::timeout::TimeoutLayer;
1714
1715        let tool = ToolBuilder::new("test_extractor_timeout")
1716            .description("Test extractor handler timeout")
1717            .extractor_handler((), |Json(input): Json<TestInput>| async move {
1718                tokio::time::sleep(Duration::from_millis(200)).await;
1719                Ok(CallToolResult::text(input.name.to_string()))
1720            })
1721            .layer(TimeoutLayer::new(Duration::from_millis(50)))
1722            .build();
1723
1724        // Should timeout
1725        let result = tool
1726            .call(serde_json::json!({"name": "slow", "count": 1}))
1727            .await;
1728        assert!(result.is_error);
1729        let msg = result.first_text().unwrap().to_lowercase();
1730        assert!(
1731            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1732            "Expected timeout error, got: {}",
1733            msg
1734        );
1735    }
1736
1737    #[tokio::test]
1738    async fn test_extractor_handler_with_multiple_layers() {
1739        use crate::ToolBuilder;
1740        use std::time::Duration;
1741        use tower::limit::ConcurrencyLimitLayer;
1742        use tower::timeout::TimeoutLayer;
1743
1744        let state = Arc::new("multi".to_string());
1745
1746        let tool = ToolBuilder::new("test_multi_layer")
1747            .description("Test multiple layers")
1748            .extractor_handler(
1749                state,
1750                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1751                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1752                },
1753            )
1754            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1755            .layer(ConcurrencyLimitLayer::new(10))
1756            .build();
1757
1758        let result = tool
1759            .call(serde_json::json!({"name": "test", "count": 1}))
1760            .await;
1761        assert!(!result.is_error);
1762        assert_eq!(result.first_text().unwrap(), "multi: test");
1763    }
1764}