Skip to main content

xrpl_mithril_client/
websocket.rs

1//! WebSocket client for connecting to XRPL nodes.
2//!
3//! Supports both request-response and subscription patterns over a persistent
4//! WebSocket connection.
5//!
6//! # Architecture
7//!
8//! The client spawns a background `tokio` task that manages the WebSocket
9//! connection. Requests are sent via an internal channel, and responses are
10//! dispatched back to the caller via `oneshot` channels. Subscription messages
11//! (those without an `id` field) are routed to subscription streams.
12//!
13//! # Examples
14//!
15//! ```no_run
16//! use xrpl_mithril_client::{WebSocketClient, Client};
17//! use xrpl_mithril_models::requests::server::ServerInfoRequest;
18//!
19//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
20//! let client = WebSocketClient::connect("wss://s1.ripple.com:443").await?;
21//! let resp = client.request(ServerInfoRequest {}).await?;
22//! println!("Server: {:?}", resp.info.build_version);
23//! # Ok(())
24//! # }
25//! ```
26
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
29use std::sync::Arc;
30
31use futures::{SinkExt, StreamExt};
32use tokio::sync::{mpsc, oneshot};
33use tokio_tungstenite::tungstenite::Message;
34use xrpl_mithril_models::requests::XrplRequest;
35
36use crate::client::Client;
37use crate::error::ClientError;
38use crate::subscription::SubscriptionStream;
39
40/// A command sent from the client API to the background WebSocket task.
41enum WsCommand {
42    /// Send a request and expect a response.
43    Request {
44        payload: serde_json::Value,
45        response_tx: oneshot::Sender<Result<serde_json::Value, ClientError>>,
46    },
47    /// Register a subscription stream.
48    Subscribe {
49        stream_tx: mpsc::UnboundedSender<serde_json::Value>,
50    },
51}
52
53/// WebSocket client for the XRP Ledger.
54///
55/// Maintains a persistent WebSocket connection with a background task that
56/// handles message routing. Supports both request-response (via `id` tracking)
57/// and subscription streams.
58pub struct WebSocketClient {
59    command_tx: mpsc::UnboundedSender<WsCommand>,
60    next_id: AtomicU64,
61    connected: Arc<AtomicBool>,
62    _task: tokio::task::JoinHandle<()>,
63}
64
65impl std::fmt::Debug for WebSocketClient {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("WebSocketClient")
68            .field("connected", &self.connected.load(Ordering::Relaxed))
69            .finish()
70    }
71}
72
73impl WebSocketClient {
74    /// Connect to an XRPL WebSocket endpoint.
75    ///
76    /// Spawns a background task that manages the connection and routes
77    /// messages between the API and the WebSocket.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`ClientError::WebSocket`] if the connection fails.
82    pub async fn connect(url: &str) -> Result<Self, ClientError> {
83        let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
84            .await
85            .map_err(|e| ClientError::WebSocket(e.to_string()))?;
86
87        let (command_tx, command_rx) = mpsc::unbounded_channel();
88        let connected = Arc::new(AtomicBool::new(true));
89        let connected_clone = Arc::clone(&connected);
90
91        let task = tokio::spawn(Self::run_loop(ws_stream, command_rx, connected_clone));
92
93        Ok(Self {
94            command_tx,
95            next_id: AtomicU64::new(1),
96            connected,
97            _task: task,
98        })
99    }
100
101    /// Subscribe to receive raw subscription messages.
102    ///
103    /// Returns a [`SubscriptionStream`] that yields `serde_json::Value`
104    /// messages pushed by the server (ledger closes, transactions, etc.).
105    ///
106    /// You must also send a [`SubscribeRequest`](xrpl_mithril_models::requests::subscription::SubscribeRequest)
107    /// via [`Client::request`] to tell the server what to subscribe to.
108    ///
109    /// # Errors
110    ///
111    /// Returns [`ClientError::ConnectionClosed`] if the WebSocket is disconnected.
112    pub fn subscribe_stream(&self) -> Result<SubscriptionStream, ClientError> {
113        let (stream_tx, stream_rx) = mpsc::unbounded_channel();
114        self.command_tx
115            .send(WsCommand::Subscribe { stream_tx })
116            .map_err(|_| ClientError::ConnectionClosed {
117                reason: "background task ended".into(),
118            })?;
119        Ok(SubscriptionStream::new(stream_rx))
120    }
121
122    /// Check if the WebSocket connection is still alive.
123    #[must_use]
124    pub fn is_connected(&self) -> bool {
125        self.connected.load(Ordering::Relaxed)
126    }
127
128    /// The background event loop that manages the WebSocket connection.
129    async fn run_loop(
130        ws_stream: tokio_tungstenite::WebSocketStream<
131            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
132        >,
133        mut command_rx: mpsc::UnboundedReceiver<WsCommand>,
134        connected: Arc<AtomicBool>,
135    ) {
136        let (mut ws_sink, mut ws_source) = ws_stream.split();
137
138        // Pending request-response pairs, keyed by request ID
139        let mut pending: HashMap<u64, oneshot::Sender<Result<serde_json::Value, ClientError>>> =
140            HashMap::new();
141
142        // Subscription stream senders
143        let mut subscribers: Vec<mpsc::UnboundedSender<serde_json::Value>> = Vec::new();
144
145        loop {
146            tokio::select! {
147                // Handle commands from the client API
148                Some(cmd) = command_rx.recv() => {
149                    match cmd {
150                        WsCommand::Request { payload, response_tx } => {
151                            let id = payload.get("id")
152                                .and_then(|v| v.as_u64())
153                                .unwrap_or(0);
154                            pending.insert(id, response_tx);
155
156                            let msg = Message::Text(payload.to_string().into());
157                            if let Err(e) = ws_sink.send(msg).await {
158                                if let Some(tx) = pending.remove(&id) {
159                                    let _ = tx.send(Err(ClientError::WebSocket(e.to_string())));
160                                }
161                            }
162                        }
163                        WsCommand::Subscribe { stream_tx } => {
164                            subscribers.push(stream_tx);
165                        }
166                    }
167                }
168
169                // Handle messages from the WebSocket
170                Some(msg_result) = ws_source.next() => {
171                    match msg_result {
172                        Ok(Message::Text(text)) => {
173                            if let Ok(value) = serde_json::from_str::<serde_json::Value>(&text) {
174                                // Check if this is a response to a pending request
175                                if let Some(id) = value.get("id").and_then(|v| v.as_u64()) {
176                                    if let Some(tx) = pending.remove(&id) {
177                                        // Extract the result, checking for errors
178                                        let result = extract_result(&value);
179                                        let _ = tx.send(result);
180                                    }
181                                } else {
182                                    // No id — this is a subscription message
183                                    // Remove closed subscribers
184                                    subscribers.retain(|tx| {
185                                        tx.send(value.clone()).is_ok()
186                                    });
187                                }
188                            }
189                        }
190                        Ok(Message::Close(_)) => {
191                            connected.store(false, Ordering::Relaxed);
192                            break;
193                        }
194                        Ok(Message::Ping(data)) => {
195                            let _ = ws_sink.send(Message::Pong(data)).await;
196                        }
197                        Err(e) => {
198                            tracing::error!(error = %e, "WebSocket error");
199                            connected.store(false, Ordering::Relaxed);
200                            break;
201                        }
202                        _ => {}
203                    }
204                }
205
206                else => break,
207            }
208        }
209
210        // Clean up: notify all pending requests that the connection closed
211        for (_id, tx) in pending {
212            let _ = tx.send(Err(ClientError::ConnectionClosed {
213                reason: "WebSocket connection closed".into(),
214            }));
215        }
216        connected.store(false, Ordering::Relaxed);
217    }
218}
219
220/// Extract the result from a WebSocket response, checking for errors.
221fn extract_result(value: &serde_json::Value) -> Result<serde_json::Value, ClientError> {
222    // Check for error status in the result
223    if let Some(result) = value.get("result") {
224        if let Some(status) = result.get("status").and_then(|v| v.as_str()) {
225            if status == "error" {
226                let error = result.get("error").and_then(|v| v.as_str()).map(String::from);
227                let code = result
228                    .get("error_code")
229                    .and_then(|v| v.as_i64())
230                    .map(|c| c as i32);
231                let message = result
232                    .get("error_message")
233                    .and_then(|v| v.as_str())
234                    .unwrap_or_else(|| error.as_deref().unwrap_or("unknown error"))
235                    .to_string();
236                return Err(ClientError::RpcError {
237                    code,
238                    message,
239                    error,
240                });
241            }
242        }
243        return Ok(result.clone());
244    }
245
246    // Some responses don't have a "result" wrapper
247    Ok(value.clone())
248}
249
250impl Client for WebSocketClient {
251    async fn request<R: XrplRequest + Send + Sync>(
252        &self,
253        request: R,
254    ) -> Result<R::Response, ClientError> {
255        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
256
257        // Serialize the request with an id field
258        let mut params = serde_json::to_value(&request)?;
259        if let Some(map) = params.as_object_mut() {
260            map.insert("id".into(), serde_json::Value::Number(id.into()));
261            map.insert(
262                "command".into(),
263                serde_json::Value::String(request.method().into()),
264            );
265        }
266
267        let (response_tx, response_rx) = oneshot::channel();
268
269        self.command_tx
270            .send(WsCommand::Request {
271                payload: params,
272                response_tx,
273            })
274            .map_err(|_| ClientError::ConnectionClosed {
275                reason: "background task ended".into(),
276            })?;
277
278        let result = response_rx.await.map_err(|_| ClientError::ConnectionClosed {
279            reason: "response channel dropped".into(),
280        })??;
281
282        let response: R::Response = serde_json::from_value(result)?;
283        Ok(response)
284    }
285}