Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8
9use std::borrow::Cow;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use schemars::{JsonSchema, Schema, SchemaGenerator};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18
19use crate::context::RequestContext;
20use crate::error::{Error, Result};
21use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
22
23/// A marker type for tools that take no parameters.
24///
25/// Use this instead of `()` when defining tools with no input parameters.
26/// The unit type `()` generates `"type": "null"` in JSON Schema, which many
27/// MCP clients reject. `NoParams` generates `"type": "object"` with no
28/// required properties, which is the correct schema for parameterless tools.
29///
30/// # Example
31///
32/// ```rust
33/// use tower_mcp::{ToolBuilder, CallToolResult, NoParams};
34///
35/// let tool = ToolBuilder::new("get_status")
36///     .description("Get current status")
37///     .handler(|_input: NoParams| async move {
38///         Ok(CallToolResult::text("OK"))
39///     })
40///     .build()
41///     .unwrap();
42/// ```
43#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
44pub struct NoParams;
45
46impl<'de> serde::Deserialize<'de> for NoParams {
47    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
48    where
49        D: serde::Deserializer<'de>,
50    {
51        // Accept null, empty object, or any object (ignoring all fields)
52        struct NoParamsVisitor;
53
54        impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
55            type Value = NoParams;
56
57            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
58                formatter.write_str("null or an object")
59            }
60
61            fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
62            where
63                E: serde::de::Error,
64            {
65                Ok(NoParams)
66            }
67
68            fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
69            where
70                E: serde::de::Error,
71            {
72                Ok(NoParams)
73            }
74
75            fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
76            where
77                D: serde::Deserializer<'de>,
78            {
79                serde::Deserialize::deserialize(deserializer)
80            }
81
82            fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
83            where
84                A: serde::de::MapAccess<'de>,
85            {
86                // Drain the map, ignoring all entries
87                while map
88                    .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
89                    .is_some()
90                {}
91                Ok(NoParams)
92            }
93        }
94
95        deserializer.deserialize_any(NoParamsVisitor)
96    }
97}
98
99impl JsonSchema for NoParams {
100    fn schema_name() -> Cow<'static, str> {
101        Cow::Borrowed("NoParams")
102    }
103
104    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
105        serde_json::json!({
106            "type": "object"
107        })
108        .try_into()
109        .expect("valid schema")
110    }
111}
112
113/// Validate a tool name according to MCP spec.
114///
115/// Tool names must be:
116/// - 1-128 characters long
117/// - Contain only alphanumeric characters, underscores, hyphens, and dots
118///
119/// Returns `Ok(())` if valid, `Err` with description if invalid.
120pub fn validate_tool_name(name: &str) -> Result<()> {
121    if name.is_empty() {
122        return Err(Error::tool("Tool name cannot be empty"));
123    }
124    if name.len() > 128 {
125        return Err(Error::tool(format!(
126            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
127            name,
128            name.len()
129        )));
130    }
131    if let Some(invalid_char) = name
132        .chars()
133        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
134    {
135        return Err(Error::tool(format!(
136            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
137            name, invalid_char
138        )));
139    }
140    Ok(())
141}
142
143/// A boxed future for tool handlers
144pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
145
146/// Tool handler trait - the core abstraction for tool execution
147pub trait ToolHandler: Send + Sync {
148    /// Execute the tool with the given arguments
149    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
150
151    /// Execute the tool with request context for progress/cancellation support
152    ///
153    /// The default implementation ignores the context and calls `call`.
154    /// Override this to receive progress/cancellation context.
155    fn call_with_context(
156        &self,
157        _ctx: RequestContext,
158        args: Value,
159    ) -> BoxFuture<'_, Result<CallToolResult>> {
160        self.call(args)
161    }
162
163    /// Returns true if this handler uses context (for optimization)
164    fn uses_context(&self) -> bool {
165        false
166    }
167
168    /// Get the tool's input schema
169    fn input_schema(&self) -> Value;
170}
171
172/// A complete tool definition with handler
173pub struct Tool {
174    pub name: String,
175    pub title: Option<String>,
176    pub description: Option<String>,
177    pub output_schema: Option<Value>,
178    pub icons: Option<Vec<ToolIcon>>,
179    pub annotations: Option<ToolAnnotations>,
180    handler: Arc<dyn ToolHandler>,
181}
182
183impl std::fmt::Debug for Tool {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("Tool")
186            .field("name", &self.name)
187            .field("title", &self.title)
188            .field("description", &self.description)
189            .field("output_schema", &self.output_schema)
190            .field("icons", &self.icons)
191            .field("annotations", &self.annotations)
192            .finish_non_exhaustive()
193    }
194}
195
196impl Tool {
197    /// Create a new tool builder
198    pub fn builder(name: impl Into<String>) -> ToolBuilder {
199        ToolBuilder::new(name)
200    }
201
202    /// Get the tool definition for tools/list
203    pub fn definition(&self) -> ToolDefinition {
204        ToolDefinition {
205            name: self.name.clone(),
206            title: self.title.clone(),
207            description: self.description.clone(),
208            input_schema: self.handler.input_schema(),
209            output_schema: self.output_schema.clone(),
210            icons: self.icons.clone(),
211            annotations: self.annotations.clone(),
212        }
213    }
214
215    /// Call the tool without context
216    pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
217        self.handler.call(args)
218    }
219
220    /// Call the tool with request context
221    ///
222    /// Use this when you have a RequestContext available for progress/cancellation.
223    pub fn call_with_context(
224        &self,
225        ctx: RequestContext,
226        args: Value,
227    ) -> BoxFuture<'_, Result<CallToolResult>> {
228        self.handler.call_with_context(ctx, args)
229    }
230
231    /// Returns true if this tool uses context
232    pub fn uses_context(&self) -> bool {
233        self.handler.uses_context()
234    }
235}
236
237// =============================================================================
238// Builder API
239// =============================================================================
240
241/// Builder for creating tools with a fluent API
242///
243/// # Example
244///
245/// ```rust
246/// use tower_mcp::{ToolBuilder, CallToolResult};
247/// use schemars::JsonSchema;
248/// use serde::Deserialize;
249///
250/// #[derive(Debug, Deserialize, JsonSchema)]
251/// struct GreetInput {
252///     name: String,
253/// }
254///
255/// let tool = ToolBuilder::new("greet")
256///     .description("Greet someone by name")
257///     .handler(|input: GreetInput| async move {
258///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
259///     })
260///     .build()
261///     .expect("valid tool name");
262///
263/// assert_eq!(tool.name, "greet");
264/// ```
265pub struct ToolBuilder {
266    name: String,
267    title: Option<String>,
268    description: Option<String>,
269    output_schema: Option<Value>,
270    icons: Option<Vec<ToolIcon>>,
271    annotations: Option<ToolAnnotations>,
272}
273
274impl ToolBuilder {
275    pub fn new(name: impl Into<String>) -> Self {
276        Self {
277            name: name.into(),
278            title: None,
279            description: None,
280            output_schema: None,
281            icons: None,
282            annotations: None,
283        }
284    }
285
286    /// Set a human-readable title for the tool
287    pub fn title(mut self, title: impl Into<String>) -> Self {
288        self.title = Some(title.into());
289        self
290    }
291
292    /// Set the output schema (JSON Schema for structured output)
293    pub fn output_schema(mut self, schema: Value) -> Self {
294        self.output_schema = Some(schema);
295        self
296    }
297
298    /// Add an icon for the tool
299    pub fn icon(mut self, src: impl Into<String>) -> Self {
300        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
301            src: src.into(),
302            mime_type: None,
303            sizes: None,
304        });
305        self
306    }
307
308    /// Add an icon with metadata
309    pub fn icon_with_meta(
310        mut self,
311        src: impl Into<String>,
312        mime_type: Option<String>,
313        sizes: Option<Vec<String>>,
314    ) -> Self {
315        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
316            src: src.into(),
317            mime_type,
318            sizes,
319        });
320        self
321    }
322
323    /// Set the tool description
324    pub fn description(mut self, description: impl Into<String>) -> Self {
325        self.description = Some(description.into());
326        self
327    }
328
329    /// Mark the tool as read-only (does not modify state)
330    pub fn read_only(mut self) -> Self {
331        self.annotations
332            .get_or_insert_with(ToolAnnotations::default)
333            .read_only_hint = true;
334        self
335    }
336
337    /// Mark the tool as non-destructive
338    pub fn non_destructive(mut self) -> Self {
339        self.annotations
340            .get_or_insert_with(ToolAnnotations::default)
341            .destructive_hint = false;
342        self
343    }
344
345    /// Mark the tool as idempotent (same args = same effect)
346    pub fn idempotent(mut self) -> Self {
347        self.annotations
348            .get_or_insert_with(ToolAnnotations::default)
349            .idempotent_hint = true;
350        self
351    }
352
353    /// Set tool annotations directly
354    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
355        self.annotations = Some(annotations);
356        self
357    }
358
359    /// Specify input type and handler.
360    ///
361    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
362    /// The handler receives the deserialized input and returns a `CallToolResult`.
363    ///
364    /// # State Sharing
365    ///
366    /// To share state across tool calls (e.g., database connections, API clients),
367    /// wrap your state in an `Arc` and clone it into the async block:
368    ///
369    /// ```rust
370    /// use std::sync::Arc;
371    /// use tower_mcp::{ToolBuilder, CallToolResult};
372    /// use schemars::JsonSchema;
373    /// use serde::Deserialize;
374    ///
375    /// struct AppState {
376    ///     api_key: String,
377    /// }
378    ///
379    /// #[derive(Debug, Deserialize, JsonSchema)]
380    /// struct MyInput {
381    ///     query: String,
382    /// }
383    ///
384    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
385    ///
386    /// let tool = ToolBuilder::new("my_tool")
387    ///     .description("A tool that uses shared state")
388    ///     .handler(move |input: MyInput| {
389    ///         let state = state.clone(); // Clone Arc for the async block
390    ///         async move {
391    ///             // Use state.api_key here...
392    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
393    ///         }
394    ///     })
395    ///     .build()
396    ///     .expect("valid tool name");
397    /// ```
398    ///
399    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
400    /// cloning it inside the closure body allows each async invocation to
401    /// have its own reference to the shared state.
402    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
403    where
404        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
405        F: Fn(I) -> Fut + Send + Sync + 'static,
406        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
407    {
408        ToolBuilderWithHandler {
409            name: self.name,
410            title: self.title,
411            description: self.description,
412            output_schema: self.output_schema,
413            icons: self.icons,
414            annotations: self.annotations,
415            handler,
416            _phantom: std::marker::PhantomData,
417        }
418    }
419
420    /// Specify input type and context-aware handler
421    ///
422    /// The handler receives a `RequestContext` for progress reporting and
423    /// cancellation checking, along with the deserialized input.
424    ///
425    /// # Example
426    ///
427    /// ```rust
428    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
429    /// use schemars::JsonSchema;
430    /// use serde::Deserialize;
431    ///
432    /// #[derive(Debug, Deserialize, JsonSchema)]
433    /// struct ProcessInput {
434    ///     items: Vec<String>,
435    /// }
436    ///
437    /// let tool = ToolBuilder::new("process")
438    ///     .description("Process items with progress")
439    ///     .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
440    ///         for (i, item) in input.items.iter().enumerate() {
441    ///             if ctx.is_cancelled() {
442    ///                 return Ok(CallToolResult::error("Cancelled"));
443    ///             }
444    ///             ctx.report_progress(i as f64, Some(input.items.len() as f64), Some("Processing...")).await;
445    ///             // Process item...
446    ///         }
447    ///         Ok(CallToolResult::text("Done"))
448    ///     })
449    ///     .build()
450    ///     .expect("valid tool name");
451    /// ```
452    pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
453    where
454        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
455        F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
456        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
457    {
458        ToolBuilderWithContextHandler {
459            name: self.name,
460            title: self.title,
461            description: self.description,
462            output_schema: self.output_schema,
463            icons: self.icons,
464            annotations: self.annotations,
465            handler,
466            _phantom: std::marker::PhantomData,
467        }
468    }
469
470    /// Specify input type, shared state, and handler.
471    ///
472    /// The state is cloned for each invocation, so wrapping it in an `Arc`
473    /// is recommended for expensive-to-clone types. This eliminates the
474    /// boilerplate of cloning state inside a `move` closure.
475    ///
476    /// # Example
477    ///
478    /// ```rust
479    /// use std::sync::Arc;
480    /// use tower_mcp::{ToolBuilder, CallToolResult};
481    /// use schemars::JsonSchema;
482    /// use serde::Deserialize;
483    ///
484    /// #[derive(Debug, Deserialize, JsonSchema)]
485    /// struct QueryInput { query: String }
486    ///
487    /// struct Db { connection_string: String }
488    ///
489    /// let db = Arc::new(Db { connection_string: "postgres://...".to_string() });
490    ///
491    /// let tool = ToolBuilder::new("search")
492    ///     .description("Search the database")
493    ///     .handler_with_state(db, |db: Arc<Db>, input: QueryInput| async move {
494    ///         Ok(CallToolResult::text(format!("Queried: {}", input.query)))
495    ///     })
496    ///     .build()
497    ///     .expect("valid tool name");
498    /// ```
499    pub fn handler_with_state<S, I, F, Fut>(
500        self,
501        state: S,
502        handler: F,
503    ) -> ToolBuilderWithHandler<
504        I,
505        impl Fn(I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
506    >
507    where
508        S: Clone + Send + Sync + 'static,
509        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
510        F: Fn(S, I) -> Fut + Send + Sync + 'static,
511        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
512    {
513        let handler = Arc::new(handler);
514        self.handler(move |input: I| {
515            let state = state.clone();
516            let handler = handler.clone();
517            Box::pin(async move { handler(state, input).await })
518                as BoxFuture<'static, Result<CallToolResult>>
519        })
520    }
521
522    /// Specify input type, shared state, and context-aware handler.
523    ///
524    /// Combines state injection with `RequestContext` access for progress
525    /// reporting, cancellation, sampling, and logging.
526    ///
527    /// # Example
528    ///
529    /// ```rust
530    /// use std::sync::Arc;
531    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
532    /// use schemars::JsonSchema;
533    /// use serde::Deserialize;
534    ///
535    /// #[derive(Debug, Deserialize, JsonSchema)]
536    /// struct QueryInput { query: String }
537    ///
538    /// struct Db { connection_string: String }
539    ///
540    /// let db = Arc::new(Db { connection_string: "postgres://...".to_string() });
541    ///
542    /// let tool = ToolBuilder::new("search")
543    ///     .description("Search the database with progress")
544    ///     .handler_with_state_and_context(db, |db: Arc<Db>, ctx: RequestContext, input: QueryInput| async move {
545    ///         ctx.report_progress(0.0, Some(1.0), Some("Searching...")).await;
546    ///         Ok(CallToolResult::text(format!("Queried: {}", input.query)))
547    ///     })
548    ///     .build()
549    ///     .expect("valid tool name");
550    /// ```
551    pub fn handler_with_state_and_context<S, I, F, Fut>(
552        self,
553        state: S,
554        handler: F,
555    ) -> ToolBuilderWithContextHandler<
556        I,
557        impl Fn(RequestContext, I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
558    >
559    where
560        S: Clone + Send + Sync + 'static,
561        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
562        F: Fn(S, RequestContext, I) -> Fut + Send + Sync + 'static,
563        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
564    {
565        let handler = Arc::new(handler);
566        self.handler_with_context(move |ctx: RequestContext, input: I| {
567            let state = state.clone();
568            let handler = handler.clone();
569            Box::pin(async move { handler(state, ctx, input).await })
570                as BoxFuture<'static, Result<CallToolResult>>
571        })
572    }
573
574    /// Create a tool that takes no parameters.
575    ///
576    /// The handler receives no input arguments. An empty object input schema
577    /// is generated automatically. Returns `Result<Tool>` directly.
578    ///
579    /// # Example
580    ///
581    /// ```rust
582    /// use tower_mcp::{ToolBuilder, CallToolResult};
583    ///
584    /// let tool = ToolBuilder::new("server_time")
585    ///     .description("Get the current server time")
586    ///     .handler_no_params(|| async {
587    ///         Ok(CallToolResult::text("2025-01-01T00:00:00Z"))
588    ///     })
589    ///     .expect("valid tool name");
590    ///
591    /// assert_eq!(tool.name, "server_time");
592    /// ```
593    pub fn handler_no_params<F, Fut>(self, handler: F) -> Result<Tool>
594    where
595        F: Fn() -> Fut + Send + Sync + 'static,
596        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
597    {
598        validate_tool_name(&self.name)?;
599        Ok(Tool {
600            name: self.name,
601            title: self.title,
602            description: self.description,
603            output_schema: self.output_schema,
604            icons: self.icons,
605            annotations: self.annotations,
606            handler: Arc::new(NoParamsHandler { handler }),
607        })
608    }
609
610    /// Create a tool with raw JSON handling (no automatic deserialization)
611    ///
612    /// Returns an error if the tool name is invalid.
613    pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
614    where
615        F: Fn(Value) -> Fut + Send + Sync + 'static,
616        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
617    {
618        validate_tool_name(&self.name)?;
619        Ok(Tool {
620            name: self.name,
621            title: self.title,
622            description: self.description,
623            output_schema: self.output_schema,
624            icons: self.icons,
625            annotations: self.annotations,
626            handler: Arc::new(RawHandler { handler }),
627        })
628    }
629
630    /// Create a tool with raw JSON handling and request context
631    ///
632    /// The handler receives a `RequestContext` for progress reporting,
633    /// cancellation, sampling, and logging, along with raw JSON arguments.
634    ///
635    /// Returns an error if the tool name is invalid.
636    pub fn raw_handler_with_context<F, Fut>(self, handler: F) -> Result<Tool>
637    where
638        F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
639        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
640    {
641        validate_tool_name(&self.name)?;
642        Ok(Tool {
643            name: self.name,
644            title: self.title,
645            description: self.description,
646            output_schema: self.output_schema,
647            icons: self.icons,
648            annotations: self.annotations,
649            handler: Arc::new(RawContextHandler { handler }),
650        })
651    }
652}
653
654/// Builder state after handler is specified
655pub struct ToolBuilderWithHandler<I, F> {
656    name: String,
657    title: Option<String>,
658    description: Option<String>,
659    output_schema: Option<Value>,
660    icons: Option<Vec<ToolIcon>>,
661    annotations: Option<ToolAnnotations>,
662    handler: F,
663    _phantom: std::marker::PhantomData<I>,
664}
665
666impl<I, F, Fut> ToolBuilderWithHandler<I, F>
667where
668    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
669    F: Fn(I) -> Fut + Send + Sync + 'static,
670    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
671{
672    /// Build the tool
673    ///
674    /// Returns an error if the tool name is invalid.
675    pub fn build(self) -> Result<Tool> {
676        validate_tool_name(&self.name)?;
677        Ok(Tool {
678            name: self.name,
679            title: self.title,
680            description: self.description,
681            output_schema: self.output_schema,
682            icons: self.icons,
683            annotations: self.annotations,
684            handler: Arc::new(TypedHandler {
685                handler: self.handler,
686                _phantom: std::marker::PhantomData,
687            }),
688        })
689    }
690}
691
692/// Builder state after context-aware handler is specified
693pub struct ToolBuilderWithContextHandler<I, F> {
694    name: String,
695    title: Option<String>,
696    description: Option<String>,
697    output_schema: Option<Value>,
698    icons: Option<Vec<ToolIcon>>,
699    annotations: Option<ToolAnnotations>,
700    handler: F,
701    _phantom: std::marker::PhantomData<I>,
702}
703
704impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
705where
706    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
707    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
708    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
709{
710    /// Build the tool
711    ///
712    /// Returns an error if the tool name is invalid.
713    pub fn build(self) -> Result<Tool> {
714        validate_tool_name(&self.name)?;
715        Ok(Tool {
716            name: self.name,
717            title: self.title,
718            description: self.description,
719            output_schema: self.output_schema,
720            icons: self.icons,
721            annotations: self.annotations,
722            handler: Arc::new(ContextAwareHandler {
723                handler: self.handler,
724                _phantom: std::marker::PhantomData,
725            }),
726        })
727    }
728}
729
730// =============================================================================
731// Handler implementations
732// =============================================================================
733
734/// Handler that deserializes input to a specific type
735struct TypedHandler<I, F> {
736    handler: F,
737    _phantom: std::marker::PhantomData<I>,
738}
739
740impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
741where
742    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
743    F: Fn(I) -> Fut + Send + Sync + 'static,
744    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
745{
746    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
747        Box::pin(async move {
748            let input: I = serde_json::from_value(args)
749                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
750            (self.handler)(input).await
751        })
752    }
753
754    fn input_schema(&self) -> Value {
755        let schema = schemars::schema_for!(I);
756        serde_json::to_value(schema).unwrap_or_else(|_| {
757            serde_json::json!({
758                "type": "object"
759            })
760        })
761    }
762}
763
764/// Handler that works with raw JSON
765struct RawHandler<F> {
766    handler: F,
767}
768
769impl<F, Fut> ToolHandler for RawHandler<F>
770where
771    F: Fn(Value) -> Fut + Send + Sync + 'static,
772    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
773{
774    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
775        Box::pin((self.handler)(args))
776    }
777
778    fn input_schema(&self) -> Value {
779        // Raw handlers accept any JSON
780        serde_json::json!({
781            "type": "object",
782            "additionalProperties": true
783        })
784    }
785}
786
787/// Handler that works with raw JSON and request context
788struct RawContextHandler<F> {
789    handler: F,
790}
791
792impl<F, Fut> ToolHandler for RawContextHandler<F>
793where
794    F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
795    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
796{
797    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
798        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
799        self.call_with_context(ctx, args)
800    }
801
802    fn call_with_context(
803        &self,
804        ctx: RequestContext,
805        args: Value,
806    ) -> BoxFuture<'_, Result<CallToolResult>> {
807        Box::pin((self.handler)(ctx, args))
808    }
809
810    fn uses_context(&self) -> bool {
811        true
812    }
813
814    fn input_schema(&self) -> Value {
815        // Raw context handlers accept any JSON object
816        serde_json::json!({
817            "type": "object",
818            "additionalProperties": true
819        })
820    }
821}
822
823/// Handler that receives request context for progress/cancellation
824struct ContextAwareHandler<I, F> {
825    handler: F,
826    _phantom: std::marker::PhantomData<I>,
827}
828
829impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
830where
831    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
832    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
833    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
834{
835    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
836        // When called without context, create a dummy context
837        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
838        self.call_with_context(ctx, args)
839    }
840
841    fn call_with_context(
842        &self,
843        ctx: RequestContext,
844        args: Value,
845    ) -> BoxFuture<'_, Result<CallToolResult>> {
846        Box::pin(async move {
847            let input: I = serde_json::from_value(args)
848                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
849            (self.handler)(ctx, input).await
850        })
851    }
852
853    fn uses_context(&self) -> bool {
854        true
855    }
856
857    fn input_schema(&self) -> Value {
858        let schema = schemars::schema_for!(I);
859        serde_json::to_value(schema).unwrap_or_else(|_| {
860            serde_json::json!({
861                "type": "object"
862            })
863        })
864    }
865}
866
867/// Handler that takes no parameters
868struct NoParamsHandler<F> {
869    handler: F,
870}
871
872impl<F, Fut> ToolHandler for NoParamsHandler<F>
873where
874    F: Fn() -> Fut + Send + Sync + 'static,
875    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
876{
877    fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
878        Box::pin((self.handler)())
879    }
880
881    fn input_schema(&self) -> Value {
882        serde_json::json!({
883            "type": "object",
884            "properties": {}
885        })
886    }
887}
888
889// =============================================================================
890// Trait-based tool definition
891// =============================================================================
892
893/// Trait for defining tools with full control
894///
895/// Implement this trait when you need more control than the builder provides,
896/// or when you want to define tools as standalone types.
897///
898/// # Example
899///
900/// ```rust
901/// use tower_mcp::tool::McpTool;
902/// use tower_mcp::error::Result;
903/// use schemars::JsonSchema;
904/// use serde::{Deserialize, Serialize};
905///
906/// #[derive(Debug, Deserialize, JsonSchema)]
907/// struct AddInput {
908///     a: i64,
909///     b: i64,
910/// }
911///
912/// struct AddTool;
913///
914/// impl McpTool for AddTool {
915///     const NAME: &'static str = "add";
916///     const DESCRIPTION: &'static str = "Add two numbers";
917///
918///     type Input = AddInput;
919///     type Output = i64;
920///
921///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
922///         Ok(input.a + input.b)
923///     }
924/// }
925///
926/// let tool = AddTool.into_tool().expect("valid tool name");
927/// assert_eq!(tool.name, "add");
928/// ```
929pub trait McpTool: Send + Sync + 'static {
930    const NAME: &'static str;
931    const DESCRIPTION: &'static str;
932
933    type Input: JsonSchema + DeserializeOwned + Send;
934    type Output: Serialize + Send;
935
936    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
937
938    /// Optional annotations for the tool
939    fn annotations(&self) -> Option<ToolAnnotations> {
940        None
941    }
942
943    /// Convert to a Tool instance
944    ///
945    /// Returns an error if the tool name is invalid.
946    fn into_tool(self) -> Result<Tool>
947    where
948        Self: Sized,
949    {
950        validate_tool_name(Self::NAME)?;
951        let annotations = self.annotations();
952        let tool = Arc::new(self);
953        Ok(Tool {
954            name: Self::NAME.to_string(),
955            title: None,
956            description: Some(Self::DESCRIPTION.to_string()),
957            output_schema: None,
958            icons: None,
959            annotations,
960            handler: Arc::new(McpToolHandler { tool }),
961        })
962    }
963}
964
965/// Wrapper to make McpTool implement ToolHandler
966struct McpToolHandler<T: McpTool> {
967    tool: Arc<T>,
968}
969
970impl<T: McpTool> ToolHandler for McpToolHandler<T> {
971    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
972        let tool = self.tool.clone();
973        Box::pin(async move {
974            let input: T::Input = serde_json::from_value(args)
975                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
976            let output = tool.call(input).await?;
977            let value = serde_json::to_value(output)
978                .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
979            Ok(CallToolResult::json(value))
980        })
981    }
982
983    fn input_schema(&self) -> Value {
984        let schema = schemars::schema_for!(T::Input);
985        serde_json::to_value(schema).unwrap_or_else(|_| {
986            serde_json::json!({
987                "type": "object"
988            })
989        })
990    }
991}
992
993#[cfg(test)]
994mod tests {
995    use super::*;
996    use schemars::JsonSchema;
997    use serde::Deserialize;
998
999    #[derive(Debug, Deserialize, JsonSchema)]
1000    struct GreetInput {
1001        name: String,
1002    }
1003
1004    #[tokio::test]
1005    async fn test_builder_tool() {
1006        let tool = ToolBuilder::new("greet")
1007            .description("Greet someone")
1008            .handler(|input: GreetInput| async move {
1009                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1010            })
1011            .build()
1012            .expect("valid tool name");
1013
1014        assert_eq!(tool.name, "greet");
1015        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1016
1017        let result = tool
1018            .call(serde_json::json!({"name": "World"}))
1019            .await
1020            .unwrap();
1021
1022        assert!(!result.is_error);
1023    }
1024
1025    #[tokio::test]
1026    async fn test_raw_handler() {
1027        let tool = ToolBuilder::new("echo")
1028            .description("Echo input")
1029            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
1030            .expect("valid tool name");
1031
1032        let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
1033
1034        assert!(!result.is_error);
1035    }
1036
1037    #[test]
1038    fn test_invalid_tool_name_empty() {
1039        let result = ToolBuilder::new("")
1040            .description("Empty name")
1041            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1042
1043        assert!(result.is_err());
1044        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1045    }
1046
1047    #[test]
1048    fn test_invalid_tool_name_too_long() {
1049        let long_name = "a".repeat(129);
1050        let result = ToolBuilder::new(long_name)
1051            .description("Too long")
1052            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1053
1054        assert!(result.is_err());
1055        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1056    }
1057
1058    #[test]
1059    fn test_invalid_tool_name_bad_chars() {
1060        let result = ToolBuilder::new("my tool!")
1061            .description("Bad chars")
1062            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1063
1064        assert!(result.is_err());
1065        assert!(
1066            result
1067                .unwrap_err()
1068                .to_string()
1069                .contains("invalid character")
1070        );
1071    }
1072
1073    #[test]
1074    fn test_valid_tool_names() {
1075        // All valid characters
1076        let names = [
1077            "my_tool",
1078            "my-tool",
1079            "my.tool",
1080            "MyTool123",
1081            "a",
1082            &"a".repeat(128),
1083        ];
1084        for name in names {
1085            let result = ToolBuilder::new(name)
1086                .description("Valid")
1087                .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1088            assert!(result.is_ok(), "Expected '{}' to be valid", name);
1089        }
1090    }
1091
1092    #[tokio::test]
1093    async fn test_context_aware_handler() {
1094        use crate::context::{RequestContext, notification_channel};
1095        use crate::protocol::{ProgressToken, RequestId};
1096
1097        #[derive(Debug, Deserialize, JsonSchema)]
1098        struct ProcessInput {
1099            count: i32,
1100        }
1101
1102        let tool = ToolBuilder::new("process")
1103            .description("Process with context")
1104            .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
1105                // Simulate progress reporting
1106                for i in 0..input.count {
1107                    if ctx.is_cancelled() {
1108                        return Ok(CallToolResult::error("Cancelled"));
1109                    }
1110                    ctx.report_progress(i as f64, Some(input.count as f64), None)
1111                        .await;
1112                }
1113                Ok(CallToolResult::text(format!(
1114                    "Processed {} items",
1115                    input.count
1116                )))
1117            })
1118            .build()
1119            .expect("valid tool name");
1120
1121        assert_eq!(tool.name, "process");
1122        assert!(tool.uses_context());
1123
1124        // Test with a context that has progress token and notification sender
1125        let (tx, mut rx) = notification_channel(10);
1126        let ctx = RequestContext::new(RequestId::Number(1))
1127            .with_progress_token(ProgressToken::Number(42))
1128            .with_notification_sender(tx);
1129
1130        let result = tool
1131            .call_with_context(ctx, serde_json::json!({"count": 3}))
1132            .await
1133            .unwrap();
1134
1135        assert!(!result.is_error);
1136
1137        // Check that progress notifications were sent
1138        let mut progress_count = 0;
1139        while rx.try_recv().is_ok() {
1140            progress_count += 1;
1141        }
1142        assert_eq!(progress_count, 3);
1143    }
1144
1145    #[tokio::test]
1146    async fn test_context_aware_handler_cancellation() {
1147        use crate::context::RequestContext;
1148        use crate::protocol::RequestId;
1149        use std::sync::Arc;
1150        use std::sync::atomic::{AtomicI32, Ordering};
1151
1152        #[derive(Debug, Deserialize, JsonSchema)]
1153        struct LongRunningInput {
1154            iterations: i32,
1155        }
1156
1157        let iterations_completed = Arc::new(AtomicI32::new(0));
1158        let iterations_ref = iterations_completed.clone();
1159
1160        let tool = ToolBuilder::new("long_running")
1161            .description("Long running task")
1162            .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
1163                let completed = iterations_ref.clone();
1164                async move {
1165                    for i in 0..input.iterations {
1166                        if ctx.is_cancelled() {
1167                            return Ok(CallToolResult::error("Cancelled"));
1168                        }
1169                        completed.fetch_add(1, Ordering::SeqCst);
1170                        // Simulate work
1171                        tokio::task::yield_now().await;
1172                        // Cancel after iteration 2
1173                        if i == 2 {
1174                            ctx.cancellation_token().cancel();
1175                        }
1176                    }
1177                    Ok(CallToolResult::text("Done"))
1178                }
1179            })
1180            .build()
1181            .expect("valid tool name");
1182
1183        let ctx = RequestContext::new(RequestId::Number(1));
1184
1185        let result = tool
1186            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1187            .await
1188            .unwrap();
1189
1190        // Should have been cancelled after 3 iterations (0, 1, 2)
1191        // The next iteration (3) checks cancellation and returns
1192        assert!(result.is_error);
1193        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1194    }
1195
1196    #[tokio::test]
1197    async fn test_tool_builder_with_enhanced_fields() {
1198        let output_schema = serde_json::json!({
1199            "type": "object",
1200            "properties": {
1201                "greeting": {"type": "string"}
1202            }
1203        });
1204
1205        let tool = ToolBuilder::new("greet")
1206            .title("Greeting Tool")
1207            .description("Greet someone")
1208            .output_schema(output_schema.clone())
1209            .icon("https://example.com/icon.png")
1210            .icon_with_meta(
1211                "https://example.com/icon-large.png",
1212                Some("image/png".to_string()),
1213                Some(vec!["96x96".to_string()]),
1214            )
1215            .handler(|input: GreetInput| async move {
1216                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1217            })
1218            .build()
1219            .expect("valid tool name");
1220
1221        assert_eq!(tool.name, "greet");
1222        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1223        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1224        assert_eq!(tool.output_schema, Some(output_schema));
1225        assert!(tool.icons.is_some());
1226        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1227
1228        // Test definition includes new fields
1229        let def = tool.definition();
1230        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1231        assert!(def.output_schema.is_some());
1232        assert!(def.icons.is_some());
1233    }
1234
1235    #[tokio::test]
1236    async fn test_handler_with_state() {
1237        let shared = Arc::new("shared-state".to_string());
1238
1239        let tool = ToolBuilder::new("stateful")
1240            .description("Uses shared state")
1241            .handler_with_state(shared, |state: Arc<String>, input: GreetInput| async move {
1242                Ok(CallToolResult::text(format!(
1243                    "{}: Hello, {}!",
1244                    state, input.name
1245                )))
1246            })
1247            .build()
1248            .expect("valid tool name");
1249
1250        let result = tool
1251            .call(serde_json::json!({"name": "World"}))
1252            .await
1253            .unwrap();
1254        assert!(!result.is_error);
1255    }
1256
1257    #[tokio::test]
1258    async fn test_handler_with_state_and_context() {
1259        use crate::context::RequestContext;
1260        use crate::protocol::RequestId;
1261
1262        let shared = Arc::new(42_i32);
1263
1264        let tool = ToolBuilder::new("stateful_ctx")
1265            .description("Uses state and context")
1266            .handler_with_state_and_context(
1267                shared,
1268                |state: Arc<i32>, _ctx: RequestContext, input: GreetInput| async move {
1269                    Ok(CallToolResult::text(format!(
1270                        "{}: Hello, {}!",
1271                        state, input.name
1272                    )))
1273                },
1274            )
1275            .build()
1276            .expect("valid tool name");
1277
1278        assert!(tool.uses_context());
1279
1280        let ctx = RequestContext::new(RequestId::Number(1));
1281        let result = tool
1282            .call_with_context(ctx, serde_json::json!({"name": "World"}))
1283            .await
1284            .unwrap();
1285        assert!(!result.is_error);
1286    }
1287
1288    #[tokio::test]
1289    async fn test_handler_no_params() {
1290        let tool = ToolBuilder::new("no_params")
1291            .description("Takes no parameters")
1292            .handler_no_params(|| async { Ok(CallToolResult::text("no params result")) })
1293            .expect("valid tool name");
1294
1295        assert_eq!(tool.name, "no_params");
1296
1297        // Should work with empty args
1298        let result = tool.call(serde_json::json!({})).await.unwrap();
1299        assert!(!result.is_error);
1300
1301        // Should also work with unexpected args (ignored)
1302        let result = tool
1303            .call(serde_json::json!({"unexpected": "value"}))
1304            .await
1305            .unwrap();
1306        assert!(!result.is_error);
1307
1308        // Check input schema is an empty-properties object
1309        let schema = tool.definition().input_schema;
1310        assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1311        assert!(
1312            schema
1313                .get("properties")
1314                .unwrap()
1315                .as_object()
1316                .unwrap()
1317                .is_empty()
1318        );
1319    }
1320
1321    #[test]
1322    fn test_no_params_schema() {
1323        // NoParams should produce a schema with type: "object"
1324        let schema = schemars::schema_for!(NoParams);
1325        let schema_value = serde_json::to_value(&schema).unwrap();
1326        assert_eq!(
1327            schema_value.get("type").and_then(|v| v.as_str()),
1328            Some("object"),
1329            "NoParams should generate type: object schema"
1330        );
1331    }
1332
1333    #[test]
1334    fn test_no_params_deserialize() {
1335        // NoParams should deserialize from various inputs
1336        let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1337        assert_eq!(from_empty_object, NoParams);
1338
1339        let from_null: NoParams = serde_json::from_str("null").unwrap();
1340        assert_eq!(from_null, NoParams);
1341
1342        // Should also accept objects with unexpected fields (ignored)
1343        let from_object_with_fields: NoParams =
1344            serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1345        assert_eq!(from_object_with_fields, NoParams);
1346    }
1347
1348    #[tokio::test]
1349    async fn test_no_params_type_in_handler() {
1350        // NoParams can be used as a handler input type
1351        let tool = ToolBuilder::new("status")
1352            .description("Get status")
1353            .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1354            .build()
1355            .expect("valid tool name");
1356
1357        // Check schema has type: object (not type: null like () would produce)
1358        let schema = tool.definition().input_schema;
1359        assert_eq!(
1360            schema.get("type").and_then(|v| v.as_str()),
1361            Some("object"),
1362            "NoParams handler should produce type: object schema"
1363        );
1364
1365        // Should work with empty input
1366        let result = tool.call(serde_json::json!({})).await.unwrap();
1367        assert!(!result.is_error);
1368    }
1369}