1use crate::store::Device;
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::Mutex;
5use wacore::libsignal::protocol::error::Result as SignalResult;
6use wacore::libsignal::protocol::{
7 Direction, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore, PrivateKey,
8 ProtocolAddress, PublicKey, SenderKeyRecord, SenderKeyStore, SessionRecord,
9 SignalProtocolError,
10};
11use wacore::libsignal::store::sender_key_name::SenderKeyName;
12use wacore::libsignal::store::*;
13use waproto::whatsapp::{PreKeyRecordStructure, SignedPreKeyRecordStructure};
14
15type StoreError = Box<dyn std::error::Error + Send + Sync>;
16
17macro_rules! impl_store_wrapper {
18 ($wrapper_ty:ty, $read_lock:ident, $write_lock:ident) => {
19 #[async_trait]
20 impl IdentityKeyStore for $wrapper_ty {
21 async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
22 self.0.$read_lock().await.get_identity_key_pair().await
23 }
24
25 async fn get_local_registration_id(&self) -> SignalResult<u32> {
26 self.0.$read_lock().await.get_local_registration_id().await
27 }
28
29 async fn save_identity(
30 &mut self,
31 address: &ProtocolAddress,
32 identity_key: &IdentityKey,
33 ) -> SignalResult<IdentityChange> {
34 self.0
35 .$write_lock()
36 .await
37 .save_identity(address, identity_key)
38 .await
39 }
40
41 async fn is_trusted_identity(
42 &self,
43 address: &ProtocolAddress,
44 identity_key: &IdentityKey,
45 direction: Direction,
46 ) -> SignalResult<bool> {
47 self.0
48 .$read_lock()
49 .await
50 .is_trusted_identity(address, identity_key, direction)
51 .await
52 }
53
54 async fn get_identity(
55 &self,
56 address: &ProtocolAddress,
57 ) -> SignalResult<Option<IdentityKey>> {
58 self.0.$read_lock().await.get_identity(address).await
59 }
60 }
61
62 #[async_trait]
63 impl PreKeyStore for $wrapper_ty {
64 async fn load_prekey(
65 &self,
66 prekey_id: u32,
67 ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
68 self.0.$read_lock().await.load_prekey(prekey_id).await
69 }
70
71 async fn store_prekey(
72 &self,
73 prekey_id: u32,
74 record: PreKeyRecordStructure,
75 uploaded: bool,
76 ) -> Result<(), StoreError> {
77 self.0
78 .$write_lock()
79 .await
80 .store_prekey(prekey_id, record, uploaded)
81 .await
82 }
83
84 async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
85 self.0.$read_lock().await.contains_prekey(prekey_id).await
86 }
87
88 async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
89 self.0.$write_lock().await.remove_prekey(prekey_id).await
90 }
91 }
92
93 #[async_trait]
94 impl SignedPreKeyStore for $wrapper_ty {
95 async fn load_signed_prekey(
96 &self,
97 signed_prekey_id: u32,
98 ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
99 self.0
100 .$read_lock()
101 .await
102 .load_signed_prekey(signed_prekey_id)
103 .await
104 }
105
106 async fn load_signed_prekeys(
107 &self,
108 ) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
109 self.0.$read_lock().await.load_signed_prekeys().await
110 }
111
112 async fn store_signed_prekey(
113 &self,
114 signed_prekey_id: u32,
115 record: SignedPreKeyRecordStructure,
116 ) -> Result<(), StoreError> {
117 self.0
118 .$write_lock()
119 .await
120 .store_signed_prekey(signed_prekey_id, record)
121 .await
122 }
123
124 async fn contains_signed_prekey(
125 &self,
126 signed_prekey_id: u32,
127 ) -> Result<bool, StoreError> {
128 self.0
129 .$read_lock()
130 .await
131 .contains_signed_prekey(signed_prekey_id)
132 .await
133 }
134
135 async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
136 self.0
137 .$write_lock()
138 .await
139 .remove_signed_prekey(signed_prekey_id)
140 .await
141 }
142 }
143
144 #[async_trait]
145 impl SessionStore for $wrapper_ty {
146 async fn load_session(
147 &self,
148 address: &ProtocolAddress,
149 ) -> Result<SessionRecord, StoreError> {
150 self.0.$read_lock().await.load_session(address).await
151 }
152
153 async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
154 self.0
155 .$read_lock()
156 .await
157 .get_sub_device_sessions(name)
158 .await
159 }
160
161 async fn store_session(
162 &self,
163 address: &ProtocolAddress,
164 record: &SessionRecord,
165 ) -> Result<(), StoreError> {
166 self.0
167 .$write_lock()
168 .await
169 .store_session(address, record)
170 .await
171 }
172
173 async fn contains_session(
174 &self,
175 address: &ProtocolAddress,
176 ) -> Result<bool, StoreError> {
177 self.0.$read_lock().await.contains_session(address).await
178 }
179
180 async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
181 self.0.$write_lock().await.delete_session(address).await
182 }
183
184 async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
185 self.0.$write_lock().await.delete_all_sessions(name).await
186 }
187 }
188 };
189}
190
191#[async_trait]
192impl IdentityKeyStore for Device {
193 async fn get_identity_key_pair(&self) -> SignalResult<IdentityKeyPair> {
194 let private_key_bytes = self.identity_key.private_key;
195 let private_key = PrivateKey::deserialize(&private_key_bytes.serialize())?;
196 let ikp = IdentityKeyPair::try_from(private_key)?;
197 Ok(ikp)
198 }
199
200 async fn get_local_registration_id(&self) -> SignalResult<u32> {
201 Ok(self.registration_id)
202 }
203
204 async fn save_identity(
205 &mut self,
206 address: &ProtocolAddress,
207 identity_key: &IdentityKey,
208 ) -> SignalResult<IdentityChange> {
209 let address_str = address.to_string();
210 let key_bytes = identity_key.public_key().public_key_bytes();
211 let existing_identity_opt = self.get_identity(address).await?;
212
213 self.backend
214 .put_identity(
215 &address_str,
216 key_bytes.try_into().map_err(|_| {
217 SignalProtocolError::InvalidArgument("Invalid key length".into())
218 })?,
219 )
220 .await
221 .map_err(|e| {
222 SignalProtocolError::InvalidState("backend put_identity", e.to_string())
223 })?;
224
225 match existing_identity_opt {
226 None => Ok(IdentityChange::NewOrUnchanged),
227 Some(existing) if &existing == identity_key => Ok(IdentityChange::NewOrUnchanged),
228 Some(_) => Ok(IdentityChange::ReplacedExisting),
229 }
230 }
231
232 async fn is_trusted_identity(
233 &self,
234 address: &ProtocolAddress,
235 identity_key: &IdentityKey,
236 _direction: Direction,
237 ) -> SignalResult<bool> {
238 match self.get_identity(address).await? {
241 None => Ok(true), Some(stored_identity) => Ok(&stored_identity == identity_key),
243 }
244 }
245
246 async fn get_identity(&self, address: &ProtocolAddress) -> SignalResult<Option<IdentityKey>> {
247 let identity_bytes = self
248 .backend
249 .load_identity(&address.to_string())
250 .await
251 .map_err(|e| {
252 SignalProtocolError::InvalidState("backend get_identity", e.to_string())
253 })?;
254
255 match identity_bytes {
256 Some(bytes) if !bytes.is_empty() => {
257 let public_key = PublicKey::from_djb_public_key_bytes(&bytes)?;
258 Ok(Some(IdentityKey::new(public_key)))
259 }
260 _ => Ok(None),
261 }
262 }
263}
264
265#[async_trait]
266impl PreKeyStore for Device {
267 async fn load_prekey(
268 &self,
269 prekey_id: u32,
270 ) -> Result<Option<PreKeyRecordStructure>, StoreError> {
271 use prost::Message;
272 use wacore::libsignal::protocol::KeyPair;
273 use wacore::libsignal::store::record_helpers::new_pre_key_record;
274
275 match self.backend.load_prekey(prekey_id).await {
276 Ok(Some(bytes)) => {
277 if let Ok(record) = PreKeyRecordStructure::decode(bytes.as_slice()) {
279 return Ok(Some(record));
280 }
281
282 if let Ok(private_key) = PrivateKey::deserialize(&bytes)
285 && let Ok(public_key) = private_key.public_key()
286 {
287 let key_pair = KeyPair::new(public_key, private_key);
288 let record = new_pre_key_record(prekey_id, &key_pair);
289 return Ok(Some(record));
290 }
291
292 Ok(None)
294 }
295 Ok(None) => Ok(None),
296 Err(e) => Err(Box::new(e) as StoreError),
297 }
298 }
299
300 async fn store_prekey(
301 &self,
302 prekey_id: u32,
303 record: PreKeyRecordStructure,
304 uploaded: bool,
305 ) -> Result<(), StoreError> {
306 use prost::Message;
307 let bytes = record.encode_to_vec();
308 self.backend
309 .store_prekey(prekey_id, &bytes, uploaded)
310 .await
311 .map_err(|e| Box::new(e) as StoreError)
312 }
313
314 async fn contains_prekey(&self, prekey_id: u32) -> Result<bool, StoreError> {
315 match self.backend.load_prekey(prekey_id).await {
316 Ok(opt) => Ok(opt.is_some()),
317 Err(e) => Err(Box::new(e) as StoreError),
318 }
319 }
320
321 async fn remove_prekey(&self, prekey_id: u32) -> Result<(), StoreError> {
322 self.backend
323 .remove_prekey(prekey_id)
324 .await
325 .map_err(|e| Box::new(e) as StoreError)
326 }
327}
328
329#[async_trait]
330impl SignedPreKeyStore for Device {
331 async fn load_signed_prekey(
332 &self,
333 signed_prekey_id: u32,
334 ) -> Result<Option<SignedPreKeyRecordStructure>, StoreError> {
335 if signed_prekey_id == self.signed_pre_key_id {
336 use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
337
338 let public_key = PublicKey::from_djb_public_key_bytes(
339 self.signed_pre_key.public_key.public_key_bytes(),
340 )
341 .map_err(|e| Box::new(e) as StoreError)?;
342 let private_key = PrivateKey::deserialize(&self.signed_pre_key.private_key.serialize())
343 .map_err(|e| Box::new(e) as StoreError)?;
344 let key_pair = KeyPair::new(public_key, private_key);
345
346 let record = wacore::libsignal::store::record_helpers::new_signed_pre_key_record(
347 self.signed_pre_key_id,
348 &key_pair,
349 self.signed_pre_key_signature,
350 chrono::Utc::now(),
351 );
352 return Ok(Some(record));
353 }
354 Ok(None)
355 }
356
357 async fn load_signed_prekeys(&self) -> Result<Vec<SignedPreKeyRecordStructure>, StoreError> {
358 log::warn!(
359 "Device: load_signed_prekeys() - returning empty list. Only the device's own signed pre-key should be accessed via load_signed_prekey()."
360 );
361 Ok(Vec::new())
362 }
363
364 async fn store_signed_prekey(
365 &self,
366 signed_prekey_id: u32,
367 _record: SignedPreKeyRecordStructure,
368 ) -> Result<(), StoreError> {
369 log::warn!(
370 "Device: store_signed_prekey({}) - no-op. Signed pre-keys should only be set once during device creation/pairing and managed via PersistenceManager.",
371 signed_prekey_id
372 );
373 Ok(())
374 }
375
376 async fn contains_signed_prekey(&self, signed_prekey_id: u32) -> Result<bool, StoreError> {
377 Ok(signed_prekey_id == self.signed_pre_key_id)
378 }
379
380 async fn remove_signed_prekey(&self, signed_prekey_id: u32) -> Result<(), StoreError> {
381 log::warn!(
382 "Device: remove_signed_prekey({}) - no-op. Signed pre-keys are managed via PersistenceManager and should not be removed individually.",
383 signed_prekey_id
384 );
385 Ok(())
386 }
387}
388
389#[async_trait]
390impl SessionStore for Device {
391 async fn load_session(&self, address: &ProtocolAddress) -> Result<SessionRecord, StoreError> {
392 let address_str = address.to_string();
393 match self.backend.get_session(&address_str).await {
394 Ok(Some(session_data)) => {
395 SessionRecord::deserialize(&session_data).map_err(|e| Box::new(e) as StoreError)
396 }
397 Ok(None) => Ok(SessionRecord::new_fresh()),
398 Err(e) => Err(Box::new(e) as StoreError),
399 }
400 }
401
402 async fn get_sub_device_sessions(&self, name: &str) -> Result<Vec<u32>, StoreError> {
403 let _ = name;
404 Ok(Vec::new())
405 }
406
407 async fn store_session(
408 &self,
409 address: &ProtocolAddress,
410 record: &SessionRecord,
411 ) -> Result<(), StoreError> {
412 let address_str = address.to_string();
413 let session_data = record.serialize().map_err(|e| Box::new(e) as StoreError)?;
414
415 self.backend
416 .put_session(&address_str, &session_data)
417 .await
418 .map_err(|e| Box::new(e) as StoreError)
419 }
420
421 async fn contains_session(&self, address: &ProtocolAddress) -> Result<bool, StoreError> {
422 let address_str = address.to_string();
423 self.backend
424 .has_session(&address_str)
425 .await
426 .map_err(|e| Box::new(e) as StoreError)
427 }
428
429 async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), StoreError> {
430 let address_str = address.to_string();
431 self.backend
432 .delete_session(&address_str)
433 .await
434 .map_err(|e| Box::new(e) as StoreError)
435 }
436
437 async fn delete_all_sessions(&self, name: &str) -> Result<(), StoreError> {
438 let _ = name;
439 Ok(())
440 }
441}
442
443use tokio::sync::RwLock;
444
445pub struct DeviceRwLockWrapper(pub Arc<RwLock<Device>>);
446
447impl DeviceRwLockWrapper {
448 pub fn new(device: Arc<RwLock<Device>>) -> Self {
449 Self(device)
450 }
451}
452
453impl Clone for DeviceRwLockWrapper {
454 fn clone(&self) -> Self {
455 Self(self.0.clone())
456 }
457}
458
459impl_store_wrapper!(DeviceRwLockWrapper, read, write);
460
461pub struct DeviceStore(pub Arc<Mutex<Device>>);
462
463impl DeviceStore {
464 pub fn new(device: Arc<Mutex<Device>>) -> Self {
465 Self(device)
466 }
467}
468
469impl Clone for DeviceStore {
470 fn clone(&self) -> Self {
471 Self(self.0.clone())
472 }
473}
474
475impl_store_wrapper!(DeviceStore, lock, lock);
476
477#[async_trait]
478impl SenderKeyStore for Device {
479 async fn store_sender_key(
480 &mut self,
481 sender_key_name: &SenderKeyName,
482 record: &SenderKeyRecord,
483 ) -> SignalResult<()> {
484 let unique_key = format!(
485 "{}:{}",
486 sender_key_name.group_id(),
487 sender_key_name.sender_id()
488 );
489 let serialized_record = record.serialize()?;
490 self.backend
491 .put_sender_key(&unique_key, &serialized_record)
492 .await
493 .map_err(|e| SignalProtocolError::InvalidState("store_sender_key", e.to_string()))
494 }
495
496 async fn load_sender_key(
497 &mut self,
498 sender_key_name: &SenderKeyName,
499 ) -> SignalResult<Option<SenderKeyRecord>> {
500 let unique_key = format!(
501 "{}:{}",
502 sender_key_name.group_id(),
503 sender_key_name.sender_id()
504 );
505 match self
506 .backend
507 .get_sender_key(&unique_key)
508 .await
509 .map_err(|e| SignalProtocolError::InvalidState("load_sender_key", e.to_string()))?
510 {
511 Some(data) => {
512 let record = SenderKeyRecord::deserialize(&data)?;
513 if record.serialize()?.is_empty() {
514 Ok(None)
515 } else {
516 Ok(Some(record))
517 }
518 }
519 None => Ok(None),
520 }
521 }
522}