Skip to main content

surreal_client/engines/
ws_cbor.rs

1//! WebSocket CBOR Engine for SurrealDB
2//!
3//! Connects with the `cbor` WebSocket subprotocol and exchanges CBOR-encoded
4//! binary frames. Preserves native types (datetime, duration, recordid, bytes)
5//! that JSON cannot carry. Accepts `ws://`, `wss://`, or `cbor://` URLs;
6//! `cbor://` is treated as `ws://`.
7
8use async_trait::async_trait;
9use ciborium::Value as CborValue;
10use futures_util::stream::{SplitSink, SplitStream};
11use futures_util::{SinkExt, StreamExt};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, Ordering::SeqCst};
15use tokio::net::TcpStream;
16use tokio::sync::{Mutex, oneshot};
17use tokio_tungstenite::MaybeTlsStream;
18use tokio_tungstenite::tungstenite::client::IntoClientRequest;
19use tokio_tungstenite::tungstenite::http::HeaderValue;
20use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
21use tokio_tungstenite::{WebSocketStream, connect_async, tungstenite::Message};
22use tracing::{Instrument as _, warn};
23
24use crate::SurrealConnection;
25use crate::{
26    engine::Engine,
27    error::{Result, SurrealError},
28};
29
30/// Request structure for CBOR WebSocket protocol
31#[derive(Debug, Clone)]
32struct RouterRequest {
33    id: String,
34    method: String,
35    params: Option<CborValue>,
36}
37
38type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
39
40/// WebSocket engine using native CBOR for SurrealDB.
41///
42/// Connects with `Sec-WebSocket-Protocol: cbor` and sends/receives binary
43/// CBOR frames. Authentication and `use ns/db` are performed during
44/// `from_connection` so the returned engine is ready to issue queries.
45pub struct WsCborEngine {
46    sink: Arc<Mutex<SplitSink<WsStream, Message>>>,
47    stream: Arc<Mutex<SplitStream<WsStream>>>,
48    msg_id: AtomicU64,
49    pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<CborValue>>>>,
50    task_handle: Option<tokio::task::JoinHandle<()>>,
51}
52
53impl WsCborEngine {
54    pub async fn from_connection(connect: &SurrealConnection) -> Result<Self> {
55        let base_url = connect
56            .url
57            .as_ref()
58            .ok_or_else(|| SurrealError::Connection("URL is required to connect".to_string()))?;
59
60        let mut ws_url = if let Some(rest) = base_url.strip_prefix("cbor://") {
61            format!("ws://{}", rest)
62        } else {
63            base_url.clone()
64        };
65        if !ws_url.ends_with("/rpc") {
66            if ws_url.ends_with('/') {
67                ws_url.push_str("rpc");
68            } else {
69                ws_url.push_str("/rpc");
70            }
71        }
72
73        let mut request = ws_url
74            .as_str()
75            .into_client_request()
76            .map_err(|e| SurrealError::Connection(format!("Invalid WebSocket URL: {}", e)))?;
77
78        request
79            .headers_mut()
80            .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("cbor"));
81
82        let (stream, _response) = connect_async(request).await.map_err(|e| {
83            SurrealError::Connection(format!("Failed to connect to WebSocket: {}", e))
84        })?;
85
86        let (sink, stream) = stream.split();
87
88        let mut engine = Self {
89            sink: Arc::new(Mutex::new(sink)),
90            stream: Arc::new(Mutex::new(stream)),
91            msg_id: AtomicU64::new(0),
92            pending_requests: Arc::new(Mutex::new(HashMap::new())),
93            task_handle: None,
94        };
95
96        let task_handle = engine.handle_messages();
97        engine.task_handle = Some(task_handle);
98
99        connect.init_engine(&mut engine).await?;
100
101        Ok(engine)
102    }
103
104    fn handle_messages(&self) -> tokio::task::JoinHandle<()> {
105        let stream = Arc::clone(&self.stream);
106        let pending_requests = Arc::clone(&self.pending_requests);
107
108        tokio::spawn(
109            async move {
110                loop {
111                    let msg = {
112                        let mut stream_guard = stream.lock().await;
113                        stream_guard.next().await
114                    };
115
116                    let msg = match msg {
117                        None => break,
118                        Some(Err(e)) => {
119                            warn!(error = %e, "CBOR ws receive error");
120                            break;
121                        }
122                        Some(Ok(msg)) => msg,
123                    };
124
125                    match msg {
126                        Message::Text(_text) => {
127                            // Ignore text messages - we only use CBOR binary
128                        }
129                        Message::Binary(binary) => {
130                            // Parse CBOR response: {id, result} or {id, error}
131                            match ciborium::from_reader(binary.as_ref()) {
132                                Ok(cbor_response) => {
133                                    if let CborValue::Map(map) = cbor_response {
134                                        let mut id_str = None;
135                                        let mut result = None;
136                                        let mut error = None;
137
138                                        for (key, value) in &map {
139                                            if let CborValue::Text(k) = key {
140                                                match k.as_str() {
141                                                    "id" => {
142                                                        if let CborValue::Text(id) = value {
143                                                            id_str = Some(id.clone());
144                                                        }
145                                                    }
146                                                    "result" => result = Some(value.clone()),
147                                                    "error" => error = Some(value.clone()),
148                                                    _ => {}
149                                                }
150                                            }
151                                        }
152
153                                        if let Some(id) = id_str {
154                                            let tx = {
155                                                let mut pending = pending_requests.lock().await;
156                                                pending.remove(&id)
157                                            };
158
159                                            if let Some(tx) = tx {
160                                                if let Some(err) = error {
161                                                    let _ = tx.send(CborValue::Map(vec![(
162                                                        CborValue::Text("error".to_string()),
163                                                        err,
164                                                    )]));
165                                                } else if let Some(res) = result {
166                                                    let _ = tx.send(res);
167                                                } else {
168                                                    let _ = tx.send(CborValue::Null);
169                                                }
170                                            }
171                                        }
172                                    }
173                                }
174                                Err(e) => {
175                                    warn!(error = %e, bytes = binary.len(), "CBOR parse failed");
176                                }
177                            }
178                        }
179                        Message::Ping(_) => {}
180                        Message::Pong(_) => {}
181                        Message::Close(_) => break,
182                        _ => {}
183                    }
184                }
185            }
186            .in_current_span(),
187        )
188    }
189}
190
191#[async_trait]
192impl Engine for WsCborEngine {
193    async fn send_message_cbor(&mut self, method: &str, params: CborValue) -> Result<CborValue> {
194        let (tx, rx) = oneshot::channel();
195        let id = self.msg_id.fetch_add(1, SeqCst).to_string();
196
197        {
198            let mut pending = self.pending_requests.lock().await;
199            pending.insert(id.clone(), tx);
200        }
201
202        let request = RouterRequest {
203            id: id.clone(),
204            method: method.to_string(),
205            params: Some(params),
206        };
207
208        let mut request_map = vec![
209            (
210                CborValue::Text("id".to_string()),
211                CborValue::Text(request.id),
212            ),
213            (
214                CborValue::Text("method".to_string()),
215                CborValue::Text(request.method),
216            ),
217        ];
218
219        if let Some(params) = request.params {
220            request_map.push((CborValue::Text("params".to_string()), params));
221        }
222
223        let rpc_message = CborValue::Map(request_map);
224
225        let mut payload = Vec::new();
226        ciborium::into_writer(&rpc_message, &mut payload)
227            .map_err(|e| SurrealError::Protocol(format!("CBOR encoding failed: {}", e)))?;
228
229        {
230            let mut sink = self.sink.lock().await;
231            sink.send(Message::Binary(payload.into()))
232                .await
233                .map_err(|e| SurrealError::Connection(format!("WS send failed: {}", e)))?;
234        }
235
236        let response = rx
237            .await
238            .map_err(|_| SurrealError::Protocol("Response channel closed".to_string()))?;
239
240        if let CborValue::Map(map) = &response {
241            for (key, value) in map {
242                if let CborValue::Text(k) = key
243                    && k == "error"
244                {
245                    if let CborValue::Map(error_map) = value {
246                        let mut code = -1;
247                        let mut message = String::new();
248
249                        for (error_key, error_value) in error_map {
250                            if let CborValue::Text(error_k) = error_key {
251                                match error_k.as_str() {
252                                    "code" => {
253                                        if let CborValue::Integer(c) = error_value {
254                                            code = (*c).try_into().unwrap_or(-1);
255                                        }
256                                    }
257                                    "message" => {
258                                        if let CborValue::Text(m) = error_value {
259                                            message = m.clone();
260                                        }
261                                    }
262                                    _ => {}
263                                }
264                            }
265                        }
266
267                        if !message.is_empty() {
268                            return Err(SurrealError::ServerError { code, message });
269                        }
270                    }
271
272                    return Err(SurrealError::Protocol(format!("Server error: {:?}", value)));
273                }
274            }
275        }
276
277        Ok(response)
278    }
279}
280
281impl Drop for WsCborEngine {
282    fn drop(&mut self) {
283        if let Some(handle) = self.task_handle.take() {
284            handle.abort();
285        }
286    }
287}