systemprompt_database/repository/
entity.rs1use async_trait::async_trait;
2use sqlx::postgres::PgRow;
3use sqlx::{FromRow, PgPool};
4use std::sync::Arc;
5
6pub trait EntityId: Send + Sync + Clone + 'static {
7 fn as_str(&self) -> &str;
8
9 fn from_string(s: String) -> Self;
10}
11
12impl EntityId for String {
13 fn as_str(&self) -> &str {
14 self
15 }
16
17 fn from_string(s: String) -> Self {
18 s
19 }
20}
21
22pub trait Entity: for<'r> FromRow<'r, PgRow> + Send + Sync + Unpin + 'static {
23 type Id: EntityId;
24
25 const TABLE: &'static str;
26
27 const COLUMNS: &'static str;
28
29 const ID_COLUMN: &'static str;
30
31 fn id(&self) -> &Self::Id;
32}
33
34#[derive(Clone)]
35pub struct GenericRepository<E: Entity> {
36 pool: Arc<PgPool>,
37 _phantom: std::marker::PhantomData<E>,
38}
39
40impl<E: Entity> std::fmt::Debug for GenericRepository<E> {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("GenericRepository")
43 .field("table", &E::TABLE)
44 .finish()
45 }
46}
47
48impl<E: Entity> GenericRepository<E> {
49 #[must_use]
50 pub const fn new(pool: Arc<PgPool>) -> Self {
51 Self {
52 pool,
53 _phantom: std::marker::PhantomData,
54 }
55 }
56
57 #[must_use]
58 pub fn pool(&self) -> &PgPool {
59 &self.pool
60 }
61
62 pub async fn get(&self, id: &E::Id) -> Result<Option<E>, sqlx::Error> {
63 let query = format!(
64 "SELECT {} FROM {} WHERE {} = $1",
65 E::COLUMNS,
66 E::TABLE,
67 E::ID_COLUMN
68 );
69 sqlx::query_as::<_, E>(&query)
70 .bind(id.as_str())
71 .fetch_optional(&*self.pool)
72 .await
73 }
74
75 pub async fn list(&self, limit: i64, offset: i64) -> Result<Vec<E>, sqlx::Error> {
76 let query = format!(
77 "SELECT {} FROM {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
78 E::COLUMNS,
79 E::TABLE
80 );
81 sqlx::query_as::<_, E>(&query)
82 .bind(limit)
83 .bind(offset)
84 .fetch_all(&*self.pool)
85 .await
86 }
87
88 pub async fn list_all(&self) -> Result<Vec<E>, sqlx::Error> {
89 let query = format!(
90 "SELECT {} FROM {} ORDER BY created_at DESC",
91 E::COLUMNS,
92 E::TABLE
93 );
94 sqlx::query_as::<_, E>(&query).fetch_all(&*self.pool).await
95 }
96
97 pub async fn delete(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
98 let query = format!("DELETE FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
99 let result = sqlx::query(&query)
100 .bind(id.as_str())
101 .execute(&*self.pool)
102 .await?;
103 Ok(result.rows_affected() > 0)
104 }
105
106 pub async fn exists(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
107 let query = format!("SELECT 1 FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
108 let result: Option<(i32,)> = sqlx::query_as(&query)
109 .bind(id.as_str())
110 .fetch_optional(&*self.pool)
111 .await?;
112 Ok(result.is_some())
113 }
114
115 pub async fn count(&self) -> Result<i64, sqlx::Error> {
116 let query = format!("SELECT COUNT(*) FROM {}", E::TABLE);
117 let result: (i64,) = sqlx::query_as(&query).fetch_one(&*self.pool).await?;
118 Ok(result.0)
119 }
120}
121
122#[async_trait]
123pub trait RepositoryExt<E: Entity>: Sized {
124 fn pool(&self) -> &PgPool;
125
126 async fn find_by<T: ToString + Send + Sync>(
127 &self,
128 column: &str,
129 value: T,
130 ) -> Result<Option<E>, sqlx::Error> {
131 if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
132 return Err(sqlx::Error::Protocol(format!(
133 "Invalid column name: {column}"
134 )));
135 }
136 let query = format!(
137 "SELECT {} FROM {} WHERE {} = $1",
138 E::COLUMNS,
139 E::TABLE,
140 column
141 );
142 sqlx::query_as::<_, E>(&query)
143 .bind(value.to_string())
144 .fetch_optional(self.pool())
145 .await
146 }
147
148 async fn find_all_by<T: ToString + Send + Sync>(
149 &self,
150 column: &str,
151 value: T,
152 ) -> Result<Vec<E>, sqlx::Error> {
153 if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
154 return Err(sqlx::Error::Protocol(format!(
155 "Invalid column name: {column}"
156 )));
157 }
158 let query = format!(
159 "SELECT {} FROM {} WHERE {} = $1 ORDER BY created_at DESC",
160 E::COLUMNS,
161 E::TABLE,
162 column
163 );
164 sqlx::query_as::<_, E>(&query)
165 .bind(value.to_string())
166 .fetch_all(self.pool())
167 .await
168 }
169}
170
171impl<E: Entity> RepositoryExt<E> for GenericRepository<E> {
172 fn pool(&self) -> &PgPool {
173 &self.pool
174 }
175}