sllm/
message.rs

1use std::{fmt, sync::Arc};
2
3use serde::Serialize;
4use tera::{Context, Tera};
5
6pub use crate::traits::MessageBuilder;
7
8#[derive(Clone)]
9pub struct TemplatedMessage {
10    template: String,
11    context: Context,
12}
13
14impl TemplatedMessage {
15    pub fn new(template: &str) -> Self {
16        Self {
17            template: template.to_string(),
18            context: Context::new(),
19        }
20    }
21
22    pub fn insert<T: Serialize + ?Sized, S: Into<String>>(&mut self, key: S, val: &T) {
23        self.context.insert(key, val);
24    }
25
26    pub fn remove(&mut self, index: &str) -> bool {
27        self.context.remove(index).is_some()
28    }
29
30    // TODO get
31}
32
33#[derive(Clone)]
34pub enum PromptMessage {
35    KeyValue {
36        title: String,
37        messages: Vec<(String, Arc<dyn Fn() -> String + Send + Sync>)>,
38    },
39    Templated(TemplatedMessage),
40    Simple(String),
41}
42
43impl From<TemplatedMessage> for PromptMessage {
44    fn from(value: TemplatedMessage) -> Self {
45        PromptMessage::Templated(value)
46    }
47}
48
49impl From<String> for PromptMessage {
50    fn from(value: String) -> Self {
51        PromptMessage::Simple(value)
52    }
53}
54
55impl From<&str> for PromptMessage {
56    fn from(value: &str) -> Self {
57        PromptMessage::Simple(value.to_string())
58    }
59}
60// #[derive(Clone)]
61// pub struct PromptMessageGroup {
62//     title: String,
63//     messages: Vec<(String, Arc<dyn Fn() -> String + Send + Sync>)>,
64// }
65
66impl std::fmt::Debug for PromptMessage {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self {
69            PromptMessage::KeyValue { title, messages } => {
70                f.debug_struct("KeyValue")
71                    .field("title", &title)
72                    .field(
73                        "messages",
74                        &messages
75                            .iter()
76                            .map(|(key, _)| key.clone())
77                            .collect::<Vec<String>>(),
78                    ) // can't display closures, so only the keys.
79                    .finish()
80            }
81            PromptMessage::Templated(templated_msg) => f
82                .debug_struct("Templated")
83                .field("template", &templated_msg.template)
84                .field("context", &templated_msg.context)
85                .finish(),
86            PromptMessage::Simple(message) => f.debug_tuple("Simple").field(message).finish(),
87        }
88    }
89}
90
91impl PromptMessage {
92    pub fn new_key_value(title: &str) -> Self {
93        PromptMessage::KeyValue {
94            title: title.into(),
95            messages: Vec::new(),
96        }
97    }
98
99    pub fn new_templated(templated_msg: TemplatedMessage) -> Self {
100        PromptMessage::Templated(templated_msg)
101    }
102
103    pub fn new_simple(message: String) -> Self {
104        PromptMessage::Simple(message)
105    }
106
107    // A method to add a static message
108    pub fn add_message(&mut self, key: &str, value: &str) {
109        match self {
110            PromptMessage::KeyValue { messages, .. } => {
111                let v = value.to_string();
112                let value_arc = Arc::new(move || v.clone());
113                messages.push((key.into(), value_arc));
114            }
115            _ => panic!("add_message is only valid for KeyValue variants"),
116        }
117    }
118
119    // A method to add a dynamic message
120    pub fn add_message_dyn<F>(&mut self, key: &str, value: F)
121    where
122        F: Fn() -> String + 'static + Send + Sync,
123    {
124        match self {
125            PromptMessage::KeyValue { messages, .. } => {
126                messages.push((key.into(), Arc::new(value)));
127            }
128            _ => panic!("add_message_dyn is only valid for KeyValue variants"),
129        }
130    }
131}
132
133impl MessageBuilder for PromptMessage {
134    fn build(&mut self) -> String {
135        match self {
136            PromptMessage::KeyValue { title, messages } => {
137                let rendered_messages = messages
138                    .iter()
139                    .map(|(key, value_fn)| {
140                        if key.is_empty() {
141                            format!("{}", value_fn())
142                        } else {
143                            format!("{}: {}", key, value_fn())
144                        }
145                    })
146                    .collect::<Vec<String>>()
147                    .join("\n");
148
149                if title.is_empty() {
150                    rendered_messages
151                } else {
152                    format!("[{}]\n{}", title, rendered_messages)
153                }
154            }
155            PromptMessage::Templated(templated_msg) => {
156                let mut tera = Tera::default();
157                tera.add_raw_template("template", &templated_msg.template)
158                    .unwrap();
159                tera.render("template", &templated_msg.context).unwrap()
160            }
161            PromptMessage::Simple(message) => message.clone(),
162        }
163    }
164}
165
166// one time
167pub struct PromptMessageBuilder<T>
168where
169    T: IntoIterator,
170    T::Item: MessageBuilder,
171{
172    groups: Option<T>,
173}
174
175impl<T> PromptMessageBuilder<T>
176where
177    T: IntoIterator,
178    T::Item: MessageBuilder,
179{
180    // Constructor for a new PromptMessageBuilder with an iterable of items that implement MessageBuilder.
181    pub fn new(groups: T) -> Self {
182        PromptMessageBuilder {
183            groups: Some(groups),
184        }
185    }
186}
187
188impl<T> MessageBuilder for PromptMessageBuilder<T>
189where
190    T: IntoIterator,
191    T::Item: MessageBuilder,
192{
193    fn build(&mut self) -> String {
194        let groups = self
195            .groups
196            .take()
197            .expect("Groups should not be taken more than once");
198
199        groups
200            .into_iter()
201            .map(|mut group| group.build())
202            .collect::<Vec<String>>()
203            .join("\n\n")
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_key_value_insert_and_build() {
213        let mut group = PromptMessage::new_key_value("Test Group");
214        group.add_message_dyn("Key1", || "Value1".to_string());
215        group.add_message("Key2", "Value2");
216
217        let output = group.build();
218        let expected_output = "[Test Group]\nKey1: Value1\nKey2: Value2";
219        assert_eq!(output, expected_output);
220    }
221
222    #[test]
223    fn test_templated_build() {
224        let mut msg = TemplatedMessage::new("Hello, {{ name }}!");
225        msg.insert("name", "World");
226        let mut group = PromptMessage::new_templated(msg);
227        let output = group.build();
228        let expected_output = "Hello, World!";
229        assert_eq!(output, expected_output);
230    }
231
232    #[test]
233    fn test_simple_build() {
234        let mut group = PromptMessage::new_simple("Just a simple message.".to_string());
235        let output = group.build();
236        let expected_output = "Just a simple message.";
237        assert_eq!(output, expected_output);
238    }
239}