1use crate::schema::*;
2use async_trait::async_trait;
3use diesel::prelude::*;
4use diesel::r2d2::{ConnectionManager, Pool};
5use diesel::sql_query;
6use diesel::sqlite::SqliteConnection;
7use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
8use log::warn;
9use prost::Message;
10use std::sync::Arc;
11use wacore::appstate::hash::HashState;
12use wacore::appstate::processor::AppStateMutationMAC;
13use wacore::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
14use wacore::store::Device as CoreDevice;
15use wacore::store::error::{Result, StoreError};
16use wacore::store::traits::*;
17use waproto::whatsapp as wa;
18
19const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
20
21type SqlitePool = Pool<ConnectionManager<SqliteConnection>>;
22type DeviceRow = (
23 i32,
24 String,
25 String,
26 i32,
27 Vec<u8>,
28 Vec<u8>,
29 Vec<u8>,
30 i32,
31 Vec<u8>,
32 Vec<u8>,
33 Option<Vec<u8>>,
34 String,
35 i32,
36 i32,
37 i64,
38 i64,
39 Option<Vec<u8>>,
40);
41
42#[derive(Clone)]
43pub struct SqliteStore {
44 pub(crate) pool: SqlitePool,
45 pub(crate) db_semaphore: Arc<tokio::sync::Semaphore>,
46 device_id: i32,
47}
48
49#[derive(Debug, Clone, Copy)]
50struct ConnectionOptions;
51
52impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
53 for ConnectionOptions
54{
55 fn on_acquire(
56 &self,
57 conn: &mut SqliteConnection,
58 ) -> std::result::Result<(), diesel::r2d2::Error> {
59 diesel::sql_query("PRAGMA busy_timeout = 30000;")
60 .execute(conn)
61 .map_err(diesel::r2d2::Error::QueryError)?;
62 diesel::sql_query("PRAGMA synchronous = NORMAL;")
63 .execute(conn)
64 .map_err(diesel::r2d2::Error::QueryError)?;
65 diesel::sql_query("PRAGMA cache_size = 512;")
66 .execute(conn)
67 .map_err(diesel::r2d2::Error::QueryError)?;
68 diesel::sql_query("PRAGMA temp_store = memory;")
69 .execute(conn)
70 .map_err(diesel::r2d2::Error::QueryError)?;
71 diesel::sql_query("PRAGMA foreign_keys = ON;")
72 .execute(conn)
73 .map_err(diesel::r2d2::Error::QueryError)?;
74 Ok(())
75 }
76}
77
78impl SqliteStore {
79 pub async fn new(database_url: &str) -> std::result::Result<Self, StoreError> {
80 let manager = ConnectionManager::<SqliteConnection>::new(database_url);
81
82 let pool_size = 2;
83
84 let pool = Pool::builder()
85 .max_size(pool_size)
86 .connection_customizer(Box::new(ConnectionOptions))
87 .build(manager)
88 .map_err(|e| StoreError::Connection(e.to_string()))?;
89
90 let pool_clone = pool.clone();
91 tokio::task::spawn_blocking(move || -> std::result::Result<(), StoreError> {
92 let mut conn = pool_clone
93 .get()
94 .map_err(|e| StoreError::Connection(e.to_string()))?;
95
96 diesel::sql_query("PRAGMA journal_mode = WAL;")
97 .execute(&mut conn)
98 .map_err(|e| StoreError::Database(e.to_string()))?;
99
100 conn.run_pending_migrations(MIGRATIONS)
101 .map_err(|e| StoreError::Migration(e.to_string()))?;
102
103 Ok(())
104 })
105 .await
106 .map_err(|e| StoreError::Database(e.to_string()))??;
107
108 Ok(Self {
109 pool,
110 db_semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
111 device_id: 1,
112 })
113 }
114
115 pub async fn new_for_device(
116 database_url: &str,
117 device_id: i32,
118 ) -> std::result::Result<Self, StoreError> {
119 let mut store = Self::new(database_url).await?;
120 store.device_id = device_id;
121 Ok(store)
122 }
123
124 pub fn device_id(&self) -> i32 {
125 self.device_id
126 }
127
128 async fn with_semaphore<F, T>(&self, f: F) -> Result<T>
129 where
130 F: FnOnce() -> Result<T> + Send + 'static,
131 T: Send + 'static,
132 {
133 let permit = self
134 .db_semaphore
135 .clone()
136 .acquire_owned()
137 .await
138 .map_err(|e| StoreError::Database(format!("Semaphore error: {}", e)))?;
139 let result = tokio::task::spawn_blocking(move || {
140 let res = f();
141 drop(permit);
142 res
143 })
144 .await
145 .map_err(|e| StoreError::Database(e.to_string()))??;
146 Ok(result)
147 }
148
149 fn serialize_keypair(&self, key_pair: &KeyPair) -> Result<Vec<u8>> {
150 let mut bytes = Vec::with_capacity(64);
151 bytes.extend_from_slice(&key_pair.private_key.serialize());
152 bytes.extend_from_slice(key_pair.public_key.public_key_bytes());
153 Ok(bytes)
154 }
155
156 fn deserialize_keypair(&self, bytes: &[u8]) -> Result<KeyPair> {
157 if bytes.len() != 64 {
158 return Err(StoreError::Serialization(format!(
159 "Invalid KeyPair length: {}",
160 bytes.len()
161 )));
162 }
163
164 let private_key = PrivateKey::deserialize(&bytes[0..32])
165 .map_err(|e| StoreError::Serialization(e.to_string()))?;
166 let public_key = PublicKey::from_djb_public_key_bytes(&bytes[32..64])
167 .map_err(|e| StoreError::Serialization(e.to_string()))?;
168
169 Ok(KeyPair::new(public_key, private_key))
170 }
171
172 pub async fn save_device_data_for_device(
173 &self,
174 device_id: i32,
175 device_data: &CoreDevice,
176 ) -> Result<()> {
177 let pool = self.pool.clone();
178 let noise_key_data = self.serialize_keypair(&device_data.noise_key)?;
179 let identity_key_data = self.serialize_keypair(&device_data.identity_key)?;
180 let signed_pre_key_data = self.serialize_keypair(&device_data.signed_pre_key)?;
181 let account_data = device_data
182 .account
183 .as_ref()
184 .map(|account| account.encode_to_vec());
185 let registration_id = device_data.registration_id as i32;
186 let signed_pre_key_id = device_data.signed_pre_key_id as i32;
187 let signed_pre_key_signature: Vec<u8> = device_data.signed_pre_key_signature.to_vec();
188 let adv_secret_key: Vec<u8> = device_data.adv_secret_key.to_vec();
189 let push_name = device_data.push_name.clone();
190 let app_version_primary = device_data.app_version_primary as i32;
191 let app_version_secondary = device_data.app_version_secondary as i32;
192 let app_version_tertiary = device_data.app_version_tertiary as i64;
193 let app_version_last_fetched_ms = device_data.app_version_last_fetched_ms;
194 let edge_routing_info = device_data.edge_routing_info.clone();
195 let new_lid = device_data
196 .lid
197 .as_ref()
198 .map(|j| j.to_string())
199 .unwrap_or_default();
200 let new_pn = device_data
201 .pn
202 .as_ref()
203 .map(|j| j.to_string())
204 .unwrap_or_default();
205
206 tokio::task::spawn_blocking(move || -> Result<()> {
207 let mut conn = pool
208 .get()
209 .map_err(|e| StoreError::Connection(e.to_string()))?;
210
211 diesel::insert_into(device::table)
212 .values((
213 device::id.eq(device_id),
214 device::lid.eq(&new_lid),
215 device::pn.eq(&new_pn),
216 device::registration_id.eq(registration_id),
217 device::noise_key.eq(&noise_key_data),
218 device::identity_key.eq(&identity_key_data),
219 device::signed_pre_key.eq(&signed_pre_key_data),
220 device::signed_pre_key_id.eq(signed_pre_key_id),
221 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
222 device::adv_secret_key.eq(&adv_secret_key[..]),
223 device::account.eq(account_data.clone()),
224 device::push_name.eq(&push_name),
225 device::app_version_primary.eq(app_version_primary),
226 device::app_version_secondary.eq(app_version_secondary),
227 device::app_version_tertiary.eq(app_version_tertiary),
228 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
229 device::edge_routing_info.eq(edge_routing_info.clone()),
230 ))
231 .on_conflict(device::id)
232 .do_update()
233 .set((
234 device::lid.eq(&new_lid),
235 device::pn.eq(&new_pn),
236 device::registration_id.eq(registration_id),
237 device::noise_key.eq(&noise_key_data),
238 device::identity_key.eq(&identity_key_data),
239 device::signed_pre_key.eq(&signed_pre_key_data),
240 device::signed_pre_key_id.eq(signed_pre_key_id),
241 device::signed_pre_key_signature.eq(&signed_pre_key_signature[..]),
242 device::adv_secret_key.eq(&adv_secret_key[..]),
243 device::account.eq(account_data.clone()),
244 device::push_name.eq(&push_name),
245 device::app_version_primary.eq(app_version_primary),
246 device::app_version_secondary.eq(app_version_secondary),
247 device::app_version_tertiary.eq(app_version_tertiary),
248 device::app_version_last_fetched_ms.eq(app_version_last_fetched_ms),
249 device::edge_routing_info.eq(edge_routing_info),
250 ))
251 .execute(&mut conn)
252 .map_err(|e| StoreError::Database(e.to_string()))?;
253
254 Ok(())
255 })
256 .await
257 .map_err(|e| StoreError::Database(e.to_string()))??;
258
259 Ok(())
260 }
261
262 pub async fn create_new_device(&self) -> Result<i32> {
263 use crate::schema::device;
264
265 let pool = self.pool.clone();
266 tokio::task::spawn_blocking(move || -> Result<i32> {
267 let mut conn = pool
268 .get()
269 .map_err(|e| StoreError::Connection(e.to_string()))?;
270
271 let new_device = wacore::store::Device::new();
272
273 let noise_key_data = {
274 let mut bytes = Vec::with_capacity(64);
275 bytes.extend_from_slice(&new_device.noise_key.private_key.serialize());
276 bytes.extend_from_slice(new_device.noise_key.public_key.public_key_bytes());
277 bytes
278 };
279 let identity_key_data = {
280 let mut bytes = Vec::with_capacity(64);
281 bytes.extend_from_slice(&new_device.identity_key.private_key.serialize());
282 bytes.extend_from_slice(new_device.identity_key.public_key.public_key_bytes());
283 bytes
284 };
285 let signed_pre_key_data = {
286 let mut bytes = Vec::with_capacity(64);
287 bytes.extend_from_slice(&new_device.signed_pre_key.private_key.serialize());
288 bytes.extend_from_slice(new_device.signed_pre_key.public_key.public_key_bytes());
289 bytes
290 };
291
292 diesel::insert_into(device::table)
293 .values((
294 device::lid.eq(""),
295 device::pn.eq(""),
296 device::registration_id.eq(new_device.registration_id as i32),
297 device::noise_key.eq(&noise_key_data),
298 device::identity_key.eq(&identity_key_data),
299 device::signed_pre_key.eq(&signed_pre_key_data),
300 device::signed_pre_key_id.eq(new_device.signed_pre_key_id as i32),
301 device::signed_pre_key_signature.eq(&new_device.signed_pre_key_signature[..]),
302 device::adv_secret_key.eq(&new_device.adv_secret_key[..]),
303 device::account.eq(None::<Vec<u8>>),
304 device::push_name.eq(&new_device.push_name),
305 device::app_version_primary.eq(new_device.app_version_primary as i32),
306 device::app_version_secondary.eq(new_device.app_version_secondary as i32),
307 device::app_version_tertiary.eq(new_device.app_version_tertiary as i64),
308 device::app_version_last_fetched_ms.eq(new_device.app_version_last_fetched_ms),
309 device::edge_routing_info.eq(None::<Vec<u8>>),
310 ))
311 .execute(&mut conn)
312 .map_err(|e| StoreError::Database(e.to_string()))?;
313
314 use diesel::sql_types::Integer;
315
316 #[derive(QueryableByName)]
317 struct LastInsertedId {
318 #[diesel(sql_type = Integer)]
319 last_insert_rowid: i32,
320 }
321
322 let device_id: i32 = sql_query("SELECT last_insert_rowid() as last_insert_rowid")
323 .get_result::<LastInsertedId>(&mut conn)
324 .map_err(|e| StoreError::Database(e.to_string()))?
325 .last_insert_rowid;
326
327 Ok(device_id)
328 })
329 .await
330 .map_err(|e| StoreError::Database(e.to_string()))?
331 }
332
333 pub async fn device_exists(&self, device_id: i32) -> Result<bool> {
334 use crate::schema::device;
335
336 let pool = self.pool.clone();
337 tokio::task::spawn_blocking(move || -> Result<bool> {
338 let mut conn = pool
339 .get()
340 .map_err(|e| StoreError::Connection(e.to_string()))?;
341
342 let count: i64 = device::table
343 .filter(device::id.eq(device_id))
344 .count()
345 .get_result(&mut conn)
346 .map_err(|e| StoreError::Database(e.to_string()))?;
347
348 Ok(count > 0)
349 })
350 .await
351 .map_err(|e| StoreError::Database(e.to_string()))?
352 }
353
354 pub async fn load_device_data_for_device(&self, device_id: i32) -> Result<Option<CoreDevice>> {
355 use crate::schema::device;
356
357 let pool = self.pool.clone();
358 let row = tokio::task::spawn_blocking(move || -> Result<Option<DeviceRow>> {
359 let mut conn = pool
360 .get()
361 .map_err(|e| StoreError::Connection(e.to_string()))?;
362 let result = device::table
363 .filter(device::id.eq(device_id))
364 .first::<DeviceRow>(&mut conn)
365 .optional()
366 .map_err(|e| StoreError::Database(e.to_string()))?;
367 Ok(result)
368 })
369 .await
370 .map_err(|e| StoreError::Database(e.to_string()))??;
371
372 if let Some((
373 _device_id,
374 lid_str,
375 pn_str,
376 registration_id,
377 noise_key_data,
378 identity_key_data,
379 signed_pre_key_data,
380 signed_pre_key_id,
381 signed_pre_key_signature_data,
382 adv_secret_key_data,
383 account_data,
384 push_name,
385 app_version_primary,
386 app_version_secondary,
387 app_version_tertiary,
388 app_version_last_fetched_ms,
389 edge_routing_info,
390 )) = row
391 {
392 let id = if !pn_str.is_empty() {
393 pn_str.parse().ok()
394 } else {
395 None
396 };
397 let lid = if !lid_str.is_empty() {
398 lid_str.parse().ok()
399 } else {
400 None
401 };
402
403 let noise_key = self.deserialize_keypair(&noise_key_data)?;
404 let identity_key = self.deserialize_keypair(&identity_key_data)?;
405 let signed_pre_key = self.deserialize_keypair(&signed_pre_key_data)?;
406
407 let signed_pre_key_signature: [u8; 64] =
408 signed_pre_key_signature_data.try_into().map_err(|_| {
409 StoreError::Serialization("Invalid signed_pre_key_signature length".to_string())
410 })?;
411
412 let adv_secret_key: [u8; 32] = adv_secret_key_data.try_into().map_err(|_| {
413 StoreError::Serialization("Invalid adv_secret_key length".to_string())
414 })?;
415
416 let account = account_data
417 .map(|data| {
418 wa::AdvSignedDeviceIdentity::decode(&data[..])
419 .map_err(|e| StoreError::Serialization(e.to_string()))
420 })
421 .transpose()?;
422
423 Ok(Some(CoreDevice {
424 pn: id,
425 lid,
426 registration_id: registration_id as u32,
427 noise_key,
428 identity_key,
429 signed_pre_key,
430 signed_pre_key_id: signed_pre_key_id as u32,
431 signed_pre_key_signature,
432 adv_secret_key,
433 account,
434 push_name,
435 app_version_primary: app_version_primary as u32,
436 app_version_secondary: app_version_secondary as u32,
437 app_version_tertiary: app_version_tertiary.try_into().unwrap_or(0u32),
438 app_version_last_fetched_ms,
439 device_props: {
440 use wacore::store::device::DEVICE_PROPS;
441 DEVICE_PROPS.clone()
442 },
443 edge_routing_info,
444 }))
445 } else {
446 Ok(None)
447 }
448 }
449
450 pub async fn put_identity_for_device(
451 &self,
452 address: &str,
453 key: [u8; 32],
454 device_id: i32,
455 ) -> Result<()> {
456 let pool = self.pool.clone();
457 let db_semaphore = self.db_semaphore.clone();
458 let address_owned = address.to_string();
459 let key_vec = key.to_vec();
460
461 const MAX_RETRIES: u32 = 5;
462
463 for attempt in 0..=MAX_RETRIES {
464 let permit =
465 db_semaphore.clone().acquire_owned().await.map_err(|e| {
466 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
467 })?;
468
469 let pool_clone = pool.clone();
470 let address_clone = address_owned.clone();
471 let key_clone = key_vec.clone();
472
473 let result = tokio::task::spawn_blocking(move || -> Result<()> {
474 let mut conn = pool_clone
475 .get()
476 .map_err(|e| StoreError::Connection(e.to_string()))?;
477 diesel::insert_into(identities::table)
478 .values((
479 identities::address.eq(address_clone),
480 identities::key.eq(&key_clone[..]),
481 identities::device_id.eq(device_id),
482 ))
483 .on_conflict((identities::address, identities::device_id))
484 .do_update()
485 .set(identities::key.eq(&key_clone[..]))
486 .execute(&mut conn)
487 .map_err(|e| StoreError::Database(e.to_string()))?;
488 Ok(())
489 })
490 .await;
491
492 drop(permit);
493
494 match result {
495 Ok(Ok(())) => return Ok(()),
496 Ok(Err(e)) => {
497 let error_msg = e.to_string();
498 if (error_msg.contains("locked") || error_msg.contains("busy"))
499 && attempt < MAX_RETRIES
500 {
501 let delay_ms = 10 * 2u64.pow(attempt);
502 warn!(
503 "Identity write failed (attempt {}/{}): {}. Retrying in {}ms...",
504 attempt + 1,
505 MAX_RETRIES + 1,
506 error_msg,
507 delay_ms
508 );
509 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
510 continue;
511 }
512 return Err(e);
513 }
514 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
515 }
516 }
517
518 Err(StoreError::Database(format!(
519 "Identity write failed after {} attempts",
520 MAX_RETRIES + 1
521 )))
522 }
523
524 pub async fn delete_identity_for_device(&self, address: &str, device_id: i32) -> Result<()> {
525 let pool = self.pool.clone();
526 let address_owned = address.to_string();
527
528 tokio::task::spawn_blocking(move || -> Result<()> {
529 let mut conn = pool
530 .get()
531 .map_err(|e| StoreError::Connection(e.to_string()))?;
532 diesel::delete(
533 identities::table
534 .filter(identities::address.eq(address_owned))
535 .filter(identities::device_id.eq(device_id)),
536 )
537 .execute(&mut conn)
538 .map_err(|e| StoreError::Database(e.to_string()))?;
539 Ok(())
540 })
541 .await
542 .map_err(|e| StoreError::Database(e.to_string()))??;
543
544 Ok(())
545 }
546
547 pub async fn load_identity_for_device(
548 &self,
549 address: &str,
550 device_id: i32,
551 ) -> Result<Option<Vec<u8>>> {
552 let pool = self.pool.clone();
553 let address = address.to_string();
554 let result = self
555 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
556 let mut conn = pool
557 .get()
558 .map_err(|e| StoreError::Connection(e.to_string()))?;
559 let res: Option<Vec<u8>> = identities::table
560 .select(identities::key)
561 .filter(identities::address.eq(address))
562 .filter(identities::device_id.eq(device_id))
563 .first(&mut conn)
564 .optional()
565 .map_err(|e| StoreError::Database(e.to_string()))?;
566 Ok(res)
567 })
568 .await?;
569
570 Ok(result)
571 }
572
573 pub async fn get_session_for_device(
574 &self,
575 address: &str,
576 device_id: i32,
577 ) -> Result<Option<Vec<u8>>> {
578 let pool = self.pool.clone();
579 let address_for_query = address.to_string();
580 let result = self
581 .with_semaphore(move || -> Result<Option<Vec<u8>>> {
582 let mut conn = pool
583 .get()
584 .map_err(|e| StoreError::Connection(e.to_string()))?;
585 let res: Option<Vec<u8>> = sessions::table
586 .select(sessions::record)
587 .filter(sessions::address.eq(address_for_query.clone()))
588 .filter(sessions::device_id.eq(device_id))
589 .first(&mut conn)
590 .optional()
591 .map_err(|e| StoreError::Database(e.to_string()))?;
592
593 Ok(res)
594 })
595 .await?;
596
597 Ok(result)
598 }
599
600 pub async fn put_session_for_device(
601 &self,
602 address: &str,
603 session: &[u8],
604 device_id: i32,
605 ) -> Result<()> {
606 let pool = self.pool.clone();
607 let db_semaphore = self.db_semaphore.clone();
608 let address_owned = address.to_string();
609 let session_vec = session.to_vec();
610
611 const MAX_RETRIES: u32 = 5;
612
613 for attempt in 0..=MAX_RETRIES {
614 let permit =
615 db_semaphore.clone().acquire_owned().await.map_err(|e| {
616 StoreError::Database(format!("Failed to acquire semaphore: {}", e))
617 })?;
618
619 let pool_clone = pool.clone();
620 let address_clone = address_owned.clone();
621 let session_clone = session_vec.clone();
622
623 let result = tokio::task::spawn_blocking(move || -> Result<()> {
624 let mut conn = pool_clone
625 .get()
626 .map_err(|e| StoreError::Connection(e.to_string()))?;
627 diesel::insert_into(sessions::table)
628 .values((
629 sessions::address.eq(address_clone),
630 sessions::record.eq(&session_clone),
631 sessions::device_id.eq(device_id),
632 ))
633 .on_conflict((sessions::address, sessions::device_id))
634 .do_update()
635 .set(sessions::record.eq(&session_clone))
636 .execute(&mut conn)
637 .map_err(|e| StoreError::Database(e.to_string()))?;
638 Ok(())
639 })
640 .await;
641
642 drop(permit);
643
644 match result {
645 Ok(Ok(())) => {
646 return Ok(());
647 }
648 Ok(Err(e)) => {
649 let error_msg = e.to_string();
650 if (error_msg.contains("locked") || error_msg.contains("busy"))
651 && attempt < MAX_RETRIES
652 {
653 let delay_ms = 10 * 2u64.pow(attempt);
654 warn!(
655 "Session write failed (attempt {}/{}): {}. Retrying in {}ms...",
656 attempt + 1,
657 MAX_RETRIES + 1,
658 error_msg,
659 delay_ms
660 );
661 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
662 continue;
663 }
664 return Err(e);
665 }
666 Err(e) => return Err(StoreError::Database(format!("Task join error: {}", e))),
667 }
668 }
669
670 Err(StoreError::Database(format!(
671 "Session write failed after {} attempts",
672 MAX_RETRIES + 1
673 )))
674 }
675
676 pub async fn delete_session_for_device(&self, address: &str, device_id: i32) -> Result<()> {
677 let pool = self.pool.clone();
678 let address_owned = address.to_string();
679
680 tokio::task::spawn_blocking(move || -> Result<()> {
681 let mut conn = pool
682 .get()
683 .map_err(|e| StoreError::Connection(e.to_string()))?;
684 diesel::delete(
685 sessions::table
686 .filter(sessions::address.eq(address_owned))
687 .filter(sessions::device_id.eq(device_id)),
688 )
689 .execute(&mut conn)
690 .map_err(|e| StoreError::Database(e.to_string()))?;
691 Ok(())
692 })
693 .await
694 .map_err(|e| StoreError::Database(e.to_string()))??;
695
696 Ok(())
697 }
698
699 pub async fn put_sender_key_for_device(
700 &self,
701 address: &str,
702 record: &[u8],
703 device_id: i32,
704 ) -> Result<()> {
705 let pool = self.pool.clone();
706 let address = address.to_string();
707 let record_vec = record.to_vec();
708 tokio::task::spawn_blocking(move || -> Result<()> {
709 let mut conn = pool
710 .get()
711 .map_err(|e| StoreError::Connection(e.to_string()))?;
712 diesel::insert_into(sender_keys::table)
713 .values((
714 sender_keys::address.eq(address),
715 sender_keys::record.eq(&record_vec),
716 sender_keys::device_id.eq(device_id),
717 ))
718 .on_conflict((sender_keys::address, sender_keys::device_id))
719 .do_update()
720 .set(sender_keys::record.eq(&record_vec))
721 .execute(&mut conn)
722 .map_err(|e| StoreError::Database(e.to_string()))?;
723 Ok(())
724 })
725 .await
726 .map_err(|e| StoreError::Database(e.to_string()))??;
727 Ok(())
728 }
729
730 pub async fn get_sender_key_for_device(
731 &self,
732 address: &str,
733 device_id: i32,
734 ) -> Result<Option<Vec<u8>>> {
735 let pool = self.pool.clone();
736 let address = address.to_string();
737 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
738 let mut conn = pool
739 .get()
740 .map_err(|e| StoreError::Connection(e.to_string()))?;
741 let res: Option<Vec<u8>> = sender_keys::table
742 .select(sender_keys::record)
743 .filter(sender_keys::address.eq(address))
744 .filter(sender_keys::device_id.eq(device_id))
745 .first(&mut conn)
746 .optional()
747 .map_err(|e| StoreError::Database(e.to_string()))?;
748 Ok(res)
749 })
750 .await
751 .map_err(|e| StoreError::Database(e.to_string()))?
752 }
753
754 pub async fn delete_sender_key_for_device(&self, address: &str, device_id: i32) -> Result<()> {
755 let pool = self.pool.clone();
756 let address = address.to_string();
757 tokio::task::spawn_blocking(move || -> Result<()> {
758 let mut conn = pool
759 .get()
760 .map_err(|e| StoreError::Connection(e.to_string()))?;
761 diesel::delete(
762 sender_keys::table
763 .filter(sender_keys::address.eq(address))
764 .filter(sender_keys::device_id.eq(device_id)),
765 )
766 .execute(&mut conn)
767 .map_err(|e| StoreError::Database(e.to_string()))?;
768 Ok(())
769 })
770 .await
771 .map_err(|e| StoreError::Database(e.to_string()))??;
772 Ok(())
773 }
774
775 pub async fn get_app_state_sync_key_for_device(
776 &self,
777 key_id: &[u8],
778 device_id: i32,
779 ) -> Result<Option<AppStateSyncKey>> {
780 let pool = self.pool.clone();
781 let key_id = key_id.to_vec();
782 let res: Option<Vec<u8>> =
783 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
784 let mut conn = pool
785 .get()
786 .map_err(|e| StoreError::Connection(e.to_string()))?;
787 let res: Option<Vec<u8>> = app_state_keys::table
788 .select(app_state_keys::key_data)
789 .filter(app_state_keys::key_id.eq(&key_id))
790 .filter(app_state_keys::device_id.eq(device_id))
791 .first(&mut conn)
792 .optional()
793 .map_err(|e| StoreError::Database(e.to_string()))?;
794 Ok(res)
795 })
796 .await
797 .map_err(|e| StoreError::Database(e.to_string()))??;
798
799 if let Some(data) = res {
800 let (key, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
801 .map_err(|e| StoreError::Serialization(e.to_string()))?;
802 Ok(Some(key))
803 } else {
804 Ok(None)
805 }
806 }
807
808 pub async fn set_app_state_sync_key_for_device(
809 &self,
810 key_id: &[u8],
811 key: AppStateSyncKey,
812 device_id: i32,
813 ) -> Result<()> {
814 let pool = self.pool.clone();
815 let key_id = key_id.to_vec();
816 let data = bincode::serde::encode_to_vec(&key, bincode::config::standard())
817 .map_err(|e| StoreError::Serialization(e.to_string()))?;
818 tokio::task::spawn_blocking(move || -> Result<()> {
819 let mut conn = pool
820 .get()
821 .map_err(|e| StoreError::Connection(e.to_string()))?;
822 diesel::insert_into(app_state_keys::table)
823 .values((
824 app_state_keys::key_id.eq(&key_id),
825 app_state_keys::key_data.eq(&data),
826 app_state_keys::device_id.eq(device_id),
827 ))
828 .on_conflict((app_state_keys::key_id, app_state_keys::device_id))
829 .do_update()
830 .set(app_state_keys::key_data.eq(&data))
831 .execute(&mut conn)
832 .map_err(|e| StoreError::Database(e.to_string()))?;
833 Ok(())
834 })
835 .await
836 .map_err(|e| StoreError::Database(e.to_string()))??;
837 Ok(())
838 }
839
840 pub async fn get_app_state_version_for_device(
841 &self,
842 name: &str,
843 device_id: i32,
844 ) -> Result<HashState> {
845 let pool = self.pool.clone();
846 let name = name.to_string();
847 let res: Option<Vec<u8>> =
848 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
849 let mut conn = pool
850 .get()
851 .map_err(|e| StoreError::Connection(e.to_string()))?;
852 let res: Option<Vec<u8>> = app_state_versions::table
853 .select(app_state_versions::state_data)
854 .filter(app_state_versions::name.eq(name))
855 .filter(app_state_versions::device_id.eq(device_id))
856 .first(&mut conn)
857 .optional()
858 .map_err(|e| StoreError::Database(e.to_string()))?;
859 Ok(res)
860 })
861 .await
862 .map_err(|e| StoreError::Database(e.to_string()))??;
863
864 if let Some(data) = res {
865 let (state, _) = bincode::serde::decode_from_slice(&data, bincode::config::standard())
866 .map_err(|e| StoreError::Serialization(e.to_string()))?;
867 Ok(state)
868 } else {
869 Ok(HashState::default())
870 }
871 }
872
873 pub async fn set_app_state_version_for_device(
874 &self,
875 name: &str,
876 state: HashState,
877 device_id: i32,
878 ) -> Result<()> {
879 let pool = self.pool.clone();
880 let name = name.to_string();
881 let data = bincode::serde::encode_to_vec(&state, bincode::config::standard())
882 .map_err(|e| StoreError::Serialization(e.to_string()))?;
883 tokio::task::spawn_blocking(move || -> Result<()> {
884 let mut conn = pool
885 .get()
886 .map_err(|e| StoreError::Connection(e.to_string()))?;
887 diesel::insert_into(app_state_versions::table)
888 .values((
889 app_state_versions::name.eq(&name),
890 app_state_versions::state_data.eq(&data),
891 app_state_versions::device_id.eq(device_id),
892 ))
893 .on_conflict((app_state_versions::name, app_state_versions::device_id))
894 .do_update()
895 .set(app_state_versions::state_data.eq(&data))
896 .execute(&mut conn)
897 .map_err(|e| StoreError::Database(e.to_string()))?;
898 Ok(())
899 })
900 .await
901 .map_err(|e| StoreError::Database(e.to_string()))??;
902 Ok(())
903 }
904
905 pub async fn put_app_state_mutation_macs_for_device(
906 &self,
907 name: &str,
908 version: u64,
909 mutations: &[AppStateMutationMAC],
910 device_id: i32,
911 ) -> Result<()> {
912 if mutations.is_empty() {
913 return Ok(());
914 }
915 let pool = self.pool.clone();
916 let name = name.to_string();
917 let mutations: Vec<AppStateMutationMAC> = mutations.to_vec();
918 tokio::task::spawn_blocking(move || -> Result<()> {
919 let mut conn = pool
920 .get()
921 .map_err(|e| StoreError::Connection(e.to_string()))?;
922 for m in mutations {
923 diesel::insert_into(app_state_mutation_macs::table)
924 .values((
925 app_state_mutation_macs::name.eq(&name),
926 app_state_mutation_macs::version.eq(version as i64),
927 app_state_mutation_macs::index_mac.eq(&m.index_mac),
928 app_state_mutation_macs::value_mac.eq(&m.value_mac),
929 app_state_mutation_macs::device_id.eq(device_id),
930 ))
931 .on_conflict((
932 app_state_mutation_macs::name,
933 app_state_mutation_macs::index_mac,
934 app_state_mutation_macs::device_id,
935 ))
936 .do_update()
937 .set((
938 app_state_mutation_macs::version.eq(version as i64),
939 app_state_mutation_macs::value_mac.eq(&m.value_mac),
940 ))
941 .execute(&mut conn)
942 .map_err(|e| StoreError::Database(e.to_string()))?;
943 }
944 Ok(())
945 })
946 .await
947 .map_err(|e| StoreError::Database(e.to_string()))??;
948 Ok(())
949 }
950
951 pub async fn delete_app_state_mutation_macs_for_device(
952 &self,
953 name: &str,
954 index_macs: &[Vec<u8>],
955 device_id: i32,
956 ) -> Result<()> {
957 if index_macs.is_empty() {
958 return Ok(());
959 }
960 let pool = self.pool.clone();
961 let name = name.to_string();
962 let index_macs: Vec<Vec<u8>> = index_macs.to_vec();
963 tokio::task::spawn_blocking(move || -> Result<()> {
964 let mut conn = pool
965 .get()
966 .map_err(|e| StoreError::Connection(e.to_string()))?;
967 for idx in index_macs {
968 diesel::delete(
969 app_state_mutation_macs::table.filter(
970 app_state_mutation_macs::name
971 .eq(&name)
972 .and(app_state_mutation_macs::index_mac.eq(&idx))
973 .and(app_state_mutation_macs::device_id.eq(device_id)),
974 ),
975 )
976 .execute(&mut conn)
977 .map_err(|e| StoreError::Database(e.to_string()))?;
978 }
979 Ok(())
980 })
981 .await
982 .map_err(|e| StoreError::Database(e.to_string()))??;
983 Ok(())
984 }
985
986 pub async fn get_app_state_mutation_mac_for_device(
987 &self,
988 name: &str,
989 index_mac: &[u8],
990 device_id: i32,
991 ) -> Result<Option<Vec<u8>>> {
992 let pool = self.pool.clone();
993 let name = name.to_string();
994 let index_mac = index_mac.to_vec();
995 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
996 let mut conn = pool
997 .get()
998 .map_err(|e| StoreError::Connection(e.to_string()))?;
999 let res: Option<Vec<u8>> = app_state_mutation_macs::table
1000 .select(app_state_mutation_macs::value_mac)
1001 .filter(app_state_mutation_macs::name.eq(&name))
1002 .filter(app_state_mutation_macs::index_mac.eq(&index_mac))
1003 .filter(app_state_mutation_macs::device_id.eq(device_id))
1004 .first(&mut conn)
1005 .optional()
1006 .map_err(|e| StoreError::Database(e.to_string()))?;
1007 Ok(res)
1008 })
1009 .await
1010 .map_err(|e| StoreError::Database(e.to_string()))?
1011 }
1012}
1013
1014#[async_trait]
1015impl SignalStore for SqliteStore {
1016 async fn put_identity(&self, address: &str, key: [u8; 32]) -> Result<()> {
1017 self.put_identity_for_device(address, key, self.device_id)
1018 .await
1019 }
1020
1021 async fn load_identity(&self, address: &str) -> Result<Option<Vec<u8>>> {
1022 self.load_identity_for_device(address, self.device_id).await
1023 }
1024
1025 async fn delete_identity(&self, address: &str) -> Result<()> {
1026 self.delete_identity_for_device(address, self.device_id)
1027 .await
1028 }
1029
1030 async fn get_session(&self, address: &str) -> Result<Option<Vec<u8>>> {
1031 self.get_session_for_device(address, self.device_id).await
1032 }
1033
1034 async fn put_session(&self, address: &str, session: &[u8]) -> Result<()> {
1035 self.put_session_for_device(address, session, self.device_id)
1036 .await
1037 }
1038
1039 async fn delete_session(&self, address: &str) -> Result<()> {
1040 self.delete_session_for_device(address, self.device_id)
1041 .await
1042 }
1043
1044 async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> Result<()> {
1045 let pool = self.pool.clone();
1046 let device_id = self.device_id;
1047 let record = record.to_vec();
1048 tokio::task::spawn_blocking(move || -> Result<()> {
1049 let mut conn = pool
1050 .get()
1051 .map_err(|e| StoreError::Connection(e.to_string()))?;
1052 diesel::insert_into(prekeys::table)
1053 .values((
1054 prekeys::id.eq(id as i32),
1055 prekeys::key.eq(&record),
1056 prekeys::uploaded.eq(uploaded),
1057 prekeys::device_id.eq(device_id),
1058 ))
1059 .on_conflict((prekeys::id, prekeys::device_id))
1060 .do_update()
1061 .set((prekeys::key.eq(&record), prekeys::uploaded.eq(uploaded)))
1062 .execute(&mut conn)
1063 .map_err(|e| StoreError::Database(e.to_string()))?;
1064 Ok(())
1065 })
1066 .await
1067 .map_err(|e| StoreError::Database(e.to_string()))??;
1068 Ok(())
1069 }
1070
1071 async fn load_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1072 let pool = self.pool.clone();
1073 let device_id = self.device_id;
1074 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1075 let mut conn = pool
1076 .get()
1077 .map_err(|e| StoreError::Connection(e.to_string()))?;
1078 let res: Option<Vec<u8>> = prekeys::table
1079 .select(prekeys::key)
1080 .filter(prekeys::id.eq(id as i32))
1081 .filter(prekeys::device_id.eq(device_id))
1082 .first(&mut conn)
1083 .optional()
1084 .map_err(|e| StoreError::Database(e.to_string()))?;
1085 Ok(res)
1086 })
1087 .await
1088 .map_err(|e| StoreError::Database(e.to_string()))?
1089 }
1090
1091 async fn remove_prekey(&self, id: u32) -> Result<()> {
1092 let pool = self.pool.clone();
1093 let device_id = self.device_id;
1094 tokio::task::spawn_blocking(move || -> Result<()> {
1095 let mut conn = pool
1096 .get()
1097 .map_err(|e| StoreError::Connection(e.to_string()))?;
1098 diesel::delete(
1099 prekeys::table
1100 .filter(prekeys::id.eq(id as i32))
1101 .filter(prekeys::device_id.eq(device_id)),
1102 )
1103 .execute(&mut conn)
1104 .map_err(|e| StoreError::Database(e.to_string()))?;
1105 Ok(())
1106 })
1107 .await
1108 .map_err(|e| StoreError::Database(e.to_string()))??;
1109 Ok(())
1110 }
1111
1112 async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> Result<()> {
1113 let pool = self.pool.clone();
1114 let device_id = self.device_id;
1115 let record = record.to_vec();
1116 tokio::task::spawn_blocking(move || -> Result<()> {
1117 let mut conn = pool
1118 .get()
1119 .map_err(|e| StoreError::Connection(e.to_string()))?;
1120 diesel::insert_into(signed_prekeys::table)
1121 .values((
1122 signed_prekeys::id.eq(id as i32),
1123 signed_prekeys::record.eq(&record),
1124 signed_prekeys::device_id.eq(device_id),
1125 ))
1126 .on_conflict((signed_prekeys::id, signed_prekeys::device_id))
1127 .do_update()
1128 .set(signed_prekeys::record.eq(&record))
1129 .execute(&mut conn)
1130 .map_err(|e| StoreError::Database(e.to_string()))?;
1131 Ok(())
1132 })
1133 .await
1134 .map_err(|e| StoreError::Database(e.to_string()))??;
1135 Ok(())
1136 }
1137
1138 async fn load_signed_prekey(&self, id: u32) -> Result<Option<Vec<u8>>> {
1139 let pool = self.pool.clone();
1140 let device_id = self.device_id;
1141 tokio::task::spawn_blocking(move || -> Result<Option<Vec<u8>>> {
1142 let mut conn = pool
1143 .get()
1144 .map_err(|e| StoreError::Connection(e.to_string()))?;
1145 let res: Option<Vec<u8>> = signed_prekeys::table
1146 .select(signed_prekeys::record)
1147 .filter(signed_prekeys::id.eq(id as i32))
1148 .filter(signed_prekeys::device_id.eq(device_id))
1149 .first(&mut conn)
1150 .optional()
1151 .map_err(|e| StoreError::Database(e.to_string()))?;
1152 Ok(res)
1153 })
1154 .await
1155 .map_err(|e| StoreError::Database(e.to_string()))?
1156 }
1157
1158 async fn load_all_signed_prekeys(&self) -> Result<Vec<(u32, Vec<u8>)>> {
1159 let pool = self.pool.clone();
1160 let device_id = self.device_id;
1161 tokio::task::spawn_blocking(move || -> Result<Vec<(u32, Vec<u8>)>> {
1162 let mut conn = pool
1163 .get()
1164 .map_err(|e| StoreError::Connection(e.to_string()))?;
1165 let results: Vec<(i32, Vec<u8>)> = signed_prekeys::table
1166 .select((signed_prekeys::id, signed_prekeys::record))
1167 .filter(signed_prekeys::device_id.eq(device_id))
1168 .load(&mut conn)
1169 .map_err(|e| StoreError::Database(e.to_string()))?;
1170 Ok(results
1171 .into_iter()
1172 .map(|(id, record)| (id as u32, record))
1173 .collect())
1174 })
1175 .await
1176 .map_err(|e| StoreError::Database(e.to_string()))?
1177 }
1178
1179 async fn remove_signed_prekey(&self, id: u32) -> Result<()> {
1180 let pool = self.pool.clone();
1181 let device_id = self.device_id;
1182 tokio::task::spawn_blocking(move || -> Result<()> {
1183 let mut conn = pool
1184 .get()
1185 .map_err(|e| StoreError::Connection(e.to_string()))?;
1186 diesel::delete(
1187 signed_prekeys::table
1188 .filter(signed_prekeys::id.eq(id as i32))
1189 .filter(signed_prekeys::device_id.eq(device_id)),
1190 )
1191 .execute(&mut conn)
1192 .map_err(|e| StoreError::Database(e.to_string()))?;
1193 Ok(())
1194 })
1195 .await
1196 .map_err(|e| StoreError::Database(e.to_string()))??;
1197 Ok(())
1198 }
1199
1200 async fn put_sender_key(&self, address: &str, record: &[u8]) -> Result<()> {
1201 self.put_sender_key_for_device(address, record, self.device_id)
1202 .await
1203 }
1204
1205 async fn get_sender_key(&self, address: &str) -> Result<Option<Vec<u8>>> {
1206 self.get_sender_key_for_device(address, self.device_id)
1207 .await
1208 }
1209
1210 async fn delete_sender_key(&self, address: &str) -> Result<()> {
1211 self.delete_sender_key_for_device(address, self.device_id)
1212 .await
1213 }
1214}
1215
1216#[async_trait]
1217impl AppSyncStore for SqliteStore {
1218 async fn get_sync_key(&self, key_id: &[u8]) -> Result<Option<AppStateSyncKey>> {
1219 self.get_app_state_sync_key_for_device(key_id, self.device_id)
1220 .await
1221 }
1222
1223 async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> Result<()> {
1224 self.set_app_state_sync_key_for_device(key_id, key, self.device_id)
1225 .await
1226 }
1227
1228 async fn get_version(&self, name: &str) -> Result<HashState> {
1229 self.get_app_state_version_for_device(name, self.device_id)
1230 .await
1231 }
1232
1233 async fn set_version(&self, name: &str, state: HashState) -> Result<()> {
1234 self.set_app_state_version_for_device(name, state, self.device_id)
1235 .await
1236 }
1237
1238 async fn put_mutation_macs(
1239 &self,
1240 name: &str,
1241 version: u64,
1242 mutations: &[AppStateMutationMAC],
1243 ) -> Result<()> {
1244 self.put_app_state_mutation_macs_for_device(name, version, mutations, self.device_id)
1245 .await
1246 }
1247
1248 async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> Result<Option<Vec<u8>>> {
1249 self.get_app_state_mutation_mac_for_device(name, index_mac, self.device_id)
1250 .await
1251 }
1252
1253 async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> Result<()> {
1254 self.delete_app_state_mutation_macs_for_device(name, index_macs, self.device_id)
1255 .await
1256 }
1257}
1258
1259#[async_trait]
1260impl ProtocolStore for SqliteStore {
1261 async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<String>> {
1262 let pool = self.pool.clone();
1263 let device_id = self.device_id;
1264 let group_jid = group_jid.to_string();
1265 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1266 let mut conn = pool
1267 .get()
1268 .map_err(|e| StoreError::Connection(e.to_string()))?;
1269 let recipients: Vec<String> = skdm_recipients::table
1270 .select(skdm_recipients::device_jid)
1271 .filter(skdm_recipients::group_jid.eq(&group_jid))
1272 .filter(skdm_recipients::device_id.eq(device_id))
1273 .load(&mut conn)
1274 .map_err(|e| StoreError::Database(e.to_string()))?;
1275 Ok(recipients)
1276 })
1277 .await
1278 .map_err(|e| StoreError::Database(e.to_string()))?
1279 }
1280
1281 async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[String]) -> Result<()> {
1282 if device_jids.is_empty() {
1283 return Ok(());
1284 }
1285 let pool = self.pool.clone();
1286 let device_id = self.device_id;
1287 let group_jid = group_jid.to_string();
1288 let device_jids: Vec<String> = device_jids.to_vec();
1289 let now = std::time::SystemTime::now()
1290 .duration_since(std::time::UNIX_EPOCH)
1291 .unwrap_or_default()
1292 .as_secs() as i32;
1293 tokio::task::spawn_blocking(move || -> Result<()> {
1294 let mut conn = pool
1295 .get()
1296 .map_err(|e| StoreError::Connection(e.to_string()))?;
1297 for device_jid in device_jids {
1298 diesel::insert_into(skdm_recipients::table)
1299 .values((
1300 skdm_recipients::group_jid.eq(&group_jid),
1301 skdm_recipients::device_jid.eq(&device_jid),
1302 skdm_recipients::device_id.eq(device_id),
1303 skdm_recipients::created_at.eq(now),
1304 ))
1305 .on_conflict((
1306 skdm_recipients::group_jid,
1307 skdm_recipients::device_jid,
1308 skdm_recipients::device_id,
1309 ))
1310 .do_nothing()
1311 .execute(&mut conn)
1312 .map_err(|e| StoreError::Database(e.to_string()))?;
1313 }
1314 Ok(())
1315 })
1316 .await
1317 .map_err(|e| StoreError::Database(e.to_string()))??;
1318 Ok(())
1319 }
1320
1321 async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<()> {
1322 let pool = self.pool.clone();
1323 let device_id = self.device_id;
1324 let group_jid = group_jid.to_string();
1325 tokio::task::spawn_blocking(move || -> Result<()> {
1326 let mut conn = pool
1327 .get()
1328 .map_err(|e| StoreError::Connection(e.to_string()))?;
1329 diesel::delete(
1330 skdm_recipients::table
1331 .filter(skdm_recipients::group_jid.eq(&group_jid))
1332 .filter(skdm_recipients::device_id.eq(device_id)),
1333 )
1334 .execute(&mut conn)
1335 .map_err(|e| StoreError::Database(e.to_string()))?;
1336 Ok(())
1337 })
1338 .await
1339 .map_err(|e| StoreError::Database(e.to_string()))??;
1340 Ok(())
1341 }
1342
1343 async fn get_lid_mapping(&self, lid: &str) -> Result<Option<LidPnMappingEntry>> {
1344 let pool = self.pool.clone();
1345 let device_id = self.device_id;
1346 let lid = lid.to_string();
1347 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1348 let mut conn = pool
1349 .get()
1350 .map_err(|e| StoreError::Connection(e.to_string()))?;
1351 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1352 .select((
1353 lid_pn_mapping::lid,
1354 lid_pn_mapping::phone_number,
1355 lid_pn_mapping::created_at,
1356 lid_pn_mapping::learning_source,
1357 lid_pn_mapping::updated_at,
1358 ))
1359 .filter(lid_pn_mapping::lid.eq(&lid))
1360 .filter(lid_pn_mapping::device_id.eq(device_id))
1361 .first(&mut conn)
1362 .optional()
1363 .map_err(|e| StoreError::Database(e.to_string()))?;
1364 Ok(row.map(
1365 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1366 lid,
1367 phone_number,
1368 created_at,
1369 updated_at,
1370 learning_source,
1371 },
1372 ))
1373 })
1374 .await
1375 .map_err(|e| StoreError::Database(e.to_string()))?
1376 }
1377
1378 async fn get_pn_mapping(&self, phone: &str) -> Result<Option<LidPnMappingEntry>> {
1379 let pool = self.pool.clone();
1380 let device_id = self.device_id;
1381 let phone = phone.to_string();
1382 tokio::task::spawn_blocking(move || -> Result<Option<LidPnMappingEntry>> {
1383 let mut conn = pool
1384 .get()
1385 .map_err(|e| StoreError::Connection(e.to_string()))?;
1386 let row: Option<(String, String, i64, String, i64)> = lid_pn_mapping::table
1387 .select((
1388 lid_pn_mapping::lid,
1389 lid_pn_mapping::phone_number,
1390 lid_pn_mapping::created_at,
1391 lid_pn_mapping::learning_source,
1392 lid_pn_mapping::updated_at,
1393 ))
1394 .filter(lid_pn_mapping::phone_number.eq(&phone))
1395 .filter(lid_pn_mapping::device_id.eq(device_id))
1396 .order(lid_pn_mapping::updated_at.desc())
1397 .first(&mut conn)
1398 .optional()
1399 .map_err(|e| StoreError::Database(e.to_string()))?;
1400 Ok(row.map(
1401 |(lid, phone_number, created_at, learning_source, updated_at)| LidPnMappingEntry {
1402 lid,
1403 phone_number,
1404 created_at,
1405 updated_at,
1406 learning_source,
1407 },
1408 ))
1409 })
1410 .await
1411 .map_err(|e| StoreError::Database(e.to_string()))?
1412 }
1413
1414 async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> Result<()> {
1415 let pool = self.pool.clone();
1416 let device_id = self.device_id;
1417 let entry = entry.clone();
1418 tokio::task::spawn_blocking(move || -> Result<()> {
1419 let mut conn = pool
1420 .get()
1421 .map_err(|e| StoreError::Connection(e.to_string()))?;
1422 diesel::insert_into(lid_pn_mapping::table)
1423 .values((
1424 lid_pn_mapping::lid.eq(&entry.lid),
1425 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1426 lid_pn_mapping::created_at.eq(entry.created_at),
1427 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1428 lid_pn_mapping::updated_at.eq(entry.updated_at),
1429 lid_pn_mapping::device_id.eq(device_id),
1430 ))
1431 .on_conflict((lid_pn_mapping::lid, lid_pn_mapping::device_id))
1432 .do_update()
1433 .set((
1434 lid_pn_mapping::phone_number.eq(&entry.phone_number),
1435 lid_pn_mapping::learning_source.eq(&entry.learning_source),
1436 lid_pn_mapping::updated_at.eq(entry.updated_at),
1437 ))
1438 .execute(&mut conn)
1439 .map_err(|e| StoreError::Database(e.to_string()))?;
1440 Ok(())
1441 })
1442 .await
1443 .map_err(|e| StoreError::Database(e.to_string()))??;
1444 Ok(())
1445 }
1446
1447 async fn get_all_lid_mappings(&self) -> Result<Vec<LidPnMappingEntry>> {
1448 let pool = self.pool.clone();
1449 let device_id = self.device_id;
1450 tokio::task::spawn_blocking(move || -> Result<Vec<LidPnMappingEntry>> {
1451 let mut conn = pool
1452 .get()
1453 .map_err(|e| StoreError::Connection(e.to_string()))?;
1454 let rows: Vec<(String, String, i64, String, i64)> = lid_pn_mapping::table
1455 .select((
1456 lid_pn_mapping::lid,
1457 lid_pn_mapping::phone_number,
1458 lid_pn_mapping::created_at,
1459 lid_pn_mapping::learning_source,
1460 lid_pn_mapping::updated_at,
1461 ))
1462 .filter(lid_pn_mapping::device_id.eq(device_id))
1463 .load(&mut conn)
1464 .map_err(|e| StoreError::Database(e.to_string()))?;
1465 Ok(rows
1466 .into_iter()
1467 .map(
1468 |(lid, phone_number, created_at, learning_source, updated_at)| {
1469 LidPnMappingEntry {
1470 lid,
1471 phone_number,
1472 created_at,
1473 updated_at,
1474 learning_source,
1475 }
1476 },
1477 )
1478 .collect())
1479 })
1480 .await
1481 .map_err(|e| StoreError::Database(e.to_string()))?
1482 }
1483
1484 async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> Result<()> {
1485 let pool = self.pool.clone();
1486 let device_id = self.device_id;
1487 let address = address.to_string();
1488 let message_id = message_id.to_string();
1489 let base_key = base_key.to_vec();
1490 let now = std::time::SystemTime::now()
1491 .duration_since(std::time::UNIX_EPOCH)
1492 .unwrap_or_default()
1493 .as_secs() as i32;
1494 tokio::task::spawn_blocking(move || -> Result<()> {
1495 let mut conn = pool
1496 .get()
1497 .map_err(|e| StoreError::Connection(e.to_string()))?;
1498 diesel::insert_into(base_keys::table)
1499 .values((
1500 base_keys::address.eq(&address),
1501 base_keys::message_id.eq(&message_id),
1502 base_keys::base_key.eq(&base_key),
1503 base_keys::device_id.eq(device_id),
1504 base_keys::created_at.eq(now),
1505 ))
1506 .on_conflict((
1507 base_keys::address,
1508 base_keys::message_id,
1509 base_keys::device_id,
1510 ))
1511 .do_update()
1512 .set(base_keys::base_key.eq(&base_key))
1513 .execute(&mut conn)
1514 .map_err(|e| StoreError::Database(e.to_string()))?;
1515 Ok(())
1516 })
1517 .await
1518 .map_err(|e| StoreError::Database(e.to_string()))??;
1519 Ok(())
1520 }
1521
1522 async fn has_same_base_key(
1523 &self,
1524 address: &str,
1525 message_id: &str,
1526 current_base_key: &[u8],
1527 ) -> Result<bool> {
1528 let pool = self.pool.clone();
1529 let device_id = self.device_id;
1530 let address = address.to_string();
1531 let message_id = message_id.to_string();
1532 let current_base_key = current_base_key.to_vec();
1533 tokio::task::spawn_blocking(move || -> Result<bool> {
1534 let mut conn = pool
1535 .get()
1536 .map_err(|e| StoreError::Connection(e.to_string()))?;
1537 let stored_key: Option<Vec<u8>> = base_keys::table
1538 .select(base_keys::base_key)
1539 .filter(base_keys::address.eq(&address))
1540 .filter(base_keys::message_id.eq(&message_id))
1541 .filter(base_keys::device_id.eq(device_id))
1542 .first(&mut conn)
1543 .optional()
1544 .map_err(|e| StoreError::Database(e.to_string()))?;
1545 Ok(stored_key.as_ref() == Some(¤t_base_key))
1546 })
1547 .await
1548 .map_err(|e| StoreError::Database(e.to_string()))?
1549 }
1550
1551 async fn delete_base_key(&self, address: &str, message_id: &str) -> Result<()> {
1552 let pool = self.pool.clone();
1553 let device_id = self.device_id;
1554 let address = address.to_string();
1555 let message_id = message_id.to_string();
1556 tokio::task::spawn_blocking(move || -> Result<()> {
1557 let mut conn = pool
1558 .get()
1559 .map_err(|e| StoreError::Connection(e.to_string()))?;
1560 diesel::delete(
1561 base_keys::table
1562 .filter(base_keys::address.eq(&address))
1563 .filter(base_keys::message_id.eq(&message_id))
1564 .filter(base_keys::device_id.eq(device_id)),
1565 )
1566 .execute(&mut conn)
1567 .map_err(|e| StoreError::Database(e.to_string()))?;
1568 Ok(())
1569 })
1570 .await
1571 .map_err(|e| StoreError::Database(e.to_string()))??;
1572 Ok(())
1573 }
1574
1575 async fn update_device_list(&self, record: DeviceListRecord) -> Result<()> {
1576 let pool = self.pool.clone();
1577 let device_id = self.device_id;
1578 let devices_json = serde_json::to_string(&record.devices)
1579 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1580 let now = std::time::SystemTime::now()
1581 .duration_since(std::time::UNIX_EPOCH)
1582 .unwrap_or_default()
1583 .as_secs() as i32;
1584 tokio::task::spawn_blocking(move || -> Result<()> {
1585 let mut conn = pool
1586 .get()
1587 .map_err(|e| StoreError::Connection(e.to_string()))?;
1588 diesel::insert_into(device_registry::table)
1589 .values((
1590 device_registry::user_id.eq(&record.user),
1591 device_registry::devices_json.eq(&devices_json),
1592 device_registry::timestamp.eq(record.timestamp as i32),
1593 device_registry::phash.eq(&record.phash),
1594 device_registry::device_id.eq(device_id),
1595 device_registry::updated_at.eq(now),
1596 ))
1597 .on_conflict((device_registry::user_id, device_registry::device_id))
1598 .do_update()
1599 .set((
1600 device_registry::devices_json.eq(&devices_json),
1601 device_registry::timestamp.eq(record.timestamp as i32),
1602 device_registry::phash.eq(&record.phash),
1603 device_registry::updated_at.eq(now),
1604 ))
1605 .execute(&mut conn)
1606 .map_err(|e| StoreError::Database(e.to_string()))?;
1607 Ok(())
1608 })
1609 .await
1610 .map_err(|e| StoreError::Database(e.to_string()))??;
1611 Ok(())
1612 }
1613
1614 async fn get_devices(&self, user: &str) -> Result<Option<DeviceListRecord>> {
1615 let pool = self.pool.clone();
1616 let device_id = self.device_id;
1617 let user = user.to_string();
1618 tokio::task::spawn_blocking(move || -> Result<Option<DeviceListRecord>> {
1619 let mut conn = pool
1620 .get()
1621 .map_err(|e| StoreError::Connection(e.to_string()))?;
1622 let row: Option<(String, String, i32, Option<String>)> = device_registry::table
1623 .select((
1624 device_registry::user_id,
1625 device_registry::devices_json,
1626 device_registry::timestamp,
1627 device_registry::phash,
1628 ))
1629 .filter(device_registry::user_id.eq(&user))
1630 .filter(device_registry::device_id.eq(device_id))
1631 .first(&mut conn)
1632 .optional()
1633 .map_err(|e| StoreError::Database(e.to_string()))?;
1634 match row {
1635 Some((user, devices_json, timestamp, phash)) => {
1636 let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
1637 .map_err(|e| StoreError::Serialization(e.to_string()))?;
1638 Ok(Some(DeviceListRecord {
1639 user,
1640 devices,
1641 timestamp: timestamp as i64,
1642 phash,
1643 }))
1644 }
1645 None => Ok(None),
1646 }
1647 })
1648 .await
1649 .map_err(|e| StoreError::Database(e.to_string()))?
1650 }
1651
1652 async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> Result<()> {
1653 let pool = self.pool.clone();
1654 let device_id = self.device_id;
1655 let group_jid = group_jid.to_string();
1656 let participant = participant.to_string();
1657 let now = std::time::SystemTime::now()
1658 .duration_since(std::time::UNIX_EPOCH)
1659 .unwrap_or_default()
1660 .as_secs() as i32;
1661 tokio::task::spawn_blocking(move || -> Result<()> {
1662 let mut conn = pool
1663 .get()
1664 .map_err(|e| StoreError::Connection(e.to_string()))?;
1665 diesel::insert_into(sender_key_status::table)
1666 .values((
1667 sender_key_status::group_jid.eq(&group_jid),
1668 sender_key_status::participant.eq(&participant),
1669 sender_key_status::device_id.eq(device_id),
1670 sender_key_status::marked_at.eq(now),
1671 ))
1672 .on_conflict((
1673 sender_key_status::group_jid,
1674 sender_key_status::participant,
1675 sender_key_status::device_id,
1676 ))
1677 .do_update()
1678 .set(sender_key_status::marked_at.eq(now))
1679 .execute(&mut conn)
1680 .map_err(|e| StoreError::Database(e.to_string()))?;
1681 Ok(())
1682 })
1683 .await
1684 .map_err(|e| StoreError::Database(e.to_string()))??;
1685 Ok(())
1686 }
1687
1688 async fn consume_forget_marks(&self, group_jid: &str) -> Result<Vec<String>> {
1689 let pool = self.pool.clone();
1690 let device_id = self.device_id;
1691 let group_jid = group_jid.to_string();
1692 tokio::task::spawn_blocking(move || -> Result<Vec<String>> {
1693 let mut conn = pool
1694 .get()
1695 .map_err(|e| StoreError::Connection(e.to_string()))?;
1696 let participants: Vec<String> = sender_key_status::table
1697 .select(sender_key_status::participant)
1698 .filter(sender_key_status::group_jid.eq(&group_jid))
1699 .filter(sender_key_status::device_id.eq(device_id))
1700 .load(&mut conn)
1701 .map_err(|e| StoreError::Database(e.to_string()))?;
1702 diesel::delete(
1703 sender_key_status::table
1704 .filter(sender_key_status::group_jid.eq(&group_jid))
1705 .filter(sender_key_status::device_id.eq(device_id)),
1706 )
1707 .execute(&mut conn)
1708 .map_err(|e| StoreError::Database(e.to_string()))?;
1709 Ok(participants)
1710 })
1711 .await
1712 .map_err(|e| StoreError::Database(e.to_string()))?
1713 }
1714}
1715
1716#[async_trait]
1717impl DeviceStore for SqliteStore {
1718 async fn save(&self, device: &CoreDevice) -> Result<()> {
1719 SqliteStore::save_device_data_for_device(self, self.device_id, device).await
1720 }
1721
1722 async fn load(&self) -> Result<Option<CoreDevice>> {
1723 SqliteStore::load_device_data_for_device(self, self.device_id).await
1724 }
1725
1726 async fn exists(&self) -> Result<bool> {
1727 SqliteStore::device_exists(self, self.device_id).await
1728 }
1729
1730 async fn create(&self) -> Result<i32> {
1731 SqliteStore::create_new_device(self).await
1732 }
1733}
1734
1735#[cfg(test)]
1736mod tests {
1737 use super::*;
1738
1739 async fn create_test_store() -> SqliteStore {
1740 SqliteStore::new(":memory:")
1741 .await
1742 .expect("Failed to create test store")
1743 }
1744
1745 #[tokio::test]
1746 async fn test_device_registry_save_and_get() {
1747 let store = create_test_store().await;
1748
1749 let record = DeviceListRecord {
1750 user: "1234567890".to_string(),
1751 devices: vec![
1752 DeviceInfo {
1753 device_id: 0,
1754 key_index: None,
1755 },
1756 DeviceInfo {
1757 device_id: 1,
1758 key_index: Some(42),
1759 },
1760 ],
1761 timestamp: 1234567890,
1762 phash: Some("2:abcdef".to_string()),
1763 };
1764
1765 store.update_device_list(record).await.expect("save failed");
1766 let loaded = store
1767 .get_devices("1234567890")
1768 .await
1769 .expect("get failed")
1770 .expect("record should exist");
1771
1772 assert_eq!(loaded.user, "1234567890");
1773 assert_eq!(loaded.devices.len(), 2);
1774 assert_eq!(loaded.devices[0].device_id, 0);
1775 assert_eq!(loaded.devices[1].device_id, 1);
1776 assert_eq!(loaded.devices[1].key_index, Some(42));
1777 assert_eq!(loaded.phash, Some("2:abcdef".to_string()));
1778 }
1779
1780 #[tokio::test]
1781 async fn test_device_registry_update_existing() {
1782 let store = create_test_store().await;
1783
1784 let record1 = DeviceListRecord {
1785 user: "1234567890".to_string(),
1786 devices: vec![DeviceInfo {
1787 device_id: 0,
1788 key_index: None,
1789 }],
1790 timestamp: 1000,
1791 phash: Some("2:old".to_string()),
1792 };
1793 store
1794 .update_device_list(record1)
1795 .await
1796 .expect("save1 failed");
1797
1798 let record2 = DeviceListRecord {
1799 user: "1234567890".to_string(),
1800 devices: vec![
1801 DeviceInfo {
1802 device_id: 0,
1803 key_index: None,
1804 },
1805 DeviceInfo {
1806 device_id: 2,
1807 key_index: None,
1808 },
1809 ],
1810 timestamp: 2000,
1811 phash: Some("2:new".to_string()),
1812 };
1813 store
1814 .update_device_list(record2)
1815 .await
1816 .expect("save2 failed");
1817
1818 let loaded = store
1819 .get_devices("1234567890")
1820 .await
1821 .expect("get failed")
1822 .expect("record should exist");
1823
1824 assert_eq!(loaded.devices.len(), 2);
1825 assert_eq!(loaded.phash, Some("2:new".to_string()));
1826 }
1827
1828 #[tokio::test]
1829 async fn test_device_registry_get_nonexistent() {
1830 let store = create_test_store().await;
1831 let result = store.get_devices("nonexistent").await.expect("get failed");
1832 assert!(result.is_none());
1833 }
1834
1835 #[tokio::test]
1836 async fn test_sender_key_status_mark_and_consume() {
1837 let store = create_test_store().await;
1838
1839 let group = "group123@g.us";
1840 let participant = "user1@s.whatsapp.net";
1841
1842 store
1843 .mark_forget_sender_key(group, participant)
1844 .await
1845 .expect("mark failed");
1846
1847 let consumed = store
1848 .consume_forget_marks(group)
1849 .await
1850 .expect("consume failed");
1851 assert_eq!(consumed.len(), 1);
1852 assert!(consumed.contains(&participant.to_string()));
1853
1854 let consumed = store
1855 .consume_forget_marks(group)
1856 .await
1857 .expect("consume failed");
1858 assert!(consumed.is_empty());
1859 }
1860
1861 #[tokio::test]
1862 async fn test_sender_key_status_consume_multiple() {
1863 let store = create_test_store().await;
1864
1865 let group = "group123@g.us";
1866
1867 store
1868 .mark_forget_sender_key(group, "user1@s.whatsapp.net")
1869 .await
1870 .expect("mark failed");
1871 store
1872 .mark_forget_sender_key(group, "user2@s.whatsapp.net")
1873 .await
1874 .expect("mark failed");
1875
1876 let consumed = store
1877 .consume_forget_marks(group)
1878 .await
1879 .expect("consume failed");
1880 assert_eq!(consumed.len(), 2);
1881 assert!(consumed.contains(&"user1@s.whatsapp.net".to_string()));
1882 assert!(consumed.contains(&"user2@s.whatsapp.net".to_string()));
1883
1884 let consumed = store
1885 .consume_forget_marks(group)
1886 .await
1887 .expect("consume failed");
1888 assert!(consumed.is_empty());
1889 }
1890
1891 #[tokio::test]
1892 async fn test_sender_key_status_different_groups() {
1893 let store = create_test_store().await;
1894
1895 let group1 = "group1@g.us";
1896 let group2 = "group2@g.us";
1897 let participant = "user@s.whatsapp.net";
1898
1899 store
1900 .mark_forget_sender_key(group1, participant)
1901 .await
1902 .expect("mark failed");
1903
1904 let consumed = store.consume_forget_marks(group1).await.unwrap();
1905 assert_eq!(consumed.len(), 1);
1906
1907 let consumed = store.consume_forget_marks(group2).await.unwrap();
1908 assert!(consumed.is_empty());
1909 }
1910}