whatsapp_rust/store/
signal.rs

1use crate::store::Device;
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::Mutex;
5use wacore::libsignal::protocol::error::Result as SignalResult;
6use wacore::libsignal::protocol::{
7    Direction, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore, PrivateKey,
8    ProtocolAddress, PublicKey, SenderKeyRecord, SenderKeyStore, SessionRecord,
9    SignalProtocolError,
10};
11use wacore::libsignal::store::sender_key_name::SenderKeyName;
12use wacore::libsignal::store::*;
13use waproto::whatsapp::{PreKeyRecordStructure, SignedPreKeyRecordStructure};
14
15type StoreError = Box<dyn std::error::Error + Send + Sync>;
16
17macro_rules! impl_store_wrapper {
18    ($wrapper_ty:ty, $read_lock:ident, $write_lock:ident) => {
19        #[async_trait]
20        impl IdentityKeyStore for $wrapper_ty {
21            async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
22                self.0.$read_lock().await.get_identity_key_pair().await
23            }
24
25            async fn get_local_registration_id(&self) -> SignalResult<u32> {
26                self.0.$read_lock().await.get_local_registration_id().await
27            }
28
29            async fn save_identity(
30                &mut self,
31                address: &ProtocolAddress,
32                identity_key: &IdentityKey,
33            ) -> SignalResult<IdentityChange> {
34                self.0
35                    .$write_lock()
36                    .await
37                    .save_identity(address, identity_key)
38                    .await
39            }
40
41            async fn is_trusted_identity(
42                &self,
43                address: &ProtocolAddress,
44                identity_key: &IdentityKey,
45                direction: Direction,
46            ) -> SignalResult<bool> {
47                self.0
48                    .$read_lock()
49                    .await
50                    .is_trusted_identity(address, identity_key, direction)
51                    .await
52            }
53
54            async fn get_identity(
55                &self,
56                address: &ProtocolAddress,
57            ) -> SignalResult<Option<IdentityKey>> {
58                self.0.$read_lock().await.get_identity(address).await
59            }
60        }
61
62        #[async_trait]
63        impl PreKeyStore for $wrapper_ty {
64            async fn load_prekey(
65                &self,
66                prekey_id: u32,
67            ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
68                self.0.$read_lock().await.load_prekey(prekey_id).await
69            }
70
71            async fn store_prekey(
72                &self,
73                prekey_id: u32,
74                record: PreKeyRecordStructure,
75                uploaded: bool,
76            ) -> Result<(), StoreError> {
77                self.0
78                    .$write_lock()
79                    .await
80                    .store_prekey(prekey_id, record, uploaded)
81                    .await
82            }
83
84            async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
85                self.0.$read_lock().await.contains_prekey(prekey_id).await
86            }
87
88            async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
89                self.0.$write_lock().await.remove_prekey(prekey_id).await
90            }
91        }
92
93        #[async_trait]
94        impl SignedPreKeyStore for $wrapper_ty {
95            async fn load_signed_prekey(
96                &self,
97                signed_prekey_id: u32,
98            ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
99                self.0
100                    .$read_lock()
101                    .await
102                    .load_signed_prekey(signed_prekey_id)
103                    .await
104            }
105
106            async fn load_signed_prekeys(
107                &self,
108            ) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
109                self.0.$read_lock().await.load_signed_prekeys().await
110            }
111
112            async fn store_signed_prekey(
113                &self,
114                signed_prekey_id: u32,
115                record: SignedPreKeyRecordStructure,
116            ) -> Result<(), StoreError> {
117                self.0
118                    .$write_lock()
119                    .await
120                    .store_signed_prekey(signed_prekey_id, record)
121                    .await
122            }
123
124            async fn contains_signed_prekey(
125                &self,
126                signed_prekey_id: u32,
127            ) -> Result<bool, StoreError> {
128                self.0
129                    .$read_lock()
130                    .await
131                    .contains_signed_prekey(signed_prekey_id)
132                    .await
133            }
134
135            async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
136                self.0
137                    .$write_lock()
138                    .await
139                    .remove_signed_prekey(signed_prekey_id)
140                    .await
141            }
142        }
143
144        #[async_trait]
145        impl SessionStore for $wrapper_ty {
146            async fn load_session(
147                &self,
148                address: &ProtocolAddress,
149            ) -> Result<SessionRecord, StoreError> {
150                self.0.$read_lock().await.load_session(address).await
151            }
152
153            async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
154                self.0
155                    .$read_lock()
156                    .await
157                    .get_sub_device_sessions(name)
158                    .await
159            }
160
161            async fn store_session(
162                &self,
163                address: &ProtocolAddress,
164                record: &SessionRecord,
165            ) -> Result<(), StoreError> {
166                self.0
167                    .$write_lock()
168                    .await
169                    .store_session(address, record)
170                    .await
171            }
172
173            async fn contains_session(
174                &self,
175                address: &ProtocolAddress,
176            ) -> Result<bool, StoreError> {
177                self.0.$read_lock().await.contains_session(address).await
178            }
179
180            async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
181                self.0.$write_lock().await.delete_session(address).await
182            }
183
184            async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
185                self.0.$write_lock().await.delete_all_sessions(name).await
186            }
187        }
188    };
189}
190
191#[async_trait]
192impl IdentityKeyStore for Device {
193    async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
194        let private_key_bytes = self.identity_key.private_key;
195        let private_key = PrivateKey::deserialize(&private_key_bytes.serialize())?;
196        let ikp = IdentityKeyPair::try_from(private_key)?;
197        Ok(ikp)
198    }
199
200    async fn get_local_registration_id(&self) -> SignalResult<u32> {
201        Ok(self.registration_id)
202    }
203
204    async fn save_identity(
205        &mut self,
206        address: &ProtocolAddress,
207        identity_key: &IdentityKey,
208    ) -> SignalResult<IdentityChange> {
209        let address_str = address.to_string();
210        let key_bytes = identity_key.public_key().public_key_bytes();
211        let existing_identity_opt = self.get_identity(address).await?;
212
213        self.backend
214            .put_identity(
215                &address_str,
216                key_bytes.try_into().map_err(|_| {
217                    SignalProtocolError::InvalidArgument("Invalid key length".into())
218                })?,
219            )
220            .await
221            .map_err(|e| {
222                SignalProtocolError::InvalidState("backend put_identity", e.to_string())
223            })?;
224
225        match existing_identity_opt {
226            None => Ok(IdentityChange::NewOrUnchanged),
227            Some(existing) if &existing == identity_key => Ok(IdentityChange::NewOrUnchanged),
228            Some(_) => Ok(IdentityChange::ReplacedExisting),
229        }
230    }
231
232    async fn is_trusted_identity(
233        &self,
234        address: &ProtocolAddress,
235        identity_key: &IdentityKey,
236        _direction: Direction,
237    ) -> SignalResult<bool> {
238        // Trust on first use: if we don't have an identity stored, trust this one
239        // If we have one stored, it must match
240        match self.get_identity(address).await? {
241            None => Ok(true), // Trust on first use
242            Some(stored_identity) => Ok(&stored_identity == identity_key),
243        }
244    }
245
246    async fn get_identity(&self, address: &ProtocolAddress) -> SignalResult<Option<IdentityKey>> {
247        let identity_bytes = self
248            .backend
249            .load_identity(&address.to_string())
250            .await
251            .map_err(|e| {
252                SignalProtocolError::InvalidState("backend get_identity", e.to_string())
253            })?;
254
255        match identity_bytes {
256            Some(bytes) if !bytes.is_empty() => {
257                let public_key = PublicKey::from_djb_public_key_bytes(&bytes)?;
258                Ok(Some(IdentityKey::new(public_key)))
259            }
260            _ => Ok(None),
261        }
262    }
263}
264
265#[async_trait]
266impl PreKeyStore for Device {
267    async fn load_prekey(
268        &self,
269        prekey_id: u32,
270    ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
271        use prost::Message;
272        use wacore::libsignal::protocol::KeyPair;
273        use wacore::libsignal::store::record_helpers::new_pre_key_record;
274
275        match self.backend.load_prekey(prekey_id).await {
276            Ok(Some(bytes)) => {
277                // Try new format first (protobuf-encoded PreKeyRecordStructure)
278                if let Ok(record) = PreKeyRecordStructure::decode(bytes.as_slice()) {
279                    return Ok(Some(record));
280                }
281
282                // Fallback: old format stored just the private key bytes (32 bytes)
283                // Reconstruct the full record by deriving the public key
284                if let Ok(private_key) = PrivateKey::deserialize(&bytes)
285                    && let Ok(public_key) = private_key.public_key()
286                {
287                    let key_pair = KeyPair::new(public_key, private_key);
288                    let record = new_pre_key_record(prekey_id, &key_pair);
289                    return Ok(Some(record));
290                }
291
292                // Could not decode in either format
293                Ok(None)
294            }
295            Ok(None) => Ok(None),
296            Err(e) => Err(Box::new(e) as StoreError),
297        }
298    }
299
300    async fn store_prekey(
301        &self,
302        prekey_id: u32,
303        record: PreKeyRecordStructure,
304        uploaded: bool,
305    ) -> Result<(), StoreError> {
306        use prost::Message;
307        let bytes = record.encode_to_vec();
308        self.backend
309            .store_prekey(prekey_id, &bytes, uploaded)
310            .await
311            .map_err(|e| Box::new(e) as StoreError)
312    }
313
314    async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
315        match self.backend.load_prekey(prekey_id).await {
316            Ok(opt) => Ok(opt.is_some()),
317            Err(e) => Err(Box::new(e) as StoreError),
318        }
319    }
320
321    async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
322        self.backend
323            .remove_prekey(prekey_id)
324            .await
325            .map_err(|e| Box::new(e) as StoreError)
326    }
327}
328
329#[async_trait]
330impl SignedPreKeyStore for Device {
331    async fn load_signed_prekey(
332        &self,
333        signed_prekey_id: u32,
334    ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
335        if signed_prekey_id == self.signed_pre_key_id {
336            use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
337
338            let public_key = PublicKey::from_djb_public_key_bytes(
339                self.signed_pre_key.public_key.public_key_bytes(),
340            )
341            .map_err(|e| Box::new(e) as StoreError)?;
342            let private_key = PrivateKey::deserialize(&self.signed_pre_key.private_key.serialize())
343                .map_err(|e| Box::new(e) as StoreError)?;
344            let key_pair = KeyPair::new(public_key, private_key);
345
346            let record = wacore::libsignal::store::record_helpers::new_signed_pre_key_record(
347                self.signed_pre_key_id,
348                &key_pair,
349                self.signed_pre_key_signature,
350                chrono::Utc::now(),
351            );
352            return Ok(Some(record));
353        }
354        Ok(None)
355    }
356
357    async fn load_signed_prekeys(&self) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
358        log::warn!(
359            "Device: load_signed_prekeys() - returning empty list. Only the device's own signed pre-key should be accessed via load_signed_prekey()."
360        );
361        Ok(Vec::new())
362    }
363
364    async fn store_signed_prekey(
365        &self,
366        signed_prekey_id: u32,
367        _record: SignedPreKeyRecordStructure,
368    ) -> Result<(), StoreError> {
369        log::warn!(
370            "Device: store_signed_prekey({}) - no-op. Signed pre-keys should only be set once during device creation/pairing and managed via PersistenceManager.",
371            signed_prekey_id
372        );
373        Ok(())
374    }
375
376    async fn contains_signed_prekey(&self, signed_prekey_id: u32) -> Result<bool, StoreError> {
377        Ok(signed_prekey_id == self.signed_pre_key_id)
378    }
379
380    async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
381        log::warn!(
382            "Device: remove_signed_prekey({}) - no-op. Signed pre-keys are managed via PersistenceManager and should not be removed individually.",
383            signed_prekey_id
384        );
385        Ok(())
386    }
387}
388
389#[async_trait]
390impl SessionStore for Device {
391    async fn load_session(&self, address: &ProtocolAddress) -> Result<SessionRecord, StoreError> {
392        let address_str = address.to_string();
393        match self.backend.get_session(&address_str).await {
394            Ok(Some(session_data)) => {
395                SessionRecord::deserialize(&session_data).map_err(|e| Box::new(e) as StoreError)
396            }
397            Ok(None) => Ok(SessionRecord::new_fresh()),
398            Err(e) => Err(Box::new(e) as StoreError),
399        }
400    }
401
402    async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
403        let _ = name;
404        Ok(Vec::new())
405    }
406
407    async fn store_session(
408        &self,
409        address: &ProtocolAddress,
410        record: &SessionRecord,
411    ) -> Result<(), StoreError> {
412        let address_str = address.to_string();
413        let session_data = record.serialize().map_err(|e| Box::new(e) as StoreError)?;
414
415        self.backend
416            .put_session(&address_str, &session_data)
417            .await
418            .map_err(|e| Box::new(e) as StoreError)
419    }
420
421    async fn contains_session(&self, address: &ProtocolAddress) -> Result<bool, StoreError> {
422        let address_str = address.to_string();
423        self.backend
424            .has_session(&address_str)
425            .await
426            .map_err(|e| Box::new(e) as StoreError)
427    }
428
429    async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
430        let address_str = address.to_string();
431        self.backend
432            .delete_session(&address_str)
433            .await
434            .map_err(|e| Box::new(e) as StoreError)
435    }
436
437    async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
438        let _ = name;
439        Ok(())
440    }
441}
442
443use tokio::sync::RwLock;
444
445pub struct DeviceRwLockWrapper(pub Arc<RwLock<Device>>);
446
447impl DeviceRwLockWrapper {
448    pub fn new(device: Arc<RwLock<Device>>) -> Self {
449        Self(device)
450    }
451}
452
453impl Clone for DeviceRwLockWrapper {
454    fn clone(&self) -> Self {
455        Self(self.0.clone())
456    }
457}
458
459impl_store_wrapper!(DeviceRwLockWrapper, read, write);
460
461pub struct DeviceStore(pub Arc<Mutex<Device>>);
462
463impl DeviceStore {
464    pub fn new(device: Arc<Mutex<Device>>) -> Self {
465        Self(device)
466    }
467}
468
469impl Clone for DeviceStore {
470    fn clone(&self) -> Self {
471        Self(self.0.clone())
472    }
473}
474
475impl_store_wrapper!(DeviceStore, lock, lock);
476
477#[async_trait]
478impl SenderKeyStore for Device {
479    async fn store_sender_key(
480        &mut self,
481        sender_key_name: &SenderKeyName,
482        record: &SenderKeyRecord,
483    ) -> SignalResult<()> {
484        let unique_key = format!(
485            "{}:{}",
486            sender_key_name.group_id(),
487            sender_key_name.sender_id()
488        );
489        let serialized_record = record.serialize()?;
490        self.backend
491            .put_sender_key(&unique_key, &serialized_record)
492            .await
493            .map_err(|e| SignalProtocolError::InvalidState("store_sender_key", e.to_string()))
494    }
495
496    async fn load_sender_key(
497        &mut self,
498        sender_key_name: &SenderKeyName,
499    ) -> SignalResult<Option<SenderKeyRecord>> {
500        let unique_key = format!(
501            "{}:{}",
502            sender_key_name.group_id(),
503            sender_key_name.sender_id()
504        );
505        match self
506            .backend
507            .get_sender_key(&unique_key)
508            .await
509            .map_err(|e| SignalProtocolError::InvalidState("load_sender_key", e.to_string()))?
510        {
511            Some(data) => {
512                let record = SenderKeyRecord::deserialize(&data)?;
513                if record.serialize()?.is_empty() {
514                    Ok(None)
515                } else {
516                    Ok(Some(record))
517                }
518            }
519            None => Ok(None),
520        }
521    }
522}