1use std::{
35 borrow::Cow,
36 sync::{LazyLock, RwLock},
37};
38
39use anyhow::{Context as _, Result};
40use tera::Tera;
41
42use crate::node::Node;
43
44#[derive(Clone, Debug)]
46pub struct Prompt {
47 template_ref: TemplateRef,
48 context: Option<tera::Context>,
49}
50
51#[derive(Clone, Debug)]
54enum TemplateRef {
55 OneOff(String),
56 Tera(String),
57}
58
59pub static SWIFTIDE_TERA: LazyLock<RwLock<Tera>> = LazyLock::new(|| RwLock::new(Tera::default()));
60
61impl Prompt {
62 pub fn extend(other: &Tera) -> Result<()> {
74 let mut swiftide_tera = SWIFTIDE_TERA.write().unwrap();
75 swiftide_tera.extend(other)?;
76 Ok(())
77 }
78
79 pub fn from_compiled_template(name: impl Into<String>) -> Prompt {
81 Prompt {
82 template_ref: TemplateRef::Tera(name.into()),
83 context: None,
84 }
85 }
86
87 #[must_use]
89 pub fn with_node(mut self, node: &Node) -> Self {
90 let context = self.context.get_or_insert_with(tera::Context::default);
91 context.insert("node", &node);
92 self
93 }
94
95 #[must_use]
97 pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
98 let context = self.context.get_or_insert_with(tera::Context::default);
99 context.extend(new_context.into());
100
101 self
102 }
103
104 #[must_use]
106 pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
107 let context = self.context.get_or_insert_with(tera::Context::default);
108 context.insert(key, &value.into());
109 self
110 }
111
112 pub fn render(&self) -> Result<String> {
124 if self.context.is_none() {
125 if let TemplateRef::OneOff(ref template) = self.template_ref {
126 return Ok(template.to_string());
127 }
128 }
129
130 let context: Cow<'_, tera::Context> = self
131 .context
132 .as_ref()
133 .map_or_else(|| Cow::Owned(tera::Context::default()), Cow::Borrowed);
134
135 match &self.template_ref {
136 TemplateRef::OneOff(template) => {
137 tera::Tera::one_off(template.as_ref(), &context, false)
138 .context("Failed to render one-off template")
139 }
140 TemplateRef::Tera(template) => SWIFTIDE_TERA
141 .read()
142 .unwrap()
143 .render(template.as_ref(), &context)
144 .context("Failed to render template"),
145 }
146 }
147}
148
149impl From<&str> for Prompt {
150 fn from(prompt: &str) -> Self {
151 Prompt {
152 template_ref: TemplateRef::OneOff(prompt.into()),
153 context: None,
154 }
155 }
156}
157
158impl From<String> for Prompt {
159 fn from(prompt: String) -> Self {
160 Prompt {
161 template_ref: TemplateRef::OneOff(prompt),
162 context: None,
163 }
164 }
165}
166
167#[cfg(test)]
168mod test {
169 use super::*;
170
171 #[tokio::test]
172 async fn test_prompt() {
173 let prompt: Prompt = "hello {{world}}".into();
174 let prompt = prompt.with_context_value("world", "swiftide");
175 assert_eq!(prompt.render().unwrap(), "hello swiftide");
176 }
177
178 #[tokio::test]
179 async fn test_prompt_with_node() {
180 let prompt: Prompt = "hello {{node.chunk}}".into();
181 let node = Node::new("test");
182 let prompt = prompt.with_node(&node);
183 assert_eq!(prompt.render().unwrap(), "hello test");
184 }
185
186 #[tokio::test]
187 async fn test_one_off_from_string() {
188 let mut prompt: Prompt = "hello {{world}}".into();
189 prompt = prompt.with_context_value("world", "swiftide");
190
191 assert_eq!(prompt.render().unwrap(), "hello swiftide");
192 }
193
194 #[tokio::test]
195 async fn test_extending_with_custom_repository() {
196 let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();
197
198 custom_tera
199 .add_raw_template("hello", "hello {{world}}")
200 .unwrap();
201
202 Prompt::extend(&custom_tera).unwrap();
203
204 let prompt =
205 Prompt::from_compiled_template("hello").with_context_value("world", "swiftide");
206
207 assert_eq!(prompt.render().unwrap(), "hello swiftide");
208 }
209
210 #[tokio::test]
211 async fn test_coercion_to_prompt() {
212 let raw: &str = "hello {{world}}";
214
215 let prompt: Prompt = raw.into();
216 assert_eq!(
217 prompt
218 .with_context_value("world", "swiftide")
219 .render()
220 .unwrap(),
221 "hello swiftide"
222 );
223
224 let prompt: Prompt = raw.to_string().into();
225 assert_eq!(
226 prompt
227 .with_context_value("world", "swiftide")
228 .render()
229 .unwrap(),
230 "hello swiftide"
231 );
232 }
233
234 #[tokio::test]
235 async fn test_assume_rendered_unless_context_methods_called() {
236 let prompt = Prompt::from("hello {{world}}");
237
238 assert_eq!(prompt.render().unwrap(), "hello {{world}}");
239 }
240}