1use std::collections::{BTreeMap, HashMap};
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::sync::Arc;
22
23use serde_json::Value;
24use tokio::sync::{mpsc, Notify, RwLock};
25use tracing::{debug, error, info, warn};
26
27use rust_tg_bot_raw::types::update::Update;
28
29use crate::context::CallbackContext;
30use crate::context_types::{ContextTypes, DefaultData};
31use crate::ext_bot::ExtBot;
32#[cfg(feature = "job-queue")]
33use crate::job_queue::JobQueue;
34#[cfg(feature = "persistence")]
35use crate::persistence::base::{BasePersistence, PersistenceInput, PersistenceResult};
36use crate::update_processor::BaseUpdateProcessor;
37#[cfg(feature = "persistence")]
38use crate::utils::types::JsonMap;
39
40pub type HandlerCallback = Arc<
46 dyn Fn(
47 Arc<Update>,
48 CallbackContext,
49 ) -> Pin<Box<dyn Future<Output = Result<(), HandlerError>> + Send>>
50 + Send
51 + Sync,
52>;
53
54pub type ErrorHandlerCallback = Arc<
58 dyn Fn(Option<Arc<Update>>, CallbackContext) -> Pin<Box<dyn Future<Output = bool> + Send>>
59 + Send
60 + Sync,
61>;
62
63pub type LifecycleHook =
65 Arc<dyn Fn(Arc<Application>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
66
67#[cfg(feature = "persistence")]
68type PersistenceDataMap = HashMap<i64, JsonMap>;
69#[cfg(feature = "persistence")]
70type PersistenceFuture<'a, T> = Pin<Box<dyn Future<Output = PersistenceResult<T>> + Send + 'a>>;
71
72#[derive(Debug, thiserror::Error)]
78#[non_exhaustive]
79pub enum HandlerError {
80 #[error("ApplicationHandlerStop")]
82 HandlerStop {
83 state: Option<Value>,
85 },
86
87 #[error("{0}")]
89 Other(Box<dyn std::error::Error + Send + Sync>),
90}
91
92impl From<rust_tg_bot_raw::error::TelegramError> for HandlerError {
93 fn from(e: rust_tg_bot_raw::error::TelegramError) -> Self {
94 HandlerError::Other(Box::new(e))
95 }
96}
97
98#[derive(Debug, thiserror::Error)]
100#[non_exhaustive]
101pub enum ApplicationError {
102 #[error("This Application was not initialized via `Application::initialize`")]
104 NotInitialized,
105
106 #[error("This Application is already running")]
108 AlreadyRunning,
109
110 #[error("This Application is not running")]
112 NotRunning,
113
114 #[error("This Application is still running")]
116 StillRunning,
117
118 #[error("{0}")]
120 Bot(#[from] rust_tg_bot_raw::error::TelegramError),
121
122 #[error("{0}")]
124 UpdateProcessor(#[from] crate::update_processor::UpdateProcessorError),
125
126 #[error("webhook error: {0}")]
128 Webhook(String),
129}
130
131#[cfg(feature = "persistence")]
140pub trait DynPersistence: Send + Sync + std::fmt::Debug {
141 fn get_user_data(&self) -> PersistenceFuture<'_, PersistenceDataMap>;
143 fn get_chat_data(&self) -> PersistenceFuture<'_, PersistenceDataMap>;
145 fn get_bot_data(&self) -> PersistenceFuture<'_, JsonMap>;
147 fn update_user_data(&self, user_id: i64, data: JsonMap) -> PersistenceFuture<'_, ()>;
149 fn update_chat_data(&self, chat_id: i64, data: JsonMap) -> PersistenceFuture<'_, ()>;
151 fn update_bot_data(&self, data: JsonMap) -> PersistenceFuture<'_, ()>;
153 fn flush(&self) -> PersistenceFuture<'_, ()>;
155 fn update_interval(&self) -> f64;
157 fn store_data(&self) -> PersistenceInput;
159}
160
161#[cfg(feature = "persistence")]
162impl<T: BasePersistence + std::fmt::Debug> DynPersistence for T {
163 fn get_user_data(&self) -> PersistenceFuture<'_, PersistenceDataMap> {
164 Box::pin(BasePersistence::get_user_data(self))
165 }
166 fn get_chat_data(&self) -> PersistenceFuture<'_, PersistenceDataMap> {
167 Box::pin(BasePersistence::get_chat_data(self))
168 }
169 fn get_bot_data(&self) -> PersistenceFuture<'_, JsonMap> {
170 Box::pin(BasePersistence::get_bot_data(self))
171 }
172 fn update_user_data(&self, user_id: i64, data: JsonMap) -> PersistenceFuture<'_, ()> {
173 Box::pin(async move { BasePersistence::update_user_data(self, user_id, &data).await })
174 }
175 fn update_chat_data(&self, chat_id: i64, data: JsonMap) -> PersistenceFuture<'_, ()> {
176 Box::pin(async move { BasePersistence::update_chat_data(self, chat_id, &data).await })
177 }
178 fn update_bot_data(&self, data: JsonMap) -> PersistenceFuture<'_, ()> {
179 Box::pin(async move { BasePersistence::update_bot_data(self, &data).await })
180 }
181 fn flush(&self) -> PersistenceFuture<'_, ()> {
182 Box::pin(BasePersistence::flush(self))
183 }
184 fn update_interval(&self) -> f64 {
185 BasePersistence::update_interval(self)
186 }
187 fn store_data(&self) -> PersistenceInput {
188 BasePersistence::store_data(self)
189 }
190}
191
192#[cfg(feature = "persistence")]
196pub fn boxed_persistence<T: BasePersistence + std::fmt::Debug + 'static>(
197 p: T,
198) -> Box<dyn DynPersistence> {
199 Box::new(p)
200}
201
202pub struct Handler {
208 pub check_update: Arc<dyn Fn(&Update) -> bool + Send + Sync>,
210 pub callback: HandlerCallback,
212 pub block: bool,
214}
215
216impl std::fmt::Debug for Handler {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("Handler")
219 .field("block", &self.block)
220 .finish()
221 }
222}
223
224pub struct Application {
230 bot: Arc<ExtBot>,
231 #[allow(dead_code)]
232 context_types: ContextTypes,
233 update_processor: Arc<BaseUpdateProcessor>,
234
235 handlers: RwLock<BTreeMap<i32, Vec<Handler>>>,
236 error_handlers: RwLock<Vec<(ErrorHandlerCallback, bool)>>,
237
238 user_data: Arc<RwLock<HashMap<i64, DefaultData>>>,
239 chat_data: Arc<RwLock<HashMap<i64, DefaultData>>>,
240 bot_data: Arc<RwLock<DefaultData>>,
241
242 #[cfg(feature = "persistence")]
243 persistence: Option<Box<dyn DynPersistence>>,
244 #[cfg(feature = "job-queue")]
245 job_queue: Option<Arc<JobQueue>>,
246 pending_tasks: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
247
248 initialized: AtomicBool,
249 running: AtomicBool,
250
251 update_tx: mpsc::Sender<Update>,
252 update_rx: RwLock<Option<mpsc::Receiver<Update>>>,
253 stop_notify: Arc<Notify>,
254
255 post_init: Option<LifecycleHook>,
256 post_stop: Option<LifecycleHook>,
257 post_shutdown: Option<LifecycleHook>,
258}
259
260impl std::fmt::Debug for Application {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("Application")
263 .field("bot_token", &self.bot.token())
264 .finish()
265 }
266}
267pub const DEFAULT_GROUP: i32 = 0;
269
270pub(crate) struct ApplicationConfig {
271 pub(crate) bot: Arc<ExtBot>,
272 pub(crate) context_types: ContextTypes,
273 pub(crate) update_processor: Arc<BaseUpdateProcessor>,
274 pub(crate) post_init: Option<LifecycleHook>,
275 pub(crate) post_stop: Option<LifecycleHook>,
276 pub(crate) post_shutdown: Option<LifecycleHook>,
277 #[cfg(feature = "persistence")]
278 pub(crate) persistence: Option<Box<dyn DynPersistence>>,
279 #[cfg(feature = "job-queue")]
280 pub(crate) job_queue: Option<Arc<JobQueue>>,
281}
282
283impl ApplicationConfig {
284 pub(crate) fn new(
285 bot: Arc<ExtBot>,
286 context_types: ContextTypes,
287 update_processor: Arc<BaseUpdateProcessor>,
288 ) -> Self {
289 Self {
290 bot,
291 context_types,
292 update_processor,
293 post_init: None,
294 post_stop: None,
295 post_shutdown: None,
296 #[cfg(feature = "persistence")]
297 persistence: None,
298 #[cfg(feature = "job-queue")]
299 job_queue: None,
300 }
301 }
302}
303
304impl Application {
305 #[must_use]
310 pub(crate) fn new(config: ApplicationConfig) -> Arc<Self> {
311 let ApplicationConfig {
312 bot,
313 context_types,
314 update_processor,
315 post_init,
316 post_stop,
317 post_shutdown,
318 #[cfg(feature = "persistence")]
319 persistence,
320 #[cfg(feature = "job-queue")]
321 job_queue,
322 } = config;
323 let (tx, rx) = mpsc::channel(64);
324 let bot_data_initial = context_types.bot_data();
325 Arc::new(Self {
326 bot,
327 context_types,
328 update_processor,
329 handlers: RwLock::new(BTreeMap::new()),
330 error_handlers: RwLock::new(Vec::new()),
331 user_data: Arc::new(RwLock::new(HashMap::new())),
332 chat_data: Arc::new(RwLock::new(HashMap::new())),
333 bot_data: Arc::new(RwLock::new(bot_data_initial)),
334 #[cfg(feature = "persistence")]
335 persistence,
336 #[cfg(feature = "job-queue")]
337 job_queue,
338 pending_tasks: Arc::new(RwLock::new(Vec::new())),
339 initialized: AtomicBool::new(false),
340 running: AtomicBool::new(false),
341 update_tx: tx,
342 update_rx: RwLock::new(Some(rx)),
343 stop_notify: Arc::new(Notify::new()),
344 post_init,
345 post_stop,
346 post_shutdown,
347 })
348 }
349
350 #[must_use]
352 pub fn bot(&self) -> &Arc<ExtBot> {
354 &self.bot
355 }
356 pub fn is_initialized(&self) -> bool {
358 self.initialized.load(Ordering::Acquire)
359 }
360 pub fn is_running(&self) -> bool {
362 self.running.load(Ordering::Acquire)
363 }
364 #[must_use]
365 pub fn concurrent_updates(&self) -> usize {
367 self.update_processor.max_concurrent_updates()
368 }
369 #[must_use]
370 pub fn user_data(&self) -> &Arc<RwLock<HashMap<i64, DefaultData>>> {
372 &self.user_data
373 }
374 #[must_use]
375 pub fn chat_data(&self) -> &Arc<RwLock<HashMap<i64, DefaultData>>> {
377 &self.chat_data
378 }
379 #[must_use]
380 pub fn bot_data(&self) -> &Arc<RwLock<DefaultData>> {
382 &self.bot_data
383 }
384 #[must_use]
385 pub fn update_sender(&self) -> mpsc::Sender<Update> {
387 self.update_tx.clone()
388 }
389 #[must_use]
390 #[cfg(feature = "job-queue")]
394 pub fn job_queue(&self) -> Option<&Arc<JobQueue>> {
395 self.job_queue.as_ref()
396 }
397
398 pub async fn initialize(&self) -> Result<(), ApplicationError> {
402 if self.initialized.load(Ordering::Acquire) {
403 debug!("This Application is already initialized.");
404 return Ok(());
405 }
406
407 self.bot.initialize().await?;
408 self.update_processor.initialize().await;
409
410 #[cfg(feature = "persistence")]
412 if let Some(ref persistence) = self.persistence {
413 let sd = persistence.store_data();
414 if sd.user_data {
415 if let Ok(data) = persistence.get_user_data().await {
416 *self.user_data.write().await = data;
417 }
418 }
419 if sd.chat_data {
420 if let Ok(data) = persistence.get_chat_data().await {
421 *self.chat_data.write().await = data;
422 }
423 }
424 if sd.bot_data {
425 if let Ok(data) = persistence.get_bot_data().await {
426 *self.bot_data.write().await = data;
427 }
428 }
429 }
430
431 #[cfg(feature = "job-queue")]
433 if let Some(ref jq) = self.job_queue {
434 jq.start().await;
435 }
436
437 self.initialized.store(true, Ordering::Release);
438 Ok(())
439 }
440
441 pub async fn shutdown(&self) -> Result<(), ApplicationError> {
444 if self.running.load(Ordering::Acquire) {
445 return Err(ApplicationError::StillRunning);
446 }
447 if !self.initialized.load(Ordering::Acquire) {
448 debug!("This Application is already shut down.");
449 return Ok(());
450 }
451
452 #[cfg(feature = "persistence")]
454 if let Some(ref persistence) = self.persistence {
455 if let Err(e) = persistence.flush().await {
456 error!("Failed to flush persistence: {e}");
457 }
458 }
459
460 self.bot.shutdown().await?;
461 self.update_processor.shutdown().await;
462 self.initialized.store(false, Ordering::Release);
463 Ok(())
464 }
465
466 pub async fn start(self: &Arc<Self>) -> Result<(), ApplicationError> {
469 if self.running.load(Ordering::Acquire) {
470 return Err(ApplicationError::AlreadyRunning);
471 }
472 self.check_initialized()?;
473 self.running.store(true, Ordering::Release);
474
475 #[cfg(feature = "job-queue")]
479 if let Some(ref jq) = self.job_queue {
480 let app_weak: std::sync::Weak<Application> = Arc::downgrade(self);
481
482 let weak_complete = app_weak.clone();
484 jq.set_on_job_complete(Arc::new(move || {
485 let weak = weak_complete.clone();
486 Box::pin(async move {
487 if let Some(app) = weak.upgrade() {
488 app.update_persistence().await;
489 }
490 })
491 }))
492 .await;
493
494 let weak_error = app_weak;
496 jq.set_on_job_error(Arc::new(
497 move |err: Box<dyn std::error::Error + Send + Sync>| {
498 let weak = weak_error.clone();
499 Box::pin(async move {
500 if let Some(app) = weak.upgrade() {
501 app.process_error(None, err).await;
502 }
503 })
504 },
505 ))
506 .await;
507 }
508
509 let rx = { self.update_rx.write().await.take() };
510 if let Some(mut rx) = rx {
511 let app = Arc::clone(self);
512 tokio::spawn(async move {
513 loop {
514 tokio::select! {
515 Some(update) = rx.recv() => {
516 let update = Arc::new(update);
517 debug!("Processing update");
518 let app2 = Arc::clone(&app);
519 let up = app.update_processor.clone();
520 let update_clone = Arc::clone(&update);
521 let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
522 if let Err(e) = app2.process_update(update_clone).await { error!("Error processing update: {e}"); }
523 });
524 if app.update_processor.max_concurrent_updates() > 1 {
525 tokio::spawn(async move { up.process_update(update, fut).await; });
526 } else {
527 up.process_update(update, fut).await;
528 }
529 }
530 _ = app.stop_notify.notified() => { debug!("Update fetcher received stop signal"); break; }
531 }
532 }
533 info!("Update fetcher stopped");
534 });
535 }
536 info!("Application started");
537 Ok(())
538 }
539
540 pub async fn stop(&self) -> Result<(), ApplicationError> {
542 if !self.running.load(Ordering::Acquire) {
543 return Err(ApplicationError::NotRunning);
544 }
545 info!("Application is stopping. This might take a moment.");
546 self.stop_notify.notify_waiters();
547
548 #[cfg(feature = "job-queue")]
549 if let Some(ref jq) = self.job_queue {
550 jq.stop().await;
551 }
552
553 {
555 let mut tasks = self.pending_tasks.write().await;
556 let handles: Vec<_> = tasks.drain(..).collect();
557 drop(tasks);
558 if !handles.is_empty() {
559 debug!("Waiting for {} pending tasks", handles.len());
560 let _ = tokio::time::timeout(
561 std::time::Duration::from_secs(5),
562 futures_join_all(handles),
563 )
564 .await;
565 }
566 }
567
568 self.running.store(false, Ordering::Release);
569 info!("Application.stop() complete");
570 Ok(())
571 }
572 pub fn stop_running(&self) {
574 self.stop_notify.notify_waiters();
575 }
576
577 pub async fn create_task(&self, future: impl Future<Output = ()> + Send + 'static) {
585 let handle = tokio::spawn(future);
586 self.pending_tasks.write().await.push(handle);
587 }
588
589 pub async fn update_persistence(&self) {
592 #[cfg(feature = "persistence")]
593 {
594 let persistence = match self.persistence.as_ref() {
595 Some(p) => p,
596 None => return,
597 };
598 let sd = persistence.store_data();
599 if sd.user_data {
600 for (uid, data) in self.user_data.read().await.iter() {
601 let _ = persistence.update_user_data(*uid, data.clone()).await;
602 }
603 }
604 if sd.chat_data {
605 for (cid, data) in self.chat_data.read().await.iter() {
606 let _ = persistence.update_chat_data(*cid, data.clone()).await;
607 }
608 }
609 if sd.bot_data {
610 let _ = persistence
611 .update_bot_data(self.bot_data.read().await.clone())
612 .await;
613 }
614 }
615 }
616
617 pub async fn run_polling(self: Arc<Self>) -> Result<(), ApplicationError> {
623 self.run_polling_configured(
624 std::time::Duration::ZERO,
625 std::time::Duration::from_secs(10),
626 None,
627 false,
628 )
629 .await
630 }
631
632 #[must_use]
644 pub fn polling(self: &Arc<Self>) -> PollingBuilder {
645 PollingBuilder::new(Arc::clone(self))
646 }
647
648 pub(crate) async fn run_polling_configured(
652 self: Arc<Self>,
653 poll_interval: std::time::Duration,
654 timeout: std::time::Duration,
655 allowed_updates: Option<Vec<String>>,
656 drop_pending_updates: bool,
657 ) -> Result<(), ApplicationError> {
658 self.initialize().await?;
659 if let Some(ref hook) = self.post_init {
660 hook(Arc::clone(&self)).await;
661 }
662 self.start().await?;
663
664 #[cfg(feature = "persistence")]
666 let persistence_handle = if let Some(persistence) = self.persistence.as_ref() {
667 let secs = persistence.update_interval();
668 let app = Arc::clone(&self);
669 let stop = Arc::clone(&self.stop_notify);
670 Some(tokio::spawn(async move {
671 let mut iv = tokio::time::interval(std::time::Duration::from_secs_f64(secs));
672 iv.tick().await;
673 loop {
674 tokio::select! { _ = iv.tick() => { app.update_persistence().await; } _ = stop.notified() => { break; } }
675 }
676 }))
677 } else {
678 None
679 };
680
681 let bot = Arc::clone(&self.bot);
682 let tx = self.update_tx.clone();
683 let stop = Arc::clone(&self.stop_notify);
684 let allowed = allowed_updates;
685
686 let poll_handle = tokio::spawn(async move {
687 let mut offset: Option<i64> = None;
688 if drop_pending_updates {
689 if let Ok(updates) = bot
690 .inner()
691 .get_updates_raw(Some(-1), Some(1), Some(0), None)
692 .await
693 {
694 if let Some(last) = updates.last() {
695 offset = Some(last.update_id + 1);
696 }
697 }
698 }
699 let timeout_secs = timeout.as_secs().max(1) as i32;
700 loop {
701 tokio::select! {
702 result = bot.inner().get_updates_raw(offset, Some(100), Some(timeout_secs), allowed.clone()) => {
703 match result {
704 Ok(updates) => {
705 for update in updates {
706 offset = Some(update.update_id + 1);
707 let _ = tx.send(update).await;
708 }
709 }
710 Err(e) => { error!("Error fetching updates: {e}"); tokio::time::sleep(std::time::Duration::from_secs(1)).await; }
711 }
712 }
713 _ = stop.notified() => { return; }
714 }
715 if !poll_interval.is_zero() {
716 tokio::time::sleep(poll_interval).await;
717 }
718 }
719 });
720
721 info!("Application is running. Press Ctrl+C to stop.");
722 tokio::select! {
723 _ = tokio::signal::ctrl_c() => { info!("Received Ctrl+C, shutting down..."); }
724 _ = self.stop_notify.notified() => { info!("Received stop signal"); }
725 }
726
727 self.stop_notify.notify_waiters();
728 let _ = poll_handle.await;
729 #[cfg(feature = "persistence")]
730 if let Some(ph) = persistence_handle {
731 let _ = ph.await;
732 }
733 if self.running.load(Ordering::Acquire) {
734 self.stop().await?;
735 }
736 if let Some(ref hook) = self.post_stop {
737 hook(Arc::clone(&self)).await;
738 }
739 self.shutdown().await?;
740 if let Some(ref hook) = self.post_shutdown {
741 hook(Arc::clone(&self)).await;
742 }
743 Ok(())
744 }
745
746 #[cfg(feature = "webhooks")]
748 pub async fn run_webhook(
749 self: Arc<Self>,
750 config: crate::updater::WebhookConfig,
751 ) -> Result<(), ApplicationError> {
752 use crate::utils::webhook_handler::WebhookServer;
753
754 self.initialize().await?;
755 if let Some(ref hook) = self.post_init {
756 hook(Arc::clone(&self)).await;
757 }
758 self.start().await?;
759
760 #[cfg(feature = "persistence")]
761 let persistence_handle = if self.persistence.is_some() {
762 let secs = self.persistence.as_ref().unwrap().update_interval();
763 let app = Arc::clone(&self);
764 let stop = Arc::clone(&self.stop_notify);
765 Some(tokio::spawn(async move {
766 let mut iv = tokio::time::interval(std::time::Duration::from_secs_f64(secs));
767 iv.tick().await;
768 loop {
769 tokio::select! { _ = iv.tick() => { app.update_persistence().await; } _ = stop.notified() => { break; } }
770 }
771 }))
772 } else {
773 None
774 };
775
776 #[cfg(feature = "webhooks-tls")]
778 let tls_config = if config.has_tls() {
779 let cert_path = config
780 .cert_path
781 .as_deref()
782 .expect("cert_path checked by has_tls");
783 let key_path = config
784 .key_path
785 .as_deref()
786 .expect("key_path checked by has_tls");
787 match crate::utils::webhook_handler::TlsConfig::from_pem_files(cert_path, key_path)
788 .await
789 {
790 Ok(tls) => Some(tls),
791 Err(e) => {
792 return Err(ApplicationError::Webhook(format!(
793 "TLS configuration failed: {e}"
794 )));
795 }
796 }
797 } else {
798 None
799 };
800
801 #[cfg(not(feature = "webhooks-tls"))]
803 if config.has_tls() {
804 warn!(
805 "TLS cert_path/key_path are set but the webhooks-tls feature is not enabled. \
806 The server will start without TLS. Enable the webhooks-tls feature to use HTTPS."
807 );
808 }
809
810 if let Some(ref url) = config.webhook_url {
812 let mut builder = self.bot.set_webhook(url);
813 if let Some(ref token) = config.secret_token {
814 builder = builder.secret_token(token.clone());
815 }
816 if config.drop_pending_updates {
817 builder = builder.drop_pending_updates(true);
818 }
819 if let Some(ref allowed) = config.allowed_updates {
820 builder = builder.allowed_updates(allowed.clone());
821 }
822 #[cfg(feature = "webhooks-tls")]
825 if let Some(ref cert) = config.cert_path {
826 use rust_tg_bot_raw::types::files::input_file::InputFile;
827 builder = builder.certificate(InputFile::path(cert));
828 }
829 if let Err(e) = builder.await {
830 error!("Failed to set webhook: {e}");
831 return Err(ApplicationError::NotInitialized);
832 }
833 info!("Webhook set to {url}");
834 }
835
836 let server = Arc::new(WebhookServer::new(
837 &config.listen,
838 config.port,
839 &config.url_path,
840 self.update_tx.clone(),
841 config.secret_token.clone(),
842 #[cfg(feature = "webhooks-tls")]
843 tls_config,
844 ));
845 let ready = Arc::new(Notify::new());
846 let rc = ready.clone();
847 let srv = server.clone();
848 let wh = tokio::spawn(async move {
849 if let Err(e) = srv.serve_forever(Some(rc)).await {
850 error!("Webhook server error: {e}");
851 }
852 });
853 ready.notified().await;
854 info!(
855 "Webhook server started on {}:{}",
856 config.listen, config.port
857 );
858
859 info!("Application is running via webhook. Press Ctrl+C to stop.");
860 tokio::select! {
861 _ = tokio::signal::ctrl_c() => { info!("Received Ctrl+C"); }
862 _ = self.stop_notify.notified() => { info!("Received stop signal"); }
863 }
864
865 self.stop_notify.notify_waiters();
866 server.shutdown();
867 let _ = wh.await;
868 #[cfg(feature = "persistence")]
869 if let Some(ph) = persistence_handle {
870 let _ = ph.await;
871 }
872 if self.running.load(Ordering::Acquire) {
873 self.stop().await?;
874 }
875 if let Some(ref hook) = self.post_stop {
876 hook(Arc::clone(&self)).await;
877 }
878 self.shutdown().await?;
879 if let Some(ref hook) = self.post_shutdown {
880 hook(Arc::clone(&self)).await;
881 }
882 Ok(())
883 }
884
885 pub async fn add_raw_handler(&self, handler: Handler, group: i32) {
891 self.handlers
892 .write()
893 .await
894 .entry(group)
895 .or_default()
896 .push(handler);
897 }
898 pub async fn add_raw_handlers(&self, new_handlers: Vec<Handler>, group: i32) {
900 self.handlers
901 .write()
902 .await
903 .entry(group)
904 .or_default()
905 .extend(new_handlers);
906 }
907 pub async fn remove_handler(&self, group: i32, index: usize) -> Option<Handler> {
909 let mut handlers = self.handlers.write().await;
910 if let Some(gh) = handlers.get_mut(&group) {
911 if index < gh.len() {
912 let removed = gh.remove(index);
913 if gh.is_empty() {
914 handlers.remove(&group);
915 }
916 return Some(removed);
917 }
918 }
919 None
920 }
921 pub async fn add_error_handler(&self, callback: ErrorHandlerCallback, block: bool) {
923 self.error_handlers.write().await.push((callback, block));
924 }
925
926 pub async fn add_handler(
947 &self,
948 handler: impl crate::handlers::base::Handler + 'static,
949 group: i32,
950 ) {
951 use crate::handlers::base::HandlerResult as TraitHandlerResult;
952
953 let handler = Arc::new(handler);
954
955 let check_handler = Arc::clone(&handler);
956 let callback_handler = Arc::clone(&handler);
957 let bot = Arc::clone(&self.bot);
958 let user_data = Arc::clone(&self.user_data);
959 let chat_data = Arc::clone(&self.chat_data);
960 let bot_data_ref = Arc::clone(&self.bot_data);
961 #[cfg(feature = "job-queue")]
962 let job_queue = self.job_queue.clone();
963
964 let legacy = Handler {
965 check_update: Arc::new(move |update: &Update| {
966 check_handler.check_update(update).is_some()
967 }),
968 callback: Arc::new(move |update: Arc<Update>, _ctx: CallbackContext| {
969 let h = Arc::clone(&callback_handler);
970 let bot = Arc::clone(&bot);
971 let ud = Arc::clone(&user_data);
972 let cd = Arc::clone(&chat_data);
973 let bd = Arc::clone(&bot_data_ref);
974 #[cfg(feature = "job-queue")]
975 let jq = job_queue.clone();
976 Box::pin(async move {
977 let match_result = h
978 .check_update(&update)
979 .unwrap_or(crate::handlers::base::MatchResult::Empty);
980
981 let mut ctx = CallbackContext::from_update(&update, bot, ud, cd, bd);
983 #[cfg(feature = "job-queue")]
984 if let Some(jq) = jq {
985 ctx = ctx.with_job_queue(jq);
986 }
987
988 h.collect_additional_context(&mut ctx, &match_result);
990
991 match h
993 .handle_update_with_context(update, match_result, ctx)
994 .await
995 {
996 TraitHandlerResult::Continue => Ok(()),
997 TraitHandlerResult::Stop => Err(HandlerError::HandlerStop { state: None }),
998 TraitHandlerResult::Error(e) => Err(HandlerError::Other(e)),
999 }
1000 }) as Pin<Box<dyn Future<Output = Result<(), HandlerError>> + Send>>
1001 }),
1002 block: handler.block(),
1003 };
1004
1005 self.add_raw_handler(legacy, group).await;
1006 }
1007
1008 pub async fn process_update(&self, update: Arc<Update>) -> Result<(), ApplicationError> {
1011 self.check_initialized()?;
1012 let mut context: Option<CallbackContext> = None;
1013 let groups: Vec<(i32, Vec<usize>)> = {
1014 let h = self.handlers.read().await;
1015 h.iter()
1016 .map(|(g, hs)| (*g, (0..hs.len()).collect()))
1017 .collect()
1018 };
1019 for (gid, indices) in &groups {
1020 let guard = self.handlers.read().await;
1021 let group = match guard.get(gid) {
1022 Some(g) => g,
1023 None => continue,
1024 };
1025 for &idx in indices {
1026 let handler = match group.get(idx) {
1027 Some(h) => h,
1028 None => continue,
1029 };
1030 if !(handler.check_update)(&update) {
1031 continue;
1032 }
1033 if context.is_none() {
1034 #[allow(unused_mut)]
1035 let mut ctx = CallbackContext::from_update(
1036 &update,
1037 Arc::clone(&self.bot),
1038 Arc::clone(&self.user_data),
1039 Arc::clone(&self.chat_data),
1040 Arc::clone(&self.bot_data),
1041 );
1042 #[cfg(feature = "job-queue")]
1043 if let Some(ref jq) = self.job_queue {
1044 ctx = ctx.with_job_queue(Arc::clone(jq));
1045 }
1046 context = Some(ctx);
1047 }
1048 let ctx = context.clone().unwrap();
1049 let cb = Arc::clone(&handler.callback);
1050 let uc = Arc::clone(&update);
1051 if handler.block {
1052 match cb(uc, ctx).await {
1053 Ok(()) => {}
1054 Err(HandlerError::HandlerStop { .. }) => {
1055 return Ok(());
1056 }
1057 Err(HandlerError::Other(e)) => {
1058 if self.process_error(Some(Arc::clone(&update)), e).await {
1059 return Ok(());
1060 }
1061 }
1062 }
1063 } else {
1064 let tasks = Arc::clone(&self.pending_tasks);
1065 let handle = tokio::spawn(async move {
1066 if let Err(e) = cb(uc, ctx).await {
1067 warn!("Non-blocking handler error: {e}");
1068 }
1069 });
1070 tasks.write().await.push(handle);
1071 }
1072 break;
1073 }
1074 drop(guard);
1075 }
1076 Ok(())
1077 }
1078
1079 pub async fn process_error(
1081 &self,
1082 update: Option<Arc<Update>>,
1083 error: Box<dyn std::error::Error + Send + Sync>,
1084 ) -> bool {
1085 let handlers = self.error_handlers.read().await;
1086 if handlers.is_empty() {
1087 error!("No error handlers registered: {error}");
1088 return false;
1089 }
1090 let error_arc: Arc<dyn std::error::Error + Send + Sync> = Arc::from(error);
1091 for (callback, block) in handlers.iter() {
1092 let ctx = CallbackContext::from_error(
1093 update.as_deref(),
1094 Arc::clone(&error_arc),
1095 Arc::clone(&self.bot),
1096 Arc::clone(&self.user_data),
1097 Arc::clone(&self.chat_data),
1098 Arc::clone(&self.bot_data),
1099 );
1100 #[cfg(feature = "job-queue")]
1101 if let Some(ref jq) = self.job_queue {
1102 ctx = ctx.with_job_queue(Arc::clone(jq));
1103 }
1104 if *block {
1105 if callback(update.clone(), ctx).await {
1106 return true;
1107 }
1108 } else {
1109 let cb = Arc::clone(callback);
1110 let upd = update.clone();
1111 tokio::spawn(async move {
1112 cb(upd, ctx).await;
1113 });
1114 }
1115 }
1116 false
1117 }
1118
1119 pub async fn drop_chat_data(&self, chat_id: i64) {
1122 self.chat_data.write().await.remove(&chat_id);
1123 }
1124
1125 pub async fn drop_user_data(&self, user_id: i64) {
1127 self.user_data.write().await.remove(&user_id);
1128 }
1129
1130 pub async fn migrate_chat_data(&self, old: i64, new: i64) {
1132 let mut s = self.chat_data.write().await;
1133 if let Some(d) = s.remove(&old) {
1134 s.insert(new, d);
1135 }
1136 }
1137
1138 fn check_initialized(&self) -> Result<(), ApplicationError> {
1139 if !self.initialized.load(Ordering::Acquire) {
1140 return Err(ApplicationError::NotInitialized);
1141 }
1142 Ok(())
1143 }
1144}
1145
1146#[derive(Debug)]
1173pub struct PollingBuilder {
1174 app: Arc<Application>,
1175 poll_interval: std::time::Duration,
1176 timeout: std::time::Duration,
1177 allowed_updates: Option<Vec<String>>,
1178 drop_pending_updates: bool,
1179}
1180
1181impl PollingBuilder {
1182 fn new(app: Arc<Application>) -> Self {
1183 Self {
1184 app,
1185 poll_interval: std::time::Duration::ZERO,
1186 timeout: std::time::Duration::from_secs(10),
1187 allowed_updates: None,
1188 drop_pending_updates: false,
1189 }
1190 }
1191
1192 #[must_use]
1194 pub fn poll_interval(mut self, interval: std::time::Duration) -> Self {
1195 self.poll_interval = interval;
1196 self
1197 }
1198
1199 #[must_use]
1201 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
1202 self.timeout = timeout;
1203 self
1204 }
1205
1206 #[must_use]
1208 pub fn allowed_updates(mut self, updates: Vec<String>) -> Self {
1209 self.allowed_updates = Some(updates);
1210 self
1211 }
1212
1213 #[must_use]
1215 pub fn drop_pending(mut self, drop: bool) -> Self {
1216 self.drop_pending_updates = drop;
1217 self
1218 }
1219
1220 pub async fn start(self) -> Result<(), ApplicationError> {
1222 self.app
1223 .run_polling_configured(
1224 self.poll_interval,
1225 self.timeout,
1226 self.allowed_updates,
1227 self.drop_pending_updates,
1228 )
1229 .await
1230 }
1231}
1232
1233async fn futures_join_all(handles: Vec<tokio::task::JoinHandle<()>>) {
1234 for h in handles {
1235 let _ = h.await;
1236 }
1237}
1238
1239#[cfg(test)]
1240mod tests {
1241 use super::*;
1242 use crate::ext_bot::test_support::mock_request;
1243 use rust_tg_bot_raw::bot::Bot;
1244
1245 fn make_app() -> Arc<Application> {
1246 let bot = Bot::new("test_token", mock_request());
1247 let ext_bot = Arc::new(ExtBot::from_bot(bot));
1248 let processor = Arc::new(crate::update_processor::simple_processor(1).unwrap());
1249 Application::new(ApplicationConfig::new(
1250 ext_bot,
1251 ContextTypes::default(),
1252 processor,
1253 ))
1254 }
1255
1256 fn make_update(json_val: serde_json::Value) -> Update {
1257 serde_json::from_value(json_val).unwrap()
1258 }
1259
1260 #[tokio::test]
1261 async fn initialize_and_shutdown() {
1262 let app = make_app();
1263 assert!(!app.is_initialized());
1264 app.initialize().await.unwrap();
1265 assert!(app.is_initialized());
1266 app.initialize().await.unwrap();
1267 app.shutdown().await.unwrap();
1268 assert!(!app.is_initialized());
1269 }
1270
1271 #[tokio::test]
1272 async fn shutdown_while_running_errors() {
1273 let app = make_app();
1274 app.initialize().await.unwrap();
1275 app.start().await.unwrap();
1276 assert!(app.shutdown().await.is_err());
1277 app.stop().await.unwrap();
1278 app.shutdown().await.unwrap();
1279 }
1280
1281 #[tokio::test]
1282 async fn add_and_process_handler() {
1283 let app = make_app();
1284 app.initialize().await.unwrap();
1285 let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
1286 let c2 = called.clone();
1287 app.add_raw_handler(
1288 Handler {
1289 check_update: Arc::new(|u| u.message().is_some()),
1290 callback: Arc::new(move |_, _| {
1291 let c = c2.clone();
1292 Box::pin(async move {
1293 c.store(true, std::sync::atomic::Ordering::SeqCst);
1294 Ok(())
1295 })
1296 }),
1297 block: true,
1298 },
1299 DEFAULT_GROUP,
1300 )
1301 .await;
1302 app.process_update(Arc::new(make_update(serde_json::json!({"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello"}})))).await.unwrap();
1303 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
1304 }
1305
1306 #[tokio::test]
1307 async fn handler_groups_priority() {
1308 let app = make_app();
1309 app.initialize().await.unwrap();
1310 let order = Arc::new(RwLock::new(Vec::new()));
1311 let o1 = order.clone();
1312 app.add_raw_handler(
1313 Handler {
1314 check_update: Arc::new(|_| true),
1315 callback: Arc::new(move |_, _| {
1316 let o = o1.clone();
1317 Box::pin(async move {
1318 o.write().await.push(1);
1319 Ok(())
1320 })
1321 }),
1322 block: true,
1323 },
1324 1,
1325 )
1326 .await;
1327 let o0 = order.clone();
1328 app.add_raw_handler(
1329 Handler {
1330 check_update: Arc::new(|_| true),
1331 callback: Arc::new(move |_, _| {
1332 let o = o0.clone();
1333 Box::pin(async move {
1334 o.write().await.push(0);
1335 Ok(())
1336 })
1337 }),
1338 block: true,
1339 },
1340 0,
1341 )
1342 .await;
1343 app.process_update(Arc::new(make_update(serde_json::json!({"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"}}})))).await.unwrap();
1344 assert_eq!(*order.read().await, vec![0, 1]);
1345 }
1346
1347 #[tokio::test]
1348 async fn handler_stop_prevents_further_groups() {
1349 let app = make_app();
1350 app.initialize().await.unwrap();
1351 let reached = Arc::new(std::sync::atomic::AtomicBool::new(false));
1352 app.add_raw_handler(
1353 Handler {
1354 check_update: Arc::new(|_| true),
1355 callback: Arc::new(|_, _| {
1356 Box::pin(async { Err(HandlerError::HandlerStop { state: None }) })
1357 }),
1358 block: true,
1359 },
1360 0,
1361 )
1362 .await;
1363 let r = reached.clone();
1364 app.add_raw_handler(
1365 Handler {
1366 check_update: Arc::new(|_| true),
1367 callback: Arc::new(move |_, _| {
1368 let r = r.clone();
1369 Box::pin(async move {
1370 r.store(true, std::sync::atomic::Ordering::SeqCst);
1371 Ok(())
1372 })
1373 }),
1374 block: true,
1375 },
1376 1,
1377 )
1378 .await;
1379 app.process_update(Arc::new(make_update(serde_json::json!({"update_id":1}))))
1380 .await
1381 .unwrap();
1382 assert!(!reached.load(std::sync::atomic::Ordering::SeqCst));
1383 }
1384
1385 #[tokio::test]
1386 async fn only_first_matching_handler_per_group() {
1387 let app = make_app();
1388 app.initialize().await.unwrap();
1389 let first = Arc::new(std::sync::atomic::AtomicBool::new(false));
1390 let second = Arc::new(std::sync::atomic::AtomicBool::new(false));
1391 let f = first.clone();
1392 app.add_raw_handler(
1393 Handler {
1394 check_update: Arc::new(|_| true),
1395 callback: Arc::new(move |_, _| {
1396 let f = f.clone();
1397 Box::pin(async move {
1398 f.store(true, std::sync::atomic::Ordering::SeqCst);
1399 Ok(())
1400 })
1401 }),
1402 block: true,
1403 },
1404 0,
1405 )
1406 .await;
1407 let s = second.clone();
1408 app.add_raw_handler(
1409 Handler {
1410 check_update: Arc::new(|_| true),
1411 callback: Arc::new(move |_, _| {
1412 let s = s.clone();
1413 Box::pin(async move {
1414 s.store(true, std::sync::atomic::Ordering::SeqCst);
1415 Ok(())
1416 })
1417 }),
1418 block: true,
1419 },
1420 0,
1421 )
1422 .await;
1423 app.process_update(Arc::new(make_update(serde_json::json!({"update_id":1}))))
1424 .await
1425 .unwrap();
1426 assert!(first.load(std::sync::atomic::Ordering::SeqCst));
1427 assert!(!second.load(std::sync::atomic::Ordering::SeqCst));
1428 }
1429
1430 #[tokio::test]
1431 async fn error_handler_called_on_failure() {
1432 let app = make_app();
1433 app.initialize().await.unwrap();
1434 app.add_raw_handler(
1435 Handler {
1436 check_update: Arc::new(|_| true),
1437 callback: Arc::new(|_, _| {
1438 Box::pin(async {
1439 Err(HandlerError::Other(Box::new(std::io::Error::new(
1440 std::io::ErrorKind::Other,
1441 "test",
1442 ))))
1443 })
1444 }),
1445 block: true,
1446 },
1447 0,
1448 )
1449 .await;
1450 let seen = Arc::new(std::sync::atomic::AtomicBool::new(false));
1451 let s = seen.clone();
1452 let eh: ErrorHandlerCallback = Arc::new(move |_, ctx| {
1453 let s = s.clone();
1454 Box::pin(async move {
1455 s.store(true, std::sync::atomic::Ordering::SeqCst);
1456 assert!(ctx.error.is_some());
1457 false
1458 })
1459 });
1460 app.add_error_handler(eh, true).await;
1461 app.process_update(Arc::new(make_update(serde_json::json!({"update_id":1}))))
1462 .await
1463 .unwrap();
1464 assert!(seen.load(std::sync::atomic::Ordering::SeqCst));
1465 }
1466
1467 #[tokio::test]
1468 async fn error_handler_can_signal_stop() {
1469 let app = make_app();
1470 app.initialize().await.unwrap();
1471 app.add_raw_handler(
1472 Handler {
1473 check_update: Arc::new(|_| true),
1474 callback: Arc::new(|_, _| {
1475 Box::pin(async {
1476 Err(HandlerError::Other(Box::new(std::io::Error::new(
1477 std::io::ErrorKind::Other,
1478 "e",
1479 ))))
1480 })
1481 }),
1482 block: true,
1483 },
1484 0,
1485 )
1486 .await;
1487 let eh: ErrorHandlerCallback = Arc::new(|_, _| Box::pin(async { true }));
1488 let reached = Arc::new(std::sync::atomic::AtomicBool::new(false));
1489 let r = reached.clone();
1490 app.add_raw_handler(
1491 Handler {
1492 check_update: Arc::new(|_| true),
1493 callback: Arc::new(move |_, _| {
1494 let r = r.clone();
1495 Box::pin(async move {
1496 r.store(true, std::sync::atomic::Ordering::SeqCst);
1497 Ok(())
1498 })
1499 }),
1500 block: true,
1501 },
1502 1,
1503 )
1504 .await;
1505 app.add_error_handler(eh, true).await;
1506 app.process_update(Arc::new(make_update(serde_json::json!({"update_id":1}))))
1507 .await
1508 .unwrap();
1509 assert!(!reached.load(std::sync::atomic::Ordering::SeqCst));
1510 }
1511
1512 #[tokio::test]
1513 async fn process_update_before_initialize_fails() {
1514 let app = make_app();
1515 assert!(app
1516 .process_update(Arc::new(make_update(serde_json::json!({"update_id": 0}))))
1517 .await
1518 .is_err());
1519 }
1520
1521 #[tokio::test]
1522 async fn drop_chat_and_user_data() {
1523 let app = make_app();
1524 {
1525 app.chat_data.write().await.insert(42, HashMap::new());
1526 }
1527 {
1528 app.user_data.write().await.insert(7, HashMap::new());
1529 }
1530 app.drop_chat_data(42).await;
1531 app.drop_user_data(7).await;
1532 assert!(app.chat_data.read().await.get(&42).is_none());
1533 assert!(app.user_data.read().await.get(&7).is_none());
1534 }
1535
1536 #[tokio::test]
1537 async fn migrate_chat_data() {
1538 let app = make_app();
1539 {
1540 let mut s = app.chat_data.write().await;
1541 let mut d = HashMap::new();
1542 d.insert("key".into(), Value::String("val".into()));
1543 s.insert(100, d);
1544 }
1545 app.migrate_chat_data(100, 200).await;
1546 let s = app.chat_data.read().await;
1547 assert!(s.get(&100).is_none());
1548 assert_eq!(
1549 s.get(&200).unwrap().get("key"),
1550 Some(&Value::String("val".into()))
1551 );
1552 }
1553
1554 #[tokio::test]
1555 async fn update_sender_works() {
1556 let app = make_app();
1557 assert!(app
1558 .update_sender()
1559 .send(make_update(serde_json::json!({"update_id":1})))
1560 .await
1561 .is_ok());
1562 }
1563
1564 #[cfg(feature = "job-queue")]
1565 #[tokio::test]
1566 async fn job_queue_accessor() {
1567 let app = make_app();
1568 assert!(app.job_queue().is_none());
1569 }
1570
1571 #[tokio::test]
1572 async fn create_task_tracks_handle() {
1573 let app = make_app();
1574 let flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
1575 let f = flag.clone();
1576 app.create_task(async move {
1577 f.store(true, std::sync::atomic::Ordering::SeqCst);
1578 })
1579 .await;
1580 tokio::task::yield_now().await;
1582 assert!(flag.load(std::sync::atomic::Ordering::SeqCst));
1583 assert_eq!(app.pending_tasks.read().await.len(), 1);
1584 }
1585}