Skip to main content

wae_database/orm/
repository.rs

1//! 仓储实现模块
2//!
3//! 提供 CRUD 操作封装
4
5use super::{
6    builder::{DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder},
7    condition::Condition,
8    entity::{Entity, FromRow, ToRow},
9};
10use crate::{DatabaseResult, connection::DatabaseConnection};
11use std::marker::PhantomData;
12
13#[cfg(feature = "turso")]
14use crate::{DatabaseError, connection::DatabaseRows};
15#[cfg(feature = "turso")]
16use turso::Value as TursoValue;
17
18/// 仓储 trait - 提供 CRUD 操作
19#[allow(async_fn_in_trait)]
20pub trait Repository<E: Entity + FromRow>: Send + Sync {
21    /// 根据主键查找
22    async fn find_by_id(&self, id: E::Id) -> DatabaseResult<Option<E>>;
23
24    /// 查找所有
25    async fn find_all(&self) -> DatabaseResult<Vec<E>>;
26
27    /// 根据条件查找
28    async fn find_by(&self, condition: Condition) -> DatabaseResult<Vec<E>>;
29
30    /// 根据条件查找一条
31    async fn find_one(&self, condition: Condition) -> DatabaseResult<Option<E>>;
32
33    /// 插入实体
34    async fn insert(&self, entity: &E) -> DatabaseResult<i64>
35    where
36        E: ToRow;
37
38    /// 更新实体
39    async fn update(&self, entity: &E) -> DatabaseResult<u64>
40    where
41        E: ToRow;
42
43    /// 根据主键删除
44    async fn delete_by_id(&self, id: E::Id) -> DatabaseResult<u64>;
45
46    /// 根据条件删除
47    async fn delete_by(&self, condition: Condition) -> DatabaseResult<u64>;
48
49    /// 统计数量
50    async fn count(&self) -> DatabaseResult<u64>;
51
52    /// 根据条件统计
53    async fn count_by(&self, condition: Condition) -> DatabaseResult<u64>;
54
55    /// 检查是否存在
56    async fn exists(&self, id: E::Id) -> DatabaseResult<bool>;
57}
58
59#[cfg(feature = "turso")]
60/// 基于数据库连接的仓储实现
61pub struct DbRepository<E: Entity + FromRow> {
62    conn: Box<dyn DatabaseConnection>,
63    _marker: PhantomData<E>,
64}
65
66#[cfg(feature = "turso")]
67impl<E: Entity + FromRow> DbRepository<E> {
68    /// 创建仓储
69    pub fn new(conn: Box<dyn DatabaseConnection>) -> Self {
70        Self { conn, _marker: PhantomData }
71    }
72
73    /// 获取 SELECT 构建器
74    pub fn select(&self) -> SelectBuilder<E> {
75        QueryBuilder::select()
76    }
77
78    /// 获取 INSERT 构建器
79    pub fn insert_builder(&self) -> InsertBuilder<E> {
80        QueryBuilder::insert()
81    }
82
83    /// 获取 UPDATE 构建器
84    pub fn update_builder(&self) -> UpdateBuilder<E> {
85        QueryBuilder::update()
86    }
87
88    /// 获取 DELETE 构建器
89    pub fn delete_builder(&self) -> DeleteBuilder<E> {
90        QueryBuilder::delete()
91    }
92
93    /// 执行查询并解析结果
94    async fn query_and_parse(&self, sql: &str, params: Vec<TursoValue>) -> DatabaseResult<Vec<E>> {
95        let rows = self.conn.query_with_turso(sql, params).await?;
96        self.parse_rows(rows).await
97    }
98
99    /// 解析查询结果
100    async fn parse_rows(&self, mut rows: DatabaseRows) -> DatabaseResult<Vec<E>> {
101        let mut entities = Vec::new();
102        while let Some(row) = rows.next().await? {
103            entities.push(E::from_row(&row)?);
104        }
105        Ok(entities)
106    }
107}
108
109#[cfg(feature = "turso")]
110impl<E: Entity + FromRow> Repository<E> for DbRepository<E> {
111    async fn find_by_id(&self, id: E::Id) -> DatabaseResult<Option<E>> {
112        let condition = Condition::eq(E::id_column(), id);
113        self.find_one(condition).await
114    }
115
116    async fn find_all(&self) -> DatabaseResult<Vec<E>> {
117        let (sql, params) = QueryBuilder::select::<E>().build_turso();
118        self.query_and_parse(&sql, params).await
119    }
120
121    async fn find_by(&self, condition: Condition) -> DatabaseResult<Vec<E>> {
122        let (sql, params) = QueryBuilder::select::<E>().where_(condition).build_turso();
123        self.query_and_parse(&sql, params).await
124    }
125
126    async fn find_one(&self, condition: Condition) -> DatabaseResult<Option<E>> {
127        let (sql, params) = QueryBuilder::select::<E>().where_(condition).limit(1).build_turso();
128        let mut entities = self.query_and_parse(&sql, params).await?;
129        Ok(entities.pop())
130    }
131
132    async fn insert(&self, entity: &E) -> DatabaseResult<i64>
133    where
134        E: ToRow,
135    {
136        let (sql, params) = InsertBuilder::<E>::from_entity(entity).build_turso();
137        self.conn.execute_with_turso(&sql, params).await?;
138        let rows = self.conn.query("SELECT last_insert_rowid()").await?;
139        let mut rows = rows;
140        if let Some(row) = rows.next().await? {
141            row.get::<i64>(0)
142        }
143        else {
144            Err(DatabaseError::internal("Failed to get last insert id"))
145        }
146    }
147
148    async fn update(&self, entity: &E) -> DatabaseResult<u64>
149    where
150        E: ToRow,
151    {
152        let id = entity.id();
153        let (sql, params) = UpdateBuilder::<E>::from_entity(entity).where_id(id).build_turso();
154        self.conn.execute_with_turso(&sql, params).await
155    }
156
157    async fn delete_by_id(&self, id: E::Id) -> DatabaseResult<u64> {
158        let (sql, params) = QueryBuilder::delete::<E>().where_id(id).build_turso();
159        self.conn.execute_with_turso(&sql, params).await
160    }
161
162    async fn delete_by(&self, condition: Condition) -> DatabaseResult<u64> {
163        let (sql, params) = QueryBuilder::delete::<E>().where_(condition).build_turso();
164        self.conn.execute_with_turso(&sql, params).await
165    }
166
167    async fn count(&self) -> DatabaseResult<u64> {
168        let sql = format!("SELECT COUNT(*) FROM {}", E::table_name());
169        let rows = self.conn.query(&sql).await?;
170        let mut rows = rows;
171        if let Some(row) = rows.next().await? { row.get::<i64>(0).map(|n| n as u64) } else { Ok(0) }
172    }
173
174    async fn count_by(&self, condition: Condition) -> DatabaseResult<u64> {
175        let (cond_sql, params) = condition.build_turso();
176        let sql = format!("SELECT COUNT(*) FROM {} WHERE {}", E::table_name(), cond_sql);
177        let rows = self.conn.query_with_turso(&sql, params).await?;
178        let mut rows = rows;
179        if let Some(row) = rows.next().await? { row.get::<i64>(0).map(|n| n as u64) } else { Ok(0) }
180    }
181
182    async fn exists(&self, id: E::Id) -> DatabaseResult<bool> {
183        let count = self.count_by(Condition::eq(E::id_column(), id)).await?;
184        Ok(count > 0)
185    }
186}
187
188#[cfg(feature = "mysql")]
189/// 基于 MySQL 数据库连接的仓储实现
190pub struct MySqlDbRepository<E: Entity + FromRow> {
191    conn: Box<dyn DatabaseConnection>,
192    _marker: PhantomData<E>,
193}
194
195#[cfg(feature = "mysql")]
196impl<E: Entity + FromRow> MySqlDbRepository<E> {
197    /// 创建仓储
198    pub fn new(conn: Box<dyn DatabaseConnection>) -> Self {
199        Self { conn, _marker: PhantomData }
200    }
201
202    /// 获取 SELECT 构建器
203    pub fn select(&self) -> SelectBuilder<E> {
204        QueryBuilder::select()
205    }
206
207    /// 获取 INSERT 构建器
208    pub fn insert_builder(&self) -> InsertBuilder<E> {
209        QueryBuilder::insert()
210    }
211
212    /// 获取 UPDATE 构建器
213    pub fn update_builder(&self) -> UpdateBuilder<E> {
214        QueryBuilder::update()
215    }
216
217    /// 获取 DELETE 构建器
218    pub fn delete_builder(&self) -> DeleteBuilder<E> {
219        QueryBuilder::delete()
220    }
221
222    /// 解析查询结果
223    async fn parse_rows(&self, mut rows: DatabaseRows) -> DatabaseResult<Vec<E>> {
224        let mut entities = Vec::new();
225        while let Some(row) = rows.next().await? {
226            entities.push(E::from_row(&row)?);
227        }
228        Ok(entities)
229    }
230}
231
232#[cfg(feature = "mysql")]
233impl<E: Entity + FromRow> Repository<E> for MySqlDbRepository<E> {
234    async fn find_by_id(&self, id: E::Id) -> DatabaseResult<Option<E>> {
235        let condition = Condition::eq(E::id_column(), id);
236        self.find_one(condition).await
237    }
238
239    async fn find_all(&self) -> DatabaseResult<Vec<E>> {
240        let sql = format!("SELECT * FROM {}", E::table_name());
241        let rows = self.conn.query(&sql).await?;
242        self.parse_rows(rows).await
243    }
244
245    async fn find_by(&self, condition: Condition) -> DatabaseResult<Vec<E>> {
246        let (cond_sql, params) = condition.build_mysql();
247        let sql = format!("SELECT * FROM {} WHERE {}", E::table_name(), cond_sql);
248        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
249        let rows = self.conn.query_with(&sql, wae_params).await?;
250        self.parse_rows(rows).await
251    }
252
253    async fn find_one(&self, condition: Condition) -> DatabaseResult<Option<E>> {
254        let (cond_sql, params) = condition.build_mysql();
255        let sql = format!("SELECT * FROM {} WHERE {} LIMIT 1", E::table_name(), cond_sql);
256        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
257        let rows = self.conn.query_with(&sql, wae_params).await?;
258        let mut entities = self.parse_rows(rows).await?;
259        Ok(entities.pop())
260    }
261
262    async fn insert(&self, entity: &E) -> DatabaseResult<i64>
263    where
264        E: ToRow,
265    {
266        let (sql, params) = InsertBuilder::<E>::from_entity(entity).build_mysql();
267        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
268        self.conn.execute_with(&sql, wae_params).await?;
269        let rows = self.conn.query("SELECT LAST_INSERT_ID()").await?;
270        let mut rows = rows;
271        if let Some(row) = rows.next().await? {
272            row.get::<i64>(0)
273        }
274        else {
275            Err(crate::DatabaseError::internal("Failed to get last insert id"))
276        }
277    }
278
279    async fn update(&self, entity: &E) -> DatabaseResult<u64>
280    where
281        E: ToRow,
282    {
283        let id = entity.id();
284        let (sql, params) = UpdateBuilder::<E>::from_entity(entity).where_id(id).build_mysql();
285        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
286        self.conn.execute_with(&sql, wae_params).await
287    }
288
289    async fn delete_by_id(&self, id: E::Id) -> DatabaseResult<u64> {
290        let (sql, params) = QueryBuilder::delete::<E>().where_id(id).build_mysql();
291        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
292        self.conn.execute_with(&sql, wae_params).await
293    }
294
295    async fn delete_by(&self, condition: Condition) -> DatabaseResult<u64> {
296        let (sql, params) = QueryBuilder::delete::<E>().where_(condition).build_mysql();
297        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
298        self.conn.execute_with(&sql, wae_params).await
299    }
300
301    async fn count(&self) -> DatabaseResult<u64> {
302        let sql = format!("SELECT COUNT(*) FROM {}", E::table_name());
303        let rows = self.conn.query(&sql).await?;
304        let mut rows = rows;
305        if let Some(row) = rows.next().await? { row.get::<i64>(0).map(|n| n as u64) } else { Ok(0) }
306    }
307
308    async fn count_by(&self, condition: Condition) -> DatabaseResult<u64> {
309        let (cond_sql, params) = condition.build_mysql();
310        let sql = format!("SELECT COUNT(*) FROM {} WHERE {}", E::table_name(), cond_sql);
311        let wae_params: Vec<wae_types::Value> = params.iter().map(|v| crate::types::mysql_value_to_wae(v.clone())).collect();
312        let rows = self.conn.query_with(&sql, wae_params).await?;
313        let mut rows = rows;
314        if let Some(row) = rows.next().await? { row.get::<i64>(0).map(|n| n as u64) } else { Ok(0) }
315    }
316
317    async fn exists(&self, id: E::Id) -> DatabaseResult<bool> {
318        let count = self.count_by(Condition::eq(E::id_column(), id)).await?;
319        Ok(count > 0)
320    }
321}