1use super::PATH;
2use super::{deserialize, serialize};
3use crate::api::conn::Connection;
4use crate::api::conn::DbResponse;
5use crate::api::conn::Method;
6use crate::api::conn::Param;
7use crate::api::conn::Route;
8use crate::api::conn::Router;
9use crate::api::engine::remote::ws::Client;
10use crate::api::engine::remote::ws::Response;
11use crate::api::engine::remote::ws::PING_INTERVAL;
12use crate::api::engine::remote::ws::PING_METHOD;
13use crate::api::err::Error;
14use crate::api::opt::Endpoint;
15#[cfg(any(feature = "native-tls", feature = "rustls"))]
16use crate::api::opt::Tls;
17use crate::api::ExtraFeatures;
18use crate::api::OnceLockExt;
19use crate::api::Result;
20use crate::api::Surreal;
21use crate::engine::remote::ws::Data;
22use crate::engine::IntervalStream;
23use crate::opt::WaitFor;
24use crate::sql::Value;
25use flume::Receiver;
26use futures::stream::SplitSink;
27use futures::SinkExt;
28use futures::StreamExt;
29use futures_concurrency::stream::Merge as _;
30use indexmap::IndexMap;
31use revision::revisioned;
32use serde::Deserialize;
33use std::collections::hash_map::Entry;
34use std::collections::BTreeMap;
35use std::collections::HashMap;
36use std::collections::HashSet;
37use std::future::Future;
38use std::mem;
39use std::pin::Pin;
40use std::sync::atomic::AtomicI64;
41use std::sync::Arc;
42use std::sync::OnceLock;
43use tokio::net::TcpStream;
44use tokio::sync::watch;
45use tokio::time;
46use tokio::time::MissedTickBehavior;
47use tokio_tungstenite::tungstenite::client::IntoClientRequest;
48use tokio_tungstenite::tungstenite::error::Error as WsError;
49use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
50use tokio_tungstenite::tungstenite::http::HeaderValue;
51use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
52use tokio_tungstenite::tungstenite::Message;
53use tokio_tungstenite::Connector;
54use tokio_tungstenite::MaybeTlsStream;
55use tokio_tungstenite::WebSocketStream;
56use trice::Instant;
57
58type WsResult<T> = std::result::Result<T, WsError>;
59
60pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; pub(crate) const MAX_FRAME_SIZE: usize = 16 << 20; pub(crate) const WRITE_BUFFER_SIZE: usize = 128000; pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE + MAX_MESSAGE_SIZE; pub(crate) const NAGLE_ALG: bool = false;
65
66pub(crate) enum Either {
67 Request(Option<Route>),
68 Response(WsResult<Message>),
69 Ping,
70}
71
72#[cfg(any(feature = "native-tls", feature = "rustls"))]
73impl From<Tls> for Connector {
74 fn from(tls: Tls) -> Self {
75 match tls {
76 #[cfg(feature = "native-tls")]
77 Tls::Native(config) => Self::NativeTls(config),
78 #[cfg(feature = "rustls")]
79 Tls::Rust(config) => Self::Rustls(Arc::new(config)),
80 }
81 }
82}
83
84pub(crate) async fn connect(
85 endpoint: &Endpoint,
86 config: Option<WebSocketConfig>,
87 #[allow(unused_variables)] maybe_connector: Option<Connector>,
88) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
89 let mut request = (&endpoint.url).into_client_request()?;
90
91 if endpoint.supports_revision {
92 request
93 .headers_mut()
94 .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static(super::REVISION_HEADER));
95 }
96
97 #[cfg(any(feature = "native-tls", feature = "rustls"))]
98 let (socket, _) = tokio_tungstenite::connect_async_tls_with_config(
99 request,
100 config,
101 NAGLE_ALG,
102 maybe_connector,
103 )
104 .await?;
105
106 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
107 let (socket, _) = tokio_tungstenite::connect_async_with_config(request, config, NAGLE_ALG).await?;
108
109 Ok(socket)
110}
111
112impl crate::api::Connection for Client {}
113
114impl Connection for Client {
115 fn new(method: Method) -> Self {
116 Self {
117 id: 0,
118 method,
119 }
120 }
121
122 fn connect(
123 mut address: Endpoint,
124 capacity: usize,
125 ) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
126 Box::pin(async move {
127 address.url = address.url.join(PATH)?;
128 #[cfg(any(feature = "native-tls", feature = "rustls"))]
129 let maybe_connector = address.config.tls_config.clone().map(Connector::from);
130 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
131 let maybe_connector = None;
132
133 let config = WebSocketConfig {
134 max_message_size: Some(MAX_MESSAGE_SIZE),
135 max_frame_size: Some(MAX_FRAME_SIZE),
136 max_write_buffer_size: MAX_WRITE_BUFFER_SIZE,
137 ..Default::default()
138 };
139
140 let socket = connect(&address, Some(config), maybe_connector.clone()).await?;
141
142 let (route_tx, route_rx) = match capacity {
143 0 => flume::unbounded(),
144 capacity => flume::bounded(capacity),
145 };
146
147 router(address, maybe_connector, capacity, config, socket, route_rx);
148
149 let mut features = HashSet::new();
150 features.insert(ExtraFeatures::LiveQueries);
151
152 Ok(Surreal::new_from_router_waiter(
153 Arc::new(OnceLock::with_value(Router {
154 features,
155 sender: route_tx,
156 last_id: AtomicI64::new(0),
157 })),
158 Arc::new(watch::channel(Some(WaitFor::Connection))),
159 ))
160 })
161 }
162
163 fn send<'r>(
164 &'r mut self,
165 router: &'r Router,
166 param: Param,
167 ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
168 Box::pin(async move {
169 self.id = router.next_id();
170 let (sender, receiver) = flume::bounded(1);
171 let route = Route {
172 request: (self.id, self.method, param),
173 response: sender,
174 };
175 router.sender.send_async(Some(route)).await?;
176 Ok(receiver)
177 })
178 }
179}
180
181#[allow(clippy::too_many_lines)]
182pub(crate) fn router(
183 endpoint: Endpoint,
184 maybe_connector: Option<Connector>,
185 capacity: usize,
186 config: WebSocketConfig,
187 mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
188 route_rx: Receiver<Option<Route>>,
189) {
190 tokio::spawn(async move {
191 let ping = {
192 let mut request = BTreeMap::new();
193 request.insert("method".to_owned(), PING_METHOD.into());
194 let value = Value::from(request);
195 let value = serialize(&value, endpoint.supports_revision).unwrap();
196 Message::Binary(value)
197 };
198
199 let mut var_stash = IndexMap::new();
200 let mut vars = IndexMap::new();
201 let mut replay = IndexMap::new();
202
203 'router: loop {
204 let (socket_sink, socket_stream) = socket.split();
205 let mut socket_sink = Socket(Some(socket_sink));
206
207 if let Socket(Some(socket_sink)) = &mut socket_sink {
208 let mut routes = match capacity {
209 0 => HashMap::new(),
210 capacity => HashMap::with_capacity(capacity),
211 };
212 let mut live_queries = HashMap::new();
213
214 let mut interval = time::interval(PING_INTERVAL);
215 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
217
218 let pinger = IntervalStream::new(interval);
219
220 let streams = (
221 socket_stream.map(Either::Response),
222 route_rx.stream().map(Either::Request),
223 pinger.map(|_| Either::Ping),
224 );
225
226 let mut merged = streams.merge();
227 let mut last_activity = Instant::now();
228
229 while let Some(either) = merged.next().await {
230 match either {
231 Either::Request(Some(Route {
232 request,
233 response,
234 })) => {
235 let (id, method, param) = request;
236 let params = match param.query {
237 Some((query, bindings)) => {
238 vec![query.into(), bindings.into()]
239 }
240 None => param.other,
241 };
242 match method {
243 Method::Set => {
244 if let [Value::Strand(key), value] = ¶ms[..2] {
245 var_stash.insert(id, (key.0.clone(), value.clone()));
246 }
247 }
248 Method::Unset => {
249 if let [Value::Strand(key)] = ¶ms[..1] {
250 vars.swap_remove(&key.0);
251 }
252 }
253 Method::Live => {
254 if let Some(sender) = param.notification_sender {
255 if let [Value::Uuid(id)] = ¶ms[..1] {
256 live_queries.insert(*id, sender);
257 }
258 }
259 if response
260 .into_send_async(Ok(DbResponse::Other(Value::None)))
261 .await
262 .is_err()
263 {
264 trace!("Receiver dropped");
265 }
266 continue;
268 }
269 Method::Kill => {
270 if let [Value::Uuid(id)] = ¶ms[..1] {
271 live_queries.remove(id);
272 }
273 }
274 _ => {}
275 }
276 let method_str = match method {
277 Method::Health => PING_METHOD,
278 _ => method.as_str(),
279 };
280 let message = {
281 let mut request = BTreeMap::new();
282 request.insert("id".to_owned(), Value::from(id));
283 request.insert("method".to_owned(), method_str.into());
284 if !params.is_empty() {
285 request.insert("params".to_owned(), params.into());
286 }
287 let payload = Value::from(request);
288 trace!("Request {payload}");
289 let payload =
290 serialize(&payload, endpoint.supports_revision).unwrap();
291 Message::Binary(payload)
292 };
293 if let Method::Authenticate
294 | Method::Invalidate
295 | Method::Signin
296 | Method::Signup
297 | Method::Use = method
298 {
299 replay.insert(method, message.clone());
300 }
301 match socket_sink.send(message).await {
302 Ok(..) => {
303 last_activity = Instant::now();
304 match routes.entry(id) {
305 Entry::Vacant(entry) => {
306 entry.insert((method, response));
308 }
309 Entry::Occupied(..) => {
310 let error = Error::DuplicateRequestId(id);
311 if response
312 .into_send_async(Err(error.into()))
313 .await
314 .is_err()
315 {
316 trace!("Receiver dropped");
317 }
318 }
319 }
320 }
321 Err(error) => {
322 let error = Error::Ws(error.to_string());
323 if response.into_send_async(Err(error.into())).await.is_err() {
324 trace!("Receiver dropped");
325 }
326 break;
327 }
328 }
329 }
330 Either::Response(result) => {
331 last_activity = Instant::now();
332 match result {
333 Ok(message) => {
334 match Response::try_from(&message, endpoint.supports_revision) {
335 Ok(option) => {
336 if let Some(response) = option {
338 trace!("{response:?}");
339 match response.id {
340 Some(id) => {
342 if let Ok(id) = id.coerce_to_i64() {
343 if let Some((method, sender)) =
345 routes.remove(&id)
346 {
347 if matches!(method, Method::Set) {
348 if let Some((key, value)) =
349 var_stash.swap_remove(&id)
350 {
351 vars.insert(key, value);
352 }
353 }
354 let mut response = response.result;
356 if matches!(method, Method::Insert)
357 {
358 if let Ok(Data::Other(
360 Value::Array(value),
361 )) = &mut response
362 {
363 if let [value] =
364 &mut value.0[..]
365 {
366 response =
367 Ok(Data::Other(
368 mem::take(
369 value,
370 ),
371 ));
372 }
373 }
374 }
375 let _res = sender
376 .into_send_async(
377 DbResponse::from(response),
378 )
379 .await;
380 }
381 }
382 }
383 None => match response.result {
385 Ok(Data::Live(notification)) => {
386 let live_query_id = notification.id;
387 if let Some(sender) =
389 live_queries.get(&live_query_id)
390 {
391 if sender
393 .send(notification)
394 .await
395 .is_err()
396 {
397 live_queries
398 .remove(&live_query_id);
399 let kill = {
400 let mut request =
401 BTreeMap::new();
402 request.insert(
403 "method".to_owned(),
404 Method::Kill
405 .as_str()
406 .into(),
407 );
408 request.insert(
409 "params".to_owned(),
410 vec![Value::from(
411 live_query_id,
412 )]
413 .into(),
414 );
415 let value =
416 Value::from(request);
417 let value = serialize(
418 &value,
419 endpoint
420 .supports_revision,
421 )
422 .unwrap();
423 Message::Binary(value)
424 };
425 if let Err(error) =
426 socket_sink.send(kill).await
427 {
428 trace!("failed to send kill query to the server; {error:?}");
429 break;
430 }
431 }
432 }
433 }
434 Ok(..) => { }
436 Err(error) => error!("{error:?}"),
437 },
438 }
439 }
440 }
441 Err(error) => {
442 #[revisioned(revision = 1)]
443 #[derive(Deserialize)]
444 struct Response {
445 id: Option<Value>,
446 }
447
448 if let Message::Binary(binary) = message {
450 if let Ok(Response {
451 id,
452 }) = deserialize(
453 &mut &binary[..],
454 endpoint.supports_revision,
455 ) {
456 if let Some(Ok(id)) =
458 id.map(Value::coerce_to_i64)
459 {
460 if let Some((_method, sender)) =
461 routes.remove(&id)
462 {
463 let _res = sender
464 .into_send_async(Err(error))
465 .await;
466 }
467 }
468 } else {
469 warn!(
471 "Failed to deserialise message; {error:?}"
472 );
473 }
474 }
475 }
476 }
477 }
478 Err(error) => {
479 match error {
480 WsError::ConnectionClosed => {
481 trace!("Connection successfully closed on the server");
482 }
483 error => {
484 trace!("{error}");
485 }
486 }
487 break;
488 }
489 }
490 }
491 Either::Ping => {
492 if last_activity.elapsed() >= PING_INTERVAL {
494 trace!("Pinging the server");
495 if let Err(error) = socket_sink.send(ping.clone()).await {
496 trace!("failed to ping the server; {error:?}");
497 break;
498 }
499 }
500 }
501 Either::Request(None) => {
503 match socket_sink.send(Message::Close(None)).await {
504 Ok(..) => trace!("Connection closed successfully"),
505 Err(error) => {
506 warn!("Failed to close database connection; {error}")
507 }
508 }
509 break 'router;
510 }
511 }
512 }
513 }
514
515 'reconnect: loop {
516 trace!("Reconnecting...");
517 match connect(&endpoint, Some(config), maybe_connector.clone()).await {
518 Ok(s) => {
519 socket = s;
520 for (_, message) in &replay {
521 if let Err(error) = socket.send(message.clone()).await {
522 trace!("{error}");
523 time::sleep(time::Duration::from_secs(1)).await;
524 continue 'reconnect;
525 }
526 }
527 for (key, value) in &vars {
528 let mut request = BTreeMap::new();
529 request.insert("method".to_owned(), Method::Set.as_str().into());
530 request.insert(
531 "params".to_owned(),
532 vec![key.as_str().into(), value.clone()].into(),
533 );
534 let payload = Value::from(request);
535 trace!("Request {payload}");
536 if let Err(error) = socket.send(Message::Binary(payload.into())).await {
537 trace!("{error}");
538 time::sleep(time::Duration::from_secs(1)).await;
539 continue 'reconnect;
540 }
541 }
542 trace!("Reconnected successfully");
543 break;
544 }
545 Err(error) => {
546 trace!("Failed to reconnect; {error}");
547 time::sleep(time::Duration::from_secs(1)).await;
548 }
549 }
550 }
551 }
552 });
553}
554
555impl Response {
556 fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
557 match message {
558 Message::Text(text) => {
559 trace!("Received an unexpected text message; {text}");
560 Ok(None)
561 }
562 Message::Binary(binary) => {
563 deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
564 Error::ResponseFromBinary {
565 binary: binary.clone(),
566 error: bincode::ErrorKind::Custom(error.to_string()).into(),
567 }
568 .into()
569 })
570 }
571 Message::Ping(..) => {
572 trace!("Received a ping from the server");
573 Ok(None)
574 }
575 Message::Pong(..) => {
576 trace!("Received a pong from the server");
577 Ok(None)
578 }
579 Message::Frame(..) => {
580 trace!("Received an unexpected raw frame");
581 Ok(None)
582 }
583 Message::Close(..) => {
584 trace!("Received an unexpected close message");
585 Ok(None)
586 }
587 }
588 }
589}
590
591pub struct Socket(Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>);
592
593#[cfg(test)]
594mod tests {
595 use super::serialize;
596 use bincode::Options;
597 use flate2::write::GzEncoder;
598 use flate2::Compression;
599 use rand::{thread_rng, Rng};
600 use std::io::Write;
601 use std::time::SystemTime;
602 use surrealdb_core::rpc::format::cbor::Cbor;
603 use surrealdb_core::sql::{Array, Value};
604
605 #[test_log::test]
606 fn large_vector_serialisation_bench() {
607 let timed = |func: &dyn Fn() -> Vec<u8>| {
609 let start = SystemTime::now();
610 let r = func();
611 (start.elapsed().unwrap(), r)
612 };
613 let compress = |v: &Vec<u8>| {
615 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
616 encoder.write_all(&v).unwrap();
617 encoder.finish().unwrap()
618 };
619 let vector_size = if cfg!(debug_assertions) {
621 200_000 } else {
623 2_000_000 };
625 let mut vector: Vec<i32> = Vec::new();
626 let mut rng = thread_rng();
627 for _ in 0..vector_size {
628 vector.push(rng.gen());
629 }
630 let mut results = vec![];
632 let ref_payload;
634 let ref_compressed;
635 const BINCODE_REF: &str = "Bincode Vec<i32>";
637 const COMPRESSED_BINCODE_REF: &str = "Compressed Bincode Vec<i32>";
638 {
639 let (duration, payload) = timed(&|| {
641 let mut payload = Vec::new();
642 bincode::options()
643 .with_fixint_encoding()
644 .serialize_into(&mut payload, &vector)
645 .unwrap();
646 payload
647 });
648 ref_payload = payload.len() as f32;
649 results.push((payload.len(), BINCODE_REF, duration, 1.0));
650
651 let (compression_duration, payload) = timed(&|| compress(&payload));
653 let duration = duration + compression_duration;
654 ref_compressed = payload.len() as f32;
655 results.push((payload.len(), COMPRESSED_BINCODE_REF, duration, 1.0));
656 }
657 let vector = Value::Array(Array::from(vector));
659 const BINCODE: &str = "Bincode Vec<Value>";
661 const COMPRESSED_BINCODE: &str = "Compressed Bincode Vec<Value>";
662 {
663 let (duration, payload) = timed(&|| {
665 let mut payload = Vec::new();
666 bincode::options()
667 .with_varint_encoding()
668 .serialize_into(&mut payload, &vector)
669 .unwrap();
670 payload
671 });
672 results.push((payload.len(), BINCODE, duration, payload.len() as f32 / ref_payload));
673
674 let (compression_duration, payload) = timed(&|| compress(&payload));
676 let duration = duration + compression_duration;
677 results.push((
678 payload.len(),
679 COMPRESSED_BINCODE,
680 duration,
681 payload.len() as f32 / ref_compressed,
682 ));
683 }
684 const UNVERSIONED: &str = "Unversioned Vec<Value>";
685 const COMPRESSED_UNVERSIONED: &str = "Compressed Unversioned Vec<Value>";
686 {
687 let (duration, payload) = timed(&|| serialize(&vector, false).unwrap());
689 results.push((
690 payload.len(),
691 UNVERSIONED,
692 duration,
693 payload.len() as f32 / ref_payload,
694 ));
695
696 let (compression_duration, payload) = timed(&|| compress(&payload));
698 let duration = duration + compression_duration;
699 results.push((
700 payload.len(),
701 COMPRESSED_UNVERSIONED,
702 duration,
703 payload.len() as f32 / ref_compressed,
704 ));
705 }
706 const VERSIONED: &str = "Versioned Vec<Value>";
708 const COMPRESSED_VERSIONED: &str = "Compressed Versioned Vec<Value>";
709 {
710 let (duration, payload) = timed(&|| serialize(&vector, true).unwrap());
712 results.push((payload.len(), VERSIONED, duration, payload.len() as f32 / ref_payload));
713
714 let (compression_duration, payload) = timed(&|| compress(&payload));
716 let duration = duration + compression_duration;
717 results.push((
718 payload.len(),
719 COMPRESSED_VERSIONED,
720 duration,
721 payload.len() as f32 / ref_compressed,
722 ));
723 }
724 const CBOR: &str = "CBor Vec<Value>";
726 const COMPRESSED_CBOR: &str = "Compressed CBor Vec<Value>";
727 {
728 let (duration, payload) = timed(&|| {
730 let cbor: Cbor = vector.clone().try_into().unwrap();
731 let mut res = Vec::new();
732 ciborium::into_writer(&cbor.0, &mut res).unwrap();
733 res
734 });
735 results.push((payload.len(), CBOR, duration, payload.len() as f32 / ref_payload));
736
737 let (compression_duration, payload) = timed(&|| compress(&payload));
739 let duration = duration + compression_duration;
740 results.push((
741 payload.len(),
742 COMPRESSED_CBOR,
743 duration,
744 payload.len() as f32 / ref_compressed,
745 ));
746 }
747 results.sort_by(|(a, _, _, _), (b, _, _, _)| a.cmp(b));
749 for (size, name, duration, factor) in &results {
750 info!("{name} - Size: {size} - Duration: {duration:?} - Factor: {factor}");
751 }
752 let results: Vec<&str> = results.into_iter().map(|(_, name, _, _)| name).collect();
754 assert_eq!(
755 results,
756 vec![
757 BINCODE_REF,
758 COMPRESSED_BINCODE_REF,
759 COMPRESSED_CBOR,
760 COMPRESSED_BINCODE,
761 COMPRESSED_UNVERSIONED,
762 CBOR,
763 COMPRESSED_VERSIONED,
764 BINCODE,
765 UNVERSIONED,
766 VERSIONED,
767 ]
768 )
769 }
770}