Skip to main content

sparrow/gateway/
ws.rs

1use futures::{SinkExt, StreamExt};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::net::TcpListener;
5use tokio::sync::{Mutex, mpsc};
6use tokio_tungstenite::accept_async;
7
8use super::{GatewayMessage, GatewayResponse, GatewayTransport};
9
10// ─── WebSocket API Server ───────────────────────────────────────────────────────
11
12pub struct WebSocketApi {
13    bind_addr: String,
14    clients: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
15}
16
17impl WebSocketApi {
18    pub fn new(bind_addr: impl Into<String>) -> Self {
19        Self {
20            bind_addr: bind_addr.into(),
21            clients: Arc::new(Mutex::new(HashMap::new())),
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl GatewayTransport for WebSocketApi {
28    fn name(&self) -> &str {
29        "ws-api"
30    }
31
32    async fn start(&self, tx: mpsc::UnboundedSender<GatewayMessage>) -> anyhow::Result<()> {
33        let listener = TcpListener::bind(&self.bind_addr).await?;
34        let clients = self.clients.clone();
35        tracing::info!("WebSocket API listening on {}", self.bind_addr);
36
37        tokio::spawn(async move {
38            loop {
39                match listener.accept().await {
40                    Ok((stream, addr)) => {
41                        tracing::debug!("WS connection from {}", addr);
42                        let tx = tx.clone();
43                        let clients = clients.clone();
44
45                        tokio::spawn(async move {
46                            match accept_async(stream).await {
47                                Ok(ws_stream) => {
48                                    let (mut write, mut read) = ws_stream.split();
49                                    let chat_id = addr.to_string();
50                                    let (out_tx, mut out_rx) = mpsc::unbounded_channel::<String>();
51                                    clients.lock().await.insert(chat_id.clone(), out_tx);
52
53                                    loop {
54                                        tokio::select! {
55                                            Some(outbound) = out_rx.recv() => {
56                                                if write
57                                                    .send(tokio_tungstenite::tungstenite::Message::Text(outbound.into()))
58                                                    .await
59                                                    .is_err()
60                                                {
61                                                    break;
62                                                }
63                                            }
64                                            incoming = read.next() => {
65                                                match incoming {
66                                                    Some(Ok(msg)) => {
67                                                        if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {
68                                                            let _ = tx.send(GatewayMessage {
69                                                                surface: "ws-api".into(),
70                                                                user_id: "ws-user".into(),
71                                                                chat_id: chat_id.clone(),
72                                                                text: text.to_string(),
73                                                                message_id: None,
74                                                            });
75
76                                                            let ack = serde_json::json!({"ack": "received"}).to_string();
77                                                            let _ = write
78                                                                .send(tokio_tungstenite::tungstenite::Message::Text(ack.into()))
79                                                                .await;
80                                                        }
81                                                    }
82                                                    Some(Err(e)) => {
83                                                        tracing::error!("WS error: {}", e);
84                                                        break;
85                                                    }
86                                                    None => break,
87                                                }
88                                            }
89                                        }
90                                    }
91                                    clients.lock().await.remove(&chat_id);
92                                }
93                                Err(e) => {
94                                    tracing::error!("WS handshake error: {}", e);
95                                }
96                            }
97                        });
98                    }
99                    Err(e) => {
100                        tracing::error!("Accept error: {}", e);
101                    }
102                }
103            }
104        });
105
106        Ok(())
107    }
108
109    async fn send(&self, response: GatewayResponse) -> anyhow::Result<()> {
110        let payload = serde_json::json!({
111            "type": "message",
112            "text": response.text,
113            "reply_to": response.reply_to,
114            "buttons": response.buttons,
115        })
116        .to_string();
117
118        if let Some(client) = self.clients.lock().await.get(&response.chat_id).cloned() {
119            client
120                .send(payload)
121                .map_err(|_| anyhow::anyhow!("WebSocket client is no longer connected"))?;
122            Ok(())
123        } else {
124            anyhow::bail!("WebSocket client not connected: {}", response.chat_id)
125        }
126    }
127
128    async fn stop(&self) -> anyhow::Result<()> {
129        tracing::info!("WebSocket API stopped");
130        Ok(())
131    }
132}