1#![deny(unused_extern_crates)]
2#![deny(warnings)]
3
4#[cfg(feature = "sqlite")]
5use sqlx_sqlite_cipher::Sqlite;
6#[cfg(feature = "sqlite")]
7pub type SqlxDb = Sqlite;
8
9use sqlx::prelude::*;
10use sqlx::{database::HasArguments, Column, Database, Pool, TypeInfo};
11use sqlx_core::types::chrono::{DateTime, Utc};
12
13use rxqlite_common::{
14 Message,
15 MessageResponse,
16 Value,
17 QueryResult,
18 Col,
19};
20
21pub trait FromSqlxType {
22 type DB: Database;
23 fn from_sqlx_type(sqlx_type: &<Self::DB as Database>::TypeInfo)->Self
24 where <Self as FromSqlxType>::DB: sqlx::Database;
25
26}
27
28impl FromSqlxType for rxqlite_common::TypeInfo {
29 type DB = SqlxDb;
30 fn from_sqlx_type(sqlx_type: &<Self::DB as Database>::TypeInfo)->Self
31 where <Self as FromSqlxType>::DB: sqlx::Database
32 {
33 if sqlx_type.is_null() {
34 rxqlite_common::TypeInfo::Null
35 } else {
36 match sqlx_type.name() {
37 "BOOL" | "BOOLEAN" => {
38 rxqlite_common::TypeInfo::Bool
39 }
40 "INT" | "INTEGER" => {
41 rxqlite_common::TypeInfo::Int64
42 }
43 "TEXT" => {
44 rxqlite_common::TypeInfo::Text
45 }
46 "VARCHAR" => {
47 rxqlite_common::TypeInfo::Text
48 }
49 "FLOAT" => {
50 rxqlite_common::TypeInfo::Float
51 }
52 "REAL" => {
53 rxqlite_common::TypeInfo::Float
54 }
55 "DATE" => {
56 rxqlite_common::TypeInfo::Date
57 }
58 "TIME" => {
59 rxqlite_common::TypeInfo::Time
60 }
61 "DATETIME" => {
62 rxqlite_common::TypeInfo::DateTime
63 }
64 _ => rxqlite_common::TypeInfo::Text,
65 }
66 }
67 }
68}
69
70use sqlparser::ast::{Query, Statement};
71use sqlparser::dialect::SQLiteDialect;
72use sqlparser::parser::Parser;
73use futures_util::StreamExt;
74
75fn prepare_query<'q, DB: Database>(
76 sql: &'q str,
77 params: Vec<Value>,
78) -> Result<sqlx::query::Query<'q, DB, <DB as HasArguments<'q>>::Arguments>, String>
79where
80 i64: Encode<'q, DB> + Type<DB>,
81 &'q str: Encode<'q, DB> + Type<DB>,
82 bool: Encode<'q, DB> + Type<DB>,
83 String: Encode<'q, DB> + Type<DB>,
84 f32: Encode<'q, DB> + Type<DB>,
85 f64: Encode<'q, DB> + Type<DB>,
86 DateTime<Utc>: Encode<'q, DB> + Type<DB>,
87 Vec<u8>: Encode<'q, DB> + Type<DB>,
88 Option<String>: sqlx::Encode<'q, DB>
89{
90 let mut query = sqlx::query(sql);
91 for param in params {
92 match param {
93 Value::Null => {
94 query = query.bind(Option::<String>::None);
95 }
96 Value::Bool(b) => {
97 query = query.bind(b);
98 }
99 Value::Int(i) => {
100 query = query.bind(i);
101 }
102 Value::F32(f) => {
103 query = query.bind(f);
104 }
105 Value::F64(f) => {
106 query = query.bind(f);
107 }
108 Value::String(s) => {
109 query = query.bind(s);
110 }
111 Value::DateTime(dt) => {
112 query = query.bind(dt);
113 }
114 Value::Blob(blob) => {
115 query = query.bind(blob);
116 }
117 }
118 }
119 Ok(query)
120}
121
122pub async fn do_sql(pool: &Pool<SqlxDb>, message: Message) -> MessageResponse {
123 match message {
124 Message::FetchMany(sql, params) => {
125 let query = prepare_query(&sql, params);
126 if let Err(err) = &query {
127 let response_message = MessageResponse::Error(format!("{}", err));
128 return response_message;
129 }
130 let query = query.unwrap();
131 #[allow(deprecated)]
132 let stream = query.fetch_many(pool);
133 let res = stream.map(|item| {
134 match item {
135 Ok(item)=>{
136 match item {
137 sqlx::Either::Left(res)=>{
138 Ok(rxqlite_common::Either::Left(QueryResult {
139 last_insert_rowid: res.last_insert_rowid(),
140 changes: res.rows_affected(),
141 }))
142 }
143 sqlx::Either::Right(row)=>{
144 Ok(rxqlite_common::Either::Right(
145 {
146 let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
147 let sqlite_columns=row.columns();
148 let columns = sqlite_columns.iter().map(|col|
149 rxqlite_common::Column {
150 ordinal: col.ordinal(),
151 name: col.name().into(),
152 type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
153 }
154 ).collect::<Vec<_>>();
155
156 let cols = row.len();
157 for i in 0..cols {
158 let col = row.column(i);
159
160 let type_info = col.type_info();
161
162 let ordinal = col.ordinal();
163
164
165 if type_info.is_null() {
166 resulting_row.push(Col::new(Value::Null,ordinal as _));
167 } else {
168 match type_info.name() {
170 "BOOL" | "BOOLEAN" => {
171 let col: bool = row.get(i);
172 resulting_row.push(Col::new(col.into(),ordinal as _));
173
174 }
175 "INT" | "INTEGER" => {
176 let col: i64 = row.get(i);
177 resulting_row.push(Col::new(col.into(),ordinal as _));
178 }
179 "TEXT" => {
180 let col: String = row.get(i);
181 resulting_row.push(Col::new(col.into(),ordinal as _));
182 }
183 "VARCHAR" => {
184 let col: String = row.get(i);
185 resulting_row.push(Col::new(col.into(),ordinal as _));
186 }
187 "FLOAT" => {
188 let col: f32 = row.get(i);
189 resulting_row.push(Col::new(col.into(),ordinal as _));
190 }
191 "REAL" => {
192 let col: f64 = row.get(i);
193 resulting_row.push(Col::new(col.into(),ordinal as _));
194 }
195 "DATETIME" => {
196 let col: DateTime<Utc> = row.get(i);
197 resulting_row.push(Col::new(col.into(),ordinal as _));
198 }
199 other => {
200 let response_message = format!("unsupported column type: {}", other);
201 return Err(response_message);
202 }
203 }
204 }
205 }
206 (columns,resulting_row).into()
207 }))
208 }
209 }
210 }
211 Err(err)=>Err(format!("{}",err)),
212 }
213 })
214 .collect::<Vec<_>>()
215 .await;
216 let response_message = MessageResponse::QueryResultsAndRows(res);
217 response_message
218 }
219 Message::Execute(sql, params) => {
220 let query = prepare_query(&sql, params);
221 if let Err(err) = &query {
222 let response_message = MessageResponse::Error(format!("{}", err));
223 return response_message;
224 }
225 let query = query.unwrap();
226 let res = query.execute(pool).await;
227 match res {
228 Ok(res) => {
229 let response_message = MessageResponse::QueryResult(QueryResult {
230 last_insert_rowid: res.last_insert_rowid(),
231 changes: res.rows_affected(),
232 });
233 response_message
234 }
235 Err(err) => {
236 tracing::error!("{}",err);
237 let response_message = MessageResponse::Error(format!("{}", err));
238 response_message
239 }
240 }
241 }
242 Message::Fetch(sql, params) => {
243 let query = prepare_query(&sql, params);
244 if let Err(err) = &query {
245 let response_message = MessageResponse::Error(format!("{}", err));
246 return response_message;
247 }
248 let query = query.unwrap();
249 let res = query.fetch_all(pool).await;
250 let mut resulting_rows: Vec<rxqlite_common::Row> = vec![];
251 match res {
252 Ok(rows) => {
253 for row in rows.iter() {
254 let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
255 let sqlite_columns=row.columns();
256 let columns = sqlite_columns.iter().map(|col|
257 rxqlite_common::Column {
258 ordinal: col.ordinal(),
259 name: col.name().into(),
260 type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
261 }
262 ).collect::<Vec<_>>();
263 let cols = row.len();
264 for i in 0..cols {
265 let col = row.column(i);
266 let type_info = col.type_info();
267 let ordinal = col.ordinal();
268
269 if type_info.is_null() {
270 resulting_row.push(Col::new(Value::Null,ordinal as _));
271 } else {
272 match type_info.name() {
274 "BOOL" | "BOOLEAN" => {
275 let col: bool = row.get(i);
276
277 resulting_row.push(Col::new(col.into(),ordinal as _));
278 }
279 "INT" | "INTEGER" => {
280 let col: i64 = row.get(i);
281 resulting_row.push(Col::new(col.into(),ordinal as _));
282 }
283 "TEXT" => {
284 let col: String = row.get(i);
285 resulting_row.push(Col::new(col.into(),ordinal as _));
286 }
287 "VARCHAR" => {
288 let col: String = row.get(i);
289 resulting_row.push(Col::new(col.into(),ordinal as _));
290 }
291 "FLOAT" => {
292 let col: f32 = row.get(i);
293 resulting_row.push(Col::new(col.into(),ordinal as _));
294 }
295 "REAL" => {
296 let col: f64 = row.get(i);
297 resulting_row.push(Col::new(col.into(),ordinal as _));
298 }
299 "DATETIME" => {
300 let col: DateTime<Utc> = row.get(i);
301 resulting_row.push(Col::new(col.into(),ordinal as _));
302 }
303 other => {
304 let response_message = MessageResponse::Error(format!("unsupported column type: {}", other));
305 return response_message;
306 }
307 }
308 }
309 }
310 resulting_rows.push((columns,resulting_row).into());
311 }
312
313
314
315 let response_message = MessageResponse::Rows(resulting_rows);
316 response_message
317 }
318 Err(err) => {
319 tracing::error!("{}",err);
320 let response_message = MessageResponse::Error(format!("{}", err));
321 response_message
322 }
323 }
324 }
325 Message::FetchOne(sql, params) => {
326 let query = prepare_query(&sql, params);
327 if let Err(err) = &query {
328 let response_message = MessageResponse::Error(format!("{}", err));
329 return response_message;
330 }
331 let query = query.unwrap();
332 let res = query.fetch_one(pool).await;
333 match res {
334 Ok(row) => {
335 let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
336 let sqlite_columns=row.columns();
337 let columns = sqlite_columns.iter().map(|col|
338 rxqlite_common::Column {
339 ordinal: col.ordinal(),
340 name: col.name().into(),
341 type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
342 }
343 ).collect::<Vec<_>>();
344 let cols = row.len();
345 for i in 0..cols {
346 let col = row.column(i);
347 let type_info = col.type_info();
348 let ordinal = col.ordinal();
349
350 if type_info.is_null() {
351 resulting_row.push(Col::new(Value::Null,ordinal as _));
352 } else {
353 match type_info.name() {
355 "BOOL" | "BOOLEAN" => {
356 let col: bool = row.get(i);
357 resulting_row.push(Col::new(col.into(),ordinal as _));
358 }
359 "INT" | "INTEGER" => {
360 let col: i64 = row.get(i);
361 resulting_row.push(Col::new(col.into(),ordinal as _));
362 }
363 "TEXT" => {
364 let col: String = row.get(i);
365 resulting_row.push(Col::new(col.into(),ordinal as _));
366 }
367 "VARCHAR" => {
368 let col: String = row.get(i);
369 resulting_row.push(Col::new(col.into(),ordinal as _));
370 }
371 "FLOAT" => {
372 let col: f32 = row.get(i);
373 resulting_row.push(Col::new(col.into(),ordinal as _));
374 }
375 "REAL" => {
376 let col: f64 = row.get(i);
377 resulting_row.push(Col::new(col.into(),ordinal as _));
378 }
379 "DATETIME" => {
380 let col: DateTime<Utc> = row.get(i);
381 resulting_row.push(Col::new(col.into(),ordinal as _));
382 }
383 other => {
384 let response_message = MessageResponse::Error(format!("unsupported column type: {}", other));
385 return response_message;
386 }
387 }
388 }
389 }
390 let response_message = MessageResponse::Rows(
391 vec![(columns,resulting_row).into()]);
392 response_message
393 }
394 Err(err) => {
395 tracing::error!("{}",err);
396 let response_message = MessageResponse::Error(format!("{}", err));
397 response_message
398 }
399 }
400 }
401 Message::FetchOptional(sql, params) => {
402 let query = prepare_query(&sql, params);
403 if let Err(err) = &query {
404 let response_message = MessageResponse::Error(format!("{}", err));
405 return response_message;
406 }
407 let query = query.unwrap();
408 let res = query.fetch_optional(pool).await;
409 match res {
410 Ok(row) => {
411 if let Some(row) = row {
412 let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
413 let sqlite_columns=row.columns();
414 let columns = sqlite_columns.iter().map(|col|
415 rxqlite_common::Column {
416 ordinal: col.ordinal(),
417 name: col.name().into(),
418 type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
419 }
420 ).collect::<Vec<_>>();
421 let cols = row.len();
422 for i in 0..cols {
423 let col = row.column(i);
424 let type_info = col.type_info();
425 let ordinal = col.ordinal();
426
427 if type_info.is_null() {
428 resulting_row.push(Col::new(Value::Null,ordinal as _));
429 } else {
430 match type_info.name() {
432 "BOOL" | "BOOLEAN" => {
433 let col: bool = row.get(i);
434 resulting_row.push(Col::new(col.into(),ordinal as _));
435 }
436 "INT" | "INTEGER" => {
437 let col: i64 = row.get(i);
438 resulting_row.push(Col::new(col.into(),ordinal as _));
439 }
440 "TEXT" => {
441 let col: String = row.get(i);
442 resulting_row.push(Col::new(col.into(),ordinal as _));
443 }
444 "VARCHAR" => {
445 let col: String = row.get(i);
446 resulting_row.push(Col::new(col.into(),ordinal as _));
447 }
448 "FLOAT" => {
449 let col: f32 = row.get(i);
450 resulting_row.push(Col::new(col.into(),ordinal as _));
451 }
452 "REAL" => {
453 let col: f64 = row.get(i);
454 resulting_row.push(Col::new(col.into(),ordinal as _));
455 }
456 "DATETIME" => {
457 let col: DateTime<Utc> = row.get(i);
458 resulting_row.push(Col::new(col.into(),ordinal as _));
459 }
460 other => {
461 let response_message = MessageResponse::Error(format!("unsupported column type: {}", other));
462 return response_message;
463 }
464 }
465 }
466 }
467 let response_message = MessageResponse::Rows(vec![(columns,resulting_row).into()]);
468 response_message
469 } else {
470 let response_message = MessageResponse::Rows(Default::default());
471 response_message
472 }
473 }
474 Err(err) => {
475 tracing::error!("{}",err);
476 let response_message = MessageResponse::Error(format!("{}", err));
477 response_message
478 }
479 }
480 }
481 }
482}
483
484fn is_for_update_or_share(query: &Box<Query>) -> bool {
485 !query.locks.is_empty()
486}
487
488pub fn is_query_write(sql: &str) -> anyhow::Result<bool> {
489 let ast = Parser::parse_sql(&SQLiteDialect {}, sql)?;
490 for stmt in ast {
491 match stmt {
492 Statement::StartTransaction { .. }
493 | Statement::SetTransaction { .. }
494 | Statement::Commit { .. }
495 | Statement::Rollback { .. } => {
496 return Err(anyhow::anyhow!("Transactions are not supported in rxqlite"));
497 }
498
499 Statement::Query(query) => {
500 if is_for_update_or_share(&query) {
501 return Ok(true);
502 } else {
503 }
504 }
505 Statement::Insert { .. } | Statement::Update { .. } | Statement::Delete { .. } => {
506 return Ok(true);
507 }
508 _ => {
509 return Ok(true);
510 }
511 }
512 }
513 Ok(false)
514}