Skip to main content

rust_d1_orm/
table.rs

1use crate::{error::OrmError, model::D1Model, query::Query, set::Set};
2use serde::Deserialize;
3use std::marker::PhantomData;
4use worker::D1Database;
5
6pub struct Table<'db, M: D1Model> {
7    db: &'db D1Database,
8    _m: PhantomData<M>,
9}
10
11impl<'db, M: D1Model> Table<'db, M> {
12    pub fn new(db: &'db D1Database) -> Self {
13        Self { db, _m: PhantomData }
14    }
15
16    pub async fn insert(&self, model: &M) -> Result<M, OrmError> {
17        let cols = M::COLUMNS.join(", ");
18        let placeholders = (1..=M::COLUMNS.len()).map(|i| format!("?{}", i)).collect::<Vec<_>>().join(", ");
19        let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", M::TABLE, cols, placeholders);
20        self.db.prepare(&sql)
21            .bind(&model.values())
22            .map_err(|_| OrmError::Bind)?
23            .first::<M>(None).await
24            .map_err(|_| OrmError::Execute)?
25            .ok_or(OrmError::Execute)
26    }
27
28    pub async fn insert_batch(&self, models: &[M]) -> Result<Vec<M>, OrmError> {
29        if models.is_empty() { return Ok(vec![]); }
30        let cols = M::COLUMNS.join(", ");
31        let placeholders = (1..=M::COLUMNS.len()).map(|i| format!("?{}", i)).collect::<Vec<_>>().join(", ");
32        let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", M::TABLE, cols, placeholders);
33        let stmts = models.iter().map(|m| {
34            self.db.prepare(&sql).bind(&m.values()).map_err(|_| OrmError::Bind)
35        }).collect::<Result<Vec<_>, _>>()?;
36        let results = self.db.batch(stmts).await.map_err(|_| OrmError::Execute)?;
37        results.into_iter().map(|r| {
38            r.results::<M>().map_err(|_| OrmError::Deserialize)?
39                .into_iter().next().ok_or(OrmError::Execute)
40        }).collect()
41    }
42
43    pub async fn find_one(&self, query: Query) -> Result<Option<M>, OrmError> {
44        let cols = M::COLUMNS.join(", ");
45        let (where_parts, values) = query.build_conditions(1);
46        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
47        let sql = format!("SELECT {} FROM {} {}", cols, M::TABLE, where_sql);
48        self.db.prepare(&sql)
49            .bind(&values)
50            .map_err(|_| OrmError::Bind)?
51            .first::<M>(None).await
52            .map_err(|_| OrmError::Execute)
53    }
54
55    pub async fn find_all(&self, query: Query) -> Result<Vec<M>, OrmError> {
56        let cols = M::COLUMNS.join(", ");
57        let (where_parts, values) = query.build_conditions(1);
58        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
59        let tail = query.build_tail();
60        let sql = format!("SELECT {} FROM {} {}{}", cols, M::TABLE, where_sql, tail);
61        let result = self.db.prepare(&sql)
62            .bind(&values)
63            .map_err(|_| OrmError::Bind)?
64            .all().await
65            .map_err(|_| OrmError::Execute)?;
66        result.results::<M>().map_err(|_| OrmError::Deserialize)
67    }
68
69    pub async fn update(&self, set: Set, query: Query) -> Result<Option<M>, OrmError> {
70        if set.is_empty() { return Ok(None); }
71        let (set_sql, mut values, next_n) = set.build(1);
72        let (where_parts, where_vals) = query.build_conditions(next_n);
73        values.extend(where_vals);
74        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
75        let sql = format!("UPDATE {} SET {} {} RETURNING *", M::TABLE, set_sql, where_sql);
76        self.db.prepare(&sql)
77            .bind(&values)
78            .map_err(|_| OrmError::Bind)?
79            .first::<M>(None).await
80            .map_err(|_| OrmError::Execute)
81    }
82
83    pub async fn delete(&self, query: Query) -> Result<(), OrmError> {
84        let (where_parts, values) = query.build_conditions(1);
85        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
86        let sql = format!("DELETE FROM {} {}", M::TABLE, where_sql);
87        self.db.prepare(&sql)
88            .bind(&values)
89            .map_err(|_| OrmError::Bind)?
90            .run().await
91            .map_err(|_| OrmError::Execute)?;
92        Ok(())
93    }
94
95    pub async fn count(&self, query: Query) -> Result<u64, OrmError> {
96        #[derive(Deserialize)]
97        struct CountRow { count: u64 }
98        let (where_parts, values) = query.build_conditions(1);
99        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
100        let sql = format!("SELECT COUNT(*) as count FROM {} {}", M::TABLE, where_sql);
101        let row = self.db.prepare(&sql)
102            .bind(&values)
103            .map_err(|_| OrmError::Bind)?
104            .first::<CountRow>(None).await
105            .map_err(|_| OrmError::Execute)?;
106        Ok(row.map(|r| r.count).unwrap_or(0))
107    }
108}