Skip to main content

systemprompt_agent/repository/context/message/
parts.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
5
6use crate::models::a2a::Part;
7
8#[derive(Clone)]
9pub struct FileUploadContext<'a> {
10    pub upload_provider: &'a DynFileUploadProvider,
11    pub context_id: &'a ContextId,
12    pub user_id: Option<&'a UserId>,
13    pub session_id: Option<&'a SessionId>,
14    pub trace_id: Option<&'a TraceId>,
15}
16
17impl std::fmt::Debug for FileUploadContext<'_> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("FileUploadContext")
20            .field("upload_provider", &"<DynFileUploadProvider>")
21            .field("context_id", &self.context_id)
22            .field("user_id", &self.user_id)
23            .field("session_id", &self.session_id)
24            .field("trace_id", &self.trace_id)
25            .finish()
26    }
27}
28
29pub async fn get_message_parts(
30    pool: &Arc<PgPool>,
31    message_id: &MessageId,
32) -> Result<Vec<Part>, RepositoryError> {
33    let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
34        crate::models::MessagePart,
35        r#"SELECT
36            id as "id!",
37            message_id as "message_id!: MessageId",
38            task_id as "task_id!: TaskId",
39            part_kind as "part_kind!",
40            sequence_number as "sequence_number!",
41            text_content,
42            file_name,
43            file_mime_type,
44            file_uri,
45            file_bytes,
46            data_content,
47            metadata
48        FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
49        message_id.as_str()
50    )
51    .fetch_all(pool.as_ref())
52    .await
53    .map_err(RepositoryError::database)?;
54
55    let mut parts = Vec::new();
56
57    for row in part_rows {
58        let part = match row.part_kind.as_str() {
59            "text" => {
60                let text = row
61                    .text_content
62                    .ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
63                Part::Text(crate::models::a2a::TextPart { text })
64            },
65            "file" => Part::File(crate::models::a2a::FilePart {
66                file: crate::models::a2a::FileContent {
67                    name: row.file_name,
68                    mime_type: row.file_mime_type,
69                    bytes: row.file_bytes,
70                    url: row.file_uri,
71                },
72            }),
73            "data" => {
74                let data_value = row
75                    .data_content
76                    .ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
77                let serde_json::Value::Object(data) = data_value else {
78                    return Err(RepositoryError::InvalidData(
79                        "Data content must be a JSON object".into(),
80                    ));
81                };
82                Part::Data(crate::models::a2a::DataPart { data })
83            },
84            _ => {
85                return Err(RepositoryError::InvalidData(format!(
86                    "Unknown part kind: {}",
87                    row.part_kind
88                )));
89            },
90        };
91
92        parts.push(part);
93    }
94
95    Ok(parts)
96}
97
98#[allow(missing_debug_implementations)]
99pub struct PersistPartSqlxParams<'a> {
100    pub tx: &'a mut sqlx::Transaction<'static, sqlx::Postgres>,
101    pub part: &'a Part,
102    pub message_id: &'a MessageId,
103    pub task_id: &'a TaskId,
104    pub sequence_number: i32,
105    pub upload_ctx: Option<&'a FileUploadContext<'a>>,
106}
107
108pub async fn persist_part_sqlx(params: PersistPartSqlxParams<'_>) -> Result<(), RepositoryError> {
109    let PersistPartSqlxParams {
110        tx,
111        part,
112        message_id,
113        task_id,
114        sequence_number,
115        upload_ctx,
116    } = params;
117    match part {
118        Part::Text(text_part) => {
119            sqlx::query!(
120                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
121                VALUES ($1, $2, 'text', $3, $4)"#,
122                message_id.as_str(),
123                task_id.as_str(),
124                sequence_number,
125                text_part.text
126            )
127            .execute(&mut **tx)
128            .await
129            .map_err(RepositoryError::database)?;
130        },
131        Part::File(file_part) => {
132            let upload_result = try_upload_file(file_part, upload_ctx).await;
133
134            let (file_id, file_uri) = match upload_result {
135                Some((id, uri)) => (Some(id), Some(uri)),
136                None => (None, None),
137            };
138
139            sqlx::query!(
140                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
141                VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
142                message_id.as_str(),
143                task_id.as_str(),
144                sequence_number,
145                file_part.file.name,
146                file_part.file.mime_type,
147                file_uri,
148                file_part.file.bytes.as_deref(),
149                file_id
150            )
151            .execute(&mut **tx)
152            .await
153            .map_err(RepositoryError::database)?;
154        },
155        Part::Data(data_part) => {
156            let data_json =
157                serde_json::to_value(&data_part.data).map_err(RepositoryError::Serialization)?;
158            sqlx::query!(
159                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
160                VALUES ($1, $2, 'data', $3, $4)"#,
161                message_id.as_str(),
162                task_id.as_str(),
163                sequence_number,
164                data_json
165            )
166            .execute(&mut **tx)
167            .await
168            .map_err(RepositoryError::database)?;
169        },
170    }
171
172    Ok(())
173}
174
175async fn try_upload_file(
176    file_part: &crate::models::a2a::FilePart,
177    upload_ctx: Option<&FileUploadContext<'_>>,
178) -> Option<(uuid::Uuid, String)> {
179    let ctx = upload_ctx?;
180
181    if !ctx.upload_provider.is_enabled() {
182        return None;
183    }
184
185    let mime_type = file_part
186        .file
187        .mime_type
188        .as_deref()
189        .unwrap_or("application/octet-stream");
190
191    let bytes = file_part.file.bytes.as_deref()?;
192    let mut input = FileUploadInput::new(mime_type, bytes, ctx.context_id.clone());
193
194    if let Some(name) = &file_part.file.name {
195        input = input.with_name(name);
196    }
197
198    if let Some(user_id) = ctx.user_id {
199        input = input.with_user_id(user_id.clone());
200    }
201
202    if let Some(session_id) = ctx.session_id {
203        input = input.with_session_id(session_id.clone());
204    }
205
206    if let Some(trace_id) = ctx.trace_id {
207        input = input.with_trace_id(trace_id.clone());
208    }
209
210    match ctx.upload_provider.upload_file(input).await {
211        Ok(uploaded) => {
212            let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
213                .map_err(|e| {
214                    tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
215                    e
216                })
217                .ok()?;
218            Some((file_uuid, uploaded.public_url))
219        },
220        Err(e) => {
221            tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
222            None
223        },
224    }
225}
226
227pub async fn persist_part_with_tx(
228    tx: &mut dyn systemprompt_database::DatabaseTransaction,
229    part: &Part,
230    message_id: &MessageId,
231    task_id: &TaskId,
232    sequence_number: i32,
233) -> Result<(), RepositoryError> {
234    let message_id_str = message_id.as_str();
235    let task_id_str = task_id.as_str();
236    match part {
237        Part::Text(text_part) => {
238            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
239                               sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
240            tx.execute(
241                &query,
242                &[
243                    &message_id_str,
244                    &task_id_str,
245                    &sequence_number,
246                    &text_part.text,
247                ],
248            )
249            .await?;
250        },
251        Part::File(file_part) => {
252            let uri_opt: Option<&str> = None;
253            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
254                               sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
255                               VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
256            tx.execute(
257                &query,
258                &[
259                    &message_id_str,
260                    &task_id_str,
261                    &sequence_number,
262                    &file_part.file.name,
263                    &file_part.file.mime_type,
264                    &uri_opt,
265                    &file_part.file.bytes.as_deref(),
266                ],
267            )
268            .await?;
269        },
270        Part::Data(data_part) => {
271            let data_json = serde_json::to_string(&data_part.data)?;
272            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
273                               sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
274            tx.execute(
275                &query,
276                &[&message_id_str, &task_id_str, &sequence_number, &data_json],
277            )
278            .await?;
279        },
280    }
281
282    Ok(())
283}