surrealdb/api/engine/remote/ws/
mod.rs

1//! WebSocket engine
2
3#[cfg(not(target_arch = "wasm32"))]
4pub(crate) mod native;
5#[cfg(target_arch = "wasm32")]
6pub(crate) mod wasm;
7
8use crate::api;
9use crate::api::conn::Command;
10use crate::api::conn::DbResponse;
11use crate::api::engine::remote::duration_from_str;
12use crate::api::err::Error;
13use crate::api::method::query::QueryResult;
14use crate::api::Connect;
15use crate::api::Result;
16use crate::api::Surreal;
17use crate::dbs::Notification;
18use crate::dbs::QueryMethodResponse;
19use crate::dbs::Status;
20use crate::method::Stats;
21use crate::opt::IntoEndpoint;
22use crate::sql::Value;
23use channel::Sender;
24use indexmap::IndexMap;
25use revision::revisioned;
26use revision::Revisioned;
27use serde::de::DeserializeOwned;
28use serde::Deserialize;
29use std::collections::HashMap;
30use std::io::Read;
31use std::marker::PhantomData;
32use std::time::Duration;
33use surrealdb_core::dbs::Notification as CoreNotification;
34use trice::Instant;
35use uuid::Uuid;
36
37pub(crate) const PATH: &str = "rpc";
38const PING_INTERVAL: Duration = Duration::from_secs(5);
39const REVISION_HEADER: &str = "revision";
40
41enum RequestEffect {
42	/// Completing this request sets a variable to a give value.
43	Set {
44		key: String,
45		value: Value,
46	},
47	/// Completing this request sets a variable to a give value.
48	Clear {
49		key: String,
50	},
51	/// Insert requests repsonses need to be flattened in an array.
52	Insert,
53	/// No effect
54	None,
55}
56
57#[derive(Clone, Copy, Eq, PartialEq, Hash)]
58enum ReplayMethod {
59	Use,
60	Signup,
61	Signin,
62	Invalidate,
63	Authenticate,
64}
65
66struct PendingRequest {
67	// Does resolving this request has some effects.
68	effect: RequestEffect,
69	// The channel to send the result of the request into.
70	response_channel: Sender<Result<DbResponse>>,
71}
72
73struct RouterState<Sink, Stream> {
74	/// Vars currently set by the set method,
75	vars: IndexMap<String, Value>,
76	/// Messages which aught to be replayed on a reconnect.
77	replay: IndexMap<ReplayMethod, Command>,
78	/// Pending live queries
79	live_queries: HashMap<Uuid, channel::Sender<CoreNotification>>,
80	/// Send requests which are still awaiting an awnser.
81	pending_requests: HashMap<i64, PendingRequest>,
82	/// The last time a message was recieved from the server.
83	last_activity: Instant,
84	/// The sink into which messages are send to surrealdb
85	sink: Sink,
86	/// The stream from which messages are recieved from surrealdb
87	stream: Stream,
88}
89
90impl<Sink, Stream> RouterState<Sink, Stream> {
91	pub fn new(sink: Sink, stream: Stream) -> Self {
92		RouterState {
93			vars: IndexMap::new(),
94			replay: IndexMap::new(),
95			live_queries: HashMap::new(),
96			pending_requests: HashMap::new(),
97			last_activity: Instant::now(),
98			sink,
99			stream,
100		}
101	}
102}
103
104enum HandleResult {
105	/// Socket disconnected, should continue to reconnect
106	Disconnected,
107	/// Nothing wrong continue as normal.
108	Ok,
109}
110
111/// The WS scheme used to connect to `ws://` endpoints
112#[derive(Debug)]
113pub struct Ws;
114
115/// The WSS scheme used to connect to `wss://` endpoints
116#[derive(Debug)]
117pub struct Wss;
118
119/// A WebSocket client for communicating with the server via WebSockets
120#[derive(Debug, Clone)]
121pub struct Client(());
122
123impl Surreal<Client> {
124	/// Connects to a specific database endpoint, saving the connection on the static client
125	///
126	/// # Examples
127	///
128	/// ```no_run
129	/// use once_cell::sync::Lazy;
130	/// use surrealdb::Surreal;
131	/// use surrealdb::engine::remote::ws::Client;
132	/// use surrealdb::engine::remote::ws::Ws;
133	///
134	/// static DB: Lazy<Surreal<Client>> = Lazy::new(Surreal::init);
135	///
136	/// # #[tokio::main]
137	/// # async fn main() -> surrealdb::Result<()> {
138	/// DB.connect::<Ws>("localhost:8000").await?;
139	/// # Ok(())
140	/// # }
141	/// ```
142	pub fn connect<P>(
143		&self,
144		address: impl IntoEndpoint<P, Client = Client>,
145	) -> Connect<Client, ()> {
146		Connect {
147			router: self.router.clone(),
148			engine: PhantomData,
149			address: address.into_endpoint(),
150			capacity: 0,
151			waiter: self.waiter.clone(),
152			response_type: PhantomData,
153		}
154	}
155}
156
157#[revisioned(revision = 1)]
158#[derive(Clone, Debug, Deserialize)]
159pub(crate) struct Failure {
160	pub(crate) code: i64,
161	pub(crate) message: String,
162}
163
164#[revisioned(revision = 1)]
165#[derive(Debug, Deserialize)]
166pub(crate) enum Data {
167	Other(Value),
168	Query(Vec<QueryMethodResponse>),
169	Live(Notification),
170}
171
172type ServerResult = std::result::Result<Data, Failure>;
173
174impl From<Failure> for Error {
175	fn from(failure: Failure) -> Self {
176		match failure.code {
177			-32600 => Self::InvalidRequest(failure.message),
178			-32602 => Self::InvalidParams(failure.message),
179			-32603 => Self::InternalError(failure.message),
180			-32700 => Self::ParseError(failure.message),
181			_ => Self::Query(failure.message),
182		}
183	}
184}
185
186impl DbResponse {
187	fn from(result: ServerResult) -> Result<Self> {
188		match result.map_err(Error::from)? {
189			Data::Other(value) => Ok(DbResponse::Other(value)),
190			Data::Query(responses) => {
191				let mut map =
192					IndexMap::<usize, (Stats, QueryResult)>::with_capacity(responses.len());
193
194				for (index, response) in responses.into_iter().enumerate() {
195					let stats = Stats {
196						execution_time: duration_from_str(&response.time),
197					};
198					match response.status {
199						Status::Ok => {
200							map.insert(index, (stats, Ok(response.result)));
201						}
202						Status::Err => {
203							map.insert(
204								index,
205								(stats, Err(Error::Query(response.result.as_raw_string()).into())),
206							);
207						}
208						_ => unreachable!(),
209					}
210				}
211
212				Ok(DbResponse::Query(api::Response {
213					results: map,
214					..api::Response::new()
215				}))
216			}
217			// Live notifications don't call this method
218			Data::Live(..) => unreachable!(),
219		}
220	}
221}
222
223#[revisioned(revision = 1)]
224#[derive(Debug, Deserialize)]
225pub(crate) struct Response {
226	id: Option<Value>,
227	pub(crate) result: ServerResult,
228}
229
230fn serialize<V>(value: &V, revisioned: bool) -> Result<Vec<u8>>
231where
232	V: serde::Serialize + Revisioned,
233{
234	if revisioned {
235		let mut buf = Vec::new();
236		value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
237		return Ok(buf);
238	}
239	crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into()))
240}
241
242fn deserialize<A, T>(bytes: &mut A, revisioned: bool) -> Result<T>
243where
244	A: Read,
245	T: Revisioned + DeserializeOwned,
246{
247	if revisioned {
248		return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into()));
249	}
250	let mut buf = Vec::new();
251	bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
252	crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
253}