1use sqlx::Row as _;
2use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
3
4use crate::error::Error;
5
6#[derive(Clone)]
7pub struct Db {
8 pool: SqlitePool,
9}
10
11impl Db {
12 pub async fn connect(url: &str) -> Result<Self, Error> {
13 let pool = SqlitePool::connect(url).await?;
14 Ok(Self { pool })
15 }
16
17 pub async fn memory() -> Result<Self, Error> {
18 let pool = SqlitePoolOptions::new()
19 .max_connections(1)
20 .connect("sqlite::memory:")
21 .await?;
22 Ok(Self { pool })
23 }
24
25 pub async fn execute(&self, sql: &str) -> Result<(), Error> {
26 sqlx::query(sql).execute(&self.pool).await?;
27 Ok(())
28 }
29
30 pub(crate) fn pool(&self) -> &SqlitePool {
31 &self.pool
32 }
33}
34
35pub struct Row<'a> {
36 inner: &'a SqliteRow,
37}
38
39impl<'a> Row<'a> {
40 pub(crate) fn new(inner: &'a SqliteRow) -> Self {
41 Self { inner }
42 }
43
44 pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
45 self.inner.try_get(name).map_err(Error::from)
46 }
47
48 pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
49 self.inner.try_get(name).map_err(Error::from)
50 }
51
52 pub fn get_string(&self, name: &str) -> Result<String, Error> {
53 self.inner.try_get(name).map_err(Error::from)
54 }
55
56 pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
57 self.inner.try_get(name).map_err(Error::from)
58 }
59}
60
61#[derive(Debug)]
62pub enum Value {
63 I32(i32),
64 I64(i64),
65 String(String),
66 Bool(bool),
67 Null,
68}
69
70impl From<i32> for Value {
71 fn from(v: i32) -> Self {
72 Value::I32(v)
73 }
74}
75
76impl From<i64> for Value {
77 fn from(v: i64) -> Self {
78 Value::I64(v)
79 }
80}
81
82impl From<String> for Value {
83 fn from(v: String) -> Self {
84 Value::String(v)
85 }
86}
87
88impl From<&str> for Value {
89 fn from(v: &str) -> Self {
90 Value::String(v.to_owned())
91 }
92}
93
94impl From<bool> for Value {
95 fn from(v: bool) -> Self {
96 Value::Bool(v)
97 }
98}
99
100pub trait Model: Sized + Send + Sync + Unpin + 'static {
101 const TABLE: &'static str;
102 const COLUMNS: &'static [&'static str];
103 const INSERT_COLUMNS: &'static [&'static str];
104
105 fn id(&self) -> i64;
106 fn from_row(row: Row<'_>) -> Result<Self, Error>;
107 fn insert_values(&self) -> Vec<Value>;
108
109 fn find(db: &Db, id: i64) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
110 where
111 Self: Send,
112 {
113 async move {
114 let sql = format!(
115 "SELECT {} FROM {} WHERE id = ?",
116 Self::COLUMNS.join(", "),
117 Self::TABLE,
118 );
119 let row = sqlx::query(&sql)
120 .bind(id)
121 .fetch_optional(db.pool())
122 .await?;
123 match row {
124 Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
125 None => Ok(None),
126 }
127 }
128 }
129
130 fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
131 async move {
132 let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
133 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
134 rows.iter()
135 .map(|r| Self::from_row(Row::new(r)))
136 .collect()
137 }
138 }
139
140 fn create<'a>(
141 &'a self,
142 db: &'a Db,
143 ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
144 async move {
145 let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
146 let sql = format!(
147 "INSERT INTO {} ({}) VALUES ({})",
148 Self::TABLE,
149 Self::INSERT_COLUMNS.join(", "),
150 placeholders,
151 );
152 let mut query = sqlx::query(&sql);
153 for v in self.insert_values() {
154 query = bind_value(query, v);
155 }
156 let result = query.execute(db.pool()).await?;
157 Ok(result.last_insert_rowid())
158 }
159 }
160
161 fn update<'a>(
162 &'a self,
163 db: &'a Db,
164 ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
165 async move {
166 let assignments: Vec<String> = Self::INSERT_COLUMNS
167 .iter()
168 .map(|c| format!("{c} = ?"))
169 .collect();
170 let sql = format!(
171 "UPDATE {} SET {} WHERE id = ?",
172 Self::TABLE,
173 assignments.join(", "),
174 );
175 let mut query = sqlx::query(&sql);
176 for v in self.insert_values() {
177 query = bind_value(query, v);
178 }
179 query = query.bind(self.id());
180 query.execute(db.pool()).await?;
181 Ok(())
182 }
183 }
184
185 fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
186 async move {
187 let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
188 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
189 Ok(())
190 }
191 }
192}
193
194fn bind_value<'q>(
195 query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
196 value: Value,
197) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
198 match value {
199 Value::I32(v) => query.bind(v),
200 Value::I64(v) => query.bind(v),
201 Value::String(v) => query.bind(v),
202 Value::Bool(v) => query.bind(v),
203 Value::Null => query.bind(Option::<i64>::None),
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[derive(Debug, PartialEq)]
212 struct User {
213 id: i64,
214 name: String,
215 is_admin: bool,
216 }
217
218 impl Model for User {
219 const TABLE: &'static str = "users";
220 const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
221 const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];
222
223 fn id(&self) -> i64 {
224 self.id
225 }
226
227 fn from_row(row: Row<'_>) -> Result<Self, Error> {
228 Ok(Self {
229 id: row.get_i64("id")?,
230 name: row.get_string("name")?,
231 is_admin: row.get_bool("is_admin")?,
232 })
233 }
234
235 fn insert_values(&self) -> Vec<Value> {
236 vec![self.name.clone().into(), self.is_admin.into()]
237 }
238 }
239
240 async fn setup() -> Db {
241 let db = Db::memory().await.unwrap();
242 db.execute(
243 "CREATE TABLE users (
244 id INTEGER PRIMARY KEY AUTOINCREMENT,
245 name TEXT NOT NULL,
246 is_admin INTEGER NOT NULL
247 )",
248 )
249 .await
250 .unwrap();
251 db
252 }
253
254 #[tokio::test]
255 async fn create_assigns_new_id_and_find_reads_it_back() {
256 let db = setup().await;
257 let u = User {
258 id: 0,
259 name: "Alice".into(),
260 is_admin: false,
261 };
262 let id = u.create(&db).await.unwrap();
263 assert!(id >= 1);
264 let back = User::find(&db, id).await.unwrap().unwrap();
265 assert_eq!(back.name, "Alice");
266 assert!(!back.is_admin);
267 assert_eq!(back.id, id);
268 }
269
270 #[tokio::test]
271 async fn find_missing_returns_none() {
272 let db = setup().await;
273 assert!(User::find(&db, 42).await.unwrap().is_none());
274 }
275
276 #[tokio::test]
277 async fn all_returns_every_row() {
278 let db = setup().await;
279 User { id: 0, name: "a".into(), is_admin: false }.create(&db).await.unwrap();
280 User { id: 0, name: "b".into(), is_admin: true }.create(&db).await.unwrap();
281 User { id: 0, name: "c".into(), is_admin: false }.create(&db).await.unwrap();
282 let rows = User::all(&db).await.unwrap();
283 assert_eq!(rows.len(), 3);
284 let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
285 assert_eq!(names, vec!["a", "b", "c"]);
286 }
287
288 #[tokio::test]
289 async fn update_modifies_row_in_place() {
290 let db = setup().await;
291 let id = User { id: 0, name: "old".into(), is_admin: false }
292 .create(&db)
293 .await
294 .unwrap();
295 let updated = User { id, name: "new".into(), is_admin: true };
296 updated.update(&db).await.unwrap();
297 let back = User::find(&db, id).await.unwrap().unwrap();
298 assert_eq!(back.name, "new");
299 assert!(back.is_admin);
300 }
301
302 #[tokio::test]
303 async fn delete_removes_row() {
304 let db = setup().await;
305 let id = User { id: 0, name: "x".into(), is_admin: false }
306 .create(&db)
307 .await
308 .unwrap();
309 assert!(User::find(&db, id).await.unwrap().is_some());
310 User::delete(&db, id).await.unwrap();
311 assert!(User::find(&db, id).await.unwrap().is_none());
312 }
313
314 #[tokio::test]
315 async fn row_getters_handle_wrong_type_gracefully() {
316 let db = setup().await;
317 User { id: 0, name: "a".into(), is_admin: false }
318 .create(&db)
319 .await
320 .unwrap();
321 let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
322 .fetch_one(db.pool())
323 .await
324 .unwrap();
325 let wrapped = Row::new(&row);
326 assert!(wrapped.get_i64("id").is_ok());
327 assert!(wrapped.get_string("nonexistent_column").is_err());
328 }
329}