Skip to main content

systemprompt_agent/services/a2a_server/processing/
conversation_service.rs

1use anyhow::{anyhow, Result};
2use base64::Engine;
3use systemprompt_database::DbPool;
4use systemprompt_identifiers::ContextId;
5use systemprompt_models::{
6    is_supported_audio, is_supported_image, is_supported_text, is_supported_video, AiContentPart,
7    AiMessage, MessageRole,
8};
9
10use crate::models::a2a::{FilePart, Part};
11use crate::models::{Artifact, Message};
12use crate::repository::task::TaskRepository;
13
14#[derive(Debug)]
15pub struct ConversationService {
16    db_pool: DbPool,
17}
18
19impl ConversationService {
20    pub const fn new(db_pool: DbPool) -> Self {
21        Self { db_pool }
22    }
23
24    pub async fn load_conversation_history(
25        &self,
26        context_id: &ContextId,
27    ) -> Result<Vec<AiMessage>> {
28        let task_repo = TaskRepository::new(self.db_pool.clone());
29        let tasks = task_repo
30            .list_tasks_by_context(context_id)
31            .await
32            .map_err(|e| anyhow!("Failed to load conversation history: {}", e))?;
33
34        let mut history_messages = Vec::new();
35
36        for task in tasks {
37            if let Some(task_history) = task.history {
38                for msg in task_history {
39                    let (text, parts) = match self.extract_message_content(&msg) {
40                        Ok((t, p)) if !t.is_empty() || !p.is_empty() => (t, p),
41                        Ok(_) => continue,
42                        Err(e) => {
43                            tracing::warn!(error = %e, "Failed to extract message content");
44                            continue;
45                        },
46                    };
47
48                    let role = match msg.role.as_str() {
49                        "user" => MessageRole::User,
50                        "agent" => MessageRole::Assistant,
51                        _ => continue,
52                    };
53
54                    history_messages.push(AiMessage {
55                        role,
56                        content: text,
57                        parts,
58                    });
59                }
60            }
61
62            if let Some(artifacts) = task.artifacts {
63                for artifact in artifacts {
64                    if let Ok(artifact_content) = self.serialize_artifact_for_context(&artifact) {
65                        history_messages.push(AiMessage {
66                            role: MessageRole::Assistant,
67                            content: artifact_content,
68                            parts: Vec::new(),
69                        });
70                    }
71                }
72            }
73        }
74
75        Ok(history_messages)
76    }
77
78    fn extract_message_content(&self, message: &Message) -> Result<(String, Vec<AiContentPart>)> {
79        let mut text_content = String::new();
80        let mut content_parts = Vec::new();
81
82        for part in &message.parts {
83            match part {
84                Part::Text(text_part) => {
85                    if text_content.is_empty() {
86                        text_content.clone_from(&text_part.text);
87                    }
88                    content_parts.push(AiContentPart::text(&text_part.text));
89                },
90                Part::File(file_part) => {
91                    if let Some(content_part) = self.file_to_content_part(file_part) {
92                        content_parts.push(content_part);
93                    }
94                },
95                Part::Data(_) => {},
96            }
97        }
98
99        Ok((text_content, content_parts))
100    }
101
102    fn file_to_content_part(&self, file_part: &FilePart) -> Option<AiContentPart> {
103        let mime_type = file_part.file.mime_type.as_deref()?;
104        let file_name = file_part.file.name.as_deref().unwrap_or("unnamed");
105
106        if is_supported_image(mime_type) {
107            return Some(AiContentPart::image(mime_type, &file_part.file.bytes));
108        }
109
110        if is_supported_audio(mime_type) {
111            return Some(AiContentPart::audio(mime_type, &file_part.file.bytes));
112        }
113
114        if is_supported_video(mime_type) {
115            return Some(AiContentPart::video(mime_type, &file_part.file.bytes));
116        }
117
118        if is_supported_text(mime_type) {
119            return self.decode_text_file(file_part, file_name, mime_type);
120        }
121
122        tracing::warn!(
123            file_name = %file_name,
124            mime_type = %mime_type,
125            "Unsupported file type - file will not be sent to AI"
126        );
127        None
128    }
129
130    fn decode_text_file(
131        &self,
132        file_part: &FilePart,
133        file_name: &str,
134        mime_type: &str,
135    ) -> Option<AiContentPart> {
136        let decoded = base64::engine::general_purpose::STANDARD
137            .decode(&file_part.file.bytes)
138            .map_err(|e| {
139                tracing::warn!(
140                    file_name = %file_name,
141                    mime_type = %mime_type,
142                    error = %e,
143                    "Failed to decode base64 text file"
144                );
145                e
146            })
147            .ok()?;
148
149        let text_content = String::from_utf8(decoded)
150            .map_err(|e| {
151                tracing::warn!(
152                    file_name = %file_name,
153                    mime_type = %mime_type,
154                    error = %e,
155                    "Failed to decode text file as UTF-8"
156                );
157                e
158            })
159            .ok()?;
160
161        let formatted = format!("[File: {file_name} ({mime_type})]\n{text_content}");
162        Some(AiContentPart::text(formatted))
163    }
164
165    fn serialize_artifact_for_context(&self, artifact: &Artifact) -> Result<String> {
166        let artifact_name = artifact
167            .name
168            .clone()
169            .unwrap_or_else(|| "unnamed".to_string());
170
171        let mut content = format!(
172            "[Artifact: {} (type: {})]\n",
173            artifact_name, artifact.metadata.artifact_type
174        );
175
176        for part in &artifact.parts {
177            match part {
178                Part::Text(text_part) => {
179                    content.push_str(&text_part.text);
180                    content.push('\n');
181                },
182                Part::Data(data_part) => {
183                    let json_str = serde_json::to_string_pretty(&data_part.data)
184                        .unwrap_or_else(|_| "{}".to_string());
185                    content.push_str(&json_str);
186                    content.push('\n');
187                },
188                Part::File(file_part) => {
189                    if let Some(name) = &file_part.file.name {
190                        content.push_str(&format!("[File: {}]\n", name));
191                    }
192                },
193            }
194        }
195
196        Ok(content)
197    }
198}