Skip to main content

shirabe_core/
adapter.rs

1use crate::{
2    bot::Bot,
3    context::Context,
4    error::FrameworkResult,
5    message::MessageElement,
6    types::{Channel, Guild, GuildMember, GuildRole, Login, LoginStatus, Message, User},
7};
8use async_trait::async_trait;
9use futures_util::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::net::TcpStream;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
14
15use url::Url;
16
17/// 适配器 trait
18/// 适配器负责与具体的聊天平台进行通信。
19#[async_trait]
20pub trait Adapter: Send + Sync + std::fmt::Debug {
21    /// 获取适配器的名称
22    fn get_name(&self) -> String;
23
24    /// 连接到聊天平台并开始接收事件
25    /// 此方法会接收一个 Arc<Bot> 的引用,以便适配器可以将事件传递给 Bot,
26    /// 或者通过 Bot 调用其他服务。
27    async fn connect(&self, bot: Arc<Bot>);
28
29    /// 断开与聊天平台的连接
30    async fn disconnect(&self, bot: Arc<Bot>);
31
32    /// 向特定消息添加某个特定表态
33    async fn create_reaction(
34        &self,
35        message_id: &str,
36        channel_id: &str,
37        emoji: &str,
38    ) -> FrameworkResult<()>;
39
40    /// 从特定消息删除某个用户添加的特定表态
41    async fn delete_reaction(
42        &self,
43        message_id: &str,
44        channel_id: &str,
45        emoji: &str,
46        user_id: &str,
47    ) -> FrameworkResult<()>;
48
49    /// 从特定消息清除某个特定表态
50    async fn clear_reaction(
51        &self,
52        message_id: &str,
53        channel_id: &str,
54        emoji: &str,
55    ) -> FrameworkResult<()>;
56
57    /// 获取添加特定消息的特定表态的用户列表
58    async fn get_reaction_list(
59        &self,
60        message_id: &str,
61        channel_id: &str,
62        emoji: &str,
63        next: Option<&str>,
64    ) -> FrameworkResult<Vec<User>>;
65
66    /// 获取频道信息
67    async fn get_channel(&self, channel_id: &str) -> FrameworkResult<Channel>;
68
69    /// 获取某个群组的频道列表
70    async fn get_channel_list(
71        &self,
72        guild_id: &str,
73        next: Option<&str>,
74    ) -> FrameworkResult<Vec<Channel>>;
75
76    /// 创建群组频道
77    async fn create_channel(&self, guild_id: &str, data: Channel) -> FrameworkResult<Channel>;
78
79    /// 修改群组频道
80    async fn update_channel(&self, channel_id: &str, data: Channel) -> FrameworkResult<()>;
81
82    /// 删除群组频道
83    async fn delete_channel(&self, channel_id: &str) -> FrameworkResult<()>;
84
85    /// 创建私聊频道
86    async fn create_direct_channel(&self, user_id: &str) -> FrameworkResult<Channel>;
87
88    /// 设置群组内用户的角色
89    async fn set_guild_member_role(
90        &self,
91        guild_id: &str,
92        user_id: &str,
93        role_id: &str,
94    ) -> FrameworkResult<()>;
95
96    /// 取消群组内用户的角色
97    async fn unset_guild_member_role(
98        &self,
99        guild_id: &str,
100        user_id: &str,
101        role_id: &str,
102    ) -> FrameworkResult<()>;
103
104    /// 获取群组内用户的角色列表
105    async fn get_guild_member_role_list(
106        &self,
107        guild_id: &str,
108        next: Option<&str>,
109    ) -> FrameworkResult<Vec<GuildRole>>;
110
111    /// 创建群组角色
112    async fn create_guild_role(
113        &self,
114        guild_id: &str,
115        role_name: &str,
116    ) -> FrameworkResult<GuildRole>;
117
118    /// 修改群组角色
119    async fn update_guild_role(
120        &self,
121        guild_id: &str,
122        role_id: &str,
123        role: GuildRole,
124    ) -> FrameworkResult<()>;
125
126    /// 删除群组角色
127    async fn delete_guild_role(&self, guild_id: &str, role_id: &str) -> FrameworkResult<()>;
128
129    /// 向特定频道发送消息
130    /// 返回消息ID列表
131    async fn send_message(
132        &self,
133        channel_id: &str,
134        elements: &[MessageElement],
135    ) -> FrameworkResult<Vec<String>>;
136
137    /// 向特定用户发送私信
138    async fn send_private_message(
139        &self,
140        user_id: &str,
141        guild_id: &str,
142        elements: &[MessageElement],
143    ) -> FrameworkResult<Vec<String>>;
144
145    /// 获取特定消息
146    async fn get_message(&self, channel_id: &str, message_id: &str) -> FrameworkResult<Message>;
147
148    /// 撤回特定消息
149    async fn delete_message(&self, channel_id: &str, message_id: &str) -> FrameworkResult<()>;
150
151    /// 修改特定消息
152    async fn update_message(
153        &self,
154        channel_id: &str,
155        message_id: &str,
156        elements: &[MessageElement],
157    ) -> FrameworkResult<()>;
158
159    /// 获取频道消息列表
160    async fn get_message_list(
161        &self,
162        channel_id: &str,
163        next: Option<&str>,
164        directory: Option<&str>,
165    ) -> FrameworkResult<Vec<Message>>;
166
167    /// 获取用户信息
168    async fn get_user(&self, user_id: &str) -> FrameworkResult<User>;
169
170    /// 获取机器人的好友列表
171    async fn get_friends(&self, next: Option<&str>) -> FrameworkResult<Vec<User>>;
172
173    /// 处理好友请求
174    async fn handle_friend_request(
175        &self,
176        message_id: &str,
177        accept: bool,
178        comment: Option<&str>,
179    ) -> FrameworkResult<()>;
180
181    /// 获取群组信息
182    async fn get_guild(&self, guild_id: &str) -> FrameworkResult<Guild>;
183
184    /// 获取机器人加入的群组列表
185    async fn get_guilds(&self, next: Option<&str>) -> FrameworkResult<Vec<Guild>>;
186
187    /// 处理来自群组的邀请
188    async fn handle_guild_invite(
189        &self,
190        message_id: &str,
191        accept: bool,
192        comment: Option<&str>,
193    ) -> FrameworkResult<()>;
194
195    /// 获取群成员信息
196    async fn get_guild_member(&self, guild_id: &str, user_id: &str)
197    -> FrameworkResult<GuildMember>;
198
199    /// 获取群成员列表
200    async fn get_guild_members(
201        &self,
202        guild_id: &str,
203        next: Option<&str>,
204    ) -> FrameworkResult<Vec<GuildMember>>;
205
206    /// 将某个用户踢出群组
207    async fn kick_guild_member(
208        &self,
209        guild_id: &str,
210        user_id: &str,
211        permanent: Option<bool>,
212    ) -> FrameworkResult<()>;
213
214    /// 禁言某个用户
215    async fn mute_guild_member(
216        &self,
217        guild_id: &str,
218        user_id: &str,
219        duration: Option<u64>,
220        reason: &str,
221    ) -> FrameworkResult<()>;
222
223    /// 处理加群请求
224    async fn handle_guild_request(
225        &self,
226        message_id: &str,
227        accept: bool,
228        comment: Option<&str>,
229    ) -> FrameworkResult<()>;
230
231    /// 获取登陆状态
232    async fn get_login(&self) -> FrameworkResult<Login>;
233}
234
235#[derive(Debug, Clone, Deserialize)]
236pub struct WSClientConfig<C> {
237    retry_lazy: u64,
238    retry_times: u64,
239    retry_interval: u64,
240    _extend: Option<C>,
241}
242
243#[async_trait]
244pub trait WSClient<C>: Adapter
245where
246    C: for<'de> Deserialize<'de> + Send,
247{
248    /// 获取适配器的上下文
249    fn ctx(&self) -> Context;
250
251    /// 获取适配器下的Bot实例
252    fn bot(&self) -> Arc<Bot>;
253
254    /// 获取适配器的WebSocket实例
255    fn socket(&self) -> Option<WebSocketStream<MaybeTlsStream<TcpStream>>>;
256
257    /// 获取适配器的配置
258    fn config(&self) -> WSClientConfig<C>;
259
260    /// 根据Bot实例生成一个WebSocket对象
261    async fn prepare(&self) -> FrameworkResult<(WebSocketStream<MaybeTlsStream<TcpStream>>, Url)>;
262
263    /// WebSocket连接成功后建立的回调函数
264    async fn accept(&self);
265
266    /// 设置status
267    fn set_status(&self, status: LoginStatus);
268
269    /// 获取适配器的状态
270    fn get_active(&self) -> bool;
271
272    async fn start(&self) {
273        let mut retry_count = 0;
274        let ws_config = self.config();
275
276        loop {
277            if !self.get_active() {
278                tracing::debug!(
279                    "Adapter {} is not active, stopping connection attempts.",
280                    self.get_name()
281                );
282                self.set_status(LoginStatus::Offline);
283                return;
284            }
285
286            tracing::debug!(
287                "Adapter {} (attempt {}): Trying to connect...",
288                self.get_name(),
289                retry_count + 1
290            );
291
292            let mut socket_stream = match self.prepare().await {
293                Ok((stream, _url)) => {
294                    self.set_status(LoginStatus::Online);
295                    tracing::info!("Adapter {} connected successfully.", self.get_name());
296                    if retry_count > 0 {
297                        retry_count = 0;
298                    }
299                    self.accept().await;
300                    stream
301                }
302                Err(e) => {
303                    tracing::warn!(
304                        "Adapter {} failed to prepare connection: {}",
305                        self.get_name(),
306                        e
307                    );
308                    let timeout = if retry_count >= ws_config.retry_times {
309                        if ws_config.retry_lazy == 0 {
310                            tracing::error!(
311                                "Adapter {} reached max retry attempts ({}) and no lazy retry configured. Stopping.",
312                                self.get_name(),
313                                ws_config.retry_times
314                            );
315                            self.set_status(LoginStatus::Offline);
316                            return;
317                        }
318                        if retry_count == ws_config.retry_times {
319                            tracing::warn!(
320                                "Adapter {} reached max retry attempts. Falling back to lazy retry ({}ms).",
321                                self.get_name(),
322                                ws_config.retry_lazy
323                            );
324                        }
325                        ws_config.retry_lazy
326                    } else {
327                        ws_config.retry_interval
328                    };
329
330                    retry_count += 1;
331                    self.set_status(LoginStatus::Reconnect);
332                    tracing::info!(
333                        "Adapter {} will retry connection in {}ms (attempt {}).",
334                        self.get_name(),
335                        timeout,
336                        retry_count
337                    );
338                    tokio::time::sleep(tokio::time::Duration::from_millis(timeout)).await;
339                    continue;
340                }
341            };
342
343            tracing::debug!("Adapter {} listening for messages.", self.get_name());
344            while let Some(message_result) = socket_stream.next().await {
345                if !self.get_active() {
346                    tracing::info!(
347                        "Adapter {} became inactive while listening. Closing connection.",
348                        self.get_name()
349                    );
350                    let _ = socket_stream.close(None).await; // Attempt to close gracefully
351                    self.set_status(LoginStatus::Offline);
352                    return;
353                }
354
355                match message_result {
356                    Ok(msg) => {
357                        if msg.is_close() {
358                            tracing::info!(
359                                "Adapter {} received WebSocket Close frame. Connection closed by peer.",
360                                self.get_name(),
361                            );
362                            break;
363                        }
364                    }
365                    Err(e) => {
366                        tracing::error!(
367                            "Adapter {} error while receiving message: {}. Attempting to reconnect.",
368                            self.get_name(),
369                            e
370                        );
371                        break;
372                    }
373                }
374            }
375
376            if !self.get_active() {
377                tracing::info!(
378                    "Adapter {} became inactive after message loop. Not reconnecting.",
379                    self.get_name()
380                );
381                self.set_status(LoginStatus::Offline);
382                return;
383            }
384
385            tracing::warn!(
386                "Adapter {} disconnected or encountered an error in message loop. Preparing to reconnect.",
387                self.get_name()
388            );
389
390            let timeout = if retry_count >= ws_config.retry_times {
391                if ws_config.retry_lazy == 0 {
392                    tracing::error!(
393                        "Adapter {} reached max retry attempts ({}) for disconnection and no lazy retry. Stopping.",
394                        self.get_name(),
395                        ws_config.retry_times
396                    );
397                    self.set_status(LoginStatus::Offline);
398                    return;
399                }
400                if retry_count == ws_config.retry_times {
401                    tracing::warn!(
402                        "Adapter {} reached max retry attempts for disconnection. Falling back to lazy retry ({}ms).",
403                        self.get_name(),
404                        ws_config.retry_lazy
405                    );
406                }
407                ws_config.retry_lazy
408            } else {
409                ws_config.retry_interval
410            };
411
412            retry_count += 1;
413            self.set_status(LoginStatus::Reconnect);
414            tracing::info!(
415                "Adapter {} will retry connection in {}ms (attempt {}).",
416                self.get_name(),
417                timeout,
418                retry_count
419            );
420            tokio::time::sleep(tokio::time::Duration::from_millis(timeout)).await;
421        }
422    }
423
424    async fn stop(&self) -> FrameworkResult<()> {
425        if let Some(mut socket) = self.socket() {
426            socket.close(None).await?;
427        }
428        self.set_status(LoginStatus::Offline);
429        tracing::info!("适配器 {} 已停止。", self.get_name());
430        Ok(())
431    }
432}