1use super::MessageStore;
4use super::types::*;
5use super::user_handle::UserHandle;
6use anyhow::Result;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub struct ThreadManager {
15 _store: MessageStore,
16 thread_cache: Arc<RwLock<HashMap<ThreadId, ThreadView>>>,
18 subscriptions: Arc<RwLock<HashSet<ThreadId>>>,
20}
21
22impl ThreadManager {
23 pub fn new(store: MessageStore) -> Self {
25 Self {
26 _store: store,
27 thread_cache: Arc::new(RwLock::new(HashMap::new())),
28 subscriptions: Arc::new(RwLock::new(HashSet::new())),
29 }
30 }
31
32 pub async fn create_thread(&self, parent_message: &RichMessage) -> Result<ThreadId> {
34 let thread_id = ThreadId::new();
35
36 let thread_view = ThreadView {
38 parent_message: parent_message.clone(),
39 replies: Vec::new(),
40 participants: vec![parent_message.sender.clone()],
41 is_following: true,
42 unread_count: 0,
43 last_activity: parent_message.created_at,
44 };
45
46 let mut cache = self.thread_cache.write().await;
48 cache.insert(thread_id, thread_view);
49
50 let mut subs = self.subscriptions.write().await;
52 subs.insert(thread_id);
53
54 Ok(thread_id)
55 }
56
57 pub async fn add_to_thread(&self, thread_id: ThreadId, message: &RichMessage) -> Result<()> {
59 let mut cache = self.thread_cache.write().await;
60
61 if let Some(thread) = cache.get_mut(&thread_id) {
62 thread.replies.push(message.clone());
64
65 if !thread.participants.contains(&message.sender) {
67 thread.participants.push(message.sender.clone());
68 }
69
70 thread.last_activity = message.created_at;
72
73 } else {
76 let thread = self.fetch_thread(thread_id).await?;
78 cache.insert(thread_id, thread);
79 }
80
81 Ok(())
82 }
83
84 pub async fn update_thread(&self, thread_id: ThreadId, message: &RichMessage) -> Result<()> {
86 self.add_to_thread(thread_id, message).await
87 }
88
89 pub async fn get_thread(&self, thread_id: ThreadId) -> Result<ThreadView> {
91 let cache = self.thread_cache.read().await;
93 if let Some(thread) = cache.get(&thread_id) {
94 return Ok(thread.clone());
95 }
96 drop(cache);
97
98 let thread = self.fetch_thread(thread_id).await?;
100
101 let mut cache = self.thread_cache.write().await;
103 cache.insert(thread_id, thread.clone());
104
105 Ok(thread)
106 }
107
108 pub async fn get_channel_threads(&self, channel_id: ChannelId) -> Result<Vec<ThreadSummary>> {
110 let cache = self.thread_cache.read().await;
112 let threads: Vec<ThreadSummary> = cache
113 .values()
114 .filter(|t| t.parent_message.channel_id == channel_id)
115 .map(ThreadSummary::from)
116 .collect();
117
118 Ok(threads)
119 }
120
121 pub async fn mark_thread_read(&self, thread_id: ThreadId) -> Result<()> {
123 let mut cache = self.thread_cache.write().await;
124 if let Some(thread) = cache.get_mut(&thread_id) {
125 thread.unread_count = 0;
126 }
127 Ok(())
128 }
129
130 pub async fn set_following(&self, thread_id: ThreadId, following: bool) -> Result<()> {
132 let mut subs = self.subscriptions.write().await;
133
134 if following {
135 subs.insert(thread_id);
136 } else {
137 subs.remove(&thread_id);
138 }
139
140 let mut cache = self.thread_cache.write().await;
142 if let Some(thread) = cache.get_mut(&thread_id) {
143 thread.is_following = following;
144 }
145
146 Ok(())
147 }
148
149 pub async fn get_followed_threads(&self) -> Result<Vec<ThreadId>> {
151 let subs = self.subscriptions.read().await;
152 Ok(subs.iter().copied().collect())
153 }
154
155 pub async fn resolve_thread(&self, thread_id: ThreadId) -> Result<()> {
157 tracing::info!("Thread {:?} resolved", thread_id);
160 Ok(())
161 }
162
163 async fn fetch_thread(&self, _thread_id: ThreadId) -> Result<ThreadView> {
165 Ok(ThreadView {
168 parent_message: RichMessage::new(
169 UserHandle::from("system-thread-mock-user"),
170 ChannelId::new(),
171 MessageContent::Text("Mock thread parent".to_string()),
172 ),
173 replies: Vec::new(),
174 participants: Vec::new(),
175 is_following: false,
176 unread_count: 0,
177 last_activity: Utc::now(),
178 })
179 }
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct ThreadSummary {
185 pub thread_id: ThreadId,
186 pub parent_preview: String,
187 pub reply_count: u32,
188 pub participant_count: u32,
189 pub last_activity: DateTime<Utc>,
190 pub unread_count: u32,
191 pub is_following: bool,
192}
193
194impl From<&ThreadView> for ThreadSummary {
195 fn from(thread: &ThreadView) -> Self {
196 let parent_preview = match &thread.parent_message.content {
197 MessageContent::Text(text) => text.chars().take(100).collect(),
198 MessageContent::RichText(rich) => rich.raw.chars().take(100).collect(),
199 _ => "[Media]".to_string(),
200 };
201
202 Self {
203 thread_id: thread.parent_message.thread_id.unwrap_or_default(),
204 parent_preview,
205 reply_count: thread.replies.len() as u32,
206 participant_count: thread.participants.len() as u32,
207 last_activity: thread.last_activity,
208 unread_count: thread.unread_count,
209 is_following: thread.is_following,
210 }
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ThreadNotificationPrefs {
217 pub all_replies: bool,
219 pub mentions_only: bool,
221 pub muted: bool,
223 pub custom_sound: Option<String>,
225}
226
227impl Default for ThreadNotificationPrefs {
228 fn default() -> Self {
229 Self {
230 all_replies: true,
231 mentions_only: false,
232 muted: false,
233 custom_sound: None,
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[tokio::test]
243 async fn test_thread_creation() {
244 #[allow(unused)]
245 let store = super::super::database::DatabaseMessageStore::new(
246 super::super::DhtClient::new_mock(),
247 None,
248 )
249 .await
250 .unwrap();
251 let manager = ThreadManager::new(store);
252
253 let parent = RichMessage::new(
254 UserHandle::from("alice"),
255 ChannelId::new(),
256 MessageContent::Text("Start a thread".to_string()),
257 );
258
259 let thread_id = manager.create_thread(&parent).await.unwrap();
260 let thread = manager.get_thread(thread_id).await.unwrap();
261
262 assert_eq!(thread.parent_message.id, parent.id);
263 assert_eq!(thread.replies.len(), 0);
264 assert_eq!(thread.participants.len(), 1);
265 assert!(thread.is_following);
266 }
267
268 #[tokio::test]
269 async fn test_thread_replies() {
270 #[allow(unused)]
271 let store = super::super::database::DatabaseMessageStore::new(
272 super::super::DhtClient::new_mock(),
273 None,
274 )
275 .await
276 .unwrap();
277 let manager = ThreadManager::new(store);
278
279 let parent = RichMessage::new(
280 UserHandle::from("alice"),
281 ChannelId::new(),
282 MessageContent::Text("Start a thread".to_string()),
283 );
284
285 let thread_id = manager.create_thread(&parent).await.unwrap();
286
287 let reply = RichMessage::new(
288 UserHandle::from("eve"),
289 parent.channel_id,
290 MessageContent::Text("Reply to thread".to_string()),
291 );
292
293 manager.add_to_thread(thread_id, &reply).await.unwrap();
294
295 let thread = manager.get_thread(thread_id).await.unwrap();
296 assert_eq!(thread.replies.len(), 1);
297 assert_eq!(thread.participants.len(), 2);
298 }
299}