1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
5
6use crate::models::a2a::Part;
7
8#[derive(Clone)]
9pub struct FileUploadContext<'a> {
10 pub upload_provider: &'a DynFileUploadProvider,
11 pub context_id: &'a ContextId,
12 pub user_id: Option<&'a UserId>,
13 pub session_id: Option<&'a SessionId>,
14 pub trace_id: Option<&'a TraceId>,
15}
16
17impl std::fmt::Debug for FileUploadContext<'_> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("FileUploadContext")
20 .field("upload_provider", &"<DynFileUploadProvider>")
21 .field("context_id", &self.context_id)
22 .field("user_id", &self.user_id)
23 .field("session_id", &self.session_id)
24 .field("trace_id", &self.trace_id)
25 .finish()
26 }
27}
28
29pub async fn get_message_parts(
30 pool: &Arc<PgPool>,
31 message_id: &MessageId,
32) -> Result<Vec<Part>, RepositoryError> {
33 let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
34 crate::models::MessagePart,
35 r#"SELECT
36 id as "id!",
37 message_id as "message_id!: MessageId",
38 task_id as "task_id!: TaskId",
39 part_kind as "part_kind!",
40 sequence_number as "sequence_number!",
41 text_content,
42 file_name,
43 file_mime_type,
44 file_uri,
45 file_bytes,
46 data_content,
47 metadata
48 FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
49 message_id.as_str()
50 )
51 .fetch_all(pool.as_ref())
52 .await
53 .map_err(RepositoryError::database)?;
54
55 let mut parts = Vec::new();
56
57 for row in part_rows {
58 let part = match row.part_kind.as_str() {
59 "text" => {
60 let text = row
61 .text_content
62 .ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
63 Part::Text(crate::models::a2a::TextPart { text })
64 },
65 "file" => Part::File(crate::models::a2a::FilePart {
66 file: crate::models::a2a::FileContent {
67 name: row.file_name,
68 mime_type: row.file_mime_type,
69 bytes: row.file_bytes,
70 url: row.file_uri,
71 },
72 }),
73 "data" => {
74 let data_value = row
75 .data_content
76 .ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
77 let serde_json::Value::Object(data) = data_value else {
78 return Err(RepositoryError::InvalidData(
79 "Data content must be a JSON object".into(),
80 ));
81 };
82 Part::Data(crate::models::a2a::DataPart { data })
83 },
84 _ => {
85 return Err(RepositoryError::InvalidData(format!(
86 "Unknown part kind: {}",
87 row.part_kind
88 )));
89 },
90 };
91
92 parts.push(part);
93 }
94
95 Ok(parts)
96}
97
98#[allow(missing_debug_implementations)]
99pub struct PersistPartSqlxParams<'a> {
100 pub tx: &'a mut sqlx::Transaction<'static, sqlx::Postgres>,
101 pub part: &'a Part,
102 pub message_id: &'a MessageId,
103 pub task_id: &'a TaskId,
104 pub sequence_number: i32,
105 pub upload_ctx: Option<&'a FileUploadContext<'a>>,
106}
107
108pub async fn persist_part_sqlx(params: PersistPartSqlxParams<'_>) -> Result<(), RepositoryError> {
109 let PersistPartSqlxParams {
110 tx,
111 part,
112 message_id,
113 task_id,
114 sequence_number,
115 upload_ctx,
116 } = params;
117 match part {
118 Part::Text(text_part) => {
119 sqlx::query!(
120 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
121 VALUES ($1, $2, 'text', $3, $4)"#,
122 message_id.as_str(),
123 task_id.as_str(),
124 sequence_number,
125 text_part.text
126 )
127 .execute(&mut **tx)
128 .await
129 .map_err(RepositoryError::database)?;
130 },
131 Part::File(file_part) => {
132 let upload_result = try_upload_file(file_part, upload_ctx).await;
133
134 let (file_id, file_uri) = match upload_result {
135 Some((id, uri)) => (Some(id), Some(uri)),
136 None => (None, None),
137 };
138
139 sqlx::query!(
140 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
141 VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
142 message_id.as_str(),
143 task_id.as_str(),
144 sequence_number,
145 file_part.file.name,
146 file_part.file.mime_type,
147 file_uri,
148 file_part.file.bytes.as_deref(),
149 file_id
150 )
151 .execute(&mut **tx)
152 .await
153 .map_err(RepositoryError::database)?;
154 },
155 Part::Data(data_part) => {
156 let data_json =
157 serde_json::to_value(&data_part.data).map_err(RepositoryError::Serialization)?;
158 sqlx::query!(
159 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
160 VALUES ($1, $2, 'data', $3, $4)"#,
161 message_id.as_str(),
162 task_id.as_str(),
163 sequence_number,
164 data_json
165 )
166 .execute(&mut **tx)
167 .await
168 .map_err(RepositoryError::database)?;
169 },
170 }
171
172 Ok(())
173}
174
175async fn try_upload_file(
176 file_part: &crate::models::a2a::FilePart,
177 upload_ctx: Option<&FileUploadContext<'_>>,
178) -> Option<(uuid::Uuid, String)> {
179 let ctx = upload_ctx?;
180
181 if !ctx.upload_provider.is_enabled() {
182 return None;
183 }
184
185 let mime_type = file_part
186 .file
187 .mime_type
188 .as_deref()
189 .unwrap_or("application/octet-stream");
190
191 let bytes = file_part.file.bytes.as_deref()?;
192 let mut input = FileUploadInput::new(mime_type, bytes, ctx.context_id.clone());
193
194 if let Some(name) = &file_part.file.name {
195 input = input.with_name(name);
196 }
197
198 if let Some(user_id) = ctx.user_id {
199 input = input.with_user_id(user_id.clone());
200 }
201
202 if let Some(session_id) = ctx.session_id {
203 input = input.with_session_id(session_id.clone());
204 }
205
206 if let Some(trace_id) = ctx.trace_id {
207 input = input.with_trace_id(trace_id.clone());
208 }
209
210 match ctx.upload_provider.upload_file(input).await {
211 Ok(uploaded) => {
212 let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
213 .map_err(|e| {
214 tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
215 e
216 })
217 .ok()?;
218 Some((file_uuid, uploaded.public_url))
219 },
220 Err(e) => {
221 tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
222 None
223 },
224 }
225}
226
227pub async fn persist_part_with_tx(
228 tx: &mut dyn systemprompt_database::DatabaseTransaction,
229 part: &Part,
230 message_id: &MessageId,
231 task_id: &TaskId,
232 sequence_number: i32,
233) -> Result<(), RepositoryError> {
234 let message_id_str = message_id.as_str();
235 let task_id_str = task_id.as_str();
236 match part {
237 Part::Text(text_part) => {
238 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
239 sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
240 tx.execute(
241 &query,
242 &[
243 &message_id_str,
244 &task_id_str,
245 &sequence_number,
246 &text_part.text,
247 ],
248 )
249 .await?;
250 },
251 Part::File(file_part) => {
252 let uri_opt: Option<&str> = None;
253 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
254 sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
255 VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
256 tx.execute(
257 &query,
258 &[
259 &message_id_str,
260 &task_id_str,
261 &sequence_number,
262 &file_part.file.name,
263 &file_part.file.mime_type,
264 &uri_opt,
265 &file_part.file.bytes.as_deref(),
266 ],
267 )
268 .await?;
269 },
270 Part::Data(data_part) => {
271 let data_json = serde_json::to_string(&data_part.data)?;
272 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
273 sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
274 tx.execute(
275 &query,
276 &[&message_id_str, &task_id_str, &sequence_number, &data_json],
277 )
278 .await?;
279 },
280 }
281
282 Ok(())
283}