1use std::sync::Arc;
12
13use futures_core::future::BoxFuture;
14use futures_core::stream::BoxStream;
15use sqlx_core::HashMap;
16use sqlx_core::connection::Connection;
17use sqlx_core::error::Error;
18use sqlx_core::executor::Executor;
19use sqlx_core::transaction::Transaction;
20
21use spg_embedded::QueryResult as EngineQueryResult;
22use spg_embedded_tokio::AsyncDatabase;
23
24use crate::column::SpgColumn;
25use crate::database::Spg;
26use crate::error::engine_to_sqlx;
27use crate::options::SpgConnectOptions;
28use crate::query_result::SpgQueryResult;
29use crate::row::SpgRow;
30use crate::type_info::SpgTypeInfo;
31
32#[derive(Debug, Clone)]
34pub struct SpgConnection {
35 pub(crate) inner: AsyncDatabase,
36 pub(crate) tx_depth: usize,
37 pub(crate) pending_rollback: bool,
38}
39
40impl SpgConnection {
41 pub fn new(inner: AsyncDatabase) -> Self {
45 Self {
46 inner,
47 tx_depth: 0,
48 pending_rollback: false,
49 }
50 }
51
52 #[must_use]
56 pub const fn engine(&self) -> &AsyncDatabase {
57 &self.inner
58 }
59}
60
61impl Connection for SpgConnection {
62 type Database = Spg;
63 type Options = SpgConnectOptions;
64
65 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
66 Box::pin(async move { Ok(()) })
69 }
70
71 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
72 Box::pin(async move { Ok(()) })
73 }
74
75 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
76 Box::pin(async move {
79 self.inner
80 .execute("SELECT 1")
81 .await
82 .map_err(engine_to_sqlx)?;
83 Ok(())
84 })
85 }
86
87 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
88 where
89 Self: Sized,
90 {
91 Transaction::begin(self, None)
92 }
93
94 fn shrink_buffers(&mut self) {
95 }
97
98 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
99 Box::pin(async move { Ok(()) })
100 }
101
102 fn should_flush(&self) -> bool {
103 false
104 }
105}
106
107impl<'c> Executor<'c> for &'c mut SpgConnection {
113 type Database = Spg;
114
115 fn fetch_many<'e, 'q: 'e, E>(
116 self,
117 mut query: E,
118 ) -> BoxStream<
119 'e,
120 Result<
121 either::Either<
122 <Self::Database as sqlx_core::database::Database>::QueryResult,
123 crate::SpgRow,
124 >,
125 Error,
126 >,
127 >
128 where
129 'c: 'e,
130 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
131 {
132 use futures_util::stream::{self, StreamExt};
133 let sql = query.sql().to_string();
134 let arguments = match query.take_arguments() {
135 Ok(args) => args,
136 Err(e) => {
137 return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
138 }
139 };
140 let inner = self.inner.clone();
141 let outcome_fut = async move { run_one(&inner, &sql, arguments).await };
142 Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
143 let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
144 Ok(Outcome::Affected(qr)) => vec![Ok(either::Either::Left(qr))],
145 Ok(Outcome::Rows(rows)) => rows
146 .into_iter()
147 .map(|r| Ok(either::Either::Right(r)))
148 .collect(),
149 Err(e) => vec![Err(e)],
150 };
151 stream::iter(items)
152 }))
153 }
154
155 fn fetch_optional<'e, 'q: 'e, E>(
156 self,
157 mut query: E,
158 ) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
159 where
160 'c: 'e,
161 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
162 {
163 let sql = query.sql().to_string();
164 let arguments = query.take_arguments();
165 let inner = self.inner.clone();
166 Box::pin(async move {
167 let args = arguments.map_err(Error::Encode)?;
168 match run_one(&inner, &sql, args).await? {
169 Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
170 Outcome::Affected(_) => Ok(None),
171 }
172 })
173 }
174
175 fn prepare_with<'e, 'q: 'e>(
176 self,
177 sql: &'q str,
178 _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
179 ) -> BoxFuture<
180 'e,
181 Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>,
182 >
183 where
184 'c: 'e,
185 {
186 let inner = self.inner.clone();
187 let sql_str = sql.to_string();
188 Box::pin(async move {
189 let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
190 let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
196 Ok(crate::SpgStatement {
197 sql: std::borrow::Cow::Owned(sql_str),
198 inner: Some(inner_stmt),
199 columns: std::sync::Arc::new(Vec::new()),
200 by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
201 })
202 })
203 }
204
205 fn describe<'e, 'q: 'e>(
206 self,
207 sql: &'q str,
208 ) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
209 where
210 'c: 'e,
211 {
212 let inner = self.inner.clone();
222 let sql_str = sql.to_string();
223 Box::pin(async move {
224 let (params, cols) = inner.describe(&sql_str).await.map_err(engine_to_sqlx)?;
225 let nullable: Vec<Option<bool>> = cols.iter().map(|c| Some(c.nullable)).collect();
226 let columns: Vec<SpgColumn> = cols
227 .iter()
228 .enumerate()
229 .map(|(i, c)| {
230 let ti = SpgTypeInfo::from_data_type(c.ty);
231 SpgColumn::new(i, c.name.clone(), ti)
232 })
233 .collect();
234 let parameters = if params.is_empty() {
235 None
236 } else {
237 Some(either::Either::Right(params.len()))
238 };
239 Ok(sqlx_core::describe::Describe {
240 columns,
241 parameters,
242 nullable,
243 })
244 })
245 }
246}
247
248enum Outcome {
252 Affected(SpgQueryResult),
254 Rows(Vec<SpgRow>),
257}
258
259async fn run_one(
260 db: &AsyncDatabase,
261 sql: &str,
262 arguments: Option<crate::SpgArguments<'_>>,
263) -> Result<Outcome, Error> {
264 let result: EngineQueryResult = if let Some(args) = arguments {
271 let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
272 db.execute_prepared(&stmt, args.into_engine_values())
273 .await
274 .map_err(engine_to_sqlx)?
275 } else {
276 db.execute(sql).await.map_err(engine_to_sqlx)?
277 };
278 match result {
279 EngineQueryResult::Rows { columns, rows } => {
280 let row_values: Vec<Vec<spg_embedded::Value>> =
281 rows.into_iter().map(|r| r.values).collect();
282 Ok(Outcome::Rows(build_rows(&columns, row_values)))
283 }
284 EngineQueryResult::CommandOk { affected, .. } => Ok(Outcome::Affected(
285 SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)),
286 )),
287 _ => Ok(Outcome::Affected(SpgQueryResult::default())),
288 }
289}
290
291#[allow(dead_code)]
292fn affected_from(qr: &EngineQueryResult) -> u64 {
293 match qr {
294 EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
295 EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
296 _ => 0,
297 }
298}
299
300fn build_rows(
301 cols: &[spg_embedded::ColumnSchema],
302 rows: Vec<Vec<spg_embedded::Value>>,
303) -> Vec<SpgRow> {
304 let columns: Arc<Vec<SpgColumn>> = Arc::new(
305 cols.iter()
306 .enumerate()
307 .map(|(i, c)| SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty)))
308 .collect(),
309 );
310 let mut by_name: HashMap<String, usize> = HashMap::new();
311 for (i, c) in cols.iter().enumerate() {
312 by_name.insert(c.name.clone(), i);
313 }
314 let by_name = Arc::new(by_name);
315 rows.into_iter()
316 .map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
317 .collect()
318}