surreal_client/engines/
ws_cbor.rs1use async_trait::async_trait;
9use ciborium::Value as CborValue;
10use futures_util::stream::{SplitSink, SplitStream};
11use futures_util::{SinkExt, StreamExt};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, Ordering::SeqCst};
15use tokio::net::TcpStream;
16use tokio::sync::{Mutex, oneshot};
17use tokio_tungstenite::MaybeTlsStream;
18use tokio_tungstenite::tungstenite::client::IntoClientRequest;
19use tokio_tungstenite::tungstenite::http::HeaderValue;
20use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
21use tokio_tungstenite::{WebSocketStream, connect_async, tungstenite::Message};
22use tracing::{Instrument as _, warn};
23
24use crate::SurrealConnection;
25use crate::{
26 engine::Engine,
27 error::{Result, SurrealError},
28};
29
30#[derive(Debug, Clone)]
32struct RouterRequest {
33 id: String,
34 method: String,
35 params: Option<CborValue>,
36}
37
38type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
39
40pub struct WsCborEngine {
46 sink: Arc<Mutex<SplitSink<WsStream, Message>>>,
47 stream: Arc<Mutex<SplitStream<WsStream>>>,
48 msg_id: AtomicU64,
49 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<CborValue>>>>,
50 task_handle: Option<tokio::task::JoinHandle<()>>,
51}
52
53impl WsCborEngine {
54 pub async fn from_connection(connect: &SurrealConnection) -> Result<Self> {
55 let base_url = connect
56 .url
57 .as_ref()
58 .ok_or_else(|| SurrealError::Connection("URL is required to connect".to_string()))?;
59
60 let mut ws_url = if let Some(rest) = base_url.strip_prefix("cbor://") {
61 format!("ws://{}", rest)
62 } else {
63 base_url.clone()
64 };
65 if !ws_url.ends_with("/rpc") {
66 if ws_url.ends_with('/') {
67 ws_url.push_str("rpc");
68 } else {
69 ws_url.push_str("/rpc");
70 }
71 }
72
73 let mut request = ws_url
74 .as_str()
75 .into_client_request()
76 .map_err(|e| SurrealError::Connection(format!("Invalid WebSocket URL: {}", e)))?;
77
78 request
79 .headers_mut()
80 .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("cbor"));
81
82 let (stream, _response) = connect_async(request).await.map_err(|e| {
83 SurrealError::Connection(format!("Failed to connect to WebSocket: {}", e))
84 })?;
85
86 let (sink, stream) = stream.split();
87
88 let mut engine = Self {
89 sink: Arc::new(Mutex::new(sink)),
90 stream: Arc::new(Mutex::new(stream)),
91 msg_id: AtomicU64::new(0),
92 pending_requests: Arc::new(Mutex::new(HashMap::new())),
93 task_handle: None,
94 };
95
96 let task_handle = engine.handle_messages();
97 engine.task_handle = Some(task_handle);
98
99 connect.init_engine(&mut engine).await?;
100
101 Ok(engine)
102 }
103
104 fn handle_messages(&self) -> tokio::task::JoinHandle<()> {
105 let stream = Arc::clone(&self.stream);
106 let pending_requests = Arc::clone(&self.pending_requests);
107
108 tokio::spawn(
109 async move {
110 loop {
111 let msg = {
112 let mut stream_guard = stream.lock().await;
113 stream_guard.next().await
114 };
115
116 let msg = match msg {
117 None => break,
118 Some(Err(e)) => {
119 warn!(error = %e, "CBOR ws receive error");
120 break;
121 }
122 Some(Ok(msg)) => msg,
123 };
124
125 match msg {
126 Message::Text(_text) => {
127 }
129 Message::Binary(binary) => {
130 match ciborium::from_reader(binary.as_ref()) {
132 Ok(cbor_response) => {
133 if let CborValue::Map(map) = cbor_response {
134 let mut id_str = None;
135 let mut result = None;
136 let mut error = None;
137
138 for (key, value) in &map {
139 if let CborValue::Text(k) = key {
140 match k.as_str() {
141 "id" => {
142 if let CborValue::Text(id) = value {
143 id_str = Some(id.clone());
144 }
145 }
146 "result" => result = Some(value.clone()),
147 "error" => error = Some(value.clone()),
148 _ => {}
149 }
150 }
151 }
152
153 if let Some(id) = id_str {
154 let tx = {
155 let mut pending = pending_requests.lock().await;
156 pending.remove(&id)
157 };
158
159 if let Some(tx) = tx {
160 if let Some(err) = error {
161 let _ = tx.send(CborValue::Map(vec![(
162 CborValue::Text("error".to_string()),
163 err,
164 )]));
165 } else if let Some(res) = result {
166 let _ = tx.send(res);
167 } else {
168 let _ = tx.send(CborValue::Null);
169 }
170 }
171 }
172 }
173 }
174 Err(e) => {
175 warn!(error = %e, bytes = binary.len(), "CBOR parse failed");
176 }
177 }
178 }
179 Message::Ping(_) => {}
180 Message::Pong(_) => {}
181 Message::Close(_) => break,
182 _ => {}
183 }
184 }
185 }
186 .in_current_span(),
187 )
188 }
189}
190
191#[async_trait]
192impl Engine for WsCborEngine {
193 async fn send_message_cbor(&mut self, method: &str, params: CborValue) -> Result<CborValue> {
194 let (tx, rx) = oneshot::channel();
195 let id = self.msg_id.fetch_add(1, SeqCst).to_string();
196
197 {
198 let mut pending = self.pending_requests.lock().await;
199 pending.insert(id.clone(), tx);
200 }
201
202 let request = RouterRequest {
203 id: id.clone(),
204 method: method.to_string(),
205 params: Some(params),
206 };
207
208 let mut request_map = vec![
209 (
210 CborValue::Text("id".to_string()),
211 CborValue::Text(request.id),
212 ),
213 (
214 CborValue::Text("method".to_string()),
215 CborValue::Text(request.method),
216 ),
217 ];
218
219 if let Some(params) = request.params {
220 request_map.push((CborValue::Text("params".to_string()), params));
221 }
222
223 let rpc_message = CborValue::Map(request_map);
224
225 let mut payload = Vec::new();
226 ciborium::into_writer(&rpc_message, &mut payload)
227 .map_err(|e| SurrealError::Protocol(format!("CBOR encoding failed: {}", e)))?;
228
229 {
230 let mut sink = self.sink.lock().await;
231 sink.send(Message::Binary(payload.into()))
232 .await
233 .map_err(|e| SurrealError::Connection(format!("WS send failed: {}", e)))?;
234 }
235
236 let response = rx
237 .await
238 .map_err(|_| SurrealError::Protocol("Response channel closed".to_string()))?;
239
240 if let CborValue::Map(map) = &response {
241 for (key, value) in map {
242 if let CborValue::Text(k) = key
243 && k == "error"
244 {
245 if let CborValue::Map(error_map) = value {
246 let mut code = -1;
247 let mut message = String::new();
248
249 for (error_key, error_value) in error_map {
250 if let CborValue::Text(error_k) = error_key {
251 match error_k.as_str() {
252 "code" => {
253 if let CborValue::Integer(c) = error_value {
254 code = (*c).try_into().unwrap_or(-1);
255 }
256 }
257 "message" => {
258 if let CborValue::Text(m) = error_value {
259 message = m.clone();
260 }
261 }
262 _ => {}
263 }
264 }
265 }
266
267 if !message.is_empty() {
268 return Err(SurrealError::ServerError { code, message });
269 }
270 }
271
272 return Err(SurrealError::Protocol(format!("Server error: {:?}", value)));
273 }
274 }
275 }
276
277 Ok(response)
278 }
279}
280
281impl Drop for WsCborEngine {
282 fn drop(&mut self) {
283 if let Some(handle) = self.task_handle.take() {
284 handle.abort();
285 }
286 }
287}