systemprompt_agent/repository/content/
push_notification.rs1use anyhow::Result;
2use chrono::Utc;
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::{ConfigId, TaskId};
7
8use crate::models::a2a::protocol::PushNotificationConfig;
9use crate::models::database_rows::PushNotificationConfigRow;
10
11pub struct PushNotificationConfigRepository {
12 pool: Arc<PgPool>,
13 write_pool: Arc<PgPool>,
14}
15
16impl std::fmt::Debug for PushNotificationConfigRepository {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("PushNotificationConfigRepository")
19 .field("pool", &"<PgPool>")
20 .finish()
21 }
22}
23
24impl PushNotificationConfigRepository {
25 pub fn new(db: &DbPool) -> Result<Self> {
26 let pool = db.pool_arc()?;
27 let write_pool = db.write_pool_arc()?;
28 Ok(Self { pool, write_pool })
29 }
30
31 pub async fn add_config(
32 &self,
33 task_id: &TaskId,
34 config: &PushNotificationConfig,
35 ) -> Result<String> {
36 let config_id = uuid::Uuid::new_v4().to_string();
37 let headers_json = config
38 .headers
39 .as_ref()
40 .map(serde_json::to_value)
41 .transpose()?;
42 let auth_json = config
43 .authentication
44 .as_ref()
45 .map(serde_json::to_value)
46 .transpose()?;
47 let now = Utc::now();
48 let task_id_str = task_id.as_str();
49
50 sqlx::query!(
51 r#"INSERT INTO task_push_notification_configs
52 (id, task_id, url, endpoint, token, headers, authentication, created_at, updated_at)
53 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
54 config_id,
55 task_id_str,
56 config.url,
57 config.endpoint,
58 config.token,
59 headers_json,
60 auth_json,
61 now,
62 now
63 )
64 .execute(&*self.write_pool)
65 .await?;
66
67 Ok(config_id)
68 }
69
70 pub async fn get_config(
71 &self,
72 task_id: &TaskId,
73 config_id: &ConfigId,
74 ) -> Result<Option<PushNotificationConfig>> {
75 let task_id_str = task_id.as_str();
76 let config_id_str = config_id.as_str();
77 let row = sqlx::query_as!(
78 PushNotificationConfigRow,
79 r#"SELECT
80 id,
81 task_id,
82 url,
83 endpoint,
84 token,
85 headers,
86 authentication,
87 created_at,
88 updated_at
89 FROM task_push_notification_configs
90 WHERE task_id = $1 AND id = $2"#,
91 task_id_str,
92 config_id_str
93 )
94 .fetch_optional(&*self.pool)
95 .await?;
96
97 row.map(|r| Self::row_to_config(&r)).transpose()
98 }
99
100 pub async fn list_configs(&self, task_id: &TaskId) -> Result<Vec<PushNotificationConfig>> {
101 let task_id_str = task_id.as_str();
102 let rows: Vec<PushNotificationConfigRow> = sqlx::query_as!(
103 PushNotificationConfigRow,
104 r#"SELECT
105 id,
106 task_id,
107 url,
108 endpoint,
109 token,
110 headers,
111 authentication,
112 created_at,
113 updated_at
114 FROM task_push_notification_configs
115 WHERE task_id = $1"#,
116 task_id_str
117 )
118 .fetch_all(&*self.pool)
119 .await?;
120
121 rows.iter()
122 .map(|r| Self::row_to_config(r))
123 .collect::<Result<Vec<_>>>()
124 }
125
126 pub async fn delete_config(&self, task_id: &TaskId, config_id: &ConfigId) -> Result<bool> {
127 let task_id_str = task_id.as_str();
128 let config_id_str = config_id.as_str();
129 let result = sqlx::query!(
130 "DELETE FROM task_push_notification_configs WHERE task_id = $1 AND id = $2",
131 task_id_str,
132 config_id_str
133 )
134 .execute(&*self.write_pool)
135 .await?;
136
137 Ok(result.rows_affected() > 0)
138 }
139
140 pub async fn delete_all_for_task(&self, task_id: &TaskId) -> Result<u64> {
141 let task_id_str = task_id.as_str();
142 let result = sqlx::query!(
143 "DELETE FROM task_push_notification_configs WHERE task_id = $1",
144 task_id_str
145 )
146 .execute(&*self.write_pool)
147 .await?;
148
149 Ok(result.rows_affected())
150 }
151
152 fn row_to_config(row: &PushNotificationConfigRow) -> Result<PushNotificationConfig> {
153 let headers = row
154 .headers
155 .as_ref()
156 .map(|v| serde_json::from_value(v.clone()))
157 .transpose()?;
158 let authentication = row
159 .authentication
160 .as_ref()
161 .map(|v| serde_json::from_value(v.clone()))
162 .transpose()?;
163
164 Ok(PushNotificationConfig {
165 url: row.url.clone(),
166 endpoint: row.endpoint.clone(),
167 token: row.token.clone(),
168 headers,
169 authentication,
170 })
171 }
172}