Skip to main content

systemprompt_agent/repository/content/
push_notification.rs

1use anyhow::Result;
2use chrono::Utc;
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::{ConfigId, TaskId};
7
8use crate::models::a2a::protocol::PushNotificationConfig;
9use crate::models::database_rows::PushNotificationConfigRow;
10
11pub struct PushNotificationConfigRepository {
12    pool: Arc<PgPool>,
13    write_pool: Arc<PgPool>,
14}
15
16impl std::fmt::Debug for PushNotificationConfigRepository {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("PushNotificationConfigRepository")
19            .field("pool", &"<PgPool>")
20            .finish()
21    }
22}
23
24impl PushNotificationConfigRepository {
25    pub fn new(db: &DbPool) -> Result<Self> {
26        let pool = db.pool_arc()?;
27        let write_pool = db.write_pool_arc()?;
28        Ok(Self { pool, write_pool })
29    }
30
31    pub async fn add_config(
32        &self,
33        task_id: &TaskId,
34        config: &PushNotificationConfig,
35    ) -> Result<String> {
36        let config_id = uuid::Uuid::new_v4().to_string();
37        let headers_json = config
38            .headers
39            .as_ref()
40            .map(serde_json::to_value)
41            .transpose()?;
42        let auth_json = config
43            .authentication
44            .as_ref()
45            .map(serde_json::to_value)
46            .transpose()?;
47        let now = Utc::now();
48        let task_id_str = task_id.as_str();
49
50        sqlx::query!(
51            r#"INSERT INTO task_push_notification_configs
52                (id, task_id, url, endpoint, token, headers, authentication, created_at, updated_at)
53            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
54            config_id,
55            task_id_str,
56            config.url,
57            config.endpoint,
58            config.token,
59            headers_json,
60            auth_json,
61            now,
62            now
63        )
64        .execute(&*self.write_pool)
65        .await?;
66
67        Ok(config_id)
68    }
69
70    pub async fn get_config(
71        &self,
72        task_id: &TaskId,
73        config_id: &ConfigId,
74    ) -> Result<Option<PushNotificationConfig>> {
75        let task_id_str = task_id.as_str();
76        let config_id_str = config_id.as_str();
77        let row = sqlx::query_as!(
78            PushNotificationConfigRow,
79            r#"SELECT
80                id,
81                task_id,
82                url,
83                endpoint,
84                token,
85                headers,
86                authentication,
87                created_at,
88                updated_at
89            FROM task_push_notification_configs
90            WHERE task_id = $1 AND id = $2"#,
91            task_id_str,
92            config_id_str
93        )
94        .fetch_optional(&*self.pool)
95        .await?;
96
97        row.map(|r| Self::row_to_config(&r)).transpose()
98    }
99
100    pub async fn list_configs(&self, task_id: &TaskId) -> Result<Vec<PushNotificationConfig>> {
101        let task_id_str = task_id.as_str();
102        let rows: Vec<PushNotificationConfigRow> = sqlx::query_as!(
103            PushNotificationConfigRow,
104            r#"SELECT
105                id,
106                task_id,
107                url,
108                endpoint,
109                token,
110                headers,
111                authentication,
112                created_at,
113                updated_at
114            FROM task_push_notification_configs
115            WHERE task_id = $1"#,
116            task_id_str
117        )
118        .fetch_all(&*self.pool)
119        .await?;
120
121        rows.iter()
122            .map(|r| Self::row_to_config(r))
123            .collect::<Result<Vec<_>>>()
124    }
125
126    pub async fn delete_config(&self, task_id: &TaskId, config_id: &ConfigId) -> Result<bool> {
127        let task_id_str = task_id.as_str();
128        let config_id_str = config_id.as_str();
129        let result = sqlx::query!(
130            "DELETE FROM task_push_notification_configs WHERE task_id = $1 AND id = $2",
131            task_id_str,
132            config_id_str
133        )
134        .execute(&*self.write_pool)
135        .await?;
136
137        Ok(result.rows_affected() > 0)
138    }
139
140    pub async fn delete_all_for_task(&self, task_id: &TaskId) -> Result<u64> {
141        let task_id_str = task_id.as_str();
142        let result = sqlx::query!(
143            "DELETE FROM task_push_notification_configs WHERE task_id = $1",
144            task_id_str
145        )
146        .execute(&*self.write_pool)
147        .await?;
148
149        Ok(result.rows_affected())
150    }
151
152    fn row_to_config(row: &PushNotificationConfigRow) -> Result<PushNotificationConfig> {
153        let headers = row
154            .headers
155            .as_ref()
156            .map(|v| serde_json::from_value(v.clone()))
157            .transpose()?;
158        let authentication = row
159            .authentication
160            .as_ref()
161            .map(|v| serde_json::from_value(v.clone()))
162            .transpose()?;
163
164        Ok(PushNotificationConfig {
165            url: row.url.clone(),
166            endpoint: row.endpoint.clone(),
167            token: row.token.clone(),
168            headers,
169            authentication,
170        })
171    }
172}