Skip to main content

systemprompt_database/repository/
entity.rs

1use async_trait::async_trait;
2use sqlx::postgres::PgRow;
3use sqlx::{FromRow, PgPool};
4use std::sync::Arc;
5
6pub trait EntityId: Send + Sync + Clone + 'static {
7    fn as_str(&self) -> &str;
8
9    fn from_string(s: String) -> Self;
10}
11
12impl EntityId for String {
13    fn as_str(&self) -> &str {
14        self
15    }
16
17    fn from_string(s: String) -> Self {
18        s
19    }
20}
21
22pub trait Entity: for<'r> FromRow<'r, PgRow> + Send + Sync + Unpin + 'static {
23    type Id: EntityId;
24
25    const TABLE: &'static str;
26
27    const COLUMNS: &'static str;
28
29    const ID_COLUMN: &'static str;
30
31    fn id(&self) -> &Self::Id;
32}
33
34#[derive(Clone)]
35pub struct GenericRepository<E: Entity> {
36    pool: Arc<PgPool>,
37    _phantom: std::marker::PhantomData<E>,
38}
39
40impl<E: Entity> std::fmt::Debug for GenericRepository<E> {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("GenericRepository")
43            .field("table", &E::TABLE)
44            .finish()
45    }
46}
47
48impl<E: Entity> GenericRepository<E> {
49    #[must_use]
50    pub const fn new(pool: Arc<PgPool>) -> Self {
51        Self {
52            pool,
53            _phantom: std::marker::PhantomData,
54        }
55    }
56
57    #[must_use]
58    pub fn pool(&self) -> &PgPool {
59        &self.pool
60    }
61
62    pub async fn get(&self, id: &E::Id) -> Result<Option<E>, sqlx::Error> {
63        let query = format!(
64            "SELECT {} FROM {} WHERE {} = $1",
65            E::COLUMNS,
66            E::TABLE,
67            E::ID_COLUMN
68        );
69        sqlx::query_as::<_, E>(&query)
70            .bind(id.as_str())
71            .fetch_optional(&*self.pool)
72            .await
73    }
74
75    pub async fn list(&self, limit: i64, offset: i64) -> Result<Vec<E>, sqlx::Error> {
76        let query = format!(
77            "SELECT {} FROM {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
78            E::COLUMNS,
79            E::TABLE
80        );
81        sqlx::query_as::<_, E>(&query)
82            .bind(limit)
83            .bind(offset)
84            .fetch_all(&*self.pool)
85            .await
86    }
87
88    pub async fn list_all(&self) -> Result<Vec<E>, sqlx::Error> {
89        let query = format!(
90            "SELECT {} FROM {} ORDER BY created_at DESC",
91            E::COLUMNS,
92            E::TABLE
93        );
94        sqlx::query_as::<_, E>(&query).fetch_all(&*self.pool).await
95    }
96
97    pub async fn delete(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
98        let query = format!("DELETE FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
99        let result = sqlx::query(&query)
100            .bind(id.as_str())
101            .execute(&*self.pool)
102            .await?;
103        Ok(result.rows_affected() > 0)
104    }
105
106    pub async fn exists(&self, id: &E::Id) -> Result<bool, sqlx::Error> {
107        let query = format!("SELECT 1 FROM {} WHERE {} = $1", E::TABLE, E::ID_COLUMN);
108        let result: Option<(i32,)> = sqlx::query_as(&query)
109            .bind(id.as_str())
110            .fetch_optional(&*self.pool)
111            .await?;
112        Ok(result.is_some())
113    }
114
115    pub async fn count(&self) -> Result<i64, sqlx::Error> {
116        let query = format!("SELECT COUNT(*) FROM {}", E::TABLE);
117        let result: (i64,) = sqlx::query_as(&query).fetch_one(&*self.pool).await?;
118        Ok(result.0)
119    }
120}
121
122#[async_trait]
123pub trait RepositoryExt<E: Entity>: Sized {
124    fn pool(&self) -> &PgPool;
125
126    async fn find_by<T: ToString + Send + Sync>(
127        &self,
128        column: &str,
129        value: T,
130    ) -> Result<Option<E>, sqlx::Error> {
131        if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
132            return Err(sqlx::Error::Protocol(format!(
133                "Invalid column name: {column}"
134            )));
135        }
136        let query = format!(
137            "SELECT {} FROM {} WHERE {} = $1",
138            E::COLUMNS,
139            E::TABLE,
140            column
141        );
142        sqlx::query_as::<_, E>(&query)
143            .bind(value.to_string())
144            .fetch_optional(self.pool())
145            .await
146    }
147
148    async fn find_all_by<T: ToString + Send + Sync>(
149        &self,
150        column: &str,
151        value: T,
152    ) -> Result<Vec<E>, sqlx::Error> {
153        if !column.chars().all(|c| c.is_alphanumeric() || c == '_') {
154            return Err(sqlx::Error::Protocol(format!(
155                "Invalid column name: {column}"
156            )));
157        }
158        let query = format!(
159            "SELECT {} FROM {} WHERE {} = $1 ORDER BY created_at DESC",
160            E::COLUMNS,
161            E::TABLE,
162            column
163        );
164        sqlx::query_as::<_, E>(&query)
165            .bind(value.to_string())
166            .fetch_all(self.pool())
167            .await
168    }
169}
170
171impl<E: Entity> RepositoryExt<E> for GenericRepository<E> {
172    fn pool(&self) -> &PgPool {
173        &self.pool
174    }
175}