1use super::SendMessageRequest;
4use super::types::*;
5use super::user_handle::UserHandle;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10pub struct MessageComposer {
12 drafts: HashMap<ChannelId, DraftMessage>,
14 mention_cache: Vec<UserHandle>,
16 emoji_shortcuts: HashMap<String, String>,
18}
19
20impl Default for MessageComposer {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl MessageComposer {
27 pub fn new() -> Self {
29 let mut emoji_shortcuts = HashMap::new();
30 emoji_shortcuts.insert(":)".to_string(), "😊".to_string());
31 emoji_shortcuts.insert(":D".to_string(), "😃".to_string());
32 emoji_shortcuts.insert(":((".to_string(), "😢".to_string());
33 emoji_shortcuts.insert("<3".to_string(), "❤️".to_string());
34 emoji_shortcuts.insert(":fire:".to_string(), "🔥".to_string());
35 emoji_shortcuts.insert(":rocket:".to_string(), "🚀".to_string());
36 emoji_shortcuts.insert(":+1:".to_string(), "👍".to_string());
37 emoji_shortcuts.insert(":-1:".to_string(), "👎".to_string());
38
39 Self {
40 drafts: HashMap::new(),
41 mention_cache: Vec::new(),
42 emoji_shortcuts,
43 }
44 }
45
46 pub fn start_draft(&mut self, channel_id: ChannelId) -> &mut DraftMessage {
48 self.drafts
49 .entry(channel_id)
50 .or_insert_with(|| DraftMessage::new(channel_id))
51 }
52
53 pub fn get_draft(&self, channel_id: ChannelId) -> Option<&DraftMessage> {
55 self.drafts.get(&channel_id)
56 }
57
58 pub fn update_draft(&mut self, channel_id: ChannelId, text: String) {
60 let draft = self.start_draft(channel_id);
61 draft.text = text;
62 draft.update_formatted();
63 }
64
65 pub fn add_mention(&mut self, channel_id: ChannelId, user: UserHandle) {
67 let draft = self.start_draft(channel_id);
68 draft.mentions.push(user.clone());
69
70 let mention_text = format!("@{} ", user);
72 draft.text.push_str(&mention_text);
73 draft.update_formatted();
74 }
75
76 pub fn add_attachment(&mut self, channel_id: ChannelId, attachment: DraftAttachment) {
78 let draft = self.start_draft(channel_id);
79 draft.attachments.push(attachment);
80 }
81
82 pub fn remove_attachment(&mut self, channel_id: ChannelId, index: usize) {
84 if let Some(draft) = self.drafts.get_mut(&channel_id)
85 && index < draft.attachments.len()
86 {
87 draft.attachments.remove(index);
88 }
89 }
90
91 pub fn set_reply_to(&mut self, channel_id: ChannelId, message_id: MessageId) {
93 let draft = self.start_draft(channel_id);
94 draft.reply_to = Some(message_id);
95 }
96
97 pub fn set_thread(&mut self, channel_id: ChannelId, thread_id: ThreadId) {
99 let draft = self.start_draft(channel_id);
100 draft.thread_id = Some(thread_id);
101 }
102
103 pub fn clear_draft(&mut self, channel_id: ChannelId) {
105 self.drafts.remove(&channel_id);
106 }
107
108 pub fn get_mention_suggestions(&self, partial: &str) -> Vec<UserHandle> {
110 self.mention_cache
111 .iter()
112 .filter(|user| {
113 user.as_str()
114 .to_lowercase()
115 .contains(&partial.to_lowercase())
116 })
117 .cloned()
118 .collect()
119 }
120
121 pub fn update_mention_cache(&mut self, users: Vec<UserHandle>) {
123 self.mention_cache = users;
124 }
125
126 pub fn apply_formatting(&mut self, channel_id: ChannelId, format: TextFormat) {
128 let draft = self.start_draft(channel_id);
129
130 match format {
131 TextFormat::Bold => {
132 draft.text = format!("**{}**", draft.text);
133 }
134 TextFormat::Italic => {
135 draft.text = format!("*{}*", draft.text);
136 }
137 TextFormat::Code => {
138 draft.text = format!("`{}`", draft.text);
139 }
140 TextFormat::Strike => {
141 draft.text = format!("~~{}~~", draft.text);
142 }
143 TextFormat::Quote => {
144 draft.text = format!("> {}", draft.text);
145 }
146 TextFormat::CodeBlock(lang) => {
147 draft.text = format!("```{}\n{}\n```", lang, draft.text);
148 }
149 }
150
151 draft.update_formatted();
152 }
153
154 pub fn insert_emoji(&mut self, channel_id: ChannelId, emoji: String) {
156 let draft = self.start_draft(channel_id);
157 draft.text.push_str(&emoji);
158 draft.update_formatted();
159 }
160
161 pub fn process_shortcuts(&mut self, channel_id: ChannelId) {
163 let shortcuts = self.emoji_shortcuts.clone();
165
166 let draft = self.start_draft(channel_id);
167
168 for (shortcut, emoji) in &shortcuts {
169 draft.text = draft.text.replace(shortcut, emoji);
170 }
171
172 draft.update_formatted();
173 }
174
175 pub fn validate_draft(&self, channel_id: ChannelId) -> Result<()> {
177 let draft = self
178 .drafts
179 .get(&channel_id)
180 .ok_or_else(|| anyhow::anyhow!("No draft found"))?;
181
182 if draft.text.trim().is_empty() && draft.attachments.is_empty() {
184 return Err(anyhow::anyhow!("Cannot send empty message"));
185 }
186
187 if draft.text.len() > 10000 {
189 return Err(anyhow::anyhow!("Message too long (max 10000 characters)"));
190 }
191
192 let total_size: usize = draft.attachments.iter().map(|a| a.size).sum();
194
195 if total_size > 100 * 1024 * 1024 {
196 return Err(anyhow::anyhow!("Total attachment size exceeds 100MB"));
197 }
198
199 Ok(())
200 }
201
202 pub fn build_message(
204 &self,
205 channel_id: ChannelId,
206 ) -> Result<SendMessageRequest> {
207 let draft = self
208 .drafts
209 .get(&channel_id)
210 .ok_or_else(|| anyhow::anyhow!("No draft found"))?;
211
212 let content = if draft.formatted_text.is_some() {
214 MessageContent::RichText(MarkdownContent {
215 raw: draft.text.clone(),
216 formatted: draft.formatted_text.clone().unwrap_or_default(),
217 mentions: draft.mentions.clone(),
218 links: draft.extract_links(),
219 })
220 } else {
221 MessageContent::Text(draft.text.clone())
222 };
223
224 let attachments = draft.attachments.iter().map(|a| a.data.clone()).collect();
226
227 Ok(SendMessageRequest {
228 channel_id,
229 content,
230 attachments,
231 thread_id: draft.thread_id,
232 reply_to: draft.reply_to,
233 mentions: draft.mentions.clone(),
234 ephemeral: draft.ephemeral,
235 })
236 }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct DraftMessage {
242 pub channel_id: ChannelId,
243 pub text: String,
244 pub formatted_text: Option<String>,
245 pub mentions: Vec<UserHandle>,
246 pub attachments: Vec<DraftAttachment>,
247 pub reply_to: Option<MessageId>,
248 pub thread_id: Option<ThreadId>,
249 pub ephemeral: bool,
250 pub created_at: chrono::DateTime<chrono::Utc>,
251 pub updated_at: chrono::DateTime<chrono::Utc>,
252}
253
254impl DraftMessage {
255 fn new(channel_id: ChannelId) -> Self {
257 let now = chrono::Utc::now();
258 Self {
259 channel_id,
260 text: String::new(),
261 formatted_text: None,
262 mentions: Vec::new(),
263 attachments: Vec::new(),
264 reply_to: None,
265 thread_id: None,
266 ephemeral: false,
267 created_at: now,
268 updated_at: now,
269 }
270 }
271
272 fn update_formatted(&mut self) {
274 if self.text.contains("**")
276 || self.text.contains("*")
277 || self.text.contains("`")
278 || self.text.contains("~~")
279 {
280 self.formatted_text = Some(self.text.clone());
281 }
282
283 self.updated_at = chrono::Utc::now();
284 }
285
286 fn extract_links(&self) -> Vec<String> {
288 fn try_re(p: &str) -> Option<regex::Regex> {
289 regex::Regex::new(p).ok()
290 }
291 let url_regex = try_re(r"https?://[^\s<]+[^<.,:;'!\?\s]")
292 .or_else(|| try_re(r"https?://.+"))
293 .or_else(|| try_re(r"https?://.*"));
294
295 if let Some(re) = url_regex {
296 re.find_iter(&self.text)
297 .map(|m| m.as_str().to_string())
298 .collect()
299 } else {
300 Vec::new()
301 }
302 }
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct DraftAttachment {
308 pub filename: String,
309 pub mime_type: String,
310 pub size: usize,
311 pub data: Vec<u8>,
312 pub thumbnail: Option<Vec<u8>>,
313}
314
315#[derive(Debug, Clone)]
317pub enum TextFormat {
318 Bold,
319 Italic,
320 Code,
321 Strike,
322 Quote,
323 CodeBlock(String),
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct AutocompleteSuggestion {
329 pub text: String,
330 pub icon: String,
331 pub description: String,
332 pub action: AutocompleteAction,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub enum AutocompleteAction {
338 InsertMention(UserHandle),
339 InsertEmoji(String),
340 InsertCommand(String),
341 InsertChannel(ChannelId),
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_draft_creation() {
350 let mut composer = MessageComposer::new();
351 let channel = ChannelId::new();
352
353 composer.update_draft(channel, "Hello world".to_string());
354
355 let draft = composer.get_draft(channel).unwrap();
356 assert_eq!(draft.text, "Hello world");
357 }
358
359 #[test]
360 fn test_mention_addition() {
361 let mut composer = MessageComposer::new();
362 let channel = ChannelId::new();
363 let user = UserHandle::from("alice");
364
365 composer.add_mention(channel, user.clone());
366
367 let draft = composer.get_draft(channel).unwrap();
368 assert!(draft.mentions.contains(&user));
369 assert!(draft.text.contains("@alice"));
370 }
371
372 #[test]
373 fn test_emoji_shortcuts() {
374 let mut composer = MessageComposer::new();
375 let channel = ChannelId::new();
376
377 composer.update_draft(channel, "Hello :) :fire:".to_string());
378 composer.process_shortcuts(channel);
379
380 let draft = composer.get_draft(channel).unwrap();
381 assert!(draft.text.contains("😊"));
382 assert!(draft.text.contains("🔥"));
383 }
384
385 #[test]
386 fn test_draft_validation() {
387 let mut composer = MessageComposer::new();
388 let channel = ChannelId::new();
389
390 let result = composer.validate_draft(channel);
392 assert!(result.is_err());
393
394 composer.update_draft(channel, "Valid message".to_string());
396 let result = composer.validate_draft(channel);
397 assert!(result.is_ok());
398 }
399}