systemprompt_database/repository/
entity.rs1use async_trait::async_trait;
2use sqlx::postgres::PgRow;
3use sqlx::{FromRow, PgPool};
4use std::sync::Arc;
5
6pub trait EntityId: Send + Sync + Clone + 'static {
7 fn as_str(&self) -> &str;
8
9 fn from_string(s: String) -> Self;
10}
11
12impl EntityId for String {
13 fn as_str(&self) -> &str {
14 self
15 }
16
17 fn from_string(s: String) -> Self {
18 s
19 }
20}
21
22pub trait Entity: for<'r> FromRow<'r, PgRow> + Send + Sync + Unpin + 'static {
23 type Id: EntityId;
24
25 const TABLE: &'static str;
26
27 const COLUMNS: &'static str;
28
29 const ID_COLUMN: &'static str;
30
31 fn id(&self) -> &Self::Id;
32}
33
34#[derive(Clone)]
35pub struct GenericRepository<E: Entity> {
36 pool: Arc<PgPool>,
37 write_pool: Arc<PgPool>,
38 _phantom: std::marker::PhantomData<E>,
39}
40
41impl<E: Entity> std::fmt::Debug for GenericRepository<E> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("GenericRepository")
44 .field("table", &E::TABLE)
45 .finish()
46 }
47}
48
49impl<E: Entity> GenericRepository<E> {
50 #[must_use]
51 pub fn new(pool: Arc<PgPool>) -> Self {
52 let write_pool = Arc::clone(&pool);
53 Self {
54 pool,
55 write_pool,
56 _phantom: std::marker::PhantomData,
57 }
58 }
59
60 #[must_use]
61 pub const fn new_with_write_pool(pool: Arc<PgPool>, write_pool: Arc<PgPool>) -> Self {
62 Self {
63 pool,
64 write_pool,
65 _phantom: std::marker::PhantomData,
66 }
67 }
68
69 #[must_use]
70 pub fn pool(&self) -> &PgPool {
71 &self.pool
72 }
73
74 #[must_use]
75 pub fn write_pool(&self) -> &PgPool {
76 &self.write_pool
77 }
78
79 pub async fn get(&self, id: &E::Id) -> Result<Option<E>, sqlx::Error> {
80 let query = format!(
81 "SELECT {} FROM {} WHERE {} = $1",
82 E::COLUMNS,
83 E::TABLE,
84 E::ID_COLUMN
85 );
86 sqlx::query_as::<_, E>(&query)
87 .bind(id.as_str())
88 .fetch_optional(&*self.pool)
89 .await
90 }
91
92 pub async fn list(&self, limit: i64, offset: i64) -> Result<Vec<E>, sqlx::Error> {
93 let query = format!(
94 "SELECT {} FROM {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
95 E::COLUMNS,
96 E::TABLE
97 );
98 sqlx::query_as::<_, E>(&query)
99 .bind(limit)
100 .bind(offset)
101 .fetch_all(&*self.pool)
102 .await
103 }
104
105 pub async fn list_all(&self) -> Result<Vec<E>, sqlx::Error> {
106 let query = format!(
107 "SELECT {} FROM {} ORDER BY created_at DESC",
108 E::COLUMNS,
109 E::TABLE
110 );
111 sqlx::query_as::<_, E>(&query).fetch_all(&*self.pool).await
112 }
113
114 pub async fn delete(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
115 let query = format!("DELETE FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
116 let result = sqlx::query(&query)
117 .bind(id.as_str())
118 .execute(&*self.write_pool)
119 .await?;
120 Ok(result.rows_affected() > 0)
121 }
122
123 pub async fn exists(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
124 let query = format!("SELECT 1 FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
125 let result: Option<(i32,)> = sqlx::query_as(&query)
126 .bind(id.as_str())
127 .fetch_optional(&*self.pool)
128 .await?;
129 Ok(result.is_some())
130 }
131
132 pub async fn count(&self) -> Result<i64, sqlx::Error> {
133 let query = format!("SELECT COUNT(*) FROM {}", E::TABLE);
134 let result: (i64,) = sqlx::query_as(&query).fetch_one(&*self.pool).await?;
135 Ok(result.0)
136 }
137}
138
139#[async_trait]
140pub trait RepositoryExt<E: Entity>: Sized {
141 fn pool(&self) -> &PgPool;
142
143 async fn find_by<T: ToString + Send + Sync>(
144 &self,
145 column: &str,
146 value: T,
147 ) -> Result<Option<E>, sqlx::Error> {
148 if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
149 return Err(sqlx::Error::Protocol(format!(
150 "Invalid column name: {column}"
151 )));
152 }
153 let query = format!(
154 "SELECT {} FROM {} WHERE {} = $1",
155 E::COLUMNS,
156 E::TABLE,
157 column
158 );
159 sqlx::query_as::<_, E>(&query)
160 .bind(value.to_string())
161 .fetch_optional(self.pool())
162 .await
163 }
164
165 async fn find_all_by<T: ToString + Send + Sync>(
166 &self,
167 column: &str,
168 value: T,
169 ) -> Result<Vec<E>, sqlx::Error> {
170 if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
171 return Err(sqlx::Error::Protocol(format!(
172 "Invalid column name: {column}"
173 )));
174 }
175 let query = format!(
176 "SELECT {} FROM {} WHERE {} = $1 ORDER BY created_at DESC",
177 E::COLUMNS,
178 E::TABLE,
179 column
180 );
181 sqlx::query_as::<_, E>(&query)
182 .bind(value.to_string())
183 .fetch_all(self.pool())
184 .await
185 }
186}
187
188impl<E: Entity> RepositoryExt<E> for GenericRepository<E> {
189 fn pool(&self) -> &PgPool {
190 &self.pool
191 }
192}