Skip to main content

tower_mcp/
extract.rs

1//! Extractor pattern for tool handlers
2//!
3//! This module provides an axum-inspired extractor pattern that makes state and context
4//! injection more declarative, reducing the combinatorial explosion of handler variants.
5//!
6//! # Overview
7//!
8//! Extractors implement [`FromToolRequest`], which extracts data from the tool request
9//! (context, state, and arguments). Multiple extractors can be combined in handler
10//! function parameters.
11//!
12//! # Built-in Extractors
13//!
14//! - [`Json<T>`] - Extract typed input from args (deserializes JSON)
15//! - [`State<T>`] - Extract shared state from per-tool state (cloned for each request)
16//! - [`Extension<T>`] - Extract data from router extensions (via `router.with_state()`)
17//! - [`Context`] - Extract the [`RequestContext`] for progress, cancellation, etc.
18//! - [`RawArgs`] - Extract raw `serde_json::Value` arguments
19//!
20//! ## State vs Extension
21//!
22//! - Use **`State<T>`** when state is passed directly to `extractor_handler()` (per-tool state)
23//! - Use **`Extension<T>`** when state is set via `McpRouter::with_state()` (router-level state)
24//!
25//! # Example
26//!
27//! ```rust
28//! use std::sync::Arc;
29//! use tower_mcp::{ToolBuilder, CallToolResult};
30//! use tower_mcp::extract::{Json, State, Context};
31//! use schemars::JsonSchema;
32//! use serde::Deserialize;
33//!
34//! #[derive(Clone)]
35//! struct AppState {
36//!     db_url: String,
37//! }
38//!
39//! #[derive(Debug, Deserialize, JsonSchema)]
40//! struct QueryInput {
41//!     query: String,
42//! }
43//!
44//! let state = Arc::new(AppState { db_url: "postgres://...".to_string() });
45//!
46//! let tool = ToolBuilder::new("search")
47//!     .description("Search the database")
48//!     .extractor_handler(state, |
49//!         State(db): State<Arc<AppState>>,
50//!         ctx: Context,
51//!         Json(input): Json<QueryInput>,
52//!     | async move {
53//!         // Check cancellation
54//!         if ctx.is_cancelled() {
55//!             return Ok(CallToolResult::error("Cancelled"));
56//!         }
57//!         // Report progress
58//!         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
59//!         // Use state
60//!         Ok(CallToolResult::text(format!("Searched {} with query: {}", db.db_url, input.query)))
61//!     })
62//!     .build();
63//! ```
64//!
65//! # Extractor Order
66//!
67//! The order of extractors in the function signature doesn't matter. Each extractor
68//! independently extracts its data from the request.
69//!
70//! # Error Handling
71//!
72//! If an extractor fails (e.g., JSON deserialization fails), the handler returns
73//! a `CallToolResult::error()` with the rejection message.
74
75use std::future::Future;
76use std::marker::PhantomData;
77use std::ops::Deref;
78use std::pin::Pin;
79
80use schemars::JsonSchema;
81use serde::de::DeserializeOwned;
82use serde_json::Value;
83
84use crate::context::RequestContext;
85use crate::error::{Error, Result};
86use crate::protocol::CallToolResult;
87
88// =============================================================================
89// Rejection Types
90// =============================================================================
91
92/// A simple rejection with a message string.
93///
94/// This is a general-purpose rejection type for custom extractors.
95/// For more specific error information, use the typed rejection types
96/// like [`JsonRejection`] or [`ExtensionRejection`].
97#[derive(Debug, Clone)]
98pub struct Rejection {
99    message: String,
100}
101
102impl Rejection {
103    /// Create a new rejection with the given message.
104    pub fn new(message: impl Into<String>) -> Self {
105        Self {
106            message: message.into(),
107        }
108    }
109
110    /// Get the rejection message.
111    pub fn message(&self) -> &str {
112        &self.message
113    }
114}
115
116impl std::fmt::Display for Rejection {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}", self.message)
119    }
120}
121
122impl std::error::Error for Rejection {}
123
124impl From<Rejection> for Error {
125    fn from(rejection: Rejection) -> Self {
126        Error::tool(rejection.message)
127    }
128}
129
130/// Rejection returned when JSON deserialization fails.
131///
132/// This rejection provides structured information about the deserialization
133/// error, including the path to the failing field when available.
134///
135/// # Example
136///
137/// ```rust
138/// use tower_mcp::extract::JsonRejection;
139///
140/// let rejection = JsonRejection::new("missing field `name`");
141/// assert!(rejection.message().contains("name"));
142/// ```
143#[derive(Debug, Clone)]
144pub struct JsonRejection {
145    message: String,
146    /// The serde error path, if available (e.g., "users[0].name")
147    path: Option<String>,
148}
149
150impl JsonRejection {
151    /// Create a new JSON rejection from a serde error.
152    pub fn new(message: impl Into<String>) -> Self {
153        Self {
154            message: message.into(),
155            path: None,
156        }
157    }
158
159    /// Create a JSON rejection with a path to the failing field.
160    pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
161        Self {
162            message: message.into(),
163            path: Some(path.into()),
164        }
165    }
166
167    /// Get the error message.
168    pub fn message(&self) -> &str {
169        &self.message
170    }
171
172    /// Get the path to the failing field, if available.
173    pub fn path(&self) -> Option<&str> {
174        self.path.as_deref()
175    }
176}
177
178impl std::fmt::Display for JsonRejection {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        if let Some(path) = &self.path {
181            write!(f, "Invalid input at `{}`: {}", path, self.message)
182        } else {
183            write!(f, "Invalid input: {}", self.message)
184        }
185    }
186}
187
188impl std::error::Error for JsonRejection {}
189
190impl From<JsonRejection> for Error {
191    fn from(rejection: JsonRejection) -> Self {
192        Error::tool(rejection.to_string())
193    }
194}
195
196impl From<serde_json::Error> for JsonRejection {
197    fn from(err: serde_json::Error) -> Self {
198        // Try to extract path information from serde error
199        let path = if err.is_data() {
200            // serde_json provides line/column but not field path in the error itself
201            // The path is embedded in the message for some error types
202            None
203        } else {
204            None
205        };
206
207        Self {
208            message: err.to_string(),
209            path,
210        }
211    }
212}
213
214/// Rejection returned when an extension is not found.
215///
216/// This rejection is returned by the [`Extension`] extractor when the
217/// requested type is not present in the router's extensions.
218///
219/// # Example
220///
221/// ```rust
222/// use tower_mcp::extract::ExtensionRejection;
223///
224/// let rejection = ExtensionRejection::not_found::<String>();
225/// assert!(rejection.type_name().contains("String"));
226/// ```
227#[derive(Debug, Clone)]
228pub struct ExtensionRejection {
229    type_name: &'static str,
230}
231
232impl ExtensionRejection {
233    /// Create a rejection for a missing extension type.
234    pub fn not_found<T>() -> Self {
235        Self {
236            type_name: std::any::type_name::<T>(),
237        }
238    }
239
240    /// Get the type name of the missing extension.
241    pub fn type_name(&self) -> &'static str {
242        self.type_name
243    }
244}
245
246impl std::fmt::Display for ExtensionRejection {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(
249            f,
250            "Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
251            self.type_name
252        )
253    }
254}
255
256impl std::error::Error for ExtensionRejection {}
257
258impl From<ExtensionRejection> for Error {
259    fn from(rejection: ExtensionRejection) -> Self {
260        Error::tool(rejection.to_string())
261    }
262}
263
264/// Trait for extracting data from a tool request.
265///
266/// Implement this trait to create custom extractors that can be used
267/// in `extractor_handler` functions.
268///
269/// # Type Parameters
270///
271/// - `S` - The state type. Defaults to `()` for extractors that don't need state.
272///
273/// # Example
274///
275/// ```rust
276/// use tower_mcp::extract::{FromToolRequest, Rejection};
277/// use tower_mcp::RequestContext;
278/// use serde_json::Value;
279///
280/// struct RequestId(String);
281///
282/// impl<S> FromToolRequest<S> for RequestId {
283///     type Rejection = Rejection;
284///
285///     fn from_tool_request(
286///         ctx: &RequestContext,
287///         _state: &S,
288///         _args: &Value,
289///     ) -> Result<Self, Self::Rejection> {
290///         Ok(RequestId(format!("{:?}", ctx.request_id())))
291///     }
292/// }
293/// ```
294pub trait FromToolRequest<S = ()>: Sized {
295    /// The rejection type returned when extraction fails.
296    type Rejection: Into<Error>;
297
298    /// Extract this type from the tool request.
299    ///
300    /// # Arguments
301    ///
302    /// * `ctx` - The request context with progress, cancellation, etc.
303    /// * `state` - The shared state passed to the handler
304    /// * `args` - The raw JSON arguments to the tool
305    fn from_tool_request(
306        ctx: &RequestContext,
307        state: &S,
308        args: &Value,
309    ) -> std::result::Result<Self, Self::Rejection>;
310}
311
312// =============================================================================
313// Built-in Extractors
314// =============================================================================
315
316/// Extract and deserialize JSON arguments into a typed struct.
317///
318/// This extractor deserializes the tool's JSON arguments into type `T`.
319/// The type must implement [`serde::de::DeserializeOwned`] and [`schemars::JsonSchema`].
320///
321/// # Example
322///
323/// ```rust
324/// use tower_mcp::extract::Json;
325/// use schemars::JsonSchema;
326/// use serde::Deserialize;
327///
328/// #[derive(Debug, Deserialize, JsonSchema)]
329/// struct MyInput {
330///     name: String,
331///     count: i32,
332/// }
333///
334/// // In an extractor handler:
335/// // |Json(input): Json<MyInput>| async move { ... }
336/// ```
337///
338/// # Rejection
339///
340/// Returns a [`JsonRejection`] if deserialization fails. The rejection contains
341/// the error message and potentially the path to the failing field.
342#[derive(Debug, Clone, Copy)]
343pub struct Json<T>(pub T);
344
345impl<T> Deref for Json<T> {
346    type Target = T;
347
348    fn deref(&self) -> &Self::Target {
349        &self.0
350    }
351}
352
353impl<S, T> FromToolRequest<S> for Json<T>
354where
355    T: DeserializeOwned,
356{
357    type Rejection = JsonRejection;
358
359    fn from_tool_request(
360        _ctx: &RequestContext,
361        _state: &S,
362        args: &Value,
363    ) -> std::result::Result<Self, Self::Rejection> {
364        serde_json::from_value(args.clone())
365            .map(Json)
366            .map_err(JsonRejection::from)
367    }
368}
369
370/// Extract shared state.
371///
372/// This extractor clones the state passed to `extractor_handler` and provides
373/// it to the handler. The state type must match the type passed to the builder.
374///
375/// # Example
376///
377/// ```rust
378/// use std::sync::Arc;
379/// use tower_mcp::extract::State;
380///
381/// #[derive(Clone)]
382/// struct AppState {
383///     db_url: String,
384/// }
385///
386/// // In an extractor handler:
387/// // |State(state): State<Arc<AppState>>| async move { ... }
388/// ```
389///
390/// # Note
391///
392/// For expensive-to-clone types, wrap them in `Arc` before passing to
393/// `extractor_handler`.
394#[derive(Debug, Clone, Copy)]
395pub struct State<T>(pub T);
396
397impl<T> Deref for State<T> {
398    type Target = T;
399
400    fn deref(&self) -> &Self::Target {
401        &self.0
402    }
403}
404
405impl<S: Clone> FromToolRequest<S> for State<S> {
406    type Rejection = Rejection;
407
408    fn from_tool_request(
409        _ctx: &RequestContext,
410        state: &S,
411        _args: &Value,
412    ) -> std::result::Result<Self, Self::Rejection> {
413        Ok(State(state.clone()))
414    }
415}
416
417/// Extract the request context.
418///
419/// This extractor provides access to the [`RequestContext`], which contains:
420/// - Progress reporting via `report_progress()`
421/// - Cancellation checking via `is_cancelled()`
422/// - Sampling capabilities via `sample()`
423/// - Elicitation capabilities via `elicit_form()` and `elicit_url()`
424/// - Log sending via `send_log()`
425///
426/// # Example
427///
428/// ```rust
429/// use tower_mcp::extract::Context;
430///
431/// // In an extractor handler:
432/// // |ctx: Context| async move {
433/// //     ctx.report_progress(0.5, Some(1.0), Some("Working...")).await;
434/// //     // ...
435/// // }
436/// ```
437#[derive(Debug, Clone)]
438pub struct Context(RequestContext);
439
440impl Context {
441    /// Get the inner RequestContext
442    pub fn into_inner(self) -> RequestContext {
443        self.0
444    }
445}
446
447impl Deref for Context {
448    type Target = RequestContext;
449
450    fn deref(&self) -> &Self::Target {
451        &self.0
452    }
453}
454
455impl<S> FromToolRequest<S> for Context {
456    type Rejection = Rejection;
457
458    fn from_tool_request(
459        ctx: &RequestContext,
460        _state: &S,
461        _args: &Value,
462    ) -> std::result::Result<Self, Self::Rejection> {
463        Ok(Context(ctx.clone()))
464    }
465}
466
467/// Extract raw JSON arguments.
468///
469/// This extractor provides the raw `serde_json::Value` arguments without
470/// any deserialization. Useful when you need full control over argument
471/// parsing or when the schema is dynamic.
472///
473/// # Example
474///
475/// ```rust
476/// use tower_mcp::extract::RawArgs;
477///
478/// // In an extractor handler:
479/// // |RawArgs(args): RawArgs| async move {
480/// //     // args is serde_json::Value
481/// //     if let Some(name) = args.get("name") { ... }
482/// // }
483/// ```
484#[derive(Debug, Clone)]
485pub struct RawArgs(pub Value);
486
487impl Deref for RawArgs {
488    type Target = Value;
489
490    fn deref(&self) -> &Self::Target {
491        &self.0
492    }
493}
494
495impl<S> FromToolRequest<S> for RawArgs {
496    type Rejection = Rejection;
497
498    fn from_tool_request(
499        _ctx: &RequestContext,
500        _state: &S,
501        args: &Value,
502    ) -> std::result::Result<Self, Self::Rejection> {
503        Ok(RawArgs(args.clone()))
504    }
505}
506
507/// Extract typed data from router extensions.
508///
509/// This extractor retrieves data that was added to the router via
510/// [`crate::McpRouter::with_state()`] or [`crate::McpRouter::with_extension()`], or
511/// inserted by middleware into the request context's extensions.
512///
513/// # Example
514///
515/// ```rust
516/// use std::sync::Arc;
517/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
518/// use tower_mcp::extract::{Extension, Json};
519/// use schemars::JsonSchema;
520/// use serde::Deserialize;
521///
522/// #[derive(Clone)]
523/// struct DatabasePool {
524///     url: String,
525/// }
526///
527/// #[derive(Deserialize, JsonSchema)]
528/// struct QueryInput {
529///     sql: String,
530/// }
531///
532/// let pool = Arc::new(DatabasePool { url: "postgres://...".into() });
533///
534/// let tool = ToolBuilder::new("query")
535///     .description("Run a query")
536///     .extractor_handler(
537///         (),
538///         |Extension(db): Extension<Arc<DatabasePool>>, Json(input): Json<QueryInput>| async move {
539///             Ok(CallToolResult::text(format!("Query on {}: {}", db.url, input.sql)))
540///         },
541///     )
542///     .build();
543///
544/// let router = McpRouter::new()
545///     .with_state(pool)
546///     .tool(tool);
547/// ```
548///
549/// # Rejection
550///
551/// Returns an [`ExtensionRejection`] if the requested type is not found in the extensions.
552/// The rejection contains the type name of the missing extension.
553#[derive(Debug, Clone)]
554pub struct Extension<T>(pub T);
555
556impl<T> Deref for Extension<T> {
557    type Target = T;
558
559    fn deref(&self) -> &Self::Target {
560        &self.0
561    }
562}
563
564impl<S, T> FromToolRequest<S> for Extension<T>
565where
566    T: Clone + Send + Sync + 'static,
567{
568    type Rejection = ExtensionRejection;
569
570    fn from_tool_request(
571        ctx: &RequestContext,
572        _state: &S,
573        _args: &Value,
574    ) -> std::result::Result<Self, Self::Rejection> {
575        ctx.extension::<T>()
576            .cloned()
577            .map(Extension)
578            .ok_or_else(ExtensionRejection::not_found::<T>)
579    }
580}
581
582// =============================================================================
583// Handler Trait
584// =============================================================================
585
586/// A handler that uses extractors.
587///
588/// This trait is implemented for functions that take extractors as arguments.
589/// You don't need to implement this trait directly; it's automatically
590/// implemented for compatible async functions.
591pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
592    /// The future returned by the handler.
593    type Future: Future<Output = Result<CallToolResult>> + Send;
594
595    /// Call the handler with extracted values.
596    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
597
598    /// Get the input schema for this handler.
599    ///
600    /// Returns `None` if no `Json<T>` extractor is used.
601    fn input_schema() -> Value;
602}
603
604// Implementation for single extractor
605impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
606where
607    S: Clone + Send + Sync + 'static,
608    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
609    Fut: Future<Output = Result<CallToolResult>> + Send,
610    T1: FromToolRequest<S> + HasSchema + Send,
611{
612    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
613
614    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
615        Box::pin(async move {
616            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
617            self(t1).await
618        })
619    }
620
621    fn input_schema() -> Value {
622        if let Some(schema) = T1::schema() {
623            return schema;
624        }
625        serde_json::json!({
626            "type": "object",
627            "additionalProperties": true
628        })
629    }
630}
631
632// Implementation for two extractors
633impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
634where
635    S: Clone + Send + Sync + 'static,
636    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
637    Fut: Future<Output = Result<CallToolResult>> + Send,
638    T1: FromToolRequest<S> + HasSchema + Send,
639    T2: FromToolRequest<S> + HasSchema + Send,
640{
641    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
642
643    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
644        Box::pin(async move {
645            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
646            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
647            self(t1, t2).await
648        })
649    }
650
651    fn input_schema() -> Value {
652        if let Some(schema) = T2::schema() {
653            return schema;
654        }
655        if let Some(schema) = T1::schema() {
656            return schema;
657        }
658        serde_json::json!({
659            "type": "object",
660            "additionalProperties": true
661        })
662    }
663}
664
665// Implementation for three extractors
666impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
667where
668    S: Clone + Send + Sync + 'static,
669    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
670    Fut: Future<Output = Result<CallToolResult>> + Send,
671    T1: FromToolRequest<S> + HasSchema + Send,
672    T2: FromToolRequest<S> + HasSchema + Send,
673    T3: FromToolRequest<S> + HasSchema + Send,
674{
675    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
676
677    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
678        Box::pin(async move {
679            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
680            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
681            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
682            self(t1, t2, t3).await
683        })
684    }
685
686    fn input_schema() -> Value {
687        if let Some(schema) = T3::schema() {
688            return schema;
689        }
690        if let Some(schema) = T2::schema() {
691            return schema;
692        }
693        if let Some(schema) = T1::schema() {
694            return schema;
695        }
696        serde_json::json!({
697            "type": "object",
698            "additionalProperties": true
699        })
700    }
701}
702
703// Implementation for four extractors
704impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
705where
706    S: Clone + Send + Sync + 'static,
707    F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
708    Fut: Future<Output = Result<CallToolResult>> + Send,
709    T1: FromToolRequest<S> + HasSchema + Send,
710    T2: FromToolRequest<S> + HasSchema + Send,
711    T3: FromToolRequest<S> + HasSchema + Send,
712    T4: FromToolRequest<S> + HasSchema + Send,
713{
714    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
715
716    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
717        Box::pin(async move {
718            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
719            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
720            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
721            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
722            self(t1, t2, t3, t4).await
723        })
724    }
725
726    fn input_schema() -> Value {
727        if let Some(schema) = T4::schema() {
728            return schema;
729        }
730        if let Some(schema) = T3::schema() {
731            return schema;
732        }
733        if let Some(schema) = T2::schema() {
734            return schema;
735        }
736        if let Some(schema) = T1::schema() {
737            return schema;
738        }
739        serde_json::json!({
740            "type": "object",
741            "additionalProperties": true
742        })
743    }
744}
745
746// Implementation for five extractors
747impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
748where
749    S: Clone + Send + Sync + 'static,
750    F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
751    Fut: Future<Output = Result<CallToolResult>> + Send,
752    T1: FromToolRequest<S> + HasSchema + Send,
753    T2: FromToolRequest<S> + HasSchema + Send,
754    T3: FromToolRequest<S> + HasSchema + Send,
755    T4: FromToolRequest<S> + HasSchema + Send,
756    T5: FromToolRequest<S> + HasSchema + Send,
757{
758    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
759
760    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
761        Box::pin(async move {
762            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
763            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
764            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
765            let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
766            let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
767            self(t1, t2, t3, t4, t5).await
768        })
769    }
770
771    fn input_schema() -> Value {
772        if let Some(schema) = T5::schema() {
773            return schema;
774        }
775        if let Some(schema) = T4::schema() {
776            return schema;
777        }
778        if let Some(schema) = T3::schema() {
779            return schema;
780        }
781        if let Some(schema) = T2::schema() {
782            return schema;
783        }
784        if let Some(schema) = T1::schema() {
785            return schema;
786        }
787        serde_json::json!({
788            "type": "object",
789            "additionalProperties": true
790        })
791    }
792}
793
794// =============================================================================
795// Schema Extraction Helper
796// =============================================================================
797
798/// Helper trait to get schema from `Json<T>` extractor
799pub trait HasSchema {
800    /// Returns the JSON Schema for this type, if available.
801    fn schema() -> Option<Value>;
802}
803
804impl<T: JsonSchema> HasSchema for Json<T> {
805    fn schema() -> Option<Value> {
806        let schema = schemars::schema_for!(T);
807        serde_json::to_value(schema)
808            .ok()
809            .map(crate::tool::ensure_object_schema)
810    }
811}
812
813// Default impl for non-Json extractors
814impl HasSchema for Context {
815    fn schema() -> Option<Value> {
816        None
817    }
818}
819
820impl HasSchema for RawArgs {
821    fn schema() -> Option<Value> {
822        None
823    }
824}
825
826impl<T> HasSchema for State<T> {
827    fn schema() -> Option<Value> {
828        None
829    }
830}
831
832impl<T> HasSchema for Extension<T> {
833    fn schema() -> Option<Value> {
834        None
835    }
836}
837
838// =============================================================================
839// Typed Extractor Handler
840// =============================================================================
841
842/// A handler that uses extractors with typed JSON input.
843///
844/// This trait is similar to [`ExtractorHandler`] but provides proper JSON
845/// schema generation for the input type when `Json<T>` is used.
846#[deprecated(
847    since = "0.8.0",
848    note = "Use `ExtractorHandler` instead -- `extractor_handler` auto-detects JSON schema from `Json<T>` extractors"
849)]
850pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
851where
852    I: JsonSchema,
853{
854    /// The future returned by the handler.
855    type Future: Future<Output = Result<CallToolResult>> + Send;
856
857    /// Call the handler with extracted values.
858    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
859}
860
861// Single extractor with Json<T>
862#[allow(deprecated)]
863impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
864where
865    S: Clone + Send + Sync + 'static,
866    F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
867    Fut: Future<Output = Result<CallToolResult>> + Send,
868    T: DeserializeOwned + JsonSchema + Send,
869{
870    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
871
872    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
873        Box::pin(async move {
874            let t1 =
875                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
876            self(t1).await
877        })
878    }
879}
880
881// Two extractors ending with Json<T>
882#[allow(deprecated)]
883impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
884where
885    S: Clone + Send + Sync + 'static,
886    F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
887    Fut: Future<Output = Result<CallToolResult>> + Send,
888    T1: FromToolRequest<S> + Send,
889    T: DeserializeOwned + JsonSchema + Send,
890{
891    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
892
893    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
894        Box::pin(async move {
895            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
896            let t2 =
897                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
898            self(t1, t2).await
899        })
900    }
901}
902
903// Three extractors ending with Json<T>
904#[allow(deprecated)]
905impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
906where
907    S: Clone + Send + Sync + 'static,
908    F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
909    Fut: Future<Output = Result<CallToolResult>> + Send,
910    T1: FromToolRequest<S> + Send,
911    T2: FromToolRequest<S> + Send,
912    T: DeserializeOwned + JsonSchema + Send,
913{
914    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
915
916    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
917        Box::pin(async move {
918            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
919            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
920            let t3 =
921                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
922            self(t1, t2, t3).await
923        })
924    }
925}
926
927// Four extractors ending with Json<T>
928#[allow(deprecated)]
929impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
930where
931    S: Clone + Send + Sync + 'static,
932    F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
933    Fut: Future<Output = Result<CallToolResult>> + Send,
934    T1: FromToolRequest<S> + Send,
935    T2: FromToolRequest<S> + Send,
936    T3: FromToolRequest<S> + Send,
937    T: DeserializeOwned + JsonSchema + Send,
938{
939    type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
940
941    fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
942        Box::pin(async move {
943            let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
944            let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
945            let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
946            let t4 =
947                Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
948            self(t1, t2, t3, t4).await
949        })
950    }
951}
952
953// =============================================================================
954// ToolBuilder Extensions
955// =============================================================================
956
957use crate::tool::{
958    BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
959};
960use tower::util::BoxCloneService;
961use tower_service::Service;
962
963/// Internal handler wrapper for extractor-based handlers
964pub(crate) struct ExtractorToolHandler<S, F, T> {
965    state: S,
966    handler: F,
967    input_schema: Value,
968    _phantom: PhantomData<T>,
969}
970
971impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
972where
973    S: Clone + Send + Sync + 'static,
974    F: ExtractorHandler<S, T> + Clone,
975    T: Send + Sync + 'static,
976{
977    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
978        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
979        self.call_with_context(ctx, args)
980    }
981
982    fn call_with_context(
983        &self,
984        ctx: RequestContext,
985        args: Value,
986    ) -> BoxFuture<'_, Result<CallToolResult>> {
987        let state = self.state.clone();
988        let handler = self.handler.clone();
989        Box::pin(async move { handler.call(ctx, state, args).await })
990    }
991
992    fn uses_context(&self) -> bool {
993        true
994    }
995
996    fn input_schema(&self) -> Value {
997        self.input_schema.clone()
998    }
999}
1000
1001/// Builder state for extractor-based handlers
1002#[doc(hidden)]
1003pub struct ToolBuilderWithExtractor<S, F, T> {
1004    pub(crate) name: String,
1005    pub(crate) title: Option<String>,
1006    pub(crate) description: Option<String>,
1007    pub(crate) output_schema: Option<Value>,
1008    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1009    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1010    pub(crate) task_support: crate::protocol::TaskSupportMode,
1011    pub(crate) state: S,
1012    pub(crate) handler: F,
1013    pub(crate) input_schema: Value,
1014    pub(crate) _phantom: PhantomData<T>,
1015}
1016
1017impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
1018where
1019    S: Clone + Send + Sync + 'static,
1020    F: ExtractorHandler<S, T> + Clone,
1021    T: Send + Sync + 'static,
1022{
1023    /// Build the tool.
1024    pub fn build(self) -> Tool {
1025        let handler = ExtractorToolHandler {
1026            state: self.state,
1027            handler: self.handler,
1028            input_schema: self.input_schema.clone(),
1029            _phantom: PhantomData,
1030        };
1031
1032        let handler_service = ToolHandlerService::new(handler);
1033        let catch_error = ToolCatchError::new(handler_service);
1034        let service = BoxCloneService::new(catch_error);
1035
1036        Tool {
1037            name: self.name,
1038            title: self.title,
1039            description: self.description,
1040            output_schema: self.output_schema,
1041            icons: self.icons,
1042            annotations: self.annotations,
1043            task_support: self.task_support,
1044            service,
1045            input_schema: self.input_schema,
1046        }
1047    }
1048
1049    /// Apply a Tower layer (middleware) to this tool.
1050    ///
1051    /// The layer wraps the tool's handler service, enabling functionality like
1052    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1053    ///
1054    /// # Example
1055    ///
1056    /// ```rust
1057    /// use std::sync::Arc;
1058    /// use std::time::Duration;
1059    /// use tower::timeout::TimeoutLayer;
1060    /// use tower_mcp::{ToolBuilder, CallToolResult};
1061    /// use tower_mcp::extract::{Json, State};
1062    /// use schemars::JsonSchema;
1063    /// use serde::Deserialize;
1064    ///
1065    /// #[derive(Clone)]
1066    /// struct AppState { prefix: String }
1067    ///
1068    /// #[derive(Debug, Deserialize, JsonSchema)]
1069    /// struct QueryInput { query: String }
1070    ///
1071    /// let state = Arc::new(AppState { prefix: "db".to_string() });
1072    ///
1073    /// let tool = ToolBuilder::new("search")
1074    ///     .description("Search with timeout")
1075    ///     .extractor_handler(state, |
1076    ///         State(app): State<Arc<AppState>>,
1077    ///         Json(input): Json<QueryInput>,
1078    ///     | async move {
1079    ///         Ok(CallToolResult::text(format!("{}: {}", app.prefix, input.query)))
1080    ///     })
1081    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1082    ///     .build();
1083    /// ```
1084    pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
1085        ToolBuilderWithExtractorLayer {
1086            name: self.name,
1087            title: self.title,
1088            description: self.description,
1089            output_schema: self.output_schema,
1090            icons: self.icons,
1091            annotations: self.annotations,
1092            task_support: self.task_support,
1093            state: self.state,
1094            handler: self.handler,
1095            input_schema: self.input_schema,
1096            layer,
1097            _phantom: PhantomData,
1098        }
1099    }
1100
1101    /// Apply a guard to this tool.
1102    ///
1103    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1104    pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
1105    where
1106        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1107    {
1108        self.layer(GuardLayer::new(guard))
1109    }
1110}
1111
1112/// Builder state after a layer has been applied to an extractor handler.
1113///
1114/// This builder allows chaining additional layers and building the final tool.
1115#[doc(hidden)]
1116pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
1117    name: String,
1118    title: Option<String>,
1119    description: Option<String>,
1120    output_schema: Option<Value>,
1121    icons: Option<Vec<crate::protocol::ToolIcon>>,
1122    annotations: Option<crate::protocol::ToolAnnotations>,
1123    task_support: crate::protocol::TaskSupportMode,
1124    state: S,
1125    handler: F,
1126    input_schema: Value,
1127    layer: L,
1128    _phantom: PhantomData<T>,
1129}
1130
1131#[allow(private_bounds)]
1132impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
1133where
1134    S: Clone + Send + Sync + 'static,
1135    F: ExtractorHandler<S, T> + Clone,
1136    T: Send + Sync + 'static,
1137    L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
1138        + Clone
1139        + Send
1140        + Sync
1141        + 'static,
1142    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1143    <L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
1144    <L::Service as Service<ToolRequest>>::Future: Send,
1145{
1146    /// Build the tool with the applied layer(s).
1147    pub fn build(self) -> Tool {
1148        let handler = ExtractorToolHandler {
1149            state: self.state,
1150            handler: self.handler,
1151            input_schema: self.input_schema.clone(),
1152            _phantom: PhantomData,
1153        };
1154
1155        let handler_service = ToolHandlerService::new(handler);
1156        let layered = self.layer.layer(handler_service);
1157        let catch_error = ToolCatchError::new(layered);
1158        let service = BoxCloneService::new(catch_error);
1159
1160        Tool {
1161            name: self.name,
1162            title: self.title,
1163            description: self.description,
1164            output_schema: self.output_schema,
1165            icons: self.icons,
1166            annotations: self.annotations,
1167            task_support: self.task_support,
1168            service,
1169            input_schema: self.input_schema,
1170        }
1171    }
1172
1173    /// Apply an additional Tower layer (middleware).
1174    ///
1175    /// Layers are applied in order, with earlier layers wrapping later ones.
1176    /// This means the first layer added is the outermost middleware.
1177    pub fn layer<L2>(
1178        self,
1179        layer: L2,
1180    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
1181        ToolBuilderWithExtractorLayer {
1182            name: self.name,
1183            title: self.title,
1184            description: self.description,
1185            output_schema: self.output_schema,
1186            icons: self.icons,
1187            annotations: self.annotations,
1188            task_support: self.task_support,
1189            state: self.state,
1190            handler: self.handler,
1191            input_schema: self.input_schema,
1192            layer: tower::layer::util::Stack::new(layer, self.layer),
1193            _phantom: PhantomData,
1194        }
1195    }
1196
1197    /// Apply a guard to this tool.
1198    ///
1199    /// See [`ToolBuilderWithHandler::guard`](crate::ToolBuilder) for details.
1200    pub fn guard<G>(
1201        self,
1202        guard: G,
1203    ) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
1204    where
1205        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1206    {
1207        self.layer(GuardLayer::new(guard))
1208    }
1209}
1210
1211/// Builder state for extractor-based handlers with typed JSON input
1212#[doc(hidden)]
1213#[deprecated(
1214    since = "0.8.0",
1215    note = "Use `ToolBuilderWithExtractor` via `extractor_handler` instead"
1216)]
1217pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
1218    pub(crate) name: String,
1219    pub(crate) title: Option<String>,
1220    pub(crate) description: Option<String>,
1221    pub(crate) output_schema: Option<Value>,
1222    pub(crate) input_schema_override: Option<Value>,
1223    pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
1224    pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
1225    pub(crate) task_support: crate::protocol::TaskSupportMode,
1226    pub(crate) state: S,
1227    pub(crate) handler: F,
1228    pub(crate) _phantom: PhantomData<(T, I)>,
1229}
1230
1231#[allow(deprecated)]
1232impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
1233where
1234    S: Clone + Send + Sync + 'static,
1235    F: TypedExtractorHandler<S, T, I> + Clone,
1236    T: Send + Sync + 'static,
1237    I: JsonSchema + Send + Sync + 'static,
1238{
1239    /// Build the tool.
1240    pub fn build(self) -> Tool {
1241        let input_schema = {
1242            let schema = self.input_schema_override.unwrap_or_else(|| {
1243                let schema = schemars::schema_for!(I);
1244                serde_json::to_value(schema).unwrap_or_else(|_| {
1245                    serde_json::json!({
1246                        "type": "object"
1247                    })
1248                })
1249            });
1250            crate::tool::ensure_object_schema(schema)
1251        };
1252
1253        let handler = TypedExtractorToolHandler {
1254            state: self.state,
1255            handler: self.handler,
1256            input_schema: input_schema.clone(),
1257            _phantom: PhantomData,
1258        };
1259
1260        let handler_service = crate::tool::ToolHandlerService::new(handler);
1261        let catch_error = ToolCatchError::new(handler_service);
1262        let service = BoxCloneService::new(catch_error);
1263
1264        Tool {
1265            name: self.name,
1266            title: self.title,
1267            description: self.description,
1268            output_schema: self.output_schema,
1269            icons: self.icons,
1270            annotations: self.annotations,
1271            task_support: self.task_support,
1272            service,
1273            input_schema,
1274        }
1275    }
1276}
1277
1278/// Internal handler wrapper for typed extractor-based handlers
1279struct TypedExtractorToolHandler<S, F, T, I> {
1280    state: S,
1281    handler: F,
1282    input_schema: Value,
1283    _phantom: PhantomData<(T, I)>,
1284}
1285
1286#[allow(deprecated)]
1287impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
1288where
1289    S: Clone + Send + Sync + 'static,
1290    F: TypedExtractorHandler<S, T, I> + Clone,
1291    T: Send + Sync + 'static,
1292    I: JsonSchema + Send + Sync + 'static,
1293{
1294    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1295        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
1296        self.call_with_context(ctx, args)
1297    }
1298
1299    fn call_with_context(
1300        &self,
1301        ctx: RequestContext,
1302        args: Value,
1303    ) -> BoxFuture<'_, Result<CallToolResult>> {
1304        let state = self.state.clone();
1305        let handler = self.handler.clone();
1306        Box::pin(async move { handler.call(ctx, state, args).await })
1307    }
1308
1309    fn uses_context(&self) -> bool {
1310        true
1311    }
1312
1313    fn input_schema(&self) -> Value {
1314        self.input_schema.clone()
1315    }
1316}
1317
1318#[cfg(test)]
1319mod tests {
1320    use super::*;
1321    use crate::protocol::RequestId;
1322    use schemars::JsonSchema;
1323    use serde::Deserialize;
1324    use std::sync::Arc;
1325
1326    #[derive(Debug, Deserialize, JsonSchema)]
1327    struct TestInput {
1328        name: String,
1329        count: i32,
1330    }
1331
1332    #[test]
1333    fn test_json_extraction() {
1334        let args = serde_json::json!({"name": "test", "count": 42});
1335        let ctx = RequestContext::new(RequestId::Number(1));
1336
1337        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1338        assert!(result.is_ok());
1339        let Json(input) = result.unwrap();
1340        assert_eq!(input.name, "test");
1341        assert_eq!(input.count, 42);
1342    }
1343
1344    #[test]
1345    fn test_json_extraction_error() {
1346        let args = serde_json::json!({"name": "test"}); // missing count
1347        let ctx = RequestContext::new(RequestId::Number(1));
1348
1349        let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
1350        assert!(result.is_err());
1351        let rejection = result.unwrap_err();
1352        // JsonRejection contains the serde error message
1353        assert!(rejection.message().contains("count"));
1354    }
1355
1356    #[test]
1357    fn test_state_extraction() {
1358        let args = serde_json::json!({});
1359        let ctx = RequestContext::new(RequestId::Number(1));
1360        let state = Arc::new("my-state".to_string());
1361
1362        let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
1363        assert!(result.is_ok());
1364        let State(extracted) = result.unwrap();
1365        assert_eq!(*extracted, "my-state");
1366    }
1367
1368    #[test]
1369    fn test_context_extraction() {
1370        let args = serde_json::json!({});
1371        let ctx = RequestContext::new(RequestId::Number(42));
1372
1373        let result = Context::from_tool_request(&ctx, &(), &args);
1374        assert!(result.is_ok());
1375        let extracted = result.unwrap();
1376        assert_eq!(*extracted.request_id(), RequestId::Number(42));
1377    }
1378
1379    #[test]
1380    fn test_raw_args_extraction() {
1381        let args = serde_json::json!({"foo": "bar", "baz": 123});
1382        let ctx = RequestContext::new(RequestId::Number(1));
1383
1384        let result = RawArgs::from_tool_request(&ctx, &(), &args);
1385        assert!(result.is_ok());
1386        let RawArgs(extracted) = result.unwrap();
1387        assert_eq!(extracted["foo"], "bar");
1388        assert_eq!(extracted["baz"], 123);
1389    }
1390
1391    #[test]
1392    fn test_extension_extraction() {
1393        use crate::context::Extensions;
1394
1395        #[derive(Clone, Debug, PartialEq)]
1396        struct DatabasePool {
1397            url: String,
1398        }
1399
1400        let args = serde_json::json!({});
1401
1402        // Create extensions with a value
1403        let mut extensions = Extensions::new();
1404        extensions.insert(Arc::new(DatabasePool {
1405            url: "postgres://localhost".to_string(),
1406        }));
1407
1408        // Create context with extensions
1409        let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
1410
1411        // Extract the extension
1412        let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
1413        assert!(result.is_ok());
1414        let Extension(pool) = result.unwrap();
1415        assert_eq!(pool.url, "postgres://localhost");
1416    }
1417
1418    #[test]
1419    fn test_extension_extraction_missing() {
1420        #[derive(Clone, Debug)]
1421        struct NotPresent;
1422
1423        let args = serde_json::json!({});
1424        let ctx = RequestContext::new(RequestId::Number(1));
1425
1426        // Try to extract something that's not in extensions
1427        let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
1428        assert!(result.is_err());
1429        let rejection = result.unwrap_err();
1430        // ExtensionRejection contains the type name
1431        assert!(rejection.type_name().contains("NotPresent"));
1432    }
1433
1434    #[tokio::test]
1435    async fn test_single_extractor_handler() {
1436        let handler = |Json(input): Json<TestInput>| async move {
1437            Ok(CallToolResult::text(format!(
1438                "{}: {}",
1439                input.name, input.count
1440            )))
1441        };
1442
1443        let ctx = RequestContext::new(RequestId::Number(1));
1444        let args = serde_json::json!({"name": "test", "count": 5});
1445
1446        // Use explicit trait to avoid ambiguity
1447        let result: Result<CallToolResult> =
1448            ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
1449        assert!(result.is_ok());
1450    }
1451
1452    #[tokio::test]
1453    async fn test_two_extractor_handler() {
1454        let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1455            Ok(CallToolResult::text(format!(
1456                "{}: {} - {}",
1457                state, input.name, input.count
1458            )))
1459        };
1460
1461        let ctx = RequestContext::new(RequestId::Number(1));
1462        let state = Arc::new("prefix".to_string());
1463        let args = serde_json::json!({"name": "test", "count": 5});
1464
1465        // Use explicit trait to avoid ambiguity
1466        let result: Result<CallToolResult> = ExtractorHandler::<
1467            Arc<String>,
1468            (State<Arc<String>>, Json<TestInput>),
1469        >::call(handler, ctx, state, args)
1470        .await;
1471        assert!(result.is_ok());
1472    }
1473
1474    #[tokio::test]
1475    async fn test_three_extractor_handler() {
1476        let handler = |State(state): State<Arc<String>>,
1477                       ctx: Context,
1478                       Json(input): Json<TestInput>| async move {
1479            // Verify we can access all extractors
1480            assert!(!ctx.is_cancelled());
1481            Ok(CallToolResult::text(format!(
1482                "{}: {} - {}",
1483                state, input.name, input.count
1484            )))
1485        };
1486
1487        let ctx = RequestContext::new(RequestId::Number(1));
1488        let state = Arc::new("prefix".to_string());
1489        let args = serde_json::json!({"name": "test", "count": 5});
1490
1491        // Use explicit trait to avoid ambiguity
1492        let result: Result<CallToolResult> = ExtractorHandler::<
1493            Arc<String>,
1494            (State<Arc<String>>, Context, Json<TestInput>),
1495        >::call(handler, ctx, state, args)
1496        .await;
1497        assert!(result.is_ok());
1498    }
1499
1500    #[test]
1501    fn test_json_schema_generation() {
1502        let schema = Json::<TestInput>::schema();
1503        assert!(schema.is_some());
1504        let schema = schema.unwrap();
1505        assert!(schema.get("properties").is_some());
1506    }
1507
1508    #[test]
1509    fn test_rejection_into_error() {
1510        let rejection = Rejection::new("test error");
1511        let error: Error = rejection.into();
1512        assert!(error.to_string().contains("test error"));
1513    }
1514
1515    #[test]
1516    fn test_json_rejection() {
1517        // Test basic JsonRejection
1518        let rejection = JsonRejection::new("missing field `name`");
1519        assert_eq!(rejection.message(), "missing field `name`");
1520        assert!(rejection.path().is_none());
1521        assert!(rejection.to_string().contains("Invalid input"));
1522
1523        // Test JsonRejection with path
1524        let rejection = JsonRejection::with_path("expected string", "users[0].name");
1525        assert_eq!(rejection.message(), "expected string");
1526        assert_eq!(rejection.path(), Some("users[0].name"));
1527        assert!(rejection.to_string().contains("users[0].name"));
1528
1529        // Test conversion to Error
1530        let error: Error = rejection.into();
1531        assert!(error.to_string().contains("users[0].name"));
1532    }
1533
1534    #[test]
1535    fn test_json_rejection_from_serde_error() {
1536        // Create a real serde error by deserializing invalid JSON
1537        #[derive(Debug, serde::Deserialize)]
1538        struct TestStruct {
1539            #[allow(dead_code)]
1540            name: String,
1541        }
1542
1543        let result: std::result::Result<TestStruct, _> =
1544            serde_json::from_value(serde_json::json!({"count": 42}));
1545        assert!(result.is_err());
1546
1547        let rejection: JsonRejection = result.unwrap_err().into();
1548        assert!(rejection.message().contains("name"));
1549    }
1550
1551    #[test]
1552    fn test_extension_rejection() {
1553        // Test ExtensionRejection
1554        let rejection = ExtensionRejection::not_found::<String>();
1555        assert!(rejection.type_name().contains("String"));
1556        assert!(rejection.to_string().contains("not found"));
1557        assert!(rejection.to_string().contains("with_state"));
1558
1559        // Test conversion to Error
1560        let error: Error = rejection.into();
1561        assert!(error.to_string().contains("not found"));
1562    }
1563
1564    #[tokio::test]
1565    async fn test_tool_builder_extractor_handler() {
1566        use crate::ToolBuilder;
1567
1568        let state = Arc::new("shared-state".to_string());
1569
1570        let tool =
1571            ToolBuilder::new("test_extractor")
1572                .description("Test extractor handler")
1573                .extractor_handler(
1574                    state,
1575                    |State(state): State<Arc<String>>,
1576                     ctx: Context,
1577                     Json(input): Json<TestInput>| async move {
1578                        assert!(!ctx.is_cancelled());
1579                        Ok(CallToolResult::text(format!(
1580                            "{}: {} - {}",
1581                            state, input.name, input.count
1582                        )))
1583                    },
1584                )
1585                .build();
1586
1587        assert_eq!(tool.name, "test_extractor");
1588        assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
1589
1590        // Test calling the tool
1591        let result = tool
1592            .call(serde_json::json!({"name": "test", "count": 42}))
1593            .await;
1594        assert!(!result.is_error);
1595    }
1596
1597    #[tokio::test]
1598    #[allow(deprecated)]
1599    async fn test_tool_builder_extractor_handler_typed() {
1600        use crate::ToolBuilder;
1601
1602        let state = Arc::new("typed-state".to_string());
1603
1604        let tool = ToolBuilder::new("test_typed")
1605            .description("Test typed extractor handler")
1606            .extractor_handler_typed::<_, _, _, TestInput>(
1607                state,
1608                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1609                    Ok(CallToolResult::text(format!(
1610                        "{}: {} - {}",
1611                        state, input.name, input.count
1612                    )))
1613                },
1614            )
1615            .build();
1616
1617        assert_eq!(tool.name, "test_typed");
1618
1619        // Verify schema is properly generated from TestInput
1620        let def = tool.definition();
1621        let schema = def.input_schema;
1622        assert!(schema.get("properties").is_some());
1623
1624        // Test calling the tool
1625        let result = tool
1626            .call(serde_json::json!({"name": "world", "count": 99}))
1627            .await;
1628        assert!(!result.is_error);
1629    }
1630
1631    #[tokio::test]
1632    async fn test_extractor_handler_auto_schema() {
1633        use crate::ToolBuilder;
1634
1635        let state = Arc::new("auto-schema".to_string());
1636
1637        // extractor_handler (not _typed) should auto-detect Json<TestInput> schema
1638        let tool = ToolBuilder::new("test_auto_schema")
1639            .description("Test auto schema detection")
1640            .extractor_handler(
1641                state,
1642                |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1643                    Ok(CallToolResult::text(format!(
1644                        "{}: {} - {}",
1645                        state, input.name, input.count
1646                    )))
1647                },
1648            )
1649            .build();
1650
1651        // Verify schema is properly generated from TestInput (not generic object)
1652        let def = tool.definition();
1653        let schema = def.input_schema;
1654        assert!(
1655            schema.get("properties").is_some(),
1656            "Schema should have properties from TestInput, got: {}",
1657            schema
1658        );
1659        let props = schema.get("properties").unwrap();
1660        assert!(
1661            props.get("name").is_some(),
1662            "Schema should have 'name' property"
1663        );
1664        assert!(
1665            props.get("count").is_some(),
1666            "Schema should have 'count' property"
1667        );
1668
1669        // Test calling the tool
1670        let result = tool
1671            .call(serde_json::json!({"name": "world", "count": 99}))
1672            .await;
1673        assert!(!result.is_error);
1674    }
1675
1676    #[test]
1677    fn test_extractor_handler_no_json_fallback() {
1678        use crate::ToolBuilder;
1679
1680        // extractor_handler without Json<T> should fall back to generic schema
1681        let tool = ToolBuilder::new("test_no_json")
1682            .description("Test no json fallback")
1683            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1684                Ok(CallToolResult::json(args))
1685            })
1686            .build();
1687
1688        let def = tool.definition();
1689        let schema = def.input_schema;
1690        assert_eq!(
1691            schema.get("type").and_then(|v| v.as_str()),
1692            Some("object"),
1693            "Schema should be generic object"
1694        );
1695        assert_eq!(
1696            schema.get("additionalProperties").and_then(|v| v.as_bool()),
1697            Some(true),
1698            "Schema should allow additional properties"
1699        );
1700        // Should NOT have specific properties
1701        assert!(
1702            schema.get("properties").is_none(),
1703            "Generic schema should not have specific properties"
1704        );
1705    }
1706
1707    #[tokio::test]
1708    async fn test_extractor_handler_with_layer() {
1709        use crate::ToolBuilder;
1710        use std::time::Duration;
1711        use tower::timeout::TimeoutLayer;
1712
1713        let state = Arc::new("layered".to_string());
1714
1715        let tool = ToolBuilder::new("test_extractor_layer")
1716            .description("Test extractor handler with layer")
1717            .extractor_handler(
1718                state,
1719                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1720                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1721                },
1722            )
1723            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1724            .build();
1725
1726        // Verify the tool works
1727        let result = tool
1728            .call(serde_json::json!({"name": "test", "count": 1}))
1729            .await;
1730        assert!(!result.is_error);
1731        assert_eq!(result.first_text().unwrap(), "layered: test");
1732
1733        // Verify schema is still properly generated
1734        let def = tool.definition();
1735        let schema = def.input_schema;
1736        assert!(
1737            schema.get("properties").is_some(),
1738            "Schema should have properties even with layer"
1739        );
1740    }
1741
1742    #[tokio::test]
1743    async fn test_extractor_handler_with_timeout_layer() {
1744        use crate::ToolBuilder;
1745        use std::time::Duration;
1746        use tower::timeout::TimeoutLayer;
1747
1748        let tool = ToolBuilder::new("test_extractor_timeout")
1749            .description("Test extractor handler timeout")
1750            .extractor_handler((), |Json(input): Json<TestInput>| async move {
1751                tokio::time::sleep(Duration::from_millis(200)).await;
1752                Ok(CallToolResult::text(input.name.to_string()))
1753            })
1754            .layer(TimeoutLayer::new(Duration::from_millis(50)))
1755            .build();
1756
1757        // Should timeout
1758        let result = tool
1759            .call(serde_json::json!({"name": "slow", "count": 1}))
1760            .await;
1761        assert!(result.is_error);
1762        let msg = result.first_text().unwrap().to_lowercase();
1763        assert!(
1764            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1765            "Expected timeout error, got: {}",
1766            msg
1767        );
1768    }
1769
1770    #[tokio::test]
1771    async fn test_extractor_handler_with_multiple_layers() {
1772        use crate::ToolBuilder;
1773        use std::time::Duration;
1774        use tower::limit::ConcurrencyLimitLayer;
1775        use tower::timeout::TimeoutLayer;
1776
1777        let state = Arc::new("multi".to_string());
1778
1779        let tool = ToolBuilder::new("test_multi_layer")
1780            .description("Test multiple layers")
1781            .extractor_handler(
1782                state,
1783                |State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
1784                    Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
1785                },
1786            )
1787            .layer(TimeoutLayer::new(Duration::from_secs(5)))
1788            .layer(ConcurrencyLimitLayer::new(10))
1789            .build();
1790
1791        let result = tool
1792            .call(serde_json::json!({"name": "test", "count": 1}))
1793            .await;
1794        assert!(!result.is_error);
1795        assert_eq!(result.first_text().unwrap(), "multi: test");
1796    }
1797}