Skip to main content

tower_mcp/
prompt.rs

1//! Prompt definition and builder API
2//!
3//! Provides ergonomic ways to define MCP prompts:
4//!
5//! 1. **Builder pattern** - Fluent API for defining prompts
6//! 2. **Trait-based** - Implement `McpPrompt` for full control
7//! 3. **Per-prompt middleware** - Apply tower middleware layers to individual prompts
8//!
9//! # Per-Prompt Middleware
10//!
11//! The `.layer()` method on `PromptBuilder` (after `.handler()`) allows applying
12//! tower middleware to a single prompt. This is useful for prompt-specific concerns
13//! like timeouts, rate limiting, or caching.
14//!
15//! ```rust
16//! use std::collections::HashMap;
17//! use std::time::Duration;
18//! use tower::timeout::TimeoutLayer;
19//! use tower_mcp::prompt::PromptBuilder;
20//! use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
21//!
22//! let prompt = PromptBuilder::new("slow_prompt")
23//!     .description("A prompt that might take a while")
24//!     .handler(|args: HashMap<String, String>| async move {
25//!         // Slow prompt generation logic...
26//!         Ok(GetPromptResult {
27//!             description: Some("Generated prompt".to_string()),
28//!             messages: vec![PromptMessage {
29//!                 role: PromptRole::User,
30//!                 content: Content::Text {
31//!                     text: "Hello!".to_string(),
32//!                     annotations: None,
33//!                 },
34//!             }],
35//!         })
36//!     })
37//!     .layer(TimeoutLayer::new(Duration::from_secs(5)));
38//!
39//! assert_eq!(prompt.name, "slow_prompt");
40//! ```
41
42use std::collections::HashMap;
43use std::convert::Infallible;
44use std::fmt;
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::task::{Context, Poll};
49
50use tokio::sync::Mutex;
51use tower::util::BoxCloneService;
52use tower::{Layer, ServiceExt};
53use tower_service::Service;
54
55use crate::context::RequestContext;
56use crate::error::{Error, Result};
57use crate::protocol::{
58    Content, GetPromptResult, PromptArgument, PromptDefinition, PromptMessage, PromptRole,
59    RequestId, ToolIcon,
60};
61
62/// A boxed future for prompt handlers
63pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
64
65// =============================================================================
66// Per-Prompt Middleware Types
67// =============================================================================
68
69/// Request type for prompt middleware.
70///
71/// Contains the request context and prompt arguments, allowing middleware
72/// to access and modify the request before it reaches the prompt handler.
73#[derive(Debug, Clone)]
74pub struct PromptRequest {
75    /// The request context with progress reporting, cancellation, etc.
76    pub context: RequestContext,
77    /// The prompt arguments (name -> value)
78    pub arguments: HashMap<String, String>,
79}
80
81impl PromptRequest {
82    /// Create a new prompt request with the given context and arguments.
83    pub fn new(context: RequestContext, arguments: HashMap<String, String>) -> Self {
84        Self { context, arguments }
85    }
86
87    /// Create a prompt request with a default context (for testing or simple use cases).
88    pub fn with_arguments(arguments: HashMap<String, String>) -> Self {
89        Self {
90            context: RequestContext::new(RequestId::Number(0)),
91            arguments,
92        }
93    }
94}
95
96/// A boxed, cloneable prompt service with `Error = Infallible`.
97///
98/// This is the service type used internally after applying middleware layers.
99/// It wraps any `Service<PromptRequest>` implementation so that the prompt
100/// handler can consume it without knowing the concrete middleware stack.
101pub type BoxPromptService = BoxCloneService<PromptRequest, GetPromptResult, Infallible>;
102
103/// A service wrapper that catches errors from middleware and converts them
104/// into prompt errors, maintaining the `Error = Infallible` contract.
105///
106/// When a middleware layer (e.g., `TimeoutLayer`) produces an error, this
107/// wrapper converts it into a prompt error. This allows error information to
108/// flow through the normal response path rather than requiring special
109/// error handling.
110pub struct PromptCatchError<S> {
111    inner: S,
112}
113
114impl<S> PromptCatchError<S> {
115    /// Create a new `PromptCatchError` wrapping the given service.
116    pub fn new(inner: S) -> Self {
117        Self { inner }
118    }
119}
120
121impl<S: Clone> Clone for PromptCatchError<S> {
122    fn clone(&self) -> Self {
123        Self {
124            inner: self.inner.clone(),
125        }
126    }
127}
128
129impl<S: fmt::Debug> fmt::Debug for PromptCatchError<S> {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.debug_struct("PromptCatchError")
132            .field("inner", &self.inner)
133            .finish()
134    }
135}
136
137impl<S> Service<PromptRequest> for PromptCatchError<S>
138where
139    S: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
140    S::Error: fmt::Display + Send,
141    S::Future: Send,
142{
143    type Response = GetPromptResult;
144    type Error = Infallible;
145    type Future =
146        Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Infallible>> + Send>>;
147
148    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
149        self.inner.poll_ready(cx).map_err(|_| unreachable!())
150    }
151
152    fn call(&mut self, req: PromptRequest) -> Self::Future {
153        let fut = self.inner.call(req);
154
155        Box::pin(async move {
156            match fut.await {
157                Ok(response) => Ok(response),
158                Err(err) => {
159                    // Convert middleware error to an error prompt result
160                    // We return a single error message as the prompt content
161                    Ok(GetPromptResult {
162                        description: Some(format!("Prompt error: {}", err)),
163                        messages: vec![PromptMessage {
164                            role: PromptRole::Assistant,
165                            content: Content::Text {
166                                text: format!("Error generating prompt: {}", err),
167                                annotations: None,
168                            },
169                        }],
170                    })
171                }
172            }
173        })
174    }
175}
176
177/// Adapts a prompt handler function into a `Service<PromptRequest>`.
178///
179/// This allows the handler to be wrapped with tower middleware layers.
180/// Used by `.layer()` on `PromptBuilderWithHandler`.
181pub struct PromptHandlerService<F> {
182    handler: F,
183}
184
185impl<F> Clone for PromptHandlerService<F>
186where
187    F: Clone,
188{
189    fn clone(&self) -> Self {
190        Self {
191            handler: self.handler.clone(),
192        }
193    }
194}
195
196impl<F, Fut> Service<PromptRequest> for PromptHandlerService<F>
197where
198    F: Fn(HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
199    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
200{
201    type Response = GetPromptResult;
202    type Error = Error;
203    type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
204
205    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
206        Poll::Ready(Ok(()))
207    }
208
209    fn call(&mut self, req: PromptRequest) -> Self::Future {
210        let handler = self.handler.clone();
211        Box::pin(async move { handler(req.arguments).await })
212    }
213}
214
215/// Adapts a context-aware prompt handler function into a `Service<PromptRequest>`.
216///
217/// Used by `.layer()` on `PromptBuilderWithContextHandler`.
218pub struct PromptContextHandlerService<F> {
219    handler: F,
220}
221
222impl<F> Clone for PromptContextHandlerService<F>
223where
224    F: Clone,
225{
226    fn clone(&self) -> Self {
227        Self {
228            handler: self.handler.clone(),
229        }
230    }
231}
232
233impl<F, Fut> Service<PromptRequest> for PromptContextHandlerService<F>
234where
235    F: Fn(RequestContext, HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
236    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
237{
238    type Response = GetPromptResult;
239    type Error = Error;
240    type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
241
242    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
243        Poll::Ready(Ok(()))
244    }
245
246    fn call(&mut self, req: PromptRequest) -> Self::Future {
247        let handler = self.handler.clone();
248        Box::pin(async move { handler(req.context, req.arguments).await })
249    }
250}
251
252/// Prompt handler trait - the core abstraction for prompt generation
253pub trait PromptHandler: Send + Sync {
254    /// Get the prompt with the given arguments
255    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>>;
256
257    /// Get the prompt with request context
258    ///
259    /// The default implementation ignores the context and calls `get`.
260    /// Override this to receive context for progress reporting, cancellation, etc.
261    fn get_with_context(
262        &self,
263        _ctx: RequestContext,
264        arguments: HashMap<String, String>,
265    ) -> BoxFuture<'_, Result<GetPromptResult>> {
266        self.get(arguments)
267    }
268
269    /// Returns true if this handler uses context (for optimization)
270    fn uses_context(&self) -> bool {
271        false
272    }
273}
274
275/// A complete prompt definition with handler
276pub struct Prompt {
277    pub name: String,
278    pub title: Option<String>,
279    pub description: Option<String>,
280    pub icons: Option<Vec<ToolIcon>>,
281    pub arguments: Vec<PromptArgument>,
282    handler: Arc<dyn PromptHandler>,
283}
284
285impl Clone for Prompt {
286    fn clone(&self) -> Self {
287        Self {
288            name: self.name.clone(),
289            title: self.title.clone(),
290            description: self.description.clone(),
291            icons: self.icons.clone(),
292            arguments: self.arguments.clone(),
293            handler: self.handler.clone(),
294        }
295    }
296}
297
298impl std::fmt::Debug for Prompt {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.debug_struct("Prompt")
301            .field("name", &self.name)
302            .field("title", &self.title)
303            .field("description", &self.description)
304            .field("icons", &self.icons)
305            .field("arguments", &self.arguments)
306            .finish_non_exhaustive()
307    }
308}
309
310impl Prompt {
311    /// Create a new prompt builder
312    pub fn builder(name: impl Into<String>) -> PromptBuilder {
313        PromptBuilder::new(name)
314    }
315
316    /// Get the prompt definition for prompts/list
317    pub fn definition(&self) -> PromptDefinition {
318        PromptDefinition {
319            name: self.name.clone(),
320            title: self.title.clone(),
321            description: self.description.clone(),
322            icons: self.icons.clone(),
323            arguments: self.arguments.clone(),
324        }
325    }
326
327    /// Get the prompt with arguments
328    pub fn get(
329        &self,
330        arguments: HashMap<String, String>,
331    ) -> BoxFuture<'_, Result<GetPromptResult>> {
332        self.handler.get(arguments)
333    }
334
335    /// Get the prompt with request context
336    ///
337    /// Use this when you have a RequestContext available for progress/cancellation.
338    pub fn get_with_context(
339        &self,
340        ctx: RequestContext,
341        arguments: HashMap<String, String>,
342    ) -> BoxFuture<'_, Result<GetPromptResult>> {
343        self.handler.get_with_context(ctx, arguments)
344    }
345
346    /// Returns true if this prompt uses context
347    pub fn uses_context(&self) -> bool {
348        self.handler.uses_context()
349    }
350}
351
352// =============================================================================
353// Builder API
354// =============================================================================
355
356/// Builder for creating prompts with a fluent API
357///
358/// # Example
359///
360/// ```rust
361/// use tower_mcp::prompt::PromptBuilder;
362/// use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
363///
364/// let prompt = PromptBuilder::new("greet")
365///     .description("Generate a greeting")
366///     .required_arg("name", "The name to greet")
367///     .handler(|args| async move {
368///         let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
369///         Ok(GetPromptResult {
370///             description: Some("A greeting prompt".to_string()),
371///             messages: vec![PromptMessage {
372///                 role: PromptRole::User,
373///                 content: Content::Text {
374///                     text: format!("Please greet {}", name),
375///                     annotations: None,
376///                 },
377///             }],
378///         })
379///     })
380///     .build();
381///
382/// assert_eq!(prompt.name, "greet");
383/// ```
384pub struct PromptBuilder {
385    name: String,
386    title: Option<String>,
387    description: Option<String>,
388    icons: Option<Vec<ToolIcon>>,
389    arguments: Vec<PromptArgument>,
390}
391
392impl PromptBuilder {
393    pub fn new(name: impl Into<String>) -> Self {
394        Self {
395            name: name.into(),
396            title: None,
397            description: None,
398            icons: None,
399            arguments: Vec::new(),
400        }
401    }
402
403    /// Set a human-readable title for the prompt
404    pub fn title(mut self, title: impl Into<String>) -> Self {
405        self.title = Some(title.into());
406        self
407    }
408
409    /// Set the prompt description
410    pub fn description(mut self, description: impl Into<String>) -> Self {
411        self.description = Some(description.into());
412        self
413    }
414
415    /// Add an icon for the prompt
416    pub fn icon(mut self, src: impl Into<String>) -> Self {
417        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
418            src: src.into(),
419            mime_type: None,
420            sizes: None,
421        });
422        self
423    }
424
425    /// Add an icon with metadata
426    pub fn icon_with_meta(
427        mut self,
428        src: impl Into<String>,
429        mime_type: Option<String>,
430        sizes: Option<Vec<String>>,
431    ) -> Self {
432        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
433            src: src.into(),
434            mime_type,
435            sizes,
436        });
437        self
438    }
439
440    /// Add a required argument
441    pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
442        self.arguments.push(PromptArgument {
443            name: name.into(),
444            description: Some(description.into()),
445            required: true,
446        });
447        self
448    }
449
450    /// Add an optional argument
451    pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
452        self.arguments.push(PromptArgument {
453            name: name.into(),
454            description: Some(description.into()),
455            required: false,
456        });
457        self
458    }
459
460    /// Add an argument with full control
461    pub fn argument(mut self, arg: PromptArgument) -> Self {
462        self.arguments.push(arg);
463        self
464    }
465
466    /// Set the handler function for getting the prompt
467    ///
468    /// Returns a `PromptBuilderWithHandler` which can be finalized with `.build()`
469    /// or have middleware applied with `.layer()`.
470    pub fn handler<F, Fut>(self, handler: F) -> PromptBuilderWithHandler<F>
471    where
472        F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
473        Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
474    {
475        PromptBuilderWithHandler {
476            name: self.name,
477            title: self.title,
478            description: self.description,
479            icons: self.icons,
480            arguments: self.arguments,
481            handler,
482        }
483    }
484
485    /// Set a context-aware handler function for getting the prompt
486    ///
487    /// The handler receives a `RequestContext` for progress reporting and
488    /// cancellation checking, along with the prompt arguments.
489    pub fn handler_with_context<F, Fut>(self, handler: F) -> PromptBuilderWithContextHandler<F>
490    where
491        F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
492        Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
493    {
494        PromptBuilderWithContextHandler {
495            name: self.name,
496            title: self.title,
497            description: self.description,
498            icons: self.icons,
499            arguments: self.arguments,
500            handler,
501        }
502    }
503
504    /// Create a static prompt (no arguments needed)
505    pub fn static_prompt(self, messages: Vec<PromptMessage>) -> Prompt {
506        let description = self.description.clone();
507        self.handler(move |_| {
508            let messages = messages.clone();
509            let description = description.clone();
510            async move {
511                Ok(GetPromptResult {
512                    description,
513                    messages,
514                })
515            }
516        })
517        .build()
518    }
519
520    /// Create a simple text prompt with a user message
521    pub fn user_message(self, text: impl Into<String>) -> Prompt {
522        let text = text.into();
523        self.static_prompt(vec![PromptMessage {
524            role: PromptRole::User,
525            content: Content::Text {
526                text,
527                annotations: None,
528            },
529        }])
530    }
531
532    /// Finalize the builder into a Prompt
533    ///
534    /// This is an alias for `handler(...).build()` for when you want to
535    /// explicitly mark the build step.
536    pub fn build<F, Fut>(self, handler: F) -> Prompt
537    where
538        F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
539        Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
540    {
541        self.handler(handler).build()
542    }
543}
544
545/// Builder state after handler is specified
546///
547/// This allows either calling `.build()` to create the prompt directly,
548/// or `.layer()` to apply middleware before building.
549pub struct PromptBuilderWithHandler<F> {
550    name: String,
551    title: Option<String>,
552    description: Option<String>,
553    icons: Option<Vec<ToolIcon>>,
554    arguments: Vec<PromptArgument>,
555    handler: F,
556}
557
558impl<F, Fut> PromptBuilderWithHandler<F>
559where
560    F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
561    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
562{
563    /// Build the prompt without any middleware
564    pub fn build(self) -> Prompt {
565        Prompt {
566            name: self.name,
567            title: self.title,
568            description: self.description,
569            icons: self.icons,
570            arguments: self.arguments,
571            handler: Arc::new(FnHandler {
572                handler: self.handler,
573            }),
574        }
575    }
576
577    /// Apply a tower middleware layer to this prompt
578    ///
579    /// The layer wraps the prompt handler, allowing middleware like timeouts,
580    /// rate limiting, or retries to be applied to this specific prompt.
581    ///
582    /// # Example
583    ///
584    /// ```rust
585    /// use std::collections::HashMap;
586    /// use std::time::Duration;
587    /// use tower::timeout::TimeoutLayer;
588    /// use tower_mcp::prompt::PromptBuilder;
589    /// use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
590    ///
591    /// let prompt = PromptBuilder::new("slow_prompt")
592    ///     .description("A prompt that might take a while")
593    ///     .handler(|_args: HashMap<String, String>| async move {
594    ///         Ok(GetPromptResult {
595    ///             description: Some("Generated prompt".to_string()),
596    ///             messages: vec![PromptMessage {
597    ///                 role: PromptRole::User,
598    ///                 content: Content::Text {
599    ///                     text: "Hello!".to_string(),
600    ///                     annotations: None,
601    ///                 },
602    ///             }],
603    ///         })
604    ///     })
605    ///     .layer(TimeoutLayer::new(Duration::from_secs(5)));
606    /// ```
607    pub fn layer<L>(self, layer: L) -> Prompt
608    where
609        L: Layer<PromptHandlerService<F>> + Send + Sync + 'static,
610        L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
611        <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
612        <L::Service as Service<PromptRequest>>::Future: Send,
613    {
614        let service = PromptHandlerService {
615            handler: self.handler,
616        };
617        let wrapped = layer.layer(service);
618        let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
619
620        Prompt {
621            name: self.name,
622            title: self.title,
623            description: self.description,
624            icons: self.icons,
625            arguments: self.arguments,
626            handler: Arc::new(ServiceHandler {
627                service: Mutex::new(boxed),
628            }),
629        }
630    }
631}
632
633/// Builder state after context-aware handler is specified
634pub struct PromptBuilderWithContextHandler<F> {
635    name: String,
636    title: Option<String>,
637    description: Option<String>,
638    icons: Option<Vec<ToolIcon>>,
639    arguments: Vec<PromptArgument>,
640    handler: F,
641}
642
643impl<F, Fut> PromptBuilderWithContextHandler<F>
644where
645    F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
646    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
647{
648    /// Build the prompt without any middleware
649    pub fn build(self) -> Prompt {
650        Prompt {
651            name: self.name,
652            title: self.title,
653            description: self.description,
654            icons: self.icons,
655            arguments: self.arguments,
656            handler: Arc::new(ContextAwareHandler {
657                handler: self.handler,
658            }),
659        }
660    }
661
662    /// Apply a tower middleware layer to this prompt
663    pub fn layer<L>(self, layer: L) -> Prompt
664    where
665        L: Layer<PromptContextHandlerService<F>> + Send + Sync + 'static,
666        L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
667        <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
668        <L::Service as Service<PromptRequest>>::Future: Send,
669    {
670        let service = PromptContextHandlerService {
671            handler: self.handler,
672        };
673        let wrapped = layer.layer(service);
674        let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
675
676        Prompt {
677            name: self.name,
678            title: self.title,
679            description: self.description,
680            icons: self.icons,
681            arguments: self.arguments,
682            handler: Arc::new(ServiceContextHandler {
683                service: Mutex::new(boxed),
684            }),
685        }
686    }
687}
688
689// =============================================================================
690// Handler implementations
691// =============================================================================
692
693/// Handler wrapping a function
694struct FnHandler<F> {
695    handler: F,
696}
697
698impl<F, Fut> PromptHandler for FnHandler<F>
699where
700    F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
701    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
702{
703    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
704        Box::pin((self.handler)(arguments))
705    }
706}
707
708/// Handler that receives request context
709struct ContextAwareHandler<F> {
710    handler: F,
711}
712
713impl<F, Fut> PromptHandler for ContextAwareHandler<F>
714where
715    F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + 'static,
716    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
717{
718    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
719        // When called without context, create a dummy context
720        let ctx = RequestContext::new(RequestId::Number(0));
721        self.get_with_context(ctx, arguments)
722    }
723
724    fn get_with_context(
725        &self,
726        ctx: RequestContext,
727        arguments: HashMap<String, String>,
728    ) -> BoxFuture<'_, Result<GetPromptResult>> {
729        Box::pin((self.handler)(ctx, arguments))
730    }
731
732    fn uses_context(&self) -> bool {
733        true
734    }
735}
736
737/// Handler wrapping a boxed service (used when middleware is applied)
738///
739/// Uses a Mutex to make the BoxCloneService (which is Send but not Sync) safe
740/// for use in a Sync context. Since we clone the service before each call,
741/// the lock is only held briefly during the clone.
742struct ServiceHandler {
743    service: Mutex<BoxPromptService>,
744}
745
746impl PromptHandler for ServiceHandler {
747    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
748        Box::pin(async move {
749            let req = PromptRequest::with_arguments(arguments);
750            let mut service = self.service.lock().await.clone();
751            match service.ready().await {
752                Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
753                Err(e) => match e {},
754            }
755        })
756    }
757
758    fn get_with_context(
759        &self,
760        ctx: RequestContext,
761        arguments: HashMap<String, String>,
762    ) -> BoxFuture<'_, Result<GetPromptResult>> {
763        Box::pin(async move {
764            let req = PromptRequest::new(ctx, arguments);
765            let mut service = self.service.lock().await.clone();
766            match service.ready().await {
767                Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
768                Err(e) => match e {},
769            }
770        })
771    }
772}
773
774/// Handler wrapping a boxed service for context-aware prompts
775struct ServiceContextHandler {
776    service: Mutex<BoxPromptService>,
777}
778
779impl PromptHandler for ServiceContextHandler {
780    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
781        let ctx = RequestContext::new(RequestId::Number(0));
782        self.get_with_context(ctx, arguments)
783    }
784
785    fn get_with_context(
786        &self,
787        ctx: RequestContext,
788        arguments: HashMap<String, String>,
789    ) -> BoxFuture<'_, Result<GetPromptResult>> {
790        Box::pin(async move {
791            let req = PromptRequest::new(ctx, arguments);
792            let mut service = self.service.lock().await.clone();
793            match service.ready().await {
794                Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
795                Err(e) => match e {},
796            }
797        })
798    }
799
800    fn uses_context(&self) -> bool {
801        true
802    }
803}
804
805// =============================================================================
806// Trait-based prompt definition
807// =============================================================================
808
809/// Trait for defining prompts with full control
810///
811/// Implement this trait when you need more control than the builder provides,
812/// or when you want to define prompts as standalone types.
813///
814/// # Example
815///
816/// ```rust
817/// use std::collections::HashMap;
818/// use tower_mcp::prompt::McpPrompt;
819/// use tower_mcp::protocol::{GetPromptResult, PromptArgument, PromptMessage, PromptRole, Content};
820/// use tower_mcp::error::Result;
821///
822/// struct CodeReviewPrompt;
823///
824/// impl McpPrompt for CodeReviewPrompt {
825///     const NAME: &'static str = "code_review";
826///     const DESCRIPTION: &'static str = "Review code for issues";
827///
828///     fn arguments(&self) -> Vec<PromptArgument> {
829///         vec![
830///             PromptArgument {
831///                 name: "code".to_string(),
832///                 description: Some("The code to review".to_string()),
833///                 required: true,
834///             },
835///             PromptArgument {
836///                 name: "language".to_string(),
837///                 description: Some("Programming language".to_string()),
838///                 required: false,
839///             },
840///         ]
841///     }
842///
843///     async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
844///         let code = args.get("code").map(|s| s.as_str()).unwrap_or("");
845///         let lang = args.get("language").map(|s| s.as_str()).unwrap_or("unknown");
846///
847///         Ok(GetPromptResult {
848///             description: Some("Code review prompt".to_string()),
849///             messages: vec![PromptMessage {
850///                 role: PromptRole::User,
851///                 content: Content::Text {
852///                     text: format!("Please review this {} code:\n\n```{}\n{}\n```", lang, lang, code),
853///                     annotations: None,
854///                 },
855///             }],
856///         })
857///     }
858/// }
859///
860/// let prompt = CodeReviewPrompt.into_prompt();
861/// assert_eq!(prompt.name, "code_review");
862/// ```
863pub trait McpPrompt: Send + Sync + 'static {
864    const NAME: &'static str;
865    const DESCRIPTION: &'static str;
866
867    /// Define the arguments for this prompt
868    fn arguments(&self) -> Vec<PromptArgument> {
869        Vec::new()
870    }
871
872    fn get(
873        &self,
874        arguments: HashMap<String, String>,
875    ) -> impl Future<Output = Result<GetPromptResult>> + Send;
876
877    /// Convert to a Prompt instance
878    fn into_prompt(self) -> Prompt
879    where
880        Self: Sized,
881    {
882        let arguments = self.arguments();
883        let prompt = Arc::new(self);
884        Prompt {
885            name: Self::NAME.to_string(),
886            title: None,
887            description: Some(Self::DESCRIPTION.to_string()),
888            icons: None,
889            arguments,
890            handler: Arc::new(McpPromptHandler { prompt }),
891        }
892    }
893}
894
895/// Wrapper to make McpPrompt implement PromptHandler
896struct McpPromptHandler<T: McpPrompt> {
897    prompt: Arc<T>,
898}
899
900impl<T: McpPrompt> PromptHandler for McpPromptHandler<T> {
901    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
902        let prompt = self.prompt.clone();
903        Box::pin(async move { prompt.get(arguments).await })
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910
911    #[tokio::test]
912    async fn test_builder_prompt() {
913        let prompt = PromptBuilder::new("greet")
914            .description("A greeting prompt")
915            .required_arg("name", "Name to greet")
916            .handler(|args| async move {
917                let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
918                Ok(GetPromptResult {
919                    description: Some("Greeting".to_string()),
920                    messages: vec![PromptMessage {
921                        role: PromptRole::User,
922                        content: Content::Text {
923                            text: format!("Hello, {}!", name),
924                            annotations: None,
925                        },
926                    }],
927                })
928            })
929            .build();
930
931        assert_eq!(prompt.name, "greet");
932        assert_eq!(prompt.description.as_deref(), Some("A greeting prompt"));
933        assert_eq!(prompt.arguments.len(), 1);
934        assert!(prompt.arguments[0].required);
935
936        let mut args = HashMap::new();
937        args.insert("name".to_string(), "Alice".to_string());
938        let result = prompt.get(args).await.unwrap();
939
940        assert_eq!(result.messages.len(), 1);
941        match &result.messages[0].content {
942            Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
943            _ => panic!("Expected text content"),
944        }
945    }
946
947    #[tokio::test]
948    async fn test_static_prompt() {
949        let prompt = PromptBuilder::new("help")
950            .description("Help prompt")
951            .user_message("How can I help you today?");
952
953        let result = prompt.get(HashMap::new()).await.unwrap();
954        assert_eq!(result.messages.len(), 1);
955        match &result.messages[0].content {
956            Content::Text { text, .. } => assert_eq!(text, "How can I help you today?"),
957            _ => panic!("Expected text content"),
958        }
959    }
960
961    #[tokio::test]
962    async fn test_trait_prompt() {
963        struct TestPrompt;
964
965        impl McpPrompt for TestPrompt {
966            const NAME: &'static str = "test";
967            const DESCRIPTION: &'static str = "A test prompt";
968
969            fn arguments(&self) -> Vec<PromptArgument> {
970                vec![PromptArgument {
971                    name: "input".to_string(),
972                    description: Some("Test input".to_string()),
973                    required: true,
974                }]
975            }
976
977            async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
978                let input = args.get("input").map(|s| s.as_str()).unwrap_or("default");
979                Ok(GetPromptResult {
980                    description: Some("Test".to_string()),
981                    messages: vec![PromptMessage {
982                        role: PromptRole::User,
983                        content: Content::Text {
984                            text: format!("Input: {}", input),
985                            annotations: None,
986                        },
987                    }],
988                })
989            }
990        }
991
992        let prompt = TestPrompt.into_prompt();
993        assert_eq!(prompt.name, "test");
994        assert_eq!(prompt.arguments.len(), 1);
995
996        let mut args = HashMap::new();
997        args.insert("input".to_string(), "hello".to_string());
998        let result = prompt.get(args).await.unwrap();
999
1000        match &result.messages[0].content {
1001            Content::Text { text, .. } => assert_eq!(text, "Input: hello"),
1002            _ => panic!("Expected text content"),
1003        }
1004    }
1005
1006    #[test]
1007    fn test_prompt_definition() {
1008        let prompt = PromptBuilder::new("test")
1009            .description("Test description")
1010            .required_arg("arg1", "First arg")
1011            .optional_arg("arg2", "Second arg")
1012            .user_message("Test");
1013
1014        let def = prompt.definition();
1015        assert_eq!(def.name, "test");
1016        assert_eq!(def.description.as_deref(), Some("Test description"));
1017        assert_eq!(def.arguments.len(), 2);
1018        assert!(def.arguments[0].required);
1019        assert!(!def.arguments[1].required);
1020    }
1021
1022    #[tokio::test]
1023    async fn test_handler_with_context() {
1024        let prompt = PromptBuilder::new("context_prompt")
1025            .description("A prompt with context")
1026            .handler_with_context(|ctx: RequestContext, args| async move {
1027                // Verify we have access to the context
1028                let _ = ctx.is_cancelled();
1029                let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1030                Ok(GetPromptResult {
1031                    description: Some("Context prompt".to_string()),
1032                    messages: vec![PromptMessage {
1033                        role: PromptRole::User,
1034                        content: Content::Text {
1035                            text: format!("Hello, {}!", name),
1036                            annotations: None,
1037                        },
1038                    }],
1039                })
1040            })
1041            .build();
1042
1043        assert_eq!(prompt.name, "context_prompt");
1044        assert!(prompt.uses_context());
1045
1046        let ctx = RequestContext::new(RequestId::Number(1));
1047        let mut args = HashMap::new();
1048        args.insert("name".to_string(), "Alice".to_string());
1049        let result = prompt.get_with_context(ctx, args).await.unwrap();
1050
1051        match &result.messages[0].content {
1052            Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1053            _ => panic!("Expected text content"),
1054        }
1055    }
1056
1057    #[tokio::test]
1058    async fn test_prompt_with_timeout_layer() {
1059        use std::time::Duration;
1060        use tower::timeout::TimeoutLayer;
1061
1062        let prompt = PromptBuilder::new("timeout_prompt")
1063            .description("A prompt with timeout")
1064            .handler(|args: HashMap<String, String>| async move {
1065                let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1066                Ok(GetPromptResult {
1067                    description: Some("Timeout prompt".to_string()),
1068                    messages: vec![PromptMessage {
1069                        role: PromptRole::User,
1070                        content: Content::Text {
1071                            text: format!("Hello, {}!", name),
1072                            annotations: None,
1073                        },
1074                    }],
1075                })
1076            })
1077            .layer(TimeoutLayer::new(Duration::from_secs(5)));
1078
1079        assert_eq!(prompt.name, "timeout_prompt");
1080
1081        let mut args = HashMap::new();
1082        args.insert("name".to_string(), "Alice".to_string());
1083        let result = prompt.get(args).await.unwrap();
1084
1085        match &result.messages[0].content {
1086            Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1087            _ => panic!("Expected text content"),
1088        }
1089    }
1090
1091    #[tokio::test]
1092    async fn test_prompt_timeout_expires() {
1093        use std::time::Duration;
1094        use tower::timeout::TimeoutLayer;
1095
1096        let prompt = PromptBuilder::new("slow_prompt")
1097            .description("A slow prompt")
1098            .handler(|_args: HashMap<String, String>| async move {
1099                // Simulate slow operation
1100                tokio::time::sleep(Duration::from_millis(100)).await;
1101                Ok(GetPromptResult {
1102                    description: Some("Slow prompt".to_string()),
1103                    messages: vec![PromptMessage {
1104                        role: PromptRole::User,
1105                        content: Content::Text {
1106                            text: "This should not appear".to_string(),
1107                            annotations: None,
1108                        },
1109                    }],
1110                })
1111            })
1112            .layer(TimeoutLayer::new(Duration::from_millis(10)));
1113
1114        let result = prompt.get(HashMap::new()).await.unwrap();
1115
1116        // Should get an error message due to timeout
1117        assert!(result.description.as_ref().unwrap().contains("error"));
1118        match &result.messages[0].content {
1119            Content::Text { text, .. } => {
1120                assert!(text.contains("Error generating prompt"));
1121            }
1122            _ => panic!("Expected text content"),
1123        }
1124    }
1125
1126    #[tokio::test]
1127    async fn test_context_handler_with_layer() {
1128        use std::time::Duration;
1129        use tower::timeout::TimeoutLayer;
1130
1131        let prompt = PromptBuilder::new("context_timeout")
1132            .description("Context prompt with timeout")
1133            .handler_with_context(
1134                |_ctx: RequestContext, args: HashMap<String, String>| async move {
1135                    let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1136                    Ok(GetPromptResult {
1137                        description: Some("Context timeout".to_string()),
1138                        messages: vec![PromptMessage {
1139                            role: PromptRole::User,
1140                            content: Content::Text {
1141                                text: format!("Hello, {}!", name),
1142                                annotations: None,
1143                            },
1144                        }],
1145                    })
1146                },
1147            )
1148            .layer(TimeoutLayer::new(Duration::from_secs(5)));
1149
1150        assert_eq!(prompt.name, "context_timeout");
1151        assert!(prompt.uses_context());
1152
1153        let ctx = RequestContext::new(RequestId::Number(1));
1154        let mut args = HashMap::new();
1155        args.insert("name".to_string(), "Bob".to_string());
1156        let result = prompt.get_with_context(ctx, args).await.unwrap();
1157
1158        match &result.messages[0].content {
1159            Content::Text { text, .. } => assert_eq!(text, "Hello, Bob!"),
1160            _ => panic!("Expected text content"),
1161        }
1162    }
1163
1164    #[test]
1165    fn test_prompt_request_construction() {
1166        let args: HashMap<String, String> = [("key".to_string(), "value".to_string())]
1167            .into_iter()
1168            .collect();
1169
1170        let req = PromptRequest::with_arguments(args.clone());
1171        assert_eq!(req.arguments.get("key"), Some(&"value".to_string()));
1172
1173        let ctx = RequestContext::new(RequestId::Number(42));
1174        let req2 = PromptRequest::new(ctx, args);
1175        assert_eq!(req2.arguments.get("key"), Some(&"value".to_string()));
1176    }
1177
1178    #[test]
1179    fn test_prompt_catch_error_clone() {
1180        // Just verify the type can be constructed and cloned
1181        let handler = PromptHandlerService {
1182            handler: |_args: HashMap<String, String>| async {
1183                Ok::<GetPromptResult, Error>(GetPromptResult {
1184                    description: None,
1185                    messages: vec![],
1186                })
1187            },
1188        };
1189        let catch_error = PromptCatchError::new(handler);
1190        let _clone = catch_error.clone();
1191        // PromptCatchError with PromptHandlerService doesn't implement Debug
1192        // because the handler function doesn't implement Debug
1193    }
1194}