sh_layer4/channel_gateway/adapter/
websocket.rs1use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::collections::VecDeque;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Mutex as AsyncMutex;
12
13use crate::channel_gateway::{Channel, ChannelType, InboundMessage, OutboundMessage};
14use crate::types::Layer4Result;
15
16use futures::{SinkExt, StreamExt};
17use tokio::net::TcpStream;
18use tokio_tungstenite::{
19 connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream,
20};
21
22#[derive(Debug, Clone)]
24pub struct WebSocketChannelConfig {
25 pub url: String,
26 pub reconnect_attempts: u32,
27 pub reconnect_interval_ms: u64,
28 pub ping_interval_ms: u64,
29 pub connect_timeout_ms: u64,
30}
31
32impl Default for WebSocketChannelConfig {
33 fn default() -> Self {
34 Self {
35 url: "ws://localhost:8080/ws".to_string(),
36 reconnect_attempts: 3,
37 reconnect_interval_ms: 1000,
38 ping_interval_ms: 30000,
39 connect_timeout_ms: 10000,
40 }
41 }
42}
43
44type WsConnection = WebSocketStream<MaybeTlsStream<TcpStream>>;
46
47pub struct WebSocketChannel {
49 channel_id: String,
50 config: WebSocketChannelConfig,
51 connected: RwLock<bool>,
52 message_queue: RwLock<VecDeque<InboundMessage>>,
53 sessions: RwLock<HashMap<String, String>>, ws_sender: Arc<AsyncMutex<Option<futures::stream::SplitSink<WsConnection, WsMessage>>>>,
56}
57
58impl WebSocketChannel {
59 pub fn new(channel_id: impl Into<String>, config: WebSocketChannelConfig) -> Self {
61 Self {
62 channel_id: channel_id.into(),
63 config,
64 connected: RwLock::new(false),
65 message_queue: RwLock::new(VecDeque::new()),
66 sessions: RwLock::new(HashMap::new()),
67 ws_sender: Arc::new(AsyncMutex::new(None)),
68 }
69 }
70
71 pub fn default_channel() -> Self {
73 Self::new("ws-default", WebSocketChannelConfig::default())
74 }
75
76 pub async fn connect(&self) -> Layer4Result<()> {
78 let url = self.config.url.clone();
79 let timeout = Duration::from_millis(self.config.connect_timeout_ms);
80
81 let connect_future = async { connect_async(&url).await };
82
83 let result = tokio::time::timeout(timeout, connect_future).await;
84
85 match result {
86 Ok(Ok((stream, _))) => {
87 let (sink, _stream) = stream.split();
89 *self.ws_sender.lock().await = Some(sink);
90 *self.connected.write() = true;
91 tracing::info!("WebSocket connected to {}", url);
92 Ok(())
93 }
94 Ok(Err(e)) => {
95 tracing::error!("WebSocket connection failed: {}", e);
96 Err(anyhow::anyhow!("WebSocket connection failed: {}", e))
97 }
98 Err(_) => {
99 tracing::error!("WebSocket connection timeout");
100 Err(anyhow::anyhow!("WebSocket connection timeout"))
101 }
102 }
103 }
104
105 pub async fn connect_with_retry(&self) -> Layer4Result<()> {
107 let mut attempts = 0;
108 let max_attempts = self.config.reconnect_attempts;
109 let interval = Duration::from_millis(self.config.reconnect_interval_ms);
110
111 loop {
112 match self.connect().await {
113 Ok(_) => return Ok(()),
114 Err(e) => {
115 attempts += 1;
116 if attempts >= max_attempts {
117 return Err(e);
118 }
119 tracing::warn!(
120 "WebSocket connection attempt {}/{} failed, retrying...",
121 attempts,
122 max_attempts
123 );
124 tokio::time::sleep(interval).await;
125 }
126 }
127 }
128 }
129
130 pub async fn send_raw(&self, message: WsMessage) -> Layer4Result<()> {
132 let mut sender = self.ws_sender.lock().await;
133 if let Some(ref mut sink) = *sender {
134 sink.send(message).await?;
135 Ok(())
136 } else {
137 Err(anyhow::anyhow!("WebSocket not connected"))
138 }
139 }
140
141 pub async fn send_text(&self, text: &str) -> Layer4Result<()> {
143 self.send_raw(WsMessage::Text(text.into())).await
144 }
145
146 pub async fn send_binary(&self, data: Vec<u8>) -> Layer4Result<()> {
148 self.send_raw(WsMessage::Binary(data.into())).await
149 }
150
151 pub fn register_session(&self, session_id: &str, user_id: &str) {
153 self.sessions
154 .write()
155 .insert(session_id.to_string(), user_id.to_string());
156 }
157
158 pub fn unregister_session(&self, session_id: &str) {
160 self.sessions.write().remove(session_id);
161 }
162
163 pub fn receive_message(&self, session_id: &str, content: &str) {
165 let user_id = self
166 .sessions
167 .read()
168 .get(session_id)
169 .cloned()
170 .unwrap_or_default();
171 let message = InboundMessage::new(&self.channel_id, &user_id, content)
172 .with_session(session_id)
173 .with_metadata(serde_json::json!({
174 "source": "websocket",
175 "session_id": session_id
176 }));
177 self.message_queue.write().push_back(message);
178 }
179
180 pub fn active_sessions(&self) -> usize {
182 self.sessions.read().len()
183 }
184}
185
186#[async_trait]
187impl Channel for WebSocketChannel {
188 fn id(&self) -> &str {
189 &self.channel_id
190 }
191
192 fn channel_type(&self) -> ChannelType {
193 ChannelType::WebSocket
194 }
195
196 async fn send(&self, message: &OutboundMessage) -> Layer4Result<()> {
197 if !*self.connected.read() {
198 return Err(anyhow::anyhow!("Channel not connected"));
199 }
200
201 let payload = serde_json::json!({
203 "message_id": message.message_id,
204 "content": message.content,
205 "message_type": message.message_type,
206 "target": message.target,
207 "metadata": message.metadata,
208 "timestamp": message.timestamp.to_rfc3339(),
209 });
210
211 self.send_text(&payload.to_string()).await?;
213
214 tracing::debug!("WebSocket channel sent message {}", message.message_id);
215 Ok(())
216 }
217
218 async fn try_receive(&self) -> Layer4Result<Option<InboundMessage>> {
219 if !*self.connected.read() {
220 return Err(anyhow::anyhow!("Channel not connected"));
221 }
222
223 Ok(self.message_queue.write().pop_front())
224 }
225
226 fn is_connected(&self) -> bool {
227 *self.connected.read()
228 }
229
230 async fn close(&self) -> Layer4Result<()> {
231 let mut sender = self.ws_sender.lock().await;
233 if let Some(ref mut sink) = *sender {
234 sink.close().await?;
235 }
236 *sender = None;
237
238 *self.connected.write() = false;
239 self.message_queue.write().clear();
240 self.sessions.write().clear();
241 tracing::info!("WebSocket channel closed");
242 Ok(())
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_websocket_channel_creation() {
252 let channel = WebSocketChannel::default_channel();
253 assert_eq!(channel.id(), "ws-default");
254 assert!(!channel.is_connected());
256 }
257
258 #[test]
259 fn test_websocket_config_default() {
260 let config = WebSocketChannelConfig::default();
261 assert_eq!(config.reconnect_attempts, 3);
262 assert_eq!(config.ping_interval_ms, 30000);
263 assert_eq!(config.connect_timeout_ms, 10000);
264 }
265
266 #[test]
267 fn test_websocket_session_management() {
268 let channel = WebSocketChannel::default_channel();
269 channel.register_session("session-1", "user-1");
270
271 assert_eq!(channel.active_sessions(), 1);
272
273 channel.unregister_session("session-1");
274 assert_eq!(channel.active_sessions(), 0);
275 }
276
277 #[test]
278 fn test_websocket_receive_message() {
279 let channel = WebSocketChannel::default_channel();
280 *channel.connected.write() = true;
282 channel.register_session("session-1", "user-1");
283 channel.receive_message("session-1", "Hello");
284
285 let count = channel.message_queue.read().len();
286 assert_eq!(count, 1);
287 }
288
289 #[tokio::test]
290 async fn test_websocket_channel_close() {
291 let channel = WebSocketChannel::default_channel();
292 *channel.connected.write() = true;
294 channel.register_session("session-1", "user-1");
295 channel.close().await.unwrap();
296
297 assert!(!channel.is_connected());
298 assert_eq!(channel.active_sessions(), 0);
299 }
300
301 #[tokio::test]
302 async fn test_send_without_connection() {
303 let channel = WebSocketChannel::default_channel();
304 let msg = OutboundMessage::to_user("test-user", "hello");
306 let result = channel.send(&msg).await;
307 assert!(result.is_err());
308 }
309}