1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::result::{DatabaseErrorKind, Error as DieselError};
6use diesel::sql_query;
7use diesel::sqlite::SqliteConnection;
8use diesel::upsert::excluded;
9use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
10use log::warn;
11use std::sync::Arc;
12use wacore::appstate::hash::HashState;
13use wacore::appstate::processor::AppStateMutationMAC;
14use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
15use wacore::store::Device as CoreDevice;
16use wacore::store::error::{Result, StoreError};
17use wacore::store::traits::*;
18use wacore_binary::jid::Jid;
19
20enum DieselOrStore {
24 Diesel(DieselError),
25 Store(StoreError),
26}
27
28impl From<DieselOrStore> for StoreError {
29 fn from(e: DieselOrStore) -> Self {
30 match e {
31 DieselOrStore::Diesel(e) => StoreError::Database(e.to_string()),
32 DieselOrStore::Store(e) => e,
33 }
34 }
35}
36
37fn is_retriable_sqlite_error(error: &DieselError) -> bool {
43 match error {
44 DieselError::DatabaseError(DatabaseErrorKind::Unknown, info) => {
45 let msg = info.message();
46 msg.contains("locked") || msg.contains("busy")
47 }
48 _ => false,
49 }
50}
51
52const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
53
54type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
55type DeviceRow = (
56 i32,
57 String,
58 String,
59 i32,
60 Vec<u8>,
61 Vec<u8>,
62 Vec<u8>,
63 i32,
64 Vec<u8>,
65 Vec<u8>,
66 Option<Vec<u8>>,
67 String,
68 i32,
69 i32,
70 i64,
71 i64,
72 Option<Vec<u8>>,
73 Option<String>,
74 i32,
75);
76
77#[derive(Clone)]
78pub struct SqliteStore {
79 pub(crate) pool: SqlitePool,
80 pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
81 pub(crate) database_path: String,
82 device_id: i32,
83}
84
85#[derive(Debug, Clone, Copy)]
86struct ConnectionOptions;
87
88impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
89 for ConnectionOptions
90{
91 fn on_acquire(
92 &self,
93 conn: &mut SqliteConnection,
94 ) -> std::result::Result<(), diesel::r2d2::Error> {
95 diesel::sql_query("PRAGMA busy_timeout = 30000;")
96 .execute(conn)
97 .map_err(diesel::r2d2::Error::QueryError)?;
98 diesel::sql_query("PRAGMA synchronous = NORMAL;")
99 .execute(conn)
100 .map_err(diesel::r2d2::Error::QueryError)?;
101 diesel::sql_query("PRAGMA cache_size = 512;")
102 .execute(conn)
103 .map_err(diesel::r2d2::Error::QueryError)?;
104 diesel::sql_query("PRAGMA temp_store = memory;")
105 .execute(conn)
106 .map_err(diesel::r2d2::Error::QueryError)?;
107 diesel::sql_query("PRAGMA foreign_keys = ON;")
108 .execute(conn)
109 .map_err(diesel::r2d2::Error::QueryError)?;
110 Ok(())
111 }
112}
113
114fn parse_database_path(database_url: &str) -> Result<String> {
115 if database_url == ":memory:" {
117 return Err(StoreError::Database(
118 "Snapshot not supported for in-memory databases".to_string(),
119 ));
120 }
121
122 let path = database_url
124 .split(['?', '#'])
125 .next()
126 .unwrap_or(database_url);
127
128 let path = path.trim_start_matches("sqlite://");
130
131 if path == ":memory:" || path.starts_with(":memory:?") {
133 return Err(StoreError::Database(
134 "Snapshot not supported for in-memory databases".to_string(),
135 ));
136 }
137
138 Ok(path.to_string())
139}
140
141impl SqliteStore {
142 pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
143 let manager = ConnectionManager::<SqliteConnection>::new(database_url);
144
145 let pool_size = 2;
146
147 let pool = Pool::builder()
148 .max_size(pool_size)
149 .connection_customizer(Box::new(ConnectionOptions))
150 .build(manager)
151 .map_err(|e| StoreError::Connection(e.to_string()))?;
152
153 let pool_clone = pool.clone();
154 tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
155 let mut conn = pool_clone
156 .get()
157 .map_err(|e| StoreError::Connection(e.to_string()))?;
158
159 diesel::sql_query("PRAGMA journal_mode = WAL;")
160 .execute(&mut conn)
161 .map_err(|e| StoreError::Database(e.to_string()))?;
162
163 conn.run_pending_migrations(MIGRATIONS)
164 .map_err(|e| StoreError::Migration(e.to_string()))?;
165
166 Ok(())
167 })
168 .await
169 .map_err(|e| StoreError::Database(e.to_string()))??;
170
171 let database_path = parse_database_path(database_url)?;
172
173 Ok(Self {
174 pool,
175 db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
176 database_path,
177 device_id: 1,
178 })
179 }
180
181 pub async fn new_for_device(
182 database_url: &str,
183 device_id: i32,
184 ) -> std::result::Result<Self, StoreError> {
185 let mut store = Self::new(database_url).await?;
186 store.device_id = device_id;
187 Ok(store)
188 }
189
190 pub fn device_id(&self) -> i32 {
191 self.device_id
192 }
193
194 async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
195 where
196 F: FnOnce() -> Result<T> + Send + 'static,
197 T: Send + 'static,
198 {
199 let permit = self
200 .db_semaphore
201 .clone()
202 .acquire_owned()
203 .await
204 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
205 let result = tokio::task::spawn_blocking(move || {
206 let res = f();
207 drop(permit);
208 res
209 })
210 .await
211 .map_err(|e| StoreError::Database(e.to_string()))??;
212 Ok(result)
213 }
214
215 async fn with_retry<F, T>(&self, op_name: &str, make_op: F) -> Result<T>
219 where
220 F: Fn() -> Box<
221 dyn FnOnce(&mut SqliteConnection) -> std::result::Result<T, DieselError> + Send,
222 >,
223 T: Send + 'static,
224 {
225 const MAX_RETRIES: u32 = 5;
226
227 for attempt in 0..=MAX_RETRIES {
228 let permit = self
229 .db_semaphore
230 .clone()
231 .acquire_owned()
232 .await
233 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
234
235 let pool = self.pool.clone();
236 let op = make_op();
237
238 let result =
239 tokio::task::spawn_blocking(move || -> std::result::Result<T, DieselOrStore> {
240 let _permit = permit;
241 let mut conn = pool
242 .get()
243 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
244 op(&mut conn).map_err(DieselOrStore::Diesel)
245 })
246 .await;
247
248 match result {
249 Ok(Ok(val)) => return Ok(val),
250 Ok(Err(DieselOrStore::Diesel(ref e)))
251 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
252 {
253 let delay_ms = 10u64 * (1u64 << attempt.min(4));
254 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
255 }
256 Ok(Err(e)) => return Err(e.into()),
257 Err(e) => return Err(StoreError::Database(e.to_string())),
258 }
259 }
260
261 Err(StoreError::Database(format!(
262 "{} exhausted retries",
263 op_name
264 )))
265 }
266
267 fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
268 let mut bytes = Vec::with_capacity(64);
269 bytes.extend_from_slice(key_pair.private_key.serialize());
270 bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
271 Ok(bytes)
272 }
273
274 fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
275 if bytes.len() != 64 {
276 return Err(StoreError::Serialization(format!(
277 "Invalid KeyPair length: {}",
278 bytes.len()
279 )));
280 }
281
282 let private_key = PrivateKey::deserialize(&bytes[0..32])
283 .map_err(|e| StoreError::Serialization(e.to_string()))?;
284 let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
285 .map_err(|e| StoreError::Serialization(e.to_string()))?;
286
287 Ok(KeyPair::new(public_key, private_key))
288 }
289
290 pub async fn save_device_data_for_device(
291 &self,
292 device_id: i32,
293 device_data: &CoreDevice,
294 ) -> Result<()> {
295 let pool = self.pool.clone();
296 let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
297 let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
298 let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
299 let account_data = device_data
300 .account
301 .as_ref()
302 .map(wacore::store::device::account_serde::to_bytes);
303 let registration_id = device_data.registration_id as i32;
304 let signed_pre_key_id = device_data.signed_pre_key_id as i32;
305 let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
306 let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
307 let push_name = device_data.push_name.clone();
308 let app_version_primary = device_data.app_version_primary as i32;
309 let app_version_secondary = device_data.app_version_secondary as i32;
310 let app_version_tertiary = device_data.app_version_tertiary as i64;
311 let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
312 let edge_routing_info = device_data.edge_routing_info.clone();
313 let props_hash = device_data.props_hash.clone();
314 let next_pre_key_id = device_data.next_pre_key_id as i32;
315 let new_lid = device_data
316 .lid
317 .as_ref()
318 .map(|j| j.to_string())
319 .unwrap_or_default();
320 let new_pn = device_data
321 .pn
322 .as_ref()
323 .map(|j| j.to_string())
324 .unwrap_or_default();
325
326 tokio::task::spawn_blocking(move || -> Result<()> {
327 let mut conn = pool
328 .get()
329 .map_err(|e| StoreError::Connection(e.to_string()))?;
330
331 diesel::insert_into(device::table)
332 .values((
333 device::id.eq(device_id),
334 device::lid.eq(&new_lid),
335 device::pn.eq(&new_pn),
336 device::registration_id.eq(registration_id),
337 device::noise_key.eq(&noise_key_data),
338 device::identity_key.eq(&identity_key_data),
339 device::signed_pre_key.eq(&signed_pre_key_data),
340 device::signed_pre_key_id.eq(signed_pre_key_id),
341 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
342 device::adv_secret_key.eq(&adv_secret_key[..]),
343 device::account.eq(account_data.clone()),
344 device::push_name.eq(&push_name),
345 device::app_version_primary.eq(app_version_primary),
346 device::app_version_secondary.eq(app_version_secondary),
347 device::app_version_tertiary.eq(app_version_tertiary),
348 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
349 device::edge_routing_info.eq(edge_routing_info.clone()),
350 device::props_hash.eq(props_hash.clone()),
351 device::next_pre_key_id.eq(next_pre_key_id),
352 ))
353 .on_conflict(device::id)
354 .do_update()
355 .set((
356 device::lid.eq(&new_lid),
357 device::pn.eq(&new_pn),
358 device::registration_id.eq(registration_id),
359 device::noise_key.eq(&noise_key_data),
360 device::identity_key.eq(&identity_key_data),
361 device::signed_pre_key.eq(&signed_pre_key_data),
362 device::signed_pre_key_id.eq(signed_pre_key_id),
363 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
364 device::adv_secret_key.eq(&adv_secret_key[..]),
365 device::account.eq(account_data.clone()),
366 device::push_name.eq(&push_name),
367 device::app_version_primary.eq(app_version_primary),
368 device::app_version_secondary.eq(app_version_secondary),
369 device::app_version_tertiary.eq(app_version_tertiary),
370 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
371 device::edge_routing_info.eq(edge_routing_info),
372 device::props_hash.eq(props_hash),
373 device::next_pre_key_id.eq(next_pre_key_id),
374 ))
375 .execute(&mut conn)
376 .map_err(|e| StoreError::Database(e.to_string()))?;
377
378 Ok(())
379 })
380 .await
381 .map_err(|e| StoreError::Database(e.to_string()))??;
382
383 Ok(())
384 }
385
386 pub async fn create_new_device(&self) -> Result<i32> {
387 use crate::schema::device;
388
389 let pool = self.pool.clone();
390 tokio::task::spawn_blocking(move || -> Result<i32> {
391 let mut conn = pool
392 .get()
393 .map_err(|e| StoreError::Connection(e.to_string()))?;
394
395 let new_device = wacore::store::Device::new();
396
397 let noise_key_data = {
398 let mut bytes = Vec::with_capacity(64);
399 bytes.extend_from_slice(new_device.noise_key.private_key.serialize());
400 bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
401 bytes
402 };
403 let identity_key_data = {
404 let mut bytes = Vec::with_capacity(64);
405 bytes.extend_from_slice(new_device.identity_key.private_key.serialize());
406 bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
407 bytes
408 };
409 let signed_pre_key_data = {
410 let mut bytes = Vec::with_capacity(64);
411 bytes.extend_from_slice(new_device.signed_pre_key.private_key.serialize());
412 bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
413 bytes
414 };
415
416 diesel::insert_into(device::table)
417 .values((
418 device::lid.eq(""),
419 device::pn.eq(""),
420 device::registration_id.eq(new_device.registration_id as i32),
421 device::noise_key.eq(&noise_key_data),
422 device::identity_key.eq(&identity_key_data),
423 device::signed_pre_key.eq(&signed_pre_key_data),
424 device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
425 device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
426 device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
427 device::account.eq(None::<Vec<u8>>),
428 device::push_name.eq(&new_device.push_name),
429 device::app_version_primary.eq(new_device.app_version_primary as i32),
430 device::app_version_secondary.eq(new_device.app_version_secondary as i32),
431 device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
432 device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
433 device::edge_routing_info.eq(None::<Vec<u8>>),
434 device::props_hash.eq(None::<String>),
435 device::next_pre_key_id.eq(new_device.next_pre_key_id as i32),
436 ))
437 .execute(&mut conn)
438 .map_err(|e| StoreError::Database(e.to_string()))?;
439
440 use diesel::sql_types::Integer;
441
442 #[derive(QueryableByName)]
443 struct LastInsertedId {
444 #[diesel(sql_type = Integer)]
445 last_insert_rowid: i32,
446 }
447
448 let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
449 .get_result::<LastInsertedId>(&mut conn)
450 .map_err(|e| StoreError::Database(e.to_string()))?
451 .last_insert_rowid;
452
453 Ok(device_id)
454 })
455 .await
456 .map_err(|e| StoreError::Database(e.to_string()))?
457 }
458
459 pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
460 use crate::schema::device;
461
462 let pool = self.pool.clone();
463 tokio::task::spawn_blocking(move || -> Result<bool> {
464 let mut conn = pool
465 .get()
466 .map_err(|e| StoreError::Connection(e.to_string()))?;
467
468 let count: i64 = device::table
469 .filter(device::id.eq(device_id))
470 .count()
471 .get_result(&mut conn)
472 .map_err(|e| StoreError::Database(e.to_string()))?;
473
474 Ok(count > 0)
475 })
476 .await
477 .map_err(|e| StoreError::Database(e.to_string()))?
478 }
479
480 pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
481 use crate::schema::device;
482
483 let pool = self.pool.clone();
484 let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
485 let mut conn = pool
486 .get()
487 .map_err(|e| StoreError::Connection(e.to_string()))?;
488 let result = device::table
489 .filter(device::id.eq(device_id))
490 .first::<DeviceRow>(&mut conn)
491 .optional()
492 .map_err(|e| StoreError::Database(e.to_string()))?;
493 Ok(result)
494 })
495 .await
496 .map_err(|e| StoreError::Database(e.to_string()))??;
497
498 if let Some((
499 _device_id,
500 lid_str,
501 pn_str,
502 registration_id,
503 noise_key_data,
504 identity_key_data,
505 signed_pre_key_data,
506 signed_pre_key_id,
507 signed_pre_key_signature_data,
508 adv_secret_key_data,
509 account_data,
510 push_name,
511 app_version_primary,
512 app_version_secondary,
513 app_version_tertiary,
514 app_version_last_fetched_ms,
515 edge_routing_info,
516 props_hash,
517 next_pre_key_id,
518 )) = row
519 {
520 let id = if !pn_str.is_empty() {
521 pn_str.parse().ok()
522 } else {
523 None
524 };
525 let lid = if !lid_str.is_empty() {
526 lid_str.parse().ok()
527 } else {
528 None
529 };
530
531 let noise_key = self.deserialize_keypair(&noise_key_data)?;
532 let identity_key = self.deserialize_keypair(&identity_key_data)?;
533 let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
534
535 let signed_pre_key_signature: [u8; 64] =
536 signed_pre_key_signature_data.try_into().map_err(|_| {
537 StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
538 })?;
539
540 let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
541 StoreError::Serialization("Invalid adv_secret_key length".to_string())
542 })?;
543
544 let account = account_data
545 .map(|data| {
546 wacore::store::device::account_serde::from_bytes(&data)
547 .map_err(|e| StoreError::Serialization(e.to_string()))
548 })
549 .transpose()?;
550
551 Ok(Some(CoreDevice {
552 pn: id,
553 lid,
554 registration_id: registration_id as u32,
555 noise_key,
556 identity_key,
557 signed_pre_key,
558 signed_pre_key_id: signed_pre_key_id as u32,
559 signed_pre_key_signature,
560 adv_secret_key,
561 account,
562 push_name,
563 app_version_primary: app_version_primary as u32,
564 app_version_secondary: app_version_secondary as u32,
565 app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
566 app_version_last_fetched_ms,
567 device_props: {
568 use wacore::store::device::DEVICE_PROPS;
569 DEVICE_PROPS.clone()
570 },
571 edge_routing_info,
572 props_hash,
573 next_pre_key_id: next_pre_key_id as u32,
574 }))
575 } else {
576 Ok(None)
577 }
578 }
579
580 pub async fn put_identity_for_device(
581 &self,
582 address: &str,
583 key: [u8; 32],
584 device_id: i32,
585 ) -> Result<()> {
586 let pool = self.pool.clone();
587 let db_semaphore = self.db_semaphore.clone();
588 let address_owned = address.to_string();
589 let key_vec = key.to_vec();
590
591 const MAX_RETRIES: u32 = 5;
592
593 for attempt in 0..=MAX_RETRIES {
594 let permit =
595 db_semaphore.clone().acquire_owned().await.map_err(|e| {
596 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
597 })?;
598
599 let pool_clone = pool.clone();
600 let address_clone = address_owned.clone();
601 let key_clone = key_vec.clone();
602
603 let result =
604 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
605 let mut conn = pool_clone
606 .get()
607 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
608 diesel::insert_into(identities::table)
609 .values((
610 identities::address.eq(address_clone),
611 identities::key.eq(&key_clone[..]),
612 identities::device_id.eq(device_id),
613 ))
614 .on_conflict((identities::address, identities::device_id))
615 .do_update()
616 .set(identities::key.eq(&key_clone[..]))
617 .execute(&mut conn)
618 .map_err(DieselOrStore::Diesel)?;
619 Ok(())
620 })
621 .await;
622
623 drop(permit);
624
625 match result {
626 Ok(Ok(())) => return Ok(()),
627 Ok(Err(DieselOrStore::Diesel(ref e)))
628 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
629 {
630 let delay_ms = 10 * 2u64.pow(attempt);
631 warn!(
632 "Identity write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
633 attempt + 1,
634 MAX_RETRIES + 1,
635 );
636 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
637 continue;
638 }
639 Ok(Err(e)) => return Err(e.into()),
640 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
641 }
642 }
643
644 Err(StoreError::Database(format!(
645 "Identity write failed after {} attempts",
646 MAX_RETRIES + 1
647 )))
648 }
649
650 pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
651 let pool = self.pool.clone();
652 let address_owned = address.to_string();
653
654 tokio::task::spawn_blocking(move || -> Result<()> {
655 let mut conn = pool
656 .get()
657 .map_err(|e| StoreError::Connection(e.to_string()))?;
658 diesel::delete(
659 identities::table
660 .filter(identities::address.eq(address_owned))
661 .filter(identities::device_id.eq(device_id)),
662 )
663 .execute(&mut conn)
664 .map_err(|e| StoreError::Database(e.to_string()))?;
665 Ok(())
666 })
667 .await
668 .map_err(|e| StoreError::Database(e.to_string()))??;
669
670 Ok(())
671 }
672
673 pub async fn load_identity_for_device(
674 &self,
675 address: &str,
676 device_id: i32,
677 ) -> Result<Option<Vec<u8>>> {
678 let pool = self.pool.clone();
679 let address = address.to_string();
680 let result = self
681 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
682 let mut conn = pool
683 .get()
684 .map_err(|e| StoreError::Connection(e.to_string()))?;
685 let res: Option<Vec<u8>> = identities::table
686 .select(identities::key)
687 .filter(identities::address.eq(address))
688 .filter(identities::device_id.eq(device_id))
689 .first(&mut conn)
690 .optional()
691 .map_err(|e| StoreError::Database(e.to_string()))?;
692 Ok(res)
693 })
694 .await?;
695
696 Ok(result)
697 }
698
699 pub async fn get_session_for_device(
700 &self,
701 address: &str,
702 device_id: i32,
703 ) -> Result<Option<Vec<u8>>> {
704 let pool = self.pool.clone();
705 let address_for_query = address.to_string();
706 let result = self
707 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
708 let mut conn = pool
709 .get()
710 .map_err(|e| StoreError::Connection(e.to_string()))?;
711 let res: Option<Vec<u8>> = sessions::table
712 .select(sessions::record)
713 .filter(sessions::address.eq(address_for_query.clone()))
714 .filter(sessions::device_id.eq(device_id))
715 .first(&mut conn)
716 .optional()
717 .map_err(|e| StoreError::Database(e.to_string()))?;
718
719 Ok(res)
720 })
721 .await?;
722
723 Ok(result)
724 }
725
726 pub async fn put_session_for_device(
727 &self,
728 address: &str,
729 session: &[u8],
730 device_id: i32,
731 ) -> Result<()> {
732 let pool = self.pool.clone();
733 let db_semaphore = self.db_semaphore.clone();
734 let address_owned = address.to_string();
735 let session_vec = session.to_vec();
736
737 const MAX_RETRIES: u32 = 5;
738
739 for attempt in 0..=MAX_RETRIES {
740 let permit =
741 db_semaphore.clone().acquire_owned().await.map_err(|e| {
742 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
743 })?;
744
745 let pool_clone = pool.clone();
746 let address_clone = address_owned.clone();
747 let session_clone = session_vec.clone();
748
749 let result =
750 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
751 let mut conn = pool_clone
752 .get()
753 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
754 diesel::insert_into(sessions::table)
755 .values((
756 sessions::address.eq(address_clone),
757 sessions::record.eq(&session_clone),
758 sessions::device_id.eq(device_id),
759 ))
760 .on_conflict((sessions::address, sessions::device_id))
761 .do_update()
762 .set(sessions::record.eq(&session_clone))
763 .execute(&mut conn)
764 .map_err(DieselOrStore::Diesel)?;
765 Ok(())
766 })
767 .await;
768
769 drop(permit);
770
771 match result {
772 Ok(Ok(())) => return Ok(()),
773 Ok(Err(DieselOrStore::Diesel(ref e)))
774 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
775 {
776 let delay_ms = 10 * 2u64.pow(attempt);
777 warn!(
778 "Session write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
779 attempt + 1,
780 MAX_RETRIES + 1,
781 );
782 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
783 continue;
784 }
785 Ok(Err(e)) => return Err(e.into()),
786 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
787 }
788 }
789
790 Err(StoreError::Database(format!(
791 "Session write failed after {} attempts",
792 MAX_RETRIES + 1
793 )))
794 }
795
796 pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
797 let pool = self.pool.clone();
798 let address_owned = address.to_string();
799
800 tokio::task::spawn_blocking(move || -> Result<()> {
801 let mut conn = pool
802 .get()
803 .map_err(|e| StoreError::Connection(e.to_string()))?;
804 diesel::delete(
805 sessions::table
806 .filter(sessions::address.eq(address_owned))
807 .filter(sessions::device_id.eq(device_id)),
808 )
809 .execute(&mut conn)
810 .map_err(|e| StoreError::Database(e.to_string()))?;
811 Ok(())
812 })
813 .await
814 .map_err(|e| StoreError::Database(e.to_string()))??;
815
816 Ok(())
817 }
818
819 pub async fn put_sender_key_for_device(
820 &self,
821 address: &str,
822 record: &[u8],
823 device_id: i32,
824 ) -> Result<()> {
825 let pool = self.pool.clone();
826 let address = address.to_string();
827 let record_vec = record.to_vec();
828 tokio::task::spawn_blocking(move || -> Result<()> {
829 let mut conn = pool
830 .get()
831 .map_err(|e| StoreError::Connection(e.to_string()))?;
832 diesel::insert_into(sender_keys::table)
833 .values((
834 sender_keys::address.eq(address),
835 sender_keys::record.eq(&record_vec),
836 sender_keys::device_id.eq(device_id),
837 ))
838 .on_conflict((sender_keys::address, sender_keys::device_id))
839 .do_update()
840 .set(sender_keys::record.eq(&record_vec))
841 .execute(&mut conn)
842 .map_err(|e| StoreError::Database(e.to_string()))?;
843 Ok(())
844 })
845 .await
846 .map_err(|e| StoreError::Database(e.to_string()))??;
847 Ok(())
848 }
849
850 pub async fn get_sender_key_for_device(
851 &self,
852 address: &str,
853 device_id: i32,
854 ) -> Result<Option<Vec<u8>>> {
855 let pool = self.pool.clone();
856 let address = address.to_string();
857 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
858 let mut conn = pool
859 .get()
860 .map_err(|e| StoreError::Connection(e.to_string()))?;
861 let res: Option<Vec<u8>> = sender_keys::table
862 .select(sender_keys::record)
863 .filter(sender_keys::address.eq(address))
864 .filter(sender_keys::device_id.eq(device_id))
865 .first(&mut conn)
866 .optional()
867 .map_err(|e| StoreError::Database(e.to_string()))?;
868 Ok(res)
869 })
870 .await
871 .map_err(|e| StoreError::Database(e.to_string()))?
872 }
873
874 pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
875 let pool = self.pool.clone();
876 let address = address.to_string();
877 tokio::task::spawn_blocking(move || -> Result<()> {
878 let mut conn = pool
879 .get()
880 .map_err(|e| StoreError::Connection(e.to_string()))?;
881 diesel::delete(
882 sender_keys::table
883 .filter(sender_keys::address.eq(address))
884 .filter(sender_keys::device_id.eq(device_id)),
885 )
886 .execute(&mut conn)
887 .map_err(|e| StoreError::Database(e.to_string()))?;
888 Ok(())
889 })
890 .await
891 .map_err(|e| StoreError::Database(e.to_string()))??;
892 Ok(())
893 }
894
895 pub async fn get_app_state_sync_key_for_device(
896 &self,
897 key_id: &[u8],
898 device_id: i32,
899 ) -> Result<Option<AppStateSyncKey>> {
900 let pool = self.pool.clone();
901 let key_id = key_id.to_vec();
902 let res: Option<Vec<u8>> =
903 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
904 let mut conn = pool
905 .get()
906 .map_err(|e| StoreError::Connection(e.to_string()))?;
907 let res: Option<Vec<u8>> = app_state_keys::table
908 .select(app_state_keys::key_data)
909 .filter(app_state_keys::key_id.eq(&key_id))
910 .filter(app_state_keys::device_id.eq(device_id))
911 .first(&mut conn)
912 .optional()
913 .map_err(|e| StoreError::Database(e.to_string()))?;
914 Ok(res)
915 })
916 .await
917 .map_err(|e| StoreError::Database(e.to_string()))??;
918
919 if let Some(data) = res {
920 let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
921 .map_err(|e| StoreError::Serialization(e.to_string()))?;
922 Ok(Some(key))
923 } else {
924 Ok(None)
925 }
926 }
927
928 pub async fn set_app_state_sync_key_for_device(
929 &self,
930 key_id: &[u8],
931 key: AppStateSyncKey,
932 device_id: i32,
933 ) -> Result<()> {
934 let pool = self.pool.clone();
935 let key_id = key_id.to_vec();
936 let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
937 .map_err(|e| StoreError::Serialization(e.to_string()))?;
938 tokio::task::spawn_blocking(move || -> Result<()> {
939 let mut conn = pool
940 .get()
941 .map_err(|e| StoreError::Connection(e.to_string()))?;
942 diesel::insert_into(app_state_keys::table)
943 .values((
944 app_state_keys::key_id.eq(&key_id),
945 app_state_keys::key_data.eq(&data),
946 app_state_keys::device_id.eq(device_id),
947 ))
948 .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
949 .do_update()
950 .set(app_state_keys::key_data.eq(&data))
951 .execute(&mut conn)
952 .map_err(|e| StoreError::Database(e.to_string()))?;
953 Ok(())
954 })
955 .await
956 .map_err(|e| StoreError::Database(e.to_string()))??;
957 Ok(())
958 }
959
960 pub async fn get_latest_app_state_sync_key_id_for_device(
961 &self,
962 device_id: i32,
963 ) -> Result<Option<Vec<u8>>> {
964 let pool = self.pool.clone();
965 let res: Option<Vec<u8>> =
966 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
967 let mut conn = pool
968 .get()
969 .map_err(|e| StoreError::Connection(e.to_string()))?;
970 let res: Option<Vec<u8>> = app_state_keys::table
971 .select(app_state_keys::key_id)
972 .filter(app_state_keys::device_id.eq(device_id))
973 .order(app_state_keys::key_id.desc())
974 .first(&mut conn)
975 .optional()
976 .map_err(|e| StoreError::Database(e.to_string()))?;
977 Ok(res)
978 })
979 .await
980 .map_err(|e| StoreError::Database(e.to_string()))??;
981 Ok(res)
982 }
983
984 pub async fn get_app_state_version_for_device(
985 &self,
986 name: &str,
987 device_id: i32,
988 ) -> Result<HashState> {
989 let pool = self.pool.clone();
990 let name = name.to_string();
991 let res: Option<Vec<u8>> =
992 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
993 let mut conn = pool
994 .get()
995 .map_err(|e| StoreError::Connection(e.to_string()))?;
996 let res: Option<Vec<u8>> = app_state_versions::table
997 .select(app_state_versions::state_data)
998 .filter(app_state_versions::name.eq(name))
999 .filter(app_state_versions::device_id.eq(device_id))
1000 .first(&mut conn)
1001 .optional()
1002 .map_err(|e| StoreError::Database(e.to_string()))?;
1003 Ok(res)
1004 })
1005 .await
1006 .map_err(|e| StoreError::Database(e.to_string()))??;
1007
1008 if let Some(data) = res {
1009 let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
1010 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1011 Ok(state)
1012 } else {
1013 Ok(HashState::default())
1014 }
1015 }
1016
1017 pub async fn set_app_state_version_for_device(
1018 &self,
1019 name: &str,
1020 state: HashState,
1021 device_id: i32,
1022 ) -> Result<()> {
1023 let name = name.to_string();
1024 let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
1025 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1026 self.with_retry("set_app_state_version", || {
1027 let name = name.clone();
1028 let data = data.clone();
1029 Box::new(move |conn: &mut SqliteConnection| {
1030 diesel::insert_into(app_state_versions::table)
1031 .values((
1032 app_state_versions::name.eq(&name),
1033 app_state_versions::state_data.eq(&data),
1034 app_state_versions::device_id.eq(device_id),
1035 ))
1036 .on_conflict((app_state_versions::name, app_state_versions::device_id))
1037 .do_update()
1038 .set(app_state_versions::state_data.eq(&data))
1039 .execute(conn)?;
1040 Ok(())
1041 })
1042 })
1043 .await
1044 }
1045
1046 pub async fn put_app_state_mutation_macs_for_device(
1047 &self,
1048 name: &str,
1049 version: u64,
1050 mutations: &[AppStateMutationMAC],
1051 device_id: i32,
1052 ) -> Result<()> {
1053 if mutations.is_empty() {
1054 return Ok(());
1055 }
1056 let name = name.to_string();
1057 let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
1058 self.with_retry("put_app_state_mutation_macs", || {
1059 let name = name.clone();
1060 let mutations = mutations.clone();
1061 Box::new(move |conn: &mut SqliteConnection| {
1062 let records: Vec<_> = mutations
1063 .iter()
1064 .map(|m| {
1065 (
1066 app_state_mutation_macs::name.eq(&name),
1067 app_state_mutation_macs::version.eq(version as i64),
1068 app_state_mutation_macs::index_mac.eq(&m.index_mac),
1069 app_state_mutation_macs::value_mac.eq(&m.value_mac),
1070 app_state_mutation_macs::device_id.eq(device_id),
1071 )
1072 })
1073 .collect();
1074
1075 const CHUNK_SIZE: usize = 100;
1078
1079 for chunk in records.chunks(CHUNK_SIZE) {
1080 diesel::insert_into(app_state_mutation_macs::table)
1081 .values(chunk)
1082 .on_conflict((
1083 app_state_mutation_macs::name,
1084 app_state_mutation_macs::index_mac,
1085 app_state_mutation_macs::device_id,
1086 ))
1087 .do_update()
1088 .set((
1089 app_state_mutation_macs::version
1090 .eq(excluded(app_state_mutation_macs::version)),
1091 app_state_mutation_macs::value_mac
1092 .eq(excluded(app_state_mutation_macs::value_mac)),
1093 ))
1094 .execute(conn)?;
1095 }
1096 Ok(())
1097 })
1098 })
1099 .await
1100 }
1101
1102 pub async fn delete_app_state_mutation_macs_for_device(
1103 &self,
1104 name: &str,
1105 index_macs: &[Vec<u8>],
1106 device_id: i32,
1107 ) -> Result<()> {
1108 if index_macs.is_empty() {
1109 return Ok(());
1110 }
1111 let name = name.to_string();
1112 let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
1113 self.with_retry("delete_app_state_mutation_macs", || {
1114 let name = name.clone();
1115 let index_macs = index_macs.clone();
1116 Box::new(move |conn: &mut SqliteConnection| {
1117 const CHUNK_SIZE: usize = 500;
1120
1121 for chunk in index_macs.chunks(CHUNK_SIZE) {
1122 diesel::delete(
1123 app_state_mutation_macs::table.filter(
1124 app_state_mutation_macs::name
1125 .eq(&name)
1126 .and(app_state_mutation_macs::index_mac.eq_any(chunk))
1127 .and(app_state_mutation_macs::device_id.eq(device_id)),
1128 ),
1129 )
1130 .execute(conn)?;
1131 }
1132 Ok(())
1133 })
1134 })
1135 .await
1136 }
1137
1138 pub async fn get_app_state_mutation_mac_for_device(
1139 &self,
1140 name: &str,
1141 index_mac: &[u8],
1142 device_id: i32,
1143 ) -> Result<Option<Vec<u8>>> {
1144 let pool = self.pool.clone();
1145 let name = name.to_string();
1146 let index_mac = index_mac.to_vec();
1147 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1148 let mut conn = pool
1149 .get()
1150 .map_err(|e| StoreError::Connection(e.to_string()))?;
1151 let res: Option<Vec<u8>> = app_state_mutation_macs::table
1152 .select(app_state_mutation_macs::value_mac)
1153 .filter(app_state_mutation_macs::name.eq(&name))
1154 .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1155 .filter(app_state_mutation_macs::device_id.eq(device_id))
1156 .first(&mut conn)
1157 .optional()
1158 .map_err(|e| StoreError::Database(e.to_string()))?;
1159 Ok(res)
1160 })
1161 .await
1162 .map_err(|e| StoreError::Database(e.to_string()))?
1163 }
1164}
1165
1166#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1167#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1168impl SignalStore for SqliteStore {
1169 async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1170 self.put_identity_for_device(address, key, self.device_id)
1171 .await
1172 }
1173
1174 async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1175 self.load_identity_for_device(address, self.device_id).await
1176 }
1177
1178 async fn delete_identity(&self, address: &str) -> Result<()> {
1179 self.delete_identity_for_device(address, self.device_id)
1180 .await
1181 }
1182
1183 async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1184 self.get_session_for_device(address, self.device_id).await
1185 }
1186
1187 async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1188 self.put_session_for_device(address, session, self.device_id)
1189 .await
1190 }
1191
1192 async fn delete_session(&self, address: &str) -> Result<()> {
1193 self.delete_session_for_device(address, self.device_id)
1194 .await
1195 }
1196
1197 async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1198 let pool = self.pool.clone();
1199 let db_semaphore = self.db_semaphore.clone();
1200 let device_id = self.device_id;
1201 let record = record.to_vec();
1202
1203 const MAX_RETRIES: u32 = 5;
1204
1205 for attempt in 0..=MAX_RETRIES {
1206 let permit =
1207 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1208 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1209 })?;
1210
1211 let pool_clone = pool.clone();
1212 let record_clone = record.clone();
1213
1214 let result =
1215 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1216 let mut conn = pool_clone
1217 .get()
1218 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1219 diesel::insert_into(prekeys::table)
1220 .values((
1221 prekeys::id.eq(id as i32),
1222 prekeys::key.eq(&record_clone),
1223 prekeys::uploaded.eq(uploaded),
1224 prekeys::device_id.eq(device_id),
1225 ))
1226 .on_conflict((prekeys::id, prekeys::device_id))
1227 .do_update()
1228 .set((
1229 prekeys::key.eq(&record_clone),
1230 prekeys::uploaded.eq(uploaded),
1231 ))
1232 .execute(&mut conn)
1233 .map_err(DieselOrStore::Diesel)?;
1234 Ok(())
1235 })
1236 .await;
1237
1238 drop(permit);
1239
1240 match result {
1241 Ok(Ok(())) => return Ok(()),
1242 Ok(Err(DieselOrStore::Diesel(ref e)))
1243 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1244 {
1245 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1246 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1247 }
1248 Ok(Err(e)) => return Err(e.into()),
1249 Err(e) => return Err(StoreError::Database(e.to_string())),
1250 }
1251 }
1252
1253 Err(StoreError::Database(
1254 "store_prekey exhausted retries".to_string(),
1255 ))
1256 }
1257
1258 async fn store_prekeys_batch(&self, keys: &[(u32, Vec<u8>)], uploaded: bool) -> Result<()> {
1259 if keys.is_empty() {
1260 return Ok(());
1261 }
1262
1263 let pool = self.pool.clone();
1264 let db_semaphore = self.db_semaphore.clone();
1265 let device_id = self.device_id;
1266 let keys = keys.to_vec();
1267
1268 const MAX_RETRIES: u32 = 5;
1269
1270 for attempt in 0..=MAX_RETRIES {
1271 let permit =
1272 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1273 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1274 })?;
1275
1276 let pool_clone = pool.clone();
1277 let keys_clone = keys.clone();
1278
1279 let result =
1280 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1281 let mut conn = pool_clone
1282 .get()
1283 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1284
1285 conn.transaction(|conn| {
1286 for (id, record) in &keys_clone {
1287 diesel::insert_into(prekeys::table)
1288 .values((
1289 prekeys::id.eq(*id as i32),
1290 prekeys::key.eq(record),
1291 prekeys::uploaded.eq(uploaded),
1292 prekeys::device_id.eq(device_id),
1293 ))
1294 .on_conflict((prekeys::id, prekeys::device_id))
1295 .do_update()
1296 .set((prekeys::key.eq(record), prekeys::uploaded.eq(uploaded)))
1297 .execute(conn)?;
1298 }
1299 Ok::<(), diesel::result::Error>(())
1300 })
1301 .map_err(DieselOrStore::Diesel)
1302 })
1303 .await;
1304
1305 drop(permit);
1306
1307 match result {
1308 Ok(Ok(())) => return Ok(()),
1309 Ok(Err(DieselOrStore::Diesel(ref e)))
1310 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1311 {
1312 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1313 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1314 }
1315 Ok(Err(e)) => return Err(e.into()),
1316 Err(e) => return Err(StoreError::Database(e.to_string())),
1317 }
1318 }
1319
1320 Err(StoreError::Database(
1321 "store_prekeys_batch exhausted retries".to_string(),
1322 ))
1323 }
1324
1325 async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1326 let pool = self.pool.clone();
1327 let device_id = self.device_id;
1328 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1329 let mut conn = pool
1330 .get()
1331 .map_err(|e| StoreError::Connection(e.to_string()))?;
1332 let res: Option<Vec<u8>> = prekeys::table
1333 .select(prekeys::key)
1334 .filter(prekeys::id.eq(id as i32))
1335 .filter(prekeys::device_id.eq(device_id))
1336 .first(&mut conn)
1337 .optional()
1338 .map_err(|e| StoreError::Database(e.to_string()))?;
1339 Ok(res)
1340 })
1341 .await
1342 .map_err(|e| StoreError::Database(e.to_string()))?
1343 }
1344
1345 async fn remove_prekey(&self, id: u32) -> Result<()> {
1346 let pool = self.pool.clone();
1347 let db_semaphore = self.db_semaphore.clone();
1348 let device_id = self.device_id;
1349
1350 const MAX_RETRIES: u32 = 5;
1351
1352 for attempt in 0..=MAX_RETRIES {
1353 let permit =
1354 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1355 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1356 })?;
1357
1358 let pool_clone = pool.clone();
1359
1360 let result =
1361 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1362 let mut conn = pool_clone
1363 .get()
1364 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1365 diesel::delete(
1366 prekeys::table
1367 .filter(prekeys::id.eq(id as i32))
1368 .filter(prekeys::device_id.eq(device_id)),
1369 )
1370 .execute(&mut conn)
1371 .map_err(DieselOrStore::Diesel)?;
1372 Ok(())
1373 })
1374 .await;
1375
1376 drop(permit);
1377
1378 match result {
1379 Ok(Ok(())) => return Ok(()),
1380 Ok(Err(DieselOrStore::Diesel(ref e)))
1381 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1382 {
1383 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1384 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1385 }
1386 Ok(Err(e)) => return Err(e.into()),
1387 Err(e) => return Err(StoreError::Database(e.to_string())),
1388 }
1389 }
1390
1391 Err(StoreError::Database(
1392 "remove_prekey exhausted retries".to_string(),
1393 ))
1394 }
1395
1396 async fn get_max_prekey_id(&self) -> Result<u32> {
1397 let pool = self.pool.clone();
1398 let device_id = self.device_id;
1399 let db_semaphore = self.db_semaphore.clone();
1400 let _permit = db_semaphore
1401 .acquire()
1402 .await
1403 .map_err(|e| StoreError::Database(format!("Failed to acquire semaphore: {}", e)))?;
1404
1405 tokio::task::spawn_blocking(move || -> Result<u32> {
1406 let mut conn = pool
1407 .get()
1408 .map_err(|e| StoreError::Connection(e.to_string()))?;
1409 use diesel::dsl::max;
1410 let result: Option<i32> = prekeys::table
1411 .filter(prekeys::device_id.eq(device_id))
1412 .select(max(prekeys::id))
1413 .first(&mut conn)
1414 .map_err(|e| StoreError::Database(e.to_string()))?;
1415 Ok(result.unwrap_or(0) as u32)
1416 })
1417 .await
1418 .map_err(|e| StoreError::Database(e.to_string()))?
1419 }
1420
1421 async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1422 let pool = self.pool.clone();
1423 let db_semaphore = self.db_semaphore.clone();
1424 let device_id = self.device_id;
1425 let record = record.to_vec();
1426
1427 const MAX_RETRIES: u32 = 5;
1428
1429 for attempt in 0..=MAX_RETRIES {
1430 let permit =
1431 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1432 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1433 })?;
1434
1435 let pool_clone = pool.clone();
1436 let record_clone = record.clone();
1437
1438 let result =
1439 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1440 let mut conn = pool_clone
1441 .get()
1442 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1443 diesel::insert_into(signed_prekeys::table)
1444 .values((
1445 signed_prekeys::id.eq(id as i32),
1446 signed_prekeys::record.eq(&record_clone),
1447 signed_prekeys::device_id.eq(device_id),
1448 ))
1449 .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1450 .do_update()
1451 .set(signed_prekeys::record.eq(&record_clone))
1452 .execute(&mut conn)
1453 .map_err(DieselOrStore::Diesel)?;
1454 Ok(())
1455 })
1456 .await;
1457
1458 drop(permit);
1459
1460 match result {
1461 Ok(Ok(())) => return Ok(()),
1462 Ok(Err(DieselOrStore::Diesel(ref e)))
1463 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1464 {
1465 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1466 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1467 }
1468 Ok(Err(e)) => return Err(e.into()),
1469 Err(e) => return Err(StoreError::Database(e.to_string())),
1470 }
1471 }
1472
1473 Err(StoreError::Database(
1474 "store_signed_prekey exhausted retries".to_string(),
1475 ))
1476 }
1477
1478 async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1479 let pool = self.pool.clone();
1480 let device_id = self.device_id;
1481 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1482 let mut conn = pool
1483 .get()
1484 .map_err(|e| StoreError::Connection(e.to_string()))?;
1485 let res: Option<Vec<u8>> = signed_prekeys::table
1486 .select(signed_prekeys::record)
1487 .filter(signed_prekeys::id.eq(id as i32))
1488 .filter(signed_prekeys::device_id.eq(device_id))
1489 .first(&mut conn)
1490 .optional()
1491 .map_err(|e| StoreError::Database(e.to_string()))?;
1492 Ok(res)
1493 })
1494 .await
1495 .map_err(|e| StoreError::Database(e.to_string()))?
1496 }
1497
1498 async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1499 let pool = self.pool.clone();
1500 let device_id = self.device_id;
1501 tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1502 let mut conn = pool
1503 .get()
1504 .map_err(|e| StoreError::Connection(e.to_string()))?;
1505 let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1506 .select((signed_prekeys::id, signed_prekeys::record))
1507 .filter(signed_prekeys::device_id.eq(device_id))
1508 .load(&mut conn)
1509 .map_err(|e| StoreError::Database(e.to_string()))?;
1510 Ok(results
1511 .into_iter()
1512 .map(|(id, record)| (id as u32, record))
1513 .collect())
1514 })
1515 .await
1516 .map_err(|e| StoreError::Database(e.to_string()))?
1517 }
1518
1519 async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1520 let pool = self.pool.clone();
1521 let db_semaphore = self.db_semaphore.clone();
1522 let device_id = self.device_id;
1523
1524 const MAX_RETRIES: u32 = 5;
1525
1526 for attempt in 0..=MAX_RETRIES {
1527 let permit =
1528 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1529 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1530 })?;
1531
1532 let pool_clone = pool.clone();
1533
1534 let result =
1535 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1536 let mut conn = pool_clone
1537 .get()
1538 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1539 diesel::delete(
1540 signed_prekeys::table
1541 .filter(signed_prekeys::id.eq(id as i32))
1542 .filter(signed_prekeys::device_id.eq(device_id)),
1543 )
1544 .execute(&mut conn)
1545 .map_err(DieselOrStore::Diesel)?;
1546 Ok(())
1547 })
1548 .await;
1549
1550 drop(permit);
1551
1552 match result {
1553 Ok(Ok(())) => return Ok(()),
1554 Ok(Err(DieselOrStore::Diesel(ref e)))
1555 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1556 {
1557 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1558 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1559 }
1560 Ok(Err(e)) => return Err(e.into()),
1561 Err(e) => return Err(StoreError::Database(e.to_string())),
1562 }
1563 }
1564
1565 Err(StoreError::Database(
1566 "remove_signed_prekey exhausted retries".to_string(),
1567 ))
1568 }
1569
1570 async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1571 self.put_sender_key_for_device(address, record, self.device_id)
1572 .await
1573 }
1574
1575 async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1576 self.get_sender_key_for_device(address, self.device_id)
1577 .await
1578 }
1579
1580 async fn delete_sender_key(&self, address: &str) -> Result<()> {
1581 self.delete_sender_key_for_device(address, self.device_id)
1582 .await
1583 }
1584}
1585
1586#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1587#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1588impl AppSyncStore for SqliteStore {
1589 async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1590 self.get_app_state_sync_key_for_device(key_id, self.device_id)
1591 .await
1592 }
1593
1594 async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1595 self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1596 .await
1597 }
1598
1599 async fn get_version(&self, name: &str) -> Result<HashState> {
1600 self.get_app_state_version_for_device(name, self.device_id)
1601 .await
1602 }
1603
1604 async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1605 self.set_app_state_version_for_device(name, state, self.device_id)
1606 .await
1607 }
1608
1609 async fn put_mutation_macs(
1610 &self,
1611 name: &str,
1612 version: u64,
1613 mutations: &[AppStateMutationMAC],
1614 ) -> Result<()> {
1615 self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1616 .await
1617 }
1618
1619 async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1620 self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1621 .await
1622 }
1623
1624 async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1625 self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1626 .await
1627 }
1628
1629 async fn get_latest_sync_key_id(&self) -> Result<Option<Vec<u8>>> {
1630 self.get_latest_app_state_sync_key_id_for_device(self.device_id)
1631 .await
1632 }
1633}
1634
1635#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1636#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1637impl ProtocolStore for SqliteStore {
1638 async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<Jid>> {
1639 let pool = self.pool.clone();
1640 let device_id = self.device_id;
1641 let group_jid = group_jid.to_string();
1642 tokio::task::spawn_blocking(move || -> Result<Vec<Jid>> {
1643 let mut conn = pool
1644 .get()
1645 .map_err(|e| StoreError::Connection(e.to_string()))?;
1646 let recipients: Vec<String> = skdm_recipients::table
1647 .select(skdm_recipients::device_jid)
1648 .filter(skdm_recipients::group_jid.eq(&group_jid))
1649 .filter(skdm_recipients::device_id.eq(device_id))
1650 .load(&mut conn)
1651 .map_err(|e| StoreError::Database(e.to_string()))?;
1652 let jids: Vec<Jid> = recipients
1653 .iter()
1654 .filter_map(|s| match s.parse::<Jid>() {
1655 Ok(jid) => Some(jid),
1656 Err(e) => {
1657 warn!("Failed to parse SKDM recipient '{}': {}", s, e);
1658 None
1659 }
1660 })
1661 .collect();
1662 Ok(jids)
1663 })
1664 .await
1665 .map_err(|e| StoreError::Database(e.to_string()))?
1666 }
1667
1668 async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> Result<()> {
1669 if device_jids.is_empty() {
1670 return Ok(());
1671 }
1672 let pool = self.pool.clone();
1673 let device_id = self.device_id;
1674 let group_jid = group_jid.to_string();
1675 let device_jid_strs: Vec<String> = device_jids.iter().map(|j| j.to_string()).collect();
1676 let now = std::time::SystemTime::now()
1677 .duration_since(std::time::UNIX_EPOCH)
1678 .unwrap_or_default()
1679 .as_secs() as i32;
1680 tokio::task::spawn_blocking(move || -> Result<()> {
1681 let mut conn = pool
1682 .get()
1683 .map_err(|e| StoreError::Connection(e.to_string()))?;
1684
1685 let values: Vec<_> = device_jid_strs
1686 .iter()
1687 .map(|device_jid| {
1688 (
1689 skdm_recipients::group_jid.eq(&group_jid),
1690 skdm_recipients::device_jid.eq(device_jid),
1691 skdm_recipients::device_id.eq(device_id),
1692 skdm_recipients::created_at.eq(now),
1693 )
1694 })
1695 .collect();
1696
1697 const CHUNK_SIZE: usize = 200; for chunk in values.chunks(CHUNK_SIZE) {
1700 diesel::insert_into(skdm_recipients::table)
1701 .values(chunk)
1702 .on_conflict((
1703 skdm_recipients::group_jid,
1704 skdm_recipients::device_jid,
1705 skdm_recipients::device_id,
1706 ))
1707 .do_nothing()
1708 .execute(&mut conn)
1709 .map_err(|e| StoreError::Database(e.to_string()))?;
1710 }
1711 Ok(())
1712 })
1713 .await
1714 .map_err(|e| StoreError::Database(e.to_string()))??;
1715 Ok(())
1716 }
1717
1718 async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1719 let pool = self.pool.clone();
1720 let device_id = self.device_id;
1721 let group_jid = group_jid.to_string();
1722 tokio::task::spawn_blocking(move || -> Result<()> {
1723 let mut conn = pool
1724 .get()
1725 .map_err(|e| StoreError::Connection(e.to_string()))?;
1726 diesel::delete(
1727 skdm_recipients::table
1728 .filter(skdm_recipients::group_jid.eq(&group_jid))
1729 .filter(skdm_recipients::device_id.eq(device_id)),
1730 )
1731 .execute(&mut conn)
1732 .map_err(|e| StoreError::Database(e.to_string()))?;
1733 Ok(())
1734 })
1735 .await
1736 .map_err(|e| StoreError::Database(e.to_string()))??;
1737 Ok(())
1738 }
1739
1740 async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1741 let pool = self.pool.clone();
1742 let device_id = self.device_id;
1743 let lid = lid.to_string();
1744 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1745 let mut conn = pool
1746 .get()
1747 .map_err(|e| StoreError::Connection(e.to_string()))?;
1748 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1749 .select((
1750 lid_pn_mapping::lid,
1751 lid_pn_mapping::phone_number,
1752 lid_pn_mapping::created_at,
1753 lid_pn_mapping::learning_source,
1754 lid_pn_mapping::updated_at,
1755 ))
1756 .filter(lid_pn_mapping::lid.eq(&lid))
1757 .filter(lid_pn_mapping::device_id.eq(device_id))
1758 .first(&mut conn)
1759 .optional()
1760 .map_err(|e| StoreError::Database(e.to_string()))?;
1761 Ok(row.map(
1762 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1763 lid,
1764 phone_number,
1765 created_at,
1766 updated_at,
1767 learning_source,
1768 },
1769 ))
1770 })
1771 .await
1772 .map_err(|e| StoreError::Database(e.to_string()))?
1773 }
1774
1775 async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1776 let pool = self.pool.clone();
1777 let device_id = self.device_id;
1778 let phone = phone.to_string();
1779 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1780 let mut conn = pool
1781 .get()
1782 .map_err(|e| StoreError::Connection(e.to_string()))?;
1783 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1784 .select((
1785 lid_pn_mapping::lid,
1786 lid_pn_mapping::phone_number,
1787 lid_pn_mapping::created_at,
1788 lid_pn_mapping::learning_source,
1789 lid_pn_mapping::updated_at,
1790 ))
1791 .filter(lid_pn_mapping::phone_number.eq(&phone))
1792 .filter(lid_pn_mapping::device_id.eq(device_id))
1793 .order(lid_pn_mapping::updated_at.desc())
1794 .first(&mut conn)
1795 .optional()
1796 .map_err(|e| StoreError::Database(e.to_string()))?;
1797 Ok(row.map(
1798 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1799 lid,
1800 phone_number,
1801 created_at,
1802 updated_at,
1803 learning_source,
1804 },
1805 ))
1806 })
1807 .await
1808 .map_err(|e| StoreError::Database(e.to_string()))?
1809 }
1810
1811 async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1812 let pool = self.pool.clone();
1813 let device_id = self.device_id;
1814 let entry = entry.clone();
1815 tokio::task::spawn_blocking(move || -> Result<()> {
1816 let mut conn = pool
1817 .get()
1818 .map_err(|e| StoreError::Connection(e.to_string()))?;
1819 diesel::insert_into(lid_pn_mapping::table)
1820 .values((
1821 lid_pn_mapping::lid.eq(&entry.lid),
1822 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1823 lid_pn_mapping::created_at.eq(entry.created_at),
1824 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1825 lid_pn_mapping::updated_at.eq(entry.updated_at),
1826 lid_pn_mapping::device_id.eq(device_id),
1827 ))
1828 .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1829 .do_update()
1830 .set((
1831 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1832 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1833 lid_pn_mapping::updated_at.eq(entry.updated_at),
1834 ))
1835 .execute(&mut conn)
1836 .map_err(|e| StoreError::Database(e.to_string()))?;
1837 Ok(())
1838 })
1839 .await
1840 .map_err(|e| StoreError::Database(e.to_string()))??;
1841 Ok(())
1842 }
1843
1844 async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1845 let pool = self.pool.clone();
1846 let device_id = self.device_id;
1847 tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1848 let mut conn = pool
1849 .get()
1850 .map_err(|e| StoreError::Connection(e.to_string()))?;
1851 let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1852 .select((
1853 lid_pn_mapping::lid,
1854 lid_pn_mapping::phone_number,
1855 lid_pn_mapping::created_at,
1856 lid_pn_mapping::learning_source,
1857 lid_pn_mapping::updated_at,
1858 ))
1859 .filter(lid_pn_mapping::device_id.eq(device_id))
1860 .load(&mut conn)
1861 .map_err(|e| StoreError::Database(e.to_string()))?;
1862 Ok(rows
1863 .into_iter()
1864 .map(
1865 |(lid, phone_number, created_at, learning_source, updated_at)| {
1866 LidPnMappingEntry {
1867 lid,
1868 phone_number,
1869 created_at,
1870 updated_at,
1871 learning_source,
1872 }
1873 },
1874 )
1875 .collect())
1876 })
1877 .await
1878 .map_err(|e| StoreError::Database(e.to_string()))?
1879 }
1880
1881 async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1882 let pool = self.pool.clone();
1883 let device_id = self.device_id;
1884 let address = address.to_string();
1885 let message_id = message_id.to_string();
1886 let base_key = base_key.to_vec();
1887 let now = std::time::SystemTime::now()
1888 .duration_since(std::time::UNIX_EPOCH)
1889 .unwrap_or_default()
1890 .as_secs() as i32;
1891 tokio::task::spawn_blocking(move || -> Result<()> {
1892 let mut conn = pool
1893 .get()
1894 .map_err(|e| StoreError::Connection(e.to_string()))?;
1895 diesel::insert_into(base_keys::table)
1896 .values((
1897 base_keys::address.eq(&address),
1898 base_keys::message_id.eq(&message_id),
1899 base_keys::base_key.eq(&base_key),
1900 base_keys::device_id.eq(device_id),
1901 base_keys::created_at.eq(now),
1902 ))
1903 .on_conflict((
1904 base_keys::address,
1905 base_keys::message_id,
1906 base_keys::device_id,
1907 ))
1908 .do_update()
1909 .set(base_keys::base_key.eq(&base_key))
1910 .execute(&mut conn)
1911 .map_err(|e| StoreError::Database(e.to_string()))?;
1912 Ok(())
1913 })
1914 .await
1915 .map_err(|e| StoreError::Database(e.to_string()))??;
1916 Ok(())
1917 }
1918
1919 async fn has_same_base_key(
1920 &self,
1921 address: &str,
1922 message_id: &str,
1923 current_base_key: &[u8],
1924 ) -> Result<bool> {
1925 let pool = self.pool.clone();
1926 let device_id = self.device_id;
1927 let address = address.to_string();
1928 let message_id = message_id.to_string();
1929 let current_base_key = current_base_key.to_vec();
1930 tokio::task::spawn_blocking(move || -> Result<bool> {
1931 let mut conn = pool
1932 .get()
1933 .map_err(|e| StoreError::Connection(e.to_string()))?;
1934 let stored_key: Option<Vec<u8>> = base_keys::table
1935 .select(base_keys::base_key)
1936 .filter(base_keys::address.eq(&address))
1937 .filter(base_keys::message_id.eq(&message_id))
1938 .filter(base_keys::device_id.eq(device_id))
1939 .first(&mut conn)
1940 .optional()
1941 .map_err(|e| StoreError::Database(e.to_string()))?;
1942 Ok(stored_key.as_ref() == Some(¤t_base_key))
1943 })
1944 .await
1945 .map_err(|e| StoreError::Database(e.to_string()))?
1946 }
1947
1948 async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1949 let pool = self.pool.clone();
1950 let device_id = self.device_id;
1951 let address = address.to_string();
1952 let message_id = message_id.to_string();
1953 tokio::task::spawn_blocking(move || -> Result<()> {
1954 let mut conn = pool
1955 .get()
1956 .map_err(|e| StoreError::Connection(e.to_string()))?;
1957 diesel::delete(
1958 base_keys::table
1959 .filter(base_keys::address.eq(&address))
1960 .filter(base_keys::message_id.eq(&message_id))
1961 .filter(base_keys::device_id.eq(device_id)),
1962 )
1963 .execute(&mut conn)
1964 .map_err(|e| StoreError::Database(e.to_string()))?;
1965 Ok(())
1966 })
1967 .await
1968 .map_err(|e| StoreError::Database(e.to_string()))??;
1969 Ok(())
1970 }
1971
1972 async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1973 let pool = self.pool.clone();
1974 let device_id = self.device_id;
1975 let devices_json = serde_json::to_string(&record.devices)
1976 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1977 let now = std::time::SystemTime::now()
1978 .duration_since(std::time::UNIX_EPOCH)
1979 .unwrap_or_default()
1980 .as_secs() as i32;
1981 tokio::task::spawn_blocking(move || -> Result<()> {
1982 let mut conn = pool
1983 .get()
1984 .map_err(|e| StoreError::Connection(e.to_string()))?;
1985 diesel::insert_into(device_registry::table)
1986 .values((
1987 device_registry::user_id.eq(&record.user),
1988 device_registry::devices_json.eq(&devices_json),
1989 device_registry::timestamp.eq(record.timestamp as i32),
1990 device_registry::phash.eq(&record.phash),
1991 device_registry::device_id.eq(device_id),
1992 device_registry::updated_at.eq(now),
1993 ))
1994 .on_conflict((device_registry::user_id, device_registry::device_id))
1995 .do_update()
1996 .set((
1997 device_registry::devices_json.eq(&devices_json),
1998 device_registry::timestamp.eq(record.timestamp as i32),
1999 device_registry::phash.eq(&record.phash),
2000 device_registry::updated_at.eq(now),
2001 ))
2002 .execute(&mut conn)
2003 .map_err(|e| StoreError::Database(e.to_string()))?;
2004 Ok(())
2005 })
2006 .await
2007 .map_err(|e| StoreError::Database(e.to_string()))??;
2008 Ok(())
2009 }
2010
2011 async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
2012 let pool = self.pool.clone();
2013 let device_id = self.device_id;
2014 let user = user.to_string();
2015 tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
2016 let mut conn = pool
2017 .get()
2018 .map_err(|e| StoreError::Connection(e.to_string()))?;
2019 let row: Option<(String, String, i32, Option<String>)> = device_registry::table
2020 .select((
2021 device_registry::user_id,
2022 device_registry::devices_json,
2023 device_registry::timestamp,
2024 device_registry::phash,
2025 ))
2026 .filter(device_registry::user_id.eq(&user))
2027 .filter(device_registry::device_id.eq(device_id))
2028 .first(&mut conn)
2029 .optional()
2030 .map_err(|e| StoreError::Database(e.to_string()))?;
2031 match row {
2032 Some((user, devices_json, timestamp, phash)) => {
2033 let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
2034 .map_err(|e| StoreError::Serialization(e.to_string()))?;
2035 Ok(Some(DeviceListRecord {
2036 user,
2037 devices,
2038 timestamp: timestamp as i64,
2039 phash,
2040 }))
2041 }
2042 None => Ok(None),
2043 }
2044 })
2045 .await
2046 .map_err(|e| StoreError::Database(e.to_string()))?
2047 }
2048
2049 async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
2050 let pool = self.pool.clone();
2051 let device_id = self.device_id;
2052 let group_jid = group_jid.to_string();
2053 let participant = participant.to_string();
2054 let now = std::time::SystemTime::now()
2055 .duration_since(std::time::UNIX_EPOCH)
2056 .unwrap_or_default()
2057 .as_secs() as i32;
2058 tokio::task::spawn_blocking(move || -> Result<()> {
2059 let mut conn = pool
2060 .get()
2061 .map_err(|e| StoreError::Connection(e.to_string()))?;
2062 diesel::insert_into(sender_key_status::table)
2063 .values((
2064 sender_key_status::group_jid.eq(&group_jid),
2065 sender_key_status::participant.eq(&participant),
2066 sender_key_status::device_id.eq(device_id),
2067 sender_key_status::marked_at.eq(now),
2068 ))
2069 .on_conflict((
2070 sender_key_status::group_jid,
2071 sender_key_status::participant,
2072 sender_key_status::device_id,
2073 ))
2074 .do_update()
2075 .set(sender_key_status::marked_at.eq(now))
2076 .execute(&mut conn)
2077 .map_err(|e| StoreError::Database(e.to_string()))?;
2078 Ok(())
2079 })
2080 .await
2081 .map_err(|e| StoreError::Database(e.to_string()))??;
2082 Ok(())
2083 }
2084
2085 async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
2086 let pool = self.pool.clone();
2087 let device_id = self.device_id;
2088 let group_jid = group_jid.to_string();
2089 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2090 let mut conn = pool
2091 .get()
2092 .map_err(|e| StoreError::Connection(e.to_string()))?;
2093 let participants: Vec<String> = sender_key_status::table
2094 .select(sender_key_status::participant)
2095 .filter(sender_key_status::group_jid.eq(&group_jid))
2096 .filter(sender_key_status::device_id.eq(device_id))
2097 .load(&mut conn)
2098 .map_err(|e| StoreError::Database(e.to_string()))?;
2099 diesel::delete(
2100 sender_key_status::table
2101 .filter(sender_key_status::group_jid.eq(&group_jid))
2102 .filter(sender_key_status::device_id.eq(device_id)),
2103 )
2104 .execute(&mut conn)
2105 .map_err(|e| StoreError::Database(e.to_string()))?;
2106 Ok(participants)
2107 })
2108 .await
2109 .map_err(|e| StoreError::Database(e.to_string()))?
2110 }
2111
2112 async fn get_tc_token(&self, jid: &str) -> Result<Option<TcTokenEntry>> {
2113 let pool = self.pool.clone();
2114 let device_id = self.device_id;
2115 let jid = jid.to_string();
2116 tokio::task::spawn_blocking(move || -> Result<Option<TcTokenEntry>> {
2117 let mut conn = pool
2118 .get()
2119 .map_err(|e| StoreError::Connection(e.to_string()))?;
2120 let row: Option<(Vec<u8>, i64, Option<i64>)> = tc_tokens::table
2121 .select((
2122 tc_tokens::token,
2123 tc_tokens::token_timestamp,
2124 tc_tokens::sender_timestamp,
2125 ))
2126 .filter(tc_tokens::jid.eq(&jid))
2127 .filter(tc_tokens::device_id.eq(device_id))
2128 .first(&mut conn)
2129 .optional()
2130 .map_err(|e| StoreError::Database(e.to_string()))?;
2131 Ok(
2132 row.map(|(token, token_timestamp, sender_timestamp)| TcTokenEntry {
2133 token,
2134 token_timestamp,
2135 sender_timestamp,
2136 }),
2137 )
2138 })
2139 .await
2140 .map_err(|e| StoreError::Database(e.to_string()))?
2141 }
2142
2143 async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> Result<()> {
2144 let pool = self.pool.clone();
2145 let device_id = self.device_id;
2146 let jid = jid.to_string();
2147 let entry = entry.clone();
2148 let now = std::time::SystemTime::now()
2149 .duration_since(std::time::UNIX_EPOCH)
2150 .unwrap_or_default()
2151 .as_secs() as i64;
2152 tokio::task::spawn_blocking(move || -> Result<()> {
2153 let mut conn = pool
2154 .get()
2155 .map_err(|e| StoreError::Connection(e.to_string()))?;
2156 diesel::insert_into(tc_tokens::table)
2157 .values((
2158 tc_tokens::jid.eq(&jid),
2159 tc_tokens::token.eq(&entry.token),
2160 tc_tokens::token_timestamp.eq(entry.token_timestamp),
2161 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2162 tc_tokens::device_id.eq(device_id),
2163 tc_tokens::updated_at.eq(now),
2164 ))
2165 .on_conflict((tc_tokens::jid, tc_tokens::device_id))
2166 .do_update()
2167 .set((
2168 tc_tokens::token.eq(&entry.token),
2169 tc_tokens::token_timestamp.eq(entry.token_timestamp),
2170 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2171 tc_tokens::updated_at.eq(now),
2172 ))
2173 .execute(&mut conn)
2174 .map_err(|e| StoreError::Database(e.to_string()))?;
2175 Ok(())
2176 })
2177 .await
2178 .map_err(|e| StoreError::Database(e.to_string()))??;
2179 Ok(())
2180 }
2181
2182 async fn delete_tc_token(&self, jid: &str) -> Result<()> {
2183 let pool = self.pool.clone();
2184 let device_id = self.device_id;
2185 let jid = jid.to_string();
2186 tokio::task::spawn_blocking(move || -> Result<()> {
2187 let mut conn = pool
2188 .get()
2189 .map_err(|e| StoreError::Connection(e.to_string()))?;
2190 diesel::delete(
2191 tc_tokens::table
2192 .filter(tc_tokens::jid.eq(&jid))
2193 .filter(tc_tokens::device_id.eq(device_id)),
2194 )
2195 .execute(&mut conn)
2196 .map_err(|e| StoreError::Database(e.to_string()))?;
2197 Ok(())
2198 })
2199 .await
2200 .map_err(|e| StoreError::Database(e.to_string()))??;
2201 Ok(())
2202 }
2203
2204 async fn get_all_tc_token_jids(&self) -> Result<Vec<String>> {
2205 let pool = self.pool.clone();
2206 let device_id = self.device_id;
2207 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2208 let mut conn = pool
2209 .get()
2210 .map_err(|e| StoreError::Connection(e.to_string()))?;
2211 let jids: Vec<String> = tc_tokens::table
2212 .select(tc_tokens::jid)
2213 .filter(tc_tokens::device_id.eq(device_id))
2214 .load(&mut conn)
2215 .map_err(|e| StoreError::Database(e.to_string()))?;
2216 Ok(jids)
2217 })
2218 .await
2219 .map_err(|e| StoreError::Database(e.to_string()))?
2220 }
2221
2222 async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> Result<u32> {
2223 let pool = self.pool.clone();
2224 let device_id = self.device_id;
2225 tokio::task::spawn_blocking(move || -> Result<u32> {
2226 let mut conn = pool
2227 .get()
2228 .map_err(|e| StoreError::Connection(e.to_string()))?;
2229 let deleted = diesel::delete(
2230 tc_tokens::table
2231 .filter(tc_tokens::token_timestamp.lt(cutoff_timestamp))
2232 .filter(tc_tokens::device_id.eq(device_id)),
2233 )
2234 .execute(&mut conn)
2235 .map_err(|e| StoreError::Database(e.to_string()))?;
2236 Ok(deleted as u32)
2237 })
2238 .await
2239 .map_err(|e| StoreError::Database(e.to_string()))?
2240 }
2241
2242 async fn store_sent_message(
2243 &self,
2244 chat_jid: &str,
2245 message_id: &str,
2246 payload: &[u8],
2247 ) -> Result<()> {
2248 let chat_jid = chat_jid.to_string();
2249 let message_id = message_id.to_string();
2250 let payload: Arc<Vec<u8>> = Arc::new(payload.to_vec());
2252 let device_id = self.device_id;
2253 self.with_retry("store_sent_message", || {
2254 let chat_jid = chat_jid.clone();
2255 let message_id = message_id.clone();
2256 let payload = Arc::clone(&payload);
2257 Box::new(move |conn: &mut SqliteConnection| {
2258 diesel::replace_into(sent_messages::table)
2259 .values((
2260 sent_messages::chat_jid.eq(&chat_jid),
2261 sent_messages::message_id.eq(&message_id),
2262 sent_messages::payload.eq(payload.as_slice()),
2263 sent_messages::device_id.eq(device_id),
2264 ))
2265 .execute(conn)?;
2266 Ok(())
2267 })
2268 })
2269 .await
2270 }
2271
2272 async fn take_sent_message(&self, chat_jid: &str, message_id: &str) -> Result<Option<Vec<u8>>> {
2273 let chat_jid = chat_jid.to_string();
2274 let message_id = message_id.to_string();
2275 let device_id = self.device_id;
2276 self.with_retry("take_sent_message", || {
2278 let chat_jid = chat_jid.clone();
2279 let message_id = message_id.clone();
2280 Box::new(move |conn: &mut SqliteConnection| {
2281 conn.immediate_transaction(|conn| {
2282 let row: Option<Vec<u8>> = sent_messages::table
2283 .select(sent_messages::payload)
2284 .filter(sent_messages::chat_jid.eq(&chat_jid))
2285 .filter(sent_messages::message_id.eq(&message_id))
2286 .filter(sent_messages::device_id.eq(device_id))
2287 .first(conn)
2288 .optional()?;
2289 if row.is_some() {
2290 diesel::delete(
2291 sent_messages::table
2292 .filter(sent_messages::chat_jid.eq(&chat_jid))
2293 .filter(sent_messages::message_id.eq(&message_id))
2294 .filter(sent_messages::device_id.eq(device_id)),
2295 )
2296 .execute(conn)?;
2297 }
2298 Ok(row)
2299 })
2300 })
2301 })
2302 .await
2303 }
2304
2305 async fn delete_expired_sent_messages(&self, cutoff_timestamp: i64) -> Result<u32> {
2306 let pool = self.pool.clone();
2307 let device_id = self.device_id;
2308 tokio::task::spawn_blocking(move || -> Result<u32> {
2309 let mut conn = pool
2310 .get()
2311 .map_err(|e| StoreError::Connection(e.to_string()))?;
2312 let deleted = diesel::delete(
2313 sent_messages::table
2314 .filter(sent_messages::created_at.lt(cutoff_timestamp))
2315 .filter(sent_messages::device_id.eq(device_id)),
2316 )
2317 .execute(&mut conn)
2318 .map_err(|e| StoreError::Database(e.to_string()))?;
2319 Ok(deleted as u32)
2320 })
2321 .await
2322 .map_err(|e| StoreError::Database(e.to_string()))?
2323 }
2324}
2325
2326#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
2327#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
2328impl DeviceStore for SqliteStore {
2329 async fn save(&self, device: &CoreDevice) -> Result<()> {
2330 SqliteStore::save_device_data_for_device(self, self.device_id, device).await
2331 }
2332
2333 async fn load(&self) -> Result<Option<CoreDevice>> {
2334 SqliteStore::load_device_data_for_device(self, self.device_id).await
2335 }
2336
2337 async fn exists(&self) -> Result<bool> {
2338 SqliteStore::device_exists(self, self.device_id).await
2339 }
2340
2341 async fn create(&self) -> Result<i32> {
2342 SqliteStore::create_new_device(self).await
2343 }
2344
2345 async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> Result<()> {
2346 fn sanitize_snapshot_name(name: &str) -> Result<String> {
2347 const MAX_LENGTH: usize = 100;
2348
2349 let sanitized: String = name
2350 .chars()
2351 .map(|c| {
2352 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' {
2353 c
2354 } else {
2355 '_'
2356 }
2357 })
2358 .collect();
2359
2360 let sanitized = sanitized
2361 .split('.')
2362 .filter(|part| !part.is_empty() && *part != "..")
2363 .collect::<Vec<_>>()
2364 .join(".");
2365
2366 let sanitized = sanitized.trim_matches(['/', '\\', '.']);
2367
2368 if sanitized.is_empty() {
2369 return Err(StoreError::Database(
2370 "Snapshot name cannot be empty after sanitization".to_string(),
2371 ));
2372 }
2373
2374 if sanitized.len() > MAX_LENGTH {
2375 return Err(StoreError::Database(format!(
2376 "Snapshot name exceeds maximum length of {} characters",
2377 MAX_LENGTH
2378 )));
2379 }
2380
2381 Ok(sanitized.to_string())
2382 }
2383
2384 let sanitized_name = sanitize_snapshot_name(name)?;
2385
2386 let pool = self.pool.clone();
2387 let db_path = self.database_path.clone();
2388 let extra_data = extra_content.map(|b| b.to_vec());
2389
2390 tokio::task::spawn_blocking(move || -> Result<()> {
2391 let mut conn = pool
2392 .get()
2393 .map_err(|e| StoreError::Connection(e.to_string()))?;
2394
2395 let timestamp = std::time::SystemTime::now()
2396 .duration_since(std::time::UNIX_EPOCH)
2397 .unwrap_or_default()
2398 .as_secs();
2399
2400 let target_path = format!("{}.snapshot-{}-{}", db_path, timestamp, sanitized_name);
2402
2403 let query = format!("VACUUM INTO '{}'", target_path.replace("'", "''"));
2406
2407 diesel::sql_query(query)
2408 .execute(&mut conn)
2409 .map_err(|e| StoreError::Database(e.to_string()))?;
2410
2411 if let Some(data) = extra_data {
2413 let extra_path = format!("{}.json", target_path);
2414 std::fs::write(&extra_path, data).map_err(|e| {
2415 StoreError::Database(format!("Failed to write snapshot extra content: {}", e))
2416 })?;
2417 }
2418
2419 Ok(())
2420 })
2421 .await
2422 .map_err(|e| StoreError::Database(e.to_string()))??;
2423
2424 Ok(())
2425 }
2426}
2427
2428#[cfg(test)]
2429mod tests {
2430 use super::*;
2431
2432 async fn create_test_store() -> SqliteStore {
2433 use std::sync::atomic::{AtomicU64, Ordering};
2434 static COUNTER: AtomicU64 = AtomicU64::new(0);
2435 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
2436 let db_name = format!(
2437 "file:memdb_test_{}_{}?mode=memory&cache=shared",
2438 std::process::id(),
2439 id
2440 );
2441 SqliteStore::new(&db_name)
2442 .await
2443 .expect("Failed to create test store")
2444 }
2445
2446 #[test]
2447 fn test_parse_database_path_regular_path() {
2448 let path = "/var/lib/whatsapp/database.db";
2449 let result = parse_database_path(path).unwrap();
2450 assert_eq!(result, "/var/lib/whatsapp/database.db");
2451 }
2452
2453 #[test]
2454 fn test_parse_database_path_with_sqlite_prefix() {
2455 let path = "sqlite:///var/lib/whatsapp/database.db";
2456 let result = parse_database_path(path).unwrap();
2457 assert_eq!(result, "/var/lib/whatsapp/database.db");
2458 }
2459
2460 #[test]
2461 fn test_parse_database_path_with_query_params() {
2462 let path = "file:database.db?mode=memory&cache=shared";
2463 let result = parse_database_path(path).unwrap();
2464 assert_eq!(result, "file:database.db");
2465 }
2466
2467 #[test]
2468 fn test_parse_database_path_with_fragment() {
2469 let path = "file:database.db#fragment";
2470 let result = parse_database_path(path).unwrap();
2471 assert_eq!(result, "file:database.db");
2472 }
2473
2474 #[test]
2475 fn test_parse_database_path_with_both_query_and_fragment() {
2476 let path = "sqlite:///var/lib/database.db?mode=ro#backup";
2477 let result = parse_database_path(path).unwrap();
2478 assert_eq!(result, "/var/lib/database.db");
2479 }
2480
2481 #[test]
2482 fn test_parse_database_path_in_memory_rejected() {
2483 let result = parse_database_path(":memory:");
2484 assert!(result.is_err());
2485 assert!(result.unwrap_err().to_string().contains("not supported"));
2486 }
2487
2488 #[test]
2489 fn test_parse_database_path_in_memory_with_query_rejected() {
2490 let result = parse_database_path(":memory:?cache=shared");
2491 assert!(result.is_err());
2492 assert!(result.unwrap_err().to_string().contains("not supported"));
2493 }
2494
2495 #[tokio::test]
2496 async fn test_device_registry_save_and_get() {
2497 let store = create_test_store().await;
2498
2499 let record = DeviceListRecord {
2500 user: "1234567890".to_string(),
2501 devices: vec![
2502 DeviceInfo {
2503 device_id: 0,
2504 key_index: None,
2505 },
2506 DeviceInfo {
2507 device_id: 1,
2508 key_index: Some(42),
2509 },
2510 ],
2511 timestamp: 1234567890,
2512 phash: Some("2:abcdef".to_string()),
2513 };
2514
2515 store.update_device_list(record).await.expect("save failed");
2516 let loaded = store
2517 .get_devices("1234567890")
2518 .await
2519 .expect("get failed")
2520 .expect("record should exist");
2521
2522 assert_eq!(loaded.user, "1234567890");
2523 assert_eq!(loaded.devices.len(), 2);
2524 assert_eq!(loaded.devices[0].device_id, 0);
2525 assert_eq!(loaded.devices[1].device_id, 1);
2526 assert_eq!(loaded.devices[1].key_index, Some(42));
2527 assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
2528 }
2529
2530 #[tokio::test]
2531 async fn test_device_registry_update_existing() {
2532 let store = create_test_store().await;
2533
2534 let record1 = DeviceListRecord {
2535 user: "1234567890".to_string(),
2536 devices: vec![DeviceInfo {
2537 device_id: 0,
2538 key_index: None,
2539 }],
2540 timestamp: 1000,
2541 phash: Some("2:old".to_string()),
2542 };
2543 store
2544 .update_device_list(record1)
2545 .await
2546 .expect("save1 failed");
2547
2548 let record2 = DeviceListRecord {
2549 user: "1234567890".to_string(),
2550 devices: vec![
2551 DeviceInfo {
2552 device_id: 0,
2553 key_index: None,
2554 },
2555 DeviceInfo {
2556 device_id: 2,
2557 key_index: None,
2558 },
2559 ],
2560 timestamp: 2000,
2561 phash: Some("2:new".to_string()),
2562 };
2563 store
2564 .update_device_list(record2)
2565 .await
2566 .expect("save2 failed");
2567
2568 let loaded = store
2569 .get_devices("1234567890")
2570 .await
2571 .expect("get failed")
2572 .expect("record should exist");
2573
2574 assert_eq!(loaded.devices.len(), 2);
2575 assert_eq!(loaded.phash, Some("2:new".to_string()));
2576 }
2577
2578 #[tokio::test]
2579 async fn test_device_registry_get_nonexistent() {
2580 let store = create_test_store().await;
2581 let result = store.get_devices("nonexistent").await.expect("get failed");
2582 assert!(result.is_none());
2583 }
2584
2585 #[tokio::test]
2586 async fn test_sender_key_status_mark_and_consume() {
2587 let store = create_test_store().await;
2588
2589 let group = "group123@g.us";
2590 let participant = "user1@s.whatsapp.net";
2591
2592 store
2593 .mark_forget_sender_key(group, participant)
2594 .await
2595 .expect("mark failed");
2596
2597 let consumed = store
2598 .consume_forget_marks(group)
2599 .await
2600 .expect("consume failed");
2601 assert_eq!(consumed.len(), 1);
2602 assert!(consumed.contains(&participant.to_string()));
2603
2604 let consumed = store
2605 .consume_forget_marks(group)
2606 .await
2607 .expect("consume failed");
2608 assert!(consumed.is_empty());
2609 }
2610
2611 #[tokio::test]
2612 async fn test_sender_key_status_consume_multiple() {
2613 let store = create_test_store().await;
2614
2615 let group = "group123@g.us";
2616
2617 store
2618 .mark_forget_sender_key(group, "user1@s.whatsapp.net")
2619 .await
2620 .expect("mark failed");
2621 store
2622 .mark_forget_sender_key(group, "user2@s.whatsapp.net")
2623 .await
2624 .expect("mark failed");
2625
2626 let consumed = store
2627 .consume_forget_marks(group)
2628 .await
2629 .expect("consume failed");
2630 assert_eq!(consumed.len(), 2);
2631 assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
2632 assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
2633
2634 let consumed = store
2635 .consume_forget_marks(group)
2636 .await
2637 .expect("consume failed");
2638 assert!(consumed.is_empty());
2639 }
2640
2641 #[tokio::test]
2642 async fn test_tc_token_put_and_get() {
2643 let store = create_test_store().await;
2644
2645 let entry = TcTokenEntry {
2646 token: vec![1, 2, 3, 4, 5],
2647 token_timestamp: 1707000000,
2648 sender_timestamp: Some(1707000100),
2649 };
2650
2651 store
2652 .put_tc_token("user@lid", &entry)
2653 .await
2654 .expect("put failed");
2655
2656 let loaded = store
2657 .get_tc_token("user@lid")
2658 .await
2659 .expect("get failed")
2660 .expect("should exist");
2661
2662 assert_eq!(loaded.token, vec![1, 2, 3, 4, 5]);
2663 assert_eq!(loaded.token_timestamp, 1707000000);
2664 assert_eq!(loaded.sender_timestamp, Some(1707000100));
2665 }
2666
2667 #[tokio::test]
2668 async fn test_tc_token_upsert() {
2669 let store = create_test_store().await;
2670
2671 let entry1 = TcTokenEntry {
2672 token: vec![1, 2, 3],
2673 token_timestamp: 1000,
2674 sender_timestamp: None,
2675 };
2676 store.put_tc_token("user@lid", &entry1).await.unwrap();
2677
2678 let entry2 = TcTokenEntry {
2679 token: vec![4, 5, 6],
2680 token_timestamp: 2000,
2681 sender_timestamp: Some(1500),
2682 };
2683 store.put_tc_token("user@lid", &entry2).await.unwrap();
2684
2685 let loaded = store.get_tc_token("user@lid").await.unwrap().unwrap();
2686 assert_eq!(loaded.token, vec![4, 5, 6]);
2687 assert_eq!(loaded.token_timestamp, 2000);
2688 assert_eq!(loaded.sender_timestamp, Some(1500));
2689 }
2690
2691 #[tokio::test]
2692 async fn test_tc_token_delete() {
2693 let store = create_test_store().await;
2694
2695 let entry = TcTokenEntry {
2696 token: vec![1, 2, 3],
2697 token_timestamp: 1000,
2698 sender_timestamp: None,
2699 };
2700 store.put_tc_token("user@lid", &entry).await.unwrap();
2701 store.delete_tc_token("user@lid").await.unwrap();
2702
2703 let result = store.get_tc_token("user@lid").await.unwrap();
2704 assert!(result.is_none());
2705 }
2706
2707 #[tokio::test]
2708 async fn test_tc_token_get_all_jids() {
2709 let store = create_test_store().await;
2710
2711 let entry = TcTokenEntry {
2712 token: vec![1],
2713 token_timestamp: 1000,
2714 sender_timestamp: None,
2715 };
2716 store.put_tc_token("user1@lid", &entry).await.unwrap();
2717 store.put_tc_token("user2@lid", &entry).await.unwrap();
2718 store.put_tc_token("user3@lid", &entry).await.unwrap();
2719
2720 let mut jids = store.get_all_tc_token_jids().await.unwrap();
2721 jids.sort();
2722 assert_eq!(jids, vec!["user1@lid", "user2@lid", "user3@lid"]);
2723 }
2724
2725 #[tokio::test]
2726 async fn test_tc_token_delete_expired() {
2727 let store = create_test_store().await;
2728
2729 let old = TcTokenEntry {
2730 token: vec![1],
2731 token_timestamp: 1000,
2732 sender_timestamp: None,
2733 };
2734 let recent = TcTokenEntry {
2735 token: vec![2],
2736 token_timestamp: 5000,
2737 sender_timestamp: None,
2738 };
2739 store.put_tc_token("old@lid", &old).await.unwrap();
2740 store.put_tc_token("recent@lid", &recent).await.unwrap();
2741
2742 let deleted = store.delete_expired_tc_tokens(3000).await.unwrap();
2743 assert_eq!(deleted, 1);
2744
2745 assert!(store.get_tc_token("old@lid").await.unwrap().is_none());
2746 assert!(store.get_tc_token("recent@lid").await.unwrap().is_some());
2747 }
2748
2749 #[tokio::test]
2750 async fn test_tc_token_get_nonexistent() {
2751 let store = create_test_store().await;
2752 let result = store.get_tc_token("nonexistent@lid").await.unwrap();
2753 assert!(result.is_none());
2754 }
2755
2756 #[tokio::test]
2757 async fn test_sender_key_status_different_groups() {
2758 let store = create_test_store().await;
2759
2760 let group1 = "group1@g.us";
2761 let group2 = "group2@g.us";
2762 let participant = "user@s.whatsapp.net";
2763
2764 store
2765 .mark_forget_sender_key(group1, participant)
2766 .await
2767 .expect("mark failed");
2768
2769 let consumed = store.consume_forget_marks(group1).await.unwrap();
2770 assert_eq!(consumed.len(), 1);
2771
2772 let consumed = store.consume_forget_marks(group2).await.unwrap();
2773 assert!(consumed.is_empty());
2774 }
2775}