1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
5
6use crate::models::a2a::Part;
7
8#[derive(Clone)]
9pub struct FileUploadContext<'a> {
10 pub upload_provider: &'a DynFileUploadProvider,
11 pub context_id: &'a ContextId,
12 pub user_id: Option<&'a UserId>,
13 pub session_id: Option<&'a SessionId>,
14 pub trace_id: Option<&'a TraceId>,
15}
16
17impl std::fmt::Debug for FileUploadContext<'_> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("FileUploadContext")
20 .field("upload_provider", &"<DynFileUploadProvider>")
21 .field("context_id", &self.context_id)
22 .field("user_id", &self.user_id)
23 .field("session_id", &self.session_id)
24 .field("trace_id", &self.trace_id)
25 .finish()
26 }
27}
28
29pub async fn get_message_parts(
30 pool: &Arc<PgPool>,
31 message_id: &MessageId,
32) -> Result<Vec<Part>, RepositoryError> {
33 let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
34 crate::models::MessagePart,
35 r#"SELECT
36 id as "id!",
37 message_id as "message_id!: MessageId",
38 task_id as "task_id!: TaskId",
39 part_kind as "part_kind!",
40 sequence_number as "sequence_number!",
41 text_content,
42 file_name,
43 file_mime_type,
44 file_uri,
45 file_bytes,
46 data_content,
47 metadata
48 FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
49 message_id.as_str()
50 )
51 .fetch_all(pool.as_ref())
52 .await
53 .map_err(RepositoryError::database)?;
54
55 let mut parts = Vec::new();
56
57 for row in part_rows {
58 let part = match row.part_kind.as_str() {
59 "text" => {
60 let text = row
61 .text_content
62 .ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
63 Part::Text(crate::models::a2a::TextPart { text })
64 },
65 "file" => {
66 let bytes = row
67 .file_bytes
68 .ok_or_else(|| RepositoryError::InvalidData("Missing file_bytes".into()))?;
69 Part::File(crate::models::a2a::FilePart {
70 file: crate::models::a2a::FileWithBytes {
71 name: row.file_name,
72 mime_type: row.file_mime_type,
73 bytes,
74 },
75 })
76 },
77 "data" => {
78 let data_value = row
79 .data_content
80 .ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
81 let serde_json::Value::Object(data) = data_value else {
82 return Err(RepositoryError::InvalidData(
83 "Data content must be a JSON object".into(),
84 ));
85 };
86 Part::Data(crate::models::a2a::DataPart { data })
87 },
88 _ => {
89 return Err(RepositoryError::InvalidData(format!(
90 "Unknown part kind: {}",
91 row.part_kind
92 )));
93 },
94 };
95
96 parts.push(part);
97 }
98
99 Ok(parts)
100}
101
102#[allow(missing_debug_implementations)]
103pub struct PersistPartSqlxParams<'a> {
104 pub tx: &'a mut sqlx::Transaction<'static, sqlx::Postgres>,
105 pub part: &'a Part,
106 pub message_id: &'a MessageId,
107 pub task_id: &'a TaskId,
108 pub sequence_number: i32,
109 pub upload_ctx: Option<&'a FileUploadContext<'a>>,
110}
111
112pub async fn persist_part_sqlx(params: PersistPartSqlxParams<'_>) -> Result<(), RepositoryError> {
113 let PersistPartSqlxParams {
114 tx,
115 part,
116 message_id,
117 task_id,
118 sequence_number,
119 upload_ctx,
120 } = params;
121 match part {
122 Part::Text(text_part) => {
123 sqlx::query!(
124 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
125 VALUES ($1, $2, 'text', $3, $4)"#,
126 message_id.as_str(),
127 task_id.as_str(),
128 sequence_number,
129 text_part.text
130 )
131 .execute(&mut **tx)
132 .await
133 .map_err(RepositoryError::database)?;
134 },
135 Part::File(file_part) => {
136 let upload_result = try_upload_file(file_part, upload_ctx).await;
137
138 let (file_id, file_uri) = match upload_result {
139 Some((id, uri)) => (Some(id), Some(uri)),
140 None => (None, None),
141 };
142
143 sqlx::query!(
144 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
145 VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
146 message_id.as_str(),
147 task_id.as_str(),
148 sequence_number,
149 file_part.file.name,
150 file_part.file.mime_type,
151 file_uri,
152 file_part.file.bytes,
153 file_id
154 )
155 .execute(&mut **tx)
156 .await
157 .map_err(RepositoryError::database)?;
158 },
159 Part::Data(data_part) => {
160 let data_json =
161 serde_json::to_value(&data_part.data).map_err(RepositoryError::Serialization)?;
162 sqlx::query!(
163 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
164 VALUES ($1, $2, 'data', $3, $4)"#,
165 message_id.as_str(),
166 task_id.as_str(),
167 sequence_number,
168 data_json
169 )
170 .execute(&mut **tx)
171 .await
172 .map_err(RepositoryError::database)?;
173 },
174 }
175
176 Ok(())
177}
178
179async fn try_upload_file(
180 file_part: &crate::models::a2a::FilePart,
181 upload_ctx: Option<&FileUploadContext<'_>>,
182) -> Option<(uuid::Uuid, String)> {
183 let ctx = upload_ctx?;
184
185 if !ctx.upload_provider.is_enabled() {
186 return None;
187 }
188
189 let mime_type = file_part
190 .file
191 .mime_type
192 .as_deref()
193 .unwrap_or("application/octet-stream");
194
195 let mut input = FileUploadInput::new(mime_type, &file_part.file.bytes, ctx.context_id.clone());
196
197 if let Some(name) = &file_part.file.name {
198 input = input.with_name(name);
199 }
200
201 if let Some(user_id) = ctx.user_id {
202 input = input.with_user_id(user_id.clone());
203 }
204
205 if let Some(session_id) = ctx.session_id {
206 input = input.with_session_id(session_id.clone());
207 }
208
209 if let Some(trace_id) = ctx.trace_id {
210 input = input.with_trace_id(trace_id.clone());
211 }
212
213 match ctx.upload_provider.upload_file(input).await {
214 Ok(uploaded) => {
215 let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
216 .map_err(|e| {
217 tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
218 e
219 })
220 .ok()?;
221 Some((file_uuid, uploaded.public_url))
222 },
223 Err(e) => {
224 tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
225 None
226 },
227 }
228}
229
230pub async fn persist_part_with_tx(
231 tx: &mut dyn systemprompt_database::DatabaseTransaction,
232 part: &Part,
233 message_id: &MessageId,
234 task_id: &TaskId,
235 sequence_number: i32,
236) -> Result<(), RepositoryError> {
237 let message_id_str = message_id.as_str();
238 let task_id_str = task_id.as_str();
239 match part {
240 Part::Text(text_part) => {
241 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
242 sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
243 tx.execute(
244 &query,
245 &[
246 &message_id_str,
247 &task_id_str,
248 &sequence_number,
249 &text_part.text,
250 ],
251 )
252 .await?;
253 },
254 Part::File(file_part) => {
255 let uri_opt: Option<&str> = None;
256 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
257 sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
258 VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
259 tx.execute(
260 &query,
261 &[
262 &message_id_str,
263 &task_id_str,
264 &sequence_number,
265 &file_part.file.name,
266 &file_part.file.mime_type,
267 &uri_opt,
268 &file_part.file.bytes,
269 ],
270 )
271 .await?;
272 },
273 Part::Data(data_part) => {
274 let data_json = serde_json::to_string(&data_part.data)?;
275 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
276 sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
277 tx.execute(
278 &query,
279 &[&message_id_str, &task_id_str, &sequence_number, &data_json],
280 )
281 .await?;
282 },
283 }
284
285 Ok(())
286}