whatsapp_rust/
bot.rs

1use crate::client::Client;
2use crate::store::persistence_manager::PersistenceManager;
3use crate::store::traits::Backend;
4use crate::types::enc_handler::EncHandler;
5use crate::types::events::{Event, EventHandler};
6use crate::types::message::MessageInfo;
7use anyhow::Result;
8use http::Uri;
9use log::{debug, info, warn};
10use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14use tokio::net::TcpStream;
15use tokio::sync::mpsc;
16use tokio::task;
17use waproto::whatsapp as wa;
18
19pub struct MessageContext {
20    pub message: Box<wa::Message>,
21    pub info: MessageInfo,
22    pub client: Arc<Client>,
23}
24
25impl MessageContext {
26    pub async fn send_message(&self, message: wa::Message) -> Result<String, anyhow::Error> {
27        self.client
28            .send_message(self.info.source.chat.clone(), message)
29            .await
30    }
31
32    pub async fn edit_message(
33        &self,
34        original_message_id: String,
35        new_message: wa::Message,
36    ) -> Result<String, anyhow::Error> {
37        self.client
38            .edit_message(
39                self.info.source.chat.clone(),
40                original_message_id,
41                new_message,
42            )
43            .await
44    }
45}
46
47type EventHandlerCallback =
48    Arc<dyn Fn(Event, Arc<Client>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
49
50struct BotEventHandler {
51    client: Arc<Client>,
52    event_handler: Option<EventHandlerCallback>,
53}
54
55impl EventHandler for BotEventHandler {
56    fn handle_event(&self, event: &Event) {
57        if let Some(handler) = &self.event_handler {
58            let handler_clone = handler.clone();
59            let event_clone = event.clone();
60            let client_clone = self.client.clone();
61
62            tokio::spawn(async move {
63                handler_clone(event_clone, client_clone).await;
64            });
65        }
66    }
67}
68
69pub struct Bot {
70    client: Arc<Client>,
71    sync_task_receiver: Option<mpsc::Receiver<crate::sync_task::MajorSyncTask>>,
72    event_handler: Option<EventHandlerCallback>,
73}
74
75impl std::fmt::Debug for Bot {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("Bot")
78            .field("client", &"<Client>")
79            .field("sync_task_receiver", &self.sync_task_receiver.is_some())
80            .field("event_handler", &self.event_handler.is_some())
81            .finish()
82    }
83}
84
85impl Bot {
86    pub fn builder() -> BotBuilder {
87        BotBuilder::new()
88    }
89
90    pub fn client(&self) -> Arc<Client> {
91        self.client.clone()
92    }
93
94    pub async fn run(&mut self) -> Result<task::JoinHandle<()>> {
95        if let Some(mut receiver) = self.sync_task_receiver.take() {
96            let worker_client = self.client.clone();
97            tokio::spawn(async move {
98                while let Some(task) = receiver.recv().await {
99                    match task {
100                        crate::sync_task::MajorSyncTask::HistorySync {
101                            message_id,
102                            notification,
103                        } => {
104                            worker_client
105                                .process_history_sync_task(message_id, *notification)
106                                .await;
107                        }
108                        crate::sync_task::MajorSyncTask::AppStateSync { name, full_sync } => {
109                            if let Err(e) = worker_client
110                                .process_app_state_sync_task(name, full_sync)
111                                .await
112                            {
113                                warn!("App state sync task for {:?} failed: {}", name, e);
114                            }
115                        }
116                    }
117                }
118                info!("Sync worker shutting down.");
119            });
120        }
121
122        let handler = Arc::new(BotEventHandler {
123            client: self.client.clone(),
124            event_handler: self.event_handler.take(),
125        });
126        self.client.core.event_bus.add_handler(handler);
127
128        let client_for_run = self.client.clone();
129        let client_handle = tokio::spawn(async move {
130            client_for_run.run().await;
131        });
132
133        Ok(client_handle)
134    }
135}
136
137#[derive(Default)]
138pub struct BotBuilder {
139    event_handler: Option<EventHandlerCallback>,
140    custom_enc_handlers: HashMap<String, Arc<dyn EncHandler>>,
141    device_id: Option<i32>,
142    // The only way to configure storage
143    backend: Option<Arc<dyn Backend>>,
144    override_version: Option<(u32, u32, u32)>,
145    os_info: Option<(Option<String>, Option<wa::device_props::AppVersion>)>,
146}
147
148impl BotBuilder {
149    fn new() -> Self {
150        Self::default()
151    }
152
153    pub fn on_event<F, Fut>(mut self, handler: F) -> Self
154    where
155        F: Fn(Event, Arc<Client>) -> Fut + Send + Sync + 'static,
156        Fut: Future<Output = ()> + Send + 'static,
157    {
158        self.event_handler = Some(Arc::new(move |event, client| {
159            Box::pin(handler(event, client))
160        }));
161        self
162    }
163
164    /// Register a custom handler for a specific encrypted message type
165    ///
166    /// # Arguments
167    /// * `enc_type` - The encrypted message type (e.g., "frskmsg")
168    /// * `handler` - The handler implementation for this type
169    ///
170    /// # Returns
171    /// The updated BotBuilder
172    pub fn with_enc_handler<H>(mut self, enc_type: impl Into<String>, handler: H) -> Self
173    where
174        H: EncHandler + 'static,
175    {
176        self.custom_enc_handlers
177            .insert(enc_type.into(), Arc::new(handler));
178        self
179    }
180
181    /// Specify which device ID to use for multi-account scenarios.
182    /// If not specified, single device mode will be used.
183    pub fn for_device(mut self, device_id: i32) -> Self {
184        self.device_id = Some(device_id);
185        self
186    }
187
188    /// Use a backend implementation for storage.
189    /// This is the only way to configure storage - there are no defaults.
190    ///
191    /// # Arguments
192    /// * `backend` - The backend implementation that provides all storage operations
193    ///
194    /// # Example
195    /// ```rust,ignore
196    /// let backend = Arc::new(SqliteStore::new("whatsapp.db").await?);
197    /// let bot = Bot::builder()
198    ///     .with_backend(backend)
199    ///     .build()
200    ///     .await?;
201    /// ```
202    pub fn with_backend(mut self, backend: Arc<dyn Backend>) -> Self {
203        self.backend = Some(backend);
204        self
205    }
206
207    /// Override the WhatsApp version used by the client.
208    ///
209    /// By default, the client will automatically fetch the latest version from WhatsApp's servers.
210    /// Use this method to force a specific version instead.
211    ///
212    /// # Arguments
213    /// * `version` - A tuple of (primary, secondary, tertiary) version numbers
214    ///
215    /// # Example
216    /// ```rust,ignore
217    /// let bot = Bot::builder()
218    ///     .with_backend(backend)
219    ///     .with_version((2, 3000, 1027868167))
220    ///     .build()
221    ///     .await?;
222    /// ```
223    pub fn with_version(mut self, version: (u32, u32, u32)) -> Self {
224        self.override_version = Some(version);
225        self
226    }
227
228    /// Override the OS information sent to WhatsApp servers.
229    /// This allows customizing the device properties that WhatsApp sees.
230    ///
231    /// # Arguments
232    /// * `os_name` - Optional OS name (e.g., "Android", "iOS", "Windows")
233    /// * `version` - Optional OS version as AppVersion struct
234    ///
235    /// You can pass `None` for either parameter to keep the default value.
236    ///
237    /// # Example
238    /// ```rust,ignore
239    /// use waproto::whatsapp::device_props;
240    ///
241    /// // Set only OS name, keep default version
242    /// let bot = Bot::builder()
243    ///     .with_backend(backend)
244    ///     .with_os_info(Some("Android".to_string()), None)
245    ///     .build()
246    ///     .await?;
247    ///
248    /// // Set only version, keep default OS
249    /// let bot = Bot::builder()
250    ///     .with_backend(backend)
251    ///     .with_os_info(None, Some(device_props::AppVersion {
252    ///         primary: Some(10),
253    ///         secondary: Some(0),
254    ///         tertiary: Some(0),
255    ///         ..Default::default()
256    ///     }))
257    ///     .build()
258    ///     .await?;
259    /// ```
260    pub fn with_os_info(
261        mut self,
262        os_name: Option<String>,
263        version: Option<wa::device_props::AppVersion>,
264    ) -> Self {
265        self.os_info = Some((os_name, version));
266        self
267    }
268
269    pub async fn build(self) -> Result<Bot> {
270        let backend = self.backend.ok_or_else(|| {
271            anyhow::anyhow!(
272                "Backend is required. Use with_backend() to set a storage implementation."
273            )
274        })?;
275
276        let persistence_manager = if let Some(device_id) = self.device_id {
277            info!("Creating PersistenceManager for device ID: {}", device_id);
278            Arc::new(
279                PersistenceManager::new_for_device(device_id, backend)
280                    .await
281                    .map_err(|e| {
282                        anyhow::anyhow!(
283                            "Failed to create persistence manager for device {}: {}",
284                            device_id,
285                            e
286                        )
287                    })?,
288            )
289        } else {
290            info!("Creating PersistenceManager for single device mode");
291            Arc::new(
292                PersistenceManager::new(backend)
293                    .await
294                    .map_err(|e| anyhow::anyhow!("Failed to create persistence manager: {}", e))?,
295            )
296        };
297
298        persistence_manager
299            .clone()
300            .run_background_saver(std::time::Duration::from_secs(30));
301
302        spawn_preconnect_task().await;
303
304        crate::version::resolve_and_update_version(&persistence_manager, self.override_version)
305            .await;
306
307        // Apply OS info override if specified
308        if let Some((os_name, version)) = self.os_info {
309            info!("Applying OS info override: {:?} {:?}", os_name, version);
310            persistence_manager
311                .modify_device(|device| {
312                    device.set_device_props(os_name, version);
313                })
314                .await;
315        }
316
317        info!("Creating client...");
318        let (client, sync_task_receiver) = Client::new(persistence_manager.clone()).await;
319
320        // Register custom enc handlers
321        for (enc_type, handler) in self.custom_enc_handlers {
322            client.custom_enc_handlers.insert(enc_type, handler);
323        }
324
325        Ok(Bot {
326            client,
327            sync_task_receiver: Some(sync_task_receiver),
328            event_handler: self.event_handler,
329        })
330    }
331}
332
333async fn spawn_preconnect_task() {
334    if let Ok(uri) = crate::socket::consts::URL.parse::<Uri>() {
335        if let Some(host) = uri.host() {
336            let port = uri.port_u16().unwrap_or(443);
337            let address = format!("{}:{}", host, port);
338
339            debug!(target: "Client/Preconnect", "Starting pre-connect to {}", address);
340            if let Err(e) = TcpStream::connect(&address).await {
341                warn!(target: "Client/Preconnect", "Pre-connection to {} failed: {}", address, e);
342            } else {
343                debug!(target: "Client/Preconnect", "Pre-connection to {} successful.", address);
344            }
345        }
346    } else {
347        warn!(target: "Client/Preconnect", "Could not parse WA_URL for pre-connect task.");
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::store::sqlite_store::SqliteStore;
355
356    async fn create_test_sqlite_backend() -> Arc<dyn Backend> {
357        let temp_db = format!(
358            "file:memdb_bot_{}?mode=memory&cache=shared",
359            uuid::Uuid::new_v4()
360        );
361        Arc::new(
362            SqliteStore::new(&temp_db)
363                .await
364                .expect("Failed to create test SqliteStore"),
365        ) as Arc<dyn Backend>
366    }
367
368    #[tokio::test]
369    async fn test_bot_builder_single_device() {
370        let backend = create_test_sqlite_backend().await;
371
372        let bot = Bot::builder()
373            .with_backend(backend)
374            .build()
375            .await
376            .expect("Failed to build bot");
377
378        let client = bot.client();
379        let persistence_manager = client.persistence_manager();
380
381        // Should have device ID 1 for single device mode
382        assert_eq!(persistence_manager.device_id(), 1);
383        assert!(!persistence_manager.is_multi_account());
384    }
385
386    #[tokio::test]
387    async fn test_bot_builder_multi_device() {
388        let backend = create_test_sqlite_backend().await;
389
390        // First, we need to create device data for device ID 42
391        let mut device = wacore::store::Device::new();
392        device.push_name = "Test Device".to_string();
393        backend
394            .save_device_data_for_device(42, &device)
395            .await
396            .expect("Failed to save device data");
397
398        let bot = Bot::builder()
399            .with_backend(backend)
400            .for_device(42)
401            .build()
402            .await
403            .expect("Failed to build bot");
404
405        let client = bot.client();
406        let persistence_manager = client.persistence_manager();
407
408        // Should have device ID 42
409        assert_eq!(persistence_manager.device_id(), 42);
410        assert!(persistence_manager.is_multi_account());
411    }
412
413    #[tokio::test]
414    async fn test_bot_builder_with_custom_backend() {
415        // Create an in-memory backend for testing
416        let backend = create_test_sqlite_backend().await;
417        let bot = Bot::builder()
418            .with_backend(backend)
419            .build()
420            .await
421            .expect("Failed to build bot with custom backend");
422
423        // Verify the bot was created successfully
424        let client = bot.client();
425        let persistence_manager = client.persistence_manager();
426
427        // Should have device ID 1 for single device mode
428        assert_eq!(persistence_manager.device_id(), 1);
429    }
430
431    #[tokio::test]
432    async fn test_bot_builder_with_custom_backend_specific_device() {
433        // Create an in-memory backend for testing
434        let backend = create_test_sqlite_backend().await;
435
436        // First, we need to create some device data for device ID 100
437        let mut device = wacore::store::Device::new();
438        device.push_name = "Test Device".to_string();
439        backend
440            .save_device_data_for_device(100, &device)
441            .await
442            .expect("Failed to save device data");
443
444        // Build a bot with the custom backend for a specific device
445        let bot = Bot::builder()
446            .with_backend(backend)
447            .for_device(100)
448            .build()
449            .await
450            .expect("Failed to build bot with custom backend for specific device");
451
452        // Verify the bot was created successfully with the correct device ID
453        let client = bot.client();
454        let persistence_manager = client.persistence_manager();
455
456        assert_eq!(persistence_manager.device_id(), 100);
457    }
458
459    #[tokio::test]
460    async fn test_bot_builder_missing_backend() {
461        // Try to build without setting a backend
462        let result = Bot::builder().build().await;
463
464        // This should fail
465        assert!(result.is_err());
466        assert!(
467            result
468                .unwrap_err()
469                .to_string()
470                .contains("Backend is required")
471        );
472    }
473
474    #[tokio::test]
475    async fn test_bot_builder_with_version_override() {
476        let backend = create_test_sqlite_backend().await;
477
478        let bot = Bot::builder()
479            .with_backend(backend)
480            .with_version((2, 3000, 123456789))
481            .build()
482            .await
483            .expect("Failed to build bot with version override");
484
485        // Verify the bot was created successfully
486        let client = bot.client();
487        let persistence_manager = client.persistence_manager();
488
489        // Check that the version was set correctly
490        let device_snapshot = persistence_manager.get_device_snapshot().await;
491        assert_eq!(device_snapshot.app_version_primary, 2);
492        assert_eq!(device_snapshot.app_version_secondary, 3000);
493        assert_eq!(device_snapshot.app_version_tertiary, 123456789);
494    }
495
496    #[tokio::test]
497    async fn test_bot_builder_with_os_info_override() {
498        let backend = create_test_sqlite_backend().await;
499
500        let custom_os = "CustomOS".to_string();
501        let custom_version = wa::device_props::AppVersion {
502            primary: Some(99),
503            secondary: Some(88),
504            tertiary: Some(77),
505            ..Default::default()
506        };
507
508        let bot = Bot::builder()
509            .with_backend(backend)
510            .with_os_info(Some(custom_os.clone()), Some(custom_version))
511            .build()
512            .await
513            .expect("Failed to build bot with OS info override");
514
515        let client = bot.client();
516        let persistence_manager = client.persistence_manager();
517        let device = persistence_manager.get_device_snapshot().await;
518
519        // Verify the OS info was overridden
520        assert_eq!(device.device_props.os, Some(custom_os));
521        assert_eq!(device.device_props.version, Some(custom_version));
522    }
523
524    #[tokio::test]
525    async fn test_bot_builder_with_os_only_override() {
526        let backend = create_test_sqlite_backend().await;
527
528        let custom_os = "CustomOS".to_string();
529
530        let bot = Bot::builder()
531            .with_backend(backend)
532            .with_os_info(Some(custom_os.clone()), None)
533            .build()
534            .await
535            .expect("Failed to build bot with OS only override");
536
537        let client = bot.client();
538        let persistence_manager = client.persistence_manager();
539        let device = persistence_manager.get_device_snapshot().await;
540
541        // Verify only OS was overridden, version should be default
542        assert_eq!(device.device_props.os, Some(custom_os));
543        // Version should be the default since we didn't override it
544        assert_eq!(
545            device.device_props.version,
546            Some(wacore::store::Device::default_device_props_version())
547        );
548    }
549
550    #[tokio::test]
551    async fn test_bot_builder_with_version_only_override() {
552        let backend = create_test_sqlite_backend().await;
553
554        let custom_version = wa::device_props::AppVersion {
555            primary: Some(99),
556            secondary: Some(88),
557            tertiary: Some(77),
558            ..Default::default()
559        };
560
561        let bot = Bot::builder()
562            .with_backend(backend)
563            .with_os_info(None, Some(custom_version))
564            .build()
565            .await
566            .expect("Failed to build bot with version only override");
567
568        let client = bot.client();
569        let persistence_manager = client.persistence_manager();
570        let device = persistence_manager.get_device_snapshot().await;
571
572        // Verify only version was overridden, OS should be default ("rust")
573        assert_eq!(device.device_props.version, Some(custom_version));
574        // OS should be the default since we didn't override it
575        assert_eq!(
576            device.device_props.os,
577            Some(wacore::store::Device::default_os().to_string())
578        );
579    }
580}