swiftide_core/
prompt.rs

1//! Prompts templating and management
2//!
3//! Prompts are first class citizens in Swiftide and use [tera] under the hood. tera
4//! uses jinja style templates which allows for a lot of flexibility.
5//!
6//! Conceptually, a [Prompt] is something you send to i.e.
7//! [`SimplePrompt`][crate::SimplePrompt]. A prompt can have
8//! added context for substitution and other templating features.
9//!
10//! Transformers in Swiftide come with default prompts, and they can be customized or replaced as
11//! needed.
12//!
13//! [`Template`] can be added with [`Template::try_compiled_from_str`]. Prompts can also be
14//! created on the fly from anything that implements [`Into<String>`]. Compiled prompts are stored in
15//! an internal repository.
16//!
17//! Additionally, `Template::String` and `Template::Static` can be used to create
18//! templates on the fly as well.
19//!
20//! It's recommended to precompile your templates.
21//!
22//! # Example
23//!
24//! ```
25//! #[tokio::main]
26//! # async fn main() {
27//! # use swiftide_core::template::Template;
28//! let template = Template::try_compiled_from_str("hello {{world}}").await.unwrap();
29//! let prompt = template.to_prompt().with_context_value("world", "swiftide");
30//!
31//! assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
32//! # }
33//! ```
34use anyhow::Result;
35
36use crate::{node::Node, template::Template};
37
38/// A Prompt can be used with large language models to prompt.
39#[derive(Clone, Debug)]
40pub struct Prompt {
41    template: Template,
42    context: Option<tera::Context>,
43}
44
45#[deprecated(
46    since = "0.16.0",
47    note = "Use `Template` instead; they serve a more general purpose"
48)]
49pub type PromptTemplate = Template;
50
51impl Prompt {
52    /// Adds an `ingestion::Node` to the context of the Prompt
53    #[must_use]
54    pub fn with_node(mut self, node: &Node) -> Self {
55        let context = self.context.get_or_insert_with(tera::Context::default);
56        context.insert("node", &node);
57        self
58    }
59
60    /// Adds anything that implements [Into<tera::Context>], like `Serialize` to the Prompt
61    #[must_use]
62    pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
63        let context = self.context.get_or_insert_with(tera::Context::default);
64        context.extend(new_context.into());
65
66        self
67    }
68
69    /// Adds a key-value pair to the context of the Prompt
70    #[must_use]
71    pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
72        let context = self.context.get_or_insert_with(tera::Context::default);
73        context.insert(key, &value.into());
74        self
75    }
76
77    /// Renders a prompt
78    ///
79    /// If no context is provided, the prompt will be rendered as is.
80    ///
81    /// # Errors
82    ///
83    /// See `Template::render`
84    pub async fn render(&self) -> Result<String> {
85        if let Some(context) = &self.context {
86            self.template.render(context).await
87        } else {
88            match &self.template {
89                Template::CompiledTemplate(_) => {
90                    self.template.render(&tera::Context::default()).await
91                }
92                Template::String(string) => Ok(string.clone()),
93                Template::Static(string) => Ok((*string).to_string()),
94            }
95        }
96    }
97}
98
99impl From<&'static str> for Prompt {
100    fn from(prompt: &'static str) -> Self {
101        Prompt {
102            template: Template::Static(prompt),
103            context: None,
104        }
105    }
106}
107
108impl From<String> for Prompt {
109    fn from(prompt: String) -> Self {
110        Prompt {
111            template: Template::String(prompt),
112            context: None,
113        }
114    }
115}
116
117impl From<&Template> for Prompt {
118    fn from(template: &Template) -> Self {
119        Prompt {
120            template: template.clone(),
121            context: None,
122        }
123    }
124}
125
126#[cfg(test)]
127mod test {
128    use super::*;
129
130    #[tokio::test]
131    async fn test_prompt() {
132        let template = Template::try_compiled_from_str("hello {{world}}")
133            .await
134            .unwrap();
135        let prompt = template.to_prompt().with_context_value("world", "swiftide");
136        assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
137    }
138
139    #[tokio::test]
140    async fn test_prompt_with_node() {
141        let template = Template::try_compiled_from_str("hello {{node.chunk}}")
142            .await
143            .unwrap();
144        let node = Node::new("test");
145        let prompt = template.to_prompt().with_node(&node);
146        assert_eq!(prompt.render().await.unwrap(), "hello test");
147    }
148
149    #[tokio::test]
150    async fn test_one_off_from_string() {
151        let mut prompt: Prompt = "hello {{world}}".into();
152        prompt = prompt.with_context_value("world", "swiftide");
153
154        assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
155    }
156
157    #[tokio::test]
158    async fn test_extending_with_custom_repository() {
159        let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();
160
161        custom_tera
162            .add_raw_template("hello", "hello {{world}}")
163            .unwrap();
164
165        Template::extend(&custom_tera).await.unwrap();
166
167        let prompt = Template::from_compiled_template_name("hello")
168            .to_prompt()
169            .with_context_value("world", "swiftide");
170
171        assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
172    }
173
174    #[tokio::test]
175    async fn test_coercion_to_prompt() {
176        // str
177        let raw: &str = "hello {{world}}";
178
179        let prompt: Prompt = raw.into();
180        assert_eq!(
181            prompt
182                .with_context_value("world", "swiftide")
183                .render()
184                .await
185                .unwrap(),
186            "hello swiftide"
187        );
188
189        let prompt: Prompt = raw.to_string().into();
190        assert_eq!(
191            prompt
192                .with_context_value("world", "swiftide")
193                .render()
194                .await
195                .unwrap(),
196            "hello swiftide"
197        );
198    }
199
200    #[tokio::test]
201    async fn test_coercion_to_template() {
202        let raw: &str = "hello {{world}}";
203
204        let prompt: Template = raw.into();
205        assert_eq!(
206            prompt
207                .to_prompt()
208                .with_context_value("world", "swiftide")
209                .render()
210                .await
211                .unwrap(),
212            "hello swiftide"
213        );
214
215        let prompt: Template = raw.to_string().into();
216        assert_eq!(
217            prompt
218                .to_prompt()
219                .with_context_value("world", "swiftide")
220                .render()
221                .await
222                .unwrap(),
223            "hello swiftide"
224        );
225    }
226
227    #[tokio::test]
228    async fn test_assume_rendered_unless_context_methods_called() {
229        let prompt = Prompt::from("hello {{world}}");
230
231        assert_eq!(prompt.render().await.unwrap(), "hello {{world}}");
232    }
233}