1use std::sync::Arc;
12
13use futures_core::future::BoxFuture;
14use futures_core::stream::BoxStream;
15use sqlx_core::connection::Connection;
16use sqlx_core::error::Error;
17use sqlx_core::executor::Executor;
18use sqlx_core::transaction::Transaction;
19use sqlx_core::HashMap;
20
21use spg_embedded::QueryResult as EngineQueryResult;
22use spg_embedded_tokio::AsyncDatabase;
23
24use crate::column::SpgColumn;
25use crate::database::Spg;
26use crate::error::engine_to_sqlx;
27use crate::options::SpgConnectOptions;
28use crate::query_result::SpgQueryResult;
29use crate::row::SpgRow;
30use crate::type_info::SpgTypeInfo;
31
32#[derive(Debug, Clone)]
34pub struct SpgConnection {
35 pub(crate) inner: AsyncDatabase,
36 pub(crate) tx_depth: usize,
37 pub(crate) pending_rollback: bool,
38}
39
40impl SpgConnection {
41 pub fn new(inner: AsyncDatabase) -> Self {
45 Self {
46 inner,
47 tx_depth: 0,
48 pending_rollback: false,
49 }
50 }
51
52 #[must_use]
56 pub const fn engine(&self) -> &AsyncDatabase {
57 &self.inner
58 }
59}
60
61impl Connection for SpgConnection {
62 type Database = Spg;
63 type Options = SpgConnectOptions;
64
65 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
66 Box::pin(async move { Ok(()) })
69 }
70
71 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
72 Box::pin(async move { Ok(()) })
73 }
74
75 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
76 Box::pin(async move {
79 self.inner
80 .execute("SELECT 1")
81 .await
82 .map_err(engine_to_sqlx)?;
83 Ok(())
84 })
85 }
86
87 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
88 where
89 Self: Sized,
90 {
91 Transaction::begin(self, None)
92 }
93
94 fn shrink_buffers(&mut self) {
95 }
97
98 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
99 Box::pin(async move { Ok(()) })
100 }
101
102 fn should_flush(&self) -> bool {
103 false
104 }
105}
106
107impl<'c> Executor<'c> for &'c mut SpgConnection {
113 type Database = Spg;
114
115 fn fetch_many<'e, 'q: 'e, E>(
116 self,
117 mut query: E,
118 ) -> BoxStream<
119 'e,
120 Result<
121 either::Either<<Self::Database as sqlx_core::database::Database>::QueryResult, crate::SpgRow>,
122 Error,
123 >,
124 >
125 where
126 'c: 'e,
127 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
128 {
129 use futures_util::stream::{self, StreamExt};
130 let sql = query.sql().to_string();
131 let arguments = match query.take_arguments() {
132 Ok(args) => args,
133 Err(e) => {
134 return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
135 }
136 };
137 let inner = self.inner.clone();
138 let outcome_fut = async move { run_one(&inner, &sql, arguments).await };
139 Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
140 let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
141 Ok(Outcome::Affected(qr)) => vec![Ok(either::Either::Left(qr))],
142 Ok(Outcome::Rows(rows)) => rows
143 .into_iter()
144 .map(|r| Ok(either::Either::Right(r)))
145 .collect(),
146 Err(e) => vec![Err(e)],
147 };
148 stream::iter(items)
149 }))
150 }
151
152 fn fetch_optional<'e, 'q: 'e, E>(
153 self,
154 mut query: E,
155 ) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
156 where
157 'c: 'e,
158 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
159 {
160 let sql = query.sql().to_string();
161 let arguments = query.take_arguments();
162 let inner = self.inner.clone();
163 Box::pin(async move {
164 let args = arguments.map_err(Error::Encode)?;
165 match run_one(&inner, &sql, args).await? {
166 Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
167 Outcome::Affected(_) => Ok(None),
168 }
169 })
170 }
171
172 fn prepare_with<'e, 'q: 'e>(
173 self,
174 sql: &'q str,
175 _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
176 ) -> BoxFuture<'e, Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>>
177 where
178 'c: 'e,
179 {
180 let inner = self.inner.clone();
181 let sql_str = sql.to_string();
182 Box::pin(async move {
183 let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
184 let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
190 Ok(crate::SpgStatement {
191 sql: std::borrow::Cow::Owned(sql_str),
192 inner: Some(inner_stmt),
193 columns: std::sync::Arc::new(Vec::new()),
194 by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
195 })
196 })
197 }
198
199 fn describe<'e, 'q: 'e>(
200 self,
201 _sql: &'q str,
202 ) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
203 where
204 'c: 'e,
205 {
206 Box::pin(async move {
207 Err(Error::Protocol(
208 "describe is v7.17 — compile-time sqlx::query!() macros need offline mode in the meantime".into(),
209 ))
210 })
211 }
212}
213
214enum Outcome {
218 Affected(SpgQueryResult),
220 Rows(Vec<SpgRow>),
223}
224
225async fn run_one(
226 db: &AsyncDatabase,
227 sql: &str,
228 arguments: Option<crate::SpgArguments<'_>>,
229) -> Result<Outcome, Error> {
230 let result: EngineQueryResult = if let Some(args) = arguments {
237 let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
238 db.execute_prepared(&stmt, args.into_engine_values())
239 .await
240 .map_err(engine_to_sqlx)?
241 } else {
242 db.execute(sql).await.map_err(engine_to_sqlx)?
243 };
244 match result {
245 EngineQueryResult::Rows { columns, rows } => {
246 let row_values: Vec<Vec<spg_embedded::Value>> =
247 rows.into_iter().map(|r| r.values).collect();
248 Ok(Outcome::Rows(build_rows(&columns, row_values)))
249 }
250 EngineQueryResult::CommandOk { affected, .. } => Ok(Outcome::Affected(
251 SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)),
252 )),
253 _ => Ok(Outcome::Affected(SpgQueryResult::default())),
254 }
255}
256
257fn affected_from(qr: &EngineQueryResult) -> u64 {
258 match qr {
259 EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
260 EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
261 _ => 0,
262 }
263}
264
265fn build_rows(
266 cols: &[spg_embedded::ColumnSchema],
267 rows: Vec<Vec<spg_embedded::Value>>,
268) -> Vec<SpgRow> {
269 let columns: Arc<Vec<SpgColumn>> = Arc::new(
270 cols.iter()
271 .enumerate()
272 .map(|(i, c)| {
273 SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty))
274 })
275 .collect(),
276 );
277 let mut by_name: HashMap<String, usize> = HashMap::new();
278 for (i, c) in cols.iter().enumerate() {
279 by_name.insert(c.name.clone(), i);
280 }
281 let by_name = Arc::new(by_name);
282 rows.into_iter()
283 .map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
284 .collect()
285}