1use futures::{SinkExt, StreamExt};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::net::TcpListener;
5use tokio::sync::{Mutex, mpsc};
6use tokio_tungstenite::accept_async;
7
8use super::{GatewayMessage, GatewayResponse, GatewayTransport};
9
10pub struct WebSocketApi {
13 bind_addr: String,
14 clients: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
15}
16
17impl WebSocketApi {
18 pub fn new(bind_addr: impl Into<String>) -> Self {
19 Self {
20 bind_addr: bind_addr.into(),
21 clients: Arc::new(Mutex::new(HashMap::new())),
22 }
23 }
24}
25
26#[async_trait::async_trait]
27impl GatewayTransport for WebSocketApi {
28 fn name(&self) -> &str {
29 "ws-api"
30 }
31
32 async fn start(&self, tx: mpsc::UnboundedSender<GatewayMessage>) -> anyhow::Result<()> {
33 let listener = TcpListener::bind(&self.bind_addr).await?;
34 let clients = self.clients.clone();
35 tracing::info!("WebSocket API listening on {}", self.bind_addr);
36
37 tokio::spawn(async move {
38 loop {
39 match listener.accept().await {
40 Ok((stream, addr)) => {
41 tracing::debug!("WS connection from {}", addr);
42 let tx = tx.clone();
43 let clients = clients.clone();
44
45 tokio::spawn(async move {
46 match accept_async(stream).await {
47 Ok(ws_stream) => {
48 let (mut write, mut read) = ws_stream.split();
49 let chat_id = addr.to_string();
50 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<String>();
51 clients.lock().await.insert(chat_id.clone(), out_tx);
52
53 loop {
54 tokio::select! {
55 Some(outbound) = out_rx.recv() => {
56 if write
57 .send(tokio_tungstenite::tungstenite::Message::Text(outbound.into()))
58 .await
59 .is_err()
60 {
61 break;
62 }
63 }
64 incoming = read.next() => {
65 match incoming {
66 Some(Ok(msg)) => {
67 if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {
68 let _ = tx.send(GatewayMessage {
69 surface: "ws-api".into(),
70 user_id: "ws-user".into(),
71 chat_id: chat_id.clone(),
72 text: text.to_string(),
73 message_id: None,
74 });
75
76 let ack = serde_json::json!({"ack": "received"}).to_string();
77 let _ = write
78 .send(tokio_tungstenite::tungstenite::Message::Text(ack.into()))
79 .await;
80 }
81 }
82 Some(Err(e)) => {
83 tracing::error!("WS error: {}", e);
84 break;
85 }
86 None => break,
87 }
88 }
89 }
90 }
91 clients.lock().await.remove(&chat_id);
92 }
93 Err(e) => {
94 tracing::error!("WS handshake error: {}", e);
95 }
96 }
97 });
98 }
99 Err(e) => {
100 tracing::error!("Accept error: {}", e);
101 }
102 }
103 }
104 });
105
106 Ok(())
107 }
108
109 async fn send(&self, response: GatewayResponse) -> anyhow::Result<()> {
110 let payload = serde_json::json!({
111 "type": "message",
112 "text": response.text,
113 "reply_to": response.reply_to,
114 "buttons": response.buttons,
115 })
116 .to_string();
117
118 if let Some(client) = self.clients.lock().await.get(&response.chat_id).cloned() {
119 client
120 .send(payload)
121 .map_err(|_| anyhow::anyhow!("WebSocket client is no longer connected"))?;
122 Ok(())
123 } else {
124 anyhow::bail!("WebSocket client not connected: {}", response.chat_id)
125 }
126 }
127
128 async fn stop(&self) -> anyhow::Result<()> {
129 tracing::info!("WebSocket API stopped");
130 Ok(())
131 }
132}