switchy_database/rusqlite/
mod.rs

1use std::{ops::Deref, sync::Arc};
2
3use async_trait::async_trait;
4use rusqlite::{Connection, Row, Rows, Statement, types::Value};
5use thiserror::Error;
6use tokio::sync::Mutex;
7
8use crate::{
9    Database, DatabaseError, DatabaseValue, DeleteStatement, InsertStatement, SelectQuery,
10    UpdateStatement, UpsertMultiStatement, UpsertStatement,
11    query::{BooleanExpression, Expression, ExpressionType, Join, Sort, SortDirection},
12};
13
14#[allow(clippy::module_name_repetitions)]
15#[derive(Debug)]
16pub struct RusqliteDatabase {
17    connection: Arc<Mutex<Connection>>,
18}
19
20impl RusqliteDatabase {
21    pub const fn new(connection: Arc<Mutex<Connection>>) -> Self {
22        Self { connection }
23    }
24}
25
26trait ToSql {
27    fn to_sql(&self) -> String;
28}
29
30impl<T: Expression + ?Sized> ToSql for T {
31    #[allow(clippy::too_many_lines)]
32    fn to_sql(&self) -> String {
33        match self.expression_type() {
34            ExpressionType::Eq(value) => {
35                if value.right.is_null() {
36                    format!("({} IS {})", value.left.to_sql(), value.right.to_sql())
37                } else {
38                    format!("({} = {})", value.left.to_sql(), value.right.to_sql())
39                }
40            }
41            ExpressionType::Gt(value) => {
42                if value.right.is_null() {
43                    panic!("Invalid > comparison with NULL");
44                } else {
45                    format!("({} > {})", value.left.to_sql(), value.right.to_sql())
46                }
47            }
48            ExpressionType::In(value) => {
49                format!("{} IN ({})", value.left.to_sql(), value.values.to_sql())
50            }
51            ExpressionType::NotIn(value) => {
52                format!("{} NOT IN ({})", value.left.to_sql(), value.values.to_sql())
53            }
54            ExpressionType::Lt(value) => {
55                if value.right.is_null() {
56                    panic!("Invalid < comparison with NULL");
57                } else {
58                    format!("({} < {})", value.left.to_sql(), value.right.to_sql())
59                }
60            }
61            ExpressionType::Or(value) => format!(
62                "({})",
63                value
64                    .conditions
65                    .iter()
66                    .map(|x| x.to_sql())
67                    .collect::<Vec<_>>()
68                    .join(" OR ")
69            ),
70            ExpressionType::And(value) => format!(
71                "({})",
72                value
73                    .conditions
74                    .iter()
75                    .map(|x| x.to_sql())
76                    .collect::<Vec<_>>()
77                    .join(" AND ")
78            ),
79            ExpressionType::Gte(value) => {
80                if value.right.is_null() {
81                    panic!("Invalid >= comparison with NULL");
82                } else {
83                    format!("({} >= {})", value.left.to_sql(), value.right.to_sql())
84                }
85            }
86            ExpressionType::Lte(value) => {
87                if value.right.is_null() {
88                    panic!("Invalid <= comparison with NULL");
89                } else {
90                    format!("({} <= {})", value.left.to_sql(), value.right.to_sql())
91                }
92            }
93            ExpressionType::Join(value) => format!(
94                "{} JOIN {} ON {}",
95                if value.left { "LEFT" } else { "" },
96                value.table_name,
97                value.on
98            ),
99            ExpressionType::Sort(value) => format!(
100                "({}) {}",
101                value.expression.to_sql(),
102                match value.direction {
103                    SortDirection::Asc => "ASC",
104                    SortDirection::Desc => "DESC",
105                }
106            ),
107            ExpressionType::NotEq(value) => {
108                if value.right.is_null() {
109                    format!("({} IS NOT {})", value.left.to_sql(), value.right.to_sql())
110                } else {
111                    format!("({} != {})", value.left.to_sql(), value.right.to_sql())
112                }
113            }
114            ExpressionType::InList(value) => value
115                .values
116                .iter()
117                .map(|value| value.to_sql())
118                .collect::<Vec<_>>()
119                .join(","),
120            ExpressionType::Coalesce(value) => format!(
121                "IFNULL({})",
122                value
123                    .values
124                    .iter()
125                    .map(|value| value.to_sql())
126                    .collect::<Vec<_>>()
127                    .join(",")
128            ),
129            ExpressionType::Literal(value) => value.value.to_string(),
130            ExpressionType::Identifier(value) => value.value.clone(),
131            ExpressionType::SelectQuery(value) => {
132                let joins = value.joins.as_ref().map_or_else(String::new, |joins| {
133                    joins.iter().map(Join::to_sql).collect::<Vec<_>>().join(" ")
134                });
135
136                let where_clause = value.filters.as_ref().map_or_else(String::new, |filters| {
137                    if filters.is_empty() {
138                        String::new()
139                    } else {
140                        format!(
141                            "WHERE {}",
142                            filters
143                                .iter()
144                                .map(|x| format!("({})", x.to_sql()))
145                                .collect::<Vec<_>>()
146                                .join(" AND ")
147                        )
148                    }
149                });
150
151                let sort_clause = value.sorts.as_ref().map_or_else(String::new, |sorts| {
152                    if sorts.is_empty() {
153                        String::new()
154                    } else {
155                        format!(
156                            "ORDER BY {}",
157                            sorts
158                                .iter()
159                                .map(Sort::to_sql)
160                                .collect::<Vec<_>>()
161                                .join(", ")
162                        )
163                    }
164                });
165
166                let limit = value
167                    .limit
168                    .map_or_else(String::new, |limit| format!("LIMIT {limit}"));
169
170                format!(
171                    "SELECT {} {} FROM {} {} {} {} {}",
172                    if value.distinct { "DISTINCT" } else { "" },
173                    value.columns.join(", "),
174                    value.table_name,
175                    joins,
176                    where_clause,
177                    sort_clause,
178                    limit
179                )
180            }
181            ExpressionType::DatabaseValue(value) => match value {
182                DatabaseValue::Null
183                | DatabaseValue::BoolOpt(None)
184                | DatabaseValue::StringOpt(None)
185                | DatabaseValue::NumberOpt(None)
186                | DatabaseValue::UNumberOpt(None)
187                | DatabaseValue::RealOpt(None) => "NULL".to_string(),
188                DatabaseValue::Now => "strftime('%Y-%m-%dT%H:%M:%f', 'now')".to_string(),
189                DatabaseValue::NowAdd(add) => {
190                    format!("strftime('%Y-%m-%dT%H:%M:%f', DateTime('now', 'LocalTime', {add}))")
191                }
192                _ => "?".to_string(),
193            },
194        }
195    }
196}
197
198#[allow(clippy::module_name_repetitions)]
199#[derive(Debug, Error)]
200pub enum RusqliteDatabaseError {
201    #[error(transparent)]
202    Rusqlite(#[from] rusqlite::Error),
203    #[error("No ID")]
204    NoId,
205    #[error("No row")]
206    NoRow,
207    #[error("Invalid request")]
208    InvalidRequest,
209    #[error("Missing unique")]
210    MissingUnique,
211}
212
213impl From<RusqliteDatabaseError> for DatabaseError {
214    fn from(value: RusqliteDatabaseError) -> Self {
215        Self::Rusqlite(value)
216    }
217}
218
219#[async_trait]
220impl Database for RusqliteDatabase {
221    async fn query(&self, query: &SelectQuery<'_>) -> Result<Vec<crate::Row>, DatabaseError> {
222        Ok(select(
223            &*self.connection.lock().await,
224            query.table_name,
225            query.distinct,
226            query.columns,
227            query.filters.as_deref(),
228            query.joins.as_deref(),
229            query.sorts.as_deref(),
230            query.limit,
231        )?)
232    }
233
234    async fn query_first(
235        &self,
236        query: &SelectQuery<'_>,
237    ) -> Result<Option<crate::Row>, DatabaseError> {
238        Ok(find_row(
239            &*self.connection.lock().await,
240            query.table_name,
241            query.distinct,
242            query.columns,
243            query.filters.as_deref(),
244            query.joins.as_deref(),
245            query.sorts.as_deref(),
246        )?)
247    }
248
249    async fn exec_delete(
250        &self,
251        statement: &DeleteStatement<'_>,
252    ) -> Result<Vec<crate::Row>, DatabaseError> {
253        Ok(delete(
254            &*self.connection.lock().await,
255            statement.table_name,
256            statement.filters.as_deref(),
257            statement.limit,
258        )?)
259    }
260
261    async fn exec_delete_first(
262        &self,
263        statement: &DeleteStatement<'_>,
264    ) -> Result<Option<crate::Row>, DatabaseError> {
265        Ok(delete(
266            &*self.connection.lock().await,
267            statement.table_name,
268            statement.filters.as_deref(),
269            Some(1),
270        )?
271        .into_iter()
272        .next())
273    }
274
275    async fn exec_insert(
276        &self,
277        statement: &InsertStatement<'_>,
278    ) -> Result<crate::Row, DatabaseError> {
279        Ok(insert_and_get_row(
280            &*self.connection.lock().await,
281            statement.table_name,
282            &statement.values,
283        )?)
284    }
285
286    async fn exec_update(
287        &self,
288        statement: &UpdateStatement<'_>,
289    ) -> Result<Vec<crate::Row>, DatabaseError> {
290        Ok(update_and_get_rows(
291            &*self.connection.lock().await,
292            statement.table_name,
293            &statement.values,
294            statement.filters.as_deref(),
295            statement.limit,
296        )?)
297    }
298
299    async fn exec_update_first(
300        &self,
301        statement: &UpdateStatement<'_>,
302    ) -> Result<Option<crate::Row>, DatabaseError> {
303        Ok(update_and_get_row(
304            &*self.connection.lock().await,
305            statement.table_name,
306            &statement.values,
307            statement.filters.as_deref(),
308            statement.limit,
309        )?)
310    }
311
312    async fn exec_upsert(
313        &self,
314        statement: &UpsertStatement<'_>,
315    ) -> Result<Vec<crate::Row>, DatabaseError> {
316        Ok(upsert(
317            &*self.connection.lock().await,
318            statement.table_name,
319            &statement.values,
320            statement.filters.as_deref(),
321            statement.limit,
322        )?)
323    }
324
325    async fn exec_upsert_first(
326        &self,
327        statement: &UpsertStatement<'_>,
328    ) -> Result<crate::Row, DatabaseError> {
329        Ok(upsert_and_get_row(
330            &*self.connection.lock().await,
331            statement.table_name,
332            &statement.values,
333            statement.filters.as_deref(),
334            statement.limit,
335        )?)
336    }
337
338    async fn exec_upsert_multi(
339        &self,
340        statement: &UpsertMultiStatement<'_>,
341    ) -> Result<Vec<crate::Row>, DatabaseError> {
342        Ok(upsert_multi(
343            &*self.connection.lock().await,
344            statement.table_name,
345            statement
346                .unique
347                .as_ref()
348                .ok_or(RusqliteDatabaseError::MissingUnique)?,
349            &statement.values,
350        )?)
351    }
352
353    async fn exec_raw(&self, statement: &str) -> Result<(), DatabaseError> {
354        log::trace!("exec_raw: query:\n{statement}");
355
356        self.connection
357            .lock()
358            .await
359            .execute_batch(statement)
360            .map_err(RusqliteDatabaseError::Rusqlite)?;
361        Ok(())
362    }
363
364    #[cfg(feature = "schema")]
365    #[allow(clippy::too_many_lines)]
366    async fn exec_create_table(
367        &self,
368        statement: &crate::schema::CreateTableStatement<'_>,
369    ) -> Result<(), DatabaseError> {
370        let mut query = "CREATE TABLE ".to_string();
371
372        if statement.if_not_exists {
373            query.push_str("IF NOT EXISTS ");
374        }
375
376        query.push_str(statement.table_name);
377        query.push('(');
378
379        let mut first = true;
380
381        for column in &statement.columns {
382            if first {
383                first = false;
384            } else {
385                query.push(',');
386            }
387
388            if column.auto_increment && statement.primary_key.is_none_or(|x| x != column.name) {
389                return Err(DatabaseError::InvalidSchema(format!(
390                    "Column '{}' must be the primary key to enable auto increment",
391                    &column.name
392                )));
393            }
394
395            query.push_str(&column.name);
396            query.push(' ');
397
398            match column.data_type {
399                crate::schema::DataType::VarChar(size) => {
400                    query.push_str("VARCHAR(");
401                    query.push_str(&size.to_string());
402                    query.push(')');
403                }
404                crate::schema::DataType::Text => query.push_str("TEXT"),
405                crate::schema::DataType::Bool
406                | crate::schema::DataType::SmallInt
407                | crate::schema::DataType::Int
408                | crate::schema::DataType::BigInt => {
409                    query.push_str("INTEGER");
410                }
411                crate::schema::DataType::Double
412                | crate::schema::DataType::Decimal(..)
413                | crate::schema::DataType::Real => query.push_str("REAL"),
414                crate::schema::DataType::DateTime => query.push_str("VARCHAR(23)"),
415            }
416
417            if !column.nullable {
418                query.push_str(" NOT NULL");
419            }
420
421            if let Some(default) = &column.default {
422                query.push_str(" DEFAULT ");
423
424                match default {
425                    DatabaseValue::Null
426                    | DatabaseValue::StringOpt(None)
427                    | DatabaseValue::BoolOpt(None)
428                    | DatabaseValue::NumberOpt(None)
429                    | DatabaseValue::UNumberOpt(None)
430                    | DatabaseValue::RealOpt(None) => {
431                        query.push_str("NULL");
432                    }
433                    DatabaseValue::StringOpt(Some(x)) | DatabaseValue::String(x) => {
434                        query.push('\'');
435                        query.push_str(x);
436                        query.push('\'');
437                    }
438                    DatabaseValue::BoolOpt(Some(x)) | DatabaseValue::Bool(x) => {
439                        query.push_str(if *x { "1" } else { "0" });
440                    }
441                    DatabaseValue::NumberOpt(Some(x)) | DatabaseValue::Number(x) => {
442                        query.push_str(&x.to_string());
443                    }
444                    DatabaseValue::UNumberOpt(Some(x)) | DatabaseValue::UNumber(x) => {
445                        query.push_str(&x.to_string());
446                    }
447                    DatabaseValue::RealOpt(Some(x)) | DatabaseValue::Real(x) => {
448                        query.push_str(&x.to_string());
449                    }
450                    DatabaseValue::NowAdd(x) => {
451                        query.push_str(
452                            "(strftime('%Y-%m-%dT%H:%M:%f', DateTime('now', 'LocalTime', ",
453                        );
454                        query.push_str(x);
455                        query.push_str(")))");
456                    }
457                    DatabaseValue::Now => {
458                        query.push_str("(strftime('%Y-%m-%dT%H:%M:%f', 'now'))");
459                    }
460                    DatabaseValue::DateTime(x) => {
461                        query.push('\'');
462                        query.push_str(&x.and_utc().to_rfc3339());
463                        query.push('\'');
464                    }
465                }
466            }
467        }
468
469        moosicbox_assert::assert!(!first);
470
471        if let Some(primary_key) = &statement.primary_key {
472            query.push_str(", PRIMARY KEY (");
473            query.push_str(primary_key);
474            query.push(')');
475        }
476
477        for (source, target) in &statement.foreign_keys {
478            query.push_str(", FOREIGN KEY (");
479            query.push_str(source);
480            query.push_str(") REFERENCES (");
481            query.push_str(target);
482            query.push(')');
483        }
484
485        query.push(')');
486
487        self.exec_raw(&query).await?;
488
489        Ok(())
490    }
491}
492
493impl From<Value> for DatabaseValue {
494    fn from(value: Value) -> Self {
495        match value {
496            Value::Null => Self::Null,
497            Value::Integer(value) => Self::Number(value),
498            Value::Real(value) => Self::Real(value),
499            Value::Text(value) => Self::String(value),
500            Value::Blob(_value) => unimplemented!("Blob types are not supported yet"),
501        }
502    }
503}
504
505fn from_row(column_names: &[String], row: &Row<'_>) -> Result<crate::Row, RusqliteDatabaseError> {
506    let mut columns = vec![];
507
508    for column in column_names {
509        columns.push((
510            column.to_string(),
511            row.get::<_, Value>(column.as_str())?.into(),
512        ));
513    }
514
515    Ok(crate::Row { columns })
516}
517
518fn update_and_get_row(
519    connection: &Connection,
520    table_name: &str,
521    values: &[(&str, Box<dyn Expression>)],
522    filters: Option<&[Box<dyn BooleanExpression>]>,
523    limit: Option<usize>,
524) -> Result<Option<crate::Row>, RusqliteDatabaseError> {
525    let select_query = limit.map(|_| {
526        format!(
527            "SELECT rowid FROM {table_name} {}",
528            build_where_clause(filters),
529        )
530    });
531
532    let query = format!(
533        "UPDATE {table_name} {} {} RETURNING *",
534        build_set_clause(values),
535        build_update_where_clause(filters, limit, select_query.as_deref()),
536    );
537
538    let all_values = values
539        .iter()
540        .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
541        .map(std::convert::Into::into)
542        .collect::<Vec<_>>();
543    let mut all_filter_values = filters
544        .map(|filters| {
545            filters
546                .iter()
547                .flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
548                .map(std::convert::Into::into)
549                .collect::<Vec<_>>()
550        })
551        .unwrap_or_default();
552
553    if limit.is_some() {
554        all_filter_values.extend(all_filter_values.clone());
555    }
556
557    let all_values = [all_values, all_filter_values].concat();
558
559    log::trace!("Running update query: {query} with params: {all_values:?}");
560
561    let mut statement = connection.prepare_cached(&query)?;
562
563    bind_values(&mut statement, Some(&all_values), false, 0)?;
564
565    let column_names = statement
566        .column_names()
567        .iter()
568        .map(std::string::ToString::to_string)
569        .collect::<Vec<_>>();
570
571    let mut query = statement.raw_query();
572
573    query
574        .next()?
575        .map(|row| from_row(&column_names, row))
576        .transpose()
577}
578
579fn update_and_get_rows(
580    connection: &Connection,
581    table_name: &str,
582    values: &[(&str, Box<dyn Expression>)],
583    filters: Option<&[Box<dyn BooleanExpression>]>,
584    limit: Option<usize>,
585) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
586    let select_query = limit.map(|_| {
587        format!(
588            "SELECT rowid FROM {table_name} {}",
589            build_where_clause(filters),
590        )
591    });
592
593    let query = format!(
594        "UPDATE {table_name} {} {} RETURNING *",
595        build_set_clause(values),
596        build_update_where_clause(filters, limit, select_query.as_deref()),
597    );
598
599    let all_values = values
600        .iter()
601        .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
602        .map(std::convert::Into::into)
603        .collect::<Vec<_>>();
604    let mut all_filter_values = filters
605        .map(|filters| {
606            filters
607                .iter()
608                .flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
609                .map(std::convert::Into::into)
610                .collect::<Vec<_>>()
611        })
612        .unwrap_or_default();
613
614    if limit.is_some() {
615        all_filter_values.extend(all_filter_values.clone());
616    }
617
618    let all_values = [all_values, all_filter_values].concat();
619
620    log::trace!("Running update query: {query} with params: {all_values:?}");
621
622    let mut statement = connection.prepare_cached(&query)?;
623    bind_values(&mut statement, Some(&all_values), false, 0)?;
624    let column_names = statement
625        .column_names()
626        .iter()
627        .map(std::string::ToString::to_string)
628        .collect::<Vec<_>>();
629
630    to_rows(&column_names, statement.raw_query())
631}
632
633fn build_join_clauses(joins: Option<&[Join]>) -> String {
634    joins.map_or_else(String::new, |joins| {
635        joins
636            .iter()
637            .map(|join| {
638                format!(
639                    "{}JOIN {} ON {}",
640                    if join.left { "LEFT " } else { "" },
641                    join.table_name,
642                    join.on
643                )
644            })
645            .collect::<Vec<_>>()
646            .join(" ")
647    })
648}
649
650fn build_where_clause(filters: Option<&[Box<dyn BooleanExpression>]>) -> String {
651    filters.map_or_else(String::new, |filters| {
652        if filters.is_empty() {
653            String::new()
654        } else {
655            format!("WHERE {}", build_where_props(filters).join(" AND "))
656        }
657    })
658}
659
660fn build_where_props(filters: &[Box<dyn BooleanExpression>]) -> Vec<String> {
661    filters
662        .iter()
663        .map(|filter| filter.deref().to_sql())
664        .collect()
665}
666
667fn build_sort_clause(sorts: Option<&[Sort]>) -> String {
668    sorts.map_or_else(String::new, |sorts| {
669        if sorts.is_empty() {
670            String::new()
671        } else {
672            format!("ORDER BY {}", build_sort_props(sorts).join(", "))
673        }
674    })
675}
676
677fn build_sort_props(sorts: &[Sort]) -> Vec<String> {
678    sorts.iter().map(Sort::to_sql).collect()
679}
680
681fn build_update_where_clause(
682    filters: Option<&[Box<dyn BooleanExpression>]>,
683    limit: Option<usize>,
684    query: Option<&str>,
685) -> String {
686    let clause = build_where_clause(filters);
687    let limit_clause = build_update_limit_clause(limit, query);
688
689    let clause = if limit_clause.is_empty() {
690        clause
691    } else if clause.is_empty() {
692        "WHERE".into()
693    } else {
694        clause + " AND"
695    };
696
697    format!("{clause} {limit_clause}").trim().to_string()
698}
699
700fn build_update_limit_clause(limit: Option<usize>, query: Option<&str>) -> String {
701    limit.map_or_else(String::new, |limit| {
702        query.map_or_else(String::new, |query| {
703            format!("rowid IN ({query} LIMIT {limit})")
704        })
705    })
706}
707
708fn build_set_clause(values: &[(&str, Box<dyn Expression>)]) -> String {
709    if values.is_empty() {
710        String::new()
711    } else {
712        format!("SET {}", build_set_props(values).join(", "))
713    }
714}
715
716fn build_set_props(values: &[(&str, Box<dyn Expression>)]) -> Vec<String> {
717    values
718        .iter()
719        .map(|(name, value)| format!("{name}=({})", value.deref().to_sql()))
720        .collect()
721}
722
723fn build_values_clause(values: &[(&str, Box<dyn Expression>)]) -> String {
724    if values.is_empty() {
725        "DEFAULT VALUES".to_string()
726    } else {
727        format!("VALUES({})", build_values_props(values).join(", "))
728    }
729}
730
731fn build_values_props(values: &[(&str, Box<dyn Expression>)]) -> Vec<String> {
732    values
733        .iter()
734        .map(|(_, value)| value.deref().to_sql())
735        .collect()
736}
737
738fn bind_values(
739    statement: &mut Statement<'_>,
740    values: Option<&[RusqliteDatabaseValue]>,
741    constant_inc: bool,
742    offset: usize,
743) -> Result<usize, RusqliteDatabaseError> {
744    if let Some(values) = values {
745        let mut i = 1 + offset;
746        for value in values {
747            match &**value {
748                DatabaseValue::String(value) => {
749                    statement.raw_bind_parameter(i, value)?;
750                    if !constant_inc {
751                        i += 1;
752                    }
753                }
754                DatabaseValue::StringOpt(Some(value)) => {
755                    statement.raw_bind_parameter(i, value)?;
756                    if !constant_inc {
757                        i += 1;
758                    }
759                }
760                DatabaseValue::Null
761                | DatabaseValue::StringOpt(None)
762                | DatabaseValue::BoolOpt(None)
763                | DatabaseValue::NumberOpt(None)
764                | DatabaseValue::UNumberOpt(None)
765                | DatabaseValue::RealOpt(None)
766                | DatabaseValue::Now => (),
767                DatabaseValue::NowAdd(_add) => (),
768                DatabaseValue::Bool(value) => {
769                    statement.raw_bind_parameter(i, i32::from(*value))?;
770                    if !constant_inc {
771                        i += 1;
772                    }
773                }
774                DatabaseValue::BoolOpt(Some(value)) => {
775                    statement.raw_bind_parameter(i, value)?;
776                    if !constant_inc {
777                        i += 1;
778                    }
779                }
780                DatabaseValue::Number(value) => {
781                    statement.raw_bind_parameter(i, *value)?;
782                    if !constant_inc {
783                        i += 1;
784                    }
785                }
786                DatabaseValue::NumberOpt(Some(value)) => {
787                    statement.raw_bind_parameter(i, *value)?;
788                    if !constant_inc {
789                        i += 1;
790                    }
791                }
792                DatabaseValue::UNumber(value) => {
793                    statement.raw_bind_parameter(i, *value)?;
794                    if !constant_inc {
795                        i += 1;
796                    }
797                }
798                DatabaseValue::UNumberOpt(Some(value)) => {
799                    statement.raw_bind_parameter(i, *value)?;
800                    if !constant_inc {
801                        i += 1;
802                    }
803                }
804                DatabaseValue::Real(value) => {
805                    statement.raw_bind_parameter(i, *value)?;
806                    if !constant_inc {
807                        i += 1;
808                    }
809                }
810                DatabaseValue::RealOpt(Some(value)) => {
811                    statement.raw_bind_parameter(i, *value)?;
812                    if !constant_inc {
813                        i += 1;
814                    }
815                }
816                DatabaseValue::DateTime(value) => {
817                    // FIXME: Actually format the date
818                    statement.raw_bind_parameter(i, value.to_string())?;
819                    if !constant_inc {
820                        i += 1;
821                    }
822                }
823            }
824            if constant_inc {
825                i += 1;
826            }
827        }
828        Ok(i - 1)
829    } else {
830        Ok(0)
831    }
832}
833
834fn to_rows(
835    column_names: &[String],
836    mut rows: Rows<'_>,
837) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
838    let mut results = vec![];
839
840    while let Some(row) = rows.next()? {
841        results.push(from_row(column_names, row)?);
842    }
843
844    log::trace!(
845        "Got {} row{}",
846        results.len(),
847        if results.len() == 1 { "" } else { "s" }
848    );
849
850    Ok(results)
851}
852
853fn to_values(values: &[(&str, DatabaseValue)]) -> Vec<RusqliteDatabaseValue> {
854    values
855        .iter()
856        .map(|(_key, value)| value.clone())
857        .map(std::convert::Into::into)
858        .collect::<Vec<_>>()
859}
860
861fn exprs_to_values(values: &[(&str, Box<dyn Expression>)]) -> Vec<RusqliteDatabaseValue> {
862    values
863        .iter()
864        .flat_map(|value| value.1.values().into_iter())
865        .flatten()
866        .cloned()
867        .map(std::convert::Into::into)
868        .collect::<Vec<_>>()
869}
870
871fn bexprs_to_values(values: &[Box<dyn BooleanExpression>]) -> Vec<RusqliteDatabaseValue> {
872    values
873        .iter()
874        .flat_map(|value| value.values().into_iter())
875        .flatten()
876        .cloned()
877        .map(std::convert::Into::into)
878        .collect::<Vec<_>>()
879}
880
881#[allow(unused)]
882fn to_values_opt(values: Option<&[(&str, DatabaseValue)]>) -> Option<Vec<RusqliteDatabaseValue>> {
883    values.map(to_values)
884}
885
886#[allow(unused)]
887fn exprs_to_values_opt(
888    values: Option<&[(&str, Box<dyn Expression>)]>,
889) -> Option<Vec<RusqliteDatabaseValue>> {
890    values.map(exprs_to_values)
891}
892
893fn bexprs_to_values_opt(
894    values: Option<&[Box<dyn BooleanExpression>]>,
895) -> Option<Vec<RusqliteDatabaseValue>> {
896    values.map(bexprs_to_values)
897}
898
899#[allow(clippy::too_many_arguments)]
900fn select(
901    connection: &Connection,
902    table_name: &str,
903    distinct: bool,
904    columns: &[&str],
905    filters: Option<&[Box<dyn BooleanExpression>]>,
906    joins: Option<&[Join<'_>]>,
907    sort: Option<&[Sort]>,
908    limit: Option<usize>,
909) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
910    let query = format!(
911        "SELECT {} {} FROM {table_name} {} {} {} {}",
912        if distinct { "DISTINCT" } else { "" },
913        columns.join(", "),
914        build_join_clauses(joins),
915        build_where_clause(filters),
916        build_sort_clause(sort),
917        limit.map_or_else(String::new, |limit| format!("LIMIT {limit}"))
918    );
919
920    log::trace!(
921        "Running select query: {query} with params: {:?}",
922        filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
923    );
924
925    let mut statement = connection.prepare_cached(&query)?;
926    let column_names = statement
927        .column_names()
928        .iter()
929        .map(std::string::ToString::to_string)
930        .collect::<Vec<_>>();
931
932    bind_values(
933        &mut statement,
934        bexprs_to_values_opt(filters).as_deref(),
935        false,
936        0,
937    )?;
938
939    to_rows(&column_names, statement.raw_query())
940}
941
942fn delete(
943    connection: &Connection,
944    table_name: &str,
945    filters: Option<&[Box<dyn BooleanExpression>]>,
946    limit: Option<usize>,
947) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
948    let where_clause = build_where_clause(filters);
949
950    let select_query = limit.map(|_| format!("SELECT rowid FROM {table_name} {where_clause}",));
951
952    let query = format!(
953        "DELETE FROM {table_name} {} RETURNING *",
954        build_update_where_clause(filters, limit, select_query.as_deref()),
955    );
956
957    let mut all_filter_values: Vec<RusqliteDatabaseValue> = filters
958        .map(|filters| {
959            filters
960                .iter()
961                .flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
962                .map(std::convert::Into::into)
963                .collect::<Vec<_>>()
964        })
965        .unwrap_or_default();
966
967    if limit.is_some() {
968        all_filter_values.extend(all_filter_values.clone());
969    }
970
971    log::trace!(
972        "Running delete query: {query} with params: {:?}",
973        all_filter_values
974            .iter()
975            .filter_map(super::query::Expression::params)
976            .collect::<Vec<_>>()
977    );
978
979    let mut statement = connection.prepare_cached(&query)?;
980    let column_names = statement
981        .column_names()
982        .iter()
983        .map(std::string::ToString::to_string)
984        .collect::<Vec<_>>();
985
986    bind_values(&mut statement, Some(&all_filter_values), false, 0)?;
987
988    to_rows(&column_names, statement.raw_query())
989}
990
991fn find_row(
992    connection: &Connection,
993    table_name: &str,
994    distinct: bool,
995    columns: &[&str],
996    filters: Option<&[Box<dyn BooleanExpression>]>,
997    joins: Option<&[Join]>,
998    sort: Option<&[Sort]>,
999) -> Result<Option<crate::Row>, RusqliteDatabaseError> {
1000    let query = format!(
1001        "SELECT {} {} FROM {table_name} {} {} {} LIMIT 1",
1002        if distinct { "DISTINCT" } else { "" },
1003        columns.join(", "),
1004        build_join_clauses(joins),
1005        build_where_clause(filters),
1006        build_sort_clause(sort),
1007    );
1008
1009    let mut statement = connection.prepare_cached(&query)?;
1010    let column_names = statement
1011        .column_names()
1012        .iter()
1013        .map(std::string::ToString::to_string)
1014        .collect::<Vec<_>>();
1015
1016    bind_values(
1017        &mut statement,
1018        bexprs_to_values_opt(filters).as_deref(),
1019        false,
1020        0,
1021    )?;
1022
1023    log::trace!(
1024        "Running find_row query: {query} with params: {:?}",
1025        filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
1026    );
1027
1028    let mut query = statement.raw_query();
1029
1030    query
1031        .next()?
1032        .map(|row| from_row(&column_names, row))
1033        .transpose()
1034}
1035
1036fn insert_and_get_row(
1037    connection: &Connection,
1038    table_name: &str,
1039    values: &[(&str, Box<dyn Expression>)],
1040) -> Result<crate::Row, RusqliteDatabaseError> {
1041    let column_names = values
1042        .iter()
1043        .map(|(key, _v)| format!("`{key}`"))
1044        .collect::<Vec<_>>()
1045        .join(", ");
1046
1047    let insert_columns = if values.is_empty() {
1048        String::new()
1049    } else {
1050        format!("({column_names})")
1051    };
1052    let query = format!(
1053        "INSERT INTO {table_name} {insert_columns} {} RETURNING *",
1054        build_values_clause(values),
1055    );
1056
1057    let mut statement = connection.prepare_cached(&query)?;
1058    let column_names = statement
1059        .column_names()
1060        .iter()
1061        .map(std::string::ToString::to_string)
1062        .collect::<Vec<_>>();
1063
1064    bind_values(&mut statement, Some(&exprs_to_values(values)), false, 0)?;
1065
1066    log::trace!(
1067        "Running insert_and_get_row query: {query} with params: {:?}",
1068        values
1069            .iter()
1070            .filter_map(|(_, x)| x.params())
1071            .collect::<Vec<_>>()
1072    );
1073
1074    let mut query = statement.raw_query();
1075
1076    query
1077        .next()?
1078        .map(|row| from_row(&column_names, row))
1079        .ok_or(RusqliteDatabaseError::NoRow)?
1080}
1081
1082/// # Errors
1083///
1084/// Will return `Err` if the update multi execution failed.
1085pub fn update_multi(
1086    connection: &Connection,
1087    table_name: &str,
1088    values: &[Vec<(&str, Box<dyn Expression>)>],
1089    filters: Option<&[Box<dyn BooleanExpression>]>,
1090    mut limit: Option<usize>,
1091) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1092    let mut results = vec![];
1093
1094    if values.is_empty() {
1095        return Ok(results);
1096    }
1097
1098    let mut pos = 0;
1099    let mut i = 0;
1100    let mut last_i = i;
1101
1102    for value in values {
1103        let count = value.len();
1104        if pos + count >= (i16::MAX - 1) as usize {
1105            results.append(&mut update_chunk(
1106                connection,
1107                table_name,
1108                &values[last_i..i],
1109                filters,
1110                limit,
1111            )?);
1112            last_i = i;
1113            pos = 0;
1114        }
1115        i += 1;
1116        pos += count;
1117
1118        if let Some(value) = limit {
1119            if count >= value {
1120                return Ok(results);
1121            }
1122
1123            limit.replace(value - count);
1124        }
1125    }
1126
1127    if i > last_i {
1128        results.append(&mut update_chunk(
1129            connection,
1130            table_name,
1131            &values[last_i..],
1132            filters,
1133            limit,
1134        )?);
1135    }
1136
1137    Ok(results)
1138}
1139
1140fn update_chunk(
1141    connection: &Connection,
1142    table_name: &str,
1143    values: &[Vec<(&str, Box<dyn Expression>)>],
1144    filters: Option<&[Box<dyn BooleanExpression>]>,
1145    limit: Option<usize>,
1146) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1147    let first = values[0].as_slice();
1148    let expected_value_size = first.len();
1149
1150    if let Some(bad_row) = values.iter().skip(1).find(|v| {
1151        v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
1152    }) {
1153        log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
1154        return Err(RusqliteDatabaseError::InvalidRequest);
1155    }
1156
1157    let set_clause = values[0]
1158        .iter()
1159        .map(|(name, _value)| format!("`{name}` = EXCLUDED.`{name}`"))
1160        .collect::<Vec<_>>()
1161        .join(", ");
1162
1163    let column_names = values[0]
1164        .iter()
1165        .map(|(key, _v)| format!("`{key}`"))
1166        .collect::<Vec<_>>()
1167        .join(", ");
1168
1169    let select_query = limit.map(|_| {
1170        format!(
1171            "SELECT rowid FROM {table_name} {}",
1172            build_where_clause(filters),
1173        )
1174    });
1175
1176    let query = format!(
1177        "
1178        UPDATE {table_name} ({column_names})
1179        {}
1180        SET {set_clause}
1181        RETURNING *",
1182        build_update_where_clause(filters, limit, select_query.as_deref()),
1183    );
1184
1185    let all_values = values
1186        .iter()
1187        .flat_map(std::iter::IntoIterator::into_iter)
1188        .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
1189        .map(std::convert::Into::into)
1190        .collect::<Vec<_>>();
1191    let mut all_filter_values = filters
1192        .as_ref()
1193        .map(|filters| {
1194            filters
1195                .iter()
1196                .flat_map(|value| {
1197                    value
1198                        .params()
1199                        .unwrap_or_default()
1200                        .into_iter()
1201                        .cloned()
1202                        .map(std::convert::Into::into)
1203                        .collect::<Vec<_>>()
1204                })
1205                .collect::<Vec<_>>()
1206        })
1207        .unwrap_or_default();
1208
1209    if limit.is_some() {
1210        all_filter_values.extend(all_filter_values.clone());
1211    }
1212
1213    let all_values = [all_values, all_filter_values].concat();
1214
1215    log::trace!("Running update chunk query: {query} with params: {all_values:?}");
1216
1217    let mut statement = connection.prepare_cached(&query)?;
1218    let column_names = statement
1219        .column_names()
1220        .iter()
1221        .map(std::string::ToString::to_string)
1222        .collect::<Vec<_>>();
1223
1224    bind_values(&mut statement, Some(&all_values), true, 0)?;
1225
1226    to_rows(&column_names, statement.raw_query())
1227}
1228
1229/// # Errors
1230///
1231/// Will return `Err` if the upsert multi execution failed.
1232pub fn upsert_multi(
1233    connection: &Connection,
1234    table_name: &str,
1235    unique: &[Box<dyn Expression>],
1236    values: &[Vec<(&str, Box<dyn Expression>)>],
1237) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1238    let mut results = vec![];
1239
1240    if values.is_empty() {
1241        return Ok(results);
1242    }
1243
1244    let mut pos = 0;
1245    let mut i = 0;
1246    let mut last_i = i;
1247
1248    for value in values {
1249        let count = value.len();
1250        if pos + count >= (i16::MAX - 1) as usize {
1251            results.append(&mut upsert_chunk(
1252                connection,
1253                table_name,
1254                unique,
1255                &values[last_i..i],
1256            )?);
1257            last_i = i;
1258            pos = 0;
1259        }
1260        i += 1;
1261        pos += count;
1262    }
1263
1264    if i > last_i {
1265        results.append(&mut upsert_chunk(
1266            connection,
1267            table_name,
1268            unique,
1269            &values[last_i..],
1270        )?);
1271    }
1272
1273    Ok(results)
1274}
1275
1276fn upsert_chunk(
1277    connection: &Connection,
1278    table_name: &str,
1279    unique: &[Box<dyn Expression>],
1280    values: &[Vec<(&str, Box<dyn Expression>)>],
1281) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1282    let first = values[0].as_slice();
1283    let expected_value_size = first.len();
1284
1285    if let Some(bad_row) = values.iter().skip(1).find(|v| {
1286        v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
1287    }) {
1288        log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
1289        return Err(RusqliteDatabaseError::InvalidRequest);
1290    }
1291
1292    let set_clause = values[0]
1293        .iter()
1294        .map(|(name, _value)| format!("`{name}` = EXCLUDED.`{name}`"))
1295        .collect::<Vec<_>>()
1296        .join(", ");
1297
1298    let column_names = values[0]
1299        .iter()
1300        .map(|(key, _v)| format!("`{key}`"))
1301        .collect::<Vec<_>>()
1302        .join(", ");
1303
1304    let values_str_list = values
1305        .iter()
1306        .map(|v| format!("({})", build_values_props(v).join(", ")))
1307        .collect::<Vec<_>>();
1308
1309    let values_str = values_str_list.join(", ");
1310    let values_str = if values_str.is_empty() {
1311        "DEFAULT VALUES".to_string()
1312    } else {
1313        format!("VALUES {values_str}")
1314    };
1315
1316    let unique_conflict = unique
1317        .iter()
1318        .map(|x| x.to_sql())
1319        .collect::<Vec<_>>()
1320        .join(", ");
1321
1322    let insert_columns = if values.is_empty() {
1323        String::new()
1324    } else {
1325        format!("({column_names})")
1326    };
1327    let query = format!(
1328        "
1329        INSERT INTO {table_name} {insert_columns} {values_str}
1330        ON CONFLICT({unique_conflict}) DO UPDATE
1331            SET {set_clause}
1332        RETURNING *"
1333    );
1334
1335    let all_values = &values
1336        .iter()
1337        .flat_map(std::iter::IntoIterator::into_iter)
1338        .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
1339        .map(std::convert::Into::into)
1340        .collect::<Vec<_>>();
1341
1342    log::trace!("Running upsert chunk query: {query} with params: {all_values:?}");
1343
1344    let mut statement = connection.prepare_cached(&query)?;
1345    let column_names = statement
1346        .column_names()
1347        .iter()
1348        .map(std::string::ToString::to_string)
1349        .collect::<Vec<_>>();
1350
1351    bind_values(&mut statement, Some(all_values), true, 0)?;
1352
1353    to_rows(&column_names, statement.raw_query())
1354}
1355
1356fn upsert(
1357    connection: &Connection,
1358    table_name: &str,
1359    values: &[(&str, Box<dyn Expression>)],
1360    filters: Option<&[Box<dyn BooleanExpression>]>,
1361    limit: Option<usize>,
1362) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1363    let rows = update_and_get_rows(connection, table_name, values, filters, limit)?;
1364
1365    Ok(if rows.is_empty() {
1366        vec![insert_and_get_row(connection, table_name, values)?]
1367    } else {
1368        rows
1369    })
1370}
1371
1372#[allow(unused)]
1373fn upsert_and_get_row(
1374    connection: &Connection,
1375    table_name: &str,
1376    values: &[(&str, Box<dyn Expression>)],
1377    filters: Option<&[Box<dyn BooleanExpression>]>,
1378    limit: Option<usize>,
1379) -> Result<crate::Row, RusqliteDatabaseError> {
1380    match find_row(connection, table_name, false, &["*"], filters, None, None)? {
1381        Some(row) => {
1382            let updated =
1383                update_and_get_row(connection, table_name, values, filters, limit)?.unwrap();
1384
1385            let str1 = format!("{row:?}");
1386            let str2 = format!("{updated:?}");
1387
1388            if str1 == str2 {
1389                log::trace!("No updates to {table_name}");
1390            } else {
1391                log::debug!("Changed {table_name} from {str1} to {str2}");
1392            }
1393
1394            Ok(updated)
1395        }
1396        None => Ok(insert_and_get_row(connection, table_name, values)?),
1397    }
1398}
1399
1400#[allow(clippy::module_name_repetitions)]
1401#[derive(Debug, Clone)]
1402pub struct RusqliteDatabaseValue(DatabaseValue);
1403
1404impl From<DatabaseValue> for RusqliteDatabaseValue {
1405    fn from(value: DatabaseValue) -> Self {
1406        Self(value)
1407    }
1408}
1409
1410impl Deref for RusqliteDatabaseValue {
1411    type Target = DatabaseValue;
1412
1413    fn deref(&self) -> &Self::Target {
1414        &self.0
1415    }
1416}
1417
1418impl Expression for RusqliteDatabaseValue {
1419    fn values(&self) -> Option<Vec<&DatabaseValue>> {
1420        Some(vec![self])
1421    }
1422
1423    fn is_null(&self) -> bool {
1424        matches!(
1425            self.0,
1426            DatabaseValue::Null
1427                | DatabaseValue::BoolOpt(None)
1428                | DatabaseValue::RealOpt(None)
1429                | DatabaseValue::StringOpt(None)
1430                | DatabaseValue::NumberOpt(None)
1431                | DatabaseValue::UNumberOpt(None)
1432        )
1433    }
1434
1435    fn expression_type(&self) -> ExpressionType {
1436        ExpressionType::DatabaseValue(self)
1437    }
1438}