1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11pub type Result<T> = std::result::Result<T, GuidanceError>;
13
14#[derive(Debug, thiserror::Error)]
16pub enum GuidanceError {
17 #[error("Template error: {0}")]
18 Template(String),
19
20 #[error("Variable not found: {0}")]
21 VariableNotFound(String),
22
23 #[error("Invalid format: {0}")]
24 InvalidFormat(String),
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PromptTemplate {
30 pub name: String,
32
33 pub content: String,
35
36 #[serde(default)]
38 pub defaults: HashMap<String, String>,
39}
40
41impl PromptTemplate {
42 pub fn new(name: impl Into<String>, content: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 content: content.into(),
47 defaults: HashMap::new(),
48 }
49 }
50
51 pub fn with_default(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
53 self.defaults.insert(key.into(), value.into());
54 self
55 }
56
57 pub fn render(&self, variables: &HashMap<String, String>) -> Result<String> {
59 let mut result = self.content.clone();
60
61 let mut all_vars = self.defaults.clone();
63 all_vars.extend(variables.clone());
64
65 for (key, value) in all_vars {
67 let placeholder = format!("{{{{{}}}}}", key);
68 result = result.replace(&placeholder, &value);
69 }
70
71 if result.contains("{{") && result.contains("}}") {
73 return Err(GuidanceError::Template(
74 "Template contains unresolved placeholders".to_string(),
75 ));
76 }
77
78 Ok(result)
79 }
80}
81
82#[derive(Debug, Default)]
84pub struct TemplateRegistry {
85 templates: HashMap<String, PromptTemplate>,
86}
87
88impl TemplateRegistry {
89 pub fn new() -> Self {
91 Self::default()
92 }
93
94 pub fn register(&mut self, template: PromptTemplate) {
96 self.templates.insert(template.name.clone(), template);
97 }
98
99 pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
101 self.templates.get(name)
102 }
103
104 pub fn render(&self, name: &str, variables: &HashMap<String, String>) -> Result<String> {
106 let template = self
107 .get(name)
108 .ok_or_else(|| GuidanceError::VariableNotFound(name.to_string()))?;
109 template.render(variables)
110 }
111
112 pub fn list(&self) -> Vec<String> {
114 self.templates.keys().cloned().collect()
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_template_creation() {
124 let template = PromptTemplate::new("test", "Hello {{name}}!");
125 assert_eq!(template.name, "test");
126 assert_eq!(template.content, "Hello {{name}}!");
127 }
128
129 #[test]
130 fn test_template_render() {
131 let template = PromptTemplate::new("test", "Hello {{name}}!");
132 let mut vars = HashMap::new();
133 vars.insert("name".to_string(), "World".to_string());
134
135 let result = template.render(&vars).unwrap();
136 assert_eq!(result, "Hello World!");
137 }
138
139 #[test]
140 fn test_template_defaults() {
141 let template =
142 PromptTemplate::new("test", "Hello {{name}}!").with_default("name", "Default");
143
144 let result = template.render(&HashMap::new()).unwrap();
145 assert_eq!(result, "Hello Default!");
146 }
147
148 #[test]
149 fn test_template_override_default() {
150 let template =
151 PromptTemplate::new("test", "Hello {{name}}!").with_default("name", "Default");
152
153 let mut vars = HashMap::new();
154 vars.insert("name".to_string(), "Custom".to_string());
155
156 let result = template.render(&vars).unwrap();
157 assert_eq!(result, "Hello Custom!");
158 }
159
160 #[test]
161 fn test_registry() {
162 let mut registry = TemplateRegistry::new();
163
164 let template = PromptTemplate::new("greeting", "Hello {{name}}!");
165 registry.register(template);
166
167 assert!(registry.get("greeting").is_some());
168 assert_eq!(registry.list().len(), 1);
169 }
170
171 #[test]
172 fn test_registry_render() {
173 let mut registry = TemplateRegistry::new();
174
175 let template = PromptTemplate::new("greeting", "Hello {{name}}!");
176 registry.register(template);
177
178 let mut vars = HashMap::new();
179 vars.insert("name".to_string(), "World".to_string());
180
181 let result = registry.render("greeting", &vars).unwrap();
182 assert_eq!(result, "Hello World!");
183 }
184}