1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::sql_query;
6use diesel::sqlite::SqliteConnection;
7use diesel::upsert::excluded;
8use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
9use log::warn;
10use prost::Message;
11use std::sync::Arc;
12use wacore::appstate::hash::HashState;
13use wacore::appstate::processor::AppStateMutationMAC;
14use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
15use wacore::store::Device as CoreDevice;
16use wacore::store::error::{Result, StoreError};
17use wacore::store::traits::*;
18use wacore_binary::jid::Jid;
19use waproto::whatsapp as wa;
20
21const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
22
23type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
24type DeviceRow = (
25 i32,
26 String,
27 String,
28 i32,
29 Vec<u8>,
30 Vec<u8>,
31 Vec<u8>,
32 i32,
33 Vec<u8>,
34 Vec<u8>,
35 Option<Vec<u8>>,
36 String,
37 i32,
38 i32,
39 i64,
40 i64,
41 Option<Vec<u8>>,
42 Option<String>,
43);
44
45#[derive(Clone)]
46pub struct SqliteStore {
47 pub(crate) pool: SqlitePool,
48 pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
49 pub(crate) database_path: String,
50 device_id: i32,
51}
52
53#[derive(Debug, Clone, Copy)]
54struct ConnectionOptions;
55
56impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
57 for ConnectionOptions
58{
59 fn on_acquire(
60 &self,
61 conn: &mut SqliteConnection,
62 ) -> std::result::Result<(), diesel::r2d2::Error> {
63 diesel::sql_query("PRAGMA busy_timeout = 30000;")
64 .execute(conn)
65 .map_err(diesel::r2d2::Error::QueryError)?;
66 diesel::sql_query("PRAGMA synchronous = NORMAL;")
67 .execute(conn)
68 .map_err(diesel::r2d2::Error::QueryError)?;
69 diesel::sql_query("PRAGMA cache_size = 512;")
70 .execute(conn)
71 .map_err(diesel::r2d2::Error::QueryError)?;
72 diesel::sql_query("PRAGMA temp_store = memory;")
73 .execute(conn)
74 .map_err(diesel::r2d2::Error::QueryError)?;
75 diesel::sql_query("PRAGMA foreign_keys = ON;")
76 .execute(conn)
77 .map_err(diesel::r2d2::Error::QueryError)?;
78 Ok(())
79 }
80}
81
82fn parse_database_path(database_url: &str) -> Result<String> {
83 if database_url == ":memory:" {
85 return Err(StoreError::Database(
86 "Snapshot not supported for in-memory databases".to_string(),
87 ));
88 }
89
90 let path = database_url
92 .split(['?', '#'])
93 .next()
94 .unwrap_or(database_url);
95
96 let path = path.trim_start_matches("sqlite://");
98
99 if path == ":memory:" || path.starts_with(":memory:?") {
101 return Err(StoreError::Database(
102 "Snapshot not supported for in-memory databases".to_string(),
103 ));
104 }
105
106 Ok(path.to_string())
107}
108
109impl SqliteStore {
110 pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
111 let manager = ConnectionManager::<SqliteConnection>::new(database_url);
112
113 let pool_size = 2;
114
115 let pool = Pool::builder()
116 .max_size(pool_size)
117 .connection_customizer(Box::new(ConnectionOptions))
118 .build(manager)
119 .map_err(|e| StoreError::Connection(e.to_string()))?;
120
121 let pool_clone = pool.clone();
122 tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
123 let mut conn = pool_clone
124 .get()
125 .map_err(|e| StoreError::Connection(e.to_string()))?;
126
127 diesel::sql_query("PRAGMA journal_mode = WAL;")
128 .execute(&mut conn)
129 .map_err(|e| StoreError::Database(e.to_string()))?;
130
131 conn.run_pending_migrations(MIGRATIONS)
132 .map_err(|e| StoreError::Migration(e.to_string()))?;
133
134 Ok(())
135 })
136 .await
137 .map_err(|e| StoreError::Database(e.to_string()))??;
138
139 let database_path = parse_database_path(database_url)?;
140
141 Ok(Self {
142 pool,
143 db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
144 database_path,
145 device_id: 1,
146 })
147 }
148
149 pub async fn new_for_device(
150 database_url: &str,
151 device_id: i32,
152 ) -> std::result::Result<Self, StoreError> {
153 let mut store = Self::new(database_url).await?;
154 store.device_id = device_id;
155 Ok(store)
156 }
157
158 pub fn device_id(&self) -> i32 {
159 self.device_id
160 }
161
162 async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
163 where
164 F: FnOnce() -> Result<T> + Send + 'static,
165 T: Send + 'static,
166 {
167 let permit = self
168 .db_semaphore
169 .clone()
170 .acquire_owned()
171 .await
172 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
173 let result = tokio::task::spawn_blocking(move || {
174 let res = f();
175 drop(permit);
176 res
177 })
178 .await
179 .map_err(|e| StoreError::Database(e.to_string()))??;
180 Ok(result)
181 }
182
183 fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
184 let mut bytes = Vec::with_capacity(64);
185 bytes.extend_from_slice(key_pair.private_key.serialize());
186 bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
187 Ok(bytes)
188 }
189
190 fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
191 if bytes.len() != 64 {
192 return Err(StoreError::Serialization(format!(
193 "Invalid KeyPair length: {}",
194 bytes.len()
195 )));
196 }
197
198 let private_key = PrivateKey::deserialize(&bytes[0..32])
199 .map_err(|e| StoreError::Serialization(e.to_string()))?;
200 let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
201 .map_err(|e| StoreError::Serialization(e.to_string()))?;
202
203 Ok(KeyPair::new(public_key, private_key))
204 }
205
206 pub async fn save_device_data_for_device(
207 &self,
208 device_id: i32,
209 device_data: &CoreDevice,
210 ) -> Result<()> {
211 let pool = self.pool.clone();
212 let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
213 let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
214 let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
215 let account_data = device_data
216 .account
217 .as_ref()
218 .map(|account| account.encode_to_vec());
219 let registration_id = device_data.registration_id as i32;
220 let signed_pre_key_id = device_data.signed_pre_key_id as i32;
221 let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
222 let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
223 let push_name = device_data.push_name.clone();
224 let app_version_primary = device_data.app_version_primary as i32;
225 let app_version_secondary = device_data.app_version_secondary as i32;
226 let app_version_tertiary = device_data.app_version_tertiary as i64;
227 let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
228 let edge_routing_info = device_data.edge_routing_info.clone();
229 let props_hash = device_data.props_hash.clone();
230 let new_lid = device_data
231 .lid
232 .as_ref()
233 .map(|j| j.to_string())
234 .unwrap_or_default();
235 let new_pn = device_data
236 .pn
237 .as_ref()
238 .map(|j| j.to_string())
239 .unwrap_or_default();
240
241 tokio::task::spawn_blocking(move || -> Result<()> {
242 let mut conn = pool
243 .get()
244 .map_err(|e| StoreError::Connection(e.to_string()))?;
245
246 diesel::insert_into(device::table)
247 .values((
248 device::id.eq(device_id),
249 device::lid.eq(&new_lid),
250 device::pn.eq(&new_pn),
251 device::registration_id.eq(registration_id),
252 device::noise_key.eq(&noise_key_data),
253 device::identity_key.eq(&identity_key_data),
254 device::signed_pre_key.eq(&signed_pre_key_data),
255 device::signed_pre_key_id.eq(signed_pre_key_id),
256 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
257 device::adv_secret_key.eq(&adv_secret_key[..]),
258 device::account.eq(account_data.clone()),
259 device::push_name.eq(&push_name),
260 device::app_version_primary.eq(app_version_primary),
261 device::app_version_secondary.eq(app_version_secondary),
262 device::app_version_tertiary.eq(app_version_tertiary),
263 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
264 device::edge_routing_info.eq(edge_routing_info.clone()),
265 device::props_hash.eq(props_hash.clone()),
266 ))
267 .on_conflict(device::id)
268 .do_update()
269 .set((
270 device::lid.eq(&new_lid),
271 device::pn.eq(&new_pn),
272 device::registration_id.eq(registration_id),
273 device::noise_key.eq(&noise_key_data),
274 device::identity_key.eq(&identity_key_data),
275 device::signed_pre_key.eq(&signed_pre_key_data),
276 device::signed_pre_key_id.eq(signed_pre_key_id),
277 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
278 device::adv_secret_key.eq(&adv_secret_key[..]),
279 device::account.eq(account_data.clone()),
280 device::push_name.eq(&push_name),
281 device::app_version_primary.eq(app_version_primary),
282 device::app_version_secondary.eq(app_version_secondary),
283 device::app_version_tertiary.eq(app_version_tertiary),
284 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
285 device::edge_routing_info.eq(edge_routing_info),
286 device::props_hash.eq(props_hash),
287 ))
288 .execute(&mut conn)
289 .map_err(|e| StoreError::Database(e.to_string()))?;
290
291 Ok(())
292 })
293 .await
294 .map_err(|e| StoreError::Database(e.to_string()))??;
295
296 Ok(())
297 }
298
299 pub async fn create_new_device(&self) -> Result<i32> {
300 use crate::schema::device;
301
302 let pool = self.pool.clone();
303 tokio::task::spawn_blocking(move || -> Result<i32> {
304 let mut conn = pool
305 .get()
306 .map_err(|e| StoreError::Connection(e.to_string()))?;
307
308 let new_device = wacore::store::Device::new();
309
310 let noise_key_data = {
311 let mut bytes = Vec::with_capacity(64);
312 bytes.extend_from_slice(new_device.noise_key.private_key.serialize());
313 bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
314 bytes
315 };
316 let identity_key_data = {
317 let mut bytes = Vec::with_capacity(64);
318 bytes.extend_from_slice(new_device.identity_key.private_key.serialize());
319 bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
320 bytes
321 };
322 let signed_pre_key_data = {
323 let mut bytes = Vec::with_capacity(64);
324 bytes.extend_from_slice(new_device.signed_pre_key.private_key.serialize());
325 bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
326 bytes
327 };
328
329 diesel::insert_into(device::table)
330 .values((
331 device::lid.eq(""),
332 device::pn.eq(""),
333 device::registration_id.eq(new_device.registration_id as i32),
334 device::noise_key.eq(&noise_key_data),
335 device::identity_key.eq(&identity_key_data),
336 device::signed_pre_key.eq(&signed_pre_key_data),
337 device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
338 device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
339 device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
340 device::account.eq(None::<Vec<u8>>),
341 device::push_name.eq(&new_device.push_name),
342 device::app_version_primary.eq(new_device.app_version_primary as i32),
343 device::app_version_secondary.eq(new_device.app_version_secondary as i32),
344 device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
345 device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
346 device::edge_routing_info.eq(None::<Vec<u8>>),
347 device::props_hash.eq(None::<String>),
348 ))
349 .execute(&mut conn)
350 .map_err(|e| StoreError::Database(e.to_string()))?;
351
352 use diesel::sql_types::Integer;
353
354 #[derive(QueryableByName)]
355 struct LastInsertedId {
356 #[diesel(sql_type = Integer)]
357 last_insert_rowid: i32,
358 }
359
360 let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
361 .get_result::<LastInsertedId>(&mut conn)
362 .map_err(|e| StoreError::Database(e.to_string()))?
363 .last_insert_rowid;
364
365 Ok(device_id)
366 })
367 .await
368 .map_err(|e| StoreError::Database(e.to_string()))?
369 }
370
371 pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
372 use crate::schema::device;
373
374 let pool = self.pool.clone();
375 tokio::task::spawn_blocking(move || -> Result<bool> {
376 let mut conn = pool
377 .get()
378 .map_err(|e| StoreError::Connection(e.to_string()))?;
379
380 let count: i64 = device::table
381 .filter(device::id.eq(device_id))
382 .count()
383 .get_result(&mut conn)
384 .map_err(|e| StoreError::Database(e.to_string()))?;
385
386 Ok(count > 0)
387 })
388 .await
389 .map_err(|e| StoreError::Database(e.to_string()))?
390 }
391
392 pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
393 use crate::schema::device;
394
395 let pool = self.pool.clone();
396 let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
397 let mut conn = pool
398 .get()
399 .map_err(|e| StoreError::Connection(e.to_string()))?;
400 let result = device::table
401 .filter(device::id.eq(device_id))
402 .first::<DeviceRow>(&mut conn)
403 .optional()
404 .map_err(|e| StoreError::Database(e.to_string()))?;
405 Ok(result)
406 })
407 .await
408 .map_err(|e| StoreError::Database(e.to_string()))??;
409
410 if let Some((
411 _device_id,
412 lid_str,
413 pn_str,
414 registration_id,
415 noise_key_data,
416 identity_key_data,
417 signed_pre_key_data,
418 signed_pre_key_id,
419 signed_pre_key_signature_data,
420 adv_secret_key_data,
421 account_data,
422 push_name,
423 app_version_primary,
424 app_version_secondary,
425 app_version_tertiary,
426 app_version_last_fetched_ms,
427 edge_routing_info,
428 props_hash,
429 )) = row
430 {
431 let id = if !pn_str.is_empty() {
432 pn_str.parse().ok()
433 } else {
434 None
435 };
436 let lid = if !lid_str.is_empty() {
437 lid_str.parse().ok()
438 } else {
439 None
440 };
441
442 let noise_key = self.deserialize_keypair(&noise_key_data)?;
443 let identity_key = self.deserialize_keypair(&identity_key_data)?;
444 let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
445
446 let signed_pre_key_signature: [u8; 64] =
447 signed_pre_key_signature_data.try_into().map_err(|_| {
448 StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
449 })?;
450
451 let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
452 StoreError::Serialization("Invalid adv_secret_key length".to_string())
453 })?;
454
455 let account = account_data
456 .map(|data| {
457 wa::AdvSignedDeviceIdentity::decode(&data[..])
458 .map_err(|e| StoreError::Serialization(e.to_string()))
459 })
460 .transpose()?;
461
462 Ok(Some(CoreDevice {
463 pn: id,
464 lid,
465 registration_id: registration_id as u32,
466 noise_key,
467 identity_key,
468 signed_pre_key,
469 signed_pre_key_id: signed_pre_key_id as u32,
470 signed_pre_key_signature,
471 adv_secret_key,
472 account,
473 push_name,
474 app_version_primary: app_version_primary as u32,
475 app_version_secondary: app_version_secondary as u32,
476 app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
477 app_version_last_fetched_ms,
478 device_props: {
479 use wacore::store::device::DEVICE_PROPS;
480 DEVICE_PROPS.clone()
481 },
482 edge_routing_info,
483 props_hash,
484 }))
485 } else {
486 Ok(None)
487 }
488 }
489
490 pub async fn put_identity_for_device(
491 &self,
492 address: &str,
493 key: [u8; 32],
494 device_id: i32,
495 ) -> Result<()> {
496 let pool = self.pool.clone();
497 let db_semaphore = self.db_semaphore.clone();
498 let address_owned = address.to_string();
499 let key_vec = key.to_vec();
500
501 const MAX_RETRIES: u32 = 5;
502
503 for attempt in 0..=MAX_RETRIES {
504 let permit =
505 db_semaphore.clone().acquire_owned().await.map_err(|e| {
506 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
507 })?;
508
509 let pool_clone = pool.clone();
510 let address_clone = address_owned.clone();
511 let key_clone = key_vec.clone();
512
513 let result = tokio::task::spawn_blocking(move || -> Result<()> {
514 let mut conn = pool_clone
515 .get()
516 .map_err(|e| StoreError::Connection(e.to_string()))?;
517 diesel::insert_into(identities::table)
518 .values((
519 identities::address.eq(address_clone),
520 identities::key.eq(&key_clone[..]),
521 identities::device_id.eq(device_id),
522 ))
523 .on_conflict((identities::address, identities::device_id))
524 .do_update()
525 .set(identities::key.eq(&key_clone[..]))
526 .execute(&mut conn)
527 .map_err(|e| StoreError::Database(e.to_string()))?;
528 Ok(())
529 })
530 .await;
531
532 drop(permit);
533
534 match result {
535 Ok(Ok(())) => return Ok(()),
536 Ok(Err(e)) => {
537 let error_msg = e.to_string();
538 if (error_msg.contains("locked") || error_msg.contains("busy"))
539 && attempt < MAX_RETRIES
540 {
541 let delay_ms = 10 * 2u64.pow(attempt);
542 warn!(
543 "Identity write failed (attempt {}/{}): {}. Retrying in {}ms...",
544 attempt + 1,
545 MAX_RETRIES + 1,
546 error_msg,
547 delay_ms
548 );
549 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
550 continue;
551 }
552 return Err(e);
553 }
554 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
555 }
556 }
557
558 Err(StoreError::Database(format!(
559 "Identity write failed after {} attempts",
560 MAX_RETRIES + 1
561 )))
562 }
563
564 pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
565 let pool = self.pool.clone();
566 let address_owned = address.to_string();
567
568 tokio::task::spawn_blocking(move || -> Result<()> {
569 let mut conn = pool
570 .get()
571 .map_err(|e| StoreError::Connection(e.to_string()))?;
572 diesel::delete(
573 identities::table
574 .filter(identities::address.eq(address_owned))
575 .filter(identities::device_id.eq(device_id)),
576 )
577 .execute(&mut conn)
578 .map_err(|e| StoreError::Database(e.to_string()))?;
579 Ok(())
580 })
581 .await
582 .map_err(|e| StoreError::Database(e.to_string()))??;
583
584 Ok(())
585 }
586
587 pub async fn load_identity_for_device(
588 &self,
589 address: &str,
590 device_id: i32,
591 ) -> Result<Option<Vec<u8>>> {
592 let pool = self.pool.clone();
593 let address = address.to_string();
594 let result = self
595 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
596 let mut conn = pool
597 .get()
598 .map_err(|e| StoreError::Connection(e.to_string()))?;
599 let res: Option<Vec<u8>> = identities::table
600 .select(identities::key)
601 .filter(identities::address.eq(address))
602 .filter(identities::device_id.eq(device_id))
603 .first(&mut conn)
604 .optional()
605 .map_err(|e| StoreError::Database(e.to_string()))?;
606 Ok(res)
607 })
608 .await?;
609
610 Ok(result)
611 }
612
613 pub async fn get_session_for_device(
614 &self,
615 address: &str,
616 device_id: i32,
617 ) -> Result<Option<Vec<u8>>> {
618 let pool = self.pool.clone();
619 let address_for_query = address.to_string();
620 let result = self
621 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
622 let mut conn = pool
623 .get()
624 .map_err(|e| StoreError::Connection(e.to_string()))?;
625 let res: Option<Vec<u8>> = sessions::table
626 .select(sessions::record)
627 .filter(sessions::address.eq(address_for_query.clone()))
628 .filter(sessions::device_id.eq(device_id))
629 .first(&mut conn)
630 .optional()
631 .map_err(|e| StoreError::Database(e.to_string()))?;
632
633 Ok(res)
634 })
635 .await?;
636
637 Ok(result)
638 }
639
640 pub async fn put_session_for_device(
641 &self,
642 address: &str,
643 session: &[u8],
644 device_id: i32,
645 ) -> Result<()> {
646 let pool = self.pool.clone();
647 let db_semaphore = self.db_semaphore.clone();
648 let address_owned = address.to_string();
649 let session_vec = session.to_vec();
650
651 const MAX_RETRIES: u32 = 5;
652
653 for attempt in 0..=MAX_RETRIES {
654 let permit =
655 db_semaphore.clone().acquire_owned().await.map_err(|e| {
656 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
657 })?;
658
659 let pool_clone = pool.clone();
660 let address_clone = address_owned.clone();
661 let session_clone = session_vec.clone();
662
663 let result = tokio::task::spawn_blocking(move || -> Result<()> {
664 let mut conn = pool_clone
665 .get()
666 .map_err(|e| StoreError::Connection(e.to_string()))?;
667 diesel::insert_into(sessions::table)
668 .values((
669 sessions::address.eq(address_clone),
670 sessions::record.eq(&session_clone),
671 sessions::device_id.eq(device_id),
672 ))
673 .on_conflict((sessions::address, sessions::device_id))
674 .do_update()
675 .set(sessions::record.eq(&session_clone))
676 .execute(&mut conn)
677 .map_err(|e| StoreError::Database(e.to_string()))?;
678 Ok(())
679 })
680 .await;
681
682 drop(permit);
683
684 match result {
685 Ok(Ok(())) => {
686 return Ok(());
687 }
688 Ok(Err(e)) => {
689 let error_msg = e.to_string();
690 if (error_msg.contains("locked") || error_msg.contains("busy"))
691 && attempt < MAX_RETRIES
692 {
693 let delay_ms = 10 * 2u64.pow(attempt);
694 warn!(
695 "Session write failed (attempt {}/{}): {}. Retrying in {}ms...",
696 attempt + 1,
697 MAX_RETRIES + 1,
698 error_msg,
699 delay_ms
700 );
701 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
702 continue;
703 }
704 return Err(e);
705 }
706 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
707 }
708 }
709
710 Err(StoreError::Database(format!(
711 "Session write failed after {} attempts",
712 MAX_RETRIES + 1
713 )))
714 }
715
716 pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
717 let pool = self.pool.clone();
718 let address_owned = address.to_string();
719
720 tokio::task::spawn_blocking(move || -> Result<()> {
721 let mut conn = pool
722 .get()
723 .map_err(|e| StoreError::Connection(e.to_string()))?;
724 diesel::delete(
725 sessions::table
726 .filter(sessions::address.eq(address_owned))
727 .filter(sessions::device_id.eq(device_id)),
728 )
729 .execute(&mut conn)
730 .map_err(|e| StoreError::Database(e.to_string()))?;
731 Ok(())
732 })
733 .await
734 .map_err(|e| StoreError::Database(e.to_string()))??;
735
736 Ok(())
737 }
738
739 pub async fn put_sender_key_for_device(
740 &self,
741 address: &str,
742 record: &[u8],
743 device_id: i32,
744 ) -> Result<()> {
745 let pool = self.pool.clone();
746 let address = address.to_string();
747 let record_vec = record.to_vec();
748 tokio::task::spawn_blocking(move || -> Result<()> {
749 let mut conn = pool
750 .get()
751 .map_err(|e| StoreError::Connection(e.to_string()))?;
752 diesel::insert_into(sender_keys::table)
753 .values((
754 sender_keys::address.eq(address),
755 sender_keys::record.eq(&record_vec),
756 sender_keys::device_id.eq(device_id),
757 ))
758 .on_conflict((sender_keys::address, sender_keys::device_id))
759 .do_update()
760 .set(sender_keys::record.eq(&record_vec))
761 .execute(&mut conn)
762 .map_err(|e| StoreError::Database(e.to_string()))?;
763 Ok(())
764 })
765 .await
766 .map_err(|e| StoreError::Database(e.to_string()))??;
767 Ok(())
768 }
769
770 pub async fn get_sender_key_for_device(
771 &self,
772 address: &str,
773 device_id: i32,
774 ) -> Result<Option<Vec<u8>>> {
775 let pool = self.pool.clone();
776 let address = address.to_string();
777 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
778 let mut conn = pool
779 .get()
780 .map_err(|e| StoreError::Connection(e.to_string()))?;
781 let res: Option<Vec<u8>> = sender_keys::table
782 .select(sender_keys::record)
783 .filter(sender_keys::address.eq(address))
784 .filter(sender_keys::device_id.eq(device_id))
785 .first(&mut conn)
786 .optional()
787 .map_err(|e| StoreError::Database(e.to_string()))?;
788 Ok(res)
789 })
790 .await
791 .map_err(|e| StoreError::Database(e.to_string()))?
792 }
793
794 pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
795 let pool = self.pool.clone();
796 let address = address.to_string();
797 tokio::task::spawn_blocking(move || -> Result<()> {
798 let mut conn = pool
799 .get()
800 .map_err(|e| StoreError::Connection(e.to_string()))?;
801 diesel::delete(
802 sender_keys::table
803 .filter(sender_keys::address.eq(address))
804 .filter(sender_keys::device_id.eq(device_id)),
805 )
806 .execute(&mut conn)
807 .map_err(|e| StoreError::Database(e.to_string()))?;
808 Ok(())
809 })
810 .await
811 .map_err(|e| StoreError::Database(e.to_string()))??;
812 Ok(())
813 }
814
815 pub async fn get_app_state_sync_key_for_device(
816 &self,
817 key_id: &[u8],
818 device_id: i32,
819 ) -> Result<Option<AppStateSyncKey>> {
820 let pool = self.pool.clone();
821 let key_id = key_id.to_vec();
822 let res: Option<Vec<u8>> =
823 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
824 let mut conn = pool
825 .get()
826 .map_err(|e| StoreError::Connection(e.to_string()))?;
827 let res: Option<Vec<u8>> = app_state_keys::table
828 .select(app_state_keys::key_data)
829 .filter(app_state_keys::key_id.eq(&key_id))
830 .filter(app_state_keys::device_id.eq(device_id))
831 .first(&mut conn)
832 .optional()
833 .map_err(|e| StoreError::Database(e.to_string()))?;
834 Ok(res)
835 })
836 .await
837 .map_err(|e| StoreError::Database(e.to_string()))??;
838
839 if let Some(data) = res {
840 let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
841 .map_err(|e| StoreError::Serialization(e.to_string()))?;
842 Ok(Some(key))
843 } else {
844 Ok(None)
845 }
846 }
847
848 pub async fn set_app_state_sync_key_for_device(
849 &self,
850 key_id: &[u8],
851 key: AppStateSyncKey,
852 device_id: i32,
853 ) -> Result<()> {
854 let pool = self.pool.clone();
855 let key_id = key_id.to_vec();
856 let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
857 .map_err(|e| StoreError::Serialization(e.to_string()))?;
858 tokio::task::spawn_blocking(move || -> Result<()> {
859 let mut conn = pool
860 .get()
861 .map_err(|e| StoreError::Connection(e.to_string()))?;
862 diesel::insert_into(app_state_keys::table)
863 .values((
864 app_state_keys::key_id.eq(&key_id),
865 app_state_keys::key_data.eq(&data),
866 app_state_keys::device_id.eq(device_id),
867 ))
868 .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
869 .do_update()
870 .set(app_state_keys::key_data.eq(&data))
871 .execute(&mut conn)
872 .map_err(|e| StoreError::Database(e.to_string()))?;
873 Ok(())
874 })
875 .await
876 .map_err(|e| StoreError::Database(e.to_string()))??;
877 Ok(())
878 }
879
880 pub async fn get_app_state_version_for_device(
881 &self,
882 name: &str,
883 device_id: i32,
884 ) -> Result<HashState> {
885 let pool = self.pool.clone();
886 let name = name.to_string();
887 let res: Option<Vec<u8>> =
888 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
889 let mut conn = pool
890 .get()
891 .map_err(|e| StoreError::Connection(e.to_string()))?;
892 let res: Option<Vec<u8>> = app_state_versions::table
893 .select(app_state_versions::state_data)
894 .filter(app_state_versions::name.eq(name))
895 .filter(app_state_versions::device_id.eq(device_id))
896 .first(&mut conn)
897 .optional()
898 .map_err(|e| StoreError::Database(e.to_string()))?;
899 Ok(res)
900 })
901 .await
902 .map_err(|e| StoreError::Database(e.to_string()))??;
903
904 if let Some(data) = res {
905 let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
906 .map_err(|e| StoreError::Serialization(e.to_string()))?;
907 Ok(state)
908 } else {
909 Ok(HashState::default())
910 }
911 }
912
913 pub async fn set_app_state_version_for_device(
914 &self,
915 name: &str,
916 state: HashState,
917 device_id: i32,
918 ) -> Result<()> {
919 let pool = self.pool.clone();
920 let name = name.to_string();
921 let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
922 .map_err(|e| StoreError::Serialization(e.to_string()))?;
923 tokio::task::spawn_blocking(move || -> Result<()> {
924 let mut conn = pool
925 .get()
926 .map_err(|e| StoreError::Connection(e.to_string()))?;
927 diesel::insert_into(app_state_versions::table)
928 .values((
929 app_state_versions::name.eq(&name),
930 app_state_versions::state_data.eq(&data),
931 app_state_versions::device_id.eq(device_id),
932 ))
933 .on_conflict((app_state_versions::name, app_state_versions::device_id))
934 .do_update()
935 .set(app_state_versions::state_data.eq(&data))
936 .execute(&mut conn)
937 .map_err(|e| StoreError::Database(e.to_string()))?;
938 Ok(())
939 })
940 .await
941 .map_err(|e| StoreError::Database(e.to_string()))??;
942 Ok(())
943 }
944
945 pub async fn put_app_state_mutation_macs_for_device(
946 &self,
947 name: &str,
948 version: u64,
949 mutations: &[AppStateMutationMAC],
950 device_id: i32,
951 ) -> Result<()> {
952 if mutations.is_empty() {
953 return Ok(());
954 }
955 let pool = self.pool.clone();
956 let name = name.to_string();
957 let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
958 tokio::task::spawn_blocking(move || -> Result<()> {
959 let mut conn = pool
960 .get()
961 .map_err(|e| StoreError::Connection(e.to_string()))?;
962
963 let records: Vec<_> = mutations
964 .iter()
965 .map(|m| {
966 (
967 app_state_mutation_macs::name.eq(&name),
968 app_state_mutation_macs::version.eq(version as i64),
969 app_state_mutation_macs::index_mac.eq(&m.index_mac),
970 app_state_mutation_macs::value_mac.eq(&m.value_mac),
971 app_state_mutation_macs::device_id.eq(device_id),
972 )
973 })
974 .collect();
975
976 const CHUNK_SIZE: usize = 100;
979
980 for chunk in records.chunks(CHUNK_SIZE) {
981 diesel::insert_into(app_state_mutation_macs::table)
982 .values(chunk)
983 .on_conflict((
984 app_state_mutation_macs::name,
985 app_state_mutation_macs::index_mac,
986 app_state_mutation_macs::device_id,
987 ))
988 .do_update()
989 .set((
990 app_state_mutation_macs::version
991 .eq(excluded(app_state_mutation_macs::version)),
992 app_state_mutation_macs::value_mac
993 .eq(excluded(app_state_mutation_macs::value_mac)),
994 ))
995 .execute(&mut conn)
996 .map_err(|e| StoreError::Database(e.to_string()))?;
997 }
998 Ok(())
999 })
1000 .await
1001 .map_err(|e| StoreError::Database(e.to_string()))??;
1002 Ok(())
1003 }
1004
1005 pub async fn delete_app_state_mutation_macs_for_device(
1006 &self,
1007 name: &str,
1008 index_macs: &[Vec<u8>],
1009 device_id: i32,
1010 ) -> Result<()> {
1011 if index_macs.is_empty() {
1012 return Ok(());
1013 }
1014 let pool = self.pool.clone();
1015 let name = name.to_string();
1016 let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
1017 tokio::task::spawn_blocking(move || -> Result<()> {
1018 let mut conn = pool
1019 .get()
1020 .map_err(|e| StoreError::Connection(e.to_string()))?;
1021
1022 const CHUNK_SIZE: usize = 500;
1025
1026 for chunk in index_macs.chunks(CHUNK_SIZE) {
1027 diesel::delete(
1028 app_state_mutation_macs::table.filter(
1029 app_state_mutation_macs::name
1030 .eq(&name)
1031 .and(app_state_mutation_macs::index_mac.eq_any(chunk))
1032 .and(app_state_mutation_macs::device_id.eq(device_id)),
1033 ),
1034 )
1035 .execute(&mut conn)
1036 .map_err(|e| StoreError::Database(e.to_string()))?;
1037 }
1038 Ok(())
1039 })
1040 .await
1041 .map_err(|e| StoreError::Database(e.to_string()))??;
1042 Ok(())
1043 }
1044
1045 pub async fn get_app_state_mutation_mac_for_device(
1046 &self,
1047 name: &str,
1048 index_mac: &[u8],
1049 device_id: i32,
1050 ) -> Result<Option<Vec<u8>>> {
1051 let pool = self.pool.clone();
1052 let name = name.to_string();
1053 let index_mac = index_mac.to_vec();
1054 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1055 let mut conn = pool
1056 .get()
1057 .map_err(|e| StoreError::Connection(e.to_string()))?;
1058 let res: Option<Vec<u8>> = app_state_mutation_macs::table
1059 .select(app_state_mutation_macs::value_mac)
1060 .filter(app_state_mutation_macs::name.eq(&name))
1061 .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1062 .filter(app_state_mutation_macs::device_id.eq(device_id))
1063 .first(&mut conn)
1064 .optional()
1065 .map_err(|e| StoreError::Database(e.to_string()))?;
1066 Ok(res)
1067 })
1068 .await
1069 .map_err(|e| StoreError::Database(e.to_string()))?
1070 }
1071}
1072
1073#[async_trait]
1074impl SignalStore for SqliteStore {
1075 async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1076 self.put_identity_for_device(address, key, self.device_id)
1077 .await
1078 }
1079
1080 async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1081 self.load_identity_for_device(address, self.device_id).await
1082 }
1083
1084 async fn delete_identity(&self, address: &str) -> Result<()> {
1085 self.delete_identity_for_device(address, self.device_id)
1086 .await
1087 }
1088
1089 async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1090 self.get_session_for_device(address, self.device_id).await
1091 }
1092
1093 async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1094 self.put_session_for_device(address, session, self.device_id)
1095 .await
1096 }
1097
1098 async fn delete_session(&self, address: &str) -> Result<()> {
1099 self.delete_session_for_device(address, self.device_id)
1100 .await
1101 }
1102
1103 async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1104 let pool = self.pool.clone();
1105 let device_id = self.device_id;
1106 let record = record.to_vec();
1107 tokio::task::spawn_blocking(move || -> Result<()> {
1108 let mut conn = pool
1109 .get()
1110 .map_err(|e| StoreError::Connection(e.to_string()))?;
1111 diesel::insert_into(prekeys::table)
1112 .values((
1113 prekeys::id.eq(id as i32),
1114 prekeys::key.eq(&record),
1115 prekeys::uploaded.eq(uploaded),
1116 prekeys::device_id.eq(device_id),
1117 ))
1118 .on_conflict((prekeys::id, prekeys::device_id))
1119 .do_update()
1120 .set((prekeys::key.eq(&record), prekeys::uploaded.eq(uploaded)))
1121 .execute(&mut conn)
1122 .map_err(|e| StoreError::Database(e.to_string()))?;
1123 Ok(())
1124 })
1125 .await
1126 .map_err(|e| StoreError::Database(e.to_string()))??;
1127 Ok(())
1128 }
1129
1130 async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1131 let pool = self.pool.clone();
1132 let device_id = self.device_id;
1133 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1134 let mut conn = pool
1135 .get()
1136 .map_err(|e| StoreError::Connection(e.to_string()))?;
1137 let res: Option<Vec<u8>> = prekeys::table
1138 .select(prekeys::key)
1139 .filter(prekeys::id.eq(id as i32))
1140 .filter(prekeys::device_id.eq(device_id))
1141 .first(&mut conn)
1142 .optional()
1143 .map_err(|e| StoreError::Database(e.to_string()))?;
1144 Ok(res)
1145 })
1146 .await
1147 .map_err(|e| StoreError::Database(e.to_string()))?
1148 }
1149
1150 async fn remove_prekey(&self, id: u32) -> Result<()> {
1151 let pool = self.pool.clone();
1152 let device_id = self.device_id;
1153 tokio::task::spawn_blocking(move || -> Result<()> {
1154 let mut conn = pool
1155 .get()
1156 .map_err(|e| StoreError::Connection(e.to_string()))?;
1157 diesel::delete(
1158 prekeys::table
1159 .filter(prekeys::id.eq(id as i32))
1160 .filter(prekeys::device_id.eq(device_id)),
1161 )
1162 .execute(&mut conn)
1163 .map_err(|e| StoreError::Database(e.to_string()))?;
1164 Ok(())
1165 })
1166 .await
1167 .map_err(|e| StoreError::Database(e.to_string()))??;
1168 Ok(())
1169 }
1170
1171 async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1172 let pool = self.pool.clone();
1173 let device_id = self.device_id;
1174 let record = record.to_vec();
1175 tokio::task::spawn_blocking(move || -> Result<()> {
1176 let mut conn = pool
1177 .get()
1178 .map_err(|e| StoreError::Connection(e.to_string()))?;
1179 diesel::insert_into(signed_prekeys::table)
1180 .values((
1181 signed_prekeys::id.eq(id as i32),
1182 signed_prekeys::record.eq(&record),
1183 signed_prekeys::device_id.eq(device_id),
1184 ))
1185 .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1186 .do_update()
1187 .set(signed_prekeys::record.eq(&record))
1188 .execute(&mut conn)
1189 .map_err(|e| StoreError::Database(e.to_string()))?;
1190 Ok(())
1191 })
1192 .await
1193 .map_err(|e| StoreError::Database(e.to_string()))??;
1194 Ok(())
1195 }
1196
1197 async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1198 let pool = self.pool.clone();
1199 let device_id = self.device_id;
1200 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1201 let mut conn = pool
1202 .get()
1203 .map_err(|e| StoreError::Connection(e.to_string()))?;
1204 let res: Option<Vec<u8>> = signed_prekeys::table
1205 .select(signed_prekeys::record)
1206 .filter(signed_prekeys::id.eq(id as i32))
1207 .filter(signed_prekeys::device_id.eq(device_id))
1208 .first(&mut conn)
1209 .optional()
1210 .map_err(|e| StoreError::Database(e.to_string()))?;
1211 Ok(res)
1212 })
1213 .await
1214 .map_err(|e| StoreError::Database(e.to_string()))?
1215 }
1216
1217 async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1218 let pool = self.pool.clone();
1219 let device_id = self.device_id;
1220 tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1221 let mut conn = pool
1222 .get()
1223 .map_err(|e| StoreError::Connection(e.to_string()))?;
1224 let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1225 .select((signed_prekeys::id, signed_prekeys::record))
1226 .filter(signed_prekeys::device_id.eq(device_id))
1227 .load(&mut conn)
1228 .map_err(|e| StoreError::Database(e.to_string()))?;
1229 Ok(results
1230 .into_iter()
1231 .map(|(id, record)| (id as u32, record))
1232 .collect())
1233 })
1234 .await
1235 .map_err(|e| StoreError::Database(e.to_string()))?
1236 }
1237
1238 async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1239 let pool = self.pool.clone();
1240 let device_id = self.device_id;
1241 tokio::task::spawn_blocking(move || -> Result<()> {
1242 let mut conn = pool
1243 .get()
1244 .map_err(|e| StoreError::Connection(e.to_string()))?;
1245 diesel::delete(
1246 signed_prekeys::table
1247 .filter(signed_prekeys::id.eq(id as i32))
1248 .filter(signed_prekeys::device_id.eq(device_id)),
1249 )
1250 .execute(&mut conn)
1251 .map_err(|e| StoreError::Database(e.to_string()))?;
1252 Ok(())
1253 })
1254 .await
1255 .map_err(|e| StoreError::Database(e.to_string()))??;
1256 Ok(())
1257 }
1258
1259 async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1260 self.put_sender_key_for_device(address, record, self.device_id)
1261 .await
1262 }
1263
1264 async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1265 self.get_sender_key_for_device(address, self.device_id)
1266 .await
1267 }
1268
1269 async fn delete_sender_key(&self, address: &str) -> Result<()> {
1270 self.delete_sender_key_for_device(address, self.device_id)
1271 .await
1272 }
1273}
1274
1275#[async_trait]
1276impl AppSyncStore for SqliteStore {
1277 async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1278 self.get_app_state_sync_key_for_device(key_id, self.device_id)
1279 .await
1280 }
1281
1282 async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1283 self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1284 .await
1285 }
1286
1287 async fn get_version(&self, name: &str) -> Result<HashState> {
1288 self.get_app_state_version_for_device(name, self.device_id)
1289 .await
1290 }
1291
1292 async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1293 self.set_app_state_version_for_device(name, state, self.device_id)
1294 .await
1295 }
1296
1297 async fn put_mutation_macs(
1298 &self,
1299 name: &str,
1300 version: u64,
1301 mutations: &[AppStateMutationMAC],
1302 ) -> Result<()> {
1303 self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1304 .await
1305 }
1306
1307 async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1308 self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1309 .await
1310 }
1311
1312 async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1313 self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1314 .await
1315 }
1316}
1317
1318#[async_trait]
1319impl ProtocolStore for SqliteStore {
1320 async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<Jid>> {
1321 let pool = self.pool.clone();
1322 let device_id = self.device_id;
1323 let group_jid = group_jid.to_string();
1324 tokio::task::spawn_blocking(move || -> Result<Vec<Jid>> {
1325 let mut conn = pool
1326 .get()
1327 .map_err(|e| StoreError::Connection(e.to_string()))?;
1328 let recipients: Vec<String> = skdm_recipients::table
1329 .select(skdm_recipients::device_jid)
1330 .filter(skdm_recipients::group_jid.eq(&group_jid))
1331 .filter(skdm_recipients::device_id.eq(device_id))
1332 .load(&mut conn)
1333 .map_err(|e| StoreError::Database(e.to_string()))?;
1334 let jids: Vec<Jid> = recipients
1335 .iter()
1336 .filter_map(|s| match s.parse::<Jid>() {
1337 Ok(jid) => Some(jid),
1338 Err(e) => {
1339 warn!("Failed to parse SKDM recipient '{}': {}", s, e);
1340 None
1341 }
1342 })
1343 .collect();
1344 Ok(jids)
1345 })
1346 .await
1347 .map_err(|e| StoreError::Database(e.to_string()))?
1348 }
1349
1350 async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> Result<()> {
1351 if device_jids.is_empty() {
1352 return Ok(());
1353 }
1354 let pool = self.pool.clone();
1355 let device_id = self.device_id;
1356 let group_jid = group_jid.to_string();
1357 let device_jid_strs: Vec<String> = device_jids.iter().map(|j| j.to_string()).collect();
1358 let now = std::time::SystemTime::now()
1359 .duration_since(std::time::UNIX_EPOCH)
1360 .unwrap_or_default()
1361 .as_secs() as i32;
1362 tokio::task::spawn_blocking(move || -> Result<()> {
1363 let mut conn = pool
1364 .get()
1365 .map_err(|e| StoreError::Connection(e.to_string()))?;
1366
1367 let values: Vec<_> = device_jid_strs
1368 .iter()
1369 .map(|device_jid| {
1370 (
1371 skdm_recipients::group_jid.eq(&group_jid),
1372 skdm_recipients::device_jid.eq(device_jid),
1373 skdm_recipients::device_id.eq(device_id),
1374 skdm_recipients::created_at.eq(now),
1375 )
1376 })
1377 .collect();
1378
1379 const CHUNK_SIZE: usize = 200; for chunk in values.chunks(CHUNK_SIZE) {
1382 diesel::insert_into(skdm_recipients::table)
1383 .values(chunk)
1384 .on_conflict((
1385 skdm_recipients::group_jid,
1386 skdm_recipients::device_jid,
1387 skdm_recipients::device_id,
1388 ))
1389 .do_nothing()
1390 .execute(&mut conn)
1391 .map_err(|e| StoreError::Database(e.to_string()))?;
1392 }
1393 Ok(())
1394 })
1395 .await
1396 .map_err(|e| StoreError::Database(e.to_string()))??;
1397 Ok(())
1398 }
1399
1400 async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1401 let pool = self.pool.clone();
1402 let device_id = self.device_id;
1403 let group_jid = group_jid.to_string();
1404 tokio::task::spawn_blocking(move || -> Result<()> {
1405 let mut conn = pool
1406 .get()
1407 .map_err(|e| StoreError::Connection(e.to_string()))?;
1408 diesel::delete(
1409 skdm_recipients::table
1410 .filter(skdm_recipients::group_jid.eq(&group_jid))
1411 .filter(skdm_recipients::device_id.eq(device_id)),
1412 )
1413 .execute(&mut conn)
1414 .map_err(|e| StoreError::Database(e.to_string()))?;
1415 Ok(())
1416 })
1417 .await
1418 .map_err(|e| StoreError::Database(e.to_string()))??;
1419 Ok(())
1420 }
1421
1422 async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1423 let pool = self.pool.clone();
1424 let device_id = self.device_id;
1425 let lid = lid.to_string();
1426 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1427 let mut conn = pool
1428 .get()
1429 .map_err(|e| StoreError::Connection(e.to_string()))?;
1430 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1431 .select((
1432 lid_pn_mapping::lid,
1433 lid_pn_mapping::phone_number,
1434 lid_pn_mapping::created_at,
1435 lid_pn_mapping::learning_source,
1436 lid_pn_mapping::updated_at,
1437 ))
1438 .filter(lid_pn_mapping::lid.eq(&lid))
1439 .filter(lid_pn_mapping::device_id.eq(device_id))
1440 .first(&mut conn)
1441 .optional()
1442 .map_err(|e| StoreError::Database(e.to_string()))?;
1443 Ok(row.map(
1444 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1445 lid,
1446 phone_number,
1447 created_at,
1448 updated_at,
1449 learning_source,
1450 },
1451 ))
1452 })
1453 .await
1454 .map_err(|e| StoreError::Database(e.to_string()))?
1455 }
1456
1457 async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1458 let pool = self.pool.clone();
1459 let device_id = self.device_id;
1460 let phone = phone.to_string();
1461 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1462 let mut conn = pool
1463 .get()
1464 .map_err(|e| StoreError::Connection(e.to_string()))?;
1465 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1466 .select((
1467 lid_pn_mapping::lid,
1468 lid_pn_mapping::phone_number,
1469 lid_pn_mapping::created_at,
1470 lid_pn_mapping::learning_source,
1471 lid_pn_mapping::updated_at,
1472 ))
1473 .filter(lid_pn_mapping::phone_number.eq(&phone))
1474 .filter(lid_pn_mapping::device_id.eq(device_id))
1475 .order(lid_pn_mapping::updated_at.desc())
1476 .first(&mut conn)
1477 .optional()
1478 .map_err(|e| StoreError::Database(e.to_string()))?;
1479 Ok(row.map(
1480 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1481 lid,
1482 phone_number,
1483 created_at,
1484 updated_at,
1485 learning_source,
1486 },
1487 ))
1488 })
1489 .await
1490 .map_err(|e| StoreError::Database(e.to_string()))?
1491 }
1492
1493 async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1494 let pool = self.pool.clone();
1495 let device_id = self.device_id;
1496 let entry = entry.clone();
1497 tokio::task::spawn_blocking(move || -> Result<()> {
1498 let mut conn = pool
1499 .get()
1500 .map_err(|e| StoreError::Connection(e.to_string()))?;
1501 diesel::insert_into(lid_pn_mapping::table)
1502 .values((
1503 lid_pn_mapping::lid.eq(&entry.lid),
1504 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1505 lid_pn_mapping::created_at.eq(entry.created_at),
1506 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1507 lid_pn_mapping::updated_at.eq(entry.updated_at),
1508 lid_pn_mapping::device_id.eq(device_id),
1509 ))
1510 .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1511 .do_update()
1512 .set((
1513 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1514 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1515 lid_pn_mapping::updated_at.eq(entry.updated_at),
1516 ))
1517 .execute(&mut conn)
1518 .map_err(|e| StoreError::Database(e.to_string()))?;
1519 Ok(())
1520 })
1521 .await
1522 .map_err(|e| StoreError::Database(e.to_string()))??;
1523 Ok(())
1524 }
1525
1526 async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1527 let pool = self.pool.clone();
1528 let device_id = self.device_id;
1529 tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1530 let mut conn = pool
1531 .get()
1532 .map_err(|e| StoreError::Connection(e.to_string()))?;
1533 let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1534 .select((
1535 lid_pn_mapping::lid,
1536 lid_pn_mapping::phone_number,
1537 lid_pn_mapping::created_at,
1538 lid_pn_mapping::learning_source,
1539 lid_pn_mapping::updated_at,
1540 ))
1541 .filter(lid_pn_mapping::device_id.eq(device_id))
1542 .load(&mut conn)
1543 .map_err(|e| StoreError::Database(e.to_string()))?;
1544 Ok(rows
1545 .into_iter()
1546 .map(
1547 |(lid, phone_number, created_at, learning_source, updated_at)| {
1548 LidPnMappingEntry {
1549 lid,
1550 phone_number,
1551 created_at,
1552 updated_at,
1553 learning_source,
1554 }
1555 },
1556 )
1557 .collect())
1558 })
1559 .await
1560 .map_err(|e| StoreError::Database(e.to_string()))?
1561 }
1562
1563 async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1564 let pool = self.pool.clone();
1565 let device_id = self.device_id;
1566 let address = address.to_string();
1567 let message_id = message_id.to_string();
1568 let base_key = base_key.to_vec();
1569 let now = std::time::SystemTime::now()
1570 .duration_since(std::time::UNIX_EPOCH)
1571 .unwrap_or_default()
1572 .as_secs() as i32;
1573 tokio::task::spawn_blocking(move || -> Result<()> {
1574 let mut conn = pool
1575 .get()
1576 .map_err(|e| StoreError::Connection(e.to_string()))?;
1577 diesel::insert_into(base_keys::table)
1578 .values((
1579 base_keys::address.eq(&address),
1580 base_keys::message_id.eq(&message_id),
1581 base_keys::base_key.eq(&base_key),
1582 base_keys::device_id.eq(device_id),
1583 base_keys::created_at.eq(now),
1584 ))
1585 .on_conflict((
1586 base_keys::address,
1587 base_keys::message_id,
1588 base_keys::device_id,
1589 ))
1590 .do_update()
1591 .set(base_keys::base_key.eq(&base_key))
1592 .execute(&mut conn)
1593 .map_err(|e| StoreError::Database(e.to_string()))?;
1594 Ok(())
1595 })
1596 .await
1597 .map_err(|e| StoreError::Database(e.to_string()))??;
1598 Ok(())
1599 }
1600
1601 async fn has_same_base_key(
1602 &self,
1603 address: &str,
1604 message_id: &str,
1605 current_base_key: &[u8],
1606 ) -> Result<bool> {
1607 let pool = self.pool.clone();
1608 let device_id = self.device_id;
1609 let address = address.to_string();
1610 let message_id = message_id.to_string();
1611 let current_base_key = current_base_key.to_vec();
1612 tokio::task::spawn_blocking(move || -> Result<bool> {
1613 let mut conn = pool
1614 .get()
1615 .map_err(|e| StoreError::Connection(e.to_string()))?;
1616 let stored_key: Option<Vec<u8>> = base_keys::table
1617 .select(base_keys::base_key)
1618 .filter(base_keys::address.eq(&address))
1619 .filter(base_keys::message_id.eq(&message_id))
1620 .filter(base_keys::device_id.eq(device_id))
1621 .first(&mut conn)
1622 .optional()
1623 .map_err(|e| StoreError::Database(e.to_string()))?;
1624 Ok(stored_key.as_ref() == Some(¤t_base_key))
1625 })
1626 .await
1627 .map_err(|e| StoreError::Database(e.to_string()))?
1628 }
1629
1630 async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1631 let pool = self.pool.clone();
1632 let device_id = self.device_id;
1633 let address = address.to_string();
1634 let message_id = message_id.to_string();
1635 tokio::task::spawn_blocking(move || -> Result<()> {
1636 let mut conn = pool
1637 .get()
1638 .map_err(|e| StoreError::Connection(e.to_string()))?;
1639 diesel::delete(
1640 base_keys::table
1641 .filter(base_keys::address.eq(&address))
1642 .filter(base_keys::message_id.eq(&message_id))
1643 .filter(base_keys::device_id.eq(device_id)),
1644 )
1645 .execute(&mut conn)
1646 .map_err(|e| StoreError::Database(e.to_string()))?;
1647 Ok(())
1648 })
1649 .await
1650 .map_err(|e| StoreError::Database(e.to_string()))??;
1651 Ok(())
1652 }
1653
1654 async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1655 let pool = self.pool.clone();
1656 let device_id = self.device_id;
1657 let devices_json = serde_json::to_string(&record.devices)
1658 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1659 let now = std::time::SystemTime::now()
1660 .duration_since(std::time::UNIX_EPOCH)
1661 .unwrap_or_default()
1662 .as_secs() as i32;
1663 tokio::task::spawn_blocking(move || -> Result<()> {
1664 let mut conn = pool
1665 .get()
1666 .map_err(|e| StoreError::Connection(e.to_string()))?;
1667 diesel::insert_into(device_registry::table)
1668 .values((
1669 device_registry::user_id.eq(&record.user),
1670 device_registry::devices_json.eq(&devices_json),
1671 device_registry::timestamp.eq(record.timestamp as i32),
1672 device_registry::phash.eq(&record.phash),
1673 device_registry::device_id.eq(device_id),
1674 device_registry::updated_at.eq(now),
1675 ))
1676 .on_conflict((device_registry::user_id, device_registry::device_id))
1677 .do_update()
1678 .set((
1679 device_registry::devices_json.eq(&devices_json),
1680 device_registry::timestamp.eq(record.timestamp as i32),
1681 device_registry::phash.eq(&record.phash),
1682 device_registry::updated_at.eq(now),
1683 ))
1684 .execute(&mut conn)
1685 .map_err(|e| StoreError::Database(e.to_string()))?;
1686 Ok(())
1687 })
1688 .await
1689 .map_err(|e| StoreError::Database(e.to_string()))??;
1690 Ok(())
1691 }
1692
1693 async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
1694 let pool = self.pool.clone();
1695 let device_id = self.device_id;
1696 let user = user.to_string();
1697 tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
1698 let mut conn = pool
1699 .get()
1700 .map_err(|e| StoreError::Connection(e.to_string()))?;
1701 let row: Option<(String, String, i32, Option<String>)> = device_registry::table
1702 .select((
1703 device_registry::user_id,
1704 device_registry::devices_json,
1705 device_registry::timestamp,
1706 device_registry::phash,
1707 ))
1708 .filter(device_registry::user_id.eq(&user))
1709 .filter(device_registry::device_id.eq(device_id))
1710 .first(&mut conn)
1711 .optional()
1712 .map_err(|e| StoreError::Database(e.to_string()))?;
1713 match row {
1714 Some((user, devices_json, timestamp, phash)) => {
1715 let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
1716 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1717 Ok(Some(DeviceListRecord {
1718 user,
1719 devices,
1720 timestamp: timestamp as i64,
1721 phash,
1722 }))
1723 }
1724 None => Ok(None),
1725 }
1726 })
1727 .await
1728 .map_err(|e| StoreError::Database(e.to_string()))?
1729 }
1730
1731 async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
1732 let pool = self.pool.clone();
1733 let device_id = self.device_id;
1734 let group_jid = group_jid.to_string();
1735 let participant = participant.to_string();
1736 let now = std::time::SystemTime::now()
1737 .duration_since(std::time::UNIX_EPOCH)
1738 .unwrap_or_default()
1739 .as_secs() as i32;
1740 tokio::task::spawn_blocking(move || -> Result<()> {
1741 let mut conn = pool
1742 .get()
1743 .map_err(|e| StoreError::Connection(e.to_string()))?;
1744 diesel::insert_into(sender_key_status::table)
1745 .values((
1746 sender_key_status::group_jid.eq(&group_jid),
1747 sender_key_status::participant.eq(&participant),
1748 sender_key_status::device_id.eq(device_id),
1749 sender_key_status::marked_at.eq(now),
1750 ))
1751 .on_conflict((
1752 sender_key_status::group_jid,
1753 sender_key_status::participant,
1754 sender_key_status::device_id,
1755 ))
1756 .do_update()
1757 .set(sender_key_status::marked_at.eq(now))
1758 .execute(&mut conn)
1759 .map_err(|e| StoreError::Database(e.to_string()))?;
1760 Ok(())
1761 })
1762 .await
1763 .map_err(|e| StoreError::Database(e.to_string()))??;
1764 Ok(())
1765 }
1766
1767 async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
1768 let pool = self.pool.clone();
1769 let device_id = self.device_id;
1770 let group_jid = group_jid.to_string();
1771 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1772 let mut conn = pool
1773 .get()
1774 .map_err(|e| StoreError::Connection(e.to_string()))?;
1775 let participants: Vec<String> = sender_key_status::table
1776 .select(sender_key_status::participant)
1777 .filter(sender_key_status::group_jid.eq(&group_jid))
1778 .filter(sender_key_status::device_id.eq(device_id))
1779 .load(&mut conn)
1780 .map_err(|e| StoreError::Database(e.to_string()))?;
1781 diesel::delete(
1782 sender_key_status::table
1783 .filter(sender_key_status::group_jid.eq(&group_jid))
1784 .filter(sender_key_status::device_id.eq(device_id)),
1785 )
1786 .execute(&mut conn)
1787 .map_err(|e| StoreError::Database(e.to_string()))?;
1788 Ok(participants)
1789 })
1790 .await
1791 .map_err(|e| StoreError::Database(e.to_string()))?
1792 }
1793
1794 async fn get_tc_token(&self, jid: &str) -> Result<Option<TcTokenEntry>> {
1795 let pool = self.pool.clone();
1796 let device_id = self.device_id;
1797 let jid = jid.to_string();
1798 tokio::task::spawn_blocking(move || -> Result<Option<TcTokenEntry>> {
1799 let mut conn = pool
1800 .get()
1801 .map_err(|e| StoreError::Connection(e.to_string()))?;
1802 let row: Option<(Vec<u8>, i64, Option<i64>)> = tc_tokens::table
1803 .select((
1804 tc_tokens::token,
1805 tc_tokens::token_timestamp,
1806 tc_tokens::sender_timestamp,
1807 ))
1808 .filter(tc_tokens::jid.eq(&jid))
1809 .filter(tc_tokens::device_id.eq(device_id))
1810 .first(&mut conn)
1811 .optional()
1812 .map_err(|e| StoreError::Database(e.to_string()))?;
1813 Ok(
1814 row.map(|(token, token_timestamp, sender_timestamp)| TcTokenEntry {
1815 token,
1816 token_timestamp,
1817 sender_timestamp,
1818 }),
1819 )
1820 })
1821 .await
1822 .map_err(|e| StoreError::Database(e.to_string()))?
1823 }
1824
1825 async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> Result<()> {
1826 let pool = self.pool.clone();
1827 let device_id = self.device_id;
1828 let jid = jid.to_string();
1829 let entry = entry.clone();
1830 let now = std::time::SystemTime::now()
1831 .duration_since(std::time::UNIX_EPOCH)
1832 .unwrap_or_default()
1833 .as_secs() as i64;
1834 tokio::task::spawn_blocking(move || -> Result<()> {
1835 let mut conn = pool
1836 .get()
1837 .map_err(|e| StoreError::Connection(e.to_string()))?;
1838 diesel::insert_into(tc_tokens::table)
1839 .values((
1840 tc_tokens::jid.eq(&jid),
1841 tc_tokens::token.eq(&entry.token),
1842 tc_tokens::token_timestamp.eq(entry.token_timestamp),
1843 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
1844 tc_tokens::device_id.eq(device_id),
1845 tc_tokens::updated_at.eq(now),
1846 ))
1847 .on_conflict((tc_tokens::jid, tc_tokens::device_id))
1848 .do_update()
1849 .set((
1850 tc_tokens::token.eq(&entry.token),
1851 tc_tokens::token_timestamp.eq(entry.token_timestamp),
1852 tc_tokens::sender_timestamp.eq(entry.sender_timestamp),
1853 tc_tokens::updated_at.eq(now),
1854 ))
1855 .execute(&mut conn)
1856 .map_err(|e| StoreError::Database(e.to_string()))?;
1857 Ok(())
1858 })
1859 .await
1860 .map_err(|e| StoreError::Database(e.to_string()))??;
1861 Ok(())
1862 }
1863
1864 async fn delete_tc_token(&self, jid: &str) -> Result<()> {
1865 let pool = self.pool.clone();
1866 let device_id = self.device_id;
1867 let jid = jid.to_string();
1868 tokio::task::spawn_blocking(move || -> Result<()> {
1869 let mut conn = pool
1870 .get()
1871 .map_err(|e| StoreError::Connection(e.to_string()))?;
1872 diesel::delete(
1873 tc_tokens::table
1874 .filter(tc_tokens::jid.eq(&jid))
1875 .filter(tc_tokens::device_id.eq(device_id)),
1876 )
1877 .execute(&mut conn)
1878 .map_err(|e| StoreError::Database(e.to_string()))?;
1879 Ok(())
1880 })
1881 .await
1882 .map_err(|e| StoreError::Database(e.to_string()))??;
1883 Ok(())
1884 }
1885
1886 async fn get_all_tc_token_jids(&self) -> Result<Vec<String>> {
1887 let pool = self.pool.clone();
1888 let device_id = self.device_id;
1889 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1890 let mut conn = pool
1891 .get()
1892 .map_err(|e| StoreError::Connection(e.to_string()))?;
1893 let jids: Vec<String> = tc_tokens::table
1894 .select(tc_tokens::jid)
1895 .filter(tc_tokens::device_id.eq(device_id))
1896 .load(&mut conn)
1897 .map_err(|e| StoreError::Database(e.to_string()))?;
1898 Ok(jids)
1899 })
1900 .await
1901 .map_err(|e| StoreError::Database(e.to_string()))?
1902 }
1903
1904 async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> Result<u32> {
1905 let pool = self.pool.clone();
1906 let device_id = self.device_id;
1907 tokio::task::spawn_blocking(move || -> Result<u32> {
1908 let mut conn = pool
1909 .get()
1910 .map_err(|e| StoreError::Connection(e.to_string()))?;
1911 let deleted = diesel::delete(
1912 tc_tokens::table
1913 .filter(tc_tokens::token_timestamp.lt(cutoff_timestamp))
1914 .filter(tc_tokens::device_id.eq(device_id)),
1915 )
1916 .execute(&mut conn)
1917 .map_err(|e| StoreError::Database(e.to_string()))?;
1918 Ok(deleted as u32)
1919 })
1920 .await
1921 .map_err(|e| StoreError::Database(e.to_string()))?
1922 }
1923}
1924
1925#[async_trait]
1926impl DeviceStore for SqliteStore {
1927 async fn save(&self, device: &CoreDevice) -> Result<()> {
1928 SqliteStore::save_device_data_for_device(self, self.device_id, device).await
1929 }
1930
1931 async fn load(&self) -> Result<Option<CoreDevice>> {
1932 SqliteStore::load_device_data_for_device(self, self.device_id).await
1933 }
1934
1935 async fn exists(&self) -> Result<bool> {
1936 SqliteStore::device_exists(self, self.device_id).await
1937 }
1938
1939 async fn create(&self) -> Result<i32> {
1940 SqliteStore::create_new_device(self).await
1941 }
1942
1943 async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> Result<()> {
1944 fn sanitize_snapshot_name(name: &str) -> Result<String> {
1945 const MAX_LENGTH: usize = 100;
1946
1947 let sanitized: String = name
1948 .chars()
1949 .map(|c| {
1950 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' {
1951 c
1952 } else {
1953 '_'
1954 }
1955 })
1956 .collect();
1957
1958 let sanitized = sanitized
1959 .split('.')
1960 .filter(|part| !part.is_empty() && *part != "..")
1961 .collect::<Vec<_>>()
1962 .join(".");
1963
1964 let sanitized = sanitized.trim_matches(['/', '\\', '.']);
1965
1966 if sanitized.is_empty() {
1967 return Err(StoreError::Database(
1968 "Snapshot name cannot be empty after sanitization".to_string(),
1969 ));
1970 }
1971
1972 if sanitized.len() > MAX_LENGTH {
1973 return Err(StoreError::Database(format!(
1974 "Snapshot name exceeds maximum length of {} characters",
1975 MAX_LENGTH
1976 )));
1977 }
1978
1979 Ok(sanitized.to_string())
1980 }
1981
1982 let sanitized_name = sanitize_snapshot_name(name)?;
1983
1984 let pool = self.pool.clone();
1985 let db_path = self.database_path.clone();
1986 let extra_data = extra_content.map(|b| b.to_vec());
1987
1988 tokio::task::spawn_blocking(move || -> Result<()> {
1989 let mut conn = pool
1990 .get()
1991 .map_err(|e| StoreError::Connection(e.to_string()))?;
1992
1993 let timestamp = std::time::SystemTime::now()
1994 .duration_since(std::time::UNIX_EPOCH)
1995 .unwrap_or_default()
1996 .as_secs();
1997
1998 let target_path = format!("{}.snapshot-{}-{}", db_path, timestamp, sanitized_name);
2000
2001 let query = format!("VACUUM INTO '{}'", target_path.replace("'", "''"));
2004
2005 diesel::sql_query(query)
2006 .execute(&mut conn)
2007 .map_err(|e| StoreError::Database(e.to_string()))?;
2008
2009 if let Some(data) = extra_data {
2011 let extra_path = format!("{}.json", target_path);
2012 std::fs::write(&extra_path, data).map_err(|e| {
2013 StoreError::Database(format!("Failed to write snapshot extra content: {}", e))
2014 })?;
2015 }
2016
2017 Ok(())
2018 })
2019 .await
2020 .map_err(|e| StoreError::Database(e.to_string()))??;
2021
2022 Ok(())
2023 }
2024}
2025
2026#[cfg(test)]
2027mod tests {
2028 use super::*;
2029
2030 async fn create_test_store() -> SqliteStore {
2031 use std::time::{SystemTime, UNIX_EPOCH};
2032 let timestamp = SystemTime::now()
2033 .duration_since(UNIX_EPOCH)
2034 .unwrap_or_default()
2035 .as_nanos();
2036 let db_name = format!("file:memdb_test_{}?mode=memory&cache=shared", timestamp);
2037 SqliteStore::new(&db_name)
2038 .await
2039 .expect("Failed to create test store")
2040 }
2041
2042 #[test]
2043 fn test_parse_database_path_regular_path() {
2044 let path = "/var/lib/whatsapp/database.db";
2045 let result = parse_database_path(path).unwrap();
2046 assert_eq!(result, "/var/lib/whatsapp/database.db");
2047 }
2048
2049 #[test]
2050 fn test_parse_database_path_with_sqlite_prefix() {
2051 let path = "sqlite:///var/lib/whatsapp/database.db";
2052 let result = parse_database_path(path).unwrap();
2053 assert_eq!(result, "/var/lib/whatsapp/database.db");
2054 }
2055
2056 #[test]
2057 fn test_parse_database_path_with_query_params() {
2058 let path = "file:database.db?mode=memory&cache=shared";
2059 let result = parse_database_path(path).unwrap();
2060 assert_eq!(result, "file:database.db");
2061 }
2062
2063 #[test]
2064 fn test_parse_database_path_with_fragment() {
2065 let path = "file:database.db#fragment";
2066 let result = parse_database_path(path).unwrap();
2067 assert_eq!(result, "file:database.db");
2068 }
2069
2070 #[test]
2071 fn test_parse_database_path_with_both_query_and_fragment() {
2072 let path = "sqlite:///var/lib/database.db?mode=ro#backup";
2073 let result = parse_database_path(path).unwrap();
2074 assert_eq!(result, "/var/lib/database.db");
2075 }
2076
2077 #[test]
2078 fn test_parse_database_path_in_memory_rejected() {
2079 let result = parse_database_path(":memory:");
2080 assert!(result.is_err());
2081 assert!(result.unwrap_err().to_string().contains("not supported"));
2082 }
2083
2084 #[test]
2085 fn test_parse_database_path_in_memory_with_query_rejected() {
2086 let result = parse_database_path(":memory:?cache=shared");
2087 assert!(result.is_err());
2088 assert!(result.unwrap_err().to_string().contains("not supported"));
2089 }
2090
2091 #[tokio::test]
2092 async fn test_device_registry_save_and_get() {
2093 let store = create_test_store().await;
2094
2095 let record = DeviceListRecord {
2096 user: "1234567890".to_string(),
2097 devices: vec![
2098 DeviceInfo {
2099 device_id: 0,
2100 key_index: None,
2101 },
2102 DeviceInfo {
2103 device_id: 1,
2104 key_index: Some(42),
2105 },
2106 ],
2107 timestamp: 1234567890,
2108 phash: Some("2:abcdef".to_string()),
2109 };
2110
2111 store.update_device_list(record).await.expect("save failed");
2112 let loaded = store
2113 .get_devices("1234567890")
2114 .await
2115 .expect("get failed")
2116 .expect("record should exist");
2117
2118 assert_eq!(loaded.user, "1234567890");
2119 assert_eq!(loaded.devices.len(), 2);
2120 assert_eq!(loaded.devices[0].device_id, 0);
2121 assert_eq!(loaded.devices[1].device_id, 1);
2122 assert_eq!(loaded.devices[1].key_index, Some(42));
2123 assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
2124 }
2125
2126 #[tokio::test]
2127 async fn test_device_registry_update_existing() {
2128 let store = create_test_store().await;
2129
2130 let record1 = DeviceListRecord {
2131 user: "1234567890".to_string(),
2132 devices: vec![DeviceInfo {
2133 device_id: 0,
2134 key_index: None,
2135 }],
2136 timestamp: 1000,
2137 phash: Some("2:old".to_string()),
2138 };
2139 store
2140 .update_device_list(record1)
2141 .await
2142 .expect("save1 failed");
2143
2144 let record2 = DeviceListRecord {
2145 user: "1234567890".to_string(),
2146 devices: vec![
2147 DeviceInfo {
2148 device_id: 0,
2149 key_index: None,
2150 },
2151 DeviceInfo {
2152 device_id: 2,
2153 key_index: None,
2154 },
2155 ],
2156 timestamp: 2000,
2157 phash: Some("2:new".to_string()),
2158 };
2159 store
2160 .update_device_list(record2)
2161 .await
2162 .expect("save2 failed");
2163
2164 let loaded = store
2165 .get_devices("1234567890")
2166 .await
2167 .expect("get failed")
2168 .expect("record should exist");
2169
2170 assert_eq!(loaded.devices.len(), 2);
2171 assert_eq!(loaded.phash, Some("2:new".to_string()));
2172 }
2173
2174 #[tokio::test]
2175 async fn test_device_registry_get_nonexistent() {
2176 let store = create_test_store().await;
2177 let result = store.get_devices("nonexistent").await.expect("get failed");
2178 assert!(result.is_none());
2179 }
2180
2181 #[tokio::test]
2182 async fn test_sender_key_status_mark_and_consume() {
2183 let store = create_test_store().await;
2184
2185 let group = "group123@g.us";
2186 let participant = "user1@s.whatsapp.net";
2187
2188 store
2189 .mark_forget_sender_key(group, participant)
2190 .await
2191 .expect("mark failed");
2192
2193 let consumed = store
2194 .consume_forget_marks(group)
2195 .await
2196 .expect("consume failed");
2197 assert_eq!(consumed.len(), 1);
2198 assert!(consumed.contains(&participant.to_string()));
2199
2200 let consumed = store
2201 .consume_forget_marks(group)
2202 .await
2203 .expect("consume failed");
2204 assert!(consumed.is_empty());
2205 }
2206
2207 #[tokio::test]
2208 async fn test_sender_key_status_consume_multiple() {
2209 let store = create_test_store().await;
2210
2211 let group = "group123@g.us";
2212
2213 store
2214 .mark_forget_sender_key(group, "user1@s.whatsapp.net")
2215 .await
2216 .expect("mark failed");
2217 store
2218 .mark_forget_sender_key(group, "user2@s.whatsapp.net")
2219 .await
2220 .expect("mark failed");
2221
2222 let consumed = store
2223 .consume_forget_marks(group)
2224 .await
2225 .expect("consume failed");
2226 assert_eq!(consumed.len(), 2);
2227 assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
2228 assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
2229
2230 let consumed = store
2231 .consume_forget_marks(group)
2232 .await
2233 .expect("consume failed");
2234 assert!(consumed.is_empty());
2235 }
2236
2237 #[tokio::test]
2238 async fn test_tc_token_put_and_get() {
2239 let store = create_test_store().await;
2240
2241 let entry = TcTokenEntry {
2242 token: vec![1, 2, 3, 4, 5],
2243 token_timestamp: 1707000000,
2244 sender_timestamp: Some(1707000100),
2245 };
2246
2247 store
2248 .put_tc_token("user@lid", &entry)
2249 .await
2250 .expect("put failed");
2251
2252 let loaded = store
2253 .get_tc_token("user@lid")
2254 .await
2255 .expect("get failed")
2256 .expect("should exist");
2257
2258 assert_eq!(loaded.token, vec![1, 2, 3, 4, 5]);
2259 assert_eq!(loaded.token_timestamp, 1707000000);
2260 assert_eq!(loaded.sender_timestamp, Some(1707000100));
2261 }
2262
2263 #[tokio::test]
2264 async fn test_tc_token_upsert() {
2265 let store = create_test_store().await;
2266
2267 let entry1 = TcTokenEntry {
2268 token: vec![1, 2, 3],
2269 token_timestamp: 1000,
2270 sender_timestamp: None,
2271 };
2272 store.put_tc_token("user@lid", &entry1).await.unwrap();
2273
2274 let entry2 = TcTokenEntry {
2275 token: vec![4, 5, 6],
2276 token_timestamp: 2000,
2277 sender_timestamp: Some(1500),
2278 };
2279 store.put_tc_token("user@lid", &entry2).await.unwrap();
2280
2281 let loaded = store.get_tc_token("user@lid").await.unwrap().unwrap();
2282 assert_eq!(loaded.token, vec![4, 5, 6]);
2283 assert_eq!(loaded.token_timestamp, 2000);
2284 assert_eq!(loaded.sender_timestamp, Some(1500));
2285 }
2286
2287 #[tokio::test]
2288 async fn test_tc_token_delete() {
2289 let store = create_test_store().await;
2290
2291 let entry = TcTokenEntry {
2292 token: vec![1, 2, 3],
2293 token_timestamp: 1000,
2294 sender_timestamp: None,
2295 };
2296 store.put_tc_token("user@lid", &entry).await.unwrap();
2297 store.delete_tc_token("user@lid").await.unwrap();
2298
2299 let result = store.get_tc_token("user@lid").await.unwrap();
2300 assert!(result.is_none());
2301 }
2302
2303 #[tokio::test]
2304 async fn test_tc_token_get_all_jids() {
2305 let store = create_test_store().await;
2306
2307 let entry = TcTokenEntry {
2308 token: vec![1],
2309 token_timestamp: 1000,
2310 sender_timestamp: None,
2311 };
2312 store.put_tc_token("user1@lid", &entry).await.unwrap();
2313 store.put_tc_token("user2@lid", &entry).await.unwrap();
2314 store.put_tc_token("user3@lid", &entry).await.unwrap();
2315
2316 let mut jids = store.get_all_tc_token_jids().await.unwrap();
2317 jids.sort();
2318 assert_eq!(jids, vec!["user1@lid", "user2@lid", "user3@lid"]);
2319 }
2320
2321 #[tokio::test]
2322 async fn test_tc_token_delete_expired() {
2323 let store = create_test_store().await;
2324
2325 let old = TcTokenEntry {
2326 token: vec![1],
2327 token_timestamp: 1000,
2328 sender_timestamp: None,
2329 };
2330 let recent = TcTokenEntry {
2331 token: vec![2],
2332 token_timestamp: 5000,
2333 sender_timestamp: None,
2334 };
2335 store.put_tc_token("old@lid", &old).await.unwrap();
2336 store.put_tc_token("recent@lid", &recent).await.unwrap();
2337
2338 let deleted = store.delete_expired_tc_tokens(3000).await.unwrap();
2339 assert_eq!(deleted, 1);
2340
2341 assert!(store.get_tc_token("old@lid").await.unwrap().is_none());
2342 assert!(store.get_tc_token("recent@lid").await.unwrap().is_some());
2343 }
2344
2345 #[tokio::test]
2346 async fn test_tc_token_get_nonexistent() {
2347 let store = create_test_store().await;
2348 let result = store.get_tc_token("nonexistent@lid").await.unwrap();
2349 assert!(result.is_none());
2350 }
2351
2352 #[tokio::test]
2353 async fn test_sender_key_status_different_groups() {
2354 let store = create_test_store().await;
2355
2356 let group1 = "group1@g.us";
2357 let group2 = "group2@g.us";
2358 let participant = "user@s.whatsapp.net";
2359
2360 store
2361 .mark_forget_sender_key(group1, participant)
2362 .await
2363 .expect("mark failed");
2364
2365 let consumed = store.consume_forget_marks(group1).await.unwrap();
2366 assert_eq!(consumed.len(), 1);
2367
2368 let consumed = store.consume_forget_marks(group2).await.unwrap();
2369 assert!(consumed.is_empty());
2370 }
2371}