whatsapp_rust/store/
persistence_manager.rs

1use super::error::{StoreError, db_err};
2use crate::store::Device;
3use crate::store::traits::Backend;
4use log::{debug, error};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use tokio::sync::{Notify, RwLock};
8use tokio::time::{Duration, sleep};
9
10pub struct PersistenceManager {
11    device: Arc<RwLock<Device>>,
12    backend: Arc<dyn Backend>,
13    dirty: Arc<AtomicBool>,
14    save_notify: Arc<Notify>,
15}
16
17impl PersistenceManager {
18    /// Create a PersistenceManager with a backend implementation.
19    ///
20    /// Note: The backend should already be configured with the correct device_id
21    /// (via SqliteStore::new_for_device for multi-account scenarios).
22    pub async fn new(backend: Arc<dyn Backend>) -> Result<Self, StoreError> {
23        debug!("PersistenceManager: Ensuring device row exists.");
24        // Ensure a device row exists for this backend's device_id; create it if not.
25        let exists = backend.exists().await.map_err(db_err)?;
26        if !exists {
27            debug!("PersistenceManager: No device row found. Creating new device row.");
28            let id = backend.create().await.map_err(db_err)?;
29            debug!("PersistenceManager: Created device row with id={id}.");
30        }
31
32        debug!("PersistenceManager: Attempting to load device data via Backend.");
33        let device_data_opt = backend.load().await.map_err(db_err)?;
34
35        let device = if let Some(serializable_device) = device_data_opt {
36            debug!(
37                "PersistenceManager: Loaded existing device data (PushName: '{}'). Initializing Device.",
38                serializable_device.push_name
39            );
40            let mut dev = Device::new(backend.clone());
41            dev.load_from_serializable(serializable_device);
42            dev
43        } else {
44            debug!("PersistenceManager: No data yet; initializing default Device in memory.");
45            Device::new(backend.clone())
46        };
47
48        Ok(Self {
49            device: Arc::new(RwLock::new(device)),
50            backend,
51            dirty: Arc::new(AtomicBool::new(false)),
52            save_notify: Arc::new(Notify::new()),
53        })
54    }
55
56    pub async fn get_device_arc(&self) -> Arc<RwLock<Device>> {
57        self.device.clone()
58    }
59
60    pub async fn get_device_snapshot(&self) -> Device {
61        self.device.read().await.clone()
62    }
63
64    pub fn backend(&self) -> Arc<dyn Backend> {
65        self.backend.clone()
66    }
67
68    pub async fn modify_device<F, R>(&self, modifier: F) -> R
69    where
70        F: FnOnce(&mut Device) -> R,
71    {
72        let mut device_guard = self.device.write().await;
73        let result = modifier(&mut device_guard);
74
75        self.dirty.store(true, Ordering::Relaxed);
76        self.save_notify.notify_one();
77
78        result
79    }
80
81    async fn save_to_disk(&self) -> Result<(), StoreError> {
82        if self.dirty.swap(false, Ordering::AcqRel) {
83            debug!("Device state is dirty, saving to disk.");
84            let device_guard = self.device.read().await;
85            let serializable_device = device_guard.to_serializable();
86            drop(device_guard);
87
88            self.backend
89                .save(&serializable_device)
90                .await
91                .map_err(db_err)?;
92            debug!("Device state saved successfully.");
93        }
94        Ok(())
95    }
96
97    pub fn run_background_saver(self: Arc<Self>, interval: Duration) {
98        tokio::spawn(async move {
99            loop {
100                tokio::select! {
101                    _ = self.save_notify.notified() => {
102                        debug!("Save notification received.");
103                    }
104                    _ = sleep(interval) => {}
105                }
106
107                if let Err(e) = self.save_to_disk().await {
108                    error!("Error saving device state in background: {e}");
109                }
110            }
111        });
112        debug!("Background saver task started with interval {interval:?}");
113    }
114}
115
116use super::commands::{DeviceCommand, apply_command_to_device};
117
118impl PersistenceManager {
119    pub async fn process_command(&self, command: DeviceCommand) {
120        self.modify_device(|device| {
121            apply_command_to_device(device, command);
122        })
123        .await;
124    }
125}
126
127// SKDM recipient tracking methods
128impl PersistenceManager {
129    /// Get the list of device JIDs that have already received SKDM for a group
130    pub async fn get_skdm_recipients(&self, group_jid: &str) -> Result<Vec<String>, StoreError> {
131        self.backend
132            .get_skdm_recipients(group_jid)
133            .await
134            .map_err(db_err)
135    }
136
137    /// Mark devices as having received SKDM for a group
138    pub async fn add_skdm_recipients(
139        &self,
140        group_jid: &str,
141        device_jids: &[String],
142    ) -> Result<(), StoreError> {
143        self.backend
144            .add_skdm_recipients(group_jid, device_jids)
145            .await
146            .map_err(db_err)
147    }
148
149    /// Clear all SKDM recipients for a group (used when sender key is rotated)
150    pub async fn clear_skdm_recipients(&self, group_jid: &str) -> Result<(), StoreError> {
151        self.backend
152            .clear_skdm_recipients(group_jid)
153            .await
154            .map_err(db_err)
155    }
156}