sh_layer4/channel_gateway/
mod.rs1pub mod adapter;
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use sh_layer3::generate_short_id;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use crate::types::Layer4Result;
15
16#[async_trait]
18pub trait Channel: Send + Sync {
19 fn id(&self) -> &str;
21
22 fn channel_type(&self) -> ChannelType;
24
25 async fn send(&self, message: &OutboundMessage) -> Layer4Result<()>;
27
28 async fn try_receive(&self) -> Layer4Result<Option<InboundMessage>>;
30
31 fn is_connected(&self) -> bool;
33
34 async fn close(&self) -> Layer4Result<()>;
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub enum ChannelType {
41 Cli,
42 Http,
43 WebSocket,
44 Mqtt,
45 Custom,
46}
47
48impl std::fmt::Display for ChannelType {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 Self::Cli => write!(f, "cli"),
52 Self::Http => write!(f, "http"),
53 Self::WebSocket => write!(f, "websocket"),
54 Self::Mqtt => write!(f, "mqtt"),
55 Self::Custom => write!(f, "custom"),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct InboundMessage {
63 pub message_id: String,
65 pub channel_id: String,
67 pub user_id: String,
69 pub session_id: Option<String>,
71 pub content: String,
73 pub message_type: MessageType,
75 pub metadata: serde_json::Value,
77 pub timestamp: chrono::DateTime<chrono::Utc>,
79}
80
81impl InboundMessage {
82 pub fn new(
83 channel_id: impl Into<String>,
84 user_id: impl Into<String>,
85 content: impl Into<String>,
86 ) -> Self {
87 Self {
88 message_id: generate_short_id(),
89 channel_id: channel_id.into(),
90 user_id: user_id.into(),
91 session_id: None,
92 content: content.into(),
93 message_type: MessageType::Text,
94 metadata: serde_json::Value::Null,
95 timestamp: chrono::Utc::now(),
96 }
97 }
98
99 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
100 self.session_id = Some(session_id.into());
101 self
102 }
103
104 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
105 self.metadata = metadata;
106 self
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct OutboundMessage {
113 pub message_id: String,
115 pub content: String,
117 pub message_type: MessageType,
119 pub target: MessageTarget,
121 pub metadata: serde_json::Value,
123 pub timestamp: chrono::DateTime<chrono::Utc>,
125}
126
127impl OutboundMessage {
128 pub fn new(content: impl Into<String>, target: MessageTarget) -> Self {
129 Self {
130 message_id: generate_short_id(),
131 content: content.into(),
132 message_type: MessageType::Text,
133 target,
134 metadata: serde_json::Value::Null,
135 timestamp: chrono::Utc::now(),
136 }
137 }
138
139 pub fn broadcast(content: impl Into<String>) -> Self {
140 Self::new(content, MessageTarget::All)
141 }
142
143 pub fn to_channel(channel_id: impl Into<String>, content: impl Into<String>) -> Self {
144 Self::new(content, MessageTarget::Channel(channel_id.into()))
145 }
146
147 pub fn to_user(user_id: impl Into<String>, content: impl Into<String>) -> Self {
148 Self::new(content, MessageTarget::User(user_id.into()))
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum MessageTarget {
155 All,
157 Channel(String),
159 User(String),
161 Session(String),
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
167pub enum MessageType {
168 Text,
169 Json,
170 Binary,
171 Command,
172 Event,
173 Error,
174}
175
176pub struct ChannelGateway {
178 channels: RwLock<HashMap<String, Box<dyn Channel>>>,
179 router: MessageRouter,
180 message_queue: RwLock<Vec<InboundMessage>>,
181}
182
183impl ChannelGateway {
184 pub fn new() -> Self {
186 Self {
187 channels: RwLock::new(HashMap::new()),
188 router: MessageRouter::new(),
189 message_queue: RwLock::new(Vec::new()),
190 }
191 }
192
193 pub async fn register_channel(&self, channel: Box<dyn Channel>) -> Layer4Result<()> {
195 let id = channel.id().to_string();
196 let channel_type = channel.channel_type();
197
198 self.channels.write().insert(id.clone(), channel);
199 self.router.register_channel(&id, channel_type);
200
201 tracing::info!("Registered channel: {} ({})", id, channel_type);
202 Ok(())
203 }
204
205 pub async fn unregister_channel(&self, channel_id: &str) -> Layer4Result<bool> {
207 let channel = self.channels.write().remove(channel_id);
208 if let Some(channel) = channel {
209 channel.close().await?;
210 self.router.unregister_channel(channel_id);
211 tracing::info!("Unregistered channel: {}", channel_id);
212 Ok(true)
213 } else {
214 Ok(false)
215 }
216 }
217
218 pub fn get_channel(&self, _channel_id: &str) -> Option<Arc<dyn Channel>> {
220 None
223 }
224
225 pub fn list_channels(&self) -> Vec<(String, ChannelType)> {
227 self.channels
228 .read()
229 .iter()
230 .map(|(id, ch)| (id.clone(), ch.channel_type()))
231 .collect()
232 }
233
234 #[allow(clippy::await_holding_lock)]
236 pub async fn broadcast(&self, message: &OutboundMessage) -> Layer4Result<()> {
237 let channels = self.channels.read();
238 for (id, channel) in channels.iter() {
239 if let Err(e) = channel.send(message).await {
240 tracing::error!("Failed to send to channel {}: {}", id, e);
241 }
242 }
243 Ok(())
244 }
245
246 #[allow(clippy::await_holding_lock)]
248 pub async fn send_to(
249 &self,
250 target: &MessageTarget,
251 message: &OutboundMessage,
252 ) -> Layer4Result<()> {
253 match target {
254 MessageTarget::All => self.broadcast(message).await,
255 MessageTarget::Channel(channel_id) => {
256 let channels = self.channels.read();
257 if let Some(channel) = channels.get(channel_id) {
258 channel.send(message).await?;
259 }
260 Ok(())
261 }
262 MessageTarget::User(user_id) => {
263 let channel_id = self.router.find_user_channel(user_id);
265 if let Some(cid) = channel_id {
266 let channels = self.channels.read();
267 if let Some(channel) = channels.get(&cid) {
268 channel.send(message).await?;
269 }
270 }
271 Ok(())
272 }
273 MessageTarget::Session(session_id) => {
274 let channel_id = self.router.find_session_channel(session_id);
275 if let Some(cid) = channel_id {
276 let channels = self.channels.read();
277 if let Some(channel) = channels.get(&cid) {
278 channel.send(message).await?;
279 }
280 }
281 Ok(())
282 }
283 }
284 }
285
286 #[allow(clippy::await_holding_lock)]
288 pub async fn receive(&self) -> Layer4Result<Option<InboundMessage>> {
289 if let Some(msg) = self.message_queue.write().pop() {
291 return Ok(Some(msg));
292 }
293
294 let channels = self.channels.read();
296 for (_, channel) in channels.iter() {
297 if let Some(msg) = channel.try_receive().await? {
298 self.router
300 .update_user_channel(&msg.user_id, &msg.channel_id);
301 if let Some(ref session_id) = msg.session_id {
302 self.router
303 .update_session_channel(session_id, &msg.channel_id);
304 }
305 return Ok(Some(msg));
306 }
307 }
308
309 Ok(None)
310 }
311
312 #[allow(clippy::await_holding_lock)]
314 pub async fn receive_all(&self) -> Layer4Result<Vec<InboundMessage>> {
315 let mut messages = Vec::new();
316
317 messages.append(&mut self.message_queue.write());
319
320 let channels = self.channels.read();
322 for (_, channel) in channels.iter() {
323 while let Some(msg) = channel.try_receive().await? {
324 messages.push(msg);
325 }
326 }
327
328 Ok(messages)
329 }
330
331 pub fn channel_count(&self) -> usize {
333 self.channels.read().len()
334 }
335
336 #[allow(clippy::await_holding_lock)]
338 pub async fn close_all(&self) -> Layer4Result<()> {
339 let mut channels = self.channels.write();
340 for (id, channel) in channels.drain() {
341 if let Err(e) = channel.close().await {
342 tracing::error!("Failed to close channel {}: {}", id, e);
343 }
344 }
345 Ok(())
346 }
347}
348
349impl Default for ChannelGateway {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355pub struct MessageRouter {
357 user_channels: RwLock<HashMap<String, String>>,
358 session_channels: RwLock<HashMap<String, String>>,
359 channel_registry: RwLock<HashMap<String, ChannelType>>,
360}
361
362impl MessageRouter {
363 pub fn new() -> Self {
364 Self {
365 user_channels: RwLock::new(HashMap::new()),
366 session_channels: RwLock::new(HashMap::new()),
367 channel_registry: RwLock::new(HashMap::new()),
368 }
369 }
370
371 pub fn register_channel(&self, channel_id: &str, channel_type: ChannelType) {
372 self.channel_registry
373 .write()
374 .insert(channel_id.to_string(), channel_type);
375 }
376
377 pub fn unregister_channel(&self, channel_id: &str) {
378 self.channel_registry.write().remove(channel_id);
379
380 self.user_channels.write().retain(|_, v| v != channel_id);
382 self.session_channels.write().retain(|_, v| v != channel_id);
383 }
384
385 pub fn update_user_channel(&self, user_id: &str, channel_id: &str) {
386 self.user_channels
387 .write()
388 .insert(user_id.to_string(), channel_id.to_string());
389 }
390
391 pub fn update_session_channel(&self, session_id: &str, channel_id: &str) {
392 self.session_channels
393 .write()
394 .insert(session_id.to_string(), channel_id.to_string());
395 }
396
397 pub fn find_user_channel(&self, user_id: &str) -> Option<String> {
398 self.user_channels.read().get(user_id).cloned()
399 }
400
401 pub fn find_session_channel(&self, session_id: &str) -> Option<String> {
402 self.session_channels.read().get(session_id).cloned()
403 }
404}
405
406impl Default for MessageRouter {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_inbound_message_creation() {
418 let msg = InboundMessage::new("cli-1", "user-1", "Hello");
419 assert_eq!(msg.channel_id, "cli-1");
420 assert_eq!(msg.user_id, "user-1");
421 assert_eq!(msg.content, "Hello");
422 }
423
424 #[test]
425 fn test_outbound_message_broadcast() {
426 let msg = OutboundMessage::broadcast("Hello all");
427 assert!(matches!(msg.target, MessageTarget::All));
428 }
429
430 #[test]
431 fn test_channel_gateway_creation() {
432 let gateway = ChannelGateway::new();
433 assert_eq!(gateway.channel_count(), 0);
434 }
435
436 #[test]
437 fn test_message_router() {
438 let router = MessageRouter::new();
439 router.update_user_channel("user-1", "cli-1");
440
441 let channel = router.find_user_channel("user-1");
442 assert_eq!(channel, Some("cli-1".to_string()));
443 }
444}