Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8//!
9//! ## Per-Tool Middleware
10//!
11//! Tools are implemented as Tower services internally, enabling middleware
12//! composition via the `.layer()` method:
13//!
14//! ```rust
15//! use std::time::Duration;
16//! use tower::timeout::TimeoutLayer;
17//! use tower_mcp::{ToolBuilder, CallToolResult};
18//! use schemars::JsonSchema;
19//! use serde::Deserialize;
20//!
21//! #[derive(Debug, Deserialize, JsonSchema)]
22//! struct SearchInput { query: String }
23//!
24//! let tool = ToolBuilder::new("slow_search")
25//!     .description("Search with extended timeout")
26//!     .handler(|input: SearchInput| async move {
27//!         Ok(CallToolResult::text("result"))
28//!     })
29//!     .layer(TimeoutLayer::new(Duration::from_secs(30)))
30//!     .build();
31//! ```
32
33use std::borrow::Cow;
34use std::convert::Infallible;
35use std::fmt;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use schemars::{JsonSchema, Schema, SchemaGenerator};
42use serde::Serialize;
43use serde::de::DeserializeOwned;
44use serde_json::Value;
45use tower::util::BoxCloneService;
46use tower_service::Service;
47
48use crate::context::RequestContext;
49use crate::error::{Error, Result, ResultExt};
50use crate::protocol::{
51    CallToolResult, TaskSupportMode, ToolAnnotations, ToolDefinition, ToolExecution, ToolIcon,
52};
53
54// =============================================================================
55// Service Types for Per-Tool Middleware
56// =============================================================================
57
58/// Request type for tool services.
59///
60/// Contains the request context (for progress reporting, cancellation, etc.)
61/// and the tool arguments as raw JSON.
62#[derive(Debug, Clone)]
63pub struct ToolRequest {
64    /// Request context for progress reporting, cancellation, and client requests
65    pub ctx: RequestContext,
66    /// Tool arguments as raw JSON
67    pub args: Value,
68}
69
70impl ToolRequest {
71    /// Create a new tool request
72    pub fn new(ctx: RequestContext, args: Value) -> Self {
73        Self { ctx, args }
74    }
75}
76
77/// A boxed, cloneable tool service with `Error = Infallible`.
78///
79/// This is the internal service type that tools use. Middleware errors are
80/// caught and converted to `CallToolResult::error()` responses, so the
81/// service never fails at the Tower level.
82pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
83
84/// Catches errors from the inner service and converts them to `CallToolResult::error()`.
85///
86/// This wrapper ensures that middleware errors (e.g., timeouts, rate limits)
87/// and handler errors are converted to tool-level error responses with
88/// `is_error: true`, rather than propagating as Tower service errors.
89pub struct ToolCatchError<S> {
90    inner: S,
91}
92
93impl<S> ToolCatchError<S> {
94    /// Create a new `ToolCatchError` wrapping the given service.
95    pub fn new(inner: S) -> Self {
96        Self { inner }
97    }
98}
99
100impl<S: Clone> Clone for ToolCatchError<S> {
101    fn clone(&self) -> Self {
102        Self {
103            inner: self.inner.clone(),
104        }
105    }
106}
107
108impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        f.debug_struct("ToolCatchError")
111            .field("inner", &self.inner)
112            .finish()
113    }
114}
115
116impl<S> Service<ToolRequest> for ToolCatchError<S>
117where
118    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
119    S::Error: fmt::Display + Send,
120    S::Future: Send,
121{
122    type Response = CallToolResult;
123    type Error = Infallible;
124    type Future =
125        Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
126
127    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
128        // Map any readiness error to Infallible (we catch it on call)
129        match self.inner.poll_ready(cx) {
130            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
131            Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
132            Poll::Pending => Poll::Pending,
133        }
134    }
135
136    fn call(&mut self, req: ToolRequest) -> Self::Future {
137        let fut = self.inner.call(req);
138
139        Box::pin(async move {
140            match fut.await {
141                Ok(result) => Ok(result),
142                Err(err) => Ok(CallToolResult::error(err.to_string())),
143            }
144        })
145    }
146}
147
148/// A tower [`Layer`](tower::Layer) that applies a guard function before the inner service.
149///
150/// Guards run before the tool handler and can short-circuit with an error message.
151/// Use via [`ToolBuilderWithHandler::guard`] or [`Tool::with_guard`] rather than
152/// constructing directly.
153///
154/// # Example
155///
156/// ```rust
157/// use tower_mcp::{ToolBuilder, ToolRequest, CallToolResult};
158/// use schemars::JsonSchema;
159/// use serde::Deserialize;
160///
161/// #[derive(Debug, Deserialize, JsonSchema)]
162/// struct DeleteInput { id: String, confirm: bool }
163///
164/// let tool = ToolBuilder::new("delete")
165///     .description("Delete a record")
166///     .handler(|input: DeleteInput| async move {
167///         Ok(CallToolResult::text(format!("deleted {}", input.id)))
168///     })
169///     .guard(|req: &ToolRequest| {
170///         let confirm = req.args.get("confirm").and_then(|v| v.as_bool()).unwrap_or(false);
171///         if !confirm {
172///             return Err("Must set confirm=true to delete".to_string());
173///         }
174///         Ok(())
175///     })
176///     .build();
177/// ```
178#[derive(Clone)]
179pub struct GuardLayer<G> {
180    guard: G,
181}
182
183impl<G> GuardLayer<G> {
184    /// Create a new guard layer from a closure.
185    ///
186    /// The closure receives a `&ToolRequest` and returns `Ok(())` to proceed
187    /// or `Err(String)` to reject with an error message.
188    pub fn new(guard: G) -> Self {
189        Self { guard }
190    }
191}
192
193impl<G, S> tower::Layer<S> for GuardLayer<G>
194where
195    G: Clone,
196{
197    type Service = GuardService<G, S>;
198
199    fn layer(&self, inner: S) -> Self::Service {
200        GuardService {
201            guard: self.guard.clone(),
202            inner,
203        }
204    }
205}
206
207/// Service wrapper that runs a guard check before calling the inner service.
208///
209/// Created by [`GuardLayer`]. See its documentation for usage.
210#[derive(Clone)]
211pub struct GuardService<G, S> {
212    guard: G,
213    inner: S,
214}
215
216impl<G, S> Service<ToolRequest> for GuardService<G, S>
217where
218    G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
219    S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
220    S::Error: Into<Error> + Send,
221    S::Future: Send,
222{
223    type Response = CallToolResult;
224    type Error = Error;
225    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
226
227    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
228        self.inner.poll_ready(cx).map_err(Into::into)
229    }
230
231    fn call(&mut self, req: ToolRequest) -> Self::Future {
232        match (self.guard)(&req) {
233            Ok(()) => {
234                let fut = self.inner.call(req);
235                Box::pin(async move { fut.await.map_err(Into::into) })
236            }
237            Err(msg) => Box::pin(async move { Err(Error::tool(msg)) }),
238        }
239    }
240}
241
242/// A marker type for tools that take no parameters.
243///
244/// Use this instead of `()` when defining tools with no input parameters.
245/// The unit type `()` generates `"type": "null"` in JSON Schema, which many
246/// MCP clients reject. `NoParams` generates `"type": "object"` with no
247/// required properties, which is the correct schema for parameterless tools.
248///
249/// # Example
250///
251/// ```rust
252/// use tower_mcp::{ToolBuilder, CallToolResult, NoParams};
253///
254/// let tool = ToolBuilder::new("get_status")
255///     .description("Get current status")
256///     .handler(|_input: NoParams| async move {
257///         Ok(CallToolResult::text("OK"))
258///     })
259///     .build();
260/// ```
261#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
262pub struct NoParams;
263
264impl<'de> serde::Deserialize<'de> for NoParams {
265    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
266    where
267        D: serde::Deserializer<'de>,
268    {
269        // Accept null, empty object, or any object (ignoring all fields)
270        struct NoParamsVisitor;
271
272        impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
273            type Value = NoParams;
274
275            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
276                formatter.write_str("null or an object")
277            }
278
279            fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
280            where
281                E: serde::de::Error,
282            {
283                Ok(NoParams)
284            }
285
286            fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
287            where
288                E: serde::de::Error,
289            {
290                Ok(NoParams)
291            }
292
293            fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
294            where
295                D: serde::Deserializer<'de>,
296            {
297                serde::Deserialize::deserialize(deserializer)
298            }
299
300            fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
301            where
302                A: serde::de::MapAccess<'de>,
303            {
304                // Drain the map, ignoring all entries
305                while map
306                    .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
307                    .is_some()
308                {}
309                Ok(NoParams)
310            }
311        }
312
313        deserializer.deserialize_any(NoParamsVisitor)
314    }
315}
316
317impl JsonSchema for NoParams {
318    fn schema_name() -> Cow<'static, str> {
319        Cow::Borrowed("NoParams")
320    }
321
322    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
323        serde_json::json!({
324            "type": "object"
325        })
326        .try_into()
327        .expect("valid schema")
328    }
329}
330
331/// Validate a tool name according to MCP spec.
332///
333/// Tool names must be:
334/// - 1-128 characters long
335/// - Contain only alphanumeric characters, underscores, hyphens, and dots
336///
337/// Returns `Ok(())` if valid, `Err` with description if invalid.
338pub fn validate_tool_name(name: &str) -> Result<()> {
339    if name.is_empty() {
340        return Err(Error::tool("Tool name cannot be empty"));
341    }
342    if name.len() > 128 {
343        return Err(Error::tool(format!(
344            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
345            name,
346            name.len()
347        )));
348    }
349    if let Some(invalid_char) = name
350        .chars()
351        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
352    {
353        return Err(Error::tool(format!(
354            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
355            name, invalid_char
356        )));
357    }
358    Ok(())
359}
360
361/// A boxed future for tool handlers
362pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
363
364/// Tool handler trait - the core abstraction for tool execution
365pub trait ToolHandler: Send + Sync {
366    /// Execute the tool with the given arguments
367    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
368
369    /// Execute the tool with request context for progress/cancellation support
370    ///
371    /// The default implementation ignores the context and calls `call`.
372    /// Override this to receive progress/cancellation context.
373    fn call_with_context(
374        &self,
375        _ctx: RequestContext,
376        args: Value,
377    ) -> BoxFuture<'_, Result<CallToolResult>> {
378        self.call(args)
379    }
380
381    /// Returns true if this handler uses context (for optimization)
382    fn uses_context(&self) -> bool {
383        false
384    }
385
386    /// Get the tool's input schema
387    fn input_schema(&self) -> Value;
388}
389
390/// Adapts a `ToolHandler` to a Tower `Service<ToolRequest>`.
391///
392/// This is an internal adapter that bridges the handler abstraction to the
393/// service abstraction, enabling middleware composition.
394pub(crate) struct ToolHandlerService<H> {
395    handler: Arc<H>,
396}
397
398impl<H> ToolHandlerService<H> {
399    pub(crate) fn new(handler: H) -> Self {
400        Self {
401            handler: Arc::new(handler),
402        }
403    }
404}
405
406impl<H> Clone for ToolHandlerService<H> {
407    fn clone(&self) -> Self {
408        Self {
409            handler: self.handler.clone(),
410        }
411    }
412}
413
414impl<H> Service<ToolRequest> for ToolHandlerService<H>
415where
416    H: ToolHandler + 'static,
417{
418    type Response = CallToolResult;
419    type Error = Error;
420    type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
421
422    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
423        Poll::Ready(Ok(()))
424    }
425
426    fn call(&mut self, req: ToolRequest) -> Self::Future {
427        let handler = self.handler.clone();
428        Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
429    }
430}
431
432/// A complete tool definition with service-based execution.
433///
434/// Tools are implemented as Tower services internally, enabling middleware
435/// composition via the builder's `.layer()` method. The service is wrapped
436/// in [`ToolCatchError`] to convert any errors (from handlers or middleware)
437/// into `CallToolResult::error()` responses.
438pub struct Tool {
439    /// Tool name (must be 1-128 chars, alphanumeric/underscore/hyphen/dot only)
440    pub name: String,
441    /// Human-readable title for the tool
442    pub title: Option<String>,
443    /// Description of what the tool does
444    pub description: Option<String>,
445    /// JSON Schema for the tool's output (optional)
446    pub output_schema: Option<Value>,
447    /// Icons for the tool
448    pub icons: Option<Vec<ToolIcon>>,
449    /// Tool annotations (hints about behavior)
450    pub annotations: Option<ToolAnnotations>,
451    /// Task support mode for this tool
452    pub task_support: TaskSupportMode,
453    /// The boxed service that executes the tool
454    pub(crate) service: BoxToolService,
455    /// JSON Schema for the tool's input
456    pub(crate) input_schema: Value,
457}
458
459impl std::fmt::Debug for Tool {
460    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461        f.debug_struct("Tool")
462            .field("name", &self.name)
463            .field("title", &self.title)
464            .field("description", &self.description)
465            .field("output_schema", &self.output_schema)
466            .field("icons", &self.icons)
467            .field("annotations", &self.annotations)
468            .field("task_support", &self.task_support)
469            .finish_non_exhaustive()
470    }
471}
472
473// SAFETY: BoxCloneService is Send + Sync (tower provides unsafe impl Sync),
474// and all other fields in Tool are Send + Sync.
475unsafe impl Send for Tool {}
476unsafe impl Sync for Tool {}
477
478impl Clone for Tool {
479    fn clone(&self) -> Self {
480        Self {
481            name: self.name.clone(),
482            title: self.title.clone(),
483            description: self.description.clone(),
484            output_schema: self.output_schema.clone(),
485            icons: self.icons.clone(),
486            annotations: self.annotations.clone(),
487            task_support: self.task_support,
488            service: self.service.clone(),
489            input_schema: self.input_schema.clone(),
490        }
491    }
492}
493
494impl Tool {
495    /// Create a new tool builder
496    pub fn builder(name: impl Into<String>) -> ToolBuilder {
497        ToolBuilder::new(name)
498    }
499
500    /// Get the tool definition for tools/list
501    pub fn definition(&self) -> ToolDefinition {
502        let execution = match self.task_support {
503            TaskSupportMode::Forbidden => None,
504            mode => Some(ToolExecution {
505                task_support: Some(mode),
506            }),
507        };
508        ToolDefinition {
509            name: self.name.clone(),
510            title: self.title.clone(),
511            description: self.description.clone(),
512            input_schema: self.input_schema.clone(),
513            output_schema: self.output_schema.clone(),
514            icons: self.icons.clone(),
515            annotations: self.annotations.clone(),
516            execution,
517            meta: None,
518        }
519    }
520
521    /// Call the tool without context
522    ///
523    /// Creates a dummy request context. For full context support, use
524    /// [`call_with_context`](Self::call_with_context).
525    pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
526        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
527        self.call_with_context(ctx, args)
528    }
529
530    /// Call the tool with request context
531    ///
532    /// The context provides progress reporting, cancellation support, and
533    /// access to client requests (for sampling, etc.).
534    ///
535    /// # Note
536    ///
537    /// This method returns `CallToolResult` directly (not `Result<CallToolResult>`).
538    /// Any errors from the handler or middleware are converted to
539    /// `CallToolResult::error()` with `is_error: true`.
540    pub fn call_with_context(
541        &self,
542        ctx: RequestContext,
543        args: Value,
544    ) -> BoxFuture<'static, CallToolResult> {
545        use tower::ServiceExt;
546        let service = self.service.clone();
547        Box::pin(async move {
548            // ServiceExt::oneshot properly handles poll_ready before call
549            // Service is Infallible, so unwrap is safe
550            service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
551        })
552    }
553
554    /// Apply a guard to this built tool.
555    ///
556    /// The guard runs before the handler and can short-circuit with an error.
557    /// This is useful for applying the same guard to multiple tools (per-group
558    /// pattern):
559    ///
560    /// ```rust
561    /// use tower_mcp::{ToolBuilder, CallToolResult};
562    /// use tower_mcp::tool::ToolRequest;
563    /// use schemars::JsonSchema;
564    /// use serde::Deserialize;
565    ///
566    /// #[derive(Debug, Deserialize, JsonSchema)]
567    /// struct Input { value: String }
568    ///
569    /// fn build_tool(name: &str) -> tower_mcp::tool::Tool {
570    ///     ToolBuilder::new(name)
571    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
572    ///         .build()
573    /// }
574    ///
575    /// let guard = |_req: &ToolRequest| -> Result<(), String> { Ok(()) };
576    ///
577    /// let tools: Vec<_> = vec![build_tool("a"), build_tool("b")]
578    ///     .into_iter()
579    ///     .map(|t| t.with_guard(guard.clone()))
580    ///     .collect();
581    /// ```
582    pub fn with_guard<G>(self, guard: G) -> Self
583    where
584        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
585    {
586        let guarded = GuardService {
587            guard,
588            inner: self.service,
589        };
590        let caught = ToolCatchError::new(guarded);
591        Tool {
592            service: BoxCloneService::new(caught),
593            ..self
594        }
595    }
596
597    /// Create a new tool with a prefixed name.
598    ///
599    /// This creates a copy of the tool with its name prefixed by the given
600    /// string and a dot separator. For example, if the tool is named "query"
601    /// and the prefix is "db", the new tool will be named "db.query".
602    ///
603    /// This is used internally by `McpRouter::nest()` to namespace tools.
604    ///
605    /// # Example
606    ///
607    /// ```rust
608    /// use tower_mcp::{ToolBuilder, CallToolResult};
609    /// use schemars::JsonSchema;
610    /// use serde::Deserialize;
611    ///
612    /// #[derive(Debug, Deserialize, JsonSchema)]
613    /// struct Input { value: String }
614    ///
615    /// let tool = ToolBuilder::new("query")
616    ///     .description("Query the database")
617    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
618    ///     .build();
619    ///
620    /// let prefixed = tool.with_name_prefix("db");
621    /// assert_eq!(prefixed.name, "db.query");
622    /// ```
623    pub fn with_name_prefix(&self, prefix: &str) -> Self {
624        Self {
625            name: format!("{}.{}", prefix, self.name),
626            title: self.title.clone(),
627            description: self.description.clone(),
628            output_schema: self.output_schema.clone(),
629            icons: self.icons.clone(),
630            annotations: self.annotations.clone(),
631            task_support: self.task_support,
632            service: self.service.clone(),
633            input_schema: self.input_schema.clone(),
634        }
635    }
636
637    /// Create a tool from a handler (internal helper)
638    #[allow(clippy::too_many_arguments)]
639    fn from_handler<H: ToolHandler + 'static>(
640        name: String,
641        title: Option<String>,
642        description: Option<String>,
643        output_schema: Option<Value>,
644        icons: Option<Vec<ToolIcon>>,
645        annotations: Option<ToolAnnotations>,
646        task_support: TaskSupportMode,
647        handler: H,
648    ) -> Self {
649        let input_schema = handler.input_schema();
650        let handler_service = ToolHandlerService::new(handler);
651        let catch_error = ToolCatchError::new(handler_service);
652        let service = BoxCloneService::new(catch_error);
653
654        Self {
655            name,
656            title,
657            description,
658            output_schema,
659            icons,
660            annotations,
661            task_support,
662            service,
663            input_schema,
664        }
665    }
666}
667
668// =============================================================================
669// Builder API
670// =============================================================================
671
672/// Builder for creating tools with a fluent API
673///
674/// # Example
675///
676/// ```rust
677/// use tower_mcp::{ToolBuilder, CallToolResult};
678/// use schemars::JsonSchema;
679/// use serde::Deserialize;
680///
681/// #[derive(Debug, Deserialize, JsonSchema)]
682/// struct GreetInput {
683///     name: String,
684/// }
685///
686/// let tool = ToolBuilder::new("greet")
687///     .description("Greet someone by name")
688///     .handler(|input: GreetInput| async move {
689///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
690///     })
691///     .build();
692///
693/// assert_eq!(tool.name, "greet");
694/// ```
695pub struct ToolBuilder {
696    name: String,
697    title: Option<String>,
698    description: Option<String>,
699    output_schema: Option<Value>,
700    icons: Option<Vec<ToolIcon>>,
701    annotations: Option<ToolAnnotations>,
702    task_support: TaskSupportMode,
703}
704
705impl ToolBuilder {
706    /// Create a new tool builder with the given name.
707    ///
708    /// Tool names must be 1-128 characters and contain only alphanumeric
709    /// characters, underscores, hyphens, and dots.
710    ///
711    /// Use [`try_new`](Self::try_new) if the name comes from runtime input.
712    ///
713    /// # Panics
714    ///
715    /// Panics if `name` is empty, exceeds 128 characters, or contains
716    /// characters other than ASCII alphanumerics, `_`, `-`, and `.`.
717    pub fn new(name: impl Into<String>) -> Self {
718        let name = name.into();
719        if let Err(e) = validate_tool_name(&name) {
720            panic!("{e}");
721        }
722        Self {
723            name,
724            title: None,
725            description: None,
726            output_schema: None,
727            icons: None,
728            annotations: None,
729            task_support: TaskSupportMode::default(),
730        }
731    }
732
733    /// Create a new tool builder, returning an error if the name is invalid.
734    ///
735    /// This is the fallible alternative to [`new`](Self::new) for cases where
736    /// the tool name comes from runtime input (e.g., user configuration or
737    /// database).
738    pub fn try_new(name: impl Into<String>) -> Result<Self> {
739        let name = name.into();
740        validate_tool_name(&name)?;
741        Ok(Self {
742            name,
743            title: None,
744            description: None,
745            output_schema: None,
746            icons: None,
747            annotations: None,
748            task_support: TaskSupportMode::default(),
749        })
750    }
751
752    /// Set a human-readable title for the tool
753    pub fn title(mut self, title: impl Into<String>) -> Self {
754        self.title = Some(title.into());
755        self
756    }
757
758    /// Set the output schema (JSON Schema for structured output)
759    pub fn output_schema(mut self, schema: Value) -> Self {
760        self.output_schema = Some(schema);
761        self
762    }
763
764    /// Add an icon for the tool
765    pub fn icon(mut self, src: impl Into<String>) -> Self {
766        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
767            src: src.into(),
768            mime_type: None,
769            sizes: None,
770            theme: None,
771        });
772        self
773    }
774
775    /// Add an icon with metadata
776    pub fn icon_with_meta(
777        mut self,
778        src: impl Into<String>,
779        mime_type: Option<String>,
780        sizes: Option<Vec<String>>,
781    ) -> Self {
782        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
783            src: src.into(),
784            mime_type,
785            sizes,
786            theme: None,
787        });
788        self
789    }
790
791    /// Set the tool description
792    pub fn description(mut self, description: impl Into<String>) -> Self {
793        self.description = Some(description.into());
794        self
795    }
796
797    /// Mark the tool as read-only (does not modify state)
798    pub fn read_only(mut self) -> Self {
799        self.annotations
800            .get_or_insert_with(ToolAnnotations::default)
801            .read_only_hint = true;
802        self
803    }
804
805    /// Mark the tool as non-destructive
806    pub fn non_destructive(mut self) -> Self {
807        self.annotations
808            .get_or_insert_with(ToolAnnotations::default)
809            .destructive_hint = false;
810        self
811    }
812
813    /// Mark the tool as destructive (may perform irreversible operations)
814    pub fn destructive(mut self) -> Self {
815        self.annotations
816            .get_or_insert_with(ToolAnnotations::default)
817            .destructive_hint = true;
818        self
819    }
820
821    /// Mark the tool as idempotent (same args = same effect)
822    pub fn idempotent(mut self) -> Self {
823        self.annotations
824            .get_or_insert_with(ToolAnnotations::default)
825            .idempotent_hint = true;
826        self
827    }
828
829    /// Mark the tool as read-only, idempotent, and non-destructive.
830    ///
831    /// This is a convenience method for safe, side-effect-free tools.
832    /// For finer control, use `.read_only()`, `.idempotent()`, and
833    /// `.non_destructive()` individually.
834    pub fn read_only_safe(mut self) -> Self {
835        let ann = self
836            .annotations
837            .get_or_insert_with(ToolAnnotations::default);
838        ann.read_only_hint = true;
839        ann.idempotent_hint = true;
840        ann.destructive_hint = false;
841        self
842    }
843
844    /// Set tool annotations directly
845    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
846        self.annotations = Some(annotations);
847        self
848    }
849
850    /// Set the task support mode for this tool
851    pub fn task_support(mut self, mode: TaskSupportMode) -> Self {
852        self.task_support = mode;
853        self
854    }
855
856    /// Create a tool that takes no parameters.
857    ///
858    /// This is a convenience method for tools that don't require any input.
859    /// It generates the correct `{"type": "object"}` schema that MCP clients expect.
860    ///
861    /// # Example
862    ///
863    /// ```rust
864    /// use tower_mcp::{ToolBuilder, CallToolResult};
865    ///
866    /// let tool = ToolBuilder::new("get_status")
867    ///     .description("Get current status")
868    ///     .no_params_handler(|| async {
869    ///         Ok(CallToolResult::text("OK"))
870    ///     })
871    ///     .build();
872    /// ```
873    pub fn no_params_handler<F, Fut>(self, handler: F) -> ToolBuilderWithNoParamsHandler<F>
874    where
875        F: Fn() -> Fut + Send + Sync + 'static,
876        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
877    {
878        ToolBuilderWithNoParamsHandler {
879            name: self.name,
880            title: self.title,
881            description: self.description,
882            output_schema: self.output_schema,
883            icons: self.icons,
884            annotations: self.annotations,
885            task_support: self.task_support,
886            handler,
887        }
888    }
889
890    /// Specify input type and handler.
891    ///
892    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
893    /// The handler receives the deserialized input and returns a `CallToolResult`.
894    ///
895    /// # State Sharing
896    ///
897    /// To share state across tool calls (e.g., database connections, API clients),
898    /// wrap your state in an `Arc` and clone it into the async block:
899    ///
900    /// ```rust
901    /// use std::sync::Arc;
902    /// use tower_mcp::{ToolBuilder, CallToolResult};
903    /// use schemars::JsonSchema;
904    /// use serde::Deserialize;
905    ///
906    /// struct AppState {
907    ///     api_key: String,
908    /// }
909    ///
910    /// #[derive(Debug, Deserialize, JsonSchema)]
911    /// struct MyInput {
912    ///     query: String,
913    /// }
914    ///
915    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
916    ///
917    /// let tool = ToolBuilder::new("my_tool")
918    ///     .description("A tool that uses shared state")
919    ///     .handler(move |input: MyInput| {
920    ///         let state = state.clone(); // Clone Arc for the async block
921    ///         async move {
922    ///             // Use state.api_key here...
923    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
924    ///         }
925    ///     })
926    ///     .build();
927    /// ```
928    ///
929    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
930    /// cloning it inside the closure body allows each async invocation to
931    /// have its own reference to the shared state.
932    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
933    where
934        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
935        F: Fn(I) -> Fut + Send + Sync + 'static,
936        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
937    {
938        ToolBuilderWithHandler {
939            name: self.name,
940            title: self.title,
941            description: self.description,
942            output_schema: self.output_schema,
943            icons: self.icons,
944            annotations: self.annotations,
945            task_support: self.task_support,
946            handler,
947            _phantom: std::marker::PhantomData,
948        }
949    }
950
951    /// Create a tool using the extractor pattern.
952    ///
953    /// This method provides an axum-inspired way to define handlers where state,
954    /// context, and input are extracted declaratively from function parameters.
955    /// This reduces the combinatorial explosion of handler variants like
956    /// `handler_with_state`, `handler_with_context`, etc.
957    ///
958    /// # Schema Auto-Detection
959    ///
960    /// When a [`Json<T>`](crate::extract::Json) extractor is used, the proper JSON
961    /// schema is automatically generated from `T`'s `JsonSchema` implementation.
962    /// This means `extractor_handler` produces the same schema as
963    /// `extractor_handler_typed` for the common case, without requiring a turbofish.
964    ///
965    /// # Extractors
966    ///
967    /// Built-in extractors available in [`crate::extract`]:
968    /// - [`Json<T>`](crate::extract::Json) - Deserialize JSON arguments to type `T`
969    /// - [`State<T>`](crate::extract::State) - Extract cloned state
970    /// - [`Extension<T>`](crate::extract::Extension) - Extract router-level state
971    /// - [`Context`](crate::extract::Context) - Extract request context
972    /// - [`RawArgs`](crate::extract::RawArgs) - Extract raw JSON arguments
973    ///
974    /// # Per-Tool Middleware
975    ///
976    /// The returned builder supports `.layer()` to apply Tower middleware:
977    ///
978    /// ```rust
979    /// use std::sync::Arc;
980    /// use std::time::Duration;
981    /// use tower::timeout::TimeoutLayer;
982    /// use tower_mcp::{ToolBuilder, CallToolResult};
983    /// use tower_mcp::extract::{Json, State};
984    /// use schemars::JsonSchema;
985    /// use serde::Deserialize;
986    ///
987    /// #[derive(Clone)]
988    /// struct Database { url: String }
989    ///
990    /// #[derive(Debug, Deserialize, JsonSchema)]
991    /// struct QueryInput { query: String }
992    ///
993    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
994    ///
995    /// let tool = ToolBuilder::new("search")
996    ///     .description("Search the database")
997    ///     .extractor_handler(db, |
998    ///         State(db): State<Arc<Database>>,
999    ///         Json(input): Json<QueryInput>,
1000    ///     | async move {
1001    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
1002    ///     })
1003    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1004    ///     .build();
1005    /// ```
1006    ///
1007    /// # Example
1008    ///
1009    /// ```rust
1010    /// use std::sync::Arc;
1011    /// use tower_mcp::{ToolBuilder, CallToolResult};
1012    /// use tower_mcp::extract::{Json, State, Context};
1013    /// use schemars::JsonSchema;
1014    /// use serde::Deserialize;
1015    ///
1016    /// #[derive(Clone)]
1017    /// struct Database { url: String }
1018    ///
1019    /// #[derive(Debug, Deserialize, JsonSchema)]
1020    /// struct QueryInput { query: String }
1021    ///
1022    /// let db = Arc::new(Database { url: "postgres://...".to_string() });
1023    ///
1024    /// let tool = ToolBuilder::new("search")
1025    ///     .description("Search the database")
1026    ///     .extractor_handler(db, |
1027    ///         State(db): State<Arc<Database>>,
1028    ///         ctx: Context,
1029    ///         Json(input): Json<QueryInput>,
1030    ///     | async move {
1031    ///         if ctx.is_cancelled() {
1032    ///             return Ok(CallToolResult::error("Cancelled"));
1033    ///         }
1034    ///         ctx.report_progress(0.5, Some(1.0), Some("Searching...")).await;
1035    ///         Ok(CallToolResult::text(format!("Searched {} with: {}", db.url, input.query)))
1036    ///     })
1037    ///     .build();
1038    /// ```
1039    ///
1040    /// # Type Inference
1041    ///
1042    /// The compiler infers extractor types from the function signature. Make sure
1043    /// to annotate the extractor types explicitly in the closure parameters.
1044    pub fn extractor_handler<S, F, T>(
1045        self,
1046        state: S,
1047        handler: F,
1048    ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
1049    where
1050        S: Clone + Send + Sync + 'static,
1051        F: crate::extract::ExtractorHandler<S, T> + Clone,
1052        T: Send + Sync + 'static,
1053    {
1054        crate::extract::ToolBuilderWithExtractor {
1055            name: self.name,
1056            title: self.title,
1057            description: self.description,
1058            output_schema: self.output_schema,
1059            icons: self.icons,
1060            annotations: self.annotations,
1061            task_support: self.task_support,
1062            state,
1063            handler,
1064            input_schema: F::input_schema(),
1065            _phantom: std::marker::PhantomData,
1066        }
1067    }
1068
1069    /// Create a tool using the extractor pattern with typed JSON input.
1070    ///
1071    /// This is similar to [`extractor_handler`](Self::extractor_handler) but requires
1072    /// an explicit type parameter for the JSON input type via turbofish syntax.
1073    ///
1074    /// Since `extractor_handler` now auto-detects the JSON schema from `Json<T>`
1075    /// extractors, this method is typically unnecessary. It remains available for
1076    /// cases where you need explicit control over the schema type parameter.
1077    ///
1078    /// # Example
1079    ///
1080    /// ```rust
1081    /// use std::sync::Arc;
1082    /// use tower_mcp::{ToolBuilder, CallToolResult};
1083    /// use tower_mcp::extract::{Json, State};
1084    /// use schemars::JsonSchema;
1085    /// use serde::Deserialize;
1086    ///
1087    /// #[derive(Clone)]
1088    /// struct AppState { prefix: String }
1089    ///
1090    /// #[derive(Debug, Deserialize, JsonSchema)]
1091    /// struct GreetInput { name: String }
1092    ///
1093    /// let state = Arc::new(AppState { prefix: "Hello".to_string() });
1094    ///
1095    /// let tool = ToolBuilder::new("greet")
1096    ///     .description("Greet someone")
1097    ///     .extractor_handler_typed::<_, _, _, GreetInput>(state, |
1098    ///         State(app): State<Arc<AppState>>,
1099    ///         Json(input): Json<GreetInput>,
1100    ///     | async move {
1101    ///         Ok(CallToolResult::text(format!("{}, {}!", app.prefix, input.name)))
1102    ///     })
1103    ///     .build();
1104    /// ```
1105    pub fn extractor_handler_typed<S, F, T, I>(
1106        self,
1107        state: S,
1108        handler: F,
1109    ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
1110    where
1111        S: Clone + Send + Sync + 'static,
1112        F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
1113        T: Send + Sync + 'static,
1114        I: schemars::JsonSchema + Send + Sync + 'static,
1115    {
1116        crate::extract::ToolBuilderWithTypedExtractor {
1117            name: self.name,
1118            title: self.title,
1119            description: self.description,
1120            output_schema: self.output_schema,
1121            icons: self.icons,
1122            annotations: self.annotations,
1123            task_support: self.task_support,
1124            state,
1125            handler,
1126            _phantom: std::marker::PhantomData,
1127        }
1128    }
1129}
1130
1131/// Handler for tools with no parameters.
1132///
1133/// Used internally by [`ToolBuilder::no_params_handler`].
1134struct NoParamsTypedHandler<F> {
1135    handler: F,
1136}
1137
1138impl<F, Fut> ToolHandler for NoParamsTypedHandler<F>
1139where
1140    F: Fn() -> Fut + Send + Sync + 'static,
1141    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1142{
1143    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1144        Box::pin(async move { (self.handler)().await })
1145    }
1146
1147    fn input_schema(&self) -> Value {
1148        serde_json::json!({ "type": "object" })
1149    }
1150}
1151
1152/// Builder state after handler is specified
1153pub struct ToolBuilderWithHandler<I, F> {
1154    name: String,
1155    title: Option<String>,
1156    description: Option<String>,
1157    output_schema: Option<Value>,
1158    icons: Option<Vec<ToolIcon>>,
1159    annotations: Option<ToolAnnotations>,
1160    task_support: TaskSupportMode,
1161    handler: F,
1162    _phantom: std::marker::PhantomData<I>,
1163}
1164
1165/// Builder state for tools with no parameters.
1166///
1167/// Created by [`ToolBuilder::no_params_handler`].
1168pub struct ToolBuilderWithNoParamsHandler<F> {
1169    name: String,
1170    title: Option<String>,
1171    description: Option<String>,
1172    output_schema: Option<Value>,
1173    icons: Option<Vec<ToolIcon>>,
1174    annotations: Option<ToolAnnotations>,
1175    task_support: TaskSupportMode,
1176    handler: F,
1177}
1178
1179impl<F, Fut> ToolBuilderWithNoParamsHandler<F>
1180where
1181    F: Fn() -> Fut + Send + Sync + 'static,
1182    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1183{
1184    /// Build the tool.
1185    pub fn build(self) -> Tool {
1186        Tool::from_handler(
1187            self.name,
1188            self.title,
1189            self.description,
1190            self.output_schema,
1191            self.icons,
1192            self.annotations,
1193            self.task_support,
1194            NoParamsTypedHandler {
1195                handler: self.handler,
1196            },
1197        )
1198    }
1199
1200    /// Apply a Tower layer (middleware) to this tool.
1201    ///
1202    /// See [`ToolBuilderWithHandler::layer`] for details.
1203    pub fn layer<L>(self, layer: L) -> ToolBuilderWithNoParamsHandlerLayer<F, L> {
1204        ToolBuilderWithNoParamsHandlerLayer {
1205            name: self.name,
1206            title: self.title,
1207            description: self.description,
1208            output_schema: self.output_schema,
1209            icons: self.icons,
1210            annotations: self.annotations,
1211            task_support: self.task_support,
1212            handler: self.handler,
1213            layer,
1214        }
1215    }
1216
1217    /// Apply a guard to this tool.
1218    ///
1219    /// See [`ToolBuilderWithHandler::guard`] for details.
1220    pub fn guard<G>(self, guard: G) -> ToolBuilderWithNoParamsHandlerLayer<F, GuardLayer<G>>
1221    where
1222        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1223    {
1224        self.layer(GuardLayer::new(guard))
1225    }
1226}
1227
1228/// Builder state after a layer has been applied to a no-params handler.
1229pub struct ToolBuilderWithNoParamsHandlerLayer<F, L> {
1230    name: String,
1231    title: Option<String>,
1232    description: Option<String>,
1233    output_schema: Option<Value>,
1234    icons: Option<Vec<ToolIcon>>,
1235    annotations: Option<ToolAnnotations>,
1236    task_support: TaskSupportMode,
1237    handler: F,
1238    layer: L,
1239}
1240
1241#[allow(private_bounds)]
1242impl<F, Fut, L> ToolBuilderWithNoParamsHandlerLayer<F, L>
1243where
1244    F: Fn() -> Fut + Send + Sync + 'static,
1245    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1246    L: tower::Layer<ToolHandlerService<NoParamsTypedHandler<F>>> + Clone + Send + Sync + 'static,
1247    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1248    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1249    <L::Service as Service<ToolRequest>>::Future: Send,
1250{
1251    /// Build the tool with the applied layer(s).
1252    pub fn build(self) -> Tool {
1253        let input_schema = serde_json::json!({ "type": "object" });
1254
1255        let handler_service = ToolHandlerService::new(NoParamsTypedHandler {
1256            handler: self.handler,
1257        });
1258        let layered = self.layer.layer(handler_service);
1259        let catch_error = ToolCatchError::new(layered);
1260        let service = BoxCloneService::new(catch_error);
1261
1262        Tool {
1263            name: self.name,
1264            title: self.title,
1265            description: self.description,
1266            output_schema: self.output_schema,
1267            icons: self.icons,
1268            annotations: self.annotations,
1269            task_support: self.task_support,
1270            service,
1271            input_schema,
1272        }
1273    }
1274
1275    /// Apply an additional Tower layer (middleware).
1276    pub fn layer<L2>(
1277        self,
1278        layer: L2,
1279    ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<L2, L>> {
1280        ToolBuilderWithNoParamsHandlerLayer {
1281            name: self.name,
1282            title: self.title,
1283            description: self.description,
1284            output_schema: self.output_schema,
1285            icons: self.icons,
1286            annotations: self.annotations,
1287            task_support: self.task_support,
1288            handler: self.handler,
1289            layer: tower::layer::util::Stack::new(layer, self.layer),
1290        }
1291    }
1292
1293    /// Apply a guard to this tool.
1294    ///
1295    /// See [`ToolBuilderWithHandler::guard`] for details.
1296    pub fn guard<G>(
1297        self,
1298        guard: G,
1299    ) -> ToolBuilderWithNoParamsHandlerLayer<F, tower::layer::util::Stack<GuardLayer<G>, L>>
1300    where
1301        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1302    {
1303        self.layer(GuardLayer::new(guard))
1304    }
1305}
1306
1307impl<I, F, Fut> ToolBuilderWithHandler<I, F>
1308where
1309    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1310    F: Fn(I) -> Fut + Send + Sync + 'static,
1311    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1312{
1313    /// Build the tool.
1314    pub fn build(self) -> Tool {
1315        Tool::from_handler(
1316            self.name,
1317            self.title,
1318            self.description,
1319            self.output_schema,
1320            self.icons,
1321            self.annotations,
1322            self.task_support,
1323            TypedHandler {
1324                handler: self.handler,
1325                _phantom: std::marker::PhantomData,
1326            },
1327        )
1328    }
1329
1330    /// Apply a Tower layer (middleware) to this tool.
1331    ///
1332    /// The layer wraps the tool's handler service, enabling functionality like
1333    /// timeouts, rate limiting, and metrics collection at the per-tool level.
1334    ///
1335    /// # Example
1336    ///
1337    /// ```rust
1338    /// use std::time::Duration;
1339    /// use tower::timeout::TimeoutLayer;
1340    /// use tower_mcp::{ToolBuilder, CallToolResult};
1341    /// use schemars::JsonSchema;
1342    /// use serde::Deserialize;
1343    ///
1344    /// #[derive(Debug, Deserialize, JsonSchema)]
1345    /// struct Input { query: String }
1346    ///
1347    /// let tool = ToolBuilder::new("search")
1348    ///     .description("Search with timeout")
1349    ///     .handler(|input: Input| async move {
1350    ///         Ok(CallToolResult::text("result"))
1351    ///     })
1352    ///     .layer(TimeoutLayer::new(Duration::from_secs(30)))
1353    ///     .build();
1354    /// ```
1355    pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
1356        ToolBuilderWithLayer {
1357            name: self.name,
1358            title: self.title,
1359            description: self.description,
1360            output_schema: self.output_schema,
1361            icons: self.icons,
1362            annotations: self.annotations,
1363            task_support: self.task_support,
1364            handler: self.handler,
1365            layer,
1366            _phantom: std::marker::PhantomData,
1367        }
1368    }
1369
1370    /// Apply a guard to this tool.
1371    ///
1372    /// The guard runs before the handler and can short-circuit with an error
1373    /// message. This is syntactic sugar for `.layer(GuardLayer::new(f))`.
1374    ///
1375    /// See [`GuardLayer`] for a full example.
1376    pub fn guard<G>(self, guard: G) -> ToolBuilderWithLayer<I, F, GuardLayer<G>>
1377    where
1378        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1379    {
1380        self.layer(GuardLayer::new(guard))
1381    }
1382}
1383
1384/// Builder state after a layer has been applied to the handler.
1385///
1386/// This builder allows chaining additional layers and building the final tool.
1387pub struct ToolBuilderWithLayer<I, F, L> {
1388    name: String,
1389    title: Option<String>,
1390    description: Option<String>,
1391    output_schema: Option<Value>,
1392    icons: Option<Vec<ToolIcon>>,
1393    annotations: Option<ToolAnnotations>,
1394    task_support: TaskSupportMode,
1395    handler: F,
1396    layer: L,
1397    _phantom: std::marker::PhantomData<I>,
1398}
1399
1400// Allow private_bounds because these internal types (ToolHandlerService, TypedHandler, etc.)
1401// are implementation details that users don't interact with directly.
1402#[allow(private_bounds)]
1403impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
1404where
1405    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1406    F: Fn(I) -> Fut + Send + Sync + 'static,
1407    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1408    L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
1409    L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
1410    <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
1411    <L::Service as Service<ToolRequest>>::Future: Send,
1412{
1413    /// Build the tool with the applied layer(s).
1414    pub fn build(self) -> Tool {
1415        let input_schema = schemars::schema_for!(I);
1416        let input_schema = serde_json::to_value(input_schema)
1417            .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
1418
1419        let handler_service = ToolHandlerService::new(TypedHandler {
1420            handler: self.handler,
1421            _phantom: std::marker::PhantomData,
1422        });
1423        let layered = self.layer.layer(handler_service);
1424        let catch_error = ToolCatchError::new(layered);
1425        let service = BoxCloneService::new(catch_error);
1426
1427        Tool {
1428            name: self.name,
1429            title: self.title,
1430            description: self.description,
1431            output_schema: self.output_schema,
1432            icons: self.icons,
1433            annotations: self.annotations,
1434            task_support: self.task_support,
1435            service,
1436            input_schema,
1437        }
1438    }
1439
1440    /// Apply an additional Tower layer (middleware).
1441    ///
1442    /// Layers are applied in order, with earlier layers wrapping later ones.
1443    /// This means the first layer added is the outermost middleware.
1444    pub fn layer<L2>(
1445        self,
1446        layer: L2,
1447    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
1448        ToolBuilderWithLayer {
1449            name: self.name,
1450            title: self.title,
1451            description: self.description,
1452            output_schema: self.output_schema,
1453            icons: self.icons,
1454            annotations: self.annotations,
1455            task_support: self.task_support,
1456            handler: self.handler,
1457            layer: tower::layer::util::Stack::new(layer, self.layer),
1458            _phantom: std::marker::PhantomData,
1459        }
1460    }
1461
1462    /// Apply a guard to this tool.
1463    ///
1464    /// See [`ToolBuilderWithHandler::guard`] for details.
1465    pub fn guard<G>(
1466        self,
1467        guard: G,
1468    ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<GuardLayer<G>, L>>
1469    where
1470        G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
1471    {
1472        self.layer(GuardLayer::new(guard))
1473    }
1474}
1475
1476// =============================================================================
1477// Handler implementations
1478// =============================================================================
1479
1480/// Handler that deserializes input to a specific type
1481struct TypedHandler<I, F> {
1482    handler: F,
1483    _phantom: std::marker::PhantomData<I>,
1484}
1485
1486impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
1487where
1488    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
1489    F: Fn(I) -> Fut + Send + Sync + 'static,
1490    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
1491{
1492    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1493        Box::pin(async move {
1494            let input: I = serde_json::from_value(args).tool_context("Invalid input")?;
1495            (self.handler)(input).await
1496        })
1497    }
1498
1499    fn input_schema(&self) -> Value {
1500        let schema = schemars::schema_for!(I);
1501        serde_json::to_value(schema).unwrap_or_else(|_| {
1502            serde_json::json!({
1503                "type": "object"
1504            })
1505        })
1506    }
1507}
1508
1509// =============================================================================
1510// Trait-based tool definition
1511// =============================================================================
1512
1513/// Trait for defining tools with full control
1514///
1515/// Implement this trait when you need more control than the builder provides,
1516/// or when you want to define tools as standalone types.
1517///
1518/// # Example
1519///
1520/// ```rust
1521/// use tower_mcp::tool::McpTool;
1522/// use tower_mcp::error::Result;
1523/// use schemars::JsonSchema;
1524/// use serde::{Deserialize, Serialize};
1525///
1526/// #[derive(Debug, Deserialize, JsonSchema)]
1527/// struct AddInput {
1528///     a: i64,
1529///     b: i64,
1530/// }
1531///
1532/// struct AddTool;
1533///
1534/// impl McpTool for AddTool {
1535///     const NAME: &'static str = "add";
1536///     const DESCRIPTION: &'static str = "Add two numbers";
1537///
1538///     type Input = AddInput;
1539///     type Output = i64;
1540///
1541///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
1542///         Ok(input.a + input.b)
1543///     }
1544/// }
1545///
1546/// let tool = AddTool.into_tool();
1547/// assert_eq!(tool.name, "add");
1548/// ```
1549pub trait McpTool: Send + Sync + 'static {
1550    const NAME: &'static str;
1551    const DESCRIPTION: &'static str;
1552
1553    type Input: JsonSchema + DeserializeOwned + Send;
1554    type Output: Serialize + Send;
1555
1556    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1557
1558    /// Optional annotations for the tool
1559    fn annotations(&self) -> Option<ToolAnnotations> {
1560        None
1561    }
1562
1563    /// Convert to a [`Tool`] instance.
1564    ///
1565    /// # Panics
1566    ///
1567    /// Panics if [`NAME`](Self::NAME) is not a valid tool name. Since `NAME`
1568    /// is a `&'static str`, invalid names are caught immediately during
1569    /// development.
1570    fn into_tool(self) -> Tool
1571    where
1572        Self: Sized,
1573    {
1574        if let Err(e) = validate_tool_name(Self::NAME) {
1575            panic!("{e}");
1576        }
1577        let annotations = self.annotations();
1578        let tool = Arc::new(self);
1579        Tool::from_handler(
1580            Self::NAME.to_string(),
1581            None,
1582            Some(Self::DESCRIPTION.to_string()),
1583            None,
1584            None,
1585            annotations,
1586            TaskSupportMode::default(),
1587            McpToolHandler { tool },
1588        )
1589    }
1590}
1591
1592/// Wrapper to make McpTool implement ToolHandler
1593struct McpToolHandler<T: McpTool> {
1594    tool: Arc<T>,
1595}
1596
1597impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1598    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1599        let tool = self.tool.clone();
1600        Box::pin(async move {
1601            let input: T::Input = serde_json::from_value(args).tool_context("Invalid input")?;
1602            let output = tool.call(input).await?;
1603            let value = serde_json::to_value(output).tool_context("Failed to serialize output")?;
1604            Ok(CallToolResult::json(value))
1605        })
1606    }
1607
1608    fn input_schema(&self) -> Value {
1609        let schema = schemars::schema_for!(T::Input);
1610        serde_json::to_value(schema).unwrap_or_else(|_| {
1611            serde_json::json!({
1612                "type": "object"
1613            })
1614        })
1615    }
1616}
1617
1618#[cfg(test)]
1619mod tests {
1620    use super::*;
1621    use crate::extract::{Context, Json, RawArgs, State};
1622    use crate::protocol::Content;
1623    use schemars::JsonSchema;
1624    use serde::Deserialize;
1625
1626    #[derive(Debug, Deserialize, JsonSchema)]
1627    struct GreetInput {
1628        name: String,
1629    }
1630
1631    #[tokio::test]
1632    async fn test_builder_tool() {
1633        let tool = ToolBuilder::new("greet")
1634            .description("Greet someone")
1635            .handler(|input: GreetInput| async move {
1636                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1637            })
1638            .build();
1639
1640        assert_eq!(tool.name, "greet");
1641        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1642
1643        let result = tool.call(serde_json::json!({"name": "World"})).await;
1644
1645        assert!(!result.is_error);
1646    }
1647
1648    #[tokio::test]
1649    async fn test_raw_handler() {
1650        let tool = ToolBuilder::new("echo")
1651            .description("Echo input")
1652            .extractor_handler((), |RawArgs(args): RawArgs| async move {
1653                Ok(CallToolResult::json(args))
1654            })
1655            .build();
1656
1657        let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1658
1659        assert!(!result.is_error);
1660    }
1661
1662    #[test]
1663    fn test_invalid_tool_name_empty() {
1664        let err = ToolBuilder::try_new("").err().expect("should fail");
1665        assert!(err.to_string().contains("cannot be empty"));
1666    }
1667
1668    #[test]
1669    fn test_invalid_tool_name_too_long() {
1670        let long_name = "a".repeat(129);
1671        let err = ToolBuilder::try_new(long_name).err().expect("should fail");
1672        assert!(err.to_string().contains("exceeds maximum"));
1673    }
1674
1675    #[test]
1676    fn test_invalid_tool_name_bad_chars() {
1677        let err = ToolBuilder::try_new("my tool!").err().expect("should fail");
1678        assert!(err.to_string().contains("invalid character"));
1679    }
1680
1681    #[test]
1682    #[should_panic(expected = "cannot be empty")]
1683    fn test_new_panics_on_empty_name() {
1684        ToolBuilder::new("");
1685    }
1686
1687    #[test]
1688    #[should_panic(expected = "exceeds maximum")]
1689    fn test_new_panics_on_too_long_name() {
1690        ToolBuilder::new("a".repeat(129));
1691    }
1692
1693    #[test]
1694    #[should_panic(expected = "invalid character")]
1695    fn test_new_panics_on_invalid_chars() {
1696        ToolBuilder::new("my tool!");
1697    }
1698
1699    #[test]
1700    fn test_valid_tool_names() {
1701        // All valid characters
1702        let names = [
1703            "my_tool",
1704            "my-tool",
1705            "my.tool",
1706            "MyTool123",
1707            "a",
1708            &"a".repeat(128),
1709        ];
1710        for name in names {
1711            assert!(
1712                ToolBuilder::try_new(name).is_ok(),
1713                "Expected '{}' to be valid",
1714                name
1715            );
1716        }
1717    }
1718
1719    #[tokio::test]
1720    async fn test_context_aware_handler() {
1721        use crate::context::notification_channel;
1722        use crate::protocol::{ProgressToken, RequestId};
1723
1724        #[derive(Debug, Deserialize, JsonSchema)]
1725        struct ProcessInput {
1726            count: i32,
1727        }
1728
1729        let tool = ToolBuilder::new("process")
1730            .description("Process with context")
1731            .extractor_handler(
1732                (),
1733                |ctx: Context, Json(input): Json<ProcessInput>| async move {
1734                    // Simulate progress reporting
1735                    for i in 0..input.count {
1736                        if ctx.is_cancelled() {
1737                            return Ok(CallToolResult::error("Cancelled"));
1738                        }
1739                        ctx.report_progress(i as f64, Some(input.count as f64), None)
1740                            .await;
1741                    }
1742                    Ok(CallToolResult::text(format!(
1743                        "Processed {} items",
1744                        input.count
1745                    )))
1746                },
1747            )
1748            .build();
1749
1750        assert_eq!(tool.name, "process");
1751
1752        // Test with a context that has progress token and notification sender
1753        let (tx, mut rx) = notification_channel(10);
1754        let ctx = RequestContext::new(RequestId::Number(1))
1755            .with_progress_token(ProgressToken::Number(42))
1756            .with_notification_sender(tx);
1757
1758        let result = tool
1759            .call_with_context(ctx, serde_json::json!({"count": 3}))
1760            .await;
1761
1762        assert!(!result.is_error);
1763
1764        // Check that progress notifications were sent
1765        let mut progress_count = 0;
1766        while rx.try_recv().is_ok() {
1767            progress_count += 1;
1768        }
1769        assert_eq!(progress_count, 3);
1770    }
1771
1772    #[tokio::test]
1773    async fn test_context_aware_handler_cancellation() {
1774        use crate::protocol::RequestId;
1775        use std::sync::atomic::{AtomicI32, Ordering};
1776
1777        #[derive(Debug, Deserialize, JsonSchema)]
1778        struct LongRunningInput {
1779            iterations: i32,
1780        }
1781
1782        let iterations_completed = Arc::new(AtomicI32::new(0));
1783        let iterations_ref = iterations_completed.clone();
1784
1785        let tool = ToolBuilder::new("long_running")
1786            .description("Long running task")
1787            .extractor_handler(
1788                (),
1789                move |ctx: Context, Json(input): Json<LongRunningInput>| {
1790                    let completed = iterations_ref.clone();
1791                    async move {
1792                        for i in 0..input.iterations {
1793                            if ctx.is_cancelled() {
1794                                return Ok(CallToolResult::error("Cancelled"));
1795                            }
1796                            completed.fetch_add(1, Ordering::SeqCst);
1797                            // Simulate work
1798                            tokio::task::yield_now().await;
1799                            // Cancel after iteration 2
1800                            if i == 2 {
1801                                ctx.cancellation_token().cancel();
1802                            }
1803                        }
1804                        Ok(CallToolResult::text("Done"))
1805                    }
1806                },
1807            )
1808            .build();
1809
1810        let ctx = RequestContext::new(RequestId::Number(1));
1811
1812        let result = tool
1813            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1814            .await;
1815
1816        // Should have been cancelled after 3 iterations (0, 1, 2)
1817        // The next iteration (3) checks cancellation and returns
1818        assert!(result.is_error);
1819        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1820    }
1821
1822    #[tokio::test]
1823    async fn test_tool_builder_with_enhanced_fields() {
1824        let output_schema = serde_json::json!({
1825            "type": "object",
1826            "properties": {
1827                "greeting": {"type": "string"}
1828            }
1829        });
1830
1831        let tool = ToolBuilder::new("greet")
1832            .title("Greeting Tool")
1833            .description("Greet someone")
1834            .output_schema(output_schema.clone())
1835            .icon("https://example.com/icon.png")
1836            .icon_with_meta(
1837                "https://example.com/icon-large.png",
1838                Some("image/png".to_string()),
1839                Some(vec!["96x96".to_string()]),
1840            )
1841            .handler(|input: GreetInput| async move {
1842                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1843            })
1844            .build();
1845
1846        assert_eq!(tool.name, "greet");
1847        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1848        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1849        assert_eq!(tool.output_schema, Some(output_schema));
1850        assert!(tool.icons.is_some());
1851        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1852
1853        // Test definition includes new fields
1854        let def = tool.definition();
1855        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1856        assert!(def.output_schema.is_some());
1857        assert!(def.icons.is_some());
1858    }
1859
1860    #[tokio::test]
1861    async fn test_handler_with_state() {
1862        let shared = Arc::new("shared-state".to_string());
1863
1864        let tool = ToolBuilder::new("stateful")
1865            .description("Uses shared state")
1866            .extractor_handler(
1867                shared,
1868                |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1869                    Ok(CallToolResult::text(format!(
1870                        "{}: Hello, {}!",
1871                        state, input.name
1872                    )))
1873                },
1874            )
1875            .build();
1876
1877        let result = tool.call(serde_json::json!({"name": "World"})).await;
1878        assert!(!result.is_error);
1879    }
1880
1881    #[tokio::test]
1882    async fn test_handler_with_state_and_context() {
1883        use crate::protocol::RequestId;
1884
1885        let shared = Arc::new(42_i32);
1886
1887        let tool =
1888            ToolBuilder::new("stateful_ctx")
1889                .description("Uses state and context")
1890                .extractor_handler(
1891                    shared,
1892                    |State(state): State<Arc<i32>>,
1893                     _ctx: Context,
1894                     Json(input): Json<GreetInput>| async move {
1895                        Ok(CallToolResult::text(format!(
1896                            "{}: Hello, {}!",
1897                            state, input.name
1898                        )))
1899                    },
1900                )
1901                .build();
1902
1903        let ctx = RequestContext::new(RequestId::Number(1));
1904        let result = tool
1905            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1906            .await;
1907        assert!(!result.is_error);
1908    }
1909
1910    #[tokio::test]
1911    async fn test_handler_no_params() {
1912        let tool = ToolBuilder::new("no_params")
1913            .description("Takes no parameters")
1914            .extractor_handler((), |Json(_): Json<NoParams>| async {
1915                Ok(CallToolResult::text("no params result"))
1916            })
1917            .build();
1918
1919        assert_eq!(tool.name, "no_params");
1920
1921        // Should work with empty args
1922        let result = tool.call(serde_json::json!({})).await;
1923        assert!(!result.is_error);
1924
1925        // Should also work with unexpected args (ignored)
1926        let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1927        assert!(!result.is_error);
1928
1929        // Check input schema includes type: object
1930        let schema = tool.definition().input_schema;
1931        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1932    }
1933
1934    #[tokio::test]
1935    async fn test_handler_with_state_no_params() {
1936        let shared = Arc::new("shared_value".to_string());
1937
1938        let tool = ToolBuilder::new("with_state_no_params")
1939            .description("Takes no parameters but has state")
1940            .extractor_handler(
1941                shared,
1942                |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1943                    Ok(CallToolResult::text(format!("state: {}", state)))
1944                },
1945            )
1946            .build();
1947
1948        assert_eq!(tool.name, "with_state_no_params");
1949
1950        // Should work with empty args
1951        let result = tool.call(serde_json::json!({})).await;
1952        assert!(!result.is_error);
1953        assert_eq!(result.first_text().unwrap(), "state: shared_value");
1954
1955        // Check input schema includes type: object
1956        let schema = tool.definition().input_schema;
1957        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1958    }
1959
1960    #[tokio::test]
1961    async fn test_handler_no_params_with_context() {
1962        let tool = ToolBuilder::new("no_params_with_context")
1963            .description("Takes no parameters but has context")
1964            .extractor_handler((), |_ctx: Context, Json(_): Json<NoParams>| async move {
1965                Ok(CallToolResult::text("context available"))
1966            })
1967            .build();
1968
1969        assert_eq!(tool.name, "no_params_with_context");
1970
1971        let result = tool.call(serde_json::json!({})).await;
1972        assert!(!result.is_error);
1973        assert_eq!(result.first_text().unwrap(), "context available");
1974    }
1975
1976    #[tokio::test]
1977    async fn test_handler_with_state_and_context_no_params() {
1978        let shared = Arc::new("shared".to_string());
1979
1980        let tool = ToolBuilder::new("state_context_no_params")
1981            .description("Has state and context, no params")
1982            .extractor_handler(
1983                shared,
1984                |State(state): State<Arc<String>>,
1985                 _ctx: Context,
1986                 Json(_): Json<NoParams>| async move {
1987                    Ok(CallToolResult::text(format!("state: {}", state)))
1988                },
1989            )
1990            .build();
1991
1992        assert_eq!(tool.name, "state_context_no_params");
1993
1994        let result = tool.call(serde_json::json!({})).await;
1995        assert!(!result.is_error);
1996        assert_eq!(result.first_text().unwrap(), "state: shared");
1997    }
1998
1999    #[tokio::test]
2000    async fn test_raw_handler_with_state() {
2001        let prefix = Arc::new("prefix:".to_string());
2002
2003        let tool = ToolBuilder::new("raw_with_state")
2004            .description("Raw handler with state")
2005            .extractor_handler(
2006                prefix,
2007                |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
2008                    Ok(CallToolResult::text(format!("{} {}", state, args)))
2009                },
2010            )
2011            .build();
2012
2013        assert_eq!(tool.name, "raw_with_state");
2014
2015        let result = tool.call(serde_json::json!({"key": "value"})).await;
2016        assert!(!result.is_error);
2017        assert!(result.first_text().unwrap().starts_with("prefix:"));
2018    }
2019
2020    #[tokio::test]
2021    async fn test_raw_handler_with_state_and_context() {
2022        let prefix = Arc::new("prefix:".to_string());
2023
2024        let tool = ToolBuilder::new("raw_state_context")
2025            .description("Raw handler with state and context")
2026            .extractor_handler(
2027                prefix,
2028                |State(state): State<Arc<String>>,
2029                 _ctx: Context,
2030                 RawArgs(args): RawArgs| async move {
2031                    Ok(CallToolResult::text(format!("{} {}", state, args)))
2032                },
2033            )
2034            .build();
2035
2036        assert_eq!(tool.name, "raw_state_context");
2037
2038        let result = tool.call(serde_json::json!({"key": "value"})).await;
2039        assert!(!result.is_error);
2040        assert!(result.first_text().unwrap().starts_with("prefix:"));
2041    }
2042
2043    #[tokio::test]
2044    async fn test_tool_with_timeout_layer() {
2045        use std::time::Duration;
2046        use tower::timeout::TimeoutLayer;
2047
2048        #[derive(Debug, Deserialize, JsonSchema)]
2049        struct SlowInput {
2050            delay_ms: u64,
2051        }
2052
2053        // Create a tool with a short timeout
2054        let tool = ToolBuilder::new("slow_tool")
2055            .description("A slow tool")
2056            .handler(|input: SlowInput| async move {
2057                tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
2058                Ok(CallToolResult::text("completed"))
2059            })
2060            .layer(TimeoutLayer::new(Duration::from_millis(50)))
2061            .build();
2062
2063        // Fast call should succeed
2064        let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
2065        assert!(!result.is_error);
2066        assert_eq!(result.first_text().unwrap(), "completed");
2067
2068        // Slow call should timeout and return an error result
2069        let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
2070        assert!(result.is_error);
2071        // Tower's timeout error message is "request timed out"
2072        let msg = result.first_text().unwrap().to_lowercase();
2073        assert!(
2074            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2075            "Expected timeout error, got: {}",
2076            msg
2077        );
2078    }
2079
2080    #[tokio::test]
2081    async fn test_tool_with_concurrency_limit_layer() {
2082        use std::sync::atomic::{AtomicU32, Ordering};
2083        use std::time::Duration;
2084        use tower::limit::ConcurrencyLimitLayer;
2085
2086        #[derive(Debug, Deserialize, JsonSchema)]
2087        struct WorkInput {
2088            id: u32,
2089        }
2090
2091        let max_concurrent = Arc::new(AtomicU32::new(0));
2092        let current_concurrent = Arc::new(AtomicU32::new(0));
2093        let max_ref = max_concurrent.clone();
2094        let current_ref = current_concurrent.clone();
2095
2096        // Create a tool with concurrency limit of 2
2097        let tool = ToolBuilder::new("concurrent_tool")
2098            .description("A concurrent tool")
2099            .handler(move |input: WorkInput| {
2100                let max = max_ref.clone();
2101                let current = current_ref.clone();
2102                async move {
2103                    // Track concurrency
2104                    let prev = current.fetch_add(1, Ordering::SeqCst);
2105                    max.fetch_max(prev + 1, Ordering::SeqCst);
2106
2107                    // Simulate work
2108                    tokio::time::sleep(Duration::from_millis(50)).await;
2109
2110                    current.fetch_sub(1, Ordering::SeqCst);
2111                    Ok(CallToolResult::text(format!("completed {}", input.id)))
2112                }
2113            })
2114            .layer(ConcurrencyLimitLayer::new(2))
2115            .build();
2116
2117        // Launch 4 concurrent calls
2118        let handles: Vec<_> = (0..4)
2119            .map(|i| {
2120                let t = tool.call(serde_json::json!({"id": i}));
2121                tokio::spawn(t)
2122            })
2123            .collect();
2124
2125        for handle in handles {
2126            let result = handle.await.unwrap();
2127            assert!(!result.is_error);
2128        }
2129
2130        // Max concurrent should not exceed 2
2131        assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
2132    }
2133
2134    #[tokio::test]
2135    async fn test_tool_with_multiple_layers() {
2136        use std::time::Duration;
2137        use tower::limit::ConcurrencyLimitLayer;
2138        use tower::timeout::TimeoutLayer;
2139
2140        #[derive(Debug, Deserialize, JsonSchema)]
2141        struct Input {
2142            value: String,
2143        }
2144
2145        // Create a tool with multiple layers stacked
2146        let tool = ToolBuilder::new("multi_layer_tool")
2147            .description("Tool with multiple layers")
2148            .handler(|input: Input| async move {
2149                Ok(CallToolResult::text(format!("processed: {}", input.value)))
2150            })
2151            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2152            .layer(ConcurrencyLimitLayer::new(10))
2153            .build();
2154
2155        let result = tool.call(serde_json::json!({"value": "test"})).await;
2156        assert!(!result.is_error);
2157        assert_eq!(result.first_text().unwrap(), "processed: test");
2158    }
2159
2160    #[test]
2161    fn test_tool_catch_error_clone() {
2162        // ToolCatchError should be Clone when inner is Clone
2163        // Use a simple tool that we can clone
2164        let tool = ToolBuilder::new("test")
2165            .description("test")
2166            .extractor_handler((), |RawArgs(_args): RawArgs| async {
2167                Ok(CallToolResult::text("ok"))
2168            })
2169            .build();
2170        // The tool contains a BoxToolService which is cloneable
2171        let _clone = tool.call(serde_json::json!({}));
2172    }
2173
2174    #[test]
2175    fn test_tool_catch_error_debug() {
2176        // ToolCatchError implements Debug when inner implements Debug
2177        // Since our internal services don't require Debug, just verify
2178        // that ToolCatchError has a Debug impl for appropriate types
2179        #[derive(Debug, Clone)]
2180        struct DebugService;
2181
2182        impl Service<ToolRequest> for DebugService {
2183            type Response = CallToolResult;
2184            type Error = crate::error::Error;
2185            type Future = Pin<
2186                Box<
2187                    dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
2188                        + Send,
2189                >,
2190            >;
2191
2192            fn poll_ready(
2193                &mut self,
2194                _cx: &mut std::task::Context<'_>,
2195            ) -> Poll<std::result::Result<(), Self::Error>> {
2196                Poll::Ready(Ok(()))
2197            }
2198
2199            fn call(&mut self, _req: ToolRequest) -> Self::Future {
2200                Box::pin(async { Ok(CallToolResult::text("ok")) })
2201            }
2202        }
2203
2204        let catch_error = ToolCatchError::new(DebugService);
2205        let debug = format!("{:?}", catch_error);
2206        assert!(debug.contains("ToolCatchError"));
2207    }
2208
2209    #[test]
2210    fn test_tool_request_new() {
2211        use crate::protocol::RequestId;
2212
2213        let ctx = RequestContext::new(RequestId::Number(42));
2214        let args = serde_json::json!({"key": "value"});
2215        let req = ToolRequest::new(ctx.clone(), args.clone());
2216
2217        assert_eq!(req.args, args);
2218    }
2219
2220    #[test]
2221    fn test_no_params_schema() {
2222        // NoParams should produce a schema with type: "object"
2223        let schema = schemars::schema_for!(NoParams);
2224        let schema_value = serde_json::to_value(&schema).unwrap();
2225        assert_eq!(
2226            schema_value.get("type").and_then(|v| v.as_str()),
2227            Some("object"),
2228            "NoParams should generate type: object schema"
2229        );
2230    }
2231
2232    #[test]
2233    fn test_no_params_deserialize() {
2234        // NoParams should deserialize from various inputs
2235        let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
2236        assert_eq!(from_empty_object, NoParams);
2237
2238        let from_null: NoParams = serde_json::from_str("null").unwrap();
2239        assert_eq!(from_null, NoParams);
2240
2241        // Should also accept objects with unexpected fields (ignored)
2242        let from_object_with_fields: NoParams =
2243            serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
2244        assert_eq!(from_object_with_fields, NoParams);
2245    }
2246
2247    #[tokio::test]
2248    async fn test_no_params_type_in_handler() {
2249        // NoParams can be used as a handler input type
2250        let tool = ToolBuilder::new("status")
2251            .description("Get status")
2252            .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
2253            .build();
2254
2255        // Check schema has type: object (not type: null like () would produce)
2256        let schema = tool.definition().input_schema;
2257        assert_eq!(
2258            schema.get("type").and_then(|v| v.as_str()),
2259            Some("object"),
2260            "NoParams handler should produce type: object schema"
2261        );
2262
2263        // Should work with empty input
2264        let result = tool.call(serde_json::json!({})).await;
2265        assert!(!result.is_error);
2266    }
2267
2268    #[tokio::test]
2269    async fn test_tool_with_name_prefix() {
2270        #[derive(Debug, Deserialize, JsonSchema)]
2271        struct Input {
2272            value: String,
2273        }
2274
2275        let tool = ToolBuilder::new("query")
2276            .description("Query something")
2277            .title("Query Tool")
2278            .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
2279            .build();
2280
2281        // Create prefixed version
2282        let prefixed = tool.with_name_prefix("db");
2283
2284        // Check name is prefixed
2285        assert_eq!(prefixed.name, "db.query");
2286
2287        // Check other fields are preserved
2288        assert_eq!(prefixed.description.as_deref(), Some("Query something"));
2289        assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
2290
2291        // Check the tool still works
2292        let result = prefixed
2293            .call(serde_json::json!({"value": "test input"}))
2294            .await;
2295        assert!(!result.is_error);
2296        match &result.content[0] {
2297            Content::Text { text, .. } => assert_eq!(text, "test input"),
2298            _ => panic!("Expected text content"),
2299        }
2300    }
2301
2302    #[tokio::test]
2303    async fn test_tool_with_name_prefix_multiple_levels() {
2304        let tool = ToolBuilder::new("action")
2305            .description("Do something")
2306            .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
2307            .build();
2308
2309        // Apply multiple prefixes
2310        let prefixed = tool.with_name_prefix("level1");
2311        assert_eq!(prefixed.name, "level1.action");
2312
2313        let double_prefixed = prefixed.with_name_prefix("level0");
2314        assert_eq!(double_prefixed.name, "level0.level1.action");
2315    }
2316
2317    // =============================================================================
2318    // no_params_handler tests
2319    // =============================================================================
2320
2321    #[tokio::test]
2322    async fn test_no_params_handler_basic() {
2323        let tool = ToolBuilder::new("get_status")
2324            .description("Get current status")
2325            .no_params_handler(|| async { Ok(CallToolResult::text("OK")) })
2326            .build();
2327
2328        assert_eq!(tool.name, "get_status");
2329        assert_eq!(tool.description.as_deref(), Some("Get current status"));
2330
2331        // Should work with empty args
2332        let result = tool.call(serde_json::json!({})).await;
2333        assert!(!result.is_error);
2334        assert_eq!(result.first_text().unwrap(), "OK");
2335
2336        // Should also work with null args
2337        let result = tool.call(serde_json::json!(null)).await;
2338        assert!(!result.is_error);
2339
2340        // Check input schema has type: object
2341        let schema = tool.definition().input_schema;
2342        assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
2343    }
2344
2345    #[tokio::test]
2346    async fn test_no_params_handler_with_captured_state() {
2347        let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
2348        let counter_ref = counter.clone();
2349
2350        let tool = ToolBuilder::new("increment")
2351            .description("Increment counter")
2352            .no_params_handler(move || {
2353                let c = counter_ref.clone();
2354                async move {
2355                    let prev = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2356                    Ok(CallToolResult::text(format!("Incremented from {}", prev)))
2357                }
2358            })
2359            .build();
2360
2361        // Call multiple times
2362        let _ = tool.call(serde_json::json!({})).await;
2363        let _ = tool.call(serde_json::json!({})).await;
2364        let result = tool.call(serde_json::json!({})).await;
2365
2366        assert!(!result.is_error);
2367        assert_eq!(result.first_text().unwrap(), "Incremented from 2");
2368        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
2369    }
2370
2371    #[tokio::test]
2372    async fn test_no_params_handler_with_layer() {
2373        use std::time::Duration;
2374        use tower::timeout::TimeoutLayer;
2375
2376        let tool = ToolBuilder::new("slow_status")
2377            .description("Slow status check")
2378            .no_params_handler(|| async {
2379                tokio::time::sleep(Duration::from_millis(10)).await;
2380                Ok(CallToolResult::text("done"))
2381            })
2382            .layer(TimeoutLayer::new(Duration::from_secs(1)))
2383            .build();
2384
2385        let result = tool.call(serde_json::json!({})).await;
2386        assert!(!result.is_error);
2387        assert_eq!(result.first_text().unwrap(), "done");
2388    }
2389
2390    #[tokio::test]
2391    async fn test_no_params_handler_timeout() {
2392        use std::time::Duration;
2393        use tower::timeout::TimeoutLayer;
2394
2395        let tool = ToolBuilder::new("very_slow_status")
2396            .description("Very slow status check")
2397            .no_params_handler(|| async {
2398                tokio::time::sleep(Duration::from_millis(200)).await;
2399                Ok(CallToolResult::text("done"))
2400            })
2401            .layer(TimeoutLayer::new(Duration::from_millis(50)))
2402            .build();
2403
2404        let result = tool.call(serde_json::json!({})).await;
2405        assert!(result.is_error);
2406        let msg = result.first_text().unwrap().to_lowercase();
2407        assert!(
2408            msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
2409            "Expected timeout error, got: {}",
2410            msg
2411        );
2412    }
2413
2414    #[tokio::test]
2415    async fn test_no_params_handler_with_multiple_layers() {
2416        use std::time::Duration;
2417        use tower::limit::ConcurrencyLimitLayer;
2418        use tower::timeout::TimeoutLayer;
2419
2420        let tool = ToolBuilder::new("multi_layer_status")
2421            .description("Status with multiple layers")
2422            .no_params_handler(|| async { Ok(CallToolResult::text("status ok")) })
2423            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2424            .layer(ConcurrencyLimitLayer::new(10))
2425            .build();
2426
2427        let result = tool.call(serde_json::json!({})).await;
2428        assert!(!result.is_error);
2429        assert_eq!(result.first_text().unwrap(), "status ok");
2430    }
2431
2432    // =========================================================================
2433    // Guard tests
2434    // =========================================================================
2435
2436    #[tokio::test]
2437    async fn test_guard_allows_request() {
2438        #[derive(Debug, Deserialize, JsonSchema)]
2439        #[allow(dead_code)]
2440        struct DeleteInput {
2441            id: String,
2442            confirm: bool,
2443        }
2444
2445        let tool = ToolBuilder::new("delete")
2446            .description("Delete a record")
2447            .handler(|input: DeleteInput| async move {
2448                Ok(CallToolResult::text(format!("deleted {}", input.id)))
2449            })
2450            .guard(|req: &ToolRequest| {
2451                let confirm = req
2452                    .args
2453                    .get("confirm")
2454                    .and_then(|v| v.as_bool())
2455                    .unwrap_or(false);
2456                if !confirm {
2457                    return Err("Must set confirm=true to delete".to_string());
2458                }
2459                Ok(())
2460            })
2461            .build();
2462
2463        let result = tool
2464            .call(serde_json::json!({"id": "abc", "confirm": true}))
2465            .await;
2466        assert!(!result.is_error);
2467        assert_eq!(result.first_text().unwrap(), "deleted abc");
2468    }
2469
2470    #[tokio::test]
2471    async fn test_guard_rejects_request() {
2472        #[derive(Debug, Deserialize, JsonSchema)]
2473        #[allow(dead_code)]
2474        struct DeleteInput2 {
2475            id: String,
2476            confirm: bool,
2477        }
2478
2479        let tool = ToolBuilder::new("delete2")
2480            .description("Delete a record")
2481            .handler(|input: DeleteInput2| async move {
2482                Ok(CallToolResult::text(format!("deleted {}", input.id)))
2483            })
2484            .guard(|req: &ToolRequest| {
2485                let confirm = req
2486                    .args
2487                    .get("confirm")
2488                    .and_then(|v| v.as_bool())
2489                    .unwrap_or(false);
2490                if !confirm {
2491                    return Err("Must set confirm=true to delete".to_string());
2492                }
2493                Ok(())
2494            })
2495            .build();
2496
2497        let result = tool
2498            .call(serde_json::json!({"id": "abc", "confirm": false}))
2499            .await;
2500        assert!(result.is_error);
2501        assert!(
2502            result
2503                .first_text()
2504                .unwrap()
2505                .contains("Must set confirm=true")
2506        );
2507    }
2508
2509    #[tokio::test]
2510    async fn test_guard_with_layer() {
2511        use std::time::Duration;
2512        use tower::timeout::TimeoutLayer;
2513
2514        let tool = ToolBuilder::new("guarded_timeout")
2515            .description("Guarded with timeout")
2516            .handler(|input: GreetInput| async move {
2517                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
2518            })
2519            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2520            .guard(|_req: &ToolRequest| Ok(()))
2521            .build();
2522
2523        let result = tool.call(serde_json::json!({"name": "World"})).await;
2524        assert!(!result.is_error);
2525        assert_eq!(result.first_text().unwrap(), "Hello, World!");
2526    }
2527
2528    #[tokio::test]
2529    async fn test_guard_on_no_params_handler() {
2530        let allowed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(true));
2531        let allowed_clone = allowed.clone();
2532
2533        let tool = ToolBuilder::new("status")
2534            .description("Get status")
2535            .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2536            .guard(move |_req: &ToolRequest| {
2537                if allowed_clone.load(std::sync::atomic::Ordering::Relaxed) {
2538                    Ok(())
2539                } else {
2540                    Err("Access denied".to_string())
2541                }
2542            })
2543            .build();
2544
2545        // Allowed
2546        let result = tool.call(serde_json::json!({})).await;
2547        assert!(!result.is_error);
2548        assert_eq!(result.first_text().unwrap(), "ok");
2549
2550        // Denied
2551        allowed.store(false, std::sync::atomic::Ordering::Relaxed);
2552        let result = tool.call(serde_json::json!({})).await;
2553        assert!(result.is_error);
2554        assert!(result.first_text().unwrap().contains("Access denied"));
2555    }
2556
2557    #[tokio::test]
2558    async fn test_guard_on_no_params_handler_with_layer() {
2559        use std::time::Duration;
2560        use tower::timeout::TimeoutLayer;
2561
2562        let tool = ToolBuilder::new("status_layered")
2563            .description("Get status with layers")
2564            .no_params_handler(|| async { Ok(CallToolResult::text("ok")) })
2565            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2566            .guard(|_req: &ToolRequest| Ok(()))
2567            .build();
2568
2569        let result = tool.call(serde_json::json!({})).await;
2570        assert!(!result.is_error);
2571        assert_eq!(result.first_text().unwrap(), "ok");
2572    }
2573
2574    #[tokio::test]
2575    async fn test_guard_on_extractor_handler() {
2576        use std::sync::Arc;
2577
2578        #[derive(Clone)]
2579        struct AppState {
2580            prefix: String,
2581        }
2582
2583        #[derive(Debug, Deserialize, JsonSchema)]
2584        struct QueryInput {
2585            query: String,
2586        }
2587
2588        let state = Arc::new(AppState {
2589            prefix: "db".to_string(),
2590        });
2591
2592        let tool = ToolBuilder::new("search")
2593            .description("Search")
2594            .extractor_handler(
2595                state,
2596                |State(app): State<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
2597                    Ok(CallToolResult::text(format!(
2598                        "{}: {}",
2599                        app.prefix, input.query
2600                    )))
2601                },
2602            )
2603            .guard(|req: &ToolRequest| {
2604                let query = req.args.get("query").and_then(|v| v.as_str()).unwrap_or("");
2605                if query.is_empty() {
2606                    return Err("Query cannot be empty".to_string());
2607                }
2608                Ok(())
2609            })
2610            .build();
2611
2612        // Valid query
2613        let result = tool.call(serde_json::json!({"query": "hello"})).await;
2614        assert!(!result.is_error);
2615        assert_eq!(result.first_text().unwrap(), "db: hello");
2616
2617        // Empty query rejected by guard
2618        let result = tool.call(serde_json::json!({"query": ""})).await;
2619        assert!(result.is_error);
2620        assert!(
2621            result
2622                .first_text()
2623                .unwrap()
2624                .contains("Query cannot be empty")
2625        );
2626    }
2627
2628    #[tokio::test]
2629    async fn test_guard_on_extractor_handler_with_layer() {
2630        use std::sync::Arc;
2631        use std::time::Duration;
2632        use tower::timeout::TimeoutLayer;
2633
2634        #[derive(Clone)]
2635        struct AppState2 {
2636            prefix: String,
2637        }
2638
2639        #[derive(Debug, Deserialize, JsonSchema)]
2640        struct QueryInput2 {
2641            query: String,
2642        }
2643
2644        let state = Arc::new(AppState2 {
2645            prefix: "db".to_string(),
2646        });
2647
2648        let tool = ToolBuilder::new("search2")
2649            .description("Search with layer and guard")
2650            .extractor_handler(
2651                state,
2652                |State(app): State<Arc<AppState2>>, Json(input): Json<QueryInput2>| async move {
2653                    Ok(CallToolResult::text(format!(
2654                        "{}: {}",
2655                        app.prefix, input.query
2656                    )))
2657                },
2658            )
2659            .layer(TimeoutLayer::new(Duration::from_secs(5)))
2660            .guard(|_req: &ToolRequest| Ok(()))
2661            .build();
2662
2663        let result = tool.call(serde_json::json!({"query": "hello"})).await;
2664        assert!(!result.is_error);
2665        assert_eq!(result.first_text().unwrap(), "db: hello");
2666    }
2667
2668    #[tokio::test]
2669    async fn test_tool_with_guard_post_build() {
2670        let tool = ToolBuilder::new("admin_action")
2671            .description("Admin action")
2672            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2673            .build();
2674
2675        // Apply guard after building
2676        let guarded = tool.with_guard(|req: &ToolRequest| {
2677            let name = req.args.get("name").and_then(|v| v.as_str()).unwrap_or("");
2678            if name == "admin" {
2679                Ok(())
2680            } else {
2681                Err("Only admin allowed".to_string())
2682            }
2683        });
2684
2685        // Admin passes
2686        let result = guarded.call(serde_json::json!({"name": "admin"})).await;
2687        assert!(!result.is_error);
2688
2689        // Non-admin blocked
2690        let result = guarded.call(serde_json::json!({"name": "user"})).await;
2691        assert!(result.is_error);
2692        assert!(result.first_text().unwrap().contains("Only admin allowed"));
2693    }
2694
2695    #[tokio::test]
2696    async fn test_with_guard_preserves_tool_metadata() {
2697        let tool = ToolBuilder::new("my_tool")
2698            .description("A tool")
2699            .title("My Tool")
2700            .read_only()
2701            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("done")) })
2702            .build();
2703
2704        let guarded = tool.with_guard(|_req: &ToolRequest| Ok(()));
2705
2706        assert_eq!(guarded.name, "my_tool");
2707        assert_eq!(guarded.description.as_deref(), Some("A tool"));
2708        assert_eq!(guarded.title.as_deref(), Some("My Tool"));
2709        assert!(guarded.annotations.is_some());
2710    }
2711
2712    #[tokio::test]
2713    async fn test_guard_group_pattern() {
2714        // Demonstrate applying the same guard to multiple tools (per-group pattern)
2715        let require_auth = |req: &ToolRequest| {
2716            let token = req
2717                .args
2718                .get("_token")
2719                .and_then(|v| v.as_str())
2720                .unwrap_or("");
2721            if token == "valid" {
2722                Ok(())
2723            } else {
2724                Err("Authentication required".to_string())
2725            }
2726        };
2727
2728        let tool1 = ToolBuilder::new("action1")
2729            .description("Action 1")
2730            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action1")) })
2731            .build();
2732        let tool2 = ToolBuilder::new("action2")
2733            .description("Action 2")
2734            .handler(|_input: GreetInput| async move { Ok(CallToolResult::text("action2")) })
2735            .build();
2736
2737        // Apply same guard to both
2738        let guarded1 = tool1.with_guard(require_auth);
2739        let guarded2 = tool2.with_guard(require_auth);
2740
2741        // Without auth
2742        let r1 = guarded1
2743            .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2744            .await;
2745        let r2 = guarded2
2746            .call(serde_json::json!({"name": "test", "_token": "invalid"}))
2747            .await;
2748        assert!(r1.is_error);
2749        assert!(r2.is_error);
2750
2751        // With auth
2752        let r1 = guarded1
2753            .call(serde_json::json!({"name": "test", "_token": "valid"}))
2754            .await;
2755        let r2 = guarded2
2756            .call(serde_json::json!({"name": "test", "_token": "valid"}))
2757            .await;
2758        assert!(!r1.is_error);
2759        assert!(!r2.is_error);
2760    }
2761}