1use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
10use sqlx::Row as _;
11
12use crate::error::Error;
13
14#[derive(Clone)]
15pub struct Db {
16 pool: SqlitePool,
17}
18
19impl Db {
20 pub async fn connect(url: &str) -> Result<Self, Error> {
21 let pool = SqlitePool::connect(url).await?;
22 Ok(Self { pool })
23 }
24
25 pub async fn memory() -> Result<Self, Error> {
26 let pool = SqlitePoolOptions::new()
27 .max_connections(1)
28 .connect("sqlite::memory:")
29 .await?;
30 Ok(Self { pool })
31 }
32
33 pub async fn execute(&self, sql: &str) -> Result<(), Error> {
34 sqlx::query(sql).execute(&self.pool).await?;
35 Ok(())
36 }
37
38 pub(crate) fn pool(&self) -> &SqlitePool {
39 &self.pool
40 }
41}
42
43pub struct Row<'a> {
44 inner: &'a SqliteRow,
45}
46
47impl<'a> Row<'a> {
48 pub(crate) fn new(inner: &'a SqliteRow) -> Self {
49 Self { inner }
50 }
51
52 pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
53 self.inner.try_get(name).map_err(Error::from)
54 }
55
56 pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
57 self.inner.try_get(name).map_err(Error::from)
58 }
59
60 pub fn get_string(&self, name: &str) -> Result<String, Error> {
61 self.inner.try_get(name).map_err(Error::from)
62 }
63
64 pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
65 self.inner.try_get(name).map_err(Error::from)
66 }
67}
68
69#[derive(Debug)]
70pub enum Value {
71 I32(i32),
72 I64(i64),
73 String(String),
74 Bool(bool),
75 Null,
76}
77
78impl From<i32> for Value {
79 fn from(v: i32) -> Self {
80 Value::I32(v)
81 }
82}
83
84impl From<i64> for Value {
85 fn from(v: i64) -> Self {
86 Value::I64(v)
87 }
88}
89
90impl From<String> for Value {
91 fn from(v: String) -> Self {
92 Value::String(v)
93 }
94}
95
96impl From<&str> for Value {
97 fn from(v: &str) -> Self {
98 Value::String(v.to_owned())
99 }
100}
101
102impl From<bool> for Value {
103 fn from(v: bool) -> Self {
104 Value::Bool(v)
105 }
106}
107
108pub trait Model: Sized + Send + Sync + Unpin + 'static {
109 const TABLE: &'static str;
110 const COLUMNS: &'static [&'static str];
111 const INSERT_COLUMNS: &'static [&'static str];
112
113 fn id(&self) -> i64;
114 fn from_row(row: Row<'_>) -> Result<Self, Error>;
115 fn insert_values(&self) -> Vec<Value>;
116
117 fn find(
118 db: &Db,
119 id: i64,
120 ) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
121 where
122 Self: Send,
123 {
124 async move {
125 let sql = format!(
126 "SELECT {} FROM {} WHERE id = ?",
127 Self::COLUMNS.join(", "),
128 Self::TABLE,
129 );
130 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
131 match row {
132 Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
133 None => Ok(None),
134 }
135 }
136 }
137
138 fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
139 async move {
140 let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
141 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
142 rows.iter().map(|r| Self::from_row(Row::new(r))).collect()
143 }
144 }
145
146 fn create<'a>(
147 &'a self,
148 db: &'a Db,
149 ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
150 async move {
151 let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
152 let sql = format!(
153 "INSERT INTO {} ({}) VALUES ({})",
154 Self::TABLE,
155 Self::INSERT_COLUMNS.join(", "),
156 placeholders,
157 );
158 let mut query = sqlx::query(&sql);
159 for v in self.insert_values() {
160 query = bind_value(query, v);
161 }
162 let result = query.execute(db.pool()).await?;
163 Ok(result.last_insert_rowid())
164 }
165 }
166
167 fn update<'a>(
168 &'a self,
169 db: &'a Db,
170 ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
171 async move {
172 let assignments: Vec<String> = Self::INSERT_COLUMNS
173 .iter()
174 .map(|c| format!("{c} = ?"))
175 .collect();
176 let sql = format!(
177 "UPDATE {} SET {} WHERE id = ?",
178 Self::TABLE,
179 assignments.join(", "),
180 );
181 let mut query = sqlx::query(&sql);
182 for v in self.insert_values() {
183 query = bind_value(query, v);
184 }
185 query = query.bind(self.id());
186 query.execute(db.pool()).await?;
187 Ok(())
188 }
189 }
190
191 fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
192 async move {
193 let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
194 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
195 Ok(())
196 }
197 }
198}
199
200fn bind_value<'q>(
201 query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
202 value: Value,
203) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
204 match value {
205 Value::I32(v) => query.bind(v),
206 Value::I64(v) => query.bind(v),
207 Value::String(v) => query.bind(v),
208 Value::Bool(v) => query.bind(v),
209 Value::Null => query.bind(Option::<i64>::None),
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[derive(Debug, PartialEq)]
218 struct User {
219 id: i64,
220 name: String,
221 is_admin: bool,
222 }
223
224 impl Model for User {
225 const TABLE: &'static str = "users";
226 const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
227 const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];
228
229 fn id(&self) -> i64 {
230 self.id
231 }
232
233 fn from_row(row: Row<'_>) -> Result<Self, Error> {
234 Ok(Self {
235 id: row.get_i64("id")?,
236 name: row.get_string("name")?,
237 is_admin: row.get_bool("is_admin")?,
238 })
239 }
240
241 fn insert_values(&self) -> Vec<Value> {
242 vec![self.name.clone().into(), self.is_admin.into()]
243 }
244 }
245
246 async fn setup() -> Db {
247 let db = Db::memory().await.unwrap();
248 db.execute(
249 "CREATE TABLE users (
250 id INTEGER PRIMARY KEY AUTOINCREMENT,
251 name TEXT NOT NULL,
252 is_admin INTEGER NOT NULL
253 )",
254 )
255 .await
256 .unwrap();
257 db
258 }
259
260 #[tokio::test]
261 async fn create_assigns_new_id_and_find_reads_it_back() {
262 let db = setup().await;
263 let u = User {
264 id: 0,
265 name: "Alice".into(),
266 is_admin: false,
267 };
268 let id = u.create(&db).await.unwrap();
269 assert!(id >= 1);
270 let back = User::find(&db, id).await.unwrap().unwrap();
271 assert_eq!(back.name, "Alice");
272 assert!(!back.is_admin);
273 assert_eq!(back.id, id);
274 }
275
276 #[tokio::test]
277 async fn find_missing_returns_none() {
278 let db = setup().await;
279 assert!(User::find(&db, 42).await.unwrap().is_none());
280 }
281
282 #[tokio::test]
283 async fn all_returns_every_row() {
284 let db = setup().await;
285 User {
286 id: 0,
287 name: "a".into(),
288 is_admin: false,
289 }
290 .create(&db)
291 .await
292 .unwrap();
293 User {
294 id: 0,
295 name: "b".into(),
296 is_admin: true,
297 }
298 .create(&db)
299 .await
300 .unwrap();
301 User {
302 id: 0,
303 name: "c".into(),
304 is_admin: false,
305 }
306 .create(&db)
307 .await
308 .unwrap();
309 let rows = User::all(&db).await.unwrap();
310 assert_eq!(rows.len(), 3);
311 let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
312 assert_eq!(names, vec!["a", "b", "c"]);
313 }
314
315 #[tokio::test]
316 async fn update_modifies_row_in_place() {
317 let db = setup().await;
318 let id = User {
319 id: 0,
320 name: "old".into(),
321 is_admin: false,
322 }
323 .create(&db)
324 .await
325 .unwrap();
326 let updated = User {
327 id,
328 name: "new".into(),
329 is_admin: true,
330 };
331 updated.update(&db).await.unwrap();
332 let back = User::find(&db, id).await.unwrap().unwrap();
333 assert_eq!(back.name, "new");
334 assert!(back.is_admin);
335 }
336
337 #[tokio::test]
338 async fn delete_removes_row() {
339 let db = setup().await;
340 let id = User {
341 id: 0,
342 name: "x".into(),
343 is_admin: false,
344 }
345 .create(&db)
346 .await
347 .unwrap();
348 assert!(User::find(&db, id).await.unwrap().is_some());
349 User::delete(&db, id).await.unwrap();
350 assert!(User::find(&db, id).await.unwrap().is_none());
351 }
352
353 #[tokio::test]
354 async fn row_getters_handle_wrong_type_gracefully() {
355 let db = setup().await;
356 User {
357 id: 0,
358 name: "a".into(),
359 is_admin: false,
360 }
361 .create(&db)
362 .await
363 .unwrap();
364 let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
365 .fetch_one(db.pool())
366 .await
367 .unwrap();
368 let wrapped = Row::new(&row);
369 assert!(wrapped.get_i64("id").is_ok());
370 assert!(wrapped.get_string("nonexistent_column").is_err());
371 }
372}