swiftide_core/
prompt.rs

1//! Prompts templating and management
2//!
3//! Prompts are first class citizens in Swiftide and use [tera] under the hood. tera
4//! uses jinja style templates which allows for a lot of flexibility.
5//!
6//! Conceptually, a [Prompt] is something you send to i.e.
7//! [`SimplePrompt`][crate::SimplePrompt]. A prompt can have
8//! added context for substitution and other templating features.
9//!
10//! Transformers in Swiftide come with default prompts, and they can be customized or replaced as
11//! needed.
12//!
13//! [`Template`] can be added with [`Template::try_compiled_from_str`]. Prompts can also be
14//! created on the fly from anything that implements [`Into<String>`]. Compiled prompts are stored
15//! in an internal repository.
16//!
17//! Additionally, `Template::String` and `Template::Static` can be used to create
18//! templates on the fly as well.
19//!
20//! It's recommended to precompile your templates.
21//!
22//! # Example
23//!
24//! ```
25//! #[tokio::main]
26//! # async fn main() {
27//! # use swiftide_core::template::Template;
28//! let template = Template::try_compiled_from_str("hello {{world}}").await.unwrap();
29//! let prompt = template.to_prompt().with_context_value("world", "swiftide");
30//!
31//! assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
32//! # }
33//! ```
34use std::{
35    borrow::Cow,
36    sync::{LazyLock, RwLock},
37};
38
39use anyhow::{Context as _, Result};
40use tera::Tera;
41
42use crate::node::Node;
43
44/// A Prompt can be used with large language models to prompt.
45#[derive(Clone, Debug)]
46pub struct Prompt {
47    template_ref: TemplateRef,
48    context: Option<tera::Context>,
49}
50
51/// References a to be rendered template
52/// Either a one-off template or a tera template
53#[derive(Clone, Debug)]
54enum TemplateRef {
55    OneOff(String),
56    Tera(String),
57}
58
59pub static SWIFTIDE_TERA: LazyLock<RwLock<Tera>> = LazyLock::new(|| RwLock::new(Tera::default()));
60
61impl Prompt {
62    /// Extend the swiftide repository with another Tera instance.
63    ///
64    /// You can use this to add your own templates, functions and partials.
65    ///
66    /// # Panics
67    ///
68    /// Panics if the `RWLock` is poisoned.
69    ///
70    /// # Errors
71    ///
72    /// Errors if the `Tera` instance cannot be extended.
73    pub fn extend(other: &Tera) -> Result<()> {
74        let mut swiftide_tera = SWIFTIDE_TERA.write().unwrap();
75        swiftide_tera.extend(other)?;
76        Ok(())
77    }
78
79    /// Create a new prompt from a compiled template that is present in the Tera repository
80    pub fn from_compiled_template(name: impl Into<String>) -> Prompt {
81        Prompt {
82            template_ref: TemplateRef::Tera(name.into()),
83            context: None,
84        }
85    }
86
87    /// Adds an `ingestion::Node` to the context of the Prompt
88    #[must_use]
89    pub fn with_node(mut self, node: &Node) -> Self {
90        let context = self.context.get_or_insert_with(tera::Context::default);
91        context.insert("node", &node);
92        self
93    }
94
95    /// Adds anything that implements [Into<tera::Context>], like `Serialize` to the Prompt
96    #[must_use]
97    pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
98        let context = self.context.get_or_insert_with(tera::Context::default);
99        context.extend(new_context.into());
100
101        self
102    }
103
104    /// Adds a key-value pair to the context of the Prompt
105    #[must_use]
106    pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
107        let context = self.context.get_or_insert_with(tera::Context::default);
108        context.insert(key, &value.into());
109        self
110    }
111
112    /// Renders a prompt
113    ///
114    /// If no context is provided, the prompt will be rendered as is.
115    ///
116    /// # Errors
117    ///
118    /// See `Template::render`
119    ///
120    /// # Panics
121    ///
122    /// Panics if the `RWLock` is poisoned.
123    pub fn render(&self) -> Result<String> {
124        if self.context.is_none() {
125            if let TemplateRef::OneOff(ref template) = self.template_ref {
126                return Ok(template.to_string());
127            }
128        }
129
130        let context: Cow<'_, tera::Context> = self
131            .context
132            .as_ref()
133            .map_or_else(|| Cow::Owned(tera::Context::default()), Cow::Borrowed);
134
135        match &self.template_ref {
136            TemplateRef::OneOff(template) => {
137                tera::Tera::one_off(template.as_ref(), &context, false)
138                    .context("Failed to render one-off template")
139            }
140            TemplateRef::Tera(template) => SWIFTIDE_TERA
141                .read()
142                .unwrap()
143                .render(template.as_ref(), &context)
144                .context("Failed to render template"),
145        }
146    }
147}
148
149impl From<&str> for Prompt {
150    fn from(prompt: &str) -> Self {
151        Prompt {
152            template_ref: TemplateRef::OneOff(prompt.into()),
153            context: None,
154        }
155    }
156}
157
158impl From<String> for Prompt {
159    fn from(prompt: String) -> Self {
160        Prompt {
161            template_ref: TemplateRef::OneOff(prompt),
162            context: None,
163        }
164    }
165}
166
167#[cfg(test)]
168mod test {
169    use super::*;
170
171    #[tokio::test]
172    async fn test_prompt() {
173        let prompt: Prompt = "hello {{world}}".into();
174        let prompt = prompt.with_context_value("world", "swiftide");
175        assert_eq!(prompt.render().unwrap(), "hello swiftide");
176    }
177
178    #[tokio::test]
179    async fn test_prompt_with_node() {
180        let prompt: Prompt = "hello {{node.chunk}}".into();
181        let node = Node::new("test");
182        let prompt = prompt.with_node(&node);
183        assert_eq!(prompt.render().unwrap(), "hello test");
184    }
185
186    #[tokio::test]
187    async fn test_one_off_from_string() {
188        let mut prompt: Prompt = "hello {{world}}".into();
189        prompt = prompt.with_context_value("world", "swiftide");
190
191        assert_eq!(prompt.render().unwrap(), "hello swiftide");
192    }
193
194    #[tokio::test]
195    async fn test_extending_with_custom_repository() {
196        let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();
197
198        custom_tera
199            .add_raw_template("hello", "hello {{world}}")
200            .unwrap();
201
202        Prompt::extend(&custom_tera).unwrap();
203
204        let prompt =
205            Prompt::from_compiled_template("hello").with_context_value("world", "swiftide");
206
207        assert_eq!(prompt.render().unwrap(), "hello swiftide");
208    }
209
210    #[tokio::test]
211    async fn test_coercion_to_prompt() {
212        // str
213        let raw: &str = "hello {{world}}";
214
215        let prompt: Prompt = raw.into();
216        assert_eq!(
217            prompt
218                .with_context_value("world", "swiftide")
219                .render()
220                .unwrap(),
221            "hello swiftide"
222        );
223
224        let prompt: Prompt = raw.to_string().into();
225        assert_eq!(
226            prompt
227                .with_context_value("world", "swiftide")
228                .render()
229                .unwrap(),
230            "hello swiftide"
231        );
232    }
233
234    #[tokio::test]
235    async fn test_assume_rendered_unless_context_methods_called() {
236        let prompt = Prompt::from("hello {{world}}");
237
238        assert_eq!(prompt.render().unwrap(), "hello {{world}}");
239    }
240}