1use std::collections::HashMap;
4
5use crate::error::SynwireError;
6use crate::messages::Message;
7use crate::prompts::PromptValue;
8
9#[derive(Debug, Clone)]
11#[non_exhaustive]
12pub enum MessageTemplate {
13 System(String),
15 Human(String),
17 AI(String),
19 Placeholder(String),
21}
22
23#[derive(Debug, Clone)]
42pub struct ChatPromptTemplate {
43 messages: Vec<MessageTemplate>,
44 input_variables: Vec<String>,
45}
46
47fn extract_variables(template: &str) -> Vec<String> {
49 let mut vars = Vec::new();
50 let mut rest = template;
51 while let Some(start) = rest.find('{') {
52 rest = &rest[start + 1..];
53 if let Some(end) = rest.find('}') {
54 let var = &rest[..end];
55 if !var.is_empty() {
56 vars.push(var.to_owned());
57 }
58 rest = &rest[end + 1..];
59 } else {
60 break;
61 }
62 }
63 vars
64}
65
66fn expand_placeholder(value: &str, out: &mut Vec<Message>) {
76 if let Ok(serde_json::Value::Array(arr)) = serde_json::from_str::<serde_json::Value>(value) {
77 for item in &arr {
78 let Some(role) = item.get("role").and_then(serde_json::Value::as_str) else {
79 continue;
80 };
81 let Some(content) = item.get("content").and_then(serde_json::Value::as_str) else {
82 continue;
83 };
84 let msg = match role {
85 "system" => Message::system(content),
86 "human" | "user" => Message::human(content),
87 "ai" | "assistant" => Message::ai(content),
88 _ => continue,
89 };
90 out.push(msg);
91 }
92 } else {
93 out.push(Message::human(value));
94 }
95}
96
97fn substitute(template: &str, variables: &HashMap<String, String>) -> Result<String, SynwireError> {
99 let mut result = template.to_owned();
100 for var in &extract_variables(template) {
101 let value = variables.get(var).ok_or_else(|| SynwireError::Prompt {
102 message: format!("missing required variable '{var}'"),
103 })?;
104 result = result.replace(&format!("{{{var}}}"), value);
105 }
106 Ok(result)
107}
108
109impl ChatPromptTemplate {
110 pub fn from_messages(messages: Vec<MessageTemplate>) -> Self {
114 let mut seen = std::collections::HashSet::new();
115 let mut input_variables = Vec::new();
116 for msg in &messages {
117 let tpl = match msg {
118 MessageTemplate::System(t) | MessageTemplate::Human(t) | MessageTemplate::AI(t) => {
119 t.as_str()
120 }
121 MessageTemplate::Placeholder(_) => continue,
122 };
123 for var in extract_variables(tpl) {
124 if seen.insert(var.clone()) {
125 input_variables.push(var);
126 }
127 }
128 }
129 Self {
130 messages,
131 input_variables,
132 }
133 }
134
135 pub fn input_variables(&self) -> &[String] {
137 &self.input_variables
138 }
139
140 pub fn format_messages(
151 &self,
152 variables: &HashMap<String, String>,
153 ) -> Result<Vec<Message>, SynwireError> {
154 let mut result = Vec::with_capacity(self.messages.len());
155 for msg in &self.messages {
156 match msg {
157 MessageTemplate::System(tpl) => {
158 let text = substitute(tpl, variables)?;
159 result.push(Message::system(text));
160 }
161 MessageTemplate::Human(tpl) => {
162 let text = substitute(tpl, variables)?;
163 result.push(Message::human(text));
164 }
165 MessageTemplate::AI(tpl) => {
166 let text = substitute(tpl, variables)?;
167 result.push(Message::ai(text));
168 }
169 MessageTemplate::Placeholder(name) => {
170 if let Some(value) = variables.get(name.as_str()) {
171 expand_placeholder(value, &mut result);
172 }
173 }
174 }
175 }
176 Ok(result)
177 }
178
179 pub fn to_prompt_value(
185 &self,
186 variables: &HashMap<String, String>,
187 ) -> Result<PromptValue, SynwireError> {
188 let messages = self.format_messages(variables)?;
189 Ok(PromptValue::Messages(messages))
190 }
191}
192
193#[cfg(test)]
194#[allow(clippy::unwrap_used)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_chat_prompt_template_format_messages() {
200 let tpl = ChatPromptTemplate::from_messages(vec![
201 MessageTemplate::System("You are {role}".into()),
202 MessageTemplate::Human("{question}".into()),
203 ]);
204 let mut vars = HashMap::new();
205 let _ = vars.insert("role".into(), "a helpful assistant".into());
206 let _ = vars.insert("question".into(), "What is Rust?".into());
207
208 let messages = tpl.format_messages(&vars).unwrap();
209 assert_eq!(messages.len(), 2);
210 assert_eq!(messages[0].message_type(), "system");
211 assert_eq!(
212 messages[0].content().as_text(),
213 "You are a helpful assistant"
214 );
215 assert_eq!(messages[1].message_type(), "human");
216 assert_eq!(messages[1].content().as_text(), "What is Rust?");
217 }
218
219 #[test]
220 fn test_chat_prompt_template_to_prompt_value() {
221 let tpl =
222 ChatPromptTemplate::from_messages(vec![MessageTemplate::Human("Hello {name}".into())]);
223 let mut vars = HashMap::new();
224 let _ = vars.insert("name".into(), "World".into());
225 let pv = tpl.to_prompt_value(&vars).unwrap();
226 let messages = pv.to_messages();
227 assert_eq!(messages.len(), 1);
228 assert_eq!(messages[0].content().as_text(), "Hello World");
229 }
230
231 #[test]
232 fn test_chat_prompt_template_missing_variable() {
233 let tpl =
234 ChatPromptTemplate::from_messages(vec![MessageTemplate::Human("{question}".into())]);
235 let vars = HashMap::new();
236 let err = tpl.format_messages(&vars).unwrap_err();
237 let msg = err.to_string();
238 assert!(
239 msg.contains("question"),
240 "error should mention the missing variable, got: {msg}"
241 );
242 }
243
244 #[test]
245 fn test_extract_variables() {
246 let vars = extract_variables("Hello {name}, you are {age} years old");
247 assert_eq!(vars, vec!["name", "age"]);
248 }
249
250 #[test]
251 fn test_extract_variables_empty() {
252 let vars = extract_variables("No variables here");
253 assert!(vars.is_empty());
254 }
255
256 #[test]
257 fn test_input_variables_auto_extracted() {
258 let tpl = ChatPromptTemplate::from_messages(vec![
259 MessageTemplate::System("You are {role}".into()),
260 MessageTemplate::Human("{question} about {topic}".into()),
261 ]);
262 assert_eq!(tpl.input_variables(), &["role", "question", "topic"]);
263 }
264
265 #[test]
266 fn test_placeholder_missing_variable_skipped() {
267 let tpl = ChatPromptTemplate::from_messages(vec![
268 MessageTemplate::System("Hello".into()),
269 MessageTemplate::Placeholder("history".into()),
270 MessageTemplate::Human("{question}".into()),
271 ]);
272 let mut vars = HashMap::new();
273 let _ = vars.insert("question".into(), "Hi".into());
274 let messages = tpl.format_messages(&vars).unwrap();
275 assert_eq!(messages.len(), 2);
277 }
278
279 #[test]
280 fn test_placeholder_json_array_expansion() {
281 let tpl = ChatPromptTemplate::from_messages(vec![
282 MessageTemplate::System("You are helpful.".into()),
283 MessageTemplate::Placeholder("history".into()),
284 MessageTemplate::Human("{question}".into()),
285 ]);
286 let history = serde_json::json!([
287 {"role": "human", "content": "What is 2+2?"},
288 {"role": "ai", "content": "4"},
289 ]);
290 let mut vars = HashMap::new();
291 let _ = vars.insert("history".into(), history.to_string());
292 let _ = vars.insert("question".into(), "And 3+3?".into());
293 let messages = tpl.format_messages(&vars).unwrap();
294 assert_eq!(messages.len(), 4);
295 assert_eq!(messages[0].message_type(), "system");
296 assert_eq!(messages[1].message_type(), "human");
297 assert_eq!(messages[1].content().as_text(), "What is 2+2?");
298 assert_eq!(messages[2].message_type(), "ai");
299 assert_eq!(messages[2].content().as_text(), "4");
300 assert_eq!(messages[3].message_type(), "human");
301 assert_eq!(messages[3].content().as_text(), "And 3+3?");
302 }
303
304 #[test]
305 fn test_placeholder_plain_string_becomes_human_message() {
306 let tpl =
307 ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("input".into())]);
308 let mut vars = HashMap::new();
309 let _ = vars.insert("input".into(), "Tell me a joke".into());
310 let messages = tpl.format_messages(&vars).unwrap();
311 assert_eq!(messages.len(), 1);
312 assert_eq!(messages[0].message_type(), "human");
313 assert_eq!(messages[0].content().as_text(), "Tell me a joke");
314 }
315
316 #[test]
317 fn test_placeholder_recognises_user_and_assistant_roles() {
318 let tpl =
319 ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("history".into())]);
320 let history = serde_json::json!([
321 {"role": "user", "content": "Hello"},
322 {"role": "assistant", "content": "Hi there"},
323 {"role": "system", "content": "Be concise"},
324 ]);
325 let mut vars = HashMap::new();
326 let _ = vars.insert("history".into(), history.to_string());
327 let messages = tpl.format_messages(&vars).unwrap();
328 assert_eq!(messages.len(), 3);
329 assert_eq!(messages[0].message_type(), "human");
330 assert_eq!(messages[1].message_type(), "ai");
331 assert_eq!(messages[2].message_type(), "system");
332 }
333
334 #[test]
335 fn test_placeholder_skips_items_with_unknown_role() {
336 let tpl =
337 ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("history".into())]);
338 let history = serde_json::json!([
339 {"role": "human", "content": "Hi"},
340 {"role": "tool", "content": "result"},
341 {"role": "ai", "content": "Done"},
342 ]);
343 let mut vars = HashMap::new();
344 let _ = vars.insert("history".into(), history.to_string());
345 let messages = tpl.format_messages(&vars).unwrap();
346 assert_eq!(messages.len(), 2);
347 assert_eq!(messages[0].message_type(), "human");
348 assert_eq!(messages[1].message_type(), "ai");
349 }
350}