Skip to main content

weixin_agent/monitor/
poll_loop.rs

1//! Long-poll `getUpdates` loop with error handling, backoff, and session guard.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio_util::sync::CancellationToken;
7
8use crate::api::client::HttpApiClient;
9use crate::api::config_cache::ConfigCache;
10use crate::api::session_guard::SessionGuard;
11use crate::error::Result;
12use crate::messaging::inbound::{self, ContextTokenStore, MessageSender};
13use crate::types::{
14    BACKOFF_DELAY_MS, GetUpdatesRequest, MAX_CONSECUTIVE_FAILURES, RETRY_DELAY_MS,
15    SESSION_EXPIRED_ERRCODE, build_base_info,
16};
17
18/// The handler trait users implement to receive messages.
19#[async_trait::async_trait]
20pub trait MessageHandler: Send + Sync {
21    /// Called for each inbound user message.
22    async fn on_message(&self, ctx: &inbound::MessageContext) -> Result<()>;
23
24    /// Called when `get_updates_buf` changes — persist it here.
25    async fn on_sync_buf_updated(&self, _sync_buf: &str) -> Result<()> {
26        Ok(())
27    }
28
29    /// Lifecycle hook: called before the poll loop starts.
30    async fn on_start(&self) -> Result<()> {
31        Ok(())
32    }
33
34    /// Lifecycle hook: called after the poll loop ends.
35    async fn on_shutdown(&self) -> Result<()> {
36        Ok(())
37    }
38}
39
40/// Run the long-poll monitor loop. Blocks until cancellation.
41#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
42pub(crate) async fn run_monitor(
43    api: Arc<HttpApiClient>,
44    cdn_base_url: String,
45    handler: Arc<dyn MessageHandler>,
46    session_guard: Arc<SessionGuard>,
47    config_cache: Arc<ConfigCache>,
48    context_tokens: Arc<ContextTokenStore>,
49    initial_sync_buf: Option<String>,
50    initial_timeout: Duration,
51    cancel: CancellationToken,
52) -> Result<()> {
53    handler.on_start().await?;
54
55    let mut get_updates_buf = initial_sync_buf.unwrap_or_default();
56    let mut next_timeout = initial_timeout;
57    let mut consecutive_failures: u32 = 0;
58
59    let sender = Arc::new(MessageSender {
60        api: Arc::clone(&api),
61        cdn_base_url: cdn_base_url.clone(),
62        config_cache: Arc::clone(&config_cache),
63    });
64
65    loop {
66        if cancel.is_cancelled() {
67            break;
68        }
69
70        // Check session guard
71        if session_guard.is_paused() {
72            let remaining = session_guard.remaining_ms();
73            tracing::info!(remaining_ms = remaining, "session paused, sleeping");
74            tokio::select! {
75                () = cancel.cancelled() => break,
76                () = tokio::time::sleep(Duration::from_millis(remaining)) => continue,
77            }
78        }
79
80        let req = GetUpdatesRequest {
81            get_updates_buf: get_updates_buf.clone(),
82            base_info: build_base_info(),
83        };
84
85        let resp = tokio::select! {
86            () = cancel.cancelled() => break,
87            result = api.get_updates(&req, next_timeout) => {
88                match result {
89                    Ok(r) => r,
90                    Err(e) => {
91                        consecutive_failures += 1;
92                        tracing::error!(
93                            error = %e,
94                            failures = consecutive_failures,
95                            "getUpdates error"
96                        );
97                        if consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
98                            consecutive_failures = 0;
99                            sleep_or_cancel(Duration::from_millis(BACKOFF_DELAY_MS), &cancel).await;
100                        } else {
101                            sleep_or_cancel(Duration::from_millis(RETRY_DELAY_MS), &cancel).await;
102                        }
103                        continue;
104                    }
105                }
106            }
107        };
108
109        // Update dynamic timeout
110        if let Some(t) = resp.longpolling_timeout_ms {
111            if t > 0 {
112                next_timeout = Duration::from_millis(t);
113            }
114        }
115
116        // Check API-level errors
117        let is_error = resp.ret.unwrap_or(0) != 0 || resp.errcode.unwrap_or(0) != 0;
118        if is_error {
119            let errcode = resp.errcode.or(resp.ret).unwrap_or(0);
120            if errcode == SESSION_EXPIRED_ERRCODE {
121                session_guard.pause();
122                consecutive_failures = 0;
123                let remaining = session_guard.remaining_ms();
124                tracing::error!(
125                    errcode,
126                    remaining_min = remaining / 60_000,
127                    "session expired, pausing"
128                );
129                sleep_or_cancel(Duration::from_millis(remaining), &cancel).await;
130                continue;
131            }
132
133            consecutive_failures += 1;
134            tracing::error!(
135                ret = resp.ret,
136                errcode = resp.errcode,
137                errmsg = resp.errmsg.as_deref().unwrap_or(""),
138                failures = consecutive_failures,
139                "getUpdates API error"
140            );
141            if consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
142                consecutive_failures = 0;
143                sleep_or_cancel(Duration::from_millis(BACKOFF_DELAY_MS), &cancel).await;
144            } else {
145                sleep_or_cancel(Duration::from_millis(RETRY_DELAY_MS), &cancel).await;
146            }
147            continue;
148        }
149
150        // Success
151        consecutive_failures = 0;
152
153        // Update sync buf (prefer get_updates_buf, fall back to deprecated sync_buf)
154        let new_buf = resp
155            .get_updates_buf
156            .as_ref()
157            .or(resp.sync_buf.as_ref())
158            .filter(|b| !b.is_empty());
159        if let Some(new_buf) = new_buf {
160            get_updates_buf.clone_from(new_buf);
161            if let Err(e) = handler.on_sync_buf_updated(new_buf).await {
162                tracing::error!(error = %e, "on_sync_buf_updated failed");
163            }
164        }
165
166        // Process messages
167        let msgs = resp.msgs.unwrap_or_default();
168        for msg in &msgs {
169            if !inbound::should_process(msg) {
170                continue;
171            }
172
173            // Update context token store
174            if let (Some(from), Some(token)) = (&msg.from_user_id, &msg.context_token) {
175                context_tokens.set(from, token);
176            }
177
178            let ctx = inbound::parse_inbound_message(msg, Arc::clone(&sender));
179            if let Err(e) = handler.on_message(&ctx).await {
180                tracing::error!(
181                    error = %e,
182                    from = %ctx.from,
183                    message_id = %ctx.message_id,
184                    "on_message handler error"
185                );
186            }
187        }
188    }
189
190    handler.on_shutdown().await?;
191    tracing::info!("monitor loop ended");
192    Ok(())
193}
194
195async fn sleep_or_cancel(duration: Duration, cancel: &CancellationToken) {
196    tokio::select! {
197        () = cancel.cancelled() => {},
198        () = tokio::time::sleep(duration) => {},
199    }
200}