1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::result::{DatabaseErrorKind, Error as DieselError};
6use diesel::sql_query;
7use diesel::sqlite::SqliteConnection;
8use diesel::upsert::excluded;
9use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
10use log::warn;
11use prost::Message;
12use std::sync::Arc;
13use wacore_ng::appstate::hash::HashState;
14use wacore_ng::appstate::processor::AppStateMutationMAC;
15use wacore_ng::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
16use wacore_ng::store::Device as CoreDevice;
17use wacore_ng::store::error::{Result, StoreError};
18use wacore_ng::store::traits::*;
19use wacore_binary_ng::jid::Jid;
20use waproto_ng::whatsapp as wa;
21
22enum DieselOrStore {
26 Diesel(DieselError),
27 Store(StoreError),
28}
29
30impl From<DieselOrStore> for StoreError {
31 fn from(e: DieselOrStore) -> Self {
32 match e {
33 DieselOrStore::Diesel(e) => StoreError::Database(e.to_string()),
34 DieselOrStore::Store(e) => e,
35 }
36 }
37}
38
39fn is_retriable_sqlite_error(error: &DieselError) -> bool {
45 match error {
46 DieselError::DatabaseError(DatabaseErrorKind::Unknown, info) => {
47 let msg = info.message();
48 msg.contains("locked") || msg.contains("busy")
49 }
50 _ => false,
51 }
52}
53
54const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
55
56type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
57type DeviceRow = (
58 i32,
59 String,
60 String,
61 i32,
62 Vec<u8>,
63 Vec<u8>,
64 Vec<u8>,
65 i32,
66 Vec<u8>,
67 Vec<u8>,
68 Option<Vec<u8>>,
69 String,
70 i32,
71 i32,
72 i64,
73 i64,
74 Option<Vec<u8>>,
75 Option<String>,
76 i32,
77);
78
79#[derive(Clone)]
80pub struct SqliteStore {
81 pub(crate) pool: SqlitePool,
82 pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
83 pub(crate) database_path: String,
84 device_id: i32,
85}
86
87#[derive(Debug, Clone, Copy)]
88struct ConnectionOptions;
89
90impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
91 for ConnectionOptions
92{
93 fn on_acquire(
94 &self,
95 conn: &mut SqliteConnection,
96 ) -> std::result::Result<(), diesel::r2d2::Error> {
97 diesel::sql_query("PRAGMA busy_timeout = 30000;")
98 .execute(conn)
99 .map_err(diesel::r2d2::Error::QueryError)?;
100 diesel::sql_query("PRAGMA synchronous = NORMAL;")
101 .execute(conn)
102 .map_err(diesel::r2d2::Error::QueryError)?;
103 diesel::sql_query("PRAGMA cache_size = 512;")
104 .execute(conn)
105 .map_err(diesel::r2d2::Error::QueryError)?;
106 diesel::sql_query("PRAGMA temp_store = memory;")
107 .execute(conn)
108 .map_err(diesel::r2d2::Error::QueryError)?;
109 diesel::sql_query("PRAGMA foreign_keys = ON;")
110 .execute(conn)
111 .map_err(diesel::r2d2::Error::QueryError)?;
112 Ok(())
113 }
114}
115
116fn parse_database_path(database_url: &str) -> Result<String> {
117 if database_url == ":memory:" {
119 return Err(StoreError::Database(
120 "Snapshot not supported for in-memory databases".to_string(),
121 ));
122 }
123
124 let path = database_url
126 .split(['?', '#'])
127 .next()
128 .unwrap_or(database_url);
129
130 let path = path.trim_start_matches("sqlite://");
132
133 if path == ":memory:" || path.starts_with(":memory:?") {
135 return Err(StoreError::Database(
136 "Snapshot not supported for in-memory databases".to_string(),
137 ));
138 }
139
140 Ok(path.to_string())
141}
142
143impl SqliteStore {
144 pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
145 let manager = ConnectionManager::<SqliteConnection>::new(database_url);
146
147 let pool_size = 2;
148
149 let pool = Pool::builder()
150 .max_size(pool_size)
151 .connection_customizer(Box::new(ConnectionOptions))
152 .build(manager)
153 .map_err(|e| StoreError::Connection(e.to_string()))?;
154
155 let pool_clone = pool.clone();
156 tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
157 let mut conn = pool_clone
158 .get()
159 .map_err(|e| StoreError::Connection(e.to_string()))?;
160
161 diesel::sql_query("PRAGMA journal_mode = WAL;")
162 .execute(&mut conn)
163 .map_err(|e| StoreError::Database(e.to_string()))?;
164
165 conn.run_pending_migrations(MIGRATIONS)
166 .map_err(|e| StoreError::Migration(e.to_string()))?;
167
168 Ok(())
169 })
170 .await
171 .map_err(|e| StoreError::Database(e.to_string()))??;
172
173 let database_path = parse_database_path(database_url)?;
174
175 Ok(Self {
176 pool,
177 db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
178 database_path,
179 device_id: 1,
180 })
181 }
182
183 pub async fn new_for_device(
184 database_url: &str,
185 device_id: i32,
186 ) -> std::result::Result<Self, StoreError> {
187 let mut store = Self::new(database_url).await?;
188 store.device_id = device_id;
189 Ok(store)
190 }
191
192 pub fn device_id(&self) -> i32 {
193 self.device_id
194 }
195
196 async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
197 where
198 F: FnOnce() -> Result<T> + Send + 'static,
199 T: Send + 'static,
200 {
201 let permit = self
202 .db_semaphore
203 .clone()
204 .acquire_owned()
205 .await
206 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
207 let result = tokio::task::spawn_blocking(move || {
208 let res = f();
209 drop(permit);
210 res
211 })
212 .await
213 .map_err(|e| StoreError::Database(e.to_string()))??;
214 Ok(result)
215 }
216
217 async fn with_retry<F, T>(&self, op_name: &str, make_op: F) -> Result<T>
221 where
222 F: Fn() -> Box<
223 dyn FnOnce(&mut SqliteConnection) -> std::result::Result<T, DieselError> + Send,
224 >,
225 T: Send + 'static,
226 {
227 const MAX_RETRIES: u32 = 5;
228
229 for attempt in 0..=MAX_RETRIES {
230 let permit = self
231 .db_semaphore
232 .clone()
233 .acquire_owned()
234 .await
235 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
236
237 let pool = self.pool.clone();
238 let op = make_op();
239
240 let result =
241 tokio::task::spawn_blocking(move || -> std::result::Result<T, DieselOrStore> {
242 let _permit = permit;
243 let mut conn = pool
244 .get()
245 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
246 op(&mut conn).map_err(DieselOrStore::Diesel)
247 })
248 .await;
249
250 match result {
251 Ok(Ok(val)) => return Ok(val),
252 Ok(Err(DieselOrStore::Diesel(ref e)))
253 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
254 {
255 let delay_ms = 10u64 * (1u64 << attempt.min(4));
256 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
257 }
258 Ok(Err(e)) => return Err(e.into()),
259 Err(e) => return Err(StoreError::Database(e.to_string())),
260 }
261 }
262
263 Err(StoreError::Database(format!(
264 "{} exhausted retries",
265 op_name
266 )))
267 }
268
269 fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
270 let mut bytes = Vec::with_capacity(64);
271 bytes.extend_from_slice(key_pair.private_key.serialize());
272 bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
273 Ok(bytes)
274 }
275
276 fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
277 if bytes.len() != 64 {
278 return Err(StoreError::Serialization(format!(
279 "Invalid KeyPair length: {}",
280 bytes.len()
281 )));
282 }
283
284 let private_key = PrivateKey::deserialize(&bytes[0..32])
285 .map_err(|e| StoreError::Serialization(e.to_string()))?;
286 let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
287 .map_err(|e| StoreError::Serialization(e.to_string()))?;
288
289 Ok(KeyPair::new(public_key, private_key))
290 }
291
292 pub async fn save_device_data_for_device(
293 &self,
294 device_id: i32,
295 device_data: &CoreDevice,
296 ) -> Result<()> {
297 let pool = self.pool.clone();
298 let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
299 let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
300 let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
301 let account_data = device_data
302 .account
303 .as_ref()
304 .map(|account| account.encode_to_vec());
305 let registration_id = device_data.registration_id as i32;
306 let signed_pre_key_id = device_data.signed_pre_key_id as i32;
307 let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
308 let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
309 let push_name = device_data.push_name.clone();
310 let app_version_primary = device_data.app_version_primary as i32;
311 let app_version_secondary = device_data.app_version_secondary as i32;
312 let app_version_tertiary = device_data.app_version_tertiary as i64;
313 let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
314 let edge_routing_info = device_data.edge_routing_info.clone();
315 let props_hash = device_data.props_hash.clone();
316 let next_pre_key_id = device_data.next_pre_key_id as i32;
317 let new_lid = device_data
318 .lid
319 .as_ref()
320 .map(|j| j.to_string())
321 .unwrap_or_default();
322 let new_pn = device_data
323 .pn
324 .as_ref()
325 .map(|j| j.to_string())
326 .unwrap_or_default();
327
328 tokio::task::spawn_blocking(move || -> Result<()> {
329 let mut conn = pool
330 .get()
331 .map_err(|e| StoreError::Connection(e.to_string()))?;
332
333 diesel::insert_into(device::table)
334 .values((
335 device::id.eq(device_id),
336 device::lid.eq(&new_lid),
337 device::pn.eq(&new_pn),
338 device::registration_id.eq(registration_id),
339 device::noise_key.eq(&noise_key_data),
340 device::identity_key.eq(&identity_key_data),
341 device::signed_pre_key.eq(&signed_pre_key_data),
342 device::signed_pre_key_id.eq(signed_pre_key_id),
343 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
344 device::adv_secret_key.eq(&adv_secret_key[..]),
345 device::account.eq(account_data.clone()),
346 device::push_name.eq(&push_name),
347 device::app_version_primary.eq(app_version_primary),
348 device::app_version_secondary.eq(app_version_secondary),
349 device::app_version_tertiary.eq(app_version_tertiary),
350 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
351 device::edge_routing_info.eq(edge_routing_info.clone()),
352 device::props_hash.eq(props_hash.clone()),
353 device::next_pre_key_id.eq(next_pre_key_id),
354 ))
355 .on_conflict(device::id)
356 .do_update()
357 .set((
358 device::lid.eq(&new_lid),
359 device::pn.eq(&new_pn),
360 device::registration_id.eq(registration_id),
361 device::noise_key.eq(&noise_key_data),
362 device::identity_key.eq(&identity_key_data),
363 device::signed_pre_key.eq(&signed_pre_key_data),
364 device::signed_pre_key_id.eq(signed_pre_key_id),
365 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
366 device::adv_secret_key.eq(&adv_secret_key[..]),
367 device::account.eq(account_data.clone()),
368 device::push_name.eq(&push_name),
369 device::app_version_primary.eq(app_version_primary),
370 device::app_version_secondary.eq(app_version_secondary),
371 device::app_version_tertiary.eq(app_version_tertiary),
372 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
373 device::edge_routing_info.eq(edge_routing_info),
374 device::props_hash.eq(props_hash),
375 device::next_pre_key_id.eq(next_pre_key_id),
376 ))
377 .execute(&mut conn)
378 .map_err(|e| StoreError::Database(e.to_string()))?;
379
380 Ok(())
381 })
382 .await
383 .map_err(|e| StoreError::Database(e.to_string()))??;
384
385 Ok(())
386 }
387
388 pub async fn create_new_device(&self) -> Result<i32> {
389 use crate::schema::device;
390
391 let pool = self.pool.clone();
392 tokio::task::spawn_blocking(move || -> Result<i32> {
393 let mut conn = pool
394 .get()
395 .map_err(|e| StoreError::Connection(e.to_string()))?;
396
397 let new_device = wacore_ng::store::Device::new();
398
399 let noise_key_data = {
400 let mut bytes = Vec::with_capacity(64);
401 bytes.extend_from_slice(new_device.noise_key.private_key.serialize());
402 bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
403 bytes
404 };
405 let identity_key_data = {
406 let mut bytes = Vec::with_capacity(64);
407 bytes.extend_from_slice(new_device.identity_key.private_key.serialize());
408 bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
409 bytes
410 };
411 let signed_pre_key_data = {
412 let mut bytes = Vec::with_capacity(64);
413 bytes.extend_from_slice(new_device.signed_pre_key.private_key.serialize());
414 bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
415 bytes
416 };
417
418 diesel::insert_into(device::table)
419 .values((
420 device::lid.eq(""),
421 device::pn.eq(""),
422 device::registration_id.eq(new_device.registration_id as i32),
423 device::noise_key.eq(&noise_key_data),
424 device::identity_key.eq(&identity_key_data),
425 device::signed_pre_key.eq(&signed_pre_key_data),
426 device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
427 device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
428 device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
429 device::account.eq(None::<Vec<u8>>),
430 device::push_name.eq(&new_device.push_name),
431 device::app_version_primary.eq(new_device.app_version_primary as i32),
432 device::app_version_secondary.eq(new_device.app_version_secondary as i32),
433 device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
434 device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
435 device::edge_routing_info.eq(None::<Vec<u8>>),
436 device::props_hash.eq(None::<String>),
437 device::next_pre_key_id.eq(new_device.next_pre_key_id as i32),
438 ))
439 .execute(&mut conn)
440 .map_err(|e| StoreError::Database(e.to_string()))?;
441
442 use diesel::sql_types::Integer;
443
444 #[derive(QueryableByName)]
445 struct LastInsertedId {
446 #[diesel(sql_type = Integer)]
447 last_insert_rowid: i32,
448 }
449
450 let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
451 .get_result::<LastInsertedId>(&mut conn)
452 .map_err(|e| StoreError::Database(e.to_string()))?
453 .last_insert_rowid;
454
455 Ok(device_id)
456 })
457 .await
458 .map_err(|e| StoreError::Database(e.to_string()))?
459 }
460
461 pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
462 use crate::schema::device;
463
464 let pool = self.pool.clone();
465 tokio::task::spawn_blocking(move || -> Result<bool> {
466 let mut conn = pool
467 .get()
468 .map_err(|e| StoreError::Connection(e.to_string()))?;
469
470 let count: i64 = device::table
471 .filter(device::id.eq(device_id))
472 .count()
473 .get_result(&mut conn)
474 .map_err(|e| StoreError::Database(e.to_string()))?;
475
476 Ok(count > 0)
477 })
478 .await
479 .map_err(|e| StoreError::Database(e.to_string()))?
480 }
481
482 pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
483 use crate::schema::device;
484
485 let pool = self.pool.clone();
486 let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
487 let mut conn = pool
488 .get()
489 .map_err(|e| StoreError::Connection(e.to_string()))?;
490 let result = device::table
491 .filter(device::id.eq(device_id))
492 .first::<DeviceRow>(&mut conn)
493 .optional()
494 .map_err(|e| StoreError::Database(e.to_string()))?;
495 Ok(result)
496 })
497 .await
498 .map_err(|e| StoreError::Database(e.to_string()))??;
499
500 if let Some((
501 _device_id,
502 lid_str,
503 pn_str,
504 registration_id,
505 noise_key_data,
506 identity_key_data,
507 signed_pre_key_data,
508 signed_pre_key_id,
509 signed_pre_key_signature_data,
510 adv_secret_key_data,
511 account_data,
512 push_name,
513 app_version_primary,
514 app_version_secondary,
515 app_version_tertiary,
516 app_version_last_fetched_ms,
517 edge_routing_info,
518 props_hash,
519 next_pre_key_id,
520 )) = row
521 {
522 let id = if !pn_str.is_empty() {
523 pn_str.parse().ok()
524 } else {
525 None
526 };
527 let lid = if !lid_str.is_empty() {
528 lid_str.parse().ok()
529 } else {
530 None
531 };
532
533 let noise_key = self.deserialize_keypair(&noise_key_data)?;
534 let identity_key = self.deserialize_keypair(&identity_key_data)?;
535 let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
536
537 let signed_pre_key_signature: [u8; 64] =
538 signed_pre_key_signature_data.try_into().map_err(|_| {
539 StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
540 })?;
541
542 let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
543 StoreError::Serialization("Invalid adv_secret_key length".to_string())
544 })?;
545
546 let account = account_data
547 .map(|data| {
548 wa::AdvSignedDeviceIdentity::decode(&data[..])
549 .map_err(|e| StoreError::Serialization(e.to_string()))
550 })
551 .transpose()?;
552
553 Ok(Some(CoreDevice {
554 pn: id,
555 lid,
556 registration_id: registration_id as u32,
557 noise_key,
558 identity_key,
559 signed_pre_key,
560 signed_pre_key_id: signed_pre_key_id as u32,
561 signed_pre_key_signature,
562 adv_secret_key,
563 account,
564 push_name,
565 app_version_primary: app_version_primary as u32,
566 app_version_secondary: app_version_secondary as u32,
567 app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
568 app_version_last_fetched_ms,
569 device_props: {
570 use wacore_ng::store::device::DEVICE_PROPS;
571 DEVICE_PROPS.clone()
572 },
573 edge_routing_info,
574 props_hash,
575 next_pre_key_id: next_pre_key_id as u32,
576 }))
577 } else {
578 Ok(None)
579 }
580 }
581
582 pub async fn put_identity_for_device(
583 &self,
584 address: &str,
585 key: [u8; 32],
586 device_id: i32,
587 ) -> Result<()> {
588 let pool = self.pool.clone();
589 let db_semaphore = self.db_semaphore.clone();
590 let address_owned = address.to_string();
591 let key_vec = key.to_vec();
592
593 const MAX_RETRIES: u32 = 5;
594
595 for attempt in 0..=MAX_RETRIES {
596 let permit =
597 db_semaphore.clone().acquire_owned().await.map_err(|e| {
598 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
599 })?;
600
601 let pool_clone = pool.clone();
602 let address_clone = address_owned.clone();
603 let key_clone = key_vec.clone();
604
605 let result =
606 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
607 let mut conn = pool_clone
608 .get()
609 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
610 diesel::insert_into(identities::table)
611 .values((
612 identities::address.eq(address_clone),
613 identities::key.eq(&key_clone[..]),
614 identities::device_id.eq(device_id),
615 ))
616 .on_conflict((identities::address, identities::device_id))
617 .do_update()
618 .set(identities::key.eq(&key_clone[..]))
619 .execute(&mut conn)
620 .map_err(DieselOrStore::Diesel)?;
621 Ok(())
622 })
623 .await;
624
625 drop(permit);
626
627 match result {
628 Ok(Ok(())) => return Ok(()),
629 Ok(Err(DieselOrStore::Diesel(ref e)))
630 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
631 {
632 let delay_ms = 10 * 2u64.pow(attempt);
633 warn!(
634 "Identity write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
635 attempt + 1,
636 MAX_RETRIES + 1,
637 );
638 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
639 continue;
640 }
641 Ok(Err(e)) => return Err(e.into()),
642 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
643 }
644 }
645
646 Err(StoreError::Database(format!(
647 "Identity write failed after {} attempts",
648 MAX_RETRIES + 1
649 )))
650 }
651
652 pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
653 let pool = self.pool.clone();
654 let address_owned = address.to_string();
655
656 tokio::task::spawn_blocking(move || -> Result<()> {
657 let mut conn = pool
658 .get()
659 .map_err(|e| StoreError::Connection(e.to_string()))?;
660 diesel::delete(
661 identities::table
662 .filter(identities::address.eq(address_owned))
663 .filter(identities::device_id.eq(device_id)),
664 )
665 .execute(&mut conn)
666 .map_err(|e| StoreError::Database(e.to_string()))?;
667 Ok(())
668 })
669 .await
670 .map_err(|e| StoreError::Database(e.to_string()))??;
671
672 Ok(())
673 }
674
675 pub async fn load_identity_for_device(
676 &self,
677 address: &str,
678 device_id: i32,
679 ) -> Result<Option<Vec<u8>>> {
680 let pool = self.pool.clone();
681 let address = address.to_string();
682 let result = self
683 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
684 let mut conn = pool
685 .get()
686 .map_err(|e| StoreError::Connection(e.to_string()))?;
687 let res: Option<Vec<u8>> = identities::table
688 .select(identities::key)
689 .filter(identities::address.eq(address))
690 .filter(identities::device_id.eq(device_id))
691 .first(&mut conn)
692 .optional()
693 .map_err(|e| StoreError::Database(e.to_string()))?;
694 Ok(res)
695 })
696 .await?;
697
698 Ok(result)
699 }
700
701 pub async fn get_session_for_device(
702 &self,
703 address: &str,
704 device_id: i32,
705 ) -> Result<Option<Vec<u8>>> {
706 let pool = self.pool.clone();
707 let address_for_query = address.to_string();
708 let result = self
709 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
710 let mut conn = pool
711 .get()
712 .map_err(|e| StoreError::Connection(e.to_string()))?;
713 let res: Option<Vec<u8>> = sessions::table
714 .select(sessions::record)
715 .filter(sessions::address.eq(address_for_query.clone()))
716 .filter(sessions::device_id.eq(device_id))
717 .first(&mut conn)
718 .optional()
719 .map_err(|e| StoreError::Database(e.to_string()))?;
720
721 Ok(res)
722 })
723 .await?;
724
725 Ok(result)
726 }
727
728 pub async fn put_session_for_device(
729 &self,
730 address: &str,
731 session: &[u8],
732 device_id: i32,
733 ) -> Result<()> {
734 let pool = self.pool.clone();
735 let db_semaphore = self.db_semaphore.clone();
736 let address_owned = address.to_string();
737 let session_vec = session.to_vec();
738
739 const MAX_RETRIES: u32 = 5;
740
741 for attempt in 0..=MAX_RETRIES {
742 let permit =
743 db_semaphore.clone().acquire_owned().await.map_err(|e| {
744 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
745 })?;
746
747 let pool_clone = pool.clone();
748 let address_clone = address_owned.clone();
749 let session_clone = session_vec.clone();
750
751 let result =
752 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
753 let mut conn = pool_clone
754 .get()
755 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
756 diesel::insert_into(sessions::table)
757 .values((
758 sessions::address.eq(address_clone),
759 sessions::record.eq(&session_clone),
760 sessions::device_id.eq(device_id),
761 ))
762 .on_conflict((sessions::address, sessions::device_id))
763 .do_update()
764 .set(sessions::record.eq(&session_clone))
765 .execute(&mut conn)
766 .map_err(DieselOrStore::Diesel)?;
767 Ok(())
768 })
769 .await;
770
771 drop(permit);
772
773 match result {
774 Ok(Ok(())) => return Ok(()),
775 Ok(Err(DieselOrStore::Diesel(ref e)))
776 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
777 {
778 let delay_ms = 10 * 2u64.pow(attempt);
779 warn!(
780 "Session write failed (attempt {}/{}): {e}. Retrying in {delay_ms}ms...",
781 attempt + 1,
782 MAX_RETRIES + 1,
783 );
784 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
785 continue;
786 }
787 Ok(Err(e)) => return Err(e.into()),
788 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
789 }
790 }
791
792 Err(StoreError::Database(format!(
793 "Session write failed after {} attempts",
794 MAX_RETRIES + 1
795 )))
796 }
797
798 pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
799 let pool = self.pool.clone();
800 let address_owned = address.to_string();
801
802 tokio::task::spawn_blocking(move || -> Result<()> {
803 let mut conn = pool
804 .get()
805 .map_err(|e| StoreError::Connection(e.to_string()))?;
806 diesel::delete(
807 sessions::table
808 .filter(sessions::address.eq(address_owned))
809 .filter(sessions::device_id.eq(device_id)),
810 )
811 .execute(&mut conn)
812 .map_err(|e| StoreError::Database(e.to_string()))?;
813 Ok(())
814 })
815 .await
816 .map_err(|e| StoreError::Database(e.to_string()))??;
817
818 Ok(())
819 }
820
821 pub async fn put_sender_key_for_device(
822 &self,
823 address: &str,
824 record: &[u8],
825 device_id: i32,
826 ) -> Result<()> {
827 let pool = self.pool.clone();
828 let address = address.to_string();
829 let record_vec = record.to_vec();
830 tokio::task::spawn_blocking(move || -> Result<()> {
831 let mut conn = pool
832 .get()
833 .map_err(|e| StoreError::Connection(e.to_string()))?;
834 diesel::insert_into(sender_keys::table)
835 .values((
836 sender_keys::address.eq(address),
837 sender_keys::record.eq(&record_vec),
838 sender_keys::device_id.eq(device_id),
839 ))
840 .on_conflict((sender_keys::address, sender_keys::device_id))
841 .do_update()
842 .set(sender_keys::record.eq(&record_vec))
843 .execute(&mut conn)
844 .map_err(|e| StoreError::Database(e.to_string()))?;
845 Ok(())
846 })
847 .await
848 .map_err(|e| StoreError::Database(e.to_string()))??;
849 Ok(())
850 }
851
852 pub async fn get_sender_key_for_device(
853 &self,
854 address: &str,
855 device_id: i32,
856 ) -> Result<Option<Vec<u8>>> {
857 let pool = self.pool.clone();
858 let address = address.to_string();
859 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
860 let mut conn = pool
861 .get()
862 .map_err(|e| StoreError::Connection(e.to_string()))?;
863 let res: Option<Vec<u8>> = sender_keys::table
864 .select(sender_keys::record)
865 .filter(sender_keys::address.eq(address))
866 .filter(sender_keys::device_id.eq(device_id))
867 .first(&mut conn)
868 .optional()
869 .map_err(|e| StoreError::Database(e.to_string()))?;
870 Ok(res)
871 })
872 .await
873 .map_err(|e| StoreError::Database(e.to_string()))?
874 }
875
876 pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
877 let pool = self.pool.clone();
878 let address = address.to_string();
879 tokio::task::spawn_blocking(move || -> Result<()> {
880 let mut conn = pool
881 .get()
882 .map_err(|e| StoreError::Connection(e.to_string()))?;
883 diesel::delete(
884 sender_keys::table
885 .filter(sender_keys::address.eq(address))
886 .filter(sender_keys::device_id.eq(device_id)),
887 )
888 .execute(&mut conn)
889 .map_err(|e| StoreError::Database(e.to_string()))?;
890 Ok(())
891 })
892 .await
893 .map_err(|e| StoreError::Database(e.to_string()))??;
894 Ok(())
895 }
896
897 pub async fn get_app_state_sync_key_for_device(
898 &self,
899 key_id: &[u8],
900 device_id: i32,
901 ) -> Result<Option<AppStateSyncKey>> {
902 let pool = self.pool.clone();
903 let key_id = key_id.to_vec();
904 let res: Option<Vec<u8>> =
905 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
906 let mut conn = pool
907 .get()
908 .map_err(|e| StoreError::Connection(e.to_string()))?;
909 let res: Option<Vec<u8>> = app_state_keys::table
910 .select(app_state_keys::key_data)
911 .filter(app_state_keys::key_id.eq(&key_id))
912 .filter(app_state_keys::device_id.eq(device_id))
913 .first(&mut conn)
914 .optional()
915 .map_err(|e| StoreError::Database(e.to_string()))?;
916 Ok(res)
917 })
918 .await
919 .map_err(|e| StoreError::Database(e.to_string()))??;
920
921 if let Some(data) = res {
922 let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
923 .map_err(|e| StoreError::Serialization(e.to_string()))?;
924 Ok(Some(key))
925 } else {
926 Ok(None)
927 }
928 }
929
930 pub async fn set_app_state_sync_key_for_device(
931 &self,
932 key_id: &[u8],
933 key: AppStateSyncKey,
934 device_id: i32,
935 ) -> Result<()> {
936 let pool = self.pool.clone();
937 let key_id = key_id.to_vec();
938 let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
939 .map_err(|e| StoreError::Serialization(e.to_string()))?;
940 tokio::task::spawn_blocking(move || -> Result<()> {
941 let mut conn = pool
942 .get()
943 .map_err(|e| StoreError::Connection(e.to_string()))?;
944 diesel::insert_into(app_state_keys::table)
945 .values((
946 app_state_keys::key_id.eq(&key_id),
947 app_state_keys::key_data.eq(&data),
948 app_state_keys::device_id.eq(device_id),
949 ))
950 .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
951 .do_update()
952 .set(app_state_keys::key_data.eq(&data))
953 .execute(&mut conn)
954 .map_err(|e| StoreError::Database(e.to_string()))?;
955 Ok(())
956 })
957 .await
958 .map_err(|e| StoreError::Database(e.to_string()))??;
959 Ok(())
960 }
961
962 pub async fn get_latest_app_state_sync_key_id_for_device(
963 &self,
964 device_id: i32,
965 ) -> Result<Option<Vec<u8>>> {
966 let pool = self.pool.clone();
967 let res: Option<Vec<u8>> =
968 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
969 let mut conn = pool
970 .get()
971 .map_err(|e| StoreError::Connection(e.to_string()))?;
972 let res: Option<Vec<u8>> = app_state_keys::table
973 .select(app_state_keys::key_id)
974 .filter(app_state_keys::device_id.eq(device_id))
975 .order(app_state_keys::key_id.desc())
976 .first(&mut conn)
977 .optional()
978 .map_err(|e| StoreError::Database(e.to_string()))?;
979 Ok(res)
980 })
981 .await
982 .map_err(|e| StoreError::Database(e.to_string()))??;
983 Ok(res)
984 }
985
986 pub async fn get_app_state_version_for_device(
987 &self,
988 name: &str,
989 device_id: i32,
990 ) -> Result<HashState> {
991 let pool = self.pool.clone();
992 let name = name.to_string();
993 let res: Option<Vec<u8>> =
994 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
995 let mut conn = pool
996 .get()
997 .map_err(|e| StoreError::Connection(e.to_string()))?;
998 let res: Option<Vec<u8>> = app_state_versions::table
999 .select(app_state_versions::state_data)
1000 .filter(app_state_versions::name.eq(name))
1001 .filter(app_state_versions::device_id.eq(device_id))
1002 .first(&mut conn)
1003 .optional()
1004 .map_err(|e| StoreError::Database(e.to_string()))?;
1005 Ok(res)
1006 })
1007 .await
1008 .map_err(|e| StoreError::Database(e.to_string()))??;
1009
1010 if let Some(data) = res {
1011 let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
1012 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1013 Ok(state)
1014 } else {
1015 Ok(HashState::default())
1016 }
1017 }
1018
1019 pub async fn set_app_state_version_for_device(
1020 &self,
1021 name: &str,
1022 state: HashState,
1023 device_id: i32,
1024 ) -> Result<()> {
1025 let name = name.to_string();
1026 let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
1027 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1028 self.with_retry("set_app_state_version", || {
1029 let name = name.clone();
1030 let data = data.clone();
1031 Box::new(move |conn: &mut SqliteConnection| {
1032 diesel::insert_into(app_state_versions::table)
1033 .values((
1034 app_state_versions::name.eq(&name),
1035 app_state_versions::state_data.eq(&data),
1036 app_state_versions::device_id.eq(device_id),
1037 ))
1038 .on_conflict((app_state_versions::name, app_state_versions::device_id))
1039 .do_update()
1040 .set(app_state_versions::state_data.eq(&data))
1041 .execute(conn)?;
1042 Ok(())
1043 })
1044 })
1045 .await
1046 }
1047
1048 pub async fn put_app_state_mutation_macs_for_device(
1049 &self,
1050 name: &str,
1051 version: u64,
1052 mutations: &[AppStateMutationMAC],
1053 device_id: i32,
1054 ) -> Result<()> {
1055 if mutations.is_empty() {
1056 return Ok(());
1057 }
1058 let name = name.to_string();
1059 let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
1060 self.with_retry("put_app_state_mutation_macs", || {
1061 let name = name.clone();
1062 let mutations = mutations.clone();
1063 Box::new(move |conn: &mut SqliteConnection| {
1064 let records: Vec<_> = mutations
1065 .iter()
1066 .map(|m| {
1067 (
1068 app_state_mutation_macs::name.eq(&name),
1069 app_state_mutation_macs::version.eq(version as i64),
1070 app_state_mutation_macs::index_mac.eq(&m.index_mac),
1071 app_state_mutation_macs::value_mac.eq(&m.value_mac),
1072 app_state_mutation_macs::device_id.eq(device_id),
1073 )
1074 })
1075 .collect();
1076
1077 const CHUNK_SIZE: usize = 100;
1080
1081 for chunk in records.chunks(CHUNK_SIZE) {
1082 diesel::insert_into(app_state_mutation_macs::table)
1083 .values(chunk)
1084 .on_conflict((
1085 app_state_mutation_macs::name,
1086 app_state_mutation_macs::index_mac,
1087 app_state_mutation_macs::device_id,
1088 ))
1089 .do_update()
1090 .set((
1091 app_state_mutation_macs::version
1092 .eq(excluded(app_state_mutation_macs::version)),
1093 app_state_mutation_macs::value_mac
1094 .eq(excluded(app_state_mutation_macs::value_mac)),
1095 ))
1096 .execute(conn)?;
1097 }
1098 Ok(())
1099 })
1100 })
1101 .await
1102 }
1103
1104 pub async fn delete_app_state_mutation_macs_for_device(
1105 &self,
1106 name: &str,
1107 index_macs: &[Vec<u8>],
1108 device_id: i32,
1109 ) -> Result<()> {
1110 if index_macs.is_empty() {
1111 return Ok(());
1112 }
1113 let name = name.to_string();
1114 let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
1115 self.with_retry("delete_app_state_mutation_macs", || {
1116 let name = name.clone();
1117 let index_macs = index_macs.clone();
1118 Box::new(move |conn: &mut SqliteConnection| {
1119 const CHUNK_SIZE: usize = 500;
1122
1123 for chunk in index_macs.chunks(CHUNK_SIZE) {
1124 diesel::delete(
1125 app_state_mutation_macs::table.filter(
1126 app_state_mutation_macs::name
1127 .eq(&name)
1128 .and(app_state_mutation_macs::index_mac.eq_any(chunk))
1129 .and(app_state_mutation_macs::device_id.eq(device_id)),
1130 ),
1131 )
1132 .execute(conn)?;
1133 }
1134 Ok(())
1135 })
1136 })
1137 .await
1138 }
1139
1140 pub async fn get_app_state_mutation_mac_for_device(
1141 &self,
1142 name: &str,
1143 index_mac: &[u8],
1144 device_id: i32,
1145 ) -> Result<Option<Vec<u8>>> {
1146 let pool = self.pool.clone();
1147 let name = name.to_string();
1148 let index_mac = index_mac.to_vec();
1149 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1150 let mut conn = pool
1151 .get()
1152 .map_err(|e| StoreError::Connection(e.to_string()))?;
1153 let res: Option<Vec<u8>> = app_state_mutation_macs::table
1154 .select(app_state_mutation_macs::value_mac)
1155 .filter(app_state_mutation_macs::name.eq(&name))
1156 .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1157 .filter(app_state_mutation_macs::device_id.eq(device_id))
1158 .first(&mut conn)
1159 .optional()
1160 .map_err(|e| StoreError::Database(e.to_string()))?;
1161 Ok(res)
1162 })
1163 .await
1164 .map_err(|e| StoreError::Database(e.to_string()))?
1165 }
1166}
1167
1168#[async_trait]
1169impl SignalStore for SqliteStore {
1170 async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1171 self.put_identity_for_device(address, key, self.device_id)
1172 .await
1173 }
1174
1175 async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1176 self.load_identity_for_device(address, self.device_id).await
1177 }
1178
1179 async fn delete_identity(&self, address: &str) -> Result<()> {
1180 self.delete_identity_for_device(address, self.device_id)
1181 .await
1182 }
1183
1184 async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1185 self.get_session_for_device(address, self.device_id).await
1186 }
1187
1188 async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1189 self.put_session_for_device(address, session, self.device_id)
1190 .await
1191 }
1192
1193 async fn delete_session(&self, address: &str) -> Result<()> {
1194 self.delete_session_for_device(address, self.device_id)
1195 .await
1196 }
1197
1198 async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1199 let pool = self.pool.clone();
1200 let db_semaphore = self.db_semaphore.clone();
1201 let device_id = self.device_id;
1202 let record = record.to_vec();
1203
1204 const MAX_RETRIES: u32 = 5;
1205
1206 for attempt in 0..=MAX_RETRIES {
1207 let permit =
1208 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1209 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1210 })?;
1211
1212 let pool_clone = pool.clone();
1213 let record_clone = record.clone();
1214
1215 let result =
1216 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1217 let mut conn = pool_clone
1218 .get()
1219 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1220 diesel::insert_into(prekeys::table)
1221 .values((
1222 prekeys::id.eq(id as i32),
1223 prekeys::key.eq(&record_clone),
1224 prekeys::uploaded.eq(uploaded),
1225 prekeys::device_id.eq(device_id),
1226 ))
1227 .on_conflict((prekeys::id, prekeys::device_id))
1228 .do_update()
1229 .set((
1230 prekeys::key.eq(&record_clone),
1231 prekeys::uploaded.eq(uploaded),
1232 ))
1233 .execute(&mut conn)
1234 .map_err(DieselOrStore::Diesel)?;
1235 Ok(())
1236 })
1237 .await;
1238
1239 drop(permit);
1240
1241 match result {
1242 Ok(Ok(())) => return Ok(()),
1243 Ok(Err(DieselOrStore::Diesel(ref e)))
1244 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1245 {
1246 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1247 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1248 }
1249 Ok(Err(e)) => return Err(e.into()),
1250 Err(e) => return Err(StoreError::Database(e.to_string())),
1251 }
1252 }
1253
1254 Err(StoreError::Database(
1255 "store_prekey exhausted retries".to_string(),
1256 ))
1257 }
1258
1259 async fn store_prekeys_batch(&self, keys: &[(u32, Vec<u8>)], uploaded: bool) -> Result<()> {
1260 if keys.is_empty() {
1261 return Ok(());
1262 }
1263
1264 let pool = self.pool.clone();
1265 let db_semaphore = self.db_semaphore.clone();
1266 let device_id = self.device_id;
1267 let keys = keys.to_vec();
1268
1269 const MAX_RETRIES: u32 = 5;
1270
1271 for attempt in 0..=MAX_RETRIES {
1272 let permit =
1273 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1274 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1275 })?;
1276
1277 let pool_clone = pool.clone();
1278 let keys_clone = keys.clone();
1279
1280 let result =
1281 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1282 let mut conn = pool_clone
1283 .get()
1284 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1285
1286 conn.transaction(|conn| {
1287 for (id, record) in &keys_clone {
1288 diesel::insert_into(prekeys::table)
1289 .values((
1290 prekeys::id.eq(*id as i32),
1291 prekeys::key.eq(record),
1292 prekeys::uploaded.eq(uploaded),
1293 prekeys::device_id.eq(device_id),
1294 ))
1295 .on_conflict((prekeys::id, prekeys::device_id))
1296 .do_update()
1297 .set((prekeys::key.eq(record), prekeys::uploaded.eq(uploaded)))
1298 .execute(conn)?;
1299 }
1300 Ok::<(), diesel::result::Error>(())
1301 })
1302 .map_err(DieselOrStore::Diesel)
1303 })
1304 .await;
1305
1306 drop(permit);
1307
1308 match result {
1309 Ok(Ok(())) => return Ok(()),
1310 Ok(Err(DieselOrStore::Diesel(ref e)))
1311 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1312 {
1313 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1314 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1315 }
1316 Ok(Err(e)) => return Err(e.into()),
1317 Err(e) => return Err(StoreError::Database(e.to_string())),
1318 }
1319 }
1320
1321 Err(StoreError::Database(
1322 "store_prekeys_batch exhausted retries".to_string(),
1323 ))
1324 }
1325
1326 async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1327 let pool = self.pool.clone();
1328 let device_id = self.device_id;
1329 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1330 let mut conn = pool
1331 .get()
1332 .map_err(|e| StoreError::Connection(e.to_string()))?;
1333 let res: Option<Vec<u8>> = prekeys::table
1334 .select(prekeys::key)
1335 .filter(prekeys::id.eq(id as i32))
1336 .filter(prekeys::device_id.eq(device_id))
1337 .first(&mut conn)
1338 .optional()
1339 .map_err(|e| StoreError::Database(e.to_string()))?;
1340 Ok(res)
1341 })
1342 .await
1343 .map_err(|e| StoreError::Database(e.to_string()))?
1344 }
1345
1346 async fn remove_prekey(&self, id: u32) -> Result<()> {
1347 let pool = self.pool.clone();
1348 let db_semaphore = self.db_semaphore.clone();
1349 let device_id = self.device_id;
1350
1351 const MAX_RETRIES: u32 = 5;
1352
1353 for attempt in 0..=MAX_RETRIES {
1354 let permit =
1355 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1356 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1357 })?;
1358
1359 let pool_clone = pool.clone();
1360
1361 let result =
1362 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1363 let mut conn = pool_clone
1364 .get()
1365 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1366 diesel::delete(
1367 prekeys::table
1368 .filter(prekeys::id.eq(id as i32))
1369 .filter(prekeys::device_id.eq(device_id)),
1370 )
1371 .execute(&mut conn)
1372 .map_err(DieselOrStore::Diesel)?;
1373 Ok(())
1374 })
1375 .await;
1376
1377 drop(permit);
1378
1379 match result {
1380 Ok(Ok(())) => return Ok(()),
1381 Ok(Err(DieselOrStore::Diesel(ref e)))
1382 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1383 {
1384 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1385 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1386 }
1387 Ok(Err(e)) => return Err(e.into()),
1388 Err(e) => return Err(StoreError::Database(e.to_string())),
1389 }
1390 }
1391
1392 Err(StoreError::Database(
1393 "remove_prekey exhausted retries".to_string(),
1394 ))
1395 }
1396
1397 async fn get_max_prekey_id(&self) -> Result<u32> {
1398 let pool = self.pool.clone();
1399 let device_id = self.device_id;
1400 let db_semaphore = self.db_semaphore.clone();
1401 let _permit = db_semaphore
1402 .acquire()
1403 .await
1404 .map_err(|e| StoreError::Database(format!("Failed to acquire semaphore: {}", e)))?;
1405
1406 tokio::task::spawn_blocking(move || -> Result<u32> {
1407 let mut conn = pool
1408 .get()
1409 .map_err(|e| StoreError::Connection(e.to_string()))?;
1410 use diesel::dsl::max;
1411 let result: Option<i32> = prekeys::table
1412 .filter(prekeys::device_id.eq(device_id))
1413 .select(max(prekeys::id))
1414 .first(&mut conn)
1415 .map_err(|e| StoreError::Database(e.to_string()))?;
1416 Ok(result.unwrap_or(0) as u32)
1417 })
1418 .await
1419 .map_err(|e| StoreError::Database(e.to_string()))?
1420 }
1421
1422 async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1423 let pool = self.pool.clone();
1424 let db_semaphore = self.db_semaphore.clone();
1425 let device_id = self.device_id;
1426 let record = record.to_vec();
1427
1428 const MAX_RETRIES: u32 = 5;
1429
1430 for attempt in 0..=MAX_RETRIES {
1431 let permit =
1432 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1433 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1434 })?;
1435
1436 let pool_clone = pool.clone();
1437 let record_clone = record.clone();
1438
1439 let result =
1440 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1441 let mut conn = pool_clone
1442 .get()
1443 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1444 diesel::insert_into(signed_prekeys::table)
1445 .values((
1446 signed_prekeys::id.eq(id as i32),
1447 signed_prekeys::record.eq(&record_clone),
1448 signed_prekeys::device_id.eq(device_id),
1449 ))
1450 .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1451 .do_update()
1452 .set(signed_prekeys::record.eq(&record_clone))
1453 .execute(&mut conn)
1454 .map_err(DieselOrStore::Diesel)?;
1455 Ok(())
1456 })
1457 .await;
1458
1459 drop(permit);
1460
1461 match result {
1462 Ok(Ok(())) => return Ok(()),
1463 Ok(Err(DieselOrStore::Diesel(ref e)))
1464 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1465 {
1466 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1467 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1468 }
1469 Ok(Err(e)) => return Err(e.into()),
1470 Err(e) => return Err(StoreError::Database(e.to_string())),
1471 }
1472 }
1473
1474 Err(StoreError::Database(
1475 "store_signed_prekey exhausted retries".to_string(),
1476 ))
1477 }
1478
1479 async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1480 let pool = self.pool.clone();
1481 let device_id = self.device_id;
1482 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1483 let mut conn = pool
1484 .get()
1485 .map_err(|e| StoreError::Connection(e.to_string()))?;
1486 let res: Option<Vec<u8>> = signed_prekeys::table
1487 .select(signed_prekeys::record)
1488 .filter(signed_prekeys::id.eq(id as i32))
1489 .filter(signed_prekeys::device_id.eq(device_id))
1490 .first(&mut conn)
1491 .optional()
1492 .map_err(|e| StoreError::Database(e.to_string()))?;
1493 Ok(res)
1494 })
1495 .await
1496 .map_err(|e| StoreError::Database(e.to_string()))?
1497 }
1498
1499 async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1500 let pool = self.pool.clone();
1501 let device_id = self.device_id;
1502 tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1503 let mut conn = pool
1504 .get()
1505 .map_err(|e| StoreError::Connection(e.to_string()))?;
1506 let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1507 .select((signed_prekeys::id, signed_prekeys::record))
1508 .filter(signed_prekeys::device_id.eq(device_id))
1509 .load(&mut conn)
1510 .map_err(|e| StoreError::Database(e.to_string()))?;
1511 Ok(results
1512 .into_iter()
1513 .map(|(id, record)| (id as u32, record))
1514 .collect())
1515 })
1516 .await
1517 .map_err(|e| StoreError::Database(e.to_string()))?
1518 }
1519
1520 async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1521 let pool = self.pool.clone();
1522 let db_semaphore = self.db_semaphore.clone();
1523 let device_id = self.device_id;
1524
1525 const MAX_RETRIES: u32 = 5;
1526
1527 for attempt in 0..=MAX_RETRIES {
1528 let permit =
1529 db_semaphore.clone().acquire_owned().await.map_err(|e| {
1530 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
1531 })?;
1532
1533 let pool_clone = pool.clone();
1534
1535 let result =
1536 tokio::task::spawn_blocking(move || -> std::result::Result<(), DieselOrStore> {
1537 let mut conn = pool_clone
1538 .get()
1539 .map_err(|e| DieselOrStore::Store(StoreError::Connection(e.to_string())))?;
1540 diesel::delete(
1541 signed_prekeys::table
1542 .filter(signed_prekeys::id.eq(id as i32))
1543 .filter(signed_prekeys::device_id.eq(device_id)),
1544 )
1545 .execute(&mut conn)
1546 .map_err(DieselOrStore::Diesel)?;
1547 Ok(())
1548 })
1549 .await;
1550
1551 drop(permit);
1552
1553 match result {
1554 Ok(Ok(())) => return Ok(()),
1555 Ok(Err(DieselOrStore::Diesel(ref e)))
1556 if is_retriable_sqlite_error(e) && attempt < MAX_RETRIES =>
1557 {
1558 let delay_ms = 10u64 * (1u64 << attempt.min(4));
1559 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
1560 }
1561 Ok(Err(e)) => return Err(e.into()),
1562 Err(e) => return Err(StoreError::Database(e.to_string())),
1563 }
1564 }
1565
1566 Err(StoreError::Database(
1567 "remove_signed_prekey exhausted retries".to_string(),
1568 ))
1569 }
1570
1571 async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1572 self.put_sender_key_for_device(address, record, self.device_id)
1573 .await
1574 }
1575
1576 async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1577 self.get_sender_key_for_device(address, self.device_id)
1578 .await
1579 }
1580
1581 async fn delete_sender_key(&self, address: &str) -> Result<()> {
1582 self.delete_sender_key_for_device(address, self.device_id)
1583 .await
1584 }
1585}
1586
1587#[async_trait]
1588impl AppSyncStore for SqliteStore {
1589 async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1590 self.get_app_state_sync_key_for_device(key_id, self.device_id)
1591 .await
1592 }
1593
1594 async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1595 self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1596 .await
1597 }
1598
1599 async fn get_version(&self, name: &str) -> Result<HashState> {
1600 self.get_app_state_version_for_device(name, self.device_id)
1601 .await
1602 }
1603
1604 async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1605 self.set_app_state_version_for_device(name, state, self.device_id)
1606 .await
1607 }
1608
1609 async fn put_mutation_macs(
1610 &self,
1611 name: &str,
1612 version: u64,
1613 mutations: &[AppStateMutationMAC],
1614 ) -> Result<()> {
1615 self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1616 .await
1617 }
1618
1619 async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1620 self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1621 .await
1622 }
1623
1624 async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1625 self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1626 .await
1627 }
1628
1629 async fn get_latest_sync_key_id(&self) -> Result<Option<Vec<u8>>> {
1630 self.get_latest_app_state_sync_key_id_for_device(self.device_id)
1631 .await
1632 }
1633}
1634
1635#[async_trait]
1636impl ProtocolStore for SqliteStore {
1637 async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<Jid>> {
1638 let pool = self.pool.clone();
1639 let device_id = self.device_id;
1640 let group_jid = group_jid.to_string();
1641 tokio::task::spawn_blocking(move || -> Result<Vec<Jid>> {
1642 let mut conn = pool
1643 .get()
1644 .map_err(|e| StoreError::Connection(e.to_string()))?;
1645 let recipients: Vec<String> = skdm_recipients::table
1646 .select(skdm_recipients::device_jid)
1647 .filter(skdm_recipients::group_jid.eq(&group_jid))
1648 .filter(skdm_recipients::device_id.eq(device_id))
1649 .load(&mut conn)
1650 .map_err(|e| StoreError::Database(e.to_string()))?;
1651 let jids: Vec<Jid> = recipients
1652 .iter()
1653 .filter_map(|s| match s.parse::<Jid>() {
1654 Ok(jid) => Some(jid),
1655 Err(e) => {
1656 warn!("Failed to parse SKDM recipient '{}': {}", s, e);
1657 None
1658 }
1659 })
1660 .collect();
1661 Ok(jids)
1662 })
1663 .await
1664 .map_err(|e| StoreError::Database(e.to_string()))?
1665 }
1666
1667 async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> Result<()> {
1668 if device_jids.is_empty() {
1669 return Ok(());
1670 }
1671 let pool = self.pool.clone();
1672 let device_id = self.device_id;
1673 let group_jid = group_jid.to_string();
1674 let device_jid_strs: Vec<String> = device_jids.iter().map(|j| j.to_string()).collect();
1675 let now = std::time::SystemTime::now()
1676 .duration_since(std::time::UNIX_EPOCH)
1677 .unwrap_or_default()
1678 .as_secs() as i32;
1679 tokio::task::spawn_blocking(move || -> Result<()> {
1680 let mut conn = pool
1681 .get()
1682 .map_err(|e| StoreError::Connection(e.to_string()))?;
1683
1684 let values: Vec<_> = device_jid_strs
1685 .iter()
1686 .map(|device_jid| {
1687 (
1688 skdm_recipients::group_jid.eq(&group_jid),
1689 skdm_recipients::device_jid.eq(device_jid),
1690 skdm_recipients::device_id.eq(device_id),
1691 skdm_recipients::created_at.eq(now),
1692 )
1693 })
1694 .collect();
1695
1696 const CHUNK_SIZE: usize = 200; for chunk in values.chunks(CHUNK_SIZE) {
1699 diesel::insert_into(skdm_recipients::table)
1700 .values(chunk)
1701 .on_conflict((
1702 skdm_recipients::group_jid,
1703 skdm_recipients::device_jid,
1704 skdm_recipients::device_id,
1705 ))
1706 .do_nothing()
1707 .execute(&mut conn)
1708 .map_err(|e| StoreError::Database(e.to_string()))?;
1709 }
1710 Ok(())
1711 })
1712 .await
1713 .map_err(|e| StoreError::Database(e.to_string()))??;
1714 Ok(())
1715 }
1716
1717 async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1718 let pool = self.pool.clone();
1719 let device_id = self.device_id;
1720 let group_jid = group_jid.to_string();
1721 tokio::task::spawn_blocking(move || -> Result<()> {
1722 let mut conn = pool
1723 .get()
1724 .map_err(|e| StoreError::Connection(e.to_string()))?;
1725 diesel::delete(
1726 skdm_recipients::table
1727 .filter(skdm_recipients::group_jid.eq(&group_jid))
1728 .filter(skdm_recipients::device_id.eq(device_id)),
1729 )
1730 .execute(&mut conn)
1731 .map_err(|e| StoreError::Database(e.to_string()))?;
1732 Ok(())
1733 })
1734 .await
1735 .map_err(|e| StoreError::Database(e.to_string()))??;
1736 Ok(())
1737 }
1738
1739 async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1740 let pool = self.pool.clone();
1741 let device_id = self.device_id;
1742 let lid = lid.to_string();
1743 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1744 let mut conn = pool
1745 .get()
1746 .map_err(|e| StoreError::Connection(e.to_string()))?;
1747 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1748 .select((
1749 lid_pn_mapping::lid,
1750 lid_pn_mapping::phone_number,
1751 lid_pn_mapping::created_at,
1752 lid_pn_mapping::learning_source,
1753 lid_pn_mapping::updated_at,
1754 ))
1755 .filter(lid_pn_mapping::lid.eq(&lid))
1756 .filter(lid_pn_mapping::device_id.eq(device_id))
1757 .first(&mut conn)
1758 .optional()
1759 .map_err(|e| StoreError::Database(e.to_string()))?;
1760 Ok(row.map(
1761 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1762 lid,
1763 phone_number,
1764 created_at,
1765 updated_at,
1766 learning_source,
1767 },
1768 ))
1769 })
1770 .await
1771 .map_err(|e| StoreError::Database(e.to_string()))?
1772 }
1773
1774 async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1775 let pool = self.pool.clone();
1776 let device_id = self.device_id;
1777 let phone = phone.to_string();
1778 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1779 let mut conn = pool
1780 .get()
1781 .map_err(|e| StoreError::Connection(e.to_string()))?;
1782 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1783 .select((
1784 lid_pn_mapping::lid,
1785 lid_pn_mapping::phone_number,
1786 lid_pn_mapping::created_at,
1787 lid_pn_mapping::learning_source,
1788 lid_pn_mapping::updated_at,
1789 ))
1790 .filter(lid_pn_mapping::phone_number.eq(&phone))
1791 .filter(lid_pn_mapping::device_id.eq(device_id))
1792 .order(lid_pn_mapping::updated_at.desc())
1793 .first(&mut conn)
1794 .optional()
1795 .map_err(|e| StoreError::Database(e.to_string()))?;
1796 Ok(row.map(
1797 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1798 lid,
1799 phone_number,
1800 created_at,
1801 updated_at,
1802 learning_source,
1803 },
1804 ))
1805 })
1806 .await
1807 .map_err(|e| StoreError::Database(e.to_string()))?
1808 }
1809
1810 async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1811 let pool = self.pool.clone();
1812 let device_id = self.device_id;
1813 let entry = entry.clone();
1814 tokio::task::spawn_blocking(move || -> Result<()> {
1815 let mut conn = pool
1816 .get()
1817 .map_err(|e| StoreError::Connection(e.to_string()))?;
1818 diesel::insert_into(lid_pn_mapping::table)
1819 .values((
1820 lid_pn_mapping::lid.eq(&entry.lid),
1821 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1822 lid_pn_mapping::created_at.eq(entry.created_at),
1823 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1824 lid_pn_mapping::updated_at.eq(entry.updated_at),
1825 lid_pn_mapping::device_id.eq(device_id),
1826 ))
1827 .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1828 .do_update()
1829 .set((
1830 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1831 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1832 lid_pn_mapping::updated_at.eq(entry.updated_at),
1833 ))
1834 .execute(&mut conn)
1835 .map_err(|e| StoreError::Database(e.to_string()))?;
1836 Ok(())
1837 })
1838 .await
1839 .map_err(|e| StoreError::Database(e.to_string()))??;
1840 Ok(())
1841 }
1842
1843 async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1844 let pool = self.pool.clone();
1845 let device_id = self.device_id;
1846 tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1847 let mut conn = pool
1848 .get()
1849 .map_err(|e| StoreError::Connection(e.to_string()))?;
1850 let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1851 .select((
1852 lid_pn_mapping::lid,
1853 lid_pn_mapping::phone_number,
1854 lid_pn_mapping::created_at,
1855 lid_pn_mapping::learning_source,
1856 lid_pn_mapping::updated_at,
1857 ))
1858 .filter(lid_pn_mapping::device_id.eq(device_id))
1859 .load(&mut conn)
1860 .map_err(|e| StoreError::Database(e.to_string()))?;
1861 Ok(rows
1862 .into_iter()
1863 .map(
1864 |(lid, phone_number, created_at, learning_source, updated_at)| {
1865 LidPnMappingEntry {
1866 lid,
1867 phone_number,
1868 created_at,
1869 updated_at,
1870 learning_source,
1871 }
1872 },
1873 )
1874 .collect())
1875 })
1876 .await
1877 .map_err(|e| StoreError::Database(e.to_string()))?
1878 }
1879
1880 async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1881 let pool = self.pool.clone();
1882 let device_id = self.device_id;
1883 let address = address.to_string();
1884 let message_id = message_id.to_string();
1885 let base_key = base_key.to_vec();
1886 let now = std::time::SystemTime::now()
1887 .duration_since(std::time::UNIX_EPOCH)
1888 .unwrap_or_default()
1889 .as_secs() as i32;
1890 tokio::task::spawn_blocking(move || -> Result<()> {
1891 let mut conn = pool
1892 .get()
1893 .map_err(|e| StoreError::Connection(e.to_string()))?;
1894 diesel::insert_into(base_keys::table)
1895 .values((
1896 base_keys::address.eq(&address),
1897 base_keys::message_id.eq(&message_id),
1898 base_keys::base_key.eq(&base_key),
1899 base_keys::device_id.eq(device_id),
1900 base_keys::created_at.eq(now),
1901 ))
1902 .on_conflict((
1903 base_keys::address,
1904 base_keys::message_id,
1905 base_keys::device_id,
1906 ))
1907 .do_update()
1908 .set(base_keys::base_key.eq(&base_key))
1909 .execute(&mut conn)
1910 .map_err(|e| StoreError::Database(e.to_string()))?;
1911 Ok(())
1912 })
1913 .await
1914 .map_err(|e| StoreError::Database(e.to_string()))??;
1915 Ok(())
1916 }
1917
1918 async fn has_same_base_key(
1919 &self,
1920 address: &str,
1921 message_id: &str,
1922 current_base_key: &[u8],
1923 ) -> Result<bool> {
1924 let pool = self.pool.clone();
1925 let device_id = self.device_id;
1926 let address = address.to_string();
1927 let message_id = message_id.to_string();
1928 let current_base_key = current_base_key.to_vec();
1929 tokio::task::spawn_blocking(move || -> Result<bool> {
1930 let mut conn = pool
1931 .get()
1932 .map_err(|e| StoreError::Connection(e.to_string()))?;
1933 let stored_key: Option<Vec<u8>> = base_keys::table
1934 .select(base_keys::base_key)
1935 .filter(base_keys::address.eq(&address))
1936 .filter(base_keys::message_id.eq(&message_id))
1937 .filter(base_keys::device_id.eq(device_id))
1938 .first(&mut conn)
1939 .optional()
1940 .map_err(|e| StoreError::Database(e.to_string()))?;
1941 Ok(stored_key.as_ref() == Some(¤t_base_key))
1942 })
1943 .await
1944 .map_err(|e| StoreError::Database(e.to_string()))?
1945 }
1946
1947 async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1948 let pool = self.pool.clone();
1949 let device_id = self.device_id;
1950 let address = address.to_string();
1951 let message_id = message_id.to_string();
1952 tokio::task::spawn_blocking(move || -> Result<()> {
1953 let mut conn = pool
1954 .get()
1955 .map_err(|e| StoreError::Connection(e.to_string()))?;
1956 diesel::delete(
1957 base_keys::table
1958 .filter(base_keys::address.eq(&address))
1959 .filter(base_keys::message_id.eq(&message_id))
1960 .filter(base_keys::device_id.eq(device_id)),
1961 )
1962 .execute(&mut conn)
1963 .map_err(|e| StoreError::Database(e.to_string()))?;
1964 Ok(())
1965 })
1966 .await
1967 .map_err(|e| StoreError::Database(e.to_string()))??;
1968 Ok(())
1969 }
1970
1971 async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1972 let pool = self.pool.clone();
1973 let device_id = self.device_id;
1974 let devices_json = serde_json::to_string(&record.devices)
1975 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1976 let now = std::time::SystemTime::now()
1977 .duration_since(std::time::UNIX_EPOCH)
1978 .unwrap_or_default()
1979 .as_secs() as i32;
1980 tokio::task::spawn_blocking(move || -> Result<()> {
1981 let mut conn = pool
1982 .get()
1983 .map_err(|e| StoreError::Connection(e.to_string()))?;
1984 diesel::insert_into(device_registry::table)
1985 .values((
1986 device_registry::user_id.eq(&record.user),
1987 device_registry::devices_json.eq(&devices_json),
1988 device_registry::timestamp.eq(record.timestamp as i32),
1989 device_registry::phash.eq(&record.phash),
1990 device_registry::device_id.eq(device_id),
1991 device_registry::updated_at.eq(now),
1992 ))
1993 .on_conflict((device_registry::user_id, device_registry::device_id))
1994 .do_update()
1995 .set((
1996 device_registry::devices_json.eq(&devices_json),
1997 device_registry::timestamp.eq(record.timestamp as i32),
1998 device_registry::phash.eq(&record.phash),
1999 device_registry::updated_at.eq(now),
2000 ))
2001 .execute(&mut conn)
2002 .map_err(|e| StoreError::Database(e.to_string()))?;
2003 Ok(())
2004 })
2005 .await
2006 .map_err(|e| StoreError::Database(e.to_string()))??;
2007 Ok(())
2008 }
2009
2010 async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
2011 let pool = self.pool.clone();
2012 let device_id = self.device_id;
2013 let user = user.to_string();
2014 tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
2015 let mut conn = pool
2016 .get()
2017 .map_err(|e| StoreError::Connection(e.to_string()))?;
2018 let row: Option<(String, String, i32, Option<String>)> = device_registry::table
2019 .select((
2020 device_registry::user_id,
2021 device_registry::devices_json,
2022 device_registry::timestamp,
2023 device_registry::phash,
2024 ))
2025 .filter(device_registry::user_id.eq(&user))
2026 .filter(device_registry::device_id.eq(device_id))
2027 .first(&mut conn)
2028 .optional()
2029 .map_err(|e| StoreError::Database(e.to_string()))?;
2030 match row {
2031 Some((user, devices_json, timestamp, phash)) => {
2032 let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
2033 .map_err(|e| StoreError::Serialization(e.to_string()))?;
2034 Ok(Some(DeviceListRecord {
2035 user,
2036 devices,
2037 timestamp: timestamp as i64,
2038 phash,
2039 }))
2040 }
2041 None => Ok(None),
2042 }
2043 })
2044 .await
2045 .map_err(|e| StoreError::Database(e.to_string()))?
2046 }
2047
2048 async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
2049 let pool = self.pool.clone();
2050 let device_id = self.device_id;
2051 let group_jid = group_jid.to_string();
2052 let participant = participant.to_string();
2053 let now = std::time::SystemTime::now()
2054 .duration_since(std::time::UNIX_EPOCH)
2055 .unwrap_or_default()
2056 .as_secs() as i32;
2057 tokio::task::spawn_blocking(move || -> Result<()> {
2058 let mut conn = pool
2059 .get()
2060 .map_err(|e| StoreError::Connection(e.to_string()))?;
2061 diesel::insert_into(sender_key_status::table)
2062 .values((
2063 sender_key_status::group_jid.eq(&group_jid),
2064 sender_key_status::participant.eq(&participant),
2065 sender_key_status::device_id.eq(device_id),
2066 sender_key_status::marked_at.eq(now),
2067 ))
2068 .on_conflict((
2069 sender_key_status::group_jid,
2070 sender_key_status::participant,
2071 sender_key_status::device_id,
2072 ))
2073 .do_update()
2074 .set(sender_key_status::marked_at.eq(now))
2075 .execute(&mut conn)
2076 .map_err(|e| StoreError::Database(e.to_string()))?;
2077 Ok(())
2078 })
2079 .await
2080 .map_err(|e| StoreError::Database(e.to_string()))??;
2081 Ok(())
2082 }
2083
2084 async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
2085 let pool = self.pool.clone();
2086 let device_id = self.device_id;
2087 let group_jid = group_jid.to_string();
2088 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2089 let mut conn = pool
2090 .get()
2091 .map_err(|e| StoreError::Connection(e.to_string()))?;
2092 let participants: Vec<String> = sender_key_status::table
2093 .select(sender_key_status::participant)
2094 .filter(sender_key_status::group_jid.eq(&group_jid))
2095 .filter(sender_key_status::device_id.eq(device_id))
2096 .load(&mut conn)
2097 .map_err(|e| StoreError::Database(e.to_string()))?;
2098 diesel::delete(
2099 sender_key_status::table
2100 .filter(sender_key_status::group_jid.eq(&group_jid))
2101 .filter(sender_key_status::device_id.eq(device_id)),
2102 )
2103 .execute(&mut conn)
2104 .map_err(|e| StoreError::Database(e.to_string()))?;
2105 Ok(participants)
2106 })
2107 .await
2108 .map_err(|e| StoreError::Database(e.to_string()))?
2109 }
2110
2111 async fn get_tc_token(&self, jid: &str) -> Result<Option<TcTokenEntry>> {
2112 let pool = self.pool.clone();
2113 let device_id = self.device_id;
2114 let jid = jid.to_string();
2115 tokio::task::spawn_blocking(move || -> Result<Option<TcTokenEntry>> {
2116 let mut conn = pool
2117 .get()
2118 .map_err(|e| StoreError::Connection(e.to_string()))?;
2119 let row: Option<(Vec<u8>, i64, Option<i64>)> = tc_tokens::table
2120 .select((
2121 tc_tokens::token,
2122 tc_tokens::token_timestamp,
2123 tc_tokens::sender_timestamp,
2124 ))
2125 .filter(tc_tokens::jid.eq(&jid))
2126 .filter(tc_tokens::device_id.eq(device_id))
2127 .first(&mut conn)
2128 .optional()
2129 .map_err(|e| StoreError::Database(e.to_string()))?;
2130 Ok(
2131 row.map(|(token, token_timestamp, sender_timestamp)| TcTokenEntry {
2132 token,
2133 token_timestamp,
2134 sender_timestamp,
2135 }),
2136 )
2137 })
2138 .await
2139 .map_err(|e| StoreError::Database(e.to_string()))?
2140 }
2141
2142 async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> Result<()> {
2143 let pool = self.pool.clone();
2144 let device_id = self.device_id;
2145 let jid = jid.to_string();
2146 let entry = entry.clone();
2147 let now = std::time::SystemTime::now()
2148 .duration_since(std::time::UNIX_EPOCH)
2149 .unwrap_or_default()
2150 .as_secs() as i64;
2151 tokio::task::spawn_blocking(move || -> Result<()> {
2152 let mut conn = pool
2153 .get()
2154 .map_err(|e| StoreError::Connection(e.to_string()))?;
2155 diesel::insert_into(tc_tokens::table)
2156 .values((
2157 tc_tokens::jid.eq(&jid),
2158 tc_tokens::token.eq(&entry.token),
2159 tc_tokens::token_timestamp.eq(entry.token_timestamp),
2160 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2161 tc_tokens::device_id.eq(device_id),
2162 tc_tokens::updated_at.eq(now),
2163 ))
2164 .on_conflict((tc_tokens::jid, tc_tokens::device_id))
2165 .do_update()
2166 .set((
2167 tc_tokens::token.eq(&entry.token),
2168 tc_tokens::token_timestamp.eq(entry.token_timestamp),
2169 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
2170 tc_tokens::updated_at.eq(now),
2171 ))
2172 .execute(&mut conn)
2173 .map_err(|e| StoreError::Database(e.to_string()))?;
2174 Ok(())
2175 })
2176 .await
2177 .map_err(|e| StoreError::Database(e.to_string()))??;
2178 Ok(())
2179 }
2180
2181 async fn delete_tc_token(&self, jid: &str) -> Result<()> {
2182 let pool = self.pool.clone();
2183 let device_id = self.device_id;
2184 let jid = jid.to_string();
2185 tokio::task::spawn_blocking(move || -> Result<()> {
2186 let mut conn = pool
2187 .get()
2188 .map_err(|e| StoreError::Connection(e.to_string()))?;
2189 diesel::delete(
2190 tc_tokens::table
2191 .filter(tc_tokens::jid.eq(&jid))
2192 .filter(tc_tokens::device_id.eq(device_id)),
2193 )
2194 .execute(&mut conn)
2195 .map_err(|e| StoreError::Database(e.to_string()))?;
2196 Ok(())
2197 })
2198 .await
2199 .map_err(|e| StoreError::Database(e.to_string()))??;
2200 Ok(())
2201 }
2202
2203 async fn get_all_tc_token_jids(&self) -> Result<Vec<String>> {
2204 let pool = self.pool.clone();
2205 let device_id = self.device_id;
2206 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
2207 let mut conn = pool
2208 .get()
2209 .map_err(|e| StoreError::Connection(e.to_string()))?;
2210 let jids: Vec<String> = tc_tokens::table
2211 .select(tc_tokens::jid)
2212 .filter(tc_tokens::device_id.eq(device_id))
2213 .load(&mut conn)
2214 .map_err(|e| StoreError::Database(e.to_string()))?;
2215 Ok(jids)
2216 })
2217 .await
2218 .map_err(|e| StoreError::Database(e.to_string()))?
2219 }
2220
2221 async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> Result<u32> {
2222 let pool = self.pool.clone();
2223 let device_id = self.device_id;
2224 tokio::task::spawn_blocking(move || -> Result<u32> {
2225 let mut conn = pool
2226 .get()
2227 .map_err(|e| StoreError::Connection(e.to_string()))?;
2228 let deleted = diesel::delete(
2229 tc_tokens::table
2230 .filter(tc_tokens::token_timestamp.lt(cutoff_timestamp))
2231 .filter(tc_tokens::device_id.eq(device_id)),
2232 )
2233 .execute(&mut conn)
2234 .map_err(|e| StoreError::Database(e.to_string()))?;
2235 Ok(deleted as u32)
2236 })
2237 .await
2238 .map_err(|e| StoreError::Database(e.to_string()))?
2239 }
2240}
2241
2242#[async_trait]
2243impl DeviceStore for SqliteStore {
2244 async fn save(&self, device: &CoreDevice) -> Result<()> {
2245 SqliteStore::save_device_data_for_device(self, self.device_id, device).await
2246 }
2247
2248 async fn load(&self) -> Result<Option<CoreDevice>> {
2249 SqliteStore::load_device_data_for_device(self, self.device_id).await
2250 }
2251
2252 async fn exists(&self) -> Result<bool> {
2253 SqliteStore::device_exists(self, self.device_id).await
2254 }
2255
2256 async fn create(&self) -> Result<i32> {
2257 SqliteStore::create_new_device(self).await
2258 }
2259
2260 async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> Result<()> {
2261 fn sanitize_snapshot_name(name: &str) -> Result<String> {
2262 const MAX_LENGTH: usize = 100;
2263
2264 let sanitized: String = name
2265 .chars()
2266 .map(|c| {
2267 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' {
2268 c
2269 } else {
2270 '_'
2271 }
2272 })
2273 .collect();
2274
2275 let sanitized = sanitized
2276 .split('.')
2277 .filter(|part| !part.is_empty() && *part != "..")
2278 .collect::<Vec<_>>()
2279 .join(".");
2280
2281 let sanitized = sanitized.trim_matches(['/', '\\', '.']);
2282
2283 if sanitized.is_empty() {
2284 return Err(StoreError::Database(
2285 "Snapshot name cannot be empty after sanitization".to_string(),
2286 ));
2287 }
2288
2289 if sanitized.len() > MAX_LENGTH {
2290 return Err(StoreError::Database(format!(
2291 "Snapshot name exceeds maximum length of {} characters",
2292 MAX_LENGTH
2293 )));
2294 }
2295
2296 Ok(sanitized.to_string())
2297 }
2298
2299 let sanitized_name = sanitize_snapshot_name(name)?;
2300
2301 let pool = self.pool.clone();
2302 let db_path = self.database_path.clone();
2303 let extra_data = extra_content.map(|b| b.to_vec());
2304
2305 tokio::task::spawn_blocking(move || -> Result<()> {
2306 let mut conn = pool
2307 .get()
2308 .map_err(|e| StoreError::Connection(e.to_string()))?;
2309
2310 let timestamp = std::time::SystemTime::now()
2311 .duration_since(std::time::UNIX_EPOCH)
2312 .unwrap_or_default()
2313 .as_secs();
2314
2315 let target_path = format!("{}.snapshot-{}-{}", db_path, timestamp, sanitized_name);
2317
2318 let query = format!("VACUUM INTO '{}'", target_path.replace("'", "''"));
2321
2322 diesel::sql_query(query)
2323 .execute(&mut conn)
2324 .map_err(|e| StoreError::Database(e.to_string()))?;
2325
2326 if let Some(data) = extra_data {
2328 let extra_path = format!("{}.json", target_path);
2329 std::fs::write(&extra_path, data).map_err(|e| {
2330 StoreError::Database(format!("Failed to write snapshot extra content: {}", e))
2331 })?;
2332 }
2333
2334 Ok(())
2335 })
2336 .await
2337 .map_err(|e| StoreError::Database(e.to_string()))??;
2338
2339 Ok(())
2340 }
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345 use super::*;
2346
2347 async fn create_test_store() -> SqliteStore {
2348 use std::sync::atomic::{AtomicU64, Ordering};
2349 static COUNTER: AtomicU64 = AtomicU64::new(0);
2350 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
2351 let db_name = format!(
2352 "file:memdb_test_{}_{}?mode=memory&cache=shared",
2353 std::process::id(),
2354 id
2355 );
2356 SqliteStore::new(&db_name)
2357 .await
2358 .expect("Failed to create test store")
2359 }
2360
2361 #[test]
2362 fn test_parse_database_path_regular_path() {
2363 let path = "/var/lib/whatsapp/database.db";
2364 let result = parse_database_path(path).unwrap();
2365 assert_eq!(result, "/var/lib/whatsapp/database.db");
2366 }
2367
2368 #[test]
2369 fn test_parse_database_path_with_sqlite_prefix() {
2370 let path = "sqlite:///var/lib/whatsapp/database.db";
2371 let result = parse_database_path(path).unwrap();
2372 assert_eq!(result, "/var/lib/whatsapp/database.db");
2373 }
2374
2375 #[test]
2376 fn test_parse_database_path_with_query_params() {
2377 let path = "file:database.db?mode=memory&cache=shared";
2378 let result = parse_database_path(path).unwrap();
2379 assert_eq!(result, "file:database.db");
2380 }
2381
2382 #[test]
2383 fn test_parse_database_path_with_fragment() {
2384 let path = "file:database.db#fragment";
2385 let result = parse_database_path(path).unwrap();
2386 assert_eq!(result, "file:database.db");
2387 }
2388
2389 #[test]
2390 fn test_parse_database_path_with_both_query_and_fragment() {
2391 let path = "sqlite:///var/lib/database.db?mode=ro#backup";
2392 let result = parse_database_path(path).unwrap();
2393 assert_eq!(result, "/var/lib/database.db");
2394 }
2395
2396 #[test]
2397 fn test_parse_database_path_in_memory_rejected() {
2398 let result = parse_database_path(":memory:");
2399 assert!(result.is_err());
2400 assert!(result.unwrap_err().to_string().contains("not supported"));
2401 }
2402
2403 #[test]
2404 fn test_parse_database_path_in_memory_with_query_rejected() {
2405 let result = parse_database_path(":memory:?cache=shared");
2406 assert!(result.is_err());
2407 assert!(result.unwrap_err().to_string().contains("not supported"));
2408 }
2409
2410 #[tokio::test]
2411 async fn test_device_registry_save_and_get() {
2412 let store = create_test_store().await;
2413
2414 let record = DeviceListRecord {
2415 user: "1234567890".to_string(),
2416 devices: vec![
2417 DeviceInfo {
2418 device_id: 0,
2419 key_index: None,
2420 },
2421 DeviceInfo {
2422 device_id: 1,
2423 key_index: Some(42),
2424 },
2425 ],
2426 timestamp: 1234567890,
2427 phash: Some("2:abcdef".to_string()),
2428 };
2429
2430 store.update_device_list(record).await.expect("save failed");
2431 let loaded = store
2432 .get_devices("1234567890")
2433 .await
2434 .expect("get failed")
2435 .expect("record should exist");
2436
2437 assert_eq!(loaded.user, "1234567890");
2438 assert_eq!(loaded.devices.len(), 2);
2439 assert_eq!(loaded.devices[0].device_id, 0);
2440 assert_eq!(loaded.devices[1].device_id, 1);
2441 assert_eq!(loaded.devices[1].key_index, Some(42));
2442 assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
2443 }
2444
2445 #[tokio::test]
2446 async fn test_device_registry_update_existing() {
2447 let store = create_test_store().await;
2448
2449 let record1 = DeviceListRecord {
2450 user: "1234567890".to_string(),
2451 devices: vec![DeviceInfo {
2452 device_id: 0,
2453 key_index: None,
2454 }],
2455 timestamp: 1000,
2456 phash: Some("2:old".to_string()),
2457 };
2458 store
2459 .update_device_list(record1)
2460 .await
2461 .expect("save1 failed");
2462
2463 let record2 = DeviceListRecord {
2464 user: "1234567890".to_string(),
2465 devices: vec![
2466 DeviceInfo {
2467 device_id: 0,
2468 key_index: None,
2469 },
2470 DeviceInfo {
2471 device_id: 2,
2472 key_index: None,
2473 },
2474 ],
2475 timestamp: 2000,
2476 phash: Some("2:new".to_string()),
2477 };
2478 store
2479 .update_device_list(record2)
2480 .await
2481 .expect("save2 failed");
2482
2483 let loaded = store
2484 .get_devices("1234567890")
2485 .await
2486 .expect("get failed")
2487 .expect("record should exist");
2488
2489 assert_eq!(loaded.devices.len(), 2);
2490 assert_eq!(loaded.phash, Some("2:new".to_string()));
2491 }
2492
2493 #[tokio::test]
2494 async fn test_device_registry_get_nonexistent() {
2495 let store = create_test_store().await;
2496 let result = store.get_devices("nonexistent").await.expect("get failed");
2497 assert!(result.is_none());
2498 }
2499
2500 #[tokio::test]
2501 async fn test_sender_key_status_mark_and_consume() {
2502 let store = create_test_store().await;
2503
2504 let group = "group123@g.us";
2505 let participant = "user1@s.whatsapp.net";
2506
2507 store
2508 .mark_forget_sender_key(group, participant)
2509 .await
2510 .expect("mark failed");
2511
2512 let consumed = store
2513 .consume_forget_marks(group)
2514 .await
2515 .expect("consume failed");
2516 assert_eq!(consumed.len(), 1);
2517 assert!(consumed.contains(&participant.to_string()));
2518
2519 let consumed = store
2520 .consume_forget_marks(group)
2521 .await
2522 .expect("consume failed");
2523 assert!(consumed.is_empty());
2524 }
2525
2526 #[tokio::test]
2527 async fn test_sender_key_status_consume_multiple() {
2528 let store = create_test_store().await;
2529
2530 let group = "group123@g.us";
2531
2532 store
2533 .mark_forget_sender_key(group, "user1@s.whatsapp.net")
2534 .await
2535 .expect("mark failed");
2536 store
2537 .mark_forget_sender_key(group, "user2@s.whatsapp.net")
2538 .await
2539 .expect("mark failed");
2540
2541 let consumed = store
2542 .consume_forget_marks(group)
2543 .await
2544 .expect("consume failed");
2545 assert_eq!(consumed.len(), 2);
2546 assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
2547 assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
2548
2549 let consumed = store
2550 .consume_forget_marks(group)
2551 .await
2552 .expect("consume failed");
2553 assert!(consumed.is_empty());
2554 }
2555
2556 #[tokio::test]
2557 async fn test_tc_token_put_and_get() {
2558 let store = create_test_store().await;
2559
2560 let entry = TcTokenEntry {
2561 token: vec![1, 2, 3, 4, 5],
2562 token_timestamp: 1707000000,
2563 sender_timestamp: Some(1707000100),
2564 };
2565
2566 store
2567 .put_tc_token("user@lid", &entry)
2568 .await
2569 .expect("put failed");
2570
2571 let loaded = store
2572 .get_tc_token("user@lid")
2573 .await
2574 .expect("get failed")
2575 .expect("should exist");
2576
2577 assert_eq!(loaded.token, vec![1, 2, 3, 4, 5]);
2578 assert_eq!(loaded.token_timestamp, 1707000000);
2579 assert_eq!(loaded.sender_timestamp, Some(1707000100));
2580 }
2581
2582 #[tokio::test]
2583 async fn test_tc_token_upsert() {
2584 let store = create_test_store().await;
2585
2586 let entry1 = TcTokenEntry {
2587 token: vec![1, 2, 3],
2588 token_timestamp: 1000,
2589 sender_timestamp: None,
2590 };
2591 store.put_tc_token("user@lid", &entry1).await.unwrap();
2592
2593 let entry2 = TcTokenEntry {
2594 token: vec![4, 5, 6],
2595 token_timestamp: 2000,
2596 sender_timestamp: Some(1500),
2597 };
2598 store.put_tc_token("user@lid", &entry2).await.unwrap();
2599
2600 let loaded = store.get_tc_token("user@lid").await.unwrap().unwrap();
2601 assert_eq!(loaded.token, vec![4, 5, 6]);
2602 assert_eq!(loaded.token_timestamp, 2000);
2603 assert_eq!(loaded.sender_timestamp, Some(1500));
2604 }
2605
2606 #[tokio::test]
2607 async fn test_tc_token_delete() {
2608 let store = create_test_store().await;
2609
2610 let entry = TcTokenEntry {
2611 token: vec![1, 2, 3],
2612 token_timestamp: 1000,
2613 sender_timestamp: None,
2614 };
2615 store.put_tc_token("user@lid", &entry).await.unwrap();
2616 store.delete_tc_token("user@lid").await.unwrap();
2617
2618 let result = store.get_tc_token("user@lid").await.unwrap();
2619 assert!(result.is_none());
2620 }
2621
2622 #[tokio::test]
2623 async fn test_tc_token_get_all_jids() {
2624 let store = create_test_store().await;
2625
2626 let entry = TcTokenEntry {
2627 token: vec![1],
2628 token_timestamp: 1000,
2629 sender_timestamp: None,
2630 };
2631 store.put_tc_token("user1@lid", &entry).await.unwrap();
2632 store.put_tc_token("user2@lid", &entry).await.unwrap();
2633 store.put_tc_token("user3@lid", &entry).await.unwrap();
2634
2635 let mut jids = store.get_all_tc_token_jids().await.unwrap();
2636 jids.sort();
2637 assert_eq!(jids, vec!["user1@lid", "user2@lid", "user3@lid"]);
2638 }
2639
2640 #[tokio::test]
2641 async fn test_tc_token_delete_expired() {
2642 let store = create_test_store().await;
2643
2644 let old = TcTokenEntry {
2645 token: vec![1],
2646 token_timestamp: 1000,
2647 sender_timestamp: None,
2648 };
2649 let recent = TcTokenEntry {
2650 token: vec![2],
2651 token_timestamp: 5000,
2652 sender_timestamp: None,
2653 };
2654 store.put_tc_token("old@lid", &old).await.unwrap();
2655 store.put_tc_token("recent@lid", &recent).await.unwrap();
2656
2657 let deleted = store.delete_expired_tc_tokens(3000).await.unwrap();
2658 assert_eq!(deleted, 1);
2659
2660 assert!(store.get_tc_token("old@lid").await.unwrap().is_none());
2661 assert!(store.get_tc_token("recent@lid").await.unwrap().is_some());
2662 }
2663
2664 #[tokio::test]
2665 async fn test_tc_token_get_nonexistent() {
2666 let store = create_test_store().await;
2667 let result = store.get_tc_token("nonexistent@lid").await.unwrap();
2668 assert!(result.is_none());
2669 }
2670
2671 #[tokio::test]
2672 async fn test_sender_key_status_different_groups() {
2673 let store = create_test_store().await;
2674
2675 let group1 = "group1@g.us";
2676 let group2 = "group2@g.us";
2677 let participant = "user@s.whatsapp.net";
2678
2679 store
2680 .mark_forget_sender_key(group1, participant)
2681 .await
2682 .expect("mark failed");
2683
2684 let consumed = store.consume_forget_marks(group1).await.unwrap();
2685 assert_eq!(consumed.len(), 1);
2686
2687 let consumed = store.consume_forget_marks(group2).await.unwrap();
2688 assert!(consumed.is_empty());
2689 }
2690}