whatsapp_rust_sqlite_storage/
sqlite_store.rs

1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::sql_query;
6use diesel::sqlite::SqliteConnection;
7use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
8use log::warn;
9use prost::Message;
10use std::sync::Arc;
11use wacore::appstate::hash::HashState;
12use wacore::appstate::processor::AppStateMutationMAC;
13use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
14use wacore::store::Device as CoreDevice;
15use wacore::store::error::{Result, StoreError};
16use wacore::store::traits::*;
17use waproto::whatsapp as wa;
18
19const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
20
21type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
22type DeviceRow = (
23    i32,
24    String,
25    String,
26    i32,
27    Vec<u8>,
28    Vec<u8>,
29    Vec<u8>,
30    i32,
31    Vec<u8>,
32    Vec<u8>,
33    Option<Vec<u8>>,
34    String,
35    i32,
36    i32,
37    i64,
38    i64,
39    Option<Vec<u8>>,
40);
41
42#[derive(Clone)]
43pub struct SqliteStore {
44    pub(crate) pool: SqlitePool,
45    pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
46    device_id: i32,
47}
48
49#[derive(Debug, Clone, Copy)]
50struct ConnectionOptions;
51
52impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
53    for ConnectionOptions
54{
55    fn on_acquire(
56        &self,
57        conn: &mut SqliteConnection,
58    ) -> std::result::Result<(), diesel::r2d2::Error> {
59        diesel::sql_query("PRAGMA busy_timeout = 30000;")
60            .execute(conn)
61            .map_err(diesel::r2d2::Error::QueryError)?;
62        diesel::sql_query("PRAGMA synchronous = NORMAL;")
63            .execute(conn)
64            .map_err(diesel::r2d2::Error::QueryError)?;
65        diesel::sql_query("PRAGMA cache_size = 512;")
66            .execute(conn)
67            .map_err(diesel::r2d2::Error::QueryError)?;
68        diesel::sql_query("PRAGMA temp_store = memory;")
69            .execute(conn)
70            .map_err(diesel::r2d2::Error::QueryError)?;
71        diesel::sql_query("PRAGMA foreign_keys = ON;")
72            .execute(conn)
73            .map_err(diesel::r2d2::Error::QueryError)?;
74        Ok(())
75    }
76}
77
78impl SqliteStore {
79    pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
80        let manager = ConnectionManager::<SqliteConnection>::new(database_url);
81
82        let pool_size = 2;
83
84        let pool = Pool::builder()
85            .max_size(pool_size)
86            .connection_customizer(Box::new(ConnectionOptions))
87            .build(manager)
88            .map_err(|e| StoreError::Connection(e.to_string()))?;
89
90        let pool_clone = pool.clone();
91        tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
92            let mut conn = pool_clone
93                .get()
94                .map_err(|e| StoreError::Connection(e.to_string()))?;
95
96            diesel::sql_query("PRAGMA journal_mode = WAL;")
97                .execute(&mut conn)
98                .map_err(|e| StoreError::Database(e.to_string()))?;
99
100            conn.run_pending_migrations(MIGRATIONS)
101                .map_err(|e| StoreError::Migration(e.to_string()))?;
102
103            Ok(())
104        })
105        .await
106        .map_err(|e| StoreError::Database(e.to_string()))??;
107
108        Ok(Self {
109            pool,
110            db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
111            device_id: 1,
112        })
113    }
114
115    pub async fn new_for_device(
116        database_url: &str,
117        device_id: i32,
118    ) -> std::result::Result<Self, StoreError> {
119        let mut store = Self::new(database_url).await?;
120        store.device_id = device_id;
121        Ok(store)
122    }
123
124    pub fn device_id(&self) -> i32 {
125        self.device_id
126    }
127
128    async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
129    where
130        F: FnOnce() -> Result<T> + Send + 'static,
131        T: Send + 'static,
132    {
133        let permit = self
134            .db_semaphore
135            .clone()
136            .acquire_owned()
137            .await
138            .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
139        let result = tokio::task::spawn_blocking(move || {
140            let res = f();
141            drop(permit);
142            res
143        })
144        .await
145        .map_err(|e| StoreError::Database(e.to_string()))??;
146        Ok(result)
147    }
148
149    fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
150        let mut bytes = Vec::with_capacity(64);
151        bytes.extend_from_slice(&key_pair.private_key.serialize());
152        bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
153        Ok(bytes)
154    }
155
156    fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
157        if bytes.len() != 64 {
158            return Err(StoreError::Serialization(format!(
159                "Invalid KeyPair length: {}",
160                bytes.len()
161            )));
162        }
163
164        let private_key = PrivateKey::deserialize(&bytes[0..32])
165            .map_err(|e| StoreError::Serialization(e.to_string()))?;
166        let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
167            .map_err(|e| StoreError::Serialization(e.to_string()))?;
168
169        Ok(KeyPair::new(public_key, private_key))
170    }
171
172    pub async fn save_device_data_for_device(
173        &self,
174        device_id: i32,
175        device_data: &CoreDevice,
176    ) -> Result<()> {
177        let pool = self.pool.clone();
178        let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
179        let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
180        let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
181        let account_data = device_data
182            .account
183            .as_ref()
184            .map(|account| account.encode_to_vec());
185        let registration_id = device_data.registration_id as i32;
186        let signed_pre_key_id = device_data.signed_pre_key_id as i32;
187        let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
188        let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
189        let push_name = device_data.push_name.clone();
190        let app_version_primary = device_data.app_version_primary as i32;
191        let app_version_secondary = device_data.app_version_secondary as i32;
192        let app_version_tertiary = device_data.app_version_tertiary as i64;
193        let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
194        let edge_routing_info = device_data.edge_routing_info.clone();
195        let new_lid = device_data
196            .lid
197            .as_ref()
198            .map(|j| j.to_string())
199            .unwrap_or_default();
200        let new_pn = device_data
201            .pn
202            .as_ref()
203            .map(|j| j.to_string())
204            .unwrap_or_default();
205
206        tokio::task::spawn_blocking(move || -> Result<()> {
207            let mut conn = pool
208                .get()
209                .map_err(|e| StoreError::Connection(e.to_string()))?;
210
211            diesel::insert_into(device::table)
212                .values((
213                    device::id.eq(device_id),
214                    device::lid.eq(&new_lid),
215                    device::pn.eq(&new_pn),
216                    device::registration_id.eq(registration_id),
217                    device::noise_key.eq(&noise_key_data),
218                    device::identity_key.eq(&identity_key_data),
219                    device::signed_pre_key.eq(&signed_pre_key_data),
220                    device::signed_pre_key_id.eq(signed_pre_key_id),
221                    device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
222                    device::adv_secret_key.eq(&adv_secret_key[..]),
223                    device::account.eq(account_data.clone()),
224                    device::push_name.eq(&push_name),
225                    device::app_version_primary.eq(app_version_primary),
226                    device::app_version_secondary.eq(app_version_secondary),
227                    device::app_version_tertiary.eq(app_version_tertiary),
228                    device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
229                    device::edge_routing_info.eq(edge_routing_info.clone()),
230                ))
231                .on_conflict(device::id)
232                .do_update()
233                .set((
234                    device::lid.eq(&new_lid),
235                    device::pn.eq(&new_pn),
236                    device::registration_id.eq(registration_id),
237                    device::noise_key.eq(&noise_key_data),
238                    device::identity_key.eq(&identity_key_data),
239                    device::signed_pre_key.eq(&signed_pre_key_data),
240                    device::signed_pre_key_id.eq(signed_pre_key_id),
241                    device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
242                    device::adv_secret_key.eq(&adv_secret_key[..]),
243                    device::account.eq(account_data.clone()),
244                    device::push_name.eq(&push_name),
245                    device::app_version_primary.eq(app_version_primary),
246                    device::app_version_secondary.eq(app_version_secondary),
247                    device::app_version_tertiary.eq(app_version_tertiary),
248                    device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
249                    device::edge_routing_info.eq(edge_routing_info),
250                ))
251                .execute(&mut conn)
252                .map_err(|e| StoreError::Database(e.to_string()))?;
253
254            Ok(())
255        })
256        .await
257        .map_err(|e| StoreError::Database(e.to_string()))??;
258
259        Ok(())
260    }
261
262    pub async fn create_new_device(&self) -> Result<i32> {
263        use crate::schema::device;
264
265        let pool = self.pool.clone();
266        tokio::task::spawn_blocking(move || -> Result<i32> {
267            let mut conn = pool
268                .get()
269                .map_err(|e| StoreError::Connection(e.to_string()))?;
270
271            let new_device = wacore::store::Device::new();
272
273            let noise_key_data = {
274                let mut bytes = Vec::with_capacity(64);
275                bytes.extend_from_slice(&new_device.noise_key.private_key.serialize());
276                bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
277                bytes
278            };
279            let identity_key_data = {
280                let mut bytes = Vec::with_capacity(64);
281                bytes.extend_from_slice(&new_device.identity_key.private_key.serialize());
282                bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
283                bytes
284            };
285            let signed_pre_key_data = {
286                let mut bytes = Vec::with_capacity(64);
287                bytes.extend_from_slice(&new_device.signed_pre_key.private_key.serialize());
288                bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
289                bytes
290            };
291
292            diesel::insert_into(device::table)
293                .values((
294                    device::lid.eq(""),
295                    device::pn.eq(""),
296                    device::registration_id.eq(new_device.registration_id as i32),
297                    device::noise_key.eq(&noise_key_data),
298                    device::identity_key.eq(&identity_key_data),
299                    device::signed_pre_key.eq(&signed_pre_key_data),
300                    device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
301                    device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
302                    device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
303                    device::account.eq(None::<Vec<u8>>),
304                    device::push_name.eq(&new_device.push_name),
305                    device::app_version_primary.eq(new_device.app_version_primary as i32),
306                    device::app_version_secondary.eq(new_device.app_version_secondary as i32),
307                    device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
308                    device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
309                    device::edge_routing_info.eq(None::<Vec<u8>>),
310                ))
311                .execute(&mut conn)
312                .map_err(|e| StoreError::Database(e.to_string()))?;
313
314            use diesel::sql_types::Integer;
315
316            #[derive(QueryableByName)]
317            struct LastInsertedId {
318                #[diesel(sql_type = Integer)]
319                last_insert_rowid: i32,
320            }
321
322            let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
323                .get_result::<LastInsertedId>(&mut conn)
324                .map_err(|e| StoreError::Database(e.to_string()))?
325                .last_insert_rowid;
326
327            Ok(device_id)
328        })
329        .await
330        .map_err(|e| StoreError::Database(e.to_string()))?
331    }
332
333    pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
334        use crate::schema::device;
335
336        let pool = self.pool.clone();
337        tokio::task::spawn_blocking(move || -> Result<bool> {
338            let mut conn = pool
339                .get()
340                .map_err(|e| StoreError::Connection(e.to_string()))?;
341
342            let count: i64 = device::table
343                .filter(device::id.eq(device_id))
344                .count()
345                .get_result(&mut conn)
346                .map_err(|e| StoreError::Database(e.to_string()))?;
347
348            Ok(count > 0)
349        })
350        .await
351        .map_err(|e| StoreError::Database(e.to_string()))?
352    }
353
354    pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
355        use crate::schema::device;
356
357        let pool = self.pool.clone();
358        let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
359            let mut conn = pool
360                .get()
361                .map_err(|e| StoreError::Connection(e.to_string()))?;
362            let result = device::table
363                .filter(device::id.eq(device_id))
364                .first::<DeviceRow>(&mut conn)
365                .optional()
366                .map_err(|e| StoreError::Database(e.to_string()))?;
367            Ok(result)
368        })
369        .await
370        .map_err(|e| StoreError::Database(e.to_string()))??;
371
372        if let Some((
373            _device_id,
374            lid_str,
375            pn_str,
376            registration_id,
377            noise_key_data,
378            identity_key_data,
379            signed_pre_key_data,
380            signed_pre_key_id,
381            signed_pre_key_signature_data,
382            adv_secret_key_data,
383            account_data,
384            push_name,
385            app_version_primary,
386            app_version_secondary,
387            app_version_tertiary,
388            app_version_last_fetched_ms,
389            edge_routing_info,
390        )) = row
391        {
392            let id = if !pn_str.is_empty() {
393                pn_str.parse().ok()
394            } else {
395                None
396            };
397            let lid = if !lid_str.is_empty() {
398                lid_str.parse().ok()
399            } else {
400                None
401            };
402
403            let noise_key = self.deserialize_keypair(&noise_key_data)?;
404            let identity_key = self.deserialize_keypair(&identity_key_data)?;
405            let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
406
407            let signed_pre_key_signature: [u8; 64] =
408                signed_pre_key_signature_data.try_into().map_err(|_| {
409                    StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
410                })?;
411
412            let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
413                StoreError::Serialization("Invalid adv_secret_key length".to_string())
414            })?;
415
416            let account = account_data
417                .map(|data| {
418                    wa::AdvSignedDeviceIdentity::decode(&data[..])
419                        .map_err(|e| StoreError::Serialization(e.to_string()))
420                })
421                .transpose()?;
422
423            Ok(Some(CoreDevice {
424                pn: id,
425                lid,
426                registration_id: registration_id as u32,
427                noise_key,
428                identity_key,
429                signed_pre_key,
430                signed_pre_key_id: signed_pre_key_id as u32,
431                signed_pre_key_signature,
432                adv_secret_key,
433                account,
434                push_name,
435                app_version_primary: app_version_primary as u32,
436                app_version_secondary: app_version_secondary as u32,
437                app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
438                app_version_last_fetched_ms,
439                device_props: {
440                    use wacore::store::device::DEVICE_PROPS;
441                    DEVICE_PROPS.clone()
442                },
443                edge_routing_info,
444            }))
445        } else {
446            Ok(None)
447        }
448    }
449
450    pub async fn put_identity_for_device(
451        &self,
452        address: &str,
453        key: [u8; 32],
454        device_id: i32,
455    ) -> Result<()> {
456        let pool = self.pool.clone();
457        let db_semaphore = self.db_semaphore.clone();
458        let address_owned = address.to_string();
459        let key_vec = key.to_vec();
460
461        const MAX_RETRIES: u32 = 5;
462
463        for attempt in 0..=MAX_RETRIES {
464            let permit =
465                db_semaphore.clone().acquire_owned().await.map_err(|e| {
466                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
467                })?;
468
469            let pool_clone = pool.clone();
470            let address_clone = address_owned.clone();
471            let key_clone = key_vec.clone();
472
473            let result = tokio::task::spawn_blocking(move || -> Result<()> {
474                let mut conn = pool_clone
475                    .get()
476                    .map_err(|e| StoreError::Connection(e.to_string()))?;
477                diesel::insert_into(identities::table)
478                    .values((
479                        identities::address.eq(address_clone),
480                        identities::key.eq(&key_clone[..]),
481                        identities::device_id.eq(device_id),
482                    ))
483                    .on_conflict((identities::address, identities::device_id))
484                    .do_update()
485                    .set(identities::key.eq(&key_clone[..]))
486                    .execute(&mut conn)
487                    .map_err(|e| StoreError::Database(e.to_string()))?;
488                Ok(())
489            })
490            .await;
491
492            drop(permit);
493
494            match result {
495                Ok(Ok(())) => return Ok(()),
496                Ok(Err(e)) => {
497                    let error_msg = e.to_string();
498                    if (error_msg.contains("locked") || error_msg.contains("busy"))
499                        && attempt < MAX_RETRIES
500                    {
501                        let delay_ms = 10 * 2u64.pow(attempt);
502                        warn!(
503                            "Identity write failed (attempt {}/{}): {}. Retrying in {}ms...",
504                            attempt + 1,
505                            MAX_RETRIES + 1,
506                            error_msg,
507                            delay_ms
508                        );
509                        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
510                        continue;
511                    }
512                    return Err(e);
513                }
514                Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
515            }
516        }
517
518        Err(StoreError::Database(format!(
519            "Identity write failed after {} attempts",
520            MAX_RETRIES + 1
521        )))
522    }
523
524    pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
525        let pool = self.pool.clone();
526        let address_owned = address.to_string();
527
528        tokio::task::spawn_blocking(move || -> Result<()> {
529            let mut conn = pool
530                .get()
531                .map_err(|e| StoreError::Connection(e.to_string()))?;
532            diesel::delete(
533                identities::table
534                    .filter(identities::address.eq(address_owned))
535                    .filter(identities::device_id.eq(device_id)),
536            )
537            .execute(&mut conn)
538            .map_err(|e| StoreError::Database(e.to_string()))?;
539            Ok(())
540        })
541        .await
542        .map_err(|e| StoreError::Database(e.to_string()))??;
543
544        Ok(())
545    }
546
547    pub async fn load_identity_for_device(
548        &self,
549        address: &str,
550        device_id: i32,
551    ) -> Result<Option<Vec<u8>>> {
552        let pool = self.pool.clone();
553        let address = address.to_string();
554        let result = self
555            .with_semaphore(move || -> Result<Option<Vec<u8>>> {
556                let mut conn = pool
557                    .get()
558                    .map_err(|e| StoreError::Connection(e.to_string()))?;
559                let res: Option<Vec<u8>> = identities::table
560                    .select(identities::key)
561                    .filter(identities::address.eq(address))
562                    .filter(identities::device_id.eq(device_id))
563                    .first(&mut conn)
564                    .optional()
565                    .map_err(|e| StoreError::Database(e.to_string()))?;
566                Ok(res)
567            })
568            .await?;
569
570        Ok(result)
571    }
572
573    pub async fn get_session_for_device(
574        &self,
575        address: &str,
576        device_id: i32,
577    ) -> Result<Option<Vec<u8>>> {
578        let pool = self.pool.clone();
579        let address_for_query = address.to_string();
580        let result = self
581            .with_semaphore(move || -> Result<Option<Vec<u8>>> {
582                let mut conn = pool
583                    .get()
584                    .map_err(|e| StoreError::Connection(e.to_string()))?;
585                let res: Option<Vec<u8>> = sessions::table
586                    .select(sessions::record)
587                    .filter(sessions::address.eq(address_for_query.clone()))
588                    .filter(sessions::device_id.eq(device_id))
589                    .first(&mut conn)
590                    .optional()
591                    .map_err(|e| StoreError::Database(e.to_string()))?;
592
593                Ok(res)
594            })
595            .await?;
596
597        Ok(result)
598    }
599
600    pub async fn put_session_for_device(
601        &self,
602        address: &str,
603        session: &[u8],
604        device_id: i32,
605    ) -> Result<()> {
606        let pool = self.pool.clone();
607        let db_semaphore = self.db_semaphore.clone();
608        let address_owned = address.to_string();
609        let session_vec = session.to_vec();
610
611        const MAX_RETRIES: u32 = 5;
612
613        for attempt in 0..=MAX_RETRIES {
614            let permit =
615                db_semaphore.clone().acquire_owned().await.map_err(|e| {
616                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
617                })?;
618
619            let pool_clone = pool.clone();
620            let address_clone = address_owned.clone();
621            let session_clone = session_vec.clone();
622
623            let result = tokio::task::spawn_blocking(move || -> Result<()> {
624                let mut conn = pool_clone
625                    .get()
626                    .map_err(|e| StoreError::Connection(e.to_string()))?;
627                diesel::insert_into(sessions::table)
628                    .values((
629                        sessions::address.eq(address_clone),
630                        sessions::record.eq(&session_clone),
631                        sessions::device_id.eq(device_id),
632                    ))
633                    .on_conflict((sessions::address, sessions::device_id))
634                    .do_update()
635                    .set(sessions::record.eq(&session_clone))
636                    .execute(&mut conn)
637                    .map_err(|e| StoreError::Database(e.to_string()))?;
638                Ok(())
639            })
640            .await;
641
642            drop(permit);
643
644            match result {
645                Ok(Ok(())) => {
646                    return Ok(());
647                }
648                Ok(Err(e)) => {
649                    let error_msg = e.to_string();
650                    if (error_msg.contains("locked") || error_msg.contains("busy"))
651                        && attempt < MAX_RETRIES
652                    {
653                        let delay_ms = 10 * 2u64.pow(attempt);
654                        warn!(
655                            "Session write failed (attempt {}/{}): {}. Retrying in {}ms...",
656                            attempt + 1,
657                            MAX_RETRIES + 1,
658                            error_msg,
659                            delay_ms
660                        );
661                        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
662                        continue;
663                    }
664                    return Err(e);
665                }
666                Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
667            }
668        }
669
670        Err(StoreError::Database(format!(
671            "Session write failed after {} attempts",
672            MAX_RETRIES + 1
673        )))
674    }
675
676    pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
677        let pool = self.pool.clone();
678        let address_owned = address.to_string();
679
680        tokio::task::spawn_blocking(move || -> Result<()> {
681            let mut conn = pool
682                .get()
683                .map_err(|e| StoreError::Connection(e.to_string()))?;
684            diesel::delete(
685                sessions::table
686                    .filter(sessions::address.eq(address_owned))
687                    .filter(sessions::device_id.eq(device_id)),
688            )
689            .execute(&mut conn)
690            .map_err(|e| StoreError::Database(e.to_string()))?;
691            Ok(())
692        })
693        .await
694        .map_err(|e| StoreError::Database(e.to_string()))??;
695
696        Ok(())
697    }
698
699    pub async fn put_sender_key_for_device(
700        &self,
701        address: &str,
702        record: &[u8],
703        device_id: i32,
704    ) -> Result<()> {
705        let pool = self.pool.clone();
706        let address = address.to_string();
707        let record_vec = record.to_vec();
708        tokio::task::spawn_blocking(move || -> Result<()> {
709            let mut conn = pool
710                .get()
711                .map_err(|e| StoreError::Connection(e.to_string()))?;
712            diesel::insert_into(sender_keys::table)
713                .values((
714                    sender_keys::address.eq(address),
715                    sender_keys::record.eq(&record_vec),
716                    sender_keys::device_id.eq(device_id),
717                ))
718                .on_conflict((sender_keys::address, sender_keys::device_id))
719                .do_update()
720                .set(sender_keys::record.eq(&record_vec))
721                .execute(&mut conn)
722                .map_err(|e| StoreError::Database(e.to_string()))?;
723            Ok(())
724        })
725        .await
726        .map_err(|e| StoreError::Database(e.to_string()))??;
727        Ok(())
728    }
729
730    pub async fn get_sender_key_for_device(
731        &self,
732        address: &str,
733        device_id: i32,
734    ) -> Result<Option<Vec<u8>>> {
735        let pool = self.pool.clone();
736        let address = address.to_string();
737        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
738            let mut conn = pool
739                .get()
740                .map_err(|e| StoreError::Connection(e.to_string()))?;
741            let res: Option<Vec<u8>> = sender_keys::table
742                .select(sender_keys::record)
743                .filter(sender_keys::address.eq(address))
744                .filter(sender_keys::device_id.eq(device_id))
745                .first(&mut conn)
746                .optional()
747                .map_err(|e| StoreError::Database(e.to_string()))?;
748            Ok(res)
749        })
750        .await
751        .map_err(|e| StoreError::Database(e.to_string()))?
752    }
753
754    pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
755        let pool = self.pool.clone();
756        let address = address.to_string();
757        tokio::task::spawn_blocking(move || -> Result<()> {
758            let mut conn = pool
759                .get()
760                .map_err(|e| StoreError::Connection(e.to_string()))?;
761            diesel::delete(
762                sender_keys::table
763                    .filter(sender_keys::address.eq(address))
764                    .filter(sender_keys::device_id.eq(device_id)),
765            )
766            .execute(&mut conn)
767            .map_err(|e| StoreError::Database(e.to_string()))?;
768            Ok(())
769        })
770        .await
771        .map_err(|e| StoreError::Database(e.to_string()))??;
772        Ok(())
773    }
774
775    pub async fn get_app_state_sync_key_for_device(
776        &self,
777        key_id: &[u8],
778        device_id: i32,
779    ) -> Result<Option<AppStateSyncKey>> {
780        let pool = self.pool.clone();
781        let key_id = key_id.to_vec();
782        let res: Option<Vec<u8>> =
783            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
784                let mut conn = pool
785                    .get()
786                    .map_err(|e| StoreError::Connection(e.to_string()))?;
787                let res: Option<Vec<u8>> = app_state_keys::table
788                    .select(app_state_keys::key_data)
789                    .filter(app_state_keys::key_id.eq(&key_id))
790                    .filter(app_state_keys::device_id.eq(device_id))
791                    .first(&mut conn)
792                    .optional()
793                    .map_err(|e| StoreError::Database(e.to_string()))?;
794                Ok(res)
795            })
796            .await
797            .map_err(|e| StoreError::Database(e.to_string()))??;
798
799        if let Some(data) = res {
800            let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
801                .map_err(|e| StoreError::Serialization(e.to_string()))?;
802            Ok(Some(key))
803        } else {
804            Ok(None)
805        }
806    }
807
808    pub async fn set_app_state_sync_key_for_device(
809        &self,
810        key_id: &[u8],
811        key: AppStateSyncKey,
812        device_id: i32,
813    ) -> Result<()> {
814        let pool = self.pool.clone();
815        let key_id = key_id.to_vec();
816        let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
817            .map_err(|e| StoreError::Serialization(e.to_string()))?;
818        tokio::task::spawn_blocking(move || -> Result<()> {
819            let mut conn = pool
820                .get()
821                .map_err(|e| StoreError::Connection(e.to_string()))?;
822            diesel::insert_into(app_state_keys::table)
823                .values((
824                    app_state_keys::key_id.eq(&key_id),
825                    app_state_keys::key_data.eq(&data),
826                    app_state_keys::device_id.eq(device_id),
827                ))
828                .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
829                .do_update()
830                .set(app_state_keys::key_data.eq(&data))
831                .execute(&mut conn)
832                .map_err(|e| StoreError::Database(e.to_string()))?;
833            Ok(())
834        })
835        .await
836        .map_err(|e| StoreError::Database(e.to_string()))??;
837        Ok(())
838    }
839
840    pub async fn get_app_state_version_for_device(
841        &self,
842        name: &str,
843        device_id: i32,
844    ) -> Result<HashState> {
845        let pool = self.pool.clone();
846        let name = name.to_string();
847        let res: Option<Vec<u8>> =
848            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
849                let mut conn = pool
850                    .get()
851                    .map_err(|e| StoreError::Connection(e.to_string()))?;
852                let res: Option<Vec<u8>> = app_state_versions::table
853                    .select(app_state_versions::state_data)
854                    .filter(app_state_versions::name.eq(name))
855                    .filter(app_state_versions::device_id.eq(device_id))
856                    .first(&mut conn)
857                    .optional()
858                    .map_err(|e| StoreError::Database(e.to_string()))?;
859                Ok(res)
860            })
861            .await
862            .map_err(|e| StoreError::Database(e.to_string()))??;
863
864        if let Some(data) = res {
865            let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
866                .map_err(|e| StoreError::Serialization(e.to_string()))?;
867            Ok(state)
868        } else {
869            Ok(HashState::default())
870        }
871    }
872
873    pub async fn set_app_state_version_for_device(
874        &self,
875        name: &str,
876        state: HashState,
877        device_id: i32,
878    ) -> Result<()> {
879        let pool = self.pool.clone();
880        let name = name.to_string();
881        let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
882            .map_err(|e| StoreError::Serialization(e.to_string()))?;
883        tokio::task::spawn_blocking(move || -> Result<()> {
884            let mut conn = pool
885                .get()
886                .map_err(|e| StoreError::Connection(e.to_string()))?;
887            diesel::insert_into(app_state_versions::table)
888                .values((
889                    app_state_versions::name.eq(&name),
890                    app_state_versions::state_data.eq(&data),
891                    app_state_versions::device_id.eq(device_id),
892                ))
893                .on_conflict((app_state_versions::name, app_state_versions::device_id))
894                .do_update()
895                .set(app_state_versions::state_data.eq(&data))
896                .execute(&mut conn)
897                .map_err(|e| StoreError::Database(e.to_string()))?;
898            Ok(())
899        })
900        .await
901        .map_err(|e| StoreError::Database(e.to_string()))??;
902        Ok(())
903    }
904
905    pub async fn put_app_state_mutation_macs_for_device(
906        &self,
907        name: &str,
908        version: u64,
909        mutations: &[AppStateMutationMAC],
910        device_id: i32,
911    ) -> Result<()> {
912        if mutations.is_empty() {
913            return Ok(());
914        }
915        let pool = self.pool.clone();
916        let name = name.to_string();
917        let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
918        tokio::task::spawn_blocking(move || -> Result<()> {
919            let mut conn = pool
920                .get()
921                .map_err(|e| StoreError::Connection(e.to_string()))?;
922            for m in mutations {
923                diesel::insert_into(app_state_mutation_macs::table)
924                    .values((
925                        app_state_mutation_macs::name.eq(&name),
926                        app_state_mutation_macs::version.eq(version as i64),
927                        app_state_mutation_macs::index_mac.eq(&m.index_mac),
928                        app_state_mutation_macs::value_mac.eq(&m.value_mac),
929                        app_state_mutation_macs::device_id.eq(device_id),
930                    ))
931                    .on_conflict((
932                        app_state_mutation_macs::name,
933                        app_state_mutation_macs::index_mac,
934                        app_state_mutation_macs::device_id,
935                    ))
936                    .do_update()
937                    .set((
938                        app_state_mutation_macs::version.eq(version as i64),
939                        app_state_mutation_macs::value_mac.eq(&m.value_mac),
940                    ))
941                    .execute(&mut conn)
942                    .map_err(|e| StoreError::Database(e.to_string()))?;
943            }
944            Ok(())
945        })
946        .await
947        .map_err(|e| StoreError::Database(e.to_string()))??;
948        Ok(())
949    }
950
951    pub async fn delete_app_state_mutation_macs_for_device(
952        &self,
953        name: &str,
954        index_macs: &[Vec<u8>],
955        device_id: i32,
956    ) -> Result<()> {
957        if index_macs.is_empty() {
958            return Ok(());
959        }
960        let pool = self.pool.clone();
961        let name = name.to_string();
962        let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
963        tokio::task::spawn_blocking(move || -> Result<()> {
964            let mut conn = pool
965                .get()
966                .map_err(|e| StoreError::Connection(e.to_string()))?;
967            for idx in index_macs {
968                diesel::delete(
969                    app_state_mutation_macs::table.filter(
970                        app_state_mutation_macs::name
971                            .eq(&name)
972                            .and(app_state_mutation_macs::index_mac.eq(&idx))
973                            .and(app_state_mutation_macs::device_id.eq(device_id)),
974                    ),
975                )
976                .execute(&mut conn)
977                .map_err(|e| StoreError::Database(e.to_string()))?;
978            }
979            Ok(())
980        })
981        .await
982        .map_err(|e| StoreError::Database(e.to_string()))??;
983        Ok(())
984    }
985
986    pub async fn get_app_state_mutation_mac_for_device(
987        &self,
988        name: &str,
989        index_mac: &[u8],
990        device_id: i32,
991    ) -> Result<Option<Vec<u8>>> {
992        let pool = self.pool.clone();
993        let name = name.to_string();
994        let index_mac = index_mac.to_vec();
995        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
996            let mut conn = pool
997                .get()
998                .map_err(|e| StoreError::Connection(e.to_string()))?;
999            let res: Option<Vec<u8>> = app_state_mutation_macs::table
1000                .select(app_state_mutation_macs::value_mac)
1001                .filter(app_state_mutation_macs::name.eq(&name))
1002                .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1003                .filter(app_state_mutation_macs::device_id.eq(device_id))
1004                .first(&mut conn)
1005                .optional()
1006                .map_err(|e| StoreError::Database(e.to_string()))?;
1007            Ok(res)
1008        })
1009        .await
1010        .map_err(|e| StoreError::Database(e.to_string()))?
1011    }
1012}
1013
1014#[async_trait]
1015impl SignalStore for SqliteStore {
1016    async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1017        self.put_identity_for_device(address, key, self.device_id)
1018            .await
1019    }
1020
1021    async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1022        self.load_identity_for_device(address, self.device_id).await
1023    }
1024
1025    async fn delete_identity(&self, address: &str) -> Result<()> {
1026        self.delete_identity_for_device(address, self.device_id)
1027            .await
1028    }
1029
1030    async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1031        self.get_session_for_device(address, self.device_id).await
1032    }
1033
1034    async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1035        self.put_session_for_device(address, session, self.device_id)
1036            .await
1037    }
1038
1039    async fn delete_session(&self, address: &str) -> Result<()> {
1040        self.delete_session_for_device(address, self.device_id)
1041            .await
1042    }
1043
1044    async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1045        let pool = self.pool.clone();
1046        let device_id = self.device_id;
1047        let record = record.to_vec();
1048        tokio::task::spawn_blocking(move || -> Result<()> {
1049            let mut conn = pool
1050                .get()
1051                .map_err(|e| StoreError::Connection(e.to_string()))?;
1052            diesel::insert_into(prekeys::table)
1053                .values((
1054                    prekeys::id.eq(id as i32),
1055                    prekeys::key.eq(&record),
1056                    prekeys::uploaded.eq(uploaded),
1057                    prekeys::device_id.eq(device_id),
1058                ))
1059                .on_conflict((prekeys::id, prekeys::device_id))
1060                .do_update()
1061                .set((prekeys::key.eq(&record), prekeys::uploaded.eq(uploaded)))
1062                .execute(&mut conn)
1063                .map_err(|e| StoreError::Database(e.to_string()))?;
1064            Ok(())
1065        })
1066        .await
1067        .map_err(|e| StoreError::Database(e.to_string()))??;
1068        Ok(())
1069    }
1070
1071    async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1072        let pool = self.pool.clone();
1073        let device_id = self.device_id;
1074        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1075            let mut conn = pool
1076                .get()
1077                .map_err(|e| StoreError::Connection(e.to_string()))?;
1078            let res: Option<Vec<u8>> = prekeys::table
1079                .select(prekeys::key)
1080                .filter(prekeys::id.eq(id as i32))
1081                .filter(prekeys::device_id.eq(device_id))
1082                .first(&mut conn)
1083                .optional()
1084                .map_err(|e| StoreError::Database(e.to_string()))?;
1085            Ok(res)
1086        })
1087        .await
1088        .map_err(|e| StoreError::Database(e.to_string()))?
1089    }
1090
1091    async fn remove_prekey(&self, id: u32) -> Result<()> {
1092        let pool = self.pool.clone();
1093        let device_id = self.device_id;
1094        tokio::task::spawn_blocking(move || -> Result<()> {
1095            let mut conn = pool
1096                .get()
1097                .map_err(|e| StoreError::Connection(e.to_string()))?;
1098            diesel::delete(
1099                prekeys::table
1100                    .filter(prekeys::id.eq(id as i32))
1101                    .filter(prekeys::device_id.eq(device_id)),
1102            )
1103            .execute(&mut conn)
1104            .map_err(|e| StoreError::Database(e.to_string()))?;
1105            Ok(())
1106        })
1107        .await
1108        .map_err(|e| StoreError::Database(e.to_string()))??;
1109        Ok(())
1110    }
1111
1112    async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1113        let pool = self.pool.clone();
1114        let device_id = self.device_id;
1115        let record = record.to_vec();
1116        tokio::task::spawn_blocking(move || -> Result<()> {
1117            let mut conn = pool
1118                .get()
1119                .map_err(|e| StoreError::Connection(e.to_string()))?;
1120            diesel::insert_into(signed_prekeys::table)
1121                .values((
1122                    signed_prekeys::id.eq(id as i32),
1123                    signed_prekeys::record.eq(&record),
1124                    signed_prekeys::device_id.eq(device_id),
1125                ))
1126                .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1127                .do_update()
1128                .set(signed_prekeys::record.eq(&record))
1129                .execute(&mut conn)
1130                .map_err(|e| StoreError::Database(e.to_string()))?;
1131            Ok(())
1132        })
1133        .await
1134        .map_err(|e| StoreError::Database(e.to_string()))??;
1135        Ok(())
1136    }
1137
1138    async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1139        let pool = self.pool.clone();
1140        let device_id = self.device_id;
1141        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1142            let mut conn = pool
1143                .get()
1144                .map_err(|e| StoreError::Connection(e.to_string()))?;
1145            let res: Option<Vec<u8>> = signed_prekeys::table
1146                .select(signed_prekeys::record)
1147                .filter(signed_prekeys::id.eq(id as i32))
1148                .filter(signed_prekeys::device_id.eq(device_id))
1149                .first(&mut conn)
1150                .optional()
1151                .map_err(|e| StoreError::Database(e.to_string()))?;
1152            Ok(res)
1153        })
1154        .await
1155        .map_err(|e| StoreError::Database(e.to_string()))?
1156    }
1157
1158    async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1159        let pool = self.pool.clone();
1160        let device_id = self.device_id;
1161        tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1162            let mut conn = pool
1163                .get()
1164                .map_err(|e| StoreError::Connection(e.to_string()))?;
1165            let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1166                .select((signed_prekeys::id, signed_prekeys::record))
1167                .filter(signed_prekeys::device_id.eq(device_id))
1168                .load(&mut conn)
1169                .map_err(|e| StoreError::Database(e.to_string()))?;
1170            Ok(results
1171                .into_iter()
1172                .map(|(id, record)| (id as u32, record))
1173                .collect())
1174        })
1175        .await
1176        .map_err(|e| StoreError::Database(e.to_string()))?
1177    }
1178
1179    async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1180        let pool = self.pool.clone();
1181        let device_id = self.device_id;
1182        tokio::task::spawn_blocking(move || -> Result<()> {
1183            let mut conn = pool
1184                .get()
1185                .map_err(|e| StoreError::Connection(e.to_string()))?;
1186            diesel::delete(
1187                signed_prekeys::table
1188                    .filter(signed_prekeys::id.eq(id as i32))
1189                    .filter(signed_prekeys::device_id.eq(device_id)),
1190            )
1191            .execute(&mut conn)
1192            .map_err(|e| StoreError::Database(e.to_string()))?;
1193            Ok(())
1194        })
1195        .await
1196        .map_err(|e| StoreError::Database(e.to_string()))??;
1197        Ok(())
1198    }
1199
1200    async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1201        self.put_sender_key_for_device(address, record, self.device_id)
1202            .await
1203    }
1204
1205    async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1206        self.get_sender_key_for_device(address, self.device_id)
1207            .await
1208    }
1209
1210    async fn delete_sender_key(&self, address: &str) -> Result<()> {
1211        self.delete_sender_key_for_device(address, self.device_id)
1212            .await
1213    }
1214}
1215
1216#[async_trait]
1217impl AppSyncStore for SqliteStore {
1218    async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1219        self.get_app_state_sync_key_for_device(key_id, self.device_id)
1220            .await
1221    }
1222
1223    async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1224        self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1225            .await
1226    }
1227
1228    async fn get_version(&self, name: &str) -> Result<HashState> {
1229        self.get_app_state_version_for_device(name, self.device_id)
1230            .await
1231    }
1232
1233    async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1234        self.set_app_state_version_for_device(name, state, self.device_id)
1235            .await
1236    }
1237
1238    async fn put_mutation_macs(
1239        &self,
1240        name: &str,
1241        version: u64,
1242        mutations: &[AppStateMutationMAC],
1243    ) -> Result<()> {
1244        self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1245            .await
1246    }
1247
1248    async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1249        self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1250            .await
1251    }
1252
1253    async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1254        self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1255            .await
1256    }
1257}
1258
1259#[async_trait]
1260impl ProtocolStore for SqliteStore {
1261    async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<String>> {
1262        let pool = self.pool.clone();
1263        let device_id = self.device_id;
1264        let group_jid = group_jid.to_string();
1265        tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1266            let mut conn = pool
1267                .get()
1268                .map_err(|e| StoreError::Connection(e.to_string()))?;
1269            let recipients: Vec<String> = skdm_recipients::table
1270                .select(skdm_recipients::device_jid)
1271                .filter(skdm_recipients::group_jid.eq(&group_jid))
1272                .filter(skdm_recipients::device_id.eq(device_id))
1273                .load(&mut conn)
1274                .map_err(|e| StoreError::Database(e.to_string()))?;
1275            Ok(recipients)
1276        })
1277        .await
1278        .map_err(|e| StoreError::Database(e.to_string()))?
1279    }
1280
1281    async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[String]) -> Result<()> {
1282        if device_jids.is_empty() {
1283            return Ok(());
1284        }
1285        let pool = self.pool.clone();
1286        let device_id = self.device_id;
1287        let group_jid = group_jid.to_string();
1288        let device_jids: Vec<String> = device_jids.to_vec();
1289        let now = std::time::SystemTime::now()
1290            .duration_since(std::time::UNIX_EPOCH)
1291            .unwrap_or_default()
1292            .as_secs() as i32;
1293        tokio::task::spawn_blocking(move || -> Result<()> {
1294            let mut conn = pool
1295                .get()
1296                .map_err(|e| StoreError::Connection(e.to_string()))?;
1297            for device_jid in device_jids {
1298                diesel::insert_into(skdm_recipients::table)
1299                    .values((
1300                        skdm_recipients::group_jid.eq(&group_jid),
1301                        skdm_recipients::device_jid.eq(&device_jid),
1302                        skdm_recipients::device_id.eq(device_id),
1303                        skdm_recipients::created_at.eq(now),
1304                    ))
1305                    .on_conflict((
1306                        skdm_recipients::group_jid,
1307                        skdm_recipients::device_jid,
1308                        skdm_recipients::device_id,
1309                    ))
1310                    .do_nothing()
1311                    .execute(&mut conn)
1312                    .map_err(|e| StoreError::Database(e.to_string()))?;
1313            }
1314            Ok(())
1315        })
1316        .await
1317        .map_err(|e| StoreError::Database(e.to_string()))??;
1318        Ok(())
1319    }
1320
1321    async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1322        let pool = self.pool.clone();
1323        let device_id = self.device_id;
1324        let group_jid = group_jid.to_string();
1325        tokio::task::spawn_blocking(move || -> Result<()> {
1326            let mut conn = pool
1327                .get()
1328                .map_err(|e| StoreError::Connection(e.to_string()))?;
1329            diesel::delete(
1330                skdm_recipients::table
1331                    .filter(skdm_recipients::group_jid.eq(&group_jid))
1332                    .filter(skdm_recipients::device_id.eq(device_id)),
1333            )
1334            .execute(&mut conn)
1335            .map_err(|e| StoreError::Database(e.to_string()))?;
1336            Ok(())
1337        })
1338        .await
1339        .map_err(|e| StoreError::Database(e.to_string()))??;
1340        Ok(())
1341    }
1342
1343    async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1344        let pool = self.pool.clone();
1345        let device_id = self.device_id;
1346        let lid = lid.to_string();
1347        tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1348            let mut conn = pool
1349                .get()
1350                .map_err(|e| StoreError::Connection(e.to_string()))?;
1351            let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1352                .select((
1353                    lid_pn_mapping::lid,
1354                    lid_pn_mapping::phone_number,
1355                    lid_pn_mapping::created_at,
1356                    lid_pn_mapping::learning_source,
1357                    lid_pn_mapping::updated_at,
1358                ))
1359                .filter(lid_pn_mapping::lid.eq(&lid))
1360                .filter(lid_pn_mapping::device_id.eq(device_id))
1361                .first(&mut conn)
1362                .optional()
1363                .map_err(|e| StoreError::Database(e.to_string()))?;
1364            Ok(row.map(
1365                |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1366                    lid,
1367                    phone_number,
1368                    created_at,
1369                    updated_at,
1370                    learning_source,
1371                },
1372            ))
1373        })
1374        .await
1375        .map_err(|e| StoreError::Database(e.to_string()))?
1376    }
1377
1378    async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1379        let pool = self.pool.clone();
1380        let device_id = self.device_id;
1381        let phone = phone.to_string();
1382        tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1383            let mut conn = pool
1384                .get()
1385                .map_err(|e| StoreError::Connection(e.to_string()))?;
1386            let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1387                .select((
1388                    lid_pn_mapping::lid,
1389                    lid_pn_mapping::phone_number,
1390                    lid_pn_mapping::created_at,
1391                    lid_pn_mapping::learning_source,
1392                    lid_pn_mapping::updated_at,
1393                ))
1394                .filter(lid_pn_mapping::phone_number.eq(&phone))
1395                .filter(lid_pn_mapping::device_id.eq(device_id))
1396                .order(lid_pn_mapping::updated_at.desc())
1397                .first(&mut conn)
1398                .optional()
1399                .map_err(|e| StoreError::Database(e.to_string()))?;
1400            Ok(row.map(
1401                |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1402                    lid,
1403                    phone_number,
1404                    created_at,
1405                    updated_at,
1406                    learning_source,
1407                },
1408            ))
1409        })
1410        .await
1411        .map_err(|e| StoreError::Database(e.to_string()))?
1412    }
1413
1414    async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1415        let pool = self.pool.clone();
1416        let device_id = self.device_id;
1417        let entry = entry.clone();
1418        tokio::task::spawn_blocking(move || -> Result<()> {
1419            let mut conn = pool
1420                .get()
1421                .map_err(|e| StoreError::Connection(e.to_string()))?;
1422            diesel::insert_into(lid_pn_mapping::table)
1423                .values((
1424                    lid_pn_mapping::lid.eq(&entry.lid),
1425                    lid_pn_mapping::phone_number.eq(&entry.phone_number),
1426                    lid_pn_mapping::created_at.eq(entry.created_at),
1427                    lid_pn_mapping::learning_source.eq(&entry.learning_source),
1428                    lid_pn_mapping::updated_at.eq(entry.updated_at),
1429                    lid_pn_mapping::device_id.eq(device_id),
1430                ))
1431                .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1432                .do_update()
1433                .set((
1434                    lid_pn_mapping::phone_number.eq(&entry.phone_number),
1435                    lid_pn_mapping::learning_source.eq(&entry.learning_source),
1436                    lid_pn_mapping::updated_at.eq(entry.updated_at),
1437                ))
1438                .execute(&mut conn)
1439                .map_err(|e| StoreError::Database(e.to_string()))?;
1440            Ok(())
1441        })
1442        .await
1443        .map_err(|e| StoreError::Database(e.to_string()))??;
1444        Ok(())
1445    }
1446
1447    async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1448        let pool = self.pool.clone();
1449        let device_id = self.device_id;
1450        tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1451            let mut conn = pool
1452                .get()
1453                .map_err(|e| StoreError::Connection(e.to_string()))?;
1454            let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1455                .select((
1456                    lid_pn_mapping::lid,
1457                    lid_pn_mapping::phone_number,
1458                    lid_pn_mapping::created_at,
1459                    lid_pn_mapping::learning_source,
1460                    lid_pn_mapping::updated_at,
1461                ))
1462                .filter(lid_pn_mapping::device_id.eq(device_id))
1463                .load(&mut conn)
1464                .map_err(|e| StoreError::Database(e.to_string()))?;
1465            Ok(rows
1466                .into_iter()
1467                .map(
1468                    |(lid, phone_number, created_at, learning_source, updated_at)| {
1469                        LidPnMappingEntry {
1470                            lid,
1471                            phone_number,
1472                            created_at,
1473                            updated_at,
1474                            learning_source,
1475                        }
1476                    },
1477                )
1478                .collect())
1479        })
1480        .await
1481        .map_err(|e| StoreError::Database(e.to_string()))?
1482    }
1483
1484    async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1485        let pool = self.pool.clone();
1486        let device_id = self.device_id;
1487        let address = address.to_string();
1488        let message_id = message_id.to_string();
1489        let base_key = base_key.to_vec();
1490        let now = std::time::SystemTime::now()
1491            .duration_since(std::time::UNIX_EPOCH)
1492            .unwrap_or_default()
1493            .as_secs() as i32;
1494        tokio::task::spawn_blocking(move || -> Result<()> {
1495            let mut conn = pool
1496                .get()
1497                .map_err(|e| StoreError::Connection(e.to_string()))?;
1498            diesel::insert_into(base_keys::table)
1499                .values((
1500                    base_keys::address.eq(&address),
1501                    base_keys::message_id.eq(&message_id),
1502                    base_keys::base_key.eq(&base_key),
1503                    base_keys::device_id.eq(device_id),
1504                    base_keys::created_at.eq(now),
1505                ))
1506                .on_conflict((
1507                    base_keys::address,
1508                    base_keys::message_id,
1509                    base_keys::device_id,
1510                ))
1511                .do_update()
1512                .set(base_keys::base_key.eq(&base_key))
1513                .execute(&mut conn)
1514                .map_err(|e| StoreError::Database(e.to_string()))?;
1515            Ok(())
1516        })
1517        .await
1518        .map_err(|e| StoreError::Database(e.to_string()))??;
1519        Ok(())
1520    }
1521
1522    async fn has_same_base_key(
1523        &self,
1524        address: &str,
1525        message_id: &str,
1526        current_base_key: &[u8],
1527    ) -> Result<bool> {
1528        let pool = self.pool.clone();
1529        let device_id = self.device_id;
1530        let address = address.to_string();
1531        let message_id = message_id.to_string();
1532        let current_base_key = current_base_key.to_vec();
1533        tokio::task::spawn_blocking(move || -> Result<bool> {
1534            let mut conn = pool
1535                .get()
1536                .map_err(|e| StoreError::Connection(e.to_string()))?;
1537            let stored_key: Option<Vec<u8>> = base_keys::table
1538                .select(base_keys::base_key)
1539                .filter(base_keys::address.eq(&address))
1540                .filter(base_keys::message_id.eq(&message_id))
1541                .filter(base_keys::device_id.eq(device_id))
1542                .first(&mut conn)
1543                .optional()
1544                .map_err(|e| StoreError::Database(e.to_string()))?;
1545            Ok(stored_key.as_ref() == Some(&current_base_key))
1546        })
1547        .await
1548        .map_err(|e| StoreError::Database(e.to_string()))?
1549    }
1550
1551    async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1552        let pool = self.pool.clone();
1553        let device_id = self.device_id;
1554        let address = address.to_string();
1555        let message_id = message_id.to_string();
1556        tokio::task::spawn_blocking(move || -> Result<()> {
1557            let mut conn = pool
1558                .get()
1559                .map_err(|e| StoreError::Connection(e.to_string()))?;
1560            diesel::delete(
1561                base_keys::table
1562                    .filter(base_keys::address.eq(&address))
1563                    .filter(base_keys::message_id.eq(&message_id))
1564                    .filter(base_keys::device_id.eq(device_id)),
1565            )
1566            .execute(&mut conn)
1567            .map_err(|e| StoreError::Database(e.to_string()))?;
1568            Ok(())
1569        })
1570        .await
1571        .map_err(|e| StoreError::Database(e.to_string()))??;
1572        Ok(())
1573    }
1574
1575    async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1576        let pool = self.pool.clone();
1577        let device_id = self.device_id;
1578        let devices_json = serde_json::to_string(&record.devices)
1579            .map_err(|e| StoreError::Serialization(e.to_string()))?;
1580        let now = std::time::SystemTime::now()
1581            .duration_since(std::time::UNIX_EPOCH)
1582            .unwrap_or_default()
1583            .as_secs() as i32;
1584        tokio::task::spawn_blocking(move || -> Result<()> {
1585            let mut conn = pool
1586                .get()
1587                .map_err(|e| StoreError::Connection(e.to_string()))?;
1588            diesel::insert_into(device_registry::table)
1589                .values((
1590                    device_registry::user_id.eq(&record.user),
1591                    device_registry::devices_json.eq(&devices_json),
1592                    device_registry::timestamp.eq(record.timestamp as i32),
1593                    device_registry::phash.eq(&record.phash),
1594                    device_registry::device_id.eq(device_id),
1595                    device_registry::updated_at.eq(now),
1596                ))
1597                .on_conflict((device_registry::user_id, device_registry::device_id))
1598                .do_update()
1599                .set((
1600                    device_registry::devices_json.eq(&devices_json),
1601                    device_registry::timestamp.eq(record.timestamp as i32),
1602                    device_registry::phash.eq(&record.phash),
1603                    device_registry::updated_at.eq(now),
1604                ))
1605                .execute(&mut conn)
1606                .map_err(|e| StoreError::Database(e.to_string()))?;
1607            Ok(())
1608        })
1609        .await
1610        .map_err(|e| StoreError::Database(e.to_string()))??;
1611        Ok(())
1612    }
1613
1614    async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
1615        let pool = self.pool.clone();
1616        let device_id = self.device_id;
1617        let user = user.to_string();
1618        tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
1619            let mut conn = pool
1620                .get()
1621                .map_err(|e| StoreError::Connection(e.to_string()))?;
1622            let row: Option<(String, String, i32, Option<String>)> = device_registry::table
1623                .select((
1624                    device_registry::user_id,
1625                    device_registry::devices_json,
1626                    device_registry::timestamp,
1627                    device_registry::phash,
1628                ))
1629                .filter(device_registry::user_id.eq(&user))
1630                .filter(device_registry::device_id.eq(device_id))
1631                .first(&mut conn)
1632                .optional()
1633                .map_err(|e| StoreError::Database(e.to_string()))?;
1634            match row {
1635                Some((user, devices_json, timestamp, phash)) => {
1636                    let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
1637                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
1638                    Ok(Some(DeviceListRecord {
1639                        user,
1640                        devices,
1641                        timestamp: timestamp as i64,
1642                        phash,
1643                    }))
1644                }
1645                None => Ok(None),
1646            }
1647        })
1648        .await
1649        .map_err(|e| StoreError::Database(e.to_string()))?
1650    }
1651
1652    async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
1653        let pool = self.pool.clone();
1654        let device_id = self.device_id;
1655        let group_jid = group_jid.to_string();
1656        let participant = participant.to_string();
1657        let now = std::time::SystemTime::now()
1658            .duration_since(std::time::UNIX_EPOCH)
1659            .unwrap_or_default()
1660            .as_secs() as i32;
1661        tokio::task::spawn_blocking(move || -> Result<()> {
1662            let mut conn = pool
1663                .get()
1664                .map_err(|e| StoreError::Connection(e.to_string()))?;
1665            diesel::insert_into(sender_key_status::table)
1666                .values((
1667                    sender_key_status::group_jid.eq(&group_jid),
1668                    sender_key_status::participant.eq(&participant),
1669                    sender_key_status::device_id.eq(device_id),
1670                    sender_key_status::marked_at.eq(now),
1671                ))
1672                .on_conflict((
1673                    sender_key_status::group_jid,
1674                    sender_key_status::participant,
1675                    sender_key_status::device_id,
1676                ))
1677                .do_update()
1678                .set(sender_key_status::marked_at.eq(now))
1679                .execute(&mut conn)
1680                .map_err(|e| StoreError::Database(e.to_string()))?;
1681            Ok(())
1682        })
1683        .await
1684        .map_err(|e| StoreError::Database(e.to_string()))??;
1685        Ok(())
1686    }
1687
1688    async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
1689        let pool = self.pool.clone();
1690        let device_id = self.device_id;
1691        let group_jid = group_jid.to_string();
1692        tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1693            let mut conn = pool
1694                .get()
1695                .map_err(|e| StoreError::Connection(e.to_string()))?;
1696            let participants: Vec<String> = sender_key_status::table
1697                .select(sender_key_status::participant)
1698                .filter(sender_key_status::group_jid.eq(&group_jid))
1699                .filter(sender_key_status::device_id.eq(device_id))
1700                .load(&mut conn)
1701                .map_err(|e| StoreError::Database(e.to_string()))?;
1702            diesel::delete(
1703                sender_key_status::table
1704                    .filter(sender_key_status::group_jid.eq(&group_jid))
1705                    .filter(sender_key_status::device_id.eq(device_id)),
1706            )
1707            .execute(&mut conn)
1708            .map_err(|e| StoreError::Database(e.to_string()))?;
1709            Ok(participants)
1710        })
1711        .await
1712        .map_err(|e| StoreError::Database(e.to_string()))?
1713    }
1714}
1715
1716#[async_trait]
1717impl DeviceStore for SqliteStore {
1718    async fn save(&self, device: &CoreDevice) -> Result<()> {
1719        SqliteStore::save_device_data_for_device(self, self.device_id, device).await
1720    }
1721
1722    async fn load(&self) -> Result<Option<CoreDevice>> {
1723        SqliteStore::load_device_data_for_device(self, self.device_id).await
1724    }
1725
1726    async fn exists(&self) -> Result<bool> {
1727        SqliteStore::device_exists(self, self.device_id).await
1728    }
1729
1730    async fn create(&self) -> Result<i32> {
1731        SqliteStore::create_new_device(self).await
1732    }
1733}
1734
1735#[cfg(test)]
1736mod tests {
1737    use super::*;
1738
1739    async fn create_test_store() -> SqliteStore {
1740        SqliteStore::new(":memory:")
1741            .await
1742            .expect("Failed to create test store")
1743    }
1744
1745    #[tokio::test]
1746    async fn test_device_registry_save_and_get() {
1747        let store = create_test_store().await;
1748
1749        let record = DeviceListRecord {
1750            user: "1234567890".to_string(),
1751            devices: vec![
1752                DeviceInfo {
1753                    device_id: 0,
1754                    key_index: None,
1755                },
1756                DeviceInfo {
1757                    device_id: 1,
1758                    key_index: Some(42),
1759                },
1760            ],
1761            timestamp: 1234567890,
1762            phash: Some("2:abcdef".to_string()),
1763        };
1764
1765        store.update_device_list(record).await.expect("save failed");
1766        let loaded = store
1767            .get_devices("1234567890")
1768            .await
1769            .expect("get failed")
1770            .expect("record should exist");
1771
1772        assert_eq!(loaded.user, "1234567890");
1773        assert_eq!(loaded.devices.len(), 2);
1774        assert_eq!(loaded.devices[0].device_id, 0);
1775        assert_eq!(loaded.devices[1].device_id, 1);
1776        assert_eq!(loaded.devices[1].key_index, Some(42));
1777        assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
1778    }
1779
1780    #[tokio::test]
1781    async fn test_device_registry_update_existing() {
1782        let store = create_test_store().await;
1783
1784        let record1 = DeviceListRecord {
1785            user: "1234567890".to_string(),
1786            devices: vec![DeviceInfo {
1787                device_id: 0,
1788                key_index: None,
1789            }],
1790            timestamp: 1000,
1791            phash: Some("2:old".to_string()),
1792        };
1793        store
1794            .update_device_list(record1)
1795            .await
1796            .expect("save1 failed");
1797
1798        let record2 = DeviceListRecord {
1799            user: "1234567890".to_string(),
1800            devices: vec![
1801                DeviceInfo {
1802                    device_id: 0,
1803                    key_index: None,
1804                },
1805                DeviceInfo {
1806                    device_id: 2,
1807                    key_index: None,
1808                },
1809            ],
1810            timestamp: 2000,
1811            phash: Some("2:new".to_string()),
1812        };
1813        store
1814            .update_device_list(record2)
1815            .await
1816            .expect("save2 failed");
1817
1818        let loaded = store
1819            .get_devices("1234567890")
1820            .await
1821            .expect("get failed")
1822            .expect("record should exist");
1823
1824        assert_eq!(loaded.devices.len(), 2);
1825        assert_eq!(loaded.phash, Some("2:new".to_string()));
1826    }
1827
1828    #[tokio::test]
1829    async fn test_device_registry_get_nonexistent() {
1830        let store = create_test_store().await;
1831        let result = store.get_devices("nonexistent").await.expect("get failed");
1832        assert!(result.is_none());
1833    }
1834
1835    #[tokio::test]
1836    async fn test_sender_key_status_mark_and_consume() {
1837        let store = create_test_store().await;
1838
1839        let group = "group123@g.us";
1840        let participant = "user1@s.whatsapp.net";
1841
1842        store
1843            .mark_forget_sender_key(group, participant)
1844            .await
1845            .expect("mark failed");
1846
1847        let consumed = store
1848            .consume_forget_marks(group)
1849            .await
1850            .expect("consume failed");
1851        assert_eq!(consumed.len(), 1);
1852        assert!(consumed.contains(&participant.to_string()));
1853
1854        let consumed = store
1855            .consume_forget_marks(group)
1856            .await
1857            .expect("consume failed");
1858        assert!(consumed.is_empty());
1859    }
1860
1861    #[tokio::test]
1862    async fn test_sender_key_status_consume_multiple() {
1863        let store = create_test_store().await;
1864
1865        let group = "group123@g.us";
1866
1867        store
1868            .mark_forget_sender_key(group, "user1@s.whatsapp.net")
1869            .await
1870            .expect("mark failed");
1871        store
1872            .mark_forget_sender_key(group, "user2@s.whatsapp.net")
1873            .await
1874            .expect("mark failed");
1875
1876        let consumed = store
1877            .consume_forget_marks(group)
1878            .await
1879            .expect("consume failed");
1880        assert_eq!(consumed.len(), 2);
1881        assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
1882        assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
1883
1884        let consumed = store
1885            .consume_forget_marks(group)
1886            .await
1887            .expect("consume failed");
1888        assert!(consumed.is_empty());
1889    }
1890
1891    #[tokio::test]
1892    async fn test_sender_key_status_different_groups() {
1893        let store = create_test_store().await;
1894
1895        let group1 = "group1@g.us";
1896        let group2 = "group2@g.us";
1897        let participant = "user@s.whatsapp.net";
1898
1899        store
1900            .mark_forget_sender_key(group1, participant)
1901            .await
1902            .expect("mark failed");
1903
1904        let consumed = store.consume_forget_marks(group1).await.unwrap();
1905        assert_eq!(consumed.len(), 1);
1906
1907        let consumed = store.consume_forget_marks(group2).await.unwrap();
1908        assert!(consumed.is_empty());
1909    }
1910}