1use method::BoxFuture;
4use semver::BuildMetadata;
5use semver::Version;
6use semver::VersionReq;
7use std::fmt;
8use std::fmt::Debug;
9use std::future::IntoFuture;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::OnceLock;
13use tokio::sync::watch;
14
15macro_rules! transparent_wrapper{
16 (
17 $(#[$m:meta])*
18 $vis:vis struct $name:ident($field_vis:vis $inner:ty)
19 ) => {
20 $(#[$m])*
21 #[repr(transparent)]
22 $vis struct $name($field_vis $inner);
23
24 impl $name{
25 #[doc(hidden)]
26 #[allow(dead_code)]
27 pub fn from_inner(inner: $inner) -> Self{
28 $name(inner)
29 }
30
31 #[doc(hidden)]
32 #[allow(dead_code)]
33 pub fn from_inner_ref(inner: &$inner) -> &Self{
34 unsafe{
35 std::mem::transmute::<&$inner,&$name>(inner)
36 }
37 }
38
39 #[doc(hidden)]
40 #[allow(dead_code)]
41 pub fn from_inner_mut(inner: &mut $inner) -> &mut Self{
42 unsafe{
43 std::mem::transmute::<&mut $inner,&mut $name>(inner)
44 }
45 }
46
47 #[doc(hidden)]
48 #[allow(dead_code)]
49 pub fn into_inner(self) -> $inner{
50 self.0
51 }
52
53 #[doc(hidden)]
54 #[allow(dead_code)]
55 pub fn into_inner_ref(&self) -> &$inner{
56 &self.0
57 }
58
59 #[doc(hidden)]
60 #[allow(dead_code)]
61 pub fn into_inner_mut(&mut self) -> &mut $inner{
62 &mut self.0
63 }
64 }
65
66 impl std::fmt::Display for $name{
67 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
68 self.0.fmt(fmt)
69 }
70 }
71 impl std::fmt::Debug for $name{
72 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result{
73 self.0.fmt(fmt)
74 }
75 }
76 };
77}
78
79macro_rules! impl_serialize_wrapper {
80 ($ty:ty) => {
81 impl ::revision::Revisioned for $ty {
82 fn revision() -> u16 {
83 CoreValue::revision()
84 }
85
86 fn serialize_revisioned<W: std::io::Write>(
87 &self,
88 w: &mut W,
89 ) -> Result<(), revision::Error> {
90 self.0.serialize_revisioned(w)
91 }
92
93 fn deserialize_revisioned<R: std::io::Read>(r: &mut R) -> Result<Self, revision::Error>
94 where
95 Self: Sized,
96 {
97 ::revision::Revisioned::deserialize_revisioned(r).map(Self::from_inner)
98 }
99 }
100
101 impl ::serde::Serialize for $ty {
102 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103 where
104 S: ::serde::ser::Serializer,
105 {
106 self.0.serialize(serializer)
107 }
108 }
109
110 impl<'de> ::serde::de::Deserialize<'de> for $ty {
111 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112 where
113 D: ::serde::de::Deserializer<'de>,
114 {
115 Ok(Self::from_inner(::serde::de::Deserialize::deserialize(deserializer)?))
116 }
117 }
118 };
119}
120
121pub mod engine;
122pub mod err;
123#[cfg(feature = "protocol-http")]
124pub mod headers;
125pub mod method;
126pub mod opt;
127pub mod value;
128
129mod conn;
130
131use self::conn::Router;
132use self::err::Error;
133use self::opt::Endpoint;
134use self::opt::EndpointKind;
135use self::opt::WaitFor;
136
137pub use method::query::Response;
138
139pub type Result<T> = std::result::Result<T, crate::Error>;
141
142type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
144
145const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <3.0.0", "20230701.55918b7c");
146const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
147
148pub trait Connection: conn::Connection {}
150
151#[derive(Debug)]
153#[must_use = "futures do nothing unless you `.await` or poll them"]
154pub struct Connect<C: Connection, Response> {
155 router: Arc<OnceLock<Router>>,
156 engine: PhantomData<C>,
157 address: Result<Endpoint>,
158 capacity: usize,
159 waiter: Arc<Waiter>,
160 response_type: PhantomData<Response>,
161}
162
163impl<C, R> Connect<C, R>
164where
165 C: Connection,
166{
167 pub const fn with_capacity(mut self, capacity: usize) -> Self {
194 self.capacity = capacity;
195 self
196 }
197}
198
199impl<Client> IntoFuture for Connect<Client, Surreal<Client>>
200where
201 Client: Connection,
202{
203 type Output = Result<Surreal<Client>>;
204 type IntoFuture = BoxFuture<'static, Self::Output>;
205
206 fn into_future(self) -> Self::IntoFuture {
207 Box::pin(async move {
208 let mut endpoint = self.address?;
209 let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
210 let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
211 if endpoint_kind.is_remote() {
212 let mut version = client.version().await?;
213 version.pre = Default::default();
215 client.check_server_version(&version).await?;
216 if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
217 endpoint.supports_revision = true;
219 client = Client::connect(endpoint, self.capacity).await?;
220 }
221 }
222 client.waiter.0.send(Some(WaitFor::Connection)).ok();
224 Ok(client)
225 })
226 }
227}
228
229impl<Client> IntoFuture for Connect<Client, ()>
230where
231 Client: Connection,
232{
233 type Output = Result<()>;
234 type IntoFuture = BoxFuture<'static, Self::Output>;
235
236 fn into_future(self) -> Self::IntoFuture {
237 Box::pin(async move {
238 if self.router.get().is_some() {
240 return Err(Error::AlreadyConnected.into());
241 }
242 let mut endpoint = self.address?;
243 let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
244 let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
245 if endpoint_kind.is_remote() {
246 let mut version = client.version().await?;
247 version.pre = Default::default();
249 client.check_server_version(&version).await?;
250 if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
251 endpoint.supports_revision = true;
253 client = Client::connect(endpoint, self.capacity).await?;
254 }
255 }
256 let cell =
257 Arc::into_inner(client.router).expect("new connection to have no references");
258 let router = cell.into_inner().expect("router to be set");
259 self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
260 self.waiter.0.send(Some(WaitFor::Connection)).ok();
262 Ok(())
263 })
264 }
265}
266
267#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
268pub(crate) enum ExtraFeatures {
269 Backup,
270 LiveQueries,
271}
272
273pub struct Surreal<C: Connection> {
275 router: Arc<OnceLock<Router>>,
276 waiter: Arc<Waiter>,
277 engine: PhantomData<C>,
278}
279
280impl<C> Surreal<C>
281where
282 C: Connection,
283{
284 pub(crate) fn new_from_router_waiter(
285 router: Arc<OnceLock<Router>>,
286 waiter: Arc<Waiter>,
287 ) -> Self {
288 Surreal {
289 router,
290 waiter,
291 engine: PhantomData,
292 }
293 }
294
295 async fn check_server_version(&self, version: &Version) -> Result<()> {
296 let (versions, build_meta) = SUPPORTED_VERSIONS;
297 let req = VersionReq::parse(versions).expect("valid supported versions");
299 let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
300 let server_build = &version.build;
301 if !req.matches(version) {
302 return Err(Error::VersionMismatch {
303 server_version: version.clone(),
304 supported_versions: versions.to_owned(),
305 }
306 .into());
307 } else if !server_build.is_empty() && server_build < &build_meta {
308 return Err(Error::BuildMetadataMismatch {
309 server_metadata: server_build.clone(),
310 supported_metadata: build_meta,
311 }
312 .into());
313 }
314 Ok(())
315 }
316}
317
318impl<C> Clone for Surreal<C>
319where
320 C: Connection,
321{
322 fn clone(&self) -> Self {
323 Self {
324 router: self.router.clone(),
325 waiter: self.waiter.clone(),
326 engine: self.engine,
327 }
328 }
329}
330
331impl<C> Debug for Surreal<C>
332where
333 C: Connection,
334{
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 f.debug_struct("Surreal")
337 .field("router", &self.router)
338 .field("engine", &self.engine)
339 .finish()
340 }
341}
342
343trait OnceLockExt {
344 fn with_value(value: Router) -> OnceLock<Router> {
345 let cell = OnceLock::new();
346 match cell.set(value) {
347 Ok(()) => cell,
348 Err(_) => unreachable!("don't have exclusive access to `cell`"),
349 }
350 }
351
352 fn extract(&self) -> Result<&Router>;
353}
354
355impl OnceLockExt for OnceLock<Router> {
356 fn extract(&self) -> Result<&Router> {
357 let router = self.get().ok_or(Error::ConnectionUninitialised)?;
358 Ok(router)
359 }
360}