sqlx_scylladb_core/connection/
executor.rs1use std::{borrow::Cow, ops::ControlFlow, pin::pin};
2
3use bytes::Bytes;
4use futures_core::{Stream, future::BoxFuture, stream::BoxStream};
5use futures_util::TryStreamExt;
6use scylla::{
7 deserialize::row::ColumnIterator,
8 response::{
9 PagingState, PagingStateResponse,
10 query_result::{ColumnSpecs, QueryResult},
11 },
12 statement::Statement,
13};
14use sqlx::{Connection, Describe, Either, Error, Executor, Row};
15use sqlx_core::{ext::ustr::UStr, try_stream};
16
17use crate::{
18 ScyllaDB, ScyllaDBArguments, ScyllaDBColumn, ScyllaDBConnection, ScyllaDBError,
19 ScyllaDBQueryResult, ScyllaDBRow, ScyllaDBStatement, ScyllaDBTypeInfo,
20 statement::ScyllaDBStatementMetadata,
21};
22
23const APPLIED_COLUMN: &'static str = "[applied]";
24
25impl ScyllaDBConnection {
26 async fn execute_single_page<'e, 'c: 'e, 'q: 'e>(
27 &'c mut self,
28 sql: &str,
29 arguments: &Option<ScyllaDBArguments>,
30 persistent: bool,
31 paging_state: PagingState,
32 ) -> Result<(QueryResult, PagingStateResponse), ScyllaDBError> {
33 if persistent {
34 let (query_result, paging_state_response) = if let Some(arguments) = arguments {
35 self.caching_session
36 .execute_single_page(sql, arguments, paging_state)
37 .await?
38 } else {
39 self.caching_session
40 .execute_single_page(sql, (), paging_state)
41 .await?
42 };
43
44 Ok((query_result, paging_state_response))
45 } else {
46 let session = self.caching_session.get_session();
47
48 let (query_result, paging_state_response) = if let Some(arguments) = arguments {
49 session
50 .query_single_page(sql, arguments, paging_state)
51 .await?
52 } else {
53 session.query_single_page(sql, (), paging_state).await?
54 };
55
56 Ok((query_result, paging_state_response))
57 }
58 }
59
60 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
61 &'c mut self,
62 sql: &'q str,
63 arguments: Option<ScyllaDBArguments>,
64 persistent: bool,
65 ) -> Result<
66 impl Stream<Item = Result<Either<ScyllaDBQueryResult, ScyllaDBRow>, Error>> + 'e,
67 Error,
68 > {
69 Ok(try_stream! {
70 let statement = self.prepare(sql).await?;
71
72 let in_batch = self.is_in_transaction() && statement.is_affect_statement;
74
75 if !in_batch {
76 let mut paging_state = PagingState::start();
77
78 loop {
79 let (query_result, paging_state_response) = self.execute_single_page(&statement.sql, &arguments, persistent, paging_state.clone()).await?;
80
81 if !query_result.is_rows() {
82 break;
83 }
84
85 let rows_result = query_result.into_rows_result().map_err(ScyllaDBError::IntoRowsResultError)?;
86 let column_specs = rows_result.column_specs();
87 let metadata = ScyllaDBStatementMetadata::from_column_specs(column_specs)?;
88
89 let is_lwt = column_specs.is_lwt();
90 let rows_num = rows_result.rows_num() as u64;
91 let mut rows_affected = 0;
92
93 let rows = rows_result.rows::<ColumnIterator<'_,'_>>().map_err(ScyllaDBError::RowsError)?;
94 for row in rows {
95 let row = row.map_err(ScyllaDBError::DeserializationError)?;
96
97 let mut columns: Vec<Option<Bytes>> = Vec::with_capacity(row.columns_remaining());
98 for column in row {
99 let column = column.map_err(ScyllaDBError::DeserializationError)?;
100 let column = match column.slice {
101 Some(slice) => {
102 Some(slice.to_bytes())
103 },
104 None => None,
105 };
106 columns.push(column)
107 }
108
109 let scylladb_row = ScyllaDBRow::new(columns, metadata.clone());
110
111 if is_lwt {
112 let applied: bool = scylladb_row.try_get(APPLIED_COLUMN).unwrap_or(false);
113 if applied {
114 rows_affected += 1;
115 }
116 }
117
118 r#yield!(Either::Right(scylladb_row))
119 }
120
121 r#yield!(Either::Left(ScyllaDBQueryResult { rows_num, rows_affected }));
122
123 match paging_state_response.into_paging_control_flow() {
124 ControlFlow::Break(()) => {
125 break;
126 }
127 ControlFlow::Continue(new_paging_state) => {
128 paging_state = new_paging_state
129 }
130 }
131 }
132 } else {
133 self.append_to_transaction(sql, arguments).await?;
134 }
135
136 Ok(())
137 })
138 }
139}
140
141impl<'c> Executor<'c> for &'c mut ScyllaDBConnection {
142 type Database = ScyllaDB;
143
144 fn fetch_many<'e, 'q: 'e, E>(
145 self,
146 query: E,
147 ) -> BoxStream<'e, Result<Either<ScyllaDBQueryResult, ScyllaDBRow>, sqlx::Error>>
148 where
149 'c: 'e,
150 E: 'q + sqlx::Execute<'q, ScyllaDB>,
151 {
152 let sql = query.sql();
153 let mut query = query;
154 let arguments = query.take_arguments().map_err(Error::Encode);
155 let persistent = query.persistent();
156
157 Box::pin(try_stream! {
158 let arguments = arguments?;
159 let mut s = pin!(self.run(sql, arguments, persistent).await?);
160
161 while let Some(v) = s.try_next().await? {
162 r#yield!(v);
163 }
164
165 Ok(())
166 })
167 }
168
169 fn fetch_optional<'e, 'q: 'e, E>(
170 self,
171 query: E,
172 ) -> BoxFuture<'e, Result<Option<ScyllaDBRow>, sqlx::Error>>
173 where
174 'c: 'e,
175 E: 'q + sqlx::Execute<'q, Self::Database>,
176 {
177 let mut s = self.fetch_many(query);
178
179 Box::pin(async move {
180 while let Some(v) = s.try_next().await? {
181 if let Either::Right(r) = v {
182 return Ok(Some(r));
183 }
184 }
185
186 Ok(None)
187 })
188 }
189
190 fn prepare_with<'e, 'q: 'e>(
191 self,
192 sql: &'q str,
193 _parameters: &'e [ScyllaDBTypeInfo],
194 ) -> BoxFuture<'e, Result<ScyllaDBStatement<'q>, sqlx::Error>>
195 where
196 'c: 'e,
197 {
198 Box::pin(async move {
199 let statement = Statement::new(sql).with_page_size(self.page_size);
200 let prepared_statement = self
201 .caching_session
202 .add_prepared_statement(&statement)
203 .await
204 .map_err(ScyllaDBError::PrepareError)?;
205
206 let column_specs = prepared_statement.get_result_set_col_specs();
207 let metadata = ScyllaDBStatementMetadata::from_column_specs(column_specs)?;
208
209 let is_affect_statement = column_specs.is_affect_statement();
210
211 Ok(ScyllaDBStatement {
212 sql: Cow::Borrowed(sql),
213 prepared_statement,
214 metadata,
215 is_affect_statement,
216 })
217 })
218 }
219
220 fn describe<'e, 'q: 'e>(
221 self,
222 sql: &'q str,
223 ) -> BoxFuture<'e, Result<Describe<Self::Database>, sqlx::Error>>
224 where
225 'c: 'e,
226 {
227 Box::pin(async move {
228 let statement = Statement::new(sql);
229 let prepared_statement = self
230 .caching_session
231 .add_prepared_statement(&statement)
232 .await
233 .map_err(ScyllaDBError::PrepareError)?;
234 let column_specs = prepared_statement.get_result_set_col_specs();
235
236 let capacity = column_specs.len();
237 let mut columns = Vec::with_capacity(capacity);
238 let mut parameters = Vec::with_capacity(capacity);
239 let mut nullable = Vec::with_capacity(capacity);
240 for (i, column_spec) in column_specs.iter().enumerate() {
241 let name = UStr::new(column_spec.name());
242 let column_type = column_spec.typ();
243 let type_info = ScyllaDBTypeInfo::from_column_type(column_type)?;
244
245 columns.push(ScyllaDBColumn {
246 ordinal: i,
247 name,
248 type_info: type_info.clone(),
249 column_type: column_type.clone().into_owned(),
250 });
251 parameters.push(type_info);
252 nullable.push(Some(true));
253 }
254
255 let describe = Describe::<ScyllaDB> {
256 columns,
257 parameters: Some(Either::Left(parameters)),
258 nullable,
259 };
260
261 Ok(describe)
262 })
263 }
264}
265
266trait ColumnSpecsExt {
267 fn is_affect_statement(&self) -> bool;
268 fn is_lwt(&self) -> bool;
269}
270
271impl ColumnSpecsExt for ColumnSpecs<'_, '_> {
272 #[inline]
273 fn is_affect_statement(&self) -> bool {
274 if self.len() == 0 {
276 return true;
277 }
278
279 if self.is_lwt() {
280 return true;
281 }
282
283 return false;
284 }
285
286 #[inline]
287 fn is_lwt(&self) -> bool {
288 if let Some(column_spec) = self.get_by_index(0) {
290 if column_spec.name() == APPLIED_COLUMN {
291 return true;
292 }
293 }
294
295 false
296 }
297}