unreql/cmd/
run.rs

1use super::args::Args;
2use crate::cmd::options::{Durability, ReadMode};
3use crate::proto::{Command, Payload};
4use crate::{err, Connection, Result, Session};
5use async_stream::try_stream;
6use async_trait::async_trait;
7use futures::io::{AsyncReadExt, AsyncWriteExt};
8use futures::stream::{Stream, StreamExt};
9use ql2::query::QueryType;
10use ql2::response::{ErrorType, ResponseType};
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::borrow::Cow;
15use std::str;
16use std::sync::atomic::Ordering;
17use tracing::trace;
18use unreql_macros::OptionsBuilder;
19
20const DATA_SIZE: usize = 4;
21const TOKEN_SIZE: usize = 8;
22const HEADER_SIZE: usize = DATA_SIZE + TOKEN_SIZE;
23
24#[derive(Deserialize, Debug)]
25#[allow(dead_code)]
26pub(crate) struct Response {
27    t: i32,
28    e: Option<i32>,
29    pub(crate) r: Value,
30    b: Option<Value>,
31    p: Option<Value>,
32    n: Option<Value>,
33}
34
35impl Response {
36    fn new() -> Self {
37        Self {
38            t: ResponseType::SuccessAtom as i32,
39            e: None,
40            r: Value::Array(Vec::new()),
41            b: None,
42            p: None,
43            n: None,
44        }
45    }
46}
47
48#[derive(
49    Debug, Clone, OptionsBuilder, Serialize, Default, Eq, PartialEq, Ord, PartialOrd, Hash,
50)]
51#[non_exhaustive]
52pub struct Options {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub read_mode: Option<ReadMode>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub time_format: Option<Format>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub profile: Option<bool>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub durability: Option<Durability>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub group_format: Option<Format>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub noreply: Option<bool>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub db: Option<Db>,
67}
68
69#[derive(Debug, Clone, Copy, Serialize, Eq, PartialEq, Ord, PartialOrd, Hash)]
70#[non_exhaustive]
71#[serde(rename_all = "lowercase")]
72pub enum Format {
73    Native,
74    Raw,
75}
76
77#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
78pub struct Db(pub Cow<'static, str>);
79
80pub(crate) const DEFAULT_DB: &str = "test";
81
82impl Options {
83    async fn default_db(self, session: &Session) -> Options {
84        let session_db = session.inner.db.lock().await;
85        if self.db.is_none() && *session_db != DEFAULT_DB {
86            return Self {
87                db: Some(Db(session_db.clone())),
88                ..self
89            };
90        }
91        self
92    }
93}
94
95#[async_trait]
96pub trait Arg {
97    async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, Options)>;
98}
99
100#[async_trait]
101impl Arg for &Session {
102    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
103        let conn = self.connection()?;
104        Ok((conn, Default::default()))
105    }
106}
107
108#[async_trait]
109impl Arg for Connection {
110    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
111        Ok((self, Default::default()))
112    }
113}
114
115#[async_trait]
116impl Arg for Args<(&Session, Options)> {
117    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
118        let Args((session, options)) = self;
119        let conn = session.connection()?;
120        Ok((conn, options))
121    }
122}
123
124#[async_trait]
125impl Arg for Args<(Connection, Options)> {
126    async fn into_run_opts(self, _for_changes: bool) -> Result<(Connection, Options)> {
127        let Args(arg) = self;
128        Ok(arg)
129    }
130}
131
132#[async_trait]
133impl Arg for &mut Session {
134    async fn into_run_opts(self, for_changes: bool) -> Result<(Connection, Options)> {
135        self.connection()?.into_run_opts(for_changes).await
136    }
137}
138
139pub(crate) fn new<A, T>(query: Command, arg: A) -> impl Stream<Item = Result<T>>
140where
141    A: Arg,
142    T: Unpin + DeserializeOwned,
143{
144    try_stream! {
145        let (mut conn, mut opts) = arg.into_run_opts(query.change_feed()).await?;
146        opts = opts.default_db(&conn.session).await;
147        let change_feed = query.change_feed();
148        if change_feed {
149            conn.session.inner.mark_change_feed();
150        }
151        let noreply = opts.noreply.unwrap_or_default();
152        let mut payload = Payload(QueryType::Start, Some(&query), opts);
153        loop {
154            let (response_type, resp) = conn.request(&payload, noreply).await?;
155            trace!("yielding response; token: {}", conn.token);
156            match response_type {
157                ResponseType::SuccessAtom => {
158                    // If response is array then will try to flat it
159                    // [[1, 2, 3]] => [1, 2, 3]
160                    let atom_val = if let Value::Array(arr) = resp.r {
161                        if arr.is_empty() {
162                            Value::Array(arr)
163                        } else {
164                            match &arr[0] {
165                                Value::Array(inner_arr) => Value::Array(inner_arr.clone()),
166                                _ => Value::Array(arr),
167                            }
168                        }
169                    } else {
170                        resp.r
171                    };
172                    for val in serde_json::from_value::<Vec<T>>(atom_val)? {
173                        yield val;
174                    }
175                    break;
176                },
177                ResponseType::SuccessSequence | ResponseType::ServerInfo => {
178                    for val in serde_json::from_value::<Vec<T>>(resp.r)? {
179                        yield val;
180                    }
181                    break;
182                }
183                ResponseType::SuccessPartial => {
184                    if conn.closed() {
185                        // reopen so we can use the connection in future
186                        conn.set_closed(false);
187                        trace!("connection closed; token: {}", conn.token);
188                        break;
189                    }
190                    payload = Payload(QueryType::Continue, None, Default::default());
191                    for val in serde_json::from_value::<Vec<T>>(resp.r)? {
192                        yield val;
193                    }
194                    continue;
195                }
196                ResponseType::WaitComplete => { break; }
197                typ => {
198                    let msg = error_message(resp.r)?;
199                    match typ {
200                        // This feed has been closed by conn.close().
201                        ResponseType::ClientError if change_feed && msg.contains("not in stream cache") => { break; }
202                        _ => Err(response_error(typ, resp.e, msg))?,
203                    }
204                }
205            }
206        }
207    }
208}
209
210impl Payload<'_> {
211    fn encode(&self, token: u64) -> Result<Vec<u8>> {
212        let bytes = self.to_bytes()?;
213        let data_len = bytes.len();
214        let mut buf = Vec::with_capacity(HEADER_SIZE + data_len);
215        buf.extend_from_slice(&token.to_le_bytes());
216        buf.extend_from_slice(&(data_len as u32).to_le_bytes());
217        buf.extend_from_slice(&bytes);
218        Ok(buf)
219    }
220}
221
222impl Connection {
223    fn send_response(&self, db_token: u64, resp: Result<(ResponseType, Response)>) {
224        if let Some(tx) = self.session.inner.channels.get(&db_token) {
225            if let Err(error) = tx.unbounded_send(resp) {
226                if error.is_disconnected() {
227                    self.session.inner.channels.remove(&db_token);
228                }
229            }
230        }
231    }
232
233    pub(crate) async fn request<'a>(
234        &mut self,
235        query: &'a Payload<'a>,
236        noreply: bool,
237    ) -> Result<(ResponseType, Response)> {
238        self.submit(query, noreply).await;
239        match self.rx.lock().await.next().await {
240            Some(resp) => resp,
241            None => Ok((ResponseType::SuccessAtom, Response::new())),
242        }
243    }
244
245    async fn submit<'a>(&self, query: &'a Payload<'a>, noreply: bool) {
246        let mut db_token = self.token;
247        let result = self.exec(query, noreply, &mut db_token).await;
248        self.send_response(db_token, result);
249    }
250
251    async fn exec<'a>(
252        &self,
253        query: &'a Payload<'a>,
254        noreply: bool,
255        db_token: &mut u64,
256    ) -> Result<(ResponseType, Response)> {
257        let buf = query.encode(self.token)?;
258
259        let guard = self.session.inner.stream.lock().await;
260        let mut stream = guard.clone();
261
262        trace!("sending query; token: {}, payload: {}", self.token, query);
263        stream.write_all(&buf).await?;
264        trace!("query sent; token: {}", self.token);
265
266        if noreply {
267            return Ok((ResponseType::SuccessAtom, Response::new()));
268        }
269
270        trace!("reading header; token: {}", self.token);
271        let mut header = [0u8; HEADER_SIZE];
272        stream.read_exact(&mut header).await?;
273
274        let mut buf = [0u8; TOKEN_SIZE];
275        buf.copy_from_slice(&header[..TOKEN_SIZE]);
276        *db_token = {
277            let token = u64::from_le_bytes(buf);
278            trace!("db_token: {}", token);
279            if token > self.session.inner.token.load(Ordering::SeqCst) {
280                self.session.inner.mark_broken();
281                return Err(err::Driver::ConnectionBroken.into());
282            }
283            token
284        };
285
286        let mut buf = [0u8; DATA_SIZE];
287        buf.copy_from_slice(&header[TOKEN_SIZE..]);
288        let len = u32::from_le_bytes(buf) as usize;
289        trace!(
290            "header read; token: {}, db_token: {}, response_len: {}",
291            self.token,
292            db_token,
293            len
294        );
295
296        trace!("reading body; token: {}", self.token);
297        let mut buf = vec![0u8; len];
298        stream.read_exact(&mut buf).await?;
299
300        trace!(
301            "body read; token: {}, db_token: {}, body: {}",
302            self.token,
303            db_token,
304            crate::tools::bytes_to_string(&buf),
305        );
306
307        let resp = serde_json::from_slice::<Response>(&buf)?;
308        trace!("response successfully parsed; token: {}", self.token,);
309
310        let response_type = ResponseType::from_i32(resp.t)
311            .ok_or_else(|| err::Driver::Other(format!("unknown response type `{}`", resp.t)))?;
312
313        if let Some(error_type) = resp.e {
314            let msg = error_message(resp.r)?;
315            return Err(response_error(response_type, Some(error_type), msg));
316        }
317
318        Ok((response_type, resp))
319    }
320}
321
322fn error_message(response: Value) -> Result<String> {
323    let messages = serde_json::from_value::<Vec<String>>(response)?;
324    Ok(messages.join(" "))
325}
326
327fn response_error(response_type: ResponseType, error_type: Option<i32>, msg: String) -> err::Error {
328    match response_type {
329        ResponseType::ClientError => err::Driver::Other(msg).into(),
330        ResponseType::CompileError => err::Error::Compile(msg),
331        ResponseType::RuntimeError => match error_type
332            .map(ErrorType::from_i32)
333            .ok_or_else(|| err::Driver::Other(format!("unexpected runtime error: {}", msg)))
334        {
335            Ok(Some(ErrorType::Internal)) => err::Runtime::Internal(msg).into(),
336            Ok(Some(ErrorType::ResourceLimit)) => err::Runtime::ResourceLimit(msg).into(),
337            Ok(Some(ErrorType::QueryLogic)) => err::Runtime::QueryLogic(msg).into(),
338            Ok(Some(ErrorType::NonExistence)) => err::Runtime::NonExistence(msg).into(),
339            Ok(Some(ErrorType::OpFailed)) => err::Availability::OpFailed(msg).into(),
340            Ok(Some(ErrorType::OpIndeterminate)) => err::Availability::OpIndeterminate(msg).into(),
341            Ok(Some(ErrorType::User)) => err::Runtime::User(msg).into(),
342            Ok(Some(ErrorType::PermissionError)) => err::Runtime::Permission(msg).into(),
343            Err(error) => error.into(),
344            _ => err::Driver::Other(format!("unexpected runtime error: {}", msg)).into(),
345        },
346        _ => err::Driver::Other(format!("unexpected response: {}", msg)).into(),
347    }
348}