1use super::{live, Stream};
2use crate::api::conn::Command;
3use crate::api::err::Error;
4use crate::api::method::BoxFuture;
5use crate::api::opt;
6use crate::api::Connection;
7use crate::api::ExtraFeatures;
8use crate::api::Result;
9use crate::engine::any::Any;
10use crate::method::OnceLockExt;
11use crate::method::Stats;
12use crate::method::WithStats;
13use crate::value::Notification;
14use crate::{Surreal, Value};
15use futures::future::Either;
16use futures::stream::SelectAll;
17use futures::StreamExt;
18use indexmap::IndexMap;
19use serde::de::DeserializeOwned;
20use serde::Serialize;
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::future::IntoFuture;
24use std::pin::Pin;
25use std::task::Context;
26use std::task::Poll;
27use surrealdb_core::sql::{
28 self, to_value as to_core_value, Object as CoreObject, Statement, Value as CoreValue,
29};
30
31#[derive(Debug)]
33#[must_use = "futures do nothing unless you `.await` or poll them"]
34pub struct Query<'r, C: Connection> {
35 pub(crate) inner: Result<ValidQuery<'r, C>>,
36}
37
38#[derive(Debug)]
39pub(crate) struct ValidQuery<'r, C: Connection> {
40 pub client: Cow<'r, Surreal<C>>,
41 pub query: Vec<Statement>,
42 pub bindings: CoreObject,
43 pub register_live_queries: bool,
44}
45
46impl<'r, C> Query<'r, C>
47where
48 C: Connection,
49{
50 pub(crate) fn new(
51 client: Cow<'r, Surreal<C>>,
52 query: Vec<Statement>,
53 bindings: CoreObject,
54 register_live_queries: bool,
55 ) -> Self {
56 Query {
57 inner: Ok(ValidQuery {
58 client,
59 query,
60 bindings,
61 register_live_queries,
62 }),
63 }
64 }
65
66 pub(crate) fn map_valid<F>(self, f: F) -> Self
67 where
68 F: FnOnce(ValidQuery<'r, C>) -> Result<ValidQuery<'r, C>>,
69 {
70 match self.inner {
71 Ok(x) => Query {
72 inner: f(x),
73 },
74 x => Query {
75 inner: x,
76 },
77 }
78 }
79
80 pub fn into_owned(self) -> Query<'static, C> {
82 let inner = match self.inner {
83 Ok(ValidQuery {
84 client,
85 query,
86 bindings,
87 register_live_queries,
88 }) => Ok(ValidQuery::<'static, C> {
89 client: Cow::Owned(client.into_owned()),
90 query,
91 bindings,
92 register_live_queries,
93 }),
94 Err(e) => Err(e),
95 };
96
97 Query {
98 inner,
99 }
100 }
101}
102
103impl<'r, Client> IntoFuture for Query<'r, Client>
104where
105 Client: Connection,
106{
107 type Output = Result<Response>;
108 type IntoFuture = BoxFuture<'r, Self::Output>;
109
110 fn into_future(self) -> Self::IntoFuture {
111 let ValidQuery {
112 client,
113 query,
114 bindings,
115 register_live_queries,
116 } = match self.inner {
117 Ok(x) => x,
118 Err(error) => return Box::pin(async move { Err(error) }),
119 };
120
121 let query_statements = query;
122
123 Box::pin(async move {
124 let router = client.router.extract()?;
126
127 let query_indicies = if register_live_queries {
129 query_statements
130 .iter()
131 .filter(|x| {
133 !matches!(
134 x,
135 Statement::Begin(_) | Statement::Commit(_) | Statement::Cancel(_)
136 )
137 })
138 .enumerate()
139 .filter(|(_, x)| matches!(x, Statement::Live(_)))
140 .map(|(i, _)| i)
141 .collect()
142 } else {
143 Vec::new()
144 };
145
146 if !query_indicies.is_empty() && !router.features.contains(&ExtraFeatures::LiveQueries)
148 {
149 return Err(Error::LiveQueriesNotSupported.into());
150 }
151
152 let mut query = sql::Query::default();
153 query.0 .0 = query_statements;
154
155 let mut response = router
156 .execute_query(Command::Query {
157 query,
158 variables: bindings,
159 })
160 .await?;
161
162 for idx in query_indicies {
163 let Some((_, result)) = response.results.get(&idx) else {
164 continue;
165 };
166
167 let res = match result {
170 Ok(id) => {
171 let CoreValue::Uuid(uuid) = id else {
172 return Err(Error::InternalError(
173 "successfull live query did not return a uuid".to_string(),
174 )
175 .into());
176 };
177 live::register(router, uuid.0).await.map(|rx| {
178 Stream::new(
179 Surreal::new_from_router_waiter(
180 client.router.clone(),
181 client.waiter.clone(),
182 ),
183 uuid.0,
184 Some(rx),
185 )
186 })
187 }
188 Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))),
189 };
190 response.live_queries.insert(idx, res);
191 }
192
193 response.client =
194 Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone());
195 Ok(response)
196 })
197 }
198}
199
200impl<'r, Client> IntoFuture for WithStats<Query<'r, Client>>
201where
202 Client: Connection,
203{
204 type Output = Result<WithStats<Response>>;
205 type IntoFuture = BoxFuture<'r, Self::Output>;
206
207 fn into_future(self) -> Self::IntoFuture {
208 Box::pin(async move {
209 let response = self.0.await?;
210 Ok(WithStats(response))
211 })
212 }
213}
214
215impl<'r, C> Query<'r, C>
216where
217 C: Connection,
218{
219 pub fn query(self, query: impl opt::IntoQuery) -> Self {
221 self.map_valid(move |mut valid| {
222 let new_query = query.into_query()?;
223 valid.query.extend(new_query);
224 Ok(valid)
225 })
226 }
227
228 pub const fn with_stats(self) -> WithStats<Self> {
230 WithStats(self)
231 }
232
233 pub fn bind(self, bindings: impl Serialize + 'static) -> Self {
275 self.map_valid(move |mut valid| {
276 let bindings = to_core_value(bindings)?;
277 match bindings {
278 CoreValue::Object(mut map) => valid.bindings.append(&mut map.0),
279 CoreValue::Array(array) => {
280 if array.len() != 2 || !matches!(array[0], CoreValue::Strand(_)) {
281 let bindings = CoreValue::Array(array);
282 let bindings = Value::from_inner(bindings);
283 return Err(Error::InvalidBindings(bindings).into());
284 }
285
286 let mut iter = array.into_iter();
287 let Some(CoreValue::Strand(key)) = iter.next() else {
288 unreachable!()
289 };
290 let Some(value) = iter.next() else {
291 unreachable!()
292 };
293
294 valid.bindings.0.insert(key.0, value);
295 }
296 _ => {
297 let bindings = Value::from_inner(bindings);
298 return Err(Error::InvalidBindings(bindings).into());
299 }
300 }
301
302 Ok(valid)
303 })
304 }
305}
306
307pub(crate) type QueryResult = Result<CoreValue>;
308
309#[derive(Debug)]
311pub struct Response {
312 pub(crate) client: Surreal<Any>,
313 pub(crate) results: IndexMap<usize, (Stats, QueryResult)>,
314 pub(crate) live_queries: IndexMap<usize, Result<Stream<Value>>>,
315}
316
317#[derive(Debug)]
319#[must_use = "streams do nothing unless you poll them"]
320pub struct QueryStream<R>(pub(crate) Either<Stream<R>, SelectAll<Stream<R>>>);
321
322impl futures::Stream for QueryStream<Value> {
323 type Item = Notification<Value>;
324
325 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326 self.as_mut().0.poll_next_unpin(cx)
327 }
328}
329
330impl<R> futures::Stream for QueryStream<Notification<R>>
331where
332 R: DeserializeOwned + Unpin,
333{
334 type Item = Result<Notification<R>>;
335
336 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
337 self.as_mut().0.poll_next_unpin(cx)
338 }
339}
340
341impl Response {
342 pub(crate) fn new() -> Self {
343 Self {
344 client: Surreal::init(),
345 results: Default::default(),
346 live_queries: Default::default(),
347 }
348 }
349
350 pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Result<R>
412 where
413 R: DeserializeOwned,
414 {
415 index.query_result(self)
416 }
417
418 pub fn stream<R>(&mut self, index: impl opt::QueryStream<R>) -> Result<QueryStream<R>> {
462 index.query_stream(self)
463 }
464
465 pub fn take_errors(&mut self) -> HashMap<usize, crate::Error> {
483 let mut keys = Vec::new();
484 for (key, result) in &self.results {
485 if result.1.is_err() {
486 keys.push(*key);
487 }
488 }
489 let mut errors = HashMap::with_capacity(keys.len());
490 for key in keys {
491 if let Some((_, Err(error))) = self.results.swap_remove(&key) {
492 errors.insert(key, error);
493 }
494 }
495 errors
496 }
497
498 pub fn check(mut self) -> Result<Self> {
513 let mut first_error = None;
514 for (key, result) in &self.results {
515 if result.1.is_err() {
516 first_error = Some(*key);
517 break;
518 }
519 }
520 if let Some(key) = first_error {
521 if let Some((_, Err(error))) = self.results.swap_remove(&key) {
522 return Err(error);
523 }
524 }
525 Ok(self)
526 }
527
528 pub fn num_statements(&self) -> usize {
544 self.results.len()
545 }
546}
547
548impl WithStats<Response> {
549 pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Option<(Stats, Result<R>)>
612 where
613 R: DeserializeOwned,
614 {
615 let stats = index.stats(&self.0)?;
616 let result = index.query_result(&mut self.0);
617 Some((stats, result))
618 }
619
620 pub fn take_errors(&mut self) -> HashMap<usize, (Stats, crate::Error)> {
638 let mut keys = Vec::new();
639 for (key, result) in &self.0.results {
640 if result.1.is_err() {
641 keys.push(*key);
642 }
643 }
644 let mut errors = HashMap::with_capacity(keys.len());
645 for key in keys {
646 if let Some((stats, Err(error))) = self.0.results.swap_remove(&key) {
647 errors.insert(key, (stats, error));
648 }
649 }
650 errors
651 }
652
653 pub fn check(self) -> Result<Self> {
668 let response = self.0.check()?;
669 Ok(Self(response))
670 }
671
672 pub fn num_statements(&self) -> usize {
688 self.0.num_statements()
689 }
690
691 pub fn into_inner(self) -> Response {
693 self.0
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700 use crate::{value::to_value, Error::Api};
701 use serde::Deserialize;
702 use surrealdb_core::sql::Value as CoreValue;
703
704 #[derive(Debug, Clone, Serialize, Deserialize)]
705 struct Summary {
706 title: String,
707 }
708
709 #[derive(Debug, Clone, Serialize, Deserialize)]
710 struct Article {
711 title: String,
712 body: String,
713 }
714
715 fn to_map(vec: Vec<QueryResult>) -> IndexMap<usize, (Stats, QueryResult)> {
716 vec.into_iter()
717 .map(|result| {
718 let stats = Stats {
719 execution_time: Default::default(),
720 };
721 (stats, result)
722 })
723 .enumerate()
724 .collect()
725 }
726
727 #[test]
728 fn take_from_an_empty_response() {
729 let mut response = Response::new();
730 let value: Value = response.take(0).unwrap();
731 assert!(value.into_inner().is_none());
732
733 let mut response = Response::new();
734 let option: Option<String> = response.take(0).unwrap();
735 assert!(option.is_none());
736
737 let mut response = Response::new();
738 let vec: Vec<String> = response.take(0).unwrap();
739 assert!(vec.is_empty());
740 }
741
742 #[test]
743 fn take_from_an_errored_query() {
744 let mut response = Response {
745 results: to_map(vec![Err(Error::ConnectionUninitialised.into())]),
746 ..Response::new()
747 };
748 response.take::<Option<()>>(0).unwrap_err();
749 }
750
751 #[test]
752 fn take_from_empty_records() {
753 let mut response = Response {
754 results: to_map(vec![]),
755 ..Response::new()
756 };
757 let value: Value = response.take(0).unwrap();
758 assert_eq!(value, Default::default());
759
760 let mut response = Response {
761 results: to_map(vec![]),
762 ..Response::new()
763 };
764 let option: Option<String> = response.take(0).unwrap();
765 assert!(option.is_none());
766
767 let mut response = Response {
768 results: to_map(vec![]),
769 ..Response::new()
770 };
771 let vec: Vec<String> = response.take(0).unwrap();
772 assert!(vec.is_empty());
773 }
774
775 #[test]
776 fn take_from_a_scalar_response() {
777 let scalar = 265;
778
779 let mut response = Response {
780 results: to_map(vec![Ok(scalar.into())]),
781 ..Response::new()
782 };
783 let value: Value = response.take(0).unwrap();
784 assert_eq!(value.into_inner(), CoreValue::from(scalar));
785
786 let mut response = Response {
787 results: to_map(vec![Ok(scalar.into())]),
788 ..Response::new()
789 };
790 let option: Option<_> = response.take(0).unwrap();
791 assert_eq!(option, Some(scalar));
792
793 let mut response = Response {
794 results: to_map(vec![Ok(scalar.into())]),
795 ..Response::new()
796 };
797 let vec: Vec<i64> = response.take(0).unwrap();
798 assert_eq!(vec, vec![scalar]);
799
800 let scalar = true;
801
802 let mut response = Response {
803 results: to_map(vec![Ok(scalar.into())]),
804 ..Response::new()
805 };
806 let value: Value = response.take(0).unwrap();
807 assert_eq!(value.into_inner(), CoreValue::from(scalar));
808
809 let mut response = Response {
810 results: to_map(vec![Ok(scalar.into())]),
811 ..Response::new()
812 };
813 let option: Option<_> = response.take(0).unwrap();
814 assert_eq!(option, Some(scalar));
815
816 let mut response = Response {
817 results: to_map(vec![Ok(scalar.into())]),
818 ..Response::new()
819 };
820 let vec: Vec<bool> = response.take(0).unwrap();
821 assert_eq!(vec, vec![scalar]);
822 }
823
824 #[test]
825 fn take_preserves_order() {
826 let mut response = Response {
827 results: to_map(vec![
828 Ok(0.into()),
829 Ok(1.into()),
830 Ok(2.into()),
831 Ok(3.into()),
832 Ok(4.into()),
833 Ok(5.into()),
834 Ok(6.into()),
835 Ok(7.into()),
836 ]),
837 ..Response::new()
838 };
839 let Some(four): Option<i32> = response.take(4).unwrap() else {
840 panic!("query not found");
841 };
842 assert_eq!(four, 4);
843 let Some(six): Option<i32> = response.take(6).unwrap() else {
844 panic!("query not found");
845 };
846 assert_eq!(six, 6);
847 let Some(zero): Option<i32> = response.take(0).unwrap() else {
848 panic!("query not found");
849 };
850 assert_eq!(zero, 0);
851 let one: Value = response.take(1).unwrap();
852 assert_eq!(one.into_inner(), CoreValue::from(1));
853 }
854
855 #[test]
856 fn take_key() {
857 let summary = Summary {
858 title: "Lorem Ipsum".to_owned(),
859 };
860 let value = to_value(summary.clone()).unwrap();
861
862 let mut response = Response {
863 results: to_map(vec![Ok(value.clone().into_inner())]),
864 ..Response::new()
865 };
866 let title: Value = response.take("title").unwrap();
867 assert_eq!(title.into_inner(), CoreValue::from(summary.title.as_str()));
868
869 let mut response = Response {
870 results: to_map(vec![Ok(value.clone().into_inner())]),
871 ..Response::new()
872 };
873 let Some(title): Option<String> = response.take("title").unwrap() else {
874 panic!("title not found");
875 };
876 assert_eq!(title, summary.title);
877
878 let mut response = Response {
879 results: to_map(vec![Ok(value.into_inner())]),
880 ..Response::new()
881 };
882 let vec: Vec<String> = response.take("title").unwrap();
883 assert_eq!(vec, vec![summary.title]);
884
885 let article = Article {
886 title: "Lorem Ipsum".to_owned(),
887 body: "Lorem Ipsum Lorem Ipsum".to_owned(),
888 };
889 let value = to_value(article.clone()).unwrap();
890
891 let mut response = Response {
892 results: to_map(vec![Ok(value.clone().into_inner())]),
893 ..Response::new()
894 };
895 let Some(title): Option<String> = response.take("title").unwrap() else {
896 panic!("title not found");
897 };
898 assert_eq!(title, article.title);
899 let Some(body): Option<String> = response.take("body").unwrap() else {
900 panic!("body not found");
901 };
902 assert_eq!(body, article.body);
903
904 let mut response = Response {
905 results: to_map(vec![Ok(value.clone().into_inner())]),
906 ..Response::new()
907 };
908 let vec: Vec<String> = response.take("title").unwrap();
909 assert_eq!(vec, vec![article.title.clone()]);
910
911 let mut response = Response {
912 results: to_map(vec![Ok(value.into_inner())]),
913 ..Response::new()
914 };
915 let value: Value = response.take("title").unwrap();
916 assert_eq!(value.into_inner(), CoreValue::from(article.title));
917 }
918
919 #[test]
920 fn take_key_multi() {
921 let article = Article {
922 title: "Lorem Ipsum".to_owned(),
923 body: "Lorem Ipsum Lorem Ipsum".to_owned(),
924 };
925 let value = to_value(article.clone()).unwrap();
926
927 let mut response = Response {
928 results: to_map(vec![Ok(value.clone().into_inner())]),
929 ..Response::new()
930 };
931 let title: Vec<String> = response.take("title").unwrap();
932 assert_eq!(title, vec![article.title.clone()]);
933 let body: Vec<String> = response.take("body").unwrap();
934 assert_eq!(body, vec![article.body]);
935
936 let mut response = Response {
937 results: to_map(vec![Ok(value.clone().into_inner())]),
938 ..Response::new()
939 };
940 let vec: Vec<String> = response.take("title").unwrap();
941 assert_eq!(vec, vec![article.title]);
942 }
943
944 #[test]
945 fn take_partial_records() {
946 let mut response = Response {
947 results: to_map(vec![Ok(vec![true, false].into())]),
948 ..Response::new()
949 };
950 let value: Value = response.take(0).unwrap();
951 assert_eq!(value.into_inner(), vec![CoreValue::from(true), CoreValue::from(false)].into());
952
953 let mut response = Response {
954 results: to_map(vec![Ok(vec![true, false].into())]),
955 ..Response::new()
956 };
957 let vec: Vec<bool> = response.take(0).unwrap();
958 assert_eq!(vec, vec![true, false]);
959
960 let mut response = Response {
961 results: to_map(vec![Ok(vec![true, false].into())]),
962 ..Response::new()
963 };
964 let Err(Api(Error::LossyTake(Response {
965 results: mut map,
966 ..
967 }))): Result<Option<bool>> = response.take(0)
968 else {
969 panic!("silently dropping records not allowed");
970 };
971 let records = map.swap_remove(&0).unwrap().1.unwrap();
972 assert_eq!(records, vec![true, false].into());
973 }
974
975 #[test]
976 fn check_returns_the_first_error() {
977 let response = vec![
978 Ok(0.into()),
979 Ok(1.into()),
980 Ok(2.into()),
981 Err(Error::ConnectionUninitialised.into()),
982 Ok(3.into()),
983 Ok(4.into()),
984 Ok(5.into()),
985 Err(Error::BackupsNotSupported.into()),
986 Ok(6.into()),
987 Ok(7.into()),
988 Err(Error::DuplicateRequestId(0).into()),
989 ];
990 let response = Response {
991 results: to_map(response),
992 ..Response::new()
993 };
994 let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err()
995 else {
996 panic!("check did not return the first error");
997 };
998 }
999
1000 #[test]
1001 fn take_errors() {
1002 let response = vec![
1003 Ok(0.into()),
1004 Ok(1.into()),
1005 Ok(2.into()),
1006 Err(Error::ConnectionUninitialised.into()),
1007 Ok(3.into()),
1008 Ok(4.into()),
1009 Ok(5.into()),
1010 Err(Error::BackupsNotSupported.into()),
1011 Ok(6.into()),
1012 Ok(7.into()),
1013 Err(Error::DuplicateRequestId(0).into()),
1014 ];
1015 let mut response = Response {
1016 results: to_map(response),
1017 ..Response::new()
1018 };
1019 let errors = response.take_errors();
1020 assert_eq!(response.num_statements(), 8);
1021 assert_eq!(errors.len(), 3);
1022 let crate::Error::Api(Error::DuplicateRequestId(0)) = errors.get(&10).unwrap() else {
1023 panic!("index `10` is not `DuplicateRequestId`");
1024 };
1025 let crate::Error::Api(Error::BackupsNotSupported) = errors.get(&7).unwrap() else {
1026 panic!("index `7` is not `BackupsNotSupported`");
1027 };
1028 let crate::Error::Api(Error::ConnectionUninitialised) = errors.get(&3).unwrap() else {
1029 panic!("index `3` is not `ConnectionUninitialised`");
1030 };
1031 let Some(value): Option<i32> = response.take(2).unwrap() else {
1032 panic!("statement not found");
1033 };
1034 assert_eq!(value, 2);
1035 let value: Value = response.take(4).unwrap();
1036 assert_eq!(value.into_inner(), CoreValue::from(3));
1037 }
1038}