surrealdb/api/
mod.rs

1//! Functionality for connecting to local and remote databases
2
3use method::BoxFuture;
4use semver::BuildMetadata;
5use semver::Version;
6use semver::VersionReq;
7use std::fmt;
8use std::fmt::Debug;
9use std::future::IntoFuture;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::OnceLock;
13use tokio::sync::watch;
14
15macro_rules! transparent_wrapper{
16	(
17		$(#[$m:meta])*
18		$vis:vis struct $name:ident($field_vis:vis $inner:ty)
19	) => {
20		$(#[$m])*
21		#[repr(transparent)]
22		$vis struct $name($field_vis $inner);
23
24		impl $name{
25			#[doc(hidden)]
26			#[allow(dead_code)]
27			pub fn from_inner(inner: $inner) -> Self{
28				$name(inner)
29			}
30
31			#[doc(hidden)]
32			#[allow(dead_code)]
33			pub fn from_inner_ref(inner: &$inner) -> &Self{
34				unsafe{
35					std::mem::transmute::<&$inner,&$name>(inner)
36				}
37			}
38
39			#[doc(hidden)]
40			#[allow(dead_code)]
41			pub fn from_inner_mut(inner: &mut $inner) -> &mut Self{
42				unsafe{
43					std::mem::transmute::<&mut $inner,&mut $name>(inner)
44				}
45			}
46
47			#[doc(hidden)]
48			#[allow(dead_code)]
49			pub fn into_inner(self) -> $inner{
50				self.0
51			}
52
53			#[doc(hidden)]
54			#[allow(dead_code)]
55			pub fn into_inner_ref(&self) -> &$inner{
56				&self.0
57			}
58
59			#[doc(hidden)]
60			#[allow(dead_code)]
61			pub fn into_inner_mut(&mut self) -> &mut $inner{
62				&mut self.0
63			}
64		}
65
66		impl std::fmt::Display for $name{
67			fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
68				self.0.fmt(fmt)
69			}
70		}
71		impl std::fmt::Debug for $name{
72			fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
73				self.0.fmt(fmt)
74			}
75		}
76	};
77}
78
79macro_rules! impl_serialize_wrapper {
80	($ty:ty) => {
81		impl ::revision::Revisioned for $ty {
82			fn revision() -> u16 {
83				CoreValue::revision()
84			}
85
86			fn serialize_revisioned<W: std::io::Write>(
87				&self,
88				w: &mut W,
89			) -> Result<(), revision::Error> {
90				self.0.serialize_revisioned(w)
91			}
92
93			fn deserialize_revisioned<R: std::io::Read>(r: &mut R) -> Result<Self, revision::Error>
94			where
95				Self: Sized,
96			{
97				::revision::Revisioned::deserialize_revisioned(r).map(Self::from_inner)
98			}
99		}
100
101		impl ::serde::Serialize for $ty {
102			fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103			where
104				S: ::serde::ser::Serializer,
105			{
106				self.0.serialize(serializer)
107			}
108		}
109
110		impl<'de> ::serde::de::Deserialize<'de> for $ty {
111			fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112			where
113				D: ::serde::de::Deserializer<'de>,
114			{
115				Ok(Self::from_inner(::serde::de::Deserialize::deserialize(deserializer)?))
116			}
117		}
118	};
119}
120
121pub mod engine;
122pub mod err;
123#[cfg(feature = "protocol-http")]
124pub mod headers;
125pub mod method;
126pub mod opt;
127pub mod value;
128
129mod conn;
130
131use self::conn::Router;
132use self::err::Error;
133use self::opt::Endpoint;
134use self::opt::EndpointKind;
135use self::opt::WaitFor;
136
137pub use method::query::Response;
138
139/// A specialized `Result` type
140pub type Result<T> = std::result::Result<T, crate::Error>;
141
142// Channel for waiters
143type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
144
145const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <3.0.0", "20230701.55918b7c");
146const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
147
148/// Connection trait implemented by supported engines
149pub trait Connection: conn::Connection {}
150
151/// The future returned when creating a new SurrealDB instance
152#[derive(Debug)]
153#[must_use = "futures do nothing unless you `.await` or poll them"]
154pub struct Connect<C: Connection, Response> {
155	router: Arc<OnceLock<Router>>,
156	engine: PhantomData<C>,
157	address: Result<Endpoint>,
158	capacity: usize,
159	waiter: Arc<Waiter>,
160	response_type: PhantomData<Response>,
161}
162
163impl<C, R> Connect<C, R>
164where
165	C: Connection,
166{
167	/// Sets the maximum capacity of the connection
168	///
169	/// This is used to set bounds of the channels used internally
170	/// as well set the capacity of the `HashMap` used for routing
171	/// responses in case of the WebSocket client.
172	///
173	/// Setting this capacity to `0` (the default) means that
174	/// unbounded channels will be used. If your queries per second
175	/// are so high that the client is running out of memory,
176	/// it might be helpful to set this to a number that works best
177	/// for you.
178	///
179	/// # Examples
180	///
181	/// ```no_run
182	/// # #[tokio::main]
183	/// # async fn main() -> surrealdb::Result<()> {
184	/// use surrealdb::engine::remote::ws::Ws;
185	/// use surrealdb::Surreal;
186	///
187	/// let db = Surreal::new::<Ws>("localhost:8000")
188	///     .with_capacity(100_000)
189	///     .await?;
190	/// # Ok(())
191	/// # }
192	/// ```
193	pub const fn with_capacity(mut self, capacity: usize) -> Self {
194		self.capacity = capacity;
195		self
196	}
197}
198
199impl<Client> IntoFuture for Connect<Client, Surreal<Client>>
200where
201	Client: Connection,
202{
203	type Output = Result<Surreal<Client>>;
204	type IntoFuture = BoxFuture<'static, Self::Output>;
205
206	fn into_future(self) -> Self::IntoFuture {
207		Box::pin(async move {
208			let mut endpoint = self.address?;
209			let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
210			let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
211			if endpoint_kind.is_remote() {
212				let mut version = client.version().await?;
213				// we would like to be able to connect to pre-releases too
214				version.pre = Default::default();
215				client.check_server_version(&version).await?;
216				if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
217					// Switch to revision based serialisation
218					endpoint.supports_revision = true;
219					client = Client::connect(endpoint, self.capacity).await?;
220				}
221			}
222			// Both ends of the channel are still alive at this point
223			client.waiter.0.send(Some(WaitFor::Connection)).ok();
224			Ok(client)
225		})
226	}
227}
228
229impl<Client> IntoFuture for Connect<Client, ()>
230where
231	Client: Connection,
232{
233	type Output = Result<()>;
234	type IntoFuture = BoxFuture<'static, Self::Output>;
235
236	fn into_future(self) -> Self::IntoFuture {
237		Box::pin(async move {
238			// Avoid establishing another connection if already connected
239			if self.router.get().is_some() {
240				return Err(Error::AlreadyConnected.into());
241			}
242			let mut endpoint = self.address?;
243			let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
244			let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
245			if endpoint_kind.is_remote() {
246				let mut version = client.version().await?;
247				// we would like to be able to connect to pre-releases too
248				version.pre = Default::default();
249				client.check_server_version(&version).await?;
250				if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
251					// Switch to revision based serialisation
252					endpoint.supports_revision = true;
253					client = Client::connect(endpoint, self.capacity).await?;
254				}
255			}
256			let cell =
257				Arc::into_inner(client.router).expect("new connection to have no references");
258			let router = cell.into_inner().expect("router to be set");
259			self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
260			// Both ends of the channel are still alive at this point
261			self.waiter.0.send(Some(WaitFor::Connection)).ok();
262			Ok(())
263		})
264	}
265}
266
267#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
268pub(crate) enum ExtraFeatures {
269	Backup,
270	LiveQueries,
271}
272
273/// A database client instance for embedded or remote databases
274pub struct Surreal<C: Connection> {
275	router: Arc<OnceLock<Router>>,
276	waiter: Arc<Waiter>,
277	engine: PhantomData<C>,
278}
279
280impl<C> Surreal<C>
281where
282	C: Connection,
283{
284	pub(crate) fn new_from_router_waiter(
285		router: Arc<OnceLock<Router>>,
286		waiter: Arc<Waiter>,
287	) -> Self {
288		Surreal {
289			router,
290			waiter,
291			engine: PhantomData,
292		}
293	}
294
295	async fn check_server_version(&self, version: &Version) -> Result<()> {
296		let (versions, build_meta) = SUPPORTED_VERSIONS;
297		// invalid version requirements should be caught during development
298		let req = VersionReq::parse(versions).expect("valid supported versions");
299		let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
300		let server_build = &version.build;
301		if !req.matches(version) {
302			return Err(Error::VersionMismatch {
303				server_version: version.clone(),
304				supported_versions: versions.to_owned(),
305			}
306			.into());
307		} else if !server_build.is_empty() && server_build < &build_meta {
308			return Err(Error::BuildMetadataMismatch {
309				server_metadata: server_build.clone(),
310				supported_metadata: build_meta,
311			}
312			.into());
313		}
314		Ok(())
315	}
316}
317
318impl<C> Clone for Surreal<C>
319where
320	C: Connection,
321{
322	fn clone(&self) -> Self {
323		Self {
324			router: self.router.clone(),
325			waiter: self.waiter.clone(),
326			engine: self.engine,
327		}
328	}
329}
330
331impl<C> Debug for Surreal<C>
332where
333	C: Connection,
334{
335	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336		f.debug_struct("Surreal")
337			.field("router", &self.router)
338			.field("engine", &self.engine)
339			.finish()
340	}
341}
342
343trait OnceLockExt {
344	fn with_value(value: Router) -> OnceLock<Router> {
345		let cell = OnceLock::new();
346		match cell.set(value) {
347			Ok(()) => cell,
348			Err(_) => unreachable!("don't have exclusive access to `cell`"),
349		}
350	}
351
352	fn extract(&self) -> Result<&Router>;
353}
354
355impl OnceLockExt for OnceLock<Router> {
356	fn extract(&self) -> Result<&Router> {
357		let router = self.get().ok_or(Error::ConnectionUninitialised)?;
358		Ok(router)
359	}
360}