whatsapp_rust/
retry.rs

1use crate::client::Client;
2use crate::types::events::Receipt;
3use dashmap::Entry;
4use log::{info, warn};
5use prost::Message;
6use rand::TryRngCore;
7use scopeguard;
8use std::sync::Arc;
9use wacore::libsignal::protocol::{KeyPair, ProtocolAddress};
10use wacore::libsignal::store::PreKeyStore;
11use wacore::libsignal::store::SessionStore;
12use wacore::types::jid::JidExt;
13use wacore_binary::builder::NodeBuilder;
14use wacore_binary::jid::JidExt as _;
15
16impl Client {
17    pub(crate) async fn handle_retry_receipt(
18        self: &Arc<Self>,
19        receipt: &Receipt,
20        node: &wacore_binary::node::Node,
21    ) -> Result<(), anyhow::Error> {
22        let retry_child = node
23            .get_optional_child("retry")
24            .ok_or_else(|| anyhow::anyhow!("<retry> child missing from receipt"))?;
25
26        let message_id = retry_child.attrs().string("id");
27
28        // For group messages, only retry once per message id to avoid loops
29        if receipt.source.chat.is_group() {
30            let dedupe_key = format!("{}:{}", receipt.source.chat, message_id);
31            match self.retried_group_messages.entry(dedupe_key.clone()) {
32                Entry::Occupied(_) => {
33                    log::debug!(
34                        "Ignoring subsequent retry for group message {}: already handled.",
35                        dedupe_key
36                    );
37                    return Ok(());
38                }
39                Entry::Vacant(e) => {
40                    e.insert(());
41                }
42            }
43        }
44
45        {
46            let mut pending = self.pending_retries.lock().await;
47            if pending.contains(&message_id) {
48                log::debug!("Ignoring retry for {message_id}: a retry is already in progress.");
49                return Ok(());
50            }
51            pending.insert(message_id.clone());
52        }
53        let _guard = scopeguard::guard((self.clone(), message_id.clone()), |(client, id)| {
54            tokio::spawn(async move {
55                client.pending_retries.lock().await.remove(&id);
56            });
57        });
58
59        let original_msg_arc = match self
60            .take_recent_message(receipt.source.chat.clone(), message_id.clone())
61            .await
62        {
63            Ok(Some(msg)) => msg,
64            Ok(None) => {
65                log::debug!(
66                    "Ignoring retry for message {message_id}: already handled or not found in cache."
67                );
68                return Ok(());
69            }
70            Err(e) => {
71                log::warn!("Failed to retrieve recent message for retry {message_id}: {e}");
72                return Ok(()); // Continue without the original message if retrieval failed
73            }
74        };
75
76        if receipt.source.chat.is_group() {
77            let dedupe_key = format!("{}:{}", receipt.source.chat, message_id);
78            self.retried_group_messages.insert(dedupe_key, ());
79        }
80
81        let participant_jid = receipt.source.sender.clone();
82
83        if receipt.source.chat.is_group() {
84            let device_snapshot = self.persistence_manager.get_device_snapshot().await;
85            let own_lid = device_snapshot
86                .lid
87                .clone()
88                .ok_or_else(|| anyhow::anyhow!("LID missing for group retry handling"))?;
89
90            let sender_address =
91                ProtocolAddress::new(own_lid.user.clone(), u32::from(own_lid.device).into());
92            let sender_key_name = wacore::libsignal::store::sender_key_name::SenderKeyName::new(
93                receipt.source.chat.to_string(),
94                sender_address.to_string(),
95            );
96
97            let device_store = self.persistence_manager.get_device_arc().await;
98            let device_guard = device_store.read().await;
99
100            // The store saves group sender keys under the composite key "{group}:{sender}".
101            let unique_key_to_delete = format!(
102                "{}:{}",
103                sender_key_name.group_id(),
104                sender_key_name.sender_id()
105            );
106
107            if let Err(e) = device_guard
108                .backend
109                .delete_sender_key(&unique_key_to_delete)
110                .await
111            {
112                log::warn!(
113                    "Failed to delete sender key for group {}: {}",
114                    receipt.source.chat,
115                    e
116                );
117            } else {
118                info!(
119                    "Deleted sender key for group {} due to retry receipt from {}",
120                    receipt.source.chat, participant_jid
121                );
122            }
123        } else {
124            let signal_address = participant_jid.to_protocol_address();
125
126            let device_store = self.persistence_manager.get_device_arc().await;
127            if let Err(e) = device_store
128                .write()
129                .await
130                .delete_session(&signal_address)
131                .await
132            {
133                log::warn!("Failed to delete session for {signal_address}: {e}");
134            } else {
135                info!("Deleted session for {signal_address} due to retry receipt");
136            }
137        }
138
139        if receipt.source.chat.is_group() {
140            info!(
141                "Handling group message retry for {}. Creating and attaching a new SenderKeyDistributionMessage.",
142                message_id
143            );
144
145            self.send_message_impl(
146                receipt.source.chat.clone(),
147                Arc::clone(&original_msg_arc),
148                Some(message_id.clone()), // Pass Some(message_id)
149                false,
150                true,
151                None,
152            )
153            .await?;
154        } else {
155            self.send_message_impl(
156                receipt.source.chat.clone(),
157                Arc::clone(&original_msg_arc),
158                Some(message_id), // Pass Some(message_id)
159                false,
160                true,
161                None,
162            )
163            .await?;
164        }
165
166        Ok(())
167    }
168
169    pub(crate) async fn send_retry_receipt(
170        &self,
171        info: &crate::types::message::MessageInfo,
172    ) -> Result<(), anyhow::Error> {
173        warn!(
174            "Sending retry receipt for message {} from {}",
175            info.id, info.source.sender
176        );
177
178        let device_snapshot = self.persistence_manager.get_device_snapshot().await;
179        let device_store = self.persistence_manager.get_device_arc().await;
180        let device_guard = device_store.read().await;
181
182        let new_prekey_id = (rand::random::<u32>() % 16777215) + 1;
183        let new_prekey_keypair = KeyPair::generate(&mut rand::rngs::OsRng.unwrap_err());
184        let new_prekey_record = wacore::libsignal::store::record_helpers::new_pre_key_record(
185            new_prekey_id,
186            &new_prekey_keypair,
187        );
188        // This key is not uploaded to the server pool, so mark as false
189        if let Err(e) = device_guard
190            .store_prekey(new_prekey_id, new_prekey_record, false)
191            .await
192        {
193            warn!("Failed to store new prekey for retry receipt: {e:?}");
194        }
195        drop(device_guard);
196
197        let registration_id_bytes = device_snapshot.registration_id.to_be_bytes().to_vec();
198
199        let identity_key_bytes = device_snapshot
200            .identity_key
201            .public_key
202            .public_key_bytes()
203            .to_vec();
204
205        let prekey_id_bytes = new_prekey_id.to_be_bytes()[1..].to_vec();
206        let prekey_value_bytes = new_prekey_keypair.public_key.public_key_bytes().to_vec();
207
208        let skey_id_bytes = 1u32.to_be_bytes()[1..].to_vec();
209        let skey_value_bytes = device_snapshot
210            .signed_pre_key
211            .public_key
212            .public_key_bytes()
213            .to_vec();
214        let skey_sig_bytes = device_snapshot.signed_pre_key_signature.to_vec();
215
216        let device_identity_bytes = device_snapshot
217            .account
218            .as_ref()
219            .ok_or_else(|| anyhow::anyhow!("Missing device account info for retry receipt"))?
220            .encode_to_vec();
221
222        let retry_node = NodeBuilder::new("retry")
223            .attr("v", "1")
224            .attr("id", info.id.clone())
225            .attr("t", info.timestamp.timestamp().to_string())
226            .attr("count", "1")
227            .build();
228
229        let type_bytes = vec![5u8];
230
231        let keys_node = NodeBuilder::new("keys")
232            .children([
233                NodeBuilder::new("type").bytes(type_bytes).build(),
234                NodeBuilder::new("identity")
235                    .bytes(identity_key_bytes)
236                    .build(),
237                NodeBuilder::new("key")
238                    .children([
239                        NodeBuilder::new("id").bytes(prekey_id_bytes).build(),
240                        NodeBuilder::new("value").bytes(prekey_value_bytes).build(),
241                    ])
242                    .build(),
243                NodeBuilder::new("skey")
244                    .children([
245                        NodeBuilder::new("id").bytes(skey_id_bytes).build(),
246                        NodeBuilder::new("value").bytes(skey_value_bytes).build(),
247                        NodeBuilder::new("signature").bytes(skey_sig_bytes).build(),
248                    ])
249                    .build(),
250                NodeBuilder::new("device-identity")
251                    .bytes(device_identity_bytes)
252                    .build(),
253            ])
254            .build();
255
256        let receipt_to = if info.source.is_group {
257            info.source.chat.to_string()
258        } else {
259            info.source.sender.to_string()
260        };
261
262        let registration_node = NodeBuilder::new("registration")
263            .bytes(registration_id_bytes)
264            .build();
265
266        let receipt_node = NodeBuilder::new("receipt")
267            .attr("to", receipt_to)
268            .attr("id", info.id.clone())
269            .attr("type", "retry")
270            .attr("participant", info.source.sender.to_string())
271            .children([retry_node, registration_node, keys_node])
272            .build();
273
274        self.send_node(receipt_node).await?;
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::store::persistence_manager::PersistenceManager;
283    use wacore_binary::jid::Jid;
284    use waproto::whatsapp as wa;
285
286    #[tokio::test]
287    async fn recent_message_cache_insert_and_take() {
288        let _ = env_logger::builder().is_test(true).try_init();
289
290        let backend = Arc::new(
291            crate::store::sqlite_store::SqliteStore::new(":memory:")
292                .await
293                .unwrap(),
294        ) as Arc<dyn crate::store::traits::Backend>;
295        let pm = Arc::new(PersistenceManager::new(backend).await.unwrap());
296        let (client, _sync_rx) = Client::new(pm.clone()).await;
297
298        let chat: Jid = "120363021033254949@g.us".parse().unwrap();
299        let msg_id = "ABC123".to_string();
300        let msg = wa::Message {
301            conversation: Some("hello".into()),
302            ..Default::default()
303        };
304
305        // Insert via the public API
306        client
307            .add_recent_message(chat.clone(), msg_id.clone(), Arc::new(msg.clone()))
308            .await
309            .expect("Failed to add recent message");
310
311        // Wait for the manager task to process reliably in tests
312        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
313
314        // First take should return and remove it from cache
315        let taken_result = client
316            .take_recent_message(chat.clone(), msg_id.clone())
317            .await;
318        match taken_result {
319            Ok(taken) => {
320                assert!(taken.is_some());
321                assert_eq!(taken.unwrap().conversation.as_deref(), Some("hello"));
322            }
323            Err(e) => panic!("Failed to take recent message: {}", e),
324        }
325
326        // Second take should return None
327        let taken_again_result = client.take_recent_message(chat, msg_id).await;
328        match taken_again_result {
329            Ok(taken_again) => assert!(taken_again.is_none()),
330            Err(e) => panic!("Failed to take recent message: {}", e),
331        }
332    }
333}