Skip to main content

whatsapp_rust_sqlite_storage/
sqlite_store.rs

1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::sql_query;
6use diesel::sqlite::SqliteConnection;
7use diesel::upsert::excluded;
8use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
9use log::warn;
10use prost::Message;
11use std::sync::Arc;
12use wacore::appstate::hash::HashState;
13use wacore::appstate::processor::AppStateMutationMAC;
14use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
15use wacore::store::Device as CoreDevice;
16use wacore::store::error::{Result, StoreError};
17use wacore::store::traits::*;
18use wacore_binary::jid::Jid;
19use waproto::whatsapp as wa;
20
21const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
22
23type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
24type DeviceRow = (
25    i32,
26    String,
27    String,
28    i32,
29    Vec<u8>,
30    Vec<u8>,
31    Vec<u8>,
32    i32,
33    Vec<u8>,
34    Vec<u8>,
35    Option<Vec<u8>>,
36    String,
37    i32,
38    i32,
39    i64,
40    i64,
41    Option<Vec<u8>>,
42    Option<String>,
43);
44
45#[derive(Clone)]
46pub struct SqliteStore {
47    pub(crate) pool: SqlitePool,
48    pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
49    pub(crate) database_path: String,
50    device_id: i32,
51}
52
53#[derive(Debug, Clone, Copy)]
54struct ConnectionOptions;
55
56impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
57    for ConnectionOptions
58{
59    fn on_acquire(
60        &self,
61        conn: &mut SqliteConnection,
62    ) -> std::result::Result<(), diesel::r2d2::Error> {
63        diesel::sql_query("PRAGMA busy_timeout = 30000;")
64            .execute(conn)
65            .map_err(diesel::r2d2::Error::QueryError)?;
66        diesel::sql_query("PRAGMA synchronous = NORMAL;")
67            .execute(conn)
68            .map_err(diesel::r2d2::Error::QueryError)?;
69        diesel::sql_query("PRAGMA cache_size = 512;")
70            .execute(conn)
71            .map_err(diesel::r2d2::Error::QueryError)?;
72        diesel::sql_query("PRAGMA temp_store = memory;")
73            .execute(conn)
74            .map_err(diesel::r2d2::Error::QueryError)?;
75        diesel::sql_query("PRAGMA foreign_keys = ON;")
76            .execute(conn)
77            .map_err(diesel::r2d2::Error::QueryError)?;
78        Ok(())
79    }
80}
81
82fn parse_database_path(database_url: &str) -> Result<String> {
83    // Reject in-memory databases
84    if database_url == ":memory:" {
85        return Err(StoreError::Database(
86            "Snapshot not supported for in-memory databases".to_string(),
87        ));
88    }
89
90    // Strip query string and fragment
91    let path = database_url
92        .split(['?', '#'])
93        .next()
94        .unwrap_or(database_url);
95
96    // Remove sqlite:// prefix if present
97    let path = path.trim_start_matches("sqlite://");
98
99    // Check if the resulting path looks like an in-memory marker
100    if path == ":memory:" || path.starts_with(":memory:?") {
101        return Err(StoreError::Database(
102            "Snapshot not supported for in-memory databases".to_string(),
103        ));
104    }
105
106    Ok(path.to_string())
107}
108
109impl SqliteStore {
110    pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
111        let manager = ConnectionManager::<SqliteConnection>::new(database_url);
112
113        let pool_size = 2;
114
115        let pool = Pool::builder()
116            .max_size(pool_size)
117            .connection_customizer(Box::new(ConnectionOptions))
118            .build(manager)
119            .map_err(|e| StoreError::Connection(e.to_string()))?;
120
121        let pool_clone = pool.clone();
122        tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
123            let mut conn = pool_clone
124                .get()
125                .map_err(|e| StoreError::Connection(e.to_string()))?;
126
127            diesel::sql_query("PRAGMA journal_mode = WAL;")
128                .execute(&mut conn)
129                .map_err(|e| StoreError::Database(e.to_string()))?;
130
131            conn.run_pending_migrations(MIGRATIONS)
132                .map_err(|e| StoreError::Migration(e.to_string()))?;
133
134            Ok(())
135        })
136        .await
137        .map_err(|e| StoreError::Database(e.to_string()))??;
138
139        let database_path = parse_database_path(database_url)?;
140
141        Ok(Self {
142            pool,
143            db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
144            database_path,
145            device_id: 1,
146        })
147    }
148
149    pub async fn new_for_device(
150        database_url: &str,
151        device_id: i32,
152    ) -> std::result::Result<Self, StoreError> {
153        let mut store = Self::new(database_url).await?;
154        store.device_id = device_id;
155        Ok(store)
156    }
157
158    pub fn device_id(&self) -> i32 {
159        self.device_id
160    }
161
162    async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
163    where
164        F: FnOnce() -> Result<T> + Send + 'static,
165        T: Send + 'static,
166    {
167        let permit = self
168            .db_semaphore
169            .clone()
170            .acquire_owned()
171            .await
172            .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
173        let result = tokio::task::spawn_blocking(move || {
174            let res = f();
175            drop(permit);
176            res
177        })
178        .await
179        .map_err(|e| StoreError::Database(e.to_string()))??;
180        Ok(result)
181    }
182
183    fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
184        let mut bytes = Vec::with_capacity(64);
185        bytes.extend_from_slice(key_pair.private_key.serialize());
186        bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
187        Ok(bytes)
188    }
189
190    fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
191        if bytes.len() != 64 {
192            return Err(StoreError::Serialization(format!(
193                "Invalid KeyPair length: {}",
194                bytes.len()
195            )));
196        }
197
198        let private_key = PrivateKey::deserialize(&bytes[0..32])
199            .map_err(|e| StoreError::Serialization(e.to_string()))?;
200        let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
201            .map_err(|e| StoreError::Serialization(e.to_string()))?;
202
203        Ok(KeyPair::new(public_key, private_key))
204    }
205
206    pub async fn save_device_data_for_device(
207        &self,
208        device_id: i32,
209        device_data: &CoreDevice,
210    ) -> Result<()> {
211        let pool = self.pool.clone();
212        let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
213        let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
214        let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
215        let account_data = device_data
216            .account
217            .as_ref()
218            .map(|account| account.encode_to_vec());
219        let registration_id = device_data.registration_id as i32;
220        let signed_pre_key_id = device_data.signed_pre_key_id as i32;
221        let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
222        let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
223        let push_name = device_data.push_name.clone();
224        let app_version_primary = device_data.app_version_primary as i32;
225        let app_version_secondary = device_data.app_version_secondary as i32;
226        let app_version_tertiary = device_data.app_version_tertiary as i64;
227        let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
228        let edge_routing_info = device_data.edge_routing_info.clone();
229        let props_hash = device_data.props_hash.clone();
230        let new_lid = device_data
231            .lid
232            .as_ref()
233            .map(|j| j.to_string())
234            .unwrap_or_default();
235        let new_pn = device_data
236            .pn
237            .as_ref()
238            .map(|j| j.to_string())
239            .unwrap_or_default();
240
241        tokio::task::spawn_blocking(move || -> Result<()> {
242            let mut conn = pool
243                .get()
244                .map_err(|e| StoreError::Connection(e.to_string()))?;
245
246            diesel::insert_into(device::table)
247                .values((
248                    device::id.eq(device_id),
249                    device::lid.eq(&new_lid),
250                    device::pn.eq(&new_pn),
251                    device::registration_id.eq(registration_id),
252                    device::noise_key.eq(&noise_key_data),
253                    device::identity_key.eq(&identity_key_data),
254                    device::signed_pre_key.eq(&signed_pre_key_data),
255                    device::signed_pre_key_id.eq(signed_pre_key_id),
256                    device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
257                    device::adv_secret_key.eq(&adv_secret_key[..]),
258                    device::account.eq(account_data.clone()),
259                    device::push_name.eq(&push_name),
260                    device::app_version_primary.eq(app_version_primary),
261                    device::app_version_secondary.eq(app_version_secondary),
262                    device::app_version_tertiary.eq(app_version_tertiary),
263                    device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
264                    device::edge_routing_info.eq(edge_routing_info.clone()),
265                    device::props_hash.eq(props_hash.clone()),
266                ))
267                .on_conflict(device::id)
268                .do_update()
269                .set((
270                    device::lid.eq(&new_lid),
271                    device::pn.eq(&new_pn),
272                    device::registration_id.eq(registration_id),
273                    device::noise_key.eq(&noise_key_data),
274                    device::identity_key.eq(&identity_key_data),
275                    device::signed_pre_key.eq(&signed_pre_key_data),
276                    device::signed_pre_key_id.eq(signed_pre_key_id),
277                    device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
278                    device::adv_secret_key.eq(&adv_secret_key[..]),
279                    device::account.eq(account_data.clone()),
280                    device::push_name.eq(&push_name),
281                    device::app_version_primary.eq(app_version_primary),
282                    device::app_version_secondary.eq(app_version_secondary),
283                    device::app_version_tertiary.eq(app_version_tertiary),
284                    device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
285                    device::edge_routing_info.eq(edge_routing_info),
286                    device::props_hash.eq(props_hash),
287                ))
288                .execute(&mut conn)
289                .map_err(|e| StoreError::Database(e.to_string()))?;
290
291            Ok(())
292        })
293        .await
294        .map_err(|e| StoreError::Database(e.to_string()))??;
295
296        Ok(())
297    }
298
299    pub async fn create_new_device(&self) -> Result<i32> {
300        use crate::schema::device;
301
302        let pool = self.pool.clone();
303        tokio::task::spawn_blocking(move || -> Result<i32> {
304            let mut conn = pool
305                .get()
306                .map_err(|e| StoreError::Connection(e.to_string()))?;
307
308            let new_device = wacore::store::Device::new();
309
310            let noise_key_data = {
311                let mut bytes = Vec::with_capacity(64);
312                bytes.extend_from_slice(new_device.noise_key.private_key.serialize());
313                bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
314                bytes
315            };
316            let identity_key_data = {
317                let mut bytes = Vec::with_capacity(64);
318                bytes.extend_from_slice(new_device.identity_key.private_key.serialize());
319                bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
320                bytes
321            };
322            let signed_pre_key_data = {
323                let mut bytes = Vec::with_capacity(64);
324                bytes.extend_from_slice(new_device.signed_pre_key.private_key.serialize());
325                bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
326                bytes
327            };
328
329            diesel::insert_into(device::table)
330                .values((
331                    device::lid.eq(""),
332                    device::pn.eq(""),
333                    device::registration_id.eq(new_device.registration_id as i32),
334                    device::noise_key.eq(&noise_key_data),
335                    device::identity_key.eq(&identity_key_data),
336                    device::signed_pre_key.eq(&signed_pre_key_data),
337                    device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
338                    device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
339                    device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
340                    device::account.eq(None::<Vec<u8>>),
341                    device::push_name.eq(&new_device.push_name),
342                    device::app_version_primary.eq(new_device.app_version_primary as i32),
343                    device::app_version_secondary.eq(new_device.app_version_secondary as i32),
344                    device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
345                    device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
346                    device::edge_routing_info.eq(None::<Vec<u8>>),
347                    device::props_hash.eq(None::<String>),
348                ))
349                .execute(&mut conn)
350                .map_err(|e| StoreError::Database(e.to_string()))?;
351
352            use diesel::sql_types::Integer;
353
354            #[derive(QueryableByName)]
355            struct LastInsertedId {
356                #[diesel(sql_type = Integer)]
357                last_insert_rowid: i32,
358            }
359
360            let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
361                .get_result::<LastInsertedId>(&mut conn)
362                .map_err(|e| StoreError::Database(e.to_string()))?
363                .last_insert_rowid;
364
365            Ok(device_id)
366        })
367        .await
368        .map_err(|e| StoreError::Database(e.to_string()))?
369    }
370
371    pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
372        use crate::schema::device;
373
374        let pool = self.pool.clone();
375        tokio::task::spawn_blocking(move || -> Result<bool> {
376            let mut conn = pool
377                .get()
378                .map_err(|e| StoreError::Connection(e.to_string()))?;
379
380            let count: i64 = device::table
381                .filter(device::id.eq(device_id))
382                .count()
383                .get_result(&mut conn)
384                .map_err(|e| StoreError::Database(e.to_string()))?;
385
386            Ok(count > 0)
387        })
388        .await
389        .map_err(|e| StoreError::Database(e.to_string()))?
390    }
391
392    pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
393        use crate::schema::device;
394
395        let pool = self.pool.clone();
396        let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
397            let mut conn = pool
398                .get()
399                .map_err(|e| StoreError::Connection(e.to_string()))?;
400            let result = device::table
401                .filter(device::id.eq(device_id))
402                .first::<DeviceRow>(&mut conn)
403                .optional()
404                .map_err(|e| StoreError::Database(e.to_string()))?;
405            Ok(result)
406        })
407        .await
408        .map_err(|e| StoreError::Database(e.to_string()))??;
409
410        if let Some((
411            _device_id,
412            lid_str,
413            pn_str,
414            registration_id,
415            noise_key_data,
416            identity_key_data,
417            signed_pre_key_data,
418            signed_pre_key_id,
419            signed_pre_key_signature_data,
420            adv_secret_key_data,
421            account_data,
422            push_name,
423            app_version_primary,
424            app_version_secondary,
425            app_version_tertiary,
426            app_version_last_fetched_ms,
427            edge_routing_info,
428            props_hash,
429        )) = row
430        {
431            let id = if !pn_str.is_empty() {
432                pn_str.parse().ok()
433            } else {
434                None
435            };
436            let lid = if !lid_str.is_empty() {
437                lid_str.parse().ok()
438            } else {
439                None
440            };
441
442            let noise_key = self.deserialize_keypair(&noise_key_data)?;
443            let identity_key = self.deserialize_keypair(&identity_key_data)?;
444            let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
445
446            let signed_pre_key_signature: [u8; 64] =
447                signed_pre_key_signature_data.try_into().map_err(|_| {
448                    StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
449                })?;
450
451            let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
452                StoreError::Serialization("Invalid adv_secret_key length".to_string())
453            })?;
454
455            let account = account_data
456                .map(|data| {
457                    wa::AdvSignedDeviceIdentity::decode(&data[..])
458                        .map_err(|e| StoreError::Serialization(e.to_string()))
459                })
460                .transpose()?;
461
462            Ok(Some(CoreDevice {
463                pn: id,
464                lid,
465                registration_id: registration_id as u32,
466                noise_key,
467                identity_key,
468                signed_pre_key,
469                signed_pre_key_id: signed_pre_key_id as u32,
470                signed_pre_key_signature,
471                adv_secret_key,
472                account,
473                push_name,
474                app_version_primary: app_version_primary as u32,
475                app_version_secondary: app_version_secondary as u32,
476                app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
477                app_version_last_fetched_ms,
478                device_props: {
479                    use wacore::store::device::DEVICE_PROPS;
480                    DEVICE_PROPS.clone()
481                },
482                edge_routing_info,
483                props_hash,
484            }))
485        } else {
486            Ok(None)
487        }
488    }
489
490    pub async fn put_identity_for_device(
491        &self,
492        address: &str,
493        key: [u8; 32],
494        device_id: i32,
495    ) -> Result<()> {
496        let pool = self.pool.clone();
497        let db_semaphore = self.db_semaphore.clone();
498        let address_owned = address.to_string();
499        let key_vec = key.to_vec();
500
501        const MAX_RETRIES: u32 = 5;
502
503        for attempt in 0..=MAX_RETRIES {
504            let permit =
505                db_semaphore.clone().acquire_owned().await.map_err(|e| {
506                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
507                })?;
508
509            let pool_clone = pool.clone();
510            let address_clone = address_owned.clone();
511            let key_clone = key_vec.clone();
512
513            let result = tokio::task::spawn_blocking(move || -> Result<()> {
514                let mut conn = pool_clone
515                    .get()
516                    .map_err(|e| StoreError::Connection(e.to_string()))?;
517                diesel::insert_into(identities::table)
518                    .values((
519                        identities::address.eq(address_clone),
520                        identities::key.eq(&key_clone[..]),
521                        identities::device_id.eq(device_id),
522                    ))
523                    .on_conflict((identities::address, identities::device_id))
524                    .do_update()
525                    .set(identities::key.eq(&key_clone[..]))
526                    .execute(&mut conn)
527                    .map_err(|e| StoreError::Database(e.to_string()))?;
528                Ok(())
529            })
530            .await;
531
532            drop(permit);
533
534            match result {
535                Ok(Ok(())) => return Ok(()),
536                Ok(Err(e)) => {
537                    let error_msg = e.to_string();
538                    if (error_msg.contains("locked") || error_msg.contains("busy"))
539                        && attempt < MAX_RETRIES
540                    {
541                        let delay_ms = 10 * 2u64.pow(attempt);
542                        warn!(
543                            "Identity write failed (attempt {}/{}): {}. Retrying in {}ms...",
544                            attempt + 1,
545                            MAX_RETRIES + 1,
546                            error_msg,
547                            delay_ms
548                        );
549                        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
550                        continue;
551                    }
552                    return Err(e);
553                }
554                Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
555            }
556        }
557
558        Err(StoreError::Database(format!(
559            "Identity write failed after {} attempts",
560            MAX_RETRIES + 1
561        )))
562    }
563
564    pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
565        let pool = self.pool.clone();
566        let address_owned = address.to_string();
567
568        tokio::task::spawn_blocking(move || -> Result<()> {
569            let mut conn = pool
570                .get()
571                .map_err(|e| StoreError::Connection(e.to_string()))?;
572            diesel::delete(
573                identities::table
574                    .filter(identities::address.eq(address_owned))
575                    .filter(identities::device_id.eq(device_id)),
576            )
577            .execute(&mut conn)
578            .map_err(|e| StoreError::Database(e.to_string()))?;
579            Ok(())
580        })
581        .await
582        .map_err(|e| StoreError::Database(e.to_string()))??;
583
584        Ok(())
585    }
586
587    pub async fn load_identity_for_device(
588        &self,
589        address: &str,
590        device_id: i32,
591    ) -> Result<Option<Vec<u8>>> {
592        let pool = self.pool.clone();
593        let address = address.to_string();
594        let result = self
595            .with_semaphore(move || -> Result<Option<Vec<u8>>> {
596                let mut conn = pool
597                    .get()
598                    .map_err(|e| StoreError::Connection(e.to_string()))?;
599                let res: Option<Vec<u8>> = identities::table
600                    .select(identities::key)
601                    .filter(identities::address.eq(address))
602                    .filter(identities::device_id.eq(device_id))
603                    .first(&mut conn)
604                    .optional()
605                    .map_err(|e| StoreError::Database(e.to_string()))?;
606                Ok(res)
607            })
608            .await?;
609
610        Ok(result)
611    }
612
613    pub async fn get_session_for_device(
614        &self,
615        address: &str,
616        device_id: i32,
617    ) -> Result<Option<Vec<u8>>> {
618        let pool = self.pool.clone();
619        let address_for_query = address.to_string();
620        let result = self
621            .with_semaphore(move || -> Result<Option<Vec<u8>>> {
622                let mut conn = pool
623                    .get()
624                    .map_err(|e| StoreError::Connection(e.to_string()))?;
625                let res: Option<Vec<u8>> = sessions::table
626                    .select(sessions::record)
627                    .filter(sessions::address.eq(address_for_query.clone()))
628                    .filter(sessions::device_id.eq(device_id))
629                    .first(&mut conn)
630                    .optional()
631                    .map_err(|e| StoreError::Database(e.to_string()))?;
632
633                Ok(res)
634            })
635            .await?;
636
637        Ok(result)
638    }
639
640    pub async fn put_session_for_device(
641        &self,
642        address: &str,
643        session: &[u8],
644        device_id: i32,
645    ) -> Result<()> {
646        let pool = self.pool.clone();
647        let db_semaphore = self.db_semaphore.clone();
648        let address_owned = address.to_string();
649        let session_vec = session.to_vec();
650
651        const MAX_RETRIES: u32 = 5;
652
653        for attempt in 0..=MAX_RETRIES {
654            let permit =
655                db_semaphore.clone().acquire_owned().await.map_err(|e| {
656                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
657                })?;
658
659            let pool_clone = pool.clone();
660            let address_clone = address_owned.clone();
661            let session_clone = session_vec.clone();
662
663            let result = tokio::task::spawn_blocking(move || -> Result<()> {
664                let mut conn = pool_clone
665                    .get()
666                    .map_err(|e| StoreError::Connection(e.to_string()))?;
667                diesel::insert_into(sessions::table)
668                    .values((
669                        sessions::address.eq(address_clone),
670                        sessions::record.eq(&session_clone),
671                        sessions::device_id.eq(device_id),
672                    ))
673                    .on_conflict((sessions::address, sessions::device_id))
674                    .do_update()
675                    .set(sessions::record.eq(&session_clone))
676                    .execute(&mut conn)
677                    .map_err(|e| StoreError::Database(e.to_string()))?;
678                Ok(())
679            })
680            .await;
681
682            drop(permit);
683
684            match result {
685                Ok(Ok(())) => {
686                    return Ok(());
687                }
688                Ok(Err(e)) => {
689                    let error_msg = e.to_string();
690                    if (error_msg.contains("locked") || error_msg.contains("busy"))
691                        && attempt < MAX_RETRIES
692                    {
693                        let delay_ms = 10 * 2u64.pow(attempt);
694                        warn!(
695                            "Session write failed (attempt {}/{}): {}. Retrying in {}ms...",
696                            attempt + 1,
697                            MAX_RETRIES + 1,
698                            error_msg,
699                            delay_ms
700                        );
701                        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
702                        continue;
703                    }
704                    return Err(e);
705                }
706                Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
707            }
708        }
709
710        Err(StoreError::Database(format!(
711            "Session write failed after {} attempts",
712            MAX_RETRIES + 1
713        )))
714    }
715
716    pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
717        let pool = self.pool.clone();
718        let address_owned = address.to_string();
719
720        tokio::task::spawn_blocking(move || -> Result<()> {
721            let mut conn = pool
722                .get()
723                .map_err(|e| StoreError::Connection(e.to_string()))?;
724            diesel::delete(
725                sessions::table
726                    .filter(sessions::address.eq(address_owned))
727                    .filter(sessions::device_id.eq(device_id)),
728            )
729            .execute(&mut conn)
730            .map_err(|e| StoreError::Database(e.to_string()))?;
731            Ok(())
732        })
733        .await
734        .map_err(|e| StoreError::Database(e.to_string()))??;
735
736        Ok(())
737    }
738
739    pub async fn put_sender_key_for_device(
740        &self,
741        address: &str,
742        record: &[u8],
743        device_id: i32,
744    ) -> Result<()> {
745        let pool = self.pool.clone();
746        let address = address.to_string();
747        let record_vec = record.to_vec();
748        tokio::task::spawn_blocking(move || -> Result<()> {
749            let mut conn = pool
750                .get()
751                .map_err(|e| StoreError::Connection(e.to_string()))?;
752            diesel::insert_into(sender_keys::table)
753                .values((
754                    sender_keys::address.eq(address),
755                    sender_keys::record.eq(&record_vec),
756                    sender_keys::device_id.eq(device_id),
757                ))
758                .on_conflict((sender_keys::address, sender_keys::device_id))
759                .do_update()
760                .set(sender_keys::record.eq(&record_vec))
761                .execute(&mut conn)
762                .map_err(|e| StoreError::Database(e.to_string()))?;
763            Ok(())
764        })
765        .await
766        .map_err(|e| StoreError::Database(e.to_string()))??;
767        Ok(())
768    }
769
770    pub async fn get_sender_key_for_device(
771        &self,
772        address: &str,
773        device_id: i32,
774    ) -> Result<Option<Vec<u8>>> {
775        let pool = self.pool.clone();
776        let address = address.to_string();
777        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
778            let mut conn = pool
779                .get()
780                .map_err(|e| StoreError::Connection(e.to_string()))?;
781            let res: Option<Vec<u8>> = sender_keys::table
782                .select(sender_keys::record)
783                .filter(sender_keys::address.eq(address))
784                .filter(sender_keys::device_id.eq(device_id))
785                .first(&mut conn)
786                .optional()
787                .map_err(|e| StoreError::Database(e.to_string()))?;
788            Ok(res)
789        })
790        .await
791        .map_err(|e| StoreError::Database(e.to_string()))?
792    }
793
794    pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
795        let pool = self.pool.clone();
796        let address = address.to_string();
797        tokio::task::spawn_blocking(move || -> Result<()> {
798            let mut conn = pool
799                .get()
800                .map_err(|e| StoreError::Connection(e.to_string()))?;
801            diesel::delete(
802                sender_keys::table
803                    .filter(sender_keys::address.eq(address))
804                    .filter(sender_keys::device_id.eq(device_id)),
805            )
806            .execute(&mut conn)
807            .map_err(|e| StoreError::Database(e.to_string()))?;
808            Ok(())
809        })
810        .await
811        .map_err(|e| StoreError::Database(e.to_string()))??;
812        Ok(())
813    }
814
815    pub async fn get_app_state_sync_key_for_device(
816        &self,
817        key_id: &[u8],
818        device_id: i32,
819    ) -> Result<Option<AppStateSyncKey>> {
820        let pool = self.pool.clone();
821        let key_id = key_id.to_vec();
822        let res: Option<Vec<u8>> =
823            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
824                let mut conn = pool
825                    .get()
826                    .map_err(|e| StoreError::Connection(e.to_string()))?;
827                let res: Option<Vec<u8>> = app_state_keys::table
828                    .select(app_state_keys::key_data)
829                    .filter(app_state_keys::key_id.eq(&key_id))
830                    .filter(app_state_keys::device_id.eq(device_id))
831                    .first(&mut conn)
832                    .optional()
833                    .map_err(|e| StoreError::Database(e.to_string()))?;
834                Ok(res)
835            })
836            .await
837            .map_err(|e| StoreError::Database(e.to_string()))??;
838
839        if let Some(data) = res {
840            let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
841                .map_err(|e| StoreError::Serialization(e.to_string()))?;
842            Ok(Some(key))
843        } else {
844            Ok(None)
845        }
846    }
847
848    pub async fn set_app_state_sync_key_for_device(
849        &self,
850        key_id: &[u8],
851        key: AppStateSyncKey,
852        device_id: i32,
853    ) -> Result<()> {
854        let pool = self.pool.clone();
855        let key_id = key_id.to_vec();
856        let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
857            .map_err(|e| StoreError::Serialization(e.to_string()))?;
858        tokio::task::spawn_blocking(move || -> Result<()> {
859            let mut conn = pool
860                .get()
861                .map_err(|e| StoreError::Connection(e.to_string()))?;
862            diesel::insert_into(app_state_keys::table)
863                .values((
864                    app_state_keys::key_id.eq(&key_id),
865                    app_state_keys::key_data.eq(&data),
866                    app_state_keys::device_id.eq(device_id),
867                ))
868                .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
869                .do_update()
870                .set(app_state_keys::key_data.eq(&data))
871                .execute(&mut conn)
872                .map_err(|e| StoreError::Database(e.to_string()))?;
873            Ok(())
874        })
875        .await
876        .map_err(|e| StoreError::Database(e.to_string()))??;
877        Ok(())
878    }
879
880    pub async fn get_app_state_version_for_device(
881        &self,
882        name: &str,
883        device_id: i32,
884    ) -> Result<HashState> {
885        let pool = self.pool.clone();
886        let name = name.to_string();
887        let res: Option<Vec<u8>> =
888            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
889                let mut conn = pool
890                    .get()
891                    .map_err(|e| StoreError::Connection(e.to_string()))?;
892                let res: Option<Vec<u8>> = app_state_versions::table
893                    .select(app_state_versions::state_data)
894                    .filter(app_state_versions::name.eq(name))
895                    .filter(app_state_versions::device_id.eq(device_id))
896                    .first(&mut conn)
897                    .optional()
898                    .map_err(|e| StoreError::Database(e.to_string()))?;
899                Ok(res)
900            })
901            .await
902            .map_err(|e| StoreError::Database(e.to_string()))??;
903
904        if let Some(data) = res {
905            let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
906                .map_err(|e| StoreError::Serialization(e.to_string()))?;
907            Ok(state)
908        } else {
909            Ok(HashState::default())
910        }
911    }
912
913    pub async fn set_app_state_version_for_device(
914        &self,
915        name: &str,
916        state: HashState,
917        device_id: i32,
918    ) -> Result<()> {
919        let pool = self.pool.clone();
920        let name = name.to_string();
921        let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
922            .map_err(|e| StoreError::Serialization(e.to_string()))?;
923        tokio::task::spawn_blocking(move || -> Result<()> {
924            let mut conn = pool
925                .get()
926                .map_err(|e| StoreError::Connection(e.to_string()))?;
927            diesel::insert_into(app_state_versions::table)
928                .values((
929                    app_state_versions::name.eq(&name),
930                    app_state_versions::state_data.eq(&data),
931                    app_state_versions::device_id.eq(device_id),
932                ))
933                .on_conflict((app_state_versions::name, app_state_versions::device_id))
934                .do_update()
935                .set(app_state_versions::state_data.eq(&data))
936                .execute(&mut conn)
937                .map_err(|e| StoreError::Database(e.to_string()))?;
938            Ok(())
939        })
940        .await
941        .map_err(|e| StoreError::Database(e.to_string()))??;
942        Ok(())
943    }
944
945    pub async fn put_app_state_mutation_macs_for_device(
946        &self,
947        name: &str,
948        version: u64,
949        mutations: &[AppStateMutationMAC],
950        device_id: i32,
951    ) -> Result<()> {
952        if mutations.is_empty() {
953            return Ok(());
954        }
955        let pool = self.pool.clone();
956        let name = name.to_string();
957        let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
958        tokio::task::spawn_blocking(move || -> Result<()> {
959            let mut conn = pool
960                .get()
961                .map_err(|e| StoreError::Connection(e.to_string()))?;
962
963            let records: Vec<_> = mutations
964                .iter()
965                .map(|m| {
966                    (
967                        app_state_mutation_macs::name.eq(&name),
968                        app_state_mutation_macs::version.eq(version as i64),
969                        app_state_mutation_macs::index_mac.eq(&m.index_mac),
970                        app_state_mutation_macs::value_mac.eq(&m.value_mac),
971                        app_state_mutation_macs::device_id.eq(device_id),
972                    )
973                })
974                .collect();
975
976            // SQLite variable limit is typically 999 or 32766.
977            // Each row has 5 columns. 100 rows * 5 = 500 params, which is safe.
978            const CHUNK_SIZE: usize = 100;
979
980            for chunk in records.chunks(CHUNK_SIZE) {
981                diesel::insert_into(app_state_mutation_macs::table)
982                    .values(chunk)
983                    .on_conflict((
984                        app_state_mutation_macs::name,
985                        app_state_mutation_macs::index_mac,
986                        app_state_mutation_macs::device_id,
987                    ))
988                    .do_update()
989                    .set((
990                        app_state_mutation_macs::version
991                            .eq(excluded(app_state_mutation_macs::version)),
992                        app_state_mutation_macs::value_mac
993                            .eq(excluded(app_state_mutation_macs::value_mac)),
994                    ))
995                    .execute(&mut conn)
996                    .map_err(|e| StoreError::Database(e.to_string()))?;
997            }
998            Ok(())
999        })
1000        .await
1001        .map_err(|e| StoreError::Database(e.to_string()))??;
1002        Ok(())
1003    }
1004
1005    pub async fn delete_app_state_mutation_macs_for_device(
1006        &self,
1007        name: &str,
1008        index_macs: &[Vec<u8>],
1009        device_id: i32,
1010    ) -> Result<()> {
1011        if index_macs.is_empty() {
1012            return Ok(());
1013        }
1014        let pool = self.pool.clone();
1015        let name = name.to_string();
1016        let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
1017        tokio::task::spawn_blocking(move || -> Result<()> {
1018            let mut conn = pool
1019                .get()
1020                .map_err(|e| StoreError::Connection(e.to_string()))?;
1021
1022            // SQLite variable limit is usually 999 or higher.
1023            // We use a safe chunk size to stay well within limits.
1024            const CHUNK_SIZE: usize = 500;
1025
1026            for chunk in index_macs.chunks(CHUNK_SIZE) {
1027                diesel::delete(
1028                    app_state_mutation_macs::table.filter(
1029                        app_state_mutation_macs::name
1030                            .eq(&name)
1031                            .and(app_state_mutation_macs::index_mac.eq_any(chunk))
1032                            .and(app_state_mutation_macs::device_id.eq(device_id)),
1033                    ),
1034                )
1035                .execute(&mut conn)
1036                .map_err(|e| StoreError::Database(e.to_string()))?;
1037            }
1038            Ok(())
1039        })
1040        .await
1041        .map_err(|e| StoreError::Database(e.to_string()))??;
1042        Ok(())
1043    }
1044
1045    pub async fn get_app_state_mutation_mac_for_device(
1046        &self,
1047        name: &str,
1048        index_mac: &[u8],
1049        device_id: i32,
1050    ) -> Result<Option<Vec<u8>>> {
1051        let pool = self.pool.clone();
1052        let name = name.to_string();
1053        let index_mac = index_mac.to_vec();
1054        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1055            let mut conn = pool
1056                .get()
1057                .map_err(|e| StoreError::Connection(e.to_string()))?;
1058            let res: Option<Vec<u8>> = app_state_mutation_macs::table
1059                .select(app_state_mutation_macs::value_mac)
1060                .filter(app_state_mutation_macs::name.eq(&name))
1061                .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1062                .filter(app_state_mutation_macs::device_id.eq(device_id))
1063                .first(&mut conn)
1064                .optional()
1065                .map_err(|e| StoreError::Database(e.to_string()))?;
1066            Ok(res)
1067        })
1068        .await
1069        .map_err(|e| StoreError::Database(e.to_string()))?
1070    }
1071}
1072
1073#[async_trait]
1074impl SignalStore for SqliteStore {
1075    async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1076        self.put_identity_for_device(address, key, self.device_id)
1077            .await
1078    }
1079
1080    async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1081        self.load_identity_for_device(address, self.device_id).await
1082    }
1083
1084    async fn delete_identity(&self, address: &str) -> Result<()> {
1085        self.delete_identity_for_device(address, self.device_id)
1086            .await
1087    }
1088
1089    async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1090        self.get_session_for_device(address, self.device_id).await
1091    }
1092
1093    async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1094        self.put_session_for_device(address, session, self.device_id)
1095            .await
1096    }
1097
1098    async fn delete_session(&self, address: &str) -> Result<()> {
1099        self.delete_session_for_device(address, self.device_id)
1100            .await
1101    }
1102
1103    async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1104        let pool = self.pool.clone();
1105        let device_id = self.device_id;
1106        let record = record.to_vec();
1107        tokio::task::spawn_blocking(move || -> Result<()> {
1108            let mut conn = pool
1109                .get()
1110                .map_err(|e| StoreError::Connection(e.to_string()))?;
1111            diesel::insert_into(prekeys::table)
1112                .values((
1113                    prekeys::id.eq(id as i32),
1114                    prekeys::key.eq(&record),
1115                    prekeys::uploaded.eq(uploaded),
1116                    prekeys::device_id.eq(device_id),
1117                ))
1118                .on_conflict((prekeys::id, prekeys::device_id))
1119                .do_update()
1120                .set((prekeys::key.eq(&record), prekeys::uploaded.eq(uploaded)))
1121                .execute(&mut conn)
1122                .map_err(|e| StoreError::Database(e.to_string()))?;
1123            Ok(())
1124        })
1125        .await
1126        .map_err(|e| StoreError::Database(e.to_string()))??;
1127        Ok(())
1128    }
1129
1130    async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1131        let pool = self.pool.clone();
1132        let device_id = self.device_id;
1133        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1134            let mut conn = pool
1135                .get()
1136                .map_err(|e| StoreError::Connection(e.to_string()))?;
1137            let res: Option<Vec<u8>> = prekeys::table
1138                .select(prekeys::key)
1139                .filter(prekeys::id.eq(id as i32))
1140                .filter(prekeys::device_id.eq(device_id))
1141                .first(&mut conn)
1142                .optional()
1143                .map_err(|e| StoreError::Database(e.to_string()))?;
1144            Ok(res)
1145        })
1146        .await
1147        .map_err(|e| StoreError::Database(e.to_string()))?
1148    }
1149
1150    async fn remove_prekey(&self, id: u32) -> Result<()> {
1151        let pool = self.pool.clone();
1152        let device_id = self.device_id;
1153        tokio::task::spawn_blocking(move || -> Result<()> {
1154            let mut conn = pool
1155                .get()
1156                .map_err(|e| StoreError::Connection(e.to_string()))?;
1157            diesel::delete(
1158                prekeys::table
1159                    .filter(prekeys::id.eq(id as i32))
1160                    .filter(prekeys::device_id.eq(device_id)),
1161            )
1162            .execute(&mut conn)
1163            .map_err(|e| StoreError::Database(e.to_string()))?;
1164            Ok(())
1165        })
1166        .await
1167        .map_err(|e| StoreError::Database(e.to_string()))??;
1168        Ok(())
1169    }
1170
1171    async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1172        let pool = self.pool.clone();
1173        let device_id = self.device_id;
1174        let record = record.to_vec();
1175        tokio::task::spawn_blocking(move || -> Result<()> {
1176            let mut conn = pool
1177                .get()
1178                .map_err(|e| StoreError::Connection(e.to_string()))?;
1179            diesel::insert_into(signed_prekeys::table)
1180                .values((
1181                    signed_prekeys::id.eq(id as i32),
1182                    signed_prekeys::record.eq(&record),
1183                    signed_prekeys::device_id.eq(device_id),
1184                ))
1185                .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1186                .do_update()
1187                .set(signed_prekeys::record.eq(&record))
1188                .execute(&mut conn)
1189                .map_err(|e| StoreError::Database(e.to_string()))?;
1190            Ok(())
1191        })
1192        .await
1193        .map_err(|e| StoreError::Database(e.to_string()))??;
1194        Ok(())
1195    }
1196
1197    async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1198        let pool = self.pool.clone();
1199        let device_id = self.device_id;
1200        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1201            let mut conn = pool
1202                .get()
1203                .map_err(|e| StoreError::Connection(e.to_string()))?;
1204            let res: Option<Vec<u8>> = signed_prekeys::table
1205                .select(signed_prekeys::record)
1206                .filter(signed_prekeys::id.eq(id as i32))
1207                .filter(signed_prekeys::device_id.eq(device_id))
1208                .first(&mut conn)
1209                .optional()
1210                .map_err(|e| StoreError::Database(e.to_string()))?;
1211            Ok(res)
1212        })
1213        .await
1214        .map_err(|e| StoreError::Database(e.to_string()))?
1215    }
1216
1217    async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1218        let pool = self.pool.clone();
1219        let device_id = self.device_id;
1220        tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1221            let mut conn = pool
1222                .get()
1223                .map_err(|e| StoreError::Connection(e.to_string()))?;
1224            let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1225                .select((signed_prekeys::id, signed_prekeys::record))
1226                .filter(signed_prekeys::device_id.eq(device_id))
1227                .load(&mut conn)
1228                .map_err(|e| StoreError::Database(e.to_string()))?;
1229            Ok(results
1230                .into_iter()
1231                .map(|(id, record)| (id as u32, record))
1232                .collect())
1233        })
1234        .await
1235        .map_err(|e| StoreError::Database(e.to_string()))?
1236    }
1237
1238    async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1239        let pool = self.pool.clone();
1240        let device_id = self.device_id;
1241        tokio::task::spawn_blocking(move || -> Result<()> {
1242            let mut conn = pool
1243                .get()
1244                .map_err(|e| StoreError::Connection(e.to_string()))?;
1245            diesel::delete(
1246                signed_prekeys::table
1247                    .filter(signed_prekeys::id.eq(id as i32))
1248                    .filter(signed_prekeys::device_id.eq(device_id)),
1249            )
1250            .execute(&mut conn)
1251            .map_err(|e| StoreError::Database(e.to_string()))?;
1252            Ok(())
1253        })
1254        .await
1255        .map_err(|e| StoreError::Database(e.to_string()))??;
1256        Ok(())
1257    }
1258
1259    async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1260        self.put_sender_key_for_device(address, record, self.device_id)
1261            .await
1262    }
1263
1264    async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1265        self.get_sender_key_for_device(address, self.device_id)
1266            .await
1267    }
1268
1269    async fn delete_sender_key(&self, address: &str) -> Result<()> {
1270        self.delete_sender_key_for_device(address, self.device_id)
1271            .await
1272    }
1273}
1274
1275#[async_trait]
1276impl AppSyncStore for SqliteStore {
1277    async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1278        self.get_app_state_sync_key_for_device(key_id, self.device_id)
1279            .await
1280    }
1281
1282    async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1283        self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1284            .await
1285    }
1286
1287    async fn get_version(&self, name: &str) -> Result<HashState> {
1288        self.get_app_state_version_for_device(name, self.device_id)
1289            .await
1290    }
1291
1292    async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1293        self.set_app_state_version_for_device(name, state, self.device_id)
1294            .await
1295    }
1296
1297    async fn put_mutation_macs(
1298        &self,
1299        name: &str,
1300        version: u64,
1301        mutations: &[AppStateMutationMAC],
1302    ) -> Result<()> {
1303        self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1304            .await
1305    }
1306
1307    async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1308        self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1309            .await
1310    }
1311
1312    async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1313        self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1314            .await
1315    }
1316}
1317
1318#[async_trait]
1319impl ProtocolStore for SqliteStore {
1320    async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<Jid>> {
1321        let pool = self.pool.clone();
1322        let device_id = self.device_id;
1323        let group_jid = group_jid.to_string();
1324        tokio::task::spawn_blocking(move || -> Result<Vec<Jid>> {
1325            let mut conn = pool
1326                .get()
1327                .map_err(|e| StoreError::Connection(e.to_string()))?;
1328            let recipients: Vec<String> = skdm_recipients::table
1329                .select(skdm_recipients::device_jid)
1330                .filter(skdm_recipients::group_jid.eq(&group_jid))
1331                .filter(skdm_recipients::device_id.eq(device_id))
1332                .load(&mut conn)
1333                .map_err(|e| StoreError::Database(e.to_string()))?;
1334            let jids: Vec<Jid> = recipients
1335                .iter()
1336                .filter_map(|s| match s.parse::<Jid>() {
1337                    Ok(jid) => Some(jid),
1338                    Err(e) => {
1339                        warn!("Failed to parse SKDM recipient '{}': {}", s, e);
1340                        None
1341                    }
1342                })
1343                .collect();
1344            Ok(jids)
1345        })
1346        .await
1347        .map_err(|e| StoreError::Database(e.to_string()))?
1348    }
1349
1350    async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> Result<()> {
1351        if device_jids.is_empty() {
1352            return Ok(());
1353        }
1354        let pool = self.pool.clone();
1355        let device_id = self.device_id;
1356        let group_jid = group_jid.to_string();
1357        let device_jid_strs: Vec<String> = device_jids.iter().map(|j| j.to_string()).collect();
1358        let now = std::time::SystemTime::now()
1359            .duration_since(std::time::UNIX_EPOCH)
1360            .unwrap_or_default()
1361            .as_secs() as i32;
1362        tokio::task::spawn_blocking(move || -> Result<()> {
1363            let mut conn = pool
1364                .get()
1365                .map_err(|e| StoreError::Connection(e.to_string()))?;
1366
1367            let values: Vec<_> = device_jid_strs
1368                .iter()
1369                .map(|device_jid| {
1370                    (
1371                        skdm_recipients::group_jid.eq(&group_jid),
1372                        skdm_recipients::device_jid.eq(device_jid),
1373                        skdm_recipients::device_id.eq(device_id),
1374                        skdm_recipients::created_at.eq(now),
1375                    )
1376                })
1377                .collect();
1378
1379            const CHUNK_SIZE: usize = 200; // SQLite variable limit ~999, 4 cols/row
1380
1381            for chunk in values.chunks(CHUNK_SIZE) {
1382                diesel::insert_into(skdm_recipients::table)
1383                    .values(chunk)
1384                    .on_conflict((
1385                        skdm_recipients::group_jid,
1386                        skdm_recipients::device_jid,
1387                        skdm_recipients::device_id,
1388                    ))
1389                    .do_nothing()
1390                    .execute(&mut conn)
1391                    .map_err(|e| StoreError::Database(e.to_string()))?;
1392            }
1393            Ok(())
1394        })
1395        .await
1396        .map_err(|e| StoreError::Database(e.to_string()))??;
1397        Ok(())
1398    }
1399
1400    async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1401        let pool = self.pool.clone();
1402        let device_id = self.device_id;
1403        let group_jid = group_jid.to_string();
1404        tokio::task::spawn_blocking(move || -> Result<()> {
1405            let mut conn = pool
1406                .get()
1407                .map_err(|e| StoreError::Connection(e.to_string()))?;
1408            diesel::delete(
1409                skdm_recipients::table
1410                    .filter(skdm_recipients::group_jid.eq(&group_jid))
1411                    .filter(skdm_recipients::device_id.eq(device_id)),
1412            )
1413            .execute(&mut conn)
1414            .map_err(|e| StoreError::Database(e.to_string()))?;
1415            Ok(())
1416        })
1417        .await
1418        .map_err(|e| StoreError::Database(e.to_string()))??;
1419        Ok(())
1420    }
1421
1422    async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1423        let pool = self.pool.clone();
1424        let device_id = self.device_id;
1425        let lid = lid.to_string();
1426        tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1427            let mut conn = pool
1428                .get()
1429                .map_err(|e| StoreError::Connection(e.to_string()))?;
1430            let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1431                .select((
1432                    lid_pn_mapping::lid,
1433                    lid_pn_mapping::phone_number,
1434                    lid_pn_mapping::created_at,
1435                    lid_pn_mapping::learning_source,
1436                    lid_pn_mapping::updated_at,
1437                ))
1438                .filter(lid_pn_mapping::lid.eq(&lid))
1439                .filter(lid_pn_mapping::device_id.eq(device_id))
1440                .first(&mut conn)
1441                .optional()
1442                .map_err(|e| StoreError::Database(e.to_string()))?;
1443            Ok(row.map(
1444                |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1445                    lid,
1446                    phone_number,
1447                    created_at,
1448                    updated_at,
1449                    learning_source,
1450                },
1451            ))
1452        })
1453        .await
1454        .map_err(|e| StoreError::Database(e.to_string()))?
1455    }
1456
1457    async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1458        let pool = self.pool.clone();
1459        let device_id = self.device_id;
1460        let phone = phone.to_string();
1461        tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1462            let mut conn = pool
1463                .get()
1464                .map_err(|e| StoreError::Connection(e.to_string()))?;
1465            let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1466                .select((
1467                    lid_pn_mapping::lid,
1468                    lid_pn_mapping::phone_number,
1469                    lid_pn_mapping::created_at,
1470                    lid_pn_mapping::learning_source,
1471                    lid_pn_mapping::updated_at,
1472                ))
1473                .filter(lid_pn_mapping::phone_number.eq(&phone))
1474                .filter(lid_pn_mapping::device_id.eq(device_id))
1475                .order(lid_pn_mapping::updated_at.desc())
1476                .first(&mut conn)
1477                .optional()
1478                .map_err(|e| StoreError::Database(e.to_string()))?;
1479            Ok(row.map(
1480                |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1481                    lid,
1482                    phone_number,
1483                    created_at,
1484                    updated_at,
1485                    learning_source,
1486                },
1487            ))
1488        })
1489        .await
1490        .map_err(|e| StoreError::Database(e.to_string()))?
1491    }
1492
1493    async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1494        let pool = self.pool.clone();
1495        let device_id = self.device_id;
1496        let entry = entry.clone();
1497        tokio::task::spawn_blocking(move || -> Result<()> {
1498            let mut conn = pool
1499                .get()
1500                .map_err(|e| StoreError::Connection(e.to_string()))?;
1501            diesel::insert_into(lid_pn_mapping::table)
1502                .values((
1503                    lid_pn_mapping::lid.eq(&entry.lid),
1504                    lid_pn_mapping::phone_number.eq(&entry.phone_number),
1505                    lid_pn_mapping::created_at.eq(entry.created_at),
1506                    lid_pn_mapping::learning_source.eq(&entry.learning_source),
1507                    lid_pn_mapping::updated_at.eq(entry.updated_at),
1508                    lid_pn_mapping::device_id.eq(device_id),
1509                ))
1510                .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1511                .do_update()
1512                .set((
1513                    lid_pn_mapping::phone_number.eq(&entry.phone_number),
1514                    lid_pn_mapping::learning_source.eq(&entry.learning_source),
1515                    lid_pn_mapping::updated_at.eq(entry.updated_at),
1516                ))
1517                .execute(&mut conn)
1518                .map_err(|e| StoreError::Database(e.to_string()))?;
1519            Ok(())
1520        })
1521        .await
1522        .map_err(|e| StoreError::Database(e.to_string()))??;
1523        Ok(())
1524    }
1525
1526    async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1527        let pool = self.pool.clone();
1528        let device_id = self.device_id;
1529        tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1530            let mut conn = pool
1531                .get()
1532                .map_err(|e| StoreError::Connection(e.to_string()))?;
1533            let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1534                .select((
1535                    lid_pn_mapping::lid,
1536                    lid_pn_mapping::phone_number,
1537                    lid_pn_mapping::created_at,
1538                    lid_pn_mapping::learning_source,
1539                    lid_pn_mapping::updated_at,
1540                ))
1541                .filter(lid_pn_mapping::device_id.eq(device_id))
1542                .load(&mut conn)
1543                .map_err(|e| StoreError::Database(e.to_string()))?;
1544            Ok(rows
1545                .into_iter()
1546                .map(
1547                    |(lid, phone_number, created_at, learning_source, updated_at)| {
1548                        LidPnMappingEntry {
1549                            lid,
1550                            phone_number,
1551                            created_at,
1552                            updated_at,
1553                            learning_source,
1554                        }
1555                    },
1556                )
1557                .collect())
1558        })
1559        .await
1560        .map_err(|e| StoreError::Database(e.to_string()))?
1561    }
1562
1563    async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1564        let pool = self.pool.clone();
1565        let device_id = self.device_id;
1566        let address = address.to_string();
1567        let message_id = message_id.to_string();
1568        let base_key = base_key.to_vec();
1569        let now = std::time::SystemTime::now()
1570            .duration_since(std::time::UNIX_EPOCH)
1571            .unwrap_or_default()
1572            .as_secs() as i32;
1573        tokio::task::spawn_blocking(move || -> Result<()> {
1574            let mut conn = pool
1575                .get()
1576                .map_err(|e| StoreError::Connection(e.to_string()))?;
1577            diesel::insert_into(base_keys::table)
1578                .values((
1579                    base_keys::address.eq(&address),
1580                    base_keys::message_id.eq(&message_id),
1581                    base_keys::base_key.eq(&base_key),
1582                    base_keys::device_id.eq(device_id),
1583                    base_keys::created_at.eq(now),
1584                ))
1585                .on_conflict((
1586                    base_keys::address,
1587                    base_keys::message_id,
1588                    base_keys::device_id,
1589                ))
1590                .do_update()
1591                .set(base_keys::base_key.eq(&base_key))
1592                .execute(&mut conn)
1593                .map_err(|e| StoreError::Database(e.to_string()))?;
1594            Ok(())
1595        })
1596        .await
1597        .map_err(|e| StoreError::Database(e.to_string()))??;
1598        Ok(())
1599    }
1600
1601    async fn has_same_base_key(
1602        &self,
1603        address: &str,
1604        message_id: &str,
1605        current_base_key: &[u8],
1606    ) -> Result<bool> {
1607        let pool = self.pool.clone();
1608        let device_id = self.device_id;
1609        let address = address.to_string();
1610        let message_id = message_id.to_string();
1611        let current_base_key = current_base_key.to_vec();
1612        tokio::task::spawn_blocking(move || -> Result<bool> {
1613            let mut conn = pool
1614                .get()
1615                .map_err(|e| StoreError::Connection(e.to_string()))?;
1616            let stored_key: Option<Vec<u8>> = base_keys::table
1617                .select(base_keys::base_key)
1618                .filter(base_keys::address.eq(&address))
1619                .filter(base_keys::message_id.eq(&message_id))
1620                .filter(base_keys::device_id.eq(device_id))
1621                .first(&mut conn)
1622                .optional()
1623                .map_err(|e| StoreError::Database(e.to_string()))?;
1624            Ok(stored_key.as_ref() == Some(&current_base_key))
1625        })
1626        .await
1627        .map_err(|e| StoreError::Database(e.to_string()))?
1628    }
1629
1630    async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1631        let pool = self.pool.clone();
1632        let device_id = self.device_id;
1633        let address = address.to_string();
1634        let message_id = message_id.to_string();
1635        tokio::task::spawn_blocking(move || -> Result<()> {
1636            let mut conn = pool
1637                .get()
1638                .map_err(|e| StoreError::Connection(e.to_string()))?;
1639            diesel::delete(
1640                base_keys::table
1641                    .filter(base_keys::address.eq(&address))
1642                    .filter(base_keys::message_id.eq(&message_id))
1643                    .filter(base_keys::device_id.eq(device_id)),
1644            )
1645            .execute(&mut conn)
1646            .map_err(|e| StoreError::Database(e.to_string()))?;
1647            Ok(())
1648        })
1649        .await
1650        .map_err(|e| StoreError::Database(e.to_string()))??;
1651        Ok(())
1652    }
1653
1654    async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1655        let pool = self.pool.clone();
1656        let device_id = self.device_id;
1657        let devices_json = serde_json::to_string(&record.devices)
1658            .map_err(|e| StoreError::Serialization(e.to_string()))?;
1659        let now = std::time::SystemTime::now()
1660            .duration_since(std::time::UNIX_EPOCH)
1661            .unwrap_or_default()
1662            .as_secs() as i32;
1663        tokio::task::spawn_blocking(move || -> Result<()> {
1664            let mut conn = pool
1665                .get()
1666                .map_err(|e| StoreError::Connection(e.to_string()))?;
1667            diesel::insert_into(device_registry::table)
1668                .values((
1669                    device_registry::user_id.eq(&record.user),
1670                    device_registry::devices_json.eq(&devices_json),
1671                    device_registry::timestamp.eq(record.timestamp as i32),
1672                    device_registry::phash.eq(&record.phash),
1673                    device_registry::device_id.eq(device_id),
1674                    device_registry::updated_at.eq(now),
1675                ))
1676                .on_conflict((device_registry::user_id, device_registry::device_id))
1677                .do_update()
1678                .set((
1679                    device_registry::devices_json.eq(&devices_json),
1680                    device_registry::timestamp.eq(record.timestamp as i32),
1681                    device_registry::phash.eq(&record.phash),
1682                    device_registry::updated_at.eq(now),
1683                ))
1684                .execute(&mut conn)
1685                .map_err(|e| StoreError::Database(e.to_string()))?;
1686            Ok(())
1687        })
1688        .await
1689        .map_err(|e| StoreError::Database(e.to_string()))??;
1690        Ok(())
1691    }
1692
1693    async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
1694        let pool = self.pool.clone();
1695        let device_id = self.device_id;
1696        let user = user.to_string();
1697        tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
1698            let mut conn = pool
1699                .get()
1700                .map_err(|e| StoreError::Connection(e.to_string()))?;
1701            let row: Option<(String, String, i32, Option<String>)> = device_registry::table
1702                .select((
1703                    device_registry::user_id,
1704                    device_registry::devices_json,
1705                    device_registry::timestamp,
1706                    device_registry::phash,
1707                ))
1708                .filter(device_registry::user_id.eq(&user))
1709                .filter(device_registry::device_id.eq(device_id))
1710                .first(&mut conn)
1711                .optional()
1712                .map_err(|e| StoreError::Database(e.to_string()))?;
1713            match row {
1714                Some((user, devices_json, timestamp, phash)) => {
1715                    let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
1716                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
1717                    Ok(Some(DeviceListRecord {
1718                        user,
1719                        devices,
1720                        timestamp: timestamp as i64,
1721                        phash,
1722                    }))
1723                }
1724                None => Ok(None),
1725            }
1726        })
1727        .await
1728        .map_err(|e| StoreError::Database(e.to_string()))?
1729    }
1730
1731    async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
1732        let pool = self.pool.clone();
1733        let device_id = self.device_id;
1734        let group_jid = group_jid.to_string();
1735        let participant = participant.to_string();
1736        let now = std::time::SystemTime::now()
1737            .duration_since(std::time::UNIX_EPOCH)
1738            .unwrap_or_default()
1739            .as_secs() as i32;
1740        tokio::task::spawn_blocking(move || -> Result<()> {
1741            let mut conn = pool
1742                .get()
1743                .map_err(|e| StoreError::Connection(e.to_string()))?;
1744            diesel::insert_into(sender_key_status::table)
1745                .values((
1746                    sender_key_status::group_jid.eq(&group_jid),
1747                    sender_key_status::participant.eq(&participant),
1748                    sender_key_status::device_id.eq(device_id),
1749                    sender_key_status::marked_at.eq(now),
1750                ))
1751                .on_conflict((
1752                    sender_key_status::group_jid,
1753                    sender_key_status::participant,
1754                    sender_key_status::device_id,
1755                ))
1756                .do_update()
1757                .set(sender_key_status::marked_at.eq(now))
1758                .execute(&mut conn)
1759                .map_err(|e| StoreError::Database(e.to_string()))?;
1760            Ok(())
1761        })
1762        .await
1763        .map_err(|e| StoreError::Database(e.to_string()))??;
1764        Ok(())
1765    }
1766
1767    async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
1768        let pool = self.pool.clone();
1769        let device_id = self.device_id;
1770        let group_jid = group_jid.to_string();
1771        tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1772            let mut conn = pool
1773                .get()
1774                .map_err(|e| StoreError::Connection(e.to_string()))?;
1775            let participants: Vec<String> = sender_key_status::table
1776                .select(sender_key_status::participant)
1777                .filter(sender_key_status::group_jid.eq(&group_jid))
1778                .filter(sender_key_status::device_id.eq(device_id))
1779                .load(&mut conn)
1780                .map_err(|e| StoreError::Database(e.to_string()))?;
1781            diesel::delete(
1782                sender_key_status::table
1783                    .filter(sender_key_status::group_jid.eq(&group_jid))
1784                    .filter(sender_key_status::device_id.eq(device_id)),
1785            )
1786            .execute(&mut conn)
1787            .map_err(|e| StoreError::Database(e.to_string()))?;
1788            Ok(participants)
1789        })
1790        .await
1791        .map_err(|e| StoreError::Database(e.to_string()))?
1792    }
1793
1794    async fn get_tc_token(&self, jid: &str) -> Result<Option<TcTokenEntry>> {
1795        let pool = self.pool.clone();
1796        let device_id = self.device_id;
1797        let jid = jid.to_string();
1798        tokio::task::spawn_blocking(move || -> Result<Option<TcTokenEntry>> {
1799            let mut conn = pool
1800                .get()
1801                .map_err(|e| StoreError::Connection(e.to_string()))?;
1802            let row: Option<(Vec<u8>, i64, Option<i64>)> = tc_tokens::table
1803                .select((
1804                    tc_tokens::token,
1805                    tc_tokens::token_timestamp,
1806                    tc_tokens::sender_timestamp,
1807                ))
1808                .filter(tc_tokens::jid.eq(&jid))
1809                .filter(tc_tokens::device_id.eq(device_id))
1810                .first(&mut conn)
1811                .optional()
1812                .map_err(|e| StoreError::Database(e.to_string()))?;
1813            Ok(
1814                row.map(|(token, token_timestamp, sender_timestamp)| TcTokenEntry {
1815                    token,
1816                    token_timestamp,
1817                    sender_timestamp,
1818                }),
1819            )
1820        })
1821        .await
1822        .map_err(|e| StoreError::Database(e.to_string()))?
1823    }
1824
1825    async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> Result<()> {
1826        let pool = self.pool.clone();
1827        let device_id = self.device_id;
1828        let jid = jid.to_string();
1829        let entry = entry.clone();
1830        let now = std::time::SystemTime::now()
1831            .duration_since(std::time::UNIX_EPOCH)
1832            .unwrap_or_default()
1833            .as_secs() as i64;
1834        tokio::task::spawn_blocking(move || -> Result<()> {
1835            let mut conn = pool
1836                .get()
1837                .map_err(|e| StoreError::Connection(e.to_string()))?;
1838            diesel::insert_into(tc_tokens::table)
1839                .values((
1840                    tc_tokens::jid.eq(&jid),
1841                    tc_tokens::token.eq(&entry.token),
1842                    tc_tokens::token_timestamp.eq(entry.token_timestamp),
1843                    tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
1844                    tc_tokens::device_id.eq(device_id),
1845                    tc_tokens::updated_at.eq(now),
1846                ))
1847                .on_conflict((tc_tokens::jid, tc_tokens::device_id))
1848                .do_update()
1849                .set((
1850                    tc_tokens::token.eq(&entry.token),
1851                    tc_tokens::token_timestamp.eq(entry.token_timestamp),
1852                    tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
1853                    tc_tokens::updated_at.eq(now),
1854                ))
1855                .execute(&mut conn)
1856                .map_err(|e| StoreError::Database(e.to_string()))?;
1857            Ok(())
1858        })
1859        .await
1860        .map_err(|e| StoreError::Database(e.to_string()))??;
1861        Ok(())
1862    }
1863
1864    async fn delete_tc_token(&self, jid: &str) -> Result<()> {
1865        let pool = self.pool.clone();
1866        let device_id = self.device_id;
1867        let jid = jid.to_string();
1868        tokio::task::spawn_blocking(move || -> Result<()> {
1869            let mut conn = pool
1870                .get()
1871                .map_err(|e| StoreError::Connection(e.to_string()))?;
1872            diesel::delete(
1873                tc_tokens::table
1874                    .filter(tc_tokens::jid.eq(&jid))
1875                    .filter(tc_tokens::device_id.eq(device_id)),
1876            )
1877            .execute(&mut conn)
1878            .map_err(|e| StoreError::Database(e.to_string()))?;
1879            Ok(())
1880        })
1881        .await
1882        .map_err(|e| StoreError::Database(e.to_string()))??;
1883        Ok(())
1884    }
1885
1886    async fn get_all_tc_token_jids(&self) -> Result<Vec<String>> {
1887        let pool = self.pool.clone();
1888        let device_id = self.device_id;
1889        tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1890            let mut conn = pool
1891                .get()
1892                .map_err(|e| StoreError::Connection(e.to_string()))?;
1893            let jids: Vec<String> = tc_tokens::table
1894                .select(tc_tokens::jid)
1895                .filter(tc_tokens::device_id.eq(device_id))
1896                .load(&mut conn)
1897                .map_err(|e| StoreError::Database(e.to_string()))?;
1898            Ok(jids)
1899        })
1900        .await
1901        .map_err(|e| StoreError::Database(e.to_string()))?
1902    }
1903
1904    async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> Result<u32> {
1905        let pool = self.pool.clone();
1906        let device_id = self.device_id;
1907        tokio::task::spawn_blocking(move || -> Result<u32> {
1908            let mut conn = pool
1909                .get()
1910                .map_err(|e| StoreError::Connection(e.to_string()))?;
1911            let deleted = diesel::delete(
1912                tc_tokens::table
1913                    .filter(tc_tokens::token_timestamp.lt(cutoff_timestamp))
1914                    .filter(tc_tokens::device_id.eq(device_id)),
1915            )
1916            .execute(&mut conn)
1917            .map_err(|e| StoreError::Database(e.to_string()))?;
1918            Ok(deleted as u32)
1919        })
1920        .await
1921        .map_err(|e| StoreError::Database(e.to_string()))?
1922    }
1923}
1924
1925#[async_trait]
1926impl DeviceStore for SqliteStore {
1927    async fn save(&self, device: &CoreDevice) -> Result<()> {
1928        SqliteStore::save_device_data_for_device(self, self.device_id, device).await
1929    }
1930
1931    async fn load(&self) -> Result<Option<CoreDevice>> {
1932        SqliteStore::load_device_data_for_device(self, self.device_id).await
1933    }
1934
1935    async fn exists(&self) -> Result<bool> {
1936        SqliteStore::device_exists(self, self.device_id).await
1937    }
1938
1939    async fn create(&self) -> Result<i32> {
1940        SqliteStore::create_new_device(self).await
1941    }
1942
1943    async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> Result<()> {
1944        fn sanitize_snapshot_name(name: &str) -> Result<String> {
1945            const MAX_LENGTH: usize = 100;
1946
1947            let sanitized: String = name
1948                .chars()
1949                .map(|c| {
1950                    if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' {
1951                        c
1952                    } else {
1953                        '_'
1954                    }
1955                })
1956                .collect();
1957
1958            let sanitized = sanitized
1959                .split('.')
1960                .filter(|part| !part.is_empty() && *part != "..")
1961                .collect::<Vec<_>>()
1962                .join(".");
1963
1964            let sanitized = sanitized.trim_matches(['/', '\\', '.']);
1965
1966            if sanitized.is_empty() {
1967                return Err(StoreError::Database(
1968                    "Snapshot name cannot be empty after sanitization".to_string(),
1969                ));
1970            }
1971
1972            if sanitized.len() > MAX_LENGTH {
1973                return Err(StoreError::Database(format!(
1974                    "Snapshot name exceeds maximum length of {} characters",
1975                    MAX_LENGTH
1976                )));
1977            }
1978
1979            Ok(sanitized.to_string())
1980        }
1981
1982        let sanitized_name = sanitize_snapshot_name(name)?;
1983
1984        let pool = self.pool.clone();
1985        let db_path = self.database_path.clone();
1986        let extra_data = extra_content.map(|b| b.to_vec());
1987
1988        tokio::task::spawn_blocking(move || -> Result<()> {
1989            let mut conn = pool
1990                .get()
1991                .map_err(|e| StoreError::Connection(e.to_string()))?;
1992
1993            let timestamp = std::time::SystemTime::now()
1994                .duration_since(std::time::UNIX_EPOCH)
1995                .unwrap_or_default()
1996                .as_secs();
1997
1998            // Construct target path: db_path.snapshot-TIMESTAMP-SANITIZED_NAME
1999            let target_path = format!("{}.snapshot-{}-{}", db_path, timestamp, sanitized_name);
2000
2001            // Use VACUUM INTO to create a consistent backup
2002            // Note: We escape single quotes in the path just in case
2003            let query = format!("VACUUM INTO '{}'", target_path.replace("'", "''"));
2004
2005            diesel::sql_query(query)
2006                .execute(&mut conn)
2007                .map_err(|e| StoreError::Database(e.to_string()))?;
2008
2009            // Save extra content if provided
2010            if let Some(data) = extra_data {
2011                let extra_path = format!("{}.json", target_path);
2012                std::fs::write(&extra_path, data).map_err(|e| {
2013                    StoreError::Database(format!("Failed to write snapshot extra content: {}", e))
2014                })?;
2015            }
2016
2017            Ok(())
2018        })
2019        .await
2020        .map_err(|e| StoreError::Database(e.to_string()))??;
2021
2022        Ok(())
2023    }
2024}
2025
2026#[cfg(test)]
2027mod tests {
2028    use super::*;
2029
2030    async fn create_test_store() -> SqliteStore {
2031        use std::time::{SystemTime, UNIX_EPOCH};
2032        let timestamp = SystemTime::now()
2033            .duration_since(UNIX_EPOCH)
2034            .unwrap_or_default()
2035            .as_nanos();
2036        let db_name = format!("file:memdb_test_{}?mode=memory&cache=shared", timestamp);
2037        SqliteStore::new(&db_name)
2038            .await
2039            .expect("Failed to create test store")
2040    }
2041
2042    #[test]
2043    fn test_parse_database_path_regular_path() {
2044        let path = "/var/lib/whatsapp/database.db";
2045        let result = parse_database_path(path).unwrap();
2046        assert_eq!(result, "/var/lib/whatsapp/database.db");
2047    }
2048
2049    #[test]
2050    fn test_parse_database_path_with_sqlite_prefix() {
2051        let path = "sqlite:///var/lib/whatsapp/database.db";
2052        let result = parse_database_path(path).unwrap();
2053        assert_eq!(result, "/var/lib/whatsapp/database.db");
2054    }
2055
2056    #[test]
2057    fn test_parse_database_path_with_query_params() {
2058        let path = "file:database.db?mode=memory&cache=shared";
2059        let result = parse_database_path(path).unwrap();
2060        assert_eq!(result, "file:database.db");
2061    }
2062
2063    #[test]
2064    fn test_parse_database_path_with_fragment() {
2065        let path = "file:database.db#fragment";
2066        let result = parse_database_path(path).unwrap();
2067        assert_eq!(result, "file:database.db");
2068    }
2069
2070    #[test]
2071    fn test_parse_database_path_with_both_query_and_fragment() {
2072        let path = "sqlite:///var/lib/database.db?mode=ro#backup";
2073        let result = parse_database_path(path).unwrap();
2074        assert_eq!(result, "/var/lib/database.db");
2075    }
2076
2077    #[test]
2078    fn test_parse_database_path_in_memory_rejected() {
2079        let result = parse_database_path(":memory:");
2080        assert!(result.is_err());
2081        assert!(result.unwrap_err().to_string().contains("not supported"));
2082    }
2083
2084    #[test]
2085    fn test_parse_database_path_in_memory_with_query_rejected() {
2086        let result = parse_database_path(":memory:?cache=shared");
2087        assert!(result.is_err());
2088        assert!(result.unwrap_err().to_string().contains("not supported"));
2089    }
2090
2091    #[tokio::test]
2092    async fn test_device_registry_save_and_get() {
2093        let store = create_test_store().await;
2094
2095        let record = DeviceListRecord {
2096            user: "1234567890".to_string(),
2097            devices: vec![
2098                DeviceInfo {
2099                    device_id: 0,
2100                    key_index: None,
2101                },
2102                DeviceInfo {
2103                    device_id: 1,
2104                    key_index: Some(42),
2105                },
2106            ],
2107            timestamp: 1234567890,
2108            phash: Some("2:abcdef".to_string()),
2109        };
2110
2111        store.update_device_list(record).await.expect("save failed");
2112        let loaded = store
2113            .get_devices("1234567890")
2114            .await
2115            .expect("get failed")
2116            .expect("record should exist");
2117
2118        assert_eq!(loaded.user, "1234567890");
2119        assert_eq!(loaded.devices.len(), 2);
2120        assert_eq!(loaded.devices[0].device_id, 0);
2121        assert_eq!(loaded.devices[1].device_id, 1);
2122        assert_eq!(loaded.devices[1].key_index, Some(42));
2123        assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
2124    }
2125
2126    #[tokio::test]
2127    async fn test_device_registry_update_existing() {
2128        let store = create_test_store().await;
2129
2130        let record1 = DeviceListRecord {
2131            user: "1234567890".to_string(),
2132            devices: vec![DeviceInfo {
2133                device_id: 0,
2134                key_index: None,
2135            }],
2136            timestamp: 1000,
2137            phash: Some("2:old".to_string()),
2138        };
2139        store
2140            .update_device_list(record1)
2141            .await
2142            .expect("save1 failed");
2143
2144        let record2 = DeviceListRecord {
2145            user: "1234567890".to_string(),
2146            devices: vec![
2147                DeviceInfo {
2148                    device_id: 0,
2149                    key_index: None,
2150                },
2151                DeviceInfo {
2152                    device_id: 2,
2153                    key_index: None,
2154                },
2155            ],
2156            timestamp: 2000,
2157            phash: Some("2:new".to_string()),
2158        };
2159        store
2160            .update_device_list(record2)
2161            .await
2162            .expect("save2 failed");
2163
2164        let loaded = store
2165            .get_devices("1234567890")
2166            .await
2167            .expect("get failed")
2168            .expect("record should exist");
2169
2170        assert_eq!(loaded.devices.len(), 2);
2171        assert_eq!(loaded.phash, Some("2:new".to_string()));
2172    }
2173
2174    #[tokio::test]
2175    async fn test_device_registry_get_nonexistent() {
2176        let store = create_test_store().await;
2177        let result = store.get_devices("nonexistent").await.expect("get failed");
2178        assert!(result.is_none());
2179    }
2180
2181    #[tokio::test]
2182    async fn test_sender_key_status_mark_and_consume() {
2183        let store = create_test_store().await;
2184
2185        let group = "group123@g.us";
2186        let participant = "user1@s.whatsapp.net";
2187
2188        store
2189            .mark_forget_sender_key(group, participant)
2190            .await
2191            .expect("mark failed");
2192
2193        let consumed = store
2194            .consume_forget_marks(group)
2195            .await
2196            .expect("consume failed");
2197        assert_eq!(consumed.len(), 1);
2198        assert!(consumed.contains(&participant.to_string()));
2199
2200        let consumed = store
2201            .consume_forget_marks(group)
2202            .await
2203            .expect("consume failed");
2204        assert!(consumed.is_empty());
2205    }
2206
2207    #[tokio::test]
2208    async fn test_sender_key_status_consume_multiple() {
2209        let store = create_test_store().await;
2210
2211        let group = "group123@g.us";
2212
2213        store
2214            .mark_forget_sender_key(group, "user1@s.whatsapp.net")
2215            .await
2216            .expect("mark failed");
2217        store
2218            .mark_forget_sender_key(group, "user2@s.whatsapp.net")
2219            .await
2220            .expect("mark failed");
2221
2222        let consumed = store
2223            .consume_forget_marks(group)
2224            .await
2225            .expect("consume failed");
2226        assert_eq!(consumed.len(), 2);
2227        assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
2228        assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
2229
2230        let consumed = store
2231            .consume_forget_marks(group)
2232            .await
2233            .expect("consume failed");
2234        assert!(consumed.is_empty());
2235    }
2236
2237    #[tokio::test]
2238    async fn test_tc_token_put_and_get() {
2239        let store = create_test_store().await;
2240
2241        let entry = TcTokenEntry {
2242            token: vec![1, 2, 3, 4, 5],
2243            token_timestamp: 1707000000,
2244            sender_timestamp: Some(1707000100),
2245        };
2246
2247        store
2248            .put_tc_token("user@lid", &entry)
2249            .await
2250            .expect("put failed");
2251
2252        let loaded = store
2253            .get_tc_token("user@lid")
2254            .await
2255            .expect("get failed")
2256            .expect("should exist");
2257
2258        assert_eq!(loaded.token, vec![1, 2, 3, 4, 5]);
2259        assert_eq!(loaded.token_timestamp, 1707000000);
2260        assert_eq!(loaded.sender_timestamp, Some(1707000100));
2261    }
2262
2263    #[tokio::test]
2264    async fn test_tc_token_upsert() {
2265        let store = create_test_store().await;
2266
2267        let entry1 = TcTokenEntry {
2268            token: vec![1, 2, 3],
2269            token_timestamp: 1000,
2270            sender_timestamp: None,
2271        };
2272        store.put_tc_token("user@lid", &entry1).await.unwrap();
2273
2274        let entry2 = TcTokenEntry {
2275            token: vec![4, 5, 6],
2276            token_timestamp: 2000,
2277            sender_timestamp: Some(1500),
2278        };
2279        store.put_tc_token("user@lid", &entry2).await.unwrap();
2280
2281        let loaded = store.get_tc_token("user@lid").await.unwrap().unwrap();
2282        assert_eq!(loaded.token, vec![4, 5, 6]);
2283        assert_eq!(loaded.token_timestamp, 2000);
2284        assert_eq!(loaded.sender_timestamp, Some(1500));
2285    }
2286
2287    #[tokio::test]
2288    async fn test_tc_token_delete() {
2289        let store = create_test_store().await;
2290
2291        let entry = TcTokenEntry {
2292            token: vec![1, 2, 3],
2293            token_timestamp: 1000,
2294            sender_timestamp: None,
2295        };
2296        store.put_tc_token("user@lid", &entry).await.unwrap();
2297        store.delete_tc_token("user@lid").await.unwrap();
2298
2299        let result = store.get_tc_token("user@lid").await.unwrap();
2300        assert!(result.is_none());
2301    }
2302
2303    #[tokio::test]
2304    async fn test_tc_token_get_all_jids() {
2305        let store = create_test_store().await;
2306
2307        let entry = TcTokenEntry {
2308            token: vec![1],
2309            token_timestamp: 1000,
2310            sender_timestamp: None,
2311        };
2312        store.put_tc_token("user1@lid", &entry).await.unwrap();
2313        store.put_tc_token("user2@lid", &entry).await.unwrap();
2314        store.put_tc_token("user3@lid", &entry).await.unwrap();
2315
2316        let mut jids = store.get_all_tc_token_jids().await.unwrap();
2317        jids.sort();
2318        assert_eq!(jids, vec!["user1@lid", "user2@lid", "user3@lid"]);
2319    }
2320
2321    #[tokio::test]
2322    async fn test_tc_token_delete_expired() {
2323        let store = create_test_store().await;
2324
2325        let old = TcTokenEntry {
2326            token: vec![1],
2327            token_timestamp: 1000,
2328            sender_timestamp: None,
2329        };
2330        let recent = TcTokenEntry {
2331            token: vec![2],
2332            token_timestamp: 5000,
2333            sender_timestamp: None,
2334        };
2335        store.put_tc_token("old@lid", &old).await.unwrap();
2336        store.put_tc_token("recent@lid", &recent).await.unwrap();
2337
2338        let deleted = store.delete_expired_tc_tokens(3000).await.unwrap();
2339        assert_eq!(deleted, 1);
2340
2341        assert!(store.get_tc_token("old@lid").await.unwrap().is_none());
2342        assert!(store.get_tc_token("recent@lid").await.unwrap().is_some());
2343    }
2344
2345    #[tokio::test]
2346    async fn test_tc_token_get_nonexistent() {
2347        let store = create_test_store().await;
2348        let result = store.get_tc_token("nonexistent@lid").await.unwrap();
2349        assert!(result.is_none());
2350    }
2351
2352    #[tokio::test]
2353    async fn test_sender_key_status_different_groups() {
2354        let store = create_test_store().await;
2355
2356        let group1 = "group1@g.us";
2357        let group2 = "group2@g.us";
2358        let participant = "user@s.whatsapp.net";
2359
2360        store
2361            .mark_forget_sender_key(group1, participant)
2362            .await
2363            .expect("mark failed");
2364
2365        let consumed = store.consume_forget_marks(group1).await.unwrap();
2366        assert_eq!(consumed.len(), 1);
2367
2368        let consumed = store.consume_forget_marks(group2).await.unwrap();
2369        assert!(consumed.is_empty());
2370    }
2371}