whatsapp_rust/store/
signal.rs

1use crate::store::Device;
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::Mutex;
5use wacore::libsignal::protocol::error::Result as SignalResult;
6use wacore::libsignal::protocol::{
7    Direction, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore, PrivateKey,
8    ProtocolAddress, PublicKey, SenderKeyRecord, SenderKeyStore, SessionRecord,
9    SignalProtocolError,
10};
11use wacore::libsignal::store::sender_key_name::SenderKeyName;
12use wacore::libsignal::store::*;
13use waproto::whatsapp::{PreKeyRecordStructure, SignedPreKeyRecordStructure};
14
15type StoreError = Box<dyn std::error::Error + Send + Sync>;
16
17macro_rules! impl_store_wrapper {
18    ($wrapper_ty:ty, $read_lock:ident, $write_lock:ident) => {
19        #[async_trait]
20        impl IdentityKeyStore for $wrapper_ty {
21            async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
22                self.0.$read_lock().await.get_identity_key_pair().await
23            }
24
25            async fn get_local_registration_id(&self) -> SignalResult<u32> {
26                self.0.$read_lock().await.get_local_registration_id().await
27            }
28
29            async fn save_identity(
30                &mut self,
31                address: &ProtocolAddress,
32                identity_key: &IdentityKey,
33            ) -> SignalResult<IdentityChange> {
34                self.0
35                    .$write_lock()
36                    .await
37                    .save_identity(address, identity_key)
38                    .await
39            }
40
41            async fn is_trusted_identity(
42                &self,
43                address: &ProtocolAddress,
44                identity_key: &IdentityKey,
45                direction: Direction,
46            ) -> SignalResult<bool> {
47                self.0
48                    .$read_lock()
49                    .await
50                    .is_trusted_identity(address, identity_key, direction)
51                    .await
52            }
53
54            async fn get_identity(
55                &self,
56                address: &ProtocolAddress,
57            ) -> SignalResult<Option<IdentityKey>> {
58                self.0.$read_lock().await.get_identity(address).await
59            }
60        }
61
62        #[async_trait]
63        impl PreKeyStore for $wrapper_ty {
64            async fn load_prekey(
65                &self,
66                prekey_id: u32,
67            ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
68                self.0.$read_lock().await.load_prekey(prekey_id).await
69            }
70
71            async fn store_prekey(
72                &self,
73                prekey_id: u32,
74                record: PreKeyRecordStructure,
75                uploaded: bool,
76            ) -> Result<(), StoreError> {
77                self.0
78                    .$write_lock()
79                    .await
80                    .store_prekey(prekey_id, record, uploaded)
81                    .await
82            }
83
84            async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
85                self.0.$read_lock().await.contains_prekey(prekey_id).await
86            }
87
88            async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
89                self.0.$write_lock().await.remove_prekey(prekey_id).await
90            }
91        }
92
93        #[async_trait]
94        impl SignedPreKeyStore for $wrapper_ty {
95            async fn load_signed_prekey(
96                &self,
97                signed_prekey_id: u32,
98            ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
99                self.0
100                    .$read_lock()
101                    .await
102                    .load_signed_prekey(signed_prekey_id)
103                    .await
104            }
105
106            async fn load_signed_prekeys(
107                &self,
108            ) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
109                self.0.$read_lock().await.load_signed_prekeys().await
110            }
111
112            async fn store_signed_prekey(
113                &self,
114                signed_prekey_id: u32,
115                record: SignedPreKeyRecordStructure,
116            ) -> Result<(), StoreError> {
117                self.0
118                    .$write_lock()
119                    .await
120                    .store_signed_prekey(signed_prekey_id, record)
121                    .await
122            }
123
124            async fn contains_signed_prekey(
125                &self,
126                signed_prekey_id: u32,
127            ) -> Result<bool, StoreError> {
128                self.0
129                    .$read_lock()
130                    .await
131                    .contains_signed_prekey(signed_prekey_id)
132                    .await
133            }
134
135            async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
136                self.0
137                    .$write_lock()
138                    .await
139                    .remove_signed_prekey(signed_prekey_id)
140                    .await
141            }
142        }
143
144        #[async_trait]
145        impl SessionStore for $wrapper_ty {
146            async fn load_session(
147                &self,
148                address: &ProtocolAddress,
149            ) -> Result<SessionRecord, StoreError> {
150                self.0.$read_lock().await.load_session(address).await
151            }
152
153            async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
154                self.0
155                    .$read_lock()
156                    .await
157                    .get_sub_device_sessions(name)
158                    .await
159            }
160
161            async fn store_session(
162                &self,
163                address: &ProtocolAddress,
164                record: &SessionRecord,
165            ) -> Result<(), StoreError> {
166                self.0
167                    .$write_lock()
168                    .await
169                    .store_session(address, record)
170                    .await
171            }
172
173            async fn contains_session(
174                &self,
175                address: &ProtocolAddress,
176            ) -> Result<bool, StoreError> {
177                self.0.$read_lock().await.contains_session(address).await
178            }
179
180            async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
181                self.0.$write_lock().await.delete_session(address).await
182            }
183
184            async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
185                self.0.$write_lock().await.delete_all_sessions(name).await
186            }
187        }
188    };
189}
190
191#[async_trait]
192impl IdentityKeyStore for Device {
193    async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
194        let private_key_bytes = self.identity_key.private_key;
195        let private_key = PrivateKey::deserialize(&private_key_bytes.serialize())?;
196        let ikp = IdentityKeyPair::try_from(private_key)?;
197        Ok(ikp)
198    }
199
200    async fn get_local_registration_id(&self) -> SignalResult<u32> {
201        Ok(self.registration_id)
202    }
203
204    async fn save_identity(
205        &mut self,
206        address: &ProtocolAddress,
207        identity_key: &IdentityKey,
208    ) -> SignalResult<IdentityChange> {
209        let address_str = address.to_string();
210        let key_bytes = identity_key.public_key().public_key_bytes();
211        let existing_identity_opt = self.get_identity(address).await?;
212
213        self.backend
214            .put_identity(
215                &address_str,
216                key_bytes.try_into().map_err(|_| {
217                    SignalProtocolError::InvalidArgument("Invalid key length".into())
218                })?,
219            )
220            .await
221            .map_err(|e| {
222                SignalProtocolError::InvalidState("backend put_identity", e.to_string())
223            })?;
224
225        match existing_identity_opt {
226            None => Ok(IdentityChange::NewOrUnchanged),
227            Some(existing) if &existing == identity_key => Ok(IdentityChange::NewOrUnchanged),
228            Some(_) => Ok(IdentityChange::ReplacedExisting),
229        }
230    }
231
232    async fn is_trusted_identity(
233        &self,
234        address: &ProtocolAddress,
235        identity_key: &IdentityKey,
236        direction: Direction,
237    ) -> SignalResult<bool> {
238        let key_bytes = identity_key.public_key().public_key_bytes();
239        let key_array: [u8; 32] = key_bytes
240            .try_into()
241            .map_err(|_| SignalProtocolError::InvalidArgument("Invalid key length".into()))?;
242
243        self.backend
244            .is_trusted_identity(&address.to_string(), &key_array, direction)
245            .await
246            .map_err(|e| {
247                SignalProtocolError::InvalidState("backend is_trusted_identity", e.to_string())
248            })
249    }
250
251    async fn get_identity(&self, address: &ProtocolAddress) -> SignalResult<Option<IdentityKey>> {
252        let identity_bytes = self
253            .backend
254            .load_identity(&address.to_string())
255            .await
256            .map_err(|e| {
257                SignalProtocolError::InvalidState("backend get_identity", e.to_string())
258            })?;
259
260        match identity_bytes {
261            Some(bytes) if !bytes.is_empty() => {
262                let public_key = PublicKey::from_djb_public_key_bytes(&bytes)?;
263                Ok(Some(IdentityKey::new(public_key)))
264            }
265            _ => Ok(None),
266        }
267    }
268}
269
270#[async_trait]
271impl PreKeyStore for Device {
272    async fn load_prekey(
273        &self,
274        prekey_id: u32,
275    ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
276        self.backend.load_prekey(prekey_id).await
277    }
278
279    async fn store_prekey(
280        &self,
281        prekey_id: u32,
282        record: PreKeyRecordStructure,
283        uploaded: bool,
284    ) -> Result<(), StoreError> {
285        self.backend.store_prekey(prekey_id, record, uploaded).await
286    }
287
288    async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
289        self.backend.contains_prekey(prekey_id).await
290    }
291
292    async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
293        self.backend.remove_prekey(prekey_id).await
294    }
295}
296
297#[async_trait]
298impl SignedPreKeyStore for Device {
299    async fn load_signed_prekey(
300        &self,
301        signed_prekey_id: u32,
302    ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
303        if signed_prekey_id == self.signed_pre_key_id {
304            use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
305
306            let public_key = PublicKey::from_djb_public_key_bytes(
307                self.signed_pre_key.public_key.public_key_bytes(),
308            )
309            .map_err(|e| Box::new(e) as StoreError)?;
310            let private_key = PrivateKey::deserialize(&self.signed_pre_key.private_key.serialize())
311                .map_err(|e| Box::new(e) as StoreError)?;
312            let key_pair = KeyPair::new(public_key, private_key);
313
314            let record = wacore::libsignal::store::record_helpers::new_signed_pre_key_record(
315                self.signed_pre_key_id,
316                &key_pair,
317                self.signed_pre_key_signature,
318                chrono::Utc::now(),
319            );
320            return Ok(Some(record));
321        }
322        Ok(None)
323    }
324
325    async fn load_signed_prekeys(&self) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
326        log::warn!(
327            "Device: load_signed_prekeys() - returning empty list. Only the device's own signed pre-key should be accessed via load_signed_prekey()."
328        );
329        Ok(Vec::new())
330    }
331
332    async fn store_signed_prekey(
333        &self,
334        signed_prekey_id: u32,
335        _record: SignedPreKeyRecordStructure,
336    ) -> Result<(), StoreError> {
337        log::warn!(
338            "Device: store_signed_prekey({}) - no-op. Signed pre-keys should only be set once during device creation/pairing and managed via PersistenceManager.",
339            signed_prekey_id
340        );
341        Ok(())
342    }
343
344    async fn contains_signed_prekey(&self, signed_prekey_id: u32) -> Result<bool, StoreError> {
345        Ok(signed_prekey_id == self.signed_pre_key_id)
346    }
347
348    async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
349        log::warn!(
350            "Device: remove_signed_prekey({}) - no-op. Signed pre-keys are managed via PersistenceManager and should not be removed individually.",
351            signed_prekey_id
352        );
353        Ok(())
354    }
355}
356
357#[async_trait]
358impl SessionStore for Device {
359    async fn load_session(&self, address: &ProtocolAddress) -> Result<SessionRecord, StoreError> {
360        let address_str = address.to_string();
361        match self.backend.get_session(&address_str).await {
362            Ok(Some(session_data)) => {
363                SessionRecord::deserialize(&session_data).map_err(|e| Box::new(e) as StoreError)
364            }
365            Ok(None) => Ok(SessionRecord::new_fresh()),
366            Err(e) => Err(Box::new(e) as StoreError),
367        }
368    }
369
370    async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
371        let _ = name;
372        Ok(Vec::new())
373    }
374
375    async fn store_session(
376        &self,
377        address: &ProtocolAddress,
378        record: &SessionRecord,
379    ) -> Result<(), StoreError> {
380        let address_str = address.to_string();
381        let session_data = record.serialize().map_err(|e| Box::new(e) as StoreError)?;
382
383        self.backend
384            .put_session(&address_str, &session_data)
385            .await
386            .map_err(|e| Box::new(e) as StoreError)
387    }
388
389    async fn contains_session(&self, address: &ProtocolAddress) -> Result<bool, StoreError> {
390        let address_str = address.to_string();
391        self.backend
392            .has_session(&address_str)
393            .await
394            .map_err(|e| Box::new(e) as StoreError)
395    }
396
397    async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
398        let address_str = address.to_string();
399        self.backend
400            .delete_session(&address_str)
401            .await
402            .map_err(|e| Box::new(e) as StoreError)
403    }
404
405    async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
406        let _ = name;
407        Ok(())
408    }
409}
410
411pub struct DeviceArcWrapper(pub Arc<Device>);
412
413impl DeviceArcWrapper {
414    pub fn new(device: Arc<Device>) -> Self {
415        Self(device)
416    }
417}
418
419impl Clone for DeviceArcWrapper {
420    fn clone(&self) -> Self {
421        Self(self.0.clone())
422    }
423}
424
425use tokio::sync::RwLock;
426
427pub struct DeviceRwLockWrapper(pub Arc<RwLock<Device>>);
428
429impl DeviceRwLockWrapper {
430    pub fn new(device: Arc<RwLock<Device>>) -> Self {
431        Self(device)
432    }
433}
434
435impl Clone for DeviceRwLockWrapper {
436    fn clone(&self) -> Self {
437        Self(self.0.clone())
438    }
439}
440
441impl_store_wrapper!(DeviceRwLockWrapper, read, write);
442
443pub struct DeviceStore(pub Arc<Mutex<Device>>);
444
445impl DeviceStore {
446    pub fn new(device: Arc<Mutex<Device>>) -> Self {
447        Self(device)
448    }
449}
450
451impl Clone for DeviceStore {
452    fn clone(&self) -> Self {
453        Self(self.0.clone())
454    }
455}
456
457impl_store_wrapper!(DeviceStore, lock, lock);
458
459#[async_trait]
460impl SenderKeyStore for Device {
461    async fn store_sender_key(
462        &mut self,
463        sender_key_name: &SenderKeyName,
464        record: &SenderKeyRecord,
465    ) -> SignalResult<()> {
466        let unique_key = format!(
467            "{}:{}",
468            sender_key_name.group_id(),
469            sender_key_name.sender_id()
470        );
471        let serialized_record = record.serialize()?;
472        self.backend
473            .put_sender_key(&unique_key, &serialized_record)
474            .await
475            .map_err(|e| SignalProtocolError::InvalidState("store_sender_key", e.to_string()))
476    }
477
478    async fn load_sender_key(
479        &mut self,
480        sender_key_name: &SenderKeyName,
481    ) -> SignalResult<Option<SenderKeyRecord>> {
482        let unique_key = format!(
483            "{}:{}",
484            sender_key_name.group_id(),
485            sender_key_name.sender_id()
486        );
487        match self
488            .backend
489            .get_sender_key(&unique_key)
490            .await
491            .map_err(|e| SignalProtocolError::InvalidState("load_sender_key", e.to_string()))?
492        {
493            Some(data) => {
494                let record = SenderKeyRecord::deserialize(&data)?;
495                if record.serialize()?.is_empty() {
496                    Ok(None)
497                } else {
498                    Ok(Some(record))
499                }
500            }
501            None => Ok(None),
502        }
503    }
504}