1use std::str::FromStr;
10
11use chrono::{DateTime, Utc};
12use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions, SqliteRow};
13use sqlx::Row as _;
14
15use crate::error::Error;
16
17#[derive(Clone)]
18pub struct Db {
19 pool: SqlitePool,
20}
21
22impl Db {
23 pub async fn connect(url: &str) -> Result<Self, Error> {
31 let opts = SqliteConnectOptions::from_str(url)
32 .map_err(|e| Error::Internal(format!("invalid database URL: {e}")))?
33 .foreign_keys(true);
34 let pool = SqlitePoolOptions::new().connect_with(opts).await?;
35 Ok(Self { pool })
36 }
37
38 pub async fn memory() -> Result<Self, Error> {
43 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
44 .map_err(|e| Error::Internal(format!("invalid database URL: {e}")))?
45 .foreign_keys(true);
46 let pool = SqlitePoolOptions::new()
47 .max_connections(1)
48 .connect_with(opts)
49 .await?;
50 Ok(Self { pool })
51 }
52
53 pub async fn execute(&self, sql: &str) -> Result<(), Error> {
54 sqlx::query(sql).execute(&self.pool).await?;
55 Ok(())
56 }
57
58 pub(crate) fn pool(&self) -> &SqlitePool {
59 &self.pool
60 }
61}
62
63pub struct Row<'a> {
64 inner: &'a SqliteRow,
65}
66
67impl<'a> Row<'a> {
68 pub(crate) fn new(inner: &'a SqliteRow) -> Self {
69 Self { inner }
70 }
71
72 pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
73 self.inner.try_get(name).map_err(Error::from)
74 }
75
76 pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
77 self.inner.try_get(name).map_err(Error::from)
78 }
79
80 pub fn get_string(&self, name: &str) -> Result<String, Error> {
81 self.inner.try_get(name).map_err(Error::from)
82 }
83
84 pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
85 self.inner.try_get(name).map_err(Error::from)
86 }
87
88 pub fn get_datetime(&self, name: &str) -> Result<DateTime<Utc>, Error> {
89 self.inner.try_get(name).map_err(Error::from)
90 }
91
92 pub fn get_optional_i32(&self, name: &str) -> Result<Option<i32>, Error> {
96 self.inner.try_get(name).map_err(Error::from)
97 }
98
99 pub fn get_optional_i64(&self, name: &str) -> Result<Option<i64>, Error> {
100 self.inner.try_get(name).map_err(Error::from)
101 }
102
103 pub fn get_optional_string(&self, name: &str) -> Result<Option<String>, Error> {
104 self.inner.try_get(name).map_err(Error::from)
105 }
106
107 pub fn get_optional_bool(&self, name: &str) -> Result<Option<bool>, Error> {
108 self.inner.try_get(name).map_err(Error::from)
109 }
110
111 pub fn get_optional_datetime(&self, name: &str) -> Result<Option<DateTime<Utc>>, Error> {
112 self.inner.try_get(name).map_err(Error::from)
113 }
114}
115
116#[non_exhaustive]
122#[derive(Debug)]
123pub enum Value {
124 I32(i32),
125 I64(i64),
126 String(String),
127 Bool(bool),
128 DateTime(DateTime<Utc>),
129 Null,
131}
132
133impl From<i32> for Value {
134 fn from(v: i32) -> Self {
135 Value::I32(v)
136 }
137}
138
139impl From<i64> for Value {
140 fn from(v: i64) -> Self {
141 Value::I64(v)
142 }
143}
144
145impl From<String> for Value {
146 fn from(v: String) -> Self {
147 Value::String(v)
148 }
149}
150
151impl From<&str> for Value {
152 fn from(v: &str) -> Self {
153 Value::String(v.to_owned())
154 }
155}
156
157impl From<bool> for Value {
158 fn from(v: bool) -> Self {
159 Value::Bool(v)
160 }
161}
162
163impl From<DateTime<Utc>> for Value {
164 fn from(v: DateTime<Utc>) -> Self {
165 Value::DateTime(v)
166 }
167}
168
169impl<T> From<Option<T>> for Value
173where
174 T: Into<Value>,
175{
176 fn from(v: Option<T>) -> Self {
177 match v {
178 Some(inner) => inner.into(),
179 None => Value::Null,
180 }
181 }
182}
183
184pub trait Model: Sized + Send + Sync + Unpin + 'static {
185 const TABLE: &'static str;
186 const COLUMNS: &'static [&'static str];
187 const INSERT_COLUMNS: &'static [&'static str];
188
189 fn id(&self) -> i64;
190 fn from_row(row: Row<'_>) -> Result<Self, Error>;
191 fn insert_values(&self) -> Vec<Value>;
192
193 fn find(
194 db: &Db,
195 id: i64,
196 ) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
197 where
198 Self: Send,
199 {
200 async move {
201 let sql = format!(
202 "SELECT {} FROM {} WHERE id = ?",
203 Self::COLUMNS.join(", "),
204 Self::TABLE,
205 );
206 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
207 match row {
208 Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
209 None => Ok(None),
210 }
211 }
212 }
213
214 fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
215 async move {
216 let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
217 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
218 rows.iter().map(|r| Self::from_row(Row::new(r))).collect()
219 }
220 }
221
222 fn create<'a>(
223 &'a self,
224 db: &'a Db,
225 ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
226 async move {
227 let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
228 let sql = format!(
229 "INSERT INTO {} ({}) VALUES ({})",
230 Self::TABLE,
231 Self::INSERT_COLUMNS.join(", "),
232 placeholders,
233 );
234 let mut query = sqlx::query(&sql);
235 for v in self.insert_values() {
236 query = bind_value(query, v);
237 }
238 let result = query.execute(db.pool()).await?;
239 Ok(result.last_insert_rowid())
240 }
241 }
242
243 fn update<'a>(
244 &'a self,
245 db: &'a Db,
246 ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
247 async move {
248 let assignments: Vec<String> = Self::INSERT_COLUMNS
249 .iter()
250 .map(|c| format!("{c} = ?"))
251 .collect();
252 let sql = format!(
253 "UPDATE {} SET {} WHERE id = ?",
254 Self::TABLE,
255 assignments.join(", "),
256 );
257 let mut query = sqlx::query(&sql);
258 for v in self.insert_values() {
259 query = bind_value(query, v);
260 }
261 query = query.bind(self.id());
262 query.execute(db.pool()).await?;
263 Ok(())
264 }
265 }
266
267 fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
268 async move {
269 let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
270 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
271 Ok(())
272 }
273 }
274}
275
276fn bind_value<'q>(
277 query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
278 value: Value,
279) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
280 match value {
281 Value::I32(v) => query.bind(v),
282 Value::I64(v) => query.bind(v),
283 Value::String(v) => query.bind(v),
284 Value::Bool(v) => query.bind(v),
285 Value::DateTime(v) => query.bind(v),
286 Value::Null => query.bind(Option::<i64>::None),
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use chrono::TimeZone as _;
294
295 #[derive(Debug, PartialEq)]
296 struct User {
297 id: i64,
298 name: String,
299 is_admin: bool,
300 }
301
302 impl Model for User {
303 const TABLE: &'static str = "users";
304 const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
305 const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];
306
307 fn id(&self) -> i64 {
308 self.id
309 }
310
311 fn from_row(row: Row<'_>) -> Result<Self, Error> {
312 Ok(Self {
313 id: row.get_i64("id")?,
314 name: row.get_string("name")?,
315 is_admin: row.get_bool("is_admin")?,
316 })
317 }
318
319 fn insert_values(&self) -> Vec<Value> {
320 vec![self.name.clone().into(), self.is_admin.into()]
321 }
322 }
323
324 async fn setup() -> Db {
325 let db = Db::memory().await.unwrap();
326 db.execute(
327 "CREATE TABLE users (
328 id INTEGER PRIMARY KEY AUTOINCREMENT,
329 name TEXT NOT NULL,
330 is_admin INTEGER NOT NULL
331 )",
332 )
333 .await
334 .unwrap();
335 db
336 }
337
338 #[tokio::test]
339 async fn create_assigns_new_id_and_find_reads_it_back() {
340 let db = setup().await;
341 let u = User {
342 id: 0,
343 name: "Alice".into(),
344 is_admin: false,
345 };
346 let id = u.create(&db).await.unwrap();
347 assert!(id >= 1);
348 let back = User::find(&db, id).await.unwrap().unwrap();
349 assert_eq!(back.name, "Alice");
350 assert!(!back.is_admin);
351 assert_eq!(back.id, id);
352 }
353
354 #[tokio::test]
355 async fn find_missing_returns_none() {
356 let db = setup().await;
357 assert!(User::find(&db, 42).await.unwrap().is_none());
358 }
359
360 #[tokio::test]
361 async fn all_returns_every_row() {
362 let db = setup().await;
363 User {
364 id: 0,
365 name: "a".into(),
366 is_admin: false,
367 }
368 .create(&db)
369 .await
370 .unwrap();
371 User {
372 id: 0,
373 name: "b".into(),
374 is_admin: true,
375 }
376 .create(&db)
377 .await
378 .unwrap();
379 User {
380 id: 0,
381 name: "c".into(),
382 is_admin: false,
383 }
384 .create(&db)
385 .await
386 .unwrap();
387 let rows = User::all(&db).await.unwrap();
388 assert_eq!(rows.len(), 3);
389 let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
390 assert_eq!(names, vec!["a", "b", "c"]);
391 }
392
393 #[tokio::test]
394 async fn update_modifies_row_in_place() {
395 let db = setup().await;
396 let id = User {
397 id: 0,
398 name: "old".into(),
399 is_admin: false,
400 }
401 .create(&db)
402 .await
403 .unwrap();
404 let updated = User {
405 id,
406 name: "new".into(),
407 is_admin: true,
408 };
409 updated.update(&db).await.unwrap();
410 let back = User::find(&db, id).await.unwrap().unwrap();
411 assert_eq!(back.name, "new");
412 assert!(back.is_admin);
413 }
414
415 #[tokio::test]
416 async fn delete_removes_row() {
417 let db = setup().await;
418 let id = User {
419 id: 0,
420 name: "x".into(),
421 is_admin: false,
422 }
423 .create(&db)
424 .await
425 .unwrap();
426 assert!(User::find(&db, id).await.unwrap().is_some());
427 User::delete(&db, id).await.unwrap();
428 assert!(User::find(&db, id).await.unwrap().is_none());
429 }
430
431 #[tokio::test]
432 async fn row_getters_handle_wrong_type_gracefully() {
433 let db = setup().await;
434 User {
435 id: 0,
436 name: "a".into(),
437 is_admin: false,
438 }
439 .create(&db)
440 .await
441 .unwrap();
442 let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
443 .fetch_one(db.pool())
444 .await
445 .unwrap();
446 let wrapped = Row::new(&row);
447 assert!(wrapped.get_i64("id").is_ok());
448 assert!(wrapped.get_string("nonexistent_column").is_err());
449 }
450
451 #[derive(Debug, PartialEq)]
459 struct Event {
460 id: i64,
461 title: String,
462 note: Option<String>,
463 priority: Option<i32>,
464 starts_at: Option<DateTime<Utc>>,
465 }
466
467 impl Model for Event {
468 const TABLE: &'static str = "events";
469 const COLUMNS: &'static [&'static str] = &["id", "title", "note", "priority", "starts_at"];
470 const INSERT_COLUMNS: &'static [&'static str] = &["title", "note", "priority", "starts_at"];
471
472 fn id(&self) -> i64 {
473 self.id
474 }
475
476 fn from_row(row: Row<'_>) -> Result<Self, Error> {
477 Ok(Self {
478 id: row.get_i64("id")?,
479 title: row.get_string("title")?,
480 note: row.get_optional_string("note")?,
481 priority: row.get_optional_i32("priority")?,
482 starts_at: row.get_optional_datetime("starts_at")?,
483 })
484 }
485
486 fn insert_values(&self) -> Vec<Value> {
487 vec![
488 self.title.clone().into(),
489 self.note.clone().into(),
490 self.priority.into(),
491 self.starts_at.into(),
492 ]
493 }
494 }
495
496 async fn setup_events() -> Db {
497 let db = Db::memory().await.unwrap();
498 db.execute(
499 "CREATE TABLE events (
500 id INTEGER PRIMARY KEY AUTOINCREMENT,
501 title TEXT NOT NULL,
502 note TEXT NULL,
503 priority INTEGER NULL,
504 starts_at TEXT NULL
505 )",
506 )
507 .await
508 .unwrap();
509 db
510 }
511
512 #[tokio::test]
513 async fn option_none_round_trips_as_null() {
514 let db = setup_events().await;
515 let id = Event {
516 id: 0,
517 title: "empty".into(),
518 note: None,
519 priority: None,
520 starts_at: None,
521 }
522 .create(&db)
523 .await
524 .unwrap();
525
526 let back = Event::find(&db, id).await.unwrap().unwrap();
527 assert_eq!(back.note, None);
528 assert_eq!(back.priority, None);
529 assert_eq!(back.starts_at, None);
530
531 let row = sqlx::query(
535 "SELECT note IS NULL AS note_is_null,
536 priority IS NULL AS priority_is_null,
537 starts_at IS NULL AS starts_is_null
538 FROM events WHERE id = ?",
539 )
540 .bind(id)
541 .fetch_one(db.pool())
542 .await
543 .unwrap();
544 assert_eq!(row.get::<i64, _>(0), 1);
545 assert_eq!(row.get::<i64, _>(1), 1);
546 assert_eq!(row.get::<i64, _>(2), 1);
547 }
548
549 #[tokio::test]
550 async fn option_some_round_trips_without_data_loss() {
551 let db = setup_events().await;
552 let when = Utc.with_ymd_and_hms(2026, 4, 18, 10, 12, 33).unwrap();
553 let id = Event {
554 id: 0,
555 title: "full".into(),
556 note: Some("hello".into()),
557 priority: Some(7),
558 starts_at: Some(when),
559 }
560 .create(&db)
561 .await
562 .unwrap();
563
564 let back = Event::find(&db, id).await.unwrap().unwrap();
565 assert_eq!(back.note.as_deref(), Some("hello"));
566 assert_eq!(back.priority, Some(7));
567 assert_eq!(back.starts_at, Some(when));
568 }
569
570 #[tokio::test]
571 async fn option_update_flips_null_to_some_and_back() {
572 let db = setup_events().await;
573 let id = Event {
574 id: 0,
575 title: "t".into(),
576 note: None,
577 priority: None,
578 starts_at: None,
579 }
580 .create(&db)
581 .await
582 .unwrap();
583
584 Event {
585 id,
586 title: "t".into(),
587 note: Some("filled".into()),
588 priority: Some(1),
589 starts_at: None,
590 }
591 .update(&db)
592 .await
593 .unwrap();
594 let mid = Event::find(&db, id).await.unwrap().unwrap();
595 assert_eq!(mid.note.as_deref(), Some("filled"));
596 assert_eq!(mid.priority, Some(1));
597 assert_eq!(mid.starts_at, None);
598
599 Event {
600 id,
601 title: "t".into(),
602 note: None,
603 priority: None,
604 starts_at: None,
605 }
606 .update(&db)
607 .await
608 .unwrap();
609 let after = Event::find(&db, id).await.unwrap().unwrap();
610 assert_eq!(after.note, None);
611 assert_eq!(after.priority, None);
612 assert_eq!(after.starts_at, None);
613 }
614}