1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::result::{DatabaseErrorKind, Error as DieselError};
6use diesel::sql_query;
7use diesel::sqlite::SqliteConnection;
8use diesel::upsert::excluded;
9use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
10use log::warn;
11use prost::Message;
12use std::sync::Arc;
13use wacore::appstate::hash::HashState;
14use wacore::appstate::processor::AppStateMutationMAC;
15use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
16use wacore::store::Device as CoreDevice;
17use wacore::store::error::{Result, StoreError};
18use wacore::store::traits::*;
19use wacore_binary::jid::Jid;
20use waproto::whatsapp as wa;
21
22enum DieselOrStore {
26 Diesel(DieselError),
27 Store(StoreError),
28}
29
30impl From<DieselOrStore> for StoreError {
31 fn from(e: DieselOrStore) -> Self {
32 match e {
33 DieselOrStore::Diesel(e) => StoreError::Database(e.to_string()),
34 DieselOrStore::Store(e) => e,
35 }
36 }
37}
38
39fn is_retriable_sqlite_error(error: &DieselError) -> bool {
45 match error {
46 DieselError::DatabaseError(DatabaseErrorKind::Unknown, info) => {
47 let msg = info.message();
48 msg.contains("locked") || msg.contains("busy")
49 }
50 _ => false,
51 }
52}
53
54const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
55
56type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
57type DeviceRow = (
58 i32,
59 String,
60 String,
61 i32,
62 Vec<u8>,
63 Vec<u8>,
64 Vec<u8>,
65 i32,
66 Vec<u8>,
67 Vec<u8>,
68 Option<Vec<u8>>,
69 String,
70 i32,
71 i32,
72 i64,
73 i64,
74 Option<Vec<u8>>,
75 Option<String>,
76 i32,
77);
78
79#[derive(Clone)]
80pub struct SqliteStore {
81 pub(crate) pool: SqlitePool,
82 pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
83 pub(crate) database_path: String,
84 device_id: i32,
85}
86
87#[derive(Debug, Clone, Copy)]
88struct ConnectionOptions;
89
90impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
91 for ConnectionOptions
92{
93 fn on_acquire(
94 &self,
95 conn: &mut SqliteConnection,
96 ) -> std::result::Result<(), diesel::r2d2::Error> {
97 diesel::sql_query("PRAGMA busy_timeout = 30000;")
98 .execute(conn)
99 .map_err(diesel::r2d2::Error::QueryError)?;
100 diesel::sql_query("PRAGMA synchronous = NORMAL;")
101 .execute(conn)
102 .map_err(diesel::r2d2::Error::QueryError)?;
103 diesel::sql_query("PRAGMA cache_size = 512;")
104 .execute(conn)
105 .map_err(diesel::r2d2::Error::QueryError)?;
106 diesel::sql_query("PRAGMA temp_store = memory;")
107 .execute(conn)
108 .map_err(diesel::r2d2::Error::QueryError)?;
109 diesel::sql_query("PRAGMA foreign_keys = ON;")
110 .execute(conn)
111 .map_err(diesel::r2d2::Error::QueryError)?;
112 Ok(())
113 }
114}
115
116fn parse_database_path(database_url: &str) -> Result<String> {
117 if database_url == ":memory:" {
119 return Err(StoreError::Database(
120 "Snapshot not supported for in-memory databases".to_string(),
121 ));
122 }
123
124 let path = database_url
126 .split(['?', '#'])
127 .next()
128 .unwrap_or(database_url);
129
130 let path = path.trim_start_matches("sqlite://");
132
133 if path == ":memory:" || path.starts_with(":memory:?") {
135 return Err(StoreError::Database(
136 "Snapshot not supported for in-memory databases".to_string(),
137 ));
138 }
139
140 Ok(path.to_string())
141}
142
143impl SqliteStore {
144 pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
145 let manager = ConnectionManager::<SqliteConnection>::new(database_url);
146
147 let pool_size = 2;
148
149 let pool = Pool::builder()
150 .max_size(pool_size)
151 .connection_customizer(Box::new(ConnectionOptions))
152 .build(manager)
153 .map_err(|e| StoreError::Connection(e.to_string()))?;
154
155 let pool_clone = pool.clone();
156 tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
157 let mut conn = pool_clone
158 .get()
159 .map_err(|e| StoreError::Connection(e.to_string()))?;
160
161 diesel::sql_query("PRAGMA journal_mode = WAL;")
162 .execute(&mut conn)
163 .map_err(|e| StoreError::Database(e.to_string()))?;
164
165 conn.run_pending_migrations(MIGRATIONS)
166 .map_err(|e| StoreError::Migration(e.to_string()))?;
167
168 Ok(())
169 })
170 .await
171 .map_err(|e| StoreError::Database(e.to_string()))??;
172
173 let database_path = parse_database_path(database_url)?;
174
175 Ok(Self {
176 pool,
177 db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
178 database_path,
179 device_id: 1,
180 })
181 }
182
183 pub async fn new_for_device(
184 database_url: &str,
185 device_id: i32,
186 ) -> std::result::Result<Self, StoreError> {
187 let mut store = Self::new(database_url).await?;
188 store.device_id = device_id;
189 Ok(store)
190 }
191
192 pub fn device_id(&self) -> i32 {
193 self.device_id
194 }
195
196 async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
197 where
198 F: FnOnce() -> Result<T> + Send + 'static,
199 T: Send + 'static,
200 {
201 let permit = self
202 .db_semaphore
203 .clone()
204 .acquire_owned()
205 .await
206 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
207 let result = tokio::task::spawn_blocking(move || {
208 let res = f();
209 drop(permit);
210 res
211 })
212 .await
213 .map_err(|e| StoreError::Database(e.to_string()))??;
214 Ok(result)
215 }
216
217 async fn with_retry<F, T>(&self, op_name: &str, make_op: F) -> Result<T>
221 where
222 F: Fn() -> Box<
223 dyn FnOnce(&mut SqliteConnection) -> std::result::Result<T, DieselError> + Send,
224 >,
225 T: Send + 'static,
226 {
227 const MAX_RETRIES: u32 = 5;
228
229 for attempt in 0..=MAX_RETRIES {
230 let permit = self
231 .db_semaphore
232 .clone()
233 .acquire_owned()
234 .await
235 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
236
237 let pool = self.pool.clone();
238 let op = make_op();
239
240 let result =
241 tokio::task::spawn_blocking(move || -> std::result::Result<T, DieselOrStore> {
242 let _permit = permit;
243 let mut conn = pool
244 .get()
245 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
246 op(&mut conn).map_err(DieselOrStore::Diesel)
247 })
248 .await;
249
250 match result {
251 Ok(Ok(val)) => return Ok(val),
252 Ok(Err(DieselOrStore::Diesel(ref e)))
253 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
254 {
255 let delay_ms = 10u64 * (1u64 << attempt.min(4));
256 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
257 }
258 Ok(Err(e)) => return Err(e.into()),
259 Err(e) => return Err(StoreError::Database(e.to_string())),
260 }
261 }
262
263 Err(StoreError::Database(format!(
264 "{} exhausted retries",
265 op_name
266 )))
267 }
268
269 fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
270 let mut bytes = Vec::with_capacity(64);
271 bytes.extend_from_slice(key_pair.private_key.serialize());
272 bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
273 Ok(bytes)
274 }
275
276 fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
277 if bytes.len() != 64 {
278 return Err(StoreError::Serialization(format!(
279 "Invalid KeyPair length: {}",
280 bytes.len()
281 )));
282 }
283
284 let private_key = PrivateKey::deserialize(&bytes[0..32])
285 .map_err(|e| StoreError::Serialization(e.to_string()))?;
286 let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
287 .map_err(|e| StoreError::Serialization(e.to_string()))?;
288
289 Ok(KeyPair::new(public_key, private_key))
290 }
291
292 pub async fn save_device_data_for_device(
293 &self,
294 device_id: i32,
295 device_data: &CoreDevice,
296 ) -> Result<()> {
297 let pool = self.pool.clone();
298 let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
299 let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
300 let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
301 let account_data = device_data
302 .account
303 .as_ref()
304 .map(|account| account.encode_to_vec());
305 let registration_id = device_data.registration_id as i32;
306 let signed_pre_key_id = device_data.signed_pre_key_id as i32;
307 let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
308 let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
309 let push_name = device_data.push_name.clone();
310 let app_version_primary = device_data.app_version_primary as i32;
311 let app_version_secondary = device_data.app_version_secondary as i32;
312 let app_version_tertiary = device_data.app_version_tertiary as i64;
313 let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
314 let edge_routing_info = device_data.edge_routing_info.clone();
315 let props_hash = device_data.props_hash.clone();
316 let next_pre_key_id = device_data.next_pre_key_id as i32;
317 let new_lid = device_data
318 .lid
319 .as_ref()
320 .map(|j| j.to_string())
321 .unwrap_or_default();
322 let new_pn = device_data
323 .pn
324 .as_ref()
325 .map(|j| j.to_string())
326 .unwrap_or_default();
327
328 tokio::task::spawn_blocking(move || -> Result<()> {
329 let mut conn = pool
330 .get()
331 .map_err(|e| StoreError::Connection(e.to_string()))?;
332
333 diesel::insert_into(device::table)
334 .values((
335 device::id.eq(device_id),
336 device::lid.eq(&new_lid),
337 device::pn.eq(&new_pn),
338 device::registration_id.eq(registration_id),
339 device::noise_key.eq(&noise_key_data),
340 device::identity_key.eq(&identity_key_data),
341 device::signed_pre_key.eq(&signed_pre_key_data),
342 device::signed_pre_key_id.eq(signed_pre_key_id),
343 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
344 device::adv_secret_key.eq(&adv_secret_key[..]),
345 device::account.eq(account_data.clone()),
346 device::push_name.eq(&push_name),
347 device::app_version_primary.eq(app_version_primary),
348 device::app_version_secondary.eq(app_version_secondary),
349 device::app_version_tertiary.eq(app_version_tertiary),
350 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
351 device::edge_routing_info.eq(edge_routing_info.clone()),
352 device::props_hash.eq(props_hash.clone()),
353 device::next_pre_key_id.eq(next_pre_key_id),
354 ))
355 .on_conflict(device::id)
356 .do_update()
357 .set((
358 device::lid.eq(&new_lid),
359 device::pn.eq(&new_pn),
360 device::registration_id.eq(registration_id),
361 device::noise_key.eq(&noise_key_data),
362 device::identity_key.eq(&identity_key_data),
363 device::signed_pre_key.eq(&signed_pre_key_data),
364 device::signed_pre_key_id.eq(signed_pre_key_id),
365 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
366 device::adv_secret_key.eq(&adv_secret_key[..]),
367 device::account.eq(account_data.clone()),
368 device::push_name.eq(&push_name),
369 device::app_version_primary.eq(app_version_primary),
370 device::app_version_secondary.eq(app_version_secondary),
371 device::app_version_tertiary.eq(app_version_tertiary),
372 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
373 device::edge_routing_info.eq(edge_routing_info),
374 device::props_hash.eq(props_hash),
375 device::next_pre_key_id.eq(next_pre_key_id),
376 ))
377 .execute(&mut conn)
378 .map_err(|e| StoreError::Database(e.to_string()))?;
379
380 Ok(())
381 })
382 .await
383 .map_err(|e| StoreError::Database(e.to_string()))??;
384
385 Ok(())
386 }
387
388 pub async fn create_new_device(&self) -> Result<i32> {
389 use crate::schema::device;
390
391 let pool = self.pool.clone();
392 tokio::task::spawn_blocking(move || -> Result<i32> {
393 let mut conn = pool
394 .get()
395 .map_err(|e| StoreError::Connection(e.to_string()))?;
396
397 let new_device = wacore::store::Device::new();
398
399 let noise_key_data = {
400 let mut bytes = Vec::with_capacity(64);
401 bytes.extend_from_slice(new_device.noise_key.private_key.serialize());
402 bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
403 bytes
404 };
405 let identity_key_data = {
406 let mut bytes = Vec::with_capacity(64);
407 bytes.extend_from_slice(new_device.identity_key.private_key.serialize());
408 bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
409 bytes
410 };
411 let signed_pre_key_data = {
412 let mut bytes = Vec::with_capacity(64);
413 bytes.extend_from_slice(new_device.signed_pre_key.private_key.serialize());
414 bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
415 bytes
416 };
417
418 diesel::insert_into(device::table)
419 .values((
420 device::lid.eq(""),
421 device::pn.eq(""),
422 device::registration_id.eq(new_device.registration_id as i32),
423 device::noise_key.eq(&noise_key_data),
424 device::identity_key.eq(&identity_key_data),
425 device::signed_pre_key.eq(&signed_pre_key_data),
426 device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
427 device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
428 device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
429 device::account.eq(None::<Vec<u8>>),
430 device::push_name.eq(&new_device.push_name),
431 device::app_version_primary.eq(new_device.app_version_primary as i32),
432 device::app_version_secondary.eq(new_device.app_version_secondary as i32),
433 device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
434 device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
435 device::edge_routing_info.eq(None::<Vec<u8>>),
436 device::props_hash.eq(None::<String>),
437 device::next_pre_key_id.eq(new_device.next_pre_key_id as i32),
438 ))
439 .execute(&mut conn)
440 .map_err(|e| StoreError::Database(e.to_string()))?;
441
442 use diesel::sql_types::Integer;
443
444 #[derive(QueryableByName)]
445 struct LastInsertedId {
446 #[diesel(sql_type = Integer)]
447 last_insert_rowid: i32,
448 }
449
450 let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
451 .get_result::<LastInsertedId>(&mut conn)
452 .map_err(|e| StoreError::Database(e.to_string()))?
453 .last_insert_rowid;
454
455 Ok(device_id)
456 })
457 .await
458 .map_err(|e| StoreError::Database(e.to_string()))?
459 }
460
461 pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
462 use crate::schema::device;
463
464 let pool = self.pool.clone();
465 tokio::task::spawn_blocking(move || -> Result<bool> {
466 let mut conn = pool
467 .get()
468 .map_err(|e| StoreError::Connection(e.to_string()))?;
469
470 let count: i64 = device::table
471 .filter(device::id.eq(device_id))
472 .count()
473 .get_result(&mut conn)
474 .map_err(|e| StoreError::Database(e.to_string()))?;
475
476 Ok(count > 0)
477 })
478 .await
479 .map_err(|e| StoreError::Database(e.to_string()))?
480 }
481
482 pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
483 use crate::schema::device;
484
485 let pool = self.pool.clone();
486 let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
487 let mut conn = pool
488 .get()
489 .map_err(|e| StoreError::Connection(e.to_string()))?;
490 let result = device::table
491 .filter(device::id.eq(device_id))
492 .first::<DeviceRow>(&mut conn)
493 .optional()
494 .map_err(|e| StoreError::Database(e.to_string()))?;
495 Ok(result)
496 })
497 .await
498 .map_err(|e| StoreError::Database(e.to_string()))??;
499
500 if let Some((
501 _device_id,
502 lid_str,
503 pn_str,
504 registration_id,
505 noise_key_data,
506 identity_key_data,
507 signed_pre_key_data,
508 signed_pre_key_id,
509 signed_pre_key_signature_data,
510 adv_secret_key_data,
511 account_data,
512 push_name,
513 app_version_primary,
514 app_version_secondary,
515 app_version_tertiary,
516 app_version_last_fetched_ms,
517 edge_routing_info,
518 props_hash,
519 next_pre_key_id,
520 )) = row
521 {
522 let id = if !pn_str.is_empty() {
523 pn_str.parse().ok()
524 } else {
525 None
526 };
527 let lid = if !lid_str.is_empty() {
528 lid_str.parse().ok()
529 } else {
530 None
531 };
532
533 let noise_key = self.deserialize_keypair(&noise_key_data)?;
534 let identity_key = self.deserialize_keypair(&identity_key_data)?;
535 let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
536
537 let signed_pre_key_signature: [u8; 64] =
538 signed_pre_key_signature_data.try_into().map_err(|_| {
539 StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
540 })?;
541
542 let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
543 StoreError::Serialization("Invalid adv_secret_key length".to_string())
544 })?;
545
546 let account = account_data
547 .map(|data| {
548 wa::AdvSignedDeviceIdentity::decode(&data[..])
549 .map_err(|e| StoreError::Serialization(e.to_string()))
550 })
551 .transpose()?;
552
553 Ok(Some(CoreDevice {
554 pn: id,
555 lid,
556 registration_id: registration_id as u32,
557 noise_key,
558 identity_key,
559 signed_pre_key,
560 signed_pre_key_id: signed_pre_key_id as u32,
561 signed_pre_key_signature,
562 adv_secret_key,
563 account,
564 push_name,
565 app_version_primary: app_version_primary as u32,
566 app_version_secondary: app_version_secondary as u32,
567 app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
568 app_version_last_fetched_ms,
569 device_props: {
570 use wacore::store::device::DEVICE_PROPS;
571 DEVICE_PROPS.clone()
572 },
573 edge_routing_info,
574 props_hash,
575 next_pre_key_id: next_pre_key_id as u32,
576 }))
577 } else {
578 Ok(None)
579 }
580 }
581
582 pub async fn put_identity_for_device(
583 &self,
584 address: &str,
585 key: [u8; 32],
586 device_id: i32,
587 ) -> Result<()> {
588 let pool = self.pool.clone();
589 let db_semaphore = self.db_semaphore.clone();
590 let address_owned = address.to_string();
591 let key_vec = key.to_vec();
592
593 const MAX_RETRIES: u32 = 5;
594
595 for attempt in 0..=MAX_RETRIES {
596 let permit =
597 db_semaphore.clone().acquire_owned().await.map_err(|e| {
598 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
599 })?;
600
601 let pool_clone = pool.clone();
602 let address_clone = address_owned.clone();
603 let key_clone = key_vec.clone();
604
605 let result =
606 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
607 let mut conn = pool_clone
608 .get()
609 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
610 diesel::insert_into(identities::table)
611 .values((
612 identities::address.eq(address_clone),
613 identities::key.eq(&key_clone[..]),
614 identities::device_id.eq(device_id),
615 ))
616 .on_conflict((identities::address, identities::device_id))
617 .do_update()
618 .set(identities::key.eq(&key_clone[..]))
619 .execute(&mut conn)
620 .map_err(DieselOrStore::Diesel)?;
621 Ok(())
622 })
623 .await;
624
625 drop(permit);
626
627 match result {
628 Ok(Ok(())) => return Ok(()),
629 Ok(Err(DieselOrStore::Diesel(ref e)))
630 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
631 {
632 let delay_ms = 10 * 2u64.pow(attempt);
633 warn!(
634 "Identity write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
635 attempt + 1,
636 MAX_RETRIES + 1,
637 );
638 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
639 continue;
640 }
641 Ok(Err(e)) => return Err(e.into()),
642 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
643 }
644 }
645
646 Err(StoreError::Database(format!(
647 "Identity write failed after {} attempts",
648 MAX_RETRIES + 1
649 )))
650 }
651
652 pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
653 let pool = self.pool.clone();
654 let address_owned = address.to_string();
655
656 tokio::task::spawn_blocking(move || -> Result<()> {
657 let mut conn = pool
658 .get()
659 .map_err(|e| StoreError::Connection(e.to_string()))?;
660 diesel::delete(
661 identities::table
662 .filter(identities::address.eq(address_owned))
663 .filter(identities::device_id.eq(device_id)),
664 )
665 .execute(&mut conn)
666 .map_err(|e| StoreError::Database(e.to_string()))?;
667 Ok(())
668 })
669 .await
670 .map_err(|e| StoreError::Database(e.to_string()))??;
671
672 Ok(())
673 }
674
675 pub async fn load_identity_for_device(
676 &self,
677 address: &str,
678 device_id: i32,
679 ) -> Result<Option<Vec<u8>>> {
680 let pool = self.pool.clone();
681 let address = address.to_string();
682 let result = self
683 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
684 let mut conn = pool
685 .get()
686 .map_err(|e| StoreError::Connection(e.to_string()))?;
687 let res: Option<Vec<u8>> = identities::table
688 .select(identities::key)
689 .filter(identities::address.eq(address))
690 .filter(identities::device_id.eq(device_id))
691 .first(&mut conn)
692 .optional()
693 .map_err(|e| StoreError::Database(e.to_string()))?;
694 Ok(res)
695 })
696 .await?;
697
698 Ok(result)
699 }
700
701 pub async fn get_session_for_device(
702 &self,
703 address: &str,
704 device_id: i32,
705 ) -> Result<Option<Vec<u8>>> {
706 let pool = self.pool.clone();
707 let address_for_query = address.to_string();
708 let result = self
709 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
710 let mut conn = pool
711 .get()
712 .map_err(|e| StoreError::Connection(e.to_string()))?;
713 let res: Option<Vec<u8>> = sessions::table
714 .select(sessions::record)
715 .filter(sessions::address.eq(address_for_query.clone()))
716 .filter(sessions::device_id.eq(device_id))
717 .first(&mut conn)
718 .optional()
719 .map_err(|e| StoreError::Database(e.to_string()))?;
720
721 Ok(res)
722 })
723 .await?;
724
725 Ok(result)
726 }
727
728 pub async fn put_session_for_device(
729 &self,
730 address: &str,
731 session: &[u8],
732 device_id: i32,
733 ) -> Result<()> {
734 let pool = self.pool.clone();
735 let db_semaphore = self.db_semaphore.clone();
736 let address_owned = address.to_string();
737 let session_vec = session.to_vec();
738
739 const MAX_RETRIES: u32 = 5;
740
741 for attempt in 0..=MAX_RETRIES {
742 let permit =
743 db_semaphore.clone().acquire_owned().await.map_err(|e| {
744 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
745 })?;
746
747 let pool_clone = pool.clone();
748 let address_clone = address_owned.clone();
749 let session_clone = session_vec.clone();
750
751 let result =
752 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
753 let mut conn = pool_clone
754 .get()
755 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
756 diesel::insert_into(sessions::table)
757 .values((
758 sessions::address.eq(address_clone),
759 sessions::record.eq(&session_clone),
760 sessions::device_id.eq(device_id),
761 ))
762 .on_conflict((sessions::address, sessions::device_id))
763 .do_update()
764 .set(sessions::record.eq(&session_clone))
765 .execute(&mut conn)
766 .map_err(DieselOrStore::Diesel)?;
767 Ok(())
768 })
769 .await;
770
771 drop(permit);
772
773 match result {
774 Ok(Ok(())) => return Ok(()),
775 Ok(Err(DieselOrStore::Diesel(ref e)))
776 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
777 {
778 let delay_ms = 10 * 2u64.pow(attempt);
779 warn!(
780 "Session write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
781 attempt + 1,
782 MAX_RETRIES + 1,
783 );
784 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
785 continue;
786 }
787 Ok(Err(e)) => return Err(e.into()),
788 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
789 }
790 }
791
792 Err(StoreError::Database(format!(
793 "Session write failed after {} attempts",
794 MAX_RETRIES + 1
795 )))
796 }
797
798 pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
799 let pool = self.pool.clone();
800 let address_owned = address.to_string();
801
802 tokio::task::spawn_blocking(move || -> Result<()> {
803 let mut conn = pool
804 .get()
805 .map_err(|e| StoreError::Connection(e.to_string()))?;
806 diesel::delete(
807 sessions::table
808 .filter(sessions::address.eq(address_owned))
809 .filter(sessions::device_id.eq(device_id)),
810 )
811 .execute(&mut conn)
812 .map_err(|e| StoreError::Database(e.to_string()))?;
813 Ok(())
814 })
815 .await
816 .map_err(|e| StoreError::Database(e.to_string()))??;
817
818 Ok(())
819 }
820
821 pub async fn put_sender_key_for_device(
822 &self,
823 address: &str,
824 record: &[u8],
825 device_id: i32,
826 ) -> Result<()> {
827 let pool = self.pool.clone();
828 let address = address.to_string();
829 let record_vec = record.to_vec();
830 tokio::task::spawn_blocking(move || -> Result<()> {
831 let mut conn = pool
832 .get()
833 .map_err(|e| StoreError::Connection(e.to_string()))?;
834 diesel::insert_into(sender_keys::table)
835 .values((
836 sender_keys::address.eq(address),
837 sender_keys::record.eq(&record_vec),
838 sender_keys::device_id.eq(device_id),
839 ))
840 .on_conflict((sender_keys::address, sender_keys::device_id))
841 .do_update()
842 .set(sender_keys::record.eq(&record_vec))
843 .execute(&mut conn)
844 .map_err(|e| StoreError::Database(e.to_string()))?;
845 Ok(())
846 })
847 .await
848 .map_err(|e| StoreError::Database(e.to_string()))??;
849 Ok(())
850 }
851
852 pub async fn get_sender_key_for_device(
853 &self,
854 address: &str,
855 device_id: i32,
856 ) -> Result<Option<Vec<u8>>> {
857 let pool = self.pool.clone();
858 let address = address.to_string();
859 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
860 let mut conn = pool
861 .get()
862 .map_err(|e| StoreError::Connection(e.to_string()))?;
863 let res: Option<Vec<u8>> = sender_keys::table
864 .select(sender_keys::record)
865 .filter(sender_keys::address.eq(address))
866 .filter(sender_keys::device_id.eq(device_id))
867 .first(&mut conn)
868 .optional()
869 .map_err(|e| StoreError::Database(e.to_string()))?;
870 Ok(res)
871 })
872 .await
873 .map_err(|e| StoreError::Database(e.to_string()))?
874 }
875
876 pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
877 let pool = self.pool.clone();
878 let address = address.to_string();
879 tokio::task::spawn_blocking(move || -> Result<()> {
880 let mut conn = pool
881 .get()
882 .map_err(|e| StoreError::Connection(e.to_string()))?;
883 diesel::delete(
884 sender_keys::table
885 .filter(sender_keys::address.eq(address))
886 .filter(sender_keys::device_id.eq(device_id)),
887 )
888 .execute(&mut conn)
889 .map_err(|e| StoreError::Database(e.to_string()))?;
890 Ok(())
891 })
892 .await
893 .map_err(|e| StoreError::Database(e.to_string()))??;
894 Ok(())
895 }
896
897 pub async fn get_app_state_sync_key_for_device(
898 &self,
899 key_id: &[u8],
900 device_id: i32,
901 ) -> Result<Option<AppStateSyncKey>> {
902 let pool = self.pool.clone();
903 let key_id = key_id.to_vec();
904 let res: Option<Vec<u8>> =
905 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
906 let mut conn = pool
907 .get()
908 .map_err(|e| StoreError::Connection(e.to_string()))?;
909 let res: Option<Vec<u8>> = app_state_keys::table
910 .select(app_state_keys::key_data)
911 .filter(app_state_keys::key_id.eq(&key_id))
912 .filter(app_state_keys::device_id.eq(device_id))
913 .first(&mut conn)
914 .optional()
915 .map_err(|e| StoreError::Database(e.to_string()))?;
916 Ok(res)
917 })
918 .await
919 .map_err(|e| StoreError::Database(e.to_string()))??;
920
921 if let Some(data) = res {
922 let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
923 .map_err(|e| StoreError::Serialization(e.to_string()))?;
924 Ok(Some(key))
925 } else {
926 Ok(None)
927 }
928 }
929
930 pub async fn set_app_state_sync_key_for_device(
931 &self,
932 key_id: &[u8],
933 key: AppStateSyncKey,
934 device_id: i32,
935 ) -> Result<()> {
936 let pool = self.pool.clone();
937 let key_id = key_id.to_vec();
938 let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
939 .map_err(|e| StoreError::Serialization(e.to_string()))?;
940 tokio::task::spawn_blocking(move || -> Result<()> {
941 let mut conn = pool
942 .get()
943 .map_err(|e| StoreError::Connection(e.to_string()))?;
944 diesel::insert_into(app_state_keys::table)
945 .values((
946 app_state_keys::key_id.eq(&key_id),
947 app_state_keys::key_data.eq(&data),
948 app_state_keys::device_id.eq(device_id),
949 ))
950 .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
951 .do_update()
952 .set(app_state_keys::key_data.eq(&data))
953 .execute(&mut conn)
954 .map_err(|e| StoreError::Database(e.to_string()))?;
955 Ok(())
956 })
957 .await
958 .map_err(|e| StoreError::Database(e.to_string()))??;
959 Ok(())
960 }
961
962 pub async fn get_latest_app_state_sync_key_id_for_device(
963 &self,
964 device_id: i32,
965 ) -> Result<Option<Vec<u8>>> {
966 let pool = self.pool.clone();
967 let res: Option<Vec<u8>> =
968 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
969 let mut conn = pool
970 .get()
971 .map_err(|e| StoreError::Connection(e.to_string()))?;
972 let res: Option<Vec<u8>> = app_state_keys::table
973 .select(app_state_keys::key_id)
974 .filter(app_state_keys::device_id.eq(device_id))
975 .order(app_state_keys::key_id.desc())
976 .first(&mut conn)
977 .optional()
978 .map_err(|e| StoreError::Database(e.to_string()))?;
979 Ok(res)
980 })
981 .await
982 .map_err(|e| StoreError::Database(e.to_string()))??;
983 Ok(res)
984 }
985
986 pub async fn get_app_state_version_for_device(
987 &self,
988 name: &str,
989 device_id: i32,
990 ) -> Result<HashState> {
991 let pool = self.pool.clone();
992 let name = name.to_string();
993 let res: Option<Vec<u8>> =
994 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
995 let mut conn = pool
996 .get()
997 .map_err(|e| StoreError::Connection(e.to_string()))?;
998 let res: Option<Vec<u8>> = app_state_versions::table
999 .select(app_state_versions::state_data)
1000 .filter(app_state_versions::name.eq(name))
1001 .filter(app_state_versions::device_id.eq(device_id))
1002 .first(&mut conn)
1003 .optional()
1004 .map_err(|e| StoreError::Database(e.to_string()))?;
1005 Ok(res)
1006 })
1007 .await
1008 .map_err(|e| StoreError::Database(e.to_string()))??;
1009
1010 if let Some(data) = res {
1011 let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
1012 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1013 Ok(state)
1014 } else {
1015 Ok(HashState::default())
1016 }
1017 }
1018
1019 pub async fn set_app_state_version_for_device(
1020 &self,
1021 name: &str,
1022 state: HashState,
1023 device_id: i32,
1024 ) -> Result<()> {
1025 let name = name.to_string();
1026 let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
1027 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1028 self.with_retry("set_app_state_version", || {
1029 let name = name.clone();
1030 let data = data.clone();
1031 Box::new(move |conn: &mut SqliteConnection| {
1032 diesel::insert_into(app_state_versions::table)
1033 .values((
1034 app_state_versions::name.eq(&name),
1035 app_state_versions::state_data.eq(&data),
1036 app_state_versions::device_id.eq(device_id),
1037 ))
1038 .on_conflict((app_state_versions::name, app_state_versions::device_id))
1039 .do_update()
1040 .set(app_state_versions::state_data.eq(&data))
1041 .execute(conn)?;
1042 Ok(())
1043 })
1044 })
1045 .await
1046 }
1047
1048 pub async fn put_app_state_mutation_macs_for_device(
1049 &self,
1050 name: &str,
1051 version: u64,
1052 mutations: &[AppStateMutationMAC],
1053 device_id: i32,
1054 ) -> Result<()> {
1055 if mutations.is_empty() {
1056 return Ok(());
1057 }
1058 let name = name.to_string();
1059 let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
1060 self.with_retry("put_app_state_mutation_macs", || {
1061 let name = name.clone();
1062 let mutations = mutations.clone();
1063 Box::new(move |conn: &mut SqliteConnection| {
1064 let records: Vec<_> = mutations
1065 .iter()
1066 .map(|m| {
1067 (
1068 app_state_mutation_macs::name.eq(&name),
1069 app_state_mutation_macs::version.eq(version as i64),
1070 app_state_mutation_macs::index_mac.eq(&m.index_mac),
1071 app_state_mutation_macs::value_mac.eq(&m.value_mac),
1072 app_state_mutation_macs::device_id.eq(device_id),
1073 )
1074 })
1075 .collect();
1076
1077 const CHUNK_SIZE: usize = 100;
1080
1081 for chunk in records.chunks(CHUNK_SIZE) {
1082 diesel::insert_into(app_state_mutation_macs::table)
1083 .values(chunk)
1084 .on_conflict((
1085 app_state_mutation_macs::name,
1086 app_state_mutation_macs::index_mac,
1087 app_state_mutation_macs::device_id,
1088 ))
1089 .do_update()
1090 .set((
1091 app_state_mutation_macs::version
1092 .eq(excluded(app_state_mutation_macs::version)),
1093 app_state_mutation_macs::value_mac
1094 .eq(excluded(app_state_mutation_macs::value_mac)),
1095 ))
1096 .execute(conn)?;
1097 }
1098 Ok(())
1099 })
1100 })
1101 .await
1102 }
1103
1104 pub async fn delete_app_state_mutation_macs_for_device(
1105 &self,
1106 name: &str,
1107 index_macs: &[Vec<u8>],
1108 device_id: i32,
1109 ) -> Result<()> {
1110 if index_macs.is_empty() {
1111 return Ok(());
1112 }
1113 let name = name.to_string();
1114 let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
1115 self.with_retry("delete_app_state_mutation_macs", || {
1116 let name = name.clone();
1117 let index_macs = index_macs.clone();
1118 Box::new(move |conn: &mut SqliteConnection| {
1119 const CHUNK_SIZE: usize = 500;
1122
1123 for chunk in index_macs.chunks(CHUNK_SIZE) {
1124 diesel::delete(
1125 app_state_mutation_macs::table.filter(
1126 app_state_mutation_macs::name
1127 .eq(&name)
1128 .and(app_state_mutation_macs::index_mac.eq_any(chunk))
1129 .and(app_state_mutation_macs::device_id.eq(device_id)),
1130 ),
1131 )
1132 .execute(conn)?;
1133 }
1134 Ok(())
1135 })
1136 })
1137 .await
1138 }
1139
1140 pub async fn get_app_state_mutation_mac_for_device(
1141 &self,
1142 name: &str,
1143 index_mac: &[u8],
1144 device_id: i32,
1145 ) -> Result<Option<Vec<u8>>> {
1146 let pool = self.pool.clone();
1147 let name = name.to_string();
1148 let index_mac = index_mac.to_vec();
1149 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1150 let mut conn = pool
1151 .get()
1152 .map_err(|e| StoreError::Connection(e.to_string()))?;
1153 let res: Option<Vec<u8>> = app_state_mutation_macs::table
1154 .select(app_state_mutation_macs::value_mac)
1155 .filter(app_state_mutation_macs::name.eq(&name))
1156 .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1157 .filter(app_state_mutation_macs::device_id.eq(device_id))
1158 .first(&mut conn)
1159 .optional()
1160 .map_err(|e| StoreError::Database(e.to_string()))?;
1161 Ok(res)
1162 })
1163 .await
1164 .map_err(|e| StoreError::Database(e.to_string()))?
1165 }
1166}
1167
1168#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1169#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1170impl SignalStore for SqliteStore {
1171 async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1172 self.put_identity_for_device(address, key, self.device_id)
1173 .await
1174 }
1175
1176 async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1177 self.load_identity_for_device(address, self.device_id).await
1178 }
1179
1180 async fn delete_identity(&self, address: &str) -> Result<()> {
1181 self.delete_identity_for_device(address, self.device_id)
1182 .await
1183 }
1184
1185 async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1186 self.get_session_for_device(address, self.device_id).await
1187 }
1188
1189 async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1190 self.put_session_for_device(address, session, self.device_id)
1191 .await
1192 }
1193
1194 async fn delete_session(&self, address: &str) -> Result<()> {
1195 self.delete_session_for_device(address, self.device_id)
1196 .await
1197 }
1198
1199 async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1200 let pool = self.pool.clone();
1201 let db_semaphore = self.db_semaphore.clone();
1202 let device_id = self.device_id;
1203 let record = record.to_vec();
1204
1205 const MAX_RETRIES: u32 = 5;
1206
1207 for attempt in 0..=MAX_RETRIES {
1208 let permit =
1209 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1210 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1211 })?;
1212
1213 let pool_clone = pool.clone();
1214 let record_clone = record.clone();
1215
1216 let result =
1217 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1218 let mut conn = pool_clone
1219 .get()
1220 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1221 diesel::insert_into(prekeys::table)
1222 .values((
1223 prekeys::id.eq(id as i32),
1224 prekeys::key.eq(&record_clone),
1225 prekeys::uploaded.eq(uploaded),
1226 prekeys::device_id.eq(device_id),
1227 ))
1228 .on_conflict((prekeys::id, prekeys::device_id))
1229 .do_update()
1230 .set((
1231 prekeys::key.eq(&record_clone),
1232 prekeys::uploaded.eq(uploaded),
1233 ))
1234 .execute(&mut conn)
1235 .map_err(DieselOrStore::Diesel)?;
1236 Ok(())
1237 })
1238 .await;
1239
1240 drop(permit);
1241
1242 match result {
1243 Ok(Ok(())) => return Ok(()),
1244 Ok(Err(DieselOrStore::Diesel(ref e)))
1245 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1246 {
1247 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1248 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1249 }
1250 Ok(Err(e)) => return Err(e.into()),
1251 Err(e) => return Err(StoreError::Database(e.to_string())),
1252 }
1253 }
1254
1255 Err(StoreError::Database(
1256 "store_prekey exhausted retries".to_string(),
1257 ))
1258 }
1259
1260 async fn store_prekeys_batch(&self, keys: &[(u32, Vec<u8>)], uploaded: bool) -> Result<()> {
1261 if keys.is_empty() {
1262 return Ok(());
1263 }
1264
1265 let pool = self.pool.clone();
1266 let db_semaphore = self.db_semaphore.clone();
1267 let device_id = self.device_id;
1268 let keys = keys.to_vec();
1269
1270 const MAX_RETRIES: u32 = 5;
1271
1272 for attempt in 0..=MAX_RETRIES {
1273 let permit =
1274 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1275 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1276 })?;
1277
1278 let pool_clone = pool.clone();
1279 let keys_clone = keys.clone();
1280
1281 let result =
1282 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1283 let mut conn = pool_clone
1284 .get()
1285 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1286
1287 conn.transaction(|conn| {
1288 for (id, record) in &keys_clone {
1289 diesel::insert_into(prekeys::table)
1290 .values((
1291 prekeys::id.eq(*id as i32),
1292 prekeys::key.eq(record),
1293 prekeys::uploaded.eq(uploaded),
1294 prekeys::device_id.eq(device_id),
1295 ))
1296 .on_conflict((prekeys::id, prekeys::device_id))
1297 .do_update()
1298 .set((prekeys::key.eq(record), prekeys::uploaded.eq(uploaded)))
1299 .execute(conn)?;
1300 }
1301 Ok::<(), diesel::result::Error>(())
1302 })
1303 .map_err(DieselOrStore::Diesel)
1304 })
1305 .await;
1306
1307 drop(permit);
1308
1309 match result {
1310 Ok(Ok(())) => return Ok(()),
1311 Ok(Err(DieselOrStore::Diesel(ref e)))
1312 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1313 {
1314 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1315 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1316 }
1317 Ok(Err(e)) => return Err(e.into()),
1318 Err(e) => return Err(StoreError::Database(e.to_string())),
1319 }
1320 }
1321
1322 Err(StoreError::Database(
1323 "store_prekeys_batch exhausted retries".to_string(),
1324 ))
1325 }
1326
1327 async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1328 let pool = self.pool.clone();
1329 let device_id = self.device_id;
1330 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1331 let mut conn = pool
1332 .get()
1333 .map_err(|e| StoreError::Connection(e.to_string()))?;
1334 let res: Option<Vec<u8>> = prekeys::table
1335 .select(prekeys::key)
1336 .filter(prekeys::id.eq(id as i32))
1337 .filter(prekeys::device_id.eq(device_id))
1338 .first(&mut conn)
1339 .optional()
1340 .map_err(|e| StoreError::Database(e.to_string()))?;
1341 Ok(res)
1342 })
1343 .await
1344 .map_err(|e| StoreError::Database(e.to_string()))?
1345 }
1346
1347 async fn remove_prekey(&self, id: u32) -> Result<()> {
1348 let pool = self.pool.clone();
1349 let db_semaphore = self.db_semaphore.clone();
1350 let device_id = self.device_id;
1351
1352 const MAX_RETRIES: u32 = 5;
1353
1354 for attempt in 0..=MAX_RETRIES {
1355 let permit =
1356 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1357 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1358 })?;
1359
1360 let pool_clone = pool.clone();
1361
1362 let result =
1363 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1364 let mut conn = pool_clone
1365 .get()
1366 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1367 diesel::delete(
1368 prekeys::table
1369 .filter(prekeys::id.eq(id as i32))
1370 .filter(prekeys::device_id.eq(device_id)),
1371 )
1372 .execute(&mut conn)
1373 .map_err(DieselOrStore::Diesel)?;
1374 Ok(())
1375 })
1376 .await;
1377
1378 drop(permit);
1379
1380 match result {
1381 Ok(Ok(())) => return Ok(()),
1382 Ok(Err(DieselOrStore::Diesel(ref e)))
1383 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1384 {
1385 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1386 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1387 }
1388 Ok(Err(e)) => return Err(e.into()),
1389 Err(e) => return Err(StoreError::Database(e.to_string())),
1390 }
1391 }
1392
1393 Err(StoreError::Database(
1394 "remove_prekey exhausted retries".to_string(),
1395 ))
1396 }
1397
1398 async fn get_max_prekey_id(&self) -> Result<u32> {
1399 let pool = self.pool.clone();
1400 let device_id = self.device_id;
1401 let db_semaphore = self.db_semaphore.clone();
1402 let _permit = db_semaphore
1403 .acquire()
1404 .await
1405 .map_err(|e| StoreError::Database(format!("Failed to acquire semaphore: {}", e)))?;
1406
1407 tokio::task::spawn_blocking(move || -> Result<u32> {
1408 let mut conn = pool
1409 .get()
1410 .map_err(|e| StoreError::Connection(e.to_string()))?;
1411 use diesel::dsl::max;
1412 let result: Option<i32> = prekeys::table
1413 .filter(prekeys::device_id.eq(device_id))
1414 .select(max(prekeys::id))
1415 .first(&mut conn)
1416 .map_err(|e| StoreError::Database(e.to_string()))?;
1417 Ok(result.unwrap_or(0) as u32)
1418 })
1419 .await
1420 .map_err(|e| StoreError::Database(e.to_string()))?
1421 }
1422
1423 async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1424 let pool = self.pool.clone();
1425 let db_semaphore = self.db_semaphore.clone();
1426 let device_id = self.device_id;
1427 let record = record.to_vec();
1428
1429 const MAX_RETRIES: u32 = 5;
1430
1431 for attempt in 0..=MAX_RETRIES {
1432 let permit =
1433 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1434 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1435 })?;
1436
1437 let pool_clone = pool.clone();
1438 let record_clone = record.clone();
1439
1440 let result =
1441 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1442 let mut conn = pool_clone
1443 .get()
1444 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1445 diesel::insert_into(signed_prekeys::table)
1446 .values((
1447 signed_prekeys::id.eq(id as i32),
1448 signed_prekeys::record.eq(&record_clone),
1449 signed_prekeys::device_id.eq(device_id),
1450 ))
1451 .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1452 .do_update()
1453 .set(signed_prekeys::record.eq(&record_clone))
1454 .execute(&mut conn)
1455 .map_err(DieselOrStore::Diesel)?;
1456 Ok(())
1457 })
1458 .await;
1459
1460 drop(permit);
1461
1462 match result {
1463 Ok(Ok(())) => return Ok(()),
1464 Ok(Err(DieselOrStore::Diesel(ref e)))
1465 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1466 {
1467 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1468 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1469 }
1470 Ok(Err(e)) => return Err(e.into()),
1471 Err(e) => return Err(StoreError::Database(e.to_string())),
1472 }
1473 }
1474
1475 Err(StoreError::Database(
1476 "store_signed_prekey exhausted retries".to_string(),
1477 ))
1478 }
1479
1480 async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1481 let pool = self.pool.clone();
1482 let device_id = self.device_id;
1483 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1484 let mut conn = pool
1485 .get()
1486 .map_err(|e| StoreError::Connection(e.to_string()))?;
1487 let res: Option<Vec<u8>> = signed_prekeys::table
1488 .select(signed_prekeys::record)
1489 .filter(signed_prekeys::id.eq(id as i32))
1490 .filter(signed_prekeys::device_id.eq(device_id))
1491 .first(&mut conn)
1492 .optional()
1493 .map_err(|e| StoreError::Database(e.to_string()))?;
1494 Ok(res)
1495 })
1496 .await
1497 .map_err(|e| StoreError::Database(e.to_string()))?
1498 }
1499
1500 async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1501 let pool = self.pool.clone();
1502 let device_id = self.device_id;
1503 tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1504 let mut conn = pool
1505 .get()
1506 .map_err(|e| StoreError::Connection(e.to_string()))?;
1507 let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1508 .select((signed_prekeys::id, signed_prekeys::record))
1509 .filter(signed_prekeys::device_id.eq(device_id))
1510 .load(&mut conn)
1511 .map_err(|e| StoreError::Database(e.to_string()))?;
1512 Ok(results
1513 .into_iter()
1514 .map(|(id, record)| (id as u32, record))
1515 .collect())
1516 })
1517 .await
1518 .map_err(|e| StoreError::Database(e.to_string()))?
1519 }
1520
1521 async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1522 let pool = self.pool.clone();
1523 let db_semaphore = self.db_semaphore.clone();
1524 let device_id = self.device_id;
1525
1526 const MAX_RETRIES: u32 = 5;
1527
1528 for attempt in 0..=MAX_RETRIES {
1529 let permit =
1530 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1531 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1532 })?;
1533
1534 let pool_clone = pool.clone();
1535
1536 let result =
1537 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1538 let mut conn = pool_clone
1539 .get()
1540 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1541 diesel::delete(
1542 signed_prekeys::table
1543 .filter(signed_prekeys::id.eq(id as i32))
1544 .filter(signed_prekeys::device_id.eq(device_id)),
1545 )
1546 .execute(&mut conn)
1547 .map_err(DieselOrStore::Diesel)?;
1548 Ok(())
1549 })
1550 .await;
1551
1552 drop(permit);
1553
1554 match result {
1555 Ok(Ok(())) => return Ok(()),
1556 Ok(Err(DieselOrStore::Diesel(ref e)))
1557 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1558 {
1559 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1560 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1561 }
1562 Ok(Err(e)) => return Err(e.into()),
1563 Err(e) => return Err(StoreError::Database(e.to_string())),
1564 }
1565 }
1566
1567 Err(StoreError::Database(
1568 "remove_signed_prekey exhausted retries".to_string(),
1569 ))
1570 }
1571
1572 async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1573 self.put_sender_key_for_device(address, record, self.device_id)
1574 .await
1575 }
1576
1577 async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1578 self.get_sender_key_for_device(address, self.device_id)
1579 .await
1580 }
1581
1582 async fn delete_sender_key(&self, address: &str) -> Result<()> {
1583 self.delete_sender_key_for_device(address, self.device_id)
1584 .await
1585 }
1586}
1587
1588#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1589#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1590impl AppSyncStore for SqliteStore {
1591 async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1592 self.get_app_state_sync_key_for_device(key_id, self.device_id)
1593 .await
1594 }
1595
1596 async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1597 self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1598 .await
1599 }
1600
1601 async fn get_version(&self, name: &str) -> Result<HashState> {
1602 self.get_app_state_version_for_device(name, self.device_id)
1603 .await
1604 }
1605
1606 async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1607 self.set_app_state_version_for_device(name, state, self.device_id)
1608 .await
1609 }
1610
1611 async fn put_mutation_macs(
1612 &self,
1613 name: &str,
1614 version: u64,
1615 mutations: &[AppStateMutationMAC],
1616 ) -> Result<()> {
1617 self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1618 .await
1619 }
1620
1621 async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1622 self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1623 .await
1624 }
1625
1626 async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1627 self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1628 .await
1629 }
1630
1631 async fn get_latest_sync_key_id(&self) -> Result<Option<Vec<u8>>> {
1632 self.get_latest_app_state_sync_key_id_for_device(self.device_id)
1633 .await
1634 }
1635}
1636
1637#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1638#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1639impl ProtocolStore for SqliteStore {
1640 async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<Jid>> {
1641 let pool = self.pool.clone();
1642 let device_id = self.device_id;
1643 let group_jid = group_jid.to_string();
1644 tokio::task::spawn_blocking(move || -> Result<Vec<Jid>> {
1645 let mut conn = pool
1646 .get()
1647 .map_err(|e| StoreError::Connection(e.to_string()))?;
1648 let recipients: Vec<String> = skdm_recipients::table
1649 .select(skdm_recipients::device_jid)
1650 .filter(skdm_recipients::group_jid.eq(&group_jid))
1651 .filter(skdm_recipients::device_id.eq(device_id))
1652 .load(&mut conn)
1653 .map_err(|e| StoreError::Database(e.to_string()))?;
1654 let jids: Vec<Jid> = recipients
1655 .iter()
1656 .filter_map(|s| match s.parse::<Jid>() {
1657 Ok(jid) => Some(jid),
1658 Err(e) => {
1659 warn!("Failed to parse SKDM recipient '{}': {}", s, e);
1660 None
1661 }
1662 })
1663 .collect();
1664 Ok(jids)
1665 })
1666 .await
1667 .map_err(|e| StoreError::Database(e.to_string()))?
1668 }
1669
1670 async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> Result<()> {
1671 if device_jids.is_empty() {
1672 return Ok(());
1673 }
1674 let pool = self.pool.clone();
1675 let device_id = self.device_id;
1676 let group_jid = group_jid.to_string();
1677 let device_jid_strs: Vec<String> = device_jids.iter().map(|j| j.to_string()).collect();
1678 let now = std::time::SystemTime::now()
1679 .duration_since(std::time::UNIX_EPOCH)
1680 .unwrap_or_default()
1681 .as_secs() as i32;
1682 tokio::task::spawn_blocking(move || -> Result<()> {
1683 let mut conn = pool
1684 .get()
1685 .map_err(|e| StoreError::Connection(e.to_string()))?;
1686
1687 let values: Vec<_> = device_jid_strs
1688 .iter()
1689 .map(|device_jid| {
1690 (
1691 skdm_recipients::group_jid.eq(&group_jid),
1692 skdm_recipients::device_jid.eq(device_jid),
1693 skdm_recipients::device_id.eq(device_id),
1694 skdm_recipients::created_at.eq(now),
1695 )
1696 })
1697 .collect();
1698
1699 const CHUNK_SIZE: usize = 200; for chunk in values.chunks(CHUNK_SIZE) {
1702 diesel::insert_into(skdm_recipients::table)
1703 .values(chunk)
1704 .on_conflict((
1705 skdm_recipients::group_jid,
1706 skdm_recipients::device_jid,
1707 skdm_recipients::device_id,
1708 ))
1709 .do_nothing()
1710 .execute(&mut conn)
1711 .map_err(|e| StoreError::Database(e.to_string()))?;
1712 }
1713 Ok(())
1714 })
1715 .await
1716 .map_err(|e| StoreError::Database(e.to_string()))??;
1717 Ok(())
1718 }
1719
1720 async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1721 let pool = self.pool.clone();
1722 let device_id = self.device_id;
1723 let group_jid = group_jid.to_string();
1724 tokio::task::spawn_blocking(move || -> Result<()> {
1725 let mut conn = pool
1726 .get()
1727 .map_err(|e| StoreError::Connection(e.to_string()))?;
1728 diesel::delete(
1729 skdm_recipients::table
1730 .filter(skdm_recipients::group_jid.eq(&group_jid))
1731 .filter(skdm_recipients::device_id.eq(device_id)),
1732 )
1733 .execute(&mut conn)
1734 .map_err(|e| StoreError::Database(e.to_string()))?;
1735 Ok(())
1736 })
1737 .await
1738 .map_err(|e| StoreError::Database(e.to_string()))??;
1739 Ok(())
1740 }
1741
1742 async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1743 let pool = self.pool.clone();
1744 let device_id = self.device_id;
1745 let lid = lid.to_string();
1746 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1747 let mut conn = pool
1748 .get()
1749 .map_err(|e| StoreError::Connection(e.to_string()))?;
1750 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1751 .select((
1752 lid_pn_mapping::lid,
1753 lid_pn_mapping::phone_number,
1754 lid_pn_mapping::created_at,
1755 lid_pn_mapping::learning_source,
1756 lid_pn_mapping::updated_at,
1757 ))
1758 .filter(lid_pn_mapping::lid.eq(&lid))
1759 .filter(lid_pn_mapping::device_id.eq(device_id))
1760 .first(&mut conn)
1761 .optional()
1762 .map_err(|e| StoreError::Database(e.to_string()))?;
1763 Ok(row.map(
1764 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1765 lid,
1766 phone_number,
1767 created_at,
1768 updated_at,
1769 learning_source,
1770 },
1771 ))
1772 })
1773 .await
1774 .map_err(|e| StoreError::Database(e.to_string()))?
1775 }
1776
1777 async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1778 let pool = self.pool.clone();
1779 let device_id = self.device_id;
1780 let phone = phone.to_string();
1781 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1782 let mut conn = pool
1783 .get()
1784 .map_err(|e| StoreError::Connection(e.to_string()))?;
1785 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1786 .select((
1787 lid_pn_mapping::lid,
1788 lid_pn_mapping::phone_number,
1789 lid_pn_mapping::created_at,
1790 lid_pn_mapping::learning_source,
1791 lid_pn_mapping::updated_at,
1792 ))
1793 .filter(lid_pn_mapping::phone_number.eq(&phone))
1794 .filter(lid_pn_mapping::device_id.eq(device_id))
1795 .order(lid_pn_mapping::updated_at.desc())
1796 .first(&mut conn)
1797 .optional()
1798 .map_err(|e| StoreError::Database(e.to_string()))?;
1799 Ok(row.map(
1800 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1801 lid,
1802 phone_number,
1803 created_at,
1804 updated_at,
1805 learning_source,
1806 },
1807 ))
1808 })
1809 .await
1810 .map_err(|e| StoreError::Database(e.to_string()))?
1811 }
1812
1813 async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1814 let pool = self.pool.clone();
1815 let device_id = self.device_id;
1816 let entry = entry.clone();
1817 tokio::task::spawn_blocking(move || -> Result<()> {
1818 let mut conn = pool
1819 .get()
1820 .map_err(|e| StoreError::Connection(e.to_string()))?;
1821 diesel::insert_into(lid_pn_mapping::table)
1822 .values((
1823 lid_pn_mapping::lid.eq(&entry.lid),
1824 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1825 lid_pn_mapping::created_at.eq(entry.created_at),
1826 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1827 lid_pn_mapping::updated_at.eq(entry.updated_at),
1828 lid_pn_mapping::device_id.eq(device_id),
1829 ))
1830 .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1831 .do_update()
1832 .set((
1833 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1834 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1835 lid_pn_mapping::updated_at.eq(entry.updated_at),
1836 ))
1837 .execute(&mut conn)
1838 .map_err(|e| StoreError::Database(e.to_string()))?;
1839 Ok(())
1840 })
1841 .await
1842 .map_err(|e| StoreError::Database(e.to_string()))??;
1843 Ok(())
1844 }
1845
1846 async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1847 let pool = self.pool.clone();
1848 let device_id = self.device_id;
1849 tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1850 let mut conn = pool
1851 .get()
1852 .map_err(|e| StoreError::Connection(e.to_string()))?;
1853 let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1854 .select((
1855 lid_pn_mapping::lid,
1856 lid_pn_mapping::phone_number,
1857 lid_pn_mapping::created_at,
1858 lid_pn_mapping::learning_source,
1859 lid_pn_mapping::updated_at,
1860 ))
1861 .filter(lid_pn_mapping::device_id.eq(device_id))
1862 .load(&mut conn)
1863 .map_err(|e| StoreError::Database(e.to_string()))?;
1864 Ok(rows
1865 .into_iter()
1866 .map(
1867 |(lid, phone_number, created_at, learning_source, updated_at)| {
1868 LidPnMappingEntry {
1869 lid,
1870 phone_number,
1871 created_at,
1872 updated_at,
1873 learning_source,
1874 }
1875 },
1876 )
1877 .collect())
1878 })
1879 .await
1880 .map_err(|e| StoreError::Database(e.to_string()))?
1881 }
1882
1883 async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1884 let pool = self.pool.clone();
1885 let device_id = self.device_id;
1886 let address = address.to_string();
1887 let message_id = message_id.to_string();
1888 let base_key = base_key.to_vec();
1889 let now = std::time::SystemTime::now()
1890 .duration_since(std::time::UNIX_EPOCH)
1891 .unwrap_or_default()
1892 .as_secs() as i32;
1893 tokio::task::spawn_blocking(move || -> Result<()> {
1894 let mut conn = pool
1895 .get()
1896 .map_err(|e| StoreError::Connection(e.to_string()))?;
1897 diesel::insert_into(base_keys::table)
1898 .values((
1899 base_keys::address.eq(&address),
1900 base_keys::message_id.eq(&message_id),
1901 base_keys::base_key.eq(&base_key),
1902 base_keys::device_id.eq(device_id),
1903 base_keys::created_at.eq(now),
1904 ))
1905 .on_conflict((
1906 base_keys::address,
1907 base_keys::message_id,
1908 base_keys::device_id,
1909 ))
1910 .do_update()
1911 .set(base_keys::base_key.eq(&base_key))
1912 .execute(&mut conn)
1913 .map_err(|e| StoreError::Database(e.to_string()))?;
1914 Ok(())
1915 })
1916 .await
1917 .map_err(|e| StoreError::Database(e.to_string()))??;
1918 Ok(())
1919 }
1920
1921 async fn has_same_base_key(
1922 &self,
1923 address: &str,
1924 message_id: &str,
1925 current_base_key: &[u8],
1926 ) -> Result<bool> {
1927 let pool = self.pool.clone();
1928 let device_id = self.device_id;
1929 let address = address.to_string();
1930 let message_id = message_id.to_string();
1931 let current_base_key = current_base_key.to_vec();
1932 tokio::task::spawn_blocking(move || -> Result<bool> {
1933 let mut conn = pool
1934 .get()
1935 .map_err(|e| StoreError::Connection(e.to_string()))?;
1936 let stored_key: Option<Vec<u8>> = base_keys::table
1937 .select(base_keys::base_key)
1938 .filter(base_keys::address.eq(&address))
1939 .filter(base_keys::message_id.eq(&message_id))
1940 .filter(base_keys::device_id.eq(device_id))
1941 .first(&mut conn)
1942 .optional()
1943 .map_err(|e| StoreError::Database(e.to_string()))?;
1944 Ok(stored_key.as_ref() == Some(¤t_base_key))
1945 })
1946 .await
1947 .map_err(|e| StoreError::Database(e.to_string()))?
1948 }
1949
1950 async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1951 let pool = self.pool.clone();
1952 let device_id = self.device_id;
1953 let address = address.to_string();
1954 let message_id = message_id.to_string();
1955 tokio::task::spawn_blocking(move || -> Result<()> {
1956 let mut conn = pool
1957 .get()
1958 .map_err(|e| StoreError::Connection(e.to_string()))?;
1959 diesel::delete(
1960 base_keys::table
1961 .filter(base_keys::address.eq(&address))
1962 .filter(base_keys::message_id.eq(&message_id))
1963 .filter(base_keys::device_id.eq(device_id)),
1964 )
1965 .execute(&mut conn)
1966 .map_err(|e| StoreError::Database(e.to_string()))?;
1967 Ok(())
1968 })
1969 .await
1970 .map_err(|e| StoreError::Database(e.to_string()))??;
1971 Ok(())
1972 }
1973
1974 async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1975 let pool = self.pool.clone();
1976 let device_id = self.device_id;
1977 let devices_json = serde_json::to_string(&record.devices)
1978 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1979 let now = std::time::SystemTime::now()
1980 .duration_since(std::time::UNIX_EPOCH)
1981 .unwrap_or_default()
1982 .as_secs() as i32;
1983 tokio::task::spawn_blocking(move || -> Result<()> {
1984 let mut conn = pool
1985 .get()
1986 .map_err(|e| StoreError::Connection(e.to_string()))?;
1987 diesel::insert_into(device_registry::table)
1988 .values((
1989 device_registry::user_id.eq(&record.user),
1990 device_registry::devices_json.eq(&devices_json),
1991 device_registry::timestamp.eq(record.timestamp as i32),
1992 device_registry::phash.eq(&record.phash),
1993 device_registry::device_id.eq(device_id),
1994 device_registry::updated_at.eq(now),
1995 ))
1996 .on_conflict((device_registry::user_id, device_registry::device_id))
1997 .do_update()
1998 .set((
1999 device_registry::devices_json.eq(&devices_json),
2000 device_registry::timestamp.eq(record.timestamp as i32),
2001 device_registry::phash.eq(&record.phash),
2002 device_registry::updated_at.eq(now),
2003 ))
2004 .execute(&mut conn)
2005 .map_err(|e| StoreError::Database(e.to_string()))?;
2006 Ok(())
2007 })
2008 .await
2009 .map_err(|e| StoreError::Database(e.to_string()))??;
2010 Ok(())
2011 }
2012
2013 async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
2014 let pool = self.pool.clone();
2015 let device_id = self.device_id;
2016 let user = user.to_string();
2017 tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
2018 let mut conn = pool
2019 .get()
2020 .map_err(|e| StoreError::Connection(e.to_string()))?;
2021 let row: Option<(String, String, i32, Option<String>)> = device_registry::table
2022 .select((
2023 device_registry::user_id,
2024 device_registry::devices_json,
2025 device_registry::timestamp,
2026 device_registry::phash,
2027 ))
2028 .filter(device_registry::user_id.eq(&user))
2029 .filter(device_registry::device_id.eq(device_id))
2030 .first(&mut conn)
2031 .optional()
2032 .map_err(|e| StoreError::Database(e.to_string()))?;
2033 match row {
2034 Some((user, devices_json, timestamp, phash)) => {
2035 let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
2036 .map_err(|e| StoreError::Serialization(e.to_string()))?;
2037 Ok(Some(DeviceListRecord {
2038 user,
2039 devices,
2040 timestamp: timestamp as i64,
2041 phash,
2042 }))
2043 }
2044 None => Ok(None),
2045 }
2046 })
2047 .await
2048 .map_err(|e| StoreError::Database(e.to_string()))?
2049 }
2050
2051 async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
2052 let pool = self.pool.clone();
2053 let device_id = self.device_id;
2054 let group_jid = group_jid.to_string();
2055 let participant = participant.to_string();
2056 let now = std::time::SystemTime::now()
2057 .duration_since(std::time::UNIX_EPOCH)
2058 .unwrap_or_default()
2059 .as_secs() as i32;
2060 tokio::task::spawn_blocking(move || -> Result<()> {
2061 let mut conn = pool
2062 .get()
2063 .map_err(|e| StoreError::Connection(e.to_string()))?;
2064 diesel::insert_into(sender_key_status::table)
2065 .values((
2066 sender_key_status::group_jid.eq(&group_jid),
2067 sender_key_status::participant.eq(&participant),
2068 sender_key_status::device_id.eq(device_id),
2069 sender_key_status::marked_at.eq(now),
2070 ))
2071 .on_conflict((
2072 sender_key_status::group_jid,
2073 sender_key_status::participant,
2074 sender_key_status::device_id,
2075 ))
2076 .do_update()
2077 .set(sender_key_status::marked_at.eq(now))
2078 .execute(&mut conn)
2079 .map_err(|e| StoreError::Database(e.to_string()))?;
2080 Ok(())
2081 })
2082 .await
2083 .map_err(|e| StoreError::Database(e.to_string()))??;
2084 Ok(())
2085 }
2086
2087 async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
2088 let pool = self.pool.clone();
2089 let device_id = self.device_id;
2090 let group_jid = group_jid.to_string();
2091 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2092 let mut conn = pool
2093 .get()
2094 .map_err(|e| StoreError::Connection(e.to_string()))?;
2095 let participants: Vec<String> = sender_key_status::table
2096 .select(sender_key_status::participant)
2097 .filter(sender_key_status::group_jid.eq(&group_jid))
2098 .filter(sender_key_status::device_id.eq(device_id))
2099 .load(&mut conn)
2100 .map_err(|e| StoreError::Database(e.to_string()))?;
2101 diesel::delete(
2102 sender_key_status::table
2103 .filter(sender_key_status::group_jid.eq(&group_jid))
2104 .filter(sender_key_status::device_id.eq(device_id)),
2105 )
2106 .execute(&mut conn)
2107 .map_err(|e| StoreError::Database(e.to_string()))?;
2108 Ok(participants)
2109 })
2110 .await
2111 .map_err(|e| StoreError::Database(e.to_string()))?
2112 }
2113
2114 async fn get_tc_token(&self, jid: &str) -> Result<Option<TcTokenEntry>> {
2115 let pool = self.pool.clone();
2116 let device_id = self.device_id;
2117 let jid = jid.to_string();
2118 tokio::task::spawn_blocking(move || -> Result<Option<TcTokenEntry>> {
2119 let mut conn = pool
2120 .get()
2121 .map_err(|e| StoreError::Connection(e.to_string()))?;
2122 let row: Option<(Vec<u8>, i64, Option<i64>)> = tc_tokens::table
2123 .select((
2124 tc_tokens::token,
2125 tc_tokens::token_timestamp,
2126 tc_tokens::sender_timestamp,
2127 ))
2128 .filter(tc_tokens::jid.eq(&jid))
2129 .filter(tc_tokens::device_id.eq(device_id))
2130 .first(&mut conn)
2131 .optional()
2132 .map_err(|e| StoreError::Database(e.to_string()))?;
2133 Ok(
2134 row.map(|(token, token_timestamp, sender_timestamp)| TcTokenEntry {
2135 token,
2136 token_timestamp,
2137 sender_timestamp,
2138 }),
2139 )
2140 })
2141 .await
2142 .map_err(|e| StoreError::Database(e.to_string()))?
2143 }
2144
2145 async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> Result<()> {
2146 let pool = self.pool.clone();
2147 let device_id = self.device_id;
2148 let jid = jid.to_string();
2149 let entry = entry.clone();
2150 let now = std::time::SystemTime::now()
2151 .duration_since(std::time::UNIX_EPOCH)
2152 .unwrap_or_default()
2153 .as_secs() as i64;
2154 tokio::task::spawn_blocking(move || -> Result<()> {
2155 let mut conn = pool
2156 .get()
2157 .map_err(|e| StoreError::Connection(e.to_string()))?;
2158 diesel::insert_into(tc_tokens::table)
2159 .values((
2160 tc_tokens::jid.eq(&jid),
2161 tc_tokens::token.eq(&entry.token),
2162 tc_tokens::token_timestamp.eq(entry.token_timestamp),
2163 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2164 tc_tokens::device_id.eq(device_id),
2165 tc_tokens::updated_at.eq(now),
2166 ))
2167 .on_conflict((tc_tokens::jid, tc_tokens::device_id))
2168 .do_update()
2169 .set((
2170 tc_tokens::token.eq(&entry.token),
2171 tc_tokens::token_timestamp.eq(entry.token_timestamp),
2172 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2173 tc_tokens::updated_at.eq(now),
2174 ))
2175 .execute(&mut conn)
2176 .map_err(|e| StoreError::Database(e.to_string()))?;
2177 Ok(())
2178 })
2179 .await
2180 .map_err(|e| StoreError::Database(e.to_string()))??;
2181 Ok(())
2182 }
2183
2184 async fn delete_tc_token(&self, jid: &str) -> Result<()> {
2185 let pool = self.pool.clone();
2186 let device_id = self.device_id;
2187 let jid = jid.to_string();
2188 tokio::task::spawn_blocking(move || -> Result<()> {
2189 let mut conn = pool
2190 .get()
2191 .map_err(|e| StoreError::Connection(e.to_string()))?;
2192 diesel::delete(
2193 tc_tokens::table
2194 .filter(tc_tokens::jid.eq(&jid))
2195 .filter(tc_tokens::device_id.eq(device_id)),
2196 )
2197 .execute(&mut conn)
2198 .map_err(|e| StoreError::Database(e.to_string()))?;
2199 Ok(())
2200 })
2201 .await
2202 .map_err(|e| StoreError::Database(e.to_string()))??;
2203 Ok(())
2204 }
2205
2206 async fn get_all_tc_token_jids(&self) -> Result<Vec<String>> {
2207 let pool = self.pool.clone();
2208 let device_id = self.device_id;
2209 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2210 let mut conn = pool
2211 .get()
2212 .map_err(|e| StoreError::Connection(e.to_string()))?;
2213 let jids: Vec<String> = tc_tokens::table
2214 .select(tc_tokens::jid)
2215 .filter(tc_tokens::device_id.eq(device_id))
2216 .load(&mut conn)
2217 .map_err(|e| StoreError::Database(e.to_string()))?;
2218 Ok(jids)
2219 })
2220 .await
2221 .map_err(|e| StoreError::Database(e.to_string()))?
2222 }
2223
2224 async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> Result<u32> {
2225 let pool = self.pool.clone();
2226 let device_id = self.device_id;
2227 tokio::task::spawn_blocking(move || -> Result<u32> {
2228 let mut conn = pool
2229 .get()
2230 .map_err(|e| StoreError::Connection(e.to_string()))?;
2231 let deleted = diesel::delete(
2232 tc_tokens::table
2233 .filter(tc_tokens::token_timestamp.lt(cutoff_timestamp))
2234 .filter(tc_tokens::device_id.eq(device_id)),
2235 )
2236 .execute(&mut conn)
2237 .map_err(|e| StoreError::Database(e.to_string()))?;
2238 Ok(deleted as u32)
2239 })
2240 .await
2241 .map_err(|e| StoreError::Database(e.to_string()))?
2242 }
2243
2244 async fn store_sent_message(
2245 &self,
2246 chat_jid: &str,
2247 message_id: &str,
2248 payload: &[u8],
2249 ) -> Result<()> {
2250 let chat_jid = chat_jid.to_string();
2251 let message_id = message_id.to_string();
2252 let payload: Arc<Vec<u8>> = Arc::new(payload.to_vec());
2254 let device_id = self.device_id;
2255 self.with_retry("store_sent_message", || {
2256 let chat_jid = chat_jid.clone();
2257 let message_id = message_id.clone();
2258 let payload = Arc::clone(&payload);
2259 Box::new(move |conn: &mut SqliteConnection| {
2260 diesel::replace_into(sent_messages::table)
2261 .values((
2262 sent_messages::chat_jid.eq(&chat_jid),
2263 sent_messages::message_id.eq(&message_id),
2264 sent_messages::payload.eq(payload.as_slice()),
2265 sent_messages::device_id.eq(device_id),
2266 ))
2267 .execute(conn)?;
2268 Ok(())
2269 })
2270 })
2271 .await
2272 }
2273
2274 async fn take_sent_message(&self, chat_jid: &str, message_id: &str) -> Result<Option<Vec<u8>>> {
2275 let chat_jid = chat_jid.to_string();
2276 let message_id = message_id.to_string();
2277 let device_id = self.device_id;
2278 self.with_retry("take_sent_message", || {
2280 let chat_jid = chat_jid.clone();
2281 let message_id = message_id.clone();
2282 Box::new(move |conn: &mut SqliteConnection| {
2283 conn.immediate_transaction(|conn| {
2284 let row: Option<Vec<u8>> = sent_messages::table
2285 .select(sent_messages::payload)
2286 .filter(sent_messages::chat_jid.eq(&chat_jid))
2287 .filter(sent_messages::message_id.eq(&message_id))
2288 .filter(sent_messages::device_id.eq(device_id))
2289 .first(conn)
2290 .optional()?;
2291 if row.is_some() {
2292 diesel::delete(
2293 sent_messages::table
2294 .filter(sent_messages::chat_jid.eq(&chat_jid))
2295 .filter(sent_messages::message_id.eq(&message_id))
2296 .filter(sent_messages::device_id.eq(device_id)),
2297 )
2298 .execute(conn)?;
2299 }
2300 Ok(row)
2301 })
2302 })
2303 })
2304 .await
2305 }
2306
2307 async fn delete_expired_sent_messages(&self, cutoff_timestamp: i64) -> Result<u32> {
2308 let pool = self.pool.clone();
2309 let device_id = self.device_id;
2310 tokio::task::spawn_blocking(move || -> Result<u32> {
2311 let mut conn = pool
2312 .get()
2313 .map_err(|e| StoreError::Connection(e.to_string()))?;
2314 let deleted = diesel::delete(
2315 sent_messages::table
2316 .filter(sent_messages::created_at.lt(cutoff_timestamp))
2317 .filter(sent_messages::device_id.eq(device_id)),
2318 )
2319 .execute(&mut conn)
2320 .map_err(|e| StoreError::Database(e.to_string()))?;
2321 Ok(deleted as u32)
2322 })
2323 .await
2324 .map_err(|e| StoreError::Database(e.to_string()))?
2325 }
2326}
2327
2328#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
2329#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
2330impl DeviceStore for SqliteStore {
2331 async fn save(&self, device: &CoreDevice) -> Result<()> {
2332 SqliteStore::save_device_data_for_device(self, self.device_id, device).await
2333 }
2334
2335 async fn load(&self) -> Result<Option<CoreDevice>> {
2336 SqliteStore::load_device_data_for_device(self, self.device_id).await
2337 }
2338
2339 async fn exists(&self) -> Result<bool> {
2340 SqliteStore::device_exists(self, self.device_id).await
2341 }
2342
2343 async fn create(&self) -> Result<i32> {
2344 SqliteStore::create_new_device(self).await
2345 }
2346
2347 async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> Result<()> {
2348 fn sanitize_snapshot_name(name: &str) -> Result<String> {
2349 const MAX_LENGTH: usize = 100;
2350
2351 let sanitized: String = name
2352 .chars()
2353 .map(|c| {
2354 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' {
2355 c
2356 } else {
2357 '_'
2358 }
2359 })
2360 .collect();
2361
2362 let sanitized = sanitized
2363 .split('.')
2364 .filter(|part| !part.is_empty() && *part != "..")
2365 .collect::<Vec<_>>()
2366 .join(".");
2367
2368 let sanitized = sanitized.trim_matches(['/', '\\', '.']);
2369
2370 if sanitized.is_empty() {
2371 return Err(StoreError::Database(
2372 "Snapshot name cannot be empty after sanitization".to_string(),
2373 ));
2374 }
2375
2376 if sanitized.len() > MAX_LENGTH {
2377 return Err(StoreError::Database(format!(
2378 "Snapshot name exceeds maximum length of {} characters",
2379 MAX_LENGTH
2380 )));
2381 }
2382
2383 Ok(sanitized.to_string())
2384 }
2385
2386 let sanitized_name = sanitize_snapshot_name(name)?;
2387
2388 let pool = self.pool.clone();
2389 let db_path = self.database_path.clone();
2390 let extra_data = extra_content.map(|b| b.to_vec());
2391
2392 tokio::task::spawn_blocking(move || -> Result<()> {
2393 let mut conn = pool
2394 .get()
2395 .map_err(|e| StoreError::Connection(e.to_string()))?;
2396
2397 let timestamp = std::time::SystemTime::now()
2398 .duration_since(std::time::UNIX_EPOCH)
2399 .unwrap_or_default()
2400 .as_secs();
2401
2402 let target_path = format!("{}.snapshot-{}-{}", db_path, timestamp, sanitized_name);
2404
2405 let query = format!("VACUUM INTO '{}'", target_path.replace("'", "''"));
2408
2409 diesel::sql_query(query)
2410 .execute(&mut conn)
2411 .map_err(|e| StoreError::Database(e.to_string()))?;
2412
2413 if let Some(data) = extra_data {
2415 let extra_path = format!("{}.json", target_path);
2416 std::fs::write(&extra_path, data).map_err(|e| {
2417 StoreError::Database(format!("Failed to write snapshot extra content: {}", e))
2418 })?;
2419 }
2420
2421 Ok(())
2422 })
2423 .await
2424 .map_err(|e| StoreError::Database(e.to_string()))??;
2425
2426 Ok(())
2427 }
2428}
2429
2430#[cfg(test)]
2431mod tests {
2432 use super::*;
2433
2434 async fn create_test_store() -> SqliteStore {
2435 use std::sync::atomic::{AtomicU64, Ordering};
2436 static COUNTER: AtomicU64 = AtomicU64::new(0);
2437 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
2438 let db_name = format!(
2439 "file:memdb_test_{}_{}?mode=memory&cache=shared",
2440 std::process::id(),
2441 id
2442 );
2443 SqliteStore::new(&db_name)
2444 .await
2445 .expect("Failed to create test store")
2446 }
2447
2448 #[test]
2449 fn test_parse_database_path_regular_path() {
2450 let path = "/var/lib/whatsapp/database.db";
2451 let result = parse_database_path(path).unwrap();
2452 assert_eq!(result, "/var/lib/whatsapp/database.db");
2453 }
2454
2455 #[test]
2456 fn test_parse_database_path_with_sqlite_prefix() {
2457 let path = "sqlite:///var/lib/whatsapp/database.db";
2458 let result = parse_database_path(path).unwrap();
2459 assert_eq!(result, "/var/lib/whatsapp/database.db");
2460 }
2461
2462 #[test]
2463 fn test_parse_database_path_with_query_params() {
2464 let path = "file:database.db?mode=memory&cache=shared";
2465 let result = parse_database_path(path).unwrap();
2466 assert_eq!(result, "file:database.db");
2467 }
2468
2469 #[test]
2470 fn test_parse_database_path_with_fragment() {
2471 let path = "file:database.db#fragment";
2472 let result = parse_database_path(path).unwrap();
2473 assert_eq!(result, "file:database.db");
2474 }
2475
2476 #[test]
2477 fn test_parse_database_path_with_both_query_and_fragment() {
2478 let path = "sqlite:///var/lib/database.db?mode=ro#backup";
2479 let result = parse_database_path(path).unwrap();
2480 assert_eq!(result, "/var/lib/database.db");
2481 }
2482
2483 #[test]
2484 fn test_parse_database_path_in_memory_rejected() {
2485 let result = parse_database_path(":memory:");
2486 assert!(result.is_err());
2487 assert!(result.unwrap_err().to_string().contains("not supported"));
2488 }
2489
2490 #[test]
2491 fn test_parse_database_path_in_memory_with_query_rejected() {
2492 let result = parse_database_path(":memory:?cache=shared");
2493 assert!(result.is_err());
2494 assert!(result.unwrap_err().to_string().contains("not supported"));
2495 }
2496
2497 #[tokio::test]
2498 async fn test_device_registry_save_and_get() {
2499 let store = create_test_store().await;
2500
2501 let record = DeviceListRecord {
2502 user: "1234567890".to_string(),
2503 devices: vec![
2504 DeviceInfo {
2505 device_id: 0,
2506 key_index: None,
2507 },
2508 DeviceInfo {
2509 device_id: 1,
2510 key_index: Some(42),
2511 },
2512 ],
2513 timestamp: 1234567890,
2514 phash: Some("2:abcdef".to_string()),
2515 };
2516
2517 store.update_device_list(record).await.expect("save failed");
2518 let loaded = store
2519 .get_devices("1234567890")
2520 .await
2521 .expect("get failed")
2522 .expect("record should exist");
2523
2524 assert_eq!(loaded.user, "1234567890");
2525 assert_eq!(loaded.devices.len(), 2);
2526 assert_eq!(loaded.devices[0].device_id, 0);
2527 assert_eq!(loaded.devices[1].device_id, 1);
2528 assert_eq!(loaded.devices[1].key_index, Some(42));
2529 assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
2530 }
2531
2532 #[tokio::test]
2533 async fn test_device_registry_update_existing() {
2534 let store = create_test_store().await;
2535
2536 let record1 = DeviceListRecord {
2537 user: "1234567890".to_string(),
2538 devices: vec![DeviceInfo {
2539 device_id: 0,
2540 key_index: None,
2541 }],
2542 timestamp: 1000,
2543 phash: Some("2:old".to_string()),
2544 };
2545 store
2546 .update_device_list(record1)
2547 .await
2548 .expect("save1 failed");
2549
2550 let record2 = DeviceListRecord {
2551 user: "1234567890".to_string(),
2552 devices: vec![
2553 DeviceInfo {
2554 device_id: 0,
2555 key_index: None,
2556 },
2557 DeviceInfo {
2558 device_id: 2,
2559 key_index: None,
2560 },
2561 ],
2562 timestamp: 2000,
2563 phash: Some("2:new".to_string()),
2564 };
2565 store
2566 .update_device_list(record2)
2567 .await
2568 .expect("save2 failed");
2569
2570 let loaded = store
2571 .get_devices("1234567890")
2572 .await
2573 .expect("get failed")
2574 .expect("record should exist");
2575
2576 assert_eq!(loaded.devices.len(), 2);
2577 assert_eq!(loaded.phash, Some("2:new".to_string()));
2578 }
2579
2580 #[tokio::test]
2581 async fn test_device_registry_get_nonexistent() {
2582 let store = create_test_store().await;
2583 let result = store.get_devices("nonexistent").await.expect("get failed");
2584 assert!(result.is_none());
2585 }
2586
2587 #[tokio::test]
2588 async fn test_sender_key_status_mark_and_consume() {
2589 let store = create_test_store().await;
2590
2591 let group = "group123@g.us";
2592 let participant = "user1@s.whatsapp.net";
2593
2594 store
2595 .mark_forget_sender_key(group, participant)
2596 .await
2597 .expect("mark failed");
2598
2599 let consumed = store
2600 .consume_forget_marks(group)
2601 .await
2602 .expect("consume failed");
2603 assert_eq!(consumed.len(), 1);
2604 assert!(consumed.contains(&participant.to_string()));
2605
2606 let consumed = store
2607 .consume_forget_marks(group)
2608 .await
2609 .expect("consume failed");
2610 assert!(consumed.is_empty());
2611 }
2612
2613 #[tokio::test]
2614 async fn test_sender_key_status_consume_multiple() {
2615 let store = create_test_store().await;
2616
2617 let group = "group123@g.us";
2618
2619 store
2620 .mark_forget_sender_key(group, "user1@s.whatsapp.net")
2621 .await
2622 .expect("mark failed");
2623 store
2624 .mark_forget_sender_key(group, "user2@s.whatsapp.net")
2625 .await
2626 .expect("mark failed");
2627
2628 let consumed = store
2629 .consume_forget_marks(group)
2630 .await
2631 .expect("consume failed");
2632 assert_eq!(consumed.len(), 2);
2633 assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
2634 assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
2635
2636 let consumed = store
2637 .consume_forget_marks(group)
2638 .await
2639 .expect("consume failed");
2640 assert!(consumed.is_empty());
2641 }
2642
2643 #[tokio::test]
2644 async fn test_tc_token_put_and_get() {
2645 let store = create_test_store().await;
2646
2647 let entry = TcTokenEntry {
2648 token: vec![1, 2, 3, 4, 5],
2649 token_timestamp: 1707000000,
2650 sender_timestamp: Some(1707000100),
2651 };
2652
2653 store
2654 .put_tc_token("user@lid", &entry)
2655 .await
2656 .expect("put failed");
2657
2658 let loaded = store
2659 .get_tc_token("user@lid")
2660 .await
2661 .expect("get failed")
2662 .expect("should exist");
2663
2664 assert_eq!(loaded.token, vec![1, 2, 3, 4, 5]);
2665 assert_eq!(loaded.token_timestamp, 1707000000);
2666 assert_eq!(loaded.sender_timestamp, Some(1707000100));
2667 }
2668
2669 #[tokio::test]
2670 async fn test_tc_token_upsert() {
2671 let store = create_test_store().await;
2672
2673 let entry1 = TcTokenEntry {
2674 token: vec![1, 2, 3],
2675 token_timestamp: 1000,
2676 sender_timestamp: None,
2677 };
2678 store.put_tc_token("user@lid", &entry1).await.unwrap();
2679
2680 let entry2 = TcTokenEntry {
2681 token: vec![4, 5, 6],
2682 token_timestamp: 2000,
2683 sender_timestamp: Some(1500),
2684 };
2685 store.put_tc_token("user@lid", &entry2).await.unwrap();
2686
2687 let loaded = store.get_tc_token("user@lid").await.unwrap().unwrap();
2688 assert_eq!(loaded.token, vec![4, 5, 6]);
2689 assert_eq!(loaded.token_timestamp, 2000);
2690 assert_eq!(loaded.sender_timestamp, Some(1500));
2691 }
2692
2693 #[tokio::test]
2694 async fn test_tc_token_delete() {
2695 let store = create_test_store().await;
2696
2697 let entry = TcTokenEntry {
2698 token: vec![1, 2, 3],
2699 token_timestamp: 1000,
2700 sender_timestamp: None,
2701 };
2702 store.put_tc_token("user@lid", &entry).await.unwrap();
2703 store.delete_tc_token("user@lid").await.unwrap();
2704
2705 let result = store.get_tc_token("user@lid").await.unwrap();
2706 assert!(result.is_none());
2707 }
2708
2709 #[tokio::test]
2710 async fn test_tc_token_get_all_jids() {
2711 let store = create_test_store().await;
2712
2713 let entry = TcTokenEntry {
2714 token: vec![1],
2715 token_timestamp: 1000,
2716 sender_timestamp: None,
2717 };
2718 store.put_tc_token("user1@lid", &entry).await.unwrap();
2719 store.put_tc_token("user2@lid", &entry).await.unwrap();
2720 store.put_tc_token("user3@lid", &entry).await.unwrap();
2721
2722 let mut jids = store.get_all_tc_token_jids().await.unwrap();
2723 jids.sort();
2724 assert_eq!(jids, vec!["user1@lid", "user2@lid", "user3@lid"]);
2725 }
2726
2727 #[tokio::test]
2728 async fn test_tc_token_delete_expired() {
2729 let store = create_test_store().await;
2730
2731 let old = TcTokenEntry {
2732 token: vec![1],
2733 token_timestamp: 1000,
2734 sender_timestamp: None,
2735 };
2736 let recent = TcTokenEntry {
2737 token: vec![2],
2738 token_timestamp: 5000,
2739 sender_timestamp: None,
2740 };
2741 store.put_tc_token("old@lid", &old).await.unwrap();
2742 store.put_tc_token("recent@lid", &recent).await.unwrap();
2743
2744 let deleted = store.delete_expired_tc_tokens(3000).await.unwrap();
2745 assert_eq!(deleted, 1);
2746
2747 assert!(store.get_tc_token("old@lid").await.unwrap().is_none());
2748 assert!(store.get_tc_token("recent@lid").await.unwrap().is_some());
2749 }
2750
2751 #[tokio::test]
2752 async fn test_tc_token_get_nonexistent() {
2753 let store = create_test_store().await;
2754 let result = store.get_tc_token("nonexistent@lid").await.unwrap();
2755 assert!(result.is_none());
2756 }
2757
2758 #[tokio::test]
2759 async fn test_sender_key_status_different_groups() {
2760 let store = create_test_store().await;
2761
2762 let group1 = "group1@g.us";
2763 let group2 = "group2@g.us";
2764 let participant = "user@s.whatsapp.net";
2765
2766 store
2767 .mark_forget_sender_key(group1, participant)
2768 .await
2769 .expect("mark failed");
2770
2771 let consumed = store.consume_forget_marks(group1).await.unwrap();
2772 assert_eq!(consumed.len(), 1);
2773
2774 let consumed = store.consume_forget_marks(group2).await.unwrap();
2775 assert!(consumed.is_empty());
2776 }
2777}