Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8//!
9//! ## Per-Tool Middleware
10//!
11//! Tools are implemented as Tower services internally, enabling middleware
12//! composition via the `.layer()` method:
13//!
14//! ```rust
15//! use std::time::Duration;
16//! use tower::timeout::TimeoutLayer;
17//! use tower_mcp::{ToolBuilder, CallToolResult};
18//! use schemars::JsonSchema;
19//! use serde::Deserialize;
20//!
21//! #[derive(Debug, Deserialize, JsonSchema)]
22//! struct SearchInput { query: String }
23//!
24//! let tool = ToolBuilder::new("slow_search")
25//!     .description("Search with extended timeout")
26//!     .handler(|input: SearchInput| async move {
27//!         Ok(CallToolResult::text("result"))
28//!     })
29//!     .layer(TimeoutLayer::new(Duration::from_secs(30)))
30//!     .build();
31//! ```
32
33use std::borrow::Cow;
34use std::convert::Infallible;
35use std::fmt;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use schemars::{JsonSchema, Schema, SchemaGenerator};
42use serde::Serialize;
43use serde::de::DeserializeOwned;
44use serde_json::Value;
45use tower::util::BoxCloneService;
46use tower_service::Service;
47
48use crate::context::RequestContext;
49use crate::error::{Error, Result, ResultExt};
50use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
51
52// =============================================================================
53// Service Types for Per-Tool Middleware
54// =============================================================================
55
56/// Request type for tool services.
57///
58/// Contains the request context (for progress reporting, cancellation, etc.)
59/// and the tool arguments as raw JSON.
60#[derive(Debug, Clone)]
61pub struct ToolRequest {
62    /// Request context for progress reporting, cancellation, and client requests
63    pub ctx: RequestContext,
64    /// Tool arguments as raw JSON
65    pub args: Value,
66}
67
68impl ToolRequest {
69    /// Create a new tool request
70    pub fn new(ctx: RequestContext, args: Value) -> Self {
71        Self { ctx, args }
72    }
73}
74
75/// A boxed, cloneable tool service with `Error = Infallible`.
76///
77/// This is the internal service type that tools use. Middleware errors are
78/// caught and converted to `CallToolResult::error()` responses, so the
79/// service never fails at the Tower level.
80pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
81
82/// Catches errors from the inner service and converts them to `CallToolResult::error()`.
83///
84/// This wrapper ensures that middleware errors (e.g., timeouts, rate limits)
85/// and handler errors are converted to tool-level error responses with
86/// `is_error: true`, rather than propagating as Tower service errors.
87pub struct ToolCatchError<S> {
88    inner: S,
89}
90
91impl<S> ToolCatchError<S> {
92    /// Create a new `ToolCatchError` wrapping the given service.
93    pub fn new(inner: S) -> Self {
94        Self { inner }
95    }
96}
97
98impl<S: Clone> Clone for ToolCatchError<S> {
99    fn clone(&self) -> Self {
100        Self {
101            inner: self.inner.clone(),
102        }
103    }
104}
105
106impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("ToolCatchError")
109            .field("inner", &self.inner)
110            .finish()
111    }
112}
113
114impl<S> Service<ToolRequest> for ToolCatchError<S>
115where
116    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
117    S::Error: fmt::Display + Send,
118    S::Future: Send,
119{
120    type Response = CallToolResult;
121    type Error = Infallible;
122    type Future =
123        Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
124
125    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
126        // Map any readiness error to Infallible (we catch it on call)
127        match self.inner.poll_ready(cx) {
128            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
129            Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
130            Poll::Pending => Poll::Pending,
131        }
132    }
133
134    fn call(&mut self, req: ToolRequest) -> Self::Future {
135        let fut = self.inner.call(req);
136
137        Box::pin(async move {
138            match fut.await {
139                Ok(result) => Ok(result),
140                Err(err) => Ok(CallToolResult::error(err.to_string())),
141            }
142        })
143    }
144}
145
146/// A tower [`Layer`](tower::Layer) that applies a guard function before the inner service.
147///
148/// Guards run before the tool handler and can short-circuit with an error message.
149/// Use via [`ToolBuilderWithHandler::guard`] or [`Tool::with_guard`] rather than
150/// constructing directly.
151///
152/// # Example
153///
154/// ```rust
155/// use tower_mcp::{ToolBuilder, ToolRequest, CallToolResult};
156/// use schemars::JsonSchema;
157/// use serde::Deserialize;
158///
159/// #[derive(Debug, Deserialize, JsonSchema)]
160/// struct DeleteInput { id: String, confirm: bool }
161///
162/// let tool = ToolBuilder::new("delete")
163///     .description("Delete a record")
164///     .handler(|input: DeleteInput| async move {
165///         Ok(CallToolResult::text(format!("deleted {}", input.id)))
166///     })
167///     .guard(|req: &ToolRequest| {
168///         let confirm = req.args.get("confirm").and_then(|v| v.as_bool()).unwrap_or(false);
169///         if !confirm {
170///             return Err("Must set confirm=true to delete".to_string());
171///         }
172///         Ok(())
173///     })
174///     .build();
175/// ```
176#[derive(Clone)]
177pub struct GuardLayer<G> {
178    guard: G,
179}
180
181impl<G> GuardLayer<G> {
182    /// Create a new guard layer from a closure.
183    ///
184    /// The closure receives a `&ToolRequest` and returns `Ok(())` to proceed
185    /// or `Err(String)` to reject with an error message.
186    pub fn new(guard: G) -> Self {
187        Self { guard }
188    }
189}
190
191impl<G, S> tower::Layer<S> for GuardLayer<G>
192where
193    G: Clone,
194{
195    type Service = GuardService<G, S>;
196
197    fn layer(&self, inner: S) -> Self::Service {
198        GuardService {
199            guard: self.guard.clone(),
200            inner,
201        }
202    }
203}
204
205/// Service wrapper that runs a guard check before calling the inner service.
206///
207/// Created by [`GuardLayer`]. See its documentation for usage.
208#[derive(Clone)]
209pub struct GuardService<G, S> {
210    guard: G,
211    inner: S,
212}
213
214impl<G, S> Service<ToolRequest> for GuardService<G, S>
215where
216    G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
217    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
218    S::Error: Into<Error> + Send,
219    S::Future: Send,
220{
221    type Response = CallToolResult;
222    type Error = Error;
223    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
224
225    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
226        self.inner.poll_ready(cx).map_err(Into::into)
227    }
228
229    fn call(&mut self, req: ToolRequest) -> Self::Future {
230        match (self.guard)(&req) {
231            Ok(()) => {
232                let fut = self.inner.call(req);
233                Box::pin(async move { fut.await.map_err(Into::into) })
234            }
235            Err(msg) => Box::pin(async move { Err(Error::tool(msg)) }),
236        }
237    }
238}
239
240/// A marker type for tools that take no parameters.
241///
242/// Use this instead of `()` when defining tools with no input parameters.
243/// The unit type `()` generates `"type": "null"` in JSON Schema, which many
244/// MCP clients reject. `NoParams` generates `"type": "object"` with no
245/// required properties, which is the correct schema for parameterless tools.
246///
247/// # Example
248///
249/// ```rust
250/// use tower_mcp::{ToolBuilder, CallToolResult, NoParams};
251///
252/// let tool = ToolBuilder::new("get_status")
253///     .description("Get current status")
254///     .handler(|_input: NoParams| async move {
255///         Ok(CallToolResult::text("OK"))
256///     })
257///     .build();
258/// ```
259#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
260pub struct NoParams;
261
262impl<'de> serde::Deserialize<'de> for NoParams {
263    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
264    where
265        D: serde::Deserializer<'de>,
266    {
267        // Accept null, empty object, or any object (ignoring all fields)
268        struct NoParamsVisitor;
269
270        impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
271            type Value = NoParams;
272
273            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
274                formatter.write_str("null or an object")
275            }
276
277            fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
278            where
279                E: serde::de::Error,
280            {
281                Ok(NoParams)
282            }
283
284            fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
285            where
286                E: serde::de::Error,
287            {
288                Ok(NoParams)
289            }
290
291            fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
292            where
293                D: serde::Deserializer<'de>,
294            {
295                serde::Deserialize::deserialize(deserializer)
296            }
297
298            fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
299            where
300                A: serde::de::MapAccess<'de>,
301            {
302                // Drain the map, ignoring all entries
303                while map
304                    .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
305                    .is_some()
306                {}
307                Ok(NoParams)
308            }
309        }
310
311        deserializer.deserialize_any(NoParamsVisitor)
312    }
313}
314
315impl JsonSchema for NoParams {
316    fn schema_name() -> Cow<'static, str> {
317        Cow::Borrowed("NoParams")
318    }
319
320    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
321        serde_json::json!({
322            "type": "object"
323        })
324        .try_into()
325        .expect("valid schema")
326    }
327}
328
329/// Validate a tool name according to MCP spec.
330///
331/// Tool names must be:
332/// - 1-128 characters long
333/// - Contain only alphanumeric characters, underscores, hyphens, and dots
334///
335/// Returns `Ok(())` if valid, `Err` with description if invalid.
336pub fn validate_tool_name(name: &str) -> Result<()> {
337    if name.is_empty() {
338        return Err(Error::tool("Tool name cannot be empty"));
339    }
340    if name.len() > 128 {
341        return Err(Error::tool(format!(
342            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
343            name,
344            name.len()
345        )));
346    }
347    if let Some(invalid_char) = name
348        .chars()
349        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
350    {
351        return Err(Error::tool(format!(
352            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
353            name, invalid_char
354        )));
355    }
356    Ok(())
357}
358
359/// A boxed future for tool handlers
360pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
361
362/// Tool handler trait - the core abstraction for tool execution
363pub trait ToolHandler: Send + Sync {
364    /// Execute the tool with the given arguments
365    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
366
367    /// Execute the tool with request context for progress/cancellation support
368    ///
369    /// The default implementation ignores the context and calls `call`.
370    /// Override this to receive progress/cancellation context.
371    fn call_with_context(
372        &self,
373        _ctx: RequestContext,
374        args: Value,
375    ) -> BoxFuture<'_, Result<CallToolResult>> {
376        self.call(args)
377    }
378
379    /// Returns true if this handler uses context (for optimization)
380    fn uses_context(&self) -> bool {
381        false
382    }
383
384    /// Get the tool's input schema
385    fn input_schema(&self) -> Value;
386}
387
388/// Adapts a `ToolHandler` to a Tower `Service<ToolRequest>`.
389///
390/// This is an internal adapter that bridges the handler abstraction to the
391/// service abstraction, enabling middleware composition.
392pub(crate) struct ToolHandlerService<H> {
393    handler: Arc<H>,
394}
395
396impl<H> ToolHandlerService<H> {
397    pub(crate) fn new(handler: H) -> Self {
398        Self {
399            handler: Arc::new(handler),
400        }
401    }
402}
403
404impl<H> Clone for ToolHandlerService<H> {
405    fn clone(&self) -> Self {
406        Self {
407            handler: self.handler.clone(),
408        }
409    }
410}
411
412impl<H> Service<ToolRequest> for ToolHandlerService<H>
413where
414    H: ToolHandler + 'static,
415{
416    type Response = CallToolResult;
417    type Error = Error;
418    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
419
420    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
421        Poll::Ready(Ok(()))
422    }
423
424    fn call(&mut self, req: ToolRequest) -> Self::Future {
425        let handler = self.handler.clone();
426        Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
427    }
428}
429
430/// A complete tool definition with service-based execution.
431///
432/// Tools are implemented as Tower services internally, enabling middleware
433/// composition via the builder's `.layer()` method. The service is wrapped
434/// in [`ToolCatchError`] to convert any errors (from handlers or middleware)
435/// into `CallToolResult::error()` responses.
436pub struct Tool {
437    /// Tool name (must be 1-128 chars, alphanumeric/underscore/hyphen/dot only)
438    pub name: String,
439    /// Human-readable title for the tool
440    pub title: Option<String>,
441    /// Description of what the tool does
442    pub description: Option<String>,
443    /// JSON Schema for the tool's output (optional)
444    pub output_schema: Option<Value>,
445    /// Icons for the tool
446    pub icons: Option<Vec<ToolIcon>>,
447    /// Tool annotations (hints about behavior)
448    pub annotations: Option<ToolAnnotations>,
449    /// The boxed service that executes the tool
450    pub(crate) service: BoxToolService,
451    /// JSON Schema for the tool's input
452    pub(crate) input_schema: Value,
453}
454
455impl std::fmt::Debug for Tool {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        f.debug_struct("Tool")
458            .field("name", &self.name)
459            .field("title", &self.title)
460            .field("description", &self.description)
461            .field("output_schema", &self.output_schema)
462            .field("icons", &self.icons)
463            .field("annotations", &self.annotations)
464            .finish_non_exhaustive()
465    }
466}
467
468// SAFETY: BoxCloneService is Send + Sync (tower provides unsafe impl Sync),
469// and all other fields in Tool are Send + Sync.
470unsafe impl Send for Tool {}
471unsafe impl Sync for Tool {}
472
473impl Clone for Tool {
474    fn clone(&self) -> Self {
475        Self {
476            name: self.name.clone(),
477            title: self.title.clone(),
478            description: self.description.clone(),
479            output_schema: self.output_schema.clone(),
480            icons: self.icons.clone(),
481            annotations: self.annotations.clone(),
482            service: self.service.clone(),
483            input_schema: self.input_schema.clone(),
484        }
485    }
486}
487
488impl Tool {
489    /// Create a new tool builder
490    pub fn builder(name: impl Into<String>) -> ToolBuilder {
491        ToolBuilder::new(name)
492    }
493
494    /// Get the tool definition for tools/list
495    pub fn definition(&self) -> ToolDefinition {
496        ToolDefinition {
497            name: self.name.clone(),
498            title: self.title.clone(),
499            description: self.description.clone(),
500            input_schema: self.input_schema.clone(),
501            output_schema: self.output_schema.clone(),
502            icons: self.icons.clone(),
503            annotations: self.annotations.clone(),
504        }
505    }
506
507    /// Call the tool without context
508    ///
509    /// Creates a dummy request context. For full context support, use
510    /// [`call_with_context`](Self::call_with_context).
511    pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
512        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
513        self.call_with_context(ctx, args)
514    }
515
516    /// Call the tool with request context
517    ///
518    /// The context provides progress reporting, cancellation support, and
519    /// access to client requests (for sampling, etc.).
520    ///
521    /// # Note
522    ///
523    /// This method returns `CallToolResult` directly (not `Result<CallToolResult>`).
524    /// Any errors from the handler or middleware are converted to
525    /// `CallToolResult::error()` with `is_error: true`.
526    pub fn call_with_context(
527        &self,
528        ctx: RequestContext,
529        args: Value,
530    ) -> BoxFuture<'static, CallToolResult> {
531        use tower::ServiceExt;
532        let service = self.service.clone();
533        Box::pin(async move {
534            // ServiceExt::oneshot properly handles poll_ready before call
535            // Service is Infallible, so unwrap is safe
536            service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
537        })
538    }
539
540    /// Apply a guard to this built tool.
541    ///
542    /// The guard runs before the handler and can short-circuit with an error.
543    /// This is useful for applying the same guard to multiple tools (per-group
544    /// pattern):
545    ///
546    /// ```rust
547    /// use tower_mcp::{ToolBuilder, CallToolResult};
548    /// use tower_mcp::tool::ToolRequest;
549    /// use schemars::JsonSchema;
550    /// use serde::Deserialize;
551    ///
552    /// #[derive(Debug, Deserialize, JsonSchema)]
553    /// struct Input { value: String }
554    ///
555    /// fn build_tool(name: &str) -> tower_mcp::tool::Tool {
556    ///     ToolBuilder::new(name)
557    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
558    ///         .build()
559    /// }
560    ///
561    /// let guard = |_req: &ToolRequest| -> Result<(), String> { Ok(()) };
562    ///
563    /// let tools: Vec<_> = vec![build_tool("a"), build_tool("b")]
564    ///     .into_iter()
565    ///     .map(|t| t.with_guard(guard.clone()))
566    ///     .collect();
567    /// ```
568    pub fn with_guard<G>(self, guard: G) -> Self
569    where
570        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
571    {
572        let guarded = GuardService {
573            guard,
574            inner: self.service,
575        };
576        let caught = ToolCatchError::new(guarded);
577        Tool {
578            service: BoxCloneService::new(caught),
579            ..self
580        }
581    }
582
583    /// Create a new tool with a prefixed name.
584    ///
585    /// This creates a copy of the tool with its name prefixed by the given
586    /// string and a dot separator. For example, if the tool is named "query"
587    /// and the prefix is "db", the new tool will be named "db.query".
588    ///
589    /// This is used internally by `McpRouter::nest()` to namespace tools.
590    ///
591    /// # Example
592    ///
593    /// ```rust
594    /// use tower_mcp::{ToolBuilder, CallToolResult};
595    /// use schemars::JsonSchema;
596    /// use serde::Deserialize;
597    ///
598    /// #[derive(Debug, Deserialize, JsonSchema)]
599    /// struct Input { value: String }
600    ///
601    /// let tool = ToolBuilder::new("query")
602    ///     .description("Query the database")
603    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
604    ///     .build();
605    ///
606    /// let prefixed = tool.with_name_prefix("db");
607    /// assert_eq!(prefixed.name, "db.query");
608    /// ```
609    pub fn with_name_prefix(&self, prefix: &str) -> Self {
610        Self {
611            name: format!("{}.{}", prefix, self.name),
612            title: self.title.clone(),
613            description: self.description.clone(),
614            output_schema: self.output_schema.clone(),
615            icons: self.icons.clone(),
616            annotations: self.annotations.clone(),
617            service: self.service.clone(),
618            input_schema: self.input_schema.clone(),
619        }
620    }
621
622    /// Create a tool from a handler (internal helper)
623    fn from_handler<H: ToolHandler + 'static>(
624        name: String,
625        title: Option<String>,
626        description: Option<String>,
627        output_schema: Option<Value>,
628        icons: Option<Vec<ToolIcon>>,
629        annotations: Option<ToolAnnotations>,
630        handler: H,
631    ) -> Self {
632        let input_schema = handler.input_schema();
633        let handler_service = ToolHandlerService::new(handler);
634        let catch_error = ToolCatchError::new(handler_service);
635        let service = BoxCloneService::new(catch_error);
636
637        Self {
638            name,
639            title,
640            description,
641            output_schema,
642            icons,
643            annotations,
644            service,
645            input_schema,
646        }
647    }
648}
649
650// =============================================================================
651// Builder API
652// =============================================================================
653
654/// Builder for creating tools with a fluent API
655///
656/// # Example
657///
658/// ```rust
659/// use tower_mcp::{ToolBuilder, CallToolResult};
660/// use schemars::JsonSchema;
661/// use serde::Deserialize;
662///
663/// #[derive(Debug, Deserialize, JsonSchema)]
664/// struct GreetInput {
665///     name: String,
666/// }
667///
668/// let tool = ToolBuilder::new("greet")
669///     .description("Greet someone by name")
670///     .handler(|input: GreetInput| async move {
671///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
672///     })
673///     .build();
674///
675/// assert_eq!(tool.name, "greet");
676/// ```
677pub struct ToolBuilder {
678    name: String,
679    title: Option<String>,
680    description: Option<String>,
681    output_schema: Option<Value>,
682    icons: Option<Vec<ToolIcon>>,
683    annotations: Option<ToolAnnotations>,
684}
685
686impl ToolBuilder {
687    /// Create a new tool builder with the given name.
688    ///
689    /// Tool names must be 1-128 characters and contain only alphanumeric
690    /// characters, underscores, hyphens, and dots.
691    ///
692    /// Use [`try_new`](Self::try_new) if the name comes from runtime input.
693    ///
694    /// # Panics
695    ///
696    /// Panics if `name` is empty, exceeds 128 characters, or contains
697    /// characters other than ASCII alphanumerics, `_`, `-`, and `.`.
698    pub fn new(name: impl Into<String>) -> Self {
699        let name = name.into();
700        if let Err(e) = validate_tool_name(&name) {
701            panic!("{e}");
702        }
703        Self {
704            name,
705            title: None,
706            description: None,
707            output_schema: None,
708            icons: None,
709            annotations: None,
710        }
711    }
712
713    /// Create a new tool builder, returning an error if the name is invalid.
714    ///
715    /// This is the fallible alternative to [`new`](Self::new) for cases where
716    /// the tool name comes from runtime input (e.g., user configuration or
717    /// database).
718    pub fn try_new(name: impl Into<String>) -> Result<Self> {
719        let name = name.into();
720        validate_tool_name(&name)?;
721        Ok(Self {
722            name,
723            title: None,
724            description: None,
725            output_schema: None,
726            icons: None,
727            annotations: None,
728        })
729    }
730
731    /// Set a human-readable title for the tool
732    pub fn title(mut self, title: impl Into<String>) -> Self {
733        self.title = Some(title.into());
734        self
735    }
736
737    /// Set the output schema (JSON Schema for structured output)
738    pub fn output_schema(mut self, schema: Value) -> Self {
739        self.output_schema = Some(schema);
740        self
741    }
742
743    /// Add an icon for the tool
744    pub fn icon(mut self, src: impl Into<String>) -> Self {
745        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
746            src: src.into(),
747            mime_type: None,
748            sizes: None,
749        });
750        self
751    }
752
753    /// Add an icon with metadata
754    pub fn icon_with_meta(
755        mut self,
756        src: impl Into<String>,
757        mime_type: Option<String>,
758        sizes: Option<Vec<String>>,
759    ) -> Self {
760        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
761            src: src.into(),
762            mime_type,
763            sizes,
764        });
765        self
766    }
767
768    /// Set the tool description
769    pub fn description(mut self, description: impl Into<String>) -> Self {
770        self.description = Some(description.into());
771        self
772    }
773
774    /// Mark the tool as read-only (does not modify state)
775    pub fn read_only(mut self) -> Self {
776        self.annotations
777            .get_or_insert_with(ToolAnnotations::default)
778            .read_only_hint = true;
779        self
780    }
781
782    /// Mark the tool as non-destructive
783    pub fn non_destructive(mut self) -> Self {
784        self.annotations
785            .get_or_insert_with(ToolAnnotations::default)
786            .destructive_hint = false;
787        self
788    }
789
790    /// Mark the tool as idempotent (same args = same effect)
791    pub fn idempotent(mut self) -> Self {
792        self.annotations
793            .get_or_insert_with(ToolAnnotations::default)
794            .idempotent_hint = true;
795        self
796    }
797
798    /// Set tool annotations directly
799    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
800        self.annotations = Some(annotations);
801        self
802    }
803
804    /// Create a tool that takes no parameters.
805    ///
806    /// This is a convenience method for tools that don't require any input.
807    /// It generates the correct `{"type": "object"}` schema that MCP clients expect.
808    ///
809    /// # Example
810    ///
811    /// ```rust
812    /// use tower_mcp::{ToolBuilder, CallToolResult};
813    ///
814    /// let tool = ToolBuilder::new("get_status")
815    ///     .description("Get current status")
816    ///     .no_params_handler(|| async {
817    ///         Ok(CallToolResult::text("OK"))
818    ///     })
819    ///     .build();
820    /// ```
821    pub fn no_params_handler<F, Fut>(self, handler: F) -> ToolBuilderWithNoParamsHandler<F>
822    where
823        F: Fn() -> Fut + Send + Sync + 'static,
824        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
825    {
826        ToolBuilderWithNoParamsHandler {
827            name: self.name,
828            title: self.title,
829            description: self.description,
830            output_schema: self.output_schema,
831            icons: self.icons,
832            annotations: self.annotations,
833            handler,
834        }
835    }
836
837    /// Specify input type and handler.
838    ///
839    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
840    /// The handler receives the deserialized input and returns a `CallToolResult`.
841    ///
842    /// # State Sharing
843    ///
844    /// To share state across tool calls (e.g., database connections, API clients),
845    /// wrap your state in an `Arc` and clone it into the async block:
846    ///
847    /// ```rust
848    /// use std::sync::Arc;
849    /// use tower_mcp::{ToolBuilder, CallToolResult};
850    /// use schemars::JsonSchema;
851    /// use serde::Deserialize;
852    ///
853    /// struct AppState {
854    ///     api_key: String,
855    /// }
856    ///
857    /// #[derive(Debug, Deserialize, JsonSchema)]
858    /// struct MyInput {
859    ///     query: String,
860    /// }
861    ///
862    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
863    ///
864    /// let tool = ToolBuilder::new("my_tool")
865    ///     .description("A tool that uses shared state")
866    ///     .handler(move |input: MyInput| {
867    ///         let state = state.clone(); // Clone Arc for the async block
868    ///         async move {
869    ///             // Use state.api_key here...
870    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
871    ///         }
872    ///     })
873    ///     .build();
874    /// ```
875    ///
876    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
877    /// cloning it inside the closure body allows each async invocation to
878    /// have its own reference to the shared state.
879    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
880    where
881        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
882        F: Fn(I) -> Fut + Send + Sync + 'static,
883        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
884    {
885        ToolBuilderWithHandler {
886            name: self.name,
887            title: self.title,
888            description: self.description,
889            output_schema: self.output_schema,
890            icons: self.icons,
891            annotations: self.annotations,
892            handler,
893            _phantom: std::marker::PhantomData,
894        }
895    }
896
897    /// Create a tool using the extractor pattern.
898    ///
899    /// This method provides an axum-inspired way to define handlers where state,
900    /// context, and input are extracted declaratively from function parameters.
901    /// This reduces the combinatorial explosion of handler variants like
902    /// `handler_with_state`, `handler_with_context`, etc.
903    ///
904    /// # Schema Auto-Detection
905    ///
906    /// When a [`Json<T>`](crate::extract::Json) extractor is used, the proper JSON
907    /// schema is automatically generated from `T`'s `JsonSchema` implementation.
908    /// This means `extractor_handler` produces the same schema as
909    /// `extractor_handler_typed` for the common case, without requiring a turbofish.
910    ///
911    /// # Extractors
912    ///
913    /// Built-in extractors available in [`crate::extract`]:
914    /// - [`Json<T>`](crate::extract::Json) - Deserialize JSON arguments to type `T`
915    /// - [`State<T>`](crate::extract::State) - Extract cloned state
916    /// - [`Extension<T>`](crate::extract::Extension) - Extract router-level state
917    /// - [`Context`](crate::extract::Context) - Extract request context
918    /// - [`RawArgs`](crate::extract::RawArgs) - Extract raw JSON arguments
919    ///
920    /// # Per-Tool Middleware
921    ///
922    /// The returned builder supports `.layer()` to apply Tower middleware:
923    ///
924    /// ```rust
925    /// use std::sync::Arc;
926    /// use std::time::Duration;
927    /// use tower::timeout::TimeoutLayer;
928    /// use tower_mcp::{ToolBuilder, CallToolResult};
929    /// use tower_mcp::extract::{Json, State};
930    /// use schemars::JsonSchema;
931    /// use serde::Deserialize;
932    ///
933    /// #[derive(Clone)]
934    /// struct Database { url: String }
935    ///
936    /// #[derive(Debug, Deserialize, JsonSchema)]
937    /// struct QueryInput { query: String }
938    ///
939    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
940    ///
941    /// let tool = ToolBuilder::new("search")
942    ///     .description("Search the database")
943    ///     .extractor_handler(db, |
944    ///         State(db): State<Arc<Database>>,
945    ///         Json(input): Json<QueryInput>,
946    ///     | async move {
947    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
948    ///     })
949    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
950    ///     .build();
951    /// ```
952    ///
953    /// # Example
954    ///
955    /// ```rust
956    /// use std::sync::Arc;
957    /// use tower_mcp::{ToolBuilder, CallToolResult};
958    /// use tower_mcp::extract::{Json, State, Context};
959    /// use schemars::JsonSchema;
960    /// use serde::Deserialize;
961    ///
962    /// #[derive(Clone)]
963    /// struct Database { url: String }
964    ///
965    /// #[derive(Debug, Deserialize, JsonSchema)]
966    /// struct QueryInput { query: String }
967    ///
968    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
969    ///
970    /// let tool = ToolBuilder::new("search")
971    ///     .description("Search the database")
972    ///     .extractor_handler(db, |
973    ///         State(db): State<Arc<Database>>,
974    ///         ctx: Context,
975    ///         Json(input): Json<QueryInput>,
976    ///     | async move {
977    ///         if ctx.is_cancelled() {
978    ///             return Ok(CallToolResult::error("Cancelled"));
979    ///         }
980    ///         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
981    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
982    ///     })
983    ///     .build();
984    /// ```
985    ///
986    /// # Type Inference
987    ///
988    /// The compiler infers extractor types from the function signature. Make sure
989    /// to annotate the extractor types explicitly in the closure parameters.
990    pub fn extractor_handler<S, F, T>(
991        self,
992        state: S,
993        handler: F,
994    ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
995    where
996        S: Clone + Send + Sync + 'static,
997        F: crate::extract::ExtractorHandler<S, T> + Clone,
998        T: Send + Sync + 'static,
999    {
1000        crate::extract::ToolBuilderWithExtractor {
1001            name: self.name,
1002            title: self.title,
1003            description: self.description,
1004            output_schema: self.output_schema,
1005            icons: self.icons,
1006            annotations: self.annotations,
1007            state,
1008            handler,
1009            input_schema: F::input_schema(),
1010            _phantom: std::marker::PhantomData,
1011        }
1012    }
1013
1014    /// Create a tool using the extractor pattern with typed JSON input.
1015    ///
1016    /// This is similar to [`extractor_handler`](Self::extractor_handler) but requires
1017    /// an explicit type parameter for the JSON input type via turbofish syntax.
1018    ///
1019    /// Since `extractor_handler` now auto-detects the JSON schema from `Json<T>`
1020    /// extractors, this method is typically unnecessary. It remains available for
1021    /// cases where you need explicit control over the schema type parameter.
1022    ///
1023    /// # Example
1024    ///
1025    /// ```rust
1026    /// use std::sync::Arc;
1027    /// use tower_mcp::{ToolBuilder, CallToolResult};
1028    /// use tower_mcp::extract::{Json, State};
1029    /// use schemars::JsonSchema;
1030    /// use serde::Deserialize;
1031    ///
1032    /// #[derive(Clone)]
1033    /// struct AppState { prefix: String }
1034    ///
1035    /// #[derive(Debug, Deserialize, JsonSchema)]
1036    /// struct GreetInput { name: String }
1037    ///
1038    /// let state = Arc::new(AppState { prefix: "Hello".to_string() });
1039    ///
1040    /// let tool = ToolBuilder::new("greet")
1041    ///     .description("Greet someone")
1042    ///     .extractor_handler_typed::<_, _, _, GreetInput>(state, |
1043    ///         State(app): State<Arc<AppState>>,
1044    ///         Json(input): Json<GreetInput>,
1045    ///     | async move {
1046    ///         Ok(CallToolResult::text(format!("{}, {}!", app.prefix, input.name)))
1047    ///     })
1048    ///     .build();
1049    /// ```
1050    pub fn extractor_handler_typed<S, F, T, I>(
1051        self,
1052        state: S,
1053        handler: F,
1054    ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
1055    where
1056        S: Clone + Send + Sync + 'static,
1057        F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
1058        T: Send + Sync + 'static,
1059        I: schemars::JsonSchema + Send + Sync + 'static,
1060    {
1061        crate::extract::ToolBuilderWithTypedExtractor {
1062            name: self.name,
1063            title: self.title,
1064            description: self.description,
1065            output_schema: self.output_schema,
1066            icons: self.icons,
1067            annotations: self.annotations,
1068            state,
1069            handler,
1070            _phantom: std::marker::PhantomData,
1071        }
1072    }
1073}
1074
1075/// Handler for tools with no parameters.
1076///
1077/// Used internally by [`ToolBuilder::no_params_handler`].
1078struct NoParamsTypedHandler<F> {
1079    handler: F,
1080}
1081
1082impl<F, Fut> ToolHandler for NoParamsTypedHandler<F>
1083where
1084    F: Fn() -> Fut + Send + Sync + 'static,
1085    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1086{
1087    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1088        Box::pin(async move { (self.handler)().await })
1089    }
1090
1091    fn input_schema(&self) -> Value {
1092        serde_json::json!({ "type": "object" })
1093    }
1094}
1095
1096/// Builder state after handler is specified
1097pub struct ToolBuilderWithHandler<I, F> {
1098    name: String,
1099    title: Option<String>,
1100    description: Option<String>,
1101    output_schema: Option<Value>,
1102    icons: Option<Vec<ToolIcon>>,
1103    annotations: Option<ToolAnnotations>,
1104    handler: F,
1105    _phantom: std::marker::PhantomData<I>,
1106}
1107
1108/// Builder state for tools with no parameters.
1109///
1110/// Created by [`ToolBuilder::no_params_handler`].
1111pub struct ToolBuilderWithNoParamsHandler<F> {
1112    name: String,
1113    title: Option<String>,
1114    description: Option<String>,
1115    output_schema: Option<Value>,
1116    icons: Option<Vec<ToolIcon>>,
1117    annotations: Option<ToolAnnotations>,
1118    handler: F,
1119}
1120
1121impl<F, Fut> ToolBuilderWithNoParamsHandler<F>
1122where
1123    F: Fn() -> Fut + Send + Sync + 'static,
1124    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1125{
1126    /// Build the tool.
1127    pub fn build(self) -> Tool {
1128        Tool::from_handler(
1129            self.name,
1130            self.title,
1131            self.description,
1132            self.output_schema,
1133            self.icons,
1134            self.annotations,
1135            NoParamsTypedHandler {
1136                handler: self.handler,
1137            },
1138        )
1139    }
1140
1141    /// Apply a Tower layer (middleware) to this tool.
1142    ///
1143    /// See [`ToolBuilderWithHandler::layer`] for details.
1144    pub fn layer<L>(self, layer: L) -> ToolBuilderWithNoParamsHandlerLayer<F, L> {
1145        ToolBuilderWithNoParamsHandlerLayer {
1146            name: self.name,
1147            title: self.title,
1148            description: self.description,
1149            output_schema: self.output_schema,
1150            icons: self.icons,
1151            annotations: self.annotations,
1152            handler: self.handler,
1153            layer,
1154        }
1155    }
1156
1157    /// Apply a guard to this tool.
1158    ///
1159    /// See [`ToolBuilderWithHandler::guard`] for details.
1160    pub fn guard<G>(self, guard: G) -> ToolBuilderWithNoParamsHandlerLayer<F, GuardLayer<G>>
1161    where
1162        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1163    {
1164        self.layer(GuardLayer::new(guard))
1165    }
1166}
1167
1168/// Builder state after a layer has been applied to a no-params handler.
1169pub struct ToolBuilderWithNoParamsHandlerLayer<F, L> {
1170    name: String,
1171    title: Option<String>,
1172    description: Option<String>,
1173    output_schema: Option<Value>,
1174    icons: Option<Vec<ToolIcon>>,
1175    annotations: Option<ToolAnnotations>,
1176    handler: F,
1177    layer: L,
1178}
1179
1180#[allow(private_bounds)]
1181impl<F, Fut, L> ToolBuilderWithNoParamsHandlerLayer<F, L>
1182where
1183    F: Fn() -> Fut + Send + Sync + 'static,
1184    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1185    L: tower::Layer<ToolHandlerService<NoParamsTypedHandler<F>>> + Clone + Send + Sync + 'static,
1186    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1187    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1188    <L::Service as Service<ToolRequest>>::Future: Send,
1189{
1190    /// Build the tool with the applied layer(s).
1191    pub fn build(self) -> Tool {
1192        let input_schema = serde_json::json!({ "type": "object" });
1193
1194        let handler_service = ToolHandlerService::new(NoParamsTypedHandler {
1195            handler: self.handler,
1196        });
1197        let layered = self.layer.layer(handler_service);
1198        let catch_error = ToolCatchError::new(layered);
1199        let service = BoxCloneService::new(catch_error);
1200
1201        Tool {
1202            name: self.name,
1203            title: self.title,
1204            description: self.description,
1205            output_schema: self.output_schema,
1206            icons: self.icons,
1207            annotations: self.annotations,
1208            service,
1209            input_schema,
1210        }
1211    }
1212
1213    /// Apply an additional Tower layer (middleware).
1214    pub fn layer<L2>(
1215        self,
1216        layer: L2,
1217    ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<L2, L>> {
1218        ToolBuilderWithNoParamsHandlerLayer {
1219            name: self.name,
1220            title: self.title,
1221            description: self.description,
1222            output_schema: self.output_schema,
1223            icons: self.icons,
1224            annotations: self.annotations,
1225            handler: self.handler,
1226            layer: tower::layer::util::Stack::new(layer, self.layer),
1227        }
1228    }
1229
1230    /// Apply a guard to this tool.
1231    ///
1232    /// See [`ToolBuilderWithHandler::guard`] for details.
1233    pub fn guard<G>(
1234        self,
1235        guard: G,
1236    ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<GuardLayer<G>, L>>
1237    where
1238        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1239    {
1240        self.layer(GuardLayer::new(guard))
1241    }
1242}
1243
1244impl<I, F, Fut> ToolBuilderWithHandler<I, F>
1245where
1246    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1247    F: Fn(I) -> Fut + Send + Sync + 'static,
1248    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1249{
1250    /// Build the tool.
1251    pub fn build(self) -> Tool {
1252        Tool::from_handler(
1253            self.name,
1254            self.title,
1255            self.description,
1256            self.output_schema,
1257            self.icons,
1258            self.annotations,
1259            TypedHandler {
1260                handler: self.handler,
1261                _phantom: std::marker::PhantomData,
1262            },
1263        )
1264    }
1265
1266    /// Apply a Tower layer (middleware) to this tool.
1267    ///
1268    /// The layer wraps the tool's handler service, enabling functionality like
1269    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1270    ///
1271    /// # Example
1272    ///
1273    /// ```rust
1274    /// use std::time::Duration;
1275    /// use tower::timeout::TimeoutLayer;
1276    /// use tower_mcp::{ToolBuilder, CallToolResult};
1277    /// use schemars::JsonSchema;
1278    /// use serde::Deserialize;
1279    ///
1280    /// #[derive(Debug, Deserialize, JsonSchema)]
1281    /// struct Input { query: String }
1282    ///
1283    /// let tool = ToolBuilder::new("search")
1284    ///     .description("Search with timeout")
1285    ///     .handler(|input: Input| async move {
1286    ///         Ok(CallToolResult::text("result"))
1287    ///     })
1288    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1289    ///     .build();
1290    /// ```
1291    pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
1292        ToolBuilderWithLayer {
1293            name: self.name,
1294            title: self.title,
1295            description: self.description,
1296            output_schema: self.output_schema,
1297            icons: self.icons,
1298            annotations: self.annotations,
1299            handler: self.handler,
1300            layer,
1301            _phantom: std::marker::PhantomData,
1302        }
1303    }
1304
1305    /// Apply a guard to this tool.
1306    ///
1307    /// The guard runs before the handler and can short-circuit with an error
1308    /// message. This is syntactic sugar for `.layer(GuardLayer::new(f))`.
1309    ///
1310    /// See [`GuardLayer`] for a full example.
1311    pub fn guard<G>(self, guard: G) -> ToolBuilderWithLayer<I, F, GuardLayer<G>>
1312    where
1313        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1314    {
1315        self.layer(GuardLayer::new(guard))
1316    }
1317}
1318
1319/// Builder state after a layer has been applied to the handler.
1320///
1321/// This builder allows chaining additional layers and building the final tool.
1322pub struct ToolBuilderWithLayer<I, F, L> {
1323    name: String,
1324    title: Option<String>,
1325    description: Option<String>,
1326    output_schema: Option<Value>,
1327    icons: Option<Vec<ToolIcon>>,
1328    annotations: Option<ToolAnnotations>,
1329    handler: F,
1330    layer: L,
1331    _phantom: std::marker::PhantomData<I>,
1332}
1333
1334// Allow private_bounds because these internal types (ToolHandlerService, TypedHandler, etc.)
1335// are implementation details that users don't interact with directly.
1336#[allow(private_bounds)]
1337impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
1338where
1339    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1340    F: Fn(I) -> Fut + Send + Sync + 'static,
1341    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1342    L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
1343    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1344    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1345    <L::Service as Service<ToolRequest>>::Future: Send,
1346{
1347    /// Build the tool with the applied layer(s).
1348    pub fn build(self) -> Tool {
1349        let input_schema = schemars::schema_for!(I);
1350        let input_schema = serde_json::to_value(input_schema)
1351            .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
1352
1353        let handler_service = ToolHandlerService::new(TypedHandler {
1354            handler: self.handler,
1355            _phantom: std::marker::PhantomData,
1356        });
1357        let layered = self.layer.layer(handler_service);
1358        let catch_error = ToolCatchError::new(layered);
1359        let service = BoxCloneService::new(catch_error);
1360
1361        Tool {
1362            name: self.name,
1363            title: self.title,
1364            description: self.description,
1365            output_schema: self.output_schema,
1366            icons: self.icons,
1367            annotations: self.annotations,
1368            service,
1369            input_schema,
1370        }
1371    }
1372
1373    /// Apply an additional Tower layer (middleware).
1374    ///
1375    /// Layers are applied in order, with earlier layers wrapping later ones.
1376    /// This means the first layer added is the outermost middleware.
1377    pub fn layer<L2>(
1378        self,
1379        layer: L2,
1380    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
1381        ToolBuilderWithLayer {
1382            name: self.name,
1383            title: self.title,
1384            description: self.description,
1385            output_schema: self.output_schema,
1386            icons: self.icons,
1387            annotations: self.annotations,
1388            handler: self.handler,
1389            layer: tower::layer::util::Stack::new(layer, self.layer),
1390            _phantom: std::marker::PhantomData,
1391        }
1392    }
1393
1394    /// Apply a guard to this tool.
1395    ///
1396    /// See [`ToolBuilderWithHandler::guard`] for details.
1397    pub fn guard<G>(
1398        self,
1399        guard: G,
1400    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<GuardLayer<G>, L>>
1401    where
1402        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1403    {
1404        self.layer(GuardLayer::new(guard))
1405    }
1406}
1407
1408// =============================================================================
1409// Handler implementations
1410// =============================================================================
1411
1412/// Handler that deserializes input to a specific type
1413struct TypedHandler<I, F> {
1414    handler: F,
1415    _phantom: std::marker::PhantomData<I>,
1416}
1417
1418impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1419where
1420    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1421    F: Fn(I) -> Fut + Send + Sync + 'static,
1422    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1423{
1424    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1425        Box::pin(async move {
1426            let input: I = serde_json::from_value(args).tool_context("Invalid input")?;
1427            (self.handler)(input).await
1428        })
1429    }
1430
1431    fn input_schema(&self) -> Value {
1432        let schema = schemars::schema_for!(I);
1433        serde_json::to_value(schema).unwrap_or_else(|_| {
1434            serde_json::json!({
1435                "type": "object"
1436            })
1437        })
1438    }
1439}
1440
1441// =============================================================================
1442// Trait-based tool definition
1443// =============================================================================
1444
1445/// Trait for defining tools with full control
1446///
1447/// Implement this trait when you need more control than the builder provides,
1448/// or when you want to define tools as standalone types.
1449///
1450/// # Example
1451///
1452/// ```rust
1453/// use tower_mcp::tool::McpTool;
1454/// use tower_mcp::error::Result;
1455/// use schemars::JsonSchema;
1456/// use serde::{Deserialize, Serialize};
1457///
1458/// #[derive(Debug, Deserialize, JsonSchema)]
1459/// struct AddInput {
1460///     a: i64,
1461///     b: i64,
1462/// }
1463///
1464/// struct AddTool;
1465///
1466/// impl McpTool for AddTool {
1467///     const NAME: &'static str = "add";
1468///     const DESCRIPTION: &'static str = "Add two numbers";
1469///
1470///     type Input = AddInput;
1471///     type Output = i64;
1472///
1473///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
1474///         Ok(input.a + input.b)
1475///     }
1476/// }
1477///
1478/// let tool = AddTool.into_tool();
1479/// assert_eq!(tool.name, "add");
1480/// ```
1481pub trait McpTool: Send + Sync + 'static {
1482    const NAME: &'static str;
1483    const DESCRIPTION: &'static str;
1484
1485    type Input: JsonSchema + DeserializeOwned + Send;
1486    type Output: Serialize + Send;
1487
1488    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1489
1490    /// Optional annotations for the tool
1491    fn annotations(&self) -> Option<ToolAnnotations> {
1492        None
1493    }
1494
1495    /// Convert to a [`Tool`] instance.
1496    ///
1497    /// # Panics
1498    ///
1499    /// Panics if [`NAME`](Self::NAME) is not a valid tool name. Since `NAME`
1500    /// is a `&'static str`, invalid names are caught immediately during
1501    /// development.
1502    fn into_tool(self) -> Tool
1503    where
1504        Self: Sized,
1505    {
1506        if let Err(e) = validate_tool_name(Self::NAME) {
1507            panic!("{e}");
1508        }
1509        let annotations = self.annotations();
1510        let tool = Arc::new(self);
1511        Tool::from_handler(
1512            Self::NAME.to_string(),
1513            None,
1514            Some(Self::DESCRIPTION.to_string()),
1515            None,
1516            None,
1517            annotations,
1518            McpToolHandler { tool },
1519        )
1520    }
1521}
1522
1523/// Wrapper to make McpTool implement ToolHandler
1524struct McpToolHandler<T: McpTool> {
1525    tool: Arc<T>,
1526}
1527
1528impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1529    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1530        let tool = self.tool.clone();
1531        Box::pin(async move {
1532            let input: T::Input = serde_json::from_value(args).tool_context("Invalid input")?;
1533            let output = tool.call(input).await?;
1534            let value = serde_json::to_value(output).tool_context("Failed to serialize output")?;
1535            Ok(CallToolResult::json(value))
1536        })
1537    }
1538
1539    fn input_schema(&self) -> Value {
1540        let schema = schemars::schema_for!(T::Input);
1541        serde_json::to_value(schema).unwrap_or_else(|_| {
1542            serde_json::json!({
1543                "type": "object"
1544            })
1545        })
1546    }
1547}
1548
1549#[cfg(test)]
1550mod tests {
1551    use super::*;
1552    use crate::extract::{Context, Json, RawArgs, State};
1553    use crate::protocol::Content;
1554    use schemars::JsonSchema;
1555    use serde::Deserialize;
1556
1557    #[derive(Debug, Deserialize, JsonSchema)]
1558    struct GreetInput {
1559        name: String,
1560    }
1561
1562    #[tokio::test]
1563    async fn test_builder_tool() {
1564        let tool = ToolBuilder::new("greet")
1565            .description("Greet someone")
1566            .handler(|input: GreetInput| async move {
1567                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1568            })
1569            .build();
1570
1571        assert_eq!(tool.name, "greet");
1572        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1573
1574        let result = tool.call(serde_json::json!({"name": "World"})).await;
1575
1576        assert!(!result.is_error);
1577    }
1578
1579    #[tokio::test]
1580    async fn test_raw_handler() {
1581        let tool = ToolBuilder::new("echo")
1582            .description("Echo input")
1583            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1584                Ok(CallToolResult::json(args))
1585            })
1586            .build();
1587
1588        let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1589
1590        assert!(!result.is_error);
1591    }
1592
1593    #[test]
1594    fn test_invalid_tool_name_empty() {
1595        let err = ToolBuilder::try_new("").err().expect("should fail");
1596        assert!(err.to_string().contains("cannot be empty"));
1597    }
1598
1599    #[test]
1600    fn test_invalid_tool_name_too_long() {
1601        let long_name = "a".repeat(129);
1602        let err = ToolBuilder::try_new(long_name).err().expect("should fail");
1603        assert!(err.to_string().contains("exceeds maximum"));
1604    }
1605
1606    #[test]
1607    fn test_invalid_tool_name_bad_chars() {
1608        let err = ToolBuilder::try_new("my tool!").err().expect("should fail");
1609        assert!(err.to_string().contains("invalid character"));
1610    }
1611
1612    #[test]
1613    #[should_panic(expected = "cannot be empty")]
1614    fn test_new_panics_on_empty_name() {
1615        ToolBuilder::new("");
1616    }
1617
1618    #[test]
1619    #[should_panic(expected = "exceeds maximum")]
1620    fn test_new_panics_on_too_long_name() {
1621        ToolBuilder::new("a".repeat(129));
1622    }
1623
1624    #[test]
1625    #[should_panic(expected = "invalid character")]
1626    fn test_new_panics_on_invalid_chars() {
1627        ToolBuilder::new("my tool!");
1628    }
1629
1630    #[test]
1631    fn test_valid_tool_names() {
1632        // All valid characters
1633        let names = [
1634            "my_tool",
1635            "my-tool",
1636            "my.tool",
1637            "MyTool123",
1638            "a",
1639            &"a".repeat(128),
1640        ];
1641        for name in names {
1642            assert!(
1643                ToolBuilder::try_new(name).is_ok(),
1644                "Expected '{}' to be valid",
1645                name
1646            );
1647        }
1648    }
1649
1650    #[tokio::test]
1651    async fn test_context_aware_handler() {
1652        use crate::context::notification_channel;
1653        use crate::protocol::{ProgressToken, RequestId};
1654
1655        #[derive(Debug, Deserialize, JsonSchema)]
1656        struct ProcessInput {
1657            count: i32,
1658        }
1659
1660        let tool = ToolBuilder::new("process")
1661            .description("Process with context")
1662            .extractor_handler(
1663                (),
1664                |ctx: Context, Json(input): Json<ProcessInput>| async move {
1665                    // Simulate progress reporting
1666                    for i in 0..input.count {
1667                        if ctx.is_cancelled() {
1668                            return Ok(CallToolResult::error("Cancelled"));
1669                        }
1670                        ctx.report_progress(i as f64, Some(input.count as f64), None)
1671                            .await;
1672                    }
1673                    Ok(CallToolResult::text(format!(
1674                        "Processed {} items",
1675                        input.count
1676                    )))
1677                },
1678            )
1679            .build();
1680
1681        assert_eq!(tool.name, "process");
1682
1683        // Test with a context that has progress token and notification sender
1684        let (tx, mut rx) = notification_channel(10);
1685        let ctx = RequestContext::new(RequestId::Number(1))
1686            .with_progress_token(ProgressToken::Number(42))
1687            .with_notification_sender(tx);
1688
1689        let result = tool
1690            .call_with_context(ctx, serde_json::json!({"count": 3}))
1691            .await;
1692
1693        assert!(!result.is_error);
1694
1695        // Check that progress notifications were sent
1696        let mut progress_count = 0;
1697        while rx.try_recv().is_ok() {
1698            progress_count += 1;
1699        }
1700        assert_eq!(progress_count, 3);
1701    }
1702
1703    #[tokio::test]
1704    async fn test_context_aware_handler_cancellation() {
1705        use crate::protocol::RequestId;
1706        use std::sync::atomic::{AtomicI32, Ordering};
1707
1708        #[derive(Debug, Deserialize, JsonSchema)]
1709        struct LongRunningInput {
1710            iterations: i32,
1711        }
1712
1713        let iterations_completed = Arc::new(AtomicI32::new(0));
1714        let iterations_ref = iterations_completed.clone();
1715
1716        let tool = ToolBuilder::new("long_running")
1717            .description("Long running task")
1718            .extractor_handler(
1719                (),
1720                move |ctx: Context, Json(input): Json<LongRunningInput>| {
1721                    let completed = iterations_ref.clone();
1722                    async move {
1723                        for i in 0..input.iterations {
1724                            if ctx.is_cancelled() {
1725                                return Ok(CallToolResult::error("Cancelled"));
1726                            }
1727                            completed.fetch_add(1, Ordering::SeqCst);
1728                            // Simulate work
1729                            tokio::task::yield_now().await;
1730                            // Cancel after iteration 2
1731                            if i == 2 {
1732                                ctx.cancellation_token().cancel();
1733                            }
1734                        }
1735                        Ok(CallToolResult::text("Done"))
1736                    }
1737                },
1738            )
1739            .build();
1740
1741        let ctx = RequestContext::new(RequestId::Number(1));
1742
1743        let result = tool
1744            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1745            .await;
1746
1747        // Should have been cancelled after 3 iterations (0, 1, 2)
1748        // The next iteration (3) checks cancellation and returns
1749        assert!(result.is_error);
1750        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1751    }
1752
1753    #[tokio::test]
1754    async fn test_tool_builder_with_enhanced_fields() {
1755        let output_schema = serde_json::json!({
1756            "type": "object",
1757            "properties": {
1758                "greeting": {"type": "string"}
1759            }
1760        });
1761
1762        let tool = ToolBuilder::new("greet")
1763            .title("Greeting Tool")
1764            .description("Greet someone")
1765            .output_schema(output_schema.clone())
1766            .icon("https://example.com/icon.png")
1767            .icon_with_meta(
1768                "https://example.com/icon-large.png",
1769                Some("image/png".to_string()),
1770                Some(vec!["96x96".to_string()]),
1771            )
1772            .handler(|input: GreetInput| async move {
1773                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1774            })
1775            .build();
1776
1777        assert_eq!(tool.name, "greet");
1778        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1779        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1780        assert_eq!(tool.output_schema, Some(output_schema));
1781        assert!(tool.icons.is_some());
1782        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1783
1784        // Test definition includes new fields
1785        let def = tool.definition();
1786        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1787        assert!(def.output_schema.is_some());
1788        assert!(def.icons.is_some());
1789    }
1790
1791    #[tokio::test]
1792    async fn test_handler_with_state() {
1793        let shared = Arc::new("shared-state".to_string());
1794
1795        let tool = ToolBuilder::new("stateful")
1796            .description("Uses shared state")
1797            .extractor_handler(
1798                shared,
1799                |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1800                    Ok(CallToolResult::text(format!(
1801                        "{}: Hello, {}!",
1802                        state, input.name
1803                    )))
1804                },
1805            )
1806            .build();
1807
1808        let result = tool.call(serde_json::json!({"name": "World"})).await;
1809        assert!(!result.is_error);
1810    }
1811
1812    #[tokio::test]
1813    async fn test_handler_with_state_and_context() {
1814        use crate::protocol::RequestId;
1815
1816        let shared = Arc::new(42_i32);
1817
1818        let tool =
1819            ToolBuilder::new("stateful_ctx")
1820                .description("Uses state and context")
1821                .extractor_handler(
1822                    shared,
1823                    |State(state): State<Arc<i32>>,
1824                     _ctx: Context,
1825                     Json(input): Json<GreetInput>| async move {
1826                        Ok(CallToolResult::text(format!(
1827                            "{}: Hello, {}!",
1828                            state, input.name
1829                        )))
1830                    },
1831                )
1832                .build();
1833
1834        let ctx = RequestContext::new(RequestId::Number(1));
1835        let result = tool
1836            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1837            .await;
1838        assert!(!result.is_error);
1839    }
1840
1841    #[tokio::test]
1842    async fn test_handler_no_params() {
1843        let tool = ToolBuilder::new("no_params")
1844            .description("Takes no parameters")
1845            .extractor_handler((), |Json(_): Json<NoParams>| async {
1846                Ok(CallToolResult::text("no params result"))
1847            })
1848            .build();
1849
1850        assert_eq!(tool.name, "no_params");
1851
1852        // Should work with empty args
1853        let result = tool.call(serde_json::json!({})).await;
1854        assert!(!result.is_error);
1855
1856        // Should also work with unexpected args (ignored)
1857        let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1858        assert!(!result.is_error);
1859
1860        // Check input schema includes type: object
1861        let schema = tool.definition().input_schema;
1862        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1863    }
1864
1865    #[tokio::test]
1866    async fn test_handler_with_state_no_params() {
1867        let shared = Arc::new("shared_value".to_string());
1868
1869        let tool = ToolBuilder::new("with_state_no_params")
1870            .description("Takes no parameters but has state")
1871            .extractor_handler(
1872                shared,
1873                |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1874                    Ok(CallToolResult::text(format!("state: {}", state)))
1875                },
1876            )
1877            .build();
1878
1879        assert_eq!(tool.name, "with_state_no_params");
1880
1881        // Should work with empty args
1882        let result = tool.call(serde_json::json!({})).await;
1883        assert!(!result.is_error);
1884        assert_eq!(result.first_text().unwrap(), "state: shared_value");
1885
1886        // Check input schema includes type: object
1887        let schema = tool.definition().input_schema;
1888        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1889    }
1890
1891    #[tokio::test]
1892    async fn test_handler_no_params_with_context() {
1893        let tool = ToolBuilder::new("no_params_with_context")
1894            .description("Takes no parameters but has context")
1895            .extractor_handler((), |_ctx: Context, Json(_): Json<NoParams>| async move {
1896                Ok(CallToolResult::text("context available"))
1897            })
1898            .build();
1899
1900        assert_eq!(tool.name, "no_params_with_context");
1901
1902        let result = tool.call(serde_json::json!({})).await;
1903        assert!(!result.is_error);
1904        assert_eq!(result.first_text().unwrap(), "context available");
1905    }
1906
1907    #[tokio::test]
1908    async fn test_handler_with_state_and_context_no_params() {
1909        let shared = Arc::new("shared".to_string());
1910
1911        let tool = ToolBuilder::new("state_context_no_params")
1912            .description("Has state and context, no params")
1913            .extractor_handler(
1914                shared,
1915                |State(state): State<Arc<String>>,
1916                 _ctx: Context,
1917                 Json(_): Json<NoParams>| async move {
1918                    Ok(CallToolResult::text(format!("state: {}", state)))
1919                },
1920            )
1921            .build();
1922
1923        assert_eq!(tool.name, "state_context_no_params");
1924
1925        let result = tool.call(serde_json::json!({})).await;
1926        assert!(!result.is_error);
1927        assert_eq!(result.first_text().unwrap(), "state: shared");
1928    }
1929
1930    #[tokio::test]
1931    async fn test_raw_handler_with_state() {
1932        let prefix = Arc::new("prefix:".to_string());
1933
1934        let tool = ToolBuilder::new("raw_with_state")
1935            .description("Raw handler with state")
1936            .extractor_handler(
1937                prefix,
1938                |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
1939                    Ok(CallToolResult::text(format!("{} {}", state, args)))
1940                },
1941            )
1942            .build();
1943
1944        assert_eq!(tool.name, "raw_with_state");
1945
1946        let result = tool.call(serde_json::json!({"key": "value"})).await;
1947        assert!(!result.is_error);
1948        assert!(result.first_text().unwrap().starts_with("prefix:"));
1949    }
1950
1951    #[tokio::test]
1952    async fn test_raw_handler_with_state_and_context() {
1953        let prefix = Arc::new("prefix:".to_string());
1954
1955        let tool = ToolBuilder::new("raw_state_context")
1956            .description("Raw handler with state and context")
1957            .extractor_handler(
1958                prefix,
1959                |State(state): State<Arc<String>>,
1960                 _ctx: Context,
1961                 RawArgs(args): RawArgs| async move {
1962                    Ok(CallToolResult::text(format!("{} {}", state, args)))
1963                },
1964            )
1965            .build();
1966
1967        assert_eq!(tool.name, "raw_state_context");
1968
1969        let result = tool.call(serde_json::json!({"key": "value"})).await;
1970        assert!(!result.is_error);
1971        assert!(result.first_text().unwrap().starts_with("prefix:"));
1972    }
1973
1974    #[tokio::test]
1975    async fn test_tool_with_timeout_layer() {
1976        use std::time::Duration;
1977        use tower::timeout::TimeoutLayer;
1978
1979        #[derive(Debug, Deserialize, JsonSchema)]
1980        struct SlowInput {
1981            delay_ms: u64,
1982        }
1983
1984        // Create a tool with a short timeout
1985        let tool = ToolBuilder::new("slow_tool")
1986            .description("A slow tool")
1987            .handler(|input: SlowInput| async move {
1988                tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
1989                Ok(CallToolResult::text("completed"))
1990            })
1991            .layer(TimeoutLayer::new(Duration::from_millis(50)))
1992            .build();
1993
1994        // Fast call should succeed
1995        let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
1996        assert!(!result.is_error);
1997        assert_eq!(result.first_text().unwrap(), "completed");
1998
1999        // Slow call should timeout and return an error result
2000        let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
2001        assert!(result.is_error);
2002        // Tower's timeout error message is "request timed out"
2003        let msg = result.first_text().unwrap().to_lowercase();
2004        assert!(
2005            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2006            "Expected timeout error, got: {}",
2007            msg
2008        );
2009    }
2010
2011    #[tokio::test]
2012    async fn test_tool_with_concurrency_limit_layer() {
2013        use std::sync::atomic::{AtomicU32, Ordering};
2014        use std::time::Duration;
2015        use tower::limit::ConcurrencyLimitLayer;
2016
2017        #[derive(Debug, Deserialize, JsonSchema)]
2018        struct WorkInput {
2019            id: u32,
2020        }
2021
2022        let max_concurrent = Arc::new(AtomicU32::new(0));
2023        let current_concurrent = Arc::new(AtomicU32::new(0));
2024        let max_ref = max_concurrent.clone();
2025        let current_ref = current_concurrent.clone();
2026
2027        // Create a tool with concurrency limit of 2
2028        let tool = ToolBuilder::new("concurrent_tool")
2029            .description("A concurrent tool")
2030            .handler(move |input: WorkInput| {
2031                let max = max_ref.clone();
2032                let current = current_ref.clone();
2033                async move {
2034                    // Track concurrency
2035                    let prev = current.fetch_add(1, Ordering::SeqCst);
2036                    max.fetch_max(prev + 1, Ordering::SeqCst);
2037
2038                    // Simulate work
2039                    tokio::time::sleep(Duration::from_millis(50)).await;
2040
2041                    current.fetch_sub(1, Ordering::SeqCst);
2042                    Ok(CallToolResult::text(format!("completed {}", input.id)))
2043                }
2044            })
2045            .layer(ConcurrencyLimitLayer::new(2))
2046            .build();
2047
2048        // Launch 4 concurrent calls
2049        let handles: Vec<_> = (0..4)
2050            .map(|i| {
2051                let t = tool.call(serde_json::json!({"id": i}));
2052                tokio::spawn(t)
2053            })
2054            .collect();
2055
2056        for handle in handles {
2057            let result = handle.await.unwrap();
2058            assert!(!result.is_error);
2059        }
2060
2061        // Max concurrent should not exceed 2
2062        assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
2063    }
2064
2065    #[tokio::test]
2066    async fn test_tool_with_multiple_layers() {
2067        use std::time::Duration;
2068        use tower::limit::ConcurrencyLimitLayer;
2069        use tower::timeout::TimeoutLayer;
2070
2071        #[derive(Debug, Deserialize, JsonSchema)]
2072        struct Input {
2073            value: String,
2074        }
2075
2076        // Create a tool with multiple layers stacked
2077        let tool = ToolBuilder::new("multi_layer_tool")
2078            .description("Tool with multiple layers")
2079            .handler(|input: Input| async move {
2080                Ok(CallToolResult::text(format!("processed: {}", input.value)))
2081            })
2082            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2083            .layer(ConcurrencyLimitLayer::new(10))
2084            .build();
2085
2086        let result = tool.call(serde_json::json!({"value": "test"})).await;
2087        assert!(!result.is_error);
2088        assert_eq!(result.first_text().unwrap(), "processed: test");
2089    }
2090
2091    #[test]
2092    fn test_tool_catch_error_clone() {
2093        // ToolCatchError should be Clone when inner is Clone
2094        // Use a simple tool that we can clone
2095        let tool = ToolBuilder::new("test")
2096            .description("test")
2097            .extractor_handler((), |RawArgs(_args): RawArgs| async {
2098                Ok(CallToolResult::text("ok"))
2099            })
2100            .build();
2101        // The tool contains a BoxToolService which is cloneable
2102        let _clone = tool.call(serde_json::json!({}));
2103    }
2104
2105    #[test]
2106    fn test_tool_catch_error_debug() {
2107        // ToolCatchError implements Debug when inner implements Debug
2108        // Since our internal services don't require Debug, just verify
2109        // that ToolCatchError has a Debug impl for appropriate types
2110        #[derive(Debug, Clone)]
2111        struct DebugService;
2112
2113        impl Service<ToolRequest> for DebugService {
2114            type Response = CallToolResult;
2115            type Error = crate::error::Error;
2116            type Future = Pin<
2117                Box<
2118                    dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
2119                        + Send,
2120                >,
2121            >;
2122
2123            fn poll_ready(
2124                &mut self,
2125                _cx: &mut std::task::Context<'_>,
2126            ) -> Poll<std::result::Result<(), Self::Error>> {
2127                Poll::Ready(Ok(()))
2128            }
2129
2130            fn call(&mut self, _req: ToolRequest) -> Self::Future {
2131                Box::pin(async { Ok(CallToolResult::text("ok")) })
2132            }
2133        }
2134
2135        let catch_error = ToolCatchError::new(DebugService);
2136        let debug = format!("{:?}", catch_error);
2137        assert!(debug.contains("ToolCatchError"));
2138    }
2139
2140    #[test]
2141    fn test_tool_request_new() {
2142        use crate::protocol::RequestId;
2143
2144        let ctx = RequestContext::new(RequestId::Number(42));
2145        let args = serde_json::json!({"key": "value"});
2146        let req = ToolRequest::new(ctx.clone(), args.clone());
2147
2148        assert_eq!(req.args, args);
2149    }
2150
2151    #[test]
2152    fn test_no_params_schema() {
2153        // NoParams should produce a schema with type: "object"
2154        let schema = schemars::schema_for!(NoParams);
2155        let schema_value = serde_json::to_value(&schema).unwrap();
2156        assert_eq!(
2157            schema_value.get("type").and_then(|v| v.as_str()),
2158            Some("object"),
2159            "NoParams should generate type: object schema"
2160        );
2161    }
2162
2163    #[test]
2164    fn test_no_params_deserialize() {
2165        // NoParams should deserialize from various inputs
2166        let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
2167        assert_eq!(from_empty_object, NoParams);
2168
2169        let from_null: NoParams = serde_json::from_str("null").unwrap();
2170        assert_eq!(from_null, NoParams);
2171
2172        // Should also accept objects with unexpected fields (ignored)
2173        let from_object_with_fields: NoParams =
2174            serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
2175        assert_eq!(from_object_with_fields, NoParams);
2176    }
2177
2178    #[tokio::test]
2179    async fn test_no_params_type_in_handler() {
2180        // NoParams can be used as a handler input type
2181        let tool = ToolBuilder::new("status")
2182            .description("Get status")
2183            .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
2184            .build();
2185
2186        // Check schema has type: object (not type: null like () would produce)
2187        let schema = tool.definition().input_schema;
2188        assert_eq!(
2189            schema.get("type").and_then(|v| v.as_str()),
2190            Some("object"),
2191            "NoParams handler should produce type: object schema"
2192        );
2193
2194        // Should work with empty input
2195        let result = tool.call(serde_json::json!({})).await;
2196        assert!(!result.is_error);
2197    }
2198
2199    #[tokio::test]
2200    async fn test_tool_with_name_prefix() {
2201        #[derive(Debug, Deserialize, JsonSchema)]
2202        struct Input {
2203            value: String,
2204        }
2205
2206        let tool = ToolBuilder::new("query")
2207            .description("Query something")
2208            .title("Query Tool")
2209            .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
2210            .build();
2211
2212        // Create prefixed version
2213        let prefixed = tool.with_name_prefix("db");
2214
2215        // Check name is prefixed
2216        assert_eq!(prefixed.name, "db.query");
2217
2218        // Check other fields are preserved
2219        assert_eq!(prefixed.description.as_deref(), Some("Query something"));
2220        assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
2221
2222        // Check the tool still works
2223        let result = prefixed
2224            .call(serde_json::json!({"value": "test input"}))
2225            .await;
2226        assert!(!result.is_error);
2227        match &result.content[0] {
2228            Content::Text { text, .. } => assert_eq!(text, "test input"),
2229            _ => panic!("Expected text content"),
2230        }
2231    }
2232
2233    #[tokio::test]
2234    async fn test_tool_with_name_prefix_multiple_levels() {
2235        let tool = ToolBuilder::new("action")
2236            .description("Do something")
2237            .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
2238            .build();
2239
2240        // Apply multiple prefixes
2241        let prefixed = tool.with_name_prefix("level1");
2242        assert_eq!(prefixed.name, "level1.action");
2243
2244        let double_prefixed = prefixed.with_name_prefix("level0");
2245        assert_eq!(double_prefixed.name, "level0.level1.action");
2246    }
2247
2248    // =============================================================================
2249    // no_params_handler tests
2250    // =============================================================================
2251
2252    #[tokio::test]
2253    async fn test_no_params_handler_basic() {
2254        let tool = ToolBuilder::new("get_status")
2255            .description("Get current status")
2256            .no_params_handler(|| async { Ok(CallToolResult::text("OK")) })
2257            .build();
2258
2259        assert_eq!(tool.name, "get_status");
2260        assert_eq!(tool.description.as_deref(), Some("Get current status"));
2261
2262        // Should work with empty args
2263        let result = tool.call(serde_json::json!({})).await;
2264        assert!(!result.is_error);
2265        assert_eq!(result.first_text().unwrap(), "OK");
2266
2267        // Should also work with null args
2268        let result = tool.call(serde_json::json!(null)).await;
2269        assert!(!result.is_error);
2270
2271        // Check input schema has type: object
2272        let schema = tool.definition().input_schema;
2273        assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
2274    }
2275
2276    #[tokio::test]
2277    async fn test_no_params_handler_with_captured_state() {
2278        let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
2279        let counter_ref = counter.clone();
2280
2281        let tool = ToolBuilder::new("increment")
2282            .description("Increment counter")
2283            .no_params_handler(move || {
2284                let c = counter_ref.clone();
2285                async move {
2286                    let prev = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2287                    Ok(CallToolResult::text(format!("Incremented from {}", prev)))
2288                }
2289            })
2290            .build();
2291
2292        // Call multiple times
2293        let _ = tool.call(serde_json::json!({})).await;
2294        let _ = tool.call(serde_json::json!({})).await;
2295        let result = tool.call(serde_json::json!({})).await;
2296
2297        assert!(!result.is_error);
2298        assert_eq!(result.first_text().unwrap(), "Incremented from 2");
2299        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
2300    }
2301
2302    #[tokio::test]
2303    async fn test_no_params_handler_with_layer() {
2304        use std::time::Duration;
2305        use tower::timeout::TimeoutLayer;
2306
2307        let tool = ToolBuilder::new("slow_status")
2308            .description("Slow status check")
2309            .no_params_handler(|| async {
2310                tokio::time::sleep(Duration::from_millis(10)).await;
2311                Ok(CallToolResult::text("done"))
2312            })
2313            .layer(TimeoutLayer::new(Duration::from_secs(1)))
2314            .build();
2315
2316        let result = tool.call(serde_json::json!({})).await;
2317        assert!(!result.is_error);
2318        assert_eq!(result.first_text().unwrap(), "done");
2319    }
2320
2321    #[tokio::test]
2322    async fn test_no_params_handler_timeout() {
2323        use std::time::Duration;
2324        use tower::timeout::TimeoutLayer;
2325
2326        let tool = ToolBuilder::new("very_slow_status")
2327            .description("Very slow status check")
2328            .no_params_handler(|| async {
2329                tokio::time::sleep(Duration::from_millis(200)).await;
2330                Ok(CallToolResult::text("done"))
2331            })
2332            .layer(TimeoutLayer::new(Duration::from_millis(50)))
2333            .build();
2334
2335        let result = tool.call(serde_json::json!({})).await;
2336        assert!(result.is_error);
2337        let msg = result.first_text().unwrap().to_lowercase();
2338        assert!(
2339            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2340            "Expected timeout error, got: {}",
2341            msg
2342        );
2343    }
2344
2345    #[tokio::test]
2346    async fn test_no_params_handler_with_multiple_layers() {
2347        use std::time::Duration;
2348        use tower::limit::ConcurrencyLimitLayer;
2349        use tower::timeout::TimeoutLayer;
2350
2351        let tool = ToolBuilder::new("multi_layer_status")
2352            .description("Status with multiple layers")
2353            .no_params_handler(|| async { Ok(CallToolResult::text("status ok")) })
2354            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2355            .layer(ConcurrencyLimitLayer::new(10))
2356            .build();
2357
2358        let result = tool.call(serde_json::json!({})).await;
2359        assert!(!result.is_error);
2360        assert_eq!(result.first_text().unwrap(), "status ok");
2361    }
2362
2363    // =========================================================================
2364    // Guard tests
2365    // =========================================================================
2366
2367    #[tokio::test]
2368    async fn test_guard_allows_request() {
2369        #[derive(Debug, Deserialize, JsonSchema)]
2370        #[allow(dead_code)]
2371        struct DeleteInput {
2372            id: String,
2373            confirm: bool,
2374        }
2375
2376        let tool = ToolBuilder::new("delete")
2377            .description("Delete a record")
2378            .handler(|input: DeleteInput| async move {
2379                Ok(CallToolResult::text(format!("deleted {}", input.id)))
2380            })
2381            .guard(|req: &ToolRequest| {
2382                let confirm = req
2383                    .args
2384                    .get("confirm")
2385                    .and_then(|v| v.as_bool())
2386                    .unwrap_or(false);
2387                if !confirm {
2388                    return Err("Must set confirm=true to delete".to_string());
2389                }
2390                Ok(())
2391            })
2392            .build();
2393
2394        let result = tool
2395            .call(serde_json::json!({"id": "abc", "confirm": true}))
2396            .await;
2397        assert!(!result.is_error);
2398        assert_eq!(result.first_text().unwrap(), "deleted abc");
2399    }
2400
2401    #[tokio::test]
2402    async fn test_guard_rejects_request() {
2403        #[derive(Debug, Deserialize, JsonSchema)]
2404        #[allow(dead_code)]
2405        struct DeleteInput2 {
2406            id: String,
2407            confirm: bool,
2408        }
2409
2410        let tool = ToolBuilder::new("delete2")
2411            .description("Delete a record")
2412            .handler(|input: DeleteInput2| async move {
2413                Ok(CallToolResult::text(format!("deleted {}", input.id)))
2414            })
2415            .guard(|req: &ToolRequest| {
2416                let confirm = req
2417                    .args
2418                    .get("confirm")
2419                    .and_then(|v| v.as_bool())
2420                    .unwrap_or(false);
2421                if !confirm {
2422                    return Err("Must set confirm=true to delete".to_string());
2423                }
2424                Ok(())
2425            })
2426            .build();
2427
2428        let result = tool
2429            .call(serde_json::json!({"id": "abc", "confirm": false}))
2430            .await;
2431        assert!(result.is_error);
2432        assert!(
2433            result
2434                .first_text()
2435                .unwrap()
2436                .contains("Must set confirm=true")
2437        );
2438    }
2439
2440    #[tokio::test]
2441    async fn test_guard_with_layer() {
2442        use std::time::Duration;
2443        use tower::timeout::TimeoutLayer;
2444
2445        let tool = ToolBuilder::new("guarded_timeout")
2446            .description("Guarded with timeout")
2447            .handler(|input: GreetInput| async move {
2448                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
2449            })
2450            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2451            .guard(|_req: &ToolRequest| Ok(()))
2452            .build();
2453
2454        let result = tool.call(serde_json::json!({"name": "World"})).await;
2455        assert!(!result.is_error);
2456        assert_eq!(result.first_text().unwrap(), "Hello, World!");
2457    }
2458
2459    #[tokio::test]
2460    async fn test_guard_on_no_params_handler() {
2461        let allowed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true));
2462        let allowed_clone = allowed.clone();
2463
2464        let tool = ToolBuilder::new("status")
2465            .description("Get status")
2466            .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2467            .guard(move |_req: &ToolRequest| {
2468                if allowed_clone.load(std::sync::atomic::Ordering::Relaxed) {
2469                    Ok(())
2470                } else {
2471                    Err("Access denied".to_string())
2472                }
2473            })
2474            .build();
2475
2476        // Allowed
2477        let result = tool.call(serde_json::json!({})).await;
2478        assert!(!result.is_error);
2479        assert_eq!(result.first_text().unwrap(), "ok");
2480
2481        // Denied
2482        allowed.store(false, std::sync::atomic::Ordering::Relaxed);
2483        let result = tool.call(serde_json::json!({})).await;
2484        assert!(result.is_error);
2485        assert!(result.first_text().unwrap().contains("Access denied"));
2486    }
2487
2488    #[tokio::test]
2489    async fn test_guard_on_no_params_handler_with_layer() {
2490        use std::time::Duration;
2491        use tower::timeout::TimeoutLayer;
2492
2493        let tool = ToolBuilder::new("status_layered")
2494            .description("Get status with layers")
2495            .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2496            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2497            .guard(|_req: &ToolRequest| Ok(()))
2498            .build();
2499
2500        let result = tool.call(serde_json::json!({})).await;
2501        assert!(!result.is_error);
2502        assert_eq!(result.first_text().unwrap(), "ok");
2503    }
2504
2505    #[tokio::test]
2506    async fn test_guard_on_extractor_handler() {
2507        use std::sync::Arc;
2508
2509        #[derive(Clone)]
2510        struct AppState {
2511            prefix: String,
2512        }
2513
2514        #[derive(Debug, Deserialize, JsonSchema)]
2515        struct QueryInput {
2516            query: String,
2517        }
2518
2519        let state = Arc::new(AppState {
2520            prefix: "db".to_string(),
2521        });
2522
2523        let tool = ToolBuilder::new("search")
2524            .description("Search")
2525            .extractor_handler(
2526                state,
2527                |State(app): State<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
2528                    Ok(CallToolResult::text(format!(
2529                        "{}: {}",
2530                        app.prefix, input.query
2531                    )))
2532                },
2533            )
2534            .guard(|req: &ToolRequest| {
2535                let query = req.args.get("query").and_then(|v| v.as_str()).unwrap_or("");
2536                if query.is_empty() {
2537                    return Err("Query cannot be empty".to_string());
2538                }
2539                Ok(())
2540            })
2541            .build();
2542
2543        // Valid query
2544        let result = tool.call(serde_json::json!({"query": "hello"})).await;
2545        assert!(!result.is_error);
2546        assert_eq!(result.first_text().unwrap(), "db: hello");
2547
2548        // Empty query rejected by guard
2549        let result = tool.call(serde_json::json!({"query": ""})).await;
2550        assert!(result.is_error);
2551        assert!(
2552            result
2553                .first_text()
2554                .unwrap()
2555                .contains("Query cannot be empty")
2556        );
2557    }
2558
2559    #[tokio::test]
2560    async fn test_guard_on_extractor_handler_with_layer() {
2561        use std::sync::Arc;
2562        use std::time::Duration;
2563        use tower::timeout::TimeoutLayer;
2564
2565        #[derive(Clone)]
2566        struct AppState2 {
2567            prefix: String,
2568        }
2569
2570        #[derive(Debug, Deserialize, JsonSchema)]
2571        struct QueryInput2 {
2572            query: String,
2573        }
2574
2575        let state = Arc::new(AppState2 {
2576            prefix: "db".to_string(),
2577        });
2578
2579        let tool = ToolBuilder::new("search2")
2580            .description("Search with layer and guard")
2581            .extractor_handler(
2582                state,
2583                |State(app): State<Arc<AppState2>>, Json(input): Json<QueryInput2>| async move {
2584                    Ok(CallToolResult::text(format!(
2585                        "{}: {}",
2586                        app.prefix, input.query
2587                    )))
2588                },
2589            )
2590            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2591            .guard(|_req: &ToolRequest| Ok(()))
2592            .build();
2593
2594        let result = tool.call(serde_json::json!({"query": "hello"})).await;
2595        assert!(!result.is_error);
2596        assert_eq!(result.first_text().unwrap(), "db: hello");
2597    }
2598
2599    #[tokio::test]
2600    async fn test_tool_with_guard_post_build() {
2601        let tool = ToolBuilder::new("admin_action")
2602            .description("Admin action")
2603            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2604            .build();
2605
2606        // Apply guard after building
2607        let guarded = tool.with_guard(|req: &ToolRequest| {
2608            let name = req.args.get("name").and_then(|v| v.as_str()).unwrap_or("");
2609            if name == "admin" {
2610                Ok(())
2611            } else {
2612                Err("Only admin allowed".to_string())
2613            }
2614        });
2615
2616        // Admin passes
2617        let result = guarded.call(serde_json::json!({"name": "admin"})).await;
2618        assert!(!result.is_error);
2619
2620        // Non-admin blocked
2621        let result = guarded.call(serde_json::json!({"name": "user"})).await;
2622        assert!(result.is_error);
2623        assert!(result.first_text().unwrap().contains("Only admin allowed"));
2624    }
2625
2626    #[tokio::test]
2627    async fn test_with_guard_preserves_tool_metadata() {
2628        let tool = ToolBuilder::new("my_tool")
2629            .description("A tool")
2630            .title("My Tool")
2631            .read_only()
2632            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2633            .build();
2634
2635        let guarded = tool.with_guard(|_req: &ToolRequest| Ok(()));
2636
2637        assert_eq!(guarded.name, "my_tool");
2638        assert_eq!(guarded.description.as_deref(), Some("A tool"));
2639        assert_eq!(guarded.title.as_deref(), Some("My Tool"));
2640        assert!(guarded.annotations.is_some());
2641    }
2642
2643    #[tokio::test]
2644    async fn test_guard_group_pattern() {
2645        // Demonstrate applying the same guard to multiple tools (per-group pattern)
2646        let require_auth = |req: &ToolRequest| {
2647            let token = req
2648                .args
2649                .get("_token")
2650                .and_then(|v| v.as_str())
2651                .unwrap_or("");
2652            if token == "valid" {
2653                Ok(())
2654            } else {
2655                Err("Authentication required".to_string())
2656            }
2657        };
2658
2659        let tool1 = ToolBuilder::new("action1")
2660            .description("Action 1")
2661            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action1")) })
2662            .build();
2663        let tool2 = ToolBuilder::new("action2")
2664            .description("Action 2")
2665            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action2")) })
2666            .build();
2667
2668        // Apply same guard to both
2669        let guarded1 = tool1.with_guard(require_auth);
2670        let guarded2 = tool2.with_guard(require_auth);
2671
2672        // Without auth
2673        let r1 = guarded1
2674            .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2675            .await;
2676        let r2 = guarded2
2677            .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2678            .await;
2679        assert!(r1.is_error);
2680        assert!(r2.is_error);
2681
2682        // With auth
2683        let r1 = guarded1
2684            .call(serde_json::json!({"name": "test", "_token": "valid"}))
2685            .await;
2686        let r2 = guarded2
2687            .call(serde_json::json!({"name": "test", "_token": "valid"}))
2688            .await;
2689        assert!(!r1.is_error);
2690        assert!(!r2.is_error);
2691    }
2692}