Skip to main content

sh_layer4/channel_gateway/
mod.rs

1//! # Channel Gateway
2//!
3//! 多渠道消息网关,支持 CLI、HTTP、WebSocket 等多种渠道接入。
4
5pub mod adapter;
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use sh_layer3::generate_short_id;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use crate::types::Layer4Result;
15
16/// 渠道接口
17#[async_trait]
18pub trait Channel: Send + Sync {
19    /// 获取渠道 ID
20    fn id(&self) -> &str;
21
22    /// 获取渠道类型
23    fn channel_type(&self) -> ChannelType;
24
25    /// 发送消息
26    async fn send(&self, message: &OutboundMessage) -> Layer4Result<()>;
27
28    /// 接收消息(非阻塞)
29    async fn try_receive(&self) -> Layer4Result<Option<InboundMessage>>;
30
31    /// 检查是否连接
32    fn is_connected(&self) -> bool;
33
34    /// 关闭渠道
35    async fn close(&self) -> Layer4Result<()>;
36}
37
38/// 渠道类型
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub enum ChannelType {
41    Cli,
42    Http,
43    WebSocket,
44    Mqtt,
45    Custom,
46}
47
48impl std::fmt::Display for ChannelType {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            Self::Cli => write!(f, "cli"),
52            Self::Http => write!(f, "http"),
53            Self::WebSocket => write!(f, "websocket"),
54            Self::Mqtt => write!(f, "mqtt"),
55            Self::Custom => write!(f, "custom"),
56        }
57    }
58}
59
60/// 入站消息
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct InboundMessage {
63    /// 消息 ID
64    pub message_id: String,
65    /// 渠道 ID
66    pub channel_id: String,
67    /// 用户 ID
68    pub user_id: String,
69    /// 会话 ID(可选)
70    pub session_id: Option<String>,
71    /// 消息内容
72    pub content: String,
73    /// 消息类型
74    pub message_type: MessageType,
75    /// 元数据
76    pub metadata: serde_json::Value,
77    /// 时间戳
78    pub timestamp: chrono::DateTime<chrono::Utc>,
79}
80
81impl InboundMessage {
82    pub fn new(
83        channel_id: impl Into<String>,
84        user_id: impl Into<String>,
85        content: impl Into<String>,
86    ) -> Self {
87        Self {
88            message_id: generate_short_id(),
89            channel_id: channel_id.into(),
90            user_id: user_id.into(),
91            session_id: None,
92            content: content.into(),
93            message_type: MessageType::Text,
94            metadata: serde_json::Value::Null,
95            timestamp: chrono::Utc::now(),
96        }
97    }
98
99    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
100        self.session_id = Some(session_id.into());
101        self
102    }
103
104    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
105        self.metadata = metadata;
106        self
107    }
108}
109
110/// 出站消息
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct OutboundMessage {
113    /// 消息 ID
114    pub message_id: String,
115    /// 消息内容
116    pub content: String,
117    /// 消息类型
118    pub message_type: MessageType,
119    /// 目标
120    pub target: MessageTarget,
121    /// 元数据
122    pub metadata: serde_json::Value,
123    /// 时间戳
124    pub timestamp: chrono::DateTime<chrono::Utc>,
125}
126
127impl OutboundMessage {
128    pub fn new(content: impl Into<String>, target: MessageTarget) -> Self {
129        Self {
130            message_id: generate_short_id(),
131            content: content.into(),
132            message_type: MessageType::Text,
133            target,
134            metadata: serde_json::Value::Null,
135            timestamp: chrono::Utc::now(),
136        }
137    }
138
139    pub fn broadcast(content: impl Into<String>) -> Self {
140        Self::new(content, MessageTarget::All)
141    }
142
143    pub fn to_channel(channel_id: impl Into<String>, content: impl Into<String>) -> Self {
144        Self::new(content, MessageTarget::Channel(channel_id.into()))
145    }
146
147    pub fn to_user(user_id: impl Into<String>, content: impl Into<String>) -> Self {
148        Self::new(content, MessageTarget::User(user_id.into()))
149    }
150}
151
152/// 消息目标
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum MessageTarget {
155    /// 广播到所有渠道
156    All,
157    /// 发送到指定渠道
158    Channel(String),
159    /// 发送到指定用户
160    User(String),
161    /// 发送到指定会话
162    Session(String),
163}
164
165/// 消息类型
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
167pub enum MessageType {
168    Text,
169    Json,
170    Binary,
171    Command,
172    Event,
173    Error,
174}
175
176/// 渠道网关
177pub struct ChannelGateway {
178    channels: RwLock<HashMap<String, Box<dyn Channel>>>,
179    router: MessageRouter,
180    message_queue: RwLock<Vec<InboundMessage>>,
181}
182
183impl ChannelGateway {
184    /// 创建新的渠道网关
185    pub fn new() -> Self {
186        Self {
187            channels: RwLock::new(HashMap::new()),
188            router: MessageRouter::new(),
189            message_queue: RwLock::new(Vec::new()),
190        }
191    }
192
193    /// 注册渠道
194    pub async fn register_channel(&self, channel: Box<dyn Channel>) -> Layer4Result<()> {
195        let id = channel.id().to_string();
196        let channel_type = channel.channel_type();
197
198        self.channels.write().insert(id.clone(), channel);
199        self.router.register_channel(&id, channel_type);
200
201        tracing::info!("Registered channel: {} ({})", id, channel_type);
202        Ok(())
203    }
204
205    /// 注销渠道
206    pub async fn unregister_channel(&self, channel_id: &str) -> Layer4Result<bool> {
207        let channel = self.channels.write().remove(channel_id);
208        if let Some(channel) = channel {
209            channel.close().await?;
210            self.router.unregister_channel(channel_id);
211            tracing::info!("Unregistered channel: {}", channel_id);
212            Ok(true)
213        } else {
214            Ok(false)
215        }
216    }
217
218    /// 获取渠道
219    pub fn get_channel(&self, _channel_id: &str) -> Option<Arc<dyn Channel>> {
220        // 由于 Box<dyn Channel> 不能直接克隆,我们返回 Option
221        // 实际使用时需要重新设计
222        None
223    }
224
225    /// 列出所有渠道
226    pub fn list_channels(&self) -> Vec<(String, ChannelType)> {
227        self.channels
228            .read()
229            .iter()
230            .map(|(id, ch)| (id.clone(), ch.channel_type()))
231            .collect()
232    }
233
234    /// 广播消息到所有渠道
235    #[allow(clippy::await_holding_lock)]
236    pub async fn broadcast(&self, message: &OutboundMessage) -> Layer4Result<()> {
237        let channels = self.channels.read();
238        for (id, channel) in channels.iter() {
239            if let Err(e) = channel.send(message).await {
240                tracing::error!("Failed to send to channel {}: {}", id, e);
241            }
242        }
243        Ok(())
244    }
245
246    /// 发送消息到指定目标
247    #[allow(clippy::await_holding_lock)]
248    pub async fn send_to(
249        &self,
250        target: &MessageTarget,
251        message: &OutboundMessage,
252    ) -> Layer4Result<()> {
253        match target {
254            MessageTarget::All => self.broadcast(message).await,
255            MessageTarget::Channel(channel_id) => {
256                let channels = self.channels.read();
257                if let Some(channel) = channels.get(channel_id) {
258                    channel.send(message).await?;
259                }
260                Ok(())
261            }
262            MessageTarget::User(user_id) => {
263                // 路由到用户所在的渠道
264                let channel_id = self.router.find_user_channel(user_id);
265                if let Some(cid) = channel_id {
266                    let channels = self.channels.read();
267                    if let Some(channel) = channels.get(&cid) {
268                        channel.send(message).await?;
269                    }
270                }
271                Ok(())
272            }
273            MessageTarget::Session(session_id) => {
274                let channel_id = self.router.find_session_channel(session_id);
275                if let Some(cid) = channel_id {
276                    let channels = self.channels.read();
277                    if let Some(channel) = channels.get(&cid) {
278                        channel.send(message).await?;
279                    }
280                }
281                Ok(())
282            }
283        }
284    }
285
286    /// 接收消息(轮询所有渠道)
287    #[allow(clippy::await_holding_lock)]
288    pub async fn receive(&self) -> Layer4Result<Option<InboundMessage>> {
289        // 先检查队列
290        if let Some(msg) = self.message_queue.write().pop() {
291            return Ok(Some(msg));
292        }
293
294        // 轮询所有渠道
295        let channels = self.channels.read();
296        for (_, channel) in channels.iter() {
297            if let Some(msg) = channel.try_receive().await? {
298                // 更新路由信息
299                self.router
300                    .update_user_channel(&msg.user_id, &msg.channel_id);
301                if let Some(ref session_id) = msg.session_id {
302                    self.router
303                        .update_session_channel(session_id, &msg.channel_id);
304                }
305                return Ok(Some(msg));
306            }
307        }
308
309        Ok(None)
310    }
311
312    /// 接收所有待处理消息
313    #[allow(clippy::await_holding_lock)]
314    pub async fn receive_all(&self) -> Layer4Result<Vec<InboundMessage>> {
315        let mut messages = Vec::new();
316
317        // 先处理队列
318        messages.append(&mut self.message_queue.write());
319
320        // 轮询所有渠道
321        let channels = self.channels.read();
322        for (_, channel) in channels.iter() {
323            while let Some(msg) = channel.try_receive().await? {
324                messages.push(msg);
325            }
326        }
327
328        Ok(messages)
329    }
330
331    /// 渠道数量
332    pub fn channel_count(&self) -> usize {
333        self.channels.read().len()
334    }
335
336    /// 关闭所有渠道
337    #[allow(clippy::await_holding_lock)]
338    pub async fn close_all(&self) -> Layer4Result<()> {
339        let mut channels = self.channels.write();
340        for (id, channel) in channels.drain() {
341            if let Err(e) = channel.close().await {
342                tracing::error!("Failed to close channel {}: {}", id, e);
343            }
344        }
345        Ok(())
346    }
347}
348
349impl Default for ChannelGateway {
350    fn default() -> Self {
351        Self::new()
352    }
353}
354
355/// 消息路由器
356pub struct MessageRouter {
357    user_channels: RwLock<HashMap<String, String>>,
358    session_channels: RwLock<HashMap<String, String>>,
359    channel_registry: RwLock<HashMap<String, ChannelType>>,
360}
361
362impl MessageRouter {
363    pub fn new() -> Self {
364        Self {
365            user_channels: RwLock::new(HashMap::new()),
366            session_channels: RwLock::new(HashMap::new()),
367            channel_registry: RwLock::new(HashMap::new()),
368        }
369    }
370
371    pub fn register_channel(&self, channel_id: &str, channel_type: ChannelType) {
372        self.channel_registry
373            .write()
374            .insert(channel_id.to_string(), channel_type);
375    }
376
377    pub fn unregister_channel(&self, channel_id: &str) {
378        self.channel_registry.write().remove(channel_id);
379
380        // 清理用户和会话映射
381        self.user_channels.write().retain(|_, v| v != channel_id);
382        self.session_channels.write().retain(|_, v| v != channel_id);
383    }
384
385    pub fn update_user_channel(&self, user_id: &str, channel_id: &str) {
386        self.user_channels
387            .write()
388            .insert(user_id.to_string(), channel_id.to_string());
389    }
390
391    pub fn update_session_channel(&self, session_id: &str, channel_id: &str) {
392        self.session_channels
393            .write()
394            .insert(session_id.to_string(), channel_id.to_string());
395    }
396
397    pub fn find_user_channel(&self, user_id: &str) -> Option<String> {
398        self.user_channels.read().get(user_id).cloned()
399    }
400
401    pub fn find_session_channel(&self, session_id: &str) -> Option<String> {
402        self.session_channels.read().get(session_id).cloned()
403    }
404}
405
406impl Default for MessageRouter {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_inbound_message_creation() {
418        let msg = InboundMessage::new("cli-1", "user-1", "Hello");
419        assert_eq!(msg.channel_id, "cli-1");
420        assert_eq!(msg.user_id, "user-1");
421        assert_eq!(msg.content, "Hello");
422    }
423
424    #[test]
425    fn test_outbound_message_broadcast() {
426        let msg = OutboundMessage::broadcast("Hello all");
427        assert!(matches!(msg.target, MessageTarget::All));
428    }
429
430    #[test]
431    fn test_channel_gateway_creation() {
432        let gateway = ChannelGateway::new();
433        assert_eq!(gateway.channel_count(), 0);
434    }
435
436    #[test]
437    fn test_message_router() {
438        let router = MessageRouter::new();
439        router.update_user_channel("user-1", "cli-1");
440
441        let channel = router.find_user_channel("user-1");
442        assert_eq!(channel, Some("cli-1".to_string()));
443    }
444}