1use crate::client::Client;
2use crate::store::persistence_manager::PersistenceManager;
3use crate::store::traits::Backend;
4use crate::types::enc_handler::EncHandler;
5use crate::types::events::{Event, EventHandler};
6use crate::types::message::MessageInfo;
7use anyhow::Result;
8use http::Uri;
9use log::{debug, info, warn};
10use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14use tokio::net::TcpStream;
15use tokio::sync::mpsc;
16use tokio::task;
17use waproto::whatsapp as wa;
18
19pub struct MessageContext {
20 pub message: Box<wa::Message>,
21 pub info: MessageInfo,
22 pub client: Arc<Client>,
23}
24
25impl MessageContext {
26 pub async fn send_message(&self, message: wa::Message) -> Result<String, anyhow::Error> {
27 self.client
28 .send_message(self.info.source.chat.clone(), message)
29 .await
30 }
31
32 pub async fn edit_message(
33 &self,
34 original_message_id: String,
35 new_message: wa::Message,
36 ) -> Result<String, anyhow::Error> {
37 self.client
38 .edit_message(
39 self.info.source.chat.clone(),
40 original_message_id,
41 new_message,
42 )
43 .await
44 }
45}
46
47type EventHandlerCallback =
48 Arc<dyn Fn(Event, Arc<Client>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
49
50struct BotEventHandler {
51 client: Arc<Client>,
52 event_handler: Option<EventHandlerCallback>,
53}
54
55impl EventHandler for BotEventHandler {
56 fn handle_event(&self, event: &Event) {
57 if let Some(handler) = &self.event_handler {
58 let handler_clone = handler.clone();
59 let event_clone = event.clone();
60 let client_clone = self.client.clone();
61
62 tokio::spawn(async move {
63 handler_clone(event_clone, client_clone).await;
64 });
65 }
66 }
67}
68
69pub struct Bot {
70 client: Arc<Client>,
71 sync_task_receiver: Option<mpsc::Receiver<crate::sync_task::MajorSyncTask>>,
72 event_handler: Option<EventHandlerCallback>,
73}
74
75impl std::fmt::Debug for Bot {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("Bot")
78 .field("client", &"<Client>")
79 .field("sync_task_receiver", &self.sync_task_receiver.is_some())
80 .field("event_handler", &self.event_handler.is_some())
81 .finish()
82 }
83}
84
85impl Bot {
86 pub fn builder() -> BotBuilder {
87 BotBuilder::new()
88 }
89
90 pub fn client(&self) -> Arc<Client> {
91 self.client.clone()
92 }
93
94 pub async fn run(&mut self) -> Result<task::JoinHandle<()>> {
95 if let Some(mut receiver) = self.sync_task_receiver.take() {
96 let worker_client = self.client.clone();
97 tokio::spawn(async move {
98 while let Some(task) = receiver.recv().await {
99 match task {
100 crate::sync_task::MajorSyncTask::HistorySync {
101 message_id,
102 notification,
103 } => {
104 worker_client
105 .process_history_sync_task(message_id, *notification)
106 .await;
107 }
108 crate::sync_task::MajorSyncTask::AppStateSync { name, full_sync } => {
109 if let Err(e) = worker_client
110 .process_app_state_sync_task(name, full_sync)
111 .await
112 {
113 warn!("App state sync task for {:?} failed: {}", name, e);
114 }
115 }
116 }
117 }
118 info!("Sync worker shutting down.");
119 });
120 }
121
122 let handler = Arc::new(BotEventHandler {
123 client: self.client.clone(),
124 event_handler: self.event_handler.take(),
125 });
126 self.client.core.event_bus.add_handler(handler);
127
128 let client_for_run = self.client.clone();
129 let client_handle = tokio::spawn(async move {
130 client_for_run.run().await;
131 });
132
133 Ok(client_handle)
134 }
135}
136
137#[derive(Default)]
138pub struct BotBuilder {
139 event_handler: Option<EventHandlerCallback>,
140 custom_enc_handlers: HashMap<String, Arc<dyn EncHandler>>,
141 device_id: Option<i32>,
142 backend: Option<Arc<dyn Backend>>,
144 override_version: Option<(u32, u32, u32)>,
145 os_info: Option<(Option<String>, Option<wa::device_props::AppVersion>)>,
146}
147
148impl BotBuilder {
149 fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn on_event<F, Fut>(mut self, handler: F) -> Self
154 where
155 F: Fn(Event, Arc<Client>) -> Fut + Send + Sync + 'static,
156 Fut: Future<Output = ()> + Send + 'static,
157 {
158 self.event_handler = Some(Arc::new(move |event, client| {
159 Box::pin(handler(event, client))
160 }));
161 self
162 }
163
164 pub fn with_enc_handler<H>(mut self, enc_type: impl Into<String>, handler: H) -> Self
173 where
174 H: EncHandler + 'static,
175 {
176 self.custom_enc_handlers
177 .insert(enc_type.into(), Arc::new(handler));
178 self
179 }
180
181 pub fn for_device(mut self, device_id: i32) -> Self {
184 self.device_id = Some(device_id);
185 self
186 }
187
188 pub fn with_backend(mut self, backend: Arc<dyn Backend>) -> Self {
203 self.backend = Some(backend);
204 self
205 }
206
207 pub fn with_version(mut self, version: (u32, u32, u32)) -> Self {
224 self.override_version = Some(version);
225 self
226 }
227
228 pub fn with_os_info(
261 mut self,
262 os_name: Option<String>,
263 version: Option<wa::device_props::AppVersion>,
264 ) -> Self {
265 self.os_info = Some((os_name, version));
266 self
267 }
268
269 pub async fn build(self) -> Result<Bot> {
270 let backend = self.backend.ok_or_else(|| {
271 anyhow::anyhow!(
272 "Backend is required. Use with_backend() to set a storage implementation."
273 )
274 })?;
275
276 let persistence_manager = if let Some(device_id) = self.device_id {
277 info!("Creating PersistenceManager for device ID: {}", device_id);
278 Arc::new(
279 PersistenceManager::new_for_device(device_id, backend)
280 .await
281 .map_err(|e| {
282 anyhow::anyhow!(
283 "Failed to create persistence manager for device {}: {}",
284 device_id,
285 e
286 )
287 })?,
288 )
289 } else {
290 info!("Creating PersistenceManager for single device mode");
291 Arc::new(
292 PersistenceManager::new(backend)
293 .await
294 .map_err(|e| anyhow::anyhow!("Failed to create persistence manager: {}", e))?,
295 )
296 };
297
298 persistence_manager
299 .clone()
300 .run_background_saver(std::time::Duration::from_secs(30));
301
302 spawn_preconnect_task().await;
303
304 crate::version::resolve_and_update_version(&persistence_manager, self.override_version)
305 .await;
306
307 if let Some((os_name, version)) = self.os_info {
309 info!("Applying OS info override: {:?} {:?}", os_name, version);
310 persistence_manager
311 .modify_device(|device| {
312 device.set_device_props(os_name, version);
313 })
314 .await;
315 }
316
317 info!("Creating client...");
318 let (client, sync_task_receiver) = Client::new(persistence_manager.clone()).await;
319
320 for (enc_type, handler) in self.custom_enc_handlers {
322 client.custom_enc_handlers.insert(enc_type, handler);
323 }
324
325 Ok(Bot {
326 client,
327 sync_task_receiver: Some(sync_task_receiver),
328 event_handler: self.event_handler,
329 })
330 }
331}
332
333async fn spawn_preconnect_task() {
334 if let Ok(uri) = crate::socket::consts::URL.parse::<Uri>() {
335 if let Some(host) = uri.host() {
336 let port = uri.port_u16().unwrap_or(443);
337 let address = format!("{}:{}", host, port);
338
339 debug!(target: "Client/Preconnect", "Starting pre-connect to {}", address);
340 if let Err(e) = TcpStream::connect(&address).await {
341 warn!(target: "Client/Preconnect", "Pre-connection to {} failed: {}", address, e);
342 } else {
343 debug!(target: "Client/Preconnect", "Pre-connection to {} successful.", address);
344 }
345 }
346 } else {
347 warn!(target: "Client/Preconnect", "Could not parse WA_URL for pre-connect task.");
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::store::sqlite_store::SqliteStore;
355
356 async fn create_test_sqlite_backend() -> Arc<dyn Backend> {
357 let temp_db = format!(
358 "file:memdb_bot_{}?mode=memory&cache=shared",
359 uuid::Uuid::new_v4()
360 );
361 Arc::new(
362 SqliteStore::new(&temp_db)
363 .await
364 .expect("Failed to create test SqliteStore"),
365 ) as Arc<dyn Backend>
366 }
367
368 #[tokio::test]
369 async fn test_bot_builder_single_device() {
370 let backend = create_test_sqlite_backend().await;
371
372 let bot = Bot::builder()
373 .with_backend(backend)
374 .build()
375 .await
376 .expect("Failed to build bot");
377
378 let client = bot.client();
379 let persistence_manager = client.persistence_manager();
380
381 assert_eq!(persistence_manager.device_id(), 1);
383 assert!(!persistence_manager.is_multi_account());
384 }
385
386 #[tokio::test]
387 async fn test_bot_builder_multi_device() {
388 let backend = create_test_sqlite_backend().await;
389
390 let mut device = wacore::store::Device::new();
392 device.push_name = "Test Device".to_string();
393 backend
394 .save_device_data_for_device(42, &device)
395 .await
396 .expect("Failed to save device data");
397
398 let bot = Bot::builder()
399 .with_backend(backend)
400 .for_device(42)
401 .build()
402 .await
403 .expect("Failed to build bot");
404
405 let client = bot.client();
406 let persistence_manager = client.persistence_manager();
407
408 assert_eq!(persistence_manager.device_id(), 42);
410 assert!(persistence_manager.is_multi_account());
411 }
412
413 #[tokio::test]
414 async fn test_bot_builder_with_custom_backend() {
415 let backend = create_test_sqlite_backend().await;
417 let bot = Bot::builder()
418 .with_backend(backend)
419 .build()
420 .await
421 .expect("Failed to build bot with custom backend");
422
423 let client = bot.client();
425 let persistence_manager = client.persistence_manager();
426
427 assert_eq!(persistence_manager.device_id(), 1);
429 }
430
431 #[tokio::test]
432 async fn test_bot_builder_with_custom_backend_specific_device() {
433 let backend = create_test_sqlite_backend().await;
435
436 let mut device = wacore::store::Device::new();
438 device.push_name = "Test Device".to_string();
439 backend
440 .save_device_data_for_device(100, &device)
441 .await
442 .expect("Failed to save device data");
443
444 let bot = Bot::builder()
446 .with_backend(backend)
447 .for_device(100)
448 .build()
449 .await
450 .expect("Failed to build bot with custom backend for specific device");
451
452 let client = bot.client();
454 let persistence_manager = client.persistence_manager();
455
456 assert_eq!(persistence_manager.device_id(), 100);
457 }
458
459 #[tokio::test]
460 async fn test_bot_builder_missing_backend() {
461 let result = Bot::builder().build().await;
463
464 assert!(result.is_err());
466 assert!(
467 result
468 .unwrap_err()
469 .to_string()
470 .contains("Backend is required")
471 );
472 }
473
474 #[tokio::test]
475 async fn test_bot_builder_with_version_override() {
476 let backend = create_test_sqlite_backend().await;
477
478 let bot = Bot::builder()
479 .with_backend(backend)
480 .with_version((2, 3000, 123456789))
481 .build()
482 .await
483 .expect("Failed to build bot with version override");
484
485 let client = bot.client();
487 let persistence_manager = client.persistence_manager();
488
489 let device_snapshot = persistence_manager.get_device_snapshot().await;
491 assert_eq!(device_snapshot.app_version_primary, 2);
492 assert_eq!(device_snapshot.app_version_secondary, 3000);
493 assert_eq!(device_snapshot.app_version_tertiary, 123456789);
494 }
495
496 #[tokio::test]
497 async fn test_bot_builder_with_os_info_override() {
498 let backend = create_test_sqlite_backend().await;
499
500 let custom_os = "CustomOS".to_string();
501 let custom_version = wa::device_props::AppVersion {
502 primary: Some(99),
503 secondary: Some(88),
504 tertiary: Some(77),
505 ..Default::default()
506 };
507
508 let bot = Bot::builder()
509 .with_backend(backend)
510 .with_os_info(Some(custom_os.clone()), Some(custom_version))
511 .build()
512 .await
513 .expect("Failed to build bot with OS info override");
514
515 let client = bot.client();
516 let persistence_manager = client.persistence_manager();
517 let device = persistence_manager.get_device_snapshot().await;
518
519 assert_eq!(device.device_props.os, Some(custom_os));
521 assert_eq!(device.device_props.version, Some(custom_version));
522 }
523
524 #[tokio::test]
525 async fn test_bot_builder_with_os_only_override() {
526 let backend = create_test_sqlite_backend().await;
527
528 let custom_os = "CustomOS".to_string();
529
530 let bot = Bot::builder()
531 .with_backend(backend)
532 .with_os_info(Some(custom_os.clone()), None)
533 .build()
534 .await
535 .expect("Failed to build bot with OS only override");
536
537 let client = bot.client();
538 let persistence_manager = client.persistence_manager();
539 let device = persistence_manager.get_device_snapshot().await;
540
541 assert_eq!(device.device_props.os, Some(custom_os));
543 assert_eq!(
545 device.device_props.version,
546 Some(wacore::store::Device::default_device_props_version())
547 );
548 }
549
550 #[tokio::test]
551 async fn test_bot_builder_with_version_only_override() {
552 let backend = create_test_sqlite_backend().await;
553
554 let custom_version = wa::device_props::AppVersion {
555 primary: Some(99),
556 secondary: Some(88),
557 tertiary: Some(77),
558 ..Default::default()
559 };
560
561 let bot = Bot::builder()
562 .with_backend(backend)
563 .with_os_info(None, Some(custom_version))
564 .build()
565 .await
566 .expect("Failed to build bot with version only override");
567
568 let client = bot.client();
569 let persistence_manager = client.persistence_manager();
570 let device = persistence_manager.get_device_snapshot().await;
571
572 assert_eq!(device.device_props.version, Some(custom_version));
574 assert_eq!(
576 device.device_props.os,
577 Some(wacore::store::Device::default_os().to_string())
578 );
579 }
580}