1use super::args::Args;
2use crate::cmd::options::{Durability, ReadMode};
3use crate::proto::{Command, Payload};
4use crate::{err, Connection, Result, Session};
5use async_stream::try_stream;
6use async_trait::async_trait;
7use futures::io::{AsyncReadExt, AsyncWriteExt};
8use futures::stream::{Stream, StreamExt};
9use ql2::query::QueryType;
10use ql2::response::{ErrorType, ResponseType};
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::borrow::Cow;
15use std::str;
16use std::sync::atomic::Ordering;
17use tracing::trace;
18use unreql_macros::OptionsBuilder;
19
20const DATA_SIZE: usize = 4;
21const TOKEN_SIZE: usize = 8;
22const HEADER_SIZE: usize = DATA_SIZE + TOKEN_SIZE;
23
24#[derive(Deserialize, Debug)]
25#[allow(dead_code)]
26pub(crate) struct Response {
27 t: i32,
28 e: Option<i32>,
29 pub(crate) r: Value,
30 b: Option<Value>,
31 p: Option<Value>,
32 n: Option<Value>,
33}
34
35impl Response {
36 fn new() -> Self {
37 Self {
38 t: ResponseType::SuccessAtom as i32,
39 e: None,
40 r: Value::Array(Vec::new()),
41 b: None,
42 p: None,
43 n: None,
44 }
45 }
46}
47
48#[derive(
49 Debug, Clone, OptionsBuilder, Serialize, Default, Eq, PartialEq, Ord, PartialOrd, Hash,
50)]
51#[non_exhaustive]
52pub struct Options {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub read_mode: Option<ReadMode>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub time_format: Option<Format>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub profile: Option<bool>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub durability: Option<Durability>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub group_format: Option<Format>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub noreply: Option<bool>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub db: Option<Db>,
67}
68
69#[derive(Debug, Clone, Copy, Serialize, Eq, PartialEq, Ord, PartialOrd, Hash)]
70#[non_exhaustive]
71#[serde(rename_all = "lowercase")]
72pub enum Format {
73 Native,
74 Raw,
75}
76
77#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
78pub struct Db(pub Cow<'static, str>);
79
80pub(crate) const DEFAULT_DB: &str = "test";
81
82impl Options {
83 async fn default_db(self, session: &Session) -> Options {
84 let session_db = session.inner.db.lock().await;
85 if self.db.is_none() && *session_db != DEFAULT_DB {
86 return Self {
87 db: Some(Db(session_db.clone())),
88 ..self
89 };
90 }
91 self
92 }
93}
94
95#[async_trait]
96pub trait Arg {
97 async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, Options)>;
98}
99
100#[async_trait]
101impl Arg for &Session {
102 async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
103 let conn = self.connection()?;
104 Ok((conn, Default::default()))
105 }
106}
107
108#[async_trait]
109impl Arg for Connection {
110 async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
111 Ok((self, Default::default()))
112 }
113}
114
115#[async_trait]
116impl Arg for Args<(&Session, Options)> {
117 async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
118 let Args((session, options)) = self;
119 let conn = session.connection()?;
120 Ok((conn, options))
121 }
122}
123
124#[async_trait]
125impl Arg for Args<(Connection, Options)> {
126 async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
127 let Args(arg) = self;
128 Ok(arg)
129 }
130}
131
132#[async_trait]
133impl Arg for &mut Session {
134 async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, Options)> {
135 self.connection()?.into_run_opts(for_changes).await
136 }
137}
138
139pub(crate) fn new<A, T>(query: Command, arg: A) -> impl Stream<Item = Result<T>>
140where
141 A: Arg,
142 T: Unpin + DeserializeOwned,
143{
144 try_stream! {
145 let (mut conn, mut opts) = arg.into_run_opts(query.change_feed()).await?;
146 opts = opts.default_db(&conn.session).await;
147 let change_feed = query.change_feed();
148 if change_feed {
149 conn.session.inner.mark_change_feed();
150 }
151 let noreply = opts.noreply.unwrap_or_default();
152 let mut payload = Payload(QueryType::Start, Some(&query), opts);
153 loop {
154 let (response_type, resp) = conn.request(&payload, noreply).await?;
155 trace!("yielding response; token: {}", conn.token);
156 match response_type {
157 ResponseType::SuccessAtom => {
158 let atom_val = if let Value::Array(arr) = resp.r {
161 if arr.is_empty() {
162 Value::Array(arr)
163 } else {
164 match &arr[0] {
165 Value::Array(inner_arr) => Value::Array(inner_arr.clone()),
166 _ => Value::Array(arr),
167 }
168 }
169 } else {
170 resp.r
171 };
172 for val in serde_json::from_value::<Vec<T>>(atom_val)? {
173 yield val;
174 }
175 break;
176 },
177 ResponseType::SuccessSequence | ResponseType::ServerInfo => {
178 for val in serde_json::from_value::<Vec<T>>(resp.r)? {
179 yield val;
180 }
181 break;
182 }
183 ResponseType::SuccessPartial => {
184 if conn.closed() {
185 conn.set_closed(false);
187 trace!("connection closed; token: {}", conn.token);
188 break;
189 }
190 payload = Payload(QueryType::Continue, None, Default::default());
191 for val in serde_json::from_value::<Vec<T>>(resp.r)? {
192 yield val;
193 }
194 continue;
195 }
196 ResponseType::WaitComplete => { break; }
197 typ => {
198 let msg = error_message(resp.r)?;
199 match typ {
200 ResponseType::ClientError if change_feed && msg.contains("not in stream cache") => { break; }
202 _ => Err(response_error(typ, resp.e, msg))?,
203 }
204 }
205 }
206 }
207 }
208}
209
210impl Payload<'_> {
211 fn encode(&self, token: u64) -> Result<Vec<u8>> {
212 let bytes = self.to_bytes()?;
213 let data_len = bytes.len();
214 let mut buf = Vec::with_capacity(HEADER_SIZE + data_len);
215 buf.extend_from_slice(&token.to_le_bytes());
216 buf.extend_from_slice(&(data_len as u32).to_le_bytes());
217 buf.extend_from_slice(&bytes);
218 Ok(buf)
219 }
220}
221
222impl Connection {
223 fn send_response(&self, db_token: u64, resp: Result<(ResponseType, Response)>) {
224 if let Some(tx) = self.session.inner.channels.get(&db_token) {
225 if let Err(error) = tx.unbounded_send(resp) {
226 if error.is_disconnected() {
227 self.session.inner.channels.remove(&db_token);
228 }
229 }
230 }
231 }
232
233 pub(crate) async fn request<'a>(
234 &mut self,
235 query: &'a Payload<'a>,
236 noreply: bool,
237 ) -> Result<(ResponseType, Response)> {
238 self.submit(query, noreply).await;
239 match self.rx.lock().await.next().await {
240 Some(resp) => resp,
241 None => Ok((ResponseType::SuccessAtom, Response::new())),
242 }
243 }
244
245 async fn submit<'a>(&self, query: &'a Payload<'a>, noreply: bool) {
246 let mut db_token = self.token;
247 let result = self.exec(query, noreply, &mut db_token).await;
248 self.send_response(db_token, result);
249 }
250
251 async fn exec<'a>(
252 &self,
253 query: &'a Payload<'a>,
254 noreply: bool,
255 db_token: &mut u64,
256 ) -> Result<(ResponseType, Response)> {
257 let buf = query.encode(self.token)?;
258
259 let guard = self.session.inner.stream.lock().await;
260 let mut stream = guard.clone();
261
262 trace!("sending query; token: {}, payload: {}", self.token, query);
263 stream.write_all(&buf).await?;
264 trace!("query sent; token: {}", self.token);
265
266 if noreply {
267 return Ok((ResponseType::SuccessAtom, Response::new()));
268 }
269
270 trace!("reading header; token: {}", self.token);
271 let mut header = [0u8; HEADER_SIZE];
272 stream.read_exact(&mut header).await?;
273
274 let mut buf = [0u8; TOKEN_SIZE];
275 buf.copy_from_slice(&header[..TOKEN_SIZE]);
276 *db_token = {
277 let token = u64::from_le_bytes(buf);
278 trace!("db_token: {}", token);
279 if token > self.session.inner.token.load(Ordering::SeqCst) {
280 self.session.inner.mark_broken();
281 return Err(err::Driver::ConnectionBroken.into());
282 }
283 token
284 };
285
286 let mut buf = [0u8; DATA_SIZE];
287 buf.copy_from_slice(&header[TOKEN_SIZE..]);
288 let len = u32::from_le_bytes(buf) as usize;
289 trace!(
290 "header read; token: {}, db_token: {}, response_len: {}",
291 self.token,
292 db_token,
293 len
294 );
295
296 trace!("reading body; token: {}", self.token);
297 let mut buf = vec![0u8; len];
298 stream.read_exact(&mut buf).await?;
299
300 trace!(
301 "body read; token: {}, db_token: {}, body: {}",
302 self.token,
303 db_token,
304 crate::tools::bytes_to_string(&buf),
305 );
306
307 let resp = serde_json::from_slice::<Response>(&buf)?;
308 trace!("response successfully parsed; token: {}", self.token,);
309
310 let response_type = ResponseType::from_i32(resp.t)
311 .ok_or_else(|| err::Driver::Other(format!("unknown response type `{}`", resp.t)))?;
312
313 if let Some(error_type) = resp.e {
314 let msg = error_message(resp.r)?;
315 return Err(response_error(response_type, Some(error_type), msg));
316 }
317
318 Ok((response_type, resp))
319 }
320}
321
322fn error_message(response: Value) -> Result<String> {
323 let messages = serde_json::from_value::<Vec<String>>(response)?;
324 Ok(messages.join(" "))
325}
326
327fn response_error(response_type: ResponseType, error_type: Option<i32>, msg: String) -> err::Error {
328 match response_type {
329 ResponseType::ClientError => err::Driver::Other(msg).into(),
330 ResponseType::CompileError => err::Error::Compile(msg),
331 ResponseType::RuntimeError => match error_type
332 .map(ErrorType::from_i32)
333 .ok_or_else(|| err::Driver::Other(format!("unexpected runtime error: {}", msg)))
334 {
335 Ok(Some(ErrorType::Internal)) => err::Runtime::Internal(msg).into(),
336 Ok(Some(ErrorType::ResourceLimit)) => err::Runtime::ResourceLimit(msg).into(),
337 Ok(Some(ErrorType::QueryLogic)) => err::Runtime::QueryLogic(msg).into(),
338 Ok(Some(ErrorType::NonExistence)) => err::Runtime::NonExistence(msg).into(),
339 Ok(Some(ErrorType::OpFailed)) => err::Availability::OpFailed(msg).into(),
340 Ok(Some(ErrorType::OpIndeterminate)) => err::Availability::OpIndeterminate(msg).into(),
341 Ok(Some(ErrorType::User)) => err::Runtime::User(msg).into(),
342 Ok(Some(ErrorType::PermissionError)) => err::Runtime::Permission(msg).into(),
343 Err(error) => error.into(),
344 _ => err::Driver::Other(format!("unexpected runtime error: {}", msg)).into(),
345 },
346 _ => err::Driver::Other(format!("unexpected response: {}", msg)).into(),
347 }
348}