1use std::{ops::Deref, sync::Arc};
2
3use async_trait::async_trait;
4use rusqlite::{Connection, Row, Rows, Statement, types::Value};
5use thiserror::Error;
6use tokio::sync::Mutex;
7
8use crate::{
9 Database, DatabaseError, DatabaseValue, DeleteStatement, InsertStatement, SelectQuery,
10 UpdateStatement, UpsertMultiStatement, UpsertStatement,
11 query::{BooleanExpression, Expression, ExpressionType, Join, Sort, SortDirection},
12};
13
14#[allow(clippy::module_name_repetitions)]
15#[derive(Debug)]
16pub struct RusqliteDatabase {
17 connection: Arc<Mutex<Connection>>,
18}
19
20impl RusqliteDatabase {
21 pub const fn new(connection: Arc<Mutex<Connection>>) -> Self {
22 Self { connection }
23 }
24}
25
26trait ToSql {
27 fn to_sql(&self) -> String;
28}
29
30impl<T: Expression + ?Sized> ToSql for T {
31 #[allow(clippy::too_many_lines)]
32 fn to_sql(&self) -> String {
33 match self.expression_type() {
34 ExpressionType::Eq(value) => {
35 if value.right.is_null() {
36 format!("({} IS {})", value.left.to_sql(), value.right.to_sql())
37 } else {
38 format!("({} = {})", value.left.to_sql(), value.right.to_sql())
39 }
40 }
41 ExpressionType::Gt(value) => {
42 if value.right.is_null() {
43 panic!("Invalid > comparison with NULL");
44 } else {
45 format!("({} > {})", value.left.to_sql(), value.right.to_sql())
46 }
47 }
48 ExpressionType::In(value) => {
49 format!("{} IN ({})", value.left.to_sql(), value.values.to_sql())
50 }
51 ExpressionType::NotIn(value) => {
52 format!("{} NOT IN ({})", value.left.to_sql(), value.values.to_sql())
53 }
54 ExpressionType::Lt(value) => {
55 if value.right.is_null() {
56 panic!("Invalid < comparison with NULL");
57 } else {
58 format!("({} < {})", value.left.to_sql(), value.right.to_sql())
59 }
60 }
61 ExpressionType::Or(value) => format!(
62 "({})",
63 value
64 .conditions
65 .iter()
66 .map(|x| x.to_sql())
67 .collect::<Vec<_>>()
68 .join(" OR ")
69 ),
70 ExpressionType::And(value) => format!(
71 "({})",
72 value
73 .conditions
74 .iter()
75 .map(|x| x.to_sql())
76 .collect::<Vec<_>>()
77 .join(" AND ")
78 ),
79 ExpressionType::Gte(value) => {
80 if value.right.is_null() {
81 panic!("Invalid >= comparison with NULL");
82 } else {
83 format!("({} >= {})", value.left.to_sql(), value.right.to_sql())
84 }
85 }
86 ExpressionType::Lte(value) => {
87 if value.right.is_null() {
88 panic!("Invalid <= comparison with NULL");
89 } else {
90 format!("({} <= {})", value.left.to_sql(), value.right.to_sql())
91 }
92 }
93 ExpressionType::Join(value) => format!(
94 "{} JOIN {} ON {}",
95 if value.left { "LEFT" } else { "" },
96 value.table_name,
97 value.on
98 ),
99 ExpressionType::Sort(value) => format!(
100 "({}) {}",
101 value.expression.to_sql(),
102 match value.direction {
103 SortDirection::Asc => "ASC",
104 SortDirection::Desc => "DESC",
105 }
106 ),
107 ExpressionType::NotEq(value) => {
108 if value.right.is_null() {
109 format!("({} IS NOT {})", value.left.to_sql(), value.right.to_sql())
110 } else {
111 format!("({} != {})", value.left.to_sql(), value.right.to_sql())
112 }
113 }
114 ExpressionType::InList(value) => value
115 .values
116 .iter()
117 .map(|value| value.to_sql())
118 .collect::<Vec<_>>()
119 .join(","),
120 ExpressionType::Coalesce(value) => format!(
121 "IFNULL({})",
122 value
123 .values
124 .iter()
125 .map(|value| value.to_sql())
126 .collect::<Vec<_>>()
127 .join(",")
128 ),
129 ExpressionType::Literal(value) => value.value.to_string(),
130 ExpressionType::Identifier(value) => value.value.clone(),
131 ExpressionType::SelectQuery(value) => {
132 let joins = value.joins.as_ref().map_or_else(String::new, |joins| {
133 joins.iter().map(Join::to_sql).collect::<Vec<_>>().join(" ")
134 });
135
136 let where_clause = value.filters.as_ref().map_or_else(String::new, |filters| {
137 if filters.is_empty() {
138 String::new()
139 } else {
140 format!(
141 "WHERE {}",
142 filters
143 .iter()
144 .map(|x| format!("({})", x.to_sql()))
145 .collect::<Vec<_>>()
146 .join(" AND ")
147 )
148 }
149 });
150
151 let sort_clause = value.sorts.as_ref().map_or_else(String::new, |sorts| {
152 if sorts.is_empty() {
153 String::new()
154 } else {
155 format!(
156 "ORDER BY {}",
157 sorts
158 .iter()
159 .map(Sort::to_sql)
160 .collect::<Vec<_>>()
161 .join(", ")
162 )
163 }
164 });
165
166 let limit = value
167 .limit
168 .map_or_else(String::new, |limit| format!("LIMIT {limit}"));
169
170 format!(
171 "SELECT {} {} FROM {} {} {} {} {}",
172 if value.distinct { "DISTINCT" } else { "" },
173 value.columns.join(", "),
174 value.table_name,
175 joins,
176 where_clause,
177 sort_clause,
178 limit
179 )
180 }
181 ExpressionType::DatabaseValue(value) => match value {
182 DatabaseValue::Null
183 | DatabaseValue::BoolOpt(None)
184 | DatabaseValue::StringOpt(None)
185 | DatabaseValue::NumberOpt(None)
186 | DatabaseValue::UNumberOpt(None)
187 | DatabaseValue::RealOpt(None) => "NULL".to_string(),
188 DatabaseValue::Now => "strftime('%Y-%m-%dT%H:%M:%f', 'now')".to_string(),
189 DatabaseValue::NowAdd(add) => {
190 format!("strftime('%Y-%m-%dT%H:%M:%f', DateTime('now', 'LocalTime', {add}))")
191 }
192 _ => "?".to_string(),
193 },
194 }
195 }
196}
197
198#[allow(clippy::module_name_repetitions)]
199#[derive(Debug, Error)]
200pub enum RusqliteDatabaseError {
201 #[error(transparent)]
202 Rusqlite(#[from] rusqlite::Error),
203 #[error("No ID")]
204 NoId,
205 #[error("No row")]
206 NoRow,
207 #[error("Invalid request")]
208 InvalidRequest,
209 #[error("Missing unique")]
210 MissingUnique,
211}
212
213impl From<RusqliteDatabaseError> for DatabaseError {
214 fn from(value: RusqliteDatabaseError) -> Self {
215 Self::Rusqlite(value)
216 }
217}
218
219#[async_trait]
220impl Database for RusqliteDatabase {
221 async fn query(&self, query: &SelectQuery<'_>) -> Result<Vec<crate::Row>, DatabaseError> {
222 Ok(select(
223 &*self.connection.lock().await,
224 query.table_name,
225 query.distinct,
226 query.columns,
227 query.filters.as_deref(),
228 query.joins.as_deref(),
229 query.sorts.as_deref(),
230 query.limit,
231 )?)
232 }
233
234 async fn query_first(
235 &self,
236 query: &SelectQuery<'_>,
237 ) -> Result<Option<crate::Row>, DatabaseError> {
238 Ok(find_row(
239 &*self.connection.lock().await,
240 query.table_name,
241 query.distinct,
242 query.columns,
243 query.filters.as_deref(),
244 query.joins.as_deref(),
245 query.sorts.as_deref(),
246 )?)
247 }
248
249 async fn exec_delete(
250 &self,
251 statement: &DeleteStatement<'_>,
252 ) -> Result<Vec<crate::Row>, DatabaseError> {
253 Ok(delete(
254 &*self.connection.lock().await,
255 statement.table_name,
256 statement.filters.as_deref(),
257 statement.limit,
258 )?)
259 }
260
261 async fn exec_delete_first(
262 &self,
263 statement: &DeleteStatement<'_>,
264 ) -> Result<Option<crate::Row>, DatabaseError> {
265 Ok(delete(
266 &*self.connection.lock().await,
267 statement.table_name,
268 statement.filters.as_deref(),
269 Some(1),
270 )?
271 .into_iter()
272 .next())
273 }
274
275 async fn exec_insert(
276 &self,
277 statement: &InsertStatement<'_>,
278 ) -> Result<crate::Row, DatabaseError> {
279 Ok(insert_and_get_row(
280 &*self.connection.lock().await,
281 statement.table_name,
282 &statement.values,
283 )?)
284 }
285
286 async fn exec_update(
287 &self,
288 statement: &UpdateStatement<'_>,
289 ) -> Result<Vec<crate::Row>, DatabaseError> {
290 Ok(update_and_get_rows(
291 &*self.connection.lock().await,
292 statement.table_name,
293 &statement.values,
294 statement.filters.as_deref(),
295 statement.limit,
296 )?)
297 }
298
299 async fn exec_update_first(
300 &self,
301 statement: &UpdateStatement<'_>,
302 ) -> Result<Option<crate::Row>, DatabaseError> {
303 Ok(update_and_get_row(
304 &*self.connection.lock().await,
305 statement.table_name,
306 &statement.values,
307 statement.filters.as_deref(),
308 statement.limit,
309 )?)
310 }
311
312 async fn exec_upsert(
313 &self,
314 statement: &UpsertStatement<'_>,
315 ) -> Result<Vec<crate::Row>, DatabaseError> {
316 Ok(upsert(
317 &*self.connection.lock().await,
318 statement.table_name,
319 &statement.values,
320 statement.filters.as_deref(),
321 statement.limit,
322 )?)
323 }
324
325 async fn exec_upsert_first(
326 &self,
327 statement: &UpsertStatement<'_>,
328 ) -> Result<crate::Row, DatabaseError> {
329 Ok(upsert_and_get_row(
330 &*self.connection.lock().await,
331 statement.table_name,
332 &statement.values,
333 statement.filters.as_deref(),
334 statement.limit,
335 )?)
336 }
337
338 async fn exec_upsert_multi(
339 &self,
340 statement: &UpsertMultiStatement<'_>,
341 ) -> Result<Vec<crate::Row>, DatabaseError> {
342 Ok(upsert_multi(
343 &*self.connection.lock().await,
344 statement.table_name,
345 statement
346 .unique
347 .as_ref()
348 .ok_or(RusqliteDatabaseError::MissingUnique)?,
349 &statement.values,
350 )?)
351 }
352
353 async fn exec_raw(&self, statement: &str) -> Result<(), DatabaseError> {
354 log::trace!("exec_raw: query:\n{statement}");
355
356 self.connection
357 .lock()
358 .await
359 .execute_batch(statement)
360 .map_err(RusqliteDatabaseError::Rusqlite)?;
361 Ok(())
362 }
363
364 #[cfg(feature = "schema")]
365 #[allow(clippy::too_many_lines)]
366 async fn exec_create_table(
367 &self,
368 statement: &crate::schema::CreateTableStatement<'_>,
369 ) -> Result<(), DatabaseError> {
370 let mut query = "CREATE TABLE ".to_string();
371
372 if statement.if_not_exists {
373 query.push_str("IF NOT EXISTS ");
374 }
375
376 query.push_str(statement.table_name);
377 query.push('(');
378
379 let mut first = true;
380
381 for column in &statement.columns {
382 if first {
383 first = false;
384 } else {
385 query.push(',');
386 }
387
388 if column.auto_increment && statement.primary_key.is_none_or(|x| x != column.name) {
389 return Err(DatabaseError::InvalidSchema(format!(
390 "Column '{}' must be the primary key to enable auto increment",
391 &column.name
392 )));
393 }
394
395 query.push_str(&column.name);
396 query.push(' ');
397
398 match column.data_type {
399 crate::schema::DataType::VarChar(size) => {
400 query.push_str("VARCHAR(");
401 query.push_str(&size.to_string());
402 query.push(')');
403 }
404 crate::schema::DataType::Text => query.push_str("TEXT"),
405 crate::schema::DataType::Bool
406 | crate::schema::DataType::SmallInt
407 | crate::schema::DataType::Int
408 | crate::schema::DataType::BigInt => {
409 query.push_str("INTEGER");
410 }
411 crate::schema::DataType::Double
412 | crate::schema::DataType::Decimal(..)
413 | crate::schema::DataType::Real => query.push_str("REAL"),
414 crate::schema::DataType::DateTime => query.push_str("VARCHAR(23)"),
415 }
416
417 if !column.nullable {
418 query.push_str(" NOT NULL");
419 }
420
421 if let Some(default) = &column.default {
422 query.push_str(" DEFAULT ");
423
424 match default {
425 DatabaseValue::Null
426 | DatabaseValue::StringOpt(None)
427 | DatabaseValue::BoolOpt(None)
428 | DatabaseValue::NumberOpt(None)
429 | DatabaseValue::UNumberOpt(None)
430 | DatabaseValue::RealOpt(None) => {
431 query.push_str("NULL");
432 }
433 DatabaseValue::StringOpt(Some(x)) | DatabaseValue::String(x) => {
434 query.push('\'');
435 query.push_str(x);
436 query.push('\'');
437 }
438 DatabaseValue::BoolOpt(Some(x)) | DatabaseValue::Bool(x) => {
439 query.push_str(if *x { "1" } else { "0" });
440 }
441 DatabaseValue::NumberOpt(Some(x)) | DatabaseValue::Number(x) => {
442 query.push_str(&x.to_string());
443 }
444 DatabaseValue::UNumberOpt(Some(x)) | DatabaseValue::UNumber(x) => {
445 query.push_str(&x.to_string());
446 }
447 DatabaseValue::RealOpt(Some(x)) | DatabaseValue::Real(x) => {
448 query.push_str(&x.to_string());
449 }
450 DatabaseValue::NowAdd(x) => {
451 query.push_str(
452 "(strftime('%Y-%m-%dT%H:%M:%f', DateTime('now', 'LocalTime', ",
453 );
454 query.push_str(x);
455 query.push_str(")))");
456 }
457 DatabaseValue::Now => {
458 query.push_str("(strftime('%Y-%m-%dT%H:%M:%f', 'now'))");
459 }
460 DatabaseValue::DateTime(x) => {
461 query.push('\'');
462 query.push_str(&x.and_utc().to_rfc3339());
463 query.push('\'');
464 }
465 }
466 }
467 }
468
469 moosicbox_assert::assert!(!first);
470
471 if let Some(primary_key) = &statement.primary_key {
472 query.push_str(", PRIMARY KEY (");
473 query.push_str(primary_key);
474 query.push(')');
475 }
476
477 for (source, target) in &statement.foreign_keys {
478 query.push_str(", FOREIGN KEY (");
479 query.push_str(source);
480 query.push_str(") REFERENCES (");
481 query.push_str(target);
482 query.push(')');
483 }
484
485 query.push(')');
486
487 self.exec_raw(&query).await?;
488
489 Ok(())
490 }
491}
492
493impl From<Value> for DatabaseValue {
494 fn from(value: Value) -> Self {
495 match value {
496 Value::Null => Self::Null,
497 Value::Integer(value) => Self::Number(value),
498 Value::Real(value) => Self::Real(value),
499 Value::Text(value) => Self::String(value),
500 Value::Blob(_value) => unimplemented!("Blob types are not supported yet"),
501 }
502 }
503}
504
505fn from_row(column_names: &[String], row: &Row<'_>) -> Result<crate::Row, RusqliteDatabaseError> {
506 let mut columns = vec![];
507
508 for column in column_names {
509 columns.push((
510 column.to_string(),
511 row.get::<_, Value>(column.as_str())?.into(),
512 ));
513 }
514
515 Ok(crate::Row { columns })
516}
517
518fn update_and_get_row(
519 connection: &Connection,
520 table_name: &str,
521 values: &[(&str, Box<dyn Expression>)],
522 filters: Option<&[Box<dyn BooleanExpression>]>,
523 limit: Option<usize>,
524) -> Result<Option<crate::Row>, RusqliteDatabaseError> {
525 let select_query = limit.map(|_| {
526 format!(
527 "SELECT rowid FROM {table_name} {}",
528 build_where_clause(filters),
529 )
530 });
531
532 let query = format!(
533 "UPDATE {table_name} {} {} RETURNING *",
534 build_set_clause(values),
535 build_update_where_clause(filters, limit, select_query.as_deref()),
536 );
537
538 let all_values = values
539 .iter()
540 .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
541 .map(std::convert::Into::into)
542 .collect::<Vec<_>>();
543 let mut all_filter_values = filters
544 .map(|filters| {
545 filters
546 .iter()
547 .flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
548 .map(std::convert::Into::into)
549 .collect::<Vec<_>>()
550 })
551 .unwrap_or_default();
552
553 if limit.is_some() {
554 all_filter_values.extend(all_filter_values.clone());
555 }
556
557 let all_values = [all_values, all_filter_values].concat();
558
559 log::trace!("Running update query: {query} with params: {all_values:?}");
560
561 let mut statement = connection.prepare_cached(&query)?;
562
563 bind_values(&mut statement, Some(&all_values), false, 0)?;
564
565 let column_names = statement
566 .column_names()
567 .iter()
568 .map(std::string::ToString::to_string)
569 .collect::<Vec<_>>();
570
571 let mut query = statement.raw_query();
572
573 query
574 .next()?
575 .map(|row| from_row(&column_names, row))
576 .transpose()
577}
578
579fn update_and_get_rows(
580 connection: &Connection,
581 table_name: &str,
582 values: &[(&str, Box<dyn Expression>)],
583 filters: Option<&[Box<dyn BooleanExpression>]>,
584 limit: Option<usize>,
585) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
586 let select_query = limit.map(|_| {
587 format!(
588 "SELECT rowid FROM {table_name} {}",
589 build_where_clause(filters),
590 )
591 });
592
593 let query = format!(
594 "UPDATE {table_name} {} {} RETURNING *",
595 build_set_clause(values),
596 build_update_where_clause(filters, limit, select_query.as_deref()),
597 );
598
599 let all_values = values
600 .iter()
601 .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
602 .map(std::convert::Into::into)
603 .collect::<Vec<_>>();
604 let mut all_filter_values = filters
605 .map(|filters| {
606 filters
607 .iter()
608 .flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
609 .map(std::convert::Into::into)
610 .collect::<Vec<_>>()
611 })
612 .unwrap_or_default();
613
614 if limit.is_some() {
615 all_filter_values.extend(all_filter_values.clone());
616 }
617
618 let all_values = [all_values, all_filter_values].concat();
619
620 log::trace!("Running update query: {query} with params: {all_values:?}");
621
622 let mut statement = connection.prepare_cached(&query)?;
623 bind_values(&mut statement, Some(&all_values), false, 0)?;
624 let column_names = statement
625 .column_names()
626 .iter()
627 .map(std::string::ToString::to_string)
628 .collect::<Vec<_>>();
629
630 to_rows(&column_names, statement.raw_query())
631}
632
633fn build_join_clauses(joins: Option<&[Join]>) -> String {
634 joins.map_or_else(String::new, |joins| {
635 joins
636 .iter()
637 .map(|join| {
638 format!(
639 "{}JOIN {} ON {}",
640 if join.left { "LEFT " } else { "" },
641 join.table_name,
642 join.on
643 )
644 })
645 .collect::<Vec<_>>()
646 .join(" ")
647 })
648}
649
650fn build_where_clause(filters: Option<&[Box<dyn BooleanExpression>]>) -> String {
651 filters.map_or_else(String::new, |filters| {
652 if filters.is_empty() {
653 String::new()
654 } else {
655 format!("WHERE {}", build_where_props(filters).join(" AND "))
656 }
657 })
658}
659
660fn build_where_props(filters: &[Box<dyn BooleanExpression>]) -> Vec<String> {
661 filters
662 .iter()
663 .map(|filter| filter.deref().to_sql())
664 .collect()
665}
666
667fn build_sort_clause(sorts: Option<&[Sort]>) -> String {
668 sorts.map_or_else(String::new, |sorts| {
669 if sorts.is_empty() {
670 String::new()
671 } else {
672 format!("ORDER BY {}", build_sort_props(sorts).join(", "))
673 }
674 })
675}
676
677fn build_sort_props(sorts: &[Sort]) -> Vec<String> {
678 sorts.iter().map(Sort::to_sql).collect()
679}
680
681fn build_update_where_clause(
682 filters: Option<&[Box<dyn BooleanExpression>]>,
683 limit: Option<usize>,
684 query: Option<&str>,
685) -> String {
686 let clause = build_where_clause(filters);
687 let limit_clause = build_update_limit_clause(limit, query);
688
689 let clause = if limit_clause.is_empty() {
690 clause
691 } else if clause.is_empty() {
692 "WHERE".into()
693 } else {
694 clause + " AND"
695 };
696
697 format!("{clause} {limit_clause}").trim().to_string()
698}
699
700fn build_update_limit_clause(limit: Option<usize>, query: Option<&str>) -> String {
701 limit.map_or_else(String::new, |limit| {
702 query.map_or_else(String::new, |query| {
703 format!("rowid IN ({query} LIMIT {limit})")
704 })
705 })
706}
707
708fn build_set_clause(values: &[(&str, Box<dyn Expression>)]) -> String {
709 if values.is_empty() {
710 String::new()
711 } else {
712 format!("SET {}", build_set_props(values).join(", "))
713 }
714}
715
716fn build_set_props(values: &[(&str, Box<dyn Expression>)]) -> Vec<String> {
717 values
718 .iter()
719 .map(|(name, value)| format!("{name}=({})", value.deref().to_sql()))
720 .collect()
721}
722
723fn build_values_clause(values: &[(&str, Box<dyn Expression>)]) -> String {
724 if values.is_empty() {
725 "DEFAULT VALUES".to_string()
726 } else {
727 format!("VALUES({})", build_values_props(values).join(", "))
728 }
729}
730
731fn build_values_props(values: &[(&str, Box<dyn Expression>)]) -> Vec<String> {
732 values
733 .iter()
734 .map(|(_, value)| value.deref().to_sql())
735 .collect()
736}
737
738fn bind_values(
739 statement: &mut Statement<'_>,
740 values: Option<&[RusqliteDatabaseValue]>,
741 constant_inc: bool,
742 offset: usize,
743) -> Result<usize, RusqliteDatabaseError> {
744 if let Some(values) = values {
745 let mut i = 1 + offset;
746 for value in values {
747 match &**value {
748 DatabaseValue::String(value) => {
749 statement.raw_bind_parameter(i, value)?;
750 if !constant_inc {
751 i += 1;
752 }
753 }
754 DatabaseValue::StringOpt(Some(value)) => {
755 statement.raw_bind_parameter(i, value)?;
756 if !constant_inc {
757 i += 1;
758 }
759 }
760 DatabaseValue::Null
761 | DatabaseValue::StringOpt(None)
762 | DatabaseValue::BoolOpt(None)
763 | DatabaseValue::NumberOpt(None)
764 | DatabaseValue::UNumberOpt(None)
765 | DatabaseValue::RealOpt(None)
766 | DatabaseValue::Now => (),
767 DatabaseValue::NowAdd(_add) => (),
768 DatabaseValue::Bool(value) => {
769 statement.raw_bind_parameter(i, i32::from(*value))?;
770 if !constant_inc {
771 i += 1;
772 }
773 }
774 DatabaseValue::BoolOpt(Some(value)) => {
775 statement.raw_bind_parameter(i, value)?;
776 if !constant_inc {
777 i += 1;
778 }
779 }
780 DatabaseValue::Number(value) => {
781 statement.raw_bind_parameter(i, *value)?;
782 if !constant_inc {
783 i += 1;
784 }
785 }
786 DatabaseValue::NumberOpt(Some(value)) => {
787 statement.raw_bind_parameter(i, *value)?;
788 if !constant_inc {
789 i += 1;
790 }
791 }
792 DatabaseValue::UNumber(value) => {
793 statement.raw_bind_parameter(i, *value)?;
794 if !constant_inc {
795 i += 1;
796 }
797 }
798 DatabaseValue::UNumberOpt(Some(value)) => {
799 statement.raw_bind_parameter(i, *value)?;
800 if !constant_inc {
801 i += 1;
802 }
803 }
804 DatabaseValue::Real(value) => {
805 statement.raw_bind_parameter(i, *value)?;
806 if !constant_inc {
807 i += 1;
808 }
809 }
810 DatabaseValue::RealOpt(Some(value)) => {
811 statement.raw_bind_parameter(i, *value)?;
812 if !constant_inc {
813 i += 1;
814 }
815 }
816 DatabaseValue::DateTime(value) => {
817 statement.raw_bind_parameter(i, value.to_string())?;
819 if !constant_inc {
820 i += 1;
821 }
822 }
823 }
824 if constant_inc {
825 i += 1;
826 }
827 }
828 Ok(i - 1)
829 } else {
830 Ok(0)
831 }
832}
833
834fn to_rows(
835 column_names: &[String],
836 mut rows: Rows<'_>,
837) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
838 let mut results = vec![];
839
840 while let Some(row) = rows.next()? {
841 results.push(from_row(column_names, row)?);
842 }
843
844 log::trace!(
845 "Got {} row{}",
846 results.len(),
847 if results.len() == 1 { "" } else { "s" }
848 );
849
850 Ok(results)
851}
852
853fn to_values(values: &[(&str, DatabaseValue)]) -> Vec<RusqliteDatabaseValue> {
854 values
855 .iter()
856 .map(|(_key, value)| value.clone())
857 .map(std::convert::Into::into)
858 .collect::<Vec<_>>()
859}
860
861fn exprs_to_values(values: &[(&str, Box<dyn Expression>)]) -> Vec<RusqliteDatabaseValue> {
862 values
863 .iter()
864 .flat_map(|value| value.1.values().into_iter())
865 .flatten()
866 .cloned()
867 .map(std::convert::Into::into)
868 .collect::<Vec<_>>()
869}
870
871fn bexprs_to_values(values: &[Box<dyn BooleanExpression>]) -> Vec<RusqliteDatabaseValue> {
872 values
873 .iter()
874 .flat_map(|value| value.values().into_iter())
875 .flatten()
876 .cloned()
877 .map(std::convert::Into::into)
878 .collect::<Vec<_>>()
879}
880
881#[allow(unused)]
882fn to_values_opt(values: Option<&[(&str, DatabaseValue)]>) -> Option<Vec<RusqliteDatabaseValue>> {
883 values.map(to_values)
884}
885
886#[allow(unused)]
887fn exprs_to_values_opt(
888 values: Option<&[(&str, Box<dyn Expression>)]>,
889) -> Option<Vec<RusqliteDatabaseValue>> {
890 values.map(exprs_to_values)
891}
892
893fn bexprs_to_values_opt(
894 values: Option<&[Box<dyn BooleanExpression>]>,
895) -> Option<Vec<RusqliteDatabaseValue>> {
896 values.map(bexprs_to_values)
897}
898
899#[allow(clippy::too_many_arguments)]
900fn select(
901 connection: &Connection,
902 table_name: &str,
903 distinct: bool,
904 columns: &[&str],
905 filters: Option<&[Box<dyn BooleanExpression>]>,
906 joins: Option<&[Join<'_>]>,
907 sort: Option<&[Sort]>,
908 limit: Option<usize>,
909) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
910 let query = format!(
911 "SELECT {} {} FROM {table_name} {} {} {} {}",
912 if distinct { "DISTINCT" } else { "" },
913 columns.join(", "),
914 build_join_clauses(joins),
915 build_where_clause(filters),
916 build_sort_clause(sort),
917 limit.map_or_else(String::new, |limit| format!("LIMIT {limit}"))
918 );
919
920 log::trace!(
921 "Running select query: {query} with params: {:?}",
922 filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
923 );
924
925 let mut statement = connection.prepare_cached(&query)?;
926 let column_names = statement
927 .column_names()
928 .iter()
929 .map(std::string::ToString::to_string)
930 .collect::<Vec<_>>();
931
932 bind_values(
933 &mut statement,
934 bexprs_to_values_opt(filters).as_deref(),
935 false,
936 0,
937 )?;
938
939 to_rows(&column_names, statement.raw_query())
940}
941
942fn delete(
943 connection: &Connection,
944 table_name: &str,
945 filters: Option<&[Box<dyn BooleanExpression>]>,
946 limit: Option<usize>,
947) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
948 let where_clause = build_where_clause(filters);
949
950 let select_query = limit.map(|_| format!("SELECT rowid FROM {table_name} {where_clause}",));
951
952 let query = format!(
953 "DELETE FROM {table_name} {} RETURNING *",
954 build_update_where_clause(filters, limit, select_query.as_deref()),
955 );
956
957 let mut all_filter_values: Vec<RusqliteDatabaseValue> = filters
958 .map(|filters| {
959 filters
960 .iter()
961 .flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
962 .map(std::convert::Into::into)
963 .collect::<Vec<_>>()
964 })
965 .unwrap_or_default();
966
967 if limit.is_some() {
968 all_filter_values.extend(all_filter_values.clone());
969 }
970
971 log::trace!(
972 "Running delete query: {query} with params: {:?}",
973 all_filter_values
974 .iter()
975 .filter_map(super::query::Expression::params)
976 .collect::<Vec<_>>()
977 );
978
979 let mut statement = connection.prepare_cached(&query)?;
980 let column_names = statement
981 .column_names()
982 .iter()
983 .map(std::string::ToString::to_string)
984 .collect::<Vec<_>>();
985
986 bind_values(&mut statement, Some(&all_filter_values), false, 0)?;
987
988 to_rows(&column_names, statement.raw_query())
989}
990
991fn find_row(
992 connection: &Connection,
993 table_name: &str,
994 distinct: bool,
995 columns: &[&str],
996 filters: Option<&[Box<dyn BooleanExpression>]>,
997 joins: Option<&[Join]>,
998 sort: Option<&[Sort]>,
999) -> Result<Option<crate::Row>, RusqliteDatabaseError> {
1000 let query = format!(
1001 "SELECT {} {} FROM {table_name} {} {} {} LIMIT 1",
1002 if distinct { "DISTINCT" } else { "" },
1003 columns.join(", "),
1004 build_join_clauses(joins),
1005 build_where_clause(filters),
1006 build_sort_clause(sort),
1007 );
1008
1009 let mut statement = connection.prepare_cached(&query)?;
1010 let column_names = statement
1011 .column_names()
1012 .iter()
1013 .map(std::string::ToString::to_string)
1014 .collect::<Vec<_>>();
1015
1016 bind_values(
1017 &mut statement,
1018 bexprs_to_values_opt(filters).as_deref(),
1019 false,
1020 0,
1021 )?;
1022
1023 log::trace!(
1024 "Running find_row query: {query} with params: {:?}",
1025 filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
1026 );
1027
1028 let mut query = statement.raw_query();
1029
1030 query
1031 .next()?
1032 .map(|row| from_row(&column_names, row))
1033 .transpose()
1034}
1035
1036fn insert_and_get_row(
1037 connection: &Connection,
1038 table_name: &str,
1039 values: &[(&str, Box<dyn Expression>)],
1040) -> Result<crate::Row, RusqliteDatabaseError> {
1041 let column_names = values
1042 .iter()
1043 .map(|(key, _v)| format!("`{key}`"))
1044 .collect::<Vec<_>>()
1045 .join(", ");
1046
1047 let insert_columns = if values.is_empty() {
1048 String::new()
1049 } else {
1050 format!("({column_names})")
1051 };
1052 let query = format!(
1053 "INSERT INTO {table_name} {insert_columns} {} RETURNING *",
1054 build_values_clause(values),
1055 );
1056
1057 let mut statement = connection.prepare_cached(&query)?;
1058 let column_names = statement
1059 .column_names()
1060 .iter()
1061 .map(std::string::ToString::to_string)
1062 .collect::<Vec<_>>();
1063
1064 bind_values(&mut statement, Some(&exprs_to_values(values)), false, 0)?;
1065
1066 log::trace!(
1067 "Running insert_and_get_row query: {query} with params: {:?}",
1068 values
1069 .iter()
1070 .filter_map(|(_, x)| x.params())
1071 .collect::<Vec<_>>()
1072 );
1073
1074 let mut query = statement.raw_query();
1075
1076 query
1077 .next()?
1078 .map(|row| from_row(&column_names, row))
1079 .ok_or(RusqliteDatabaseError::NoRow)?
1080}
1081
1082pub fn update_multi(
1086 connection: &Connection,
1087 table_name: &str,
1088 values: &[Vec<(&str, Box<dyn Expression>)>],
1089 filters: Option<&[Box<dyn BooleanExpression>]>,
1090 mut limit: Option<usize>,
1091) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1092 let mut results = vec![];
1093
1094 if values.is_empty() {
1095 return Ok(results);
1096 }
1097
1098 let mut pos = 0;
1099 let mut i = 0;
1100 let mut last_i = i;
1101
1102 for value in values {
1103 let count = value.len();
1104 if pos + count >= (i16::MAX - 1) as usize {
1105 results.append(&mut update_chunk(
1106 connection,
1107 table_name,
1108 &values[last_i..i],
1109 filters,
1110 limit,
1111 )?);
1112 last_i = i;
1113 pos = 0;
1114 }
1115 i += 1;
1116 pos += count;
1117
1118 if let Some(value) = limit {
1119 if count >= value {
1120 return Ok(results);
1121 }
1122
1123 limit.replace(value - count);
1124 }
1125 }
1126
1127 if i > last_i {
1128 results.append(&mut update_chunk(
1129 connection,
1130 table_name,
1131 &values[last_i..],
1132 filters,
1133 limit,
1134 )?);
1135 }
1136
1137 Ok(results)
1138}
1139
1140fn update_chunk(
1141 connection: &Connection,
1142 table_name: &str,
1143 values: &[Vec<(&str, Box<dyn Expression>)>],
1144 filters: Option<&[Box<dyn BooleanExpression>]>,
1145 limit: Option<usize>,
1146) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1147 let first = values[0].as_slice();
1148 let expected_value_size = first.len();
1149
1150 if let Some(bad_row) = values.iter().skip(1).find(|v| {
1151 v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
1152 }) {
1153 log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
1154 return Err(RusqliteDatabaseError::InvalidRequest);
1155 }
1156
1157 let set_clause = values[0]
1158 .iter()
1159 .map(|(name, _value)| format!("`{name}` = EXCLUDED.`{name}`"))
1160 .collect::<Vec<_>>()
1161 .join(", ");
1162
1163 let column_names = values[0]
1164 .iter()
1165 .map(|(key, _v)| format!("`{key}`"))
1166 .collect::<Vec<_>>()
1167 .join(", ");
1168
1169 let select_query = limit.map(|_| {
1170 format!(
1171 "SELECT rowid FROM {table_name} {}",
1172 build_where_clause(filters),
1173 )
1174 });
1175
1176 let query = format!(
1177 "
1178 UPDATE {table_name} ({column_names})
1179 {}
1180 SET {set_clause}
1181 RETURNING *",
1182 build_update_where_clause(filters, limit, select_query.as_deref()),
1183 );
1184
1185 let all_values = values
1186 .iter()
1187 .flat_map(std::iter::IntoIterator::into_iter)
1188 .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
1189 .map(std::convert::Into::into)
1190 .collect::<Vec<_>>();
1191 let mut all_filter_values = filters
1192 .as_ref()
1193 .map(|filters| {
1194 filters
1195 .iter()
1196 .flat_map(|value| {
1197 value
1198 .params()
1199 .unwrap_or_default()
1200 .into_iter()
1201 .cloned()
1202 .map(std::convert::Into::into)
1203 .collect::<Vec<_>>()
1204 })
1205 .collect::<Vec<_>>()
1206 })
1207 .unwrap_or_default();
1208
1209 if limit.is_some() {
1210 all_filter_values.extend(all_filter_values.clone());
1211 }
1212
1213 let all_values = [all_values, all_filter_values].concat();
1214
1215 log::trace!("Running update chunk query: {query} with params: {all_values:?}");
1216
1217 let mut statement = connection.prepare_cached(&query)?;
1218 let column_names = statement
1219 .column_names()
1220 .iter()
1221 .map(std::string::ToString::to_string)
1222 .collect::<Vec<_>>();
1223
1224 bind_values(&mut statement, Some(&all_values), true, 0)?;
1225
1226 to_rows(&column_names, statement.raw_query())
1227}
1228
1229pub fn upsert_multi(
1233 connection: &Connection,
1234 table_name: &str,
1235 unique: &[Box<dyn Expression>],
1236 values: &[Vec<(&str, Box<dyn Expression>)>],
1237) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1238 let mut results = vec![];
1239
1240 if values.is_empty() {
1241 return Ok(results);
1242 }
1243
1244 let mut pos = 0;
1245 let mut i = 0;
1246 let mut last_i = i;
1247
1248 for value in values {
1249 let count = value.len();
1250 if pos + count >= (i16::MAX - 1) as usize {
1251 results.append(&mut upsert_chunk(
1252 connection,
1253 table_name,
1254 unique,
1255 &values[last_i..i],
1256 )?);
1257 last_i = i;
1258 pos = 0;
1259 }
1260 i += 1;
1261 pos += count;
1262 }
1263
1264 if i > last_i {
1265 results.append(&mut upsert_chunk(
1266 connection,
1267 table_name,
1268 unique,
1269 &values[last_i..],
1270 )?);
1271 }
1272
1273 Ok(results)
1274}
1275
1276fn upsert_chunk(
1277 connection: &Connection,
1278 table_name: &str,
1279 unique: &[Box<dyn Expression>],
1280 values: &[Vec<(&str, Box<dyn Expression>)>],
1281) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1282 let first = values[0].as_slice();
1283 let expected_value_size = first.len();
1284
1285 if let Some(bad_row) = values.iter().skip(1).find(|v| {
1286 v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
1287 }) {
1288 log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
1289 return Err(RusqliteDatabaseError::InvalidRequest);
1290 }
1291
1292 let set_clause = values[0]
1293 .iter()
1294 .map(|(name, _value)| format!("`{name}` = EXCLUDED.`{name}`"))
1295 .collect::<Vec<_>>()
1296 .join(", ");
1297
1298 let column_names = values[0]
1299 .iter()
1300 .map(|(key, _v)| format!("`{key}`"))
1301 .collect::<Vec<_>>()
1302 .join(", ");
1303
1304 let values_str_list = values
1305 .iter()
1306 .map(|v| format!("({})", build_values_props(v).join(", ")))
1307 .collect::<Vec<_>>();
1308
1309 let values_str = values_str_list.join(", ");
1310 let values_str = if values_str.is_empty() {
1311 "DEFAULT VALUES".to_string()
1312 } else {
1313 format!("VALUES {values_str}")
1314 };
1315
1316 let unique_conflict = unique
1317 .iter()
1318 .map(|x| x.to_sql())
1319 .collect::<Vec<_>>()
1320 .join(", ");
1321
1322 let insert_columns = if values.is_empty() {
1323 String::new()
1324 } else {
1325 format!("({column_names})")
1326 };
1327 let query = format!(
1328 "
1329 INSERT INTO {table_name} {insert_columns} {values_str}
1330 ON CONFLICT({unique_conflict}) DO UPDATE
1331 SET {set_clause}
1332 RETURNING *"
1333 );
1334
1335 let all_values = &values
1336 .iter()
1337 .flat_map(std::iter::IntoIterator::into_iter)
1338 .flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
1339 .map(std::convert::Into::into)
1340 .collect::<Vec<_>>();
1341
1342 log::trace!("Running upsert chunk query: {query} with params: {all_values:?}");
1343
1344 let mut statement = connection.prepare_cached(&query)?;
1345 let column_names = statement
1346 .column_names()
1347 .iter()
1348 .map(std::string::ToString::to_string)
1349 .collect::<Vec<_>>();
1350
1351 bind_values(&mut statement, Some(all_values), true, 0)?;
1352
1353 to_rows(&column_names, statement.raw_query())
1354}
1355
1356fn upsert(
1357 connection: &Connection,
1358 table_name: &str,
1359 values: &[(&str, Box<dyn Expression>)],
1360 filters: Option<&[Box<dyn BooleanExpression>]>,
1361 limit: Option<usize>,
1362) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
1363 let rows = update_and_get_rows(connection, table_name, values, filters, limit)?;
1364
1365 Ok(if rows.is_empty() {
1366 vec![insert_and_get_row(connection, table_name, values)?]
1367 } else {
1368 rows
1369 })
1370}
1371
1372#[allow(unused)]
1373fn upsert_and_get_row(
1374 connection: &Connection,
1375 table_name: &str,
1376 values: &[(&str, Box<dyn Expression>)],
1377 filters: Option<&[Box<dyn BooleanExpression>]>,
1378 limit: Option<usize>,
1379) -> Result<crate::Row, RusqliteDatabaseError> {
1380 match find_row(connection, table_name, false, &["*"], filters, None, None)? {
1381 Some(row) => {
1382 let updated =
1383 update_and_get_row(connection, table_name, values, filters, limit)?.unwrap();
1384
1385 let str1 = format!("{row:?}");
1386 let str2 = format!("{updated:?}");
1387
1388 if str1 == str2 {
1389 log::trace!("No updates to {table_name}");
1390 } else {
1391 log::debug!("Changed {table_name} from {str1} to {str2}");
1392 }
1393
1394 Ok(updated)
1395 }
1396 None => Ok(insert_and_get_row(connection, table_name, values)?),
1397 }
1398}
1399
1400#[allow(clippy::module_name_repetitions)]
1401#[derive(Debug, Clone)]
1402pub struct RusqliteDatabaseValue(DatabaseValue);
1403
1404impl From<DatabaseValue> for RusqliteDatabaseValue {
1405 fn from(value: DatabaseValue) -> Self {
1406 Self(value)
1407 }
1408}
1409
1410impl Deref for RusqliteDatabaseValue {
1411 type Target = DatabaseValue;
1412
1413 fn deref(&self) -> &Self::Target {
1414 &self.0
1415 }
1416}
1417
1418impl Expression for RusqliteDatabaseValue {
1419 fn values(&self) -> Option<Vec<&DatabaseValue>> {
1420 Some(vec![self])
1421 }
1422
1423 fn is_null(&self) -> bool {
1424 matches!(
1425 self.0,
1426 DatabaseValue::Null
1427 | DatabaseValue::BoolOpt(None)
1428 | DatabaseValue::RealOpt(None)
1429 | DatabaseValue::StringOpt(None)
1430 | DatabaseValue::NumberOpt(None)
1431 | DatabaseValue::UNumberOpt(None)
1432 )
1433 }
1434
1435 fn expression_type(&self) -> ExpressionType {
1436 ExpressionType::DatabaseValue(self)
1437 }
1438}