whatsapp_rust/store/
persistence_manager.rs1use super::error::{StoreError, db_err};
2use crate::store::Device;
3use crate::store::traits::Backend;
4use log::{debug, error};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use tokio::sync::{Notify, RwLock};
8use tokio::time::{Duration, sleep};
9
10pub struct PersistenceManager {
11 device: Arc<RwLock<Device>>,
12 backend: Arc<dyn Backend>,
13 dirty: Arc<AtomicBool>,
14 save_notify: Arc<Notify>,
15}
16
17impl PersistenceManager {
18 pub async fn new(backend: Arc<dyn Backend>) -> Result<Self, StoreError> {
23 debug!("PersistenceManager: Ensuring device row exists.");
24 let exists = backend.exists().await.map_err(db_err)?;
26 if !exists {
27 debug!("PersistenceManager: No device row found. Creating new device row.");
28 let id = backend.create().await.map_err(db_err)?;
29 debug!("PersistenceManager: Created device row with id={id}.");
30 }
31
32 debug!("PersistenceManager: Attempting to load device data via Backend.");
33 let device_data_opt = backend.load().await.map_err(db_err)?;
34
35 let device = if let Some(serializable_device) = device_data_opt {
36 debug!(
37 "PersistenceManager: Loaded existing device data (PushName: '{}'). Initializing Device.",
38 serializable_device.push_name
39 );
40 let mut dev = Device::new(backend.clone());
41 dev.load_from_serializable(serializable_device);
42 dev
43 } else {
44 debug!("PersistenceManager: No data yet; initializing default Device in memory.");
45 Device::new(backend.clone())
46 };
47
48 Ok(Self {
49 device: Arc::new(RwLock::new(device)),
50 backend,
51 dirty: Arc::new(AtomicBool::new(false)),
52 save_notify: Arc::new(Notify::new()),
53 })
54 }
55
56 pub async fn get_device_arc(&self) -> Arc<RwLock<Device>> {
57 self.device.clone()
58 }
59
60 pub async fn get_device_snapshot(&self) -> Device {
61 self.device.read().await.clone()
62 }
63
64 pub fn backend(&self) -> Arc<dyn Backend> {
65 self.backend.clone()
66 }
67
68 pub async fn modify_device<F, R>(&self, modifier: F) -> R
69 where
70 F: FnOnce(&mut Device) -> R,
71 {
72 let mut device_guard = self.device.write().await;
73 let result = modifier(&mut device_guard);
74
75 self.dirty.store(true, Ordering::Relaxed);
76 self.save_notify.notify_one();
77
78 result
79 }
80
81 async fn save_to_disk(&self) -> Result<(), StoreError> {
82 if self.dirty.swap(false, Ordering::AcqRel) {
83 debug!("Device state is dirty, saving to disk.");
84 let device_guard = self.device.read().await;
85 let serializable_device = device_guard.to_serializable();
86 drop(device_guard);
87
88 self.backend
89 .save(&serializable_device)
90 .await
91 .map_err(db_err)?;
92 debug!("Device state saved successfully.");
93 }
94 Ok(())
95 }
96
97 pub fn run_background_saver(self: Arc<Self>, interval: Duration) {
98 tokio::spawn(async move {
99 loop {
100 tokio::select! {
101 _ = self.save_notify.notified() => {
102 debug!("Save notification received.");
103 }
104 _ = sleep(interval) => {}
105 }
106
107 if let Err(e) = self.save_to_disk().await {
108 error!("Error saving device state in background: {e}");
109 }
110 }
111 });
112 debug!("Background saver task started with interval {interval:?}");
113 }
114}
115
116use super::commands::{DeviceCommand, apply_command_to_device};
117
118impl PersistenceManager {
119 pub async fn process_command(&self, command: DeviceCommand) {
120 self.modify_device(|device| {
121 apply_command_to_device(device, command);
122 })
123 .await;
124 }
125}
126
127impl PersistenceManager {
129 pub async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<String>, StoreError> {
131 self.backend
132 .get_skdm_recipients(group_jid)
133 .await
134 .map_err(db_err)
135 }
136
137 pub async fn add_skdm_recipients(
139 &self,
140 group_jid: &str,
141 device_jids: &[String],
142 ) -> Result<(), StoreError> {
143 self.backend
144 .add_skdm_recipients(group_jid, device_jids)
145 .await
146 .map_err(db_err)
147 }
148
149 pub async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<(), StoreError> {
151 self.backend
152 .clear_skdm_recipients(group_jid)
153 .await
154 .map_err(db_err)
155 }
156}