wacore/session.rs
1//! Session management with deduplication for concurrent prekey fetches.
2//!
3//! This module implements a pattern similar to WhatsApp Web's `ensureE2ESessions`,
4//! which provides:
5//! - Deduplication: Multiple concurrent requests for the same JID share a single fetch
6//! - Batching: Prekey fetches are batched up to SESSION_CHECK_BATCH_SIZE
7//!
8//! This prevents redundant network requests when sending messages to the same
9//! recipient from multiple concurrent operations.
10
11use async_lock::Mutex;
12use futures::channel::oneshot;
13use std::collections::{HashMap, HashSet};
14use wacore_binary::jid::Jid;
15
16// Tests live in whatsapp-rust/src/session.rs (they use tokio for spawning)
17
18/// Maximum number of JIDs to include in a single prekey fetch request.
19/// Matches WhatsApp Web's SESSION_CHECK_BATCH constant.
20pub const SESSION_CHECK_BATCH_SIZE: usize = 50;
21
22/// Result of a session ensure operation
23pub type SessionResult = Result<(), SessionError>;
24
25/// Errors that can occur during session management
26#[derive(Debug, Clone)]
27pub enum SessionError {
28 /// The prekey fetch operation failed
29 FetchFailed(String),
30 /// The session establishment failed
31 EstablishmentFailed(String),
32 /// Internal channel error
33 ChannelClosed,
34}
35
36impl std::fmt::Display for SessionError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 SessionError::FetchFailed(msg) => write!(f, "prekey fetch failed: {}", msg),
40 SessionError::EstablishmentFailed(msg) => {
41 write!(f, "session establishment failed: {}", msg)
42 }
43 SessionError::ChannelClosed => write!(f, "internal channel closed"),
44 }
45 }
46}
47
48impl std::error::Error for SessionError {}
49
50/// Manages session establishment with deduplication.
51///
52/// When multiple concurrent operations need sessions for overlapping JIDs,
53/// this manager ensures only one prekey fetch is performed per JID.
54/// Subsequent requests wait for the in-flight fetch to complete.
55pub struct SessionManager {
56 /// JIDs currently being processed (prekeys being fetched + sessions being established)
57 processing: Mutex<HashSet<String>>,
58
59 /// JIDs waiting for processing, mapped to their notification channels.
60 /// When a JID finishes processing, all waiters are notified.
61 pending: Mutex<HashMap<String, Vec<oneshot::Sender<SessionResult>>>>,
62}
63
64impl SessionManager {
65 /// Create a new SessionManager
66 pub fn new() -> Self {
67 Self {
68 processing: Mutex::new(HashSet::new()),
69 pending: Mutex::new(HashMap::new()),
70 }
71 }
72
73 /// Ensure sessions exist for the given JIDs.
74 ///
75 /// This method deduplicates requests: if a JID is already being processed,
76 /// this call will wait for that processing to complete rather than
77 /// initiating a duplicate fetch.
78 ///
79 /// # Arguments
80 /// * `jids` - JIDs that need sessions
81 /// * `has_session` - Closure to check if a session already exists
82 /// * `fetch_and_establish` - Closure to fetch prekeys and establish sessions
83 ///
84 /// # Returns
85 /// Ok(()) if all sessions were established (or already existed)
86 pub async fn ensure_sessions<F, H, Fut>(
87 &self,
88 jids: Vec<Jid>,
89 has_session: H,
90 fetch_and_establish: F,
91 ) -> SessionResult
92 where
93 H: Fn(&Jid) -> bool,
94 F: Fn(Vec<Jid>) -> Fut,
95 Fut: std::future::Future<Output = Result<(), anyhow::Error>>,
96 {
97 if jids.is_empty() {
98 return Ok(());
99 }
100
101 // Step 1: Filter to JIDs that actually need sessions
102 let jids_needing_sessions: Vec<Jid> =
103 jids.into_iter().filter(|jid| !has_session(jid)).collect();
104
105 if jids_needing_sessions.is_empty() {
106 return Ok(());
107 }
108
109 // Step 2: Determine which JIDs we need to process vs wait for.
110 // Store (Jid, String) pairs so the string key computed here can be
111 // reused in step 3 cleanup, avoiding a redundant second to_string() pass.
112 let (to_process, to_wait) = {
113 let mut processing = self.processing.lock().await;
114 let mut pending = self.pending.lock().await;
115
116 let mut to_process: Vec<(Jid, String)> =
117 Vec::with_capacity(jids_needing_sessions.len());
118 let mut to_wait = Vec::with_capacity(jids_needing_sessions.len());
119
120 for jid in jids_needing_sessions {
121 let jid_str = jid.to_string();
122
123 if processing.contains(&jid_str) {
124 // Already being processed - we need to wait
125 let (tx, rx) = oneshot::channel();
126 pending.entry(jid_str).or_default().push(tx);
127 to_wait.push(rx);
128 } else {
129 // Not being processed - we'll handle it
130 processing.insert(jid_str.clone());
131 to_process.push((jid, jid_str));
132 }
133 }
134
135 (to_process, to_wait)
136 };
137
138 // Step 3: Process JIDs we're responsible for (in batches)
139 let mut process_error: Option<SessionError> = None;
140
141 if !to_process.is_empty() {
142 // Process in batches of SESSION_CHECK_BATCH_SIZE
143 for batch in to_process.chunks(SESSION_CHECK_BATCH_SIZE) {
144 let batch_jids: Vec<Jid> = batch.iter().map(|(jid, _)| jid.clone()).collect();
145
146 let result = fetch_and_establish(batch_jids).await;
147
148 // Notify any waiters and remove from processing
149 let notify_result = match &result {
150 Ok(()) => Ok(()),
151 Err(e) => Err(SessionError::FetchFailed(e.to_string())),
152 };
153
154 if notify_result.is_err() && process_error.is_none() {
155 process_error = Some(notify_result.clone().unwrap_err());
156 }
157
158 // Clean up processing set and notify waiters
159 // Reuse the string keys stored alongside each JID in step 2.
160 {
161 let mut processing = self.processing.lock().await;
162 let mut pending = self.pending.lock().await;
163
164 for (_, jid_str) in batch {
165 processing.remove(jid_str);
166
167 if let Some(waiters) = pending.remove(jid_str) {
168 for waiter in waiters {
169 let _ = waiter.send(notify_result.clone());
170 }
171 }
172 }
173 }
174 }
175 }
176
177 // Step 4: Wait for JIDs being processed by others
178 for rx in to_wait {
179 match rx.await {
180 Ok(result) => {
181 if let Err(e) = result
182 && process_error.is_none()
183 {
184 process_error = Some(e);
185 }
186 }
187 Err(_) => {
188 if process_error.is_none() {
189 process_error = Some(SessionError::ChannelClosed);
190 }
191 }
192 }
193 }
194
195 match process_error {
196 Some(e) => Err(e),
197 None => Ok(()),
198 }
199 }
200
201 /// Check if a JID is currently being processed
202 pub async fn is_processing(&self, jid: &str) -> bool {
203 self.processing.lock().await.contains(jid)
204 }
205
206 /// Get the number of JIDs currently being processed
207 pub async fn processing_count(&self) -> usize {
208 self.processing.lock().await.len()
209 }
210
211 /// Get the number of JIDs with pending waiters
212 pub async fn pending_count(&self) -> usize {
213 self.pending.lock().await.len()
214 }
215}
216
217impl Default for SessionManager {
218 fn default() -> Self {
219 Self::new()
220 }
221}