1use super::live;
2use super::Stream;
3
4use crate::api::conn::Method;
5use crate::api::conn::Param;
6use crate::api::err::Error;
7use crate::api::opt;
8use crate::api::Connection;
9use crate::api::ExtraFeatures;
10use crate::api::Result;
11use crate::engine::any::Any;
12use crate::method::OnceLockExt;
13use crate::method::Stats;
14use crate::method::WithStats;
15use crate::sql;
16use crate::sql::to_value;
17use crate::sql::Statement;
18use crate::sql::Value;
19use crate::Notification;
20use crate::Surreal;
21use futures::future::Either;
22use futures::stream::SelectAll;
23use futures::StreamExt;
24use indexmap::IndexMap;
25use serde::de::DeserializeOwned;
26use serde::Serialize;
27use std::borrow::Cow;
28use std::collections::BTreeMap;
29use std::collections::HashMap;
30use std::future::Future;
31use std::future::IntoFuture;
32use std::mem;
33use std::pin::Pin;
34use std::task::Context;
35use std::task::Poll;
36
37#[derive(Debug)]
39#[must_use = "futures do nothing unless you `.await` or poll them"]
40pub struct Query<'r, C: Connection> {
41 pub(crate) inner: Result<ValidQuery<'r, C>>,
42}
43
44#[derive(Debug)]
45pub(crate) struct ValidQuery<'r, C: Connection> {
46 pub client: Cow<'r, Surreal<C>>,
47 pub query: Vec<Statement>,
48 pub bindings: BTreeMap<String, Value>,
49 pub register_live_queries: bool,
50}
51
52impl<'r, C> Query<'r, C>
53where
54 C: Connection,
55{
56 pub(crate) fn new(
57 client: Cow<'r, Surreal<C>>,
58 query: Vec<Statement>,
59 bindings: BTreeMap<String, Value>,
60 register_live_queries: bool,
61 ) -> Self {
62 Query {
63 inner: Ok(ValidQuery {
64 client,
65 query,
66 bindings,
67 register_live_queries,
68 }),
69 }
70 }
71
72 pub(crate) fn map_valid<F>(self, f: F) -> Self
73 where
74 F: FnOnce(ValidQuery<'r, C>) -> Result<ValidQuery<'r, C>>,
75 {
76 match self.inner {
77 Ok(x) => Query {
78 inner: f(x),
79 },
80 x => Query {
81 inner: x,
82 },
83 }
84 }
85
86 pub fn into_owned(self) -> Query<'static, C> {
88 let inner = match self.inner {
89 Ok(ValidQuery {
90 client,
91 query,
92 bindings,
93 register_live_queries,
94 }) => Ok(ValidQuery::<'static, C> {
95 client: Cow::Owned(client.into_owned()),
96 query,
97 bindings,
98 register_live_queries,
99 }),
100 Err(e) => Err(e),
101 };
102
103 Query {
104 inner,
105 }
106 }
107}
108
109impl<'r, Client> IntoFuture for Query<'r, Client>
110where
111 Client: Connection,
112{
113 type Output = Result<Response>;
114 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
115
116 fn into_future(self) -> Self::IntoFuture {
117 let ValidQuery {
118 client,
119 query,
120 bindings,
121 register_live_queries,
122 } = match self.inner {
123 Ok(x) => x,
124 Err(error) => return Box::pin(async move { Err(error) }),
125 };
126
127 let query_statements = query;
128
129 Box::pin(async move {
130 let router = client.router.extract()?;
132
133 let query_indicies = if register_live_queries {
135 query_statements
136 .iter()
137 .filter(|x| {
139 !matches!(
140 x,
141 Statement::Begin(_) | Statement::Commit(_) | Statement::Cancel(_)
142 )
143 })
144 .enumerate()
145 .filter(|(_, x)| matches!(x, Statement::Live(_)))
146 .map(|(i, _)| i)
147 .collect()
148 } else {
149 Vec::new()
150 };
151
152 if !query_indicies.is_empty() && !router.features.contains(&ExtraFeatures::LiveQueries)
154 {
155 return Err(Error::LiveQueriesNotSupported.into());
156 }
157
158 let mut query = sql::Query::default();
159 query.0 .0 = query_statements;
160
161 let param = Param::query(query, bindings);
162 let mut conn = Client::new(Method::Query);
163 let mut response = conn.execute_query(router, param).await?;
164
165 for idx in query_indicies {
166 let Some((_, result)) = response.results.get(&idx) else {
167 continue;
168 };
169
170 let res = match result {
173 Ok(id) => live::register::<Client>(router, id.clone()).await.map(|rx| {
174 Stream::new(
175 Surreal::new_from_router_waiter(
176 client.router.clone(),
177 client.waiter.clone(),
178 ),
179 id.clone(),
180 Some(rx),
181 )
182 }),
183 Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))),
184 };
185
186 dbg!(&response);
187 response.live_queries.insert(idx, res);
188 }
189
190 response.client =
191 Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone());
192 Ok(response)
193 })
194 }
195}
196
197impl<'r, Client> IntoFuture for WithStats<Query<'r, Client>>
198where
199 Client: Connection,
200{
201 type Output = Result<WithStats<Response>>;
202 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
203
204 fn into_future(self) -> Self::IntoFuture {
205 Box::pin(async move {
206 let response = self.0.await?;
207 Ok(WithStats(response))
208 })
209 }
210}
211
212impl<'r, C> Query<'r, C>
213where
214 C: Connection,
215{
216 pub fn query(self, query: impl opt::IntoQuery) -> Self {
218 self.map_valid(move |mut valid| {
219 let new_query = query.into_query()?;
220 valid.query.extend(new_query);
221 Ok(valid)
222 })
223 }
224
225 pub const fn with_stats(self) -> WithStats<Self> {
227 WithStats(self)
228 }
229
230 pub fn bind(self, bindings: impl Serialize) -> Self {
272 self.map_valid(move |mut valid| {
273 let mut bindings = to_value(bindings)?;
274 if let Value::Array(array) = &mut bindings {
275 if let [Value::Strand(key), value] = &mut array.0[..] {
276 let mut map = BTreeMap::new();
277 map.insert(mem::take(&mut key.0), mem::take(value));
278 bindings = map.into();
279 }
280 }
281 match &mut bindings {
282 Value::Object(map) => valid.bindings.append(&mut map.0),
283 _ => {
284 return Err(Error::InvalidBindings(bindings).into());
285 }
286 }
287
288 Ok(valid)
289 })
290 }
291}
292
293pub(crate) type QueryResult = Result<Value>;
294
295#[derive(Debug)]
297pub struct Response {
298 pub(crate) client: Surreal<Any>,
299 pub(crate) results: IndexMap<usize, (Stats, QueryResult)>,
300 pub(crate) live_queries: IndexMap<usize, Result<Stream<'static, Any, Value>>>,
301}
302
303#[derive(Debug)]
305#[must_use = "streams do nothing unless you poll them"]
306pub struct QueryStream<R>(
307 pub(crate) Either<Stream<'static, Any, R>, SelectAll<Stream<'static, Any, R>>>,
308);
309
310impl futures::Stream for QueryStream<Value> {
311 type Item = Notification<Value>;
312
313 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
314 self.as_mut().0.poll_next_unpin(cx)
315 }
316}
317
318impl<R> futures::Stream for QueryStream<Notification<R>>
319where
320 R: DeserializeOwned + Unpin,
321{
322 type Item = Result<Notification<R>>;
323
324 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
325 self.as_mut().0.poll_next_unpin(cx)
326 }
327}
328
329impl Response {
330 pub(crate) fn new() -> Self {
331 Self {
332 client: Surreal::init(),
333 results: Default::default(),
334 live_queries: Default::default(),
335 }
336 }
337
338 pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Result<R>
400 where
401 R: DeserializeOwned,
402 {
403 index.query_result(self)
404 }
405
406 pub fn stream<R>(&mut self, index: impl opt::QueryStream<R>) -> Result<QueryStream<R>> {
450 index.query_stream(self)
451 }
452
453 pub fn take_errors(&mut self) -> HashMap<usize, crate::Error> {
472 let mut keys = Vec::new();
473 for (key, result) in &self.results {
474 if result.1.is_err() {
475 keys.push(*key);
476 }
477 }
478 let mut errors = HashMap::with_capacity(keys.len());
479 for key in keys {
480 if let Some((_, Err(error))) = self.results.swap_remove(&key) {
481 errors.insert(key, error);
482 }
483 }
484 errors
485 }
486
487 pub fn check(mut self) -> Result<Self> {
503 let mut first_error = None;
504 for (key, result) in &self.results {
505 if result.1.is_err() {
506 first_error = Some(*key);
507 break;
508 }
509 }
510 if let Some(key) = first_error {
511 if let Some((_, Err(error))) = self.results.swap_remove(&key) {
512 return Err(error);
513 }
514 }
515 Ok(self)
516 }
517
518 pub fn num_statements(&self) -> usize {
535 self.results.len()
536 }
537}
538
539impl WithStats<Response> {
540 pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Option<(Stats, Result<R>)>
604 where
605 R: DeserializeOwned,
606 {
607 let stats = index.stats(&self.0)?;
608 let result = index.query_result(&mut self.0);
609 Some((stats, result))
610 }
611
612 pub fn take_errors(&mut self) -> HashMap<usize, (Stats, crate::Error)> {
631 let mut keys = Vec::new();
632 for (key, result) in &self.0.results {
633 if result.1.is_err() {
634 keys.push(*key);
635 }
636 }
637 let mut errors = HashMap::with_capacity(keys.len());
638 for key in keys {
639 if let Some((stats, Err(error))) = self.0.results.swap_remove(&key) {
640 errors.insert(key, (stats, error));
641 }
642 }
643 errors
644 }
645
646 pub fn check(self) -> Result<Self> {
662 let response = self.0.check()?;
663 Ok(Self(response))
664 }
665
666 pub fn num_statements(&self) -> usize {
683 self.0.num_statements()
684 }
685
686 pub fn into_inner(self) -> Response {
688 self.0
689 }
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use crate::Error::Api;
696 use serde::Deserialize;
697
698 #[derive(Debug, Clone, Serialize, Deserialize)]
699 struct Summary {
700 title: String,
701 }
702
703 #[derive(Debug, Clone, Serialize, Deserialize)]
704 struct Article {
705 title: String,
706 body: String,
707 }
708
709 fn to_map(vec: Vec<QueryResult>) -> IndexMap<usize, (Stats, QueryResult)> {
710 vec.into_iter()
711 .map(|result| {
712 let stats = Stats {
713 execution_time: Default::default(),
714 };
715 (stats, result)
716 })
717 .enumerate()
718 .collect()
719 }
720
721 #[test]
722 fn take_from_an_empty_response() {
723 let mut response = Response::new();
724 let value: Value = response.take(0).unwrap();
725 assert!(value.is_none());
726
727 let mut response = Response::new();
728 let option: Option<String> = response.take(0).unwrap();
729 assert!(option.is_none());
730
731 let mut response = Response::new();
732 let vec: Vec<String> = response.take(0).unwrap();
733 assert!(vec.is_empty());
734 }
735
736 #[test]
737 fn take_from_an_errored_query() {
738 let mut response = Response {
739 results: to_map(vec![Err(Error::ConnectionUninitialised.into())]),
740 ..Response::new()
741 };
742 response.take::<Option<()>>(0).unwrap_err();
743 }
744
745 #[test]
746 fn take_from_empty_records() {
747 let mut response = Response {
748 results: to_map(vec![]),
749 ..Response::new()
750 };
751 let value: Value = response.take(0).unwrap();
752 assert_eq!(value, Default::default());
753
754 let mut response = Response {
755 results: to_map(vec![]),
756 ..Response::new()
757 };
758 let option: Option<String> = response.take(0).unwrap();
759 assert!(option.is_none());
760
761 let mut response = Response {
762 results: to_map(vec![]),
763 ..Response::new()
764 };
765 let vec: Vec<String> = response.take(0).unwrap();
766 assert!(vec.is_empty());
767 }
768
769 #[test]
770 fn take_from_a_scalar_response() {
771 let scalar = 265;
772
773 let mut response = Response {
774 results: to_map(vec![Ok(scalar.into())]),
775 ..Response::new()
776 };
777 let value: Value = response.take(0).unwrap();
778 assert_eq!(value, Value::from(scalar));
779
780 let mut response = Response {
781 results: to_map(vec![Ok(scalar.into())]),
782 ..Response::new()
783 };
784 let option: Option<_> = response.take(0).unwrap();
785 assert_eq!(option, Some(scalar));
786
787 let mut response = Response {
788 results: to_map(vec![Ok(scalar.into())]),
789 ..Response::new()
790 };
791 let vec: Vec<usize> = response.take(0).unwrap();
792 assert_eq!(vec, vec![scalar]);
793
794 let scalar = true;
795
796 let mut response = Response {
797 results: to_map(vec![Ok(scalar.into())]),
798 ..Response::new()
799 };
800 let value: Value = response.take(0).unwrap();
801 assert_eq!(value, Value::from(scalar));
802
803 let mut response = Response {
804 results: to_map(vec![Ok(scalar.into())]),
805 ..Response::new()
806 };
807 let option: Option<_> = response.take(0).unwrap();
808 assert_eq!(option, Some(scalar));
809
810 let mut response = Response {
811 results: to_map(vec![Ok(scalar.into())]),
812 ..Response::new()
813 };
814 let vec: Vec<bool> = response.take(0).unwrap();
815 assert_eq!(vec, vec![scalar]);
816 }
817
818 #[test]
819 fn take_preserves_order() {
820 let mut response = Response {
821 results: to_map(vec![
822 Ok(0.into()),
823 Ok(1.into()),
824 Ok(2.into()),
825 Ok(3.into()),
826 Ok(4.into()),
827 Ok(5.into()),
828 Ok(6.into()),
829 Ok(7.into()),
830 ]),
831 ..Response::new()
832 };
833 let Some(four): Option<i32> = response.take(4).unwrap() else {
834 panic!("query not found");
835 };
836 assert_eq!(four, 4);
837 let Some(six): Option<i32> = response.take(6).unwrap() else {
838 panic!("query not found");
839 };
840 assert_eq!(six, 6);
841 let Some(zero): Option<i32> = response.take(0).unwrap() else {
842 panic!("query not found");
843 };
844 assert_eq!(zero, 0);
845 let one: Value = response.take(1).unwrap();
846 assert_eq!(one, Value::from(1));
847 }
848
849 #[test]
850 fn take_key() {
851 let summary = Summary {
852 title: "Lorem Ipsum".to_owned(),
853 };
854 let value = to_value(summary.clone()).unwrap();
855
856 let mut response = Response {
857 results: to_map(vec![Ok(value.clone())]),
858 ..Response::new()
859 };
860 let title: Value = response.take("title").unwrap();
861 assert_eq!(title, Value::from(summary.title.as_str()));
862
863 let mut response = Response {
864 results: to_map(vec![Ok(value.clone())]),
865 ..Response::new()
866 };
867 let Some(title): Option<String> = response.take("title").unwrap() else {
868 panic!("title not found");
869 };
870 assert_eq!(title, summary.title);
871
872 let mut response = Response {
873 results: to_map(vec![Ok(value)]),
874 ..Response::new()
875 };
876 let vec: Vec<String> = response.take("title").unwrap();
877 assert_eq!(vec, vec![summary.title]);
878
879 let article = Article {
880 title: "Lorem Ipsum".to_owned(),
881 body: "Lorem Ipsum Lorem Ipsum".to_owned(),
882 };
883 let value = to_value(article.clone()).unwrap();
884
885 let mut response = Response {
886 results: to_map(vec![Ok(value.clone())]),
887 ..Response::new()
888 };
889 let Some(title): Option<String> = response.take("title").unwrap() else {
890 panic!("title not found");
891 };
892 assert_eq!(title, article.title);
893 let Some(body): Option<String> = response.take("body").unwrap() else {
894 panic!("body not found");
895 };
896 assert_eq!(body, article.body);
897
898 let mut response = Response {
899 results: to_map(vec![Ok(value.clone())]),
900 ..Response::new()
901 };
902 let vec: Vec<String> = response.take("title").unwrap();
903 assert_eq!(vec, vec![article.title.clone()]);
904
905 let mut response = Response {
906 results: to_map(vec![Ok(value)]),
907 ..Response::new()
908 };
909 let value: Value = response.take("title").unwrap();
910 assert_eq!(value, Value::from(article.title));
911 }
912
913 #[test]
914 fn take_partial_records() {
915 let mut response = Response {
916 results: to_map(vec![Ok(vec![true, false].into())]),
917 ..Response::new()
918 };
919 let value: Value = response.take(0).unwrap();
920 assert_eq!(value, vec![Value::from(true), Value::from(false)].into());
921
922 let mut response = Response {
923 results: to_map(vec![Ok(vec![true, false].into())]),
924 ..Response::new()
925 };
926 let vec: Vec<bool> = response.take(0).unwrap();
927 assert_eq!(vec, vec![true, false]);
928
929 let mut response = Response {
930 results: to_map(vec![Ok(vec![true, false].into())]),
931 ..Response::new()
932 };
933 let Err(Api(Error::LossyTake(Response {
934 results: mut map,
935 ..
936 }))): Result<Option<bool>> = response.take(0)
937 else {
938 panic!("silently dropping records not allowed");
939 };
940 let records = map.swap_remove(&0).unwrap().1.unwrap();
941 assert_eq!(records, vec![true, false].into());
942 }
943
944 #[test]
945 fn check_returns_the_first_error() {
946 let response = vec![
947 Ok(0.into()),
948 Ok(1.into()),
949 Ok(2.into()),
950 Err(Error::ConnectionUninitialised.into()),
951 Ok(3.into()),
952 Ok(4.into()),
953 Ok(5.into()),
954 Err(Error::BackupsNotSupported.into()),
955 Ok(6.into()),
956 Ok(7.into()),
957 Err(Error::DuplicateRequestId(0).into()),
958 ];
959 let response = Response {
960 results: to_map(response),
961 ..Response::new()
962 };
963 let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err()
964 else {
965 panic!("check did not return the first error");
966 };
967 }
968
969 #[test]
970 fn take_errors() {
971 let response = vec![
972 Ok(0.into()),
973 Ok(1.into()),
974 Ok(2.into()),
975 Err(Error::ConnectionUninitialised.into()),
976 Ok(3.into()),
977 Ok(4.into()),
978 Ok(5.into()),
979 Err(Error::BackupsNotSupported.into()),
980 Ok(6.into()),
981 Ok(7.into()),
982 Err(Error::DuplicateRequestId(0).into()),
983 ];
984 let mut response = Response {
985 results: to_map(response),
986 ..Response::new()
987 };
988 let errors = response.take_errors();
989 assert_eq!(response.num_statements(), 8);
990 assert_eq!(errors.len(), 3);
991 let crate::Error::Api(Error::DuplicateRequestId(0)) = errors.get(&10).unwrap() else {
992 panic!("index `10` is not `DuplicateRequestId`");
993 };
994 let crate::Error::Api(Error::BackupsNotSupported) = errors.get(&7).unwrap() else {
995 panic!("index `7` is not `BackupsNotSupported`");
996 };
997 let crate::Error::Api(Error::ConnectionUninitialised) = errors.get(&3).unwrap() else {
998 panic!("index `3` is not `ConnectionUninitialised`");
999 };
1000 let Some(value): Option<i32> = response.take(2).unwrap() else {
1001 panic!("statement not found");
1002 };
1003 assert_eq!(value, 2);
1004 let value: Value = response.take(4).unwrap();
1005 assert_eq!(value, Value::from(3));
1006 }
1007}