Skip to main content

systemprompt_agent/repository/context/message/
parts.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
5
6use crate::models::a2a::Part;
7
8#[derive(Clone)]
9pub struct FileUploadContext<'a> {
10    pub upload_provider: &'a DynFileUploadProvider,
11    pub context_id: &'a ContextId,
12    pub user_id: Option<&'a UserId>,
13    pub session_id: Option<&'a SessionId>,
14    pub trace_id: Option<&'a TraceId>,
15}
16
17impl std::fmt::Debug for FileUploadContext<'_> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("FileUploadContext")
20            .field("upload_provider", &"<DynFileUploadProvider>")
21            .field("context_id", &self.context_id)
22            .field("user_id", &self.user_id)
23            .field("session_id", &self.session_id)
24            .field("trace_id", &self.trace_id)
25            .finish()
26    }
27}
28
29pub async fn get_message_parts(
30    pool: &Arc<PgPool>,
31    message_id: &MessageId,
32) -> Result<Vec<Part>, RepositoryError> {
33    let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
34        crate::models::MessagePart,
35        r#"SELECT
36            id as "id!",
37            message_id as "message_id!: MessageId",
38            task_id as "task_id!: TaskId",
39            part_kind as "part_kind!",
40            sequence_number as "sequence_number!",
41            text_content,
42            file_name,
43            file_mime_type,
44            file_uri,
45            file_bytes,
46            data_content,
47            metadata
48        FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
49        message_id.as_str()
50    )
51    .fetch_all(pool.as_ref())
52    .await
53    .map_err(RepositoryError::database)?;
54
55    let mut parts = Vec::new();
56
57    for row in part_rows {
58        let part = match row.part_kind.as_str() {
59            "text" => {
60                let text = row
61                    .text_content
62                    .ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
63                Part::Text(crate::models::a2a::TextPart { text })
64            },
65            "file" => {
66                let bytes = row
67                    .file_bytes
68                    .ok_or_else(|| RepositoryError::InvalidData("Missing file_bytes".into()))?;
69                Part::File(crate::models::a2a::FilePart {
70                    file: crate::models::a2a::FileWithBytes {
71                        name: row.file_name,
72                        mime_type: row.file_mime_type,
73                        bytes,
74                    },
75                })
76            },
77            "data" => {
78                let data_value = row
79                    .data_content
80                    .ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
81                let serde_json::Value::Object(data) = data_value else {
82                    return Err(RepositoryError::InvalidData(
83                        "Data content must be a JSON object".into(),
84                    ));
85                };
86                Part::Data(crate::models::a2a::DataPart { data })
87            },
88            _ => {
89                return Err(RepositoryError::InvalidData(format!(
90                    "Unknown part kind: {}",
91                    row.part_kind
92                )));
93            },
94        };
95
96        parts.push(part);
97    }
98
99    Ok(parts)
100}
101
102#[allow(missing_debug_implementations)]
103pub struct PersistPartSqlxParams<'a> {
104    pub tx: &'a mut sqlx::Transaction<'static, sqlx::Postgres>,
105    pub part: &'a Part,
106    pub message_id: &'a MessageId,
107    pub task_id: &'a TaskId,
108    pub sequence_number: i32,
109    pub upload_ctx: Option<&'a FileUploadContext<'a>>,
110}
111
112pub async fn persist_part_sqlx(params: PersistPartSqlxParams<'_>) -> Result<(), RepositoryError> {
113    let PersistPartSqlxParams {
114        tx,
115        part,
116        message_id,
117        task_id,
118        sequence_number,
119        upload_ctx,
120    } = params;
121    match part {
122        Part::Text(text_part) => {
123            sqlx::query!(
124                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
125                VALUES ($1, $2, 'text', $3, $4)"#,
126                message_id.as_str(),
127                task_id.as_str(),
128                sequence_number,
129                text_part.text
130            )
131            .execute(&mut **tx)
132            .await
133            .map_err(RepositoryError::database)?;
134        },
135        Part::File(file_part) => {
136            let upload_result = try_upload_file(file_part, upload_ctx).await;
137
138            let (file_id, file_uri) = match upload_result {
139                Some((id, uri)) => (Some(id), Some(uri)),
140                None => (None, None),
141            };
142
143            sqlx::query!(
144                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
145                VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
146                message_id.as_str(),
147                task_id.as_str(),
148                sequence_number,
149                file_part.file.name,
150                file_part.file.mime_type,
151                file_uri,
152                file_part.file.bytes,
153                file_id
154            )
155            .execute(&mut **tx)
156            .await
157            .map_err(RepositoryError::database)?;
158        },
159        Part::Data(data_part) => {
160            let data_json =
161                serde_json::to_value(&data_part.data).map_err(RepositoryError::Serialization)?;
162            sqlx::query!(
163                r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
164                VALUES ($1, $2, 'data', $3, $4)"#,
165                message_id.as_str(),
166                task_id.as_str(),
167                sequence_number,
168                data_json
169            )
170            .execute(&mut **tx)
171            .await
172            .map_err(RepositoryError::database)?;
173        },
174    }
175
176    Ok(())
177}
178
179async fn try_upload_file(
180    file_part: &crate::models::a2a::FilePart,
181    upload_ctx: Option<&FileUploadContext<'_>>,
182) -> Option<(uuid::Uuid, String)> {
183    let ctx = upload_ctx?;
184
185    if !ctx.upload_provider.is_enabled() {
186        return None;
187    }
188
189    let mime_type = file_part
190        .file
191        .mime_type
192        .as_deref()
193        .unwrap_or("application/octet-stream");
194
195    let mut input = FileUploadInput::new(mime_type, &file_part.file.bytes, ctx.context_id.clone());
196
197    if let Some(name) = &file_part.file.name {
198        input = input.with_name(name);
199    }
200
201    if let Some(user_id) = ctx.user_id {
202        input = input.with_user_id(user_id.clone());
203    }
204
205    if let Some(session_id) = ctx.session_id {
206        input = input.with_session_id(session_id.clone());
207    }
208
209    if let Some(trace_id) = ctx.trace_id {
210        input = input.with_trace_id(trace_id.clone());
211    }
212
213    match ctx.upload_provider.upload_file(input).await {
214        Ok(uploaded) => {
215            let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
216                .map_err(|e| {
217                    tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
218                    e
219                })
220                .ok()?;
221            Some((file_uuid, uploaded.public_url))
222        },
223        Err(e) => {
224            tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
225            None
226        },
227    }
228}
229
230pub async fn persist_part_with_tx(
231    tx: &mut dyn systemprompt_database::DatabaseTransaction,
232    part: &Part,
233    message_id: &MessageId,
234    task_id: &TaskId,
235    sequence_number: i32,
236) -> Result<(), RepositoryError> {
237    let message_id_str = message_id.as_str();
238    let task_id_str = task_id.as_str();
239    match part {
240        Part::Text(text_part) => {
241            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
242                               sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
243            tx.execute(
244                &query,
245                &[
246                    &message_id_str,
247                    &task_id_str,
248                    &sequence_number,
249                    &text_part.text,
250                ],
251            )
252            .await?;
253        },
254        Part::File(file_part) => {
255            let uri_opt: Option<&str> = None;
256            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
257                               sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
258                               VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
259            tx.execute(
260                &query,
261                &[
262                    &message_id_str,
263                    &task_id_str,
264                    &sequence_number,
265                    &file_part.file.name,
266                    &file_part.file.mime_type,
267                    &uri_opt,
268                    &file_part.file.bytes,
269                ],
270            )
271            .await?;
272        },
273        Part::Data(data_part) => {
274            let data_json = serde_json::to_string(&data_part.data)?;
275            let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
276                               sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
277            tx.execute(
278                &query,
279                &[&message_id_str, &task_id_str, &sequence_number, &data_json],
280            )
281            .await?;
282        },
283    }
284
285    Ok(())
286}