surreal_simple_client/
surreal_client.rs

1use std::collections::HashMap;
2
3use futures::stream::SplitSink;
4use futures::Future;
5use futures::SinkExt;
6use futures::StreamExt;
7use serde::de::DeserializeOwned;
8use serde_json::json;
9use serde_json::Value;
10use tokio::net::TcpStream;
11use tokio::sync::mpsc;
12use tokio::sync::oneshot;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::MaybeTlsStream;
15use tokio_tungstenite::WebSocketStream;
16
17use crate::rpc::RpcResult;
18use crate::SurrealMessage;
19use crate::SurrealResponseData;
20
21type SurrealResponseSender = oneshot::Sender<SurrealResponseData>;
22
23#[derive(Debug)]
24pub struct SurrealResponse {
25  receiver: oneshot::Receiver<SurrealResponseData>,
26}
27impl Future for SurrealResponse {
28  type Output = <oneshot::Receiver<SurrealResponseData> as Future>::Output;
29
30  fn poll(
31    self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>,
32  ) -> std::task::Poll<Self::Output> {
33    // SAFETY
34    // As long as nothing ever hands out an `&(mut) Receiver` this is safe.
35    unsafe {
36      self
37        .map_unchecked_mut(|response| &mut response.receiver)
38        .poll(cx)
39    }
40  }
41}
42
43pub struct SurrealClient {
44  socket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
45  resp_sink: mpsc::UnboundedSender<(String, SurrealResponseSender)>,
46}
47
48impl SurrealClient {
49  pub async fn new(url: &str) -> RpcResult<Self> {
50    let (socket, _) = tokio_tungstenite::connect_async(url).await?;
51    let (socket_sink, mut socket_stream) = socket.split();
52    let (resp_sink, resp_stream) = tokio::sync::mpsc::unbounded_channel();
53    let mut recv_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(resp_stream);
54
55    tokio::spawn(async move {
56      let mut requests: HashMap<String, SurrealResponseSender> = HashMap::new();
57
58      loop {
59        tokio::select! {
60            receiver = recv_stream.next() => {
61              if let Some((id, sender)) = receiver {
62                  requests.insert(id, sender);
63              }
64            },
65
66            res = socket_stream.next() => {
67              if let Some(Ok(Message::Text(json_message))) = res {
68                match serde_json::from_str::<SurrealResponseData>(&json_message) {
69                  Ok(response) => if let Some(sender) = requests.remove(&response.id) {
70                    if let Err(_) = sender.send(response) {
71                      // do nothing at the moment, an error from a .send() call
72                      // means the receiver is no longer listening. Which is a
73                      // possible & valid state.
74                    }
75                  },
76                  Err(_) => {
77                    // TODO: this error should be handled, probably by sending
78                    // it through the `sender`. But that would require the
79                    // `SurrealResponseSender` to accept an enum as it only accepts
80                    // valid data at the moment.
81                  },
82                };
83              }
84            },
85        }
86      }
87    });
88
89    Ok(Self {
90      socket_sink,
91      resp_sink,
92    })
93  }
94
95  pub async fn signin<T: AsRef<str>>(&mut self, user: T, pass: T) -> RpcResult<()>
96  where
97    String: From<T>,
98  {
99    self
100      .send_message(
101        "signin",
102        json!([{
103            "user": String::from(user),
104            "pass": String::from(pass)
105        }]),
106      )
107      .await?
108      .await?;
109
110    Ok(())
111  }
112
113  pub async fn use_namespace<T: AsRef<str>>(&mut self, namespace: T, database: T) -> RpcResult<()>
114  where
115    String: From<T>,
116  {
117    self
118      .send_message(
119        "use",
120        json!([String::from(namespace), String::from(database)]),
121      )
122      .await?
123      .await?;
124
125    Ok(())
126  }
127
128  pub async fn send_message(
129    &mut self, method: &'static str, params: Value,
130  ) -> RpcResult<SurrealResponse> {
131    const ALPHABET: [char; 36] = [
132      '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
133      'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
134    ];
135
136    let message = SurrealMessage {
137      id: nanoid::nanoid!(10, &ALPHABET),
138      method: method.to_owned(),
139      params,
140    };
141
142    let (tx, rx) = oneshot::channel();
143
144    self.resp_sink.send((message.id.clone(), tx)).unwrap();
145    self
146      .socket_sink
147      .send(Message::Text(serde_json::to_string(&message).unwrap()))
148      .await?;
149
150    Ok(SurrealResponse { receiver: rx })
151  }
152
153  /// Send a query using the current socket connection then return the raw [SurrealResponse]
154  pub async fn send_query(&mut self, query: String, params: Value) -> RpcResult<SurrealResponse> {
155    Ok(self.send_message("query", json!([query, params])).await?)
156  }
157
158  /// Send a query using the current socket connection then return the **first** [Value]
159  /// from the received [SurrealResponse]
160  ///
161  /// Use [`Self::find_one()`] instead to get a typed return value.
162  async fn find_one_value(&mut self, query: String, params: Value) -> RpcResult<Option<Value>> {
163    let response = self.send_query(query, params).await?.await?;
164
165    Ok(
166      response
167        .get_nth_query_result(0)
168        .and_then(|query_results| query_results.results().first().cloned()),
169    )
170  }
171
172  /// Send a query using the current socket connection then return the **first** [T]
173  /// from the received [SurrealResponse].
174  pub async fn find_one<T: DeserializeOwned>(
175    &mut self, query: String, params: Value,
176  ) -> RpcResult<Option<T>> {
177    let value = self.find_one_value(query, params).await?;
178
179    match value {
180      None => Ok(None),
181      Some(inner) => {
182        let deser_result = serde_json::from_value::<T>(inner)?;
183
184        Ok(Some(deser_result))
185      }
186    }
187  }
188
189  /// Fetch the value for the given `key` out of the first row that is returned by
190  /// the supplied `query`. If the key is missing then [None] is returned.
191  pub async fn find_one_key<T: DeserializeOwned>(
192    &mut self, key: &str, query: String, params: Value,
193  ) -> RpcResult<Option<T>> {
194    let response = self.send_query(query, params).await?.await?;
195
196    let value = response
197      .get_nth_query_result(0)
198      .and_then(|query_results| query_results.results_key(key).first().cloned().cloned());
199
200    match value {
201      None => Ok(None),
202      Some(inner) => {
203        let deser_result = serde_json::from_value::<T>(inner)?;
204
205        Ok(Some(deser_result))
206      }
207    }
208  }
209
210  /// Send a query using the current socket connection then return the [Value]s
211  /// from the received [SurrealResponse]
212  ///
213  /// Use [`Self::find_many()`] instead to get a typed return value.
214  async fn find_many_values(&mut self, query: String, params: Value) -> RpcResult<Vec<Value>> {
215    let response = self.send_query(query, params).await?.await?;
216
217    Ok(
218      response
219        .get_nth_query_result(0)
220        .and_then(|query_results| Some(query_results.results().clone()))
221        .unwrap_or_default(),
222    )
223  }
224
225  /// Send a query using the current socket connection then return the many [`<T>`]
226  /// from the received [SurrealResponse].
227  pub async fn find_many<T: DeserializeOwned>(
228    &mut self, query: String, params: Value,
229  ) -> RpcResult<Vec<T>> {
230    let values = self.find_many_values(query, params).await?;
231    let deser_result: Vec<T> = serde_json::from_value(Value::Array(values))?;
232
233    Ok(deser_result)
234  }
235
236  /// Get the value for every row that were returned by the supplied `query` and
237  /// where `key` exists. If the `key` is missing from a row then the row will
238  /// be filtered out of the returned [Vec].
239  pub async fn find_many_key<T: DeserializeOwned>(
240    &mut self, key: &str, query: String, params: Value,
241  ) -> RpcResult<Vec<T>> {
242    let response = self.send_query(query, params).await?.await?;
243
244    let values = response
245      .get_nth_query_result(0)
246      .and_then(|query_results| Some(query_results.results_key(key)))
247      .unwrap_or_default();
248
249    let deser_result: Vec<T> =
250      serde_json::from_value(Value::Array(values.into_iter().cloned().collect()))?;
251
252    Ok(deser_result)
253  }
254}