Skip to main content

systemprompt_database/repository/
entity.rs

1use async_trait::async_trait;
2use sqlx::postgres::PgRow;
3use sqlx::{FromRow, PgPool};
4use std::sync::Arc;
5
6pub trait EntityId: Send + Sync + Clone + 'static {
7    fn as_str(&self) -> &str;
8
9    fn from_string(s: String) -> Self;
10}
11
12impl EntityId for String {
13    fn as_str(&self) -> &str {
14        self
15    }
16
17    fn from_string(s: String) -> Self {
18        s
19    }
20}
21
22pub trait Entity: for<'r> FromRow<'r, PgRow> + Send + Sync + Unpin + 'static {
23    type Id: EntityId;
24
25    const TABLE: &'static str;
26
27    const COLUMNS: &'static str;
28
29    const ID_COLUMN: &'static str;
30
31    fn id(&self) -> &Self::Id;
32}
33
34#[derive(Clone)]
35pub struct GenericRepository<E: Entity> {
36    pool: Arc<PgPool>,
37    write_pool: Arc<PgPool>,
38    _phantom: std::marker::PhantomData<E>,
39}
40
41impl<E: Entity> std::fmt::Debug for GenericRepository<E> {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("GenericRepository")
44            .field("table", &E::TABLE)
45            .finish()
46    }
47}
48
49impl<E: Entity> GenericRepository<E> {
50    #[must_use]
51    pub fn new(pool: Arc<PgPool>) -> Self {
52        let write_pool = Arc::clone(&pool);
53        Self {
54            pool,
55            write_pool,
56            _phantom: std::marker::PhantomData,
57        }
58    }
59
60    #[must_use]
61    pub const fn new_with_write_pool(pool: Arc<PgPool>, write_pool: Arc<PgPool>) -> Self {
62        Self {
63            pool,
64            write_pool,
65            _phantom: std::marker::PhantomData,
66        }
67    }
68
69    #[must_use]
70    pub fn pool(&self) -> &PgPool {
71        &self.pool
72    }
73
74    #[must_use]
75    pub fn write_pool(&self) -> &PgPool {
76        &self.write_pool
77    }
78
79    pub async fn get(&self, id: &E::Id) -> Result<Option<E>, sqlx::Error> {
80        let query = format!(
81            "SELECT {} FROM {} WHERE {} = $1",
82            E::COLUMNS,
83            E::TABLE,
84            E::ID_COLUMN
85        );
86        sqlx::query_as::<_, E>(&query)
87            .bind(id.as_str())
88            .fetch_optional(&*self.pool)
89            .await
90    }
91
92    pub async fn list(&self, limit: i64, offset: i64) -> Result<Vec<E>, sqlx::Error> {
93        let query = format!(
94            "SELECT {} FROM {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
95            E::COLUMNS,
96            E::TABLE
97        );
98        sqlx::query_as::<_, E>(&query)
99            .bind(limit)
100            .bind(offset)
101            .fetch_all(&*self.pool)
102            .await
103    }
104
105    pub async fn list_all(&self) -> Result<Vec<E>, sqlx::Error> {
106        let query = format!(
107            "SELECT {} FROM {} ORDER BY created_at DESC",
108            E::COLUMNS,
109            E::TABLE
110        );
111        sqlx::query_as::<_, E>(&query).fetch_all(&*self.pool).await
112    }
113
114    pub async fn delete(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
115        let query = format!("DELETE FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
116        let result = sqlx::query(&query)
117            .bind(id.as_str())
118            .execute(&*self.write_pool)
119            .await?;
120        Ok(result.rows_affected() > 0)
121    }
122
123    pub async fn exists(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
124        let query = format!("SELECT 1 FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
125        let result: Option<(i32,)> = sqlx::query_as(&query)
126            .bind(id.as_str())
127            .fetch_optional(&*self.pool)
128            .await?;
129        Ok(result.is_some())
130    }
131
132    pub async fn count(&self) -> Result<i64, sqlx::Error> {
133        let query = format!("SELECT COUNT(*) FROM {}", E::TABLE);
134        let result: (i64,) = sqlx::query_as(&query).fetch_one(&*self.pool).await?;
135        Ok(result.0)
136    }
137}
138
139#[async_trait]
140pub trait RepositoryExt<E: Entity>: Sized {
141    fn pool(&self) -> &PgPool;
142
143    async fn find_by<T: ToString + Send + Sync>(
144        &self,
145        column: &str,
146        value: T,
147    ) -> Result<Option<E>, sqlx::Error> {
148        if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
149            return Err(sqlx::Error::Protocol(format!(
150                "Invalid column name: {column}"
151            )));
152        }
153        let query = format!(
154            "SELECT {} FROM {} WHERE {} = $1",
155            E::COLUMNS,
156            E::TABLE,
157            column
158        );
159        sqlx::query_as::<_, E>(&query)
160            .bind(value.to_string())
161            .fetch_optional(self.pool())
162            .await
163    }
164
165    async fn find_all_by<T: ToString + Send + Sync>(
166        &self,
167        column: &str,
168        value: T,
169    ) -> Result<Vec<E>, sqlx::Error> {
170        if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
171            return Err(sqlx::Error::Protocol(format!(
172                "Invalid column name: {column}"
173            )));
174        }
175        let query = format!(
176            "SELECT {} FROM {} WHERE {} = $1 ORDER BY created_at DESC",
177            E::COLUMNS,
178            E::TABLE,
179            column
180        );
181        sqlx::query_as::<_, E>(&query)
182            .bind(value.to_string())
183            .fetch_all(self.pool())
184            .await
185    }
186}
187
188impl<E: Entity> RepositoryExt<E> for GenericRepository<E> {
189    fn pool(&self) -> &PgPool {
190        &self.pool
191    }
192}