1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
5
6use crate::models::a2a::Part;
7
8#[derive(Clone)]
9pub struct FileUploadContext<'a> {
10 pub upload_provider: &'a DynFileUploadProvider,
11 pub context_id: &'a ContextId,
12 pub user_id: Option<&'a UserId>,
13 pub session_id: Option<&'a SessionId>,
14 pub trace_id: Option<&'a TraceId>,
15}
16
17impl std::fmt::Debug for FileUploadContext<'_> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("FileUploadContext")
20 .field("upload_provider", &"<DynFileUploadProvider>")
21 .field("context_id", &self.context_id)
22 .field("user_id", &self.user_id)
23 .field("session_id", &self.session_id)
24 .field("trace_id", &self.trace_id)
25 .finish()
26 }
27}
28
29pub async fn get_message_parts(
30 pool: &Arc<PgPool>,
31 message_id: &MessageId,
32) -> Result<Vec<Part>, RepositoryError> {
33 let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
34 crate::models::MessagePart,
35 r#"SELECT
36 id as "id!",
37 message_id as "message_id!: MessageId",
38 task_id as "task_id!: TaskId",
39 part_kind as "part_kind!",
40 sequence_number as "sequence_number!",
41 text_content,
42 file_name,
43 file_mime_type,
44 file_uri,
45 file_bytes,
46 data_content,
47 metadata
48 FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
49 message_id.as_str()
50 )
51 .fetch_all(pool.as_ref())
52 .await
53 .map_err(|e| RepositoryError::database(e))?;
54
55 let mut parts = Vec::new();
56
57 for row in part_rows {
58 let part = match row.part_kind.as_str() {
59 "text" => {
60 let text = row
61 .text_content
62 .ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
63 Part::Text(crate::models::a2a::TextPart { text })
64 },
65 "file" => {
66 let bytes = row
67 .file_bytes
68 .ok_or_else(|| RepositoryError::InvalidData("Missing file_bytes".into()))?;
69 Part::File(crate::models::a2a::FilePart {
70 file: crate::models::a2a::FileWithBytes {
71 name: row.file_name,
72 mime_type: row.file_mime_type,
73 bytes,
74 },
75 })
76 },
77 "data" => {
78 let data_value = row
79 .data_content
80 .ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
81 let data = if let serde_json::Value::Object(map) = data_value {
82 map
83 } else {
84 return Err(RepositoryError::InvalidData(
85 "Data content must be a JSON object".into(),
86 ));
87 };
88 Part::Data(crate::models::a2a::DataPart { data })
89 },
90 _ => {
91 return Err(RepositoryError::InvalidData(format!(
92 "Unknown part kind: {}",
93 row.part_kind
94 )));
95 },
96 };
97
98 parts.push(part);
99 }
100
101 Ok(parts)
102}
103
104pub async fn persist_part_sqlx(
105 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
106 part: &Part,
107 message_id: &MessageId,
108 task_id: &TaskId,
109 sequence_number: i32,
110 upload_ctx: Option<&FileUploadContext<'_>>,
111) -> Result<(), RepositoryError> {
112 match part {
113 Part::Text(text_part) => {
114 sqlx::query!(
115 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
116 VALUES ($1, $2, 'text', $3, $4)"#,
117 message_id.as_str(),
118 task_id.as_str(),
119 sequence_number,
120 text_part.text
121 )
122 .execute(&mut **tx)
123 .await
124 .map_err(|e| RepositoryError::database(e))?;
125 },
126 Part::File(file_part) => {
127 let upload_result = try_upload_file(file_part, upload_ctx).await;
128
129 let (file_id, file_uri) = match upload_result {
130 Some((id, uri)) => (Some(id), Some(uri)),
131 None => (None, None),
132 };
133
134 sqlx::query!(
135 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
136 VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
137 message_id.as_str(),
138 task_id.as_str(),
139 sequence_number,
140 file_part.file.name,
141 file_part.file.mime_type,
142 file_uri,
143 file_part.file.bytes,
144 file_id
145 )
146 .execute(&mut **tx)
147 .await
148 .map_err(|e| RepositoryError::database(e))?;
149 },
150 Part::Data(data_part) => {
151 let data_json = serde_json::to_value(&data_part.data)
152 .map_err(|e| RepositoryError::Serialization(e))?;
153 sqlx::query!(
154 r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
155 VALUES ($1, $2, 'data', $3, $4)"#,
156 message_id.as_str(),
157 task_id.as_str(),
158 sequence_number,
159 data_json
160 )
161 .execute(&mut **tx)
162 .await
163 .map_err(|e| RepositoryError::database(e))?;
164 },
165 }
166
167 Ok(())
168}
169
170async fn try_upload_file(
171 file_part: &crate::models::a2a::FilePart,
172 upload_ctx: Option<&FileUploadContext<'_>>,
173) -> Option<(uuid::Uuid, String)> {
174 let ctx = upload_ctx?;
175
176 if !ctx.upload_provider.is_enabled() {
177 return None;
178 }
179
180 let mime_type = file_part
181 .file
182 .mime_type
183 .as_deref()
184 .unwrap_or("application/octet-stream");
185
186 let mut input = FileUploadInput::new(mime_type, &file_part.file.bytes, ctx.context_id.clone());
187
188 if let Some(name) = &file_part.file.name {
189 input = input.with_name(name);
190 }
191
192 if let Some(user_id) = ctx.user_id {
193 input = input.with_user_id(user_id.clone());
194 }
195
196 if let Some(session_id) = ctx.session_id {
197 input = input.with_session_id(session_id.clone());
198 }
199
200 if let Some(trace_id) = ctx.trace_id {
201 input = input.with_trace_id(trace_id.clone());
202 }
203
204 match ctx.upload_provider.upload_file(input).await {
205 Ok(uploaded) => {
206 let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
207 .map_err(|e| {
208 tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
209 e
210 })
211 .ok()?;
212 Some((file_uuid, uploaded.public_url))
213 },
214 Err(e) => {
215 tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
216 None
217 },
218 }
219}
220
221pub async fn persist_part_with_tx(
222 tx: &mut dyn systemprompt_database::DatabaseTransaction,
223 part: &Part,
224 message_id: &MessageId,
225 task_id: &TaskId,
226 sequence_number: i32,
227) -> Result<(), RepositoryError> {
228 let message_id_str = message_id.as_str();
229 let task_id_str = task_id.as_str();
230 match part {
231 Part::Text(text_part) => {
232 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
233 sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
234 tx.execute(
235 &query,
236 &[
237 &message_id_str,
238 &task_id_str,
239 &sequence_number,
240 &text_part.text,
241 ],
242 )
243 .await?;
244 },
245 Part::File(file_part) => {
246 let uri_opt: Option<&str> = None;
247 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
248 sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
249 VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
250 tx.execute(
251 &query,
252 &[
253 &message_id_str,
254 &task_id_str,
255 &sequence_number,
256 &file_part.file.name,
257 &file_part.file.mime_type,
258 &uri_opt,
259 &file_part.file.bytes,
260 ],
261 )
262 .await?;
263 },
264 Part::Data(data_part) => {
265 let data_json = serde_json::to_string(&data_part.data)?;
266 let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
267 sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
268 tx.execute(
269 &query,
270 &[&message_id_str, &task_id_str, &sequence_number, &data_json],
271 )
272 .await?;
273 },
274 }
275
276 Ok(())
277}