1use std::{
34 borrow::Cow,
35 sync::{LazyLock, RwLock},
36};
37
38use anyhow::{Context as _, Result};
39use tera::Tera;
40
41use crate::node::Node;
42
43#[derive(Clone, Debug)]
45pub struct Prompt {
46 template_ref: TemplateRef,
47 context: Option<tera::Context>,
48}
49
50#[derive(Clone, Debug)]
53enum TemplateRef {
54 OneOff(Cow<'static, str>),
55 Tera(Cow<'static, str>),
56}
57
58pub static SWIFTIDE_TERA: LazyLock<RwLock<Tera>> = LazyLock::new(|| RwLock::new(Tera::default()));
59
60impl Prompt {
61 pub fn extend(other: &Tera) -> Result<()> {
73 let mut swiftide_tera = SWIFTIDE_TERA.write().unwrap();
74 swiftide_tera.extend(other)?;
75 Ok(())
76 }
77
78 pub fn from_compiled_template(name: impl Into<Cow<'static, str>>) -> Prompt {
80 Prompt {
81 template_ref: TemplateRef::Tera(name.into()),
82 context: None,
83 }
84 }
85
86 #[must_use]
88 pub fn with_node(mut self, node: &Node) -> Self {
89 let context = self.context.get_or_insert_with(tera::Context::default);
90 context.insert("node", &node);
91 self
92 }
93
94 #[must_use]
96 pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
97 let context = self.context.get_or_insert_with(tera::Context::default);
98 context.extend(new_context.into());
99
100 self
101 }
102
103 #[must_use]
105 pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
106 let context = self.context.get_or_insert_with(tera::Context::default);
107 context.insert(key, &value.into());
108 self
109 }
110
111 pub fn render(&self) -> Result<String> {
123 if self.context.is_none()
124 && let TemplateRef::OneOff(ref template) = self.template_ref
125 {
126 return Ok(template.to_string());
127 }
128
129 let context: Cow<'_, tera::Context> = self
130 .context
131 .as_ref()
132 .map_or_else(|| Cow::Owned(tera::Context::default()), Cow::Borrowed);
133
134 match &self.template_ref {
135 TemplateRef::OneOff(template) => {
136 tera::Tera::one_off(template.as_ref(), &context, false)
137 .context("Failed to render one-off template")
138 }
139 TemplateRef::Tera(template) => SWIFTIDE_TERA
140 .read()
141 .unwrap()
142 .render(template.as_ref(), &context)
143 .context("Failed to render template"),
144 }
145 }
146}
147
148impl From<&'static str> for Prompt {
149 fn from(prompt: &'static str) -> Self {
150 Prompt {
151 template_ref: TemplateRef::OneOff(prompt.into()),
152 context: None,
153 }
154 }
155}
156
157impl From<String> for Prompt {
158 fn from(prompt: String) -> Self {
159 Prompt {
160 template_ref: TemplateRef::OneOff(prompt.into()),
161 context: None,
162 }
163 }
164}
165
166#[cfg(test)]
167mod test {
168 use super::*;
169
170 #[tokio::test]
171 async fn test_prompt() {
172 let prompt: Prompt = "hello {{world}}".into();
173 let prompt = prompt.with_context_value("world", "swiftide");
174 assert_eq!(prompt.render().unwrap(), "hello swiftide");
175 }
176
177 #[tokio::test]
178 async fn test_prompt_with_node() {
179 let prompt: Prompt = "hello {{node.chunk}}".into();
180 let node = Node::new("test");
181 let prompt = prompt.with_node(&node);
182 assert_eq!(prompt.render().unwrap(), "hello test");
183 }
184
185 #[tokio::test]
186 async fn test_one_off_from_string() {
187 let mut prompt: Prompt = "hello {{world}}".into();
188 prompt = prompt.with_context_value("world", "swiftide");
189
190 assert_eq!(prompt.render().unwrap(), "hello swiftide");
191 }
192
193 #[tokio::test]
194 async fn test_extending_with_custom_repository() {
195 let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();
196
197 custom_tera
198 .add_raw_template("hello", "hello {{world}}")
199 .unwrap();
200
201 Prompt::extend(&custom_tera).unwrap();
202
203 let prompt =
204 Prompt::from_compiled_template("hello").with_context_value("world", "swiftide");
205
206 assert_eq!(prompt.render().unwrap(), "hello swiftide");
207 }
208
209 #[tokio::test]
210 async fn test_coercion_to_prompt() {
211 let raw: &str = "hello {{world}}";
213
214 let prompt: Prompt = raw.into();
215 assert_eq!(
216 prompt
217 .with_context_value("world", "swiftide")
218 .render()
219 .unwrap(),
220 "hello swiftide"
221 );
222
223 let prompt: Prompt = raw.to_string().into();
224 assert_eq!(
225 prompt
226 .with_context_value("world", "swiftide")
227 .render()
228 .unwrap(),
229 "hello swiftide"
230 );
231 }
232
233 #[tokio::test]
234 async fn test_assume_rendered_unless_context_methods_called() {
235 let prompt = Prompt::from("hello {{world}}");
236
237 assert_eq!(prompt.render().unwrap(), "hello {{world}}");
238 }
239}