whatsapp_rust/store/
device_aware_store.rs

1use super::sqlite_store::SqliteStore;
2use crate::store::schema::*;
3use async_trait::async_trait;
4use diesel::prelude::*;
5use prost::Message;
6use std::sync::Arc;
7use wacore::appstate::hash::HashState;
8use wacore::store::error::Result;
9use wacore::store::traits::*;
10
11/// A device-aware wrapper around SqliteStore that ensures all operations
12/// are scoped to a specific device_id for proper multi-account isolation.
13#[derive(Clone)]
14pub struct DeviceAwareSqliteStore {
15    store: Arc<SqliteStore>,
16    device_id: i32,
17}
18
19impl DeviceAwareSqliteStore {
20    pub fn new(store: Arc<SqliteStore>, device_id: i32) -> Self {
21        Self { store, device_id }
22    }
23}
24
25#[async_trait]
26impl IdentityStore for DeviceAwareSqliteStore {
27    async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
28        self.store
29            .put_identity_for_device(address, key, self.device_id)
30            .await
31    }
32
33    async fn delete_identity(&self, address: &str) -> Result<()> {
34        self.store
35            .delete_identity_for_device(address, self.device_id)
36            .await
37    }
38
39    async fn is_trusted_identity(
40        &self,
41        address: &str,
42        key: &[u8; 32],
43        _direction: wacore::libsignal::protocol::Direction,
44    ) -> Result<bool> {
45        // For now, we'll trust all identities as per the original implementation
46        // but we should load and check against stored identity
47        match self.load_identity(address).await? {
48            Some(stored_key) => Ok(stored_key == key.to_vec()),
49            None => Ok(true), // Trust on first use
50        }
51    }
52
53    async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
54        self.store
55            .load_identity_for_device(address, self.device_id)
56            .await
57    }
58}
59
60#[async_trait]
61impl SessionStore for DeviceAwareSqliteStore {
62    async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
63        self.store
64            .get_session_for_device(address, self.device_id)
65            .await
66    }
67
68    async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
69        self.store
70            .put_session_for_device(address, session, self.device_id)
71            .await
72    }
73
74    async fn delete_session(&self, address: &str) -> Result<()> {
75        self.store
76            .delete_session_for_device(address, self.device_id)
77            .await
78    }
79
80    async fn has_session(&self, address: &str) -> Result<bool> {
81        self.store
82            .has_session_for_device(address, self.device_id)
83            .await
84    }
85}
86
87#[async_trait]
88impl SenderKeyStoreHelper for DeviceAwareSqliteStore {
89    async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
90        self.store
91            .put_sender_key_for_device(address, record, self.device_id)
92            .await
93    }
94
95    async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
96        self.store
97            .get_sender_key_for_device(address, self.device_id)
98            .await
99    }
100
101    async fn delete_sender_key(&self, address: &str) -> Result<()> {
102        self.store
103            .delete_sender_key_for_device(address, self.device_id)
104            .await
105    }
106}
107
108#[async_trait]
109impl AppStateKeyStore for DeviceAwareSqliteStore {
110    async fn get_app_state_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
111        self.store
112            .get_app_state_sync_key_for_device(key_id, self.device_id)
113            .await
114    }
115
116    async fn set_app_state_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
117        self.store
118            .set_app_state_sync_key_for_device(key_id, key, self.device_id)
119            .await
120    }
121}
122
123#[async_trait]
124impl AppStateStore for DeviceAwareSqliteStore {
125    async fn get_app_state_version(&self, name: &str) -> Result<HashState> {
126        self.store
127            .get_app_state_version_for_device(name, self.device_id)
128            .await
129    }
130
131    async fn set_app_state_version(&self, name: &str, state: HashState) -> Result<()> {
132        self.store
133            .set_app_state_version_for_device(name, state, self.device_id)
134            .await
135    }
136
137    async fn put_app_state_mutation_macs(
138        &self,
139        name: &str,
140        version: u64,
141        mutations: &[AppStateMutationMAC],
142    ) -> Result<()> {
143        self.store
144            .put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
145            .await
146    }
147
148    async fn delete_app_state_mutation_macs(
149        &self,
150        name: &str,
151        index_macs: &[Vec<u8>],
152    ) -> Result<()> {
153        self.store
154            .delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
155            .await
156    }
157
158    async fn get_app_state_mutation_mac(
159        &self,
160        name: &str,
161        index_mac: &[u8],
162    ) -> Result<Option<Vec<u8>>> {
163        self.store
164            .get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
165            .await
166    }
167}
168
169// Implement libsignal::store traits by delegating to the original SqliteStore
170// but with device_id filtering for the prekey operations
171#[async_trait]
172impl wacore::libsignal::store::PreKeyStore for DeviceAwareSqliteStore {
173    async fn load_prekey(
174        &self,
175        prekey_id: u32,
176    ) -> std::result::Result<
177        Option<waproto::whatsapp::PreKeyRecordStructure>,
178        Box<dyn std::error::Error + Send + Sync>,
179    > {
180        let pool = self.store.pool.clone();
181        let device_id = self.device_id;
182
183        tokio::task::spawn_blocking(move || -> std::result::Result<Option<waproto::whatsapp::PreKeyRecordStructure>, Box<dyn std::error::Error + Send + Sync>> {
184            let mut conn = pool.get()?;
185
186            let key_data: Option<Vec<u8>> = prekeys::table
187                .select(prekeys::key)
188                .filter(prekeys::id.eq(prekey_id as i32))
189                .filter(prekeys::device_id.eq(device_id))
190                .first(&mut conn)
191                .optional()?;
192
193                match key_data {
194                Some(data) => {
195                    let record = waproto::whatsapp::PreKeyRecordStructure::decode(&data[..])?;
196                    Ok(Some(record))
197                }
198                None => Ok(None),
199            }
200        })
201        .await?
202    }
203
204    async fn store_prekey(
205        &self,
206        prekey_id: u32,
207        record: waproto::whatsapp::PreKeyRecordStructure,
208        uploaded: bool,
209    ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
210        let pool = self.store.pool.clone();
211        let device_id = self.device_id;
212        let key_data = record.encode_to_vec();
213
214        tokio::task::spawn_blocking(
215            move || -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
216                let mut conn = pool.get()?;
217
218                diesel::insert_into(prekeys::table)
219                    .values((
220                        prekeys::id.eq(prekey_id as i32),
221                        prekeys::key.eq(&key_data),
222                        prekeys::device_id.eq(device_id),
223                        prekeys::uploaded.eq(uploaded),
224                    ))
225                    .on_conflict((prekeys::id, prekeys::device_id))
226                    .do_update()
227                    .set((prekeys::key.eq(&key_data), prekeys::uploaded.eq(uploaded)))
228                    .execute(&mut conn)?;
229                Ok(())
230            },
231        )
232        .await?
233    }
234
235    async fn contains_prekey(
236        &self,
237        prekey_id: u32,
238    ) -> std::result::Result<bool, Box<dyn std::error::Error + Send + Sync>> {
239        let pool = self.store.pool.clone();
240        let device_id = self.device_id;
241
242        tokio::task::spawn_blocking(
243            move || -> std::result::Result<bool, Box<dyn std::error::Error + Send + Sync>> {
244                let mut conn = pool.get()?;
245
246                let count: i64 = prekeys::table
247                    .filter(prekeys::id.eq(prekey_id as i32))
248                    .filter(prekeys::device_id.eq(device_id))
249                    .count()
250                    .get_result(&mut conn)?;
251
252                Ok(count > 0)
253            },
254        )
255        .await?
256    }
257
258    async fn remove_prekey(
259        &self,
260        prekey_id: u32,
261    ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
262        let pool = self.store.pool.clone();
263        let device_id = self.device_id;
264
265        tokio::task::spawn_blocking(
266            move || -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
267                let mut conn = pool.get()?;
268
269                diesel::delete(
270                    prekeys::table
271                        .filter(prekeys::id.eq(prekey_id as i32))
272                        .filter(prekeys::device_id.eq(device_id)),
273                )
274                .execute(&mut conn)?;
275                Ok(())
276            },
277        )
278        .await?
279    }
280}
281
282#[async_trait]
283impl wacore::libsignal::store::SignedPreKeyStore for DeviceAwareSqliteStore {
284    async fn load_signed_prekey(
285        &self,
286        signed_prekey_id: u32,
287    ) -> std::result::Result<
288        Option<waproto::whatsapp::SignedPreKeyRecordStructure>,
289        Box<dyn std::error::Error + Send + Sync>,
290    > {
291        let pool = self.store.pool.clone();
292        let device_id = self.device_id;
293
294        tokio::task::spawn_blocking(move || -> std::result::Result<Option<waproto::whatsapp::SignedPreKeyRecordStructure>, Box<dyn std::error::Error + Send + Sync>> {
295            let mut conn = pool.get()?;
296
297            let record_data: Option<Vec<u8>> = signed_prekeys::table
298                .select(signed_prekeys::record)
299                .filter(signed_prekeys::id.eq(signed_prekey_id as i32))
300                .filter(signed_prekeys::device_id.eq(device_id))
301                .first(&mut conn)
302                .optional()?;
303
304                match record_data {
305                Some(data) => {
306                    let record = waproto::whatsapp::SignedPreKeyRecordStructure::decode(&data[..])?;
307                    Ok(Some(record))
308                }
309                None => Ok(None),
310            }
311        })
312        .await?
313    }
314
315    async fn load_signed_prekeys(
316        &self,
317    ) -> std::result::Result<
318        Vec<waproto::whatsapp::SignedPreKeyRecordStructure>,
319        Box<dyn std::error::Error + Send + Sync>,
320    > {
321        let pool = self.store.pool.clone();
322        let device_id = self.device_id;
323
324        tokio::task::spawn_blocking(move || -> std::result::Result<Vec<waproto::whatsapp::SignedPreKeyRecordStructure>, Box<dyn std::error::Error + Send + Sync>> {
325            let mut conn = pool.get()?;
326
327            let records_data: Vec<Vec<u8>> = signed_prekeys::table
328                .select(signed_prekeys::record)
329                .filter(signed_prekeys::device_id.eq(device_id))
330                .load(&mut conn)?;
331
332            let mut records = Vec::new();
333            for data in records_data {
334                let record = waproto::whatsapp::SignedPreKeyRecordStructure::decode(&data[..])?;
335                records.push(record);
336            }
337            Ok(records)
338        })
339        .await?
340    }
341
342    async fn store_signed_prekey(
343        &self,
344        signed_prekey_id: u32,
345        record: waproto::whatsapp::SignedPreKeyRecordStructure,
346    ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
347        let pool = self.store.pool.clone();
348        let device_id = self.device_id;
349        let record_data = record.encode_to_vec();
350
351        tokio::task::spawn_blocking(
352            move || -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
353                let mut conn = pool.get()?;
354
355                diesel::insert_into(signed_prekeys::table)
356                    .values((
357                        signed_prekeys::id.eq(signed_prekey_id as i32),
358                        signed_prekeys::record.eq(&record_data),
359                        signed_prekeys::device_id.eq(device_id),
360                    ))
361                    .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
362                    .do_update()
363                    .set(signed_prekeys::record.eq(&record_data))
364                    .execute(&mut conn)?;
365                Ok(())
366            },
367        )
368        .await?
369    }
370
371    async fn contains_signed_prekey(
372        &self,
373        signed_prekey_id: u32,
374    ) -> std::result::Result<bool, Box<dyn std::error::Error + Send + Sync>> {
375        let pool = self.store.pool.clone();
376        let device_id = self.device_id;
377
378        tokio::task::spawn_blocking(
379            move || -> std::result::Result<bool, Box<dyn std::error::Error + Send + Sync>> {
380                let mut conn = pool.get()?;
381
382                let count: i64 = signed_prekeys::table
383                    .filter(signed_prekeys::id.eq(signed_prekey_id as i32))
384                    .filter(signed_prekeys::device_id.eq(device_id))
385                    .count()
386                    .get_result(&mut conn)?;
387
388                Ok(count > 0)
389            },
390        )
391        .await?
392    }
393
394    async fn remove_signed_prekey(
395        &self,
396        signed_prekey_id: u32,
397    ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
398        let pool = self.store.pool.clone();
399        let device_id = self.device_id;
400
401        tokio::task::spawn_blocking(
402            move || -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
403                let mut conn = pool.get()?;
404
405                diesel::delete(
406                    signed_prekeys::table
407                        .filter(signed_prekeys::id.eq(signed_prekey_id as i32))
408                        .filter(signed_prekeys::device_id.eq(device_id)),
409                )
410                .execute(&mut conn)?;
411                Ok(())
412            },
413        )
414        .await?
415    }
416}
417
418#[async_trait]
419impl wacore::store::traits::DevicePersistence for DeviceAwareSqliteStore {
420    async fn save_device_data(
421        &self,
422        device_data: &wacore::store::Device,
423    ) -> wacore::store::error::Result<()> {
424        self.store.save_device_data(device_data).await
425    }
426
427    async fn save_device_data_for_device(
428        &self,
429        device_id: i32,
430        device_data: &wacore::store::Device,
431    ) -> wacore::store::error::Result<()> {
432        self.store
433            .save_device_data_for_device(device_id, device_data)
434            .await
435    }
436
437    async fn load_device_data(
438        &self,
439    ) -> wacore::store::error::Result<Option<wacore::store::Device>> {
440        self.store.load_device_data().await
441    }
442
443    async fn load_device_data_for_device(
444        &self,
445        device_id: i32,
446    ) -> wacore::store::error::Result<Option<wacore::store::Device>> {
447        self.store.load_device_data_for_device(device_id).await
448    }
449
450    async fn device_exists(&self, device_id: i32) -> wacore::store::error::Result<bool> {
451        self.store.device_exists(device_id).await
452    }
453
454    async fn create_new_device(&self) -> wacore::store::error::Result<i32> {
455        self.store.create_new_device().await
456    }
457}