Skip to main content

whatsapp_rust/store/
signal.rs

1use crate::store::Device;
2use async_lock::Mutex;
3use async_trait::async_trait;
4use std::sync::Arc;
5use wacore::libsignal::protocol::error::Result as SignalResult;
6use wacore::libsignal::protocol::{
7    Direction, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore, PrivateKey,
8    ProtocolAddress, PublicKey, SenderKeyRecord, SenderKeyStore, SessionRecord,
9    SignalProtocolError,
10};
11use wacore::libsignal::store::sender_key_name::SenderKeyName;
12use wacore::libsignal::store::*;
13use waproto::whatsapp::{PreKeyRecordStructure, SignedPreKeyRecordStructure};
14
15type StoreError = Box<dyn std::error::Error + Send + Sync>;
16
17macro_rules! impl_store_wrapper {
18    ($wrapper_ty:ty, $read_lock:ident, $write_lock:ident) => {
19        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
20        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
21        impl IdentityKeyStore for $wrapper_ty {
22            async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
23                self.0.$read_lock().await.get_identity_key_pair().await
24            }
25
26            async fn get_local_registration_id(&self) -> SignalResult<u32> {
27                self.0.$read_lock().await.get_local_registration_id().await
28            }
29
30            async fn save_identity(
31                &mut self,
32                address: &ProtocolAddress,
33                identity_key: &IdentityKey,
34            ) -> SignalResult<IdentityChange> {
35                self.0
36                    .$write_lock()
37                    .await
38                    .save_identity(address, identity_key)
39                    .await
40            }
41
42            async fn is_trusted_identity(
43                &self,
44                address: &ProtocolAddress,
45                identity_key: &IdentityKey,
46                direction: Direction,
47            ) -> SignalResult<bool> {
48                self.0
49                    .$read_lock()
50                    .await
51                    .is_trusted_identity(address, identity_key, direction)
52                    .await
53            }
54
55            async fn get_identity(
56                &self,
57                address: &ProtocolAddress,
58            ) -> SignalResult<Option<IdentityKey>> {
59                self.0.$read_lock().await.get_identity(address).await
60            }
61        }
62
63        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
64        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
65        impl PreKeyStore for $wrapper_ty {
66            async fn load_prekey(
67                &self,
68                prekey_id: u32,
69            ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
70                self.0.$read_lock().await.load_prekey(prekey_id).await
71            }
72
73            async fn store_prekey(
74                &self,
75                prekey_id: u32,
76                record: PreKeyRecordStructure,
77                uploaded: bool,
78            ) -> Result<(), StoreError> {
79                self.0
80                    .$write_lock()
81                    .await
82                    .store_prekey(prekey_id, record, uploaded)
83                    .await
84            }
85
86            async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
87                self.0.$read_lock().await.contains_prekey(prekey_id).await
88            }
89
90            async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
91                self.0.$write_lock().await.remove_prekey(prekey_id).await
92            }
93        }
94
95        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
96        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
97        impl SignedPreKeyStore for $wrapper_ty {
98            async fn load_signed_prekey(
99                &self,
100                signed_prekey_id: u32,
101            ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
102                self.0
103                    .$read_lock()
104                    .await
105                    .load_signed_prekey(signed_prekey_id)
106                    .await
107            }
108
109            async fn load_signed_prekeys(
110                &self,
111            ) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
112                self.0.$read_lock().await.load_signed_prekeys().await
113            }
114
115            async fn store_signed_prekey(
116                &self,
117                signed_prekey_id: u32,
118                record: SignedPreKeyRecordStructure,
119            ) -> Result<(), StoreError> {
120                self.0
121                    .$write_lock()
122                    .await
123                    .store_signed_prekey(signed_prekey_id, record)
124                    .await
125            }
126
127            async fn contains_signed_prekey(
128                &self,
129                signed_prekey_id: u32,
130            ) -> Result<bool, StoreError> {
131                self.0
132                    .$read_lock()
133                    .await
134                    .contains_signed_prekey(signed_prekey_id)
135                    .await
136            }
137
138            async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
139                self.0
140                    .$write_lock()
141                    .await
142                    .remove_signed_prekey(signed_prekey_id)
143                    .await
144            }
145        }
146
147        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
148        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
149        impl SessionStore for $wrapper_ty {
150            async fn load_session(
151                &self,
152                address: &ProtocolAddress,
153            ) -> Result<SessionRecord, StoreError> {
154                self.0.$read_lock().await.load_session(address).await
155            }
156
157            async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
158                self.0
159                    .$read_lock()
160                    .await
161                    .get_sub_device_sessions(name)
162                    .await
163            }
164
165            async fn store_session(
166                &self,
167                address: &ProtocolAddress,
168                record: &SessionRecord,
169            ) -> Result<(), StoreError> {
170                self.0
171                    .$write_lock()
172                    .await
173                    .store_session(address, record)
174                    .await
175            }
176
177            async fn contains_session(
178                &self,
179                address: &ProtocolAddress,
180            ) -> Result<bool, StoreError> {
181                self.0.$read_lock().await.contains_session(address).await
182            }
183
184            async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
185                self.0.$write_lock().await.delete_session(address).await
186            }
187
188            async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
189                self.0.$write_lock().await.delete_all_sessions(name).await
190            }
191        }
192    };
193}
194
195#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
196#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
197impl IdentityKeyStore for Device {
198    async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
199        Ok(self.identity_key.clone().into())
200    }
201
202    async fn get_local_registration_id(&self) -> SignalResult<u32> {
203        Ok(self.registration_id)
204    }
205
206    async fn save_identity(
207        &mut self,
208        address: &ProtocolAddress,
209        identity_key: &IdentityKey,
210    ) -> SignalResult<IdentityChange> {
211        let address_str = address.to_string();
212        let key_bytes = identity_key.public_key().public_key_bytes();
213        let existing_identity_opt = self.get_identity(address).await?;
214
215        self.backend
216            .put_identity(
217                &address_str,
218                key_bytes.try_into().map_err(|_| {
219                    SignalProtocolError::InvalidArgument("Invalid key length".into())
220                })?,
221            )
222            .await
223            .map_err(|e| {
224                SignalProtocolError::InvalidState("backend put_identity", e.to_string())
225            })?;
226
227        match existing_identity_opt {
228            None => Ok(IdentityChange::NewOrUnchanged),
229            Some(existing) if &existing == identity_key => Ok(IdentityChange::NewOrUnchanged),
230            Some(_) => Ok(IdentityChange::ReplacedExisting),
231        }
232    }
233
234    async fn is_trusted_identity(
235        &self,
236        address: &ProtocolAddress,
237        identity_key: &IdentityKey,
238        _direction: Direction,
239    ) -> SignalResult<bool> {
240        // Trust on first use: if we don't have an identity stored, trust this one
241        // If we have one stored, it must match
242        match self.get_identity(address).await? {
243            None => Ok(true), // Trust on first use
244            Some(stored_identity) => Ok(&stored_identity == identity_key),
245        }
246    }
247
248    async fn get_identity(&self, address: &ProtocolAddress) -> SignalResult<Option<IdentityKey>> {
249        let identity_bytes = self
250            .backend
251            .load_identity(&address.to_string())
252            .await
253            .map_err(|e| {
254                SignalProtocolError::InvalidState("backend get_identity", e.to_string())
255            })?;
256
257        match identity_bytes {
258            Some(bytes) if !bytes.is_empty() => {
259                let public_key = PublicKey::from_djb_public_key_bytes(&bytes)?;
260                Ok(Some(IdentityKey::new(public_key)))
261            }
262            _ => Ok(None),
263        }
264    }
265}
266
267#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
268#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
269impl PreKeyStore for Device {
270    async fn load_prekey(
271        &self,
272        prekey_id: u32,
273    ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
274        use prost::Message;
275        use wacore::libsignal::protocol::KeyPair;
276        use wacore::libsignal::store::record_helpers::new_pre_key_record;
277
278        match self.backend.load_prekey(prekey_id).await {
279            Ok(Some(bytes)) => {
280                // Try new format first (protobuf-encoded PreKeyRecordStructure)
281                if let Ok(record) = PreKeyRecordStructure::decode(bytes.as_slice()) {
282                    return Ok(Some(record));
283                }
284
285                // Fallback: old format stored just the private key bytes (32 bytes)
286                // Reconstruct the full record by deriving the public key
287                if let Ok(private_key) = PrivateKey::deserialize(&bytes)
288                    && let Ok(public_key) = private_key.public_key()
289                {
290                    let key_pair = KeyPair::new(public_key, private_key);
291                    let record = new_pre_key_record(prekey_id, &key_pair);
292                    return Ok(Some(record));
293                }
294
295                // Could not decode in either format
296                Ok(None)
297            }
298            Ok(None) => Ok(None),
299            Err(e) => Err(Box::new(e) as StoreError),
300        }
301    }
302
303    async fn store_prekey(
304        &self,
305        prekey_id: u32,
306        record: PreKeyRecordStructure,
307        uploaded: bool,
308    ) -> Result<(), StoreError> {
309        use prost::Message;
310        let bytes = record.encode_to_vec();
311        self.backend
312            .store_prekey(prekey_id, &bytes, uploaded)
313            .await
314            .map_err(|e| Box::new(e) as StoreError)
315    }
316
317    async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
318        match self.backend.load_prekey(prekey_id).await {
319            Ok(opt) => Ok(opt.is_some()),
320            Err(e) => Err(Box::new(e) as StoreError),
321        }
322    }
323
324    async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
325        self.backend
326            .remove_prekey(prekey_id)
327            .await
328            .map_err(|e| Box::new(e) as StoreError)
329    }
330}
331
332#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
333#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
334impl SignedPreKeyStore for Device {
335    async fn load_signed_prekey(
336        &self,
337        signed_prekey_id: u32,
338    ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
339        if signed_prekey_id == self.signed_pre_key_id {
340            let record = wacore::libsignal::store::record_helpers::new_signed_pre_key_record(
341                self.signed_pre_key_id,
342                &self.signed_pre_key,
343                self.signed_pre_key_signature,
344                wacore::time::now_utc(),
345            );
346            return Ok(Some(record));
347        }
348        Ok(None)
349    }
350
351    async fn load_signed_prekeys(&self) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
352        log::warn!(
353            "Device: load_signed_prekeys() - returning empty list. Only the device's own signed pre-key should be accessed via load_signed_prekey()."
354        );
355        Ok(Vec::new())
356    }
357
358    async fn store_signed_prekey(
359        &self,
360        signed_prekey_id: u32,
361        _record: SignedPreKeyRecordStructure,
362    ) -> Result<(), StoreError> {
363        log::warn!(
364            "Device: store_signed_prekey({}) - no-op. Signed pre-keys should only be set once during device creation/pairing and managed via PersistenceManager.",
365            signed_prekey_id
366        );
367        Ok(())
368    }
369
370    async fn contains_signed_prekey(&self, signed_prekey_id: u32) -> Result<bool, StoreError> {
371        Ok(signed_prekey_id == self.signed_pre_key_id)
372    }
373
374    async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
375        log::warn!(
376            "Device: remove_signed_prekey({}) - no-op. Signed pre-keys are managed via PersistenceManager and should not be removed individually.",
377            signed_prekey_id
378        );
379        Ok(())
380    }
381}
382
383#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
384#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
385impl SessionStore for Device {
386    async fn load_session(&self, address: &ProtocolAddress) -> Result<SessionRecord, StoreError> {
387        let address_str = address.to_string();
388        match self.backend.get_session(&address_str).await {
389            Ok(Some(session_data)) => {
390                SessionRecord::deserialize(&session_data).map_err(|e| Box::new(e) as StoreError)
391            }
392            Ok(None) => Ok(SessionRecord::new_fresh()),
393            Err(e) => Err(Box::new(e) as StoreError),
394        }
395    }
396
397    async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
398        let _ = name;
399        Ok(Vec::new())
400    }
401
402    async fn store_session(
403        &self,
404        address: &ProtocolAddress,
405        record: &SessionRecord,
406    ) -> Result<(), StoreError> {
407        let address_str = address.to_string();
408        let session_data = record.serialize().map_err(|e| Box::new(e) as StoreError)?;
409
410        self.backend
411            .put_session(&address_str, &session_data)
412            .await
413            .map_err(|e| Box::new(e) as StoreError)
414    }
415
416    async fn contains_session(&self, address: &ProtocolAddress) -> Result<bool, StoreError> {
417        let address_str = address.to_string();
418        self.backend
419            .has_session(&address_str)
420            .await
421            .map_err(|e| Box::new(e) as StoreError)
422    }
423
424    async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
425        let address_str = address.to_string();
426        self.backend
427            .delete_session(&address_str)
428            .await
429            .map_err(|e| Box::new(e) as StoreError)
430    }
431
432    async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
433        let _ = name;
434        Ok(())
435    }
436}
437
438use async_lock::RwLock;
439
440pub struct DeviceRwLockWrapper(pub Arc<RwLock<Device>>);
441
442impl DeviceRwLockWrapper {
443    pub fn new(device: Arc<RwLock<Device>>) -> Self {
444        Self(device)
445    }
446}
447
448impl Clone for DeviceRwLockWrapper {
449    fn clone(&self) -> Self {
450        Self(self.0.clone())
451    }
452}
453
454impl_store_wrapper!(DeviceRwLockWrapper, read, write);
455
456pub struct DeviceStore(pub Arc<Mutex<Device>>);
457
458impl DeviceStore {
459    pub fn new(device: Arc<Mutex<Device>>) -> Self {
460        Self(device)
461    }
462}
463
464impl Clone for DeviceStore {
465    fn clone(&self) -> Self {
466        Self(self.0.clone())
467    }
468}
469
470impl_store_wrapper!(DeviceStore, lock, lock);
471
472#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
473#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
474impl SenderKeyStore for Device {
475    async fn store_sender_key(
476        &mut self,
477        sender_key_name: &SenderKeyName,
478        record: &SenderKeyRecord,
479    ) -> SignalResult<()> {
480        let unique_key = format!(
481            "{}:{}",
482            sender_key_name.group_id(),
483            sender_key_name.sender_id()
484        );
485        let serialized_record = record.serialize()?;
486        self.backend
487            .put_sender_key(&unique_key, &serialized_record)
488            .await
489            .map_err(|e| SignalProtocolError::InvalidState("store_sender_key", e.to_string()))
490    }
491
492    async fn load_sender_key(
493        &mut self,
494        sender_key_name: &SenderKeyName,
495    ) -> SignalResult<Option<SenderKeyRecord>> {
496        let unique_key = format!(
497            "{}:{}",
498            sender_key_name.group_id(),
499            sender_key_name.sender_id()
500        );
501        match self
502            .backend
503            .get_sender_key(&unique_key)
504            .await
505            .map_err(|e| SignalProtocolError::InvalidState("load_sender_key", e.to_string()))?
506        {
507            Some(data) => {
508                let record = SenderKeyRecord::deserialize(&data)?;
509                if record.serialize()?.is_empty() {
510                    Ok(None)
511                } else {
512                    Ok(Some(record))
513                }
514            }
515            None => Ok(None),
516        }
517    }
518}