1use std::collections::HashMap;
43use std::convert::Infallible;
44use std::fmt;
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::task::{Context, Poll};
49
50use tokio::sync::Mutex;
51use tower::util::BoxCloneService;
52use tower::{Layer, ServiceExt};
53use tower_service::Service;
54
55use crate::context::RequestContext;
56use crate::error::{Error, Result};
57use crate::protocol::{
58 Content, GetPromptResult, PromptArgument, PromptDefinition, PromptMessage, PromptRole,
59 RequestId, ToolIcon,
60};
61
62pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
64
65#[derive(Debug, Clone)]
74pub struct PromptRequest {
75 pub context: RequestContext,
77 pub arguments: HashMap<String, String>,
79}
80
81impl PromptRequest {
82 pub fn new(context: RequestContext, arguments: HashMap<String, String>) -> Self {
84 Self { context, arguments }
85 }
86
87 pub fn with_arguments(arguments: HashMap<String, String>) -> Self {
89 Self {
90 context: RequestContext::new(RequestId::Number(0)),
91 arguments,
92 }
93 }
94}
95
96pub type BoxPromptService = BoxCloneService<PromptRequest, GetPromptResult, Infallible>;
102
103pub struct PromptCatchError<S> {
111 inner: S,
112}
113
114impl<S> PromptCatchError<S> {
115 pub fn new(inner: S) -> Self {
117 Self { inner }
118 }
119}
120
121impl<S: Clone> Clone for PromptCatchError<S> {
122 fn clone(&self) -> Self {
123 Self {
124 inner: self.inner.clone(),
125 }
126 }
127}
128
129impl<S: fmt::Debug> fmt::Debug for PromptCatchError<S> {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 f.debug_struct("PromptCatchError")
132 .field("inner", &self.inner)
133 .finish()
134 }
135}
136
137impl<S> Service<PromptRequest> for PromptCatchError<S>
138where
139 S: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
140 S::Error: fmt::Display + Send,
141 S::Future: Send,
142{
143 type Response = GetPromptResult;
144 type Error = Infallible;
145 type Future =
146 Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Infallible>> + Send>>;
147
148 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
149 self.inner.poll_ready(cx).map_err(|_| unreachable!())
150 }
151
152 fn call(&mut self, req: PromptRequest) -> Self::Future {
153 let fut = self.inner.call(req);
154
155 Box::pin(async move {
156 match fut.await {
157 Ok(response) => Ok(response),
158 Err(err) => {
159 Ok(GetPromptResult {
162 description: Some(format!("Prompt error: {}", err)),
163 messages: vec![PromptMessage {
164 role: PromptRole::Assistant,
165 content: Content::Text {
166 text: format!("Error generating prompt: {}", err),
167 annotations: None,
168 },
169 }],
170 })
171 }
172 }
173 })
174 }
175}
176
177pub struct PromptHandlerService<F> {
182 handler: F,
183}
184
185impl<F> Clone for PromptHandlerService<F>
186where
187 F: Clone,
188{
189 fn clone(&self) -> Self {
190 Self {
191 handler: self.handler.clone(),
192 }
193 }
194}
195
196impl<F, Fut> Service<PromptRequest> for PromptHandlerService<F>
197where
198 F: Fn(HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
199 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
200{
201 type Response = GetPromptResult;
202 type Error = Error;
203 type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
204
205 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
206 Poll::Ready(Ok(()))
207 }
208
209 fn call(&mut self, req: PromptRequest) -> Self::Future {
210 let handler = self.handler.clone();
211 Box::pin(async move { handler(req.arguments).await })
212 }
213}
214
215pub struct PromptContextHandlerService<F> {
219 handler: F,
220}
221
222impl<F> Clone for PromptContextHandlerService<F>
223where
224 F: Clone,
225{
226 fn clone(&self) -> Self {
227 Self {
228 handler: self.handler.clone(),
229 }
230 }
231}
232
233impl<F, Fut> Service<PromptRequest> for PromptContextHandlerService<F>
234where
235 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
236 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
237{
238 type Response = GetPromptResult;
239 type Error = Error;
240 type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
241
242 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
243 Poll::Ready(Ok(()))
244 }
245
246 fn call(&mut self, req: PromptRequest) -> Self::Future {
247 let handler = self.handler.clone();
248 Box::pin(async move { handler(req.context, req.arguments).await })
249 }
250}
251
252pub trait PromptHandler: Send + Sync {
254 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>>;
256
257 fn get_with_context(
262 &self,
263 _ctx: RequestContext,
264 arguments: HashMap<String, String>,
265 ) -> BoxFuture<'_, Result<GetPromptResult>> {
266 self.get(arguments)
267 }
268
269 fn uses_context(&self) -> bool {
271 false
272 }
273}
274
275pub struct Prompt {
277 pub name: String,
278 pub title: Option<String>,
279 pub description: Option<String>,
280 pub icons: Option<Vec<ToolIcon>>,
281 pub arguments: Vec<PromptArgument>,
282 handler: Arc<dyn PromptHandler>,
283}
284
285impl Clone for Prompt {
286 fn clone(&self) -> Self {
287 Self {
288 name: self.name.clone(),
289 title: self.title.clone(),
290 description: self.description.clone(),
291 icons: self.icons.clone(),
292 arguments: self.arguments.clone(),
293 handler: self.handler.clone(),
294 }
295 }
296}
297
298impl std::fmt::Debug for Prompt {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.debug_struct("Prompt")
301 .field("name", &self.name)
302 .field("title", &self.title)
303 .field("description", &self.description)
304 .field("icons", &self.icons)
305 .field("arguments", &self.arguments)
306 .finish_non_exhaustive()
307 }
308}
309
310impl Prompt {
311 pub fn builder(name: impl Into<String>) -> PromptBuilder {
313 PromptBuilder::new(name)
314 }
315
316 pub fn definition(&self) -> PromptDefinition {
318 PromptDefinition {
319 name: self.name.clone(),
320 title: self.title.clone(),
321 description: self.description.clone(),
322 icons: self.icons.clone(),
323 arguments: self.arguments.clone(),
324 }
325 }
326
327 pub fn get(
329 &self,
330 arguments: HashMap<String, String>,
331 ) -> BoxFuture<'_, Result<GetPromptResult>> {
332 self.handler.get(arguments)
333 }
334
335 pub fn get_with_context(
339 &self,
340 ctx: RequestContext,
341 arguments: HashMap<String, String>,
342 ) -> BoxFuture<'_, Result<GetPromptResult>> {
343 self.handler.get_with_context(ctx, arguments)
344 }
345
346 pub fn uses_context(&self) -> bool {
348 self.handler.uses_context()
349 }
350}
351
352pub struct PromptBuilder {
385 name: String,
386 title: Option<String>,
387 description: Option<String>,
388 icons: Option<Vec<ToolIcon>>,
389 arguments: Vec<PromptArgument>,
390}
391
392impl PromptBuilder {
393 pub fn new(name: impl Into<String>) -> Self {
394 Self {
395 name: name.into(),
396 title: None,
397 description: None,
398 icons: None,
399 arguments: Vec::new(),
400 }
401 }
402
403 pub fn title(mut self, title: impl Into<String>) -> Self {
405 self.title = Some(title.into());
406 self
407 }
408
409 pub fn description(mut self, description: impl Into<String>) -> Self {
411 self.description = Some(description.into());
412 self
413 }
414
415 pub fn icon(mut self, src: impl Into<String>) -> Self {
417 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
418 src: src.into(),
419 mime_type: None,
420 sizes: None,
421 });
422 self
423 }
424
425 pub fn icon_with_meta(
427 mut self,
428 src: impl Into<String>,
429 mime_type: Option<String>,
430 sizes: Option<Vec<String>>,
431 ) -> Self {
432 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
433 src: src.into(),
434 mime_type,
435 sizes,
436 });
437 self
438 }
439
440 pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
442 self.arguments.push(PromptArgument {
443 name: name.into(),
444 description: Some(description.into()),
445 required: true,
446 });
447 self
448 }
449
450 pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
452 self.arguments.push(PromptArgument {
453 name: name.into(),
454 description: Some(description.into()),
455 required: false,
456 });
457 self
458 }
459
460 pub fn argument(mut self, arg: PromptArgument) -> Self {
462 self.arguments.push(arg);
463 self
464 }
465
466 pub fn handler<F, Fut>(self, handler: F) -> PromptBuilderWithHandler<F>
471 where
472 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
473 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
474 {
475 PromptBuilderWithHandler {
476 name: self.name,
477 title: self.title,
478 description: self.description,
479 icons: self.icons,
480 arguments: self.arguments,
481 handler,
482 }
483 }
484
485 pub fn handler_with_context<F, Fut>(self, handler: F) -> PromptBuilderWithContextHandler<F>
490 where
491 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
492 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
493 {
494 PromptBuilderWithContextHandler {
495 name: self.name,
496 title: self.title,
497 description: self.description,
498 icons: self.icons,
499 arguments: self.arguments,
500 handler,
501 }
502 }
503
504 pub fn static_prompt(self, messages: Vec<PromptMessage>) -> Prompt {
506 let description = self.description.clone();
507 self.handler(move |_| {
508 let messages = messages.clone();
509 let description = description.clone();
510 async move {
511 Ok(GetPromptResult {
512 description,
513 messages,
514 })
515 }
516 })
517 .build()
518 }
519
520 pub fn user_message(self, text: impl Into<String>) -> Prompt {
522 let text = text.into();
523 self.static_prompt(vec![PromptMessage {
524 role: PromptRole::User,
525 content: Content::Text {
526 text,
527 annotations: None,
528 },
529 }])
530 }
531
532 pub fn build<F, Fut>(self, handler: F) -> Prompt
537 where
538 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
539 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
540 {
541 self.handler(handler).build()
542 }
543}
544
545pub struct PromptBuilderWithHandler<F> {
550 name: String,
551 title: Option<String>,
552 description: Option<String>,
553 icons: Option<Vec<ToolIcon>>,
554 arguments: Vec<PromptArgument>,
555 handler: F,
556}
557
558impl<F, Fut> PromptBuilderWithHandler<F>
559where
560 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
561 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
562{
563 pub fn build(self) -> Prompt {
565 Prompt {
566 name: self.name,
567 title: self.title,
568 description: self.description,
569 icons: self.icons,
570 arguments: self.arguments,
571 handler: Arc::new(FnHandler {
572 handler: self.handler,
573 }),
574 }
575 }
576
577 pub fn layer<L>(self, layer: L) -> Prompt
608 where
609 L: Layer<PromptHandlerService<F>> + Send + Sync + 'static,
610 L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
611 <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
612 <L::Service as Service<PromptRequest>>::Future: Send,
613 {
614 let service = PromptHandlerService {
615 handler: self.handler,
616 };
617 let wrapped = layer.layer(service);
618 let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
619
620 Prompt {
621 name: self.name,
622 title: self.title,
623 description: self.description,
624 icons: self.icons,
625 arguments: self.arguments,
626 handler: Arc::new(ServiceHandler {
627 service: Mutex::new(boxed),
628 }),
629 }
630 }
631}
632
633pub struct PromptBuilderWithContextHandler<F> {
635 name: String,
636 title: Option<String>,
637 description: Option<String>,
638 icons: Option<Vec<ToolIcon>>,
639 arguments: Vec<PromptArgument>,
640 handler: F,
641}
642
643impl<F, Fut> PromptBuilderWithContextHandler<F>
644where
645 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
646 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
647{
648 pub fn build(self) -> Prompt {
650 Prompt {
651 name: self.name,
652 title: self.title,
653 description: self.description,
654 icons: self.icons,
655 arguments: self.arguments,
656 handler: Arc::new(ContextAwareHandler {
657 handler: self.handler,
658 }),
659 }
660 }
661
662 pub fn layer<L>(self, layer: L) -> Prompt
664 where
665 L: Layer<PromptContextHandlerService<F>> + Send + Sync + 'static,
666 L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
667 <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
668 <L::Service as Service<PromptRequest>>::Future: Send,
669 {
670 let service = PromptContextHandlerService {
671 handler: self.handler,
672 };
673 let wrapped = layer.layer(service);
674 let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
675
676 Prompt {
677 name: self.name,
678 title: self.title,
679 description: self.description,
680 icons: self.icons,
681 arguments: self.arguments,
682 handler: Arc::new(ServiceContextHandler {
683 service: Mutex::new(boxed),
684 }),
685 }
686 }
687}
688
689struct FnHandler<F> {
695 handler: F,
696}
697
698impl<F, Fut> PromptHandler for FnHandler<F>
699where
700 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
701 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
702{
703 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
704 Box::pin((self.handler)(arguments))
705 }
706}
707
708struct ContextAwareHandler<F> {
710 handler: F,
711}
712
713impl<F, Fut> PromptHandler for ContextAwareHandler<F>
714where
715 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + 'static,
716 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
717{
718 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
719 let ctx = RequestContext::new(RequestId::Number(0));
721 self.get_with_context(ctx, arguments)
722 }
723
724 fn get_with_context(
725 &self,
726 ctx: RequestContext,
727 arguments: HashMap<String, String>,
728 ) -> BoxFuture<'_, Result<GetPromptResult>> {
729 Box::pin((self.handler)(ctx, arguments))
730 }
731
732 fn uses_context(&self) -> bool {
733 true
734 }
735}
736
737struct ServiceHandler {
743 service: Mutex<BoxPromptService>,
744}
745
746impl PromptHandler for ServiceHandler {
747 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
748 Box::pin(async move {
749 let req = PromptRequest::with_arguments(arguments);
750 let mut service = self.service.lock().await.clone();
751 match service.ready().await {
752 Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
753 Err(e) => match e {},
754 }
755 })
756 }
757
758 fn get_with_context(
759 &self,
760 ctx: RequestContext,
761 arguments: HashMap<String, String>,
762 ) -> BoxFuture<'_, Result<GetPromptResult>> {
763 Box::pin(async move {
764 let req = PromptRequest::new(ctx, arguments);
765 let mut service = self.service.lock().await.clone();
766 match service.ready().await {
767 Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
768 Err(e) => match e {},
769 }
770 })
771 }
772}
773
774struct ServiceContextHandler {
776 service: Mutex<BoxPromptService>,
777}
778
779impl PromptHandler for ServiceContextHandler {
780 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
781 let ctx = RequestContext::new(RequestId::Number(0));
782 self.get_with_context(ctx, arguments)
783 }
784
785 fn get_with_context(
786 &self,
787 ctx: RequestContext,
788 arguments: HashMap<String, String>,
789 ) -> BoxFuture<'_, Result<GetPromptResult>> {
790 Box::pin(async move {
791 let req = PromptRequest::new(ctx, arguments);
792 let mut service = self.service.lock().await.clone();
793 match service.ready().await {
794 Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
795 Err(e) => match e {},
796 }
797 })
798 }
799
800 fn uses_context(&self) -> bool {
801 true
802 }
803}
804
805pub trait McpPrompt: Send + Sync + 'static {
864 const NAME: &'static str;
865 const DESCRIPTION: &'static str;
866
867 fn arguments(&self) -> Vec<PromptArgument> {
869 Vec::new()
870 }
871
872 fn get(
873 &self,
874 arguments: HashMap<String, String>,
875 ) -> impl Future<Output = Result<GetPromptResult>> + Send;
876
877 fn into_prompt(self) -> Prompt
879 where
880 Self: Sized,
881 {
882 let arguments = self.arguments();
883 let prompt = Arc::new(self);
884 Prompt {
885 name: Self::NAME.to_string(),
886 title: None,
887 description: Some(Self::DESCRIPTION.to_string()),
888 icons: None,
889 arguments,
890 handler: Arc::new(McpPromptHandler { prompt }),
891 }
892 }
893}
894
895struct McpPromptHandler<T: McpPrompt> {
897 prompt: Arc<T>,
898}
899
900impl<T: McpPrompt> PromptHandler for McpPromptHandler<T> {
901 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
902 let prompt = self.prompt.clone();
903 Box::pin(async move { prompt.get(arguments).await })
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910
911 #[tokio::test]
912 async fn test_builder_prompt() {
913 let prompt = PromptBuilder::new("greet")
914 .description("A greeting prompt")
915 .required_arg("name", "Name to greet")
916 .handler(|args| async move {
917 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
918 Ok(GetPromptResult {
919 description: Some("Greeting".to_string()),
920 messages: vec![PromptMessage {
921 role: PromptRole::User,
922 content: Content::Text {
923 text: format!("Hello, {}!", name),
924 annotations: None,
925 },
926 }],
927 })
928 })
929 .build();
930
931 assert_eq!(prompt.name, "greet");
932 assert_eq!(prompt.description.as_deref(), Some("A greeting prompt"));
933 assert_eq!(prompt.arguments.len(), 1);
934 assert!(prompt.arguments[0].required);
935
936 let mut args = HashMap::new();
937 args.insert("name".to_string(), "Alice".to_string());
938 let result = prompt.get(args).await.unwrap();
939
940 assert_eq!(result.messages.len(), 1);
941 match &result.messages[0].content {
942 Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
943 _ => panic!("Expected text content"),
944 }
945 }
946
947 #[tokio::test]
948 async fn test_static_prompt() {
949 let prompt = PromptBuilder::new("help")
950 .description("Help prompt")
951 .user_message("How can I help you today?");
952
953 let result = prompt.get(HashMap::new()).await.unwrap();
954 assert_eq!(result.messages.len(), 1);
955 match &result.messages[0].content {
956 Content::Text { text, .. } => assert_eq!(text, "How can I help you today?"),
957 _ => panic!("Expected text content"),
958 }
959 }
960
961 #[tokio::test]
962 async fn test_trait_prompt() {
963 struct TestPrompt;
964
965 impl McpPrompt for TestPrompt {
966 const NAME: &'static str = "test";
967 const DESCRIPTION: &'static str = "A test prompt";
968
969 fn arguments(&self) -> Vec<PromptArgument> {
970 vec![PromptArgument {
971 name: "input".to_string(),
972 description: Some("Test input".to_string()),
973 required: true,
974 }]
975 }
976
977 async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
978 let input = args.get("input").map(|s| s.as_str()).unwrap_or("default");
979 Ok(GetPromptResult {
980 description: Some("Test".to_string()),
981 messages: vec![PromptMessage {
982 role: PromptRole::User,
983 content: Content::Text {
984 text: format!("Input: {}", input),
985 annotations: None,
986 },
987 }],
988 })
989 }
990 }
991
992 let prompt = TestPrompt.into_prompt();
993 assert_eq!(prompt.name, "test");
994 assert_eq!(prompt.arguments.len(), 1);
995
996 let mut args = HashMap::new();
997 args.insert("input".to_string(), "hello".to_string());
998 let result = prompt.get(args).await.unwrap();
999
1000 match &result.messages[0].content {
1001 Content::Text { text, .. } => assert_eq!(text, "Input: hello"),
1002 _ => panic!("Expected text content"),
1003 }
1004 }
1005
1006 #[test]
1007 fn test_prompt_definition() {
1008 let prompt = PromptBuilder::new("test")
1009 .description("Test description")
1010 .required_arg("arg1", "First arg")
1011 .optional_arg("arg2", "Second arg")
1012 .user_message("Test");
1013
1014 let def = prompt.definition();
1015 assert_eq!(def.name, "test");
1016 assert_eq!(def.description.as_deref(), Some("Test description"));
1017 assert_eq!(def.arguments.len(), 2);
1018 assert!(def.arguments[0].required);
1019 assert!(!def.arguments[1].required);
1020 }
1021
1022 #[tokio::test]
1023 async fn test_handler_with_context() {
1024 let prompt = PromptBuilder::new("context_prompt")
1025 .description("A prompt with context")
1026 .handler_with_context(|ctx: RequestContext, args| async move {
1027 let _ = ctx.is_cancelled();
1029 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1030 Ok(GetPromptResult {
1031 description: Some("Context prompt".to_string()),
1032 messages: vec![PromptMessage {
1033 role: PromptRole::User,
1034 content: Content::Text {
1035 text: format!("Hello, {}!", name),
1036 annotations: None,
1037 },
1038 }],
1039 })
1040 })
1041 .build();
1042
1043 assert_eq!(prompt.name, "context_prompt");
1044 assert!(prompt.uses_context());
1045
1046 let ctx = RequestContext::new(RequestId::Number(1));
1047 let mut args = HashMap::new();
1048 args.insert("name".to_string(), "Alice".to_string());
1049 let result = prompt.get_with_context(ctx, args).await.unwrap();
1050
1051 match &result.messages[0].content {
1052 Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1053 _ => panic!("Expected text content"),
1054 }
1055 }
1056
1057 #[tokio::test]
1058 async fn test_prompt_with_timeout_layer() {
1059 use std::time::Duration;
1060 use tower::timeout::TimeoutLayer;
1061
1062 let prompt = PromptBuilder::new("timeout_prompt")
1063 .description("A prompt with timeout")
1064 .handler(|args: HashMap<String, String>| async move {
1065 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1066 Ok(GetPromptResult {
1067 description: Some("Timeout prompt".to_string()),
1068 messages: vec![PromptMessage {
1069 role: PromptRole::User,
1070 content: Content::Text {
1071 text: format!("Hello, {}!", name),
1072 annotations: None,
1073 },
1074 }],
1075 })
1076 })
1077 .layer(TimeoutLayer::new(Duration::from_secs(5)));
1078
1079 assert_eq!(prompt.name, "timeout_prompt");
1080
1081 let mut args = HashMap::new();
1082 args.insert("name".to_string(), "Alice".to_string());
1083 let result = prompt.get(args).await.unwrap();
1084
1085 match &result.messages[0].content {
1086 Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1087 _ => panic!("Expected text content"),
1088 }
1089 }
1090
1091 #[tokio::test]
1092 async fn test_prompt_timeout_expires() {
1093 use std::time::Duration;
1094 use tower::timeout::TimeoutLayer;
1095
1096 let prompt = PromptBuilder::new("slow_prompt")
1097 .description("A slow prompt")
1098 .handler(|_args: HashMap<String, String>| async move {
1099 tokio::time::sleep(Duration::from_millis(100)).await;
1101 Ok(GetPromptResult {
1102 description: Some("Slow prompt".to_string()),
1103 messages: vec![PromptMessage {
1104 role: PromptRole::User,
1105 content: Content::Text {
1106 text: "This should not appear".to_string(),
1107 annotations: None,
1108 },
1109 }],
1110 })
1111 })
1112 .layer(TimeoutLayer::new(Duration::from_millis(10)));
1113
1114 let result = prompt.get(HashMap::new()).await.unwrap();
1115
1116 assert!(result.description.as_ref().unwrap().contains("error"));
1118 match &result.messages[0].content {
1119 Content::Text { text, .. } => {
1120 assert!(text.contains("Error generating prompt"));
1121 }
1122 _ => panic!("Expected text content"),
1123 }
1124 }
1125
1126 #[tokio::test]
1127 async fn test_context_handler_with_layer() {
1128 use std::time::Duration;
1129 use tower::timeout::TimeoutLayer;
1130
1131 let prompt = PromptBuilder::new("context_timeout")
1132 .description("Context prompt with timeout")
1133 .handler_with_context(
1134 |_ctx: RequestContext, args: HashMap<String, String>| async move {
1135 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1136 Ok(GetPromptResult {
1137 description: Some("Context timeout".to_string()),
1138 messages: vec![PromptMessage {
1139 role: PromptRole::User,
1140 content: Content::Text {
1141 text: format!("Hello, {}!", name),
1142 annotations: None,
1143 },
1144 }],
1145 })
1146 },
1147 )
1148 .layer(TimeoutLayer::new(Duration::from_secs(5)));
1149
1150 assert_eq!(prompt.name, "context_timeout");
1151 assert!(prompt.uses_context());
1152
1153 let ctx = RequestContext::new(RequestId::Number(1));
1154 let mut args = HashMap::new();
1155 args.insert("name".to_string(), "Bob".to_string());
1156 let result = prompt.get_with_context(ctx, args).await.unwrap();
1157
1158 match &result.messages[0].content {
1159 Content::Text { text, .. } => assert_eq!(text, "Hello, Bob!"),
1160 _ => panic!("Expected text content"),
1161 }
1162 }
1163
1164 #[test]
1165 fn test_prompt_request_construction() {
1166 let args: HashMap<String, String> = [("key".to_string(), "value".to_string())]
1167 .into_iter()
1168 .collect();
1169
1170 let req = PromptRequest::with_arguments(args.clone());
1171 assert_eq!(req.arguments.get("key"), Some(&"value".to_string()));
1172
1173 let ctx = RequestContext::new(RequestId::Number(42));
1174 let req2 = PromptRequest::new(ctx, args);
1175 assert_eq!(req2.arguments.get("key"), Some(&"value".to_string()));
1176 }
1177
1178 #[test]
1179 fn test_prompt_catch_error_clone() {
1180 let handler = PromptHandlerService {
1182 handler: |_args: HashMap<String, String>| async {
1183 Ok::<GetPromptResult, Error>(GetPromptResult {
1184 description: None,
1185 messages: vec![],
1186 })
1187 },
1188 };
1189 let catch_error = PromptCatchError::new(handler);
1190 let _clone = catch_error.clone();
1191 }
1194}