1use std::collections::HashMap;
46use std::convert::Infallible;
47use std::fmt;
48use std::future::Future;
49use std::pin::Pin;
50use std::sync::Arc;
51use std::task::{Context, Poll};
52
53use pin_project_lite::pin_project;
54
55use tokio::sync::Mutex;
56use tower::util::BoxCloneService;
57use tower::{Layer, ServiceExt};
58use tower_service::Service;
59
60use crate::context::RequestContext;
61use crate::error::{Error, Result};
62use crate::protocol::{
63 Content, GetPromptResult, PromptArgument, PromptDefinition, PromptMessage, PromptRole,
64 RequestId, ToolIcon,
65};
66
67pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
69
70#[derive(Debug, Clone)]
79pub struct PromptRequest {
80 pub context: RequestContext,
82 pub arguments: HashMap<String, String>,
84}
85
86impl PromptRequest {
87 pub fn new(context: RequestContext, arguments: HashMap<String, String>) -> Self {
89 Self { context, arguments }
90 }
91
92 pub fn with_arguments(arguments: HashMap<String, String>) -> Self {
94 Self {
95 context: RequestContext::new(RequestId::Number(0)),
96 arguments,
97 }
98 }
99}
100
101pub type BoxPromptService = BoxCloneService<PromptRequest, GetPromptResult, Infallible>;
107
108#[doc(hidden)]
116pub struct PromptCatchError<S> {
117 inner: S,
118}
119
120impl<S> PromptCatchError<S> {
121 pub fn new(inner: S) -> Self {
123 Self { inner }
124 }
125}
126
127impl<S: Clone> Clone for PromptCatchError<S> {
128 fn clone(&self) -> Self {
129 Self {
130 inner: self.inner.clone(),
131 }
132 }
133}
134
135impl<S: fmt::Debug> fmt::Debug for PromptCatchError<S> {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.debug_struct("PromptCatchError")
138 .field("inner", &self.inner)
139 .finish()
140 }
141}
142
143pin_project! {
144 #[doc(hidden)]
146 pub struct PromptCatchErrorFuture<F> {
147 #[pin]
148 inner: F,
149 }
150}
151
152impl<F, E> Future for PromptCatchErrorFuture<F>
153where
154 F: Future<Output = std::result::Result<GetPromptResult, E>>,
155 E: fmt::Display,
156{
157 type Output = std::result::Result<GetPromptResult, Infallible>;
158
159 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160 match self.project().inner.poll(cx) {
161 Poll::Pending => Poll::Pending,
162 Poll::Ready(Ok(response)) => Poll::Ready(Ok(response)),
163 Poll::Ready(Err(err)) => Poll::Ready(Ok(GetPromptResult {
164 description: Some(format!("Prompt error: {}", err)),
165 messages: vec![PromptMessage {
166 role: PromptRole::Assistant,
167 content: Content::Text {
168 text: format!("Error generating prompt: {}", err),
169 annotations: None,
170 meta: None,
171 },
172 meta: None,
173 }],
174 meta: None,
175 })),
176 }
177 }
178}
179
180impl<S> Service<PromptRequest> for PromptCatchError<S>
181where
182 S: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
183 S::Error: fmt::Display + Send,
184 S::Future: Send,
185{
186 type Response = GetPromptResult;
187 type Error = Infallible;
188 type Future = PromptCatchErrorFuture<S::Future>;
189
190 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
191 self.inner.poll_ready(cx).map_err(|_| unreachable!())
192 }
193
194 fn call(&mut self, req: PromptRequest) -> Self::Future {
195 PromptCatchErrorFuture {
196 inner: self.inner.call(req),
197 }
198 }
199}
200
201#[doc(hidden)]
206pub struct PromptHandlerService<F> {
207 handler: F,
208}
209
210impl<F> Clone for PromptHandlerService<F>
211where
212 F: Clone,
213{
214 fn clone(&self) -> Self {
215 Self {
216 handler: self.handler.clone(),
217 }
218 }
219}
220
221impl<F, Fut> Service<PromptRequest> for PromptHandlerService<F>
222where
223 F: Fn(HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
224 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
225{
226 type Response = GetPromptResult;
227 type Error = Error;
228 type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
229
230 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
231 Poll::Ready(Ok(()))
232 }
233
234 fn call(&mut self, req: PromptRequest) -> Self::Future {
235 let handler = self.handler.clone();
236 Box::pin(async move { handler(req.arguments).await })
237 }
238}
239
240#[doc(hidden)]
244pub struct PromptContextHandlerService<F> {
245 handler: F,
246}
247
248impl<F> Clone for PromptContextHandlerService<F>
249where
250 F: Clone,
251{
252 fn clone(&self) -> Self {
253 Self {
254 handler: self.handler.clone(),
255 }
256 }
257}
258
259impl<F, Fut> Service<PromptRequest> for PromptContextHandlerService<F>
260where
261 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
262 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
263{
264 type Response = GetPromptResult;
265 type Error = Error;
266 type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
267
268 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
269 Poll::Ready(Ok(()))
270 }
271
272 fn call(&mut self, req: PromptRequest) -> Self::Future {
273 let handler = self.handler.clone();
274 Box::pin(async move { handler(req.context, req.arguments).await })
275 }
276}
277
278pub trait PromptHandler: Send + Sync {
280 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>>;
282
283 fn get_with_context(
288 &self,
289 _ctx: RequestContext,
290 arguments: HashMap<String, String>,
291 ) -> BoxFuture<'_, Result<GetPromptResult>> {
292 self.get(arguments)
293 }
294
295 fn uses_context(&self) -> bool {
297 false
298 }
299}
300
301pub struct Prompt {
303 pub name: String,
305 pub title: Option<String>,
307 pub description: Option<String>,
309 pub icons: Option<Vec<ToolIcon>>,
311 pub arguments: Vec<PromptArgument>,
313 handler: Arc<dyn PromptHandler>,
314}
315
316impl Clone for Prompt {
317 fn clone(&self) -> Self {
318 Self {
319 name: self.name.clone(),
320 title: self.title.clone(),
321 description: self.description.clone(),
322 icons: self.icons.clone(),
323 arguments: self.arguments.clone(),
324 handler: self.handler.clone(),
325 }
326 }
327}
328
329impl std::fmt::Debug for Prompt {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 f.debug_struct("Prompt")
332 .field("name", &self.name)
333 .field("title", &self.title)
334 .field("description", &self.description)
335 .field("icons", &self.icons)
336 .field("arguments", &self.arguments)
337 .finish_non_exhaustive()
338 }
339}
340
341impl Prompt {
342 pub fn builder(name: impl Into<String>) -> PromptBuilder {
344 PromptBuilder::new(name)
345 }
346
347 pub fn definition(&self) -> PromptDefinition {
349 PromptDefinition {
350 name: self.name.clone(),
351 title: self.title.clone(),
352 description: self.description.clone(),
353 icons: self.icons.clone(),
354 arguments: self.arguments.clone(),
355 meta: None,
356 }
357 }
358
359 pub fn get(
361 &self,
362 arguments: HashMap<String, String>,
363 ) -> BoxFuture<'_, Result<GetPromptResult>> {
364 self.handler.get(arguments)
365 }
366
367 pub fn get_with_context(
371 &self,
372 ctx: RequestContext,
373 arguments: HashMap<String, String>,
374 ) -> BoxFuture<'_, Result<GetPromptResult>> {
375 self.handler.get_with_context(ctx, arguments)
376 }
377
378 pub fn uses_context(&self) -> bool {
380 self.handler.uses_context()
381 }
382}
383
384pub struct PromptBuilder {
420 name: String,
421 title: Option<String>,
422 description: Option<String>,
423 icons: Option<Vec<ToolIcon>>,
424 arguments: Vec<PromptArgument>,
425}
426
427impl PromptBuilder {
428 pub fn new(name: impl Into<String>) -> Self {
430 Self {
431 name: name.into(),
432 title: None,
433 description: None,
434 icons: None,
435 arguments: Vec::new(),
436 }
437 }
438
439 pub fn title(mut self, title: impl Into<String>) -> Self {
441 self.title = Some(title.into());
442 self
443 }
444
445 pub fn description(mut self, description: impl Into<String>) -> Self {
447 self.description = Some(description.into());
448 self
449 }
450
451 pub fn icon(mut self, src: impl Into<String>) -> Self {
453 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
454 src: src.into(),
455 mime_type: None,
456 sizes: None,
457 theme: None,
458 });
459 self
460 }
461
462 pub fn icon_with_meta(
464 mut self,
465 src: impl Into<String>,
466 mime_type: Option<String>,
467 sizes: Option<Vec<String>>,
468 ) -> Self {
469 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
470 src: src.into(),
471 mime_type,
472 sizes,
473 theme: None,
474 });
475 self
476 }
477
478 pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
480 self.arguments.push(PromptArgument {
481 name: name.into(),
482 description: Some(description.into()),
483 required: true,
484 });
485 self
486 }
487
488 pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
490 self.arguments.push(PromptArgument {
491 name: name.into(),
492 description: Some(description.into()),
493 required: false,
494 });
495 self
496 }
497
498 pub fn argument(mut self, arg: PromptArgument) -> Self {
500 self.arguments.push(arg);
501 self
502 }
503
504 pub fn handler<F, Fut>(self, handler: F) -> PromptBuilderWithHandler<F>
552 where
553 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
554 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
555 {
556 PromptBuilderWithHandler {
557 name: self.name,
558 title: self.title,
559 description: self.description,
560 icons: self.icons,
561 arguments: self.arguments,
562 handler,
563 }
564 }
565
566 pub fn handler_with_context<F, Fut>(self, handler: F) -> PromptBuilderWithContextHandler<F>
571 where
572 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
573 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
574 {
575 PromptBuilderWithContextHandler {
576 name: self.name,
577 title: self.title,
578 description: self.description,
579 icons: self.icons,
580 arguments: self.arguments,
581 handler,
582 }
583 }
584
585 pub fn static_prompt(self, messages: Vec<PromptMessage>) -> Prompt {
587 let description = self.description.clone();
588 self.handler(move |_| {
589 let messages = messages.clone();
590 let description = description.clone();
591 async move {
592 Ok(GetPromptResult {
593 description,
594 messages,
595 meta: None,
596 })
597 }
598 })
599 .build()
600 }
601
602 pub fn user_message(self, text: impl Into<String>) -> Prompt {
604 let text = text.into();
605 self.static_prompt(vec![PromptMessage {
606 role: PromptRole::User,
607 content: Content::Text {
608 text,
609 annotations: None,
610 meta: None,
611 },
612 meta: None,
613 }])
614 }
615
616 pub fn build<F, Fut>(self, handler: F) -> Prompt
621 where
622 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
623 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
624 {
625 self.handler(handler).build()
626 }
627}
628
629#[doc(hidden)]
634pub struct PromptBuilderWithHandler<F> {
635 name: String,
636 title: Option<String>,
637 description: Option<String>,
638 icons: Option<Vec<ToolIcon>>,
639 arguments: Vec<PromptArgument>,
640 handler: F,
641}
642
643impl<F, Fut> PromptBuilderWithHandler<F>
644where
645 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
646 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
647{
648 pub fn build(self) -> Prompt {
650 Prompt {
651 name: self.name,
652 title: self.title,
653 description: self.description,
654 icons: self.icons,
655 arguments: self.arguments,
656 handler: Arc::new(FnHandler {
657 handler: self.handler,
658 }),
659 }
660 }
661
662 pub fn layer<L>(self, layer: L) -> Prompt
696 where
697 L: Layer<PromptHandlerService<F>> + Send + Sync + 'static,
698 L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
699 <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
700 <L::Service as Service<PromptRequest>>::Future: Send,
701 {
702 let service = PromptHandlerService {
703 handler: self.handler,
704 };
705 let wrapped = layer.layer(service);
706 let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
707
708 Prompt {
709 name: self.name,
710 title: self.title,
711 description: self.description,
712 icons: self.icons,
713 arguments: self.arguments,
714 handler: Arc::new(ServiceHandler {
715 service: Mutex::new(boxed),
716 }),
717 }
718 }
719}
720
721#[doc(hidden)]
723pub struct PromptBuilderWithContextHandler<F> {
724 name: String,
725 title: Option<String>,
726 description: Option<String>,
727 icons: Option<Vec<ToolIcon>>,
728 arguments: Vec<PromptArgument>,
729 handler: F,
730}
731
732impl<F, Fut> PromptBuilderWithContextHandler<F>
733where
734 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
735 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
736{
737 pub fn build(self) -> Prompt {
739 Prompt {
740 name: self.name,
741 title: self.title,
742 description: self.description,
743 icons: self.icons,
744 arguments: self.arguments,
745 handler: Arc::new(ContextAwareHandler {
746 handler: self.handler,
747 }),
748 }
749 }
750
751 pub fn layer<L>(self, layer: L) -> Prompt
753 where
754 L: Layer<PromptContextHandlerService<F>> + Send + Sync + 'static,
755 L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
756 <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
757 <L::Service as Service<PromptRequest>>::Future: Send,
758 {
759 let service = PromptContextHandlerService {
760 handler: self.handler,
761 };
762 let wrapped = layer.layer(service);
763 let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
764
765 Prompt {
766 name: self.name,
767 title: self.title,
768 description: self.description,
769 icons: self.icons,
770 arguments: self.arguments,
771 handler: Arc::new(ServiceContextHandler {
772 service: Mutex::new(boxed),
773 }),
774 }
775 }
776}
777
778struct FnHandler<F> {
784 handler: F,
785}
786
787impl<F, Fut> PromptHandler for FnHandler<F>
788where
789 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
790 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
791{
792 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
793 Box::pin((self.handler)(arguments))
794 }
795}
796
797struct ContextAwareHandler<F> {
799 handler: F,
800}
801
802impl<F, Fut> PromptHandler for ContextAwareHandler<F>
803where
804 F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + 'static,
805 Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
806{
807 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
808 let ctx = RequestContext::new(RequestId::Number(0));
810 self.get_with_context(ctx, arguments)
811 }
812
813 fn get_with_context(
814 &self,
815 ctx: RequestContext,
816 arguments: HashMap<String, String>,
817 ) -> BoxFuture<'_, Result<GetPromptResult>> {
818 Box::pin((self.handler)(ctx, arguments))
819 }
820
821 fn uses_context(&self) -> bool {
822 true
823 }
824}
825
826struct ServiceHandler {
832 service: Mutex<BoxPromptService>,
833}
834
835impl PromptHandler for ServiceHandler {
836 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
837 Box::pin(async move {
838 let req = PromptRequest::with_arguments(arguments);
839 let mut service = self.service.lock().await.clone();
840 match service.ready().await {
841 Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
842 Err(e) => match e {},
843 }
844 })
845 }
846
847 fn get_with_context(
848 &self,
849 ctx: RequestContext,
850 arguments: HashMap<String, String>,
851 ) -> BoxFuture<'_, Result<GetPromptResult>> {
852 Box::pin(async move {
853 let req = PromptRequest::new(ctx, arguments);
854 let mut service = self.service.lock().await.clone();
855 match service.ready().await {
856 Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
857 Err(e) => match e {},
858 }
859 })
860 }
861}
862
863struct ServiceContextHandler {
865 service: Mutex<BoxPromptService>,
866}
867
868impl PromptHandler for ServiceContextHandler {
869 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
870 let ctx = RequestContext::new(RequestId::Number(0));
871 self.get_with_context(ctx, arguments)
872 }
873
874 fn get_with_context(
875 &self,
876 ctx: RequestContext,
877 arguments: HashMap<String, String>,
878 ) -> BoxFuture<'_, Result<GetPromptResult>> {
879 Box::pin(async move {
880 let req = PromptRequest::new(ctx, arguments);
881 let mut service = self.service.lock().await.clone();
882 match service.ready().await {
883 Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
884 Err(e) => match e {},
885 }
886 })
887 }
888
889 fn uses_context(&self) -> bool {
890 true
891 }
892}
893
894pub trait McpPrompt: Send + Sync + 'static {
956 const NAME: &'static str;
958 const DESCRIPTION: &'static str;
960
961 fn arguments(&self) -> Vec<PromptArgument> {
963 Vec::new()
964 }
965
966 fn get(
968 &self,
969 arguments: HashMap<String, String>,
970 ) -> impl Future<Output = Result<GetPromptResult>> + Send;
971
972 fn into_prompt(self) -> Prompt
974 where
975 Self: Sized,
976 {
977 let arguments = self.arguments();
978 let prompt = Arc::new(self);
979 Prompt {
980 name: Self::NAME.to_string(),
981 title: None,
982 description: Some(Self::DESCRIPTION.to_string()),
983 icons: None,
984 arguments,
985 handler: Arc::new(McpPromptHandler { prompt }),
986 }
987 }
988}
989
990struct McpPromptHandler<T: McpPrompt> {
992 prompt: Arc<T>,
993}
994
995impl<T: McpPrompt> PromptHandler for McpPromptHandler<T> {
996 fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
997 let prompt = self.prompt.clone();
998 Box::pin(async move { prompt.get(arguments).await })
999 }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 use super::*;
1005
1006 #[tokio::test]
1007 async fn test_builder_prompt() {
1008 let prompt = PromptBuilder::new("greet")
1009 .description("A greeting prompt")
1010 .required_arg("name", "Name to greet")
1011 .handler(|args| async move {
1012 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1013 Ok(GetPromptResult {
1014 description: Some("Greeting".to_string()),
1015 messages: vec![PromptMessage {
1016 role: PromptRole::User,
1017 content: Content::Text {
1018 text: format!("Hello, {}!", name),
1019 annotations: None,
1020 meta: None,
1021 },
1022 meta: None,
1023 }],
1024 meta: None,
1025 })
1026 })
1027 .build();
1028
1029 assert_eq!(prompt.name, "greet");
1030 assert_eq!(prompt.description.as_deref(), Some("A greeting prompt"));
1031 assert_eq!(prompt.arguments.len(), 1);
1032 assert!(prompt.arguments[0].required);
1033
1034 let mut args = HashMap::new();
1035 args.insert("name".to_string(), "Alice".to_string());
1036 let result = prompt.get(args).await.unwrap();
1037
1038 assert_eq!(result.messages.len(), 1);
1039 match &result.messages[0].content {
1040 Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1041 _ => panic!("Expected text content"),
1042 }
1043 }
1044
1045 #[tokio::test]
1046 async fn test_static_prompt() {
1047 let prompt = PromptBuilder::new("help")
1048 .description("Help prompt")
1049 .user_message("How can I help you today?");
1050
1051 let result = prompt.get(HashMap::new()).await.unwrap();
1052 assert_eq!(result.messages.len(), 1);
1053 match &result.messages[0].content {
1054 Content::Text { text, .. } => assert_eq!(text, "How can I help you today?"),
1055 _ => panic!("Expected text content"),
1056 }
1057 }
1058
1059 #[tokio::test]
1060 async fn test_trait_prompt() {
1061 struct TestPrompt;
1062
1063 impl McpPrompt for TestPrompt {
1064 const NAME: &'static str = "test";
1065 const DESCRIPTION: &'static str = "A test prompt";
1066
1067 fn arguments(&self) -> Vec<PromptArgument> {
1068 vec![PromptArgument {
1069 name: "input".to_string(),
1070 description: Some("Test input".to_string()),
1071 required: true,
1072 }]
1073 }
1074
1075 async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
1076 let input = args.get("input").map(|s| s.as_str()).unwrap_or("default");
1077 Ok(GetPromptResult {
1078 description: Some("Test".to_string()),
1079 messages: vec![PromptMessage {
1080 role: PromptRole::User,
1081 content: Content::Text {
1082 text: format!("Input: {}", input),
1083 annotations: None,
1084 meta: None,
1085 },
1086 meta: None,
1087 }],
1088 meta: None,
1089 })
1090 }
1091 }
1092
1093 let prompt = TestPrompt.into_prompt();
1094 assert_eq!(prompt.name, "test");
1095 assert_eq!(prompt.arguments.len(), 1);
1096
1097 let mut args = HashMap::new();
1098 args.insert("input".to_string(), "hello".to_string());
1099 let result = prompt.get(args).await.unwrap();
1100
1101 match &result.messages[0].content {
1102 Content::Text { text, .. } => assert_eq!(text, "Input: hello"),
1103 _ => panic!("Expected text content"),
1104 }
1105 }
1106
1107 #[test]
1108 fn test_prompt_definition() {
1109 let prompt = PromptBuilder::new("test")
1110 .description("Test description")
1111 .required_arg("arg1", "First arg")
1112 .optional_arg("arg2", "Second arg")
1113 .user_message("Test");
1114
1115 let def = prompt.definition();
1116 assert_eq!(def.name, "test");
1117 assert_eq!(def.description.as_deref(), Some("Test description"));
1118 assert_eq!(def.arguments.len(), 2);
1119 assert!(def.arguments[0].required);
1120 assert!(!def.arguments[1].required);
1121 }
1122
1123 #[tokio::test]
1124 async fn test_handler_with_context() {
1125 let prompt = PromptBuilder::new("context_prompt")
1126 .description("A prompt with context")
1127 .handler_with_context(|ctx: RequestContext, args| async move {
1128 let _ = ctx.is_cancelled();
1130 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1131 Ok(GetPromptResult {
1132 description: Some("Context prompt".to_string()),
1133 messages: vec![PromptMessage {
1134 role: PromptRole::User,
1135 content: Content::Text {
1136 text: format!("Hello, {}!", name),
1137 annotations: None,
1138 meta: None,
1139 },
1140 meta: None,
1141 }],
1142 meta: None,
1143 })
1144 })
1145 .build();
1146
1147 assert_eq!(prompt.name, "context_prompt");
1148 assert!(prompt.uses_context());
1149
1150 let ctx = RequestContext::new(RequestId::Number(1));
1151 let mut args = HashMap::new();
1152 args.insert("name".to_string(), "Alice".to_string());
1153 let result = prompt.get_with_context(ctx, args).await.unwrap();
1154
1155 match &result.messages[0].content {
1156 Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1157 _ => panic!("Expected text content"),
1158 }
1159 }
1160
1161 #[tokio::test]
1162 async fn test_prompt_with_timeout_layer() {
1163 use std::time::Duration;
1164 use tower::timeout::TimeoutLayer;
1165
1166 let prompt = PromptBuilder::new("timeout_prompt")
1167 .description("A prompt with timeout")
1168 .handler(|args: HashMap<String, String>| async move {
1169 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1170 Ok(GetPromptResult {
1171 description: Some("Timeout prompt".to_string()),
1172 messages: vec![PromptMessage {
1173 role: PromptRole::User,
1174 content: Content::Text {
1175 text: format!("Hello, {}!", name),
1176 annotations: None,
1177 meta: None,
1178 },
1179 meta: None,
1180 }],
1181 meta: None,
1182 })
1183 })
1184 .layer(TimeoutLayer::new(Duration::from_secs(5)));
1185
1186 assert_eq!(prompt.name, "timeout_prompt");
1187
1188 let mut args = HashMap::new();
1189 args.insert("name".to_string(), "Alice".to_string());
1190 let result = prompt.get(args).await.unwrap();
1191
1192 match &result.messages[0].content {
1193 Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1194 _ => panic!("Expected text content"),
1195 }
1196 }
1197
1198 #[tokio::test]
1199 async fn test_prompt_timeout_expires() {
1200 use std::time::Duration;
1201 use tower::timeout::TimeoutLayer;
1202
1203 let prompt = PromptBuilder::new("slow_prompt")
1204 .description("A slow prompt")
1205 .handler(|_args: HashMap<String, String>| async move {
1206 tokio::time::sleep(Duration::from_secs(1)).await;
1208 Ok(GetPromptResult {
1209 description: Some("Slow prompt".to_string()),
1210 messages: vec![PromptMessage {
1211 role: PromptRole::User,
1212 content: Content::Text {
1213 text: "This should not appear".to_string(),
1214 annotations: None,
1215 meta: None,
1216 },
1217 meta: None,
1218 }],
1219 meta: None,
1220 })
1221 })
1222 .layer(TimeoutLayer::new(Duration::from_millis(50)));
1223
1224 let result = prompt.get(HashMap::new()).await.unwrap();
1225
1226 assert!(result.description.as_ref().unwrap().contains("error"));
1228 match &result.messages[0].content {
1229 Content::Text { text, .. } => {
1230 assert!(text.contains("Error generating prompt"));
1231 }
1232 _ => panic!("Expected text content"),
1233 }
1234 }
1235
1236 #[tokio::test]
1237 async fn test_context_handler_with_layer() {
1238 use std::time::Duration;
1239 use tower::timeout::TimeoutLayer;
1240
1241 let prompt = PromptBuilder::new("context_timeout")
1242 .description("Context prompt with timeout")
1243 .handler_with_context(
1244 |_ctx: RequestContext, args: HashMap<String, String>| async move {
1245 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1246 Ok(GetPromptResult {
1247 description: Some("Context timeout".to_string()),
1248 messages: vec![PromptMessage {
1249 role: PromptRole::User,
1250 content: Content::Text {
1251 text: format!("Hello, {}!", name),
1252 annotations: None,
1253 meta: None,
1254 },
1255 meta: None,
1256 }],
1257 meta: None,
1258 })
1259 },
1260 )
1261 .layer(TimeoutLayer::new(Duration::from_secs(5)));
1262
1263 assert_eq!(prompt.name, "context_timeout");
1264 assert!(prompt.uses_context());
1265
1266 let ctx = RequestContext::new(RequestId::Number(1));
1267 let mut args = HashMap::new();
1268 args.insert("name".to_string(), "Bob".to_string());
1269 let result = prompt.get_with_context(ctx, args).await.unwrap();
1270
1271 match &result.messages[0].content {
1272 Content::Text { text, .. } => assert_eq!(text, "Hello, Bob!"),
1273 _ => panic!("Expected text content"),
1274 }
1275 }
1276
1277 #[test]
1278 fn test_prompt_request_construction() {
1279 let args: HashMap<String, String> = [("key".to_string(), "value".to_string())]
1280 .into_iter()
1281 .collect();
1282
1283 let req = PromptRequest::with_arguments(args.clone());
1284 assert_eq!(req.arguments.get("key"), Some(&"value".to_string()));
1285
1286 let ctx = RequestContext::new(RequestId::Number(42));
1287 let req2 = PromptRequest::new(ctx, args);
1288 assert_eq!(req2.arguments.get("key"), Some(&"value".to_string()));
1289 }
1290
1291 #[test]
1292 fn test_prompt_catch_error_clone() {
1293 let handler = PromptHandlerService {
1295 handler: |_args: HashMap<String, String>| async {
1296 Ok::<GetPromptResult, Error>(GetPromptResult {
1297 description: None,
1298 messages: vec![],
1299 meta: None,
1300 })
1301 },
1302 };
1303 let catch_error = PromptCatchError::new(handler);
1304 let _clone = catch_error.clone();
1305 }
1308}