sqlx_scylladb_core/connection/
executor.rs

1use std::{borrow::Cow, ops::ControlFlow, pin::pin};
2
3use bytes::Bytes;
4use futures_core::{Stream, future::BoxFuture, stream::BoxStream};
5use futures_util::TryStreamExt;
6use scylla::{
7    deserialize::row::ColumnIterator,
8    response::{
9        PagingState, PagingStateResponse,
10        query_result::{ColumnSpecs, QueryResult},
11    },
12    statement::Statement,
13};
14use sqlx::{Connection, Describe, Either, Error, Executor, Row};
15use sqlx_core::{ext::ustr::UStr, try_stream};
16
17use crate::{
18    ScyllaDB, ScyllaDBArguments, ScyllaDBColumn, ScyllaDBConnection, ScyllaDBError,
19    ScyllaDBQueryResult, ScyllaDBRow, ScyllaDBStatement, ScyllaDBTypeInfo,
20    statement::ScyllaDBStatementMetadata,
21};
22
23const APPLIED_COLUMN: &'static str = "[applied]";
24
25impl ScyllaDBConnection {
26    async fn execute_single_page<'e, 'c: 'e, 'q: 'e>(
27        &'c mut self,
28        sql: &str,
29        arguments: &Option<ScyllaDBArguments>,
30        persistent: bool,
31        paging_state: PagingState,
32    ) -> Result<(QueryResult, PagingStateResponse), ScyllaDBError> {
33        if persistent {
34            let (query_result, paging_state_response) = if let Some(arguments) = arguments {
35                self.caching_session
36                    .execute_single_page(sql, arguments, paging_state)
37                    .await?
38            } else {
39                self.caching_session
40                    .execute_single_page(sql, (), paging_state)
41                    .await?
42            };
43
44            Ok((query_result, paging_state_response))
45        } else {
46            let session = self.caching_session.get_session();
47
48            let (query_result, paging_state_response) = if let Some(arguments) = arguments {
49                session
50                    .query_single_page(sql, arguments, paging_state)
51                    .await?
52            } else {
53                session.query_single_page(sql, (), paging_state).await?
54            };
55
56            Ok((query_result, paging_state_response))
57        }
58    }
59
60    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
61        &'c mut self,
62        sql: &'q str,
63        arguments: Option<ScyllaDBArguments>,
64        persistent: bool,
65    ) -> Result<
66        impl Stream<Item = Result<Either<ScyllaDBQueryResult, ScyllaDBRow>, Error>> + 'e,
67        Error,
68    > {
69        Ok(try_stream! {
70            let statement = self.prepare(sql).await?;
71
72            // INSERT, UPDATE, and DELETE queries during transactions are processed in batches.
73            let in_batch = self.is_in_transaction() && statement.is_affect_statement;
74
75            if !in_batch {
76                let mut paging_state = PagingState::start();
77
78                loop {
79                    let (query_result, paging_state_response) = self.execute_single_page(&statement.sql, &arguments, persistent, paging_state.clone()).await?;
80
81                    if !query_result.is_rows() {
82                        break;
83                    }
84
85                    let rows_result = query_result.into_rows_result().map_err(ScyllaDBError::IntoRowsResultError)?;
86                    let column_specs = rows_result.column_specs();
87                    let metadata = ScyllaDBStatementMetadata::from_column_specs(column_specs)?;
88
89                    let is_lwt = column_specs.is_lwt();
90                    let rows_num = rows_result.rows_num() as u64;
91                    let mut rows_affected = 0;
92
93                    let rows = rows_result.rows::<ColumnIterator<'_,'_>>().map_err(ScyllaDBError::RowsError)?;
94                    for row in rows {
95                        let row = row.map_err(ScyllaDBError::DeserializationError)?;
96
97                        let mut columns: Vec<Option<Bytes>> = Vec::with_capacity(row.columns_remaining());
98                        for column in row {
99                            let column = column.map_err(ScyllaDBError::DeserializationError)?;
100                            let column = match column.slice {
101                                Some(slice) => {
102                                    Some(slice.to_bytes())
103                                },
104                                None => None,
105                            };
106                            columns.push(column)
107                        }
108
109                        let scylladb_row = ScyllaDBRow::new(columns, metadata.clone());
110
111                        if is_lwt {
112                            let applied: bool = scylladb_row.try_get(APPLIED_COLUMN).unwrap_or(false);
113                            if applied {
114                                rows_affected += 1;
115                            }
116                        }
117
118                        r#yield!(Either::Right(scylladb_row))
119                    }
120
121                    r#yield!(Either::Left(ScyllaDBQueryResult { rows_num, rows_affected }));
122
123                    match paging_state_response.into_paging_control_flow() {
124                        ControlFlow::Break(()) => {
125                            break;
126                        }
127                        ControlFlow::Continue(new_paging_state) => {
128                            paging_state = new_paging_state
129                        }
130                    }
131                }
132            } else {
133                self.append_to_transaction(sql, arguments).await?;
134            }
135
136            Ok(())
137        })
138    }
139}
140
141impl<'c> Executor<'c> for &'c mut ScyllaDBConnection {
142    type Database = ScyllaDB;
143
144    fn fetch_many<'e, 'q: 'e, E>(
145        self,
146        query: E,
147    ) -> BoxStream<'e, Result<Either<ScyllaDBQueryResult, ScyllaDBRow>, sqlx::Error>>
148    where
149        'c: 'e,
150        E: 'q + sqlx::Execute<'q, ScyllaDB>,
151    {
152        let sql = query.sql();
153        let mut query = query;
154        let arguments = query.take_arguments().map_err(Error::Encode);
155        let persistent = query.persistent();
156
157        Box::pin(try_stream! {
158            let arguments = arguments?;
159            let mut s = pin!(self.run(sql, arguments, persistent).await?);
160
161            while let Some(v) = s.try_next().await? {
162                r#yield!(v);
163            }
164
165            Ok(())
166        })
167    }
168
169    fn fetch_optional<'e, 'q: 'e, E>(
170        self,
171        query: E,
172    ) -> BoxFuture<'e, Result<Option<ScyllaDBRow>, sqlx::Error>>
173    where
174        'c: 'e,
175        E: 'q + sqlx::Execute<'q, Self::Database>,
176    {
177        let mut s = self.fetch_many(query);
178
179        Box::pin(async move {
180            while let Some(v) = s.try_next().await? {
181                if let Either::Right(r) = v {
182                    return Ok(Some(r));
183                }
184            }
185
186            Ok(None)
187        })
188    }
189
190    fn prepare_with<'e, 'q: 'e>(
191        self,
192        sql: &'q str,
193        _parameters: &'e [ScyllaDBTypeInfo],
194    ) -> BoxFuture<'e, Result<ScyllaDBStatement<'q>, sqlx::Error>>
195    where
196        'c: 'e,
197    {
198        Box::pin(async move {
199            let statement = Statement::new(sql).with_page_size(self.page_size);
200            let prepared_statement = self
201                .caching_session
202                .add_prepared_statement(&statement)
203                .await
204                .map_err(ScyllaDBError::PrepareError)?;
205
206            let column_specs = prepared_statement.get_result_set_col_specs();
207            let metadata = ScyllaDBStatementMetadata::from_column_specs(column_specs)?;
208
209            let is_affect_statement = column_specs.is_affect_statement();
210
211            Ok(ScyllaDBStatement {
212                sql: Cow::Borrowed(sql),
213                prepared_statement,
214                metadata,
215                is_affect_statement,
216            })
217        })
218    }
219
220    fn describe<'e, 'q: 'e>(
221        self,
222        sql: &'q str,
223    ) -> BoxFuture<'e, Result<Describe<Self::Database>, sqlx::Error>>
224    where
225        'c: 'e,
226    {
227        Box::pin(async move {
228            let statement = Statement::new(sql);
229            let prepared_statement = self
230                .caching_session
231                .add_prepared_statement(&statement)
232                .await
233                .map_err(ScyllaDBError::PrepareError)?;
234            let column_specs = prepared_statement.get_result_set_col_specs();
235
236            let capacity = column_specs.len();
237            let mut columns = Vec::with_capacity(capacity);
238            let mut parameters = Vec::with_capacity(capacity);
239            let mut nullable = Vec::with_capacity(capacity);
240            for (i, column_spec) in column_specs.iter().enumerate() {
241                let name = UStr::new(column_spec.name());
242                let column_type = column_spec.typ();
243                let type_info = ScyllaDBTypeInfo::from_column_type(column_type)?;
244
245                columns.push(ScyllaDBColumn {
246                    ordinal: i,
247                    name,
248                    type_info: type_info.clone(),
249                    column_type: column_type.clone().into_owned(),
250                });
251                parameters.push(type_info);
252                nullable.push(Some(true));
253            }
254
255            let describe = Describe::<ScyllaDB> {
256                columns,
257                parameters: Some(Either::Left(parameters)),
258                nullable,
259            };
260
261            Ok(describe)
262        })
263    }
264}
265
266trait ColumnSpecsExt {
267    fn is_affect_statement(&self) -> bool;
268    fn is_lwt(&self) -> bool;
269}
270
271impl ColumnSpecsExt for ColumnSpecs<'_, '_> {
272    #[inline]
273    fn is_affect_statement(&self) -> bool {
274        // Returns 0 for queries other than SELECT queries.
275        if self.len() == 0 {
276            return true;
277        }
278
279        if self.is_lwt() {
280            return true;
281        }
282
283        return false;
284    }
285
286    #[inline]
287    fn is_lwt(&self) -> bool {
288        // In the case of a lightweight transaction, [applied] is returned.
289        if let Some(column_spec) = self.get_by_index(0) {
290            if column_spec.name() == APPLIED_COLUMN {
291                return true;
292            }
293        }
294
295        false
296    }
297}