Skip to main content

wx_bot_sdk/
multi_bot.rs

1use std::sync::{Arc, Mutex};
2
3use tokio::task::JoinHandle;
4
5use crate::{
6    auth::accounts::{CDN_BASE_URL, DEFAULT_BASE_URL},
7    bot::{StartOptions, WeixinBot, WeixinBotOptions},
8    messaging::process_message::MessageHandler,
9};
10
11#[derive(Clone, Debug)]
12pub struct BotAccountOptions {
13    pub token: String,
14    pub account_id: Option<String>,
15    pub base_url: Option<String>,
16    pub cdn_base_url: Option<String>,
17}
18
19#[derive(Clone, Debug)]
20pub struct MultiWeixinBotOptions {
21    pub accounts: Vec<BotAccountOptions>,
22    pub state_dir: Option<String>,
23}
24
25pub struct MultiStartOptions {
26    pub on_message: MessageHandler,
27    pub long_poll_timeout_ms: Option<u64>,
28}
29
30#[derive(Clone)]
31pub struct MultiWeixinBot {
32    bots: Vec<WeixinBot>,
33    handles: Arc<Mutex<Vec<JoinHandle<crate::Result<()>>>>>,
34}
35
36impl MultiWeixinBot {
37    pub fn new(opts: MultiWeixinBotOptions) -> Self {
38        let bots = opts
39            .accounts
40            .into_iter()
41            .map(|account| {
42                WeixinBot::new(WeixinBotOptions {
43                    token: account.token,
44                    base_url: account
45                        .base_url
46                        .or_else(|| Some(DEFAULT_BASE_URL.to_string())),
47                    cdn_base_url: account
48                        .cdn_base_url
49                        .or_else(|| Some(CDN_BASE_URL.to_string())),
50                    state_dir: opts.state_dir.clone(),
51                    account_id: account.account_id,
52                    user_id: None,
53                })
54            })
55            .collect();
56        Self {
57            bots,
58            handles: Arc::new(Mutex::new(Vec::new())),
59        }
60    }
61
62    pub async fn start(&self, opts: MultiStartOptions) -> crate::Result<()> {
63        let mut handles = self.handles.lock().expect("multi bot handles poisoned");
64        if !handles.is_empty() {
65            return Ok(());
66        }
67
68        for bot in &self.bots {
69            let bot = bot.clone();
70            let on_message = opts.on_message.clone();
71            let long_poll_timeout_ms = opts.long_poll_timeout_ms;
72            handles.push(tokio::spawn(async move {
73                bot.start(StartOptions {
74                    on_message,
75                    long_poll_timeout_ms,
76                })
77                .await
78            }));
79        }
80        Ok(())
81    }
82
83    pub async fn stop(&self) -> crate::Result<()> {
84        for bot in &self.bots {
85            bot.stop().await?;
86        }
87        Ok(())
88    }
89
90    pub async fn join(&self) -> crate::Result<()> {
91        let handles = {
92            let mut locked = self.handles.lock().expect("multi bot handles poisoned");
93            locked.drain(..).collect::<Vec<_>>()
94        };
95
96        for handle in handles {
97            handle.await??;
98        }
99        Ok(())
100    }
101
102    pub fn bots(&self) -> &[WeixinBot] {
103        &self.bots
104    }
105
106    pub fn account_ids(&self) -> Vec<String> {
107        self.bots
108            .iter()
109            .map(|bot| bot.account_id().to_string())
110            .collect()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn creates_one_bot_per_token() {
120        let bot = MultiWeixinBot::new(MultiWeixinBotOptions {
121            accounts: vec![
122                BotAccountOptions {
123                    token: "token-a".into(),
124                    account_id: None,
125                    base_url: None,
126                    cdn_base_url: None,
127                },
128                BotAccountOptions {
129                    token: "token-b".into(),
130                    account_id: None,
131                    base_url: None,
132                    cdn_base_url: None,
133                },
134            ],
135            state_dir: None,
136        });
137        assert_eq!(bot.bots().len(), 2);
138        let ids = bot.account_ids();
139        assert_eq!(ids.len(), 2);
140        assert_ne!(ids[0], ids[1]);
141    }
142
143    #[test]
144    fn keeps_explicit_account_ids() {
145        let bot = MultiWeixinBot::new(MultiWeixinBotOptions {
146            accounts: vec![BotAccountOptions {
147                token: "token-a".into(),
148                account_id: Some("account-a".into()),
149                base_url: None,
150                cdn_base_url: None,
151            }],
152            state_dir: None,
153        });
154        assert_eq!(bot.account_ids(), vec!["account-a".to_string()]);
155    }
156}