Skip to main content

systemprompt_agent/repository/context/message/
parts.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
5
6use crate::models::a2a::Part;
7
8#[derive(Clone)]
9pub struct FileUploadContext<'a> {
10    pub upload_provider: &'a DynFileUploadProvider,
11    pub context_id: &'a ContextId,
12    pub user_id: Option<&'a UserId>,
13    pub session_id: Option<&'a SessionId>,
14    pub trace_id: Option<&'a TraceId>,
15}
16
17impl std::fmt::Debug for FileUploadContext<'_> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("FileUploadContext")
20            .field("upload_provider", &"<DynFileUploadProvider>")
21            .field("context_id", &self.context_id)
22            .field("user_id", &self.user_id)
23            .field("session_id", &self.session_id)
24            .field("trace_id", &self.trace_id)
25            .finish()
26    }
27}
28
29pub async fn get_message_parts(
30    pool: &Arc<PgPool>,
31    message_id: &MessageId,
32) -> Result<Vec<Part>, RepositoryError> {
33    let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
34        crate::models::MessagePart,
35        r#"SELECT
36            id as "id!",
37            message_id as "message_id!: MessageId",
38            task_id as "task_id!: TaskId",
39            part_kind as "part_kind!",
40            sequence_number as "sequence_number!",
41            text_content,
42            file_name,
43            file_mime_type,
44            file_uri,
45            file_bytes,
46            data_content,
47            metadata
48        FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
49        message_id.as_str()
50    )
51    .fetch_all(pool.as_ref())
52    .await
53    .map_err(|e| RepositoryError::database(e))?;
54
55    let mut parts = Vec::new();
56
57    for row in part_rows {
58        let part = match row.part_kind.as_str() {
59            "text" => {
60                let text = row
61                    .text_content
62                    .ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
63                Part::Text(crate::models::a2a::TextPart { text })
64            },
65            "file" => {
66                let bytes = row
67                    .file_bytes
68                    .ok_or_else(|| RepositoryError::InvalidData("Missing file_bytes".into()))?;
69                Part::File(crate::models::a2a::FilePart {
70                    file: crate::models::a2a::FileWithBytes {
71                        name: row.file_name,
72                        mime_type: row.file_mime_type,
73                        bytes,
74                    },
75                })
76            },
77            "data" => {
78                let data_value = row
79                    .data_content
80                    .ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
81                let data = if let serde_json::Value::Object(map) = data_value {
82                    map
83                } else {
84                    return Err(RepositoryError::InvalidData(
85                        "Data content must be a JSON object".into(),
86                    ));
87                };
88                Part::Data(crate::models::a2a::DataPart { data })
89            },
90            _ => {
91                return Err(RepositoryError::InvalidData(format!(
92                    "Unknown part kind: {}",
93                    row.part_kind
94                )));
95            },
96        };
97
98        parts.push(part);
99    }
100
101    Ok(parts)
102}
103
104pub async fn persist_part_sqlx(
105    tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
106    part: &Part,
107    message_id: &MessageId,
108    task_id: &TaskId,
109    sequence_number: i32,
110    upload_ctx: Option<&FileUploadContext<'_>>,
111) -> Result<(), RepositoryError> {
112    match part {
113        Part::Text(text_part) => {
114            sqlx::query!(
115                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
116                VALUES ($1, $2, 'text', $3, $4)"#,
117                message_id.as_str(),
118                task_id.as_str(),
119                sequence_number,
120                text_part.text
121            )
122            .execute(&mut **tx)
123            .await
124            .map_err(|e| RepositoryError::database(e))?;
125        },
126        Part::File(file_part) => {
127            let upload_result = try_upload_file(file_part, upload_ctx).await;
128
129            let (file_id, file_uri) = match upload_result {
130                Some((id, uri)) => (Some(id), Some(uri)),
131                None => (None, None),
132            };
133
134            sqlx::query!(
135                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
136                VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
137                message_id.as_str(),
138                task_id.as_str(),
139                sequence_number,
140                file_part.file.name,
141                file_part.file.mime_type,
142                file_uri,
143                file_part.file.bytes,
144                file_id
145            )
146            .execute(&mut **tx)
147            .await
148            .map_err(|e| RepositoryError::database(e))?;
149        },
150        Part::Data(data_part) => {
151            let data_json = serde_json::to_value(&data_part.data)
152                .map_err(|e| RepositoryError::Serialization(e))?;
153            sqlx::query!(
154                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
155                VALUES ($1, $2, 'data', $3, $4)"#,
156                message_id.as_str(),
157                task_id.as_str(),
158                sequence_number,
159                data_json
160            )
161            .execute(&mut **tx)
162            .await
163            .map_err(|e| RepositoryError::database(e))?;
164        },
165    }
166
167    Ok(())
168}
169
170async fn try_upload_file(
171    file_part: &crate::models::a2a::FilePart,
172    upload_ctx: Option<&FileUploadContext<'_>>,
173) -> Option<(uuid::Uuid, String)> {
174    let ctx = upload_ctx?;
175
176    if !ctx.upload_provider.is_enabled() {
177        return None;
178    }
179
180    let mime_type = file_part
181        .file
182        .mime_type
183        .as_deref()
184        .unwrap_or("application/octet-stream");
185
186    let mut input = FileUploadInput::new(mime_type, &file_part.file.bytes, ctx.context_id.clone());
187
188    if let Some(name) = &file_part.file.name {
189        input = input.with_name(name);
190    }
191
192    if let Some(user_id) = ctx.user_id {
193        input = input.with_user_id(user_id.clone());
194    }
195
196    if let Some(session_id) = ctx.session_id {
197        input = input.with_session_id(session_id.clone());
198    }
199
200    if let Some(trace_id) = ctx.trace_id {
201        input = input.with_trace_id(trace_id.clone());
202    }
203
204    match ctx.upload_provider.upload_file(input).await {
205        Ok(uploaded) => {
206            let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
207                .map_err(|e| {
208                    tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
209                    e
210                })
211                .ok()?;
212            Some((file_uuid, uploaded.public_url))
213        },
214        Err(e) => {
215            tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
216            None
217        },
218    }
219}
220
221pub async fn persist_part_with_tx(
222    tx: &mut dyn systemprompt_database::DatabaseTransaction,
223    part: &Part,
224    message_id: &MessageId,
225    task_id: &TaskId,
226    sequence_number: i32,
227) -> Result<(), RepositoryError> {
228    let message_id_str = message_id.as_str();
229    let task_id_str = task_id.as_str();
230    match part {
231        Part::Text(text_part) => {
232            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
233                               sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
234            tx.execute(
235                &query,
236                &[
237                    &message_id_str,
238                    &task_id_str,
239                    &sequence_number,
240                    &text_part.text,
241                ],
242            )
243            .await?;
244        },
245        Part::File(file_part) => {
246            let uri_opt: Option<&str> = None;
247            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
248                               sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
249                               VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
250            tx.execute(
251                &query,
252                &[
253                    &message_id_str,
254                    &task_id_str,
255                    &sequence_number,
256                    &file_part.file.name,
257                    &file_part.file.mime_type,
258                    &uri_opt,
259                    &file_part.file.bytes,
260                ],
261            )
262            .await?;
263        },
264        Part::Data(data_part) => {
265            let data_json = serde_json::to_string(&data_part.data)?;
266            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
267                               sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
268            tx.execute(
269                &query,
270                &[&message_id_str, &task_id_str, &sequence_number, &data_json],
271            )
272            .await?;
273        },
274    }
275
276    Ok(())
277}