1use crate::store::Device;
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::Mutex;
5use wacore::libsignal::protocol::error::Result as SignalResult;
6use wacore::libsignal::protocol::{
7 Direction, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore, PrivateKey,
8 ProtocolAddress, PublicKey, SenderKeyRecord, SenderKeyStore, SessionRecord,
9 SignalProtocolError,
10};
11use wacore::libsignal::store::sender_key_name::SenderKeyName;
12use wacore::libsignal::store::*;
13use waproto::whatsapp::{PreKeyRecordStructure, SignedPreKeyRecordStructure};
14
15type StoreError = Box<dyn std::error::Error + Send + Sync>;
16
17macro_rules! impl_store_wrapper {
18 ($wrapper_ty:ty, $read_lock:ident, $write_lock:ident) => {
19 #[async_trait]
20 impl IdentityKeyStore for $wrapper_ty {
21 async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
22 self.0.$read_lock().await.get_identity_key_pair().await
23 }
24
25 async fn get_local_registration_id(&self) -> SignalResult<u32> {
26 self.0.$read_lock().await.get_local_registration_id().await
27 }
28
29 async fn save_identity(
30 &mut self,
31 address: &ProtocolAddress,
32 identity_key: &IdentityKey,
33 ) -> SignalResult<IdentityChange> {
34 self.0
35 .$write_lock()
36 .await
37 .save_identity(address, identity_key)
38 .await
39 }
40
41 async fn is_trusted_identity(
42 &self,
43 address: &ProtocolAddress,
44 identity_key: &IdentityKey,
45 direction: Direction,
46 ) -> SignalResult<bool> {
47 self.0
48 .$read_lock()
49 .await
50 .is_trusted_identity(address, identity_key, direction)
51 .await
52 }
53
54 async fn get_identity(
55 &self,
56 address: &ProtocolAddress,
57 ) -> SignalResult<Option<IdentityKey>> {
58 self.0.$read_lock().await.get_identity(address).await
59 }
60 }
61
62 #[async_trait]
63 impl PreKeyStore for $wrapper_ty {
64 async fn load_prekey(
65 &self,
66 prekey_id: u32,
67 ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
68 self.0.$read_lock().await.load_prekey(prekey_id).await
69 }
70
71 async fn store_prekey(
72 &self,
73 prekey_id: u32,
74 record: PreKeyRecordStructure,
75 uploaded: bool,
76 ) -> Result<(), StoreError> {
77 self.0
78 .$write_lock()
79 .await
80 .store_prekey(prekey_id, record, uploaded)
81 .await
82 }
83
84 async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
85 self.0.$read_lock().await.contains_prekey(prekey_id).await
86 }
87
88 async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
89 self.0.$write_lock().await.remove_prekey(prekey_id).await
90 }
91 }
92
93 #[async_trait]
94 impl SignedPreKeyStore for $wrapper_ty {
95 async fn load_signed_prekey(
96 &self,
97 signed_prekey_id: u32,
98 ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
99 self.0
100 .$read_lock()
101 .await
102 .load_signed_prekey(signed_prekey_id)
103 .await
104 }
105
106 async fn load_signed_prekeys(
107 &self,
108 ) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
109 self.0.$read_lock().await.load_signed_prekeys().await
110 }
111
112 async fn store_signed_prekey(
113 &self,
114 signed_prekey_id: u32,
115 record: SignedPreKeyRecordStructure,
116 ) -> Result<(), StoreError> {
117 self.0
118 .$write_lock()
119 .await
120 .store_signed_prekey(signed_prekey_id, record)
121 .await
122 }
123
124 async fn contains_signed_prekey(
125 &self,
126 signed_prekey_id: u32,
127 ) -> Result<bool, StoreError> {
128 self.0
129 .$read_lock()
130 .await
131 .contains_signed_prekey(signed_prekey_id)
132 .await
133 }
134
135 async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
136 self.0
137 .$write_lock()
138 .await
139 .remove_signed_prekey(signed_prekey_id)
140 .await
141 }
142 }
143
144 #[async_trait]
145 impl SessionStore for $wrapper_ty {
146 async fn load_session(
147 &self,
148 address: &ProtocolAddress,
149 ) -> Result<SessionRecord, StoreError> {
150 self.0.$read_lock().await.load_session(address).await
151 }
152
153 async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
154 self.0
155 .$read_lock()
156 .await
157 .get_sub_device_sessions(name)
158 .await
159 }
160
161 async fn store_session(
162 &self,
163 address: &ProtocolAddress,
164 record: &SessionRecord,
165 ) -> Result<(), StoreError> {
166 self.0
167 .$write_lock()
168 .await
169 .store_session(address, record)
170 .await
171 }
172
173 async fn contains_session(
174 &self,
175 address: &ProtocolAddress,
176 ) -> Result<bool, StoreError> {
177 self.0.$read_lock().await.contains_session(address).await
178 }
179
180 async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
181 self.0.$write_lock().await.delete_session(address).await
182 }
183
184 async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
185 self.0.$write_lock().await.delete_all_sessions(name).await
186 }
187 }
188 };
189}
190
191#[async_trait]
192impl IdentityKeyStore for Device {
193 async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
194 let private_key_bytes = self.identity_key.private_key;
195 let private_key = PrivateKey::deserialize(&private_key_bytes.serialize())?;
196 let ikp = IdentityKeyPair::try_from(private_key)?;
197 Ok(ikp)
198 }
199
200 async fn get_local_registration_id(&self) -> SignalResult<u32> {
201 Ok(self.registration_id)
202 }
203
204 async fn save_identity(
205 &mut self,
206 address: &ProtocolAddress,
207 identity_key: &IdentityKey,
208 ) -> SignalResult<IdentityChange> {
209 let address_str = address.to_string();
210 let key_bytes = identity_key.public_key().public_key_bytes();
211 let existing_identity_opt = self.get_identity(address).await?;
212
213 self.backend
214 .put_identity(
215 &address_str,
216 key_bytes.try_into().map_err(|_| {
217 SignalProtocolError::InvalidArgument("Invalid key length".into())
218 })?,
219 )
220 .await
221 .map_err(|e| {
222 SignalProtocolError::InvalidState("backend put_identity", e.to_string())
223 })?;
224
225 match existing_identity_opt {
226 None => Ok(IdentityChange::NewOrUnchanged),
227 Some(existing) if &existing == identity_key => Ok(IdentityChange::NewOrUnchanged),
228 Some(_) => Ok(IdentityChange::ReplacedExisting),
229 }
230 }
231
232 async fn is_trusted_identity(
233 &self,
234 address: &ProtocolAddress,
235 identity_key: &IdentityKey,
236 direction: Direction,
237 ) -> SignalResult<bool> {
238 let key_bytes = identity_key.public_key().public_key_bytes();
239 let key_array: [u8; 32] = key_bytes
240 .try_into()
241 .map_err(|_| SignalProtocolError::InvalidArgument("Invalid key length".into()))?;
242
243 self.backend
244 .is_trusted_identity(&address.to_string(), &key_array, direction)
245 .await
246 .map_err(|e| {
247 SignalProtocolError::InvalidState("backend is_trusted_identity", e.to_string())
248 })
249 }
250
251 async fn get_identity(&self, address: &ProtocolAddress) -> SignalResult<Option<IdentityKey>> {
252 let identity_bytes = self
253 .backend
254 .load_identity(&address.to_string())
255 .await
256 .map_err(|e| {
257 SignalProtocolError::InvalidState("backend get_identity", e.to_string())
258 })?;
259
260 match identity_bytes {
261 Some(bytes) if !bytes.is_empty() => {
262 let public_key = PublicKey::from_djb_public_key_bytes(&bytes)?;
263 Ok(Some(IdentityKey::new(public_key)))
264 }
265 _ => Ok(None),
266 }
267 }
268}
269
270#[async_trait]
271impl PreKeyStore for Device {
272 async fn load_prekey(
273 &self,
274 prekey_id: u32,
275 ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
276 self.backend.load_prekey(prekey_id).await
277 }
278
279 async fn store_prekey(
280 &self,
281 prekey_id: u32,
282 record: PreKeyRecordStructure,
283 uploaded: bool,
284 ) -> Result<(), StoreError> {
285 self.backend.store_prekey(prekey_id, record, uploaded).await
286 }
287
288 async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
289 self.backend.contains_prekey(prekey_id).await
290 }
291
292 async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
293 self.backend.remove_prekey(prekey_id).await
294 }
295}
296
297#[async_trait]
298impl SignedPreKeyStore for Device {
299 async fn load_signed_prekey(
300 &self,
301 signed_prekey_id: u32,
302 ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
303 if signed_prekey_id == self.signed_pre_key_id {
304 use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
305
306 let public_key = PublicKey::from_djb_public_key_bytes(
307 self.signed_pre_key.public_key.public_key_bytes(),
308 )
309 .map_err(|e| Box::new(e) as StoreError)?;
310 let private_key = PrivateKey::deserialize(&self.signed_pre_key.private_key.serialize())
311 .map_err(|e| Box::new(e) as StoreError)?;
312 let key_pair = KeyPair::new(public_key, private_key);
313
314 let record = wacore::libsignal::store::record_helpers::new_signed_pre_key_record(
315 self.signed_pre_key_id,
316 &key_pair,
317 self.signed_pre_key_signature,
318 chrono::Utc::now(),
319 );
320 return Ok(Some(record));
321 }
322 Ok(None)
323 }
324
325 async fn load_signed_prekeys(&self) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
326 log::warn!(
327 "Device: load_signed_prekeys() - returning empty list. Only the device's own signed pre-key should be accessed via load_signed_prekey()."
328 );
329 Ok(Vec::new())
330 }
331
332 async fn store_signed_prekey(
333 &self,
334 signed_prekey_id: u32,
335 _record: SignedPreKeyRecordStructure,
336 ) -> Result<(), StoreError> {
337 log::warn!(
338 "Device: store_signed_prekey({}) - no-op. Signed pre-keys should only be set once during device creation/pairing and managed via PersistenceManager.",
339 signed_prekey_id
340 );
341 Ok(())
342 }
343
344 async fn contains_signed_prekey(&self, signed_prekey_id: u32) -> Result<bool, StoreError> {
345 Ok(signed_prekey_id == self.signed_pre_key_id)
346 }
347
348 async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
349 log::warn!(
350 "Device: remove_signed_prekey({}) - no-op. Signed pre-keys are managed via PersistenceManager and should not be removed individually.",
351 signed_prekey_id
352 );
353 Ok(())
354 }
355}
356
357#[async_trait]
358impl SessionStore for Device {
359 async fn load_session(&self, address: &ProtocolAddress) -> Result<SessionRecord, StoreError> {
360 let address_str = address.to_string();
361 match self.backend.get_session(&address_str).await {
362 Ok(Some(session_data)) => {
363 SessionRecord::deserialize(&session_data).map_err(|e| Box::new(e) as StoreError)
364 }
365 Ok(None) => Ok(SessionRecord::new_fresh()),
366 Err(e) => Err(Box::new(e) as StoreError),
367 }
368 }
369
370 async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
371 let _ = name;
372 Ok(Vec::new())
373 }
374
375 async fn store_session(
376 &self,
377 address: &ProtocolAddress,
378 record: &SessionRecord,
379 ) -> Result<(), StoreError> {
380 let address_str = address.to_string();
381 let session_data = record.serialize().map_err(|e| Box::new(e) as StoreError)?;
382
383 self.backend
384 .put_session(&address_str, &session_data)
385 .await
386 .map_err(|e| Box::new(e) as StoreError)
387 }
388
389 async fn contains_session(&self, address: &ProtocolAddress) -> Result<bool, StoreError> {
390 let address_str = address.to_string();
391 self.backend
392 .has_session(&address_str)
393 .await
394 .map_err(|e| Box::new(e) as StoreError)
395 }
396
397 async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
398 let address_str = address.to_string();
399 self.backend
400 .delete_session(&address_str)
401 .await
402 .map_err(|e| Box::new(e) as StoreError)
403 }
404
405 async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
406 let _ = name;
407 Ok(())
408 }
409}
410
411pub struct DeviceArcWrapper(pub Arc<Device>);
412
413impl DeviceArcWrapper {
414 pub fn new(device: Arc<Device>) -> Self {
415 Self(device)
416 }
417}
418
419impl Clone for DeviceArcWrapper {
420 fn clone(&self) -> Self {
421 Self(self.0.clone())
422 }
423}
424
425use tokio::sync::RwLock;
426
427pub struct DeviceRwLockWrapper(pub Arc<RwLock<Device>>);
428
429impl DeviceRwLockWrapper {
430 pub fn new(device: Arc<RwLock<Device>>) -> Self {
431 Self(device)
432 }
433}
434
435impl Clone for DeviceRwLockWrapper {
436 fn clone(&self) -> Self {
437 Self(self.0.clone())
438 }
439}
440
441impl_store_wrapper!(DeviceRwLockWrapper, read, write);
442
443pub struct DeviceStore(pub Arc<Mutex<Device>>);
444
445impl DeviceStore {
446 pub fn new(device: Arc<Mutex<Device>>) -> Self {
447 Self(device)
448 }
449}
450
451impl Clone for DeviceStore {
452 fn clone(&self) -> Self {
453 Self(self.0.clone())
454 }
455}
456
457impl_store_wrapper!(DeviceStore, lock, lock);
458
459#[async_trait]
460impl SenderKeyStore for Device {
461 async fn store_sender_key(
462 &mut self,
463 sender_key_name: &SenderKeyName,
464 record: &SenderKeyRecord,
465 ) -> SignalResult<()> {
466 let unique_key = format!(
467 "{}:{}",
468 sender_key_name.group_id(),
469 sender_key_name.sender_id()
470 );
471 let serialized_record = record.serialize()?;
472 self.backend
473 .put_sender_key(&unique_key, &serialized_record)
474 .await
475 .map_err(|e| SignalProtocolError::InvalidState("store_sender_key", e.to_string()))
476 }
477
478 async fn load_sender_key(
479 &mut self,
480 sender_key_name: &SenderKeyName,
481 ) -> SignalResult<Option<SenderKeyRecord>> {
482 let unique_key = format!(
483 "{}:{}",
484 sender_key_name.group_id(),
485 sender_key_name.sender_id()
486 );
487 match self
488 .backend
489 .get_sender_key(&unique_key)
490 .await
491 .map_err(|e| SignalProtocolError::InvalidState("load_sender_key", e.to_string()))?
492 {
493 Some(data) => {
494 let record = SenderKeyRecord::deserialize(&data)?;
495 if record.serialize()?.is_empty() {
496 Ok(None)
497 } else {
498 Ok(Some(record))
499 }
500 }
501 None => Ok(None),
502 }
503 }
504}