Skip to main content

synwire_core/prompts/
chat.rs

1//! Chat prompt template with message-level variable substitution.
2
3use std::collections::HashMap;
4
5use crate::error::SynwireError;
6use crate::messages::Message;
7use crate::prompts::PromptValue;
8
9/// A template for a single message in a chat prompt.
10#[derive(Debug, Clone)]
11#[non_exhaustive]
12pub enum MessageTemplate {
13    /// A system message template.
14    System(String),
15    /// A human message template.
16    Human(String),
17    /// An AI message template.
18    AI(String),
19    /// A placeholder for dynamic messages (e.g., chat history).
20    Placeholder(String),
21}
22
23/// A template that formats messages with variable substitution.
24///
25/// # Examples
26///
27/// ```
28/// use std::collections::HashMap;
29/// use synwire_core::prompts::{ChatPromptTemplate, MessageTemplate};
30///
31/// let tpl = ChatPromptTemplate::from_messages(vec![
32///     MessageTemplate::System("You are {role}".into()),
33///     MessageTemplate::Human("{question}".into()),
34/// ]);
35/// let mut vars = HashMap::new();
36/// vars.insert("role".into(), "a helpful assistant".into());
37/// vars.insert("question".into(), "What is Rust?".into());
38/// let messages = tpl.format_messages(&vars).unwrap();
39/// assert_eq!(messages.len(), 2);
40/// ```
41#[derive(Debug, Clone)]
42pub struct ChatPromptTemplate {
43    messages: Vec<MessageTemplate>,
44    input_variables: Vec<String>,
45}
46
47/// Extracts `{variable}` names from a template string.
48fn extract_variables(template: &str) -> Vec<String> {
49    let mut vars = Vec::new();
50    let mut rest = template;
51    while let Some(start) = rest.find('{') {
52        rest = &rest[start + 1..];
53        if let Some(end) = rest.find('}') {
54            let var = &rest[..end];
55            if !var.is_empty() {
56                vars.push(var.to_owned());
57            }
58            rest = &rest[end + 1..];
59        } else {
60            break;
61        }
62    }
63    vars
64}
65
66/// Expands a placeholder value into one or more messages.
67///
68/// If the value is a valid JSON array of objects with `role` and `content`
69/// fields, each element is converted to the corresponding [`Message`] variant.
70/// Recognised roles are `"system"`, `"human"` / `"user"`, and `"ai"` /
71/// `"assistant"`; unrecognised roles are silently skipped.
72///
73/// If the value is not a valid JSON array, it is treated as a single human
74/// message whose content is the raw string.
75fn expand_placeholder(value: &str, out: &mut Vec<Message>) {
76    if let Ok(serde_json::Value::Array(arr)) = serde_json::from_str::<serde_json::Value>(value) {
77        for item in &arr {
78            let Some(role) = item.get("role").and_then(serde_json::Value::as_str) else {
79                continue;
80            };
81            let Some(content) = item.get("content").and_then(serde_json::Value::as_str) else {
82                continue;
83            };
84            let msg = match role {
85                "system" => Message::system(content),
86                "human" | "user" => Message::human(content),
87                "ai" | "assistant" => Message::ai(content),
88                _ => continue,
89            };
90            out.push(msg);
91        }
92    } else {
93        out.push(Message::human(value));
94    }
95}
96
97/// Performs `{variable}` substitution on a template string.
98fn substitute(template: &str, variables: &HashMap<String, String>) -> Result<String, SynwireError> {
99    let mut result = template.to_owned();
100    for var in &extract_variables(template) {
101        let value = variables.get(var).ok_or_else(|| SynwireError::Prompt {
102            message: format!("missing required variable '{var}'"),
103        })?;
104        result = result.replace(&format!("{{{var}}}"), value);
105    }
106    Ok(result)
107}
108
109impl ChatPromptTemplate {
110    /// Creates a `ChatPromptTemplate` from a list of message templates.
111    ///
112    /// Input variables are automatically extracted from template strings.
113    pub fn from_messages(messages: Vec<MessageTemplate>) -> Self {
114        let mut seen = std::collections::HashSet::new();
115        let mut input_variables = Vec::new();
116        for msg in &messages {
117            let tpl = match msg {
118                MessageTemplate::System(t) | MessageTemplate::Human(t) | MessageTemplate::AI(t) => {
119                    t.as_str()
120                }
121                MessageTemplate::Placeholder(_) => continue,
122            };
123            for var in extract_variables(tpl) {
124                if seen.insert(var.clone()) {
125                    input_variables.push(var);
126                }
127            }
128        }
129        Self {
130            messages,
131            input_variables,
132        }
133    }
134
135    /// Returns the input variables required by this template.
136    pub fn input_variables(&self) -> &[String] {
137        &self.input_variables
138    }
139
140    /// Formats all message templates into concrete [`Message`] values.
141    ///
142    /// `Placeholder` templates are expanded by looking up their variable name
143    /// in the provided map. If the value is a JSON array of `{role, content}`
144    /// objects the corresponding messages are injected; otherwise the value is
145    /// treated as a single human message. Missing placeholders are skipped.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`SynwireError::Prompt`] if a required variable is missing.
150    pub fn format_messages(
151        &self,
152        variables: &HashMap<String, String>,
153    ) -> Result<Vec<Message>, SynwireError> {
154        let mut result = Vec::with_capacity(self.messages.len());
155        for msg in &self.messages {
156            match msg {
157                MessageTemplate::System(tpl) => {
158                    let text = substitute(tpl, variables)?;
159                    result.push(Message::system(text));
160                }
161                MessageTemplate::Human(tpl) => {
162                    let text = substitute(tpl, variables)?;
163                    result.push(Message::human(text));
164                }
165                MessageTemplate::AI(tpl) => {
166                    let text = substitute(tpl, variables)?;
167                    result.push(Message::ai(text));
168                }
169                MessageTemplate::Placeholder(name) => {
170                    if let Some(value) = variables.get(name.as_str()) {
171                        expand_placeholder(value, &mut result);
172                    }
173                }
174            }
175        }
176        Ok(result)
177    }
178
179    /// Formats messages and wraps them in a [`PromptValue::Messages`].
180    ///
181    /// # Errors
182    ///
183    /// Returns [`SynwireError::Prompt`] if a required variable is missing.
184    pub fn to_prompt_value(
185        &self,
186        variables: &HashMap<String, String>,
187    ) -> Result<PromptValue, SynwireError> {
188        let messages = self.format_messages(variables)?;
189        Ok(PromptValue::Messages(messages))
190    }
191}
192
193#[cfg(test)]
194#[allow(clippy::unwrap_used)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_chat_prompt_template_format_messages() {
200        let tpl = ChatPromptTemplate::from_messages(vec![
201            MessageTemplate::System("You are {role}".into()),
202            MessageTemplate::Human("{question}".into()),
203        ]);
204        let mut vars = HashMap::new();
205        let _ = vars.insert("role".into(), "a helpful assistant".into());
206        let _ = vars.insert("question".into(), "What is Rust?".into());
207
208        let messages = tpl.format_messages(&vars).unwrap();
209        assert_eq!(messages.len(), 2);
210        assert_eq!(messages[0].message_type(), "system");
211        assert_eq!(
212            messages[0].content().as_text(),
213            "You are a helpful assistant"
214        );
215        assert_eq!(messages[1].message_type(), "human");
216        assert_eq!(messages[1].content().as_text(), "What is Rust?");
217    }
218
219    #[test]
220    fn test_chat_prompt_template_to_prompt_value() {
221        let tpl =
222            ChatPromptTemplate::from_messages(vec![MessageTemplate::Human("Hello {name}".into())]);
223        let mut vars = HashMap::new();
224        let _ = vars.insert("name".into(), "World".into());
225        let pv = tpl.to_prompt_value(&vars).unwrap();
226        let messages = pv.to_messages();
227        assert_eq!(messages.len(), 1);
228        assert_eq!(messages[0].content().as_text(), "Hello World");
229    }
230
231    #[test]
232    fn test_chat_prompt_template_missing_variable() {
233        let tpl =
234            ChatPromptTemplate::from_messages(vec![MessageTemplate::Human("{question}".into())]);
235        let vars = HashMap::new();
236        let err = tpl.format_messages(&vars).unwrap_err();
237        let msg = err.to_string();
238        assert!(
239            msg.contains("question"),
240            "error should mention the missing variable, got: {msg}"
241        );
242    }
243
244    #[test]
245    fn test_extract_variables() {
246        let vars = extract_variables("Hello {name}, you are {age} years old");
247        assert_eq!(vars, vec!["name", "age"]);
248    }
249
250    #[test]
251    fn test_extract_variables_empty() {
252        let vars = extract_variables("No variables here");
253        assert!(vars.is_empty());
254    }
255
256    #[test]
257    fn test_input_variables_auto_extracted() {
258        let tpl = ChatPromptTemplate::from_messages(vec![
259            MessageTemplate::System("You are {role}".into()),
260            MessageTemplate::Human("{question} about {topic}".into()),
261        ]);
262        assert_eq!(tpl.input_variables(), &["role", "question", "topic"]);
263    }
264
265    #[test]
266    fn test_placeholder_missing_variable_skipped() {
267        let tpl = ChatPromptTemplate::from_messages(vec![
268            MessageTemplate::System("Hello".into()),
269            MessageTemplate::Placeholder("history".into()),
270            MessageTemplate::Human("{question}".into()),
271        ]);
272        let mut vars = HashMap::new();
273        let _ = vars.insert("question".into(), "Hi".into());
274        let messages = tpl.format_messages(&vars).unwrap();
275        // Placeholder variable not provided, so only System + Human
276        assert_eq!(messages.len(), 2);
277    }
278
279    #[test]
280    fn test_placeholder_json_array_expansion() {
281        let tpl = ChatPromptTemplate::from_messages(vec![
282            MessageTemplate::System("You are helpful.".into()),
283            MessageTemplate::Placeholder("history".into()),
284            MessageTemplate::Human("{question}".into()),
285        ]);
286        let history = serde_json::json!([
287            {"role": "human", "content": "What is 2+2?"},
288            {"role": "ai", "content": "4"},
289        ]);
290        let mut vars = HashMap::new();
291        let _ = vars.insert("history".into(), history.to_string());
292        let _ = vars.insert("question".into(), "And 3+3?".into());
293        let messages = tpl.format_messages(&vars).unwrap();
294        assert_eq!(messages.len(), 4);
295        assert_eq!(messages[0].message_type(), "system");
296        assert_eq!(messages[1].message_type(), "human");
297        assert_eq!(messages[1].content().as_text(), "What is 2+2?");
298        assert_eq!(messages[2].message_type(), "ai");
299        assert_eq!(messages[2].content().as_text(), "4");
300        assert_eq!(messages[3].message_type(), "human");
301        assert_eq!(messages[3].content().as_text(), "And 3+3?");
302    }
303
304    #[test]
305    fn test_placeholder_plain_string_becomes_human_message() {
306        let tpl =
307            ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("input".into())]);
308        let mut vars = HashMap::new();
309        let _ = vars.insert("input".into(), "Tell me a joke".into());
310        let messages = tpl.format_messages(&vars).unwrap();
311        assert_eq!(messages.len(), 1);
312        assert_eq!(messages[0].message_type(), "human");
313        assert_eq!(messages[0].content().as_text(), "Tell me a joke");
314    }
315
316    #[test]
317    fn test_placeholder_recognises_user_and_assistant_roles() {
318        let tpl =
319            ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("history".into())]);
320        let history = serde_json::json!([
321            {"role": "user", "content": "Hello"},
322            {"role": "assistant", "content": "Hi there"},
323            {"role": "system", "content": "Be concise"},
324        ]);
325        let mut vars = HashMap::new();
326        let _ = vars.insert("history".into(), history.to_string());
327        let messages = tpl.format_messages(&vars).unwrap();
328        assert_eq!(messages.len(), 3);
329        assert_eq!(messages[0].message_type(), "human");
330        assert_eq!(messages[1].message_type(), "ai");
331        assert_eq!(messages[2].message_type(), "system");
332    }
333
334    #[test]
335    fn test_placeholder_skips_items_with_unknown_role() {
336        let tpl =
337            ChatPromptTemplate::from_messages(vec![MessageTemplate::Placeholder("history".into())]);
338        let history = serde_json::json!([
339            {"role": "human", "content": "Hi"},
340            {"role": "tool", "content": "result"},
341            {"role": "ai", "content": "Done"},
342        ]);
343        let mut vars = HashMap::new();
344        let _ = vars.insert("history".into(), history.to_string());
345        let messages = tpl.format_messages(&vars).unwrap();
346        assert_eq!(messages.len(), 2);
347        assert_eq!(messages[0].message_type(), "human");
348        assert_eq!(messages[1].message_type(), "ai");
349    }
350}