Skip to main content

whatsapp_rust_sqlite_storage/
sqlite_store.rs

1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::result::{DatabaseErrorKind, Error as DieselError};
6use diesel::sql_query;
7use diesel::sqlite::SqliteConnection;
8use diesel::upsert::excluded;
9use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
10use log::warn;
11use std::sync::Arc;
12use wacore::appstate::hash::HashState;
13use wacore::appstate::processor::AppStateMutationMAC;
14use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
15use wacore::store::Device as CoreDevice;
16use wacore::store::error::{Result, StoreError};
17use wacore::store::traits::*;
18use wacore_binary::jid::Jid;
19
20/// Internal error type that preserves the Diesel error for structured matching
21/// before converting to `StoreError`. Used in retry loops where we need to
22/// distinguish retriable SQLite lock errors from other failures.
23enum DieselOrStore {
24    Diesel(DieselError),
25    Store(StoreError),
26}
27
28impl From<DieselOrStore> for StoreError {
29    fn from(e: DieselOrStore) -> Self {
30        match e {
31            DieselOrStore::Diesel(e) => StoreError::Database(e.to_string()),
32            DieselOrStore::Store(e) => e,
33        }
34    }
35}
36
37/// Check if a Diesel error represents a retriable SQLite lock contention.
38///
39/// SQLite BUSY (error code 5) and LOCKED (error code 6) both map to
40/// `DatabaseError(Unknown, _)` in Diesel. We inspect the error message
41/// from `sqlite3_errmsg()` to distinguish them from other unknown errors.
42fn is_retriable_sqlite_error(error: &DieselError) -> bool {
43    match error {
44        DieselError::DatabaseError(DatabaseErrorKind::Unknown, info) => {
45            let msg = info.message();
46            msg.contains("locked") || msg.contains("busy")
47        }
48        _ => false,
49    }
50}
51
52const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
53
54type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
55type DeviceRow = (
56    i32,
57    String,
58    String,
59    i32,
60    Vec<u8>,
61    Vec<u8>,
62    Vec<u8>,
63    i32,
64    Vec<u8>,
65    Vec<u8>,
66    Option<Vec<u8>>,
67    String,
68    i32,
69    i32,
70    i64,
71    i64,
72    Option<Vec<u8>>,
73    Option<String>,
74    i32,
75);
76
77#[derive(Clone)]
78pub struct SqliteStore {
79    pub(crate) pool: SqlitePool,
80    pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
81    pub(crate) database_path: String,
82    device_id: i32,
83}
84
85#[derive(Debug, Clone, Copy)]
86struct ConnectionOptions;
87
88impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
89    for ConnectionOptions
90{
91    fn on_acquire(
92        &self,
93        conn: &mut SqliteConnection,
94    ) -> std::result::Result<(), diesel::r2d2::Error> {
95        diesel::sql_query("PRAGMA busy_timeout = 30000;")
96            .execute(conn)
97            .map_err(diesel::r2d2::Error::QueryError)?;
98        diesel::sql_query("PRAGMA synchronous = NORMAL;")
99            .execute(conn)
100            .map_err(diesel::r2d2::Error::QueryError)?;
101        diesel::sql_query("PRAGMA cache_size = 512;")
102            .execute(conn)
103            .map_err(diesel::r2d2::Error::QueryError)?;
104        diesel::sql_query("PRAGMA temp_store = memory;")
105            .execute(conn)
106            .map_err(diesel::r2d2::Error::QueryError)?;
107        diesel::sql_query("PRAGMA foreign_keys = ON;")
108            .execute(conn)
109            .map_err(diesel::r2d2::Error::QueryError)?;
110        Ok(())
111    }
112}
113
114fn parse_database_path(database_url: &str) -> Result<String> {
115    // Reject in-memory databases
116    if database_url == ":memory:" {
117        return Err(StoreError::Database(
118            "Snapshot not supported for in-memory databases".to_string(),
119        ));
120    }
121
122    // Strip query string and fragment
123    let path = database_url
124        .split(['?', '#'])
125        .next()
126        .unwrap_or(database_url);
127
128    // Remove sqlite:// prefix if present
129    let path = path.trim_start_matches("sqlite://");
130
131    // Check if the resulting path looks like an in-memory marker
132    if path == ":memory:" || path.starts_with(":memory:?") {
133        return Err(StoreError::Database(
134            "Snapshot not supported for in-memory databases".to_string(),
135        ));
136    }
137
138    Ok(path.to_string())
139}
140
141impl SqliteStore {
142    pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
143        let manager = ConnectionManager::<SqliteConnection>::new(database_url);
144
145        let pool_size = 2;
146
147        let pool = Pool::builder()
148            .max_size(pool_size)
149            .connection_customizer(Box::new(ConnectionOptions))
150            .build(manager)
151            .map_err(|e| StoreError::Connection(e.to_string()))?;
152
153        let pool_clone = pool.clone();
154        tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
155            let mut conn = pool_clone
156                .get()
157                .map_err(|e| StoreError::Connection(e.to_string()))?;
158
159            diesel::sql_query("PRAGMA journal_mode = WAL;")
160                .execute(&mut conn)
161                .map_err(|e| StoreError::Database(e.to_string()))?;
162
163            conn.run_pending_migrations(MIGRATIONS)
164                .map_err(|e| StoreError::Migration(e.to_string()))?;
165
166            Ok(())
167        })
168        .await
169        .map_err(|e| StoreError::Database(e.to_string()))??;
170
171        let database_path = parse_database_path(database_url)?;
172
173        Ok(Self {
174            pool,
175            db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
176            database_path,
177            device_id: 1,
178        })
179    }
180
181    pub async fn new_for_device(
182        database_url: &str,
183        device_id: i32,
184    ) -> std::result::Result<Self, StoreError> {
185        let mut store = Self::new(database_url).await?;
186        store.device_id = device_id;
187        Ok(store)
188    }
189
190    pub fn device_id(&self) -> i32 {
191        self.device_id
192    }
193
194    async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
195    where
196        F: FnOnce() -> Result<T> + Send + 'static,
197        T: Send + 'static,
198    {
199        let permit = self
200            .db_semaphore
201            .clone()
202            .acquire_owned()
203            .await
204            .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
205        let result = tokio::task::spawn_blocking(move || {
206            let res = f();
207            drop(permit);
208            res
209        })
210        .await
211        .map_err(|e| StoreError::Database(e.to_string()))??;
212        Ok(result)
213    }
214
215    /// Execute a database operation with semaphore serialization and retry on
216    /// transient SQLite lock/busy errors. Mirrors WhatsApp Web's PromiseQueue
217    /// pattern that serializes database commits to avoid concurrent write contention.
218    async fn with_retry<F, T>(&self, op_name: &str, make_op: F) -> Result<T>
219    where
220        F: Fn() -> Box<
221            dyn FnOnce(&mut SqliteConnection) -> std::result::Result<T, DieselError> + Send,
222        >,
223        T: Send + 'static,
224    {
225        const MAX_RETRIES: u32 = 5;
226
227        for attempt in 0..=MAX_RETRIES {
228            let permit = self
229                .db_semaphore
230                .clone()
231                .acquire_owned()
232                .await
233                .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
234
235            let pool = self.pool.clone();
236            let op = make_op();
237
238            let result =
239                tokio::task::spawn_blocking(move || -> std::result::Result<T, DieselOrStore> {
240                    let _permit = permit;
241                    let mut conn = pool
242                        .get()
243                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
244                    op(&mut conn).map_err(DieselOrStore::Diesel)
245                })
246                .await;
247
248            match result {
249                Ok(Ok(val)) => return Ok(val),
250                Ok(Err(DieselOrStore::Diesel(ref e)))
251                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
252                {
253                    let delay_ms = 10u64 * (1u64 << attempt.min(4));
254                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
255                }
256                Ok(Err(e)) => return Err(e.into()),
257                Err(e) => return Err(StoreError::Database(e.to_string())),
258            }
259        }
260
261        Err(StoreError::Database(format!(
262            "{} exhausted retries",
263            op_name
264        )))
265    }
266
267    fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
268        let mut bytes = Vec::with_capacity(64);
269        bytes.extend_from_slice(key_pair.private_key.serialize());
270        bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
271        Ok(bytes)
272    }
273
274    fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
275        if bytes.len() != 64 {
276            return Err(StoreError::Serialization(format!(
277                "Invalid KeyPair length: {}",
278                bytes.len()
279            )));
280        }
281
282        let private_key = PrivateKey::deserialize(&bytes[0..32])
283            .map_err(|e| StoreError::Serialization(e.to_string()))?;
284        let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
285            .map_err(|e| StoreError::Serialization(e.to_string()))?;
286
287        Ok(KeyPair::new(public_key, private_key))
288    }
289
290    pub async fn save_device_data_for_device(
291        &self,
292        device_id: i32,
293        device_data: &CoreDevice,
294    ) -> Result<()> {
295        let pool = self.pool.clone();
296        let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
297        let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
298        let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
299        let account_data = device_data
300            .account
301            .as_ref()
302            .map(wacore::store::device::account_serde::to_bytes);
303        let registration_id = device_data.registration_id as i32;
304        let signed_pre_key_id = device_data.signed_pre_key_id as i32;
305        let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
306        let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
307        let push_name = device_data.push_name.clone();
308        let app_version_primary = device_data.app_version_primary as i32;
309        let app_version_secondary = device_data.app_version_secondary as i32;
310        let app_version_tertiary = device_data.app_version_tertiary as i64;
311        let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
312        let edge_routing_info = device_data.edge_routing_info.clone();
313        let props_hash = device_data.props_hash.clone();
314        let next_pre_key_id = device_data.next_pre_key_id as i32;
315        let new_lid = device_data
316            .lid
317            .as_ref()
318            .map(|j| j.to_string())
319            .unwrap_or_default();
320        let new_pn = device_data
321            .pn
322            .as_ref()
323            .map(|j| j.to_string())
324            .unwrap_or_default();
325
326        tokio::task::spawn_blocking(move || -> Result<()> {
327            let mut conn = pool
328                .get()
329                .map_err(|e| StoreError::Connection(e.to_string()))?;
330
331            diesel::insert_into(device::table)
332                .values((
333                    device::id.eq(device_id),
334                    device::lid.eq(&new_lid),
335                    device::pn.eq(&new_pn),
336                    device::registration_id.eq(registration_id),
337                    device::noise_key.eq(&noise_key_data),
338                    device::identity_key.eq(&identity_key_data),
339                    device::signed_pre_key.eq(&signed_pre_key_data),
340                    device::signed_pre_key_id.eq(signed_pre_key_id),
341                    device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
342                    device::adv_secret_key.eq(&adv_secret_key[..]),
343                    device::account.eq(account_data.clone()),
344                    device::push_name.eq(&push_name),
345                    device::app_version_primary.eq(app_version_primary),
346                    device::app_version_secondary.eq(app_version_secondary),
347                    device::app_version_tertiary.eq(app_version_tertiary),
348                    device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
349                    device::edge_routing_info.eq(edge_routing_info.clone()),
350                    device::props_hash.eq(props_hash.clone()),
351                    device::next_pre_key_id.eq(next_pre_key_id),
352                ))
353                .on_conflict(device::id)
354                .do_update()
355                .set((
356                    device::lid.eq(&new_lid),
357                    device::pn.eq(&new_pn),
358                    device::registration_id.eq(registration_id),
359                    device::noise_key.eq(&noise_key_data),
360                    device::identity_key.eq(&identity_key_data),
361                    device::signed_pre_key.eq(&signed_pre_key_data),
362                    device::signed_pre_key_id.eq(signed_pre_key_id),
363                    device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
364                    device::adv_secret_key.eq(&adv_secret_key[..]),
365                    device::account.eq(account_data.clone()),
366                    device::push_name.eq(&push_name),
367                    device::app_version_primary.eq(app_version_primary),
368                    device::app_version_secondary.eq(app_version_secondary),
369                    device::app_version_tertiary.eq(app_version_tertiary),
370                    device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
371                    device::edge_routing_info.eq(edge_routing_info),
372                    device::props_hash.eq(props_hash),
373                    device::next_pre_key_id.eq(next_pre_key_id),
374                ))
375                .execute(&mut conn)
376                .map_err(|e| StoreError::Database(e.to_string()))?;
377
378            Ok(())
379        })
380        .await
381        .map_err(|e| StoreError::Database(e.to_string()))??;
382
383        Ok(())
384    }
385
386    pub async fn create_new_device(&self) -> Result<i32> {
387        use crate::schema::device;
388
389        let pool = self.pool.clone();
390        tokio::task::spawn_blocking(move || -> Result<i32> {
391            let mut conn = pool
392                .get()
393                .map_err(|e| StoreError::Connection(e.to_string()))?;
394
395            let new_device = wacore::store::Device::new();
396
397            let noise_key_data = {
398                let mut bytes = Vec::with_capacity(64);
399                bytes.extend_from_slice(new_device.noise_key.private_key.serialize());
400                bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
401                bytes
402            };
403            let identity_key_data = {
404                let mut bytes = Vec::with_capacity(64);
405                bytes.extend_from_slice(new_device.identity_key.private_key.serialize());
406                bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
407                bytes
408            };
409            let signed_pre_key_data = {
410                let mut bytes = Vec::with_capacity(64);
411                bytes.extend_from_slice(new_device.signed_pre_key.private_key.serialize());
412                bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
413                bytes
414            };
415
416            diesel::insert_into(device::table)
417                .values((
418                    device::lid.eq(""),
419                    device::pn.eq(""),
420                    device::registration_id.eq(new_device.registration_id as i32),
421                    device::noise_key.eq(&noise_key_data),
422                    device::identity_key.eq(&identity_key_data),
423                    device::signed_pre_key.eq(&signed_pre_key_data),
424                    device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
425                    device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
426                    device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
427                    device::account.eq(None::<Vec<u8>>),
428                    device::push_name.eq(&new_device.push_name),
429                    device::app_version_primary.eq(new_device.app_version_primary as i32),
430                    device::app_version_secondary.eq(new_device.app_version_secondary as i32),
431                    device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
432                    device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
433                    device::edge_routing_info.eq(None::<Vec<u8>>),
434                    device::props_hash.eq(None::<String>),
435                    device::next_pre_key_id.eq(new_device.next_pre_key_id as i32),
436                ))
437                .execute(&mut conn)
438                .map_err(|e| StoreError::Database(e.to_string()))?;
439
440            use diesel::sql_types::Integer;
441
442            #[derive(QueryableByName)]
443            struct LastInsertedId {
444                #[diesel(sql_type = Integer)]
445                last_insert_rowid: i32,
446            }
447
448            let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
449                .get_result::<LastInsertedId>(&mut conn)
450                .map_err(|e| StoreError::Database(e.to_string()))?
451                .last_insert_rowid;
452
453            Ok(device_id)
454        })
455        .await
456        .map_err(|e| StoreError::Database(e.to_string()))?
457    }
458
459    pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
460        use crate::schema::device;
461
462        let pool = self.pool.clone();
463        tokio::task::spawn_blocking(move || -> Result<bool> {
464            let mut conn = pool
465                .get()
466                .map_err(|e| StoreError::Connection(e.to_string()))?;
467
468            let count: i64 = device::table
469                .filter(device::id.eq(device_id))
470                .count()
471                .get_result(&mut conn)
472                .map_err(|e| StoreError::Database(e.to_string()))?;
473
474            Ok(count > 0)
475        })
476        .await
477        .map_err(|e| StoreError::Database(e.to_string()))?
478    }
479
480    pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
481        use crate::schema::device;
482
483        let pool = self.pool.clone();
484        let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
485            let mut conn = pool
486                .get()
487                .map_err(|e| StoreError::Connection(e.to_string()))?;
488            let result = device::table
489                .filter(device::id.eq(device_id))
490                .first::<DeviceRow>(&mut conn)
491                .optional()
492                .map_err(|e| StoreError::Database(e.to_string()))?;
493            Ok(result)
494        })
495        .await
496        .map_err(|e| StoreError::Database(e.to_string()))??;
497
498        if let Some((
499            _device_id,
500            lid_str,
501            pn_str,
502            registration_id,
503            noise_key_data,
504            identity_key_data,
505            signed_pre_key_data,
506            signed_pre_key_id,
507            signed_pre_key_signature_data,
508            adv_secret_key_data,
509            account_data,
510            push_name,
511            app_version_primary,
512            app_version_secondary,
513            app_version_tertiary,
514            app_version_last_fetched_ms,
515            edge_routing_info,
516            props_hash,
517            next_pre_key_id,
518        )) = row
519        {
520            let id = if !pn_str.is_empty() {
521                pn_str.parse().ok()
522            } else {
523                None
524            };
525            let lid = if !lid_str.is_empty() {
526                lid_str.parse().ok()
527            } else {
528                None
529            };
530
531            let noise_key = self.deserialize_keypair(&noise_key_data)?;
532            let identity_key = self.deserialize_keypair(&identity_key_data)?;
533            let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
534
535            let signed_pre_key_signature: [u8; 64] =
536                signed_pre_key_signature_data.try_into().map_err(|_| {
537                    StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
538                })?;
539
540            let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
541                StoreError::Serialization("Invalid adv_secret_key length".to_string())
542            })?;
543
544            let account = account_data
545                .map(|data| {
546                    wacore::store::device::account_serde::from_bytes(&data)
547                        .map_err(|e| StoreError::Serialization(e.to_string()))
548                })
549                .transpose()?;
550
551            Ok(Some(CoreDevice {
552                pn: id,
553                lid,
554                registration_id: registration_id as u32,
555                noise_key,
556                identity_key,
557                signed_pre_key,
558                signed_pre_key_id: signed_pre_key_id as u32,
559                signed_pre_key_signature,
560                adv_secret_key,
561                account,
562                push_name,
563                app_version_primary: app_version_primary as u32,
564                app_version_secondary: app_version_secondary as u32,
565                app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
566                app_version_last_fetched_ms,
567                device_props: {
568                    use wacore::store::device::DEVICE_PROPS;
569                    DEVICE_PROPS.clone()
570                },
571                edge_routing_info,
572                props_hash,
573                next_pre_key_id: next_pre_key_id as u32,
574            }))
575        } else {
576            Ok(None)
577        }
578    }
579
580    pub async fn put_identity_for_device(
581        &self,
582        address: &str,
583        key: [u8; 32],
584        device_id: i32,
585    ) -> Result<()> {
586        let pool = self.pool.clone();
587        let db_semaphore = self.db_semaphore.clone();
588        let address_owned = address.to_string();
589        let key_vec = key.to_vec();
590
591        const MAX_RETRIES: u32 = 5;
592
593        for attempt in 0..=MAX_RETRIES {
594            let permit =
595                db_semaphore.clone().acquire_owned().await.map_err(|e| {
596                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
597                })?;
598
599            let pool_clone = pool.clone();
600            let address_clone = address_owned.clone();
601            let key_clone = key_vec.clone();
602
603            let result =
604                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
605                    let mut conn = pool_clone
606                        .get()
607                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
608                    diesel::insert_into(identities::table)
609                        .values((
610                            identities::address.eq(address_clone),
611                            identities::key.eq(&key_clone[..]),
612                            identities::device_id.eq(device_id),
613                        ))
614                        .on_conflict((identities::address, identities::device_id))
615                        .do_update()
616                        .set(identities::key.eq(&key_clone[..]))
617                        .execute(&mut conn)
618                        .map_err(DieselOrStore::Diesel)?;
619                    Ok(())
620                })
621                .await;
622
623            drop(permit);
624
625            match result {
626                Ok(Ok(())) => return Ok(()),
627                Ok(Err(DieselOrStore::Diesel(ref e)))
628                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
629                {
630                    let delay_ms = 10 * 2u64.pow(attempt);
631                    warn!(
632                        "Identity write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
633                        attempt + 1,
634                        MAX_RETRIES + 1,
635                    );
636                    tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
637                    continue;
638                }
639                Ok(Err(e)) => return Err(e.into()),
640                Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
641            }
642        }
643
644        Err(StoreError::Database(format!(
645            "Identity write failed after {} attempts",
646            MAX_RETRIES + 1
647        )))
648    }
649
650    pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
651        let pool = self.pool.clone();
652        let address_owned = address.to_string();
653
654        tokio::task::spawn_blocking(move || -> Result<()> {
655            let mut conn = pool
656                .get()
657                .map_err(|e| StoreError::Connection(e.to_string()))?;
658            diesel::delete(
659                identities::table
660                    .filter(identities::address.eq(address_owned))
661                    .filter(identities::device_id.eq(device_id)),
662            )
663            .execute(&mut conn)
664            .map_err(|e| StoreError::Database(e.to_string()))?;
665            Ok(())
666        })
667        .await
668        .map_err(|e| StoreError::Database(e.to_string()))??;
669
670        Ok(())
671    }
672
673    pub async fn load_identity_for_device(
674        &self,
675        address: &str,
676        device_id: i32,
677    ) -> Result<Option<Vec<u8>>> {
678        let pool = self.pool.clone();
679        let address = address.to_string();
680        let result = self
681            .with_semaphore(move || -> Result<Option<Vec<u8>>> {
682                let mut conn = pool
683                    .get()
684                    .map_err(|e| StoreError::Connection(e.to_string()))?;
685                let res: Option<Vec<u8>> = identities::table
686                    .select(identities::key)
687                    .filter(identities::address.eq(address))
688                    .filter(identities::device_id.eq(device_id))
689                    .first(&mut conn)
690                    .optional()
691                    .map_err(|e| StoreError::Database(e.to_string()))?;
692                Ok(res)
693            })
694            .await?;
695
696        Ok(result)
697    }
698
699    pub async fn get_session_for_device(
700        &self,
701        address: &str,
702        device_id: i32,
703    ) -> Result<Option<Vec<u8>>> {
704        let pool = self.pool.clone();
705        let address_for_query = address.to_string();
706        let result = self
707            .with_semaphore(move || -> Result<Option<Vec<u8>>> {
708                let mut conn = pool
709                    .get()
710                    .map_err(|e| StoreError::Connection(e.to_string()))?;
711                let res: Option<Vec<u8>> = sessions::table
712                    .select(sessions::record)
713                    .filter(sessions::address.eq(address_for_query.clone()))
714                    .filter(sessions::device_id.eq(device_id))
715                    .first(&mut conn)
716                    .optional()
717                    .map_err(|e| StoreError::Database(e.to_string()))?;
718
719                Ok(res)
720            })
721            .await?;
722
723        Ok(result)
724    }
725
726    pub async fn put_session_for_device(
727        &self,
728        address: &str,
729        session: &[u8],
730        device_id: i32,
731    ) -> Result<()> {
732        let pool = self.pool.clone();
733        let db_semaphore = self.db_semaphore.clone();
734        let address_owned = address.to_string();
735        let session_vec = session.to_vec();
736
737        const MAX_RETRIES: u32 = 5;
738
739        for attempt in 0..=MAX_RETRIES {
740            let permit =
741                db_semaphore.clone().acquire_owned().await.map_err(|e| {
742                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
743                })?;
744
745            let pool_clone = pool.clone();
746            let address_clone = address_owned.clone();
747            let session_clone = session_vec.clone();
748
749            let result =
750                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
751                    let mut conn = pool_clone
752                        .get()
753                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
754                    diesel::insert_into(sessions::table)
755                        .values((
756                            sessions::address.eq(address_clone),
757                            sessions::record.eq(&session_clone),
758                            sessions::device_id.eq(device_id),
759                        ))
760                        .on_conflict((sessions::address, sessions::device_id))
761                        .do_update()
762                        .set(sessions::record.eq(&session_clone))
763                        .execute(&mut conn)
764                        .map_err(DieselOrStore::Diesel)?;
765                    Ok(())
766                })
767                .await;
768
769            drop(permit);
770
771            match result {
772                Ok(Ok(())) => return Ok(()),
773                Ok(Err(DieselOrStore::Diesel(ref e)))
774                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
775                {
776                    let delay_ms = 10 * 2u64.pow(attempt);
777                    warn!(
778                        "Session write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
779                        attempt + 1,
780                        MAX_RETRIES + 1,
781                    );
782                    tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
783                    continue;
784                }
785                Ok(Err(e)) => return Err(e.into()),
786                Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
787            }
788        }
789
790        Err(StoreError::Database(format!(
791            "Session write failed after {} attempts",
792            MAX_RETRIES + 1
793        )))
794    }
795
796    pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
797        let pool = self.pool.clone();
798        let address_owned = address.to_string();
799
800        tokio::task::spawn_blocking(move || -> Result<()> {
801            let mut conn = pool
802                .get()
803                .map_err(|e| StoreError::Connection(e.to_string()))?;
804            diesel::delete(
805                sessions::table
806                    .filter(sessions::address.eq(address_owned))
807                    .filter(sessions::device_id.eq(device_id)),
808            )
809            .execute(&mut conn)
810            .map_err(|e| StoreError::Database(e.to_string()))?;
811            Ok(())
812        })
813        .await
814        .map_err(|e| StoreError::Database(e.to_string()))??;
815
816        Ok(())
817    }
818
819    pub async fn put_sender_key_for_device(
820        &self,
821        address: &str,
822        record: &[u8],
823        device_id: i32,
824    ) -> Result<()> {
825        let pool = self.pool.clone();
826        let address = address.to_string();
827        let record_vec = record.to_vec();
828        tokio::task::spawn_blocking(move || -> Result<()> {
829            let mut conn = pool
830                .get()
831                .map_err(|e| StoreError::Connection(e.to_string()))?;
832            diesel::insert_into(sender_keys::table)
833                .values((
834                    sender_keys::address.eq(address),
835                    sender_keys::record.eq(&record_vec),
836                    sender_keys::device_id.eq(device_id),
837                ))
838                .on_conflict((sender_keys::address, sender_keys::device_id))
839                .do_update()
840                .set(sender_keys::record.eq(&record_vec))
841                .execute(&mut conn)
842                .map_err(|e| StoreError::Database(e.to_string()))?;
843            Ok(())
844        })
845        .await
846        .map_err(|e| StoreError::Database(e.to_string()))??;
847        Ok(())
848    }
849
850    pub async fn get_sender_key_for_device(
851        &self,
852        address: &str,
853        device_id: i32,
854    ) -> Result<Option<Vec<u8>>> {
855        let pool = self.pool.clone();
856        let address = address.to_string();
857        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
858            let mut conn = pool
859                .get()
860                .map_err(|e| StoreError::Connection(e.to_string()))?;
861            let res: Option<Vec<u8>> = sender_keys::table
862                .select(sender_keys::record)
863                .filter(sender_keys::address.eq(address))
864                .filter(sender_keys::device_id.eq(device_id))
865                .first(&mut conn)
866                .optional()
867                .map_err(|e| StoreError::Database(e.to_string()))?;
868            Ok(res)
869        })
870        .await
871        .map_err(|e| StoreError::Database(e.to_string()))?
872    }
873
874    pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
875        let pool = self.pool.clone();
876        let address = address.to_string();
877        tokio::task::spawn_blocking(move || -> Result<()> {
878            let mut conn = pool
879                .get()
880                .map_err(|e| StoreError::Connection(e.to_string()))?;
881            diesel::delete(
882                sender_keys::table
883                    .filter(sender_keys::address.eq(address))
884                    .filter(sender_keys::device_id.eq(device_id)),
885            )
886            .execute(&mut conn)
887            .map_err(|e| StoreError::Database(e.to_string()))?;
888            Ok(())
889        })
890        .await
891        .map_err(|e| StoreError::Database(e.to_string()))??;
892        Ok(())
893    }
894
895    pub async fn get_app_state_sync_key_for_device(
896        &self,
897        key_id: &[u8],
898        device_id: i32,
899    ) -> Result<Option<AppStateSyncKey>> {
900        let pool = self.pool.clone();
901        let key_id = key_id.to_vec();
902        let res: Option<Vec<u8>> =
903            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
904                let mut conn = pool
905                    .get()
906                    .map_err(|e| StoreError::Connection(e.to_string()))?;
907                let res: Option<Vec<u8>> = app_state_keys::table
908                    .select(app_state_keys::key_data)
909                    .filter(app_state_keys::key_id.eq(&key_id))
910                    .filter(app_state_keys::device_id.eq(device_id))
911                    .first(&mut conn)
912                    .optional()
913                    .map_err(|e| StoreError::Database(e.to_string()))?;
914                Ok(res)
915            })
916            .await
917            .map_err(|e| StoreError::Database(e.to_string()))??;
918
919        if let Some(data) = res {
920            let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
921                .map_err(|e| StoreError::Serialization(e.to_string()))?;
922            Ok(Some(key))
923        } else {
924            Ok(None)
925        }
926    }
927
928    pub async fn set_app_state_sync_key_for_device(
929        &self,
930        key_id: &[u8],
931        key: AppStateSyncKey,
932        device_id: i32,
933    ) -> Result<()> {
934        let pool = self.pool.clone();
935        let key_id = key_id.to_vec();
936        let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
937            .map_err(|e| StoreError::Serialization(e.to_string()))?;
938        tokio::task::spawn_blocking(move || -> Result<()> {
939            let mut conn = pool
940                .get()
941                .map_err(|e| StoreError::Connection(e.to_string()))?;
942            diesel::insert_into(app_state_keys::table)
943                .values((
944                    app_state_keys::key_id.eq(&key_id),
945                    app_state_keys::key_data.eq(&data),
946                    app_state_keys::device_id.eq(device_id),
947                ))
948                .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
949                .do_update()
950                .set(app_state_keys::key_data.eq(&data))
951                .execute(&mut conn)
952                .map_err(|e| StoreError::Database(e.to_string()))?;
953            Ok(())
954        })
955        .await
956        .map_err(|e| StoreError::Database(e.to_string()))??;
957        Ok(())
958    }
959
960    pub async fn get_latest_app_state_sync_key_id_for_device(
961        &self,
962        device_id: i32,
963    ) -> Result<Option<Vec<u8>>> {
964        let pool = self.pool.clone();
965        let res: Option<Vec<u8>> =
966            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
967                let mut conn = pool
968                    .get()
969                    .map_err(|e| StoreError::Connection(e.to_string()))?;
970                let res: Option<Vec<u8>> = app_state_keys::table
971                    .select(app_state_keys::key_id)
972                    .filter(app_state_keys::device_id.eq(device_id))
973                    .order(app_state_keys::key_id.desc())
974                    .first(&mut conn)
975                    .optional()
976                    .map_err(|e| StoreError::Database(e.to_string()))?;
977                Ok(res)
978            })
979            .await
980            .map_err(|e| StoreError::Database(e.to_string()))??;
981        Ok(res)
982    }
983
984    pub async fn get_app_state_version_for_device(
985        &self,
986        name: &str,
987        device_id: i32,
988    ) -> Result<HashState> {
989        let pool = self.pool.clone();
990        let name = name.to_string();
991        let res: Option<Vec<u8>> =
992            tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
993                let mut conn = pool
994                    .get()
995                    .map_err(|e| StoreError::Connection(e.to_string()))?;
996                let res: Option<Vec<u8>> = app_state_versions::table
997                    .select(app_state_versions::state_data)
998                    .filter(app_state_versions::name.eq(name))
999                    .filter(app_state_versions::device_id.eq(device_id))
1000                    .first(&mut conn)
1001                    .optional()
1002                    .map_err(|e| StoreError::Database(e.to_string()))?;
1003                Ok(res)
1004            })
1005            .await
1006            .map_err(|e| StoreError::Database(e.to_string()))??;
1007
1008        if let Some(data) = res {
1009            let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
1010                .map_err(|e| StoreError::Serialization(e.to_string()))?;
1011            Ok(state)
1012        } else {
1013            Ok(HashState::default())
1014        }
1015    }
1016
1017    pub async fn set_app_state_version_for_device(
1018        &self,
1019        name: &str,
1020        state: HashState,
1021        device_id: i32,
1022    ) -> Result<()> {
1023        let name = name.to_string();
1024        let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
1025            .map_err(|e| StoreError::Serialization(e.to_string()))?;
1026        self.with_retry("set_app_state_version", || {
1027            let name = name.clone();
1028            let data = data.clone();
1029            Box::new(move |conn: &mut SqliteConnection| {
1030                diesel::insert_into(app_state_versions::table)
1031                    .values((
1032                        app_state_versions::name.eq(&name),
1033                        app_state_versions::state_data.eq(&data),
1034                        app_state_versions::device_id.eq(device_id),
1035                    ))
1036                    .on_conflict((app_state_versions::name, app_state_versions::device_id))
1037                    .do_update()
1038                    .set(app_state_versions::state_data.eq(&data))
1039                    .execute(conn)?;
1040                Ok(())
1041            })
1042        })
1043        .await
1044    }
1045
1046    pub async fn put_app_state_mutation_macs_for_device(
1047        &self,
1048        name: &str,
1049        version: u64,
1050        mutations: &[AppStateMutationMAC],
1051        device_id: i32,
1052    ) -> Result<()> {
1053        if mutations.is_empty() {
1054            return Ok(());
1055        }
1056        let name = name.to_string();
1057        let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
1058        self.with_retry("put_app_state_mutation_macs", || {
1059            let name = name.clone();
1060            let mutations = mutations.clone();
1061            Box::new(move |conn: &mut SqliteConnection| {
1062                let records: Vec<_> = mutations
1063                    .iter()
1064                    .map(|m| {
1065                        (
1066                            app_state_mutation_macs::name.eq(&name),
1067                            app_state_mutation_macs::version.eq(version as i64),
1068                            app_state_mutation_macs::index_mac.eq(&m.index_mac),
1069                            app_state_mutation_macs::value_mac.eq(&m.value_mac),
1070                            app_state_mutation_macs::device_id.eq(device_id),
1071                        )
1072                    })
1073                    .collect();
1074
1075                // SQLite variable limit is typically 999 or 32766.
1076                // Each row has 5 columns. 100 rows * 5 = 500 params, which is safe.
1077                const CHUNK_SIZE: usize = 100;
1078
1079                for chunk in records.chunks(CHUNK_SIZE) {
1080                    diesel::insert_into(app_state_mutation_macs::table)
1081                        .values(chunk)
1082                        .on_conflict((
1083                            app_state_mutation_macs::name,
1084                            app_state_mutation_macs::index_mac,
1085                            app_state_mutation_macs::device_id,
1086                        ))
1087                        .do_update()
1088                        .set((
1089                            app_state_mutation_macs::version
1090                                .eq(excluded(app_state_mutation_macs::version)),
1091                            app_state_mutation_macs::value_mac
1092                                .eq(excluded(app_state_mutation_macs::value_mac)),
1093                        ))
1094                        .execute(conn)?;
1095                }
1096                Ok(())
1097            })
1098        })
1099        .await
1100    }
1101
1102    pub async fn delete_app_state_mutation_macs_for_device(
1103        &self,
1104        name: &str,
1105        index_macs: &[Vec<u8>],
1106        device_id: i32,
1107    ) -> Result<()> {
1108        if index_macs.is_empty() {
1109            return Ok(());
1110        }
1111        let name = name.to_string();
1112        let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
1113        self.with_retry("delete_app_state_mutation_macs", || {
1114            let name = name.clone();
1115            let index_macs = index_macs.clone();
1116            Box::new(move |conn: &mut SqliteConnection| {
1117                // SQLite variable limit is usually 999 or higher.
1118                // We use a safe chunk size to stay well within limits.
1119                const CHUNK_SIZE: usize = 500;
1120
1121                for chunk in index_macs.chunks(CHUNK_SIZE) {
1122                    diesel::delete(
1123                        app_state_mutation_macs::table.filter(
1124                            app_state_mutation_macs::name
1125                                .eq(&name)
1126                                .and(app_state_mutation_macs::index_mac.eq_any(chunk))
1127                                .and(app_state_mutation_macs::device_id.eq(device_id)),
1128                        ),
1129                    )
1130                    .execute(conn)?;
1131                }
1132                Ok(())
1133            })
1134        })
1135        .await
1136    }
1137
1138    pub async fn get_app_state_mutation_mac_for_device(
1139        &self,
1140        name: &str,
1141        index_mac: &[u8],
1142        device_id: i32,
1143    ) -> Result<Option<Vec<u8>>> {
1144        let pool = self.pool.clone();
1145        let name = name.to_string();
1146        let index_mac = index_mac.to_vec();
1147        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1148            let mut conn = pool
1149                .get()
1150                .map_err(|e| StoreError::Connection(e.to_string()))?;
1151            let res: Option<Vec<u8>> = app_state_mutation_macs::table
1152                .select(app_state_mutation_macs::value_mac)
1153                .filter(app_state_mutation_macs::name.eq(&name))
1154                .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1155                .filter(app_state_mutation_macs::device_id.eq(device_id))
1156                .first(&mut conn)
1157                .optional()
1158                .map_err(|e| StoreError::Database(e.to_string()))?;
1159            Ok(res)
1160        })
1161        .await
1162        .map_err(|e| StoreError::Database(e.to_string()))?
1163    }
1164}
1165
1166#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1167#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1168impl SignalStore for SqliteStore {
1169    async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1170        self.put_identity_for_device(address, key, self.device_id)
1171            .await
1172    }
1173
1174    async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1175        self.load_identity_for_device(address, self.device_id).await
1176    }
1177
1178    async fn delete_identity(&self, address: &str) -> Result<()> {
1179        self.delete_identity_for_device(address, self.device_id)
1180            .await
1181    }
1182
1183    async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1184        self.get_session_for_device(address, self.device_id).await
1185    }
1186
1187    async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1188        self.put_session_for_device(address, session, self.device_id)
1189            .await
1190    }
1191
1192    async fn delete_session(&self, address: &str) -> Result<()> {
1193        self.delete_session_for_device(address, self.device_id)
1194            .await
1195    }
1196
1197    async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1198        let pool = self.pool.clone();
1199        let db_semaphore = self.db_semaphore.clone();
1200        let device_id = self.device_id;
1201        let record = record.to_vec();
1202
1203        const MAX_RETRIES: u32 = 5;
1204
1205        for attempt in 0..=MAX_RETRIES {
1206            let permit =
1207                db_semaphore.clone().acquire_owned().await.map_err(|e| {
1208                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1209                })?;
1210
1211            let pool_clone = pool.clone();
1212            let record_clone = record.clone();
1213
1214            let result =
1215                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1216                    let mut conn = pool_clone
1217                        .get()
1218                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1219                    diesel::insert_into(prekeys::table)
1220                        .values((
1221                            prekeys::id.eq(id as i32),
1222                            prekeys::key.eq(&record_clone),
1223                            prekeys::uploaded.eq(uploaded),
1224                            prekeys::device_id.eq(device_id),
1225                        ))
1226                        .on_conflict((prekeys::id, prekeys::device_id))
1227                        .do_update()
1228                        .set((
1229                            prekeys::key.eq(&record_clone),
1230                            prekeys::uploaded.eq(uploaded),
1231                        ))
1232                        .execute(&mut conn)
1233                        .map_err(DieselOrStore::Diesel)?;
1234                    Ok(())
1235                })
1236                .await;
1237
1238            drop(permit);
1239
1240            match result {
1241                Ok(Ok(())) => return Ok(()),
1242                Ok(Err(DieselOrStore::Diesel(ref e)))
1243                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1244                {
1245                    let delay_ms = 10u64 * (1u64 << attempt.min(4));
1246                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1247                }
1248                Ok(Err(e)) => return Err(e.into()),
1249                Err(e) => return Err(StoreError::Database(e.to_string())),
1250            }
1251        }
1252
1253        Err(StoreError::Database(
1254            "store_prekey exhausted retries".to_string(),
1255        ))
1256    }
1257
1258    async fn store_prekeys_batch(&self, keys: &[(u32, Vec<u8>)], uploaded: bool) -> Result<()> {
1259        if keys.is_empty() {
1260            return Ok(());
1261        }
1262
1263        let pool = self.pool.clone();
1264        let db_semaphore = self.db_semaphore.clone();
1265        let device_id = self.device_id;
1266        let keys = keys.to_vec();
1267
1268        const MAX_RETRIES: u32 = 5;
1269
1270        for attempt in 0..=MAX_RETRIES {
1271            let permit =
1272                db_semaphore.clone().acquire_owned().await.map_err(|e| {
1273                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1274                })?;
1275
1276            let pool_clone = pool.clone();
1277            let keys_clone = keys.clone();
1278
1279            let result =
1280                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1281                    let mut conn = pool_clone
1282                        .get()
1283                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1284
1285                    conn.transaction(|conn| {
1286                        for (id, record) in &keys_clone {
1287                            diesel::insert_into(prekeys::table)
1288                                .values((
1289                                    prekeys::id.eq(*id as i32),
1290                                    prekeys::key.eq(record),
1291                                    prekeys::uploaded.eq(uploaded),
1292                                    prekeys::device_id.eq(device_id),
1293                                ))
1294                                .on_conflict((prekeys::id, prekeys::device_id))
1295                                .do_update()
1296                                .set((prekeys::key.eq(record), prekeys::uploaded.eq(uploaded)))
1297                                .execute(conn)?;
1298                        }
1299                        Ok::<(), diesel::result::Error>(())
1300                    })
1301                    .map_err(DieselOrStore::Diesel)
1302                })
1303                .await;
1304
1305            drop(permit);
1306
1307            match result {
1308                Ok(Ok(())) => return Ok(()),
1309                Ok(Err(DieselOrStore::Diesel(ref e)))
1310                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1311                {
1312                    let delay_ms = 10u64 * (1u64 << attempt.min(4));
1313                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1314                }
1315                Ok(Err(e)) => return Err(e.into()),
1316                Err(e) => return Err(StoreError::Database(e.to_string())),
1317            }
1318        }
1319
1320        Err(StoreError::Database(
1321            "store_prekeys_batch exhausted retries".to_string(),
1322        ))
1323    }
1324
1325    async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1326        let pool = self.pool.clone();
1327        let device_id = self.device_id;
1328        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1329            let mut conn = pool
1330                .get()
1331                .map_err(|e| StoreError::Connection(e.to_string()))?;
1332            let res: Option<Vec<u8>> = prekeys::table
1333                .select(prekeys::key)
1334                .filter(prekeys::id.eq(id as i32))
1335                .filter(prekeys::device_id.eq(device_id))
1336                .first(&mut conn)
1337                .optional()
1338                .map_err(|e| StoreError::Database(e.to_string()))?;
1339            Ok(res)
1340        })
1341        .await
1342        .map_err(|e| StoreError::Database(e.to_string()))?
1343    }
1344
1345    async fn remove_prekey(&self, id: u32) -> Result<()> {
1346        let pool = self.pool.clone();
1347        let db_semaphore = self.db_semaphore.clone();
1348        let device_id = self.device_id;
1349
1350        const MAX_RETRIES: u32 = 5;
1351
1352        for attempt in 0..=MAX_RETRIES {
1353            let permit =
1354                db_semaphore.clone().acquire_owned().await.map_err(|e| {
1355                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1356                })?;
1357
1358            let pool_clone = pool.clone();
1359
1360            let result =
1361                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1362                    let mut conn = pool_clone
1363                        .get()
1364                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1365                    diesel::delete(
1366                        prekeys::table
1367                            .filter(prekeys::id.eq(id as i32))
1368                            .filter(prekeys::device_id.eq(device_id)),
1369                    )
1370                    .execute(&mut conn)
1371                    .map_err(DieselOrStore::Diesel)?;
1372                    Ok(())
1373                })
1374                .await;
1375
1376            drop(permit);
1377
1378            match result {
1379                Ok(Ok(())) => return Ok(()),
1380                Ok(Err(DieselOrStore::Diesel(ref e)))
1381                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1382                {
1383                    let delay_ms = 10u64 * (1u64 << attempt.min(4));
1384                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1385                }
1386                Ok(Err(e)) => return Err(e.into()),
1387                Err(e) => return Err(StoreError::Database(e.to_string())),
1388            }
1389        }
1390
1391        Err(StoreError::Database(
1392            "remove_prekey exhausted retries".to_string(),
1393        ))
1394    }
1395
1396    async fn get_max_prekey_id(&self) -> Result<u32> {
1397        let pool = self.pool.clone();
1398        let device_id = self.device_id;
1399        let db_semaphore = self.db_semaphore.clone();
1400        let _permit = db_semaphore
1401            .acquire()
1402            .await
1403            .map_err(|e| StoreError::Database(format!("Failed to acquire semaphore: {}", e)))?;
1404
1405        tokio::task::spawn_blocking(move || -> Result<u32> {
1406            let mut conn = pool
1407                .get()
1408                .map_err(|e| StoreError::Connection(e.to_string()))?;
1409            use diesel::dsl::max;
1410            let result: Option<i32> = prekeys::table
1411                .filter(prekeys::device_id.eq(device_id))
1412                .select(max(prekeys::id))
1413                .first(&mut conn)
1414                .map_err(|e| StoreError::Database(e.to_string()))?;
1415            Ok(result.unwrap_or(0) as u32)
1416        })
1417        .await
1418        .map_err(|e| StoreError::Database(e.to_string()))?
1419    }
1420
1421    async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1422        let pool = self.pool.clone();
1423        let db_semaphore = self.db_semaphore.clone();
1424        let device_id = self.device_id;
1425        let record = record.to_vec();
1426
1427        const MAX_RETRIES: u32 = 5;
1428
1429        for attempt in 0..=MAX_RETRIES {
1430            let permit =
1431                db_semaphore.clone().acquire_owned().await.map_err(|e| {
1432                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1433                })?;
1434
1435            let pool_clone = pool.clone();
1436            let record_clone = record.clone();
1437
1438            let result =
1439                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1440                    let mut conn = pool_clone
1441                        .get()
1442                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1443                    diesel::insert_into(signed_prekeys::table)
1444                        .values((
1445                            signed_prekeys::id.eq(id as i32),
1446                            signed_prekeys::record.eq(&record_clone),
1447                            signed_prekeys::device_id.eq(device_id),
1448                        ))
1449                        .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1450                        .do_update()
1451                        .set(signed_prekeys::record.eq(&record_clone))
1452                        .execute(&mut conn)
1453                        .map_err(DieselOrStore::Diesel)?;
1454                    Ok(())
1455                })
1456                .await;
1457
1458            drop(permit);
1459
1460            match result {
1461                Ok(Ok(())) => return Ok(()),
1462                Ok(Err(DieselOrStore::Diesel(ref e)))
1463                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1464                {
1465                    let delay_ms = 10u64 * (1u64 << attempt.min(4));
1466                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1467                }
1468                Ok(Err(e)) => return Err(e.into()),
1469                Err(e) => return Err(StoreError::Database(e.to_string())),
1470            }
1471        }
1472
1473        Err(StoreError::Database(
1474            "store_signed_prekey exhausted retries".to_string(),
1475        ))
1476    }
1477
1478    async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1479        let pool = self.pool.clone();
1480        let device_id = self.device_id;
1481        tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1482            let mut conn = pool
1483                .get()
1484                .map_err(|e| StoreError::Connection(e.to_string()))?;
1485            let res: Option<Vec<u8>> = signed_prekeys::table
1486                .select(signed_prekeys::record)
1487                .filter(signed_prekeys::id.eq(id as i32))
1488                .filter(signed_prekeys::device_id.eq(device_id))
1489                .first(&mut conn)
1490                .optional()
1491                .map_err(|e| StoreError::Database(e.to_string()))?;
1492            Ok(res)
1493        })
1494        .await
1495        .map_err(|e| StoreError::Database(e.to_string()))?
1496    }
1497
1498    async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1499        let pool = self.pool.clone();
1500        let device_id = self.device_id;
1501        tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1502            let mut conn = pool
1503                .get()
1504                .map_err(|e| StoreError::Connection(e.to_string()))?;
1505            let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1506                .select((signed_prekeys::id, signed_prekeys::record))
1507                .filter(signed_prekeys::device_id.eq(device_id))
1508                .load(&mut conn)
1509                .map_err(|e| StoreError::Database(e.to_string()))?;
1510            Ok(results
1511                .into_iter()
1512                .map(|(id, record)| (id as u32, record))
1513                .collect())
1514        })
1515        .await
1516        .map_err(|e| StoreError::Database(e.to_string()))?
1517    }
1518
1519    async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1520        let pool = self.pool.clone();
1521        let db_semaphore = self.db_semaphore.clone();
1522        let device_id = self.device_id;
1523
1524        const MAX_RETRIES: u32 = 5;
1525
1526        for attempt in 0..=MAX_RETRIES {
1527            let permit =
1528                db_semaphore.clone().acquire_owned().await.map_err(|e| {
1529                    StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1530                })?;
1531
1532            let pool_clone = pool.clone();
1533
1534            let result =
1535                tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1536                    let mut conn = pool_clone
1537                        .get()
1538                        .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1539                    diesel::delete(
1540                        signed_prekeys::table
1541                            .filter(signed_prekeys::id.eq(id as i32))
1542                            .filter(signed_prekeys::device_id.eq(device_id)),
1543                    )
1544                    .execute(&mut conn)
1545                    .map_err(DieselOrStore::Diesel)?;
1546                    Ok(())
1547                })
1548                .await;
1549
1550            drop(permit);
1551
1552            match result {
1553                Ok(Ok(())) => return Ok(()),
1554                Ok(Err(DieselOrStore::Diesel(ref e)))
1555                    if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1556                {
1557                    let delay_ms = 10u64 * (1u64 << attempt.min(4));
1558                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1559                }
1560                Ok(Err(e)) => return Err(e.into()),
1561                Err(e) => return Err(StoreError::Database(e.to_string())),
1562            }
1563        }
1564
1565        Err(StoreError::Database(
1566            "remove_signed_prekey exhausted retries".to_string(),
1567        ))
1568    }
1569
1570    async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1571        self.put_sender_key_for_device(address, record, self.device_id)
1572            .await
1573    }
1574
1575    async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1576        self.get_sender_key_for_device(address, self.device_id)
1577            .await
1578    }
1579
1580    async fn delete_sender_key(&self, address: &str) -> Result<()> {
1581        self.delete_sender_key_for_device(address, self.device_id)
1582            .await
1583    }
1584}
1585
1586#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1587#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1588impl AppSyncStore for SqliteStore {
1589    async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1590        self.get_app_state_sync_key_for_device(key_id, self.device_id)
1591            .await
1592    }
1593
1594    async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1595        self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1596            .await
1597    }
1598
1599    async fn get_version(&self, name: &str) -> Result<HashState> {
1600        self.get_app_state_version_for_device(name, self.device_id)
1601            .await
1602    }
1603
1604    async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1605        self.set_app_state_version_for_device(name, state, self.device_id)
1606            .await
1607    }
1608
1609    async fn put_mutation_macs(
1610        &self,
1611        name: &str,
1612        version: u64,
1613        mutations: &[AppStateMutationMAC],
1614    ) -> Result<()> {
1615        self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1616            .await
1617    }
1618
1619    async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1620        self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1621            .await
1622    }
1623
1624    async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1625        self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1626            .await
1627    }
1628
1629    async fn get_latest_sync_key_id(&self) -> Result<Option<Vec<u8>>> {
1630        self.get_latest_app_state_sync_key_id_for_device(self.device_id)
1631            .await
1632    }
1633}
1634
1635#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1636#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1637impl ProtocolStore for SqliteStore {
1638    async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<Jid>> {
1639        let pool = self.pool.clone();
1640        let device_id = self.device_id;
1641        let group_jid = group_jid.to_string();
1642        tokio::task::spawn_blocking(move || -> Result<Vec<Jid>> {
1643            let mut conn = pool
1644                .get()
1645                .map_err(|e| StoreError::Connection(e.to_string()))?;
1646            let recipients: Vec<String> = skdm_recipients::table
1647                .select(skdm_recipients::device_jid)
1648                .filter(skdm_recipients::group_jid.eq(&group_jid))
1649                .filter(skdm_recipients::device_id.eq(device_id))
1650                .load(&mut conn)
1651                .map_err(|e| StoreError::Database(e.to_string()))?;
1652            let jids: Vec<Jid> = recipients
1653                .iter()
1654                .filter_map(|s| match s.parse::<Jid>() {
1655                    Ok(jid) => Some(jid),
1656                    Err(e) => {
1657                        warn!("Failed to parse SKDM recipient '{}': {}", s, e);
1658                        None
1659                    }
1660                })
1661                .collect();
1662            Ok(jids)
1663        })
1664        .await
1665        .map_err(|e| StoreError::Database(e.to_string()))?
1666    }
1667
1668    async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> Result<()> {
1669        if device_jids.is_empty() {
1670            return Ok(());
1671        }
1672        let pool = self.pool.clone();
1673        let device_id = self.device_id;
1674        let group_jid = group_jid.to_string();
1675        let device_jid_strs: Vec<String> = device_jids.iter().map(|j| j.to_string()).collect();
1676        let now = std::time::SystemTime::now()
1677            .duration_since(std::time::UNIX_EPOCH)
1678            .unwrap_or_default()
1679            .as_secs() as i32;
1680        tokio::task::spawn_blocking(move || -> Result<()> {
1681            let mut conn = pool
1682                .get()
1683                .map_err(|e| StoreError::Connection(e.to_string()))?;
1684
1685            let values: Vec<_> = device_jid_strs
1686                .iter()
1687                .map(|device_jid| {
1688                    (
1689                        skdm_recipients::group_jid.eq(&group_jid),
1690                        skdm_recipients::device_jid.eq(device_jid),
1691                        skdm_recipients::device_id.eq(device_id),
1692                        skdm_recipients::created_at.eq(now),
1693                    )
1694                })
1695                .collect();
1696
1697            const CHUNK_SIZE: usize = 200; // SQLite variable limit ~999, 4 cols/row
1698
1699            for chunk in values.chunks(CHUNK_SIZE) {
1700                diesel::insert_into(skdm_recipients::table)
1701                    .values(chunk)
1702                    .on_conflict((
1703                        skdm_recipients::group_jid,
1704                        skdm_recipients::device_jid,
1705                        skdm_recipients::device_id,
1706                    ))
1707                    .do_nothing()
1708                    .execute(&mut conn)
1709                    .map_err(|e| StoreError::Database(e.to_string()))?;
1710            }
1711            Ok(())
1712        })
1713        .await
1714        .map_err(|e| StoreError::Database(e.to_string()))??;
1715        Ok(())
1716    }
1717
1718    async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1719        let pool = self.pool.clone();
1720        let device_id = self.device_id;
1721        let group_jid = group_jid.to_string();
1722        tokio::task::spawn_blocking(move || -> Result<()> {
1723            let mut conn = pool
1724                .get()
1725                .map_err(|e| StoreError::Connection(e.to_string()))?;
1726            diesel::delete(
1727                skdm_recipients::table
1728                    .filter(skdm_recipients::group_jid.eq(&group_jid))
1729                    .filter(skdm_recipients::device_id.eq(device_id)),
1730            )
1731            .execute(&mut conn)
1732            .map_err(|e| StoreError::Database(e.to_string()))?;
1733            Ok(())
1734        })
1735        .await
1736        .map_err(|e| StoreError::Database(e.to_string()))??;
1737        Ok(())
1738    }
1739
1740    async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1741        let pool = self.pool.clone();
1742        let device_id = self.device_id;
1743        let lid = lid.to_string();
1744        tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1745            let mut conn = pool
1746                .get()
1747                .map_err(|e| StoreError::Connection(e.to_string()))?;
1748            let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1749                .select((
1750                    lid_pn_mapping::lid,
1751                    lid_pn_mapping::phone_number,
1752                    lid_pn_mapping::created_at,
1753                    lid_pn_mapping::learning_source,
1754                    lid_pn_mapping::updated_at,
1755                ))
1756                .filter(lid_pn_mapping::lid.eq(&lid))
1757                .filter(lid_pn_mapping::device_id.eq(device_id))
1758                .first(&mut conn)
1759                .optional()
1760                .map_err(|e| StoreError::Database(e.to_string()))?;
1761            Ok(row.map(
1762                |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1763                    lid,
1764                    phone_number,
1765                    created_at,
1766                    updated_at,
1767                    learning_source,
1768                },
1769            ))
1770        })
1771        .await
1772        .map_err(|e| StoreError::Database(e.to_string()))?
1773    }
1774
1775    async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1776        let pool = self.pool.clone();
1777        let device_id = self.device_id;
1778        let phone = phone.to_string();
1779        tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1780            let mut conn = pool
1781                .get()
1782                .map_err(|e| StoreError::Connection(e.to_string()))?;
1783            let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1784                .select((
1785                    lid_pn_mapping::lid,
1786                    lid_pn_mapping::phone_number,
1787                    lid_pn_mapping::created_at,
1788                    lid_pn_mapping::learning_source,
1789                    lid_pn_mapping::updated_at,
1790                ))
1791                .filter(lid_pn_mapping::phone_number.eq(&phone))
1792                .filter(lid_pn_mapping::device_id.eq(device_id))
1793                .order(lid_pn_mapping::updated_at.desc())
1794                .first(&mut conn)
1795                .optional()
1796                .map_err(|e| StoreError::Database(e.to_string()))?;
1797            Ok(row.map(
1798                |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1799                    lid,
1800                    phone_number,
1801                    created_at,
1802                    updated_at,
1803                    learning_source,
1804                },
1805            ))
1806        })
1807        .await
1808        .map_err(|e| StoreError::Database(e.to_string()))?
1809    }
1810
1811    async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1812        let pool = self.pool.clone();
1813        let device_id = self.device_id;
1814        let entry = entry.clone();
1815        tokio::task::spawn_blocking(move || -> Result<()> {
1816            let mut conn = pool
1817                .get()
1818                .map_err(|e| StoreError::Connection(e.to_string()))?;
1819            diesel::insert_into(lid_pn_mapping::table)
1820                .values((
1821                    lid_pn_mapping::lid.eq(&entry.lid),
1822                    lid_pn_mapping::phone_number.eq(&entry.phone_number),
1823                    lid_pn_mapping::created_at.eq(entry.created_at),
1824                    lid_pn_mapping::learning_source.eq(&entry.learning_source),
1825                    lid_pn_mapping::updated_at.eq(entry.updated_at),
1826                    lid_pn_mapping::device_id.eq(device_id),
1827                ))
1828                .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1829                .do_update()
1830                .set((
1831                    lid_pn_mapping::phone_number.eq(&entry.phone_number),
1832                    lid_pn_mapping::learning_source.eq(&entry.learning_source),
1833                    lid_pn_mapping::updated_at.eq(entry.updated_at),
1834                ))
1835                .execute(&mut conn)
1836                .map_err(|e| StoreError::Database(e.to_string()))?;
1837            Ok(())
1838        })
1839        .await
1840        .map_err(|e| StoreError::Database(e.to_string()))??;
1841        Ok(())
1842    }
1843
1844    async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1845        let pool = self.pool.clone();
1846        let device_id = self.device_id;
1847        tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1848            let mut conn = pool
1849                .get()
1850                .map_err(|e| StoreError::Connection(e.to_string()))?;
1851            let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1852                .select((
1853                    lid_pn_mapping::lid,
1854                    lid_pn_mapping::phone_number,
1855                    lid_pn_mapping::created_at,
1856                    lid_pn_mapping::learning_source,
1857                    lid_pn_mapping::updated_at,
1858                ))
1859                .filter(lid_pn_mapping::device_id.eq(device_id))
1860                .load(&mut conn)
1861                .map_err(|e| StoreError::Database(e.to_string()))?;
1862            Ok(rows
1863                .into_iter()
1864                .map(
1865                    |(lid, phone_number, created_at, learning_source, updated_at)| {
1866                        LidPnMappingEntry {
1867                            lid,
1868                            phone_number,
1869                            created_at,
1870                            updated_at,
1871                            learning_source,
1872                        }
1873                    },
1874                )
1875                .collect())
1876        })
1877        .await
1878        .map_err(|e| StoreError::Database(e.to_string()))?
1879    }
1880
1881    async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1882        let pool = self.pool.clone();
1883        let device_id = self.device_id;
1884        let address = address.to_string();
1885        let message_id = message_id.to_string();
1886        let base_key = base_key.to_vec();
1887        let now = std::time::SystemTime::now()
1888            .duration_since(std::time::UNIX_EPOCH)
1889            .unwrap_or_default()
1890            .as_secs() as i32;
1891        tokio::task::spawn_blocking(move || -> Result<()> {
1892            let mut conn = pool
1893                .get()
1894                .map_err(|e| StoreError::Connection(e.to_string()))?;
1895            diesel::insert_into(base_keys::table)
1896                .values((
1897                    base_keys::address.eq(&address),
1898                    base_keys::message_id.eq(&message_id),
1899                    base_keys::base_key.eq(&base_key),
1900                    base_keys::device_id.eq(device_id),
1901                    base_keys::created_at.eq(now),
1902                ))
1903                .on_conflict((
1904                    base_keys::address,
1905                    base_keys::message_id,
1906                    base_keys::device_id,
1907                ))
1908                .do_update()
1909                .set(base_keys::base_key.eq(&base_key))
1910                .execute(&mut conn)
1911                .map_err(|e| StoreError::Database(e.to_string()))?;
1912            Ok(())
1913        })
1914        .await
1915        .map_err(|e| StoreError::Database(e.to_string()))??;
1916        Ok(())
1917    }
1918
1919    async fn has_same_base_key(
1920        &self,
1921        address: &str,
1922        message_id: &str,
1923        current_base_key: &[u8],
1924    ) -> Result<bool> {
1925        let pool = self.pool.clone();
1926        let device_id = self.device_id;
1927        let address = address.to_string();
1928        let message_id = message_id.to_string();
1929        let current_base_key = current_base_key.to_vec();
1930        tokio::task::spawn_blocking(move || -> Result<bool> {
1931            let mut conn = pool
1932                .get()
1933                .map_err(|e| StoreError::Connection(e.to_string()))?;
1934            let stored_key: Option<Vec<u8>> = base_keys::table
1935                .select(base_keys::base_key)
1936                .filter(base_keys::address.eq(&address))
1937                .filter(base_keys::message_id.eq(&message_id))
1938                .filter(base_keys::device_id.eq(device_id))
1939                .first(&mut conn)
1940                .optional()
1941                .map_err(|e| StoreError::Database(e.to_string()))?;
1942            Ok(stored_key.as_ref() == Some(&current_base_key))
1943        })
1944        .await
1945        .map_err(|e| StoreError::Database(e.to_string()))?
1946    }
1947
1948    async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1949        let pool = self.pool.clone();
1950        let device_id = self.device_id;
1951        let address = address.to_string();
1952        let message_id = message_id.to_string();
1953        tokio::task::spawn_blocking(move || -> Result<()> {
1954            let mut conn = pool
1955                .get()
1956                .map_err(|e| StoreError::Connection(e.to_string()))?;
1957            diesel::delete(
1958                base_keys::table
1959                    .filter(base_keys::address.eq(&address))
1960                    .filter(base_keys::message_id.eq(&message_id))
1961                    .filter(base_keys::device_id.eq(device_id)),
1962            )
1963            .execute(&mut conn)
1964            .map_err(|e| StoreError::Database(e.to_string()))?;
1965            Ok(())
1966        })
1967        .await
1968        .map_err(|e| StoreError::Database(e.to_string()))??;
1969        Ok(())
1970    }
1971
1972    async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1973        let pool = self.pool.clone();
1974        let device_id = self.device_id;
1975        let devices_json = serde_json::to_string(&record.devices)
1976            .map_err(|e| StoreError::Serialization(e.to_string()))?;
1977        let now = std::time::SystemTime::now()
1978            .duration_since(std::time::UNIX_EPOCH)
1979            .unwrap_or_default()
1980            .as_secs() as i32;
1981        tokio::task::spawn_blocking(move || -> Result<()> {
1982            let mut conn = pool
1983                .get()
1984                .map_err(|e| StoreError::Connection(e.to_string()))?;
1985            diesel::insert_into(device_registry::table)
1986                .values((
1987                    device_registry::user_id.eq(&record.user),
1988                    device_registry::devices_json.eq(&devices_json),
1989                    device_registry::timestamp.eq(record.timestamp as i32),
1990                    device_registry::phash.eq(&record.phash),
1991                    device_registry::device_id.eq(device_id),
1992                    device_registry::updated_at.eq(now),
1993                ))
1994                .on_conflict((device_registry::user_id, device_registry::device_id))
1995                .do_update()
1996                .set((
1997                    device_registry::devices_json.eq(&devices_json),
1998                    device_registry::timestamp.eq(record.timestamp as i32),
1999                    device_registry::phash.eq(&record.phash),
2000                    device_registry::updated_at.eq(now),
2001                ))
2002                .execute(&mut conn)
2003                .map_err(|e| StoreError::Database(e.to_string()))?;
2004            Ok(())
2005        })
2006        .await
2007        .map_err(|e| StoreError::Database(e.to_string()))??;
2008        Ok(())
2009    }
2010
2011    async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
2012        let pool = self.pool.clone();
2013        let device_id = self.device_id;
2014        let user = user.to_string();
2015        tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
2016            let mut conn = pool
2017                .get()
2018                .map_err(|e| StoreError::Connection(e.to_string()))?;
2019            let row: Option<(String, String, i32, Option<String>)> = device_registry::table
2020                .select((
2021                    device_registry::user_id,
2022                    device_registry::devices_json,
2023                    device_registry::timestamp,
2024                    device_registry::phash,
2025                ))
2026                .filter(device_registry::user_id.eq(&user))
2027                .filter(device_registry::device_id.eq(device_id))
2028                .first(&mut conn)
2029                .optional()
2030                .map_err(|e| StoreError::Database(e.to_string()))?;
2031            match row {
2032                Some((user, devices_json, timestamp, phash)) => {
2033                    let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
2034                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
2035                    Ok(Some(DeviceListRecord {
2036                        user,
2037                        devices,
2038                        timestamp: timestamp as i64,
2039                        phash,
2040                    }))
2041                }
2042                None => Ok(None),
2043            }
2044        })
2045        .await
2046        .map_err(|e| StoreError::Database(e.to_string()))?
2047    }
2048
2049    async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
2050        let pool = self.pool.clone();
2051        let device_id = self.device_id;
2052        let group_jid = group_jid.to_string();
2053        let participant = participant.to_string();
2054        let now = std::time::SystemTime::now()
2055            .duration_since(std::time::UNIX_EPOCH)
2056            .unwrap_or_default()
2057            .as_secs() as i32;
2058        tokio::task::spawn_blocking(move || -> Result<()> {
2059            let mut conn = pool
2060                .get()
2061                .map_err(|e| StoreError::Connection(e.to_string()))?;
2062            diesel::insert_into(sender_key_status::table)
2063                .values((
2064                    sender_key_status::group_jid.eq(&group_jid),
2065                    sender_key_status::participant.eq(&participant),
2066                    sender_key_status::device_id.eq(device_id),
2067                    sender_key_status::marked_at.eq(now),
2068                ))
2069                .on_conflict((
2070                    sender_key_status::group_jid,
2071                    sender_key_status::participant,
2072                    sender_key_status::device_id,
2073                ))
2074                .do_update()
2075                .set(sender_key_status::marked_at.eq(now))
2076                .execute(&mut conn)
2077                .map_err(|e| StoreError::Database(e.to_string()))?;
2078            Ok(())
2079        })
2080        .await
2081        .map_err(|e| StoreError::Database(e.to_string()))??;
2082        Ok(())
2083    }
2084
2085    async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
2086        let pool = self.pool.clone();
2087        let device_id = self.device_id;
2088        let group_jid = group_jid.to_string();
2089        tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2090            let mut conn = pool
2091                .get()
2092                .map_err(|e| StoreError::Connection(e.to_string()))?;
2093            let participants: Vec<String> = sender_key_status::table
2094                .select(sender_key_status::participant)
2095                .filter(sender_key_status::group_jid.eq(&group_jid))
2096                .filter(sender_key_status::device_id.eq(device_id))
2097                .load(&mut conn)
2098                .map_err(|e| StoreError::Database(e.to_string()))?;
2099            diesel::delete(
2100                sender_key_status::table
2101                    .filter(sender_key_status::group_jid.eq(&group_jid))
2102                    .filter(sender_key_status::device_id.eq(device_id)),
2103            )
2104            .execute(&mut conn)
2105            .map_err(|e| StoreError::Database(e.to_string()))?;
2106            Ok(participants)
2107        })
2108        .await
2109        .map_err(|e| StoreError::Database(e.to_string()))?
2110    }
2111
2112    async fn get_tc_token(&self, jid: &str) -> Result<Option<TcTokenEntry>> {
2113        let pool = self.pool.clone();
2114        let device_id = self.device_id;
2115        let jid = jid.to_string();
2116        tokio::task::spawn_blocking(move || -> Result<Option<TcTokenEntry>> {
2117            let mut conn = pool
2118                .get()
2119                .map_err(|e| StoreError::Connection(e.to_string()))?;
2120            let row: Option<(Vec<u8>, i64, Option<i64>)> = tc_tokens::table
2121                .select((
2122                    tc_tokens::token,
2123                    tc_tokens::token_timestamp,
2124                    tc_tokens::sender_timestamp,
2125                ))
2126                .filter(tc_tokens::jid.eq(&jid))
2127                .filter(tc_tokens::device_id.eq(device_id))
2128                .first(&mut conn)
2129                .optional()
2130                .map_err(|e| StoreError::Database(e.to_string()))?;
2131            Ok(
2132                row.map(|(token, token_timestamp, sender_timestamp)| TcTokenEntry {
2133                    token,
2134                    token_timestamp,
2135                    sender_timestamp,
2136                }),
2137            )
2138        })
2139        .await
2140        .map_err(|e| StoreError::Database(e.to_string()))?
2141    }
2142
2143    async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> Result<()> {
2144        let pool = self.pool.clone();
2145        let device_id = self.device_id;
2146        let jid = jid.to_string();
2147        let entry = entry.clone();
2148        let now = std::time::SystemTime::now()
2149            .duration_since(std::time::UNIX_EPOCH)
2150            .unwrap_or_default()
2151            .as_secs() as i64;
2152        tokio::task::spawn_blocking(move || -> Result<()> {
2153            let mut conn = pool
2154                .get()
2155                .map_err(|e| StoreError::Connection(e.to_string()))?;
2156            diesel::insert_into(tc_tokens::table)
2157                .values((
2158                    tc_tokens::jid.eq(&jid),
2159                    tc_tokens::token.eq(&entry.token),
2160                    tc_tokens::token_timestamp.eq(entry.token_timestamp),
2161                    tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2162                    tc_tokens::device_id.eq(device_id),
2163                    tc_tokens::updated_at.eq(now),
2164                ))
2165                .on_conflict((tc_tokens::jid, tc_tokens::device_id))
2166                .do_update()
2167                .set((
2168                    tc_tokens::token.eq(&entry.token),
2169                    tc_tokens::token_timestamp.eq(entry.token_timestamp),
2170                    tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2171                    tc_tokens::updated_at.eq(now),
2172                ))
2173                .execute(&mut conn)
2174                .map_err(|e| StoreError::Database(e.to_string()))?;
2175            Ok(())
2176        })
2177        .await
2178        .map_err(|e| StoreError::Database(e.to_string()))??;
2179        Ok(())
2180    }
2181
2182    async fn delete_tc_token(&self, jid: &str) -> Result<()> {
2183        let pool = self.pool.clone();
2184        let device_id = self.device_id;
2185        let jid = jid.to_string();
2186        tokio::task::spawn_blocking(move || -> Result<()> {
2187            let mut conn = pool
2188                .get()
2189                .map_err(|e| StoreError::Connection(e.to_string()))?;
2190            diesel::delete(
2191                tc_tokens::table
2192                    .filter(tc_tokens::jid.eq(&jid))
2193                    .filter(tc_tokens::device_id.eq(device_id)),
2194            )
2195            .execute(&mut conn)
2196            .map_err(|e| StoreError::Database(e.to_string()))?;
2197            Ok(())
2198        })
2199        .await
2200        .map_err(|e| StoreError::Database(e.to_string()))??;
2201        Ok(())
2202    }
2203
2204    async fn get_all_tc_token_jids(&self) -> Result<Vec<String>> {
2205        let pool = self.pool.clone();
2206        let device_id = self.device_id;
2207        tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2208            let mut conn = pool
2209                .get()
2210                .map_err(|e| StoreError::Connection(e.to_string()))?;
2211            let jids: Vec<String> = tc_tokens::table
2212                .select(tc_tokens::jid)
2213                .filter(tc_tokens::device_id.eq(device_id))
2214                .load(&mut conn)
2215                .map_err(|e| StoreError::Database(e.to_string()))?;
2216            Ok(jids)
2217        })
2218        .await
2219        .map_err(|e| StoreError::Database(e.to_string()))?
2220    }
2221
2222    async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> Result<u32> {
2223        let pool = self.pool.clone();
2224        let device_id = self.device_id;
2225        tokio::task::spawn_blocking(move || -> Result<u32> {
2226            let mut conn = pool
2227                .get()
2228                .map_err(|e| StoreError::Connection(e.to_string()))?;
2229            let deleted = diesel::delete(
2230                tc_tokens::table
2231                    .filter(tc_tokens::token_timestamp.lt(cutoff_timestamp))
2232                    .filter(tc_tokens::device_id.eq(device_id)),
2233            )
2234            .execute(&mut conn)
2235            .map_err(|e| StoreError::Database(e.to_string()))?;
2236            Ok(deleted as u32)
2237        })
2238        .await
2239        .map_err(|e| StoreError::Database(e.to_string()))?
2240    }
2241
2242    async fn store_sent_message(
2243        &self,
2244        chat_jid: &str,
2245        message_id: &str,
2246        payload: &[u8],
2247    ) -> Result<()> {
2248        let chat_jid = chat_jid.to_string();
2249        let message_id = message_id.to_string();
2250        // Arc avoids cloning the full payload bytes on each retry iteration
2251        let payload: Arc<Vec<u8>> = Arc::new(payload.to_vec());
2252        let device_id = self.device_id;
2253        self.with_retry("store_sent_message", || {
2254            let chat_jid = chat_jid.clone();
2255            let message_id = message_id.clone();
2256            let payload = Arc::clone(&payload);
2257            Box::new(move |conn: &mut SqliteConnection| {
2258                diesel::replace_into(sent_messages::table)
2259                    .values((
2260                        sent_messages::chat_jid.eq(&chat_jid),
2261                        sent_messages::message_id.eq(&message_id),
2262                        sent_messages::payload.eq(payload.as_slice()),
2263                        sent_messages::device_id.eq(device_id),
2264                    ))
2265                    .execute(conn)?;
2266                Ok(())
2267            })
2268        })
2269        .await
2270    }
2271
2272    async fn take_sent_message(&self, chat_jid: &str, message_id: &str) -> Result<Option<Vec<u8>>> {
2273        let chat_jid = chat_jid.to_string();
2274        let message_id = message_id.to_string();
2275        let device_id = self.device_id;
2276        // Atomic SELECT+DELETE with retry for SQLITE_BUSY resilience.
2277        self.with_retry("take_sent_message", || {
2278            let chat_jid = chat_jid.clone();
2279            let message_id = message_id.clone();
2280            Box::new(move |conn: &mut SqliteConnection| {
2281                conn.immediate_transaction(|conn| {
2282                    let row: Option<Vec<u8>> = sent_messages::table
2283                        .select(sent_messages::payload)
2284                        .filter(sent_messages::chat_jid.eq(&chat_jid))
2285                        .filter(sent_messages::message_id.eq(&message_id))
2286                        .filter(sent_messages::device_id.eq(device_id))
2287                        .first(conn)
2288                        .optional()?;
2289                    if row.is_some() {
2290                        diesel::delete(
2291                            sent_messages::table
2292                                .filter(sent_messages::chat_jid.eq(&chat_jid))
2293                                .filter(sent_messages::message_id.eq(&message_id))
2294                                .filter(sent_messages::device_id.eq(device_id)),
2295                        )
2296                        .execute(conn)?;
2297                    }
2298                    Ok(row)
2299                })
2300            })
2301        })
2302        .await
2303    }
2304
2305    async fn delete_expired_sent_messages(&self, cutoff_timestamp: i64) -> Result<u32> {
2306        let pool = self.pool.clone();
2307        let device_id = self.device_id;
2308        tokio::task::spawn_blocking(move || -> Result<u32> {
2309            let mut conn = pool
2310                .get()
2311                .map_err(|e| StoreError::Connection(e.to_string()))?;
2312            let deleted = diesel::delete(
2313                sent_messages::table
2314                    .filter(sent_messages::created_at.lt(cutoff_timestamp))
2315                    .filter(sent_messages::device_id.eq(device_id)),
2316            )
2317            .execute(&mut conn)
2318            .map_err(|e| StoreError::Database(e.to_string()))?;
2319            Ok(deleted as u32)
2320        })
2321        .await
2322        .map_err(|e| StoreError::Database(e.to_string()))?
2323    }
2324}
2325
2326#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
2327#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
2328impl DeviceStore for SqliteStore {
2329    async fn save(&self, device: &CoreDevice) -> Result<()> {
2330        SqliteStore::save_device_data_for_device(self, self.device_id, device).await
2331    }
2332
2333    async fn load(&self) -> Result<Option<CoreDevice>> {
2334        SqliteStore::load_device_data_for_device(self, self.device_id).await
2335    }
2336
2337    async fn exists(&self) -> Result<bool> {
2338        SqliteStore::device_exists(self, self.device_id).await
2339    }
2340
2341    async fn create(&self) -> Result<i32> {
2342        SqliteStore::create_new_device(self).await
2343    }
2344
2345    async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> Result<()> {
2346        fn sanitize_snapshot_name(name: &str) -> Result<String> {
2347            const MAX_LENGTH: usize = 100;
2348
2349            let sanitized: String = name
2350                .chars()
2351                .map(|c| {
2352                    if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' {
2353                        c
2354                    } else {
2355                        '_'
2356                    }
2357                })
2358                .collect();
2359
2360            let sanitized = sanitized
2361                .split('.')
2362                .filter(|part| !part.is_empty() && *part != "..")
2363                .collect::<Vec<_>>()
2364                .join(".");
2365
2366            let sanitized = sanitized.trim_matches(['/', '\\', '.']);
2367
2368            if sanitized.is_empty() {
2369                return Err(StoreError::Database(
2370                    "Snapshot name cannot be empty after sanitization".to_string(),
2371                ));
2372            }
2373
2374            if sanitized.len() > MAX_LENGTH {
2375                return Err(StoreError::Database(format!(
2376                    "Snapshot name exceeds maximum length of {} characters",
2377                    MAX_LENGTH
2378                )));
2379            }
2380
2381            Ok(sanitized.to_string())
2382        }
2383
2384        let sanitized_name = sanitize_snapshot_name(name)?;
2385
2386        let pool = self.pool.clone();
2387        let db_path = self.database_path.clone();
2388        let extra_data = extra_content.map(|b| b.to_vec());
2389
2390        tokio::task::spawn_blocking(move || -> Result<()> {
2391            let mut conn = pool
2392                .get()
2393                .map_err(|e| StoreError::Connection(e.to_string()))?;
2394
2395            let timestamp = std::time::SystemTime::now()
2396                .duration_since(std::time::UNIX_EPOCH)
2397                .unwrap_or_default()
2398                .as_secs();
2399
2400            // Construct target path: db_path.snapshot-TIMESTAMP-SANITIZED_NAME
2401            let target_path = format!("{}.snapshot-{}-{}", db_path, timestamp, sanitized_name);
2402
2403            // Use VACUUM INTO to create a consistent backup
2404            // Note: We escape single quotes in the path just in case
2405            let query = format!("VACUUM INTO '{}'", target_path.replace("'", "''"));
2406
2407            diesel::sql_query(query)
2408                .execute(&mut conn)
2409                .map_err(|e| StoreError::Database(e.to_string()))?;
2410
2411            // Save extra content if provided
2412            if let Some(data) = extra_data {
2413                let extra_path = format!("{}.json", target_path);
2414                std::fs::write(&extra_path, data).map_err(|e| {
2415                    StoreError::Database(format!("Failed to write snapshot extra content: {}", e))
2416                })?;
2417            }
2418
2419            Ok(())
2420        })
2421        .await
2422        .map_err(|e| StoreError::Database(e.to_string()))??;
2423
2424        Ok(())
2425    }
2426}
2427
2428#[cfg(test)]
2429mod tests {
2430    use super::*;
2431
2432    async fn create_test_store() -> SqliteStore {
2433        use std::sync::atomic::{AtomicU64, Ordering};
2434        static COUNTER: AtomicU64 = AtomicU64::new(0);
2435        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
2436        let db_name = format!(
2437            "file:memdb_test_{}_{}?mode=memory&cache=shared",
2438            std::process::id(),
2439            id
2440        );
2441        SqliteStore::new(&db_name)
2442            .await
2443            .expect("Failed to create test store")
2444    }
2445
2446    #[test]
2447    fn test_parse_database_path_regular_path() {
2448        let path = "/var/lib/whatsapp/database.db";
2449        let result = parse_database_path(path).unwrap();
2450        assert_eq!(result, "/var/lib/whatsapp/database.db");
2451    }
2452
2453    #[test]
2454    fn test_parse_database_path_with_sqlite_prefix() {
2455        let path = "sqlite:///var/lib/whatsapp/database.db";
2456        let result = parse_database_path(path).unwrap();
2457        assert_eq!(result, "/var/lib/whatsapp/database.db");
2458    }
2459
2460    #[test]
2461    fn test_parse_database_path_with_query_params() {
2462        let path = "file:database.db?mode=memory&cache=shared";
2463        let result = parse_database_path(path).unwrap();
2464        assert_eq!(result, "file:database.db");
2465    }
2466
2467    #[test]
2468    fn test_parse_database_path_with_fragment() {
2469        let path = "file:database.db#fragment";
2470        let result = parse_database_path(path).unwrap();
2471        assert_eq!(result, "file:database.db");
2472    }
2473
2474    #[test]
2475    fn test_parse_database_path_with_both_query_and_fragment() {
2476        let path = "sqlite:///var/lib/database.db?mode=ro#backup";
2477        let result = parse_database_path(path).unwrap();
2478        assert_eq!(result, "/var/lib/database.db");
2479    }
2480
2481    #[test]
2482    fn test_parse_database_path_in_memory_rejected() {
2483        let result = parse_database_path(":memory:");
2484        assert!(result.is_err());
2485        assert!(result.unwrap_err().to_string().contains("not supported"));
2486    }
2487
2488    #[test]
2489    fn test_parse_database_path_in_memory_with_query_rejected() {
2490        let result = parse_database_path(":memory:?cache=shared");
2491        assert!(result.is_err());
2492        assert!(result.unwrap_err().to_string().contains("not supported"));
2493    }
2494
2495    #[tokio::test]
2496    async fn test_device_registry_save_and_get() {
2497        let store = create_test_store().await;
2498
2499        let record = DeviceListRecord {
2500            user: "1234567890".to_string(),
2501            devices: vec![
2502                DeviceInfo {
2503                    device_id: 0,
2504                    key_index: None,
2505                },
2506                DeviceInfo {
2507                    device_id: 1,
2508                    key_index: Some(42),
2509                },
2510            ],
2511            timestamp: 1234567890,
2512            phash: Some("2:abcdef".to_string()),
2513        };
2514
2515        store.update_device_list(record).await.expect("save failed");
2516        let loaded = store
2517            .get_devices("1234567890")
2518            .await
2519            .expect("get failed")
2520            .expect("record should exist");
2521
2522        assert_eq!(loaded.user, "1234567890");
2523        assert_eq!(loaded.devices.len(), 2);
2524        assert_eq!(loaded.devices[0].device_id, 0);
2525        assert_eq!(loaded.devices[1].device_id, 1);
2526        assert_eq!(loaded.devices[1].key_index, Some(42));
2527        assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
2528    }
2529
2530    #[tokio::test]
2531    async fn test_device_registry_update_existing() {
2532        let store = create_test_store().await;
2533
2534        let record1 = DeviceListRecord {
2535            user: "1234567890".to_string(),
2536            devices: vec![DeviceInfo {
2537                device_id: 0,
2538                key_index: None,
2539            }],
2540            timestamp: 1000,
2541            phash: Some("2:old".to_string()),
2542        };
2543        store
2544            .update_device_list(record1)
2545            .await
2546            .expect("save1 failed");
2547
2548        let record2 = DeviceListRecord {
2549            user: "1234567890".to_string(),
2550            devices: vec![
2551                DeviceInfo {
2552                    device_id: 0,
2553                    key_index: None,
2554                },
2555                DeviceInfo {
2556                    device_id: 2,
2557                    key_index: None,
2558                },
2559            ],
2560            timestamp: 2000,
2561            phash: Some("2:new".to_string()),
2562        };
2563        store
2564            .update_device_list(record2)
2565            .await
2566            .expect("save2 failed");
2567
2568        let loaded = store
2569            .get_devices("1234567890")
2570            .await
2571            .expect("get failed")
2572            .expect("record should exist");
2573
2574        assert_eq!(loaded.devices.len(), 2);
2575        assert_eq!(loaded.phash, Some("2:new".to_string()));
2576    }
2577
2578    #[tokio::test]
2579    async fn test_device_registry_get_nonexistent() {
2580        let store = create_test_store().await;
2581        let result = store.get_devices("nonexistent").await.expect("get failed");
2582        assert!(result.is_none());
2583    }
2584
2585    #[tokio::test]
2586    async fn test_sender_key_status_mark_and_consume() {
2587        let store = create_test_store().await;
2588
2589        let group = "group123@g.us";
2590        let participant = "user1@s.whatsapp.net";
2591
2592        store
2593            .mark_forget_sender_key(group, participant)
2594            .await
2595            .expect("mark failed");
2596
2597        let consumed = store
2598            .consume_forget_marks(group)
2599            .await
2600            .expect("consume failed");
2601        assert_eq!(consumed.len(), 1);
2602        assert!(consumed.contains(&participant.to_string()));
2603
2604        let consumed = store
2605            .consume_forget_marks(group)
2606            .await
2607            .expect("consume failed");
2608        assert!(consumed.is_empty());
2609    }
2610
2611    #[tokio::test]
2612    async fn test_sender_key_status_consume_multiple() {
2613        let store = create_test_store().await;
2614
2615        let group = "group123@g.us";
2616
2617        store
2618            .mark_forget_sender_key(group, "user1@s.whatsapp.net")
2619            .await
2620            .expect("mark failed");
2621        store
2622            .mark_forget_sender_key(group, "user2@s.whatsapp.net")
2623            .await
2624            .expect("mark failed");
2625
2626        let consumed = store
2627            .consume_forget_marks(group)
2628            .await
2629            .expect("consume failed");
2630        assert_eq!(consumed.len(), 2);
2631        assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
2632        assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
2633
2634        let consumed = store
2635            .consume_forget_marks(group)
2636            .await
2637            .expect("consume failed");
2638        assert!(consumed.is_empty());
2639    }
2640
2641    #[tokio::test]
2642    async fn test_tc_token_put_and_get() {
2643        let store = create_test_store().await;
2644
2645        let entry = TcTokenEntry {
2646            token: vec![1, 2, 3, 4, 5],
2647            token_timestamp: 1707000000,
2648            sender_timestamp: Some(1707000100),
2649        };
2650
2651        store
2652            .put_tc_token("user@lid", &entry)
2653            .await
2654            .expect("put failed");
2655
2656        let loaded = store
2657            .get_tc_token("user@lid")
2658            .await
2659            .expect("get failed")
2660            .expect("should exist");
2661
2662        assert_eq!(loaded.token, vec![1, 2, 3, 4, 5]);
2663        assert_eq!(loaded.token_timestamp, 1707000000);
2664        assert_eq!(loaded.sender_timestamp, Some(1707000100));
2665    }
2666
2667    #[tokio::test]
2668    async fn test_tc_token_upsert() {
2669        let store = create_test_store().await;
2670
2671        let entry1 = TcTokenEntry {
2672            token: vec![1, 2, 3],
2673            token_timestamp: 1000,
2674            sender_timestamp: None,
2675        };
2676        store.put_tc_token("user@lid", &entry1).await.unwrap();
2677
2678        let entry2 = TcTokenEntry {
2679            token: vec![4, 5, 6],
2680            token_timestamp: 2000,
2681            sender_timestamp: Some(1500),
2682        };
2683        store.put_tc_token("user@lid", &entry2).await.unwrap();
2684
2685        let loaded = store.get_tc_token("user@lid").await.unwrap().unwrap();
2686        assert_eq!(loaded.token, vec![4, 5, 6]);
2687        assert_eq!(loaded.token_timestamp, 2000);
2688        assert_eq!(loaded.sender_timestamp, Some(1500));
2689    }
2690
2691    #[tokio::test]
2692    async fn test_tc_token_delete() {
2693        let store = create_test_store().await;
2694
2695        let entry = TcTokenEntry {
2696            token: vec![1, 2, 3],
2697            token_timestamp: 1000,
2698            sender_timestamp: None,
2699        };
2700        store.put_tc_token("user@lid", &entry).await.unwrap();
2701        store.delete_tc_token("user@lid").await.unwrap();
2702
2703        let result = store.get_tc_token("user@lid").await.unwrap();
2704        assert!(result.is_none());
2705    }
2706
2707    #[tokio::test]
2708    async fn test_tc_token_get_all_jids() {
2709        let store = create_test_store().await;
2710
2711        let entry = TcTokenEntry {
2712            token: vec![1],
2713            token_timestamp: 1000,
2714            sender_timestamp: None,
2715        };
2716        store.put_tc_token("user1@lid", &entry).await.unwrap();
2717        store.put_tc_token("user2@lid", &entry).await.unwrap();
2718        store.put_tc_token("user3@lid", &entry).await.unwrap();
2719
2720        let mut jids = store.get_all_tc_token_jids().await.unwrap();
2721        jids.sort();
2722        assert_eq!(jids, vec!["user1@lid", "user2@lid", "user3@lid"]);
2723    }
2724
2725    #[tokio::test]
2726    async fn test_tc_token_delete_expired() {
2727        let store = create_test_store().await;
2728
2729        let old = TcTokenEntry {
2730            token: vec![1],
2731            token_timestamp: 1000,
2732            sender_timestamp: None,
2733        };
2734        let recent = TcTokenEntry {
2735            token: vec![2],
2736            token_timestamp: 5000,
2737            sender_timestamp: None,
2738        };
2739        store.put_tc_token("old@lid", &old).await.unwrap();
2740        store.put_tc_token("recent@lid", &recent).await.unwrap();
2741
2742        let deleted = store.delete_expired_tc_tokens(3000).await.unwrap();
2743        assert_eq!(deleted, 1);
2744
2745        assert!(store.get_tc_token("old@lid").await.unwrap().is_none());
2746        assert!(store.get_tc_token("recent@lid").await.unwrap().is_some());
2747    }
2748
2749    #[tokio::test]
2750    async fn test_tc_token_get_nonexistent() {
2751        let store = create_test_store().await;
2752        let result = store.get_tc_token("nonexistent@lid").await.unwrap();
2753        assert!(result.is_none());
2754    }
2755
2756    #[tokio::test]
2757    async fn test_sender_key_status_different_groups() {
2758        let store = create_test_store().await;
2759
2760        let group1 = "group1@g.us";
2761        let group2 = "group2@g.us";
2762        let participant = "user@s.whatsapp.net";
2763
2764        store
2765            .mark_forget_sender_key(group1, participant)
2766            .await
2767            .expect("mark failed");
2768
2769        let consumed = store.consume_forget_marks(group1).await.unwrap();
2770        assert_eq!(consumed.len(), 1);
2771
2772        let consumed = store.consume_forget_marks(group2).await.unwrap();
2773        assert!(consumed.is_empty());
2774    }
2775}