1use std::{fmt, sync::Arc};
2
3use serde::Serialize;
4use tera::{Context, Tera};
5
6pub use crate::traits::MessageBuilder;
7
8#[derive(Clone)]
9pub struct TemplatedMessage {
10 template: String,
11 context: Context,
12}
13
14impl TemplatedMessage {
15 pub fn new(template: &str) -> Self {
16 Self {
17 template: template.to_string(),
18 context: Context::new(),
19 }
20 }
21
22 pub fn insert<T: Serialize + ?Sized, S: Into<String>>(&mut self, key: S, val: &T) {
23 self.context.insert(key, val);
24 }
25
26 pub fn remove(&mut self, index: &str) -> bool {
27 self.context.remove(index).is_some()
28 }
29
30 }
32
33#[derive(Clone)]
34pub enum PromptMessage {
35 KeyValue {
36 title: String,
37 messages: Vec<(String, Arc<dyn Fn() -> String + Send + Sync>)>,
38 },
39 Templated(TemplatedMessage),
40 Simple(String),
41}
42
43impl From<TemplatedMessage> for PromptMessage {
44 fn from(value: TemplatedMessage) -> Self {
45 PromptMessage::Templated(value)
46 }
47}
48
49impl From<String> for PromptMessage {
50 fn from(value: String) -> Self {
51 PromptMessage::Simple(value)
52 }
53}
54
55impl From<&str> for PromptMessage {
56 fn from(value: &str) -> Self {
57 PromptMessage::Simple(value.to_string())
58 }
59}
60impl std::fmt::Debug for PromptMessage {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 match self {
69 PromptMessage::KeyValue { title, messages } => {
70 f.debug_struct("KeyValue")
71 .field("title", &title)
72 .field(
73 "messages",
74 &messages
75 .iter()
76 .map(|(key, _)| key.clone())
77 .collect::<Vec<String>>(),
78 ) .finish()
80 }
81 PromptMessage::Templated(templated_msg) => f
82 .debug_struct("Templated")
83 .field("template", &templated_msg.template)
84 .field("context", &templated_msg.context)
85 .finish(),
86 PromptMessage::Simple(message) => f.debug_tuple("Simple").field(message).finish(),
87 }
88 }
89}
90
91impl PromptMessage {
92 pub fn new_key_value(title: &str) -> Self {
93 PromptMessage::KeyValue {
94 title: title.into(),
95 messages: Vec::new(),
96 }
97 }
98
99 pub fn new_templated(templated_msg: TemplatedMessage) -> Self {
100 PromptMessage::Templated(templated_msg)
101 }
102
103 pub fn new_simple(message: String) -> Self {
104 PromptMessage::Simple(message)
105 }
106
107 pub fn add_message(&mut self, key: &str, value: &str) {
109 match self {
110 PromptMessage::KeyValue { messages, .. } => {
111 let v = value.to_string();
112 let value_arc = Arc::new(move || v.clone());
113 messages.push((key.into(), value_arc));
114 }
115 _ => panic!("add_message is only valid for KeyValue variants"),
116 }
117 }
118
119 pub fn add_message_dyn<F>(&mut self, key: &str, value: F)
121 where
122 F: Fn() -> String + 'static + Send + Sync,
123 {
124 match self {
125 PromptMessage::KeyValue { messages, .. } => {
126 messages.push((key.into(), Arc::new(value)));
127 }
128 _ => panic!("add_message_dyn is only valid for KeyValue variants"),
129 }
130 }
131}
132
133impl MessageBuilder for PromptMessage {
134 fn build(&mut self) -> String {
135 match self {
136 PromptMessage::KeyValue { title, messages } => {
137 let rendered_messages = messages
138 .iter()
139 .map(|(key, value_fn)| {
140 if key.is_empty() {
141 format!("{}", value_fn())
142 } else {
143 format!("{}: {}", key, value_fn())
144 }
145 })
146 .collect::<Vec<String>>()
147 .join("\n");
148
149 if title.is_empty() {
150 rendered_messages
151 } else {
152 format!("[{}]\n{}", title, rendered_messages)
153 }
154 }
155 PromptMessage::Templated(templated_msg) => {
156 let mut tera = Tera::default();
157 tera.add_raw_template("template", &templated_msg.template)
158 .unwrap();
159 tera.render("template", &templated_msg.context).unwrap()
160 }
161 PromptMessage::Simple(message) => message.clone(),
162 }
163 }
164}
165
166pub struct PromptMessageBuilder<T>
168where
169 T: IntoIterator,
170 T::Item: MessageBuilder,
171{
172 groups: Option<T>,
173}
174
175impl<T> PromptMessageBuilder<T>
176where
177 T: IntoIterator,
178 T::Item: MessageBuilder,
179{
180 pub fn new(groups: T) -> Self {
182 PromptMessageBuilder {
183 groups: Some(groups),
184 }
185 }
186}
187
188impl<T> MessageBuilder for PromptMessageBuilder<T>
189where
190 T: IntoIterator,
191 T::Item: MessageBuilder,
192{
193 fn build(&mut self) -> String {
194 let groups = self
195 .groups
196 .take()
197 .expect("Groups should not be taken more than once");
198
199 groups
200 .into_iter()
201 .map(|mut group| group.build())
202 .collect::<Vec<String>>()
203 .join("\n\n")
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_key_value_insert_and_build() {
213 let mut group = PromptMessage::new_key_value("Test Group");
214 group.add_message_dyn("Key1", || "Value1".to_string());
215 group.add_message("Key2", "Value2");
216
217 let output = group.build();
218 let expected_output = "[Test Group]\nKey1: Value1\nKey2: Value2";
219 assert_eq!(output, expected_output);
220 }
221
222 #[test]
223 fn test_templated_build() {
224 let mut msg = TemplatedMessage::new("Hello, {{ name }}!");
225 msg.insert("name", "World");
226 let mut group = PromptMessage::new_templated(msg);
227 let output = group.build();
228 let expected_output = "Hello, World!";
229 assert_eq!(output, expected_output);
230 }
231
232 #[test]
233 fn test_simple_build() {
234 let mut group = PromptMessage::new_simple("Just a simple message.".to_string());
235 let output = group.build();
236 let expected_output = "Just a simple message.";
237 assert_eq!(output, expected_output);
238 }
239}