Skip to main content

systemprompt_agent/repository/content/
push_notification.rs

1use anyhow::Result;
2use chrono::Utc;
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::{ConfigId, TaskId};
7
8use crate::models::a2a::protocol::PushNotificationConfig;
9use crate::models::database_rows::PushNotificationConfigRow;
10
11pub struct PushNotificationConfigRepository {
12    pool: Arc<PgPool>,
13}
14
15impl std::fmt::Debug for PushNotificationConfigRepository {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.debug_struct("PushNotificationConfigRepository")
18            .field("pool", &"<PgPool>")
19            .finish()
20    }
21}
22
23impl PushNotificationConfigRepository {
24    pub fn new(db: &DbPool) -> Result<Self> {
25        let pool = db.pool_arc()?;
26        Ok(Self { pool })
27    }
28
29    pub async fn add_config(
30        &self,
31        task_id: &TaskId,
32        config: &PushNotificationConfig,
33    ) -> Result<String> {
34        let config_id = uuid::Uuid::new_v4().to_string();
35        let headers_json = config
36            .headers
37            .as_ref()
38            .map(serde_json::to_value)
39            .transpose()?;
40        let auth_json = config
41            .authentication
42            .as_ref()
43            .map(serde_json::to_value)
44            .transpose()?;
45        let now = Utc::now();
46        let task_id_str = task_id.as_str();
47
48        sqlx::query!(
49            r#"INSERT INTO task_push_notification_configs
50                (id, task_id, url, endpoint, token, headers, authentication, created_at, updated_at)
51            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
52            config_id,
53            task_id_str,
54            config.url,
55            config.endpoint,
56            config.token,
57            headers_json,
58            auth_json,
59            now,
60            now
61        )
62        .execute(&*self.pool)
63        .await?;
64
65        Ok(config_id)
66    }
67
68    pub async fn get_config(
69        &self,
70        task_id: &TaskId,
71        config_id: &ConfigId,
72    ) -> Result<Option<PushNotificationConfig>> {
73        let task_id_str = task_id.as_str();
74        let config_id_str = config_id.as_str();
75        let row = sqlx::query_as!(
76            PushNotificationConfigRow,
77            r#"SELECT
78                id,
79                task_id,
80                url,
81                endpoint,
82                token,
83                headers,
84                authentication,
85                created_at,
86                updated_at
87            FROM task_push_notification_configs
88            WHERE task_id = $1 AND id = $2"#,
89            task_id_str,
90            config_id_str
91        )
92        .fetch_optional(&*self.pool)
93        .await?;
94
95        row.map(|r| Self::row_to_config(&r)).transpose()
96    }
97
98    pub async fn list_configs(&self, task_id: &TaskId) -> Result<Vec<PushNotificationConfig>> {
99        let task_id_str = task_id.as_str();
100        let rows: Vec<PushNotificationConfigRow> = sqlx::query_as!(
101            PushNotificationConfigRow,
102            r#"SELECT
103                id,
104                task_id,
105                url,
106                endpoint,
107                token,
108                headers,
109                authentication,
110                created_at,
111                updated_at
112            FROM task_push_notification_configs
113            WHERE task_id = $1"#,
114            task_id_str
115        )
116        .fetch_all(&*self.pool)
117        .await?;
118
119        rows.iter()
120            .map(|r| Self::row_to_config(r))
121            .collect::<Result<Vec<_>>>()
122    }
123
124    pub async fn delete_config(&self, task_id: &TaskId, config_id: &ConfigId) -> Result<bool> {
125        let task_id_str = task_id.as_str();
126        let config_id_str = config_id.as_str();
127        let result = sqlx::query!(
128            "DELETE FROM task_push_notification_configs WHERE task_id = $1 AND id = $2",
129            task_id_str,
130            config_id_str
131        )
132        .execute(&*self.pool)
133        .await?;
134
135        Ok(result.rows_affected() > 0)
136    }
137
138    pub async fn delete_all_for_task(&self, task_id: &TaskId) -> Result<u64> {
139        let task_id_str = task_id.as_str();
140        let result = sqlx::query!(
141            "DELETE FROM task_push_notification_configs WHERE task_id = $1",
142            task_id_str
143        )
144        .execute(&*self.pool)
145        .await?;
146
147        Ok(result.rows_affected())
148    }
149
150    fn row_to_config(row: &PushNotificationConfigRow) -> Result<PushNotificationConfig> {
151        let headers = row
152            .headers
153            .as_ref()
154            .map(|v| serde_json::from_value(v.clone()))
155            .transpose()?;
156        let authentication = row
157            .authentication
158            .as_ref()
159            .map(|v| serde_json::from_value(v.clone()))
160            .transpose()?;
161
162        Ok(PushNotificationConfig {
163            url: row.url.clone(),
164            endpoint: row.endpoint.clone(),
165            token: row.token.clone(),
166            headers,
167            authentication,
168        })
169    }
170}