1use std::collections::HashMap;
2
3use futures::stream::SplitSink;
4use futures::Future;
5use futures::SinkExt;
6use futures::StreamExt;
7use serde::de::DeserializeOwned;
8use serde_json::json;
9use serde_json::Value;
10use tokio::net::TcpStream;
11use tokio::sync::mpsc;
12use tokio::sync::oneshot;
13use tokio_tungstenite::tungstenite::Message;
14use tokio_tungstenite::MaybeTlsStream;
15use tokio_tungstenite::WebSocketStream;
16
17use crate::rpc::RpcResult;
18use crate::SurrealMessage;
19use crate::SurrealResponseData;
20
21type SurrealResponseSender = oneshot::Sender<SurrealResponseData>;
22
23#[derive(Debug)]
24pub struct SurrealResponse {
25 receiver: oneshot::Receiver<SurrealResponseData>,
26}
27impl Future for SurrealResponse {
28 type Output = <oneshot::Receiver<SurrealResponseData> as Future>::Output;
29
30 fn poll(
31 self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>,
32 ) -> std::task::Poll<Self::Output> {
33 unsafe {
36 self
37 .map_unchecked_mut(|response| &mut response.receiver)
38 .poll(cx)
39 }
40 }
41}
42
43pub struct SurrealClient {
44 socket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
45 resp_sink: mpsc::UnboundedSender<(String, SurrealResponseSender)>,
46}
47
48impl SurrealClient {
49 pub async fn new(url: &str) -> RpcResult<Self> {
50 let (socket, _) = tokio_tungstenite::connect_async(url).await?;
51 let (socket_sink, mut socket_stream) = socket.split();
52 let (resp_sink, resp_stream) = tokio::sync::mpsc::unbounded_channel();
53 let mut recv_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(resp_stream);
54
55 tokio::spawn(async move {
56 let mut requests: HashMap<String, SurrealResponseSender> = HashMap::new();
57
58 loop {
59 tokio::select! {
60 receiver = recv_stream.next() => {
61 if let Some((id, sender)) = receiver {
62 requests.insert(id, sender);
63 }
64 },
65
66 res = socket_stream.next() => {
67 if let Some(Ok(Message::Text(json_message))) = res {
68 match serde_json::from_str::<SurrealResponseData>(&json_message) {
69 Ok(response) => if let Some(sender) = requests.remove(&response.id) {
70 if let Err(_) = sender.send(response) {
71 }
75 },
76 Err(_) => {
77 },
82 };
83 }
84 },
85 }
86 }
87 });
88
89 Ok(Self {
90 socket_sink,
91 resp_sink,
92 })
93 }
94
95 pub async fn signin<T: AsRef<str>>(&mut self, user: T, pass: T) -> RpcResult<()>
96 where
97 String: From<T>,
98 {
99 self
100 .send_message(
101 "signin",
102 json!([{
103 "user": String::from(user),
104 "pass": String::from(pass)
105 }]),
106 )
107 .await?
108 .await?;
109
110 Ok(())
111 }
112
113 pub async fn use_namespace<T: AsRef<str>>(&mut self, namespace: T, database: T) -> RpcResult<()>
114 where
115 String: From<T>,
116 {
117 self
118 .send_message(
119 "use",
120 json!([String::from(namespace), String::from(database)]),
121 )
122 .await?
123 .await?;
124
125 Ok(())
126 }
127
128 pub async fn send_message(
129 &mut self, method: &'static str, params: Value,
130 ) -> RpcResult<SurrealResponse> {
131 const ALPHABET: [char; 36] = [
132 '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
133 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
134 ];
135
136 let message = SurrealMessage {
137 id: nanoid::nanoid!(10, &ALPHABET),
138 method: method.to_owned(),
139 params,
140 };
141
142 let (tx, rx) = oneshot::channel();
143
144 self.resp_sink.send((message.id.clone(), tx)).unwrap();
145 self
146 .socket_sink
147 .send(Message::Text(serde_json::to_string(&message).unwrap()))
148 .await?;
149
150 Ok(SurrealResponse { receiver: rx })
151 }
152
153 pub async fn send_query(&mut self, query: String, params: Value) -> RpcResult<SurrealResponse> {
155 Ok(self.send_message("query", json!([query, params])).await?)
156 }
157
158 async fn find_one_value(&mut self, query: String, params: Value) -> RpcResult<Option<Value>> {
163 let response = self.send_query(query, params).await?.await?;
164
165 Ok(
166 response
167 .get_nth_query_result(0)
168 .and_then(|query_results| query_results.results().first().cloned()),
169 )
170 }
171
172 pub async fn find_one<T: DeserializeOwned>(
175 &mut self, query: String, params: Value,
176 ) -> RpcResult<Option<T>> {
177 let value = self.find_one_value(query, params).await?;
178
179 match value {
180 None => Ok(None),
181 Some(inner) => {
182 let deser_result = serde_json::from_value::<T>(inner)?;
183
184 Ok(Some(deser_result))
185 }
186 }
187 }
188
189 pub async fn find_one_key<T: DeserializeOwned>(
192 &mut self, key: &str, query: String, params: Value,
193 ) -> RpcResult<Option<T>> {
194 let response = self.send_query(query, params).await?.await?;
195
196 let value = response
197 .get_nth_query_result(0)
198 .and_then(|query_results| query_results.results_key(key).first().cloned().cloned());
199
200 match value {
201 None => Ok(None),
202 Some(inner) => {
203 let deser_result = serde_json::from_value::<T>(inner)?;
204
205 Ok(Some(deser_result))
206 }
207 }
208 }
209
210 async fn find_many_values(&mut self, query: String, params: Value) -> RpcResult<Vec<Value>> {
215 let response = self.send_query(query, params).await?.await?;
216
217 Ok(
218 response
219 .get_nth_query_result(0)
220 .and_then(|query_results| Some(query_results.results().clone()))
221 .unwrap_or_default(),
222 )
223 }
224
225 pub async fn find_many<T: DeserializeOwned>(
228 &mut self, query: String, params: Value,
229 ) -> RpcResult<Vec<T>> {
230 let values = self.find_many_values(query, params).await?;
231 let deser_result: Vec<T> = serde_json::from_value(Value::Array(values))?;
232
233 Ok(deser_result)
234 }
235
236 pub async fn find_many_key<T: DeserializeOwned>(
240 &mut self, key: &str, query: String, params: Value,
241 ) -> RpcResult<Vec<T>> {
242 let response = self.send_query(query, params).await?.await?;
243
244 let values = response
245 .get_nth_query_result(0)
246 .and_then(|query_results| Some(query_results.results_key(key)))
247 .unwrap_or_default();
248
249 let deser_result: Vec<T> =
250 serde_json::from_value(Value::Array(values.into_iter().cloned().collect()))?;
251
252 Ok(deser_result)
253 }
254}