1use anyhow::Result;
35
36use crate::{node::Node, template::Template};
37
38#[derive(Clone, Debug)]
40pub struct Prompt {
41 template: Template,
42 context: Option<tera::Context>,
43}
44
45#[deprecated(
46 since = "0.16.0",
47 note = "Use `Template` instead; they serve a more general purpose"
48)]
49pub type PromptTemplate = Template;
50
51impl Prompt {
52 #[must_use]
54 pub fn with_node(mut self, node: &Node) -> Self {
55 let context = self.context.get_or_insert_with(tera::Context::default);
56 context.insert("node", &node);
57 self
58 }
59
60 #[must_use]
62 pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
63 let context = self.context.get_or_insert_with(tera::Context::default);
64 context.extend(new_context.into());
65
66 self
67 }
68
69 #[must_use]
71 pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
72 let context = self.context.get_or_insert_with(tera::Context::default);
73 context.insert(key, &value.into());
74 self
75 }
76
77 pub async fn render(&self) -> Result<String> {
85 if let Some(context) = &self.context {
86 self.template.render(context).await
87 } else {
88 match &self.template {
89 Template::CompiledTemplate(_) => {
90 self.template.render(&tera::Context::default()).await
91 }
92 Template::String(string) => Ok(string.clone()),
93 Template::Static(string) => Ok((*string).to_string()),
94 }
95 }
96 }
97}
98
99impl From<&'static str> for Prompt {
100 fn from(prompt: &'static str) -> Self {
101 Prompt {
102 template: Template::Static(prompt),
103 context: None,
104 }
105 }
106}
107
108impl From<String> for Prompt {
109 fn from(prompt: String) -> Self {
110 Prompt {
111 template: Template::String(prompt),
112 context: None,
113 }
114 }
115}
116
117impl From<&Template> for Prompt {
118 fn from(template: &Template) -> Self {
119 Prompt {
120 template: template.clone(),
121 context: None,
122 }
123 }
124}
125
126#[cfg(test)]
127mod test {
128 use super::*;
129
130 #[tokio::test]
131 async fn test_prompt() {
132 let template = Template::try_compiled_from_str("hello {{world}}")
133 .await
134 .unwrap();
135 let prompt = template.to_prompt().with_context_value("world", "swiftide");
136 assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
137 }
138
139 #[tokio::test]
140 async fn test_prompt_with_node() {
141 let template = Template::try_compiled_from_str("hello {{node.chunk}}")
142 .await
143 .unwrap();
144 let node = Node::new("test");
145 let prompt = template.to_prompt().with_node(&node);
146 assert_eq!(prompt.render().await.unwrap(), "hello test");
147 }
148
149 #[tokio::test]
150 async fn test_one_off_from_string() {
151 let mut prompt: Prompt = "hello {{world}}".into();
152 prompt = prompt.with_context_value("world", "swiftide");
153
154 assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
155 }
156
157 #[tokio::test]
158 async fn test_extending_with_custom_repository() {
159 let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();
160
161 custom_tera
162 .add_raw_template("hello", "hello {{world}}")
163 .unwrap();
164
165 Template::extend(&custom_tera).await.unwrap();
166
167 let prompt = Template::from_compiled_template_name("hello")
168 .to_prompt()
169 .with_context_value("world", "swiftide");
170
171 assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
172 }
173
174 #[tokio::test]
175 async fn test_coercion_to_prompt() {
176 let raw: &str = "hello {{world}}";
178
179 let prompt: Prompt = raw.into();
180 assert_eq!(
181 prompt
182 .with_context_value("world", "swiftide")
183 .render()
184 .await
185 .unwrap(),
186 "hello swiftide"
187 );
188
189 let prompt: Prompt = raw.to_string().into();
190 assert_eq!(
191 prompt
192 .with_context_value("world", "swiftide")
193 .render()
194 .await
195 .unwrap(),
196 "hello swiftide"
197 );
198 }
199
200 #[tokio::test]
201 async fn test_coercion_to_template() {
202 let raw: &str = "hello {{world}}";
203
204 let prompt: Template = raw.into();
205 assert_eq!(
206 prompt
207 .to_prompt()
208 .with_context_value("world", "swiftide")
209 .render()
210 .await
211 .unwrap(),
212 "hello swiftide"
213 );
214
215 let prompt: Template = raw.to_string().into();
216 assert_eq!(
217 prompt
218 .to_prompt()
219 .with_context_value("world", "swiftide")
220 .render()
221 .await
222 .unwrap(),
223 "hello swiftide"
224 );
225 }
226
227 #[tokio::test]
228 async fn test_assume_rendered_unless_context_methods_called() {
229 let prompt = Prompt::from("hello {{world}}");
230
231 assert_eq!(prompt.render().await.unwrap(), "hello {{world}}");
232 }
233}