Skip to main content

synaptic_prompts/
template.rs

1use std::collections::HashMap;
2
3use thiserror::Error;
4
5#[derive(Debug, Error, PartialEq, Eq)]
6pub enum PromptError {
7    #[error("missing variable: {0}")]
8    MissingVariable(String),
9}
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct PromptTemplate {
13    template: String,
14    partial_variables: HashMap<String, String>,
15}
16
17impl PromptTemplate {
18    pub fn new(template: impl Into<String>) -> Self {
19        Self {
20            template: template.into(),
21            partial_variables: HashMap::new(),
22        }
23    }
24
25    /// Set a partial variable that will be used as a default during rendering.
26    /// Provided values at render time take precedence over partial variables.
27    pub fn with_partial(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
28        self.partial_variables.insert(key.into(), value.into());
29        self
30    }
31
32    pub fn render(&self, values: &HashMap<String, String>) -> Result<String, PromptError> {
33        // Merge partial_variables with provided values; provided values take precedence
34        let mut merged = self.partial_variables.clone();
35        for (k, v) in values {
36            merged.insert(k.clone(), v.clone());
37        }
38
39        let mut output = String::with_capacity(self.template.len());
40        let mut rest = self.template.as_str();
41
42        while let Some(start) = rest.find("{{") {
43            output.push_str(&rest[..start]);
44            let after_start = &rest[start + 2..];
45            if let Some(end) = after_start.find("}}") {
46                let key = after_start[..end].trim();
47                let value = merged
48                    .get(key)
49                    .ok_or_else(|| PromptError::MissingVariable(key.to_string()))?;
50                output.push_str(value);
51                rest = &after_start[end + 2..];
52            } else {
53                output.push_str(&rest[start..]);
54                rest = "";
55                break;
56            }
57        }
58
59        output.push_str(rest);
60        Ok(output)
61    }
62}