swiftide_core/
prompt.rs

1//! Prompts templating and management
2//!
3//! Prompts are first class citizens in Swiftide and use [tera] under the hood. tera
4//! uses jinja style templates which allows for a lot of flexibility.
5//!
6//! Conceptually, a [Prompt] is something you send to i.e.
7//! [`SimplePrompt`][crate::SimplePrompt]. A prompt can have
8//! added context for substitution and other templating features.
9//!
10//! Transformers in Swiftide come with default prompts, and they can be customized or replaced as
11//! needed.
12//!
13//! [`Template`] can be added with [`Template::try_compiled_from_str`]. Prompts can also be
14//! created on the fly from anything that implements [`Into<String>`]. Compiled prompts are stored
15//! in an internal repository.
16//!
17//! Additionally, `Template::String` and `Template::Static` can be used to create
18//! templates on the fly as well.
19//!
20//! It's recommended to precompile your templates.
21//!
22//! # Example
23//!
24//! ```
25//! #[tokio::main]
26//! # async fn main() {
27//! # use swiftide_core::prompt::Prompt;
28//! let prompt = Prompt::from("hello {{world}}").with_context_value("world", "swiftide");
29//!
30//! assert_eq!(prompt.render().unwrap(), "hello swiftide");
31//! # }
32//! ```
33use std::{
34    borrow::Cow,
35    sync::{LazyLock, RwLock},
36};
37
38use anyhow::{Context as _, Result};
39use tera::Tera;
40
41use crate::node::Node;
42
43/// A Prompt can be used with large language models to prompt.
44#[derive(Clone, Debug)]
45pub struct Prompt {
46    template_ref: TemplateRef,
47    context: Option<tera::Context>,
48}
49
50/// References a to be rendered template
51/// Either a one-off template or a tera template
52#[derive(Clone, Debug)]
53enum TemplateRef {
54    OneOff(Cow<'static, str>),
55    Tera(Cow<'static, str>),
56}
57
58pub static SWIFTIDE_TERA: LazyLock<RwLock<Tera>> = LazyLock::new(|| RwLock::new(Tera::default()));
59
60impl Prompt {
61    /// Extend the swiftide repository with another Tera instance.
62    ///
63    /// You can use this to add your own templates, functions and partials.
64    ///
65    /// # Panics
66    ///
67    /// Panics if the `RWLock` is poisoned.
68    ///
69    /// # Errors
70    ///
71    /// Errors if the `Tera` instance cannot be extended.
72    pub fn extend(other: &Tera) -> Result<()> {
73        let mut swiftide_tera = SWIFTIDE_TERA.write().unwrap();
74        swiftide_tera.extend(other)?;
75        Ok(())
76    }
77
78    /// Create a new prompt from a compiled template that is present in the Tera repository
79    pub fn from_compiled_template(name: impl Into<Cow<'static, str>>) -> Prompt {
80        Prompt {
81            template_ref: TemplateRef::Tera(name.into()),
82            context: None,
83        }
84    }
85
86    /// Adds an `ingestion::Node` to the context of the Prompt
87    #[must_use]
88    pub fn with_node(mut self, node: &Node) -> Self {
89        let context = self.context.get_or_insert_with(tera::Context::default);
90        context.insert("node", &node);
91        self
92    }
93
94    /// Adds anything that implements [Into<tera::Context>], like `Serialize` to the Prompt
95    #[must_use]
96    pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
97        let context = self.context.get_or_insert_with(tera::Context::default);
98        context.extend(new_context.into());
99
100        self
101    }
102
103    /// Adds a key-value pair to the context of the Prompt
104    #[must_use]
105    pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
106        let context = self.context.get_or_insert_with(tera::Context::default);
107        context.insert(key, &value.into());
108        self
109    }
110
111    /// Renders a prompt
112    ///
113    /// If no context is provided, the prompt will be rendered as is.
114    ///
115    /// # Errors
116    ///
117    /// See `Template::render`
118    ///
119    /// # Panics
120    ///
121    /// Panics if the `RWLock` is poisoned.
122    pub fn render(&self) -> Result<String> {
123        if self.context.is_none()
124            && let TemplateRef::OneOff(ref template) = self.template_ref
125        {
126            return Ok(template.to_string());
127        }
128
129        let context: Cow<'_, tera::Context> = self
130            .context
131            .as_ref()
132            .map_or_else(|| Cow::Owned(tera::Context::default()), Cow::Borrowed);
133
134        match &self.template_ref {
135            TemplateRef::OneOff(template) => {
136                tera::Tera::one_off(template.as_ref(), &context, false)
137                    .context("Failed to render one-off template")
138            }
139            TemplateRef::Tera(template) => SWIFTIDE_TERA
140                .read()
141                .unwrap()
142                .render(template.as_ref(), &context)
143                .context("Failed to render template"),
144        }
145    }
146}
147
148impl From<&'static str> for Prompt {
149    fn from(prompt: &'static str) -> Self {
150        Prompt {
151            template_ref: TemplateRef::OneOff(prompt.into()),
152            context: None,
153        }
154    }
155}
156
157impl From<String> for Prompt {
158    fn from(prompt: String) -> Self {
159        Prompt {
160            template_ref: TemplateRef::OneOff(prompt.into()),
161            context: None,
162        }
163    }
164}
165
166#[cfg(test)]
167mod test {
168    use super::*;
169
170    #[tokio::test]
171    async fn test_prompt() {
172        let prompt: Prompt = "hello {{world}}".into();
173        let prompt = prompt.with_context_value("world", "swiftide");
174        assert_eq!(prompt.render().unwrap(), "hello swiftide");
175    }
176
177    #[tokio::test]
178    async fn test_prompt_with_node() {
179        let prompt: Prompt = "hello {{node.chunk}}".into();
180        let node = Node::new("test");
181        let prompt = prompt.with_node(&node);
182        assert_eq!(prompt.render().unwrap(), "hello test");
183    }
184
185    #[tokio::test]
186    async fn test_one_off_from_string() {
187        let mut prompt: Prompt = "hello {{world}}".into();
188        prompt = prompt.with_context_value("world", "swiftide");
189
190        assert_eq!(prompt.render().unwrap(), "hello swiftide");
191    }
192
193    #[tokio::test]
194    async fn test_extending_with_custom_repository() {
195        let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();
196
197        custom_tera
198            .add_raw_template("hello", "hello {{world}}")
199            .unwrap();
200
201        Prompt::extend(&custom_tera).unwrap();
202
203        let prompt =
204            Prompt::from_compiled_template("hello").with_context_value("world", "swiftide");
205
206        assert_eq!(prompt.render().unwrap(), "hello swiftide");
207    }
208
209    #[tokio::test]
210    async fn test_coercion_to_prompt() {
211        // str
212        let raw: &str = "hello {{world}}";
213
214        let prompt: Prompt = raw.into();
215        assert_eq!(
216            prompt
217                .with_context_value("world", "swiftide")
218                .render()
219                .unwrap(),
220            "hello swiftide"
221        );
222
223        let prompt: Prompt = raw.to_string().into();
224        assert_eq!(
225            prompt
226                .with_context_value("world", "swiftide")
227                .render()
228                .unwrap(),
229            "hello swiftide"
230        );
231    }
232
233    #[tokio::test]
234    async fn test_assume_rendered_unless_context_methods_called() {
235        let prompt = Prompt::from("hello {{world}}");
236
237        assert_eq!(prompt.render().unwrap(), "hello {{world}}");
238    }
239}