synwire_core/prompts/
template.rs1use std::collections::HashMap;
4
5use crate::error::SynwireError;
6use crate::prompts::PromptValue;
7
8#[derive(Debug, Clone, Default)]
10#[non_exhaustive]
11pub enum TemplateFormat {
12 #[default]
14 FString,
15}
16
17#[derive(Debug, Clone)]
31pub struct PromptTemplate {
32 template: String,
33 input_variables: Vec<String>,
34 _template_format: TemplateFormat,
35}
36
37impl PromptTemplate {
38 pub fn new(template: impl Into<String>, input_variables: Vec<String>) -> Self {
40 Self {
41 template: template.into(),
42 input_variables,
43 _template_format: TemplateFormat::default(),
44 }
45 }
46
47 pub fn input_variables(&self) -> &[String] {
49 &self.input_variables
50 }
51
52 pub fn format(&self, variables: &HashMap<String, String>) -> Result<String, SynwireError> {
58 let mut result = self.template.clone();
59 for var in &self.input_variables {
60 let value = variables.get(var).ok_or_else(|| SynwireError::Prompt {
61 message: format!("missing required variable '{var}'"),
62 })?;
63 result = result.replace(&format!("{{{var}}}"), value);
64 }
65 Ok(result)
66 }
67
68 pub fn to_prompt_value(
74 &self,
75 variables: &HashMap<String, String>,
76 ) -> Result<PromptValue, SynwireError> {
77 let text = self.format(variables)?;
78 Ok(PromptValue::String(text))
79 }
80}
81
82#[cfg(test)]
83#[allow(clippy::unwrap_used)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn test_prompt_template_format() {
89 let tpl = PromptTemplate::new("Hello {name}", vec!["name".into()]);
90 let mut vars = HashMap::new();
91 let _ = vars.insert("name".into(), "World".into());
92 assert_eq!(tpl.format(&vars).unwrap(), "Hello World");
93 }
94
95 #[test]
96 fn test_prompt_template_format_multiple_vars() {
97 let tpl = PromptTemplate::new(
98 "Hello {name}, you are {age}",
99 vec!["name".into(), "age".into()],
100 );
101 let mut vars = HashMap::new();
102 let _ = vars.insert("name".into(), "Alice".into());
103 let _ = vars.insert("age".into(), "30".into());
104 assert_eq!(tpl.format(&vars).unwrap(), "Hello Alice, you are 30");
105 }
106
107 #[test]
108 fn test_prompt_template_missing_variable() {
109 let tpl = PromptTemplate::new("Hello {name}", vec!["name".into()]);
110 let vars = HashMap::new();
111 let err = tpl.format(&vars).unwrap_err();
112 let msg = err.to_string();
113 assert!(
114 msg.contains("name"),
115 "error should mention the missing variable, got: {msg}"
116 );
117 }
118
119 #[test]
120 fn test_prompt_template_to_prompt_value() {
121 let tpl = PromptTemplate::new("Hello {name}", vec!["name".into()]);
122 let mut vars = HashMap::new();
123 let _ = vars.insert("name".into(), "World".into());
124 let pv = tpl.to_prompt_value(&vars).unwrap();
125 assert_eq!(pv.to_text(), "Hello World");
126 }
127
128 #[test]
129 fn test_input_variables_getter() {
130 let tpl = PromptTemplate::new("Hi {a} {b}", vec!["a".into(), "b".into()]);
131 assert_eq!(tpl.input_variables(), &["a", "b"]);
132 }
133}