rxqlite_sqlx_common/
lib.rs

1#![deny(unused_extern_crates)]
2#![deny(warnings)]
3
4#[cfg(feature = "sqlite")]
5use sqlx_sqlite_cipher::Sqlite;
6#[cfg(feature = "sqlite")]
7pub type SqlxDb = Sqlite;
8
9use sqlx::prelude::*;
10use sqlx::{database::HasArguments, Column, Database, Pool, TypeInfo};
11use sqlx_core::types::chrono::{DateTime, Utc};
12
13use rxqlite_common::{
14  Message, 
15  MessageResponse, 
16  Value,
17  QueryResult,
18  Col,
19};
20
21pub trait FromSqlxType {
22  type DB: Database;
23  fn from_sqlx_type(sqlx_type: &<Self::DB as Database>::TypeInfo)->Self
24    where <Self as FromSqlxType>::DB: sqlx::Database;
25  
26}
27
28impl FromSqlxType for rxqlite_common::TypeInfo {
29  type DB = SqlxDb;
30  fn from_sqlx_type(sqlx_type: &<Self::DB as Database>::TypeInfo)->Self 
31    where <Self as FromSqlxType>::DB: sqlx::Database
32  {
33    if sqlx_type.is_null() {
34      rxqlite_common::TypeInfo::Null
35    } else {
36      match sqlx_type.name() {
37        "BOOL" | "BOOLEAN" => {
38            rxqlite_common::TypeInfo::Bool
39        }
40        "INT" | "INTEGER" => {
41            rxqlite_common::TypeInfo::Int64
42        }
43        "TEXT" => {
44            rxqlite_common::TypeInfo::Text
45        }
46        "VARCHAR" => {
47            rxqlite_common::TypeInfo::Text
48        }
49        "FLOAT" => {
50            rxqlite_common::TypeInfo::Float
51        }
52        "REAL" => {
53            rxqlite_common::TypeInfo::Float
54        }
55        "DATE" => {
56            rxqlite_common::TypeInfo::Date
57        }
58        "TIME" => {
59            rxqlite_common::TypeInfo::Time
60        }
61        "DATETIME" => {
62            rxqlite_common::TypeInfo::DateTime
63        }
64        _ => rxqlite_common::TypeInfo::Text,
65      }
66    }
67  }
68}
69
70use sqlparser::ast::{Query, Statement};
71use sqlparser::dialect::SQLiteDialect;
72use sqlparser::parser::Parser;
73use futures_util::StreamExt;
74
75fn prepare_query<'q, DB: Database>(
76    sql: &'q str,
77    params: Vec<Value>,
78) -> Result<sqlx::query::Query<'q, DB, <DB as HasArguments<'q>>::Arguments>, String>
79where
80    i64: Encode<'q, DB> + Type<DB>,
81    &'q str: Encode<'q, DB> + Type<DB>,
82    bool: Encode<'q, DB> + Type<DB>,
83    String: Encode<'q, DB> + Type<DB>,
84    f32: Encode<'q, DB> + Type<DB>,
85    f64: Encode<'q, DB> + Type<DB>,
86    DateTime<Utc>: Encode<'q, DB> + Type<DB>,
87    Vec<u8>: Encode<'q, DB> + Type<DB>,
88    Option<String>: sqlx::Encode<'q, DB>
89{
90    let mut query = sqlx::query(sql);
91    for param in params {
92        match param {
93            Value::Null => {
94                query = query.bind(Option::<String>::None);
95            }
96            Value::Bool(b) => {
97                query = query.bind(b);
98            }
99            Value::Int(i) => {
100                query = query.bind(i);
101            }
102            Value::F32(f) => {
103                query = query.bind(f);
104            }
105            Value::F64(f) => {
106                query = query.bind(f);
107            }
108            Value::String(s) => {
109                query = query.bind(s);
110            }
111            Value::DateTime(dt) => {
112                query = query.bind(dt);
113            }
114            Value::Blob(blob) => {
115                query = query.bind(blob);
116            }
117        }
118    }
119    Ok(query)
120}
121
122pub async fn do_sql(pool: &Pool<SqlxDb>, message: Message) -> MessageResponse {
123    match message {
124        Message::FetchMany(sql, params) => {
125          let query = prepare_query(&sql, params);
126            if let Err(err) = &query {
127                let response_message = MessageResponse::Error(format!("{}", err));
128                return response_message;
129            }
130            let query = query.unwrap();
131            #[allow(deprecated)]
132            let stream = query.fetch_many(pool);
133            let res = stream.map(|item| {
134              match item {
135                Ok(item)=>{
136                  match item {
137                    sqlx::Either::Left(res)=>{
138                        Ok(rxqlite_common::Either::Left(QueryResult {
139                          last_insert_rowid: res.last_insert_rowid(),
140                          changes: res.rows_affected(),
141                        }))
142                    }
143                    sqlx::Either::Right(row)=>{
144                      Ok(rxqlite_common::Either::Right(
145                      {
146                        let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
147                        let sqlite_columns=row.columns();
148                        let columns = sqlite_columns.iter().map(|col|
149                          rxqlite_common::Column {
150                            ordinal: col.ordinal(),
151                            name: col.name().into(),
152                            type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
153                          }
154                        ).collect::<Vec<_>>();
155                        
156                        let cols = row.len();
157                        for i in 0..cols {
158                            let col = row.column(i);
159                            
160                            let type_info = col.type_info();
161                            
162                            let ordinal = col.ordinal();
163                            
164                            
165                            if type_info.is_null() {
166                                resulting_row.push(Col::new(Value::Null,ordinal as _));
167                            } else {
168                                //println!("TYPE: {}",type_info.name());
169                                match type_info.name() {
170                                    "BOOL" | "BOOLEAN" => {
171                                        let col: bool = row.get(i);
172                                        resulting_row.push(Col::new(col.into(),ordinal as _));
173                                        
174                                    }
175                                    "INT" | "INTEGER" => {
176                                        let col: i64 = row.get(i);
177                                        resulting_row.push(Col::new(col.into(),ordinal as _));
178                                    }
179                                    "TEXT" => {
180                                        let col: String = row.get(i);
181                                        resulting_row.push(Col::new(col.into(),ordinal as _));
182                                    }
183                                    "VARCHAR" => {
184                                        let col: String = row.get(i);
185                                        resulting_row.push(Col::new(col.into(),ordinal as _));
186                                    }
187                                    "FLOAT" => {
188                                        let col: f32 = row.get(i);
189                                        resulting_row.push(Col::new(col.into(),ordinal as _));
190                                    }
191                                    "REAL" => {
192                                        let col: f64 = row.get(i);
193                                        resulting_row.push(Col::new(col.into(),ordinal as _));
194                                    }
195                                    "DATETIME" => {
196                                        let col: DateTime<Utc> = row.get(i);
197                                        resulting_row.push(Col::new(col.into(),ordinal as _));
198                                    }
199                                    other => {
200                                      let response_message = format!("unsupported column type: {}", other);
201                                      return Err(response_message);
202                                    }
203                                }
204                            }
205                        }
206                        (columns,resulting_row).into()
207                      }))
208                    }
209                  }
210                }
211                Err(err)=>Err(format!("{}",err)),
212              }
213            })
214            .collect::<Vec<_>>()
215            .await;
216            let response_message = MessageResponse::QueryResultsAndRows(res);
217            response_message        
218        }
219        Message::Execute(sql, params) => {
220            let query = prepare_query(&sql, params);
221            if let Err(err) = &query {
222                let response_message = MessageResponse::Error(format!("{}", err));
223                return response_message;
224            }
225            let query = query.unwrap();
226            let res = query.execute(pool).await;
227            match res {
228                Ok(res) => {
229                    let response_message = MessageResponse::QueryResult(QueryResult {
230                      last_insert_rowid: res.last_insert_rowid(),
231                      changes: res.rows_affected(),
232                    });
233                    response_message
234                }
235                Err(err) => {
236                    tracing::error!("{}",err);
237                    let response_message = MessageResponse::Error(format!("{}", err));
238                    response_message
239                }
240            }
241        }
242        Message::Fetch(sql, params) => {
243            let query = prepare_query(&sql, params);
244            if let Err(err) = &query {
245                let response_message = MessageResponse::Error(format!("{}", err));
246                return response_message;
247            }
248            let query = query.unwrap();
249            let res = query.fetch_all(pool).await;
250            let mut resulting_rows: Vec<rxqlite_common::Row> = vec![];
251            match res {
252                Ok(rows) => {
253                    for row in rows.iter() {
254                        let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
255                        let sqlite_columns=row.columns();
256                        let columns = sqlite_columns.iter().map(|col|
257                          rxqlite_common::Column {
258                            ordinal: col.ordinal(),
259                            name: col.name().into(),
260                            type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
261                          }
262                        ).collect::<Vec<_>>();
263                        let cols = row.len();
264                        for i in 0..cols {
265                            let col = row.column(i);
266                            let type_info = col.type_info();
267                            let ordinal = col.ordinal();
268                            
269                            if type_info.is_null() {
270                                resulting_row.push(Col::new(Value::Null,ordinal as _));
271                            } else {
272                                //println!("TYPE: {}",type_info.name());
273                                match type_info.name() {
274                                    "BOOL" | "BOOLEAN" => {
275                                        let col: bool = row.get(i);
276                                        
277                                        resulting_row.push(Col::new(col.into(),ordinal as _));
278                                    }
279                                    "INT" | "INTEGER" => {
280                                        let col: i64 = row.get(i);
281                                        resulting_row.push(Col::new(col.into(),ordinal as _));
282                                    }
283                                    "TEXT" => {
284                                        let col: String = row.get(i);
285                                        resulting_row.push(Col::new(col.into(),ordinal as _));
286                                    }
287                                    "VARCHAR" => {
288                                        let col: String = row.get(i);
289                                        resulting_row.push(Col::new(col.into(),ordinal as _));
290                                    }
291                                    "FLOAT" => {
292                                        let col: f32 = row.get(i);
293                                        resulting_row.push(Col::new(col.into(),ordinal as _));
294                                    }
295                                    "REAL" => {
296                                        let col: f64 = row.get(i);
297                                        resulting_row.push(Col::new(col.into(),ordinal as _));
298                                    }
299                                    "DATETIME" => {
300                                        let col: DateTime<Utc> = row.get(i);
301                                        resulting_row.push(Col::new(col.into(),ordinal as _));
302                                    }
303                                    other => {
304                                      let response_message = MessageResponse::Error(format!("unsupported column type: {}", other));
305                                      return response_message;
306                                    }
307                                }
308                            }
309                        }
310                        resulting_rows.push((columns,resulting_row).into());
311                    }
312                    
313                    
314                    
315                    let response_message = MessageResponse::Rows(resulting_rows);
316                    response_message
317                }
318                Err(err) => {
319                    tracing::error!("{}",err);
320                    let response_message = MessageResponse::Error(format!("{}", err));
321                    response_message
322                }
323            }
324        }
325        Message::FetchOne(sql, params) => {
326            let query = prepare_query(&sql, params);
327            if let Err(err) = &query {
328                let response_message = MessageResponse::Error(format!("{}", err));
329                return response_message;
330            }
331            let query = query.unwrap();
332            let res = query.fetch_one(pool).await;
333            match res {
334                Ok(row) => {
335                    let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
336                    let sqlite_columns=row.columns();
337                    let columns = sqlite_columns.iter().map(|col|
338                      rxqlite_common::Column {
339                        ordinal: col.ordinal(),
340                        name: col.name().into(),
341                        type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
342                      }
343                    ).collect::<Vec<_>>();
344                    let cols = row.len();
345                    for i in 0..cols {
346                        let col = row.column(i);
347                        let type_info = col.type_info();
348                        let ordinal = col.ordinal();
349                        
350                        if type_info.is_null() {
351                            resulting_row.push(Col::new(Value::Null,ordinal as _));
352                        } else {
353                            //println!("TYPE: {}",type_info.name());
354                            match type_info.name() {
355                                "BOOL" | "BOOLEAN" => {
356                                    let col: bool = row.get(i);
357                                    resulting_row.push(Col::new(col.into(),ordinal as _));
358                                }
359                                "INT" | "INTEGER" => {
360                                    let col: i64 = row.get(i);
361                                    resulting_row.push(Col::new(col.into(),ordinal as _));
362                                }
363                                "TEXT" => {
364                                    let col: String = row.get(i);
365                                    resulting_row.push(Col::new(col.into(),ordinal as _));
366                                }
367                                "VARCHAR" => {
368                                    let col: String = row.get(i);
369                                    resulting_row.push(Col::new(col.into(),ordinal as _));
370                                }
371                                "FLOAT" => {
372                                    let col: f32 = row.get(i);
373                                    resulting_row.push(Col::new(col.into(),ordinal as _));
374                                }
375                                "REAL" => {
376                                    let col: f64 = row.get(i);
377                                    resulting_row.push(Col::new(col.into(),ordinal as _));
378                                }
379                                "DATETIME" => {
380                                    let col: DateTime<Utc> = row.get(i);
381                                    resulting_row.push(Col::new(col.into(),ordinal as _));
382                                }
383                                other => {
384                                    let response_message = MessageResponse::Error(format!("unsupported column type: {}", other));
385                                    return response_message;
386                                }
387                            }
388                        }
389                    }
390                    let response_message = MessageResponse::Rows(
391                    vec![(columns,resulting_row).into()]);
392                    response_message
393                }
394                Err(err) => {
395                    tracing::error!("{}",err);
396                    let response_message = MessageResponse::Error(format!("{}", err));
397                    response_message
398                }
399            }
400        }
401        Message::FetchOptional(sql, params) => {
402            let query = prepare_query(&sql, params);
403            if let Err(err) = &query {
404                let response_message = MessageResponse::Error(format!("{}", err));
405                return response_message;
406            }
407            let query = query.unwrap();
408            let res = query.fetch_optional(pool).await;
409            match res {
410                Ok(row) => {
411                    if let Some(row) = row {
412                        let mut resulting_row: Vec<rxqlite_common::Col> = vec![];
413                        let sqlite_columns=row.columns();
414                        let columns = sqlite_columns.iter().map(|col|
415                          rxqlite_common::Column {
416                            ordinal: col.ordinal(),
417                            name: col.name().into(),
418                            type_info: rxqlite_common::TypeInfo::from_sqlx_type(&col.type_info()),
419                          }
420                        ).collect::<Vec<_>>();
421                        let cols = row.len();
422                        for i in 0..cols {
423                            let col = row.column(i);
424                            let type_info = col.type_info();
425                            let ordinal = col.ordinal();
426                            
427                            if type_info.is_null() {
428                                resulting_row.push(Col::new(Value::Null,ordinal as _));
429                            } else {
430                                //println!("TYPE: {}",type_info.name());
431                                match type_info.name() {
432                                    "BOOL" | "BOOLEAN" => {
433                                        let col: bool = row.get(i);
434                                        resulting_row.push(Col::new(col.into(),ordinal as _));
435                                    }
436                                    "INT" | "INTEGER" => {
437                                        let col: i64 = row.get(i);
438                                        resulting_row.push(Col::new(col.into(),ordinal as _));
439                                    }
440                                    "TEXT" => {
441                                        let col: String = row.get(i);
442                                        resulting_row.push(Col::new(col.into(),ordinal as _));
443                                    }
444                                    "VARCHAR" => {
445                                        let col: String = row.get(i);
446                                        resulting_row.push(Col::new(col.into(),ordinal as _));
447                                    }
448                                    "FLOAT" => {
449                                        let col: f32 = row.get(i);
450                                        resulting_row.push(Col::new(col.into(),ordinal as _));
451                                    }
452                                    "REAL" => {
453                                        let col: f64 = row.get(i);
454                                        resulting_row.push(Col::new(col.into(),ordinal as _));
455                                    }
456                                    "DATETIME" => {
457                                        let col: DateTime<Utc> = row.get(i);
458                                        resulting_row.push(Col::new(col.into(),ordinal as _));
459                                    }
460                                    other => {
461                                      let response_message = MessageResponse::Error(format!("unsupported column type: {}", other));
462                                      return response_message;
463                                    }
464                                }
465                            }
466                        }
467                        let response_message = MessageResponse::Rows(vec![(columns,resulting_row).into()]);
468                        response_message
469                    } else {
470                        let response_message = MessageResponse::Rows(Default::default());
471                        response_message
472                    }
473                }
474                Err(err) => {
475                    tracing::error!("{}",err);
476                    let response_message = MessageResponse::Error(format!("{}", err));
477                    response_message
478                }
479            }
480        }
481    }
482}
483
484fn is_for_update_or_share(query: &Box<Query>) -> bool {
485    !query.locks.is_empty()
486}
487
488pub fn is_query_write(sql: &str) -> anyhow::Result<bool> {
489    let ast = Parser::parse_sql(&SQLiteDialect {}, sql)?;
490    for stmt in ast {
491        match stmt {
492            Statement::StartTransaction { .. } 
493              | Statement::SetTransaction { .. } 
494              | Statement::Commit { .. }
495              | Statement::Rollback { .. } => {
496              return Err(anyhow::anyhow!("Transactions are not supported in rxqlite"));
497            }
498             
499            Statement::Query(query) => {
500                if is_for_update_or_share(&query) {
501                    return Ok(true);
502                } else {
503                }
504            }
505            Statement::Insert { .. } | Statement::Update { .. } | Statement::Delete { .. } => {
506                return Ok(true);
507            }
508            _ => {
509                return Ok(true);
510            }
511        }
512    }
513    Ok(false)
514}