Skip to main content

wacore/
session.rs

1//! Session management with deduplication for concurrent prekey fetches.
2//!
3//! This module implements a pattern similar to WhatsApp Web's `ensureE2ESessions`,
4//! which provides:
5//! - Deduplication: Multiple concurrent requests for the same JID share a single fetch
6//! - Batching: Prekey fetches are batched up to SESSION_CHECK_BATCH_SIZE
7//!
8//! This prevents redundant network requests when sending messages to the same
9//! recipient from multiple concurrent operations.
10
11use async_lock::Mutex;
12use futures::channel::oneshot;
13use std::collections::{HashMap, HashSet};
14use wacore_binary::jid::Jid;
15
16// Tests live in whatsapp-rust/src/session.rs (they use tokio for spawning)
17
18/// Maximum number of JIDs to include in a single prekey fetch request.
19/// Matches WhatsApp Web's SESSION_CHECK_BATCH constant.
20pub const SESSION_CHECK_BATCH_SIZE: usize = 50;
21
22/// Result of a session ensure operation
23pub type SessionResult = Result<(), SessionError>;
24
25/// Errors that can occur during session management
26#[derive(Debug, Clone)]
27pub enum SessionError {
28    /// The prekey fetch operation failed
29    FetchFailed(String),
30    /// The session establishment failed
31    EstablishmentFailed(String),
32    /// Internal channel error
33    ChannelClosed,
34}
35
36impl std::fmt::Display for SessionError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            SessionError::FetchFailed(msg) => write!(f, "prekey fetch failed: {}", msg),
40            SessionError::EstablishmentFailed(msg) => {
41                write!(f, "session establishment failed: {}", msg)
42            }
43            SessionError::ChannelClosed => write!(f, "internal channel closed"),
44        }
45    }
46}
47
48impl std::error::Error for SessionError {}
49
50/// Manages session establishment with deduplication.
51///
52/// When multiple concurrent operations need sessions for overlapping JIDs,
53/// this manager ensures only one prekey fetch is performed per JID.
54/// Subsequent requests wait for the in-flight fetch to complete.
55pub struct SessionManager {
56    /// JIDs currently being processed (prekeys being fetched + sessions being established)
57    processing: Mutex<HashSet<String>>,
58
59    /// JIDs waiting for processing, mapped to their notification channels.
60    /// When a JID finishes processing, all waiters are notified.
61    pending: Mutex<HashMap<String, Vec<oneshot::Sender<SessionResult>>>>,
62}
63
64impl SessionManager {
65    /// Create a new SessionManager
66    pub fn new() -> Self {
67        Self {
68            processing: Mutex::new(HashSet::new()),
69            pending: Mutex::new(HashMap::new()),
70        }
71    }
72
73    /// Ensure sessions exist for the given JIDs.
74    ///
75    /// This method deduplicates requests: if a JID is already being processed,
76    /// this call will wait for that processing to complete rather than
77    /// initiating a duplicate fetch.
78    ///
79    /// # Arguments
80    /// * `jids` - JIDs that need sessions
81    /// * `has_session` - Closure to check if a session already exists
82    /// * `fetch_and_establish` - Closure to fetch prekeys and establish sessions
83    ///
84    /// # Returns
85    /// Ok(()) if all sessions were established (or already existed)
86    pub async fn ensure_sessions<F, H, Fut>(
87        &self,
88        jids: Vec<Jid>,
89        has_session: H,
90        fetch_and_establish: F,
91    ) -> SessionResult
92    where
93        H: Fn(&Jid) -> bool,
94        F: Fn(Vec<Jid>) -> Fut,
95        Fut: std::future::Future<Output = Result<(), anyhow::Error>>,
96    {
97        if jids.is_empty() {
98            return Ok(());
99        }
100
101        // Step 1: Filter to JIDs that actually need sessions
102        let jids_needing_sessions: Vec<Jid> =
103            jids.into_iter().filter(|jid| !has_session(jid)).collect();
104
105        if jids_needing_sessions.is_empty() {
106            return Ok(());
107        }
108
109        // Step 2: Determine which JIDs we need to process vs wait for.
110        // Store (Jid, String) pairs so the string key computed here can be
111        // reused in step 3 cleanup, avoiding a redundant second to_string() pass.
112        let (to_process, to_wait) = {
113            let mut processing = self.processing.lock().await;
114            let mut pending = self.pending.lock().await;
115
116            let mut to_process: Vec<(Jid, String)> =
117                Vec::with_capacity(jids_needing_sessions.len());
118            let mut to_wait = Vec::with_capacity(jids_needing_sessions.len());
119
120            for jid in jids_needing_sessions {
121                let jid_str = jid.to_string();
122
123                if processing.contains(&jid_str) {
124                    // Already being processed - we need to wait
125                    let (tx, rx) = oneshot::channel();
126                    pending.entry(jid_str).or_default().push(tx);
127                    to_wait.push(rx);
128                } else {
129                    // Not being processed - we'll handle it
130                    processing.insert(jid_str.clone());
131                    to_process.push((jid, jid_str));
132                }
133            }
134
135            (to_process, to_wait)
136        };
137
138        // Step 3: Process JIDs we're responsible for (in batches)
139        let mut process_error: Option<SessionError> = None;
140
141        if !to_process.is_empty() {
142            // Process in batches of SESSION_CHECK_BATCH_SIZE
143            for batch in to_process.chunks(SESSION_CHECK_BATCH_SIZE) {
144                let batch_jids: Vec<Jid> = batch.iter().map(|(jid, _)| jid.clone()).collect();
145
146                let result = fetch_and_establish(batch_jids).await;
147
148                // Notify any waiters and remove from processing
149                let notify_result = match &result {
150                    Ok(()) => Ok(()),
151                    Err(e) => Err(SessionError::FetchFailed(e.to_string())),
152                };
153
154                if notify_result.is_err() && process_error.is_none() {
155                    process_error = Some(notify_result.clone().unwrap_err());
156                }
157
158                // Clean up processing set and notify waiters
159                // Reuse the string keys stored alongside each JID in step 2.
160                {
161                    let mut processing = self.processing.lock().await;
162                    let mut pending = self.pending.lock().await;
163
164                    for (_, jid_str) in batch {
165                        processing.remove(jid_str);
166
167                        if let Some(waiters) = pending.remove(jid_str) {
168                            for waiter in waiters {
169                                let _ = waiter.send(notify_result.clone());
170                            }
171                        }
172                    }
173                }
174            }
175        }
176
177        // Step 4: Wait for JIDs being processed by others
178        for rx in to_wait {
179            match rx.await {
180                Ok(result) => {
181                    if let Err(e) = result
182                        && process_error.is_none()
183                    {
184                        process_error = Some(e);
185                    }
186                }
187                Err(_) => {
188                    if process_error.is_none() {
189                        process_error = Some(SessionError::ChannelClosed);
190                    }
191                }
192            }
193        }
194
195        match process_error {
196            Some(e) => Err(e),
197            None => Ok(()),
198        }
199    }
200
201    /// Check if a JID is currently being processed
202    pub async fn is_processing(&self, jid: &str) -> bool {
203        self.processing.lock().await.contains(jid)
204    }
205
206    /// Get the number of JIDs currently being processed
207    pub async fn processing_count(&self) -> usize {
208        self.processing.lock().await.len()
209    }
210
211    /// Get the number of JIDs with pending waiters
212    pub async fn pending_count(&self) -> usize {
213        self.pending.lock().await.len()
214    }
215}
216
217impl Default for SessionManager {
218    fn default() -> Self {
219        Self::new()
220    }
221}