surrealdb/api/engine/remote/ws/
native.rs

1use super::PATH;
2use super::{deserialize, serialize};
3use crate::api::conn::Connection;
4use crate::api::conn::DbResponse;
5use crate::api::conn::Method;
6use crate::api::conn::Param;
7use crate::api::conn::Route;
8use crate::api::conn::Router;
9use crate::api::engine::remote::ws::Client;
10use crate::api::engine::remote::ws::Response;
11use crate::api::engine::remote::ws::PING_INTERVAL;
12use crate::api::engine::remote::ws::PING_METHOD;
13use crate::api::err::Error;
14use crate::api::opt::Endpoint;
15#[cfg(any(feature = "native-tls", feature = "rustls"))]
16use crate::api::opt::Tls;
17use crate::api::ExtraFeatures;
18use crate::api::OnceLockExt;
19use crate::api::Result;
20use crate::api::Surreal;
21use crate::engine::remote::ws::Data;
22use crate::engine::IntervalStream;
23use crate::opt::WaitFor;
24use crate::sql::Value;
25use flume::Receiver;
26use futures::stream::SplitSink;
27use futures::SinkExt;
28use futures::StreamExt;
29use futures_concurrency::stream::Merge as _;
30use indexmap::IndexMap;
31use revision::revisioned;
32use serde::Deserialize;
33use std::collections::hash_map::Entry;
34use std::collections::BTreeMap;
35use std::collections::HashMap;
36use std::collections::HashSet;
37use std::future::Future;
38use std::mem;
39use std::pin::Pin;
40use std::sync::atomic::AtomicI64;
41use std::sync::Arc;
42use std::sync::OnceLock;
43use tokio::net::TcpStream;
44use tokio::sync::watch;
45use tokio::time;
46use tokio::time::MissedTickBehavior;
47use tokio_tungstenite::tungstenite::client::IntoClientRequest;
48use tokio_tungstenite::tungstenite::error::Error as WsError;
49use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
50use tokio_tungstenite::tungstenite::http::HeaderValue;
51use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
52use tokio_tungstenite::tungstenite::Message;
53use tokio_tungstenite::Connector;
54use tokio_tungstenite::MaybeTlsStream;
55use tokio_tungstenite::WebSocketStream;
56use trice::Instant;
57
58type WsResult<T> = std::result::Result<T, WsError>;
59
60pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; // 64 MiB
61pub(crate) const MAX_FRAME_SIZE: usize = 16 << 20; // 16 MiB
62pub(crate) const WRITE_BUFFER_SIZE: usize = 128000; // tungstenite default
63pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE + MAX_MESSAGE_SIZE; // Recommended max according to tungstenite docs
64pub(crate) const NAGLE_ALG: bool = false;
65
66pub(crate) enum Either {
67	Request(Option<Route>),
68	Response(WsResult<Message>),
69	Ping,
70}
71
72#[cfg(any(feature = "native-tls", feature = "rustls"))]
73impl From<Tls> for Connector {
74	fn from(tls: Tls) -> Self {
75		match tls {
76			#[cfg(feature = "native-tls")]
77			Tls::Native(config) => Self::NativeTls(config),
78			#[cfg(feature = "rustls")]
79			Tls::Rust(config) => Self::Rustls(Arc::new(config)),
80		}
81	}
82}
83
84pub(crate) async fn connect(
85	endpoint: &Endpoint,
86	config: Option<WebSocketConfig>,
87	#[allow(unused_variables)] maybe_connector: Option<Connector>,
88) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
89	let mut request = (&endpoint.url).into_client_request()?;
90
91	if endpoint.supports_revision {
92		request
93			.headers_mut()
94			.insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static(super::REVISION_HEADER));
95	}
96
97	#[cfg(any(feature = "native-tls", feature = "rustls"))]
98	let (socket, _) = tokio_tungstenite::connect_async_tls_with_config(
99		request,
100		config,
101		NAGLE_ALG,
102		maybe_connector,
103	)
104	.await?;
105
106	#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
107	let (socket, _) = tokio_tungstenite::connect_async_with_config(request, config, NAGLE_ALG).await?;
108
109	Ok(socket)
110}
111
112impl crate::api::Connection for Client {}
113
114impl Connection for Client {
115	fn new(method: Method) -> Self {
116		Self {
117			id: 0,
118			method,
119		}
120	}
121
122	fn connect(
123		mut address: Endpoint,
124		capacity: usize,
125	) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
126		Box::pin(async move {
127			address.url = address.url.join(PATH)?;
128			#[cfg(any(feature = "native-tls", feature = "rustls"))]
129			let maybe_connector = address.config.tls_config.clone().map(Connector::from);
130			#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
131			let maybe_connector = None;
132
133			let config = WebSocketConfig {
134				max_message_size: Some(MAX_MESSAGE_SIZE),
135				max_frame_size: Some(MAX_FRAME_SIZE),
136				max_write_buffer_size: MAX_WRITE_BUFFER_SIZE,
137				..Default::default()
138			};
139
140			let socket = connect(&address, Some(config), maybe_connector.clone()).await?;
141
142			let (route_tx, route_rx) = match capacity {
143				0 => flume::unbounded(),
144				capacity => flume::bounded(capacity),
145			};
146
147			router(address, maybe_connector, capacity, config, socket, route_rx);
148
149			let mut features = HashSet::new();
150			features.insert(ExtraFeatures::LiveQueries);
151
152			Ok(Surreal::new_from_router_waiter(
153				Arc::new(OnceLock::with_value(Router {
154					features,
155					sender: route_tx,
156					last_id: AtomicI64::new(0),
157				})),
158				Arc::new(watch::channel(Some(WaitFor::Connection))),
159			))
160		})
161	}
162
163	fn send<'r>(
164		&'r mut self,
165		router: &'r Router,
166		param: Param,
167	) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
168		Box::pin(async move {
169			self.id = router.next_id();
170			let (sender, receiver) = flume::bounded(1);
171			let route = Route {
172				request: (self.id, self.method, param),
173				response: sender,
174			};
175			router.sender.send_async(Some(route)).await?;
176			Ok(receiver)
177		})
178	}
179}
180
181#[allow(clippy::too_many_lines)]
182pub(crate) fn router(
183	endpoint: Endpoint,
184	maybe_connector: Option<Connector>,
185	capacity: usize,
186	config: WebSocketConfig,
187	mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
188	route_rx: Receiver<Option<Route>>,
189) {
190	tokio::spawn(async move {
191		let ping = {
192			let mut request = BTreeMap::new();
193			request.insert("method".to_owned(), PING_METHOD.into());
194			let value = Value::from(request);
195			let value = serialize(&value, endpoint.supports_revision).unwrap();
196			Message::Binary(value)
197		};
198
199		let mut var_stash = IndexMap::new();
200		let mut vars = IndexMap::new();
201		let mut replay = IndexMap::new();
202
203		'router: loop {
204			let (socket_sink, socket_stream) = socket.split();
205			let mut socket_sink = Socket(Some(socket_sink));
206
207			if let Socket(Some(socket_sink)) = &mut socket_sink {
208				let mut routes = match capacity {
209					0 => HashMap::new(),
210					capacity => HashMap::with_capacity(capacity),
211				};
212				let mut live_queries = HashMap::new();
213
214				let mut interval = time::interval(PING_INTERVAL);
215				// don't bombard the server with pings if we miss some ticks
216				interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
217
218				let pinger = IntervalStream::new(interval);
219
220				let streams = (
221					socket_stream.map(Either::Response),
222					route_rx.stream().map(Either::Request),
223					pinger.map(|_| Either::Ping),
224				);
225
226				let mut merged = streams.merge();
227				let mut last_activity = Instant::now();
228
229				while let Some(either) = merged.next().await {
230					match either {
231						Either::Request(Some(Route {
232							request,
233							response,
234						})) => {
235							let (id, method, param) = request;
236							let params = match param.query {
237								Some((query, bindings)) => {
238									vec![query.into(), bindings.into()]
239								}
240								None => param.other,
241							};
242							match method {
243								Method::Set => {
244									if let [Value::Strand(key), value] = &params[..2] {
245										var_stash.insert(id, (key.0.clone(), value.clone()));
246									}
247								}
248								Method::Unset => {
249									if let [Value::Strand(key)] = &params[..1] {
250										vars.swap_remove(&key.0);
251									}
252								}
253								Method::Live => {
254									if let Some(sender) = param.notification_sender {
255										if let [Value::Uuid(id)] = &params[..1] {
256											live_queries.insert(*id, sender);
257										}
258									}
259									if response
260										.into_send_async(Ok(DbResponse::Other(Value::None)))
261										.await
262										.is_err()
263									{
264										trace!("Receiver dropped");
265									}
266									// There is nothing to send to the server here
267									continue;
268								}
269								Method::Kill => {
270									if let [Value::Uuid(id)] = &params[..1] {
271										live_queries.remove(id);
272									}
273								}
274								_ => {}
275							}
276							let method_str = match method {
277								Method::Health => PING_METHOD,
278								_ => method.as_str(),
279							};
280							let message = {
281								let mut request = BTreeMap::new();
282								request.insert("id".to_owned(), Value::from(id));
283								request.insert("method".to_owned(), method_str.into());
284								if !params.is_empty() {
285									request.insert("params".to_owned(), params.into());
286								}
287								let payload = Value::from(request);
288								trace!("Request {payload}");
289								let payload =
290									serialize(&payload, endpoint.supports_revision).unwrap();
291								Message::Binary(payload)
292							};
293							if let Method::Authenticate
294							| Method::Invalidate
295							| Method::Signin
296							| Method::Signup
297							| Method::Use = method
298							{
299								replay.insert(method, message.clone());
300							}
301							match socket_sink.send(message).await {
302								Ok(..) => {
303									last_activity = Instant::now();
304									match routes.entry(id) {
305										Entry::Vacant(entry) => {
306											// Register query route
307											entry.insert((method, response));
308										}
309										Entry::Occupied(..) => {
310											let error = Error::DuplicateRequestId(id);
311											if response
312												.into_send_async(Err(error.into()))
313												.await
314												.is_err()
315											{
316												trace!("Receiver dropped");
317											}
318										}
319									}
320								}
321								Err(error) => {
322									let error = Error::Ws(error.to_string());
323									if response.into_send_async(Err(error.into())).await.is_err() {
324										trace!("Receiver dropped");
325									}
326									break;
327								}
328							}
329						}
330						Either::Response(result) => {
331							last_activity = Instant::now();
332							match result {
333								Ok(message) => {
334									match Response::try_from(&message, endpoint.supports_revision) {
335										Ok(option) => {
336											// We are only interested in responses that are not empty
337											if let Some(response) = option {
338												trace!("{response:?}");
339												match response.id {
340													// If `id` is set this is a normal response
341													Some(id) => {
342														if let Ok(id) = id.coerce_to_i64() {
343															// We can only route responses with IDs
344															if let Some((method, sender)) =
345																routes.remove(&id)
346															{
347																if matches!(method, Method::Set) {
348																	if let Some((key, value)) =
349																		var_stash.swap_remove(&id)
350																	{
351																		vars.insert(key, value);
352																	}
353																}
354																// Send the response back to the caller
355																let mut response = response.result;
356																if matches!(method, Method::Insert)
357																{
358																	// For insert, we need to flatten single responses in an array
359																	if let Ok(Data::Other(
360																		Value::Array(value),
361																	)) = &mut response
362																	{
363																		if let [value] =
364																			&mut value.0[..]
365																		{
366																			response =
367																				Ok(Data::Other(
368																					mem::take(
369																						value,
370																					),
371																				));
372																		}
373																	}
374																}
375																let _res = sender
376																	.into_send_async(
377																		DbResponse::from(response),
378																	)
379																	.await;
380															}
381														}
382													}
383													// If `id` is not set, this may be a live query notification
384													None => match response.result {
385														Ok(Data::Live(notification)) => {
386															let live_query_id = notification.id;
387															// Check if this live query is registered
388															if let Some(sender) =
389																live_queries.get(&live_query_id)
390															{
391																// Send the notification back to the caller or kill live query if the receiver is already dropped
392																if sender
393																	.send(notification)
394																	.await
395																	.is_err()
396																{
397																	live_queries
398																		.remove(&live_query_id);
399																	let kill = {
400																		let mut request =
401																			BTreeMap::new();
402																		request.insert(
403																			"method".to_owned(),
404																			Method::Kill
405																				.as_str()
406																				.into(),
407																		);
408																		request.insert(
409																			"params".to_owned(),
410																			vec![Value::from(
411																				live_query_id,
412																			)]
413																			.into(),
414																		);
415																		let value =
416																			Value::from(request);
417																		let value = serialize(
418																			&value,
419																			endpoint
420																				.supports_revision,
421																		)
422																		.unwrap();
423																		Message::Binary(value)
424																	};
425																	if let Err(error) =
426																		socket_sink.send(kill).await
427																	{
428																		trace!("failed to send kill query to the server; {error:?}");
429																		break;
430																	}
431																}
432															}
433														}
434														Ok(..) => { /* Ignored responses like pings */
435														}
436														Err(error) => error!("{error:?}"),
437													},
438												}
439											}
440										}
441										Err(error) => {
442											#[revisioned(revision = 1)]
443											#[derive(Deserialize)]
444											struct Response {
445												id: Option<Value>,
446											}
447
448											// Let's try to find out the ID of the response that failed to deserialise
449											if let Message::Binary(binary) = message {
450												if let Ok(Response {
451													id,
452												}) = deserialize(
453													&mut &binary[..],
454													endpoint.supports_revision,
455												) {
456													// Return an error if an ID was returned
457													if let Some(Ok(id)) =
458														id.map(Value::coerce_to_i64)
459													{
460														if let Some((_method, sender)) =
461															routes.remove(&id)
462														{
463															let _res = sender
464																.into_send_async(Err(error))
465																.await;
466														}
467													}
468												} else {
469													// Unfortunately, we don't know which response failed to deserialize
470													warn!(
471														"Failed to deserialise message; {error:?}"
472													);
473												}
474											}
475										}
476									}
477								}
478								Err(error) => {
479									match error {
480										WsError::ConnectionClosed => {
481											trace!("Connection successfully closed on the server");
482										}
483										error => {
484											trace!("{error}");
485										}
486									}
487									break;
488								}
489							}
490						}
491						Either::Ping => {
492							// only ping if we haven't talked to the server recently
493							if last_activity.elapsed() >= PING_INTERVAL {
494								trace!("Pinging the server");
495								if let Err(error) = socket_sink.send(ping.clone()).await {
496									trace!("failed to ping the server; {error:?}");
497									break;
498								}
499							}
500						}
501						// Close connection request received
502						Either::Request(None) => {
503							match socket_sink.send(Message::Close(None)).await {
504								Ok(..) => trace!("Connection closed successfully"),
505								Err(error) => {
506									warn!("Failed to close database connection; {error}")
507								}
508							}
509							break 'router;
510						}
511					}
512				}
513			}
514
515			'reconnect: loop {
516				trace!("Reconnecting...");
517				match connect(&endpoint, Some(config), maybe_connector.clone()).await {
518					Ok(s) => {
519						socket = s;
520						for (_, message) in &replay {
521							if let Err(error) = socket.send(message.clone()).await {
522								trace!("{error}");
523								time::sleep(time::Duration::from_secs(1)).await;
524								continue 'reconnect;
525							}
526						}
527						for (key, value) in &vars {
528							let mut request = BTreeMap::new();
529							request.insert("method".to_owned(), Method::Set.as_str().into());
530							request.insert(
531								"params".to_owned(),
532								vec![key.as_str().into(), value.clone()].into(),
533							);
534							let payload = Value::from(request);
535							trace!("Request {payload}");
536							if let Err(error) = socket.send(Message::Binary(payload.into())).await {
537								trace!("{error}");
538								time::sleep(time::Duration::from_secs(1)).await;
539								continue 'reconnect;
540							}
541						}
542						trace!("Reconnected successfully");
543						break;
544					}
545					Err(error) => {
546						trace!("Failed to reconnect; {error}");
547						time::sleep(time::Duration::from_secs(1)).await;
548					}
549				}
550			}
551		}
552	});
553}
554
555impl Response {
556	fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
557		match message {
558			Message::Text(text) => {
559				trace!("Received an unexpected text message; {text}");
560				Ok(None)
561			}
562			Message::Binary(binary) => {
563				deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
564					Error::ResponseFromBinary {
565						binary: binary.clone(),
566						error: bincode::ErrorKind::Custom(error.to_string()).into(),
567					}
568					.into()
569				})
570			}
571			Message::Ping(..) => {
572				trace!("Received a ping from the server");
573				Ok(None)
574			}
575			Message::Pong(..) => {
576				trace!("Received a pong from the server");
577				Ok(None)
578			}
579			Message::Frame(..) => {
580				trace!("Received an unexpected raw frame");
581				Ok(None)
582			}
583			Message::Close(..) => {
584				trace!("Received an unexpected close message");
585				Ok(None)
586			}
587		}
588	}
589}
590
591pub struct Socket(Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>);
592
593#[cfg(test)]
594mod tests {
595	use super::serialize;
596	use bincode::Options;
597	use flate2::write::GzEncoder;
598	use flate2::Compression;
599	use rand::{thread_rng, Rng};
600	use std::io::Write;
601	use std::time::SystemTime;
602	use surrealdb_core::rpc::format::cbor::Cbor;
603	use surrealdb_core::sql::{Array, Value};
604
605	#[test_log::test]
606	fn large_vector_serialisation_bench() {
607		//
608		let timed = |func: &dyn Fn() -> Vec<u8>| {
609			let start = SystemTime::now();
610			let r = func();
611			(start.elapsed().unwrap(), r)
612		};
613		//
614		let compress = |v: &Vec<u8>| {
615			let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
616			encoder.write_all(&v).unwrap();
617			encoder.finish().unwrap()
618		};
619		// Generate a random vector
620		let vector_size = if cfg!(debug_assertions) {
621			200_000 // Debug is slow
622		} else {
623			2_000_000 // Release is fast
624		};
625		let mut vector: Vec<i32> = Vec::new();
626		let mut rng = thread_rng();
627		for _ in 0..vector_size {
628			vector.push(rng.gen());
629		}
630		//	Store the results
631		let mut results = vec![];
632		// Calculate the reference
633		let ref_payload;
634		let ref_compressed;
635		//
636		const BINCODE_REF: &str = "Bincode Vec<i32>";
637		const COMPRESSED_BINCODE_REF: &str = "Compressed Bincode Vec<i32>";
638		{
639			// Bincode Vec<i32>
640			let (duration, payload) = timed(&|| {
641				let mut payload = Vec::new();
642				bincode::options()
643					.with_fixint_encoding()
644					.serialize_into(&mut payload, &vector)
645					.unwrap();
646				payload
647			});
648			ref_payload = payload.len() as f32;
649			results.push((payload.len(), BINCODE_REF, duration, 1.0));
650
651			// Compressed bincode
652			let (compression_duration, payload) = timed(&|| compress(&payload));
653			let duration = duration + compression_duration;
654			ref_compressed = payload.len() as f32;
655			results.push((payload.len(), COMPRESSED_BINCODE_REF, duration, 1.0));
656		}
657		// Build the Value
658		let vector = Value::Array(Array::from(vector));
659		//
660		const BINCODE: &str = "Bincode Vec<Value>";
661		const COMPRESSED_BINCODE: &str = "Compressed Bincode Vec<Value>";
662		{
663			// Bincode Vec<i32>
664			let (duration, payload) = timed(&|| {
665				let mut payload = Vec::new();
666				bincode::options()
667					.with_varint_encoding()
668					.serialize_into(&mut payload, &vector)
669					.unwrap();
670				payload
671			});
672			results.push((payload.len(), BINCODE, duration, payload.len() as f32 / ref_payload));
673
674			// Compressed bincode
675			let (compression_duration, payload) = timed(&|| compress(&payload));
676			let duration = duration + compression_duration;
677			results.push((
678				payload.len(),
679				COMPRESSED_BINCODE,
680				duration,
681				payload.len() as f32 / ref_compressed,
682			));
683		}
684		const UNVERSIONED: &str = "Unversioned Vec<Value>";
685		const COMPRESSED_UNVERSIONED: &str = "Compressed Unversioned Vec<Value>";
686		{
687			// Unversioned
688			let (duration, payload) = timed(&|| serialize(&vector, false).unwrap());
689			results.push((
690				payload.len(),
691				UNVERSIONED,
692				duration,
693				payload.len() as f32 / ref_payload,
694			));
695
696			// Compressed Versioned
697			let (compression_duration, payload) = timed(&|| compress(&payload));
698			let duration = duration + compression_duration;
699			results.push((
700				payload.len(),
701				COMPRESSED_UNVERSIONED,
702				duration,
703				payload.len() as f32 / ref_compressed,
704			));
705		}
706		//
707		const VERSIONED: &str = "Versioned Vec<Value>";
708		const COMPRESSED_VERSIONED: &str = "Compressed Versioned Vec<Value>";
709		{
710			// Versioned
711			let (duration, payload) = timed(&|| serialize(&vector, true).unwrap());
712			results.push((payload.len(), VERSIONED, duration, payload.len() as f32 / ref_payload));
713
714			// Compressed Versioned
715			let (compression_duration, payload) = timed(&|| compress(&payload));
716			let duration = duration + compression_duration;
717			results.push((
718				payload.len(),
719				COMPRESSED_VERSIONED,
720				duration,
721				payload.len() as f32 / ref_compressed,
722			));
723		}
724		//
725		const CBOR: &str = "CBor Vec<Value>";
726		const COMPRESSED_CBOR: &str = "Compressed CBor Vec<Value>";
727		{
728			// CBor
729			let (duration, payload) = timed(&|| {
730				let cbor: Cbor = vector.clone().try_into().unwrap();
731				let mut res = Vec::new();
732				ciborium::into_writer(&cbor.0, &mut res).unwrap();
733				res
734			});
735			results.push((payload.len(), CBOR, duration, payload.len() as f32 / ref_payload));
736
737			// Compressed Cbor
738			let (compression_duration, payload) = timed(&|| compress(&payload));
739			let duration = duration + compression_duration;
740			results.push((
741				payload.len(),
742				COMPRESSED_CBOR,
743				duration,
744				payload.len() as f32 / ref_compressed,
745			));
746		}
747		// Sort the results by ascending size
748		results.sort_by(|(a, _, _, _), (b, _, _, _)| a.cmp(b));
749		for (size, name, duration, factor) in &results {
750			info!("{name} - Size: {size} - Duration: {duration:?} - Factor: {factor}");
751		}
752		// Check the expected sorted results
753		let results: Vec<&str> = results.into_iter().map(|(_, name, _, _)| name).collect();
754		assert_eq!(
755			results,
756			vec![
757				BINCODE_REF,
758				COMPRESSED_BINCODE_REF,
759				COMPRESSED_CBOR,
760				COMPRESSED_BINCODE,
761				COMPRESSED_UNVERSIONED,
762				CBOR,
763				COMPRESSED_VERSIONED,
764				BINCODE,
765				UNVERSIONED,
766				VERSIONED,
767			]
768		)
769	}
770}